cyreal.tutorials.rl_quickstart

We provide utilities for interacting with gymnax environments. Take care to update the policy state in the dataloader after updating your policy parameters.

import gymnax
import jax
import jax.numpy as jnp

from cyreal.transforms import BatchTransform
from cyreal.loader import DataLoader
from cyreal.rl import set_loader_policy_state, set_source_policy_state
from cyreal.sources import GymnaxSource

env = gymnax.environments.classic_control.cartpole.CartPole()
env_params = env.default_params

def act(obs, policy_state, new_episode, key):
    # policy_state can hold nn parameters and recurrent states
    # new_episode can be used to reset recurrent states
    # within the policy_state if needed.
    logits = obs @ policy_state["params"]
    action = jax.random.categorical(key, logits=logits)
    return action, policy_state

policy_state = {
    "params": jnp.zeros((4, 2)),
    "recurrent_state": jnp.zeros((3,)),
}

# GymnaxSource will call policy_step_fn to sample actions from the environment
source = GymnaxSource(
    env=env,
    env_params=env_params,
    policy_step_fn=act,
    policy_state_template=policy_state,
    # Rollouts of length 32
    steps_per_epoch=32,
)
pipeline = [
    source,
    # Rollouts are length 32, batches are length 16
    # Two batches per epoch
    BatchTransform(batch_size=16),
]
loader = DataLoader(pipeline)
state = loader.init_state(jax.random.key(0))
state = set_loader_policy_state(state, policy_state)

# Perform training
for epoch in range(2):
    for _ in range(loader.steps_per_epoch):
        batch, state, mask = jax.jit(loader.next)(state)
        # Update the rollout policy parameters after each policy update
        policy_state.update({"params": jnp.ones((4, 2))})
        state = set_loader_policy_state(state, policy_state)
 1"""We provide utilities for interacting with `gymnax` environments. Take care to update the policy
 2state in the dataloader after updating your policy parameters.
 3
 4```python
 5import gymnax
 6import jax
 7import jax.numpy as jnp
 8
 9from cyreal.transforms import BatchTransform
10from cyreal.loader import DataLoader
11from cyreal.rl import set_loader_policy_state, set_source_policy_state
12from cyreal.sources import GymnaxSource
13
14env = gymnax.environments.classic_control.cartpole.CartPole()
15env_params = env.default_params
16
17def act(obs, policy_state, new_episode, key):
18    # policy_state can hold nn parameters and recurrent states
19    # new_episode can be used to reset recurrent states
20    # within the policy_state if needed.
21    logits = obs @ policy_state["params"]
22    action = jax.random.categorical(key, logits=logits)
23    return action, policy_state
24
25policy_state = {
26    "params": jnp.zeros((4, 2)),
27    "recurrent_state": jnp.zeros((3,)),
28}
29
30# GymnaxSource will call policy_step_fn to sample actions from the environment
31source = GymnaxSource(
32    env=env,
33    env_params=env_params,
34    policy_step_fn=act,
35    policy_state_template=policy_state,
36    # Rollouts of length 32
37    steps_per_epoch=32,
38)
39pipeline = [
40    source,
41    # Rollouts are length 32, batches are length 16
42    # Two batches per epoch
43    BatchTransform(batch_size=16),
44]
45loader = DataLoader(pipeline)
46state = loader.init_state(jax.random.key(0))
47state = set_loader_policy_state(state, policy_state)
48
49# Perform training
50for epoch in range(2):
51    for _ in range(loader.steps_per_epoch):
52        batch, state, mask = jax.jit(loader.next)(state)
53        # Update the rollout policy parameters after each policy update
54        policy_state.update({"params": jnp.ones((4, 2))})
55        state = set_loader_policy_state(state, policy_state)
56```
57"""