Phasor and Spiking Phasor Networks¶
This tutorial walks through the spyx.phasor module in three steps:
- Continuous phasor MLP — build and train a complex-valued network on a small toy task.
- Phase ↔ spike codec — convert phases to one-spike-per-cycle trains and back.
- Spiking inference — run the trained weights on a spiking substrate via
SpikingPhasor.
The phasor architecture is due to Bybee, Frady & Sommer (2021, arXiv 2106.11908). Our implementation stores weights as paired kernel_re / kernel_im float32 parameters so that a stock optax.adam loop converges — see the module docstring for the Wirtinger-gradient rationale.
import jax
import jax.numpy as jnp
import optax
from flax import nnx
from spyx import phasor
1. Continuous phasor MLP¶
A phasor unit holds a complex-valued activation z = r · e^{iθ}. PhasorLinear is a complex-valued dense layer, PhasorActivation projects each activation back onto the unit circle (the TPAM threshold function), and PhasorReadout collapses the complex output to real logits via Re(W · z).
Build a tiny MLP and show it learns a linearly-separable toy task:
rngs = nnx.Rngs(0)
model = phasor.PhasorMLP(
in_features=4,
hidden_features=16,
out_features=2,
depth=2,
rngs=rngs,
)
# Synthetic task: label = 1 iff x[0] > 0.5.
x = jax.random.uniform(jax.random.PRNGKey(1), (64, 4))
y = (x[:, 0] > 0.5).astype(jnp.int32)
optimizer = nnx.Optimizer(model, optax.adam(5e-3), wrt=nnx.Param)
@nnx.jit
def train_step(model, optimizer, x, y):
def loss_fn(m):
return optax.softmax_cross_entropy_with_integer_labels(m(x), y).mean()
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
losses = []
for step in range(300):
losses.append(float(train_step(model, optimizer, x, y)))
print(f"initial loss: {losses[0]:.4f}")
print(f"final loss: {losses[-1]:.4f}")
# Final accuracy.
preds = jnp.argmax(model(x), axis=-1)
print(f"final accuracy: {float((preds == y).mean()):.3f}")
initial loss: 0.6607 final loss: 0.0004
final accuracy: 1.000
2. Phase ↔ spike round-trip¶
For spiking inference each phase is mapped to a single spike time inside a cycle of length T. phase_to_spikes(θ, T) emits a one-hot spike train; spikes_to_phase(spikes, T) recovers the phase via spike-time centroid. The round-trip is exact up to the quantisation step 2π/T.
T = 64
theta = jnp.linspace(-jnp.pi + 1e-3, jnp.pi - 1e-3, 16)
spikes = phasor.phase_to_spikes(theta, T)
recovered = phasor.spikes_to_phase(spikes, T)
print("spike-train shape:", spikes.shape)
print("spikes per neuron (always 1):", int(jnp.sum(spikes, axis=0)[0]))
bin_size = 2.0 * jnp.pi / T
max_err = float(jnp.max(jnp.abs(recovered - theta)))
print(f"max round-trip error: {max_err:.4f} (bin size = {float(bin_size):.4f})")
spike-train shape: (64, 16) spikes per neuron (always 1): 1 max round-trip error: 0.0972 (bin size = 0.0982)
Visualising the codec¶
Each column below is one neuron's spike train; brighter inputs (higher phase) fire later in the cycle.
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(8, 4))
ax.imshow(spikes.T, aspect="auto", cmap="Greys")
ax.set_xlabel("time bin (0..T)")
ax.set_ylabel("neuron (phase -π → +π)")
ax.set_title("phase_to_spikes(θ, T=64): one spike per neuron per cycle")
plt.show()
3. Spiking inference¶
SpikingPhasor wraps a trained PhasorLinear for evaluation on a spiking substrate:
- Recovers per-unit phases from the input spike train with
spikes_to_phase. - Runs the complex matmul through the wrapped
PhasorLinear. - Normalises back to the unit circle via
PhasorActivation. - Re-emits spikes via
phase_to_spikes.
A full phasor SNN is just a stack of SpikingPhasor layers feeding each other spike trains — no per-timestep state, no leaky dynamics.
# Build a single PhasorLinear and wrap it for spiking inference.
pl = phasor.PhasorLinear(in_features=8, out_features=6, rngs=nnx.Rngs(2))
sp = phasor.SpikingPhasor(pl, period_T=32)
# Synthetic input: random phases -> spike train [T=32, B=4, C=8].
theta_in = jax.random.uniform(
jax.random.PRNGKey(3), (4, 8), minval=-jnp.pi, maxval=jnp.pi
)
spikes_in = phasor.phase_to_spikes(theta_in, T=32)
spikes_out = sp(spikes_in)
print("input shape:", spikes_in.shape)
print("output shape:", spikes_out.shape)
print("spikes per out neuron (always 1):", int(jnp.sum(spikes_out, axis=0)[0, 0]))
input shape: (32, 4, 8) output shape: (32, 4, 6) spikes per out neuron (always 1): 1
Where to go next¶
- Try stacking two
SpikingPhasorblocks to build a hidden-layer phasor SNN. - Compare wall-clock against the surrogate-gradient LIF tutorial at a matched parameter count — phasor networks have no recurrent state to scan over, which wins on long cycles.
- Quantize the
kernel_re/kernel_imtensors throughspyx.quant.quantize(treats each as a plainnnx.Linear-style parameter).
See the module docstring in src/spyx/phasor.py for the Wirtinger-gradient background and scripts/phasor_demo.py for a non-notebook version of the forward / round-trip checks.