spyx.bench
Benchmarking harness for Spyx neurons and models. benchmark measures median forward and forward+backward latency, throughput, peak memory, XLA-cost-model FLOPs/MFU, and the spike_rate energy proxy for any module following the spyx.nn stateful contract; compare sweeps a set of modules across sequence lengths and format_table renders the results. See the benchmarking how-to for the timing methodology and runnable examples.
Benchmarking harness for Spyx neurons and models.
This module measures both efficiency (latency, throughput, peak memory,
FLOPs, model-FLOP-utilisation) and a spiking-specific performance proxy
(spike rate) for any Spyx / Flax NNX module that follows the spyx.nn
stateful contract.
Timing methodology (this is the load-bearing part):
- Inputs are built time-major with shape
(seq_len, batch, *input_shape)and the module is driven over time with :func:spyx.nn.run(ajax.lax.scan), exactly like training. - The timed function is JIT-compiled and the first
n_warmupiterations are discarded so we never time tracing/compilation. - Because JAX dispatches asynchronously, every timed call is followed by
:func:
jax.block_until_readyon its outputs before the timer is stopped. Without this the numbers are meaningless (you would only be timing dispatch). - We report the median over
n_itersiterations, which is far more robust to OS jitter / GC pauses than the mean.
FLOPs come from XLA's own cost model:
jax.jit(f).lower(...).compile().cost_analysis()['flops'] when the backend
exposes it (None otherwise). MFU is flops_per_second / device_peak_flops
using a small hard-coded peak-FLOPs table; when the device is unknown the peak
is None and MFU is reported as None rather than guessed. The spike rate
is the mean fraction of non-zero output activations, i.e. the standard
event-driven energy proxy for SNNs.
BenchResult
dataclass
Container for a single benchmark measurement.
All latency fields are the median over the timed iterations. Fields that
could not be determined on the current backend are None rather than a
fabricated value.
Source code in spyx/bench.py
benchmark(module, input_shape, *, seq_len, batch, n_warmup=3, n_iters=20, backward=True, run_fn=None, name=None, key=None, dtype=jnp.float32)
Benchmark a single Spyx module / neuron.
:param module: an nnx.Module or a zero-arg thunk returning one.
:param input_shape: per-timestep feature shape (everything after batch).
:param seq_len: number of timesteps T in the time-major input.
:param batch: batch size B.
:param n_warmup: untimed warmup iterations (compilation + first-run).
:param n_iters: timed iterations; the median is reported.
:param backward: also time a value_and_grad of mean(outputs).
:param run_fn: optional (module, x) -> outputs override; defaults to a
module parallel method if present, else :func:spyx.nn.run.
:param name: label for the result; defaults to the module class name.
:param key: PRNG key for the random input (defaults to a fixed seed so
results are deterministic).
:param dtype: dtype of the generated input.
:return: a populated :class:BenchResult.
Source code in spyx/bench.py
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 | |
compare(modules, input_shape, *, seq_lens, batch, n_warmup=3, n_iters=20, backward=True, run_fn=None, key=None, dtype=jnp.float32)
Sweep seq_lens x modules and return one result per combination.
Passing thunks (zero-arg builders) as the dict values is recommended so each sweep point gets a fresh module instance. The results are ordered seq_len-outer, module-inner.
:param modules: mapping of label -> module or thunk.
:param input_shape: per-timestep feature shape.
:param seq_lens: list of sequence lengths to sweep.
:param batch: batch size shared across the sweep.
:return: flat list of :class:BenchResult.
Source code in spyx/bench.py
format_table(results)
Pretty-print benchmark results as an aligned plain-text table.