cyreal.tutorials.rl_quickstart
We provide utilities for interacting with gymnax environments. Take care to update the policy
state in the dataloader after updating your policy parameters.
import gymnax
import jax
import jax.numpy as jnp
from cyreal.transforms import BatchTransform
from cyreal.loader import DataLoader
from cyreal.rl import set_loader_policy_state, set_source_policy_state
from cyreal.sources import GymnaxSource
env = gymnax.environments.classic_control.cartpole.CartPole()
env_params = env.default_params
def act(obs, policy_state, new_episode, key):
# policy_state can hold nn parameters and recurrent states
# new_episode can be used to reset recurrent states
# within the policy_state if needed.
logits = obs @ policy_state["params"]
action = jax.random.categorical(key, logits=logits)
return action, policy_state
policy_state = {
"params": jnp.zeros((4, 2)),
"recurrent_state": jnp.zeros((3,)),
}
# GymnaxSource will call policy_step_fn to sample actions from the environment
source = GymnaxSource(
env=env,
env_params=env_params,
policy_step_fn=act,
policy_state_template=policy_state,
# Rollouts of length 32
steps_per_epoch=32,
)
pipeline = [
source,
# Rollouts are length 32, batches are length 16
# Two batches per epoch
BatchTransform(batch_size=16),
]
loader = DataLoader(pipeline)
state = loader.init_state(jax.random.key(0))
state = set_loader_policy_state(state, policy_state)
# Perform training
for epoch in range(2):
for _ in range(loader.steps_per_epoch):
batch, state, mask = jax.jit(loader.next)(state)
# Update the rollout policy parameters after each policy update
policy_state.update({"params": jnp.ones((4, 2))})
state = set_loader_policy_state(state, policy_state)
1"""We provide utilities for interacting with `gymnax` environments. Take care to update the policy 2state in the dataloader after updating your policy parameters. 3 4```python 5import gymnax 6import jax 7import jax.numpy as jnp 8 9from cyreal.transforms import BatchTransform 10from cyreal.loader import DataLoader 11from cyreal.rl import set_loader_policy_state, set_source_policy_state 12from cyreal.sources import GymnaxSource 13 14env = gymnax.environments.classic_control.cartpole.CartPole() 15env_params = env.default_params 16 17def act(obs, policy_state, new_episode, key): 18 # policy_state can hold nn parameters and recurrent states 19 # new_episode can be used to reset recurrent states 20 # within the policy_state if needed. 21 logits = obs @ policy_state["params"] 22 action = jax.random.categorical(key, logits=logits) 23 return action, policy_state 24 25policy_state = { 26 "params": jnp.zeros((4, 2)), 27 "recurrent_state": jnp.zeros((3,)), 28} 29 30# GymnaxSource will call policy_step_fn to sample actions from the environment 31source = GymnaxSource( 32 env=env, 33 env_params=env_params, 34 policy_step_fn=act, 35 policy_state_template=policy_state, 36 # Rollouts of length 32 37 steps_per_epoch=32, 38) 39pipeline = [ 40 source, 41 # Rollouts are length 32, batches are length 16 42 # Two batches per epoch 43 BatchTransform(batch_size=16), 44] 45loader = DataLoader(pipeline) 46state = loader.init_state(jax.random.key(0)) 47state = set_loader_policy_state(state, policy_state) 48 49# Perform training 50for epoch in range(2): 51 for _ in range(loader.steps_per_epoch): 52 batch, state, mask = jax.jit(loader.next)(state) 53 # Update the rollout policy parameters after each policy update 54 policy_state.update({"params": jnp.ones((4, 2))}) 55 state = set_loader_policy_state(state, policy_state) 56``` 57"""