Design and architecture
Spyx is a compact spiking neural network (SNN) library built on JAX and Flax NNX. It is designed to deliver PyTorch-class ergonomics on top of JAX's JIT and auto-differentiation so researchers can iterate quickly on GPUs and TPUs, then export trained models to neuromorphic silicon via the Neuromorphic Intermediate Representation (NIR).
Design principles
- JAX-first. Every layer composes cleanly with
jax.jit,jax.vmap,jax.lax.scan, andjax.lax.associative_scan. Nothing in Spyx blocks JIT compilation or forces state out to the host. - NNX modules, not transforms. Every stateful piece — neurons, SSMs, phasor layers — inherits from
flax.nnx.Module. Parameters are plainnnx.Paramnodes, so standard Optax workflows (nnx.Optimizer+nnx.value_and_grad) work out of the box. - Surrogate gradients are
jax.custom_gradient. No hand-rolled backward passes, no PyTree gymnastics. You can use any autograd tool that works on plain JAX. - Functional core. Losses, regularisers, and data transforms in
spyx.fnandspyx.dataare pure functions wrapped injax.jit, so they compose with anything. - Optional extras, lean core. The default install is minimal. Data loaders (
tonic) and quantization (qwix) live behind[loaders]and[quant]extras.
Module map
| Module | What it contains |
|---|---|
spyx.nn |
Spiking neuron models (IF, LIF, ALIF, CuBaLIF, recurrent variants), Sequential, and the time-major run helper. |
spyx.axn |
Surrogate-gradient factories (SuperSpike, Arctan, Tanh, Boxcar, Triangular, straight-through). |
spyx.fn |
Losses and metrics — integral_crossentropy, integral_accuracy, silence_reg, sparsity_reg, mse_spikerate. |
spyx.data |
Grain-based data pipeline with rate / angle / latency encoding and SHD_loader / NMNIST_loader. |
spyx.optimize |
Quick-training helpers — fit, make_train_step, make_eval_step. |
spyx.nir |
NIR import/export for interop with neuromorphic toolchains. |
spyx.quant |
qwix-backed int8 / int4 / BitNet-ternary quantization-aware training. |
spyx.ssm |
State-space layers — LRU, S5Diag, Mamba, MambaBlock, ChunkedSSM (H-Net skeleton). |
spyx.phasor |
Complex-valued phasor networks with spike-time conversion. |
spyx.experimental |
Research-grade neurons (SPSN, stochastic-associative LIF/CuBaLIF, PSU_LIF). |
Where Spyx fits
| PyTorch SNNs (snnTorch, Norse, SpikingJelly) | Spyx | |
|---|---|---|
| Framework | PyTorch eager | JAX JIT + vmap |
| Module system | torch.nn.Module |
flax.nnx.Module |
| Accelerator | CUDA | CUDA / ROCm / TPU |
| Speed | Depends on custom kernels | Fully JIT-compiled, competitive with custom kernels |
| Quantization | external (TorchAO, BitsAndBytes) | spyx.quant (qwix) |
| Neuromorphic export | NIR | NIR |
Spyx is the right choice if you want JAX's compilation model and benefit from vmap across populations (e.g. neuroevolution, per-seed sweeps) without paying the price of hand-writing CUDA kernels for SNN dynamics.
Workshop talk
For a guided walkthrough, see the Spyx Workshop. Note that the video predates the Flax NNX migration; the concepts carry over but the API surface has since moved to match what's documented here.