cyreal.tutorials.disk_stream

Streaming from disk using DiskSource. This is much slower than an in-memory dataset but it allows you to work with datasets that do not fit into RAM. The key is to call make_disk_source on your dataset class to get a disk-backed source.

import jax

from cyreal.transforms import BatchTransform
from cyreal.loader import DataLoader
from cyreal.datasets import MNISTDataset

pipeline = [
    # Prefetch 1024 examples for each disk read
    MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
    BatchTransform(batch_size=128),
]

loader = DataLoader(pipeline=pipeline)
state = loader.init_state(jax.random.PRNGKey(0))

for batch, mask in loader.iterate(state):
    ...  # stream without holding the dataset in RAM
 1"""Streaming from disk using `DiskSource`. This is much slower than an in-memory dataset
 2but it allows you to work with datasets that do not fit into RAM. The key is to call
 3`make_disk_source` on your dataset class to get a disk-backed source.
 4
 5```python
 6import jax
 7
 8from cyreal.transforms import BatchTransform
 9from cyreal.loader import DataLoader
10from cyreal.datasets import MNISTDataset
11
12pipeline = [
13    # Prefetch 1024 examples for each disk read
14    MNISTDataset.make_disk_source(split="train", ordering="shuffle", prefetch_size=1024),
15    BatchTransform(batch_size=128),
16]
17
18loader = DataLoader(pipeline=pipeline)
19state = loader.init_state(jax.random.PRNGKey(0))
20
21for batch, mask in loader.iterate(state):
22    ...  # stream without holding the dataset in RAM
23```
24"""