Skip to content

spyx.data

Grain-based data pipeline. The functional encoders (rate_code, angle_code, latency_code, shift_augment) return JIT-compiled callables; the RateCode / AngleCode / LatencyCode / ShiftAugment classes are their grain.MapTransform counterparts for use inside dataset pipelines. The SHD_loader / NMNIST_loader classes require the [loaders] extra.

AngleCode

Bases: MapTransform

Grain MapTransform for angle encoding.

Source code in spyx/data.py
class AngleCode(grain.MapTransform):
    """
    Grain MapTransform for angle encoding.
    """

    def __init__(
        self, neuron_count, min_val, max_val, input_key="obs", output_key="obs"
    ):
        self.neuron_count = neuron_count
        self.min_val = min_val
        self.max_val = max_val
        self.input_key = input_key
        self.output_key = output_key
        self.neurons = np.linspace(min_val, max_val, neuron_count)

    def map(self, record):
        obs = record[self.input_key]
        digital = np.digitize(obs, self.neurons) - 1
        idx = np.clip(digital, 0, self.neuron_count - 1)
        record[self.output_key] = np.eye(self.neuron_count, dtype=np.uint8)[idx]
        return record

GrainLoader

A wrapper around Grain's DataLoader with a Spyx-compatible interface.

:dataset: grain MapDataset (or any len + getitem source). :batch_size: items per emitted State. :shuffle: whether to shuffle sample order. :seed: RNG seed for the sampler. :worker_count: number of Grain worker processes. None picks _default_worker_count(); 0 disables multiprocessing (useful for debugging but slow for tonic-backed sources).

Source code in spyx/data.py
class GrainLoader:
    """A wrapper around Grain's DataLoader with a Spyx-compatible interface.

    :dataset: grain ``MapDataset`` (or any ``len + getitem`` source).
    :batch_size: items per emitted ``State``.
    :shuffle: whether to shuffle sample order.
    :seed: RNG seed for the sampler.
    :worker_count: number of Grain worker processes. ``None`` picks
        ``_default_worker_count()``; ``0`` disables multiprocessing
        (useful for debugging but slow for tonic-backed sources).
    """

    def __init__(
        self,
        dataset,
        batch_size,
        shuffle,
        seed=0,
        worker_count=None,
    ):
        if worker_count is None:
            worker_count = _default_worker_count()

        sampler = grain.IndexSampler(
            num_records=len(dataset),
            shuffle=shuffle,
            seed=seed,
            shard_options=grain.NoSharding(),
        )

        self.data_loader = grain.DataLoader(
            data_source=dataset,
            sampler=sampler,
            worker_count=worker_count,
            operations=[grain.Batch(batch_size, drop_remainder=True)],
        )

    def __iter__(self):
        for batch in self.data_loader:
            yield State(obs=batch["obs"], labels=batch["labels"])

LatencyCode

Bases: MapTransform

Grain MapTransform for time-to-first-spike (latency) encoding.

Counterpart to :func:latency_code, wrapped in the Grain op interface so it can slot into an existing SHD_loader-style pipeline.

Source code in spyx/data.py
class LatencyCode(grain.MapTransform):
    """Grain MapTransform for time-to-first-spike (latency) encoding.

    Counterpart to :func:`latency_code`, wrapped in the Grain op interface so
    it can slot into an existing ``SHD_loader``-style pipeline.
    """

    def __init__(self, sample_T, threshold=0.01, input_key="obs", output_key="obs"):
        self.sample_T = sample_T
        self.threshold = threshold
        self.input_key = input_key
        self.output_key = output_key

    def map(self, record):
        data = np.asarray(record[self.input_key], dtype=np.float32)
        data = np.clip(data, 0.0, 1.0)
        spike_idx = np.round((1.0 - data) * (self.sample_T - 1)).astype(np.int32)
        # Build a one-hot mask along axis 0.
        spikes = np.zeros((self.sample_T,) + data.shape, dtype=np.uint8)
        idx_grid = np.indices(data.shape)
        silent = data <= self.threshold
        # Fire only at the computed time bin, and only for non-silent units.
        spikes[(spike_idx, *idx_grid)] = np.where(silent, 0, 1).astype(np.uint8)
        record[self.output_key] = np.packbits(spikes, axis=0)
        return record

NMNIST_loader

Dataloader for the Neuromorphic MNIST dataset using Google Grain and Tonic.

:worker_count: number of Grain worker processes. None picks a sensible default (half your CPU cores, capped at 4); 0 disables multi- processing. Passing a positive integer usually cuts first-batch latency significantly when tonic is decoding samples.

Source code in spyx/data.py
class NMNIST_loader:
    """Dataloader for the Neuromorphic MNIST dataset using Google Grain and Tonic.

    :worker_count: number of Grain worker processes. ``None`` picks a sensible
        default (half your CPU cores, capped at 4); ``0`` disables multi-
        processing. Passing a positive integer usually cuts first-batch
        latency significantly when tonic is decoding samples.
    """

    def __init__(
        self,
        batch_size=32,
        sample_T=40,
        key=0,
        download_dir="./data",
        worker_count=None,
    ):
        if not tonic_installed:
            raise ImportError(
                "Please install the optional dependencies by running 'pip install spyx[loaders]' to use this feature."
            )

        self.batch_size = batch_size
        self.sample_T = sample_T
        self.obs_shape = (2, 34, 34)
        self.act_shape = (10,)

        transform = transforms.Compose(
            [
                transforms.ToFrame(sensor_size=(34, 34, 2), n_time_bins=sample_T),
            ]
        )

        train_ds = datasets.NMNIST(download_dir, train=True, transform=transform)
        test_ds = datasets.NMNIST(download_dir, train=False, transform=transform)

        train_mds = SpyxMapDataset(TonicSource(train_ds))
        test_mds = SpyxMapDataset(TonicSource(test_ds))

        self._train_dl = GrainLoader(
            train_mds, batch_size, shuffle=True, seed=key, worker_count=worker_count
        )
        self._test_dl = GrainLoader(
            test_mds, batch_size, shuffle=False, seed=key, worker_count=worker_count
        )

    def train_epoch(self):
        return iter(self._train_dl)

    def test_epoch(self):
        return iter(self._test_dl)

RateCode

Bases: MapTransform

Grain MapTransform for rate encoding.

Source code in spyx/data.py
class RateCode(grain.MapTransform):
    """
    Grain MapTransform for rate encoding.
    """

    def __init__(self, sample_T, max_r=0.75, input_key="obs", output_key="obs"):
        self.sample_T = sample_T
        self.max_r = max_r
        self.input_key = input_key
        self.output_key = output_key

    def map(self, record):
        data = record[self.input_key]
        # Assumes data is scaled 0-1
        # We use numpy for transformation in grain pipelines
        spikes = np.random.rand(self.sample_T, *data.shape) < (data * self.max_r)
        record[self.output_key] = np.packbits(spikes.astype(np.uint8), axis=0)
        return record

SHD_loader

Dataloader for the Spiking Heidelberg Dataset using Google Grain and Tonic.

Notes:

  • Tonic's stock HSD.__getitem__ reads spikes/times as float32 and doesn't filter non-finite values, which triggers noisy NumPy warnings and, for a handful of samples, produces garbage event streams that make ToFrame allocate enormous frames. spyx.data patches HSD at import time (see :func:_patch_tonic_hsd) to cast to float64 and drop non-finite timestamps. The patch is a monkey-patch pending an upstream fix in tonic.

:worker_count: number of Grain worker processes. None picks a sensible default (half your CPU cores, capped at 4); 0 disables multiprocessing. For the default batch_size=256 on a laptop, setting worker_count=4 cuts first-batch latency from ~30s to a few seconds.

Source code in spyx/data.py
class SHD_loader:
    """Dataloader for the Spiking Heidelberg Dataset using Google Grain and Tonic.

    Notes:

    * Tonic's stock ``HSD.__getitem__`` reads ``spikes/times`` as
      float32 and doesn't filter non-finite values, which triggers
      noisy NumPy warnings and, for a handful of samples, produces
      garbage event streams that make ``ToFrame`` allocate enormous
      frames. ``spyx.data`` patches ``HSD`` at import time (see
      :func:`_patch_tonic_hsd`) to cast to float64 and drop non-finite
      timestamps. The patch is a monkey-patch pending an upstream fix
      in tonic.

    :worker_count: number of Grain worker processes. ``None`` picks a
        sensible default (half your CPU cores, capped at 4); ``0``
        disables multiprocessing. For the default ``batch_size=256``
        on a laptop, setting ``worker_count=4`` cuts first-batch
        latency from ~30s to a few seconds.
    """

    def __init__(
        self,
        batch_size=256,
        sample_T=128,
        channels=128,
        key=0,
        download_dir="./data",
        worker_count=None,
    ):
        if not tonic_installed:
            raise ImportError(
                "Please install the optional dependencies by running 'pip install spyx[loaders]' to use this feature."
            )

        net_channels = channels
        self.obs_shape = (channels,)
        self.act_shape = (20,)
        self.batch_size = batch_size
        self.sample_T = sample_T

        # Custom raster transform: faster per-sample than tonic's
        # ``ToFrame(n_time_bins=T)`` and dodges the integer-floor-division
        # bug in ``SliceByTimeBins`` that silently zeroes out every frame
        # when ``Downsample(time_factor=...)`` pre-compresses timestamps
        # into ``[0, sample_T]``. See ``_SHD2Raster`` above and tonic
        # issue https://github.com/neuromorphs/tonic/issues/313 .
        # We apply ``Downsample(time_factor=..., spatial_factor=...)`` up
        # front because the raster builds a ``(t_max+1, channels)`` zero
        # tensor per sample; leaving timestamps in microseconds would
        # allocate a ~1e6-row tensor per sample and blow out memory.
        shd_timestep = 1e-6
        net_dt = 1 / sample_T
        transform = transforms.Compose(
            [
                transforms.Downsample(
                    time_factor=shd_timestep / net_dt,
                    spatial_factor=net_channels / 700,
                ),
                _SHD2Raster(encoding_dim=net_channels, sample_T=sample_T),
            ]
        )

        train_ds = datasets.SHD(download_dir, train=True, transform=transform)
        test_ds = datasets.SHD(download_dir, train=False, transform=transform)

        # Keep the raw tonic datasets around for the bulk-prestage path.
        self._train_tonic = train_ds
        self._test_tonic = test_ds

        train_mds = SpyxMapDataset(TonicSource(train_ds))
        test_mds = SpyxMapDataset(TonicSource(test_ds))

        self._train_dl = GrainLoader(
            train_mds, batch_size, shuffle=True, seed=key, worker_count=worker_count
        )
        self._test_dl = GrainLoader(
            test_mds, batch_size, shuffle=False, seed=key, worker_count=worker_count
        )

    def prestage(self, split: str = "train"):
        """Bulk-load a split into a single on-device array, fast and torch-free.

        Walks the underlying tonic dataset in-process (no grain workers,
        no PyTorch DataLoader) and rasterises each sample via
        :class:`_SHD2Raster`. For the 8 k-sample SHD train split this runs
        in a handful of seconds, vs. tens of seconds through grain's
        streaming pipeline which pays per-iterator spinup + inter-process
        shared-memory overhead. Matches the "entire dataset lives in vRAM"
        pattern the Spyx paper relied on for throughput.

        :split: ``"train"`` or ``"test"``.
        :return: ``(obs_NBTC, labels_NB)`` — ``obs`` is
            ``uint8[n_batches, batch_size, T_packed, channels]`` with
            time packed along axis 2 (axis 1 is ``batch_size``), ``labels`` is
            ``int[n_batches, batch_size]``. Trailing partial batch is
            dropped so the training loop can scan over a fixed ``N``.
        """
        if split not in ("train", "test"):
            raise ValueError(f"split must be 'train' or 'test'; got {split!r}")
        ds = self._train_tonic if split == "train" else self._test_tonic

        # Allocate once and fill by index. _SHD2Raster output is already
        # (sample_T, channels) uint8 binary; we pack along axis 0 per
        # sample to match the rest of the pipeline.
        N = len(ds)
        C = self.obs_shape[0]
        T_packed = (self.sample_T + 7) // 8
        packed = np.empty((N, T_packed, C), dtype=np.uint8)
        labels_np = np.empty((N,), dtype=np.int64)
        for i in range(N):
            frame, label = ds[i]
            packed[i] = np.packbits(frame, axis=0)
            labels_np[i] = int(label)

        n_batches = N // self.batch_size
        cutoff = n_batches * self.batch_size
        obs_NBTC = jnp.asarray(
            packed[:cutoff].reshape(n_batches, self.batch_size, T_packed, C)
        )
        labels_NB = jnp.asarray(labels_np[:cutoff].reshape(n_batches, self.batch_size))
        return obs_NBTC, labels_NB

    def train_epoch(self):
        return iter(self._train_dl)

    def test_epoch(self):
        return iter(self._test_dl)

prestage(split='train')

Bulk-load a split into a single on-device array, fast and torch-free.

Walks the underlying tonic dataset in-process (no grain workers, no PyTorch DataLoader) and rasterises each sample via :class:_SHD2Raster. For the 8 k-sample SHD train split this runs in a handful of seconds, vs. tens of seconds through grain's streaming pipeline which pays per-iterator spinup + inter-process shared-memory overhead. Matches the "entire dataset lives in vRAM" pattern the Spyx paper relied on for throughput.

:split: "train" or "test". :return: (obs_NBTC, labels_NB)obs is uint8[n_batches, batch_size, T_packed, channels] with time packed along axis 2 (axis 1 is batch_size), labels is int[n_batches, batch_size]. Trailing partial batch is dropped so the training loop can scan over a fixed N.

Source code in spyx/data.py
def prestage(self, split: str = "train"):
    """Bulk-load a split into a single on-device array, fast and torch-free.

    Walks the underlying tonic dataset in-process (no grain workers,
    no PyTorch DataLoader) and rasterises each sample via
    :class:`_SHD2Raster`. For the 8 k-sample SHD train split this runs
    in a handful of seconds, vs. tens of seconds through grain's
    streaming pipeline which pays per-iterator spinup + inter-process
    shared-memory overhead. Matches the "entire dataset lives in vRAM"
    pattern the Spyx paper relied on for throughput.

    :split: ``"train"`` or ``"test"``.
    :return: ``(obs_NBTC, labels_NB)`` — ``obs`` is
        ``uint8[n_batches, batch_size, T_packed, channels]`` with
        time packed along axis 2 (axis 1 is ``batch_size``), ``labels`` is
        ``int[n_batches, batch_size]``. Trailing partial batch is
        dropped so the training loop can scan over a fixed ``N``.
    """
    if split not in ("train", "test"):
        raise ValueError(f"split must be 'train' or 'test'; got {split!r}")
    ds = self._train_tonic if split == "train" else self._test_tonic

    # Allocate once and fill by index. _SHD2Raster output is already
    # (sample_T, channels) uint8 binary; we pack along axis 0 per
    # sample to match the rest of the pipeline.
    N = len(ds)
    C = self.obs_shape[0]
    T_packed = (self.sample_T + 7) // 8
    packed = np.empty((N, T_packed, C), dtype=np.uint8)
    labels_np = np.empty((N,), dtype=np.int64)
    for i in range(N):
        frame, label = ds[i]
        packed[i] = np.packbits(frame, axis=0)
        labels_np[i] = int(label)

    n_batches = N // self.batch_size
    cutoff = n_batches * self.batch_size
    obs_NBTC = jnp.asarray(
        packed[:cutoff].reshape(n_batches, self.batch_size, T_packed, C)
    )
    labels_NB = jnp.asarray(labels_np[:cutoff].reshape(n_batches, self.batch_size))
    return obs_NBTC, labels_NB

ShiftAugment

Bases: MapTransform

Grain MapTransform for random shift augmentation.

Source code in spyx/data.py
class ShiftAugment(grain.MapTransform):
    """
    Grain MapTransform for random shift augmentation.
    """

    def __init__(self, max_shift=10, axes=(-1,), input_key="obs"):
        self.max_shift = max_shift
        self.axes = axes
        self.input_key = input_key

    def map(self, record):
        data = record[self.input_key]
        shift = np.random.randint(-self.max_shift, self.max_shift, size=len(self.axes))
        record[self.input_key] = np.roll(data, shift, axis=self.axes)
        return record

angle_code(neuron_count, min_val, max_val)

Higher-order-function which returns an angle encoding function; given a continuous value, an angle converter generates a one-hot vector corresponding to where the value falls between a specified minimum and maximum. To achieve non-linear descritization, apply a function to the continuous value before feeding it into the encoder.

:neuron_count: The number of output channels for the angle encoder :min_val: A lower bound on the continuous input channel :max_val: An upper bound on the continuous input channel.

Source code in spyx/data.py
def angle_code(neuron_count, min_val, max_val):
    """
    Higher-order-function which returns an angle encoding function; given a continuous value, an angle converter generates a one-hot vector corresponding to where the value falls between a specified minimum and maximum.
    To achieve non-linear descritization, apply a function to the continuous value before feeding it into the encoder.

    :neuron_count: The number of output channels for the angle encoder
    :min_val: A lower bound on the continuous input channel
    :max_val: An upper bound on the continuous input channel.
    """
    neurons = jnp.linspace(min_val, max_val, neuron_count)

    def _call(obs):
        digital = jnp.digitize(obs, neurons) - 1
        digital = jnp.clip(digital, 0, neuron_count - 1)
        return jax.nn.one_hot(digital, neuron_count)

    return jax.jit(_call)

latency_code(num_steps, threshold=0.01)

Time-to-first-spike (latency) encoding.

Large input values fire earlier in the cycle, small values fire later. Concretely, an input in [0, 1] is mapped to a spike time t = round((1 - x) * (num_steps - 1)) and a one-hot spike train is emitted along a new leading time axis. Inputs below threshold never fire, producing all-zero rows.

The encoding preserves total information in a single spike per neuron, which is both far sparser than rate coding and matches the time-to- first-spike training scheme used in the neuromorphic hardware literature.

:param num_steps: length of the emitted spike train (time axis). :param threshold: values <= threshold are considered silent. :return: JIT-compiled function mapping data: [..., C] (values in [0, 1]) to spikes: [num_steps, ..., C] of dtype uint8.

Source code in spyx/data.py
def latency_code(num_steps, threshold=0.01):
    """Time-to-first-spike (latency) encoding.

    Large input values fire earlier in the cycle, small values fire later.
    Concretely, an input in ``[0, 1]`` is mapped to a spike time
    ``t = round((1 - x) * (num_steps - 1))`` and a one-hot spike train is
    emitted along a new leading time axis. Inputs below ``threshold`` never
    fire, producing all-zero rows.

    The encoding preserves total information in a single spike per neuron,
    which is both far sparser than rate coding and matches the time-to-
    first-spike training scheme used in the neuromorphic hardware literature.

    :param num_steps: length of the emitted spike train (time axis).
    :param threshold: values ``<= threshold`` are considered silent.
    :return: JIT-compiled function mapping ``data: [..., C]`` (values in
        ``[0, 1]``) to ``spikes: [num_steps, ..., C]`` of dtype ``uint8``.
    """

    def _call(data):
        x = jnp.asarray(data, dtype=jnp.float32)
        x = jnp.clip(x, 0.0, 1.0)
        spike_idx = jnp.round((1.0 - x) * (num_steps - 1)).astype(jnp.int32)
        # One-hot along a new time axis at the end, then move it to the front.
        one_hot = jax.nn.one_hot(spike_idx, num_steps, dtype=jnp.uint8)
        moved = jnp.moveaxis(one_hot, -1, 0)  # [T, ..., C]
        # Zero out silent units.
        silent_mask = (x <= threshold).astype(jnp.uint8)
        return moved * (1 - silent_mask)

    return jax.jit(_call)

rate_code(num_steps, max_r=0.75)

Unrolls input data along axis 1 and converts to rate encoded spikes; the probability of spiking is based on the input value multiplied by a max rate, with each time step being a sample drawn from a Bernoulli distribution. Currently Assumes input values have been rescaled to between 0 and 1.

Source code in spyx/data.py
def rate_code(num_steps, max_r=0.75):
    """
    Unrolls input data along axis 1 and converts to rate encoded spikes; the probability of spiking is based on the input value multiplied by a max rate, with each time step being a sample drawn from a Bernoulli distribution.
    Currently Assumes input values have been rescaled to between 0 and 1.
    """

    def _call(data, key):
        data = jnp.array(data, dtype=jnp.float16)
        unrolled_data = jnp.repeat(data, num_steps, axis=1)
        return jax.random.bernoulli(key, unrolled_data * max_r).astype(jnp.uint8)

    return jax.jit(_call)

shift_augment(max_shift=10, axes=(-1,))

Shift data augmentation tool. Rolls data along specified axes randomly up to a certain amount.

:max_shift: maximum to which values can be shifted. :axes: the data axis or axes along which the input will be randomly shifted.

Source code in spyx/data.py
def shift_augment(max_shift=10, axes=(-1,)):
    """Shift data augmentation tool. Rolls data along specified axes randomly up to a certain amount.


    :max_shift: maximum to which values can be shifted.
    :axes: the data axis or axes along which the input will be randomly shifted.
    """

    def _shift(data, rng):
        shift = jax.random.randint(rng, (len(axes),), -max_shift, max_shift)
        return jnp.roll(data, shift, axes)

    return jax.jit(_shift)

shuffler(dataset, batch_size)

Higher-order-function which builds a shuffle function for a dataset.

:dataset: jnp.array [# samples, time, channels...] :batch_size: desired batch size.

Source code in spyx/data.py
def shuffler(dataset, batch_size):
    """
    Higher-order-function which builds a shuffle function for a dataset.

    :dataset: jnp.array [# samples, time, channels...]
    :batch_size: desired batch size.
    """
    x, y = dataset
    cutoff = (y.shape[0] // batch_size) * batch_size
    data_shape = (-1, batch_size) + x.shape[1:]

    def _shuffle(dataset, shuffle_rng):
        """
        Given a dataset as a single tensor, shuffle its batches.

        :dataset: tuple of jnp.arrays with shape [# batches, batch size, time, ...] and [# batches, batchsize]
        :shuffle_rng: JAX.random.PRNGKey
        """
        x, y = dataset

        indices = jax.random.permutation(shuffle_rng, y.shape[0])[:cutoff]
        obs, labels = x[indices], y[indices]

        obs = jnp.reshape(obs, data_shape)
        labels = jnp.reshape(
            labels, (-1, batch_size)
        )  # should make batch size a global

        return (obs, labels)

    return jax.jit(_shuffle)