InĀ [1]:
Copied!
import spyx
import spyx.nn as snn
import jax
import jax.numpy as jnp
import numpy as np
import nir
# for loading dataset
import torch
import matplotlib.pyplot as plt
import spyx
import spyx.nn as snn
import jax
import jax.numpy as jnp
import numpy as np
import nir
# for loading dataset
import torch
import matplotlib.pyplot as plt
InĀ [2]:
Copied!
ng = nir.read("braille_noDelay_noBias_subtract_subgraph.nir")
ng = nir.read("braille_noDelay_noBias_subtract_subgraph.nir")
InĀ [3]:
Copied!
data = torch.load("ds_test.pt")
data = torch.load("ds_test.pt")
InĀ [4]:
Copied!
x, y = data.tensors
x, y = data.tensors
InĀ [5]:
Copied!
x = jnp.array(x)
x = jnp.array(x)
InĀ [6]:
Copied!
y = jnp.array(y)
y = jnp.array(y)
InĀ [7]:
Copied!
x.shape # --> 256 time
x.shape # --> 256 time
Out[7]:
(140, 256, 12)
InĀ [8]:
Copied!
ng.nodes.keys()
ng.nodes.keys()
Out[8]:
dict_keys(['fc1', 'fc2', 'input', 'lif1', 'lif2', 'output'])
InĀ [9]:
Copied!
ng.edges
ng.edges
Out[9]:
[('input', 'fc1'),
('fc1', 'lif1'),
('lif1', 'fc2'),
('fc2', 'lif2'),
('lif2', 'output')]
InĀ [10]:
Copied!
n_list = ["input", "fc1", "lif1", "output"]
subgraph_nodes = {k : ng.nodes[k] for k in n_list}
subgraph_edges = [
('input', 'fc1'),
('fc1', 'lif1'),
('lif1', 'output')
]
subgraph = nir.NIRGraph(subgraph_nodes, subgraph_edges)
n_list = ["input", "fc1", "lif1", "output"]
subgraph_nodes = {k : ng.nodes[k] for k in n_list}
subgraph_edges = [
('input', 'fc1'),
('fc1', 'lif1'),
('lif1', 'output')
]
subgraph = nir.NIRGraph(subgraph_nodes, subgraph_edges)
InĀ [11]:
Copied!
subgraph.nodes["output"].output_type['output'] = np.array([40])
subgraph.nodes["output"].output_type['output'] = np.array([40])
InĀ [12]:
Copied!
SNN, params = spyx.nir.from_nir(subgraph, x, dt=1e-4, return_all_states=True)
SNN, params = spyx.nir.from_nir(subgraph, x, dt=1e-4, return_all_states=True)
found subgraph, trying to parse as RNN found subgraph, trying to parse as RNN
InĀ [13]:
Copied!
params.keys()
params.keys()
Out[13]:
dict_keys(['linear', 'RCuBaLIF'])
InĀ [14]:
Copied!
a, b = SNN.apply(params, x)
a, b = SNN.apply(params, x)
found subgraph, trying to parse as RNN
InĀ [15]:
Copied!
plt.figure(figsize=(12,4))
plt.imshow(x[0].T, aspect=4, interpolation="none")
plt.show()
plt.figure(figsize=(12,4))
plt.imshow(x[0].T, aspect=4, interpolation="none")
plt.show()
InĀ [16]:
Copied!
a.shape
a.shape
Out[16]:
(140, 256, 40)
InĀ [17]:
Copied!
plt.figure(figsize=(12,4))
plt.imshow(a[0].T, aspect=4, interpolation="none")
plt.show()
plt.figure(figsize=(12,4))
plt.imshow(a[0].T, aspect=4, interpolation="none")
plt.show()
InĀ [18]:
Copied!
np.save("spyx_activity_noDelay_noBias_subtract.npy", a[0])
np.save("spyx_activity_noDelay_noBias_subtract.npy", a[0])
zero partial network¶
InĀ [19]:
Copied!
ng = nir.read("braille_noDelay_bias_zero_subgraph.nir")
ng = nir.read("braille_noDelay_bias_zero_subgraph.nir")
InĀ [20]:
Copied!
n_list = ["input", "fc1", "lif1", "output"]
subgraph_nodes = {k : ng.nodes[k] for k in n_list}
subgraph_edges = [
('input', 'fc1'),
('fc1', 'lif1'),
('lif1', 'output')
]
subgraph = nir.NIRGraph(subgraph_nodes, subgraph_edges)
n_list = ["input", "fc1", "lif1", "output"]
subgraph_nodes = {k : ng.nodes[k] for k in n_list}
subgraph_edges = [
('input', 'fc1'),
('fc1', 'lif1'),
('lif1', 'output')
]
subgraph = nir.NIRGraph(subgraph_nodes, subgraph_edges)
InĀ [21]:
Copied!
subgraph.nodes["output"].output_type['output'] = np.array([40])
subgraph.nodes["output"].output_type['output'] = np.array([40])
InĀ [22]:
Copied!
SNN, params = spyx.nir.from_nir(subgraph, x, dt=1e-4, return_all_states=True)
SNN, params = spyx.nir.from_nir(subgraph, x, dt=1e-4, return_all_states=True)
found subgraph, trying to parse as RNN found subgraph, trying to parse as RNN
InĀ [23]:
Copied!
params.keys()
params.keys()
Out[23]:
dict_keys(['linear', 'RCuBaLIF'])
InĀ [24]:
Copied!
a, b = SNN.apply(params, x)
a, b = SNN.apply(params, x)
found subgraph, trying to parse as RNN
InĀ [25]:
Copied!
plt.figure(figsize=(12,4))
plt.imshow(x[0].T, aspect=4, interpolation="none")
plt.show()
plt.figure(figsize=(12,4))
plt.imshow(x[0].T, aspect=4, interpolation="none")
plt.show()
InĀ [26]:
Copied!
a.shape
a.shape
Out[26]:
(140, 256, 38)
InĀ [27]:
Copied!
plt.figure(figsize=(12,4))
plt.imshow(a[0].T, aspect=4, interpolation="none")
plt.show()
plt.figure(figsize=(12,4))
plt.imshow(a[0].T, aspect=4, interpolation="none")
plt.show()
InĀ [28]:
Copied!
np.save("spyx_activity_noDelay_bias_zero.npy", a[0])
np.save("spyx_activity_noDelay_bias_zero.npy", a[0])
InĀ [Ā ]:
Copied!
InĀ [29]:
Copied!
ng = nir.read("braille_noDelay_noBias_subtract_subgraph.nir")
ng = nir.read("braille_noDelay_noBias_subtract_subgraph.nir")
InĀ [30]:
Copied!
SNN, params = spyx.nir.from_nir(ng, x, dt=1e-4, return_all_states=True)
SNN, params = spyx.nir.from_nir(ng, x, dt=1e-4, return_all_states=True)
found subgraph, trying to parse as RNN found subgraph, trying to parse as RNN
InĀ [31]:
Copied!
spyx_rec = params["RCuBaLIF"]["w"].T.flatten()
spyx_inp = params["linear"]["w"].T.flatten()
spyx_rec = params["RCuBaLIF"]["w"].T.flatten()
spyx_inp = params["linear"]["w"].T.flatten()
InĀ [32]:
Copied!
from snntorch import import_nirtorch
nirgraph = nir.read("braille_noDelay_noBias_subtract_subgraph.nir")
net_snn = import_nirtorch.from_nir(nirgraph)
snn_rec = net_snn.lif1.recurrent.weight.detach().reshape(-1)
print('w_rec close?', jnp.allclose(spyx_rec, snn_rec.numpy()))
snn_inp = net_snn.fc1.weight.detach().reshape(-1)
print('fc1 close?', jnp.allclose(spyx_inp, snn_inp.numpy()))
from snntorch import import_nirtorch
nirgraph = nir.read("braille_noDelay_noBias_subtract_subgraph.nir")
net_snn = import_nirtorch.from_nir(nirgraph)
snn_rec = net_snn.lif1.recurrent.weight.detach().reshape(-1)
print('w_rec close?', jnp.allclose(spyx_rec, snn_rec.numpy()))
snn_inp = net_snn.fc1.weight.detach().reshape(-1)
print('fc1 close?', jnp.allclose(spyx_inp, snn_inp.numpy()))
replace rnn subgraph with nirgraph w_rec close? True fc1 close? True
InĀ [33]:
Copied!
a, b = SNN.apply(params, x)
a, b = SNN.apply(params, x)
found subgraph, trying to parse as RNN
InĀ [34]:
Copied!
plt.figure(figsize=(12,4))
plt.imshow(a[0].T, aspect=10, interpolation="none")
plt.figure(figsize=(12,4))
plt.imshow(a[0].T, aspect=10, interpolation="none")
Out[34]:
<matplotlib.image.AxesImage at 0x7f47302fc190>
InĀ [35]:
Copied!
y
y
Out[35]:
Array([1, 3, 2, 2, 6, 1, 1, 3, 4, 5, 4, 0, 5, 5, 0, 2, 4, 3, 1, 2, 5, 2,
4, 6, 2, 2, 4, 1, 4, 4, 1, 3, 2, 0, 4, 5, 1, 0, 3, 5, 1, 2, 0, 4,
5, 4, 5, 6, 6, 1, 4, 5, 0, 2, 3, 4, 5, 0, 2, 5, 5, 5, 6, 5, 6, 4,
1, 2, 6, 1, 0, 0, 6, 4, 0, 3, 3, 0, 1, 6, 2, 0, 3, 1, 0, 1, 2, 0,
3, 0, 0, 0, 4, 6, 1, 3, 2, 5, 2, 6, 0, 5, 5, 0, 3, 1, 6, 6, 3, 2,
4, 4, 6, 3, 6, 2, 2, 5, 3, 6, 2, 1, 3, 6, 5, 4, 5, 4, 1, 6, 3, 0,
3, 6, 3, 1, 6, 4, 3, 1], dtype=int32)
InĀ [36]:
Copied!
acc, preds = spyx.fn.integral_accuracy(a, y)
acc
acc, preds = spyx.fn.integral_accuracy(a, y)
acc
Out[36]:
Array(0.92142856, dtype=float32)
InĀ [37]:
Copied!
np.save("spyx_accuracy_noDelay_noBias_subtract.npy", acc)
np.save("spyx_accuracy_noDelay_noBias_subtract.npy", acc)
Bias + Zero reset¶
InĀ [38]:
Copied!
ng = nir.read("braille_noDelay_bias_zero_subgraph.nir")
ng = nir.read("braille_noDelay_bias_zero_subgraph.nir")
InĀ [39]:
Copied!
SNN, params = spyx.nir.from_nir(ng, x, dt=1e-4, return_all_states=True)
SNN, params = spyx.nir.from_nir(ng, x, dt=1e-4, return_all_states=True)
found subgraph, trying to parse as RNN found subgraph, trying to parse as RNN
InĀ [40]:
Copied!
spyx_rec = params["RCuBaLIF"]["w"].T.flatten()
spyx_inp = params["linear"]["w"].T.flatten()
spyx_rec = params["RCuBaLIF"]["w"].T.flatten()
spyx_inp = params["linear"]["w"].T.flatten()
InĀ [41]:
Copied!
from snntorch import import_nirtorch
nirgraph = nir.read("braille_noDelay_bias_zero_subgraph.nir")
net_snn = import_nirtorch.from_nir(nirgraph)
snn_rec = net_snn.lif1.recurrent.weight.detach().reshape(-1)
print('w_rec close?', jnp.allclose(spyx_rec, snn_rec.numpy()))
snn_inp = net_snn.fc1.weight.detach().reshape(-1)
print('fc1 close?', jnp.allclose(spyx_inp, snn_inp.numpy()))
from snntorch import import_nirtorch
nirgraph = nir.read("braille_noDelay_bias_zero_subgraph.nir")
net_snn = import_nirtorch.from_nir(nirgraph)
snn_rec = net_snn.lif1.recurrent.weight.detach().reshape(-1)
print('w_rec close?', jnp.allclose(spyx_rec, snn_rec.numpy()))
snn_inp = net_snn.fc1.weight.detach().reshape(-1)
print('fc1 close?', jnp.allclose(spyx_inp, snn_inp.numpy()))
replace rnn subgraph with nirgraph w_rec close? True fc1 close? True
InĀ [42]:
Copied!
a, b = SNN.apply(params, x)
a, b = SNN.apply(params, x)
found subgraph, trying to parse as RNN
InĀ [43]:
Copied!
plt.figure(figsize=(12,4))
plt.imshow(a[0].T, aspect=10, interpolation="none")
plt.figure(figsize=(12,4))
plt.imshow(a[0].T, aspect=10, interpolation="none")
Out[43]:
<matplotlib.image.AxesImage at 0x7f4730385c30>
InĀ [44]:
Copied!
y
y
Out[44]:
Array([1, 3, 2, 2, 6, 1, 1, 3, 4, 5, 4, 0, 5, 5, 0, 2, 4, 3, 1, 2, 5, 2,
4, 6, 2, 2, 4, 1, 4, 4, 1, 3, 2, 0, 4, 5, 1, 0, 3, 5, 1, 2, 0, 4,
5, 4, 5, 6, 6, 1, 4, 5, 0, 2, 3, 4, 5, 0, 2, 5, 5, 5, 6, 5, 6, 4,
1, 2, 6, 1, 0, 0, 6, 4, 0, 3, 3, 0, 1, 6, 2, 0, 3, 1, 0, 1, 2, 0,
3, 0, 0, 0, 4, 6, 1, 3, 2, 5, 2, 6, 0, 5, 5, 0, 3, 1, 6, 6, 3, 2,
4, 4, 6, 3, 6, 2, 2, 5, 3, 6, 2, 1, 3, 6, 5, 4, 5, 4, 1, 6, 3, 0,
3, 6, 3, 1, 6, 4, 3, 1], dtype=int32)
InĀ [45]:
Copied!
acc, preds = spyx.fn.integral_accuracy(a, y)
acc
acc, preds = spyx.fn.integral_accuracy(a, y)
acc
Out[45]:
Array(0.8428571, dtype=float32)
InĀ [46]:
Copied!
np.save("spyx_accuracy_noDelay_bias_zero.npy", acc)
np.save("spyx_accuracy_noDelay_bias_zero.npy", acc)