Parallel spiking neurons
Spiking neurons are usually run one timestep at a time. A standard
spyx.nn.LIF keeps a membrane voltage \(V\) and, at every
step, leaks it, integrates the input, fires, and resets by subtracting the
spike:
That - S_t θ term is the whole story. Because \(S_t\) is a (nonlinear) function
of \(V_t\), the update from \(V_t\) to \(V_{t+1}\) is nonlinear and every step
depends on the spike emitted by the step before it. There is no way around
walking the sequence in order: the recurrence is intrinsically sequential,
and Spyx evaluates it with a jax.lax.scan inside spyx.nn.run.
This page explains a different family of neurons that trade the hard reset for the ability to score a whole sequence in parallel, why that trade-off exists, and what it buys you in practice.
Sequential scan: O(T) depth
A jax.lax.scan over T timesteps is a chain of T dependent steps. On an
accelerator the wall-clock cost is dominated not by the arithmetic — each step is
tiny — but by the length of the dependency chain: T kernel launches that
cannot start until their predecessor finishes. The critical path is O(T)
deep. For short sequences on a busy GPU this is fine; for long sequences on a GPU
with spare compute it leaves the device mostly idle, waiting on latency rather
than doing work.
The reset-free trick: a linear recurrence
Drop the reset. Without the - S_t θ term the membrane is a pure linear leaky
integrator:
This is a first-order linear recurrence, and linear recurrences are
associative: composing "multiply by \(\beta\), add \(x\)" maps can be
reordered freely. That is exactly the structure
jax.lax.associative_scan
exploits — the same parallel-prefix-scan machinery behind modern diagonal
state-space models (S4D/S5, LRU, Mamba; see spyx.ssm).
The whole membrane trace \(V_0, \dots, V_{T-1}\) can be computed in
\(O(\log T)\) parallel depth instead of \(O(T)\). Spikes are then a pointwise
surrogate threshold applied to the whole trace, \(s_t = \sigma(V_t - \theta)\),
which is embarrassingly parallel.
The catch is the reset you gave up. A reset-free neuron never depresses after
firing, so it can spike on consecutive steps and its activity is less sparse.
On the reference machine below, spyx.nn.PSU_LIF fires
roughly 3× more often than a tuned LIF on the same input — you keep
activity bounded with the leak \(\beta\) and the threshold rather than with a
reset. Sparsity is an SNN's energy story (see the
benchmarking how-to on the spike_rate proxy), so
this is a real cost, not a free lunch: you buy parallel depth with density.
Two neurons that do this
Both neurons follow the standard Spyx contract — a stepwise
__call__(x_t, state) -> (spikes, state) usable in
spyx.nn.run, Sequential, and NIR — and additionally
expose a parallel(x) method that scores a time-major (T, B, …) sequence at
once via an associative scan. The two paths use the same parameters and the same
surrogate and are numerically identical: scanning the step reproduces the
parallel result exactly.
spyx.nn.PSU_LIF — parallel real leaky integrator
spyx.nn.PSU_LIF (Parallel Spiking Unit LIF) is the
reset-free real-valued neuron described above: \(V_t = \operatorname{clip}(\beta)
\, V_{t-1} + x_t\), with a per-unit learnable leak. Use it as a drop-in for LIF
when the reset is not essential and the sequence is long. It is the template for
reset-free parallel spiking neurons in Spyx.
spyx.phasor.ResonateFire — parallel complex oscillator
spyx.phasor.ResonateFire is the complex/oscillatory
sibling of PSU_LIF. Its membrane is a complex number evolving as a damped
harmonic oscillator,
with per-unit decay \(\lambda \ge 0\) and angular frequency \(\omega\). Real input
is injected into the real part; spikes come from a surrogate threshold on
\(\Re(z_t)\). Because it too is reset-free, the recurrence stays linear and the
same \(O(\log T)\) associative scan applies — only now over a complex pole a
instead of a real leak. Storing \(\lambda\) through a softplus keeps
\(|a| = e^{-\mathrm{dt}\,\lambda} \le 1\), so the oscillation never grows.
Following spyx.phasor.PhasorLinear, the pole
parameters are stored as real float32 params, so a stock optax +
jax.grad loop over a real loss trains them without the Wirtinger-conjugate
surprise. ResonateFire gives you tunable frequency selectivity that a
first-order leaky integrator cannot express.
When does parallel actually win?
The O(log T) depth is an asymptotic statement about the critical path, not a
guarantee about wall-clock time on your GPU. Whether the parallel scan beats the
sequential one depends on whether the device has slack. The numbers below
were measured with spyx.bench on an
AMD Radeon 8060S iGPU (gfx1151) on ROCm — treat them as an empirical
observation on that hardware, not a guarantee.
-
GPU saturated (large batch / large hidden size, throughput-bound). When the device is already busy doing useful arithmetic every step, the sequential scan's latency is largely hidden, and the parallel scan wins only modestly — roughly 1–3×. Here you are compute-bound and both schedules keep the GPU full.
-
GPU with slack (small batch, long sequences, latency-bound). When each step is too small to fill the device, the sequential scan spends most of its time waiting on its own
O(T)dependency chain while the GPU sits idle. The parallel scan collapses that chain toO(log T)and fills the device instead, and the measured speed-up climbs to 100×+ in the long-sequence, small-batch regime.
The crossover is exactly what the depth argument predicts: parallelism pays when
the sequential critical path — not the arithmetic — is the bottleneck. Use the
benchmarking how-to to find the crossover on your
hardware and workload before committing to one path; spyx.bench's default
driver automatically uses a neuron's parallel method when it has one, so
compare() measures the parallel path out of the box.
Neuromorphic export
Because PSU_LIF and ResonateFire expose the standard stepwise contract, they
participate in the NIR export/import flow like any other
neuron. NIR describes stepwise dynamics, so it is the sequential step — not the
parallel scan — that is the exportable object; the parallel path is a
mathematically equivalent way to evaluate the same recurrence on a GPU. NIR
support for these reset-free neurons is being extended, so treat their interop as
high-level for now and consult the NIR how-to for the layers
with tested round-trip parity.
Further reading
These neurons build directly on prior work on parallelizable spiking units — see
the Stochastic Parallelizable Spiking Neurons study and the
parallel_spiking_neurons work in the Spyx research directory,
which benchmarks PSU_LIF and ResonateFire against a hard-reset LIF.