Skip to content

spyx.nir

Import / export Neuromorphic Intermediate Representation graphs. Supports feed-forward networks with Linear, Conv, LIF, CuBaLIF, and recurrent subgraphs for RIF, RLIF, RCuBaLIF.

from_nir(nir_graph, input_data, dt=1, *, return_all_states=False, rngs=None)

Reconstruct a Spyx/NNX model from a NIR graph and run it on input_data.

:param nir_graph: the NIR graph to import. :param input_data: time-major input, shape (T, B, ...); scanned over the leading time axis. :param dt: simulation timestep used to convert NIR time constants back to Spyx decay factors (must match the dt used on export). :param return_all_states: when True, also return the per-layer neuron states at every timestep (e.g. membrane-potential traces), as a pytree of (T, B, ...) arrays mirroring model.initial_state. :param rngs: optional nnx.Rngs for reconstructing the modules. :return: (model, outputs) where outputs is (T, B, ...); or (model, (outputs, states)) when return_all_states is True.

Source code in spyx/nir.py
def from_nir(
    nir_graph: nir.NIRGraph,
    input_data,
    dt: float = 1,
    *,
    return_all_states: bool = False,
    rngs: nnx.Rngs | None = None,
):
    """Reconstruct a Spyx/NNX model from a NIR graph and run it on ``input_data``.

    :param nir_graph: the NIR graph to import.
    :param input_data: time-major input, shape ``(T, B, ...)``; scanned over the
        leading time axis.
    :param dt: simulation timestep used to convert NIR time constants back to
        Spyx decay factors (must match the ``dt`` used on export).
    :param return_all_states: when True, also return the per-layer neuron states
        at *every* timestep (e.g. membrane-potential traces), as a pytree of
        ``(T, B, ...)`` arrays mirroring ``model.initial_state``.
    :param rngs: optional ``nnx.Rngs`` for reconstructing the modules.
    :return: ``(model, outputs)`` where ``outputs`` is ``(T, B, ...)``; or
        ``(model, (outputs, states))`` when ``return_all_states`` is True.
    """
    if rngs is None:
        rngs = nnx.Rngs(0)

    model = _build_model(nir_graph, dt, rngs)

    if not return_all_states:
        outputs, _ = run(model, input_data)
        return model, outputs

    # Capture the per-layer state at each timestep (membrane traces, etc.).
    init_state = model.initial_state(input_data.shape[1])

    def _step(state, x_t):
        out, new_state = model(x_t, state)
        return new_state, (out, new_state)

    _, (outputs, states) = jax.lax.scan(_step, init_state, input_data)
    return model, (outputs, states)

reorder_layers(init_params, trained_params)

Some optimization libraries may permute the keys of the network's PyTree; this is an issue as exporting to NIR assumes the keys are in their original order after initializing the network. This simple function takes the original and trained parameters and returns the trained parameters in the proper order for exportation.

Source code in spyx/nir.py
def reorder_layers(init_params, trained_params):
    """
    Some optimization libraries may permute the keys of the network's PyTree;
    this is an issue as exporting to NIR assumes the keys are in their original order
    after initializing the network. This simple function takes the original and trained parameters
    and returns the trained parameters in the proper order for exportation.
    """
    return {k: trained_params[k] for k in init_params.keys()}

to_nir(model, input_shape, output_shape, dt=1)

Converts a Spyx/NNX model to a NIR graph.

Source code in spyx/nir.py
def to_nir(model, input_shape, output_shape, dt=1) -> nir.NIRGraph:
    """Converts a Spyx/NNX model to a NIR graph."""

    nodes = {"input": nir.Input(input_shape), "output": nir.Output(output_shape)}
    edges = []

    prev_node = "input"

    # We assume a sequential model for now
    if not isinstance(model, nnx.Sequential):
        layers = [model]
    else:
        layers = model.layers

    # Track the per-sample tensor shape (NIR shapes exclude the batch axis) so
    # nodes that need it — Conv2d (spatial dims) and Flatten (full shape) — can
    # be constructed correctly. NIR uses channels-first (C, N_x, N_y).
    _in = next(iter(input_shape.values()))
    cur_shape = tuple(int(d) for d in np.ravel(np.asarray(_in)))

    for i, layer in enumerate(layers):
        node_key = f"layer_{i}"

        if isinstance(layer, nnx.Linear):
            if layer.bias is not None:
                nodes[node_key] = nir.Affine(
                    np.array(layer.kernel[...].T), np.array(layer.bias[...])
                )
            else:
                nodes[node_key] = nir.Linear(np.array(layer.kernel[...].T))
            cur_shape = cur_shape[:-1] + (int(layer.kernel.shape[-1]),)

        elif isinstance(layer, nnx.Conv):
            # NNX Conv is HWIO, NIR is OIHW
            weight = np.array(layer.kernel[...].transpose((3, 2, 0, 1)))
            spatial = tuple(cur_shape[-2:])  # (N_x, N_y)
            conv_node = nir.Conv2d(
                input_shape=spatial,  # required to disambiguate the shape
                weight=weight,
                bias=np.array(layer.bias[...]) if layer.bias is not None else None,
                dilation=1,  # Default
                stride=layer.strides,
                # Honour the layer's actual padding (SAME / VALID / explicit)
                # instead of assuming SAME, which mis-shapes VALID convs.
                padding=_nnx_pad_to_nir(layer.padding),
                groups=1,
            )
            nodes[node_key] = conv_node
            # nir.Conv2d infers the exact output shape (C_out, N_x, N_y) from the
            # padding/stride/dilation, so read it back rather than assuming SAME.
            cur_shape = tuple(
                int(d) for d in np.asarray(conv_node.output_type["output"])
            )

        elif isinstance(layer, SumPool):
            ksize = np.atleast_1d(np.asarray(layer.window_shape)).astype(int)
            stride = np.atleast_1d(np.asarray(layer.strides)).astype(int)
            if ksize.size == 1:
                ksize = np.array([int(ksize[0])] * 2)
            if stride.size == 1:
                stride = np.array([int(stride[0])] * 2)
            ch, h, w = cur_shape
            if str(layer.padding).upper() == "SAME":
                # SAME: output = ceil(in / stride); record the explicit symmetric
                # NIR pad amount that realises that output size.
                out_h = -(-h // int(stride[0]))
                out_w = -(-w // int(stride[1]))
                # Total pad each spatial dim needs to reach the SAME output size;
                # stored as-is so a non-zero value flags SAME on re-import (spyx
                # SumPool models only VALID / SAME, not per-side pad amounts).
                pad_h = max(0, (out_h - 1) * int(stride[0]) + int(ksize[0]) - h)
                pad_w = max(0, (out_w - 1) * int(stride[1]) + int(ksize[1]) - w)
                padding = np.array([pad_h, pad_w])
                cur_shape = (ch, out_h, out_w)
            else:  # VALID pooling shrinks the (channels-first) spatial dims.
                padding = np.array([0, 0])
                cur_shape = (
                    ch,
                    (h - int(ksize[0])) // int(stride[0]) + 1,
                    (w - int(ksize[1])) // int(stride[1]) + 1,
                )
            nodes[node_key] = nir.SumPool2d(
                kernel_size=ksize,
                stride=stride,
                padding=padding,
            )

        elif isinstance(layer, IF):
            # nir.IF requires array-valued r / v_threshold shaped to the layer.
            # cur_shape carries spatial dims when the neuron follows a conv.
            nodes[node_key] = nir.IF(
                r=np.ones(cur_shape, dtype=np.float32),
                v_threshold=_spyx_param_to_nir(layer.threshold, cur_shape),
            )

        elif isinstance(layer, LIF):
            beta = _spyx_param_to_nir(layer.beta[...], cur_shape)
            nodes[node_key] = nir.LIF(
                tau=dt / (1 - beta),
                v_threshold=_spyx_param_to_nir(layer.threshold, cur_shape),
                v_leak=np.zeros(cur_shape, dtype=np.float32),
                r=beta,
            )

        elif isinstance(layer, CuBaLIF):
            alpha = _spyx_param_to_nir(layer.alpha[...], cur_shape)
            beta = _spyx_param_to_nir(layer.beta[...], cur_shape)
            nodes[node_key] = nir.CubaLIF(
                tau_mem=dt / (1 - beta),
                tau_syn=dt / (1 - alpha),
                v_threshold=_spyx_param_to_nir(layer.threshold, cur_shape),
                v_leak=np.zeros(cur_shape, dtype=np.float32),
                r=beta,
            )

        elif isinstance(layer, (RIF, RLIF, RCuBaLIF)):
            nodes[node_key] = _spyx_recurrent_to_nirgraph(layer, node_key, dt)

        elif isinstance(layer, PSU_LIF):
            # PSU_LIF is a *reset-free* leaky integrator followed by a threshold:
            #   V_t = clip(beta) * V_{t-1} + x_t ,   s_t = (V_t > threshold).
            # NIR has no single reset-free spiking primitive: nir.LIF (and IF)
            # always subtract / clamp on spike, so mapping to LIF would inject a
            # reset PSU_LIF does not have. Instead we export the exact two-part
            # decomposition NIR *can* represent faithfully:
            #   nir.LI       -- the reset-free linear membrane (tau=dt/(1-beta),
            #                   v_leak=0, r=1), identical to PSU_LIF's recurrence;
            #   nir.Threshold-- the pointwise spike rule s = (V > threshold),
            #                   matching the heaviside forward of the surrogate.
            # This pair round-trips back into a single PSU_LIF (see _build_model)
            # and carries no reset-semantics gap.
            beta = np.clip(_spyx_param_to_nir(layer.beta[...], cur_shape), 0.0, 1.0)
            thr_key = f"{node_key}_threshold"
            nodes[node_key] = nir.LI(
                tau=dt / (1 - beta),
                r=np.ones(cur_shape, dtype=np.float32),
                v_leak=np.zeros(cur_shape, dtype=np.float32),
            )
            nodes[thr_key] = nir.Threshold(
                threshold=_spyx_param_to_nir(layer.threshold, cur_shape),
            )
            edges.append((prev_node, node_key))
            prev_node = thr_key
            edges.append((node_key, thr_key))
            # cur_shape unchanged: the threshold preserves the tensor shape.
            continue

        elif isinstance(layer, ResonateFire):
            # ResonateFire is a complex resonate-and-fire oscillator whose pole is
            #   a = exp(dt * (-lambda + i*omega)),  z_t = a * z_{t-1} + x_t.
            # NIR has no complex / oscillatory / resonate-and-fire primitive; all
            # of its neuron nodes carry real-valued state, so the imaginary
            # rotation (the "resonate") cannot be represented without discarding
            # the oscillation entirely -- which would be a faked mapping. We
            # therefore refuse rather than silently degrade to a real leak.
            raise NotImplementedError(
                "ResonateFire has no faithful NIR representation: NIR defines no "
                "complex / oscillatory / resonate-and-fire primitive, so the "
                "complex pole a = exp(dt*(-lambda + i*omega)) cannot be exported "
                "without discarding the oscillatory (imaginary) dynamics. Export "
                "of ResonateFire to NIR is intentionally unsupported."
            )

        elif isinstance(layer, Flatten):
            # spyx.nn.Flatten collapses every non-batch dim; NIR shapes have no
            # batch axis, so flatten the whole per-sample shape (start_dim=0).
            nodes[node_key] = nir.Flatten(input_type=cur_shape, start_dim=0, end_dim=-1)
            cur_shape = (int(np.prod(cur_shape)),)

        else:
            print(
                f"[Warning] Layer {type(layer)} not recognized/supported for NIR export."
            )
            continue

        edges.append((prev_node, node_key))
        prev_node = node_key

    edges.append((prev_node, "output"))

    return nir.NIRGraph(nodes, edges)