Skip to content

How to use SSM and phasor layers

To model long-range temporal structure beyond what LIF dynamics capture, combine your spiking stack with the sequence layers in spyx.ssm (LRU, S5Diag, Mamba, MambaBlock, ChunkedSSM) or the complex-valued layers in spyx.phasor.

The key contract difference: Spyx neurons are stepwise ((x_t, state) -> (out, state), driven by spyx.nn.run), while SSM layers are whole-sequence ((T, B, d_model) -> (T, B, d_model), internally parallelised with jax.lax.associative_scan). You compose them by running the spiking front-end first, then feeding its spike train to the SSM.

Use a single SSM layer

import jax.numpy as jnp
from flax import nnx
from spyx import ssm

rngs = nnx.Rngs(0)
layer = ssm.LRU(d_model=8, d_state=64, rngs=rngs)   # Linear Recurrent Unit

u = jnp.ones((128, 32, 8))     # time-major (T, B, d_model)
y = layer(u)                   # same shape, real-valued

The available layers (all time-major, all trainable with a stock optax.adam + nnx.Optimizer loop):

Layer What it is
ssm.LRU(d_model, d_state, ...) Linear Recurrent Unit (Orvieto et al., 2023); stability enforced by construction.
ssm.S5Diag(d_model, d_state, ...) Diagonal S4D/S5 with HiPPO-LegS init and learnable log-step — best for long-range tasks.
ssm.Mamba(d_inner, d_state, ...) Selective SSM core (input-dependent Δ, B, C).
ssm.MambaBlock(d_model, d_state, d_conv, expand, ...) Full Mamba block: in-proj → depthwise conv → SSM → gate → out-proj.
ssm.ChunkedSSM(inner, outer, chunk_size=..., pool=...) H-Net-style hierarchy: an inner SSM per timestep plus an outer SSM over chunk summaries.

ChunkedSSM wraps any two (T, B, D) -> (T, B, D) modules, so the inner/outer pair can mix layer types:

inner = ssm.MambaBlock(d_model=8, d_state=4, rngs=rngs)
outer = ssm.LRU(d_model=8, d_state=8, rngs=rngs)
hnet = ssm.ChunkedSSM(inner, outer, chunk_size=4, pool="mean")  # T must divide by chunk_size

Build a hybrid SNN → SSM stack

To add an SSM on top of a spiking front-end (the Linear → LIF → LRU pattern exercised in tests/test_ssm.py), run the spiking layers with spyx.nn.run, then apply the SSM to the resulting spike train:

import spyx
import spyx.nn as snn

rngs = nnx.Rngs(0)
snn_front = snn.Sequential(
    nnx.Linear(4, 8, use_bias=False, rngs=rngs),
    snn.LIF((8,), activation=spyx.axn.triangular(), rngs=rngs),
)
ssm_layer = ssm.LRU(d_model=8, d_state=16, rngs=rngs)
readout = nnx.Linear(8, 3, use_bias=False, rngs=rngs)

T, B = 128, 32
u = jnp.ones((T, B, 4))                      # time-major input
spikes, _ = snn.run(snn_front, u)            # (T, B, 8) binary spikes
h = ssm_layer(spikes)                        # (T, B, 8) real features
logits = readout(h.sum(axis=0))              # (B, 3)

Gradients flow through the whole pipeline, so the stack trains end-to-end with the usual nnx.value_and_grad step — see How to train a model. Wrap the three stages in one nnx.Module if you want a single trainable object.

To quantize the nnx.Linear layers around an SSM (the SSM's own B/C projections are raw params and stay fp32), use spyx.quant; scripts/ssm_demo.py demonstrates both int8 and BitNet-ternary variants.

Use phasor layers

Phasor networks (Bybee, Frady & Sommer, 2022) represent activations as unit-magnitude complex numbers and convert to single-spike-per-cycle trains at inference. Train in the continuous complex domain:

from spyx import phasor

model = phasor.PhasorMLP(
    in_features=8, hidden_features=16, out_features=4, depth=2, rngs=rngs,
)
x = jnp.ones((32, 8))          # real inputs in [0, 1]
logits = model(x)              # real (32, 4) — trains with stock optax.adam

PhasorLinear stores its complex kernel as paired kernel_re / kernel_im float32 params, so gradients stay real and a plain optax.adam + nnx.Optimizer loop converges without Wirtinger-gradient surprises.

To run a trained phasor layer in the spike domain, wrap it in SpikingPhasor, which decodes phases from an incoming spike train, applies the layer, and re-emits spikes:

linear = phasor.PhasorLinear(16, 16, rngs=rngs)
spiking = phasor.SpikingPhasor(linear, period_T=32)

theta = jnp.zeros((32, 16))                       # (B, features) phases
spikes_in = phasor.phase_to_spikes(theta, T=32)   # (T, B, features)
spikes_out = spiking(spikes_in)                   # (T, B, features)

The codec helpers phase_to_spikes / spikes_to_phase convert between phases in (-π, π] and one-spike-per-cycle trains; round-trip error is bounded by the bin size 2π / T. See scripts/phasor_demo.py and the Phasor Networks notebook for end-to-end examples, and the State-Space Models notebook for SSM training runs.