spyx.nir#

Module Contents#

Functions#

reorder_layers(init_params, trained_params)

Some optimization libraries may permute the keys of the network's PyTree;

to_nir(→ nir.NIRGraph)

Converts a Spyx network to a NIR graph. Under Construction. Currently only supports exporting networks without explicit recurrence/feedback.

from_nir(nir_graph, sample_batch, dt[, time_major, ...])

Converts a NIR graph to a Spyx network.

spyx.nir.reorder_layers(init_params, trained_params)[source]#

Some optimization libraries may permute the keys of the network’s PyTree; this is an issue as exporting to NIR assumes the keys are in their original order after initializing the network. This simple function takes the original and trained parameters and returns the trained parameters in the proper order for exportation.

spyx.nir.to_nir(spyx_pytree, input_shape, output_shape, dt=1) nir.NIRGraph[source]#

Converts a Spyx network to a NIR graph. Under Construction. Currently only supports exporting networks without explicit recurrence/feedback.

Return type:

nir.NIRGraph

spyx.nir.from_nir(nir_graph: nir.NIRGraph, sample_batch: jax.numpy.array, dt: float, time_major: bool = False, return_all_states: bool = False)[source]#

Converts a NIR graph to a Spyx network.

Parameters:
  • nir_graph (nir.NIRGraph) –

  • sample_batch (jax.numpy.array) –

  • dt (float) –

  • time_major (bool) –

  • return_all_states (bool) –