How to benchmark neurons and models
Use spyx.bench to measure the latency, throughput,
memory, and spiking activity of any module that follows the Spyx stepwise
contract ((x_t, state) -> (out, state) plus initial_state). It is the tool
behind the parallel-spiking-neuron
crossover numbers, and the recommended way to fill in a
research study's results table so runs are comparable.
Benchmark one module
benchmark(module, input_shape, *, seq_len, batch, ...) builds a random
time-major (seq_len, batch, *input_shape) input, drives the module over time,
and returns a BenchResult:
import spyx
from flax import nnx
neuron = spyx.nn.LIF((256,), rngs=nnx.Rngs(0))
result = spyx.bench.benchmark(neuron, (256,), seq_len=512, batch=32)
print(result.fwd_latency_ms, result.throughput_elem_ts_per_s, result.spike_rate)
input_shape is the per-timestep feature shape (everything after the batch
axis). Pass a zero-arg thunk instead of a module when you want a fresh
instance built lazily (handy for sweeps):
result = spyx.bench.benchmark(
lambda: spyx.nn.LIF((256,), rngs=nnx.Rngs(0)),
(256,), seq_len=512, batch=32,
)
By default the driver uses a module's own parallel method when it has one
(e.g. PSU_LIF), otherwise it
falls back to spyx.nn.run. Override it explicitly with run_fn=(module, x) ->
outputs — for example, to force PSU_LIF down the sequential scan so you can
compare it against its own parallel path:
seq_fn = lambda m, x: spyx.nn.run(m, x)[0] # ignore the returned final state
result = spyx.bench.benchmark(
lambda: spyx.nn.PSU_LIF((256,), rngs=nnx.Rngs(0)),
(256,), seq_len=1024, batch=8, run_fn=seq_fn,
)
What gets measured
| Field | Meaning |
|---|---|
fwd_latency_ms |
Median forward-pass latency (ms). |
fwd_bwd_latency_ms |
Median forward + backward (value_and_grad of mean(outputs)); None if backward=False. |
throughput_elem_ts_per_s |
seq_len * batch / fwd_time — element-timesteps per second. |
spike_rate |
Mean fraction of non-zero output activations — the SNN energy proxy. |
peak_mem_mb |
Peak device memory (None if the backend does not expose it). |
flops |
FLOPs from XLA's cost model (None if unavailable). |
mfu |
Model-FLOP-utilisation = achieved FLOP/s ÷ device peak (None when the device peak is unknown — never guessed). |
param_count, device, seq_len, batch |
Run metadata. |
spike_rate is the standard event-driven energy proxy for spiking networks:
lower means sparser spiking, which on neuromorphic hardware means less energy.
It is why reset-free neurons — which fire more densely — cost something for their
parallelism (see the
parallel-spiking-neuron explanation).
The timing methodology (why the numbers are trustworthy)
spyx.bench is deliberate about how it times, because naïve JAX timing is
almost always wrong:
- JIT first. The timed function is
jax.jit-compiled and the firstn_warmup(default 3) iterations are discarded, so tracing and compilation are never counted. - Block before stopping the clock. JAX dispatches asynchronously, so every
timed call is followed by
jax.block_until_readyon its outputs before the timer stops. Without this you would be timing dispatch, not computation. - Median, not mean. The reported latency is the median over
n_iters(default 20) iterations — far more robust to OS jitter and GC pauses than the mean.
Sweep and compare
compare(modules, input_shape, *, seq_lens, batch, ...) sweeps a dict of
labelled modules across a list of sequence lengths and returns a flat list of
BenchResult. format_table(results) renders them as an aligned plain-text
table. This is the idiom for the parallel-vs-sequential crossover:
import spyx
from flax import nnx
results = spyx.bench.compare(
{
"LIF (sequential)": lambda: spyx.nn.LIF((256,), rngs=nnx.Rngs(0)),
"PSU_LIF (parallel)": lambda: spyx.nn.PSU_LIF((256,), rngs=nnx.Rngs(0)),
},
(256,),
seq_lens=[128, 512, 2048],
batch=8,
)
print(spyx.bench.format_table(results))
LIF has no parallel method so it is driven sequentially; PSU_LIF is driven
by its associative scan. Comparing the two rows at each seq_len shows the
crossover widening as the sequence grows and the sequential critical path starts
to dominate. Passing thunks (as above) gives every sweep point a fresh module.
BenchResult.as_dict() returns a plain dict if you would rather push results into
a DataFrame or log them.
Numbers are hardware-specific
Latency, throughput, and MFU depend entirely on the accelerator, driver, and JAX version. Always record them alongside your results — the reference machine for Spyx's parallel-neuron work is an AMD Radeon 8060S (gfx1151) on ROCm.