Skip to content

spyx.experimental

Research-stage building blocks that are not part of the stable Spyx surface. Everything here is tested and usable, but the contract is different from the rest of the library.

Stability contract

The APIs in spyx.experimental — and in some cases their numerical behaviour — may change without a deprecation cycle as the underlying research matures. Anything you depend on for production or a long-lived experiment should come from the stable top-level modules (spyx.nn, spyx.ssm, spyx.phasor, spyx.nir, spyx.bench, spyx.quant, spyx.data, spyx.optimize).

The rule of thumb: import experimental things from spyx.experimental so the dependency is explicit; rely on the top-level modules for stable work. See Research with Spyx for how things graduate from here into the core.

What's here

Symbol Kind Notes
spyx.experimental.PSU_LIF Neuron Reset-free parallel LIF. Physically defined in spyx.nn, surfaced here as its supported experimental entry point.
spyx.experimental.ResonateFire Neuron Complex resonate-and-fire oscillatory neuron. Physically defined in spyx.phasor.
spyx.experimental.raven Module Routing-slot memory (RavenRSM), spiking sibling (SpikingSlotMemory), SlotRouter, and the make_recall_batch MQAR generator.
spyx.experimental.compress Module Bit-packed activation storage for memory-efficient BPTT.
spyx.experimental.stochastic Module Stochastic (Bernoulli-spiking) and parallelizable prototypes: SPSN, StochasticAssociative{LIF,CuBaLIF}, and the sigmoid_bernoulli activations.

Related research studies live under research/new/ in the repository.

Re-exported neurons

These two are physically defined in stable modules and re-exported here so the experimental surface is discoverable in one place.

Bases: Module

Parallel Spiking Unit LIF: a reset-free leaky integrate-and-fire neuron.

.. note:: Experimental. Its supported entry point is :class:spyx.experimental.PSU_LIF; the API may change without a deprecation cycle. It is defined here for locality with the other neurons.

A standard :class:LIF subtracts a reset spikes * threshold from the membrane every step, which couples each timestep to the (nonlinear) spike of the previous step and forces a strictly sequential O(T) scan. Dropping the reset turns the membrane into a pure linear leaky integrator,

.. math:: V_t = \beta \, V_{t-1} + x_t ,

which is a first-order associative recurrence and can therefore be evaluated with :func:jax.lax.associative_scan in O(\log T) parallel depth on an accelerator. Spikes are a pointwise surrogate threshold applied to the whole membrane trace, :math:s_t = \sigma(V_t - \text{threshold}).

Removing the reset is a deliberate accuracy/parallelism trade-off: the neuron never depresses after firing, so it can fire on consecutive steps while a well-tuned integration window keeps activity bounded. In exchange the sequence can be scored in logarithmic instead of linear depth.

Two execution modes are provided and are numerically identical:

  • :meth:__call__ -- one reset-free timestep (x, V) -> (spikes, V) with V = beta * V + x; a drop-in for :func:spyx.nn.run, :class:Sequential, and NIR, exactly like :class:LIF.
  • :meth:parallel -- the whole time-major sequence at once via an associative scan over the leak, O(\log T) depth.

Because both modes use the same clipped beta and the same surrogate, and :meth:__call__ integrates the input before spiking, scanning :meth:__call__ over x reproduces :meth:parallel exactly.

Source code in spyx/nn.py
class PSU_LIF(nnx.Module):
    r"""Parallel Spiking Unit LIF: a reset-free leaky integrate-and-fire neuron.

    .. note::
       **Experimental.** Its supported entry point is
       :class:`spyx.experimental.PSU_LIF`; the API may change without a
       deprecation cycle. It is defined here for locality with the other neurons.

    A standard :class:`LIF` subtracts a reset ``spikes * threshold`` from the
    membrane every step, which couples each timestep to the (nonlinear) spike
    of the previous step and forces a strictly sequential ``O(T)`` scan.
    Dropping the reset turns the membrane into a pure linear leaky integrator,

    .. math::
        V_t = \beta \, V_{t-1} + x_t ,

    which is a first-order *associative* recurrence and can therefore be
    evaluated with :func:`jax.lax.associative_scan` in ``O(\log T)`` parallel
    depth on an accelerator. Spikes are a pointwise surrogate threshold applied
    to the whole membrane trace, :math:`s_t = \sigma(V_t - \text{threshold})`.

    Removing the reset is a deliberate accuracy/parallelism trade-off: the
    neuron never depresses after firing, so it can fire on consecutive steps
    while a well-tuned integration window keeps activity bounded. In exchange
    the sequence can be scored in logarithmic instead of linear depth.

    Two execution modes are provided and are numerically identical:

    * :meth:`__call__` -- one reset-free timestep ``(x, V) -> (spikes, V)``
      with ``V = beta * V + x``; a drop-in for :func:`spyx.nn.run`,
      :class:`Sequential`, and NIR, exactly like :class:`LIF`.
    * :meth:`parallel` -- the whole time-major sequence at once via an
      associative scan over the leak, ``O(\log T)`` depth.

    Because both modes use the *same* clipped ``beta`` and the *same* surrogate,
    and :meth:`__call__` integrates the input *before* spiking, scanning
    :meth:`__call__` over ``x`` reproduces :meth:`parallel` exactly.
    """

    def __init__(
        self,
        hidden_shape: tuple,
        beta=None,
        threshold=1.0,
        activation=None,
        *,
        rngs: nnx.Rngs,
    ):
        """
        :hidden_shape: Shape of the layer.
        :beta: decay rate. Scalar if provided, else learnable per-unit init.
        :threshold: firing threshold. Defaults to 1.
        :activation: spyx.axn.Axon object determining the surrogate spike.
        """
        self.hidden_shape = hidden_shape
        self.threshold = threshold
        self.spike = activation if activation is not None else _DEFAULT_ACTIVATION

        if beta is None:
            self.beta = nnx.Param(
                nnx.initializers.truncated_normal(stddev=0.25)(
                    rngs.params(), self.hidden_shape
                )
                + 0.5
            )
        else:
            self.beta = nnx.Param(jnp.full((), beta))

    def __call__(self, x, V):
        """One reset-free timestep.

        :x: input vector coming from previous layer.
        :V: neuron state tensor.

        Integrates the input into the membrane (``V = beta * V + x``, no
        reset), then emits a surrogate spike on the updated membrane so that
        scanning this method matches :meth:`parallel` exactly.
        """
        beta = jnp.clip(self.beta[...], 0, 1)
        V = beta * V + x
        spikes = self.spike(V - self.threshold)
        return spikes, V

    def parallel(self, x):
        r"""Score a whole time-major sequence with an associative scan.

        :x: input with shape ``[Time, Batch, ...]``.
        :return: spikes with shape ``[Time, Batch, ...]``.

        Computes the full membrane trace ``V_t = beta * V_{t-1} + x_t`` (with
        ``V_{-1} = 0``) via :func:`jax.lax.associative_scan` over the time axis
        in ``O(\log T)`` depth, then applies the surrogate spike pointwise.
        """
        beta = jnp.clip(self.beta[...], 0, 1)
        # Broadcast the (scalar or per-unit) leak to every (Time, Batch, ...)
        # element so the linear-recurrence coefficient A_t == beta everywhere.
        A = jnp.broadcast_to(beta, x.shape)
        _, V = jax.lax.associative_scan(_leaky_associative_op, (A, x), axis=0)
        return self.spike(V - self.threshold)

    def initial_state(self, batch_size):
        return jnp.zeros((batch_size,) + self.hidden_shape)

__call__(x, V)

One reset-free timestep.

❌ input vector coming from previous layer. :V: neuron state tensor.

Integrates the input into the membrane (V = beta * V + x, no reset), then emits a surrogate spike on the updated membrane so that scanning this method matches :meth:parallel exactly.

Source code in spyx/nn.py
def __call__(self, x, V):
    """One reset-free timestep.

    :x: input vector coming from previous layer.
    :V: neuron state tensor.

    Integrates the input into the membrane (``V = beta * V + x``, no
    reset), then emits a surrogate spike on the updated membrane so that
    scanning this method matches :meth:`parallel` exactly.
    """
    beta = jnp.clip(self.beta[...], 0, 1)
    V = beta * V + x
    spikes = self.spike(V - self.threshold)
    return spikes, V

__init__(hidden_shape, beta=None, threshold=1.0, activation=None, *, rngs)

:hidden_shape: Shape of the layer. :beta: decay rate. Scalar if provided, else learnable per-unit init. :threshold: firing threshold. Defaults to 1. :activation: spyx.axn.Axon object determining the surrogate spike.

Source code in spyx/nn.py
def __init__(
    self,
    hidden_shape: tuple,
    beta=None,
    threshold=1.0,
    activation=None,
    *,
    rngs: nnx.Rngs,
):
    """
    :hidden_shape: Shape of the layer.
    :beta: decay rate. Scalar if provided, else learnable per-unit init.
    :threshold: firing threshold. Defaults to 1.
    :activation: spyx.axn.Axon object determining the surrogate spike.
    """
    self.hidden_shape = hidden_shape
    self.threshold = threshold
    self.spike = activation if activation is not None else _DEFAULT_ACTIVATION

    if beta is None:
        self.beta = nnx.Param(
            nnx.initializers.truncated_normal(stddev=0.25)(
                rngs.params(), self.hidden_shape
            )
            + 0.5
        )
    else:
        self.beta = nnx.Param(jnp.full((), beta))

parallel(x)

Score a whole time-major sequence with an associative scan.

❌ input with shape [Time, Batch, ...]. :return: spikes with shape [Time, Batch, ...].

Computes the full membrane trace V_t = beta * V_{t-1} + x_t (with V_{-1} = 0) via :func:jax.lax.associative_scan over the time axis in O(\log T) depth, then applies the surrogate spike pointwise.

Source code in spyx/nn.py
def parallel(self, x):
    r"""Score a whole time-major sequence with an associative scan.

    :x: input with shape ``[Time, Batch, ...]``.
    :return: spikes with shape ``[Time, Batch, ...]``.

    Computes the full membrane trace ``V_t = beta * V_{t-1} + x_t`` (with
    ``V_{-1} = 0``) via :func:`jax.lax.associative_scan` over the time axis
    in ``O(\log T)`` depth, then applies the surrogate spike pointwise.
    """
    beta = jnp.clip(self.beta[...], 0, 1)
    # Broadcast the (scalar or per-unit) leak to every (Time, Batch, ...)
    # element so the linear-recurrence coefficient A_t == beta everywhere.
    A = jnp.broadcast_to(beta, x.shape)
    _, V = jax.lax.associative_scan(_leaky_associative_op, (A, x), axis=0)
    return self.spike(V - self.threshold)

Bases: Module

Resonate-and-fire neuron: the complex/oscillatory sibling of PSU_LIF.

.. note:: Experimental. Its supported entry point is :class:spyx.experimental.ResonateFire; the API may change without a deprecation cycle. It is defined here for locality with the phasor layers.

A resonate-and-fire neuron carries a complex membrane that behaves as a damped harmonic oscillator. Written reset-free, its subthreshold dynamics are a complex linear recurrence

.. math:: z_t = a \, z_{t-1} + x_t , \qquad a = e^{\,\mathrm{dt}\,(-\lambda + i\,\omega)} ,

with per-unit decay :math:\lambda \ge 0 and angular frequency :math:\omega. The real input current x_t is injected into the real part of the membrane. Because there is no reset, the recurrence stays linear, so exactly like :class:spyx.nn.PSU_LIF it can be evaluated with :func:jax.lax.associative_scan in :math:O(\log T) parallel depth -- only now the scan runs over a complex pole a instead of a real leak.

Spikes are emitted by a pointwise surrogate threshold on the real part of the oscillator, :math:s_t = \sigma(\Re(z_t) - \text{threshold}). The rule is reset-free so the linear recurrence -- and therefore the parallel scan -- is preserved.

Stability: the pole magnitude is |a| = exp(-dt * lambda). Storing the decay through a softplus keeps :math:\lambda \ge 0, hence :math:|a| \le 1 and the oscillation never grows.

Parameters that enter the complex pole (lambda, omega) are stored as real float32 nnx.Param tensors, mirroring :class:PhasorLinear: the complex structure appears only in the forward pass, so a stock optax + jax.grad loop over a real loss trains them without the Wirtinger-conjugate surprise.

Two execution modes are provided and are numerically identical:

  • :meth:__call__ -- one reset-free timestep (x, z) -> (spikes, z) with z = a * z + x; a drop-in for :func:spyx.nn.run / :class:Sequential.
  • :meth:parallel -- the whole time-major sequence at once via an associative scan over the complex pole, :math:O(\log T) depth.

Because both modes use the same pole and surrogate and integrate the input before spiking, scanning :meth:__call__ over x reproduces :meth:parallel exactly.

Source code in spyx/phasor.py
class ResonateFire(nnx.Module):
    r"""Resonate-and-fire neuron: the complex/oscillatory sibling of ``PSU_LIF``.

    .. note::
       **Experimental.** Its supported entry point is
       :class:`spyx.experimental.ResonateFire`; the API may change without a
       deprecation cycle. It is defined here for locality with the phasor layers.


    A resonate-and-fire neuron carries a **complex** membrane that behaves as a
    damped harmonic oscillator. Written reset-free, its subthreshold dynamics
    are a *complex linear recurrence*

    .. math::
        z_t = a \, z_{t-1} + x_t , \qquad a = e^{\,\mathrm{dt}\,(-\lambda + i\,\omega)} ,

    with per-unit decay :math:`\lambda \ge 0` and angular frequency
    :math:`\omega`. The real input current ``x_t`` is injected into the *real*
    part of the membrane. Because there is no reset, the recurrence stays
    linear, so exactly like :class:`spyx.nn.PSU_LIF` it can be evaluated with
    :func:`jax.lax.associative_scan` in :math:`O(\log T)` parallel depth -- only
    now the scan runs over a *complex* pole ``a`` instead of a real leak.

    Spikes are emitted by a pointwise surrogate threshold on the real part of
    the oscillator, :math:`s_t = \sigma(\Re(z_t) - \text{threshold})`. The rule
    is reset-free so the linear recurrence -- and therefore the parallel scan --
    is preserved.

    Stability: the pole magnitude is ``|a| = exp(-dt * lambda)``. Storing the
    decay through a ``softplus`` keeps :math:`\lambda \ge 0`, hence
    :math:`|a| \le 1` and the oscillation never grows.

    Parameters that enter the complex pole (``lambda``, ``omega``) are stored as
    **real** ``float32`` ``nnx.Param`` tensors, mirroring :class:`PhasorLinear`:
    the complex structure appears only in the forward pass, so a stock
    ``optax`` + ``jax.grad`` loop over a real loss trains them without the
    Wirtinger-conjugate surprise.

    Two execution modes are provided and are numerically identical:

    * :meth:`__call__` -- one reset-free timestep ``(x, z) -> (spikes, z)`` with
      ``z = a * z + x``; a drop-in for :func:`spyx.nn.run` / :class:`Sequential`.
    * :meth:`parallel` -- the whole time-major sequence at once via an
      associative scan over the complex pole, :math:`O(\log T)` depth.

    Because both modes use the *same* pole and surrogate and integrate the input
    *before* spiking, scanning :meth:`__call__` over ``x`` reproduces
    :meth:`parallel` exactly.
    """

    def __init__(
        self,
        hidden_shape: tuple,
        lambda_init=None,
        omega_init=None,
        threshold: float = 1.0,
        dt: float = 1.0,
        activation=None,
        *,
        rngs: nnx.Rngs,
    ):
        """
        :hidden_shape: Per-unit shape of the layer.
        :lambda_init: Membrane decay ``>= 0``. Scalar constant if provided, else
            a learnable per-unit initialisation. Stored through ``softplus`` so
            the effective decay is always non-negative.
        :omega_init: Angular frequency of the oscillator. Scalar constant if
            provided, else a learnable per-unit initialisation.
        :threshold: Real firing threshold on ``Re(z)``. Defaults to 1.
        :dt: Integration timestep entering the pole ``exp(dt(-lambda+i*omega))``.
        :activation: :class:`spyx.axn.Axon` surrogate spike; defaults to
            ``superspike``.
        """
        if dt <= 0:
            raise ValueError(f"dt must be positive; got {dt}.")
        self.hidden_shape = hidden_shape
        self.threshold = threshold
        self.dt = dt
        self.spike = activation if activation is not None else _DEFAULT_ACTIVATION

        # Raw decay parameter; effective lambda = softplus(raw) >= 0 so |a| <= 1.
        if lambda_init is None:
            # Small positive decays: softplus(N(0.5, 0.25)) ~ light damping.
            raw = (
                nnx.initializers.truncated_normal(stddev=0.25)(
                    rngs.params(), self.hidden_shape
                )
                + 0.5
            )
            self.raw_lambda = nnx.Param(raw.astype(jnp.float32))
        else:
            self.raw_lambda = nnx.Param(
                _inverse_softplus(jnp.full((), float(lambda_init))).astype(jnp.float32)
            )

        if omega_init is None:
            # Spread frequencies around ~1 rad/step so units resonate distinctly.
            omega = (
                nnx.initializers.truncated_normal(stddev=0.5)(
                    rngs.params(), self.hidden_shape
                )
                + 1.0
            )
            self.omega = nnx.Param(omega.astype(jnp.float32))
        else:
            self.omega = nnx.Param(jnp.full((), float(omega_init)))

    @property
    def decay(self) -> jax.Array:
        """Effective non-negative decay ``lambda = softplus(raw_lambda)``."""
        return jax.nn.softplus(self.raw_lambda[...])

    @property
    def a(self) -> jax.Array:
        """Complex oscillator pole ``a = exp(dt(-lambda + i*omega))``.

        The magnitude ``|a| = exp(-dt * lambda) <= 1`` guarantees stability.
        """
        exponent = self.dt * (-self.decay + 1j * self.omega[...])
        return jnp.exp(exponent).astype(jnp.complex64)

    def __call__(self, x, z):
        """One reset-free timestep.

        :x: real input current from the previous layer, broadcastable to ``z``.
        :z: complex64 membrane state.

        Injects ``x`` into the real part of the membrane and advances the
        complex recurrence ``z = a * z + x`` (no reset), then emits a surrogate
        spike on ``Re(z)`` so that scanning this method matches :meth:`parallel`.
        """
        a = self.a
        z = a * z + x.astype(z.dtype)
        spikes = self.spike(jnp.real(z) - self.threshold)
        return spikes, z

    def parallel(self, x):
        r"""Score a whole time-major sequence with an associative scan.

        :x: real input with shape ``[Time, Batch, ...]``.
        :return: spikes with shape ``[Time, Batch, ...]``.

        Computes the full complex membrane trace ``z_t = a * z_{t-1} + x_t``
        (with ``z_{-1} = 0``) via :func:`jax.lax.associative_scan` over the time
        axis in :math:`O(\log T)` depth, then applies the surrogate spike
        pointwise on ``Re(z)``.
        """
        a = self.a
        xc = x.astype(jnp.complex64)
        # Broadcast the (scalar or per-unit) complex pole to every element so the
        # linear-recurrence coefficient a_t == a everywhere along the time axis.
        A = jnp.broadcast_to(a, xc.shape)
        _, z = jax.lax.associative_scan(_resonate_associative_op, (A, xc), axis=0)
        return self.spike(jnp.real(z) - self.threshold)

    def initial_state(self, batch_size):
        """Return complex64 zeros of shape ``(batch_size,) + hidden_shape``."""
        return jnp.zeros((batch_size,) + tuple(self.hidden_shape), dtype=jnp.complex64)

a property

Complex oscillator pole a = exp(dt(-lambda + i*omega)).

The magnitude |a| = exp(-dt * lambda) <= 1 guarantees stability.

decay property

Effective non-negative decay lambda = softplus(raw_lambda).

__call__(x, z)

One reset-free timestep.

❌ real input current from the previous layer, broadcastable to z. :z: complex64 membrane state.

Injects x into the real part of the membrane and advances the complex recurrence z = a * z + x (no reset), then emits a surrogate spike on Re(z) so that scanning this method matches :meth:parallel.

Source code in spyx/phasor.py
def __call__(self, x, z):
    """One reset-free timestep.

    :x: real input current from the previous layer, broadcastable to ``z``.
    :z: complex64 membrane state.

    Injects ``x`` into the real part of the membrane and advances the
    complex recurrence ``z = a * z + x`` (no reset), then emits a surrogate
    spike on ``Re(z)`` so that scanning this method matches :meth:`parallel`.
    """
    a = self.a
    z = a * z + x.astype(z.dtype)
    spikes = self.spike(jnp.real(z) - self.threshold)
    return spikes, z

__init__(hidden_shape, lambda_init=None, omega_init=None, threshold=1.0, dt=1.0, activation=None, *, rngs)

:hidden_shape: Per-unit shape of the layer. :lambda_init: Membrane decay >= 0. Scalar constant if provided, else a learnable per-unit initialisation. Stored through softplus so the effective decay is always non-negative. :omega_init: Angular frequency of the oscillator. Scalar constant if provided, else a learnable per-unit initialisation. :threshold: Real firing threshold on Re(z). Defaults to 1. :dt: Integration timestep entering the pole exp(dt(-lambda+i*omega)). :activation: :class:spyx.axn.Axon surrogate spike; defaults to superspike.

Source code in spyx/phasor.py
def __init__(
    self,
    hidden_shape: tuple,
    lambda_init=None,
    omega_init=None,
    threshold: float = 1.0,
    dt: float = 1.0,
    activation=None,
    *,
    rngs: nnx.Rngs,
):
    """
    :hidden_shape: Per-unit shape of the layer.
    :lambda_init: Membrane decay ``>= 0``. Scalar constant if provided, else
        a learnable per-unit initialisation. Stored through ``softplus`` so
        the effective decay is always non-negative.
    :omega_init: Angular frequency of the oscillator. Scalar constant if
        provided, else a learnable per-unit initialisation.
    :threshold: Real firing threshold on ``Re(z)``. Defaults to 1.
    :dt: Integration timestep entering the pole ``exp(dt(-lambda+i*omega))``.
    :activation: :class:`spyx.axn.Axon` surrogate spike; defaults to
        ``superspike``.
    """
    if dt <= 0:
        raise ValueError(f"dt must be positive; got {dt}.")
    self.hidden_shape = hidden_shape
    self.threshold = threshold
    self.dt = dt
    self.spike = activation if activation is not None else _DEFAULT_ACTIVATION

    # Raw decay parameter; effective lambda = softplus(raw) >= 0 so |a| <= 1.
    if lambda_init is None:
        # Small positive decays: softplus(N(0.5, 0.25)) ~ light damping.
        raw = (
            nnx.initializers.truncated_normal(stddev=0.25)(
                rngs.params(), self.hidden_shape
            )
            + 0.5
        )
        self.raw_lambda = nnx.Param(raw.astype(jnp.float32))
    else:
        self.raw_lambda = nnx.Param(
            _inverse_softplus(jnp.full((), float(lambda_init))).astype(jnp.float32)
        )

    if omega_init is None:
        # Spread frequencies around ~1 rad/step so units resonate distinctly.
        omega = (
            nnx.initializers.truncated_normal(stddev=0.5)(
                rngs.params(), self.hidden_shape
            )
            + 1.0
        )
        self.omega = nnx.Param(omega.astype(jnp.float32))
    else:
        self.omega = nnx.Param(jnp.full((), float(omega_init)))

initial_state(batch_size)

Return complex64 zeros of shape (batch_size,) + hidden_shape.

Source code in spyx/phasor.py
def initial_state(self, batch_size):
    """Return complex64 zeros of shape ``(batch_size,) + hidden_shape``."""
    return jnp.zeros((batch_size,) + tuple(self.hidden_shape), dtype=jnp.complex64)

parallel(x)

Score a whole time-major sequence with an associative scan.

❌ real input with shape [Time, Batch, ...]. :return: spikes with shape [Time, Batch, ...].

Computes the full complex membrane trace z_t = a * z_{t-1} + x_t (with z_{-1} = 0) via :func:jax.lax.associative_scan over the time axis in :math:O(\log T) depth, then applies the surrogate spike pointwise on Re(z).

Source code in spyx/phasor.py
def parallel(self, x):
    r"""Score a whole time-major sequence with an associative scan.

    :x: real input with shape ``[Time, Batch, ...]``.
    :return: spikes with shape ``[Time, Batch, ...]``.

    Computes the full complex membrane trace ``z_t = a * z_{t-1} + x_t``
    (with ``z_{-1} = 0``) via :func:`jax.lax.associative_scan` over the time
    axis in :math:`O(\log T)` depth, then applies the surrogate spike
    pointwise on ``Re(z)``.
    """
    a = self.a
    xc = x.astype(jnp.complex64)
    # Broadcast the (scalar or per-unit) complex pole to every element so the
    # linear-recurrence coefficient a_t == a everywhere along the time axis.
    A = jnp.broadcast_to(a, xc.shape)
    _, z = jax.lax.associative_scan(_resonate_associative_op, (A, xc), axis=0)
    return self.spike(jnp.real(z) - self.threshold)

spyx.experimental.raven

Raven Routing-Slot-Memory (RSM) block for Spyx.

A Flax NNX implementation of the Routing Slot Memory recurrence introduced by Raven (Afzal, Bick, Xing, Cevher, Gu, 2026; "High-recall sequence modeling with sparse memory routing"). Compressed-state recurrent models (a single SSM state with uniform decay) struggle with exact recall: every new token perturbs the whole state, so previously written associations interfere with each other.

Raven's fix is to partition the memory into M independent slots and use a learned sparse router r_t to write only the selected slots, leaving the rest untouched (shielded from interference). Writing slot m at step t:

.. math:: S_t = (1 - r_t) \odot S_{t-1} + r_t \odot ( D_t S_{t-1} A_t + U_t )

  • S_t: slot memory, shape (B, M, d_slot).
  • r_t \in [0, 1]^M: the per-slot router (ideally sparse). Unselected slots (r_t[m] ≈ 0) pass through unchanged; selected slots decay and are written.
  • U_t: the write (a projection of the current input).

The router is "a Mixture-of-Experts for memory". Two reductions are worth remembering (and are exercised by the tests):

  • a dense router (r_t all-ones) recovers a standard gated diagonal SSM,
  • a one-hot cyclic router recovers sliding-window attention.

Faithful-but-tractable simplification (documented, see :class:RavenRSM): the per-slot transition is made diagonal — the full matrix sandwich D_t S_{t-1} A_t is replaced by a per-slot (per-dim) decay a ⊙ S_{t-1}, so each slot is a gated diagonal recurrence. The full matrix-sandwich form is deferred. Likewise the recurrence is run with a plain :func:jax.lax.scan reference (honest baseline); because the per-step transition is input-dependent through the router gate (1 - r_t), the recurrence is a per-timestep diagonal linear recurrence and an associative / chunked associative_scan form is in principle possible (the Raven authors defer it to a "Part 2"), but is not implemented here.

RavenRSM

Bases: Module

Routing-Slot-Memory recurrent block (diagonal simplification).

Sequence-in / sequence-out, matching the :mod:spyx.ssm interface: __call__(u: (T, B, d_model)) -> (T, B, d_model).

Per step t the block computes, from u_t:

  • a sparse write router r_t = SlotRouter(u_t) \in [0, 1]^{(B, M)},
  • the write U_t = reshape(W_u u_t) \in (B, M, d_slot),

and updates the slot memory with the diagonal RSM recurrence

.. math:: S_t = (1 - r_t) \odot S_{t-1} + r_t \odot (a \odot S_{t-1} + U_t)

where a = sigmoid(raw_decay) \in (0, 1)^{(M, d_slot)} is a static, learnable per-slot / per-dim decay (kept in (0, 1) for stability; an input-dependent / selective decay is a straightforward extension but is not used here so the dense reduction stays a clean gated diagonal SSM). The recurrence is evaluated with :func:jax.lax.scan over time.

Readout (y_t): a query-gated read over slots. A learned query q_t = softmax(W_q u_t) \in (B, M) mixes the slots into a single read vector read_t = \sum_m q_t[m] S_t[m] \in (B, d_slot), which a linear map projects back to (B, d_model). This mirrors the routing idea on the read side: the query key selects which slot(s) to retrieve.

Simplifications (deferred, per the module docstring): (1) the full matrix-sandwich transition D_t S_{t-1} A_t is replaced by the diagonal decay a; (2) only a sequential lax.scan is provided — a chunked / associative-scan form is possible but deferred.

Source code in spyx/experimental/raven.py
class RavenRSM(nnx.Module):
    r"""Routing-Slot-Memory recurrent block (diagonal simplification).

    Sequence-in / sequence-out, matching the :mod:`spyx.ssm` interface:
    ``__call__(u: (T, B, d_model)) -> (T, B, d_model)``.

    Per step ``t`` the block computes, from ``u_t``:

    * a sparse write router ``r_t = SlotRouter(u_t) \in [0, 1]^{(B, M)}``,
    * the write ``U_t = reshape(W_u u_t) \in (B, M, d_slot)``,

    and updates the slot memory with the diagonal RSM recurrence

    .. math::
        S_t = (1 - r_t) \odot S_{t-1} + r_t \odot (a \odot S_{t-1} + U_t)

    where ``a = sigmoid(raw_decay) \in (0, 1)^{(M, d_slot)}`` is a **static,
    learnable per-slot / per-dim decay** (kept in ``(0, 1)`` for stability; an
    input-dependent / selective decay is a straightforward extension but is not
    used here so the dense reduction stays a clean gated diagonal SSM). The
    recurrence is evaluated with :func:`jax.lax.scan` over time.

    **Readout** (``y_t``): a query-gated read over slots. A learned query
    ``q_t = softmax(W_q u_t) \in (B, M)`` mixes the slots into a single read
    vector ``read_t = \sum_m q_t[m] S_t[m] \in (B, d_slot)``, which a linear map
    projects back to ``(B, d_model)``. This mirrors the routing idea on the read
    side: the query key selects which slot(s) to retrieve.

    Simplifications (deferred, per the module docstring): (1) the full
    matrix-sandwich transition ``D_t S_{t-1} A_t`` is replaced by the diagonal
    decay ``a``; (2) only a sequential ``lax.scan`` is provided — a chunked /
    associative-scan form is possible but deferred.
    """

    def __init__(
        self,
        d_model: int,
        n_slots: int = 8,
        d_slot: int | None = None,
        *,
        hard_top_k: int | None = None,
        decay_init: float = 0.9,
        rngs: nnx.Rngs,
    ):
        if d_slot is None:
            d_slot = d_model
        if n_slots < 1:
            raise ValueError(f"n_slots must be >= 1; got {n_slots}.")
        if d_slot < 1:
            raise ValueError(f"d_slot must be >= 1; got {d_slot}.")
        if not 0.0 < decay_init < 1.0:
            raise ValueError(f"decay_init must be in (0, 1); got {decay_init}.")

        self.d_model = d_model
        self.n_slots = n_slots
        self.d_slot = d_slot

        self.router = SlotRouter(d_model, n_slots, hard_top_k=hard_top_k, rngs=rngs)
        # Write projection: u_t -> (M * d_slot), reshaped to (M, d_slot).
        self.write = nnx.Linear(d_model, n_slots * d_slot, rngs=rngs)
        # Read side: query over slots + projection back to d_model.
        self.readout_query = nnx.Linear(d_model, n_slots, rngs=rngs)
        self.out_proj = nnx.Linear(d_slot, d_model, rngs=rngs)

        # Static learnable per-slot / per-dim decay, stored as a raw logit so
        # that a = sigmoid(raw_decay) stays in (0, 1). Init near ``decay_init``
        # (slow decay -> long memory) with a little jitter.
        logit = float(jnp.log(decay_init / (1.0 - decay_init)))
        noise = 0.01 * jax.random.normal(rngs.params(), (n_slots, d_slot))
        self.raw_decay = nnx.Param(jnp.full((n_slots, d_slot), logit) + noise)

    @property
    def decay(self) -> jax.Array:
        """Effective per-slot / per-dim decay ``a = sigmoid(raw_decay)`` in ``(0, 1)``."""
        return jax.nn.sigmoid(self.raw_decay[...])

    def initial_state(self, batch_size: int) -> jax.Array:
        """Return zero slot memory of shape ``(batch_size, M, d_slot)``."""
        return jnp.zeros((batch_size, self.n_slots, self.d_slot), dtype=jnp.float32)

    def _route(self, u_t: jax.Array) -> jax.Array:
        """Expose the router for reuse: ``u_t (..., d_model) -> r (..., M)``."""
        return self.router(u_t)

    def step(self, state: jax.Array, u_t: jax.Array) -> tuple[jax.Array, jax.Array]:
        """One reset-free RSM timestep.

        :state: slot memory ``S_{t-1}``, shape ``(B, M, d_slot)``.
        :u_t: input ``(B, d_model)``.
        :return: ``(S_t, y_t)`` with ``y_t`` of shape ``(B, d_model)``.
        """
        r_t = self.router(u_t)  # (B, M)
        U_t = self.write(u_t).reshape(u_t.shape[0], self.n_slots, self.d_slot)
        a = self.decay[None]  # (1, M, d_slot)
        gated = a * state + U_t
        r_exp = r_t[..., None]  # (B, M, 1)
        s_new = (1.0 - r_exp) * state + r_exp * gated
        attn = jax.nn.softmax(self.readout_query(u_t), axis=-1)  # (B, M)
        read = jnp.einsum("bm,bmd->bd", attn, s_new)  # (B, d_slot)
        y_t = self.out_proj(read)
        return s_new, y_t

    def _run(self, u: jax.Array, r: jax.Array) -> jax.Array:
        """Core recurrence with a *precomputed* router ``r`` of shape ``(T, B, M)``.

        Factored out so tests (and the dense-router reduction) can force ``r``.
        """
        T, B, _ = u.shape
        U = self.write(u).reshape(T, B, self.n_slots, self.d_slot)
        attn = jax.nn.softmax(self.readout_query(u), axis=-1)  # (T, B, M)
        a = self.decay[None]  # (1, M, d_slot)

        def scan_step(state, inp):
            r_t, U_t, attn_t = inp
            r_exp = r_t[..., None]
            gated = a * state + U_t
            s_new = (1.0 - r_exp) * state + r_exp * gated
            read = jnp.einsum("bm,bmd->bd", attn_t, s_new)
            return s_new, read

        s0 = self.initial_state(B)
        _, read_seq = jax.lax.scan(scan_step, s0, (r, U, attn))  # (T, B, d_slot)
        return self.out_proj(read_seq)

    def __call__(self, u: jax.Array) -> jax.Array:
        """Apply the RSM block to a time-major input.

        :u: real array of shape ``(T, B, d_model)``.
        :return: real array of shape ``(T, B, d_model)``.
        """
        if u.ndim != 3 or u.shape[-1] != self.d_model:
            raise ValueError(
                f"RavenRSM expects [T, B, d_model={self.d_model}]; got {u.shape}."
            )
        r = self.router(u)  # (T, B, M)
        return self._run(u, r)

decay property

Effective per-slot / per-dim decay a = sigmoid(raw_decay) in (0, 1).

__call__(u)

Apply the RSM block to a time-major input.

:u: real array of shape (T, B, d_model). :return: real array of shape (T, B, d_model).

Source code in spyx/experimental/raven.py
def __call__(self, u: jax.Array) -> jax.Array:
    """Apply the RSM block to a time-major input.

    :u: real array of shape ``(T, B, d_model)``.
    :return: real array of shape ``(T, B, d_model)``.
    """
    if u.ndim != 3 or u.shape[-1] != self.d_model:
        raise ValueError(
            f"RavenRSM expects [T, B, d_model={self.d_model}]; got {u.shape}."
        )
    r = self.router(u)  # (T, B, M)
    return self._run(u, r)

initial_state(batch_size)

Return zero slot memory of shape (batch_size, M, d_slot).

Source code in spyx/experimental/raven.py
def initial_state(self, batch_size: int) -> jax.Array:
    """Return zero slot memory of shape ``(batch_size, M, d_slot)``."""
    return jnp.zeros((batch_size, self.n_slots, self.d_slot), dtype=jnp.float32)

step(state, u_t)

One reset-free RSM timestep.

:state: slot memory S_{t-1}, shape (B, M, d_slot). :u_t: input (B, d_model). :return: (S_t, y_t) with y_t of shape (B, d_model).

Source code in spyx/experimental/raven.py
def step(self, state: jax.Array, u_t: jax.Array) -> tuple[jax.Array, jax.Array]:
    """One reset-free RSM timestep.

    :state: slot memory ``S_{t-1}``, shape ``(B, M, d_slot)``.
    :u_t: input ``(B, d_model)``.
    :return: ``(S_t, y_t)`` with ``y_t`` of shape ``(B, d_model)``.
    """
    r_t = self.router(u_t)  # (B, M)
    U_t = self.write(u_t).reshape(u_t.shape[0], self.n_slots, self.d_slot)
    a = self.decay[None]  # (1, M, d_slot)
    gated = a * state + U_t
    r_exp = r_t[..., None]  # (B, M, 1)
    s_new = (1.0 - r_exp) * state + r_exp * gated
    attn = jax.nn.softmax(self.readout_query(u_t), axis=-1)  # (B, M)
    read = jnp.einsum("bm,bmd->bd", attn, s_new)  # (B, d_slot)
    y_t = self.out_proj(read)
    return s_new, y_t

SlotRouter

Bases: Module

Learned per-slot write gate r_t = sigmoid(W_r u_t).

A small, reusable submodule (the spiking Raven variant reuses it). Maps an input of shape (..., d_model) to per-slot gates of shape (..., M) in [0, 1]. With hard_top_k set, the gate is additionally sparsified to the k most-active slots per row via a straight-through top-k (forward is sparse, gradients stay dense); the default (None) is a soft gate.

Design choice: a per-input sigmoid (independent per-slot Bernoulli logits) is used rather than a softmax so that several slots can be written at once (a multi-write MoE-for-memory), and so the dense all-ones reduction is reachable in the limit of large positive logits.

Source code in spyx/experimental/raven.py
class SlotRouter(nnx.Module):
    """Learned per-slot write gate ``r_t = sigmoid(W_r u_t)``.

    A small, reusable submodule (the spiking Raven variant reuses it). Maps an
    input of shape ``(..., d_model)`` to per-slot gates of shape ``(..., M)`` in
    ``[0, 1]``. With ``hard_top_k`` set, the gate is additionally sparsified to
    the ``k`` most-active slots per row via a straight-through top-``k`` (forward
    is sparse, gradients stay dense); the default (``None``) is a soft gate.

    Design choice: a per-input ``sigmoid`` (independent per-slot Bernoulli
    logits) is used rather than a ``softmax`` so that *several* slots can be
    written at once (a multi-write MoE-for-memory), and so the dense all-ones
    reduction is reachable in the limit of large positive logits.
    """

    def __init__(
        self,
        d_model: int,
        n_slots: int,
        *,
        hard_top_k: int | None = None,
        rngs: nnx.Rngs,
    ):
        if hard_top_k is not None and hard_top_k < 1:
            raise ValueError(f"hard_top_k must be >= 1 or None; got {hard_top_k}.")
        self.proj = nnx.Linear(d_model, n_slots, rngs=rngs)
        self.n_slots = n_slots
        self.hard_top_k = hard_top_k

    def __call__(self, u: jax.Array) -> jax.Array:
        """u: ``(..., d_model)`` -> gates ``(..., M)`` in ``[0, 1]``."""
        r = jax.nn.sigmoid(self.proj(u))
        if self.hard_top_k is not None:
            r = _straight_through_topk(r, self.hard_top_k)
        return r

__call__(u)

u: (..., d_model) -> gates (..., M) in [0, 1].

Source code in spyx/experimental/raven.py
def __call__(self, u: jax.Array) -> jax.Array:
    """u: ``(..., d_model)`` -> gates ``(..., M)`` in ``[0, 1]``."""
    r = jax.nn.sigmoid(self.proj(u))
    if self.hard_top_k is not None:
        r = _straight_through_topk(r, self.hard_top_k)
    return r

SpikingSlotMemory

Bases: Module

Spiking Routing-Slot Memory: a slot memory whose slots are spiking units.

This is the spiking sibling of :class:RavenRSM. It keeps the two ideas that make Raven a high-recall memory -- a bank of M independent slots and the same sparse write router -- but replaces each slot's linear accumulator with the reset-free spiking membrane of :class:spyx.nn.PSU_LIF: a leaky integrator V \leftarrow \beta V + x that emits a surrogate spike s = \sigma(V - \text{threshold}). The result is dual sparsity -- sparse in time (spikes) and sparse in slots (routing).

The slot membrane V_t has shape (B, M, d_slot). Per step t, from the input u_t:

  • the write router r_t = SlotRouter(u_t) \in [0, 1]^{(B, M)} (the exact router type reused from :class:RavenRSM -- self.router is a :class:SlotRouter, not a fork), and
  • the write U_t = reshape(W_u u_t) \in (B, M, d_slot).

The membrane is then advanced with the routed, reset-free spiking recurrence

.. math:: V_t = (1 - r_t) \odot V_{t-1} + r_t \odot (\beta \odot V_{t-1} + U_t), \qquad s_t = \sigma(V_t - \text{threshold}),

where \beta = sigmoid(raw_beta) \in (0, 1)^{(M, d_slot)} is a static, learnable per-slot / per-dim leak. Shielding: where r_t[m] = 0 the update collapses to V_t[m] = V_{t-1}[m] -- the slot's membrane (and hence its spike) is passed through byte-for-byte unchanged, shielded from interference exactly as in :class:RavenRSM. Where r_t[m] = 1 the slot runs a plain :class:spyx.nn.PSU_LIF step V \leftarrow \beta V + U_t.

Output is the raw slot spike train of shape (T, B, M, d_slot) (no dense readout projection -- the block is a spiking memory; compose a linear head downstream if real-valued outputs are needed).

Reset-freeness is deliberate: the membrane recurrence stays a first-order linear map per slot, so -- exactly as documented for :class:spyx.nn.PSU_LIF -- a chunked / :func:jax.lax.associative_scan parallel form is possible. Because the per-step transition here is input-dependent through the router gate (1 - r_t), the associative element is the affine map V \mapsto A_t V + b_t with A_t = (1 - r_t) + r_t \beta and b_t = r_t U_t; only the sequential :func:jax.lax.scan reference is implemented here (an honest baseline), matching :class:RavenRSM.

Reductions (exercised by the tests): a dense router (r_t all-ones) turns every slot into an independent, always-written :class:spyx.nn.PSU_LIF -- i.e. a plain bank of spiking leaky integrators driven by U_t; the routing is what makes it a memory.

Source code in spyx/experimental/raven.py
class SpikingSlotMemory(nnx.Module):
    r"""Spiking Routing-Slot Memory: a slot memory whose slots are *spiking* units.

    This is the spiking sibling of :class:`RavenRSM`. It keeps the two ideas that
    make Raven a high-recall memory -- a bank of ``M`` independent **slots** and
    the *same* sparse write **router** -- but replaces each slot's linear
    accumulator with the **reset-free spiking membrane** of
    :class:`spyx.nn.PSU_LIF`: a leaky integrator ``V \leftarrow \beta V + x`` that
    emits a surrogate spike ``s = \sigma(V - \text{threshold})``. The result is
    **dual sparsity** -- sparse in *time* (spikes) *and* sparse in *slots*
    (routing).

    The slot membrane ``V_t`` has shape ``(B, M, d_slot)``. Per step ``t``, from
    the input ``u_t``:

    * the write router ``r_t = SlotRouter(u_t) \in [0, 1]^{(B, M)}`` (the **exact**
      router type reused from :class:`RavenRSM` -- ``self.router`` is a
      :class:`SlotRouter`, not a fork), and
    * the write ``U_t = reshape(W_u u_t) \in (B, M, d_slot)``.

    The membrane is then advanced with the routed, reset-free spiking recurrence

    .. math::
        V_t = (1 - r_t) \odot V_{t-1} + r_t \odot (\beta \odot V_{t-1} + U_t),
        \qquad s_t = \sigma(V_t - \text{threshold}),

    where ``\beta = sigmoid(raw_beta) \in (0, 1)^{(M, d_slot)}`` is a static,
    learnable per-slot / per-dim leak. **Shielding:** where ``r_t[m] = 0`` the
    update collapses to ``V_t[m] = V_{t-1}[m]`` -- the slot's membrane (and hence
    its spike) is passed through byte-for-byte unchanged, shielded from
    interference exactly as in :class:`RavenRSM`. Where ``r_t[m] = 1`` the slot
    runs a plain :class:`spyx.nn.PSU_LIF` step ``V \leftarrow \beta V + U_t``.

    **Output** is the raw slot spike train of shape ``(T, B, M, d_slot)`` (no
    dense readout projection -- the block *is* a spiking memory; compose a linear
    head downstream if real-valued outputs are needed).

    Reset-freeness is deliberate: the membrane recurrence stays a first-order
    linear map per slot, so -- exactly as documented for :class:`spyx.nn.PSU_LIF`
    -- a chunked / :func:`jax.lax.associative_scan` parallel form is *possible*.
    Because the per-step transition here is *input-dependent* through the router
    gate ``(1 - r_t)``, the associative element is the affine map
    ``V \mapsto A_t V + b_t`` with ``A_t = (1 - r_t) + r_t \beta`` and
    ``b_t = r_t U_t``; only the sequential :func:`jax.lax.scan` reference is
    implemented here (an honest baseline), matching :class:`RavenRSM`.

    Reductions (exercised by the tests): a **dense** router (``r_t`` all-ones)
    turns every slot into an independent, always-written
    :class:`spyx.nn.PSU_LIF` -- i.e. a plain bank of spiking leaky integrators
    driven by ``U_t``; the routing is what makes it a *memory*.
    """

    def __init__(
        self,
        d_model: int,
        n_slots: int = 8,
        d_slot: int | None = None,
        *,
        hard_top_k: int | None = None,
        beta_init: float = 0.9,
        threshold: float = 1.0,
        activation=None,
        rngs: nnx.Rngs,
    ):
        """
        :d_model: Input feature width.
        :n_slots: Number of independent memory slots ``M``.
        :d_slot: Per-slot membrane width (defaults to ``d_model``).
        :hard_top_k: If set, the router keeps only its ``k`` most-active slots per
            step (straight-through top-``k``); the default is a soft gate.
        :beta_init: Initial per-slot leak in ``(0, 1)`` (stored as a logit).
        :threshold: Firing threshold on the membrane.
        :activation: :class:`spyx.axn.Axon` surrogate spike; defaults to
            ``superspike`` (matching :class:`spyx.nn.PSU_LIF`).
        :rngs: NNX PRNG collection.
        """
        if d_slot is None:
            d_slot = d_model
        if n_slots < 1:
            raise ValueError(f"n_slots must be >= 1; got {n_slots}.")
        if d_slot < 1:
            raise ValueError(f"d_slot must be >= 1; got {d_slot}.")
        if not 0.0 < beta_init < 1.0:
            raise ValueError(f"beta_init must be in (0, 1); got {beta_init}.")

        self.d_model = d_model
        self.n_slots = n_slots
        self.d_slot = d_slot
        self.threshold = threshold
        self.spike = activation if activation is not None else _DEFAULT_SPIKE

        # Reuse the *exact* router mechanism from RavenRSM (same SlotRouter class).
        self.router = SlotRouter(d_model, n_slots, hard_top_k=hard_top_k, rngs=rngs)
        # Write projection: u_t -> (M * d_slot), reshaped to (M, d_slot).
        self.write = nnx.Linear(d_model, n_slots * d_slot, rngs=rngs)

        # Static learnable per-slot / per-dim leak, stored as a raw logit so that
        # beta = sigmoid(raw_beta) stays in (0, 1). Init near ``beta_init`` (slow
        # leak -> long membrane memory) with a little jitter.
        logit = float(jnp.log(beta_init / (1.0 - beta_init)))
        noise = 0.01 * jax.random.normal(rngs.params(), (n_slots, d_slot))
        self.raw_beta = nnx.Param(jnp.full((n_slots, d_slot), logit) + noise)

    @property
    def beta(self) -> jax.Array:
        """Effective per-slot / per-dim leak ``beta = sigmoid(raw_beta)`` in ``(0, 1)``."""
        return jax.nn.sigmoid(self.raw_beta[...])

    def initial_state(self, batch_size: int) -> jax.Array:
        """Return zero slot membrane of shape ``(batch_size, M, d_slot)``."""
        return jnp.zeros((batch_size, self.n_slots, self.d_slot), dtype=jnp.float32)

    def _route(self, u_t: jax.Array) -> jax.Array:
        """Expose the reused router: ``u_t (..., d_model) -> r (..., M)``."""
        return self.router(u_t)

    def step(self, state: jax.Array, u_t: jax.Array) -> tuple[jax.Array, jax.Array]:
        """One reset-free spiking-slot timestep.

        :state: slot membrane ``V_{t-1}``, shape ``(B, M, d_slot)``.
        :u_t: input ``(B, d_model)``.
        :return: ``(V_t, s_t)`` -- the new membrane and the slot spikes of shape
            ``(B, M, d_slot)``.
        """
        r_t = self.router(u_t)  # (B, M)
        U_t = self.write(u_t).reshape(u_t.shape[0], self.n_slots, self.d_slot)
        beta = self.beta[None]  # (1, M, d_slot)
        gated = beta * state + U_t
        r_exp = r_t[..., None]  # (B, M, 1)
        v_new = (1.0 - r_exp) * state + r_exp * gated
        spikes = self.spike(v_new - self.threshold)
        return v_new, spikes

    def _run(self, u: jax.Array, r: jax.Array) -> jax.Array:
        """Core recurrence with a *precomputed* router ``r`` of shape ``(T, B, M)``.

        Factored out so tests (and the dense-router reduction) can force ``r``.
        """
        T, B, _ = u.shape
        U = self.write(u).reshape(T, B, self.n_slots, self.d_slot)
        beta = self.beta[None]  # (1, M, d_slot)

        def scan_step(state, inp):
            r_t, U_t = inp
            r_exp = r_t[..., None]
            gated = beta * state + U_t
            v_new = (1.0 - r_exp) * state + r_exp * gated
            spikes = self.spike(v_new - self.threshold)
            return v_new, spikes

        v0 = self.initial_state(B)
        _, spikes = jax.lax.scan(scan_step, v0, (r, U))  # (T, B, M, d_slot)
        return spikes

    def __call__(self, u: jax.Array) -> jax.Array:
        """Apply the spiking slot memory to a time-major input.

        :u: real array of shape ``(T, B, d_model)``.
        :return: spike train of shape ``(T, B, M, d_slot)``.
        """
        if u.ndim != 3 or u.shape[-1] != self.d_model:
            raise ValueError(
                f"SpikingSlotMemory expects [T, B, d_model={self.d_model}]; "
                f"got {u.shape}."
            )
        r = self.router(u)  # (T, B, M)
        return self._run(u, r)

beta property

Effective per-slot / per-dim leak beta = sigmoid(raw_beta) in (0, 1).

__call__(u)

Apply the spiking slot memory to a time-major input.

:u: real array of shape (T, B, d_model). :return: spike train of shape (T, B, M, d_slot).

Source code in spyx/experimental/raven.py
def __call__(self, u: jax.Array) -> jax.Array:
    """Apply the spiking slot memory to a time-major input.

    :u: real array of shape ``(T, B, d_model)``.
    :return: spike train of shape ``(T, B, M, d_slot)``.
    """
    if u.ndim != 3 or u.shape[-1] != self.d_model:
        raise ValueError(
            f"SpikingSlotMemory expects [T, B, d_model={self.d_model}]; "
            f"got {u.shape}."
        )
    r = self.router(u)  # (T, B, M)
    return self._run(u, r)

__init__(d_model, n_slots=8, d_slot=None, *, hard_top_k=None, beta_init=0.9, threshold=1.0, activation=None, rngs)

:d_model: Input feature width. :n_slots: Number of independent memory slots M. :d_slot: Per-slot membrane width (defaults to d_model). :hard_top_k: If set, the router keeps only its k most-active slots per step (straight-through top-k); the default is a soft gate. :beta_init: Initial per-slot leak in (0, 1) (stored as a logit). :threshold: Firing threshold on the membrane. :activation: :class:spyx.axn.Axon surrogate spike; defaults to superspike (matching :class:spyx.nn.PSU_LIF). :rngs: NNX PRNG collection.

Source code in spyx/experimental/raven.py
def __init__(
    self,
    d_model: int,
    n_slots: int = 8,
    d_slot: int | None = None,
    *,
    hard_top_k: int | None = None,
    beta_init: float = 0.9,
    threshold: float = 1.0,
    activation=None,
    rngs: nnx.Rngs,
):
    """
    :d_model: Input feature width.
    :n_slots: Number of independent memory slots ``M``.
    :d_slot: Per-slot membrane width (defaults to ``d_model``).
    :hard_top_k: If set, the router keeps only its ``k`` most-active slots per
        step (straight-through top-``k``); the default is a soft gate.
    :beta_init: Initial per-slot leak in ``(0, 1)`` (stored as a logit).
    :threshold: Firing threshold on the membrane.
    :activation: :class:`spyx.axn.Axon` surrogate spike; defaults to
        ``superspike`` (matching :class:`spyx.nn.PSU_LIF`).
    :rngs: NNX PRNG collection.
    """
    if d_slot is None:
        d_slot = d_model
    if n_slots < 1:
        raise ValueError(f"n_slots must be >= 1; got {n_slots}.")
    if d_slot < 1:
        raise ValueError(f"d_slot must be >= 1; got {d_slot}.")
    if not 0.0 < beta_init < 1.0:
        raise ValueError(f"beta_init must be in (0, 1); got {beta_init}.")

    self.d_model = d_model
    self.n_slots = n_slots
    self.d_slot = d_slot
    self.threshold = threshold
    self.spike = activation if activation is not None else _DEFAULT_SPIKE

    # Reuse the *exact* router mechanism from RavenRSM (same SlotRouter class).
    self.router = SlotRouter(d_model, n_slots, hard_top_k=hard_top_k, rngs=rngs)
    # Write projection: u_t -> (M * d_slot), reshaped to (M, d_slot).
    self.write = nnx.Linear(d_model, n_slots * d_slot, rngs=rngs)

    # Static learnable per-slot / per-dim leak, stored as a raw logit so that
    # beta = sigmoid(raw_beta) stays in (0, 1). Init near ``beta_init`` (slow
    # leak -> long membrane memory) with a little jitter.
    logit = float(jnp.log(beta_init / (1.0 - beta_init)))
    noise = 0.01 * jax.random.normal(rngs.params(), (n_slots, d_slot))
    self.raw_beta = nnx.Param(jnp.full((n_slots, d_slot), logit) + noise)

initial_state(batch_size)

Return zero slot membrane of shape (batch_size, M, d_slot).

Source code in spyx/experimental/raven.py
def initial_state(self, batch_size: int) -> jax.Array:
    """Return zero slot membrane of shape ``(batch_size, M, d_slot)``."""
    return jnp.zeros((batch_size, self.n_slots, self.d_slot), dtype=jnp.float32)

step(state, u_t)

One reset-free spiking-slot timestep.

:state: slot membrane V_{t-1}, shape (B, M, d_slot). :u_t: input (B, d_model). :return: (V_t, s_t) -- the new membrane and the slot spikes of shape (B, M, d_slot).

Source code in spyx/experimental/raven.py
def step(self, state: jax.Array, u_t: jax.Array) -> tuple[jax.Array, jax.Array]:
    """One reset-free spiking-slot timestep.

    :state: slot membrane ``V_{t-1}``, shape ``(B, M, d_slot)``.
    :u_t: input ``(B, d_model)``.
    :return: ``(V_t, s_t)`` -- the new membrane and the slot spikes of shape
        ``(B, M, d_slot)``.
    """
    r_t = self.router(u_t)  # (B, M)
    U_t = self.write(u_t).reshape(u_t.shape[0], self.n_slots, self.d_slot)
    beta = self.beta[None]  # (1, M, d_slot)
    gated = beta * state + U_t
    r_exp = r_t[..., None]  # (B, M, 1)
    v_new = (1.0 - r_exp) * state + r_exp * gated
    spikes = self.spike(v_new - self.threshold)
    return v_new, spikes

make_recall_batch(key, *, batch=8, n_pairs=3, n_keys=8, n_values=8)

Generate a multi-query associative-recall (MQAR-style) batch.

Each example is a sequence of n_pairs (key, value) bindings followed by a single query token equal to one of the presented keys. The target is the value bound to the queried key — a task compressed-state SSMs fail at but slot-routed memories solve, because each binding can live in its own (interference-free) slot.

Tokens are one-hot encoded into d_model = n_keys + n_values dims: key i -> e_i; value j -> e_{n_keys + j}. The query token reuses its key's encoding. Sequence length is T = 2 * n_pairs + 1.

🔑 PRNG key. :batch: number of independent examples. :n_pairs: key/value bindings per example (distinct keys, sampled w/o repl.). :n_keys: key vocabulary size (must be >= n_pairs). :n_values: value vocabulary size. :return: (u, target) where u is (T, B, d_model) float one-hots and target is (B,) int32 value ids for the query.

Source code in spyx/experimental/raven.py
def make_recall_batch(
    key: jax.Array,
    *,
    batch: int = 8,
    n_pairs: int = 3,
    n_keys: int = 8,
    n_values: int = 8,
) -> tuple[jax.Array, jax.Array]:
    """Generate a multi-query associative-recall (MQAR-style) batch.

    Each example is a sequence of ``n_pairs`` ``(key, value)`` bindings followed
    by a single **query** token equal to one of the presented keys. The target
    is the value bound to the queried key — a task compressed-state SSMs fail at
    but slot-routed memories solve, because each binding can live in its own
    (interference-free) slot.

    Tokens are one-hot encoded into ``d_model = n_keys + n_values`` dims: key
    ``i`` -> ``e_i``; value ``j`` -> ``e_{n_keys + j}``. The query token reuses
    its key's encoding. Sequence length is ``T = 2 * n_pairs + 1``.

    :key: PRNG key.
    :batch: number of independent examples.
    :n_pairs: key/value bindings per example (distinct keys, sampled w/o repl.).
    :n_keys: key vocabulary size (must be ``>= n_pairs``).
    :n_values: value vocabulary size.
    :return: ``(u, target)`` where ``u`` is ``(T, B, d_model)`` float one-hots
        and ``target`` is ``(B,)`` int32 value ids for the query.
    """
    if n_keys < n_pairs:
        raise ValueError(f"n_keys ({n_keys}) must be >= n_pairs ({n_pairs}).")
    d_model = n_keys + n_values
    T = 2 * n_pairs + 1

    keys_out = jnp.zeros((T, batch, d_model), dtype=jnp.float32)
    targets = jnp.zeros((batch,), dtype=jnp.int32)

    for b in range(batch):
        key, k_perm, k_val, k_q = jax.random.split(key, 4)
        # Distinct keys for this example.
        key_ids = jax.random.permutation(k_perm, n_keys)[:n_pairs]
        value_ids = jax.random.randint(k_val, (n_pairs,), 0, n_values)

        for p in range(n_pairs):
            kid = int(key_ids[p])
            vid = int(value_ids[p])
            keys_out = keys_out.at[2 * p, b, kid].set(1.0)
            keys_out = keys_out.at[2 * p + 1, b, n_keys + vid].set(1.0)

        q = int(jax.random.randint(k_q, (), 0, n_pairs))
        qid = int(key_ids[q])
        keys_out = keys_out.at[T - 1, b, qid].set(1.0)
        targets = targets.at[b].set(int(value_ids[q]))

    return keys_out, targets

spyx.experimental.compress

Bit-packed activation storage for memory-efficient BPTT.

Training spiking networks with backpropagation-through-time is dominated, memory-wise, by the activations saved for the backward pass. In an SNN the activations feeding each linear layer are the spikes, which are exactly {0, 1} valued. A dense op spikes @ weight normally stashes the full floating-point spikes tensor as its backward residual so it can later form dW = spikes^T @ g. Storing one bit per spike as a float wastes 8x-32x the memory it needs.

This module bit-packs that residual with :func:jax.numpy.packbits (8 spikes per uint8) and unpacks it lazily inside the backward pass. The forward output and both gradients (w.r.t. weight and spikes) are numerically identical to the naive spikes @ weight -- we only trade a cheap unpack-recompute for a large cut in the dominant activation residual.

Correctness relies on the input being exactly binary (values in {0, 1}); :func:packed_spike_dense is only valid for spike tensors, not arbitrary floats.

pack_spikes(x, axis=-1)

Bit-pack a binary spike tensor along axis.

Mirrors the np.packbits(..., axis=...) convention used by :mod:spyx.data (which packs along the time axis): every group of 8 consecutive {0, 1} values along axis is packed into a single uint8, big-endian bit order. If the axis length is not a multiple of 8 the final byte is zero-padded on the low bits, so the original length must be supplied to :func:unpack_spikes to recover the exact tensor.

:param x: binary tensor (values in {0, 1}); cast to uint8. :param axis: axis along which to pack (default last). :return: uint8 tensor with ceil(len/8) entries along axis.

Source code in spyx/experimental/compress.py
def pack_spikes(x, axis=-1):
    """Bit-pack a binary spike tensor along ``axis``.

    Mirrors the ``np.packbits(..., axis=...)`` convention used by
    :mod:`spyx.data` (which packs along the time axis): every group of 8
    consecutive ``{0, 1}`` values along ``axis`` is packed into a single
    ``uint8``, big-endian bit order. If the axis length is not a multiple of
    8 the final byte is zero-padded on the low bits, so the original length
    must be supplied to :func:`unpack_spikes` to recover the exact tensor.

    :param x: binary tensor (values in ``{0, 1}``); cast to ``uint8``.
    :param axis: axis along which to pack (default last).
    :return: ``uint8`` tensor with ``ceil(len/8)`` entries along ``axis``.
    """
    return jnp.packbits(x.astype(jnp.uint8), axis=axis)

packed_spike_dense(spikes, weight)

spikes @ weight with a bit-packed backward residual.

Forward numerics are a plain matmul over the trailing feature axis of spikes (shape (..., in)) against weight (shape (in, out)), yielding (..., out). The custom VJP saves packbits(spikes) -- a uint8 tensor 8x smaller than spikes would be as bf16/fp -- instead of the dense activations, unpacking it in the backward pass to form dW = spikes^T @ g and dspikes = g @ weight^T.

Both first-order gradients equal those of the naive spikes @ weight.

Limitations: valid only when spikes is exactly binary (values in {0, 1}) -- packing a general float tensor silently binarizes the saved residual, so the forward stays exact but dW becomes wrong. Only the first-order VJP is correct; second-order derivatives (grad-of-grad) are not, since the packed residual is not itself differentiated. Both are fine for ordinary first-order BPTT, the intended use.

Source code in spyx/experimental/compress.py
@jax.custom_vjp
def packed_spike_dense(spikes, weight):
    """``spikes @ weight`` with a bit-packed backward residual.

    Forward numerics are a plain matmul over the trailing feature axis of
    ``spikes`` (shape ``(..., in)``) against ``weight`` (shape ``(in, out)``),
    yielding ``(..., out)``. The custom VJP saves ``packbits(spikes)`` -- a
    ``uint8`` tensor 8x smaller than ``spikes`` would be as bf16/fp -- instead
    of the dense activations, unpacking it in the backward pass to form
    ``dW = spikes^T @ g`` and ``dspikes = g @ weight^T``.

    Both first-order gradients equal those of the naive ``spikes @ weight``.

    Limitations: valid only when ``spikes`` is exactly binary (values in
    ``{0, 1}``) -- packing a general float tensor silently binarizes the saved
    residual, so the forward stays exact but ``dW`` becomes wrong. Only the
    first-order VJP is correct; second-order derivatives (grad-of-grad) are not,
    since the packed residual is not itself differentiated. Both are fine for
    ordinary first-order BPTT, the intended use.
    """
    return _dense(spikes, weight)

unpack_spikes(packed, length, axis=-1)

Invert :func:pack_spikes, recovering length values along axis.

:param packed: uint8 tensor produced by :func:pack_spikes. :param length: original (pre-pack) size of axis; trims the zero padding introduced when length is not a multiple of 8. :param axis: axis along which the tensor was packed (default last). :return: uint8 tensor of {0, 1} values, length long on axis.

Source code in spyx/experimental/compress.py
def unpack_spikes(packed, length, axis=-1):
    """Invert :func:`pack_spikes`, recovering ``length`` values along ``axis``.

    :param packed: ``uint8`` tensor produced by :func:`pack_spikes`.
    :param length: original (pre-pack) size of ``axis``; trims the zero
        padding introduced when ``length`` is not a multiple of 8.
    :param axis: axis along which the tensor was packed (default last).
    :return: ``uint8`` tensor of ``{0, 1}`` values, ``length`` long on ``axis``.
    """
    return jnp.unpackbits(packed, axis=axis, count=length)

spyx.experimental.stochastic

Experimental stochastic / parallelizable spiking-neuron prototypes.

Stochastic (Bernoulli-spiking) neurons and the SPSN prototype, all built on the parallel prefix-scan (_pscan) membrane. Research-stage; the promoted, production reset-free neuron is :class:spyx.experimental.PSU_LIF (in spyx.nn). See [[SPSN]] (arXiv:2306.12666).

SPSN

Bases: Module

Prototype implementation of Stochastic Parallelizable Spiking Neuron:

https://doi.org/10.48550/arXiv.2306.12666

Source code in spyx/experimental/stochastic.py
class SPSN(nnx.Module):
    """
    Prototype implementation of Stochastic Parallelizable Spiking Neuron:

    https://doi.org/10.48550/arXiv.2306.12666
    """

    def __init__(self, hidden_shape: tuple, threshold=1, k=10, *, rngs: nnx.Rngs):
        self.hidden_shape = hidden_shape
        self.threshold = threshold
        self.spike = sigmoid_bernoulli(k, threshold)

        self.beta = nnx.Param(
            nnx.initializers.truncated_normal(stddev=0.25)(
                rngs.params(), self.hidden_shape
            )
            + 0.5
        )

    def __call__(self, key, x):
        beta = jnp.clip(self.beta[:], 0, 1)

        # per-neuron, per-timestep decay kernel B[t, c] = beta[c]**t * (1 - beta[c])
        T = x.shape[1]
        B = jnp.power(beta[None, :], jnp.arange(T)[:, None]) * (1 - beta[None, :])

        fft_B = jnp.fft.rfft(B, n=2 * T, axis=0)[None, :, :]
        fft_X = jnp.fft.rfft(x, n=2 * T, axis=1)

        V = jnp.fft.irfft(fft_X * fft_B, n=2 * T, axis=1)[:, :T, :]

        # calculate whether spike is generated, and update membrane potential
        spikes = self.spike(V, key)

        return spikes, V

beta = nnx.Param(nnx.initializers.truncated_normal(stddev=0.25)(rngs.params(), self.hidden_shape) + 0.5) instance-attribute

hidden_shape = hidden_shape instance-attribute

spike = sigmoid_bernoulli(k, threshold) instance-attribute

threshold = threshold instance-attribute

__call__(key, x)

Source code in spyx/experimental/stochastic.py
def __call__(self, key, x):
    beta = jnp.clip(self.beta[:], 0, 1)

    # per-neuron, per-timestep decay kernel B[t, c] = beta[c]**t * (1 - beta[c])
    T = x.shape[1]
    B = jnp.power(beta[None, :], jnp.arange(T)[:, None]) * (1 - beta[None, :])

    fft_B = jnp.fft.rfft(B, n=2 * T, axis=0)[None, :, :]
    fft_X = jnp.fft.rfft(x, n=2 * T, axis=1)

    V = jnp.fft.irfft(fft_X * fft_B, n=2 * T, axis=1)[:, :T, :]

    # calculate whether spike is generated, and update membrane potential
    spikes = self.spike(V, key)

    return spikes, V

__init__(hidden_shape, threshold=1, k=10, *, rngs)

Source code in spyx/experimental/stochastic.py
def __init__(self, hidden_shape: tuple, threshold=1, k=10, *, rngs: nnx.Rngs):
    self.hidden_shape = hidden_shape
    self.threshold = threshold
    self.spike = sigmoid_bernoulli(k, threshold)

    self.beta = nnx.Param(
        nnx.initializers.truncated_normal(stddev=0.25)(
            rngs.params(), self.hidden_shape
        )
        + 0.5
    )

StochasticAssociativeCuBaLIF

Bases: Module

Source code in spyx/experimental/stochastic.py
class StochasticAssociativeCuBaLIF(nnx.Module):
    def __init__(self, hidden_shape, threshold=1, k=100, *, rngs: nnx.Rngs):
        self.hidden_shape = hidden_shape
        self.spike = refractory_sigmoid_bernoulli(k, threshold)

        self.alpha = nnx.Param(
            nnx.initializers.truncated_normal(stddev=0.25)(
                rngs.params(), self.hidden_shape
            )
            + 0.5
        )
        self.beta = nnx.Param(
            nnx.initializers.truncated_normal(stddev=0.25)(
                rngs.params(), self.hidden_shape
            )
            + 0.5
        )

    def __call__(self, key, u):
        alpha = jnp.clip(self.alpha[:], 0, 1)
        beta = jnp.clip(self.beta[:], 0, 1)

        # this can probably be condensed.
        _, x = jax.vmap(_pscan, in_axes=(None, 0))(alpha, u)
        _, V = jax.vmap(_pscan, in_axes=(None, 0))(beta, x)

        return self.spike(V, key)

alpha = nnx.Param(nnx.initializers.truncated_normal(stddev=0.25)(rngs.params(), self.hidden_shape) + 0.5) instance-attribute

beta = nnx.Param(nnx.initializers.truncated_normal(stddev=0.25)(rngs.params(), self.hidden_shape) + 0.5) instance-attribute

hidden_shape = hidden_shape instance-attribute

spike = refractory_sigmoid_bernoulli(k, threshold) instance-attribute

__call__(key, u)

Source code in spyx/experimental/stochastic.py
def __call__(self, key, u):
    alpha = jnp.clip(self.alpha[:], 0, 1)
    beta = jnp.clip(self.beta[:], 0, 1)

    # this can probably be condensed.
    _, x = jax.vmap(_pscan, in_axes=(None, 0))(alpha, u)
    _, V = jax.vmap(_pscan, in_axes=(None, 0))(beta, x)

    return self.spike(V, key)

__init__(hidden_shape, threshold=1, k=100, *, rngs)

Source code in spyx/experimental/stochastic.py
def __init__(self, hidden_shape, threshold=1, k=100, *, rngs: nnx.Rngs):
    self.hidden_shape = hidden_shape
    self.spike = refractory_sigmoid_bernoulli(k, threshold)

    self.alpha = nnx.Param(
        nnx.initializers.truncated_normal(stddev=0.25)(
            rngs.params(), self.hidden_shape
        )
        + 0.5
    )
    self.beta = nnx.Param(
        nnx.initializers.truncated_normal(stddev=0.25)(
            rngs.params(), self.hidden_shape
        )
        + 0.5
    )

StochasticAssociativeLIF

Bases: Module

Source code in spyx/experimental/stochastic.py
class StochasticAssociativeLIF(nnx.Module):
    def __init__(self, hidden_shape, threshold=1, k=100, spike=True, *, rngs: nnx.Rngs):
        self.hidden_shape = hidden_shape
        self.threshold = threshold
        if spike:
            self.spike = sigmoid_bernoulli(k, threshold)
        else:
            self.spike = lambda x, k: x

        self.beta = nnx.Param(
            nnx.initializers.truncated_normal(stddev=0.25)(
                rngs.params(), self.hidden_shape
            )
            + 0.5
        )

    # x.shape = B, T, C
    def __call__(self, key, x):
        beta = jnp.clip(self.beta[:], 0, 1)

        _, V = jax.vmap(_pscan, in_axes=(None, 0))(beta, x)

        return self.spike(V, key), V

beta = nnx.Param(nnx.initializers.truncated_normal(stddev=0.25)(rngs.params(), self.hidden_shape) + 0.5) instance-attribute

hidden_shape = hidden_shape instance-attribute

spike = sigmoid_bernoulli(k, threshold) instance-attribute

threshold = threshold instance-attribute

__call__(key, x)

Source code in spyx/experimental/stochastic.py
def __call__(self, key, x):
    beta = jnp.clip(self.beta[:], 0, 1)

    _, V = jax.vmap(_pscan, in_axes=(None, 0))(beta, x)

    return self.spike(V, key), V

__init__(hidden_shape, threshold=1, k=100, spike=True, *, rngs)

Source code in spyx/experimental/stochastic.py
def __init__(self, hidden_shape, threshold=1, k=100, spike=True, *, rngs: nnx.Rngs):
    self.hidden_shape = hidden_shape
    self.threshold = threshold
    if spike:
        self.spike = sigmoid_bernoulli(k, threshold)
    else:
        self.spike = lambda x, k: x

    self.beta = nnx.Param(
        nnx.initializers.truncated_normal(stddev=0.25)(
            rngs.params(), self.hidden_shape
        )
        + 0.5
    )

refractory_sigmoid_bernoulli(k=50, threshold=1)

Source code in spyx/experimental/stochastic.py
def refractory_sigmoid_bernoulli(k=50, threshold=1):
    freq = 2 * jnp.pi * threshold

    @jax.custom_gradient
    def activation(x, key):
        U = x - threshold
        r = jnp.cos(freq * U)
        s = jax.nn.sigmoid(k * U)
        p_n = jnp.maximum(r * s, 0)
        return jax.random.bernoulli(key, p_n).astype(U.dtype), lambda g: (g * p_n, None)

    return activation

sigmoid_bernoulli(k=10, threshold=1.0, max_prob=0.8)

Source code in spyx/experimental/stochastic.py
def sigmoid_bernoulli(k=10, threshold=1.0, max_prob=0.8):
    @jax.custom_gradient
    def activation(x, key):
        U = x - threshold
        p_n = jax.nn.sigmoid(k * U) * max_prob
        return jax.random.bernoulli(key, p_n).astype(U.dtype), lambda g: (g * p_n, None)

    return activation