spyx.nn
Spiking-neuron layers (IF, LIF, ALIF, CuBaLIF and recurrent variants), the stateful Sequential container, and the time-major run scan helper.
PSU_LIF (documented below) is a reset-free parallel spiking neuron: a pure linear leaky integrator V_t = clip(beta)·V_{t-1} + x_t that exposes both the standard stepwise __call__ and a parallel(x) associative-scan path with O(log T) depth. See the parallel spiking neurons explanation for the sequential-vs-parallel trade-off and the benchmarking how-to to measure it.
ALIF
Bases: Module
Adaptive LIF Neuron based on the model used in LSNNs:
Bellec, G., Salaj, D., Subramoney, A., Legenstein, R. & Maass, Maass, W. Long short- term memory and learning-to-learn in networks of spiking neurons. 32nd Conference on Neural Information Processing Systems (2018).
Source code in spyx/nn.py
__call__(x, VT)
Tensor from previous layer.
:VT: Neuron state vector.
Source code in spyx/nn.py
__init__(hidden_shape, beta=None, gamma=None, threshold=1, activation=None, *, rngs)
:hidden_shape: Hidden layer shape. :beta: Membrane decay/inverse time constant. :gamma: Threshold adaptation constant. :threshold: Neuron firing threshold. :activation: spyx.axn.Axon object determining forward function and surrogate gradient function.
Source code in spyx/nn.py
ActivityRegularization
Bases: Module
Track the cumulative number of spikes emitted per neuron per batch.
The running spike count is threaded through :func:spyx.nn.run (and
:class:Sequential) as part of the scan carry, exactly like a neuron's
membrane state: :meth:initial_state seeds a zero buffer and each
:meth:__call__ returns the incoming spikes unchanged plus the updated
count. The final accumulated count comes back as this layer's entry in the
final_state returned by run, and can be fed to
spyx.fn.silence_reg / spyx.fn.sparsity_reg for activity penalties.
Threading the count through the carry (rather than mutating an
nnx.Variable in place) is what lets it accumulate inside the raw
jax.lax.scan used by :func:spyx.nn.run, where in-place variable
mutation raises TraceContextError.
Source code in spyx/nn.py
__call__(spikes, spike_count)
:spikes: Spikes emitted by the previous layer at this timestep.
:spike_count: Running per-neuron spike count carried through the scan.
:return: (spikes, spike_count + spikes) -- the spikes pass through
unchanged while the count accumulates.
Source code in spyx/nn.py
__init__(hidden_shape, batch_size=1, dtype=jnp.float32)
:hidden_shape: Per-neuron shape of the layer being regularized. :batch_size: Leading batch dimension of the spike-count buffer. :dtype: Storage dtype for the spike-count buffer.
Source code in spyx/nn.py
Flatten
Bases: Module
Flatten every non-batch dimension of a per-timestep input.
Stateless: maps x of shape (B, ...) to (B, prod(...)). It has no
initial_state, so :class:Sequential runs it in stateless mode. Used by
:mod:spyx.nir to represent NIR Flatten nodes; flax.nnx has no
built-in flatten layer.
Source code in spyx/nn.py
IF
Bases: Module
Integrate and Fire neuron model.
Source code in spyx/nn.py
__call__(x, V)
Vector coming from previous layer.
:V: Neuron state tensor.
__init__(hidden_shape, threshold=1, activation=None)
:hidden_shape: Shape of the layer. :threshold: threshold for reset. Defaults to 1. :activation: spyx.activation function.
Source code in spyx/nn.py
LI
Bases: Module
Leaky-Integrate (Non-spiking) neuron model.
Source code in spyx/nn.py
__call__(x, Vin)
__init__(layer_shape, beta=None, *, rngs)
:layer_shape: Shape of the layer. :beta: Decay rate on membrane potential (voltage).
Source code in spyx/nn.py
LIF
Bases: Module
Leaky Integrate and Fire neuron model.
Source code in spyx/nn.py
__call__(x, V)
input vector coming from previous layer.
:V: neuron state tensor.
Source code in spyx/nn.py
__init__(hidden_shape, beta=None, threshold=1.0, activation=None, *, rngs)
:hidden_shape: Shape of the layer. :beta: decay rate. :threshold: threshold for reset. Defaults to 1. :activation: spyx.axn.Axon object.
Source code in spyx/nn.py
PSU_LIF
Bases: Module
Parallel Spiking Unit LIF: a reset-free leaky integrate-and-fire neuron.
.. note::
Experimental. Its supported entry point is
:class:spyx.experimental.PSU_LIF; the API may change without a
deprecation cycle. It is defined here for locality with the other neurons.
A standard :class:LIF subtracts a reset spikes * threshold from the
membrane every step, which couples each timestep to the (nonlinear) spike
of the previous step and forces a strictly sequential O(T) scan.
Dropping the reset turns the membrane into a pure linear leaky integrator,
.. math:: V_t = \beta \, V_{t-1} + x_t ,
which is a first-order associative recurrence and can therefore be
evaluated with :func:jax.lax.associative_scan in O(\log T) parallel
depth on an accelerator. Spikes are a pointwise surrogate threshold applied
to the whole membrane trace, :math:s_t = \sigma(V_t - \text{threshold}).
Removing the reset is a deliberate accuracy/parallelism trade-off: the neuron never depresses after firing, so it can fire on consecutive steps while a well-tuned integration window keeps activity bounded. In exchange the sequence can be scored in logarithmic instead of linear depth.
Two execution modes are provided and are numerically identical:
- :meth:
__call__-- one reset-free timestep(x, V) -> (spikes, V)withV = beta * V + x; a drop-in for :func:spyx.nn.run, :class:Sequential, and NIR, exactly like :class:LIF. - :meth:
parallel-- the whole time-major sequence at once via an associative scan over the leak,O(\log T)depth.
Because both modes use the same clipped beta and the same surrogate,
and :meth:__call__ integrates the input before spiking, scanning
:meth:__call__ over x reproduces :meth:parallel exactly.
Source code in spyx/nn.py
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 | |
__call__(x, V)
One reset-free timestep.
input vector coming from previous layer.
:V: neuron state tensor.
Integrates the input into the membrane (V = beta * V + x, no
reset), then emits a surrogate spike on the updated membrane so that
scanning this method matches :meth:parallel exactly.
Source code in spyx/nn.py
__init__(hidden_shape, beta=None, threshold=1.0, activation=None, *, rngs)
:hidden_shape: Shape of the layer. :beta: decay rate. Scalar if provided, else learnable per-unit init. :threshold: firing threshold. Defaults to 1. :activation: spyx.axn.Axon object determining the surrogate spike.
Source code in spyx/nn.py
parallel(x)
Score a whole time-major sequence with an associative scan.
input with shape
[Time, Batch, ...].
:return: spikes with shape [Time, Batch, ...].
Computes the full membrane trace V_t = beta * V_{t-1} + x_t (with
V_{-1} = 0) via :func:jax.lax.associative_scan over the time axis
in O(\log T) depth, then applies the surrogate spike pointwise.
Source code in spyx/nn.py
RIF
Bases: Module
Recurrent Integrate and Fire neuron model.
Source code in spyx/nn.py
__call__(x, V)
Vector coming from previous layer.
:V: Neuron state tensor.
Source code in spyx/nn.py
RLIF
Bases: Module
Recurrent LIF Neuron.
Source code in spyx/nn.py
__call__(x, V)
The input data/latent vector from another layer.
:V: The state tensor.
Source code in spyx/nn.py
Sequential
Bases: Sequential
A Sequential container that supports passing state through its layers.
Source code in spyx/nn.py
SumPool
Bases: Module
Sum pool.
Source code in spyx/nn.py
run(model, x, state=None)
Execute a model over a sequence of inputs using jax.lax.scan.
:model: A stateful Flax NNX Module, typically :class:Sequential or a
Spyx neuron. It must either take (x_t, state) -> (out, next_state)
or expose an initial_state(batch_size) method (or both). Plain
stateless modules like nnx.Linear don't fit the contract — wrap
them in a :class:Sequential with at least one stateful layer, or
use jax.vmap if you just want to apply the module per timestep.
Input data with shape [Time, Batch, ...].
:state: Initial state for the model. If None,
model.initial_state is
consulted; if the model has no initial_state and no state is
supplied explicitly, a clear error is raised.
:return: (outputs, final_state)
Source code in spyx/nn.py
sum_pool(value, window_shape, strides, padding, channel_axis=-1)
Sum pool.