Training an SNN using Neuroevolution!

Training an SNN using Neuroevolution!#

This is a simple notebook showing how Spyx can be used to explore neuromorphic control. To run this example, you’ll need to install the Gymnax library using pip.

import spyx
import spyx.nn as snn

# JAX imports
import os
import jax
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".70"
from jax import numpy as jnp

from tqdm import tqdm

# implement our SNN in DeepMind's Haiku
import haiku as hk

# optimize the parameters using evosax
import evosax
from evosax.strategies import OpenES as ES

import gymnax

# rendering tools
import matplotlib.pyplot as plt
#%matplotlib notebook

Create Env#

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make("CartPole-v1")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)
done
/home/legion/.local/lib/python3.10/site-packages/gymnax/environments/spaces.py:38: UserWarning: Explicitly requested dtype <class 'jax.numpy.int64'> requested in astype is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  ).astype(self.dtype)
Array(False, dtype=bool)

Since the cartpole environment returns continuous values, we need some way to map those signals to spikes. One way to do this is by binning the angles and/or velocities into discrete ranges and then feeding spikes to the SNN when the input value is in a certain range. This allows for the network to only take in three spikes per time step instead of processing continuous values.

class binarize:
    def __init__(self, neuron_count, min_val, max_val):
        self.neuron_count = neuron_count
        self.min_val = min_val
        self.max_val = max_val
        
    def __call__(self, obs):
        digital = jnp.digitize(obs, jnp.linspace(self.min_val, self.max_val, self.neuron_count))
        return jax.nn.one_hot(digital, self.neuron_count)
    
class NeuromorphicCartpole:
    def __init__(self, angle_neurons=16, cart_v_neurons=16, pole_w_neurons=16):
        self.angle_converter = binarize(angle_neurons, -.21, .21)
        self.v_converter = binarize(cart_v_neurons, -3.5, 3.5)
        self.w_converter = binarize(pole_w_neurons, -3.5, 3.5)
        
    def __call__(self, obs):
        cart_v = self.v_converter(obs[1]) #self.v_converter(obs[1])
        theta = self.angle_converter(obs[2])
        pole_w = self.w_converter(obs[3]) #self.w_converter(obs[3])
        
        return jnp.concatenate([cart_v, theta, pole_w])
        
adapter = NeuromorphicCartpole()

SNN#

Here we define a simple controller network with a LIF layer and a pair of non-spiking leaky-integrate neurons that work in opposition to each other.

# push the cart left or right based on which LI neuron has the higher membrane potential.    
def action_selection(spike_trains):
    return jnp.argmax(spike_trains, axis=0)
        
def controller(x, state):
    # seqs is [T, F].
    core = hk.DeepRNN([
        hk.Linear(64, with_bias=False),
        snn.LIF(64, beta=0.8),
        hk.Linear(2, with_bias=False),
        snn.LI(2)
    ])
    
    spikes, out_state = core(x, state) # note here, instead of unrolling the SNN we have to manually pass the network state in and out
    return spikes, out_state           # this is because we are stepping the simulation and SNN in tandem.
key = jax.random.PRNGKey(0)
init_state = (jnp.zeros(64), jnp.zeros(2))
policy = hk.without_apply_rng(hk.transform(controller))
policy_params = policy.init(rng=key, x=adapter(obs), state=init_state)
policy.apply(policy_params, adapter(obs), init_state)
(Array([0., 0.], dtype=float32),
 (Array([ 0.16964896,  0.5141022 , -0.09112424, -0.18181202, -0.145446  ,
          0.23649067, -0.0926201 , -0.08848015,  0.1166473 ,  0.2440478 ,
          0.31869254,  0.38696027,  0.19660228, -0.03745358, -0.05304112,
          0.30566677,  0.07718392,  0.17089759, -0.213139  , -0.34491202,
         -0.2476323 , -0.04124497,  0.03419026, -0.1731914 ,  0.25182706,
         -0.29428682,  0.07899477, -0.12332907,  0.2795438 , -0.03600941,
          0.03087014, -0.3129828 ,  0.1251778 , -0.0249256 ,  0.41489142,
         -0.00283933,  0.04712863,  0.07251973, -0.17288032, -0.0770252 ,
          0.10979733, -0.00819776, -0.11719832, -0.57406926, -0.22288814,
          0.11991896, -0.25034907, -0.21484518,  0.11607233,  0.00887433,
          0.06059062, -0.19645457, -0.19393854,  0.14266899,  0.06141619,
          0.00783546,  0.1000794 , -0.25339204, -0.36541682,  0.1896359 ,
          0.17901754,  0.11905757, -0.12810999, -0.08214025], dtype=float32),
  Array([0., 0.], dtype=float32)))
adapter(obs).shape
(48,)

Evolution#

The JAX ecosystem includes a package for evolutionary strategies, allowing us to optimize our cartpole controller using neuroevolution! In our experiment, we’ll evaluate a population of 128 networks using a simple ES approach, with each controller getting 32 attempts at controlling the cart. We’ll let the population evolve for 25 generations and see what happens:

# Instantiate the environment & its settings.
env, env_params = gymnax.make("CartPole-v1")

def rollout(policy_params, init_policy_state, env_params, rng_input, steps_in_episode):
    """Rollout a jitted gymnax episode with lax.scan."""
    # Reset the environment
    rng_reset, rng_episode = jax.random.split(rng_input)
    obs, env_state = env.reset(rng_reset, env_params)

    def policy_step(state_input, tmp):
        """lax.scan compatible step transition in jax env."""

        # first unpack all of our state variables.
        obs, env_state, policy_params, policy_state, prev_done, rng = state_input
        # split our RNG apart.
        rng, rng_step, rng_net = jax.random.split(rng, 3)
        # get the network activity from the controller
        activation, new_policy_state = policy.apply(policy_params, adapter(obs), policy_state)
        action = action_selection(activation)
        next_obs, next_state, reward, done, _ = env.step(
            rng_step, env_state, action, env_params
        )
        carry = [next_obs, next_state, policy_params, new_policy_state, prev_done + done, rng]
        return carry, [obs, action, reward, next_obs, prev_done+done]

    # Scan over episode step loop
    _, scan_out = jax.lax.scan(
        policy_step,
        [obs, env_state, policy_params, init_policy_state, 
         False, rng_episode],
        (),
        steps_in_episode
    )
    # Return masked sum of rewards accumulated by agent in episode
    obs, action, reward, next_obs, done = scan_out
    return obs, action, reward, next_obs, done

jit_rollout = jax.jit(rollout, static_argnums=[4]) # compile the rollout/simulation function
vector_rollout = jax.vmap(jit_rollout, (0,None,None,None,None)) # autovectorize the rollout function across our population of parameters.
def evolution(SNN, params, epochs=25, trials=32, steps=500, key=0):

    rng = jax.random.PRNGKey(key)
    param_reshaper = evosax.ParameterReshaper(params)
    
    # Instantiate and initialize the evolution strategy
    strategy = ES(popsize=128,
                      num_dims=param_reshaper.total_params,
                      opt_name="adam"
                    )

    es_params = strategy.default_params
    es_params = es_params.replace(sigma_init=0.1, sigma_decay=0.999, sigma_limit=0.01)
    es_params = es_params.replace(opt_params=es_params.opt_params.replace(
        lrate_init=0.1, lrate_decay=0.999, lrate_limit=0.001))
    
    # check the initialization here....
    strat_state = strategy.initialize(rng, es_params)
        
    @jax.jit
    def step(rng, pop):
        rng, rng_eval = jax.random.split(rng)
        # ASK
        population_params = param_reshaper.reshape(pop)
        init_policy_state = init_state
                
        # EVAL
        obs, action, reward, next_obs, done = \
            vector_rollout(population_params, init_policy_state, env_params, rng_eval, steps)
        

        # TELL
        total_reward = jnp.sum(reward*(1-done), axis=-1)        
        return rng, total_reward, done
    
    
    # this code here can be refactored to be fully JIT compiled for even greater performance.
    for gen in range(epochs):
        
        # figure out way to JIT this inner loop better to account for trials
        total_reward = jnp.zeros([128])
        
        rng, rng_ask = jax.random.split(rng)
        pop, strat_state = strategy.ask(rng_ask, strat_state)

        
        pbar = tqdm([*range(trials)])
        pbar.set_description("Epoch #{}".format(gen))
        for trials_so_far in pbar:
            
            rng, reward, done = step(rng, pop)
            total_reward += reward
            pbar.set_postfix(Reward=jnp.max(total_reward)/(trials_so_far+1))
        
        strat_state = strategy.tell(pop, -total_reward/trials, strat_state)
            
        
    elite = param_reshaper.reshape(jnp.array([strat_state.best_member]))
    return jax.tree_util.tree_map(jnp.squeeze, elite)
elite_params = evolution(policy, policy_params)
ParameterReshaper: 3200 parameters detected for optimization.
Epoch #0: 100%|██████████| 32/32 [00:16<00:00,  1.90it/s, Reward=21.125]   
Epoch #1: 100%|██████████| 32/32 [00:14<00:00,  2.24it/s, Reward=30.90625] 
Epoch #2: 100%|██████████| 32/32 [00:12<00:00,  2.51it/s, Reward=46.96875] 
Epoch #3: 100%|██████████| 32/32 [00:12<00:00,  2.61it/s, Reward=58.96875] 
Epoch #4: 100%|██████████| 32/32 [00:12<00:00,  2.51it/s, Reward=73.9375]  
Epoch #5: 100%|██████████| 32/32 [00:12<00:00,  2.47it/s, Reward=120.96875] 
Epoch #6: 100%|██████████| 32/32 [00:13<00:00,  2.44it/s, Reward=131.65625]
Epoch #7: 100%|██████████| 32/32 [00:12<00:00,  2.52it/s, Reward=203.96875]
Epoch #8: 100%|██████████| 32/32 [00:12<00:00,  2.55it/s, Reward=217.03125]
Epoch #9: 100%|██████████| 32/32 [00:11<00:00,  2.85it/s, Reward=238.65625]
Epoch #10: 100%|██████████| 32/32 [00:11<00:00,  2.69it/s, Reward=247.28125]
Epoch #11: 100%|██████████| 32/32 [00:12<00:00,  2.51it/s, Reward=244.28125]
Epoch #12: 100%|██████████| 32/32 [00:13<00:00,  2.42it/s, Reward=288.71875]
Epoch #13: 100%|██████████| 32/32 [00:12<00:00,  2.47it/s, Reward=295.03125]
Epoch #14: 100%|██████████| 32/32 [00:12<00:00,  2.63it/s, Reward=307.3125] 
Epoch #15: 100%|██████████| 32/32 [00:11<00:00,  2.72it/s, Reward=351.1875] 
Epoch #16: 100%|██████████| 32/32 [00:11<00:00,  2.79it/s, Reward=422.6875] 
Epoch #17: 100%|██████████| 32/32 [00:11<00:00,  2.76it/s, Reward=430.25]   
Epoch #18: 100%|██████████| 32/32 [00:11<00:00,  2.73it/s, Reward=450.4375] 
Epoch #19: 100%|██████████| 32/32 [00:11<00:00,  2.68it/s, Reward=472.875]  
Epoch #20: 100%|██████████| 32/32 [00:11<00:00,  2.76it/s, Reward=472.65625]
Epoch #21: 100%|██████████| 32/32 [00:11<00:00,  2.76it/s, Reward=447.21875]
Epoch #22: 100%|██████████| 32/32 [00:11<00:00,  2.84it/s, Reward=457.4375] 
Epoch #23: 100%|██████████| 32/32 [00:12<00:00,  2.59it/s, Reward=475.3125] 
Epoch #24: 100%|██████████| 32/32 [00:13<00:00,  2.45it/s, Reward=475.3125] 

Results:#

After training, we can see our population of controllers has converged to promising reward values. Let’s try to run the best solution again to see how it fairs:

#activation_seq = []
action_seq = []
state_seq, reward_seq = [], []
rng, rng_reset = jax.random.split(rng)
obs, env_state = env.reset(rng_reset, env_params)
new_policy_state = init_state
while True:
    state_seq.append(env_state)
    rng, rng_step = jax.random.split(rng, 2)
    activation, new_policy_state = policy.apply(elite_params, adapter(obs), new_policy_state)
    action = action_selection(activation)
    action_seq.append(action)
    #activation_seq.append(activation)
    next_obs, next_env_state, reward, done, info = env.step(
        rng_step, env_state, action, env_params
    )
    reward_seq.append(reward)
    if done:
        break
    else:
        obs = next_obs
        env_state = next_env_state

cumulative_rewards = jnp.sum(jnp.array(reward_seq))
cumulative_rewards
Array(500., dtype=float32)

It earns a perfect score on the environment! The actions are listed below.

action_seq
[Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(1, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32),
 Array(0, dtype=int32)]

Next Steps:#

This notebook serves as a sampler for what is possible for neuromorphic control research in JAX. There are many more environments including Brax which are ripe for exploration and high throughput simulation. Similarly, we should be able to use Spyx to record controller neuron activities and then use that as a measure of efficiency compared to normal neural control models.