cyreal.tutorials.scan_and_jit
For those premature optimizers, you can jit your entire training epoch.
import jax
import jax.numpy as jnp
from cyreal.sources import ArraySource
from cyreal.transforms import BatchTransform
from cyreal.loader import DataLoader
from cyreal.datasets import MNISTDataset
train_data = MNISTDataset(split="test").as_array_dict()
pipeline = [
ArraySource(train_data, ordering="shuffle"),
BatchTransform(batch_size=128),
]
loader = DataLoader(pipeline)
loader_state = loader.init_state(jax.random.key(0))
model_state = {"params": jnp.array(0)}
@jax.jit
def train_epoch(model_state, loader_state):
def model_update(model_state, batch, mask):
# Update the network using your train fn
new_model_state = {"params": model_state['params'] + 1}
return new_model_state, None
# scan_epoch is a helper method, but the loader itself is fully JIT-compatible
# in case you want to roll your own training loop.
loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, model_update)
return model_state, loader_state
model_state, loader_state = train_epoch(model_state, loader_state)
1"""For those premature optimizers, you can `jit` your entire training epoch. 2 3```python 4import jax 5import jax.numpy as jnp 6 7from cyreal.sources import ArraySource 8from cyreal.transforms import BatchTransform 9from cyreal.loader import DataLoader 10from cyreal.datasets import MNISTDataset 11 12train_data = MNISTDataset(split="test").as_array_dict() 13pipeline = [ 14 ArraySource(train_data, ordering="shuffle"), 15 BatchTransform(batch_size=128), 16] 17loader = DataLoader(pipeline) 18loader_state = loader.init_state(jax.random.key(0)) 19 20model_state = {"params": jnp.array(0)} 21 22@jax.jit 23def train_epoch(model_state, loader_state): 24 def model_update(model_state, batch, mask): 25 # Update the network using your train fn 26 new_model_state = {"params": model_state['params'] + 1} 27 return new_model_state, None 28 29 # scan_epoch is a helper method, but the loader itself is fully JIT-compatible 30 # in case you want to roll your own training loop. 31 loader_state, model_state, _ = loader.scan_epoch(loader_state, model_state, model_update) 32 return model_state, loader_state 33 34model_state, loader_state = train_epoch(model_state, loader_state) 35``` 36"""