How to quantize a model
To run quantization-aware training (QAT) on a Spyx SNN, use spyx.quant — a thin SNN-aware wrapper around Google's qwix library.
Prerequisite: install qwix
spyx.quant is built on qwix, which has no
PyPI release. Because uv sources aren't transitive, the spyx[quant] extra only
auto-resolves qwix inside the Spyx repo; in your own project install qwix from
GitHub directly. This works with both uv and pip:
uv add spyx "qwix @ git+https://github.com/google/qwix"
pip install spyx "qwix @ git+https://github.com/google/qwix"
Gate any quantization code on availability — import spyx.quant is always safe, and the helpers raise ImportError with these install instructions if you call them without qwix:
import spyx
if not spyx.quant.available():
raise SystemExit(
"quantization needs qwix: "
'pip install "qwix @ git+https://github.com/google/qwix"'
)
Quantize with the int8 defaults
To quantize a model for QAT, call spyx.quant.quantize with example inputs matching your model's __call__ signature (qwix traces the module graph to find the layers):
import jax.numpy as jnp
from flax import nnx
import spyx
import spyx.nn as snn
rngs = nnx.Rngs(0)
model = snn.Sequential(
nnx.Linear(128, 64, use_bias=False, rngs=rngs),
snn.LIF((64,), rngs=rngs),
nnx.Linear(64, 20, use_bias=False, rngs=rngs),
snn.LI((20,), rngs=rngs),
)
B = 32
sample_x = jnp.zeros((B, 128)) # one timestep of input
sample_state = model.initial_state(B)
qmodel = spyx.quant.quantize(model, sample_x, sample_state)
By default this applies int8 weights + activations to nnx.Linear / nnx.Conv layers only. The spiking dynamics (LIF, CuBaLIF, ALIF, IF) and the LI readout stay in fp32 — their state recurrences (V = beta * V + x - reset) involve cancellations that integer rounding tends to collapse into silence.
The returned qmodel is an ordinary NNX module: train it with spyx.optimize.fit or a hand-rolled nnx.Optimizer loop exactly as in How to train a model.
Choose a different precision with rules
To override the defaults, pass a list of qwix QuantizationRules via rules=. Spyx ships three shorthand factories:
# int4 weights + int8 activations on Linear / Conv:
rules = spyx.quant.linear_only_rules(weight_qtype="int4", act_qtype="int8")
# weights-only int8 (activations stay fp32) — for memory-bound deployment:
rules = spyx.quant.weights_only_rules("int8")
# BitNet b1.58-style ternary weights + int8 activations:
rules = spyx.quant.bitnet_ternary_rules()
qmodel = spyx.quant.quantize(model, sample_x, sample_state, rules=rules)
About the BitNet 'ternary' rules
Qwix doesn't expose a true ternary qtype today, so bitnet_ternary_rules falls back to "int2" (values in {-2, -1, 0, 1}). That gives the same memory profile and storage class as ternary; for strict {-1, 0, +1} semantics you'd need a custom qwix.QuantizationRule with a hand-rolled calibration. Pass act_qtype=None for pure weights-only ternary.
For anything the shorthands don't cover, build qwix.QuantizationRule instances directly — module_path is a regex over module names, e.g. r".*Linear.*".
Post-training quantization
To quantize an already-trained model without further training, pass mode="ptq":
The recommended workflow
- Train the fp32 model to convergence.
- Wrap it with
quantize(...)(QAT mode) and fine-tune for a few epochs. - Compare
integral_accuracybetween the fp32 and quantized models before deployment.
For a full worked example, see the Quantization-Aware Training notebook, and scripts/ssm_demo.py for quantizing the linear layers around an SSM.