spyx.axn
Surrogate-gradient factories that return JIT-compiled jax.custom_gradient functions. Pass the returned callable as the activation argument to any spiking neuron.
Surrogate-gradient activations for spiking neurons.
Each public factory in this module returns a JIT-compiled
jax.custom_gradient function of signature (x: jax.Array) -> jax.Array
suitable for passing to the activation= argument of any neuron in
spyx.nn. The forward pass is always the Heaviside step (spike / no
spike); the factories differ only in the surrogate they expose to the
backward pass.
Activation = Callable[[jax.Array], jax.Array]
module-attribute
Type alias for a surrogate-gradient activation function.
A mapping from a pre-activation tensor to a binary spike tensor of the
same shape. Produced by :func:custom, :func:superspike,
:func:arctan, and friends.
arctan(k=2)
This class implements the Arctangent surrogate gradient activation function for a spiking neuron.
The Arctangent function returns a value between -pi/2 and pi/2 for inputs in the range of -Infinity to Infinity. It is often used in the context of spiking neurons because it provides a smooth approximation to the step function that is differentiable everywhere, which is a requirement for gradient-based optimization methods.
:k: A scaling factor that can be used to adjust the steepness of the Arctangent function. Default is 2. :return: JIT compiled arctangent-derived surrogate gradient function.
Source code in spyx/axn.py
boxcar(width=2, height=0.5)
Boxcar surrogate gradient.
The forward pass is the Heaviside step; the backward pass uses a
rectangular window of half-width width/2 centred at zero:
.. math:: g(x) = \begin{cases} h & |x| \le w/2 \ 0 & \text{otherwise} \end{cases}
The boxcar is the simplest symmetric surrogate and has been shown to train competitively on SHD despite its discontinuity.
:width: Total width :math:w of the non-zero gradient window, centred
on zero.
:height: Value :math:h of the gradient inside the window.
:return: JIT-compiled boxcar surrogate gradient function.
Source code in spyx/axn.py
custom(bwd=lambda x: x, fwd=lambda x: heaviside(x))
Activation with a user-supplied surrogate gradient.
Used as the spiking nonlinearity inside every Spyx neuron. The default
fwd is the Heaviside step and the default bwd is the identity,
which together give the straight-through estimator (STE):
.. math:: y = \mathrm{Heaviside}(x), \qquad \frac{\partial y}{\partial x} \approx 1.
It is assumed that the input has already had its threshold subtracted by the calling neuron model.
:bwd: Function that computes the surrogate gradient :math:g(x) used during
the backward pass. Should return an array of the same shape as x.
:fwd: Forward activation / spiking function. Defaults to
:func:heaviside centred at zero.
:return: A JIT-compiled activation function comprised of the specified
forward and backward functions.
Source code in spyx/axn.py
heaviside(x)
superspike(k=25)
This function implements the SuperSpike surrogate gradient activation function for a spiking neuron.
The SuperSpike function is defined as 1/(1+k|U|)^2, where U is the input to the function and k is a scaling factor. It returns a value between 0 and 1 for inputs in the range of -Infinity to Infinity.
It is often used in the context of spiking neurons because it provides a smooth approximation to the step function that is differentiable everywhere, which is a requirement for gradient-based optimization methods.
It is a fast approximation of the Sigmoid function adapted from:
F. Zenke, S. Ganguli (2018) SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks. Neural Computation, pp. 1514-1541.
:k: A scaling factor that can be used to adjust the steepness of the SuperSpike function. Default is 25. :return: JIT compiled SuperSpike surrogate gradient function.
Source code in spyx/axn.py
tanh(k=1)
Hyperbolic-tangent surrogate gradient.
The forward pass is the Heaviside step; the backward pass uses the
derivative of :math:\tanh(kx):
.. math:: g(x) = \frac{4}{(e^{-kx} + e^{kx})^2}.
:k: Slope factor. Larger values make the gradient more peaked around the threshold and closer to a true Heaviside derivative. :return: JIT-compiled tanh surrogate gradient function.
Source code in spyx/axn.py
triangular(k=2)
Triangular activation inspired by Esser et. al. https://arxiv.org/abs/1603.08270
.. math:: max(0, 1-|kx|)
:k: scale factor :return: JIT compiled triangular surrogate gradient function.