spyx.fn
Losses, metrics, and activity regularisers. All factories return JIT-compiled callables of signature (traces, targets) -> ...; shape checks raise ValueError at trace time if traces and targets disagree.
Losses, metrics, and activity regularisers for spiking networks.
Every public function in this module is a factory: it returns a
JIT-compiled callable that takes network outputs and targets and returns
a scalar loss (or a (metric, predictions) tuple).
Signatures at a glance
- Losses:
(traces, targets) -> losswheretraceshas shape(..., time, classes)andtargetshas the batch shape (everything beforetimeand withoutclasses). - Metrics:
(traces, targets) -> (score, predictions). - Regularisers:
(spike_pytree) -> loss.
All three check argument shapes at trace time and raise ValueError
early if the target / prediction layout doesn't line up — see
:func:_check_traces_vs_targets.
LossFn = Callable[[jax.Array, jax.Array], jax.Array]
module-attribute
Type alias for (traces, targets) -> loss.
MetricFn = Callable[[jax.Array, jax.Array], tuple[jax.Array, jax.Array]]
module-attribute
Type alias for (traces, targets) -> (score, predictions).
RegFn = Callable[[Any], jax.Array]
module-attribute
Type alias for (spike_pytree) -> scalar.
integral_accuracy(time_axis=1)
Calculate the accuracy of a network's predictions based on the voltage traces. Used in combination with a Leaky-Integrate neuron model as the final layer.
:param traces: the output of the final layer of the SNN :param targets: the integer labels for each class :return: function which computes Accuracy score and predictions that takes SNN output traces and integer index labels.
Source code in spyx/fn.py
integral_crossentropy(smoothing=0.3, time_axis=1)
Calculate the crossentropy between the integral of membrane potentials. Allows for label smoothing to discourage silencing the other neurons in the readout layer.
:param smoothing: rate at which to smooth labels. :param time_axis: temporal axis of data :return: crossentropy loss function that takes SNN output traces and integer index labels.
Source code in spyx/fn.py
mse_spikerate(sparsity=0.25, smoothing=0.0, time_axis=1)
Calculate the mean squared error of the mean spike rate. Allows for label smoothing to discourage silencing the other neurons in the readout layer.
:param sparsity: the percentage of the time you want the neurons to spike :param smoothing: [optional] rate at which to smooth labels. :return: Mean-Squared-Error loss function on the spike rate that takes SNN output traces and integer index labels.
Source code in spyx/fn.py
silence_reg(min_spikes)
L2-Norm per-neuron activation normalization for spiking less than a target number of times.
:param min_spikes: neurons which spike below this value on average over the batch incur quadratic penalty. :return: JIT compiled regularization function.
Source code in spyx/fn.py
sparsity_reg(max_spikes, norm=optax.huber_loss)
Layer activation normalization that seeks to discourage all neurons having a high firing rate.
:param max_spikes: Threshold for which penalty is incurred if the average number of spikes in the layer exceeds it. :param norm: an Optax loss function. Default is Huber loss. :return: JIT compiled regularization function.