cyreal.sources

Streaming dataset sources inspired by Grain's paradigm.

  1"""Streaming dataset sources inspired by Grain's paradigm."""
  2from __future__ import annotations
  3
  4from dataclasses import dataclass
  5from typing import Any, Callable, Literal, Protocol, TypeVar
  6
  7import jax
  8import jax.numpy as jnp
  9import numpy as np
 10from jax import tree_util
 11from jax.experimental import io_callback
 12
 13PyTree = Any
 14StateT = TypeVar("StateT")
 15
 16
 17class Source(Protocol[StateT]):
 18    """Interface for stateful, JIT-friendly data streams.
 19
 20    A source exposes fixed-length epochs (`steps_per_epoch`) and provides
 21    methods for initializing internal state and iteratively producing samples
 22    plus boolean masks that denote whether a sample is valid (e.g. under
 23    padding).
 24    """
 25
 26    steps_per_epoch: int
 27    """Number of items emitted per epoch."""
 28
 29    def init_state(self, key: jax.Array) -> StateT:
 30        """Return an initial state for the source.
 31
 32        Args:
 33            key: Optional PRNG key used for randomized behavior such as
 34                shuffling. Implementations should fall back to a default key
 35                when ``None`` is provided.
 36
 37        Returns:
 38            Backend-specific state object that must be passed to ``next``.
 39        """
 40        ...
 41
 42    def next(self, state: StateT) -> tuple[PyTree, jax.Array, StateT]:
 43        """Advance the stream and return the next value.
 44
 45        Args:
 46            state: Previously-initialized source state.
 47
 48        Returns:
 49            Tuple ``(value, mask, new_state)`` where ``value`` is a PyTree of
 50            arrays, ``mask`` is a boolean array indicating whether the sample
 51            is valid, and ``new_state`` should be provided to the next call.
 52        """
 53        ...
 54
 55    def element_spec(self) -> PyTree:
 56        """PyTree of :class:`jax.ShapeDtypeStruct` describing emitted samples."""
 57        ...
 58
 59
 60@jax.tree_util.register_pytree_node_class
 61@dataclass
 62class _ArraySourceState:
 63    indices: jax.Array
 64    mask: jax.Array
 65    position: jax.Array
 66    key: jax.Array
 67    epoch: jax.Array
 68
 69    def tree_flatten(self):
 70        return (self.indices, self.mask, self.position, self.key, self.epoch), None
 71
 72    @classmethod
 73    def tree_unflatten(cls, aux_data, children):
 74        indices, mask, position, key, epoch = children
 75        return cls(indices=indices, mask=mask, position=position, key=key, epoch=epoch)
 76
 77
 78@jax.tree_util.register_pytree_node_class
 79@dataclass
 80class _DiskSourceState:
 81    indices: jax.Array
 82    position: jax.Array
 83    key: jax.Array
 84    epoch: jax.Array
 85    buffer: PyTree
 86    buffer_pos: jax.Array
 87    buffer_count: jax.Array
 88
 89    def tree_flatten(self):
 90        buffer_leaves, buffer_def = tree_util.tree_flatten(self.buffer)
 91        children = (
 92            self.indices,
 93            self.position,
 94            self.key,
 95            self.epoch,
 96            *buffer_leaves,
 97            self.buffer_pos,
 98            self.buffer_count,
 99        )
100        return children, buffer_def
101
102    @classmethod
103    def tree_unflatten(cls, buffer_def, children):
104        indices, position, key, epoch, *rest = children
105        buffer_leaf_count = buffer_def.num_leaves if buffer_def is not None else 0
106        buffer_leaves = rest[:buffer_leaf_count]
107        buffer = (
108            tree_util.tree_unflatten(buffer_def, buffer_leaves)
109            if buffer_def is not None
110            else None
111        )
112        buffer_pos, buffer_count = rest[buffer_leaf_count:]
113        return cls(
114            indices=indices,
115            position=position,
116            key=key,
117            epoch=epoch,
118            buffer=buffer,
119            buffer_pos=buffer_pos,
120            buffer_count=buffer_count,
121        )
122
123
124@dataclass
125class ArraySource(Source[_ArraySourceState]):
126    """Sample-level stream over an in-memory PyTree of arrays.
127
128    Args:
129        data: PyTree whose leaves are arrays with a leading sample dimension.
130        ordering: Either ``"sequential"`` or ``"shuffle"``.
131    """
132
133    data: PyTree
134    ordering: Literal["sequential", "shuffle"] = "shuffle"
135
136    def __post_init__(self) -> None:
137        leaves, self._treedef = tree_util.tree_flatten(self.data)
138        if not leaves:
139            raise ValueError("Data tree must contain at least one array.")
140
141        first = leaves[0]
142        self._num_samples = int(first.shape[0])
143        for leaf in leaves[1:]:
144            if leaf.shape[0] != self._num_samples:
145                raise ValueError("All leaves must share the leading dimension.")
146        if self._num_samples == 0:
147            raise ValueError("Dataset cannot be empty.")
148
149        self.steps_per_epoch = self._num_samples
150        self._mask_template = jnp.ones(self._num_samples, dtype=bool)
151        self._element_spec = tree_util.tree_map(
152            lambda leaf: jax.ShapeDtypeStruct(shape=leaf.shape[1:], dtype=leaf.dtype),
153            self.data,
154        )
155
156    @property
157    def num_samples(self) -> int:
158        return self._num_samples
159
160    def element_spec(self) -> PyTree:
161        """Shape/dtype metadata describing samples produced by the source."""
162        return self._element_spec
163
164    def _build_epoch_indices(self, key: jax.Array) -> tuple[jax.Array, jax.Array]:
165        base = jnp.arange(self._num_samples)
166        if self.ordering == "shuffle":
167            base = jax.random.permutation(key, base)
168        elif self.ordering != "sequential":
169            raise ValueError(f"Unknown ordering '{self.ordering}'.")
170
171        return base, self._mask_template
172
173    def init_state(self, key: jax.Array) -> _ArraySourceState:
174        """Create the initial iteration state.
175
176        Args:
177            key: Optional PRNG key. Defaults to ``jax.random.PRNGKey(0)`` when
178                omitted.
179        """
180        key, perm_key = jax.random.split(key)
181        indices, mask = self._build_epoch_indices(perm_key)
182        position = jnp.array(0, dtype=jnp.int32)
183        epoch = jnp.array(0, dtype=jnp.int32)
184        return _ArraySourceState(indices=indices, mask=mask, position=position, key=key, epoch=epoch)
185
186    def next(self, state: _ArraySourceState) -> tuple[PyTree, jax.Array, _ArraySourceState]:
187        """Return the next sample (with mask) and the advanced state."""
188        index = jax.lax.dynamic_index_in_dim(state.indices, state.position, axis=0, keepdims=False)
189        mask_value = jax.lax.dynamic_index_in_dim(state.mask, state.position, axis=0, keepdims=False)
190        sample = tree_util.tree_map(
191            lambda arr: jax.lax.dynamic_index_in_dim(arr, index, axis=0, keepdims=False),
192            self.data,
193        )
194
195        next_position = state.position + 1
196
197        def _reset_epoch(_: None):
198            new_key, perm_key = jax.random.split(state.key)
199            indices, mask = self._build_epoch_indices(perm_key)
200            return _ArraySourceState(
201                indices=indices,
202                mask=mask,
203                position=jnp.array(0, dtype=jnp.int32),
204                key=new_key,
205                epoch=state.epoch + 1,
206            )
207
208        def _advance(_: None):
209            return _ArraySourceState(
210                indices=state.indices,
211                mask=state.mask,
212                position=next_position,
213                key=state.key,
214                epoch=state.epoch,
215            )
216
217        need_reset = next_position >= state.indices.shape[0]
218        new_state = jax.lax.cond(need_reset, _reset_epoch, _advance, operand=None)
219        return sample, mask_value, new_state
220
221
222@dataclass
223class DiskSource(Source[_DiskSourceState]):
224    """Sample-level stream that loads items via a Python callback (disk, RPC, etc.).
225
226    This is slow, only use this if your dataset will not fit in system memory.
227    """
228
229    length: int
230    """Number of samples in the dataset."""
231    sample_fn: Callable[[int], PyTree]
232    """Python callable that takes an integer index and returns a PyTree of arrays."""
233    sample_spec: PyTree | None = None
234    """Optional PyTree of `jax.ShapeDtypeStruct` describing the shape and dtype of samples."""
235    ordering: Literal["sequential", "shuffle"] = "shuffle"
236    """Sample ordering strategy, either 'sequential' or 'shuffle'. The shuffling occurs over the entire dataset, not within the prefetch buffer."""
237    prefetch_size: int = 64
238    """Number of samples to prefetch into a JAX array buffer. Set this larger to achieve better throughput at the cost of more memory usage."""
239
240    def __post_init__(self) -> None:
241        if self.length <= 0:
242            raise ValueError("Dataset cannot be empty.")
243        if self.prefetch_size <= 0:
244            raise ValueError("prefetch_size must be positive.")
245
246        if self.sample_spec is None:
247            example = self.sample_fn(0)
248
249            def _to_spec(leaf):
250                arr = np.asarray(leaf)
251                return jax.ShapeDtypeStruct(shape=arr.shape, dtype=arr.dtype)
252
253            self.sample_spec = tree_util.tree_map(_to_spec, example)
254
255        leaves = tree_util.tree_leaves(self.sample_spec)
256        if not leaves:
257            raise ValueError("element_spec must include at least one leaf.")
258        for leaf in leaves:
259            if not isinstance(leaf, jax.ShapeDtypeStruct):
260                raise TypeError("element_spec leaves must be jax.ShapeDtypeStruct instances.")
261
262        self._num_samples = int(self.length)
263        self.steps_per_epoch = self._num_samples
264        self._element_spec = self.sample_spec
265        self.prefetch_size = int(self.prefetch_size)
266
267        def _zeros(spec: jax.ShapeDtypeStruct):
268            return np.zeros(spec.shape, dtype=np.dtype(spec.dtype))
269
270        def _buffer_shape(spec: jax.ShapeDtypeStruct):
271            return jax.ShapeDtypeStruct(
272                shape=(self.prefetch_size, *spec.shape),
273                dtype=spec.dtype,
274            )
275
276        self._zero_sample = tree_util.tree_map(_zeros, self.sample_spec)
277        self._chunk_spec = tree_util.tree_map(_buffer_shape, self.sample_spec)
278        self._buffer_template = tree_util.tree_map(
279            lambda spec: jnp.zeros((self.prefetch_size, *spec.shape), dtype=spec.dtype),
280            self.sample_spec,
281        )
282
283    def element_spec(self) -> PyTree:
284        """Shape/dtype metadata describing samples produced by the source."""
285        return self._element_spec
286
287    def _build_epoch_indices(self, key: jax.Array) -> jax.Array:
288        base = jnp.arange(self._num_samples)
289        if self.ordering == "shuffle":
290            base = jax.random.permutation(key, base)
291        elif self.ordering != "sequential":
292            raise ValueError(f"Unknown ordering '{self.ordering}'.")
293        return base
294
295    def init_state(self, key: jax.Array) -> _DiskSourceState:
296        """Build the starting state, optionally seeding randomness with ``key``."""
297        key, perm_key = jax.random.split(key)
298        indices = self._build_epoch_indices(perm_key)
299        position = jnp.array(0, dtype=jnp.int32)
300        epoch = jnp.array(0, dtype=jnp.int32)
301        return _DiskSourceState(
302            indices=indices,
303            position=position,
304            key=key,
305            epoch=epoch,
306            buffer=self._buffer_template,
307            buffer_pos=jnp.array(0, dtype=jnp.int32),
308            buffer_count=jnp.array(0, dtype=jnp.int32),
309        )
310
311    def _chunk_callback(self, indices: np.ndarray, mask: np.ndarray) -> PyTree:
312        idx_array = np.asarray(indices, dtype=np.int64)
313        mask_array = np.asarray(mask, dtype=bool)
314        samples: list[PyTree] = []
315        for keep, idx in zip(mask_array, idx_array):
316            if keep:
317                samples.append(self.sample_fn(int(idx)))
318            else:
319                samples.append(self._zero_sample)
320        return tree_util.tree_map(lambda *xs: np.stack(xs, axis=0), *samples)
321
322    def _maybe_reset_epoch(self, state: _DiskSourceState) -> _DiskSourceState:
323        def _reset(state: _DiskSourceState):
324            new_key, perm_key = jax.random.split(state.key)
325            indices = self._build_epoch_indices(perm_key)
326            return _DiskSourceState(
327                indices=indices,
328                position=jnp.array(0, dtype=jnp.int32),
329                key=new_key,
330                epoch=state.epoch + 1,
331                buffer=self._buffer_template,
332                buffer_pos=jnp.array(0, dtype=jnp.int32),
333                buffer_count=jnp.array(0, dtype=jnp.int32),
334            )
335
336        return jax.lax.cond(state.position >= self._num_samples, _reset, lambda s: s, state)
337
338    def _maybe_refill_buffer(self, state: _DiskSourceState) -> _DiskSourceState:
339        def _needs(state: _DiskSourceState):
340            return jnp.logical_or(state.buffer_count == 0, state.buffer_pos >= state.buffer_count)
341
342        def _refill(state: _DiskSourceState):
343            refreshed = self._maybe_reset_epoch(state)
344            remaining = self._num_samples - refreshed.position
345            chunk = jnp.minimum(remaining, self.prefetch_size)
346            chunk = jnp.maximum(chunk, 0)
347            chunk = chunk.astype(jnp.int32)
348            offsets = jnp.arange(self.prefetch_size, dtype=jnp.int32)
349            gather_positions = jnp.minimum(
350                refreshed.position + offsets,
351                refreshed.indices.shape[0] - 1,
352            )
353            chunk_indices = jax.vmap(
354                lambda idx: jax.lax.dynamic_index_in_dim(
355                    refreshed.indices, idx, axis=0, keepdims=False
356                )
357            )(gather_positions)
358            valid_mask = offsets < chunk
359            buffer = io_callback(self._chunk_callback, self._chunk_spec, chunk_indices, valid_mask)
360            new_position = refreshed.position + chunk
361            return _DiskSourceState(
362                indices=refreshed.indices,
363                position=new_position,
364                key=refreshed.key,
365                epoch=refreshed.epoch,
366                buffer=buffer,
367                buffer_pos=jnp.array(0, dtype=jnp.int32),
368                buffer_count=chunk,
369            )
370
371        return jax.lax.cond(_needs(state), _refill, lambda s: s, state)
372
373    def next(self, state: _DiskSourceState) -> tuple[PyTree, jax.Array, _DiskSourceState]:
374        """Return buffered sample, all-True mask, and updated state."""
375        state = self._maybe_refill_buffer(state)
376        sample = tree_util.tree_map(
377            lambda buf: jax.lax.dynamic_index_in_dim(
378                buf, state.buffer_pos, axis=0, keepdims=False
379            ),
380            state.buffer,
381        )
382        mask_value = jnp.array(True, dtype=bool)
383        new_state = _DiskSourceState(
384            indices=state.indices,
385            position=state.position,
386            key=state.key,
387            epoch=state.epoch,
388            buffer=state.buffer,
389            buffer_pos=state.buffer_pos + 1,
390            buffer_count=state.buffer_count,
391        )
392        return sample, mask_value, new_state
393
394
395
396
397@jax.tree_util.register_pytree_node_class
398@dataclass
399class GymnaxSourceState:
400    env_state: PyTree
401    obs: PyTree
402    key: jax.Array
403    step: jax.Array
404    epoch: jax.Array
405    policy_state: PyTree | None = None
406    new_episode: jax.Array | None = None
407
408    def tree_flatten(self):
409        return (
410            self.env_state,
411            self.obs,
412            self.key,
413            self.step,
414            self.epoch,
415            self.policy_state,
416            self.new_episode,
417        ), None
418
419    @classmethod
420    def tree_unflatten(cls, aux_data, children):
421        env_state, obs, key, step, epoch, policy_state, new_episode = children
422        return cls(
423            env_state=env_state,
424            obs=obs,
425            key=key,
426            step=step,
427            epoch=epoch,
428            policy_state=policy_state,
429            new_episode=new_episode,
430        )
431
432
433@dataclass
434class GymnaxSource(Source[GymnaxSourceState]):
435    """Stream transitions by rolling out a Gymnax environment with a policy.
436    
437    Useful for reinforcement learning.
438
439    Args:
440        env: Gymnax environment instance.
441        env_params: Parameters to pass to the environment's reset and step functions.
442        policy_step_fn: Callable that takes (observation, policy_state, new_episode, key) and
443            returns (action, new_policy_state).
444        policy_state_template: Example PyTree carrying everything required by
445            ``policy_step_fn`` (for example, policy parameters and recurrent
446            carries). This template is used only to infer the element spec; callers
447            are responsible for injecting a real policy state into the loader
448            state before calling ``next``.
449        steps_per_epoch: Number of environment steps per epoch for a single environment.
450    """
451
452    env: Any
453    env_params: Any
454    policy_step_fn: Callable[[PyTree, PyTree, jax.Array, jax.Array], tuple[PyTree, PyTree]]
455    policy_state_template: PyTree | None = None
456    steps_per_epoch: int = 1024
457
458    def __post_init__(self) -> None:
459        if self.steps_per_epoch <= 0:
460            raise ValueError("steps_per_epoch must be positive.")
461        if self.policy_state_template is None:
462            raise ValueError("GymnaxSource requires a policy_state_template for shape inference.")
463
464        def _sample(key, policy_state):
465            obs, env_state = self.env.reset(key, self.env_params)
466            action, next_policy_state = self.policy_step_fn(
467                obs,
468                policy_state,
469                jnp.array(True, dtype=jnp.bool_),
470                key,
471            )
472            next_obs, _, reward, done, info = self.env.step(
473                key,
474                env_state,
475                action,
476                self.env_params,
477            )
478            transition = {
479                "state": obs,
480                "action": action,
481                "reward": reward,
482                "next_state": next_obs,
483                "done": done,
484                "info": info,
485            }
486            return transition, next_policy_state
487
488        shaped, _ = jax.eval_shape(_sample, jax.random.PRNGKey(0), self.policy_state_template)
489        self._element_spec = tree_util.tree_map(
490            lambda arr: jax.ShapeDtypeStruct(shape=arr.shape, dtype=arr.dtype), shaped
491        )
492        self.policy_state_template = None
493
494    def element_spec(self) -> PyTree:
495        """Shape/dtype metadata describing Gymnax transitions."""
496        return self._element_spec
497
498    def init_state(self, key: jax.Array) -> GymnaxSourceState:
499        """Return RNG-seeded environment + policy state for iteration."""
500        key, env_key = jax.random.split(key)
501        obs, env_state = self.env.reset(env_key, self.env_params)
502        return GymnaxSourceState(
503            env_state=env_state,
504            obs=obs,
505            key=key,
506            step=jnp.array(0, dtype=jnp.int32),
507            epoch=jnp.array(0, dtype=jnp.int32),
508            policy_state=None,
509            new_episode=jnp.array(True, dtype=jnp.bool_),
510        )
511
512    def next(self, state: GymnaxSourceState) -> tuple[PyTree, jax.Array, GymnaxSourceState]:
513        """Roll the environment forward one step and emit a transition."""
514        key, policy_key, step_key, done_reset_key, epoch_reset_key = jax.random.split(state.key, 5)
515
516        if state.policy_state is None:
517            raise ValueError(
518                "GymnaxSource state is missing `policy_state`; set it explicitly before calling `next`."
519            )
520        policy_state = state.policy_state
521
522        if state.new_episode is None:
523            raise ValueError("GymnaxSource state is missing `new_episode` flag.")
524
525        action, updated_policy_state = self.policy_step_fn(
526            state.obs,
527            policy_state,
528            state.new_episode,
529            policy_key,
530        )
531        next_obs, next_env_state, reward, done, info = self.env.step(
532            step_key,
533            state.env_state,
534            action,
535            self.env_params,
536        )
537
538        transition = {
539            "state": state.obs,
540            "action": action,
541            "reward": reward,
542            "next_state": next_obs,
543            "done": done,
544            "info": info,
545        }
546        mask = jnp.array(True, dtype=bool)
547
548        done_flag = jnp.asarray(done, dtype=bool)
549        done_flag = jnp.reshape(done_flag, ())
550        reset_obs, reset_env_state = self.env.reset(done_reset_key, self.env_params)
551
552        cont_obs, cont_env_state = jax.lax.cond(
553            done_flag,
554            lambda _: (reset_obs, reset_env_state),
555            lambda _: (next_obs, next_env_state),
556            operand=None,
557        )
558
559        next_step = state.step + 1
560        need_epoch_reset = next_step >= self.steps_per_epoch
561
562        def _reset_epoch(_: None):
563            epoch_obs, epoch_env_state = self.env.reset(epoch_reset_key, self.env_params)
564            return GymnaxSourceState(
565                env_state=epoch_env_state,
566                obs=epoch_obs,
567                key=key,
568                step=jnp.array(0, dtype=jnp.int32),
569                epoch=state.epoch + 1,
570                policy_state=updated_policy_state,
571                new_episode=jnp.array(True, dtype=jnp.bool_),
572            )
573
574        def _continue(_: None):
575            return GymnaxSourceState(
576                env_state=cont_env_state,
577                obs=cont_obs,
578                key=key,
579                step=next_step,
580                epoch=state.epoch,
581                policy_state=updated_policy_state,
582                new_episode=done_flag,
583            )
584
585        new_state = jax.lax.cond(need_epoch_reset, _reset_epoch, _continue, operand=None)
586        return transition, mask, new_state
PyTree = typing.Any
class Source(typing.Protocol[~StateT]):
18class Source(Protocol[StateT]):
19    """Interface for stateful, JIT-friendly data streams.
20
21    A source exposes fixed-length epochs (`steps_per_epoch`) and provides
22    methods for initializing internal state and iteratively producing samples
23    plus boolean masks that denote whether a sample is valid (e.g. under
24    padding).
25    """
26
27    steps_per_epoch: int
28    """Number of items emitted per epoch."""
29
30    def init_state(self, key: jax.Array) -> StateT:
31        """Return an initial state for the source.
32
33        Args:
34            key: Optional PRNG key used for randomized behavior such as
35                shuffling. Implementations should fall back to a default key
36                when ``None`` is provided.
37
38        Returns:
39            Backend-specific state object that must be passed to ``next``.
40        """
41        ...
42
43    def next(self, state: StateT) -> tuple[PyTree, jax.Array, StateT]:
44        """Advance the stream and return the next value.
45
46        Args:
47            state: Previously-initialized source state.
48
49        Returns:
50            Tuple ``(value, mask, new_state)`` where ``value`` is a PyTree of
51            arrays, ``mask`` is a boolean array indicating whether the sample
52            is valid, and ``new_state`` should be provided to the next call.
53        """
54        ...
55
56    def element_spec(self) -> PyTree:
57        """PyTree of :class:`jax.ShapeDtypeStruct` describing emitted samples."""
58        ...

Interface for stateful, JIT-friendly data streams.

A source exposes fixed-length epochs (steps_per_epoch) and provides methods for initializing internal state and iteratively producing samples plus boolean masks that denote whether a sample is valid (e.g. under padding).

Source(*args, **kwargs)
1866def _no_init_or_replace_init(self, *args, **kwargs):
1867    cls = type(self)
1868
1869    if cls._is_protocol:
1870        raise TypeError('Protocols cannot be instantiated')
1871
1872    # Already using a custom `__init__`. No need to calculate correct
1873    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1874    if cls.__init__ is not _no_init_or_replace_init:
1875        return
1876
1877    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1878    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1879    # searches for a proper new `__init__` in the MRO. The new `__init__`
1880    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1881    # instantiation of the protocol subclass will thus use the new
1882    # `__init__` and no longer call `_no_init_or_replace_init`.
1883    for base in cls.__mro__:
1884        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1885        if init is not _no_init_or_replace_init:
1886            cls.__init__ = init
1887            break
1888    else:
1889        # should not happen
1890        cls.__init__ = object.__init__
1891
1892    cls.__init__(self, *args, **kwargs)
steps_per_epoch: int

Number of items emitted per epoch.

def init_state(self, key: jax.Array) -> ~StateT:
30    def init_state(self, key: jax.Array) -> StateT:
31        """Return an initial state for the source.
32
33        Args:
34            key: Optional PRNG key used for randomized behavior such as
35                shuffling. Implementations should fall back to a default key
36                when ``None`` is provided.
37
38        Returns:
39            Backend-specific state object that must be passed to ``next``.
40        """
41        ...

Return an initial state for the source.

Args: key: Optional PRNG key used for randomized behavior such as shuffling. Implementations should fall back to a default key when None is provided.

Returns: Backend-specific state object that must be passed to next.

def next(self, state: ~StateT) -> tuple[typing.Any, jax.Array, ~StateT]:
43    def next(self, state: StateT) -> tuple[PyTree, jax.Array, StateT]:
44        """Advance the stream and return the next value.
45
46        Args:
47            state: Previously-initialized source state.
48
49        Returns:
50            Tuple ``(value, mask, new_state)`` where ``value`` is a PyTree of
51            arrays, ``mask`` is a boolean array indicating whether the sample
52            is valid, and ``new_state`` should be provided to the next call.
53        """
54        ...

Advance the stream and return the next value.

Args: state: Previously-initialized source state.

Returns: Tuple (value, mask, new_state) where value is a PyTree of arrays, mask is a boolean array indicating whether the sample is valid, and new_state should be provided to the next call.

def element_spec(self) -> Any:
56    def element_spec(self) -> PyTree:
57        """PyTree of :class:`jax.ShapeDtypeStruct` describing emitted samples."""
58        ...

PyTree of jax.ShapeDtypeStruct describing emitted samples.

@dataclass
class ArraySource(cyreal.sources.Source[cyreal.sources._ArraySourceState]):
125@dataclass
126class ArraySource(Source[_ArraySourceState]):
127    """Sample-level stream over an in-memory PyTree of arrays.
128
129    Args:
130        data: PyTree whose leaves are arrays with a leading sample dimension.
131        ordering: Either ``"sequential"`` or ``"shuffle"``.
132    """
133
134    data: PyTree
135    ordering: Literal["sequential", "shuffle"] = "shuffle"
136
137    def __post_init__(self) -> None:
138        leaves, self._treedef = tree_util.tree_flatten(self.data)
139        if not leaves:
140            raise ValueError("Data tree must contain at least one array.")
141
142        first = leaves[0]
143        self._num_samples = int(first.shape[0])
144        for leaf in leaves[1:]:
145            if leaf.shape[0] != self._num_samples:
146                raise ValueError("All leaves must share the leading dimension.")
147        if self._num_samples == 0:
148            raise ValueError("Dataset cannot be empty.")
149
150        self.steps_per_epoch = self._num_samples
151        self._mask_template = jnp.ones(self._num_samples, dtype=bool)
152        self._element_spec = tree_util.tree_map(
153            lambda leaf: jax.ShapeDtypeStruct(shape=leaf.shape[1:], dtype=leaf.dtype),
154            self.data,
155        )
156
157    @property
158    def num_samples(self) -> int:
159        return self._num_samples
160
161    def element_spec(self) -> PyTree:
162        """Shape/dtype metadata describing samples produced by the source."""
163        return self._element_spec
164
165    def _build_epoch_indices(self, key: jax.Array) -> tuple[jax.Array, jax.Array]:
166        base = jnp.arange(self._num_samples)
167        if self.ordering == "shuffle":
168            base = jax.random.permutation(key, base)
169        elif self.ordering != "sequential":
170            raise ValueError(f"Unknown ordering '{self.ordering}'.")
171
172        return base, self._mask_template
173
174    def init_state(self, key: jax.Array) -> _ArraySourceState:
175        """Create the initial iteration state.
176
177        Args:
178            key: Optional PRNG key. Defaults to ``jax.random.PRNGKey(0)`` when
179                omitted.
180        """
181        key, perm_key = jax.random.split(key)
182        indices, mask = self._build_epoch_indices(perm_key)
183        position = jnp.array(0, dtype=jnp.int32)
184        epoch = jnp.array(0, dtype=jnp.int32)
185        return _ArraySourceState(indices=indices, mask=mask, position=position, key=key, epoch=epoch)
186
187    def next(self, state: _ArraySourceState) -> tuple[PyTree, jax.Array, _ArraySourceState]:
188        """Return the next sample (with mask) and the advanced state."""
189        index = jax.lax.dynamic_index_in_dim(state.indices, state.position, axis=0, keepdims=False)
190        mask_value = jax.lax.dynamic_index_in_dim(state.mask, state.position, axis=0, keepdims=False)
191        sample = tree_util.tree_map(
192            lambda arr: jax.lax.dynamic_index_in_dim(arr, index, axis=0, keepdims=False),
193            self.data,
194        )
195
196        next_position = state.position + 1
197
198        def _reset_epoch(_: None):
199            new_key, perm_key = jax.random.split(state.key)
200            indices, mask = self._build_epoch_indices(perm_key)
201            return _ArraySourceState(
202                indices=indices,
203                mask=mask,
204                position=jnp.array(0, dtype=jnp.int32),
205                key=new_key,
206                epoch=state.epoch + 1,
207            )
208
209        def _advance(_: None):
210            return _ArraySourceState(
211                indices=state.indices,
212                mask=state.mask,
213                position=next_position,
214                key=state.key,
215                epoch=state.epoch,
216            )
217
218        need_reset = next_position >= state.indices.shape[0]
219        new_state = jax.lax.cond(need_reset, _reset_epoch, _advance, operand=None)
220        return sample, mask_value, new_state

Sample-level stream over an in-memory PyTree of arrays.

Args: data: PyTree whose leaves are arrays with a leading sample dimension. ordering: Either "sequential" or "shuffle".

ArraySource(data: Any, ordering: Literal['sequential', 'shuffle'] = 'shuffle')
data: Any
ordering: Literal['sequential', 'shuffle'] = 'shuffle'
num_samples: int
157    @property
158    def num_samples(self) -> int:
159        return self._num_samples
def element_spec(self) -> Any:
161    def element_spec(self) -> PyTree:
162        """Shape/dtype metadata describing samples produced by the source."""
163        return self._element_spec

Shape/dtype metadata describing samples produced by the source.

def init_state(self, key: jax.Array) -> cyreal.sources._ArraySourceState:
174    def init_state(self, key: jax.Array) -> _ArraySourceState:
175        """Create the initial iteration state.
176
177        Args:
178            key: Optional PRNG key. Defaults to ``jax.random.PRNGKey(0)`` when
179                omitted.
180        """
181        key, perm_key = jax.random.split(key)
182        indices, mask = self._build_epoch_indices(perm_key)
183        position = jnp.array(0, dtype=jnp.int32)
184        epoch = jnp.array(0, dtype=jnp.int32)
185        return _ArraySourceState(indices=indices, mask=mask, position=position, key=key, epoch=epoch)

Create the initial iteration state.

Args: key: Optional PRNG key. Defaults to jax.random.PRNGKey(0) when omitted.

def next( self, state: cyreal.sources._ArraySourceState) -> tuple[typing.Any, jax.Array, cyreal.sources._ArraySourceState]:
187    def next(self, state: _ArraySourceState) -> tuple[PyTree, jax.Array, _ArraySourceState]:
188        """Return the next sample (with mask) and the advanced state."""
189        index = jax.lax.dynamic_index_in_dim(state.indices, state.position, axis=0, keepdims=False)
190        mask_value = jax.lax.dynamic_index_in_dim(state.mask, state.position, axis=0, keepdims=False)
191        sample = tree_util.tree_map(
192            lambda arr: jax.lax.dynamic_index_in_dim(arr, index, axis=0, keepdims=False),
193            self.data,
194        )
195
196        next_position = state.position + 1
197
198        def _reset_epoch(_: None):
199            new_key, perm_key = jax.random.split(state.key)
200            indices, mask = self._build_epoch_indices(perm_key)
201            return _ArraySourceState(
202                indices=indices,
203                mask=mask,
204                position=jnp.array(0, dtype=jnp.int32),
205                key=new_key,
206                epoch=state.epoch + 1,
207            )
208
209        def _advance(_: None):
210            return _ArraySourceState(
211                indices=state.indices,
212                mask=state.mask,
213                position=next_position,
214                key=state.key,
215                epoch=state.epoch,
216            )
217
218        need_reset = next_position >= state.indices.shape[0]
219        new_state = jax.lax.cond(need_reset, _reset_epoch, _advance, operand=None)
220        return sample, mask_value, new_state

Return the next sample (with mask) and the advanced state.

Inherited Members
Source
steps_per_epoch
@dataclass
class DiskSource(cyreal.sources.Source[cyreal.sources._DiskSourceState]):
223@dataclass
224class DiskSource(Source[_DiskSourceState]):
225    """Sample-level stream that loads items via a Python callback (disk, RPC, etc.).
226
227    This is slow, only use this if your dataset will not fit in system memory.
228    """
229
230    length: int
231    """Number of samples in the dataset."""
232    sample_fn: Callable[[int], PyTree]
233    """Python callable that takes an integer index and returns a PyTree of arrays."""
234    sample_spec: PyTree | None = None
235    """Optional PyTree of `jax.ShapeDtypeStruct` describing the shape and dtype of samples."""
236    ordering: Literal["sequential", "shuffle"] = "shuffle"
237    """Sample ordering strategy, either 'sequential' or 'shuffle'. The shuffling occurs over the entire dataset, not within the prefetch buffer."""
238    prefetch_size: int = 64
239    """Number of samples to prefetch into a JAX array buffer. Set this larger to achieve better throughput at the cost of more memory usage."""
240
241    def __post_init__(self) -> None:
242        if self.length <= 0:
243            raise ValueError("Dataset cannot be empty.")
244        if self.prefetch_size <= 0:
245            raise ValueError("prefetch_size must be positive.")
246
247        if self.sample_spec is None:
248            example = self.sample_fn(0)
249
250            def _to_spec(leaf):
251                arr = np.asarray(leaf)
252                return jax.ShapeDtypeStruct(shape=arr.shape, dtype=arr.dtype)
253
254            self.sample_spec = tree_util.tree_map(_to_spec, example)
255
256        leaves = tree_util.tree_leaves(self.sample_spec)
257        if not leaves:
258            raise ValueError("element_spec must include at least one leaf.")
259        for leaf in leaves:
260            if not isinstance(leaf, jax.ShapeDtypeStruct):
261                raise TypeError("element_spec leaves must be jax.ShapeDtypeStruct instances.")
262
263        self._num_samples = int(self.length)
264        self.steps_per_epoch = self._num_samples
265        self._element_spec = self.sample_spec
266        self.prefetch_size = int(self.prefetch_size)
267
268        def _zeros(spec: jax.ShapeDtypeStruct):
269            return np.zeros(spec.shape, dtype=np.dtype(spec.dtype))
270
271        def _buffer_shape(spec: jax.ShapeDtypeStruct):
272            return jax.ShapeDtypeStruct(
273                shape=(self.prefetch_size, *spec.shape),
274                dtype=spec.dtype,
275            )
276
277        self._zero_sample = tree_util.tree_map(_zeros, self.sample_spec)
278        self._chunk_spec = tree_util.tree_map(_buffer_shape, self.sample_spec)
279        self._buffer_template = tree_util.tree_map(
280            lambda spec: jnp.zeros((self.prefetch_size, *spec.shape), dtype=spec.dtype),
281            self.sample_spec,
282        )
283
284    def element_spec(self) -> PyTree:
285        """Shape/dtype metadata describing samples produced by the source."""
286        return self._element_spec
287
288    def _build_epoch_indices(self, key: jax.Array) -> jax.Array:
289        base = jnp.arange(self._num_samples)
290        if self.ordering == "shuffle":
291            base = jax.random.permutation(key, base)
292        elif self.ordering != "sequential":
293            raise ValueError(f"Unknown ordering '{self.ordering}'.")
294        return base
295
296    def init_state(self, key: jax.Array) -> _DiskSourceState:
297        """Build the starting state, optionally seeding randomness with ``key``."""
298        key, perm_key = jax.random.split(key)
299        indices = self._build_epoch_indices(perm_key)
300        position = jnp.array(0, dtype=jnp.int32)
301        epoch = jnp.array(0, dtype=jnp.int32)
302        return _DiskSourceState(
303            indices=indices,
304            position=position,
305            key=key,
306            epoch=epoch,
307            buffer=self._buffer_template,
308            buffer_pos=jnp.array(0, dtype=jnp.int32),
309            buffer_count=jnp.array(0, dtype=jnp.int32),
310        )
311
312    def _chunk_callback(self, indices: np.ndarray, mask: np.ndarray) -> PyTree:
313        idx_array = np.asarray(indices, dtype=np.int64)
314        mask_array = np.asarray(mask, dtype=bool)
315        samples: list[PyTree] = []
316        for keep, idx in zip(mask_array, idx_array):
317            if keep:
318                samples.append(self.sample_fn(int(idx)))
319            else:
320                samples.append(self._zero_sample)
321        return tree_util.tree_map(lambda *xs: np.stack(xs, axis=0), *samples)
322
323    def _maybe_reset_epoch(self, state: _DiskSourceState) -> _DiskSourceState:
324        def _reset(state: _DiskSourceState):
325            new_key, perm_key = jax.random.split(state.key)
326            indices = self._build_epoch_indices(perm_key)
327            return _DiskSourceState(
328                indices=indices,
329                position=jnp.array(0, dtype=jnp.int32),
330                key=new_key,
331                epoch=state.epoch + 1,
332                buffer=self._buffer_template,
333                buffer_pos=jnp.array(0, dtype=jnp.int32),
334                buffer_count=jnp.array(0, dtype=jnp.int32),
335            )
336
337        return jax.lax.cond(state.position >= self._num_samples, _reset, lambda s: s, state)
338
339    def _maybe_refill_buffer(self, state: _DiskSourceState) -> _DiskSourceState:
340        def _needs(state: _DiskSourceState):
341            return jnp.logical_or(state.buffer_count == 0, state.buffer_pos >= state.buffer_count)
342
343        def _refill(state: _DiskSourceState):
344            refreshed = self._maybe_reset_epoch(state)
345            remaining = self._num_samples - refreshed.position
346            chunk = jnp.minimum(remaining, self.prefetch_size)
347            chunk = jnp.maximum(chunk, 0)
348            chunk = chunk.astype(jnp.int32)
349            offsets = jnp.arange(self.prefetch_size, dtype=jnp.int32)
350            gather_positions = jnp.minimum(
351                refreshed.position + offsets,
352                refreshed.indices.shape[0] - 1,
353            )
354            chunk_indices = jax.vmap(
355                lambda idx: jax.lax.dynamic_index_in_dim(
356                    refreshed.indices, idx, axis=0, keepdims=False
357                )
358            )(gather_positions)
359            valid_mask = offsets < chunk
360            buffer = io_callback(self._chunk_callback, self._chunk_spec, chunk_indices, valid_mask)
361            new_position = refreshed.position + chunk
362            return _DiskSourceState(
363                indices=refreshed.indices,
364                position=new_position,
365                key=refreshed.key,
366                epoch=refreshed.epoch,
367                buffer=buffer,
368                buffer_pos=jnp.array(0, dtype=jnp.int32),
369                buffer_count=chunk,
370            )
371
372        return jax.lax.cond(_needs(state), _refill, lambda s: s, state)
373
374    def next(self, state: _DiskSourceState) -> tuple[PyTree, jax.Array, _DiskSourceState]:
375        """Return buffered sample, all-True mask, and updated state."""
376        state = self._maybe_refill_buffer(state)
377        sample = tree_util.tree_map(
378            lambda buf: jax.lax.dynamic_index_in_dim(
379                buf, state.buffer_pos, axis=0, keepdims=False
380            ),
381            state.buffer,
382        )
383        mask_value = jnp.array(True, dtype=bool)
384        new_state = _DiskSourceState(
385            indices=state.indices,
386            position=state.position,
387            key=state.key,
388            epoch=state.epoch,
389            buffer=state.buffer,
390            buffer_pos=state.buffer_pos + 1,
391            buffer_count=state.buffer_count,
392        )
393        return sample, mask_value, new_state

Sample-level stream that loads items via a Python callback (disk, RPC, etc.).

This is slow, only use this if your dataset will not fit in system memory.

DiskSource( length: int, sample_fn: Callable[[int], Any], sample_spec: Any | None = None, ordering: Literal['sequential', 'shuffle'] = 'shuffle', prefetch_size: int = 64)
length: int

Number of samples in the dataset.

sample_fn: Callable[[int], Any]

Python callable that takes an integer index and returns a PyTree of arrays.

sample_spec: Any | None = None

Optional PyTree of jax.ShapeDtypeStruct describing the shape and dtype of samples.

ordering: Literal['sequential', 'shuffle'] = 'shuffle'

Sample ordering strategy, either 'sequential' or 'shuffle'. The shuffling occurs over the entire dataset, not within the prefetch buffer.

prefetch_size: int = 64

Number of samples to prefetch into a JAX array buffer. Set this larger to achieve better throughput at the cost of more memory usage.

def element_spec(self) -> Any:
284    def element_spec(self) -> PyTree:
285        """Shape/dtype metadata describing samples produced by the source."""
286        return self._element_spec

Shape/dtype metadata describing samples produced by the source.

def init_state(self, key: jax.Array) -> cyreal.sources._DiskSourceState:
296    def init_state(self, key: jax.Array) -> _DiskSourceState:
297        """Build the starting state, optionally seeding randomness with ``key``."""
298        key, perm_key = jax.random.split(key)
299        indices = self._build_epoch_indices(perm_key)
300        position = jnp.array(0, dtype=jnp.int32)
301        epoch = jnp.array(0, dtype=jnp.int32)
302        return _DiskSourceState(
303            indices=indices,
304            position=position,
305            key=key,
306            epoch=epoch,
307            buffer=self._buffer_template,
308            buffer_pos=jnp.array(0, dtype=jnp.int32),
309            buffer_count=jnp.array(0, dtype=jnp.int32),
310        )

Build the starting state, optionally seeding randomness with key.

def next( self, state: cyreal.sources._DiskSourceState) -> tuple[typing.Any, jax.Array, cyreal.sources._DiskSourceState]:
374    def next(self, state: _DiskSourceState) -> tuple[PyTree, jax.Array, _DiskSourceState]:
375        """Return buffered sample, all-True mask, and updated state."""
376        state = self._maybe_refill_buffer(state)
377        sample = tree_util.tree_map(
378            lambda buf: jax.lax.dynamic_index_in_dim(
379                buf, state.buffer_pos, axis=0, keepdims=False
380            ),
381            state.buffer,
382        )
383        mask_value = jnp.array(True, dtype=bool)
384        new_state = _DiskSourceState(
385            indices=state.indices,
386            position=state.position,
387            key=state.key,
388            epoch=state.epoch,
389            buffer=state.buffer,
390            buffer_pos=state.buffer_pos + 1,
391            buffer_count=state.buffer_count,
392        )
393        return sample, mask_value, new_state

Return buffered sample, all-True mask, and updated state.

Inherited Members
Source
steps_per_epoch
@jax.tree_util.register_pytree_node_class
@dataclass
class GymnaxSourceState:
398@jax.tree_util.register_pytree_node_class
399@dataclass
400class GymnaxSourceState:
401    env_state: PyTree
402    obs: PyTree
403    key: jax.Array
404    step: jax.Array
405    epoch: jax.Array
406    policy_state: PyTree | None = None
407    new_episode: jax.Array | None = None
408
409    def tree_flatten(self):
410        return (
411            self.env_state,
412            self.obs,
413            self.key,
414            self.step,
415            self.epoch,
416            self.policy_state,
417            self.new_episode,
418        ), None
419
420    @classmethod
421    def tree_unflatten(cls, aux_data, children):
422        env_state, obs, key, step, epoch, policy_state, new_episode = children
423        return cls(
424            env_state=env_state,
425            obs=obs,
426            key=key,
427            step=step,
428            epoch=epoch,
429            policy_state=policy_state,
430            new_episode=new_episode,
431        )
GymnaxSourceState( env_state: Any, obs: Any, key: jax.Array, step: jax.Array, epoch: jax.Array, policy_state: Any | None = None, new_episode: jax.Array | None = None)
env_state: Any
obs: Any
key: jax.Array
step: jax.Array
epoch: jax.Array
policy_state: Any | None = None
new_episode: jax.Array | None = None
def tree_flatten(self):
409    def tree_flatten(self):
410        return (
411            self.env_state,
412            self.obs,
413            self.key,
414            self.step,
415            self.epoch,
416            self.policy_state,
417            self.new_episode,
418        ), None
@classmethod
def tree_unflatten(cls, aux_data, children):
420    @classmethod
421    def tree_unflatten(cls, aux_data, children):
422        env_state, obs, key, step, epoch, policy_state, new_episode = children
423        return cls(
424            env_state=env_state,
425            obs=obs,
426            key=key,
427            step=step,
428            epoch=epoch,
429            policy_state=policy_state,
430            new_episode=new_episode,
431        )
@dataclass
class GymnaxSource(cyreal.sources.Source[cyreal.sources.GymnaxSourceState]):
434@dataclass
435class GymnaxSource(Source[GymnaxSourceState]):
436    """Stream transitions by rolling out a Gymnax environment with a policy.
437    
438    Useful for reinforcement learning.
439
440    Args:
441        env: Gymnax environment instance.
442        env_params: Parameters to pass to the environment's reset and step functions.
443        policy_step_fn: Callable that takes (observation, policy_state, new_episode, key) and
444            returns (action, new_policy_state).
445        policy_state_template: Example PyTree carrying everything required by
446            ``policy_step_fn`` (for example, policy parameters and recurrent
447            carries). This template is used only to infer the element spec; callers
448            are responsible for injecting a real policy state into the loader
449            state before calling ``next``.
450        steps_per_epoch: Number of environment steps per epoch for a single environment.
451    """
452
453    env: Any
454    env_params: Any
455    policy_step_fn: Callable[[PyTree, PyTree, jax.Array, jax.Array], tuple[PyTree, PyTree]]
456    policy_state_template: PyTree | None = None
457    steps_per_epoch: int = 1024
458
459    def __post_init__(self) -> None:
460        if self.steps_per_epoch <= 0:
461            raise ValueError("steps_per_epoch must be positive.")
462        if self.policy_state_template is None:
463            raise ValueError("GymnaxSource requires a policy_state_template for shape inference.")
464
465        def _sample(key, policy_state):
466            obs, env_state = self.env.reset(key, self.env_params)
467            action, next_policy_state = self.policy_step_fn(
468                obs,
469                policy_state,
470                jnp.array(True, dtype=jnp.bool_),
471                key,
472            )
473            next_obs, _, reward, done, info = self.env.step(
474                key,
475                env_state,
476                action,
477                self.env_params,
478            )
479            transition = {
480                "state": obs,
481                "action": action,
482                "reward": reward,
483                "next_state": next_obs,
484                "done": done,
485                "info": info,
486            }
487            return transition, next_policy_state
488
489        shaped, _ = jax.eval_shape(_sample, jax.random.PRNGKey(0), self.policy_state_template)
490        self._element_spec = tree_util.tree_map(
491            lambda arr: jax.ShapeDtypeStruct(shape=arr.shape, dtype=arr.dtype), shaped
492        )
493        self.policy_state_template = None
494
495    def element_spec(self) -> PyTree:
496        """Shape/dtype metadata describing Gymnax transitions."""
497        return self._element_spec
498
499    def init_state(self, key: jax.Array) -> GymnaxSourceState:
500        """Return RNG-seeded environment + policy state for iteration."""
501        key, env_key = jax.random.split(key)
502        obs, env_state = self.env.reset(env_key, self.env_params)
503        return GymnaxSourceState(
504            env_state=env_state,
505            obs=obs,
506            key=key,
507            step=jnp.array(0, dtype=jnp.int32),
508            epoch=jnp.array(0, dtype=jnp.int32),
509            policy_state=None,
510            new_episode=jnp.array(True, dtype=jnp.bool_),
511        )
512
513    def next(self, state: GymnaxSourceState) -> tuple[PyTree, jax.Array, GymnaxSourceState]:
514        """Roll the environment forward one step and emit a transition."""
515        key, policy_key, step_key, done_reset_key, epoch_reset_key = jax.random.split(state.key, 5)
516
517        if state.policy_state is None:
518            raise ValueError(
519                "GymnaxSource state is missing `policy_state`; set it explicitly before calling `next`."
520            )
521        policy_state = state.policy_state
522
523        if state.new_episode is None:
524            raise ValueError("GymnaxSource state is missing `new_episode` flag.")
525
526        action, updated_policy_state = self.policy_step_fn(
527            state.obs,
528            policy_state,
529            state.new_episode,
530            policy_key,
531        )
532        next_obs, next_env_state, reward, done, info = self.env.step(
533            step_key,
534            state.env_state,
535            action,
536            self.env_params,
537        )
538
539        transition = {
540            "state": state.obs,
541            "action": action,
542            "reward": reward,
543            "next_state": next_obs,
544            "done": done,
545            "info": info,
546        }
547        mask = jnp.array(True, dtype=bool)
548
549        done_flag = jnp.asarray(done, dtype=bool)
550        done_flag = jnp.reshape(done_flag, ())
551        reset_obs, reset_env_state = self.env.reset(done_reset_key, self.env_params)
552
553        cont_obs, cont_env_state = jax.lax.cond(
554            done_flag,
555            lambda _: (reset_obs, reset_env_state),
556            lambda _: (next_obs, next_env_state),
557            operand=None,
558        )
559
560        next_step = state.step + 1
561        need_epoch_reset = next_step >= self.steps_per_epoch
562
563        def _reset_epoch(_: None):
564            epoch_obs, epoch_env_state = self.env.reset(epoch_reset_key, self.env_params)
565            return GymnaxSourceState(
566                env_state=epoch_env_state,
567                obs=epoch_obs,
568                key=key,
569                step=jnp.array(0, dtype=jnp.int32),
570                epoch=state.epoch + 1,
571                policy_state=updated_policy_state,
572                new_episode=jnp.array(True, dtype=jnp.bool_),
573            )
574
575        def _continue(_: None):
576            return GymnaxSourceState(
577                env_state=cont_env_state,
578                obs=cont_obs,
579                key=key,
580                step=next_step,
581                epoch=state.epoch,
582                policy_state=updated_policy_state,
583                new_episode=done_flag,
584            )
585
586        new_state = jax.lax.cond(need_epoch_reset, _reset_epoch, _continue, operand=None)
587        return transition, mask, new_state

Stream transitions by rolling out a Gymnax environment with a policy.

Useful for reinforcement learning.

Args: env: Gymnax environment instance. env_params: Parameters to pass to the environment's reset and step functions. policy_step_fn: Callable that takes (observation, policy_state, new_episode, key) and returns (action, new_policy_state). policy_state_template: Example PyTree carrying everything required by policy_step_fn (for example, policy parameters and recurrent carries). This template is used only to infer the element spec; callers are responsible for injecting a real policy state into the loader state before calling next. steps_per_epoch: Number of environment steps per epoch for a single environment.

GymnaxSource( env: Any, env_params: Any, policy_step_fn: Callable[[Any, Any, jax.Array, jax.Array], tuple[Any, Any]], policy_state_template: Any | None = None, steps_per_epoch: int = 1024)
env: Any
env_params: Any
policy_step_fn: Callable[[Any, Any, jax.Array, jax.Array], tuple[Any, Any]]
policy_state_template: Any | None = None
steps_per_epoch: int = 1024

Number of items emitted per epoch.

def element_spec(self) -> Any:
495    def element_spec(self) -> PyTree:
496        """Shape/dtype metadata describing Gymnax transitions."""
497        return self._element_spec

Shape/dtype metadata describing Gymnax transitions.

def init_state(self, key: jax.Array) -> GymnaxSourceState:
499    def init_state(self, key: jax.Array) -> GymnaxSourceState:
500        """Return RNG-seeded environment + policy state for iteration."""
501        key, env_key = jax.random.split(key)
502        obs, env_state = self.env.reset(env_key, self.env_params)
503        return GymnaxSourceState(
504            env_state=env_state,
505            obs=obs,
506            key=key,
507            step=jnp.array(0, dtype=jnp.int32),
508            epoch=jnp.array(0, dtype=jnp.int32),
509            policy_state=None,
510            new_episode=jnp.array(True, dtype=jnp.bool_),
511        )

Return RNG-seeded environment + policy state for iteration.

def next( self, state: GymnaxSourceState) -> tuple[typing.Any, jax.Array, GymnaxSourceState]:
513    def next(self, state: GymnaxSourceState) -> tuple[PyTree, jax.Array, GymnaxSourceState]:
514        """Roll the environment forward one step and emit a transition."""
515        key, policy_key, step_key, done_reset_key, epoch_reset_key = jax.random.split(state.key, 5)
516
517        if state.policy_state is None:
518            raise ValueError(
519                "GymnaxSource state is missing `policy_state`; set it explicitly before calling `next`."
520            )
521        policy_state = state.policy_state
522
523        if state.new_episode is None:
524            raise ValueError("GymnaxSource state is missing `new_episode` flag.")
525
526        action, updated_policy_state = self.policy_step_fn(
527            state.obs,
528            policy_state,
529            state.new_episode,
530            policy_key,
531        )
532        next_obs, next_env_state, reward, done, info = self.env.step(
533            step_key,
534            state.env_state,
535            action,
536            self.env_params,
537        )
538
539        transition = {
540            "state": state.obs,
541            "action": action,
542            "reward": reward,
543            "next_state": next_obs,
544            "done": done,
545            "info": info,
546        }
547        mask = jnp.array(True, dtype=bool)
548
549        done_flag = jnp.asarray(done, dtype=bool)
550        done_flag = jnp.reshape(done_flag, ())
551        reset_obs, reset_env_state = self.env.reset(done_reset_key, self.env_params)
552
553        cont_obs, cont_env_state = jax.lax.cond(
554            done_flag,
555            lambda _: (reset_obs, reset_env_state),
556            lambda _: (next_obs, next_env_state),
557            operand=None,
558        )
559
560        next_step = state.step + 1
561        need_epoch_reset = next_step >= self.steps_per_epoch
562
563        def _reset_epoch(_: None):
564            epoch_obs, epoch_env_state = self.env.reset(epoch_reset_key, self.env_params)
565            return GymnaxSourceState(
566                env_state=epoch_env_state,
567                obs=epoch_obs,
568                key=key,
569                step=jnp.array(0, dtype=jnp.int32),
570                epoch=state.epoch + 1,
571                policy_state=updated_policy_state,
572                new_episode=jnp.array(True, dtype=jnp.bool_),
573            )
574
575        def _continue(_: None):
576            return GymnaxSourceState(
577                env_state=cont_env_state,
578                obs=cont_obs,
579                key=key,
580                step=next_step,
581                epoch=state.epoch,
582                policy_state=updated_policy_state,
583                new_episode=done_flag,
584            )
585
586        new_state = jax.lax.cond(need_epoch_reset, _reset_epoch, _continue, operand=None)
587        return transition, mask, new_state

Roll the environment forward one step and emit a transition.