cyreal.tutorials.scan_and_jit

For those premature optimizers, you can jit your entire training epoch.

import jax
import jax.numpy as jnp

from cyreal.sources import ArraySource
from cyreal.transforms import BatchTransform
from cyreal.loader import DataLoader
from cyreal.datasets import MNISTDataset

train_data = MNISTDataset(split="test").as_array_dict()
pipeline = [
    ArraySource(train_data, ordering="shuffle"),
    BatchTransform(batch_size=128),
]
loader = DataLoader(pipeline)
loader_state = loader.init_state(jax.random.key(0))

model_state = {"params": jnp.array(0)}

@jax.jit
def train_epoch(model_state, loader_state):
    def model_update(model_state, batch, mask):
        # Update the network using your train fn
        new_model_state = {"params": model_state['params'] + 1}
        return new_model_state, None

    # scan_epoch is a helper method, but the loader itself is fully JIT-compatible
    # in case you want to roll your own training loop.
    loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, model_update)
    return model_state, loader_state

model_state, loader_state = train_epoch(model_state, loader_state)
 1"""For those premature optimizers, you can `jit` your entire training epoch.
 2
 3```python
 4import jax
 5import jax.numpy as jnp
 6
 7from cyreal.sources import ArraySource
 8from cyreal.transforms import BatchTransform
 9from cyreal.loader import DataLoader
10from cyreal.datasets import MNISTDataset
11
12train_data = MNISTDataset(split="test").as_array_dict()
13pipeline = [
14    ArraySource(train_data, ordering="shuffle"),
15    BatchTransform(batch_size=128),
16]
17loader = DataLoader(pipeline)
18loader_state = loader.init_state(jax.random.key(0))
19
20model_state = {"params": jnp.array(0)}
21
22@jax.jit
23def train_epoch(model_state, loader_state):
24    def model_update(model_state, batch, mask):
25        # Update the network using your train fn
26        new_model_state = {"params": model_state['params'] + 1}
27        return new_model_state, None
28
29    # scan_epoch is a helper method, but the loader itself is fully JIT-compatible
30    # in case you want to roll your own training loop.
31    loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, model_update)
32    return model_state, loader_state
33
34model_state, loader_state = train_epoch(model_state, loader_state)
35```
36"""