State-Space Models in Spyx¶
This tutorial introduces the spyx.ssm module with runnable examples for each layer type:
- LRU (Linear Recurrent Unit) on a copy task — the simplest diagonal-complex SSM.
- S5Diag with HiPPO-LegS initialisation — long-range-friendly init, learnable step size.
- MambaBlock — the selective-SSM block with input-dependent dynamics.
- Hybrid SNN + SSM stack — drop an SSM between spiking layers via
spyx.nn.Sequential. - Quantized SSM via
spyx.quant.quantize(int8 and BitNet-ternary).
All SSMs here use jax.lax.associative_scan for O(log T) parallel depth on GPU / TPU — no custom CUDA kernels.
import jax
import jax.numpy as jnp
import optax
from flax import nnx
import spyx
import spyx.nn as snn
from spyx import ssm
1. LRU on a copy task¶
The Linear Recurrent Unit (Orvieto et al., 2023) is the simplest member of the family: a diagonal complex decay λ = exp(-exp(ν) + iθ) with real skip. Stability is built into the parameterisation so you don't have to clip weights.
We'll train it to reproduce its input — the canonical trivial sequence task. Loss should drop from ~1 to ~0 within a few dozen steps.
rngs = nnx.Rngs(0)
lru = ssm.LRU(d_model=4, d_state=16, rngs=rngs)
optimizer = nnx.Optimizer(lru, optax.adam(5e-3), wrt=nnx.Param)
T, B = 32, 16
u = jax.random.normal(jax.random.PRNGKey(1), (T, B, 4))
@nnx.jit
def train_step(model, optimizer, u):
def loss_fn(m):
return jnp.mean((m(u) - u) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
losses = [float(train_step(lru, optimizer, u)) for _ in range(150)]
print(f"initial MSE: {losses[0]:.4f}")
print(f"final MSE: {losses[-1]:.4f}")
# Confirm |λ| < 1 (stability).
lam, _, _ = lru._complex_matrices()
print(f"|λ| range: [{float(jnp.abs(lam).min()):.3f}, {float(jnp.abs(lam).max()):.3f}]")
initial MSE: 1.4408 final MSE: 0.0020 |λ| range: [0.203, 0.847]
2. S5Diag with HiPPO-LegS¶
S5Diag uses the HiPPO-LegS eigenvalues λ_n = -½ + i·π·n as continuous-time priors and learns a log-step log_dt that controls the effective decay. This is the recipe that performs best on long-range sequence tasks in the S4 / S5 papers.
s5 = ssm.S5Diag(d_model=8, d_state=32, rngs=nnx.Rngs(2))
lam, _, _ = s5._complex_matrices()
print("HiPPO eigenvalues after ZOH discretisation:")
print(f" |λ| range: [{float(jnp.abs(lam).min()):.4f}, {float(jnp.abs(lam).max()):.4f}]")
print(f" all |λ| < 1 (stable): {bool(jnp.all(jnp.abs(lam) < 1.0))}")
# Forward pass on a random sequence.
u_test = jax.random.normal(jax.random.PRNGKey(3), (64, 2, 8))
y = s5(u_test)
print(f"\nforward shape: {y.shape} dtype={y.dtype}")
HiPPO eigenvalues after ZOH discretisation: |λ| range: [0.9523, 0.9995] all |λ| < 1 (stable): True
forward shape: (64, 2, 8) dtype=float32
3. MambaBlock on a selective-copying task¶
MambaBlock is the full published Mamba block: in-projection, depthwise 1D conv, SiLU, the selective SSM, SiLU(z) gate, out-projection. The selective SSM uses input-dependent (Δ, B, C) with a learned diagonal A = -exp(A_log) — the portable JAX fallback for the selective_scan_cuda kernel in the reference PyTorch repo.
We'll train it on the same copy task so you can compare its convergence against the LRU.
mb = ssm.MambaBlock(d_model=8, d_state=8, d_conv=4, expand=2, rngs=nnx.Rngs(4))
mb_opt = nnx.Optimizer(mb, optax.adam(3e-3), wrt=nnx.Param)
u = jax.random.normal(jax.random.PRNGKey(5), (16, 4, 8))
@nnx.jit
def mb_train_step(model, optimizer, u):
def loss_fn(m):
return jnp.mean((m(u) - u) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
mb_losses = [float(mb_train_step(mb, mb_opt, u)) for _ in range(100)]
print(f"MambaBlock initial MSE: {mb_losses[0]:.4f}")
print(f"MambaBlock final MSE: {mb_losses[-1]:.4f}")
MambaBlock initial MSE: 1.1178 MambaBlock final MSE: 0.4200
4. Hybrid SNN + SSM stack¶
spyx.nn.Sequential is happy to thread state through layers of different kinds. A common pattern for neuromorphic workloads: use a spiking front-end to generate sparse events, an SSM to do the actual long-range temporal modelling, and a linear readout.
The spyx.nn.run helper scans the spiking layers over time; the SSM then eats the full [T, B, d] spike tensor in one call.
rngs = nnx.Rngs(0)
snn_front = snn.Sequential(
nnx.Linear(4, 8, use_bias=False, rngs=rngs),
snn.LIF((8,), activation=spyx.axn.triangular(), rngs=rngs),
)
ssm_layer = ssm.LRU(d_model=8, d_state=16, rngs=rngs)
readout = nnx.Linear(8, 3, use_bias=False, rngs=rngs)
T, B = 12, 4
u = jax.random.normal(jax.random.PRNGKey(6), (T, B, 4))
spikes, _ = snn.run(snn_front, u) # (T, B, 8) binary spikes
h = ssm_layer(spikes) # (T, B, 8) continuous SSM output
logits = readout(h.sum(axis=0)) # (B, 3) class logits
print(f"input: {u.shape}")
print(f"spikes: {spikes.shape} sparsity: {float(spikes.mean()):.3f}")
print(f"logits: {logits.shape}")
input: (12, 4, 4) spikes: (12, 4, 8) sparsity: 0.102 logits: (4, 3)
5. Quantized SSM¶
spyx.quant.quantize handles SSM-containing models just like any other Flax NNX module. The default linear_only_rules() quantizes the nnx.Linear / nnx.Conv layers — including the SSM's in_proj, x_proj, dt_proj, and out_proj — and leaves the SSM's own diagonal state matrix (A_log, B_re, etc.) in fp32.
Try int8 first, then the BitNet-ternary recipe.
if not spyx.quant.available():
print("qwix not installed; install with `pip install \"spyx[quant]\"`.")
else:
class SSMBlock(nnx.Module):
def __init__(self, in_dim, hidden, out_dim, *, rngs):
self.pre = nnx.Linear(in_dim, hidden, use_bias=False, rngs=rngs)
self.ssm = ssm.LRU(d_model=hidden, d_state=8, rngs=rngs)
self.post = nnx.Linear(hidden, out_dim, use_bias=False, rngs=rngs)
def __call__(self, u):
return self.post(self.ssm(self.pre(u)))
T, B = 8, 2
sample = jax.random.normal(jax.random.PRNGKey(7), (T, B, 4))
for label, rules in (
("int8 W+A", spyx.quant.linear_only_rules("int8", "int8")),
("BitNet ternary", spyx.quant.bitnet_ternary_rules()),
):
fp_model = SSMBlock(4, 16, 3, rngs=nnx.Rngs(0))
qmodel = spyx.quant.quantize(fp_model, sample, rules=rules)
out_fp = fp_model(sample)
out_q = qmodel(sample)
print(f"{label:16s} max |fp - q|: {float(jnp.max(jnp.abs(out_fp - out_q))):.4f}")
int8 W+A max |fp - q|: 0.0354 BitNet ternary max |fp - q|: 1.2347
Where to go next¶
- Swap LRU for
ChunkedSSM(inner=MambaBlock, outer=LRU, chunk_size=8)to try the H-Net skeleton. - Plug the hybrid SNN + SSM into an SHD loader and compare against the pure-LIF baseline.
- Use
spyx.quant.bitnet_ternary_rules()with a longer fine-tune budget — the accuracy dip from int2 weights usually recovers within a few epochs.
See src/spyx/ssm.py for the full module docstrings, scripts/ssm_demo.py for a non-notebook version of the above.