spyx.phasor
Complex-valued phasor layers with spike-time conversion helpers. Weights are stored as paired kernel_re / kernel_im float32 parameters so a stock optax.adam loop converges (avoiding JAX's Wirtinger-conjugate gradient surprise on complex parameters).
ResonateFire (documented below) is the complex/oscillatory sibling of spyx.nn.PSU_LIF: a reset-free resonate-and-fire neuron whose complex membrane z_t = a·z_{t-1} + x_t is a linear recurrence, so it exposes both a stepwise __call__ and a parallel(x) associative-scan path with O(log T) depth. See the parallel spiking neurons explanation.
Phasor and Spiking Phasor networks for Spyx.
Implements the deep phasor architecture of Bybee, Frady & Sommer (2022, arXiv 2106.11908) on top of Flax NNX, taking advantage of JAX's native complex dtype so the complex-valued forward and backward passes are handled by the autodiff engine without manual real/imag splitting.
The two halves of a phasor pipeline:
-
Continuous (training-time): complex-valued layers with phases on the unit circle.
PhasorLineardoesz_out = W @ z_in + bwithW: complex64;PhasorActivationprojects back onto the unit circle, mimicking the threshold function of the Frady/Sommer attractor model. -
Spiking (inference-time): each phase is mapped to a single spike inside a cycle of length
T. The companion helpersphase_to_spikesandspikes_to_phasemake it possible to run the same trained weights on a spiking substrate via :class:SpikingPhasor.
This module is intentionally minimal and targets the pattern documented in
docs/examples/phasor/phasor_intro.ipynb (issue #38).
.. note::
Parameters that enter a complex-valued forward pass are stored as
separate kernel_re + kernel_im float32 tensors and assembled
on each call (see :class:PhasorLinear). This sidesteps the JAX
Wirtinger-conjugate-gradient surprise that bit the first iteration of
this module, and lets you train phasor networks with a stock
optax.adam + nnx.Optimizer loop.
PhasorActivation
Bases: Module
Project complex activations back onto the unit circle.
This is the "threshold" function of the TPAM attractor model: it discards
the magnitude and keeps only the phase. eps prevents division-by-zero
when an activation collapses to 0 + 0j (rare but possible during early
training).
Source code in spyx/phasor.py
PhasorLinear
Bases: Module
Complex-valued dense layer with real/imag parameter storage.
z_out = z_in @ kernel + bias where kernel = kernel_re + i·kernel_im
is reconstructed on each forward pass from two float32 parameters.
Why not store kernel as a single complex64 nnx.Param?
JAX returns the conjugate Wirtinger derivative when you take
jax.grad of a real-valued loss with respect to a complex parameter.
Optax is real-arithmetic only and does not unwind the conjugation, which
caused vanilla optax.adam steps to drift sideways on the imaginary
axis in the first iteration of this module. Splitting storage into
kernel_re + kernel_im sidesteps the whole issue: the gradients
optax sees are always real, and the complex structure shows up only in
the forward pass. This matches the pattern used by the TF reference in
wilkieolin/phasor_networks.
Source code in spyx/phasor.py
bias
property
Complex bias reconstructed from the real/imag storage (or None).
kernel
property
Complex kernel reconstructed from the real/imag storage.
PhasorMLP
Bases: Module
A small phasor MLP: encode -> N x (PhasorLinear -> PhasorActivation) -> readout.
Convenience constructor for the most common phasor topology.
Source code in spyx/phasor.py
PhasorReadout
Bases: Module
Map complex hidden states to real-valued logits.
Implementation: take the real part of a final PhasorLinear. Equivalent
to projecting each output phasor onto the cosine basis. Works as a drop-in
replacement for the final nnx.Linear of a classifier.
Source code in spyx/phasor.py
ResonateFire
Bases: Module
Resonate-and-fire neuron: the complex/oscillatory sibling of PSU_LIF.
.. note::
Experimental. Its supported entry point is
:class:spyx.experimental.ResonateFire; the API may change without a
deprecation cycle. It is defined here for locality with the phasor layers.
A resonate-and-fire neuron carries a complex membrane that behaves as a damped harmonic oscillator. Written reset-free, its subthreshold dynamics are a complex linear recurrence
.. math:: z_t = a \, z_{t-1} + x_t , \qquad a = e^{\,\mathrm{dt}\,(-\lambda + i\,\omega)} ,
with per-unit decay :math:\lambda \ge 0 and angular frequency
:math:\omega. The real input current x_t is injected into the real
part of the membrane. Because there is no reset, the recurrence stays
linear, so exactly like :class:spyx.nn.PSU_LIF it can be evaluated with
:func:jax.lax.associative_scan in :math:O(\log T) parallel depth -- only
now the scan runs over a complex pole a instead of a real leak.
Spikes are emitted by a pointwise surrogate threshold on the real part of
the oscillator, :math:s_t = \sigma(\Re(z_t) - \text{threshold}). The rule
is reset-free so the linear recurrence -- and therefore the parallel scan --
is preserved.
Stability: the pole magnitude is |a| = exp(-dt * lambda). Storing the
decay through a softplus keeps :math:\lambda \ge 0, hence
:math:|a| \le 1 and the oscillation never grows.
Parameters that enter the complex pole (lambda, omega) are stored as
real float32 nnx.Param tensors, mirroring :class:PhasorLinear:
the complex structure appears only in the forward pass, so a stock
optax + jax.grad loop over a real loss trains them without the
Wirtinger-conjugate surprise.
Two execution modes are provided and are numerically identical:
- :meth:
__call__-- one reset-free timestep(x, z) -> (spikes, z)withz = a * z + x; a drop-in for :func:spyx.nn.run/ :class:Sequential. - :meth:
parallel-- the whole time-major sequence at once via an associative scan over the complex pole, :math:O(\log T)depth.
Because both modes use the same pole and surrogate and integrate the input
before spiking, scanning :meth:__call__ over x reproduces
:meth:parallel exactly.
Source code in spyx/phasor.py
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 | |
a
property
Complex oscillator pole a = exp(dt(-lambda + i*omega)).
The magnitude |a| = exp(-dt * lambda) <= 1 guarantees stability.
decay
property
Effective non-negative decay lambda = softplus(raw_lambda).
__call__(x, z)
One reset-free timestep.
real input current from the previous layer, broadcastable to
z.
:z: complex64 membrane state.
Injects x into the real part of the membrane and advances the
complex recurrence z = a * z + x (no reset), then emits a surrogate
spike on Re(z) so that scanning this method matches :meth:parallel.
Source code in spyx/phasor.py
__init__(hidden_shape, lambda_init=None, omega_init=None, threshold=1.0, dt=1.0, activation=None, *, rngs)
:hidden_shape: Per-unit shape of the layer.
:lambda_init: Membrane decay >= 0. Scalar constant if provided, else
a learnable per-unit initialisation. Stored through softplus so
the effective decay is always non-negative.
:omega_init: Angular frequency of the oscillator. Scalar constant if
provided, else a learnable per-unit initialisation.
:threshold: Real firing threshold on Re(z). Defaults to 1.
:dt: Integration timestep entering the pole exp(dt(-lambda+i*omega)).
:activation: :class:spyx.axn.Axon surrogate spike; defaults to
superspike.
Source code in spyx/phasor.py
initial_state(batch_size)
parallel(x)
Score a whole time-major sequence with an associative scan.
real input with shape
[Time, Batch, ...].
:return: spikes with shape [Time, Batch, ...].
Computes the full complex membrane trace z_t = a * z_{t-1} + x_t
(with z_{-1} = 0) via :func:jax.lax.associative_scan over the time
axis in :math:O(\log T) depth, then applies the surrogate spike
pointwise on Re(z).
Source code in spyx/phasor.py
SpikingPhasor
Bases: Module
Spiking inference wrapper around a single :class:PhasorLinear.
The forward pass:
- Takes a batched spike train
[T, B, in_features]. - Recovers per-unit phases via :func:
spikes_to_phase. - Multiplies the resulting unit-magnitude phasors through
PhasorLinear. - Applies :class:
PhasorActivationto renormalise to the unit circle. - Re-emits a spike train
[T, B, out_features]via :func:phase_to_spikes.
This makes a phasor layer drop-in compatible with spyx.nn.Sequential
+ spyx.nn.run for spike-domain evaluation. For training, use
PhasorLinear directly on the complex domain (much faster) and only
convert to SpikingPhasor at deployment.
Source code in spyx/phasor.py
phase_of(z)
phase_to_spikes(theta, T)
Convert phases to single-spike-per-cycle spike trains.
A neuron with phase θ ∈ (-π, π] fires at timestep round((θ + π) /
(2π) * T) within a cycle of T ticks. The returned tensor has the time
axis prepended.
:theta: real array of shape (...).
:T: int, number of ticks per cycle.
:return: float32 array of shape (T, ...), exactly one 1. per
(time, neuron) slice along the time axis.
Source code in spyx/phasor.py
phasor_to_real(z)
Decode phasors to real values via the real component (cos of phase).
Convenient when feeding a downstream real-valued readout / loss.
real_to_phasor(x, scale=jnp.pi)
Encode real-valued inputs as unit-magnitude phasors.
Maps each scalar x to e^{i * scale * x}. With the default
scale = π and inputs in [0, 1] this fills the upper half-circle,
which keeps the encoding monotonic in x without aliasing.
real array of any shape.
:scale: phase scaling.
π is the natural choice for inputs in [0, 1].
:return: complex64 array, same shape as x.
Source code in spyx/phasor.py
spikes_to_phase(spike_train, T=None)
Recover phases from a spike train (inverse of :func:phase_to_spikes).
For each unit, computes the spike-time centroid weighted by the spike
train, then maps it back to a phase in (-π, π]. If a unit emits no
spikes the centroid is undefined; we return 0 in that case.
:spike_train: shape (T, ...).
:T: cycle length; defaults to spike_train.shape[0].
:return: real array of shape (...).