NIR N-MNIST SCNN
In [1]:
Copied!
import spyx
import spyx.nn as snn
import jax
import jax.numpy as jnp
import numpy as np
import nir
import matplotlib.pyplot as plt
import spyx
import spyx.nn as snn
import jax
import jax.numpy as jnp
import numpy as np
import nir
import matplotlib.pyplot as plt
In [2]:
Copied!
input_data = jnp.transpose(jnp.array(np.load("val_numbers.npy"), dtype=jnp.float32), axes=(1,0,2,3,4))
input_data = jnp.transpose(jnp.array(np.load("val_numbers.npy"), dtype=jnp.float32), axes=(1,0,2,3,4))
In [3]:
Copied!
input_data.shape # had to transpose data because it's stored in time-major format...
input_data.shape # had to transpose data because it's stored in time-major format...
Out[3]:
(10, 300, 2, 34, 34)
In [4]:
Copied!
ng = nir.read("scnn_mnist.nir")
ng = nir.read("scnn_mnist.nir")
In [5]:
Copied!
layers_to_grab = ["input", "0", "1", "output"]
new_nodes = {k:ng.nodes[k] for k in layers_to_grab}
new_edges = [("input", "0"), ("0","1"), ("1", "output")]
first_conv = nir.NIRGraph(new_nodes, new_edges)
first_conv.nodes["output"].output_type["output"] = np.array([16,16,16])
layers_to_grab = ["input", "0", "1", "output"]
new_nodes = {k:ng.nodes[k] for k in layers_to_grab}
new_edges = [("input", "0"), ("0","1"), ("1", "output")]
first_conv = nir.NIRGraph(new_nodes, new_edges)
first_conv.nodes["output"].output_type["output"] = np.array([16,16,16])
In [6]:
Copied!
fl_SNN, fl_params = spyx.nir.from_nir(first_conv, input_data, dt=1, return_all_states=True)
fl_SNN, fl_params = spyx.nir.from_nir(first_conv, input_data, dt=1, return_all_states=True)
In [7]:
Copied!
output_spikes, membrane_potentials = fl_SNN.apply(fl_params, input_data)
output_spikes, membrane_potentials = fl_SNN.apply(fl_params, input_data)
In [8]:
Copied!
save_data = np.array(jnp.transpose(output_spikes, (1,0,2,3,4)))
save_data.shape
save_data = np.array(jnp.transpose(output_spikes, (1,0,2,3,4)))
save_data.shape
Out[8]:
(300, 10, 16, 16, 16)
In [9]:
Copied!
spyx_act = jnp.sum(output_spikes[0], axis=0)
spyx_act = jnp.sum(output_spikes[0], axis=0)
In [10]:
Copied!
spyx_act.shape
spyx_act.shape
Out[10]:
(16, 16, 16)
In [11]:
Copied!
plt.imshow(jnp.sum(spyx_act, axis=(0)))
plt.colorbar()
plt.title("Spyx SCNN Spiking, first Conv+IF Layer")
plt.show()
plt.imshow(jnp.sum(spyx_act, axis=(0)))
plt.colorbar()
plt.title("Spyx SCNN Spiking, first Conv+IF Layer")
plt.show()
In [12]:
Copied!
import numpy as np
snntorch_act = np.sum(np.load("./snnTorch_activity.npy", "r"), axis=0)
import numpy as np
snntorch_act = np.sum(np.load("./snnTorch_activity.npy", "r"), axis=0)
In [13]:
Copied!
plt.imshow(jnp.sum(snntorch_act[0], axis=(0)))
plt.colorbar()
plt.title("snnTorch SCNN First Conv+IF Layer")
plt.show()
plt.imshow(jnp.sum(snntorch_act[0], axis=(0)))
plt.colorbar()
plt.title("snnTorch SCNN First Conv+IF Layer")
plt.show()
In [14]:
Copied!
def cosine_similarity(vector1, vector2):
# Compute the dot product of the two vectors
dot_product = jnp.dot(vector1, vector2)
# Compute the magnitude (L2 norm) of each vector
magnitude1 = jnp.linalg.norm(vector1)
magnitude2 = jnp.linalg.norm(vector2)
# Compute the cosine similarity
similarity = dot_product / (magnitude1 * magnitude2)
return similarity
def cosine_similarity(vector1, vector2):
# Compute the dot product of the two vectors
dot_product = jnp.dot(vector1, vector2)
# Compute the magnitude (L2 norm) of each vector
magnitude1 = jnp.linalg.norm(vector1)
magnitude2 = jnp.linalg.norm(vector2)
# Compute the cosine similarity
similarity = dot_product / (magnitude1 * magnitude2)
return similarity
In [15]:
Copied!
cosine_similarity(jnp.sum(spyx_act, 0).flatten(), jnp.sum(snntorch_act[0], 0).flatten())
cosine_similarity(jnp.sum(spyx_act, 0).flatten(), jnp.sum(snntorch_act[0], 0).flatten())
Out[15]:
Array(0.9865636, dtype=float32)
In [16]:
Copied!
np.save("spyx_activity.npy", save_data)
np.save("spyx_activity.npy", save_data)
Inference Accuracy¶
In [17]:
Copied!
import tonic
import torch
bs = 128
collate = tonic.collation.PadTensors(batch_first=False)
to_frame = tonic.transforms.ToFrame(sensor_size=tonic.datasets.NMNIST.sensor_size, time_window=1e3)
test_ds = tonic.datasets.NMNIST("./nmnist", transform=to_frame, train=False)
test_dl = torch.utils.data.DataLoader(test_ds, shuffle=True, batch_size=bs, collate_fn=collate)
import tonic
import torch
bs = 128
collate = tonic.collation.PadTensors(batch_first=False)
to_frame = tonic.transforms.ToFrame(sensor_size=tonic.datasets.NMNIST.sensor_size, time_window=1e3)
test_ds = tonic.datasets.NMNIST("./nmnist", transform=to_frame, train=False)
test_dl = torch.utils.data.DataLoader(test_ds, shuffle=True, batch_size=bs, collate_fn=collate)
In [18]:
Copied!
SNN, params = spyx.nir.from_nir(ng, input_data, dt=1)
SNN, params = spyx.nir.from_nir(ng, input_data, dt=1)
In [19]:
Copied!
accs = []
for (x, y) in test_dl:
x = jnp.transpose(jnp.array(x), (1,0,2,3,4))
spikes, V = SNN.apply(params, x)
acc, preds = spyx.fn.integral_accuracy(spikes, jnp.array(y))
accs.append(acc)
accs = []
for (x, y) in test_dl:
x = jnp.transpose(jnp.array(x), (1,0,2,3,4))
spikes, V = SNN.apply(params, x)
acc, preds = spyx.fn.integral_accuracy(spikes, jnp.array(y))
accs.append(acc)
In [20]:
Copied!
final_acc = np.mean(np.array(accs))
final_acc
final_acc = np.mean(np.array(accs))
final_acc
Out[20]:
0.9713212
In [21]:
Copied!
np.save("spyx_accuracy.npy", final_acc)
np.save("spyx_accuracy.npy", final_acc)