Quantization-Aware Training for Spyx SNNs¶
This tutorial wires Google's qwix JAX quantization library into Spyx so you can train a spiking network with int8 weights/activations end-to-end.
Required extra: qwix has no PyPI release, and uv sources aren't transitive, so install it from GitHub directly (works with both uv and pip):
uv add spyx "qwix @ git+https://github.com/google/qwix"
pip install spyx "qwix @ git+https://github.com/google/qwix"
Why a wrapper?¶
qwix.quantize_model happily quantizes any Flax NNX module, but on a spiking network the right defaults are non-obvious:
- The dense
nnx.Linear/nnx.Convlayers benefit from int8 weights + activations. - The spiking dynamics (
LIF,CuBaLIF,ALIF,IF) and the leaky readout (LI) are very sensitive to integer rounding because they recurse on the membrane potential.
spyx.quant.linear_only_rules() encodes this default. qwix matches a rule's module_path against the NNX attribute path (e.g. core/layers/0), never the class name — so instead of a .*Linear.* regex it selects the dense/conv ops (dot_general, conv_general_dilated). The elementwise neuron updates use neither op and stay in fp32 automatically.
import jax
import jax.numpy as jnp
import optax
from flax import nnx
import spyx
import spyx.nn as snn
assert spyx.quant.available(), (
"qwix is not installed. Install with `uv pip install \"spyx[quant]\"`."
)
Build a small SNN¶
We use a tiny network for demonstration; the same pattern scales to the full SHD trainer in docs/examples/surrogate_gradient/SurrogateGradientTutorial.ipynb.
import spyx.nn as snn
T, IN_DIM, HIDDEN, N_CLASSES = 16, 32, 64, 10
BATCH = 8
class SNNClassifier(nnx.Module):
"""A tiny LIF SNN driven over ``T`` timesteps.
``snn.run`` scans the ``Sequential`` core over the time axis so the
neurons actually integrate input and spike; the readout is the leaky-
integrator trace with shape ``(B, T, N_CLASSES)``.
"""
def __init__(self, seed=0):
rngs = nnx.Rngs(seed)
self.core = snn.Sequential(
nnx.Linear(IN_DIM, HIDDEN, use_bias=False, rngs=rngs),
snn.LIF((HIDDEN,), activation=spyx.axn.triangular(), rngs=rngs),
nnx.Linear(HIDDEN, HIDDEN, use_bias=False, rngs=rngs),
snn.LIF((HIDDEN,), activation=spyx.axn.triangular(), rngs=rngs),
nnx.Linear(HIDDEN, N_CLASSES, use_bias=False, rngs=rngs),
snn.LI((N_CLASSES,), rngs=rngs),
)
def __call__(self, x_TBC):
traces, _ = snn.run(self.core, x_TBC) # (T, B, N_CLASSES)
return jnp.transpose(traces, (1, 0, 2)) # (B, T, N_CLASSES)
def make_model(seed=0):
return SNNClassifier(seed)
# Poisson-style spike trains (T, B, IN) so the LIF layers fire.
fp_model = make_model()
sample_x = (
jax.random.uniform(jax.random.PRNGKey(0), (T, BATCH, IN_DIM)) < 0.2
).astype(jnp.float32)
fp_out = fp_model(sample_x)
print("fp32 output shape:", fp_out.shape) # (B, T, N_CLASSES)
fp32 output shape: (8, 16, 10)
Wrap with spyx.quant.quantize¶
quantize traces the model with the example inputs (so qwix can discover the modules), then returns a new nnx.Module whose Linear layers are wrapped in qwix.QArray with int8 weights and activations. The default mode is qat (quantization-aware training); pass mode="ptq" for post-training quantization.
qmodel = spyx.quant.quantize(fp_model, sample_x)
q_out = qmodel(sample_x)
print("quantized output shape:", q_out.shape)
print("max abs diff vs fp32:", float(jnp.max(jnp.abs(q_out - fp_out))))
quantized output shape: (8, 16, 10)
max abs diff vs fp32: 0.2541166841983795
Train through the quantized model¶
The QAT model fits straight into the standard NNX training loop: nnx.Optimizer(qmodel, optax.lion(...), wrt=nnx.Param) plus nnx.value_and_grad. The straight-through estimator inside qwix passes gradients through the quant/dequant pair, so the loss is fully differentiable.
Loss = spyx.fn.integral_crossentropy()
optimizer = nnx.Optimizer(qmodel, optax.lion(1e-3), wrt=nnx.Param)
# One fixed batch to overfit, so you can watch the QAT loss drop.
x = (
jax.random.uniform(jax.random.PRNGKey(1), (T, BATCH, IN_DIM)) < 0.2
).astype(jnp.float32)
targets = jax.random.randint(jax.random.PRNGKey(2), (BATCH,), 0, N_CLASSES)
@nnx.jit
def train_step(model, optimizer, x, targets):
def loss_fn(m):
return Loss(m(x), targets)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
for step in range(50):
loss = float(train_step(qmodel, optimizer, x, targets))
if step % 10 == 0:
print(f"step {step:2d}: loss={loss:.4f}")
step 0: loss=2.3714
step 10: loss=2.1373 step 20: loss=1.9673 step 30: loss=1.9073 step 40: loss=1.8112
Customising the quantization rules¶
spyx.quant.linear_only_rules is a thin shorthand. Tweak the dtypes, or use qwix.QuantizationRule directly for finer control.
import qwix
# int4 weights, fp32 activations - good for memory-bound deployment.
rules = spyx.quant.weights_only_rules(weight_qtype="int4")
qmodel_int4 = spyx.quant.quantize(make_model(), sample_x, rules=rules)
print(
"int4 weights, max diff vs fp32:",
float(jnp.max(jnp.abs(qmodel_int4(sample_x) - fp_out))),
)
# Custom: quantize only the readout Linear (path core/layers/4), leaving
# the hidden layers in fp32. module_path is matched (re.fullmatch) against
# the NNX attribute path; op_names narrows to the matmul in that scope.
custom_rules = [
qwix.QuantizationRule(
module_path=r".*layers/4",
op_names=("dot_general",),
weight_qtype="int8",
act_qtype="int8",
),
]
qmodel_readout = spyx.quant.quantize(make_model(), sample_x, rules=custom_rules)
print(
"readout-only quant, max diff vs fp32:",
float(jnp.max(jnp.abs(qmodel_readout(sample_x) - fp_out))),
)
int4 weights, max diff vs fp32: 0.33079519867897034
readout-only quant, max diff vs fp32: 0.0026528537273406982
Notes¶
- For full QAT runs on real datasets, plug
spyx.quant.quantizeinto any of the SHD tutorials right after model construction. The training loop is unchanged. - Qwix supports
int1-int8,nf4, andfp8. The native fast paths areint4,int8, andfp8on TPU/H100; everything else is emulated. - For pure mixed-precision (bf16/fp16) without integer quantization, set
jax.config.update("jax_default_matmul_precision", "bfloat16")instead - qwix is overkill for that case.