Skip to content

How to quantize a model

To run quantization-aware training (QAT) on a Spyx SNN, use spyx.quant — a thin SNN-aware wrapper around Google's qwix library.

Prerequisite: install qwix

spyx.quant is built on qwix, which has no PyPI release. Because uv sources aren't transitive, the spyx[quant] extra only auto-resolves qwix inside the Spyx repo; in your own project install qwix from GitHub directly. This works with both uv and pip:

uv add  spyx "qwix @ git+https://github.com/google/qwix"
pip install spyx "qwix @ git+https://github.com/google/qwix"

Gate any quantization code on availability — import spyx.quant is always safe, and the helpers raise ImportError with these install instructions if you call them without qwix:

import spyx

if not spyx.quant.available():
    raise SystemExit(
        "quantization needs qwix: "
        'pip install "qwix @ git+https://github.com/google/qwix"'
    )

Quantize with the int8 defaults

To quantize a model for QAT, call spyx.quant.quantize with example inputs matching your model's __call__ signature (qwix traces the module graph to find the layers):

import jax.numpy as jnp
from flax import nnx
import spyx
import spyx.nn as snn

rngs = nnx.Rngs(0)
model = snn.Sequential(
    nnx.Linear(128, 64, use_bias=False, rngs=rngs),
    snn.LIF((64,), rngs=rngs),
    nnx.Linear(64, 20, use_bias=False, rngs=rngs),
    snn.LI((20,), rngs=rngs),
)

B = 32
sample_x = jnp.zeros((B, 128))               # one timestep of input
sample_state = model.initial_state(B)
qmodel = spyx.quant.quantize(model, sample_x, sample_state)

By default this applies int8 weights + activations to nnx.Linear / nnx.Conv layers only. The spiking dynamics (LIF, CuBaLIF, ALIF, IF) and the LI readout stay in fp32 — their state recurrences (V = beta * V + x - reset) involve cancellations that integer rounding tends to collapse into silence.

The returned qmodel is an ordinary NNX module: train it with spyx.optimize.fit or a hand-rolled nnx.Optimizer loop exactly as in How to train a model.

Choose a different precision with rules

To override the defaults, pass a list of qwix QuantizationRules via rules=. Spyx ships three shorthand factories:

# int4 weights + int8 activations on Linear / Conv:
rules = spyx.quant.linear_only_rules(weight_qtype="int4", act_qtype="int8")

# weights-only int8 (activations stay fp32) — for memory-bound deployment:
rules = spyx.quant.weights_only_rules("int8")

# BitNet b1.58-style ternary weights + int8 activations:
rules = spyx.quant.bitnet_ternary_rules()

qmodel = spyx.quant.quantize(model, sample_x, sample_state, rules=rules)

About the BitNet 'ternary' rules

Qwix doesn't expose a true ternary qtype today, so bitnet_ternary_rules falls back to "int2" (values in {-2, -1, 0, 1}). That gives the same memory profile and storage class as ternary; for strict {-1, 0, +1} semantics you'd need a custom qwix.QuantizationRule with a hand-rolled calibration. Pass act_qtype=None for pure weights-only ternary.

For anything the shorthands don't cover, build qwix.QuantizationRule instances directly — module_path is a regex over module names, e.g. r".*Linear.*".

Post-training quantization

To quantize an already-trained model without further training, pass mode="ptq":

qmodel = spyx.quant.quantize(model, sample_x, sample_state, mode="ptq")
  1. Train the fp32 model to convergence.
  2. Wrap it with quantize(...) (QAT mode) and fine-tune for a few epochs.
  3. Compare integral_accuracy between the fp32 and quantized models before deployment.

For a full worked example, see the Quantization-Aware Training notebook, and scripts/ssm_demo.py for quantizing the linear layers around an SSM.