cyreal.tutorials.buffer_quickstart

Example using BufferTransform for reservoir replay or FIFO buffering.

import jax
import jax.numpy as jnp

from cyreal.sources import ArraySource
from cyreal.transforms import BufferTransform, BatchTransform
from cyreal.loader import DataLoader
from cyreal.datasets import MNISTDataset

train_data = MNISTDataset(split="train").as_array_dict()
pipeline = [
    ArraySource(train_data, ordering="shuffle"),
    # We have a lot of options for the BufferTransform
    # You can use it for either reservoir sampling or FIFO buffering
    # Prefill determines how many samples to wait before yielding batches
    BufferTransform(capacity=128, prefill=16, sample_size=16, mode="shuffled", write_mode="reservoir"),
    # BufferTransform yields 16 samples, and we can perform additional subsampling with
    # BatchTransform if necessary
    BatchTransform(batch_size=8),
]
loader = DataLoader(pipeline)
loader_state = loader.init_state(jax.random.key(0))
sample, mask, loader_state = jax.jit(loader.next)(loader_state)
 1"""Example using `BufferTransform` for reservoir replay or FIFO buffering.
 2
 3```python
 4import jax
 5import jax.numpy as jnp
 6
 7from cyreal.sources import ArraySource
 8from cyreal.transforms import BufferTransform, BatchTransform
 9from cyreal.loader import DataLoader
10from cyreal.datasets import MNISTDataset
11
12train_data = MNISTDataset(split="train").as_array_dict()
13pipeline = [
14    ArraySource(train_data, ordering="shuffle"),
15    # We have a lot of options for the BufferTransform
16    # You can use it for either reservoir sampling or FIFO buffering
17    # Prefill determines how many samples to wait before yielding batches
18    BufferTransform(capacity=128, prefill=16, sample_size=16, mode="shuffled", write_mode="reservoir"),
19    # BufferTransform yields 16 samples, and we can perform additional subsampling with
20    # BatchTransform if necessary
21    BatchTransform(batch_size=8),
22]
23loader = DataLoader(pipeline)
24loader_state = loader.init_state(jax.random.key(0))
25sample, mask, loader_state = jax.jit(loader.next)(loader_state)
26```
27"""