cyreal.sources
Streaming dataset sources inspired by Grain's paradigm.
1"""Streaming dataset sources inspired by Grain's paradigm.""" 2from __future__ import annotations 3 4from dataclasses import dataclass 5from typing import Any, Callable, Literal, Protocol, TypeVar 6 7import jax 8import jax.numpy as jnp 9import numpy as np 10from jax import tree_util 11from jax.experimental import io_callback 12 13PyTree = Any 14StateT = TypeVar("StateT") 15 16 17class Source(Protocol[StateT]): 18 """Interface for stateful, JIT-friendly data streams. 19 20 A source exposes fixed-length epochs (`steps_per_epoch`) and provides 21 methods for initializing internal state and iteratively producing samples 22 plus boolean masks that denote whether a sample is valid (e.g. under 23 padding). 24 """ 25 26 steps_per_epoch: int 27 """Number of items emitted per epoch.""" 28 29 def init_state(self, key: jax.Array) -> StateT: 30 """Return an initial state for the source. 31 32 Args: 33 key: Optional PRNG key used for randomized behavior such as 34 shuffling. Implementations should fall back to a default key 35 when ``None`` is provided. 36 37 Returns: 38 Backend-specific state object that must be passed to ``next``. 39 """ 40 ... 41 42 def next(self, state: StateT) -> tuple[PyTree, jax.Array, StateT]: 43 """Advance the stream and return the next value. 44 45 Args: 46 state: Previously-initialized source state. 47 48 Returns: 49 Tuple ``(value, mask, new_state)`` where ``value`` is a PyTree of 50 arrays, ``mask`` is a boolean array indicating whether the sample 51 is valid, and ``new_state`` should be provided to the next call. 52 """ 53 ... 54 55 def element_spec(self) -> PyTree: 56 """PyTree of :class:`jax.ShapeDtypeStruct` describing emitted samples.""" 57 ... 58 59 60@jax.tree_util.register_pytree_node_class 61@dataclass 62class _ArraySourceState: 63 indices: jax.Array 64 mask: jax.Array 65 position: jax.Array 66 key: jax.Array 67 epoch: jax.Array 68 69 def tree_flatten(self): 70 return (self.indices, self.mask, self.position, self.key, self.epoch), None 71 72 @classmethod 73 def tree_unflatten(cls, aux_data, children): 74 indices, mask, position, key, epoch = children 75 return cls(indices=indices, mask=mask, position=position, key=key, epoch=epoch) 76 77 78@jax.tree_util.register_pytree_node_class 79@dataclass 80class _DiskSourceState: 81 indices: jax.Array 82 position: jax.Array 83 key: jax.Array 84 epoch: jax.Array 85 buffer: PyTree 86 buffer_pos: jax.Array 87 buffer_count: jax.Array 88 89 def tree_flatten(self): 90 buffer_leaves, buffer_def = tree_util.tree_flatten(self.buffer) 91 children = ( 92 self.indices, 93 self.position, 94 self.key, 95 self.epoch, 96 *buffer_leaves, 97 self.buffer_pos, 98 self.buffer_count, 99 ) 100 return children, buffer_def 101 102 @classmethod 103 def tree_unflatten(cls, buffer_def, children): 104 indices, position, key, epoch, *rest = children 105 buffer_leaf_count = buffer_def.num_leaves if buffer_def is not None else 0 106 buffer_leaves = rest[:buffer_leaf_count] 107 buffer = ( 108 tree_util.tree_unflatten(buffer_def, buffer_leaves) 109 if buffer_def is not None 110 else None 111 ) 112 buffer_pos, buffer_count = rest[buffer_leaf_count:] 113 return cls( 114 indices=indices, 115 position=position, 116 key=key, 117 epoch=epoch, 118 buffer=buffer, 119 buffer_pos=buffer_pos, 120 buffer_count=buffer_count, 121 ) 122 123 124@dataclass 125class ArraySource(Source[_ArraySourceState]): 126 """Sample-level stream over an in-memory PyTree of arrays. 127 128 Args: 129 data: PyTree whose leaves are arrays with a leading sample dimension. 130 ordering: Either ``"sequential"`` or ``"shuffle"``. 131 """ 132 133 data: PyTree 134 ordering: Literal["sequential", "shuffle"] = "shuffle" 135 136 def __post_init__(self) -> None: 137 leaves, self._treedef = tree_util.tree_flatten(self.data) 138 if not leaves: 139 raise ValueError("Data tree must contain at least one array.") 140 141 first = leaves[0] 142 self._num_samples = int(first.shape[0]) 143 for leaf in leaves[1:]: 144 if leaf.shape[0] != self._num_samples: 145 raise ValueError("All leaves must share the leading dimension.") 146 if self._num_samples == 0: 147 raise ValueError("Dataset cannot be empty.") 148 149 self.steps_per_epoch = self._num_samples 150 self._mask_template = jnp.ones(self._num_samples, dtype=bool) 151 self._element_spec = tree_util.tree_map( 152 lambda leaf: jax.ShapeDtypeStruct(shape=leaf.shape[1:], dtype=leaf.dtype), 153 self.data, 154 ) 155 156 @property 157 def num_samples(self) -> int: 158 return self._num_samples 159 160 def element_spec(self) -> PyTree: 161 """Shape/dtype metadata describing samples produced by the source.""" 162 return self._element_spec 163 164 def _build_epoch_indices(self, key: jax.Array) -> tuple[jax.Array, jax.Array]: 165 base = jnp.arange(self._num_samples) 166 if self.ordering == "shuffle": 167 base = jax.random.permutation(key, base) 168 elif self.ordering != "sequential": 169 raise ValueError(f"Unknown ordering '{self.ordering}'.") 170 171 return base, self._mask_template 172 173 def init_state(self, key: jax.Array) -> _ArraySourceState: 174 """Create the initial iteration state. 175 176 Args: 177 key: Optional PRNG key. Defaults to ``jax.random.PRNGKey(0)`` when 178 omitted. 179 """ 180 key, perm_key = jax.random.split(key) 181 indices, mask = self._build_epoch_indices(perm_key) 182 position = jnp.array(0, dtype=jnp.int32) 183 epoch = jnp.array(0, dtype=jnp.int32) 184 return _ArraySourceState(indices=indices, mask=mask, position=position, key=key, epoch=epoch) 185 186 def next(self, state: _ArraySourceState) -> tuple[PyTree, jax.Array, _ArraySourceState]: 187 """Return the next sample (with mask) and the advanced state.""" 188 index = jax.lax.dynamic_index_in_dim(state.indices, state.position, axis=0, keepdims=False) 189 mask_value = jax.lax.dynamic_index_in_dim(state.mask, state.position, axis=0, keepdims=False) 190 sample = tree_util.tree_map( 191 lambda arr: jax.lax.dynamic_index_in_dim(arr, index, axis=0, keepdims=False), 192 self.data, 193 ) 194 195 next_position = state.position + 1 196 197 def _reset_epoch(_: None): 198 new_key, perm_key = jax.random.split(state.key) 199 indices, mask = self._build_epoch_indices(perm_key) 200 return _ArraySourceState( 201 indices=indices, 202 mask=mask, 203 position=jnp.array(0, dtype=jnp.int32), 204 key=new_key, 205 epoch=state.epoch + 1, 206 ) 207 208 def _advance(_: None): 209 return _ArraySourceState( 210 indices=state.indices, 211 mask=state.mask, 212 position=next_position, 213 key=state.key, 214 epoch=state.epoch, 215 ) 216 217 need_reset = next_position >= state.indices.shape[0] 218 new_state = jax.lax.cond(need_reset, _reset_epoch, _advance, operand=None) 219 return sample, mask_value, new_state 220 221 222@dataclass 223class DiskSource(Source[_DiskSourceState]): 224 """Sample-level stream that loads items via a Python callback (disk, RPC, etc.). 225 226 This is slow, only use this if your dataset will not fit in system memory. 227 """ 228 229 length: int 230 """Number of samples in the dataset.""" 231 sample_fn: Callable[[int], PyTree] 232 """Python callable that takes an integer index and returns a PyTree of arrays.""" 233 sample_spec: PyTree | None = None 234 """Optional PyTree of `jax.ShapeDtypeStruct` describing the shape and dtype of samples.""" 235 ordering: Literal["sequential", "shuffle"] = "shuffle" 236 """Sample ordering strategy, either 'sequential' or 'shuffle'. The shuffling occurs over the entire dataset, not within the prefetch buffer.""" 237 prefetch_size: int = 64 238 """Number of samples to prefetch into a JAX array buffer. Set this larger to achieve better throughput at the cost of more memory usage.""" 239 240 def __post_init__(self) -> None: 241 if self.length <= 0: 242 raise ValueError("Dataset cannot be empty.") 243 if self.prefetch_size <= 0: 244 raise ValueError("prefetch_size must be positive.") 245 246 if self.sample_spec is None: 247 example = self.sample_fn(0) 248 249 def _to_spec(leaf): 250 arr = np.asarray(leaf) 251 return jax.ShapeDtypeStruct(shape=arr.shape, dtype=arr.dtype) 252 253 self.sample_spec = tree_util.tree_map(_to_spec, example) 254 255 leaves = tree_util.tree_leaves(self.sample_spec) 256 if not leaves: 257 raise ValueError("element_spec must include at least one leaf.") 258 for leaf in leaves: 259 if not isinstance(leaf, jax.ShapeDtypeStruct): 260 raise TypeError("element_spec leaves must be jax.ShapeDtypeStruct instances.") 261 262 self._num_samples = int(self.length) 263 self.steps_per_epoch = self._num_samples 264 self._element_spec = self.sample_spec 265 self.prefetch_size = int(self.prefetch_size) 266 267 def _zeros(spec: jax.ShapeDtypeStruct): 268 return np.zeros(spec.shape, dtype=np.dtype(spec.dtype)) 269 270 def _buffer_shape(spec: jax.ShapeDtypeStruct): 271 return jax.ShapeDtypeStruct( 272 shape=(self.prefetch_size, *spec.shape), 273 dtype=spec.dtype, 274 ) 275 276 self._zero_sample = tree_util.tree_map(_zeros, self.sample_spec) 277 self._chunk_spec = tree_util.tree_map(_buffer_shape, self.sample_spec) 278 self._buffer_template = tree_util.tree_map( 279 lambda spec: jnp.zeros((self.prefetch_size, *spec.shape), dtype=spec.dtype), 280 self.sample_spec, 281 ) 282 283 def element_spec(self) -> PyTree: 284 """Shape/dtype metadata describing samples produced by the source.""" 285 return self._element_spec 286 287 def _build_epoch_indices(self, key: jax.Array) -> jax.Array: 288 base = jnp.arange(self._num_samples) 289 if self.ordering == "shuffle": 290 base = jax.random.permutation(key, base) 291 elif self.ordering != "sequential": 292 raise ValueError(f"Unknown ordering '{self.ordering}'.") 293 return base 294 295 def init_state(self, key: jax.Array) -> _DiskSourceState: 296 """Build the starting state, optionally seeding randomness with ``key``.""" 297 key, perm_key = jax.random.split(key) 298 indices = self._build_epoch_indices(perm_key) 299 position = jnp.array(0, dtype=jnp.int32) 300 epoch = jnp.array(0, dtype=jnp.int32) 301 return _DiskSourceState( 302 indices=indices, 303 position=position, 304 key=key, 305 epoch=epoch, 306 buffer=self._buffer_template, 307 buffer_pos=jnp.array(0, dtype=jnp.int32), 308 buffer_count=jnp.array(0, dtype=jnp.int32), 309 ) 310 311 def _chunk_callback(self, indices: np.ndarray, mask: np.ndarray) -> PyTree: 312 idx_array = np.asarray(indices, dtype=np.int64) 313 mask_array = np.asarray(mask, dtype=bool) 314 samples: list[PyTree] = [] 315 for keep, idx in zip(mask_array, idx_array): 316 if keep: 317 samples.append(self.sample_fn(int(idx))) 318 else: 319 samples.append(self._zero_sample) 320 return tree_util.tree_map(lambda *xs: np.stack(xs, axis=0), *samples) 321 322 def _maybe_reset_epoch(self, state: _DiskSourceState) -> _DiskSourceState: 323 def _reset(state: _DiskSourceState): 324 new_key, perm_key = jax.random.split(state.key) 325 indices = self._build_epoch_indices(perm_key) 326 return _DiskSourceState( 327 indices=indices, 328 position=jnp.array(0, dtype=jnp.int32), 329 key=new_key, 330 epoch=state.epoch + 1, 331 buffer=self._buffer_template, 332 buffer_pos=jnp.array(0, dtype=jnp.int32), 333 buffer_count=jnp.array(0, dtype=jnp.int32), 334 ) 335 336 return jax.lax.cond(state.position >= self._num_samples, _reset, lambda s: s, state) 337 338 def _maybe_refill_buffer(self, state: _DiskSourceState) -> _DiskSourceState: 339 def _needs(state: _DiskSourceState): 340 return jnp.logical_or(state.buffer_count == 0, state.buffer_pos >= state.buffer_count) 341 342 def _refill(state: _DiskSourceState): 343 refreshed = self._maybe_reset_epoch(state) 344 remaining = self._num_samples - refreshed.position 345 chunk = jnp.minimum(remaining, self.prefetch_size) 346 chunk = jnp.maximum(chunk, 0) 347 chunk = chunk.astype(jnp.int32) 348 offsets = jnp.arange(self.prefetch_size, dtype=jnp.int32) 349 gather_positions = jnp.minimum( 350 refreshed.position + offsets, 351 refreshed.indices.shape[0] - 1, 352 ) 353 chunk_indices = jax.vmap( 354 lambda idx: jax.lax.dynamic_index_in_dim( 355 refreshed.indices, idx, axis=0, keepdims=False 356 ) 357 )(gather_positions) 358 valid_mask = offsets < chunk 359 buffer = io_callback(self._chunk_callback, self._chunk_spec, chunk_indices, valid_mask) 360 new_position = refreshed.position + chunk 361 return _DiskSourceState( 362 indices=refreshed.indices, 363 position=new_position, 364 key=refreshed.key, 365 epoch=refreshed.epoch, 366 buffer=buffer, 367 buffer_pos=jnp.array(0, dtype=jnp.int32), 368 buffer_count=chunk, 369 ) 370 371 return jax.lax.cond(_needs(state), _refill, lambda s: s, state) 372 373 def next(self, state: _DiskSourceState) -> tuple[PyTree, jax.Array, _DiskSourceState]: 374 """Return buffered sample, all-True mask, and updated state.""" 375 state = self._maybe_refill_buffer(state) 376 sample = tree_util.tree_map( 377 lambda buf: jax.lax.dynamic_index_in_dim( 378 buf, state.buffer_pos, axis=0, keepdims=False 379 ), 380 state.buffer, 381 ) 382 mask_value = jnp.array(True, dtype=bool) 383 new_state = _DiskSourceState( 384 indices=state.indices, 385 position=state.position, 386 key=state.key, 387 epoch=state.epoch, 388 buffer=state.buffer, 389 buffer_pos=state.buffer_pos + 1, 390 buffer_count=state.buffer_count, 391 ) 392 return sample, mask_value, new_state 393 394 395 396 397@jax.tree_util.register_pytree_node_class 398@dataclass 399class GymnaxSourceState: 400 env_state: PyTree 401 obs: PyTree 402 key: jax.Array 403 step: jax.Array 404 epoch: jax.Array 405 policy_state: PyTree | None = None 406 new_episode: jax.Array | None = None 407 408 def tree_flatten(self): 409 return ( 410 self.env_state, 411 self.obs, 412 self.key, 413 self.step, 414 self.epoch, 415 self.policy_state, 416 self.new_episode, 417 ), None 418 419 @classmethod 420 def tree_unflatten(cls, aux_data, children): 421 env_state, obs, key, step, epoch, policy_state, new_episode = children 422 return cls( 423 env_state=env_state, 424 obs=obs, 425 key=key, 426 step=step, 427 epoch=epoch, 428 policy_state=policy_state, 429 new_episode=new_episode, 430 ) 431 432 433@dataclass 434class GymnaxSource(Source[GymnaxSourceState]): 435 """Stream transitions by rolling out a Gymnax environment with a policy. 436 437 Useful for reinforcement learning. 438 439 Args: 440 env: Gymnax environment instance. 441 env_params: Parameters to pass to the environment's reset and step functions. 442 policy_step_fn: Callable that takes (observation, policy_state, new_episode, key) and 443 returns (action, new_policy_state). 444 policy_state_template: Example PyTree carrying everything required by 445 ``policy_step_fn`` (for example, policy parameters and recurrent 446 carries). This template is used only to infer the element spec; callers 447 are responsible for injecting a real policy state into the loader 448 state before calling ``next``. 449 steps_per_epoch: Number of environment steps per epoch for a single environment. 450 """ 451 452 env: Any 453 env_params: Any 454 policy_step_fn: Callable[[PyTree, PyTree, jax.Array, jax.Array], tuple[PyTree, PyTree]] 455 policy_state_template: PyTree | None = None 456 steps_per_epoch: int = 1024 457 458 def __post_init__(self) -> None: 459 if self.steps_per_epoch <= 0: 460 raise ValueError("steps_per_epoch must be positive.") 461 if self.policy_state_template is None: 462 raise ValueError("GymnaxSource requires a policy_state_template for shape inference.") 463 464 def _sample(key, policy_state): 465 obs, env_state = self.env.reset(key, self.env_params) 466 action, next_policy_state = self.policy_step_fn( 467 obs, 468 policy_state, 469 jnp.array(True, dtype=jnp.bool_), 470 key, 471 ) 472 next_obs, _, reward, done, info = self.env.step( 473 key, 474 env_state, 475 action, 476 self.env_params, 477 ) 478 transition = { 479 "state": obs, 480 "action": action, 481 "reward": reward, 482 "next_state": next_obs, 483 "done": done, 484 "info": info, 485 } 486 return transition, next_policy_state 487 488 shaped, _ = jax.eval_shape(_sample, jax.random.PRNGKey(0), self.policy_state_template) 489 self._element_spec = tree_util.tree_map( 490 lambda arr: jax.ShapeDtypeStruct(shape=arr.shape, dtype=arr.dtype), shaped 491 ) 492 self.policy_state_template = None 493 494 def element_spec(self) -> PyTree: 495 """Shape/dtype metadata describing Gymnax transitions.""" 496 return self._element_spec 497 498 def init_state(self, key: jax.Array) -> GymnaxSourceState: 499 """Return RNG-seeded environment + policy state for iteration.""" 500 key, env_key = jax.random.split(key) 501 obs, env_state = self.env.reset(env_key, self.env_params) 502 return GymnaxSourceState( 503 env_state=env_state, 504 obs=obs, 505 key=key, 506 step=jnp.array(0, dtype=jnp.int32), 507 epoch=jnp.array(0, dtype=jnp.int32), 508 policy_state=None, 509 new_episode=jnp.array(True, dtype=jnp.bool_), 510 ) 511 512 def next(self, state: GymnaxSourceState) -> tuple[PyTree, jax.Array, GymnaxSourceState]: 513 """Roll the environment forward one step and emit a transition.""" 514 key, policy_key, step_key, done_reset_key, epoch_reset_key = jax.random.split(state.key, 5) 515 516 if state.policy_state is None: 517 raise ValueError( 518 "GymnaxSource state is missing `policy_state`; set it explicitly before calling `next`." 519 ) 520 policy_state = state.policy_state 521 522 if state.new_episode is None: 523 raise ValueError("GymnaxSource state is missing `new_episode` flag.") 524 525 action, updated_policy_state = self.policy_step_fn( 526 state.obs, 527 policy_state, 528 state.new_episode, 529 policy_key, 530 ) 531 next_obs, next_env_state, reward, done, info = self.env.step( 532 step_key, 533 state.env_state, 534 action, 535 self.env_params, 536 ) 537 538 transition = { 539 "state": state.obs, 540 "action": action, 541 "reward": reward, 542 "next_state": next_obs, 543 "done": done, 544 "info": info, 545 } 546 mask = jnp.array(True, dtype=bool) 547 548 done_flag = jnp.asarray(done, dtype=bool) 549 done_flag = jnp.reshape(done_flag, ()) 550 reset_obs, reset_env_state = self.env.reset(done_reset_key, self.env_params) 551 552 cont_obs, cont_env_state = jax.lax.cond( 553 done_flag, 554 lambda _: (reset_obs, reset_env_state), 555 lambda _: (next_obs, next_env_state), 556 operand=None, 557 ) 558 559 next_step = state.step + 1 560 need_epoch_reset = next_step >= self.steps_per_epoch 561 562 def _reset_epoch(_: None): 563 epoch_obs, epoch_env_state = self.env.reset(epoch_reset_key, self.env_params) 564 return GymnaxSourceState( 565 env_state=epoch_env_state, 566 obs=epoch_obs, 567 key=key, 568 step=jnp.array(0, dtype=jnp.int32), 569 epoch=state.epoch + 1, 570 policy_state=updated_policy_state, 571 new_episode=jnp.array(True, dtype=jnp.bool_), 572 ) 573 574 def _continue(_: None): 575 return GymnaxSourceState( 576 env_state=cont_env_state, 577 obs=cont_obs, 578 key=key, 579 step=next_step, 580 epoch=state.epoch, 581 policy_state=updated_policy_state, 582 new_episode=done_flag, 583 ) 584 585 new_state = jax.lax.cond(need_epoch_reset, _reset_epoch, _continue, operand=None) 586 return transition, mask, new_state
18class Source(Protocol[StateT]): 19 """Interface for stateful, JIT-friendly data streams. 20 21 A source exposes fixed-length epochs (`steps_per_epoch`) and provides 22 methods for initializing internal state and iteratively producing samples 23 plus boolean masks that denote whether a sample is valid (e.g. under 24 padding). 25 """ 26 27 steps_per_epoch: int 28 """Number of items emitted per epoch.""" 29 30 def init_state(self, key: jax.Array) -> StateT: 31 """Return an initial state for the source. 32 33 Args: 34 key: Optional PRNG key used for randomized behavior such as 35 shuffling. Implementations should fall back to a default key 36 when ``None`` is provided. 37 38 Returns: 39 Backend-specific state object that must be passed to ``next``. 40 """ 41 ... 42 43 def next(self, state: StateT) -> tuple[PyTree, jax.Array, StateT]: 44 """Advance the stream and return the next value. 45 46 Args: 47 state: Previously-initialized source state. 48 49 Returns: 50 Tuple ``(value, mask, new_state)`` where ``value`` is a PyTree of 51 arrays, ``mask`` is a boolean array indicating whether the sample 52 is valid, and ``new_state`` should be provided to the next call. 53 """ 54 ... 55 56 def element_spec(self) -> PyTree: 57 """PyTree of :class:`jax.ShapeDtypeStruct` describing emitted samples.""" 58 ...
Interface for stateful, JIT-friendly data streams.
A source exposes fixed-length epochs (steps_per_epoch) and provides
methods for initializing internal state and iteratively producing samples
plus boolean masks that denote whether a sample is valid (e.g. under
padding).
1866def _no_init_or_replace_init(self, *args, **kwargs): 1867 cls = type(self) 1868 1869 if cls._is_protocol: 1870 raise TypeError('Protocols cannot be instantiated') 1871 1872 # Already using a custom `__init__`. No need to calculate correct 1873 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1874 if cls.__init__ is not _no_init_or_replace_init: 1875 return 1876 1877 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1878 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1879 # searches for a proper new `__init__` in the MRO. The new `__init__` 1880 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1881 # instantiation of the protocol subclass will thus use the new 1882 # `__init__` and no longer call `_no_init_or_replace_init`. 1883 for base in cls.__mro__: 1884 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1885 if init is not _no_init_or_replace_init: 1886 cls.__init__ = init 1887 break 1888 else: 1889 # should not happen 1890 cls.__init__ = object.__init__ 1891 1892 cls.__init__(self, *args, **kwargs)
30 def init_state(self, key: jax.Array) -> StateT: 31 """Return an initial state for the source. 32 33 Args: 34 key: Optional PRNG key used for randomized behavior such as 35 shuffling. Implementations should fall back to a default key 36 when ``None`` is provided. 37 38 Returns: 39 Backend-specific state object that must be passed to ``next``. 40 """ 41 ...
Return an initial state for the source.
Args:
key: Optional PRNG key used for randomized behavior such as
shuffling. Implementations should fall back to a default key
when None is provided.
Returns:
Backend-specific state object that must be passed to next.
43 def next(self, state: StateT) -> tuple[PyTree, jax.Array, StateT]: 44 """Advance the stream and return the next value. 45 46 Args: 47 state: Previously-initialized source state. 48 49 Returns: 50 Tuple ``(value, mask, new_state)`` where ``value`` is a PyTree of 51 arrays, ``mask`` is a boolean array indicating whether the sample 52 is valid, and ``new_state`` should be provided to the next call. 53 """ 54 ...
Advance the stream and return the next value.
Args: state: Previously-initialized source state.
Returns:
Tuple (value, mask, new_state) where value is a PyTree of
arrays, mask is a boolean array indicating whether the sample
is valid, and new_state should be provided to the next call.
125@dataclass 126class ArraySource(Source[_ArraySourceState]): 127 """Sample-level stream over an in-memory PyTree of arrays. 128 129 Args: 130 data: PyTree whose leaves are arrays with a leading sample dimension. 131 ordering: Either ``"sequential"`` or ``"shuffle"``. 132 """ 133 134 data: PyTree 135 ordering: Literal["sequential", "shuffle"] = "shuffle" 136 137 def __post_init__(self) -> None: 138 leaves, self._treedef = tree_util.tree_flatten(self.data) 139 if not leaves: 140 raise ValueError("Data tree must contain at least one array.") 141 142 first = leaves[0] 143 self._num_samples = int(first.shape[0]) 144 for leaf in leaves[1:]: 145 if leaf.shape[0] != self._num_samples: 146 raise ValueError("All leaves must share the leading dimension.") 147 if self._num_samples == 0: 148 raise ValueError("Dataset cannot be empty.") 149 150 self.steps_per_epoch = self._num_samples 151 self._mask_template = jnp.ones(self._num_samples, dtype=bool) 152 self._element_spec = tree_util.tree_map( 153 lambda leaf: jax.ShapeDtypeStruct(shape=leaf.shape[1:], dtype=leaf.dtype), 154 self.data, 155 ) 156 157 @property 158 def num_samples(self) -> int: 159 return self._num_samples 160 161 def element_spec(self) -> PyTree: 162 """Shape/dtype metadata describing samples produced by the source.""" 163 return self._element_spec 164 165 def _build_epoch_indices(self, key: jax.Array) -> tuple[jax.Array, jax.Array]: 166 base = jnp.arange(self._num_samples) 167 if self.ordering == "shuffle": 168 base = jax.random.permutation(key, base) 169 elif self.ordering != "sequential": 170 raise ValueError(f"Unknown ordering '{self.ordering}'.") 171 172 return base, self._mask_template 173 174 def init_state(self, key: jax.Array) -> _ArraySourceState: 175 """Create the initial iteration state. 176 177 Args: 178 key: Optional PRNG key. Defaults to ``jax.random.PRNGKey(0)`` when 179 omitted. 180 """ 181 key, perm_key = jax.random.split(key) 182 indices, mask = self._build_epoch_indices(perm_key) 183 position = jnp.array(0, dtype=jnp.int32) 184 epoch = jnp.array(0, dtype=jnp.int32) 185 return _ArraySourceState(indices=indices, mask=mask, position=position, key=key, epoch=epoch) 186 187 def next(self, state: _ArraySourceState) -> tuple[PyTree, jax.Array, _ArraySourceState]: 188 """Return the next sample (with mask) and the advanced state.""" 189 index = jax.lax.dynamic_index_in_dim(state.indices, state.position, axis=0, keepdims=False) 190 mask_value = jax.lax.dynamic_index_in_dim(state.mask, state.position, axis=0, keepdims=False) 191 sample = tree_util.tree_map( 192 lambda arr: jax.lax.dynamic_index_in_dim(arr, index, axis=0, keepdims=False), 193 self.data, 194 ) 195 196 next_position = state.position + 1 197 198 def _reset_epoch(_: None): 199 new_key, perm_key = jax.random.split(state.key) 200 indices, mask = self._build_epoch_indices(perm_key) 201 return _ArraySourceState( 202 indices=indices, 203 mask=mask, 204 position=jnp.array(0, dtype=jnp.int32), 205 key=new_key, 206 epoch=state.epoch + 1, 207 ) 208 209 def _advance(_: None): 210 return _ArraySourceState( 211 indices=state.indices, 212 mask=state.mask, 213 position=next_position, 214 key=state.key, 215 epoch=state.epoch, 216 ) 217 218 need_reset = next_position >= state.indices.shape[0] 219 new_state = jax.lax.cond(need_reset, _reset_epoch, _advance, operand=None) 220 return sample, mask_value, new_state
Sample-level stream over an in-memory PyTree of arrays.
Args:
data: PyTree whose leaves are arrays with a leading sample dimension.
ordering: Either "sequential" or "shuffle".
161 def element_spec(self) -> PyTree: 162 """Shape/dtype metadata describing samples produced by the source.""" 163 return self._element_spec
Shape/dtype metadata describing samples produced by the source.
174 def init_state(self, key: jax.Array) -> _ArraySourceState: 175 """Create the initial iteration state. 176 177 Args: 178 key: Optional PRNG key. Defaults to ``jax.random.PRNGKey(0)`` when 179 omitted. 180 """ 181 key, perm_key = jax.random.split(key) 182 indices, mask = self._build_epoch_indices(perm_key) 183 position = jnp.array(0, dtype=jnp.int32) 184 epoch = jnp.array(0, dtype=jnp.int32) 185 return _ArraySourceState(indices=indices, mask=mask, position=position, key=key, epoch=epoch)
Create the initial iteration state.
Args:
key: Optional PRNG key. Defaults to jax.random.PRNGKey(0) when
omitted.
187 def next(self, state: _ArraySourceState) -> tuple[PyTree, jax.Array, _ArraySourceState]: 188 """Return the next sample (with mask) and the advanced state.""" 189 index = jax.lax.dynamic_index_in_dim(state.indices, state.position, axis=0, keepdims=False) 190 mask_value = jax.lax.dynamic_index_in_dim(state.mask, state.position, axis=0, keepdims=False) 191 sample = tree_util.tree_map( 192 lambda arr: jax.lax.dynamic_index_in_dim(arr, index, axis=0, keepdims=False), 193 self.data, 194 ) 195 196 next_position = state.position + 1 197 198 def _reset_epoch(_: None): 199 new_key, perm_key = jax.random.split(state.key) 200 indices, mask = self._build_epoch_indices(perm_key) 201 return _ArraySourceState( 202 indices=indices, 203 mask=mask, 204 position=jnp.array(0, dtype=jnp.int32), 205 key=new_key, 206 epoch=state.epoch + 1, 207 ) 208 209 def _advance(_: None): 210 return _ArraySourceState( 211 indices=state.indices, 212 mask=state.mask, 213 position=next_position, 214 key=state.key, 215 epoch=state.epoch, 216 ) 217 218 need_reset = next_position >= state.indices.shape[0] 219 new_state = jax.lax.cond(need_reset, _reset_epoch, _advance, operand=None) 220 return sample, mask_value, new_state
Return the next sample (with mask) and the advanced state.
Inherited Members
223@dataclass 224class DiskSource(Source[_DiskSourceState]): 225 """Sample-level stream that loads items via a Python callback (disk, RPC, etc.). 226 227 This is slow, only use this if your dataset will not fit in system memory. 228 """ 229 230 length: int 231 """Number of samples in the dataset.""" 232 sample_fn: Callable[[int], PyTree] 233 """Python callable that takes an integer index and returns a PyTree of arrays.""" 234 sample_spec: PyTree | None = None 235 """Optional PyTree of `jax.ShapeDtypeStruct` describing the shape and dtype of samples.""" 236 ordering: Literal["sequential", "shuffle"] = "shuffle" 237 """Sample ordering strategy, either 'sequential' or 'shuffle'. The shuffling occurs over the entire dataset, not within the prefetch buffer.""" 238 prefetch_size: int = 64 239 """Number of samples to prefetch into a JAX array buffer. Set this larger to achieve better throughput at the cost of more memory usage.""" 240 241 def __post_init__(self) -> None: 242 if self.length <= 0: 243 raise ValueError("Dataset cannot be empty.") 244 if self.prefetch_size <= 0: 245 raise ValueError("prefetch_size must be positive.") 246 247 if self.sample_spec is None: 248 example = self.sample_fn(0) 249 250 def _to_spec(leaf): 251 arr = np.asarray(leaf) 252 return jax.ShapeDtypeStruct(shape=arr.shape, dtype=arr.dtype) 253 254 self.sample_spec = tree_util.tree_map(_to_spec, example) 255 256 leaves = tree_util.tree_leaves(self.sample_spec) 257 if not leaves: 258 raise ValueError("element_spec must include at least one leaf.") 259 for leaf in leaves: 260 if not isinstance(leaf, jax.ShapeDtypeStruct): 261 raise TypeError("element_spec leaves must be jax.ShapeDtypeStruct instances.") 262 263 self._num_samples = int(self.length) 264 self.steps_per_epoch = self._num_samples 265 self._element_spec = self.sample_spec 266 self.prefetch_size = int(self.prefetch_size) 267 268 def _zeros(spec: jax.ShapeDtypeStruct): 269 return np.zeros(spec.shape, dtype=np.dtype(spec.dtype)) 270 271 def _buffer_shape(spec: jax.ShapeDtypeStruct): 272 return jax.ShapeDtypeStruct( 273 shape=(self.prefetch_size, *spec.shape), 274 dtype=spec.dtype, 275 ) 276 277 self._zero_sample = tree_util.tree_map(_zeros, self.sample_spec) 278 self._chunk_spec = tree_util.tree_map(_buffer_shape, self.sample_spec) 279 self._buffer_template = tree_util.tree_map( 280 lambda spec: jnp.zeros((self.prefetch_size, *spec.shape), dtype=spec.dtype), 281 self.sample_spec, 282 ) 283 284 def element_spec(self) -> PyTree: 285 """Shape/dtype metadata describing samples produced by the source.""" 286 return self._element_spec 287 288 def _build_epoch_indices(self, key: jax.Array) -> jax.Array: 289 base = jnp.arange(self._num_samples) 290 if self.ordering == "shuffle": 291 base = jax.random.permutation(key, base) 292 elif self.ordering != "sequential": 293 raise ValueError(f"Unknown ordering '{self.ordering}'.") 294 return base 295 296 def init_state(self, key: jax.Array) -> _DiskSourceState: 297 """Build the starting state, optionally seeding randomness with ``key``.""" 298 key, perm_key = jax.random.split(key) 299 indices = self._build_epoch_indices(perm_key) 300 position = jnp.array(0, dtype=jnp.int32) 301 epoch = jnp.array(0, dtype=jnp.int32) 302 return _DiskSourceState( 303 indices=indices, 304 position=position, 305 key=key, 306 epoch=epoch, 307 buffer=self._buffer_template, 308 buffer_pos=jnp.array(0, dtype=jnp.int32), 309 buffer_count=jnp.array(0, dtype=jnp.int32), 310 ) 311 312 def _chunk_callback(self, indices: np.ndarray, mask: np.ndarray) -> PyTree: 313 idx_array = np.asarray(indices, dtype=np.int64) 314 mask_array = np.asarray(mask, dtype=bool) 315 samples: list[PyTree] = [] 316 for keep, idx in zip(mask_array, idx_array): 317 if keep: 318 samples.append(self.sample_fn(int(idx))) 319 else: 320 samples.append(self._zero_sample) 321 return tree_util.tree_map(lambda *xs: np.stack(xs, axis=0), *samples) 322 323 def _maybe_reset_epoch(self, state: _DiskSourceState) -> _DiskSourceState: 324 def _reset(state: _DiskSourceState): 325 new_key, perm_key = jax.random.split(state.key) 326 indices = self._build_epoch_indices(perm_key) 327 return _DiskSourceState( 328 indices=indices, 329 position=jnp.array(0, dtype=jnp.int32), 330 key=new_key, 331 epoch=state.epoch + 1, 332 buffer=self._buffer_template, 333 buffer_pos=jnp.array(0, dtype=jnp.int32), 334 buffer_count=jnp.array(0, dtype=jnp.int32), 335 ) 336 337 return jax.lax.cond(state.position >= self._num_samples, _reset, lambda s: s, state) 338 339 def _maybe_refill_buffer(self, state: _DiskSourceState) -> _DiskSourceState: 340 def _needs(state: _DiskSourceState): 341 return jnp.logical_or(state.buffer_count == 0, state.buffer_pos >= state.buffer_count) 342 343 def _refill(state: _DiskSourceState): 344 refreshed = self._maybe_reset_epoch(state) 345 remaining = self._num_samples - refreshed.position 346 chunk = jnp.minimum(remaining, self.prefetch_size) 347 chunk = jnp.maximum(chunk, 0) 348 chunk = chunk.astype(jnp.int32) 349 offsets = jnp.arange(self.prefetch_size, dtype=jnp.int32) 350 gather_positions = jnp.minimum( 351 refreshed.position + offsets, 352 refreshed.indices.shape[0] - 1, 353 ) 354 chunk_indices = jax.vmap( 355 lambda idx: jax.lax.dynamic_index_in_dim( 356 refreshed.indices, idx, axis=0, keepdims=False 357 ) 358 )(gather_positions) 359 valid_mask = offsets < chunk 360 buffer = io_callback(self._chunk_callback, self._chunk_spec, chunk_indices, valid_mask) 361 new_position = refreshed.position + chunk 362 return _DiskSourceState( 363 indices=refreshed.indices, 364 position=new_position, 365 key=refreshed.key, 366 epoch=refreshed.epoch, 367 buffer=buffer, 368 buffer_pos=jnp.array(0, dtype=jnp.int32), 369 buffer_count=chunk, 370 ) 371 372 return jax.lax.cond(_needs(state), _refill, lambda s: s, state) 373 374 def next(self, state: _DiskSourceState) -> tuple[PyTree, jax.Array, _DiskSourceState]: 375 """Return buffered sample, all-True mask, and updated state.""" 376 state = self._maybe_refill_buffer(state) 377 sample = tree_util.tree_map( 378 lambda buf: jax.lax.dynamic_index_in_dim( 379 buf, state.buffer_pos, axis=0, keepdims=False 380 ), 381 state.buffer, 382 ) 383 mask_value = jnp.array(True, dtype=bool) 384 new_state = _DiskSourceState( 385 indices=state.indices, 386 position=state.position, 387 key=state.key, 388 epoch=state.epoch, 389 buffer=state.buffer, 390 buffer_pos=state.buffer_pos + 1, 391 buffer_count=state.buffer_count, 392 ) 393 return sample, mask_value, new_state
Sample-level stream that loads items via a Python callback (disk, RPC, etc.).
This is slow, only use this if your dataset will not fit in system memory.
Python callable that takes an integer index and returns a PyTree of arrays.
Optional PyTree of jax.ShapeDtypeStruct describing the shape and dtype of samples.
Sample ordering strategy, either 'sequential' or 'shuffle'. The shuffling occurs over the entire dataset, not within the prefetch buffer.
Number of samples to prefetch into a JAX array buffer. Set this larger to achieve better throughput at the cost of more memory usage.
284 def element_spec(self) -> PyTree: 285 """Shape/dtype metadata describing samples produced by the source.""" 286 return self._element_spec
Shape/dtype metadata describing samples produced by the source.
296 def init_state(self, key: jax.Array) -> _DiskSourceState: 297 """Build the starting state, optionally seeding randomness with ``key``.""" 298 key, perm_key = jax.random.split(key) 299 indices = self._build_epoch_indices(perm_key) 300 position = jnp.array(0, dtype=jnp.int32) 301 epoch = jnp.array(0, dtype=jnp.int32) 302 return _DiskSourceState( 303 indices=indices, 304 position=position, 305 key=key, 306 epoch=epoch, 307 buffer=self._buffer_template, 308 buffer_pos=jnp.array(0, dtype=jnp.int32), 309 buffer_count=jnp.array(0, dtype=jnp.int32), 310 )
Build the starting state, optionally seeding randomness with key.
374 def next(self, state: _DiskSourceState) -> tuple[PyTree, jax.Array, _DiskSourceState]: 375 """Return buffered sample, all-True mask, and updated state.""" 376 state = self._maybe_refill_buffer(state) 377 sample = tree_util.tree_map( 378 lambda buf: jax.lax.dynamic_index_in_dim( 379 buf, state.buffer_pos, axis=0, keepdims=False 380 ), 381 state.buffer, 382 ) 383 mask_value = jnp.array(True, dtype=bool) 384 new_state = _DiskSourceState( 385 indices=state.indices, 386 position=state.position, 387 key=state.key, 388 epoch=state.epoch, 389 buffer=state.buffer, 390 buffer_pos=state.buffer_pos + 1, 391 buffer_count=state.buffer_count, 392 ) 393 return sample, mask_value, new_state
Return buffered sample, all-True mask, and updated state.
Inherited Members
398@jax.tree_util.register_pytree_node_class 399@dataclass 400class GymnaxSourceState: 401 env_state: PyTree 402 obs: PyTree 403 key: jax.Array 404 step: jax.Array 405 epoch: jax.Array 406 policy_state: PyTree | None = None 407 new_episode: jax.Array | None = None 408 409 def tree_flatten(self): 410 return ( 411 self.env_state, 412 self.obs, 413 self.key, 414 self.step, 415 self.epoch, 416 self.policy_state, 417 self.new_episode, 418 ), None 419 420 @classmethod 421 def tree_unflatten(cls, aux_data, children): 422 env_state, obs, key, step, epoch, policy_state, new_episode = children 423 return cls( 424 env_state=env_state, 425 obs=obs, 426 key=key, 427 step=step, 428 epoch=epoch, 429 policy_state=policy_state, 430 new_episode=new_episode, 431 )
420 @classmethod 421 def tree_unflatten(cls, aux_data, children): 422 env_state, obs, key, step, epoch, policy_state, new_episode = children 423 return cls( 424 env_state=env_state, 425 obs=obs, 426 key=key, 427 step=step, 428 epoch=epoch, 429 policy_state=policy_state, 430 new_episode=new_episode, 431 )
434@dataclass 435class GymnaxSource(Source[GymnaxSourceState]): 436 """Stream transitions by rolling out a Gymnax environment with a policy. 437 438 Useful for reinforcement learning. 439 440 Args: 441 env: Gymnax environment instance. 442 env_params: Parameters to pass to the environment's reset and step functions. 443 policy_step_fn: Callable that takes (observation, policy_state, new_episode, key) and 444 returns (action, new_policy_state). 445 policy_state_template: Example PyTree carrying everything required by 446 ``policy_step_fn`` (for example, policy parameters and recurrent 447 carries). This template is used only to infer the element spec; callers 448 are responsible for injecting a real policy state into the loader 449 state before calling ``next``. 450 steps_per_epoch: Number of environment steps per epoch for a single environment. 451 """ 452 453 env: Any 454 env_params: Any 455 policy_step_fn: Callable[[PyTree, PyTree, jax.Array, jax.Array], tuple[PyTree, PyTree]] 456 policy_state_template: PyTree | None = None 457 steps_per_epoch: int = 1024 458 459 def __post_init__(self) -> None: 460 if self.steps_per_epoch <= 0: 461 raise ValueError("steps_per_epoch must be positive.") 462 if self.policy_state_template is None: 463 raise ValueError("GymnaxSource requires a policy_state_template for shape inference.") 464 465 def _sample(key, policy_state): 466 obs, env_state = self.env.reset(key, self.env_params) 467 action, next_policy_state = self.policy_step_fn( 468 obs, 469 policy_state, 470 jnp.array(True, dtype=jnp.bool_), 471 key, 472 ) 473 next_obs, _, reward, done, info = self.env.step( 474 key, 475 env_state, 476 action, 477 self.env_params, 478 ) 479 transition = { 480 "state": obs, 481 "action": action, 482 "reward": reward, 483 "next_state": next_obs, 484 "done": done, 485 "info": info, 486 } 487 return transition, next_policy_state 488 489 shaped, _ = jax.eval_shape(_sample, jax.random.PRNGKey(0), self.policy_state_template) 490 self._element_spec = tree_util.tree_map( 491 lambda arr: jax.ShapeDtypeStruct(shape=arr.shape, dtype=arr.dtype), shaped 492 ) 493 self.policy_state_template = None 494 495 def element_spec(self) -> PyTree: 496 """Shape/dtype metadata describing Gymnax transitions.""" 497 return self._element_spec 498 499 def init_state(self, key: jax.Array) -> GymnaxSourceState: 500 """Return RNG-seeded environment + policy state for iteration.""" 501 key, env_key = jax.random.split(key) 502 obs, env_state = self.env.reset(env_key, self.env_params) 503 return GymnaxSourceState( 504 env_state=env_state, 505 obs=obs, 506 key=key, 507 step=jnp.array(0, dtype=jnp.int32), 508 epoch=jnp.array(0, dtype=jnp.int32), 509 policy_state=None, 510 new_episode=jnp.array(True, dtype=jnp.bool_), 511 ) 512 513 def next(self, state: GymnaxSourceState) -> tuple[PyTree, jax.Array, GymnaxSourceState]: 514 """Roll the environment forward one step and emit a transition.""" 515 key, policy_key, step_key, done_reset_key, epoch_reset_key = jax.random.split(state.key, 5) 516 517 if state.policy_state is None: 518 raise ValueError( 519 "GymnaxSource state is missing `policy_state`; set it explicitly before calling `next`." 520 ) 521 policy_state = state.policy_state 522 523 if state.new_episode is None: 524 raise ValueError("GymnaxSource state is missing `new_episode` flag.") 525 526 action, updated_policy_state = self.policy_step_fn( 527 state.obs, 528 policy_state, 529 state.new_episode, 530 policy_key, 531 ) 532 next_obs, next_env_state, reward, done, info = self.env.step( 533 step_key, 534 state.env_state, 535 action, 536 self.env_params, 537 ) 538 539 transition = { 540 "state": state.obs, 541 "action": action, 542 "reward": reward, 543 "next_state": next_obs, 544 "done": done, 545 "info": info, 546 } 547 mask = jnp.array(True, dtype=bool) 548 549 done_flag = jnp.asarray(done, dtype=bool) 550 done_flag = jnp.reshape(done_flag, ()) 551 reset_obs, reset_env_state = self.env.reset(done_reset_key, self.env_params) 552 553 cont_obs, cont_env_state = jax.lax.cond( 554 done_flag, 555 lambda _: (reset_obs, reset_env_state), 556 lambda _: (next_obs, next_env_state), 557 operand=None, 558 ) 559 560 next_step = state.step + 1 561 need_epoch_reset = next_step >= self.steps_per_epoch 562 563 def _reset_epoch(_: None): 564 epoch_obs, epoch_env_state = self.env.reset(epoch_reset_key, self.env_params) 565 return GymnaxSourceState( 566 env_state=epoch_env_state, 567 obs=epoch_obs, 568 key=key, 569 step=jnp.array(0, dtype=jnp.int32), 570 epoch=state.epoch + 1, 571 policy_state=updated_policy_state, 572 new_episode=jnp.array(True, dtype=jnp.bool_), 573 ) 574 575 def _continue(_: None): 576 return GymnaxSourceState( 577 env_state=cont_env_state, 578 obs=cont_obs, 579 key=key, 580 step=next_step, 581 epoch=state.epoch, 582 policy_state=updated_policy_state, 583 new_episode=done_flag, 584 ) 585 586 new_state = jax.lax.cond(need_epoch_reset, _reset_epoch, _continue, operand=None) 587 return transition, mask, new_state
Stream transitions by rolling out a Gymnax environment with a policy.
Useful for reinforcement learning.
Args:
env: Gymnax environment instance.
env_params: Parameters to pass to the environment's reset and step functions.
policy_step_fn: Callable that takes (observation, policy_state, new_episode, key) and
returns (action, new_policy_state).
policy_state_template: Example PyTree carrying everything required by
policy_step_fn (for example, policy parameters and recurrent
carries). This template is used only to infer the element spec; callers
are responsible for injecting a real policy state into the loader
state before calling next.
steps_per_epoch: Number of environment steps per epoch for a single environment.
495 def element_spec(self) -> PyTree: 496 """Shape/dtype metadata describing Gymnax transitions.""" 497 return self._element_spec
Shape/dtype metadata describing Gymnax transitions.
499 def init_state(self, key: jax.Array) -> GymnaxSourceState: 500 """Return RNG-seeded environment + policy state for iteration.""" 501 key, env_key = jax.random.split(key) 502 obs, env_state = self.env.reset(env_key, self.env_params) 503 return GymnaxSourceState( 504 env_state=env_state, 505 obs=obs, 506 key=key, 507 step=jnp.array(0, dtype=jnp.int32), 508 epoch=jnp.array(0, dtype=jnp.int32), 509 policy_state=None, 510 new_episode=jnp.array(True, dtype=jnp.bool_), 511 )
Return RNG-seeded environment + policy state for iteration.
513 def next(self, state: GymnaxSourceState) -> tuple[PyTree, jax.Array, GymnaxSourceState]: 514 """Roll the environment forward one step and emit a transition.""" 515 key, policy_key, step_key, done_reset_key, epoch_reset_key = jax.random.split(state.key, 5) 516 517 if state.policy_state is None: 518 raise ValueError( 519 "GymnaxSource state is missing `policy_state`; set it explicitly before calling `next`." 520 ) 521 policy_state = state.policy_state 522 523 if state.new_episode is None: 524 raise ValueError("GymnaxSource state is missing `new_episode` flag.") 525 526 action, updated_policy_state = self.policy_step_fn( 527 state.obs, 528 policy_state, 529 state.new_episode, 530 policy_key, 531 ) 532 next_obs, next_env_state, reward, done, info = self.env.step( 533 step_key, 534 state.env_state, 535 action, 536 self.env_params, 537 ) 538 539 transition = { 540 "state": state.obs, 541 "action": action, 542 "reward": reward, 543 "next_state": next_obs, 544 "done": done, 545 "info": info, 546 } 547 mask = jnp.array(True, dtype=bool) 548 549 done_flag = jnp.asarray(done, dtype=bool) 550 done_flag = jnp.reshape(done_flag, ()) 551 reset_obs, reset_env_state = self.env.reset(done_reset_key, self.env_params) 552 553 cont_obs, cont_env_state = jax.lax.cond( 554 done_flag, 555 lambda _: (reset_obs, reset_env_state), 556 lambda _: (next_obs, next_env_state), 557 operand=None, 558 ) 559 560 next_step = state.step + 1 561 need_epoch_reset = next_step >= self.steps_per_epoch 562 563 def _reset_epoch(_: None): 564 epoch_obs, epoch_env_state = self.env.reset(epoch_reset_key, self.env_params) 565 return GymnaxSourceState( 566 env_state=epoch_env_state, 567 obs=epoch_obs, 568 key=key, 569 step=jnp.array(0, dtype=jnp.int32), 570 epoch=state.epoch + 1, 571 policy_state=updated_policy_state, 572 new_episode=jnp.array(True, dtype=jnp.bool_), 573 ) 574 575 def _continue(_: None): 576 return GymnaxSourceState( 577 env_state=cont_env_state, 578 obs=cont_obs, 579 key=key, 580 step=next_step, 581 epoch=state.epoch, 582 policy_state=updated_policy_state, 583 new_episode=done_flag, 584 ) 585 586 new_state = jax.lax.cond(need_epoch_reset, _reset_epoch, _continue, operand=None) 587 return transition, mask, new_state
Roll the environment forward one step and emit a transition.