spyx.optimize
High-level training loop that wraps nnx.Optimizer + nnx.value_and_grad. Use fit(...) for the common case or make_train_step / make_eval_step to roll your own loop.
High-level training utilities for Spyx SNNs.
Issue #26 asked for a "quick train/eval loop" so users don't have to
re-derive the nnx.Optimizer + nnx.value_and_grad + per-epoch boiler-
plate every time they build a new model. This module provides that, with a
minimum of magic:
- :func:
train_step— JIT-compiled single-step update. - :func:
eval_step— JIT-compiled single-step accuracy/loss. - :func:
fit— end-to-end Python epoch loop that iterates an iterable data source (anything yielding(events, targets)tuples — Spyx loader, generator, or plain list).
The utilities deliberately don't hide the loss / metric / optimizer choices.
Pass your own via spyx.fn.integral_crossentropy / optax.lion etc.
fit(model, tx, loss_fn, train_iter, *, epochs, eval_iter=None, eval_fn=None, on_epoch_end=None)
End-to-end training loop.
:param model: the Spyx / Flax NNX module to train.
:param tx: an Optax :class:GradientTransformation (e.g. optax.lion(3e-4)).
:param loss_fn: (model, *batch) -> loss. batch is whatever
train_iter yields.
:param train_iter: zero-arg callable returning a fresh iterable of
training batches each epoch. This matches the spyx.data.*_loader
convention where loader.train_epoch() is called per epoch.
:param epochs: number of training epochs.
:param eval_iter: optional zero-arg callable yielding evaluation batches.
:param eval_fn: optional (model, *batch) -> (accuracy, loss);
required if eval_iter is set.
:param on_epoch_end: optional callback (epoch, metrics_dict) -> None
for progress printing etc. Metrics dict carries keys
train_loss, plus eval_acc / eval_loss when evaluating.
:return: list of per-epoch metric dicts.
Source code in spyx/optimize.py
make_eval_step(metric_fn)
Build a JIT-compiled single-step evaluation callable.
:param metric_fn: closure taking (model, *metric_args) and returning
(accuracy_or_similar, loss).
Source code in spyx/optimize.py
make_train_step(loss_fn)
Build a JIT-compiled single-step updater.
The returned callable has signature (model, optimizer, *loss_args) ->
loss_value and mutates model / optimizer in place via NNX.
:param loss_fn: closure taking (model, *loss_args) and returning a
scalar loss. Typically wraps spyx.fn.integral_crossentropy().