# Grain - Feeding JAX Models Grain is an open-source Python library by Google for reading and processing data for training and evaluating JAX models. It is designed to be flexible, fast, and deterministic: given the same configuration, multiple runs of the same pipeline always produce the same output. Grain is also resilient to preemptions — its iterators support lightweight checkpointing so training can resume exactly where it left off. While optimized for JAX, Grain does not require JAX as a dependency and works with other ML frameworks. Grain exposes two high-level APIs: `Dataset` (recommended for complex pipelines involving mixing, packing, or splitting) and `DataLoader` (recommended for simpler sequential pipelines). The `Dataset` API centers on three classes — `MapDataset`, `IterDataset`, and `DatasetIterator` — that compose lazily through method chaining. The `DataLoader` API pairs a `RandomAccessDataSource` with a `Sampler` and a flat sequence of `Transformation` objects. Both APIs support multi-process prefetching, distributed sharding, deterministic shuffling, and Orbax-based checkpointing. --- ## Installation ```bash pip install grain ``` --- ## `MapDataset.source` — Create a dataset from a sequence or data source Wraps any Python `Sequence` or `RandomAccessDataSource` into a `MapDataset` supporting efficient random access. This is the typical starting point for a `Dataset` pipeline. ```python import grain # From a plain Python list ds = grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) print(ds[3]) # 3 print(len(ds)) # 10 # From a custom RandomAccessDataSource class MySource(grain.sources.RandomAccessDataSource): def __init__(self, data): self._data = data def __len__(self): return len(self._data) def __getitem__(self, idx): return self._data[idx] source = MySource(["a", "b", "c", "d", "e"]) ds = grain.MapDataset.source(source) print(ds[1]) # "b" print(list(ds)) # ['a', 'b', 'c', 'd', 'e'] ``` --- ## `MapDataset.range` — Create an integer range dataset Constructs a `MapDataset` whose elements are integers, mirroring the semantics of Python's built-in `range`. ```python import grain ds = grain.MapDataset.range(10) print(list(ds)) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ds = grain.MapDataset.range(2, 8, 2) print(list(ds)) # [2, 4, 6] ``` --- ## `MapDataset.shuffle` — Globally shuffle a dataset Returns a new `MapDataset` with the same elements in a deterministically shuffled order. Each epoch uses a different permutation when used with `repeat`. Pass `seed` directly or via the pipeline-level `ds.seed(seed)`. ```python import grain ds = ( grain.MapDataset.range(10) .shuffle(seed=42) ) print(list(ds)) # e.g. [3, 7, 0, 5, 9, 1, 6, 2, 8, 4] (deterministic) # Pipeline-level seed: each downstream random op derives its own unique seed ds = ( grain.MapDataset.range(10) .seed(42) .shuffle() # derives seed from pipeline seed ) print(set(ds) == set(range(10))) # True ``` --- ## `MapDataset.map` — Apply a 1:1 transformation to every element Returns a new `MapDataset` where every element is passed through the callable or `grain.transforms.Map` subclass. ```python import grain import dataclasses # Lambda shorthand ds = grain.MapDataset.range(5).map(lambda x: x ** 2) print(list(ds)) # [0, 1, 4, 9, 16] # Class-based transform (recommended for serialization / DataLoader use) @dataclasses.dataclass class Normalize(grain.transforms.Map): mean: float std: float def map(self, element): return (element - self.mean) / self.std ds = grain.MapDataset.source([0.0, 1.0, 2.0, 3.0]).map(Normalize(mean=1.5, std=1.0)) print(list(ds)) # [-1.5, -0.5, 0.5, 1.5] ``` --- ## `MapDataset.random_map` — Apply a randomized 1:1 transformation Calls the transform with the element **and** a `np.random.Generator` seeded deterministically per element, enabling reproducible stochastic augmentations. ```python import grain import numpy as np # Adds a random integer in [0, 5) to each element ds = grain.MapDataset.range(5).random_map( lambda x, rng: x + rng.integers(0, 5), seed=0, ) print(list(ds)) # deterministic, e.g. [2, 4, 4, 6, 5] # Class-based with numpy image augmentation @dataclasses.dataclass class RandomFlip(grain.transforms.RandomMap): def random_map(self, image: np.ndarray, rng: np.random.Generator) -> np.ndarray: if rng.random() > 0.5: return np.flip(image, axis=1) return image ds = grain.MapDataset.source(images).seed(123).random_map(RandomFlip()) ``` --- ## `MapDataset.map_with_index` — Map with element index Passes the element's position within the dataset as the first argument to the transform, enabling index-aware augmentations. ```python import grain ds = grain.MapDataset.source(["a", "b", "c", "d"]) ds = ds.map_with_index(lambda i, x: f"{x}_{i}") print(list(ds)) # ['a_0', 'b_1', 'c_2', 'd_3'] ``` --- ## `MapDataset.filter` — Remove elements matching a predicate Returns a new `MapDataset` where elements that fail the predicate are replaced by `None`. Iteration automatically skips `None` values. Direct indexing can return `None`. ```python import grain ds = grain.MapDataset.range(10).filter(lambda x: x % 2 == 0) print(ds[0]) # 0 print(ds[1]) # None (1 is filtered out) print(list(ds)) # [0, 2, 4, 6, 8] (Nones skipped during iteration) # When filter and batch are both needed, filter first then convert to IterDataset ds = ( grain.MapDataset.range(20) .filter(lambda x: x % 3 == 0) .to_iter_dataset() .batch(batch_size=3) ) for batch in ds: print(batch) # e.g. array([0, 3, 6]) ``` --- ## `MapDataset.batch` — Batch consecutive elements Groups consecutive elements into fixed-size batches stacked along a new leading dimension. Dataset elements are expected to be PyTrees (numpy arrays, dicts of arrays, etc.). ```python import grain ds = grain.MapDataset.range(10).batch(batch_size=3) print(list(ds)) # [array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8]), array([9])] # Drop the final partial batch ds = grain.MapDataset.range(10).batch(batch_size=3, drop_remainder=True) print(list(ds)) # [array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8])] # Custom batch function (e.g., build a dict-of-arrays batch) def dict_batch(elements): import numpy as np return {"values": np.stack([e["values"] for e in elements])} ds = grain.MapDataset.source([{"values": np.array([i])} for i in range(6)]) ds = ds.batch(batch_size=2, batch_fn=dict_batch) ``` --- ## `MapDataset.repeat` — Repeat dataset across epochs Repeats the dataset a fixed number of times or infinitely (`num_epochs=None`). By default, random upstream transformations such as `shuffle` are re-seeded per epoch, producing different orderings each epoch. ```python import grain # Finite repeat ds = grain.MapDataset.range(3).repeat(2) print(list(ds)) # [0, 1, 2, 0, 1, 2] # Infinite repeat for training loops ds = grain.MapDataset.range(5).shuffle(seed=0).repeat() print(len(ds) == import_sys_maxsize()) # True (sys.maxsize) # Practical infinite training loop ds = ( grain.MapDataset.range(1000) .shuffle(seed=42) .map(preprocess_fn) .batch(batch_size=32) .repeat() ) for step, batch in enumerate(ds): if step >= max_steps: break train_step(batch) ``` --- ## `MapDataset.slice` — Shard or subset a dataset Returns a view of the dataset restricted to the given `slice`. The primary use case is per-host data sharding in distributed training. ```python import grain ds = grain.MapDataset.range(10) # Manual sharding across 4 hosts shard_index = 0 # jax.process_index() shard_count = 4 # jax.process_count() sharded = ds[shard_index::shard_count] print(list(sharded)) # [0, 4, 8] # Or equivalently sharded = ds.slice(slice(shard_index, None, shard_count)) ``` --- ## `MapDataset.mix` — Mix multiple datasets Interleaves elements from multiple `MapDataset`s according to the given proportions. The mixed dataset length equals the shortest input dataset length. ```python import grain ds1 = grain.MapDataset.range(5) # [0, 1, 2, 3, 4] ds2 = grain.MapDataset.range(5, 8) # [5, 6, 7] # Uniform interleaving ds = grain.MapDataset.mix([ds1, ds2]) print(list(ds)) # [0, 5, 1, 6, 2, 7] # Weighted mixing (70% from ds1, 30% from ds2) ds1_large = grain.MapDataset.range(100) ds2_large = grain.MapDataset.range(100, 200) ds = grain.MapDataset.mix([ds1_large, ds2_large], weights=[0.7, 0.3]) ``` --- ## `MapDataset.concatenate` — Concatenate datasets Returns a new dataset that sequences elements from all input datasets one after another. ```python import grain ds1 = grain.MapDataset.range(3) # [0, 1, 2] ds2 = grain.MapDataset.range(3, 7) # [3, 4, 5, 6] ds = grain.MapDataset.concatenate([ds1, ds2]) print(list(ds)) # [0, 1, 2, 3, 4, 5, 6] ``` --- ## `MapDataset.to_iter_dataset` — Convert to an iterable dataset Converts a `MapDataset` into an `IterDataset`, enabling multi-threaded prefetching via `ReadOptions`. This should be called as late as possible in the pipeline since some transformations (e.g., `shuffle`, `slice`) are not available on `IterDataset`. ```python import grain ds = ( grain.MapDataset.range(1000) .shuffle(seed=0) .map(lambda x: x * 2) .to_iter_dataset( grain.ReadOptions(num_threads=16, prefetch_buffer_size=500) ) ) for batch in ds: process(batch) ``` --- ## `IterDataset.mp_prefetch` — Multi-process prefetching Distributes dataset processing across multiple worker processes, each handling a slice of the `MapDataset`. This is the primary way to parallelize CPU-bound preprocessing. ```python import grain ds = ( grain.MapDataset.source(my_source) .shuffle(seed=42) .map(heavy_preprocess_fn) .batch(batch_size=64) .to_iter_dataset() .mp_prefetch(grain.MultiprocessingOptions(num_workers=8, per_worker_buffer_size=2)) ) it = iter(ds) try: for step in range(total_steps): batch = next(it) train_step(batch) finally: it.close() # Explicitly release worker processes ``` --- ## `IterDataset.mix` — Mix iterable datasets Interleaves elements from multiple `IterDataset`s. Useful when mixture components are sparse (e.g., after `filter`). ```python import grain ds1 = ( grain.MapDataset.source(source_a) .filter(lambda x: x["label"] == 0) .to_iter_dataset() ) ds2 = ( grain.MapDataset.source(source_b) .filter(lambda x: x["label"] == 1) .to_iter_dataset() ) # Mix with equal weights (stops when either is exhausted) mixed = grain.IterDataset.mix([ds1, ds2], weights=[0.5, 0.5]) # Named components for checkpoint recovery across topology changes mixed = grain.IterDataset.mix( {"class_0": ds1, "class_1": ds2}, weights={"class_0": 0.5, "class_1": 0.5}, ) ``` --- ## `DatasetIterator.get_state` / `set_state` — Manual checkpointing `DatasetIterator` (returned by `iter(iter_dataset)`) exposes `get_state` and `set_state` for saving and restoring pipeline position. Works with Orbax for full distributed checkpointing. ```python import grain import orbax.checkpoint as ocp ds = ( grain.MapDataset.range(1000) .seed(42) .shuffle() .to_iter_dataset() ) ds_iter = iter(ds) # Advance iterator for _ in range(100): x = next(ds_iter) # Save via Orbax CheckpointManager mngr = ocp.CheckpointManager("/tmp/my_checkpoint") mngr.save(step=100, args=grain.checkpoint.CheckpointSave(ds_iter), force=True) mngr.wait_until_finished() # Restore to step 100 and continue from exactly that position mngr.restore(100, args=grain.checkpoint.CheckpointRestore(ds_iter)) x = next(ds_iter) # element 101, identical to original run ``` --- ## `DatasetIterator.close` — Release multiprocessing resources Explicitly closes the iterator and cleans up worker processes spawned by `mp_prefetch`. Garbage collection will also trigger cleanup but is not guaranteed in CPython. ```python import grain ds = ( grain.MapDataset.source(my_source) .map(preprocess) .to_iter_dataset() .mp_prefetch(grain.MultiprocessingOptions(num_workers=4)) ) it = iter(ds) for batch in it: train_step(batch) it.close() # or use as a context (via try/finally) ``` --- ## `MapDataset.pipe` — Chainable custom transformations Applies any callable to the dataset using method-chaining syntax, useful for third-party transformations not built into Grain. ```python import grain def add_noise_dataset(ds, noise_std, seed): return ds.random_map( lambda x, rng: x + rng.normal(scale=noise_std), seed=seed, ) ds = ( grain.MapDataset.range(10) .pipe(add_noise_dataset, noise_std=0.1, seed=7) .batch(batch_size=5) ) print(list(ds)) ``` --- ## `MapDataset.apply` — Apply transformation objects declaratively Accepts one or a sequence of `grain.transforms.Transformation` objects and dispatches to the correct dataset method. Useful when building pipelines programmatically. ```python import grain @dataclasses.dataclass class AddOne(grain.transforms.Map): def map(self, x): return x + 1 pipeline = [ AddOne(), grain.transforms.Batch(batch_size=3), ] ds = grain.MapDataset.range(9).apply(pipeline) print(list(ds)) # [array([1, 2, 3]), array([4, 5, 6]), array([7, 8, 9])] ``` --- ## `DataLoader` — High-level loading pipeline The `DataLoader` API combines a `RandomAccessDataSource`, a `Sampler`, and a flat list of `Transformation`s into a single object. It is simpler than the `Dataset` API and automatically handles sharding and multiprocessing placement. ```python import grain source = grain.ArrayRecordDataSource(["/data/train-*.arrayrecord"]) sampler = grain.IndexSampler( num_records=len(source), shard_options=grain.ShardByJaxProcess(), shuffle=True, num_epochs=None, # infinite seed=42, ) @dataclasses.dataclass class DecodeAndAugment(grain.transforms.Map): def map(self, record: bytes): example = decode_proto(record) return example loader = grain.DataLoader( data_source=source, sampler=sampler, operations=[ DecodeAndAugment(), grain.transforms.Batch(batch_size=128, drop_remainder=True), ], worker_count=8, worker_buffer_size=2, ) iterator = iter(loader) state = iterator.get_state() # bytes; save to disk for preemption recovery for step in range(max_steps): batch = next(iterator) train_step(batch) ``` --- ## `grain.load` — Convenience wrapper for simple pipelines A one-call entry point that internally creates an `IndexSampler` and `DataLoader`. Best for straightforward use cases without mixing or packing. ```python import grain source = grain.ArrayRecordDataSource(["/data/train-*.arrayrecord"]) @dataclasses.dataclass class Decode(grain.transforms.Map): def map(self, record: bytes): return parse_example(record) loader = grain.load( source=source, num_epochs=10, shuffle=True, seed=0, shard_options=grain.ShardByJaxProcess(), transformations=[Decode()], batch_size=64, drop_remainder=True, worker_count=4, ) for batch in loader: train_step(batch) ``` --- ## `ArrayRecordDataSource` — Efficient file-backed random access Wraps one or more ArrayRecord files into a `RandomAccessDataSource`. Supports `FileInstruction` to read specific record ranges within sharded files. ```python import grain # Single file source = grain.ArrayRecordDataSource("/data/train.arrayrecord") # Multiple shards source = grain.ArrayRecordDataSource([ "/data/train-00000.arrayrecord", "/data/train-00001.arrayrecord", ]) # With index stored in memory for fast random access source = grain.ArrayRecordDataSource( ["/data/train-*.arrayrecord"], reader_options={"index_storage_option": "in_memory"}, ) print(f"Total records: {len(source)}") record_bytes = source[42] # raw bytes; decode in a Map transform ``` --- ## `ShardOptions` / `ShardByJaxProcess` / `NoSharding` — Distributed sharding Controls how data is split across JAX processes (hosts) for distributed training. `ShardByJaxProcess` automatically reads `jax.process_index()` and `jax.process_count()`. ```python import grain # No sharding (single process) shard = grain.NoSharding() # Shard by JAX process (multi-host training) shard = grain.ShardByJaxProcess(drop_remainder=True) # Manual sharding shard = grain.ShardOptions(shard_index=1, shard_count=4, drop_remainder=False) # Use in Dataset pipeline ds = ( grain.MapDataset.source(source) .shuffle(seed=0) [shard.shard_index::shard.shard_count] # slice-based sharding ) # Or delegate to IndexSampler / DataLoader sampler = grain.IndexSampler( num_records=len(source), shard_options=grain.ShardByJaxProcess(), shuffle=True, seed=0, ) ``` --- ## `IndexSampler` — Epoch/shuffle/shard controller for `DataLoader` Determines the order in which records are fed to the `DataLoader`. Supports infinite training, shuffled epochs, and sharding. ```python import grain # Single-epoch sequential sampler = grain.IndexSampler( num_records=60000, shard_options=grain.NoSharding(), shuffle=False, num_epochs=1, ) # Infinite shuffled training across 8 JAX hosts sampler = grain.IndexSampler( num_records=1_000_000, shard_options=grain.ShardByJaxProcess(), shuffle=True, num_epochs=None, seed=42, ) print(repr(sampler)) # IndexSampler(num_records=1000000, shard_options=ShardByJaxProcess(...), ...) ``` --- ## `ReadOptions` / `MultiprocessingOptions` — Performance tuning Dataclasses for configuring threading (within a process) and multiprocessing (across worker processes). ```python import grain # ReadOptions: controls per-process I/O threading read_opts = grain.ReadOptions( num_threads=16, # threads reading from DataSource in parallel prefetch_buffer_size=500, # elements buffered per process ) # MultiprocessingOptions: controls worker processes mp_opts = grain.MultiprocessingOptions( num_workers=8, # number of worker processes per_worker_buffer_size=2, # batches pre-computed per worker enable_profiling=False, ) ds = ( grain.MapDataset.source(source) .shuffle(seed=0) .map(preprocess) .batch(batch_size=64) .to_iter_dataset(read_opts) .mp_prefetch(mp_opts) ) ``` --- ## Transformation base classes — Custom `Map`, `Filter`, `FlatMap`, `RandomMap` Abstract base classes in `grain.transforms` for authoring reusable, serializable data transformations. Prefer class-based transforms over lambdas for use with `DataLoader` and checkpointing. ```python import grain import dataclasses import numpy as np from typing import Sequence @dataclasses.dataclass class Tokenize(grain.transforms.Map): vocab_size: int = 32000 def map(self, text: str) -> np.ndarray: return tokenizer.encode(text) @dataclasses.dataclass class LengthFilter(grain.transforms.Filter): max_length: int = 512 def filter(self, tokens: np.ndarray) -> bool: return len(tokens) <= self.max_length @dataclasses.dataclass class SplitDocument(grain.transforms.FlatMap): max_fan_out: int = 10 chunk_size: int = 128 def flat_map(self, tokens: np.ndarray) -> Sequence[np.ndarray]: return [tokens[i:i+self.chunk_size] for i in range(0, len(tokens), self.chunk_size)] @dataclasses.dataclass class RandomCrop(grain.transforms.RandomMap): crop_size: int = 224 def random_map(self, image: np.ndarray, rng: np.random.Generator) -> np.ndarray: h, w = image.shape[:2] top = rng.integers(0, h - self.crop_size + 1) left = rng.integers(0, w - self.crop_size + 1) return image[top:top+self.crop_size, left:left+self.crop_size] # Use in a Dataset pipeline ds = ( grain.MapDataset.source(text_source) .map(Tokenize()) .filter(LengthFilter(max_length=512)) .to_iter_dataset() .map(lambda x: x) # further per-element ops ) # Use in a DataLoader pipeline loader = grain.DataLoader( data_source=image_source, sampler=sampler, operations=[RandomCrop(crop_size=224), grain.transforms.Batch(32)], worker_count=4, ) ``` --- ## Summary Grain's primary use cases are building deterministic, preemption-resilient input pipelines for large-scale ML training. The `Dataset` API is ideal when pipelines require dataset mixing with arbitrary proportions, variable-length sequence packing, document splitting with global re-shuffling, or fine-grained control over sharding order. The `DataLoader` API suits simpler workflows: a single data source, sequential transformations, and built-in support for shuffle, shard, and epoch management via `IndexSampler`. Both paths support the same underlying transformation primitives (`Map`, `Filter`, `RandomMap`, `FlatMap`, `Batch`) implemented as typed dataclasses for reproducibility. In distributed JAX training environments, Grain integrates cleanly with the JAX process model via `ShardByJaxProcess` and with Orbax for coordinated model+data checkpointing. For maximum throughput, the recommended pattern is: apply all `MapDataset` transformations (including shuffle, sharding via `slice`, and map), convert to `IterDataset` with `to_iter_dataset(ReadOptions(...))` for threaded prefetching, and then call `mp_prefetch(MultiprocessingOptions(num_workers=N))` for CPU-bound preprocessing in parallel worker processes. The `DatasetIterator.close()` call should always be paired with `mp_prefetch` pipelines to reclaim OS resources promptly.