spyx.nn#

Module Contents#

Classes#

ALIF

Adaptive LIF Neuron based on the model used in LSNNs:

LI

Leaky-Integrate (Non-spiking) neuron model.

IF

Integrate and Fire neuron model. While not being as powerful/rich as other neuron models, they are very easy to implement in hardware.

LIF

Leaky Integrate and Fire neuron model inspired by the implementation in

CuBaLIF

Base class for RNN cores.

RIF

Recurrent Integrate and Fire neuron model.

RLIF

Recurrent LIF Neuron adapted from snnTorch.

RCuBaLIF

Base class for RNN cores.

ActivityRegularization

Add state to the SNN to track the average number of spikes emitted per neuron per batch.

SumPool

Sum pool.

Functions#

PopulationCode(num_classes)

Add population coding to the preceding neuron layer. Preceding layer's output shape must be a multiple of

sum_pool(→ jax.Array)

Sum pool.

class spyx.nn.ALIF(hidden_shape, beta=None, gamma=None, threshold=1, activation=superspike(), name='ALIF')[source]#

Bases: haiku.RNNCore

Adaptive LIF Neuron based on the model used in LSNNs:

Bellec, G., Salaj, D., Subramoney, A., Legenstein, R. & Maass, W. Long short- term memory and learning-to-learn in networks of spiking neurons. 32nd Conference on Neural Information Processing Systems (2018).

__call__(x, VT)[source]#
X:

Tensor from previous layer.

VT:

Neuron state vector.

initial_state(batch_size)[source]#

Constructs an initial state for this core.

Parameters:

batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns:

Arbitrarily nested initial state for this core.

class spyx.nn.LI(layer_shape, beta=None, name='LI')[source]#

Bases: haiku.RNNCore

Leaky-Integrate (Non-spiking) neuron model.

__call__(x, Vin)[source]#
X:

Input tensor from previous layer.

Vin:

Neuron state tensor.

initial_state(batch_size)[source]#

Constructs an initial state for this core.

Parameters:

batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns:

Arbitrarily nested initial state for this core.

class spyx.nn.IF(hidden_shape, threshold=1, activation=superspike(), name='IF')[source]#

Bases: haiku.RNNCore

Integrate and Fire neuron model. While not being as powerful/rich as other neuron models, they are very easy to implement in hardware.

__call__(x, V)[source]#
X:

Vector coming from previous layer.

V:

Neuron state tensor.

initial_state(batch_size)[source]#

Constructs an initial state for this core.

Parameters:

batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns:

Arbitrarily nested initial state for this core.

class spyx.nn.LIF(hidden_shape: tuple, beta=None, threshold=1.0, activation=superspike(), name='LIF')[source]#

Bases: haiku.RNNCore

Leaky Integrate and Fire neuron model inspired by the implementation in snnTorch:

https://snntorch.readthedocs.io/en/latest/snn.neurons_leaky.html

Parameters:

hidden_shape (tuple) –

__call__(x, V)[source]#
X:

input vector coming from previous layer.

V:

neuron state tensor.

initial_state(batch_size)[source]#

Constructs an initial state for this core.

Parameters:

batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns:

Arbitrarily nested initial state for this core.

class spyx.nn.CuBaLIF(hidden_shape, alpha=None, beta=None, threshold=1, activation=superspike(), name='CuBaLIF')[source]#

Bases: haiku.RNNCore

Base class for RNN cores.

This class defines the basic functionality that every core should implement: initial_state(), used to construct an example of the core state; and __call__() which applies the core parameterized by a previous state to an input.

Cores may be used with dynamic_unroll() and static_unroll() to iteratively construct an output sequence from the given input sequence.

__call__(x, VI)[source]#

Run one step of the RNN.

Parameters:
  • inputs – An arbitrarily nested structure.

  • prev_state – Previous core state.

Returns:

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

initial_state(batch_size)[source]#

Constructs an initial state for this core.

Parameters:

batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns:

Arbitrarily nested initial state for this core.

class spyx.nn.RIF(hidden_shape, threshold=1, activation=superspike(), name='RIF')[source]#

Bases: haiku.RNNCore

Recurrent Integrate and Fire neuron model.

__call__(x, V)[source]#
X:

Vector coming from previous layer.

V:

Neuron state tensor.

initial_state(batch_size)[source]#

Constructs an initial state for this core.

Parameters:

batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns:

Arbitrarily nested initial state for this core.

class spyx.nn.RLIF(hidden_shape, beta=None, threshold=1, activation=superspike(), name='RLIF')[source]#

Bases: haiku.RNNCore

Recurrent LIF Neuron adapted from snnTorch.

https://snntorch.readthedocs.io/en/latest/snn.neurons_rleaky.html

__call__(x, V)[source]#
X:

The input data/latent vector from another layer.

V:

The state tensor.

initial_state(batch_size)[source]#

Constructs an initial state for this core.

Parameters:

batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns:

Arbitrarily nested initial state for this core.

class spyx.nn.RCuBaLIF(hidden_shape, alpha=None, beta=None, activation=superspike(), name='RCuBaLIF')[source]#

Bases: haiku.RNNCore

Base class for RNN cores.

This class defines the basic functionality that every core should implement: initial_state(), used to construct an example of the core state; and __call__() which applies the core parameterized by a previous state to an input.

Cores may be used with dynamic_unroll() and static_unroll() to iteratively construct an output sequence from the given input sequence.

__call__(x, VI)[source]#

Run one step of the RNN.

Parameters:
  • inputs – An arbitrarily nested structure.

  • prev_state – Previous core state.

Returns:

A tuple with two elements output, next_state. output is an arbitrarily nested structure. next_state is the next core state, this must be the same shape as prev_state.

initial_state(batch_size)[source]#

Constructs an initial state for this core.

Parameters:

batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns:

Arbitrarily nested initial state for this core.

class spyx.nn.ActivityRegularization(name='ActReg')[source]#

Bases: haiku.Module

Add state to the SNN to track the average number of spikes emitted per neuron per batch.

Adding this to a network requires using the Haiku transform_with_state transform, which will also return an initial regularization state vector. This blank initial vector can be reused and is provided as the second arg to the SNN’s apply function.

__call__(spikes)[source]#
spyx.nn.PopulationCode(num_classes)[source]#

Add population coding to the preceding neuron layer. Preceding layer’s output shape must be a multiple of the number of classes. Use this for rate coded SNNs where the time steps are too few to get a good spike count.

spyx.nn.sum_pool(value: jax.Array, window_shape: int | collections.abc.Sequence[int], strides: int | collections.abc.Sequence[int], padding: str, channel_axis: int | None = -1) jax.Array[source]#

Sum pool.

Parameters:
  • value (jax.Array) – Value to pool.

  • window_shape (Union[int, collections.abc.Sequence[int]]) – Shape of the pooling window, same rank as value.

  • strides (Union[int, collections.abc.Sequence[int]]) – Strides of the pooling window, same rank as value.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.

Returns:

Pooled result. Same rank as value.

Return type:

jax.Array

class spyx.nn.SumPool(window_shape: int | collections.abc.Sequence[int], strides: int | collections.abc.Sequence[int], padding: str, channel_axis: int | None = -1, name: str | None = None)[source]#

Bases: haiku.Module

Sum pool.

Returns the total number of spikes emitted in a receptive field.

Parameters:
  • window_shape (Union[int, collections.abc.Sequence[int]]) –

  • strides (Union[int, collections.abc.Sequence[int]]) –

  • padding (str) –

  • channel_axis (Optional[int]) –

  • name (Optional[str]) –

__call__(value: jax.Array) jax.Array[source]#
Parameters:

value (jax.Array) –

Return type:

jax.Array