spyx.data
Grain-based data pipeline. The functional encoders (rate_code, angle_code, latency_code, shift_augment) return JIT-compiled callables; the RateCode / AngleCode / LatencyCode / ShiftAugment classes are their grain.MapTransform counterparts for use inside dataset pipelines. The SHD_loader / NMNIST_loader classes require the [loaders] extra.
AngleCode
Bases: MapTransform
Grain MapTransform for angle encoding.
Source code in spyx/data.py
GrainLoader
A wrapper around Grain's DataLoader with a Spyx-compatible interface.
:dataset: grain MapDataset (or any len + getitem source).
:batch_size: items per emitted State.
:shuffle: whether to shuffle sample order.
:seed: RNG seed for the sampler.
:worker_count: number of Grain worker processes. None picks
_default_worker_count(); 0 disables multiprocessing
(useful for debugging but slow for tonic-backed sources).
Source code in spyx/data.py
LatencyCode
Bases: MapTransform
Grain MapTransform for time-to-first-spike (latency) encoding.
Counterpart to :func:latency_code, wrapped in the Grain op interface so
it can slot into an existing SHD_loader-style pipeline.
Source code in spyx/data.py
NMNIST_loader
Dataloader for the Neuromorphic MNIST dataset using Google Grain and Tonic.
:worker_count: number of Grain worker processes. None picks a sensible
default (half your CPU cores, capped at 4); 0 disables multi-
processing. Passing a positive integer usually cuts first-batch
latency significantly when tonic is decoding samples.
Source code in spyx/data.py
RateCode
Bases: MapTransform
Grain MapTransform for rate encoding.
Source code in spyx/data.py
SHD_loader
Dataloader for the Spiking Heidelberg Dataset using Google Grain and Tonic.
Notes:
- Tonic's stock
HSD.__getitem__readsspikes/timesas float32 and doesn't filter non-finite values, which triggers noisy NumPy warnings and, for a handful of samples, produces garbage event streams that makeToFrameallocate enormous frames.spyx.datapatchesHSDat import time (see :func:_patch_tonic_hsd) to cast to float64 and drop non-finite timestamps. The patch is a monkey-patch pending an upstream fix in tonic.
:worker_count: number of Grain worker processes. None picks a
sensible default (half your CPU cores, capped at 4); 0
disables multiprocessing. For the default batch_size=256
on a laptop, setting worker_count=4 cuts first-batch
latency from ~30s to a few seconds.
Source code in spyx/data.py
478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 | |
prestage(split='train')
Bulk-load a split into a single on-device array, fast and torch-free.
Walks the underlying tonic dataset in-process (no grain workers,
no PyTorch DataLoader) and rasterises each sample via
:class:_SHD2Raster. For the 8 k-sample SHD train split this runs
in a handful of seconds, vs. tens of seconds through grain's
streaming pipeline which pays per-iterator spinup + inter-process
shared-memory overhead. Matches the "entire dataset lives in vRAM"
pattern the Spyx paper relied on for throughput.
:split: "train" or "test".
:return: (obs_NBTC, labels_NB) — obs is
uint8[n_batches, batch_size, T_packed, channels] with
time packed along axis 2 (axis 1 is batch_size), labels is
int[n_batches, batch_size]. Trailing partial batch is
dropped so the training loop can scan over a fixed N.
Source code in spyx/data.py
ShiftAugment
Bases: MapTransform
Grain MapTransform for random shift augmentation.
Source code in spyx/data.py
angle_code(neuron_count, min_val, max_val)
Higher-order-function which returns an angle encoding function; given a continuous value, an angle converter generates a one-hot vector corresponding to where the value falls between a specified minimum and maximum. To achieve non-linear descritization, apply a function to the continuous value before feeding it into the encoder.
:neuron_count: The number of output channels for the angle encoder :min_val: A lower bound on the continuous input channel :max_val: An upper bound on the continuous input channel.
Source code in spyx/data.py
latency_code(num_steps, threshold=0.01)
Time-to-first-spike (latency) encoding.
Large input values fire earlier in the cycle, small values fire later.
Concretely, an input in [0, 1] is mapped to a spike time
t = round((1 - x) * (num_steps - 1)) and a one-hot spike train is
emitted along a new leading time axis. Inputs below threshold never
fire, producing all-zero rows.
The encoding preserves total information in a single spike per neuron, which is both far sparser than rate coding and matches the time-to- first-spike training scheme used in the neuromorphic hardware literature.
:param num_steps: length of the emitted spike train (time axis).
:param threshold: values <= threshold are considered silent.
:return: JIT-compiled function mapping data: [..., C] (values in
[0, 1]) to spikes: [num_steps, ..., C] of dtype uint8.
Source code in spyx/data.py
rate_code(num_steps, max_r=0.75)
Unrolls input data along axis 1 and converts to rate encoded spikes; the probability of spiking is based on the input value multiplied by a max rate, with each time step being a sample drawn from a Bernoulli distribution. Currently Assumes input values have been rescaled to between 0 and 1.
Source code in spyx/data.py
shift_augment(max_shift=10, axes=(-1,))
Shift data augmentation tool. Rolls data along specified axes randomly up to a certain amount.
:max_shift: maximum to which values can be shifted. :axes: the data axis or axes along which the input will be randomly shifted.
Source code in spyx/data.py
shuffler(dataset, batch_size)
Higher-order-function which builds a shuffle function for a dataset.
:dataset: jnp.array [# samples, time, channels...] :batch_size: desired batch size.