Skip to content

API Reference

This page provides documentation for the core Spyx modules.

Spyx Global Functions

integral_accuracy(time_axis=1)

Calculate the accuracy of a network's predictions based on the voltage traces. Used in combination with a Leaky-Integrate neuron model as the final layer.

:param traces: the output of the final layer of the SNN :param targets: the integer labels for each class :return: function which computes Accuracy score and predictions that takes SNN output traces and integer index labels.

Source code in spyx/fn.py
def integral_accuracy(time_axis=1):
    """Calculate the accuracy of a network's predictions based on the voltage traces. Used in combination with a Leaky-Integrate neuron model as the final layer.

    :param traces: the output of the final layer of the SNN
    :param targets: the integer labels for each class
    :return: function which computes Accuracy score and predictions that takes SNN output traces and integer index labels.
    """
    def _integral_accuracy(traces, targets):
        preds = jnp.argmax(jnp.sum(traces, axis=time_axis), axis=-1)
        return jnp.mean(preds == targets), preds

    return jax.jit(_integral_accuracy)

integral_crossentropy(smoothing=0.3, time_axis=1)

Calculate the crossentropy between the integral of membrane potentials. Allows for label smoothing to discourage silencing the other neurons in the readout layer.

:param smoothing: rate at which to smooth labels. :param time_axis: temporal axis of data :return: crossentropy loss function that takes SNN output traces and integer index labels.

Source code in spyx/fn.py
def integral_crossentropy(smoothing=0.3, time_axis=1):
    """Calculate the crossentropy between the integral of membrane potentials. Allows for label smoothing to discourage silencing the other neurons in the readout layer.

    :param smoothing: rate at which to smooth labels.
    :param time_axis: temporal axis of data
    :return: crossentropy loss function that takes SNN output traces and integer index labels.
    """

    def _integral_crossentropy(traces, targets):
        logits = jnp.sum(traces, axis=time_axis) # time axis.
        one_hot = jax.nn.one_hot(targets, logits.shape[-1])
        labels = optax.smooth_labels(one_hot, smoothing)
        return jnp.mean(optax.softmax_cross_entropy(logits, labels))

    return _integral_crossentropy

mse_spikerate(sparsity=0.25, smoothing=0.0, time_axis=1)

Calculate the mean squared error of the mean spike rate. Allows for label smoothing to discourage silencing the other neurons in the readout layer.

:param sparsity: the percentage of the time you want the neurons to spike :param smoothing: [optional] rate at which to smooth labels. :return: Mean-Squared-Error loss function on the spike rate that takes SNN output traces and integer index labels.

Source code in spyx/fn.py
def mse_spikerate(sparsity=0.25, smoothing=0.0, time_axis=1):
    """Calculate the mean squared error of the mean spike rate. Allows for label smoothing to discourage silencing the other neurons in the readout layer.

    :param sparsity: the percentage of the time you want the neurons to spike
    :param smoothing: [optional] rate at which to smooth labels.
    :return: Mean-Squared-Error loss function on the spike rate that takes SNN output traces and integer index labels.
    """
    def _mse_spikerate(traces, targets):

        t = traces.shape[time_axis]
        logits = jnp.mean(traces, axis=time_axis) # time axis.
        labels = optax.smooth_labels(jax.nn.one_hot(targets, logits.shape[-1]), smoothing)
        return jnp.mean(optax.squared_error(logits, labels * sparsity * t))

    return jax.jit(_mse_spikerate)

silence_reg(min_spikes)

L2-Norm per-neuron activation normalization for spiking less than a target number of times.

:param min_spikes: neurons which spike below this value on average over the batch incur quadratic penalty. :return: JIT compiled regularization function.

Source code in spyx/fn.py
def silence_reg(min_spikes):
    """L2-Norm per-neuron activation normalization for spiking less than a target number of times.

    :param min_spikes: neurons which spike below this value on average over the batch incur quadratic penalty.
    :return: JIT compiled regularization function.
    """
    def _loss(x):
        return (jnp.maximum(0, min_spikes-jnp.mean(x, axis=0)))**2

    def _flatten(x):
        return jnp.reshape(x, (x.shape[0], -1))

    def _call(spikes):
        flat_spikes = tree.tree_map(_flatten, spikes)
        loss_vectors = tree.tree_map(_loss, flat_spikes)
        return jnp.sum(jnp.concatenate(tree.tree_flatten(loss_vectors)[0]))

    return jax.jit(_call)

sparsity_reg(max_spikes, norm=optax.huber_loss)

Layer activation normalization that seeks to discourage all neurons having a high firing rate.

:param max_spikes: Threshold for which penalty is incurred if the average number of spikes in the layer exceeds it. :param norm: an Optax loss function. Default is Huber loss. :return: JIT compiled regularization function.

Source code in spyx/fn.py
def sparsity_reg(max_spikes, norm=optax.huber_loss):
    """Layer activation normalization that seeks to discourage all neurons having a high firing rate.

    :param max_spikes: Threshold for which penalty is incurred if the average number of spikes in the layer exceeds it.
    :param norm: an Optax loss function. Default is Huber loss.
    :return: JIT compiled regularization function. 
    """
    def _loss(x):
        return norm(jnp.maximum(0, jnp.mean(x, axis=-1) - max_spikes)) # this may not work for convolution layers....

    def _flatten(x):
        return jnp.reshape(x, (x.shape[0], -1))

    def _call(spikes):
        flat_spikes = tree.tree_map(_flatten, spikes)
        loss_vectors = tree.tree_map(_loss, flat_spikes)
        return jnp.sum(jnp.concatenate(tree.tree_flatten(loss_vectors)[0]))

    return jax.jit(_call)

Neural Networks

ALIF

Bases: RNNCore

Adaptive LIF Neuron based on the model used in LSNNs:

Bellec, G., Salaj, D., Subramoney, A., Legenstein, R. & Maass, W. Long short- term memory and learning-to-learn in networks of spiking neurons. 32nd Conference on Neural Information Processing Systems (2018).

Source code in spyx/nn.py
class ALIF(hk.RNNCore): 
    """
    Adaptive LIF Neuron based on the model used in LSNNs:

    Bellec, G., Salaj, D., Subramoney, A., Legenstein, R. & Maass, W. 
    Long short- term memory and learning-to-learn in networks of spiking neurons. 
    32nd Conference on Neural Information Processing Systems (2018).

    """


    def __init__(self, hidden_shape, beta=None, gamma=None,
                 threshold = 1,
                 activation = superspike(),
                 name="ALIF"):

        """
        :hidden_shape: Hidden layer shape.
        :beta: Membrane decay/inverse time constant.
        :gamma: Threshold adaptation constant.
        :threshold: Neuron firing threshold.
        :activation: spyx.axn.Axon object determining forward function and surrogate gradient function.
        """

        super().__init__(name=name)
        self.hidden_shape = hidden_shape
        self.beta = beta
        self.gamma = gamma
        self.threshold = threshold
        self.spike = activation

    def __call__(self, x, VT):
        """
        :x: Tensor from previous layer.
        :VT: Neuron state vector.
        """

        V, T = jnp.split(VT, 2, -1)

        gamma = self.gamma
        beta = self.beta
        # threshold adaptation
        if not gamma:
            gamma = hk.get_parameter("gamma", self.hidden_shape, 
                                 init=hk.initializers.TruncatedNormal(0.25, 0.5))
            gamma = jnp.clip(gamma, 0, 1)
        else:
            gamma = hk.get_parameter("gamma", [],
                                init=hk.initializers.Constant(gamma))
            gamma = jnp.clip(gamma, 0, 1)

        if not beta:
            beta = hk.get_parameter("beta", self.hidden_shape, 
                                init=hk.initializers.TruncatedNormal(0.25, 0.5))
            beta = jnp.clip(beta, 0, 1)
        else:
            beta = hk.get_parameter("beta", [],
                                init=hk.initializers.Constant(beta))
            beta = jnp.clip(beta, 0, 1)

        # calculate whether spike is generated, and update membrane potential
        thresh = self.threshold + T
        spikes = self.spike(V - thresh) # T is the dynamic threshold adaptation
        V = beta*V + x - spikes*thresh
        T = gamma*T + (1-gamma)*spikes

        VT = jnp.concatenate([V,T], axis=-1)
        return spikes, VT

    # not sure if this is borked.
    def initial_state(self, batch_size): # this might need fixed to match CuBaLIF...
        return jnp.zeros((batch_size,) + tuple(2*s for s in self.hidden_shape))

__call__(x, VT)

:x: Tensor from previous layer. :VT: Neuron state vector.

Source code in spyx/nn.py
def __call__(self, x, VT):
    """
    :x: Tensor from previous layer.
    :VT: Neuron state vector.
    """

    V, T = jnp.split(VT, 2, -1)

    gamma = self.gamma
    beta = self.beta
    # threshold adaptation
    if not gamma:
        gamma = hk.get_parameter("gamma", self.hidden_shape, 
                             init=hk.initializers.TruncatedNormal(0.25, 0.5))
        gamma = jnp.clip(gamma, 0, 1)
    else:
        gamma = hk.get_parameter("gamma", [],
                            init=hk.initializers.Constant(gamma))
        gamma = jnp.clip(gamma, 0, 1)

    if not beta:
        beta = hk.get_parameter("beta", self.hidden_shape, 
                            init=hk.initializers.TruncatedNormal(0.25, 0.5))
        beta = jnp.clip(beta, 0, 1)
    else:
        beta = hk.get_parameter("beta", [],
                            init=hk.initializers.Constant(beta))
        beta = jnp.clip(beta, 0, 1)

    # calculate whether spike is generated, and update membrane potential
    thresh = self.threshold + T
    spikes = self.spike(V - thresh) # T is the dynamic threshold adaptation
    V = beta*V + x - spikes*thresh
    T = gamma*T + (1-gamma)*spikes

    VT = jnp.concatenate([V,T], axis=-1)
    return spikes, VT

__init__(hidden_shape, beta=None, gamma=None, threshold=1, activation=superspike(), name='ALIF')

:hidden_shape: Hidden layer shape. :beta: Membrane decay/inverse time constant. :gamma: Threshold adaptation constant. :threshold: Neuron firing threshold. :activation: spyx.axn.Axon object determining forward function and surrogate gradient function.

Source code in spyx/nn.py
def __init__(self, hidden_shape, beta=None, gamma=None,
             threshold = 1,
             activation = superspike(),
             name="ALIF"):

    """
    :hidden_shape: Hidden layer shape.
    :beta: Membrane decay/inverse time constant.
    :gamma: Threshold adaptation constant.
    :threshold: Neuron firing threshold.
    :activation: spyx.axn.Axon object determining forward function and surrogate gradient function.
    """

    super().__init__(name=name)
    self.hidden_shape = hidden_shape
    self.beta = beta
    self.gamma = gamma
    self.threshold = threshold
    self.spike = activation

ActivityRegularization

Bases: Module

Add state to the SNN to track the average number of spikes emitted per neuron per batch.

Adding this to a network requires using the Haiku transform_with_state transform, which will also return an initial regularization state vector. This blank initial vector can be reused and is provided as the second arg to the SNN's apply function.

Source code in spyx/nn.py
class ActivityRegularization(hk.Module):
    """
    Add state to the SNN to track the average number of spikes emitted per neuron per batch.

    Adding this to a network requires using the Haiku transform_with_state transform, which will also return an initial regularization state vector.
    This blank initial vector can be reused and is provided as the second arg to the SNN's apply function. 
    """

    def __init__(self, name="ActReg"):
        super().__init__(name=name)

    def __call__(self, spikes):
        spike_count = hk.get_state("spike_count", spikes.shape, init=jnp.zeros, dtype=spikes.dtype)
        hk.set_state("spike_count", spike_count + spikes) 
        return spikes

IF

Bases: RNNCore

Integrate and Fire neuron model. While not being as powerful/rich as other neuron models, they are very easy to implement in hardware.

Source code in spyx/nn.py
class IF(hk.RNNCore): 
    """
    Integrate and Fire neuron model. While not being as powerful/rich as other neuron models, they are very easy to implement in hardware.

    """

    def __init__(self, hidden_shape,
                 threshold = 1,
                 activation = superspike(),
                 name="IF"):
        """
        :hidden_size: Size of preceding layer's outputs
        :threshold: threshold for reset. Defaults to 1.
        :activation: spyx.activation function, default is Heaviside with Straight-Through-Estimation.
        """
        super().__init__(name=name)
        self.hidden_shape = hidden_shape
        self.threshold = threshold
        self.spike = activation

    def __call__(self, x, V):
        """
        :x: Vector coming from previous layer.
        :V: Neuron state tensor.
        """
        # calculate whether spike is generated, and update membrane potential
        spikes = self.spike(V-self.threshold)
        V = V + x - spikes*self.threshold

        return spikes, V

    def initial_state(self, batch_size): 
        return jnp.zeros((batch_size,) + self.hidden_shape)

__call__(x, V)

:x: Vector coming from previous layer. :V: Neuron state tensor.

Source code in spyx/nn.py
def __call__(self, x, V):
    """
    :x: Vector coming from previous layer.
    :V: Neuron state tensor.
    """
    # calculate whether spike is generated, and update membrane potential
    spikes = self.spike(V-self.threshold)
    V = V + x - spikes*self.threshold

    return spikes, V

__init__(hidden_shape, threshold=1, activation=superspike(), name='IF')

:hidden_size: Size of preceding layer's outputs :threshold: threshold for reset. Defaults to 1. :activation: spyx.activation function, default is Heaviside with Straight-Through-Estimation.

Source code in spyx/nn.py
def __init__(self, hidden_shape,
             threshold = 1,
             activation = superspike(),
             name="IF"):
    """
    :hidden_size: Size of preceding layer's outputs
    :threshold: threshold for reset. Defaults to 1.
    :activation: spyx.activation function, default is Heaviside with Straight-Through-Estimation.
    """
    super().__init__(name=name)
    self.hidden_shape = hidden_shape
    self.threshold = threshold
    self.spike = activation

LI

Bases: RNNCore

Leaky-Integrate (Non-spiking) neuron model.

Source code in spyx/nn.py
class LI(hk.RNNCore):
    """
    Leaky-Integrate (Non-spiking) neuron model.


    """

    def __init__(self, layer_shape, beta=None, name="LI"):
        """

        :layer_size: Number of output neurons from the previous linear layer.
        :beta: Decay rate on membrane potential (voltage). Set uniformly across the layer.
        """
        super().__init__(name=name)
        self.layer_shape = layer_shape
        self.beta = beta

    def __call__(self, x, Vin):
        """
        :x: Input tensor from previous layer.
        :Vin: Neuron state tensor. 
        """
        beta = self.beta
        if not beta:
            beta = hk.get_parameter("beta", self.layer_shape,
                                init=hk.initializers.Constant(0.8))
            beta = jnp.clip(beta, 0, 1)
        else:
            beta = hk.get_parameter("beta", [],
                                init=hk.initializers.Constant(beta))
            beta = jnp.clip(beta, 0, 1)

        # calculate whether spike is generated, and update membrane potential
        Vout = beta*Vin + x
        return Vout, Vout

    def initial_state(self, batch_size):
        return jnp.zeros((batch_size,) + self.layer_shape)

__call__(x, Vin)

:x: Input tensor from previous layer. :Vin: Neuron state tensor.

Source code in spyx/nn.py
def __call__(self, x, Vin):
    """
    :x: Input tensor from previous layer.
    :Vin: Neuron state tensor. 
    """
    beta = self.beta
    if not beta:
        beta = hk.get_parameter("beta", self.layer_shape,
                            init=hk.initializers.Constant(0.8))
        beta = jnp.clip(beta, 0, 1)
    else:
        beta = hk.get_parameter("beta", [],
                            init=hk.initializers.Constant(beta))
        beta = jnp.clip(beta, 0, 1)

    # calculate whether spike is generated, and update membrane potential
    Vout = beta*Vin + x
    return Vout, Vout

__init__(layer_shape, beta=None, name='LI')

:layer_size: Number of output neurons from the previous linear layer. :beta: Decay rate on membrane potential (voltage). Set uniformly across the layer.

Source code in spyx/nn.py
def __init__(self, layer_shape, beta=None, name="LI"):
    """

    :layer_size: Number of output neurons from the previous linear layer.
    :beta: Decay rate on membrane potential (voltage). Set uniformly across the layer.
    """
    super().__init__(name=name)
    self.layer_shape = layer_shape
    self.beta = beta

LIF

Bases: RNNCore

Leaky Integrate and Fire neuron model inspired by the implementation in snnTorch:

https://snntorch.readthedocs.io/en/latest/snn.neurons_leaky.html

Source code in spyx/nn.py
class LIF(hk.RNNCore):
    """
    Leaky Integrate and Fire neuron model inspired by the implementation in
    snnTorch:

    https://snntorch.readthedocs.io/en/latest/snn.neurons_leaky.html

    """

    def __init__(self, 
                 hidden_shape: tuple, 
                 beta=None,
                 threshold = 1.,
                 activation = superspike(),
                 name="LIF"):

        """

        :hidden_size: Size of preceding layer's outputs
        :beta: decay rate. Set to float in range (0,1] for uniform decay across layer, otherwise it will be a normal
                distribution centered on 0.5 with stddev of 0.25
        :threshold: threshold for reset. Defaults to 1.
        :activation: spyx.axn.Axon object, default is Heaviside with Straight-Through-Estimation.
        """
        super().__init__(name=name)
        self.hidden_shape = hidden_shape
        self.beta = beta
        self.threshold = threshold
        self.spike = activation

    def __call__(self, x, V):
        """
        :x: input vector coming from previous layer.
        :V: neuron state tensor.

        """
        beta = self.beta # this line can probably be deleted, and the check changed to self.beta
        if not beta:
            beta = hk.get_parameter("beta", self.hidden_shape,
                                init=hk.initializers.TruncatedNormal(0.25, 0.5))
            beta = jnp.clip(beta, 0, 1)
        else:
            beta = hk.get_parameter("beta", [],
                                init=hk.initializers.Constant(beta))
            beta = jnp.clip(beta, 0, 1)

        # calculate whether spike is generated, and update membrane potential
        spikes = self.spike(V-self.threshold)
        V = beta*V + x - spikes * self.threshold

        return spikes, V

    def initial_state(self, batch_size): 
        return jnp.zeros((batch_size,) + self.hidden_shape)

__call__(x, V)

:x: input vector coming from previous layer. :V: neuron state tensor.

Source code in spyx/nn.py
def __call__(self, x, V):
    """
    :x: input vector coming from previous layer.
    :V: neuron state tensor.

    """
    beta = self.beta # this line can probably be deleted, and the check changed to self.beta
    if not beta:
        beta = hk.get_parameter("beta", self.hidden_shape,
                            init=hk.initializers.TruncatedNormal(0.25, 0.5))
        beta = jnp.clip(beta, 0, 1)
    else:
        beta = hk.get_parameter("beta", [],
                            init=hk.initializers.Constant(beta))
        beta = jnp.clip(beta, 0, 1)

    # calculate whether spike is generated, and update membrane potential
    spikes = self.spike(V-self.threshold)
    V = beta*V + x - spikes * self.threshold

    return spikes, V

__init__(hidden_shape, beta=None, threshold=1.0, activation=superspike(), name='LIF')

:hidden_size: Size of preceding layer's outputs :beta: decay rate. Set to float in range (0,1] for uniform decay across layer, otherwise it will be a normal distribution centered on 0.5 with stddev of 0.25 :threshold: threshold for reset. Defaults to 1. :activation: spyx.axn.Axon object, default is Heaviside with Straight-Through-Estimation.

Source code in spyx/nn.py
def __init__(self, 
             hidden_shape: tuple, 
             beta=None,
             threshold = 1.,
             activation = superspike(),
             name="LIF"):

    """

    :hidden_size: Size of preceding layer's outputs
    :beta: decay rate. Set to float in range (0,1] for uniform decay across layer, otherwise it will be a normal
            distribution centered on 0.5 with stddev of 0.25
    :threshold: threshold for reset. Defaults to 1.
    :activation: spyx.axn.Axon object, default is Heaviside with Straight-Through-Estimation.
    """
    super().__init__(name=name)
    self.hidden_shape = hidden_shape
    self.beta = beta
    self.threshold = threshold
    self.spike = activation

RIF

Bases: RNNCore

Recurrent Integrate and Fire neuron model.

Source code in spyx/nn.py
class RIF(hk.RNNCore): 
    """
    Recurrent Integrate and Fire neuron model.

    """

    def __init__(self, hidden_shape, 
                 threshold = 1,
                 activation = superspike(),
                 name="RIF"):
        """
        :hidden_size: Size of preceding layer's outputs
        :threshold: threshold for reset. Defaults to 1.
        :activation: spyx.activation function, default is Heaviside with Straight-Through-Estimation.
        """
        super().__init__(name=name)
        self.hidden_shape = hidden_shape
        self.threshold = threshold
        self.spike = activation

    def __call__(self, x, V):
        """
        :x: Vector coming from previous layer.
        :V: Neuron state tensor.
        """

        recurrent = hk.get_parameter("w", self.hidden_shape*2, init=hk.initializers.TruncatedNormal())

        # calculate whether spike is generated, and update membrane potential
        spikes = self.spike(V-self.threshold)
        feedback = spikes@recurrent
        V = V + x + feedback - spikes*self.threshold

        return spikes, V

    def initial_state(self, batch_size): 
        return jnp.zeros((batch_size,) + self.hidden_shape)

__call__(x, V)

:x: Vector coming from previous layer. :V: Neuron state tensor.

Source code in spyx/nn.py
def __call__(self, x, V):
    """
    :x: Vector coming from previous layer.
    :V: Neuron state tensor.
    """

    recurrent = hk.get_parameter("w", self.hidden_shape*2, init=hk.initializers.TruncatedNormal())

    # calculate whether spike is generated, and update membrane potential
    spikes = self.spike(V-self.threshold)
    feedback = spikes@recurrent
    V = V + x + feedback - spikes*self.threshold

    return spikes, V

__init__(hidden_shape, threshold=1, activation=superspike(), name='RIF')

:hidden_size: Size of preceding layer's outputs :threshold: threshold for reset. Defaults to 1. :activation: spyx.activation function, default is Heaviside with Straight-Through-Estimation.

Source code in spyx/nn.py
def __init__(self, hidden_shape, 
             threshold = 1,
             activation = superspike(),
             name="RIF"):
    """
    :hidden_size: Size of preceding layer's outputs
    :threshold: threshold for reset. Defaults to 1.
    :activation: spyx.activation function, default is Heaviside with Straight-Through-Estimation.
    """
    super().__init__(name=name)
    self.hidden_shape = hidden_shape
    self.threshold = threshold
    self.spike = activation

RLIF

Bases: RNNCore

Recurrent LIF Neuron adapted from snnTorch.

https://snntorch.readthedocs.io/en/latest/snn.neurons_rleaky.html

Source code in spyx/nn.py
class RLIF(hk.RNNCore): 
    """
    Recurrent LIF Neuron adapted from snnTorch. 

    https://snntorch.readthedocs.io/en/latest/snn.neurons_rleaky.html
    """

    def __init__(self, hidden_shape, beta=None,
                 threshold = 1,
                 activation = superspike(),
                 name="RLIF"):

        """
        Initialization function.

        :hidden_shape: The tuple describing the layer's shape. Can accomodate varying shapes to directly stack on convolution layers without flattening.
        :beta: Decay constant. Unless explicitly set to a float of range [0,1], it is treated as a learnable parameter.
        :threshold: Firing threshold for the layer. Does not currently support learning/trainable thresholds.
        :activation: A spyx.axn.Axon object specifying the forward and reverse activation function. By default it is Heaviside with Straight Through Estimation.
        """

        super().__init__(name=name)
        self.hidden_shape = hidden_shape
        self.beta = beta
        self.threshold = threshold
        self.spike = activation

    def __call__(self, x, V):
        """
        :x: The input data/latent vector from another layer.
        :V: The state tensor.
        """

        recurrent = hk.get_parameter("w", self.hidden_shape*2, init=hk.initializers.TruncatedNormal())

        beta = self.beta
        if not beta:
            beta = hk.get_parameter("beta", self.hidden_shape, 
                                init=hk.initializers.TruncatedNormal(0.25, 0.5))
            beta = jnp.clip(beta, 0, 1)
        else:
            beta = hk.get_parameter("beta", [], 
                                init=hk.initializers.Constant(beta))
            beta = jnp.clip(beta, 0, 1)

        spikes = self.spike(V-self.threshold)
        feedback = spikes@recurrent # investigate and fix this...
        V = beta*V + x + feedback - spikes*self.threshold

        return spikes, V

    def initial_state(self, batch_size):
        return jnp.zeros((batch_size,) + self.hidden_shape)

__call__(x, V)

:x: The input data/latent vector from another layer. :V: The state tensor.

Source code in spyx/nn.py
def __call__(self, x, V):
    """
    :x: The input data/latent vector from another layer.
    :V: The state tensor.
    """

    recurrent = hk.get_parameter("w", self.hidden_shape*2, init=hk.initializers.TruncatedNormal())

    beta = self.beta
    if not beta:
        beta = hk.get_parameter("beta", self.hidden_shape, 
                            init=hk.initializers.TruncatedNormal(0.25, 0.5))
        beta = jnp.clip(beta, 0, 1)
    else:
        beta = hk.get_parameter("beta", [], 
                            init=hk.initializers.Constant(beta))
        beta = jnp.clip(beta, 0, 1)

    spikes = self.spike(V-self.threshold)
    feedback = spikes@recurrent # investigate and fix this...
    V = beta*V + x + feedback - spikes*self.threshold

    return spikes, V

__init__(hidden_shape, beta=None, threshold=1, activation=superspike(), name='RLIF')

Initialization function.

:hidden_shape: The tuple describing the layer's shape. Can accomodate varying shapes to directly stack on convolution layers without flattening. :beta: Decay constant. Unless explicitly set to a float of range [0,1], it is treated as a learnable parameter. :threshold: Firing threshold for the layer. Does not currently support learning/trainable thresholds. :activation: A spyx.axn.Axon object specifying the forward and reverse activation function. By default it is Heaviside with Straight Through Estimation.

Source code in spyx/nn.py
def __init__(self, hidden_shape, beta=None,
             threshold = 1,
             activation = superspike(),
             name="RLIF"):

    """
    Initialization function.

    :hidden_shape: The tuple describing the layer's shape. Can accomodate varying shapes to directly stack on convolution layers without flattening.
    :beta: Decay constant. Unless explicitly set to a float of range [0,1], it is treated as a learnable parameter.
    :threshold: Firing threshold for the layer. Does not currently support learning/trainable thresholds.
    :activation: A spyx.axn.Axon object specifying the forward and reverse activation function. By default it is Heaviside with Straight Through Estimation.
    """

    super().__init__(name=name)
    self.hidden_shape = hidden_shape
    self.beta = beta
    self.threshold = threshold
    self.spike = activation

SumPool

Bases: Module

Sum pool.

Returns the total number of spikes emitted in a receptive field.

Source code in spyx/nn.py
class SumPool(hk.Module):
  """Sum pool.

  Returns the total number of spikes emitted in a receptive field.
  """

  def __init__(
      self,
      window_shape: Union[int, Sequence[int]],
      strides: Union[int, Sequence[int]],
      padding: str,
      channel_axis: Optional[int] = -1,
      name: Optional[str] = None,
  ):
    """Sum pool.

    Args:
      window_shape: Shape of the pooling window, same rank as value.
      strides: Strides of the pooling window, same rank as value.
      padding: Padding algorithm. Either ``VALID`` or ``SAME``.
      channel_axis: Axis of the spatial channels for which pooling is skipped.
      name: String name for the module.
    """
    super().__init__(name=name)
    self.window_shape = window_shape
    self.strides = strides
    self.padding = padding
    self.channel_axis = channel_axis

  def __call__(self, value: jax.Array) -> jax.Array:
    return sum_pool(value, self.window_shape, self.strides,
                    self.padding, self.channel_axis)

__init__(window_shape, strides, padding, channel_axis=-1, name=None)

Sum pool.

Parameters:

Name Type Description Default
window_shape Union[int, Sequence[int]]

Shape of the pooling window, same rank as value.

required
strides Union[int, Sequence[int]]

Strides of the pooling window, same rank as value.

required
padding str

Padding algorithm. Either VALID or SAME.

required
channel_axis Optional[int]

Axis of the spatial channels for which pooling is skipped.

-1
name Optional[str]

String name for the module.

None
Source code in spyx/nn.py
def __init__(
    self,
    window_shape: Union[int, Sequence[int]],
    strides: Union[int, Sequence[int]],
    padding: str,
    channel_axis: Optional[int] = -1,
    name: Optional[str] = None,
):
  """Sum pool.

  Args:
    window_shape: Shape of the pooling window, same rank as value.
    strides: Strides of the pooling window, same rank as value.
    padding: Padding algorithm. Either ``VALID`` or ``SAME``.
    channel_axis: Axis of the spatial channels for which pooling is skipped.
    name: String name for the module.
  """
  super().__init__(name=name)
  self.window_shape = window_shape
  self.strides = strides
  self.padding = padding
  self.channel_axis = channel_axis

PopulationCode(num_classes)

Add population coding to the preceding neuron layer. Preceding layer's output shape must be a multiple of the number of classes. Use this for rate coded SNNs where the time steps are too few to get a good spike count.

Source code in spyx/nn.py
def PopulationCode(num_classes):
    """
    Add population coding to the preceding neuron layer. Preceding layer's output shape must be a multiple of
    the number of classes. Use this for rate coded SNNs where the time steps are too few to get a good spike count.
    """
    def _pop_code(x):
        return jnp.sum(jnp.reshape(x, (-1,num_classes)), axis=-1)
    return jax.jit(_pop_code)

sum_pool(value, window_shape, strides, padding, channel_axis=-1)

Sum pool.

Parameters:

Name Type Description Default
value Array

Value to pool.

required
window_shape Union[int, Sequence[int]]

Shape of the pooling window, same rank as value.

required
strides Union[int, Sequence[int]]

Strides of the pooling window, same rank as value.

required
padding str

Padding algorithm. Either VALID or SAME.

required
channel_axis Optional[int]

Axis of the spatial channels for which pooling is skipped.

-1

Returns:

Type Description
Array

Pooled result. Same rank as value.

Source code in spyx/nn.py
def sum_pool(
    value: jax.Array,
    window_shape: Union[int, Sequence[int]],
    strides: Union[int, Sequence[int]],
    padding: str,
    channel_axis: Optional[int] = -1,
) -> jax.Array:
  """Sum pool.

  Args:
    value: Value to pool.
    window_shape: Shape of the pooling window, same rank as value.
    strides: Strides of the pooling window, same rank as value.
    padding: Padding algorithm. Either ``VALID`` or ``SAME``.
    channel_axis: Axis of the spatial channels for which pooling is skipped.

  Returns:
    Pooled result. Same rank as value.
  """
  if padding not in ("SAME", "VALID"):
    raise ValueError(f"Invalid padding '{padding}', must be 'SAME' or 'VALID'.")

  _warn_if_unsafe(window_shape, strides)
  window_shape = _infer_shape(value, window_shape, channel_axis)
  strides = _infer_shape(value, strides, channel_axis)

  return jax.lax.reduce_window(value, 0., jax.lax.add, window_shape, strides,
                           padding)

Activation Functions

arctan(k=2)

This class implements the Arctangent surrogate gradient activation function for a spiking neuron.

The Arctangent function returns a value between -pi/2 and pi/2 for inputs in the range of -Infinity to Infinity. It is often used in the context of spiking neurons because it provides a smooth approximation to the step function that is differentiable everywhere, which is a requirement for gradient-based optimization methods.

:k: A scaling factor that can be used to adjust the steepness of the Arctangent function. Default is 2. :return: JIT compiled arctangent-derived surrogate gradient function.

Source code in spyx/axn.py
def arctan(k=2):
    """
    This class implements the Arctangent surrogate gradient activation function for a spiking neuron.

    The Arctangent function returns a value between -pi/2 and pi/2 for inputs in the range of -Infinity to Infinity.
    It is often used in the context of spiking neurons because it provides a smooth approximation to the step function 
    that is differentiable everywhere, which is a requirement for gradient-based optimization methods.

    :k: A scaling factor that can be used to adjust the steepness of the 
                      Arctangent function. Default is 2.
    :return: JIT compiled arctangent-derived surrogate gradient function.
    """
    k_pi = k*jnp.pi

    def grad_arctan(x):
        k_pi_x = k_pi * x
        return 1 / ((1+k_pi_x**2) * jnp.pi)

    return custom(grad_arctan, heaviside)

boxcar(width=2, height=0.5)

Boxcar activation.

:width: Total width of non-zero gradient flow, centered on 0. :height: Value for gradient within the specified window. :return: JIT compiled boxcar surrogate gradient function.

Source code in spyx/axn.py
def boxcar(width=2, height=0.5):
    """Boxcar activation.

    :width: Total width of non-zero gradient flow, centered on 0.
    :height: Value for gradient within the specified window.
    :return: JIT compiled boxcar surrogate gradient function.
    """
    k = width / 2
    h = height

    def grad_boxcar(x):
        return h * heaviside(-(jnp.abs(x) - k))

    return custom(grad_boxcar, heaviside)

custom(bwd=lambda x: x, fwd=lambda x: heaviside(x))

This function serves as the activation function for the SNNs, allowing for custom definitions of both surrogate gradients for backwards passes as well as substitution of the Heaviside function for relaxations such as sigmoids.

It is assumed that the input to this layer has already had it's threshold subtracted within the neuron model dynamics.

The default behavior is a Heaviside forward activation with a stragiht through estimator surrogate gradient.

:bwd: Function that calculates the gradient to be used in the backwards pass. :fwd: Forward activation/spiking function. Default is the heaviside function centered at 0. :return: A JIT compiled activation function comprised of the specified forward and backward functions.

Source code in spyx/axn.py
def custom(bwd=lambda x: x, 
           fwd=lambda x: heaviside(x)): # this is probably redundant and could just be fwd=heaviside
    """
    This function serves as the activation function for the SNNs, allowing for custom definitions of both surrogate gradients for backwards
    passes as well as substitution of the Heaviside function for relaxations such as sigmoids. 

    It is assumed that the input to this layer has already had it's threshold subtracted within the neuron model dynamics.

    The default behavior is a Heaviside forward activation with a stragiht through estimator surrogate gradient.

    :bwd: Function that calculates the gradient to be used in the backwards pass.
    :fwd: Forward activation/spiking function. Default is the heaviside function centered at 0.
    :return: A JIT compiled activation function comprised of the specified forward and backward functions.
    """

    @jax.custom_gradient
    def f(x):
        return fwd(x), lambda g: g * bwd(x)

    return jax.jit(f)

superspike(k=25)

This function implements the SuperSpike surrogate gradient activation function for a spiking neuron.

The SuperSpike function is defined as 1/(1+k|U|)^2, where U is the input to the function and k is a scaling factor. It returns a value between 0 and 1 for inputs in the range of -Infinity to Infinity.

It is often used in the context of spiking neurons because it provides a smooth approximation to the step function that is differentiable everywhere, which is a requirement for gradient-based optimization methods.

It is a fast approximation of the Sigmoid function adapted from:

F. Zenke, S. Ganguli (2018) SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks. Neural Computation, pp. 1514-1541.

:k: A scaling factor that can be used to adjust the steepness of the SuperSpike function. Default is 25. :return: JIT compiled SuperSpike surrogate gradient function.

Source code in spyx/axn.py
def superspike(k=25):
    """
    This function implements the SuperSpike surrogate gradient activation function for a spiking neuron.

    The SuperSpike function is defined as 1/(1+k|U|)^2, where U is the input to the function and k is a scaling factor.
    It returns a value between 0 and 1 for inputs in the range of -Infinity to Infinity.

    It is often used in the context of spiking neurons because it provides a smooth approximation to the step function 
    that is differentiable everywhere, which is a requirement for gradient-based optimization methods.

    It is a fast approximation of the Sigmoid function adapted from:

    F. Zenke, S. Ganguli (2018) SuperSpike: Supervised Learning in Multilayer Spiking Neural Networks. Neural Computation, pp. 1514-1541.


    :k: A scaling factor that can be used to adjust the steepness of the 
                      SuperSpike function. Default is 25.
    :return: JIT compiled SuperSpike surrogate gradient function.
    """
    def grad_superspike(x):
        return 1 / (1 + k*jnp.abs(x))**2

    return custom(grad_superspike, heaviside)

tanh(k=1)

Hyperbolic Tangent activation.

.. math:: 4 / (e^{-kx} + e^{kx})^2

:k: Value for scaling the slope of the surrogate gradient. :return: JIT compiled tanh surrogate gradient function.

Source code in spyx/axn.py
def tanh(k=1):
    """Hyperbolic Tangent activation.

    .. math:: 4 / (e^{-kx} + e^{kx})^2

    :k: Value for scaling the slope of the surrogate gradient.
    :return: JIT compiled tanh surrogate gradient function.
    """
    def grad_tanh(x):
        kx = k * x
        return 4 / (jnp.exp(-kx) + jnp.exp(kx))**2

    return custom(grad_tanh, heaviside)

triangular(k=2)

Triangular activation inspired by Esser et. al. https://arxiv.org/abs/1603.08270

.. math:: max(0, 1-|kx|)

:k: scale factor :return: JIT compiled triangular surrogate gradient function.

Source code in spyx/axn.py
def triangular(k=2):
    """
    Triangular activation inspired by Esser et. al. https://arxiv.org/abs/1603.08270

    .. math::
        max(0, 1-|kx|)


    :k: scale factor
    :return: JIT compiled triangular surrogate gradient function.
    """

    def grad_traingle(x):
        return jnp.maximum(0, 1-jnp.abs(k*x))

    return custom(grad_traingle, heaviside)

Data Utilities

angle_code(neuron_count, min_val, max_val)

Higher-order-function which returns an angle encoding function; given a continuous value, an angle converter generates a one-hot vector corresponding to where the value falls between a specified minimum and maximum. To achieve non-linear descritization, apply a function to the continuous value before feeding it into the encoder.

:neuron_count: The number of output channels for the angle encoder :min_val: A lower bound on the continuous input channel :max_val: An upper bound on the continuous input channel.

Source code in spyx/data.py
def angle_code(neuron_count, min_val, max_val):
    """
    Higher-order-function which returns an angle encoding function; given a continuous value, an angle converter generates a one-hot vector corresponding to where the value falls between a specified minimum and maximum.
    To achieve non-linear descritization, apply a function to the continuous value before feeding it into the encoder.

    :neuron_count: The number of output channels for the angle encoder
    :min_val: A lower bound on the continuous input channel
    :max_val: An upper bound on the continuous input channel.
    """
    neurons = jnp.linspace(min_val, max_val, neuron_count)

    def _call(obs):
        digital = jnp.digitize(obs, neurons)
        return jax.nn.one_hot(digital, neuron_count)

    return jax.jit(_call)

rate_code(num_steps, max_r=0.75)

Unrolls input data along axis 1 and converts to rate encoded spikes; the probability of spiking is based on the input value multiplied by a max rate, with each time step being a sample drawn from a Bernoulli distribution. Currently Assumes input values have been rescaled to between 0 and 1.

Source code in spyx/data.py
def rate_code(num_steps, max_r=0.75):
    """
    Unrolls input data along axis 1 and converts to rate encoded spikes; the probability of spiking is based on the input value multiplied by a max rate, with each time step being a sample drawn from a Bernoulli distribution.
    Currently Assumes input values have been rescaled to between 0 and 1.
    """

    def _call(data, key):
        data = jnp.array(data, dtype=jnp.float16)
        unrolled_data = jnp.repeat(data, num_steps, axis=1)
        return jax.random.bernoulli(key, unrolled_data*max_r).astype(jnp.uint8)

    return jax.jit(_call)

shift_augment(max_shift=10, axes=(-1,))

Shift data augmentation tool. Rolls data along specified axes randomly up to a certain amount.

:max_shift: maximum to which values can be shifted. :axes: the data axis or axes along which the input will be randomly shifted.

Source code in spyx/data.py
def shift_augment(max_shift=10, axes=(-1,)):
    """Shift data augmentation tool. Rolls data along specified axes randomly up to a certain amount.


    :max_shift: maximum to which values can be shifted.
    :axes: the data axis or axes along which the input will be randomly shifted.
    """

    def _shift(data, rng):
        shift = jax.random.randint(rng, (len(axes),), -max_shift, max_shift)
        return jnp.roll(data, shift, axes)

    return jax.jit(_shift)

shuffler(dataset, batch_size)

Higher-order-function which builds a shuffle function for a dataset.

:dataset: jnp.array [# samples, time, channels...] :batch_size: desired batch size.

Source code in spyx/data.py
def shuffler(dataset, batch_size):
    """
    Higher-order-function which builds a shuffle function for a dataset.

    :dataset: jnp.array [# samples, time, channels...]
    :batch_size: desired batch size.
    """
    x, y = dataset
    cutoff = (y.shape[0] // batch_size) * batch_size
    data_shape = (-1, batch_size) + x.shape[1:]

    def _shuffle(dataset, shuffle_rng):
        """
        Given a dataset as a single tensor, shuffle its batches.

        :dataset: tuple of jnp.arrays with shape [# batches, batch size, time, ...] and [# batches, batchsize]
        :shuffle_rng: JAX.random.PRNGKey
        """
        x, y = dataset

        indices = jax.random.permutation(shuffle_rng, y.shape[0])[:cutoff]
        obs, labels = x[indices], y[indices]

        obs = jnp.reshape(obs, data_shape)
        labels = jnp.reshape(labels, (-1, batch_size)) # should make batch size a global

        return (obs, labels)

    return jax.jit(_shuffle)