cyreal.tutorials.host_callback

HostCallbackTransform allows you to log metrics and call other impure IO within jit.

import jax
import jax.numpy as jnp
import numpy as np

from cyreal.sources import ArraySource
from cyreal.transforms import BatchTransform, HostCallbackTransform
from cyreal.loader import DataLoader
from cyreal.datasets import MNISTDataset

def model(images):
    return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))

def cross_entropy(logits, labels):
    labels = labels.astype(jnp.float32)
    return (logits - labels) ** 2

def log_loss(batch, mask):
    logits = model(batch["image"])
    loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
    print("loss:", float(np.asarray(loss)))
    return batch

loader = DataLoader(
    pipeline=[
        ArraySource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
        BatchTransform(batch_size=128),
        HostCallbackTransform(fn=log_loss),
    ],
)
# Still jittable
state = loader.init_state(jax.random.key(0))
sample, mask, state = jax.jit(loader.next)(state)
 1"""`HostCallbackTransform` allows you to log metrics and call other impure IO within jit.
 2
 3```python
 4import jax
 5import jax.numpy as jnp
 6import numpy as np
 7
 8from cyreal.sources import ArraySource
 9from cyreal.transforms import BatchTransform, HostCallbackTransform
10from cyreal.loader import DataLoader
11from cyreal.datasets import MNISTDataset
12
13def model(images):
14    return jnp.mean(images.astype(jnp.float32), axis=(1, 2, 3))
15
16def cross_entropy(logits, labels):
17    labels = labels.astype(jnp.float32)
18    return (logits - labels) ** 2
19
20def log_loss(batch, mask):
21    logits = model(batch["image"])
22    loss = jnp.mean(cross_entropy(logits, batch["label"]) * mask[:, None])
23    print("loss:", float(np.asarray(loss)))
24    return batch
25
26loader = DataLoader(
27    pipeline=[
28        ArraySource(MNISTDataset(split="train").as_array_dict(), ordering="shuffle"),
29        BatchTransform(batch_size=128),
30        HostCallbackTransform(fn=log_loss),
31    ],
32)
33# Still jittable
34state = loader.init_state(jax.random.key(0))
35sample, mask, state = jax.jit(loader.next)(state)
36```
37"""