API Reference
This page provides documentation for the core Spyx modules.
Spyx Global Functions
integral_accuracy(time_axis=1)
Calculate the accuracy of a network's predictions based on the voltage traces. Used in combination with a Leaky-Integrate neuron model as the final layer.
:param traces: the output of the final layer of the SNN :param targets: the integer labels for each class :return: function which computes Accuracy score and predictions that takes SNN output traces and integer index labels.
Source code in spyx/fn.py
integral_crossentropy(smoothing=0.3, time_axis=1)
Calculate the crossentropy between the integral of membrane potentials. Allows for label smoothing to discourage silencing the other neurons in the readout layer.
:param smoothing: rate at which to smooth labels. :param time_axis: temporal axis of data :return: crossentropy loss function that takes SNN output traces and integer index labels.
Source code in spyx/fn.py
mse_spikerate(sparsity=0.25, smoothing=0.0, time_axis=1)
Calculate the mean squared error of the mean spike rate. Allows for label smoothing to discourage silencing the other neurons in the readout layer.
:param sparsity: the percentage of the time you want the neurons to spike :param smoothing: [optional] rate at which to smooth labels. :return: Mean-Squared-Error loss function on the spike rate that takes SNN output traces and integer index labels.
Source code in spyx/fn.py
silence_reg(min_spikes)
L2-Norm per-neuron activation normalization for spiking less than a target number of times.
:param min_spikes: neurons which spike below this value on average over the batch incur quadratic penalty. :return: JIT compiled regularization function.
Source code in spyx/fn.py
sparsity_reg(max_spikes, norm=optax.huber_loss)
Layer activation normalization that seeks to discourage all neurons having a high firing rate.
:param max_spikes: Threshold for which penalty is incurred if the average number of spikes in the layer exceeds it. :param norm: an Optax loss function. Default is Huber loss. :return: JIT compiled regularization function.
Source code in spyx/fn.py
Neural Networks
ALIF
Bases: RNNCore
Adaptive LIF Neuron based on the model used in LSNNs:
Bellec, G., Salaj, D., Subramoney, A., Legenstein, R. & Maass, W. Long short- term memory and learning-to-learn in networks of spiking neurons. 32nd Conference on Neural Information Processing Systems (2018).
Source code in spyx/nn.py
__call__(x, VT)
:x: Tensor from previous layer. :VT: Neuron state vector.
Source code in spyx/nn.py
__init__(hidden_shape, beta=None, gamma=None, threshold=1, activation=superspike(), name='ALIF')
:hidden_shape: Hidden layer shape. :beta: Membrane decay/inverse time constant. :gamma: Threshold adaptation constant. :threshold: Neuron firing threshold. :activation: spyx.axn.Axon object determining forward function and surrogate gradient function.
Source code in spyx/nn.py
ActivityRegularization
Bases: Module
Add state to the SNN to track the average number of spikes emitted per neuron per batch.
Adding this to a network requires using the Haiku transform_with_state transform, which will also return an initial regularization state vector. This blank initial vector can be reused and is provided as the second arg to the SNN's apply function.
Source code in spyx/nn.py
IF
Bases: RNNCore
Integrate and Fire neuron model. While not being as powerful/rich as other neuron models, they are very easy to implement in hardware.
Source code in spyx/nn.py
__call__(x, V)
:x: Vector coming from previous layer. :V: Neuron state tensor.
Source code in spyx/nn.py
__init__(hidden_shape, threshold=1, activation=superspike(), name='IF')
:hidden_size: Size of preceding layer's outputs :threshold: threshold for reset. Defaults to 1. :activation: spyx.activation function, default is Heaviside with Straight-Through-Estimation.
Source code in spyx/nn.py
LI
Bases: RNNCore
Leaky-Integrate (Non-spiking) neuron model.
Source code in spyx/nn.py
__call__(x, Vin)
:x: Input tensor from previous layer. :Vin: Neuron state tensor.
Source code in spyx/nn.py
__init__(layer_shape, beta=None, name='LI')
:layer_size: Number of output neurons from the previous linear layer. :beta: Decay rate on membrane potential (voltage). Set uniformly across the layer.
Source code in spyx/nn.py
LIF
Bases: RNNCore
Leaky Integrate and Fire neuron model inspired by the implementation in snnTorch:
https://snntorch.readthedocs.io/en/latest/snn.neurons_leaky.html
Source code in spyx/nn.py
__call__(x, V)
:x: input vector coming from previous layer. :V: neuron state tensor.
Source code in spyx/nn.py
__init__(hidden_shape, beta=None, threshold=1.0, activation=superspike(), name='LIF')
:hidden_size: Size of preceding layer's outputs :beta: decay rate. Set to float in range (0,1] for uniform decay across layer, otherwise it will be a normal distribution centered on 0.5 with stddev of 0.25 :threshold: threshold for reset. Defaults to 1. :activation: spyx.axn.Axon object, default is Heaviside with Straight-Through-Estimation.
Source code in spyx/nn.py
RIF
Bases: RNNCore
Recurrent Integrate and Fire neuron model.
Source code in spyx/nn.py
__call__(x, V)
:x: Vector coming from previous layer. :V: Neuron state tensor.
Source code in spyx/nn.py
__init__(hidden_shape, threshold=1, activation=superspike(), name='RIF')
:hidden_size: Size of preceding layer's outputs :threshold: threshold for reset. Defaults to 1. :activation: spyx.activation function, default is Heaviside with Straight-Through-Estimation.
Source code in spyx/nn.py
RLIF
Bases: RNNCore
Recurrent LIF Neuron adapted from snnTorch.
https://snntorch.readthedocs.io/en/latest/snn.neurons_rleaky.html
Source code in spyx/nn.py
__call__(x, V)
:x: The input data/latent vector from another layer. :V: The state tensor.
Source code in spyx/nn.py
__init__(hidden_shape, beta=None, threshold=1, activation=superspike(), name='RLIF')
Initialization function.
:hidden_shape: The tuple describing the layer's shape. Can accomodate varying shapes to directly stack on convolution layers without flattening. :beta: Decay constant. Unless explicitly set to a float of range [0,1], it is treated as a learnable parameter. :threshold: Firing threshold for the layer. Does not currently support learning/trainable thresholds. :activation: A spyx.axn.Axon object specifying the forward and reverse activation function. By default it is Heaviside with Straight Through Estimation.
Source code in spyx/nn.py
SumPool
Bases: Module
Sum pool.
Returns the total number of spikes emitted in a receptive field.
Source code in spyx/nn.py
__init__(window_shape, strides, padding, channel_axis=-1, name=None)
Sum pool.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
window_shape
|
Union[int, Sequence[int]]
|
Shape of the pooling window, same rank as value. |
required |
strides
|
Union[int, Sequence[int]]
|
Strides of the pooling window, same rank as value. |
required |
padding
|
str
|
Padding algorithm. Either |
required |
channel_axis
|
Optional[int]
|
Axis of the spatial channels for which pooling is skipped. |
-1
|
name
|
Optional[str]
|
String name for the module. |
None
|
Source code in spyx/nn.py
PopulationCode(num_classes)
Add population coding to the preceding neuron layer. Preceding layer's output shape must be a multiple of the number of classes. Use this for rate coded SNNs where the time steps are too few to get a good spike count.
Source code in spyx/nn.py
sum_pool(value, window_shape, strides, padding, channel_axis=-1)
Sum pool.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Array
|
Value to pool. |
required |
window_shape
|
Union[int, Sequence[int]]
|
Shape of the pooling window, same rank as value. |
required |
strides
|
Union[int, Sequence[int]]
|
Strides of the pooling window, same rank as value. |
required |
padding
|
str
|
Padding algorithm. Either |
required |
channel_axis
|
Optional[int]
|
Axis of the spatial channels for which pooling is skipped. |
-1
|
Returns:
| Type | Description |
|---|---|
Array
|
Pooled result. Same rank as value. |
Source code in spyx/nn.py
Activation Functions
arctan(k=2)
This class implements the Arctangent surrogate gradient activation function for a spiking neuron.
The Arctangent function returns a value between -pi/2 and pi/2 for inputs in the range of -Infinity to Infinity. It is often used in the context of spiking neurons because it provides a smooth approximation to the step function that is differentiable everywhere, which is a requirement for gradient-based optimization methods.
:k: A scaling factor that can be used to adjust the steepness of the Arctangent function. Default is 2. :return: JIT compiled arctangent-derived surrogate gradient function.
Source code in spyx/axn.py
boxcar(width=2, height=0.5)
Boxcar activation.
:width: Total width of non-zero gradient flow, centered on 0. :height: Value for gradient within the specified window. :return: JIT compiled boxcar surrogate gradient function.
Source code in spyx/axn.py
custom(bwd=lambda x: x, fwd=lambda x: heaviside(x))
This function serves as the activation function for the SNNs, allowing for custom definitions of both surrogate gradients for backwards passes as well as substitution of the Heaviside function for relaxations such as sigmoids.
It is assumed that the input to this layer has already had it's threshold subtracted within the neuron model dynamics.
The default behavior is a Heaviside forward activation with a stragiht through estimator surrogate gradient.
:bwd: Function that calculates the gradient to be used in the backwards pass. :fwd: Forward activation/spiking function. Default is the heaviside function centered at 0. :return: A JIT compiled activation function comprised of the specified forward and backward functions.
Source code in spyx/axn.py
superspike(k=25)
This function implements the SuperSpike surrogate gradient activation function for a spiking neuron.
The SuperSpike function is defined as 1/(1+k|U|)^2, where U is the input to the function and k is a scaling factor. It returns a value between 0 and 1 for inputs in the range of -Infinity to Infinity.
It is often used in the context of spiking neurons because it provides a smooth approximation to the step function that is differentiable everywhere, which is a requirement for gradient-based optimization methods.
It is a fast approximation of the Sigmoid function adapted from:
F. Zenke, S. Ganguli (2018) SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks. Neural Computation, pp. 1514-1541.
:k: A scaling factor that can be used to adjust the steepness of the SuperSpike function. Default is 25. :return: JIT compiled SuperSpike surrogate gradient function.
Source code in spyx/axn.py
tanh(k=1)
Hyperbolic Tangent activation.
.. math:: 4 / (e^{-kx} + e^{kx})^2
:k: Value for scaling the slope of the surrogate gradient. :return: JIT compiled tanh surrogate gradient function.
Source code in spyx/axn.py
triangular(k=2)
Triangular activation inspired by Esser et. al. https://arxiv.org/abs/1603.08270
.. math:: max(0, 1-|kx|)
:k: scale factor :return: JIT compiled triangular surrogate gradient function.
Source code in spyx/axn.py
Data Utilities
angle_code(neuron_count, min_val, max_val)
Higher-order-function which returns an angle encoding function; given a continuous value, an angle converter generates a one-hot vector corresponding to where the value falls between a specified minimum and maximum. To achieve non-linear descritization, apply a function to the continuous value before feeding it into the encoder.
:neuron_count: The number of output channels for the angle encoder :min_val: A lower bound on the continuous input channel :max_val: An upper bound on the continuous input channel.
Source code in spyx/data.py
rate_code(num_steps, max_r=0.75)
Unrolls input data along axis 1 and converts to rate encoded spikes; the probability of spiking is based on the input value multiplied by a max rate, with each time step being a sample drawn from a Bernoulli distribution. Currently Assumes input values have been rescaled to between 0 and 1.
Source code in spyx/data.py
shift_augment(max_shift=10, axes=(-1,))
Shift data augmentation tool. Rolls data along specified axes randomly up to a certain amount.
:max_shift: maximum to which values can be shifted. :axes: the data axis or axes along which the input will be randomly shifted.
Source code in spyx/data.py
shuffler(dataset, batch_size)
Higher-order-function which builds a shuffle function for a dataset.
:dataset: jnp.array [# samples, time, channels...] :batch_size: desired batch size.