def to_nir(model, input_shape, output_shape, dt=1) -> nir.NIRGraph:
"""Converts a Spyx/NNX model to a NIR graph."""
nodes = {"input": nir.Input(input_shape), "output": nir.Output(output_shape)}
edges = []
prev_node = "input"
# We assume a sequential model for now
if not isinstance(model, nnx.Sequential):
layers = [model]
else:
layers = model.layers
# Track the per-sample tensor shape (NIR shapes exclude the batch axis) so
# nodes that need it — Conv2d (spatial dims) and Flatten (full shape) — can
# be constructed correctly. NIR uses channels-first (C, N_x, N_y).
_in = next(iter(input_shape.values()))
cur_shape = tuple(int(d) for d in np.ravel(np.asarray(_in)))
for i, layer in enumerate(layers):
node_key = f"layer_{i}"
if isinstance(layer, nnx.Linear):
if layer.bias is not None:
nodes[node_key] = nir.Affine(
np.array(layer.kernel[...].T), np.array(layer.bias[...])
)
else:
nodes[node_key] = nir.Linear(np.array(layer.kernel[...].T))
cur_shape = cur_shape[:-1] + (int(layer.kernel.shape[-1]),)
elif isinstance(layer, nnx.Conv):
# NNX Conv is HWIO, NIR is OIHW
weight = np.array(layer.kernel[...].transpose((3, 2, 0, 1)))
spatial = tuple(cur_shape[-2:]) # (N_x, N_y)
conv_node = nir.Conv2d(
input_shape=spatial, # required to disambiguate the shape
weight=weight,
bias=np.array(layer.bias[...]) if layer.bias is not None else None,
dilation=1, # Default
stride=layer.strides,
# Honour the layer's actual padding (SAME / VALID / explicit)
# instead of assuming SAME, which mis-shapes VALID convs.
padding=_nnx_pad_to_nir(layer.padding),
groups=1,
)
nodes[node_key] = conv_node
# nir.Conv2d infers the exact output shape (C_out, N_x, N_y) from the
# padding/stride/dilation, so read it back rather than assuming SAME.
cur_shape = tuple(
int(d) for d in np.asarray(conv_node.output_type["output"])
)
elif isinstance(layer, SumPool):
ksize = np.atleast_1d(np.asarray(layer.window_shape)).astype(int)
stride = np.atleast_1d(np.asarray(layer.strides)).astype(int)
if ksize.size == 1:
ksize = np.array([int(ksize[0])] * 2)
if stride.size == 1:
stride = np.array([int(stride[0])] * 2)
ch, h, w = cur_shape
if str(layer.padding).upper() == "SAME":
# SAME: output = ceil(in / stride); record the explicit symmetric
# NIR pad amount that realises that output size.
out_h = -(-h // int(stride[0]))
out_w = -(-w // int(stride[1]))
# Total pad each spatial dim needs to reach the SAME output size;
# stored as-is so a non-zero value flags SAME on re-import (spyx
# SumPool models only VALID / SAME, not per-side pad amounts).
pad_h = max(0, (out_h - 1) * int(stride[0]) + int(ksize[0]) - h)
pad_w = max(0, (out_w - 1) * int(stride[1]) + int(ksize[1]) - w)
padding = np.array([pad_h, pad_w])
cur_shape = (ch, out_h, out_w)
else: # VALID pooling shrinks the (channels-first) spatial dims.
padding = np.array([0, 0])
cur_shape = (
ch,
(h - int(ksize[0])) // int(stride[0]) + 1,
(w - int(ksize[1])) // int(stride[1]) + 1,
)
nodes[node_key] = nir.SumPool2d(
kernel_size=ksize,
stride=stride,
padding=padding,
)
elif isinstance(layer, IF):
# nir.IF requires array-valued r / v_threshold shaped to the layer.
# cur_shape carries spatial dims when the neuron follows a conv.
nodes[node_key] = nir.IF(
r=np.ones(cur_shape, dtype=np.float32),
v_threshold=_spyx_param_to_nir(layer.threshold, cur_shape),
)
elif isinstance(layer, LIF):
beta = _spyx_param_to_nir(layer.beta[...], cur_shape)
nodes[node_key] = nir.LIF(
tau=dt / (1 - beta),
v_threshold=_spyx_param_to_nir(layer.threshold, cur_shape),
v_leak=np.zeros(cur_shape, dtype=np.float32),
r=beta,
)
elif isinstance(layer, CuBaLIF):
alpha = _spyx_param_to_nir(layer.alpha[...], cur_shape)
beta = _spyx_param_to_nir(layer.beta[...], cur_shape)
nodes[node_key] = nir.CubaLIF(
tau_mem=dt / (1 - beta),
tau_syn=dt / (1 - alpha),
v_threshold=_spyx_param_to_nir(layer.threshold, cur_shape),
v_leak=np.zeros(cur_shape, dtype=np.float32),
r=beta,
)
elif isinstance(layer, (RIF, RLIF, RCuBaLIF)):
nodes[node_key] = _spyx_recurrent_to_nirgraph(layer, node_key, dt)
elif isinstance(layer, PSU_LIF):
# PSU_LIF is a *reset-free* leaky integrator followed by a threshold:
# V_t = clip(beta) * V_{t-1} + x_t , s_t = (V_t > threshold).
# NIR has no single reset-free spiking primitive: nir.LIF (and IF)
# always subtract / clamp on spike, so mapping to LIF would inject a
# reset PSU_LIF does not have. Instead we export the exact two-part
# decomposition NIR *can* represent faithfully:
# nir.LI -- the reset-free linear membrane (tau=dt/(1-beta),
# v_leak=0, r=1), identical to PSU_LIF's recurrence;
# nir.Threshold-- the pointwise spike rule s = (V > threshold),
# matching the heaviside forward of the surrogate.
# This pair round-trips back into a single PSU_LIF (see _build_model)
# and carries no reset-semantics gap.
beta = np.clip(_spyx_param_to_nir(layer.beta[...], cur_shape), 0.0, 1.0)
thr_key = f"{node_key}_threshold"
nodes[node_key] = nir.LI(
tau=dt / (1 - beta),
r=np.ones(cur_shape, dtype=np.float32),
v_leak=np.zeros(cur_shape, dtype=np.float32),
)
nodes[thr_key] = nir.Threshold(
threshold=_spyx_param_to_nir(layer.threshold, cur_shape),
)
edges.append((prev_node, node_key))
prev_node = thr_key
edges.append((node_key, thr_key))
# cur_shape unchanged: the threshold preserves the tensor shape.
continue
elif isinstance(layer, ResonateFire):
# ResonateFire is a complex resonate-and-fire oscillator whose pole is
# a = exp(dt * (-lambda + i*omega)), z_t = a * z_{t-1} + x_t.
# NIR has no complex / oscillatory / resonate-and-fire primitive; all
# of its neuron nodes carry real-valued state, so the imaginary
# rotation (the "resonate") cannot be represented without discarding
# the oscillation entirely -- which would be a faked mapping. We
# therefore refuse rather than silently degrade to a real leak.
raise NotImplementedError(
"ResonateFire has no faithful NIR representation: NIR defines no "
"complex / oscillatory / resonate-and-fire primitive, so the "
"complex pole a = exp(dt*(-lambda + i*omega)) cannot be exported "
"without discarding the oscillatory (imaginary) dynamics. Export "
"of ResonateFire to NIR is intentionally unsupported."
)
elif isinstance(layer, Flatten):
# spyx.nn.Flatten collapses every non-batch dim; NIR shapes have no
# batch axis, so flatten the whole per-sample shape (start_dim=0).
nodes[node_key] = nir.Flatten(input_type=cur_shape, start_dim=0, end_dim=-1)
cur_shape = (int(np.prod(cur_shape)),)
else:
print(
f"[Warning] Layer {type(layer)} not recognized/supported for NIR export."
)
continue
edges.append((prev_node, node_key))
prev_node = node_key
edges.append((prev_node, "output"))
return nir.NIRGraph(nodes, edges)