spyx.nn
#
Module Contents#
Classes#
Adaptive LIF Neuron based on the model used in LSNNs: |
|
Leaky-Integrate (Non-spiking) neuron model. |
|
Integrate and Fire neuron model. While not being as powerful/rich as other neuron models, they are very easy to implement in hardware. |
|
Leaky Integrate and Fire neuron model inspired by the implementation in |
|
Base class for RNN cores. |
|
Recurrent Integrate and Fire neuron model. |
|
Recurrent LIF Neuron adapted from snnTorch. |
|
Base class for RNN cores. |
|
Add state to the SNN to track the average number of spikes emitted per neuron per batch. |
|
Sum pool. |
Functions#
|
Add population coding to the preceding neuron layer. Preceding layer's output shape must be a multiple of |
|
Sum pool. |
- class spyx.nn.ALIF(hidden_shape, beta=None, gamma=None, threshold=1, activation=superspike(), name='ALIF')[source]#
Bases:
haiku.RNNCore
Adaptive LIF Neuron based on the model used in LSNNs:
Bellec, G., Salaj, D., Subramoney, A., Legenstein, R. & Maass, W. Long short- term memory and learning-to-learn in networks of spiking neurons. 32nd Conference on Neural Information Processing Systems (2018).
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters:
batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns:
Arbitrarily nested initial state for this core.
- class spyx.nn.LI(layer_shape, beta=None, name='LI')[source]#
Bases:
haiku.RNNCore
Leaky-Integrate (Non-spiking) neuron model.
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters:
batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns:
Arbitrarily nested initial state for this core.
- class spyx.nn.IF(hidden_shape, threshold=1, activation=superspike(), name='IF')[source]#
Bases:
haiku.RNNCore
Integrate and Fire neuron model. While not being as powerful/rich as other neuron models, they are very easy to implement in hardware.
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters:
batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns:
Arbitrarily nested initial state for this core.
- class spyx.nn.LIF(hidden_shape: tuple, beta=None, threshold=1.0, activation=superspike(), name='LIF')[source]#
Bases:
haiku.RNNCore
Leaky Integrate and Fire neuron model inspired by the implementation in snnTorch:
https://snntorch.readthedocs.io/en/latest/snn.neurons_leaky.html
- Parameters:
hidden_shape (tuple) –
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters:
batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns:
Arbitrarily nested initial state for this core.
- class spyx.nn.CuBaLIF(hidden_shape, alpha=None, beta=None, threshold=1, activation=superspike(), name='CuBaLIF')[source]#
Bases:
haiku.RNNCore
Base class for RNN cores.
This class defines the basic functionality that every core should implement:
initial_state()
, used to construct an example of the core state; and__call__()
which applies the core parameterized by a previous state to an input.Cores may be used with
dynamic_unroll()
andstatic_unroll()
to iteratively construct an output sequence from the given input sequence.- __call__(x, VI)[source]#
Run one step of the RNN.
- Parameters:
inputs – An arbitrarily nested structure.
prev_state – Previous core state.
- Returns:
A tuple with two elements
output, next_state
.output
is an arbitrarily nested structure.next_state
is the next core state, this must be the same shape asprev_state
.
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters:
batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns:
Arbitrarily nested initial state for this core.
- class spyx.nn.RIF(hidden_shape, threshold=1, activation=superspike(), name='RIF')[source]#
Bases:
haiku.RNNCore
Recurrent Integrate and Fire neuron model.
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters:
batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns:
Arbitrarily nested initial state for this core.
- class spyx.nn.RLIF(hidden_shape, beta=None, threshold=1, activation=superspike(), name='RLIF')[source]#
Bases:
haiku.RNNCore
Recurrent LIF Neuron adapted from snnTorch.
https://snntorch.readthedocs.io/en/latest/snn.neurons_rleaky.html
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters:
batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns:
Arbitrarily nested initial state for this core.
- class spyx.nn.RCuBaLIF(hidden_shape, alpha=None, beta=None, activation=superspike(), name='RCuBaLIF')[source]#
Bases:
haiku.RNNCore
Base class for RNN cores.
This class defines the basic functionality that every core should implement:
initial_state()
, used to construct an example of the core state; and__call__()
which applies the core parameterized by a previous state to an input.Cores may be used with
dynamic_unroll()
andstatic_unroll()
to iteratively construct an output sequence from the given input sequence.- __call__(x, VI)[source]#
Run one step of the RNN.
- Parameters:
inputs – An arbitrarily nested structure.
prev_state – Previous core state.
- Returns:
A tuple with two elements
output, next_state
.output
is an arbitrarily nested structure.next_state
is the next core state, this must be the same shape asprev_state
.
- initial_state(batch_size)[source]#
Constructs an initial state for this core.
- Parameters:
batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns:
Arbitrarily nested initial state for this core.
- class spyx.nn.ActivityRegularization(name='ActReg')[source]#
Bases:
haiku.Module
Add state to the SNN to track the average number of spikes emitted per neuron per batch.
Adding this to a network requires using the Haiku transform_with_state transform, which will also return an initial regularization state vector. This blank initial vector can be reused and is provided as the second arg to the SNN’s apply function.
- spyx.nn.PopulationCode(num_classes)[source]#
Add population coding to the preceding neuron layer. Preceding layer’s output shape must be a multiple of the number of classes. Use this for rate coded SNNs where the time steps are too few to get a good spike count.
- spyx.nn.sum_pool(value: jax.Array, window_shape: int | collections.abc.Sequence[int], strides: int | collections.abc.Sequence[int], padding: str, channel_axis: int | None = -1) jax.Array [source]#
Sum pool.
- Parameters:
value (jax.Array) – Value to pool.
window_shape (Union[int, collections.abc.Sequence[int]]) – Shape of the pooling window, same rank as value.
strides (Union[int, collections.abc.Sequence[int]]) – Strides of the pooling window, same rank as value.
padding (str) – Padding algorithm. Either
VALID
orSAME
.channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped.
- Returns:
Pooled result. Same rank as value.
- Return type:
jax.Array
- class spyx.nn.SumPool(window_shape: int | collections.abc.Sequence[int], strides: int | collections.abc.Sequence[int], padding: str, channel_axis: int | None = -1, name: str | None = None)[source]#
Bases:
haiku.Module
Sum pool.
Returns the total number of spikes emitted in a receptive field.
- Parameters:
window_shape (Union[int, collections.abc.Sequence[int]]) –
strides (Union[int, collections.abc.Sequence[int]]) –
padding (str) –
channel_axis (Optional[int]) –
name (Optional[str]) –