Surrogate Gradient Template#

import spyx
import spyx.nn as snn

# JAX imports
import os
import jax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".80"
from jax import numpy as jnp
import jmp
import numpy as np

from jax_tqdm import scan_tqdm
from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# for surrogate loss training.
import optax

# rendering tools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
%matplotlib notebook

Set Mixed Precision Policy#

policy = jmp.get_policy('half')


hk.mixed_precision.set_policy(hk.Linear, policy)
hk.mixed_precision.set_policy(snn.LIF, policy)
hk.mixed_precision.set_policy(snn.LI, policy)
hk.mixed_precision.set_policy(spyx.nn.ActivityRegularization, policy)

Data Loading#

For the template, the Spiking Heidelberg Dataset is used. Feel free to replace this section with your own dataloading pipeline, just be sure once you’ve loaded and processed your whole dataset into RAM as a single numpy array and call np.packbits(data, axis=<time axis>) to compress the data prior to converting it to a jax.numpy array.

shd_dl = spyx.loaders.SHD_loader(256,128,128)
key = jax.random.PRNGKey(0)
x, y = shd_dl.train_epoch(key)

SNN#

Here we define a simple feed-forward SNN using Haiku’s RNN features, incorporating our LIF neuron models where activation functions would usually go. Haiku manages all of the state for us, so when we transform the function and get an apply() function we just need to pass the params!

Since spiking neurons have a discrete all-or-nothing activation, in order to do gradient descent we’ll have to approximate the derivative of the Heaviside function with something smoother. In this case, we use the SuperSpike surrogate gradient from Zenke & Ganguli 2017. Also not that we aren’t using bias terms on the linear layers and since the inputs are images, we flatten the data before feeding it to the first layer.

Depending on computational constraints, we can use haiku’s dynamic unroll to iterate the SNN, or we can use static unroll where the SNN will be unrolled during the JIT compiling process to further increase speed when training on GPU. Note that the static unroll will take longer to compile, but once it runs the iterations per second will be 2x-3x greater than the dynamic unroll.

def shd_snn(x):
    
    x = hk.BatchApply(hk.Linear(64, with_bias=False))(x)
    
    core = hk.DeepRNN([
        snn.LIF((64,), activation=spyx.axn.triangular()),
        spyx.nn.ActivityRegularization(),
        hk.Linear(64, with_bias=False),
        snn.LIF((64,), activation=spyx.axn.triangular()),
        spyx.nn.ActivityRegularization(),
        hk.Linear(20, with_bias=False),
        snn.LI((20,))
    ])
    
    # static unroll for maximum performance
    spikes, V = hk.dynamic_unroll(core, x, core.initial_state(x.shape[0]), time_major=False, unroll=32)
    
    return spikes, V
key = jax.random.PRNGKey(0)
# Since there's nothing stochastic about the network, we can avoid using an RNG as a param!
SNN = hk.without_apply_rng(hk.transform_with_state(shd_snn))
params, reg_init = SNN.init(rng=key, x=x[0])

Gradient Descent#

def gd(SNN, params, dl, epochs=300, schedule=3e-4):
    
    aug = spyx.data.shift_augment(max_shift=16) # need to make this stateless

    # These are not finely tuned, so you'll likely want to adjust them.
    u_reg = spyx.fn.sparsity_reg(16)
    l_reg = spyx.fn.silence_reg(4)
    
    Loss = spyx.fn.integral_crossentropy()
    Acc = spyx.fn.integral_accuracy()
    
    opt = optax.chain(
        optax.centralize(),
        optax.lion(learning_rate=schedule),
    )
    # create and initialize the optimizer
    opt_state = opt.init(params)
    grad_params = params
        
    # define and compile our eval function that computes the loss for our SNN
    @jax.jit
    def net_eval(weights, events, targets):
        readout, spike_counts = SNN.apply(weights, reg_init, events)
        traces, V_f = readout
        return Loss(traces, targets) + \
            0.0*(u_reg(spike_counts) + l_reg(spike_counts))
    # adjust the regularization factor above accordingly.
        
    # Use JAX to create a function that calculates the loss and the gradient!
    surrogate_grad = jax.value_and_grad(net_eval) 
        
    rng = jax.random.PRNGKey(0)        
    
    # compile the meat of our training loop for speed
    @jax.jit
    def train_step(state, data):
        grad_params, opt_state = state
        events, targets = data # fix this
        events = jnp.unpackbits(events, axis=1) # decompress temporal axis
        events = aug(events, jax.random.fold_in(rng,jnp.sum(targets)))
        # compute loss and gradient
        loss, grads = surrogate_grad(grad_params, events, targets)
        # generate updates based on the gradients and optimizer
        updates, opt_state = opt.update(grads, opt_state, grad_params)
        # return the updated parameters
        new_state = [optax.apply_updates(grad_params, updates), opt_state]
        return new_state, loss
    
    # For validation epochs, do the same as before but compute the
    # accuracy, predictions and losses (no gradients needed)
    @jax.jit
    def eval_step(grad_params, data):
        events, targets = data # fix
        events = jnp.unpackbits(events, axis=1)
        readout, spike_counts = SNN.apply(grad_params, reg_init, events)
        traces, V_f = readout
        acc, pred = Acc(traces, targets)
        loss = Loss(traces, targets)
        return grad_params, jnp.array([acc, loss])
        
    
    val_data = dl.val_epoch()
    
    # Here's the start of our training loop!
    @scan_tqdm(epochs)
    def epoch(epoch_state, epoch_num):
        curr_params, curr_opt_state = epoch_state
        
        shuffle_rng = jax.random.fold_in(rng, epoch_num)
        train_data = dl.train_epoch(shuffle_rng)
        
        # train epoch
        end_state, train_loss = jax.lax.scan(
            train_step,# func
            [curr_params, curr_opt_state],# init
            train_data,# xs
            train_data.obs.shape[0]# len
        )
        
        new_params, _ = end_state
            
        # val epoch
        _, val_metrics = jax.lax.scan(
            eval_step,# func
            new_params,# init
            val_data,# xs
            val_data.obs.shape[0]# len
        )

        
        return end_state, jnp.concatenate([jnp.expand_dims(jnp.mean(train_loss),0), jnp.mean(val_metrics, axis=0)])
    # end epoch
    
    # epoch loop
    final_state, metrics = jax.lax.scan(
        epoch,
        [grad_params, opt_state], # metric arrays
        jnp.arange(epochs), # 
        epochs # len of loop
    )
    
    final_params, _ = final_state
    
                
    # return our final, optimized network.       
    return final_params, metrics
def test_gd(SNN, params, dl):

    Acc = spyx.fn.integral_accuracy()
    Loss = spyx.fn.integral_crossentropy()
    
    @jax.jit
    def test_step(params, data):
        events, targets = data
        events = jnp.unpackbits(events, axis=1)
        readout, spike_counts = SNN.apply(params, reg_init, events)
        traces, V_f = readout
        acc, pred = Acc(traces, targets)
        loss = Loss(traces, targets)
        return params, [acc, loss, pred, targets, spike_counts["ActReg"]["spike_count"]]
    
    test_data = dl.test_epoch()
    
    _, test_metrics = jax.lax.scan(
            test_step,# func
            params,# init
            test_data,# xs
            test_data.obs.shape[0]# len
    )
    
    acc = jnp.mean(test_metrics[0])
    loss = jnp.mean(test_metrics[1])
    preds = jnp.array(test_metrics[2]).flatten()
    tgts = jnp.array(test_metrics[3]).flatten()
    spike_counts = jnp.array(test_metrics[4])
    return acc, loss, preds, tgts, spike_counts

Training Time#

grad_params, metrics = gd(SNN, params, shd_dl, epochs=300)
metrics
Array([[2.71798248e+02, 4.49218750e-02, 1.14273682e+01],
       [2.20212341e+02, 4.62239608e-02, 1.33500843e+01],
       [2.02267380e+02, 6.96614608e-02, 1.11168118e+01],
       [1.92901810e+02, 7.22656250e-02, 9.26954651e+00],
       [1.80120972e+02, 4.94791679e-02, 8.30320930e+00],
       [1.74991989e+02, 6.11979179e-02, 7.78008366e+00],
       [1.69293564e+02, 9.04947966e-02, 7.75548172e+00],
       [1.63788467e+02, 8.39843750e-02, 6.87410355e+00],
       [1.60246429e+02, 8.46354216e-02, 6.24794769e+00],
       [1.57619415e+02, 8.72395858e-02, 5.47549200e+00],
       [1.56202148e+02, 9.89583358e-02, 4.73115158e+00],
       [1.54992325e+02, 1.11328125e-01, 4.41693687e+00],
       [1.54426849e+02, 1.09375000e-01, 4.20145082e+00],
       [1.53812912e+02, 1.02213547e-01, 3.86983371e+00],
       [1.52814728e+02, 1.04166672e-01, 3.85918427e+00],
       [1.52189484e+02, 1.32161468e-01, 3.59870911e+00],
       [1.50948837e+02, 1.28255218e-01, 3.48907948e+00],
       [1.50006348e+02, 1.68619797e-01, 3.26322246e+00],
       [1.49447540e+02, 1.96614593e-01, 3.23117495e+00],
       [1.49058868e+02, 1.71875000e-01, 3.14428496e+00],
       [1.48448517e+02, 2.06380218e-01, 3.00565982e+00],
       [1.48372177e+02, 2.13541672e-01, 2.95006347e+00],
       [1.47624023e+02, 2.14843750e-01, 2.95505571e+00],
       [1.47637131e+02, 2.23958343e-01, 2.92709398e+00],
       [1.47339783e+02, 1.95312500e-01, 2.97801256e+00],
       [1.47075775e+02, 2.42838547e-01, 2.95496655e+00],
       [1.46741165e+02, 2.44140625e-01, 2.84769964e+00],
       [1.46123138e+02, 2.48697922e-01, 2.79583597e+00],
       [1.45478790e+02, 2.33072922e-01, 2.81013751e+00],
       [1.44328125e+02, 2.02473968e-01, 2.90054297e+00],
       [1.42939743e+02, 2.34375000e-01, 2.85900855e+00],
       [1.41477844e+02, 1.68619797e-01, 3.05182743e+00],
       [1.39553696e+02, 1.64062500e-01, 3.12778234e+00],
       [1.38253220e+02, 1.60807297e-01, 3.20245147e+00],
       [1.35473175e+02, 1.47786468e-01, 3.26104069e+00],
       [1.31843155e+02, 1.37369797e-01, 3.24062681e+00],
       [1.28423233e+02, 1.73828125e-01, 3.22785091e+00],
       [1.27477646e+02, 2.11588547e-01, 3.12620878e+00],
       [1.26435143e+02, 2.15494797e-01, 3.05043006e+00],
       [1.25879288e+02, 2.20703125e-01, 2.91598296e+00],
       [1.25127647e+02, 2.40885422e-01, 2.88420391e+00],
       [1.25216736e+02, 2.66276062e-01, 2.85582519e+00],
       [1.24864464e+02, 2.72786468e-01, 2.79917669e+00],
       [1.24649292e+02, 3.13151062e-01, 2.69871807e+00],
       [1.24680092e+02, 3.41796875e-01, 2.65188408e+00],
       [1.24629189e+02, 3.60677093e-01, 2.58125257e+00],
       [1.24538162e+02, 3.93880218e-01, 2.53070307e+00],
       [1.24378571e+02, 3.60677093e-01, 2.57072735e+00],
       [1.24354851e+02, 3.81510437e-01, 2.53855968e+00],
       [1.24222496e+02, 4.25781250e-01, 2.45774603e+00],
       [1.24225433e+02, 4.29036468e-01, 2.44899178e+00],
       [1.24186813e+02, 4.59635437e-01, 2.42412996e+00],
       [1.24235435e+02, 4.59635437e-01, 2.42456198e+00],
       [1.24215134e+02, 4.45312500e-01, 2.45166922e+00],
       [1.24114159e+02, 4.73307312e-01, 2.42908573e+00],
       [1.24074265e+02, 4.76562500e-01, 2.39220977e+00],
       [1.24092041e+02, 5.11718750e-01, 2.34038496e+00],
       [1.24091255e+02, 4.92838562e-01, 2.37832880e+00],
       [1.24122963e+02, 4.90234375e-01, 2.38071156e+00],
       [1.24063141e+02, 4.94791687e-01, 2.36306548e+00],
       [1.24041138e+02, 5.28645873e-01, 2.34776926e+00],
       [1.23983307e+02, 5.41015625e-01, 2.33414817e+00],
       [1.23957588e+02, 5.33854187e-01, 2.31031704e+00],
       [1.24011688e+02, 5.50130248e-01, 2.29904985e+00],
       [1.23991188e+02, 5.59244812e-01, 2.29068089e+00],
       [1.24059326e+02, 5.22786498e-01, 2.31853867e+00],
       [1.23948906e+02, 5.53385437e-01, 2.29044199e+00],
       [1.23989769e+02, 5.50130248e-01, 2.30142760e+00],
       [1.23925652e+02, 5.83333373e-01, 2.25427532e+00],
       [1.23920990e+02, 5.99609375e-01, 2.23671007e+00],
       [1.23933594e+02, 5.85286498e-01, 2.25170422e+00],
       [1.23966148e+02, 5.69010437e-01, 2.24896717e+00],
       [1.23950241e+02, 6.10677123e-01, 2.24032497e+00],
       [1.23897827e+02, 6.12630248e-01, 2.21782589e+00],
       [1.23946465e+02, 6.01562500e-01, 2.22873473e+00],
       [1.23899910e+02, 6.17187500e-01, 2.22768736e+00],
       [1.23873970e+02, 6.18489623e-01, 2.20884848e+00],
       [1.23914268e+02, 6.14583373e-01, 2.21338940e+00],
       [1.23895958e+02, 6.23046875e-01, 2.20168328e+00],
       [1.23858559e+02, 6.34114623e-01, 2.19854903e+00],
       [1.23847282e+02, 6.34765625e-01, 2.18568039e+00],
       [1.23870270e+02, 6.34765625e-01, 2.18117356e+00],
       [1.23830162e+02, 6.64713562e-01, 2.14763165e+00],
       [1.23817635e+02, 6.49739623e-01, 2.17211771e+00],
       [1.23801796e+02, 6.30208373e-01, 2.19242740e+00],
       [1.23803894e+02, 6.60156250e-01, 2.16907883e+00],
       [1.23763275e+02, 6.07421875e-01, 2.21988869e+00],
       [1.23749290e+02, 6.86197937e-01, 2.11556411e+00],
       [1.23829643e+02, 6.83593750e-01, 2.11980724e+00],
       [1.23815758e+02, 6.77083373e-01, 2.11524105e+00],
       [1.23783379e+02, 6.95963562e-01, 2.09629750e+00],
       [1.23785347e+02, 6.77734375e-01, 2.11123347e+00],
       [1.23768707e+02, 6.81640625e-01, 2.10875607e+00],
       [1.23785873e+02, 7.00520873e-01, 2.08831835e+00],
       [1.23774834e+02, 6.80989623e-01, 2.09724307e+00],
       [1.23764595e+02, 6.94010437e-01, 2.10606813e+00],
       [1.23745361e+02, 6.71875000e-01, 2.11411905e+00],
       [1.23752487e+02, 6.80338562e-01, 2.11444736e+00],
       [1.23731621e+02, 6.88802123e-01, 2.10682750e+00],
       [1.23731392e+02, 7.19401062e-01, 2.06307602e+00],
       [1.23724136e+02, 7.07031250e-01, 2.08056498e+00],
       [1.23728119e+02, 7.09635437e-01, 2.07516336e+00],
       [1.23692924e+02, 6.95963562e-01, 2.07251740e+00],
       [1.23717438e+02, 7.03125000e-01, 2.07651114e+00],
       [1.23697273e+02, 7.08984375e-01, 2.07362652e+00],
       [1.23676079e+02, 7.27864623e-01, 2.05277109e+00],
       [1.23709732e+02, 7.10937500e-01, 2.08720255e+00],
       [1.23761208e+02, 7.05729187e-01, 2.07313395e+00],
       [1.23716530e+02, 7.29166687e-01, 2.05300856e+00],
       [1.23670204e+02, 7.05729187e-01, 2.05412483e+00],
       [1.23706718e+02, 6.94010437e-01, 2.06830692e+00],
       [1.23686333e+02, 6.99869812e-01, 2.07016373e+00],
       [1.23701950e+02, 7.31119812e-01, 2.05466723e+00],
       [1.23683311e+02, 7.17447937e-01, 2.05699587e+00],
       [1.23676903e+02, 7.25911498e-01, 2.04536486e+00],
       [1.23663910e+02, 7.14192748e-01, 2.04410625e+00],
       [1.23689438e+02, 7.07031250e-01, 2.05310345e+00],
       [1.23664276e+02, 7.20703125e-01, 2.04657936e+00],
       [1.23673935e+02, 7.18098998e-01, 2.05201149e+00],
       [1.23674927e+02, 7.13541687e-01, 2.04881001e+00],
       [1.23672340e+02, 7.23958373e-01, 2.03193521e+00],
       [1.23659950e+02, 7.23958373e-01, 2.03003979e+00],
       [1.23652061e+02, 7.22656250e-01, 2.02337337e+00],
       [1.23676163e+02, 7.35677123e-01, 2.02926993e+00],
       [1.23640121e+02, 7.31770873e-01, 2.03260517e+00],
       [1.23626472e+02, 7.51302123e-01, 2.00557899e+00],
       [1.23667603e+02, 7.23307312e-01, 2.03202677e+00],
       [1.23641380e+02, 7.34375000e-01, 2.02302647e+00],
       [1.23640015e+02, 7.38281250e-01, 2.02361012e+00],
       [1.23616280e+02, 7.38281250e-01, 2.01076293e+00],
       [1.23659790e+02, 7.57161498e-01, 2.00694394e+00],
       [1.23630302e+02, 7.46093750e-01, 2.00720596e+00],
       [1.23628746e+02, 7.64322937e-01, 1.99329829e+00],
       [1.23638191e+02, 7.45442748e-01, 2.01799989e+00],
       [1.23645859e+02, 7.56510437e-01, 2.01238847e+00],
       [1.23643906e+02, 7.59765625e-01, 1.99834967e+00],
       [1.23620377e+02, 7.56510437e-01, 1.99412227e+00],
       [1.23654488e+02, 7.58463562e-01, 1.99370229e+00],
       [1.23639320e+02, 7.61718750e-01, 1.98999596e+00],
       [1.23640472e+02, 7.45442748e-01, 2.00034142e+00],
       [1.23599258e+02, 7.77994812e-01, 1.97219646e+00],
       [1.23645775e+02, 7.52604187e-01, 1.99514055e+00],
       [1.23623940e+02, 7.44140625e-01, 2.00179005e+00],
       [1.23615791e+02, 7.61718750e-01, 1.99153638e+00],
       [1.23627617e+02, 7.54557312e-01, 1.99635196e+00],
       [1.23611130e+02, 7.69531250e-01, 1.98267746e+00],
       [1.23631996e+02, 7.66276062e-01, 1.98800313e+00],
       [1.23612778e+02, 7.42187500e-01, 2.01868057e+00],
       [1.23624168e+02, 7.67578125e-01, 1.98686922e+00],
       [1.23611618e+02, 7.60416687e-01, 1.99129915e+00],
       [1.23590134e+02, 7.60416687e-01, 2.00695896e+00],
       [1.23585457e+02, 7.72135437e-01, 1.98131561e+00],
       [1.23610085e+02, 7.66276062e-01, 1.99607229e+00],
       [1.23570663e+02, 7.75390625e-01, 1.97641218e+00],
       [1.23605995e+02, 7.73437500e-01, 1.97448337e+00],
       [1.23593613e+02, 7.81250000e-01, 1.96058249e+00],
       [1.23612793e+02, 7.78645873e-01, 1.97256839e+00],
       [1.23580193e+02, 7.71484375e-01, 1.97692919e+00],
       [1.23594414e+02, 7.57161498e-01, 1.99404895e+00],
       [1.23607254e+02, 7.59114623e-01, 2.00517297e+00],
       [1.23569908e+02, 7.46093750e-01, 2.00172210e+00],
       [1.23584236e+02, 7.60416687e-01, 1.99198234e+00],
       [1.23574623e+02, 7.65625000e-01, 1.98879087e+00],
       [1.23574150e+02, 7.69531250e-01, 1.98723018e+00],
       [1.23574127e+02, 7.91015625e-01, 1.96969450e+00],
       [1.23592339e+02, 7.51302123e-01, 1.99060202e+00],
       [1.23563286e+02, 7.65625000e-01, 1.98240399e+00],
       [1.23569908e+02, 7.78645873e-01, 1.96860492e+00],
       [1.23560623e+02, 7.76692748e-01, 1.96450984e+00],
       [1.23582878e+02, 7.80598998e-01, 1.96602666e+00],
       [1.23560394e+02, 7.60416687e-01, 1.97557580e+00],
       [1.23616417e+02, 7.70833373e-01, 1.96892631e+00],
       [1.23577431e+02, 7.94921875e-01, 1.96068096e+00],
       [1.23551025e+02, 8.00781250e-01, 1.94389427e+00],
       [1.23562744e+02, 7.89062500e-01, 1.94568443e+00],
       [1.23591110e+02, 7.86458373e-01, 1.95709360e+00],
       [1.23574760e+02, 7.79947937e-01, 1.95875156e+00],
       [1.23560867e+02, 7.92968750e-01, 1.94165552e+00],
       [1.23574715e+02, 7.77343750e-01, 1.96927738e+00],
       [1.23566124e+02, 7.82552123e-01, 1.96292245e+00],
       [1.23582855e+02, 7.63671875e-01, 1.96426141e+00],
       [1.23549301e+02, 7.74739623e-01, 1.96867442e+00],
       [1.23581749e+02, 7.75390625e-01, 1.95982933e+00],
       [1.23573936e+02, 7.76692748e-01, 1.95557141e+00],
       [1.23565819e+02, 7.96223998e-01, 1.95374489e+00],
       [1.23564476e+02, 7.83203125e-01, 1.96845984e+00],
       [1.23579193e+02, 7.42838562e-01, 2.00867891e+00],
       [1.23565262e+02, 7.71484375e-01, 1.96710813e+00],
       [1.23571541e+02, 7.72786498e-01, 1.96393037e+00],
       [1.23557510e+02, 7.70182312e-01, 1.97563004e+00],
       [1.23562782e+02, 7.79947937e-01, 1.95946240e+00],
       [1.23562469e+02, 7.93619812e-01, 1.94424629e+00],
       [1.23574326e+02, 7.97526062e-01, 1.94693708e+00],
       [1.23583549e+02, 7.92317748e-01, 1.95275009e+00],
       [1.23555962e+02, 7.93619812e-01, 1.94845629e+00],
       [1.23546127e+02, 8.07291687e-01, 1.93196583e+00],
       [1.23557068e+02, 7.92968750e-01, 1.94701362e+00],
       [1.23556580e+02, 7.94921875e-01, 1.93329358e+00],
       [1.23560249e+02, 7.91666687e-01, 1.94251382e+00],
       [1.23553169e+02, 7.90364623e-01, 1.94781542e+00],
       [1.23567879e+02, 7.87109375e-01, 1.94792438e+00],
       [1.23557564e+02, 7.57812500e-01, 1.95999515e+00],
       [1.23536133e+02, 7.89713562e-01, 1.93853176e+00],
       [1.23553551e+02, 7.90364623e-01, 1.94165444e+00],
       [1.23516998e+02, 7.94921875e-01, 1.93287253e+00],
       [1.23522369e+02, 8.04036498e-01, 1.92648625e+00],
       [1.23540405e+02, 7.98177123e-01, 1.92133045e+00],
       [1.23547630e+02, 8.05338562e-01, 1.92567873e+00],
       [1.23527977e+02, 7.94270873e-01, 1.92475069e+00],
       [1.23554848e+02, 7.94921875e-01, 1.93271351e+00],
       [1.23544243e+02, 7.91015625e-01, 1.94211566e+00],
       [1.23543823e+02, 7.83854187e-01, 1.94071901e+00],
       [1.23538261e+02, 7.91666687e-01, 1.93251109e+00],
       [1.23534843e+02, 8.02734375e-01, 1.92488265e+00],
       [1.23505066e+02, 8.09895873e-01, 1.91250515e+00],
       [1.23562897e+02, 7.95572937e-01, 1.93218327e+00],
       [1.23525879e+02, 8.01432312e-01, 1.92849874e+00],
       [1.23532730e+02, 7.94921875e-01, 1.93269658e+00],
       [1.23543861e+02, 7.88411498e-01, 1.92990863e+00],
       [1.23513824e+02, 7.94270873e-01, 1.92696714e+00],
       [1.23524223e+02, 8.08593750e-01, 1.91775584e+00],
       [1.23523239e+02, 7.87109375e-01, 1.93594742e+00],
       [1.23519623e+02, 7.95572937e-01, 1.92654359e+00],
       [1.23534531e+02, 8.05989623e-01, 1.92314219e+00],
       [1.23529900e+02, 7.84505248e-01, 1.93622100e+00],
       [1.23509857e+02, 8.00781250e-01, 1.91628778e+00],
       [1.23557068e+02, 8.00130248e-01, 1.92465258e+00],
       [1.23509872e+02, 7.92317748e-01, 1.93283939e+00],
       [1.23520966e+02, 8.01432312e-01, 1.93533921e+00],
       [1.23509315e+02, 7.98177123e-01, 1.93562889e+00],
       [1.23497261e+02, 8.07291687e-01, 1.92452848e+00],
       [1.23537422e+02, 8.00781250e-01, 1.92694986e+00],
       [1.23545448e+02, 7.98828125e-01, 1.92965305e+00],
       [1.23507935e+02, 8.05338562e-01, 1.91244435e+00],
       [1.23531738e+02, 8.03385437e-01, 1.91770172e+00],
       [1.23550835e+02, 8.17057312e-01, 1.91326284e+00],
       [1.23532578e+02, 8.02734375e-01, 1.92777610e+00],
       [1.23524742e+02, 8.14453125e-01, 1.91653776e+00],
       [1.23522049e+02, 8.03385437e-01, 1.91214359e+00],
       [1.23524597e+02, 8.04036498e-01, 1.91845167e+00],
       [1.23530235e+02, 7.96223998e-01, 1.92588079e+00],
       [1.23512566e+02, 8.00130248e-01, 1.92293203e+00],
       [1.23526146e+02, 7.80598998e-01, 1.93098724e+00],
       [1.23513451e+02, 7.99479187e-01, 1.91475379e+00],
       [1.23551918e+02, 7.73437500e-01, 1.93819404e+00],
       [1.23527451e+02, 7.96875000e-01, 1.93275785e+00],
       [1.23526947e+02, 7.85807312e-01, 1.92240179e+00],
       [1.23532928e+02, 8.02083373e-01, 1.92968023e+00],
       [1.23522095e+02, 7.92317748e-01, 1.93037069e+00],
       [1.23499321e+02, 7.87109375e-01, 1.92218208e+00],
       [1.23509521e+02, 8.07942748e-01, 1.91500413e+00],
       [1.23527451e+02, 8.00781250e-01, 1.93181646e+00],
       [1.23508789e+02, 8.11197937e-01, 1.91907537e+00],
       [1.23502846e+02, 7.96875000e-01, 1.91801929e+00],
       [1.23509331e+02, 7.94270873e-01, 1.92319942e+00],
       [1.23511314e+02, 7.95572937e-01, 1.92603874e+00],
       [1.23510925e+02, 7.97526062e-01, 1.92372823e+00],
       [1.23510902e+02, 7.79947937e-01, 1.93510842e+00],
       [1.23497681e+02, 7.91015625e-01, 1.92741656e+00],
       [1.23518364e+02, 7.95572937e-01, 1.92090702e+00],
       [1.23494156e+02, 8.21614623e-01, 1.90011060e+00],
       [1.23495613e+02, 8.11197937e-01, 1.90124214e+00],
       [1.23529449e+02, 7.93619812e-01, 1.91420245e+00],
       [1.23506706e+02, 8.02083373e-01, 1.90475452e+00],
       [1.23530006e+02, 8.08593750e-01, 1.91455972e+00],
       [1.23498451e+02, 7.94270873e-01, 1.91952729e+00],
       [1.23483887e+02, 7.98828125e-01, 1.91418099e+00],
       [1.23491318e+02, 8.09895873e-01, 1.90646148e+00],
       [1.23494057e+02, 8.15104187e-01, 1.89818907e+00],
       [1.23497147e+02, 8.11848998e-01, 1.89916682e+00],
       [1.23505112e+02, 8.04036498e-01, 1.91482198e+00],
       [1.23497528e+02, 7.92968750e-01, 1.91294217e+00],
       [1.23497986e+02, 8.04687500e-01, 1.91128814e+00],
       [1.23506218e+02, 8.13802123e-01, 1.90483415e+00],
       [1.23502823e+02, 7.98828125e-01, 1.90963447e+00],
       [1.23501625e+02, 7.98828125e-01, 1.91299748e+00],
       [1.23512520e+02, 7.89062500e-01, 1.91810930e+00],
       [1.23525429e+02, 7.92317748e-01, 1.92684793e+00],
       [1.23480713e+02, 8.10546875e-01, 1.91402507e+00],
       [1.23515938e+02, 7.94270873e-01, 1.92163897e+00],
       [1.23489967e+02, 7.99479187e-01, 1.91509414e+00],
       [1.23500755e+02, 8.01432312e-01, 1.92208421e+00],
       [1.23494637e+02, 8.00130248e-01, 1.91290402e+00],
       [1.23503868e+02, 8.03385437e-01, 1.91764438e+00],
       [1.23491699e+02, 7.98177123e-01, 1.90725970e+00],
       [1.23490700e+02, 8.02083373e-01, 1.90650272e+00],
       [1.23484550e+02, 8.04687500e-01, 1.90314865e+00],
       [1.23486618e+02, 8.13151062e-01, 1.90095365e+00],
       [1.23482353e+02, 8.07291687e-01, 1.90434408e+00],
       [1.23485886e+02, 8.19661498e-01, 1.90076232e+00],
       [1.23488373e+02, 8.18359375e-01, 1.90249014e+00],
       [1.23488304e+02, 8.07291687e-01, 1.91295171e+00],
       [1.23477730e+02, 8.00130248e-01, 1.91174912e+00],
       [1.23477646e+02, 8.14453125e-01, 1.90647054e+00],
       [1.23478523e+02, 8.20963562e-01, 1.89476621e+00],
       [1.23489990e+02, 8.20312500e-01, 1.89162636e+00],
       [1.23489052e+02, 8.24869812e-01, 1.89113128e+00],
       [1.23471169e+02, 8.22916687e-01, 1.89311194e+00],
       [1.23499512e+02, 8.09244812e-01, 1.91797137e+00],
       [1.23481773e+02, 8.03385437e-01, 1.91132450e+00]], dtype=float32)
print("Performance: train_loss={}, val_acc={}, val_loss={}".format(*metrics[-1]))
Performance: train_loss=123.48177337646484, val_acc=0.8033854365348816, val_loss=1.9113245010375977
plt.plot(metrics, label=["train loss", "val acc", "val loss"])
plt.title("SHD Surrogate Grad. Dual Regularization")
plt.legend()
plt.show()

Evaluation Time#

Now we’ll run the network on the test set and see what happens:

acc, loss, preds, tgts, spike_counts = test_gd(SNN, grad_params, shd_dl)
print("Accuracy:", acc, "Loss:", loss)
Accuracy: 0.77490234 Loss: 1.9661974
spike_counts.shape
(8, 256, 64)
jnp.mean(spike_counts[0], axis=0)
Array([ 11.945,   3.375,   0.   ,  19.97 ,   0.   ,   6.33 ,   0.   ,
         6.645,  14.14 ,   6.125,  28.25 ,  10.14 ,  21.78 ,   4.906,
         6.89 ,   6.48 ,   4.27 ,   5.62 ,   6.12 ,   8.91 ,  13.96 ,
         6.996,   0.   ,   7.42 ,   7.117,   0.   , 126.4  ,   7.902,
         0.   ,   6.125,  15.95 ,   7.297,  12.766,   0.   ,   0.   ,
         0.   ,  19.47 ,  25.27 ,  17.8  ,   4.633,   6.6  ,   9.33 ,
        12.625,   5.785,   4.96 ,  14.734,   4.93 ,  16.34 ,   5.5  ,
        34.53 ,   6.45 ,  11.234,   9.95 ,   8.67 ,   0.   ,   0.   ,
         0.   ,   8.25 ,   7.47 ,   0.   ,   0.   ,   0.   ,  16.4  ,
        15.42 ], dtype=float16)
cm = confusion_matrix(tgts, preds)
ConfusionMatrixDisplay(cm).plot()
plt.show()