cyreal.tutorials.disk_stream
Streaming from disk using DiskSource. This is much slower than an in-memory dataset
but it allows you to work with datasets that do not fit into RAM. The key is to call
make_disk_source on your dataset class to get a disk-backed source.
import jax
from cyreal.transforms import BatchTransform
from cyreal.loader import DataLoader
from cyreal.datasets import MNISTDataset
pipeline = [
# Prefetch 1024 examples for each disk read
MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
BatchTransform(batch_size=128),
]
loader = DataLoader(pipeline=pipeline)
state = loader.init_state(jax.random.PRNGKey(0))
for batch, mask in loader.iterate(state):
... # stream without holding the dataset in RAM
1"""Streaming from disk using `DiskSource`. This is much slower than an in-memory dataset 2but it allows you to work with datasets that do not fit into RAM. The key is to call 3`make_disk_source` on your dataset class to get a disk-backed source. 4 5```python 6import jax 7 8from cyreal.transforms import BatchTransform 9from cyreal.loader import DataLoader 10from cyreal.datasets import MNISTDataset 11 12pipeline = [ 13 # Prefetch 1024 examples for each disk read 14 MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024), 15 BatchTransform(batch_size=128), 16] 17 18loader = DataLoader(pipeline=pipeline) 19state = loader.init_state(jax.random.PRNGKey(0)) 20 21for batch, mask in loader.iterate(state): 22 ... # stream without holding the dataset in RAM 23``` 24"""