spyx.fn#

Module Contents#

Functions#

silence_reg(min_spikes)

L2-Norm per-neuron activation normalization for spiking less than a target number of times.

sparsity_reg(max_spikes[, norm])

Layer activation normalization that seeks to discourage all neurons having a high firing rate.

integral_accuracy([time_axis])

Calculate the accuracy of a network's predictions based on the voltage traces. Used in combination with a Leaky-Integrate neuron model as the final layer.

integral_crossentropy([smoothing, time_axis])

Calculate the crossentropy between the integral of membrane potentials. Allows for label smoothing to discourage silencing the other neurons in the readout layer.

mse_spikerate([sparsity, smoothing, time_axis])

Calculate the mean squared error of the mean spike rate. Allows for label smoothing to discourage silencing the other neurons in the readout layer.

spyx.fn.silence_reg(min_spikes)[source]#

L2-Norm per-neuron activation normalization for spiking less than a target number of times.

Parameters:

min_spikes – neurons which spike below this value on average over the batch incur quadratic penalty.

Returns:

JIT compiled regularization function.

spyx.fn.sparsity_reg(max_spikes, norm=optax.huber_loss)[source]#

Layer activation normalization that seeks to discourage all neurons having a high firing rate.

Parameters:
  • max_spikes – Threshold for which penalty is incurred if the average number of spikes in the layer exceeds it.

  • norm – an Optax loss function. Default is Huber loss.

Returns:

JIT compiled regularization function.

spyx.fn.integral_accuracy(time_axis=1)[source]#

Calculate the accuracy of a network’s predictions based on the voltage traces. Used in combination with a Leaky-Integrate neuron model as the final layer.

Parameters:
  • traces – the output of the final layer of the SNN

  • targets – the integer labels for each class

Returns:

function which computes Accuracy score and predictions that takes SNN output traces and integer index labels.

spyx.fn.integral_crossentropy(smoothing=0.3, time_axis=1)[source]#

Calculate the crossentropy between the integral of membrane potentials. Allows for label smoothing to discourage silencing the other neurons in the readout layer.

Parameters:
  • smoothing – rate at which to smooth labels.

  • time_axis – temporal axis of data

Returns:

crossentropy loss function that takes SNN output traces and integer index labels.

spyx.fn.mse_spikerate(sparsity=0.25, smoothing=0.0, time_axis=1)[source]#

Calculate the mean squared error of the mean spike rate. Allows for label smoothing to discourage silencing the other neurons in the readout layer.

Parameters:
  • sparsity – the percentage of the time you want the neurons to spike

  • smoothing – [optional] rate at which to smooth labels.

Returns:

Mean-Squared-Error loss function on the spike rate that takes SNN output traces and integer index labels.