### Setup Flax and LM1B Example Dependencies
Source: https://github.com/google/flax/blob/main/examples/lm1b/README.md
Clones the Flax repository, installs it locally, and then installs the specific dependencies for the LM1B example.
```bash
git clone --depth=1 --branch=main https://github.com/google/flax
cd flax
pip install -e .
cd examples/lm1b
pip install -r requirements.txt
```
--------------------------------
### Install and Run VAE Example
Source: https://github.com/google/flax/blob/main/examples/vae/README.md
Install dependencies and execute the VAE training script with default configurations.
```bash
pip install -r requirements.txt
python main.py --workdir=/tmp/mnist --config=configs/default.py
```
--------------------------------
### Setup environment and copy example files
Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb
Configures the environment by cloning the repository and optionally mounting Google Drive for persistent storage of example files.
```python
# (If you run this code in Jupyter[lab], then you're already in the
# example directory and nothing needs to be done.)
#@markdown **Fetch newest Flax version and copy of example code.**
#@markdown
#@markdown **If you select no** below, then the files will be stored on the
#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will
#@markdown be restarted and any changes will be lost**.
#@markdown
#@markdown **If you select yes** below, then you will be asked for your
#@markdown credentials to mount your personal Google Drive. In this case, all
#@markdown changes you make will *persist*. Even if you re-run this
#@markdown Colab notebook later on, the files will still exist. You can
#@markdown remove directories inside your Drive's `flax/` root if you want to
#@markdown manually revert these files.
if 'google.colab' in str(get_ipython()):
import os
os.chdir('/content')
# Download Flax repo from Github.
if not os.path.isdir('flaxrepo'):
!git clone --depth=1 -b $branch $repo flaxrepo
# Copy example files & change directory.
mount_gdrive = 'no' #@param ['yes', 'no']
if mount_gdrive == 'yes':
DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'
from google.colab import drive
drive.mount('/content/gdrive')
example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'
else:
DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'
example_root_path = f'/content/{example_directory}'
from IPython import display
display.display(display.HTML(
f'
{DISCLAIMER}
'))
if not os.path.isdir(example_root_path):
os.makedirs(example_root_path)
!cp -r flaxrepo/$example_directory/* "$example_root_path"
os.chdir(example_root_path)
from google.colab import files
for relpath in editor_relpaths:
s = open(f'{example_root_path}/{relpath}').read()
open(f'{example_root_path}/{relpath}', 'w').write(
f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s)
files.view(f'{example_root_path}/{relpath}')
```
--------------------------------
### Setup imports and dummy data
Source: https://github.com/google/flax/blob/main/docs_nnx/guides/jax_and_nnx_transforms.rst
Initializes the environment with necessary imports and random input data for the examples.
```python
from flax import nnx
import jax
x = jax.random.normal(jax.random.key(0), (1, 2))
y = jax.random.normal(jax.random.key(1), (1, 3))
```
--------------------------------
### Configure Example Directory and Repository
Source: https://github.com/google/flax/blob/main/examples/imagenet/imagenet.ipynb
Sets up variables for the example directory path and repository details. These are used for cloning the repository and managing example files.
```python
example_directory = 'examples/imagenet'
editor_relpaths = ('configs/default.py', 'input_pipeline.py', 'models.py', 'train.py')
repo, branch = 'https://github.com/google/flax', 'main'
```
--------------------------------
### Setup Python Environment and Install Jax on TPU VM
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Commands to set up a Python 3.12 virtual environment using uv, install Jax with TPU support, and verify the installation on the TPU VM.
```bash
# Setup Python 3.12 env with UV
python -m pip install uv
uv venv --python 3.12 /tmp/venv
source /tmp/venv/bin/activate
uv pip install pip
# which python && python -VV && pip --version
pip install "jax[tpu]"
# Check whether TPUs are available:
```
--------------------------------
### Install Flax and dependencies
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Clone the Flax repository and install required dependencies for the Gemma example.
```bash
git clone --depth=1 --branch=main https://github.com/google/flax
cd flax
pip install -e .
cd examples/gemma
pip install -r requirements.txt
```
--------------------------------
### Install Flax and Dependencies
Source: https://github.com/google/flax/blob/main/examples/wmt/README.md
Clones the Flax repository and installs required dependencies for the WMT example.
```bash
git clone --depth=1 --branch=main https://github.com/google/flax
cd flax
pip install -e .
cd examples/wmt
pip install -r requirements.txt
```
--------------------------------
### Initialize Flax Model and Dataset Setup
Source: https://github.com/google/flax/blob/main/docs_nnx/flip/1009-optimizer-api.md
Provides the necessary imports, model definition, and data loading utilities to run Flax migration examples.
```python
import functools
from typing import Callable, Sequence
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import tensorflow as tf
import tensorflow_datasets as tfds
def pp(features):
return {
'image': tf.cast(features['image'], tf.float32) / 255 - 0.5,
'label': features['label'],
}
class Model(nn.Module):
@nn.compact
def __call__(self, inputs):
x = inputs.reshape([inputs.shape[0], -1])
x = nn.normalization.BatchNorm(True)(x)
x = nn.Dense(10)(x)
x = nn.log_softmax(x)
return x
def onehot(labels, num_classes, on_value=1.0, off_value=0.0):
x = (labels[..., None] == jnp.arange(num_classes)[None])
x = jax.lax.select(
x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value))
return x.astype(jnp.float32)
def xent_loss(logits, labels):
return -jnp.sum(
onehot(labels, num_classes=10) * logits) / labels.size
def get_learning_rate(step):
return 0.1
model = Model()
rng = jax.random.key(0)
ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16)
batch = next(iter(ds))
variables = model.init(rng, jnp.array(batch['image'][:1]))
jax.tree_util.tree_map(jnp.shape, variables)
```
--------------------------------
### Install Flax and dependencies
Source: https://github.com/google/flax/blob/main/examples/imagenet/README.md
Clone the Flax repository and install the required dependencies for the ImageNet example.
```shell
git clone --depth=1 --branch=main https://github.com/google/flax
cd flax
pip install -e .
cd examples/imagenet
pip install -r requirements.txt
```
--------------------------------
### Colab Environment Setup for Flax Imagenet Example
Source: https://github.com/google/flax/blob/main/examples/imagenet/imagenet.ipynb
Configures the Colab environment by cloning the Flax repository, copying example files, and optionally mounting Google Drive for persistent storage. It also modifies files to include disclaimers about data persistence.
```python
# (If you run this code in Jupyter[lab], then you're already in the
# example directory and nothing needs to be done.)
#@markdown **Fetch newest Flax, copy example code**
#@markdown
#@markdown **If you select no** below, then the files will be stored on the
#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will
#@markdown be restarted and any changes are lost**.
#@markdown
#@markdown **If you select yes** below, then you will be asked for your
#@markdown credentials to mount your personal Google Drive. In this case, all
#@markdown changes you make will be *persisted*, and even if you re-run the
#@markdown Colab later on, the files will still be the same (you can of course
#@markdown remove directories inside your Drive's `flax/` root if you want to
#@markdown manually revert these files).
if 'google.colab' in str(get_ipython()):
import os
os.chdir('/content')
# Download Flax repo from Github.
if not os.path.isdir('flaxrepo'):
!git clone --depth=1 -b $branch $repo flaxrepo
# Copy example files & change directory.
mount_gdrive = 'no' #@param ['yes', 'no']
if mount_gdrive == 'yes':
DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'
from google.colab import drive
drive.mount('/content/gdrive')
example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'
else:
DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'
example_root_path = f'/content/{example_directory}'
from IPython import display
display.display(display.HTML(
f'{DISCLAIMER}
'))
if not os.path.isdir(example_root_path):
os.makedirs(example_root_path)
!cp -r flaxrepo/$example_directory/* "$example_root_path"
os.chdir(example_root_path)
from google.colab import files
for relpath in editor_relpaths:
s = open(f'{example_root_path}/{relpath}').read()
open(f'{example_root_path}/{relpath}', 'w').write(
f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s)
files.view(f'{example_root_path}/{relpath}')
```
--------------------------------
### Define Example Paths and Repository
Source: https://github.com/google/flax/blob/main/examples/seq2seq/seq2seq.ipynb
Sets the directory paths and repository information for the seq2seq example.
```python
example_directory = 'examples/seq2seq'
editor_relpaths = ('train.py', 'input_pipeline.py', 'models.py')
repo, branch = 'https://github.com/google/flax', 'main'
```
--------------------------------
### Fetch and Configure Example Files
Source: https://github.com/google/flax/blob/main/examples/mnist/mnist.ipynb
Downloads the Flax repository and sets up the environment to edit example files either on the ephemeral VM or a mounted Google Drive.
```python
# (If you run this code in Jupyter[lab], then you're already in the
# example directory and nothing needs to be done.)
#@markdown **Fetch newest Flax, copy example code**
#@markdown
#@markdown **If you select no** below, then the files will be stored on the
#@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will
#@markdown be restarted an any changes are lost**.
#@markdown
#@markdown **If you select yes** below, then you will be asked for your
#@markdown credentials to mount your personal Google Drive. In this case, all
#@markdown changes you make will be *persisted*, and even if you re-run the
#@markdown Colab later on, the files will still be the same (you can of course
#@markdown remove directories inside your Drive's `flax/` root if you want to
#@markdown manually revert these files).
if 'google.colab' in str(get_ipython()):
import os
os.chdir('/content')
# Download Flax repo from Github.
if not os.path.isdir('flaxrepo'):
!git clone --depth=1 -b $branch $repo flaxrepo
# Copy example files & change directory.
mount_gdrive = 'no' #@param ['yes', 'no']
if mount_gdrive == 'yes':
DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.'
from google.colab import drive
drive.mount('/content/gdrive')
example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}'
else:
DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!'
example_root_path = f'/content/{example_directory}'
from IPython import display
display.display(display.HTML(
f'{DISCLAIMER}
'))
if not os.path.isdir(example_root_path):
os.makedirs(example_root_path)
!cp -r flaxrepo/$example_directory/* "$example_root_path"
os.chdir(example_root_path)
from google.colab import files
for relpath in editor_relpaths:
s = open(f'{example_root_path}/{relpath}').read()
open(f'{example_root_path}/{relpath}', 'w').write(
f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s)
files.view(f'{example_root_path}/{relpath}')
```
--------------------------------
### Training Loop Setup
Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb
This section indicates the start of the training loop and suggests using a Colab GPU runtime for faster training. It also notes that TPUs are not used here due to the lack of pmap() distribution.
```python
# Training loop
# Use a Colab GPU runtime to speed up training.
# We don't use TPUs in this Colab because we do not distribute our
# training using pmap() - if you're looking for an example using TPUs
# checkout the below Colab notebook:
```
--------------------------------
### Define example paths and repository variables
Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb
Sets the directory paths and repository information for the OGBG-MolPCBA example.
```python
example_directory = 'examples/ogbg_molpcba'
editor_relpaths = ('configs/default.py', 'input_pipeline.py', 'models.py', 'train.py')
repo, branch = 'https://github.com/google/flax', 'main'
```
--------------------------------
### Install build dependencies
Source: https://github.com/google/flax/blob/main/flaxlib_src/README.md
Install the required build tools including meson-python, ninja, and build.
```shell
pip install meson-python ninja build
```
--------------------------------
### Start TensorBoard on TPU Worker
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Installs xprof and launches a TensorBoard instance on a specific TPU worker, forwarding the port to localhost.
```bash
gcloud compute tpus tpu-vm ssh --zone $ZONE $TPU_NAME --worker=0 --command="pip install xprof"
gcloud compute tpus tpu-vm ssh --zone $ZONE $TPU_NAME --worker=0 --command="tensorboard --logdir=$out_dir --port 7007" -- -L 7007:localhost:7007
```
--------------------------------
### Define example directory and file paths
Source: https://github.com/google/flax/blob/main/examples/sst2/sst2.ipynb
Sets the base directory and relative paths for the SST-2 example files.
```python
example_directory = 'examples/sst2'
editor_relpaths = ('configs/default.py', 'train.py', 'models.py')
```
--------------------------------
### Setup JAX Devices and Mesh
Source: https://github.com/google/flax/blob/main/docs_nnx/guides/flax_gspmd.md
Initializes JAX with a specified number of CPU devices and creates a 2x4 device mesh for distributed computation. This setup is useful for testing or when actual accelerators are not available.
```python
from functools import partial
import jax
from jax import numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType
import optax
import flax
from flax import nnx
# Ignore this if you are already running on a TPU or GPU
if not jax._src.xla_bridge.backends_are_initialized():
jax.config.update('jax_num_cpu_devices', 8)
print(f'You have 8 “fake” JAX devices now: {jax.devices()}')
```
```python
# Create an auto-mode mesh of two dimensions and annotate each axis with a name.
auto_mesh = jax.make_mesh((2, 4), ('data', 'model'))
```
--------------------------------
### Clone and install dependencies
Source: https://github.com/google/flax/blob/main/docs_nnx/contributing.md
Clones the repository and installs required packages for development, testing, and documentation.
```bash
git clone https://github.com/YOUR_USERNAME/flax
cd flax
pip install -e ".[all,testing,docs]"
```
```bash
uv sync --all-extras
```
--------------------------------
### Define Example Paths
Source: https://github.com/google/flax/blob/main/examples/mnist/mnist.ipynb
Sets the directory and file paths for the MNIST example and the repository source.
```python
example_directory = 'examples/mnist'
editor_relpaths = ('configs/default.py', 'train.py')
repo, branch = 'https://github.com/google/flax', 'main'
```
--------------------------------
### Run Flax Tracing and Lowering Benchmarks
Source: https://github.com/google/flax/blob/main/benchmarks/tracing/README.md
Commands to install required dependencies and execute performance profiling for Flax examples.
```bash
pip install -r benchmarks/tracing/requirements.txt
# Benchmark trace and lower timing for all workloads.
python tracing_benchmark.py
# Profile a single example.
python tracing_benchmark.py --example=wmt
# Profile just tracing for a single example.
python tracing_benchmark.py --example=wmt --mode=trace
```
--------------------------------
### TPU Setup and Device Check
Source: https://github.com/google/flax/blob/main/examples/imagenet/imagenet.ipynb
Provides instructions for setting up and using TPUs with Flax, including installing compatible versions of Flax, JAX, and JAXlib. It then checks and lists the available JAX devices.
```python
# It's possible to run this Colab with TPUs:
# 1. change runtime type to TPU
# 2. install compatible flax version: `!pip install flax==0.6.4 jax==0.3.25 jaxlib==0.3.25`
# 3. uncomment lines below
# import flax, jax, jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()
import jax
jax.devices()
```
--------------------------------
### Install Dependencies
Source: https://github.com/google/flax/blob/main/docs_nnx/examples/gemma.md
Install the required libraries for Flax NNX and model management.
```bash
! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope
```
--------------------------------
### Verify JAX installation
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Check if JAX is correctly installed and identifying devices.
```bash
python3 -c "import jax; print(jax.devices())"
```
--------------------------------
### Start TensorBoard Monitoring
Source: https://github.com/google/flax/blob/main/examples/lm1b/README.md
Starts TensorBoard to monitor training logs, assuming logs are stored in the specified directory.
```bash
tensorboard --logdir=$HOME/logs
```
--------------------------------
### Run training and monitoring
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Execute the training script and start TensorBoard for monitoring.
```bash
python3 main.py --workdir=$HOME/logs/gemma_lm1b_256 --config.per_device_batch_size=32
```
```bash
tensorboard --logdir=$HOME/logs
```
--------------------------------
### Install Flax from head
Source: https://github.com/google/flax/blob/main/tests/import_test.ipynb
Installs the latest development version of Flax directly from the GitHub repository.
```bash
!pip install git+https://github.com/google/flax.git
```
--------------------------------
### Run MNIST Example
Source: https://github.com/google/flax/blob/main/examples/mnist/README.md
Basic command to run the MNIST example. Specify the work directory and configuration file.
```shell
python main.py --workdir=/tmp/mnist --config=configs/default.py
```
--------------------------------
### Launch MNIST Example on GCE
Source: https://github.com/google/flax/blob/main/examples/cloud/README.md
Launches the MNIST training example on a Google Compute Engine VM. Ensure you have set the PROJECT and GCS_BUCKET environment variables.
```shell
python examples/cloud/launch_gce.py \
--project=$PROJECT \
--zone=us-west1-a \
--machine_type=n2-standard-2 \
--gcs_workdir_base=gs://$GCS_BUCKET/workdir_base \
--repo=${REPO:-https://github.com/google/flax} \
--branch=${BRANCH:-main} \
--example=mnist \
--args='--config=configs/default.py' \
--name=default
```
--------------------------------
### Start TensorBoard for Training Monitoring
Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb
This code snippet loads the TensorBoard extension and starts a TensorBoard server. It's useful for live monitoring of training progress, especially when using Colab.
```python
# Start TensorBoard
# Get a live update during training - use the "refresh" button!
# (In Jupyter[lab] start "tensorboard" in the local directory instead.)
if 'google.colab' in str(get_ipython()):
%load_ext tensorboard
%tensorboard --logdir=.
```
--------------------------------
### Install JAX on TPU VM
Source: https://github.com/google/flax/blob/main/examples/lm1b/README.md
Installs JAX with TPU support on the connected TPU VM, specifying a version compatible with TPU releases.
```bash
pip install "jax[tpu]>=0.2.16" \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
--------------------------------
### Start training on TPU
Source: https://github.com/google/flax/blob/main/examples/imagenet/README.md
Execute the training script with the appropriate configuration and backend target.
```shell
export TFDS_DATA_DIR=gs://$GCS_TFDS_BUCKET/datasets
python3 main.py --workdir=$HOME/logs/imagenet_tpu --config=configs/tpu.py \
--jax_backend_target="grpc://192.168.0.2:8470"
```
--------------------------------
### Memory-Efficient Partial Initialization Setup
Source: https://github.com/google/flax/blob/main/docs_nnx/guides/surgery.md
Initializes the state for memory-efficient partial initialization.
```python
# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))
```
--------------------------------
### Flax NNX Training Step Execution Example
Source: https://github.com/google/flax/blob/main/docs_nnx/migrating/haiku_to_flax.rst
Example of how to execute a Flax NNX training step with a model instance, inputs, and labels. The model is updated in-place.
```python
sample_x = jnp.ones((1, 784))
train_step(model, sample_x, jnp.ones((1,), dtype=jnp.int32))
```
--------------------------------
### Launch ImageNet Example on GCE
Source: https://github.com/google/flax/blob/main/examples/cloud/README.md
Launches the ImageNet training example on a Google Compute Engine VM with specified hardware accelerators. Ensure PROJECT, GCS_BUCKET, and GCS_TFDS_BUCKET environment variables are set.
```shell
python examples/cloud/launch_gce.py \
--project=$PROJECT \
--zone=us-west1-a \
--machine_type=n1-standard-96 \
--accelerator_type=nvidia-tesla-v100 --accelerator_count=8 \
--gcs_workdir_base=gs://$GCS_BUCKET/workdir_base \
--tfds_data_dir=gs://$GCS_TFDS_BUCKET/datasets \
--repo=${REPO:-https://github.com/google/flax} \
--branch=${BRANCH:-main} \
--example=imagenet \
--args='--config=configs/v100_x8_mixed_precision.py' \
--name=v100_x8_mixed_precision
```
--------------------------------
### Install Jax with CUDA and Flax Dependencies
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Commands to install Jax with CUDA support, Flax, and example-specific requirements for GPU training. Includes a check for GPU availability.
```bash
pip install jax[cuda13]
# Check whether GPUs are available:
# python3 -c "import jax; print(jax.devices())"
git clone --depth=1 --branch=main https://github.com/google/flax
cd flax
pip install -e .
cd examples/gemma
pip install -r requirements.txt
```
--------------------------------
### Configure the build
Source: https://github.com/google/flax/blob/main/flaxlib_src/README.md
Set up the build environment by installing dependencies via meson wrap and initializing the build directory.
```shell
mkdir -p subprojects
meson wrap install robin-map
meson wrap install nanobind
meson setup builddir
```
--------------------------------
### Setup dependencies and checkpoint directory
Source: https://github.com/google/flax/blob/main/docs_nnx/guides/checkpointing.md
Imports necessary libraries and initializes a temporary directory for checkpoint storage.
```python
from flax import nnx
import orbax.checkpoint as ocp
import jax
from jax import numpy as jnp
import numpy as np
ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')
```
--------------------------------
### Define and Train a CNN with Flax NNX
Source: https://context7.com/google/flax/llms.txt
Full example demonstrating CNN definition, initialization, and optimizer setup.
```python
from flax import nnx
import jax
import jax.numpy as jnp
import optax
# Define a CNN model
class CNN(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.conv1(x))
x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nnx.relu(self.conv2(x))
x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape(x.shape[0], -1) # Flatten
x = nnx.relu(self.linear1(x))
x = self.dropout(x)
x = self.linear2(x)
return x
# Initialize model and optimizer
model = CNN(rngs=nnx.Rngs(0))
optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param)
```
--------------------------------
### Initialize JAX Environment for Multi-Device Simulation
Source: https://github.com/google/flax/blob/main/docs_nnx/guides/flax_gspmd.ipynb
Sets up a multi-device environment using fake CPU devices for testing purposes.
```python
from functools import partial
import jax
from jax import numpy as jnp
from jax.sharding import PartitionSpec as P, NamedSharding, AxisType
import optax
import flax
from flax import nnx
# Ignore this if you are already running on a TPU or GPU
if not jax._src.xla_bridge.backends_are_initialized():
jax.config.update('jax_num_cpu_devices', 8)
print(f'You have 8 “fake” JAX devices now: {jax.devices()}')
```
--------------------------------
### Initialize Model with Partitioned Weights
Source: https://github.com/google/flax/blob/main/docs_nnx/flip/2434-general-metadata.md
Demonstrates initializing a model with partitioned weights and inspecting the resulting variable structure, including the Partitioned metadata.
```python
variables = partitioned_dense.init(rng, jnp.ones((4,)))
jax.tree.map(np.shape, variables) # => {"params": {"kernel": Partitioned(value=(4, 8), names=(None, "data")), bias: (8,)}}
```
--------------------------------
### Setup imports and environment
Source: https://github.com/google/flax/blob/main/docs_nnx/guides/bridge_guide.ipynb
Initial configuration and imports required for using Flax NNX and Linen together.
```python
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
from flax import nnx
from flax import linen as nn
from flax.nnx import bridge
import jax
from jax import numpy as jnp
from jax.experimental import mesh_utils
from typing import *
```
--------------------------------
### Install GCSFuse
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Configures the GCSFuse repository and installs the necessary packages.
```bash
export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`
echo "deb [signed-by=/usr/share/keyrings/cloud.google.asc] https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | tee /etc/apt/sources.list.d/gcsfuse.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | tee /usr/share/keyrings/cloud.google.asc
apt-get update
apt-get install -y fuse gcsfuse --no-install-recommends
```
--------------------------------
### Install SST-2 dependencies
Source: https://github.com/google/flax/blob/main/examples/sst2/sst2.ipynb
Installs required packages from the requirements.txt file.
```bash
# Install SST-2 dependencies.
!pip install -q -r requirements.txt
```
--------------------------------
### Configure multi-host TPU environment
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Set environment variables and create a startup script for multi-host TPU training.
```bash
export ZONE=us-east5-a
export ACCELERATOR_TYPE=v5p-32
export RUNTIME_VERSION=v2-alpha-tpuv5
export TPU_NAME=flax-gemma-lm1b-${ACCELERATOR_TYPE}
export GCS_OUTPUT_BUCKET=flax-gemma-example-training
export GCS_DATA_BUCKET=flax-lm1b-arrayrecords
cat << EOF > /tmp/example_startup.sh
#!/bin/bash
set -xeu
python -m pip install uv
uv venv --python 3.12 /tmp/venv
source /tmp/venv/bin/activate
uv pip install pip
echo "source /tmp/venv/bin/activate" > /root/.bashrc
# Install JAX, FLAX and other deps
python -m pip install jax[tpu]
python -m pip install \
"absl-py~=2.2" \
"clu==0.0.12" \
"mlcroissant~=1.0" \
"numpy~=2.3" \
"optax~=0.2" \
"sentencepiece~=0.2" \
"jaxtyping~=0.3" \
"tensorflow-cpu~=2.20" \
"tensorboard~=2.20" \
"tensorflow-datasets~=4.9" \
"grain~=0.2" \
"orbax-checkpoint[gcp]~=0.11" \
"google-cloud-storage"
cd /root
git clone --depth=1 --branch=main https://github.com/google/flax
cd flax
python -m pip install -e .
# Install gcsfuse
```
--------------------------------
### Install CLU and Flax
Source: https://github.com/google/flax/blob/main/examples/seq2seq/seq2seq.ipynb
Installs the required CLU and Flax libraries in the environment.
```python
# Install CLU & Flax.
!pip install -q clu flax
```
--------------------------------
### Install Flax dependencies
Source: https://github.com/google/flax/blob/main/docs_nnx/nnx_basics.md
Use pip to install the latest version of Flax.
```bash
# ! pip install -U flax
```
--------------------------------
### Install Dependencies
Source: https://github.com/google/flax/blob/main/examples/mnist/mnist.ipynb
Installs the ml-collections library and the latest version of Flax from the GitHub repository.
```python
# Install ml-collections & latest Flax version from Github.
!pip install -q ml-collections git+https://github.com/google/flax
```
--------------------------------
### Run training with default configuration
Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/README.md
Execute the training script using the default configuration file.
```shell
python main.py --workdir=./ogbg_molpcba --config=configs/default.py
```
--------------------------------
### Prepare Output Directories
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Creates output directories on the TPU VM using worker 0.
```bash
export out_name=gemma3-1b_lm1b_run-$ACCELERATOR_TYPE-$(date -u +%Y%m%d-%H%M)
export out_dir=/root/logs/$out_name
export chpt_bucket=gs://$GCS_OUTPUT_BUCKET/$out_name/checkpoint
gcloud compute tpus tpu-vm ssh --zone $ZONE $TPU_NAME --worker=0 --command="export out_dir=$out_dir && mkdir -p $out_dir"
gcloud compute tpus tpu-vm ssh --zone $ZONE $TPU_NAME --worker=0 --command="ls /root/logs"
```
--------------------------------
### Install dependencies
Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb
Installs required libraries including clu, ml-collections, Flax, tensorflow_datasets, and jraph.
```python
# Install clu, ml-collections, latest Flax version, and tensorflow_datasets.
!pip install -U -q clu ml-collections git+https://github.com/google/flax tfds_nightly jraph
```
--------------------------------
### Create a virtual environment
Source: https://github.com/google/flax/blob/main/docs_nnx/contributing.md
Initializes a virtual environment to isolate project dependencies.
```bash
python3 -m virtualenv env
. env/bin/activate
```
--------------------------------
### Create TPU VM
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Provisions a new TPU VM with a specified startup script.
```bash
gcloud compute tpus tpu-vm create $TPU_NAME --spot \
--zone $ZONE \
--accelerator-type=$ACCELERATOR_TYPE \
--version=$RUNTIME_VERSION \
--metadata-from-file=startup-script=/tmp/example_startup.sh
```
--------------------------------
### Install Flax
Source: https://github.com/google/flax/blob/main/docs_nnx/mnist_tutorial.md
Use pip to install Flax and JAX with CUDA support in a Python environment.
```python
# !pip install -U "jax[cuda12]"
# !pip install -U flax
```
--------------------------------
### Initialize Imports and Environment
Source: https://github.com/google/flax/blob/main/docs_nnx/guides/surgery.ipynb
Standard imports and random key initialization required for Flax NNX examples.
```python
from typing import *
from pprint import pprint
import functools
import jax
from jax import lax, numpy as jnp, tree_util as jtu
from jax.sharding import PartitionSpec, Mesh, NamedSharding
from jax.experimental import mesh_utils
import flax
from flax import nnx
import flax.traverse_util
import numpy as np
import orbax.checkpoint as orbax
key = jax.random.key(0)
```
--------------------------------
### Install gcsfuse on TPU VM
Source: https://github.com/google/flax/blob/main/examples/gemma/README.md
Configure and install gcsfuse to mount Google Cloud Storage buckets.
```bash
# Install gcsfuse
export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`
echo "deb [signed-by=/usr/share/keyrings/cloud.google.asc] https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | tee /etc/apt/sources.list.d/gcsfuse.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | tee /usr/share/keyrings/cloud.google.asc
apt-get update
apt-get install -y fuse gcsfuse --no-install-recommends
```
--------------------------------
### Run PPO with custom configuration
Source: https://github.com/google/flax/blob/main/examples/ppo/README.md
Execute the PPO training script with a specified configuration file and custom working directory.
```bash
python ppo_main.py --config=configs/default.py --workdir=/my_fav_directory
```
```bash
python ppo_main.py --config=configs/default.py --config.game=Seaquest --config.total_frames=20000000 --config.decaying_lr_and_clip_param=False --workdir=/tmp/seaquest
```
--------------------------------
### Install Flax
Source: https://github.com/google/flax/blob/main/README.md
Use this command to install the Flax library from PyPI. Ensure you have Python 3.8 or later.
```bash
pip install flax
```
--------------------------------
### Initialize Toy Model and Loss Function
Source: https://github.com/google/flax/blob/main/docs_nnx/guides/optimization_cookbook.md
Sets up a basic sequential model and a mean squared error loss function for training demonstrations.
```python
import jax
from flax import nnx
jax.config.update('jax_num_cpu_devices', 8)
import jax.numpy as jnp
import functools as ft
import optax
import matplotlib.pyplot as plt
lecun_normal = jax.nn.initializers.lecun_normal()
rngs = nnx.Rngs(0)
def make_model(rngs, init=lecun_normal):
return nnx.Sequential(
nnx.Linear(2,8, rngs=rngs, kernel_init=init),
nnx.Linear(8,8, rngs=rngs, kernel_init=init))
def loss_fn(model, x, y):
return jnp.mean((model(x) - y) ** 2)
```
--------------------------------
### Install Flax with All Dependencies
Source: https://github.com/google/flax/blob/main/README.md
Install Flax along with a comprehensive set of additional dependencies, such as matplotlib, using this command.
```bash
pip install "flax[all]"
```
--------------------------------
### Naive Partial Initialization Example
Source: https://github.com/google/flax/blob/main/docs_nnx/guides/surgery.ipynb
This snippet shows a naive approach to partial initialization by initializing the entire model and then swapping pre-trained parameters. It highlights potential memory overhead due to intermediate parameter creation.
```python
# Some pretrained model state
old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0)))
simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42)))
print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}')
# In this line, extra kernel and bias is created inside the new LoRALinear!
# They are wasted, because you are going to use the kernel and bias in `old_state` anyway.
simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42))
print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}'
' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)')
nnx.update(simple_model, old_state)
print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}'
' (2 discarded - only lora_a & lora_b are used in model)')
```
--------------------------------
### Bidirectional RNN Usage Example
Source: https://github.com/google/flax/blob/main/docs_nnx/flip/2396-rnn.md
Example of how to instantiate and use the Bidirectional RNN combinator with different RNN cell types.
```python
forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32)
backward_rnn = nn.RNN(nn.GRUCell(), cell_size=32)
# Bidirectional combinator.
bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn)
# Encodes a batch of input sequences in both directions.
carry, outputs = bi_rnn(inputs, seq_lengths)
```
--------------------------------
### Haiku Training Step Execution Example
Source: https://github.com/google/flax/blob/main/docs_nnx/migrating/haiku_to_flax.rst
Example of how to execute a Haiku training step with a PRNG key, parameters, inputs, and labels.
```python
train_step(jax.random.key(0), params, sample_x, jnp.ones((1,), dtype=jnp.int32))
```
--------------------------------
### Install Jupytext
Source: https://github.com/google/flax/blob/main/docs_nnx/contributing.md
Install a specific version of jupytext required for syncing notebooks. Ensure the version matches the one specified in .pre-commit-config.yaml.
```bash
pip install jupytext==1.13.8
```
--------------------------------
### Start Training Job
Source: https://github.com/google/flax/blob/main/examples/wmt/README.md
Executes the training script with specified work directory and batch size.
```bash
python3 main.py --workdir=$HOME/logs/wmt_256 \
--config.per_device_batch_size=32 \
--jax_backend_target="grpc://192.168.0.2:8470"
```