cyreal.tutorials.buffer_quickstart
Example using BufferTransform for reservoir replay or FIFO buffering.
import jax
import jax.numpy as jnp
from cyreal.sources import ArraySource
from cyreal.transforms import BufferTransform, BatchTransform
from cyreal.loader import DataLoader
from cyreal.datasets import MNISTDataset
train_data = MNISTDataset(split="train").as_array_dict()
pipeline = [
ArraySource(train_data, ordering="shuffle"),
# We have a lot of options for the BufferTransform
# You can use it for either reservoir sampling or FIFO buffering
# Prefill determines how many samples to wait before yielding batches
BufferTransform(capacity=128, prefill=16, sample_size=16, mode="shuffled", write_mode="reservoir"),
# BufferTransform yields 16 samples, and we can perform additional subsampling with
# BatchTransform if necessary
BatchTransform(batch_size=8),
]
loader = DataLoader(pipeline)
loader_state = loader.init_state(jax.random.key(0))
sample, mask, loader_state = jax.jit(loader.next)(loader_state)
1"""Example using `BufferTransform` for reservoir replay or FIFO buffering. 2 3```python 4import jax 5import jax.numpy as jnp 6 7from cyreal.sources import ArraySource 8from cyreal.transforms import BufferTransform, BatchTransform 9from cyreal.loader import DataLoader 10from cyreal.datasets import MNISTDataset 11 12train_data = MNISTDataset(split="train").as_array_dict() 13pipeline = [ 14 ArraySource(train_data, ordering="shuffle"), 15 # We have a lot of options for the BufferTransform 16 # You can use it for either reservoir sampling or FIFO buffering 17 # Prefill determines how many samples to wait before yielding batches 18 BufferTransform(capacity=128, prefill=16, sample_size=16, mode="shuffled", write_mode="reservoir"), 19 # BufferTransform yields 16 samples, and we can perform additional subsampling with 20 # BatchTransform if necessary 21 BatchTransform(batch_size=8), 22] 23loader = DataLoader(pipeline) 24loader_state = loader.init_state(jax.random.key(0)) 25sample, mask, loader_state = jax.jit(loader.next)(loader_state) 26``` 27"""