cyreal.tutorials.host_callback
HostCallbackTransform allows you to log metrics and call other impure IO within jit.
import jax
import jax.numpy as jnp
import numpy as np
from cyreal.sources import ArraySource
from cyreal.transforms import BatchTransform, HostCallbackTransform
from cyreal.loader import DataLoader
from cyreal.datasets import MNISTDataset
def model(images):
return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))
def cross_entropy(logits, labels):
labels = labels.astype(jnp.float32)
return (logits - labels) ** 2
def log_loss(batch, mask):
logits = model(batch["image"])
loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
print("loss:", float(np.asarray(loss)))
return batch
loader = DataLoader(
pipeline=[
ArraySource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
BatchTransform(batch_size=128),
HostCallbackTransform(fn=log_loss),
],
)
# Still jittable
state = loader.init_state(jax.random.key(0))
sample, mask, state = jax.jit(loader.next)(state)
1"""`HostCallbackTransform` allows you to log metrics and call other impure IO within jit. 2 3```python 4import jax 5import jax.numpy as jnp 6import numpy as np 7 8from cyreal.sources import ArraySource 9from cyreal.transforms import BatchTransform, HostCallbackTransform 10from cyreal.loader import DataLoader 11from cyreal.datasets import MNISTDataset 12 13def model(images): 14 return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3)) 15 16def cross_entropy(logits, labels): 17 labels = labels.astype(jnp.float32) 18 return (logits - labels) ** 2 19 20def log_loss(batch, mask): 21 logits = model(batch["image"]) 22 loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None]) 23 print("loss:", float(np.asarray(loss))) 24 return batch 25 26loader = DataLoader( 27 pipeline=[ 28 ArraySource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"), 29 BatchTransform(batch_size=128), 30 HostCallbackTransform(fn=log_loss), 31 ], 32) 33# Still jittable 34state = loader.init_state(jax.random.key(0)) 35sample, mask, state = jax.jit(loader.next)(state) 36``` 37"""