### Setup Flax and LM1B Example Dependencies Source: https://github.com/google/flax/blob/main/examples/lm1b/README.md Clones the Flax repository, installs it locally, and then installs the specific dependencies for the LM1B example. ```bash git clone --depth=1 --branch=main https://github.com/google/flax cd flax pip install -e . cd examples/lm1b pip install -r requirements.txt ``` -------------------------------- ### Install and Run VAE Example Source: https://github.com/google/flax/blob/main/examples/vae/README.md Install dependencies and execute the VAE training script with default configurations. ```bash pip install -r requirements.txt python main.py --workdir=/tmp/mnist --config=configs/default.py ``` -------------------------------- ### Setup environment and copy example files Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb Configures the environment by cloning the repository and optionally mounting Google Drive for persistent storage of example files. ```python # (If you run this code in Jupyter[lab], then you're already in the # example directory and nothing needs to be done.) #@markdown **Fetch newest Flax version and copy of example code.** #@markdown #@markdown **If you select no** below, then the files will be stored on the #@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will #@markdown be restarted and any changes will be lost**. #@markdown #@markdown **If you select yes** below, then you will be asked for your #@markdown credentials to mount your personal Google Drive. In this case, all #@markdown changes you make will *persist*. Even if you re-run this #@markdown Colab notebook later on, the files will still exist. You can #@markdown remove directories inside your Drive's `flax/` root if you want to #@markdown manually revert these files. if 'google.colab' in str(get_ipython()): import os os.chdir('/content') # Download Flax repo from Github. if not os.path.isdir('flaxrepo'): !git clone --depth=1 -b $branch $repo flaxrepo # Copy example files & change directory. mount_gdrive = 'no' #@param ['yes', 'no'] if mount_gdrive == 'yes': DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.' from google.colab import drive drive.mount('/content/gdrive') example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}' else: DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!' example_root_path = f'/content/{example_directory}' from IPython import display display.display(display.HTML( f'

{DISCLAIMER}

')) if not os.path.isdir(example_root_path): os.makedirs(example_root_path) !cp -r flaxrepo/$example_directory/* "$example_root_path" os.chdir(example_root_path) from google.colab import files for relpath in editor_relpaths: s = open(f'{example_root_path}/{relpath}').read() open(f'{example_root_path}/{relpath}', 'w').write( f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s) files.view(f'{example_root_path}/{relpath}') ``` -------------------------------- ### Setup imports and dummy data Source: https://github.com/google/flax/blob/main/docs_nnx/guides/jax_and_nnx_transforms.rst Initializes the environment with necessary imports and random input data for the examples. ```python from flax import nnx import jax x = jax.random.normal(jax.random.key(0), (1, 2)) y = jax.random.normal(jax.random.key(1), (1, 3)) ``` -------------------------------- ### Configure Example Directory and Repository Source: https://github.com/google/flax/blob/main/examples/imagenet/imagenet.ipynb Sets up variables for the example directory path and repository details. These are used for cloning the repository and managing example files. ```python example_directory = 'examples/imagenet' editor_relpaths = ('configs/default.py', 'input_pipeline.py', 'models.py', 'train.py') repo, branch = 'https://github.com/google/flax', 'main' ``` -------------------------------- ### Setup Python Environment and Install Jax on TPU VM Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Commands to set up a Python 3.12 virtual environment using uv, install Jax with TPU support, and verify the installation on the TPU VM. ```bash # Setup Python 3.12 env with UV python -m pip install uv uv venv --python 3.12 /tmp/venv source /tmp/venv/bin/activate uv pip install pip # which python && python -VV && pip --version pip install "jax[tpu]" # Check whether TPUs are available: ``` -------------------------------- ### Install Flax and dependencies Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Clone the Flax repository and install required dependencies for the Gemma example. ```bash git clone --depth=1 --branch=main https://github.com/google/flax cd flax pip install -e . cd examples/gemma pip install -r requirements.txt ``` -------------------------------- ### Install Flax and Dependencies Source: https://github.com/google/flax/blob/main/examples/wmt/README.md Clones the Flax repository and installs required dependencies for the WMT example. ```bash git clone --depth=1 --branch=main https://github.com/google/flax cd flax pip install -e . cd examples/wmt pip install -r requirements.txt ``` -------------------------------- ### Initialize Flax Model and Dataset Setup Source: https://github.com/google/flax/blob/main/docs_nnx/flip/1009-optimizer-api.md Provides the necessary imports, model definition, and data loading utilities to run Flax migration examples. ```python import functools from typing import Callable, Sequence import jax import jax.numpy as jnp import flax import flax.linen as nn import tensorflow as tf import tensorflow_datasets as tfds def pp(features): return { 'image': tf.cast(features['image'], tf.float32) / 255 - 0.5, 'label': features['label'], } class Model(nn.Module): @nn.compact def __call__(self, inputs): x = inputs.reshape([inputs.shape[0], -1]) x = nn.normalization.BatchNorm(True)(x) x = nn.Dense(10)(x) x = nn.log_softmax(x) return x def onehot(labels, num_classes, on_value=1.0, off_value=0.0): x = (labels[..., None] == jnp.arange(num_classes)[None]) x = jax.lax.select( x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value)) return x.astype(jnp.float32) def xent_loss(logits, labels): return -jnp.sum( onehot(labels, num_classes=10) * logits) / labels.size def get_learning_rate(step): return 0.1 model = Model() rng = jax.random.key(0) ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16) batch = next(iter(ds)) variables = model.init(rng, jnp.array(batch['image'][:1])) jax.tree_util.tree_map(jnp.shape, variables) ``` -------------------------------- ### Install Flax and dependencies Source: https://github.com/google/flax/blob/main/examples/imagenet/README.md Clone the Flax repository and install the required dependencies for the ImageNet example. ```shell git clone --depth=1 --branch=main https://github.com/google/flax cd flax pip install -e . cd examples/imagenet pip install -r requirements.txt ``` -------------------------------- ### Colab Environment Setup for Flax Imagenet Example Source: https://github.com/google/flax/blob/main/examples/imagenet/imagenet.ipynb Configures the Colab environment by cloning the Flax repository, copying example files, and optionally mounting Google Drive for persistent storage. It also modifies files to include disclaimers about data persistence. ```python # (If you run this code in Jupyter[lab], then you're already in the # example directory and nothing needs to be done.) #@markdown **Fetch newest Flax, copy example code** #@markdown #@markdown **If you select no** below, then the files will be stored on the #@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will #@markdown be restarted and any changes are lost**. #@markdown #@markdown **If you select yes** below, then you will be asked for your #@markdown credentials to mount your personal Google Drive. In this case, all #@markdown changes you make will be *persisted*, and even if you re-run the #@markdown Colab later on, the files will still be the same (you can of course #@markdown remove directories inside your Drive's `flax/` root if you want to #@markdown manually revert these files). if 'google.colab' in str(get_ipython()): import os os.chdir('/content') # Download Flax repo from Github. if not os.path.isdir('flaxrepo'): !git clone --depth=1 -b $branch $repo flaxrepo # Copy example files & change directory. mount_gdrive = 'no' #@param ['yes', 'no'] if mount_gdrive == 'yes': DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.' from google.colab import drive drive.mount('/content/gdrive') example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}' else: DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!' example_root_path = f'/content/{example_directory}' from IPython import display display.display(display.HTML( f'

{DISCLAIMER}

')) if not os.path.isdir(example_root_path): os.makedirs(example_root_path) !cp -r flaxrepo/$example_directory/* "$example_root_path" os.chdir(example_root_path) from google.colab import files for relpath in editor_relpaths: s = open(f'{example_root_path}/{relpath}').read() open(f'{example_root_path}/{relpath}', 'w').write( f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s) files.view(f'{example_root_path}/{relpath}') ``` -------------------------------- ### Define Example Paths and Repository Source: https://github.com/google/flax/blob/main/examples/seq2seq/seq2seq.ipynb Sets the directory paths and repository information for the seq2seq example. ```python example_directory = 'examples/seq2seq' editor_relpaths = ('train.py', 'input_pipeline.py', 'models.py') repo, branch = 'https://github.com/google/flax', 'main' ``` -------------------------------- ### Fetch and Configure Example Files Source: https://github.com/google/flax/blob/main/examples/mnist/mnist.ipynb Downloads the Flax repository and sets up the environment to edit example files either on the ephemeral VM or a mounted Google Drive. ```python # (If you run this code in Jupyter[lab], then you're already in the # example directory and nothing needs to be done.) #@markdown **Fetch newest Flax, copy example code** #@markdown #@markdown **If you select no** below, then the files will be stored on the #@markdown *ephemeral* Colab VM. **After some time of inactivity, this VM will #@markdown be restarted an any changes are lost**. #@markdown #@markdown **If you select yes** below, then you will be asked for your #@markdown credentials to mount your personal Google Drive. In this case, all #@markdown changes you make will be *persisted*, and even if you re-run the #@markdown Colab later on, the files will still be the same (you can of course #@markdown remove directories inside your Drive's `flax/` root if you want to #@markdown manually revert these files). if 'google.colab' in str(get_ipython()): import os os.chdir('/content') # Download Flax repo from Github. if not os.path.isdir('flaxrepo'): !git clone --depth=1 -b $branch $repo flaxrepo # Copy example files & change directory. mount_gdrive = 'no' #@param ['yes', 'no'] if mount_gdrive == 'yes': DISCLAIMER = 'Note : Editing in your Google Drive, changes will persist.' from google.colab import drive drive.mount('/content/gdrive') example_root_path = f'/content/gdrive/My Drive/flax/{example_directory}' else: DISCLAIMER = 'WARNING : Editing in VM - changes lost after reboot!!' example_root_path = f'/content/{example_directory}' from IPython import display display.display(display.HTML( f'

{DISCLAIMER}

')) if not os.path.isdir(example_root_path): os.makedirs(example_root_path) !cp -r flaxrepo/$example_directory/* "$example_root_path" os.chdir(example_root_path) from google.colab import files for relpath in editor_relpaths: s = open(f'{example_root_path}/{relpath}').read() open(f'{example_root_path}/{relpath}', 'w').write( f'## {DISCLAIMER}\n' + '#' * (len(DISCLAIMER) + 3) + '\n\n' + s) files.view(f'{example_root_path}/{relpath}') ``` -------------------------------- ### Training Loop Setup Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb This section indicates the start of the training loop and suggests using a Colab GPU runtime for faster training. It also notes that TPUs are not used here due to the lack of pmap() distribution. ```python # Training loop # Use a Colab GPU runtime to speed up training. # We don't use TPUs in this Colab because we do not distribute our # training using pmap() - if you're looking for an example using TPUs # checkout the below Colab notebook: ``` -------------------------------- ### Define example paths and repository variables Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb Sets the directory paths and repository information for the OGBG-MolPCBA example. ```python example_directory = 'examples/ogbg_molpcba' editor_relpaths = ('configs/default.py', 'input_pipeline.py', 'models.py', 'train.py') repo, branch = 'https://github.com/google/flax', 'main' ``` -------------------------------- ### Install build dependencies Source: https://github.com/google/flax/blob/main/flaxlib_src/README.md Install the required build tools including meson-python, ninja, and build. ```shell pip install meson-python ninja build ``` -------------------------------- ### Start TensorBoard on TPU Worker Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Installs xprof and launches a TensorBoard instance on a specific TPU worker, forwarding the port to localhost. ```bash gcloud compute tpus tpu-vm ssh --zone $ZONE $TPU_NAME --worker=0 --command="pip install xprof" gcloud compute tpus tpu-vm ssh --zone $ZONE $TPU_NAME --worker=0 --command="tensorboard --logdir=$out_dir --port 7007" -- -L 7007:localhost:7007 ``` -------------------------------- ### Define example directory and file paths Source: https://github.com/google/flax/blob/main/examples/sst2/sst2.ipynb Sets the base directory and relative paths for the SST-2 example files. ```python example_directory = 'examples/sst2' editor_relpaths = ('configs/default.py', 'train.py', 'models.py') ``` -------------------------------- ### Setup JAX Devices and Mesh Source: https://github.com/google/flax/blob/main/docs_nnx/guides/flax_gspmd.md Initializes JAX with a specified number of CPU devices and creates a 2x4 device mesh for distributed computation. This setup is useful for testing or when actual accelerators are not available. ```python from functools import partial import jax from jax import numpy as jnp from jax.sharding import PartitionSpec as P, NamedSharding, AxisType import optax import flax from flax import nnx # Ignore this if you are already running on a TPU or GPU if not jax._src.xla_bridge.backends_are_initialized(): jax.config.update('jax_num_cpu_devices', 8) print(f'You have 8 “fake” JAX devices now: {jax.devices()}') ``` ```python # Create an auto-mode mesh of two dimensions and annotate each axis with a name. auto_mesh = jax.make_mesh((2, 4), ('data', 'model')) ``` -------------------------------- ### Clone and install dependencies Source: https://github.com/google/flax/blob/main/docs_nnx/contributing.md Clones the repository and installs required packages for development, testing, and documentation. ```bash git clone https://github.com/YOUR_USERNAME/flax cd flax pip install -e ".[all,testing,docs]" ``` ```bash uv sync --all-extras ``` -------------------------------- ### Define Example Paths Source: https://github.com/google/flax/blob/main/examples/mnist/mnist.ipynb Sets the directory and file paths for the MNIST example and the repository source. ```python example_directory = 'examples/mnist' editor_relpaths = ('configs/default.py', 'train.py') repo, branch = 'https://github.com/google/flax', 'main' ``` -------------------------------- ### Run Flax Tracing and Lowering Benchmarks Source: https://github.com/google/flax/blob/main/benchmarks/tracing/README.md Commands to install required dependencies and execute performance profiling for Flax examples. ```bash pip install -r benchmarks/tracing/requirements.txt # Benchmark trace and lower timing for all workloads. python tracing_benchmark.py # Profile a single example. python tracing_benchmark.py --example=wmt # Profile just tracing for a single example. python tracing_benchmark.py --example=wmt --mode=trace ``` -------------------------------- ### TPU Setup and Device Check Source: https://github.com/google/flax/blob/main/examples/imagenet/imagenet.ipynb Provides instructions for setting up and using TPUs with Flax, including installing compatible versions of Flax, JAX, and JAXlib. It then checks and lists the available JAX devices. ```python # It's possible to run this Colab with TPUs: # 1. change runtime type to TPU # 2. install compatible flax version: `!pip install flax==0.6.4 jax==0.3.25 jaxlib==0.3.25` # 3. uncomment lines below # import flax, jax, jax.tools.colab_tpu # jax.tools.colab_tpu.setup_tpu() import jax jax.devices() ``` -------------------------------- ### Install Dependencies Source: https://github.com/google/flax/blob/main/docs_nnx/examples/gemma.md Install the required libraries for Flax NNX and model management. ```bash ! pip install --no-deps -U flax ! pip install jaxtyping kagglehub treescope ``` -------------------------------- ### Verify JAX installation Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Check if JAX is correctly installed and identifying devices. ```bash python3 -c "import jax; print(jax.devices())" ``` -------------------------------- ### Start TensorBoard Monitoring Source: https://github.com/google/flax/blob/main/examples/lm1b/README.md Starts TensorBoard to monitor training logs, assuming logs are stored in the specified directory. ```bash tensorboard --logdir=$HOME/logs ``` -------------------------------- ### Run training and monitoring Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Execute the training script and start TensorBoard for monitoring. ```bash python3 main.py --workdir=$HOME/logs/gemma_lm1b_256 --config.per_device_batch_size=32 ``` ```bash tensorboard --logdir=$HOME/logs ``` -------------------------------- ### Install Flax from head Source: https://github.com/google/flax/blob/main/tests/import_test.ipynb Installs the latest development version of Flax directly from the GitHub repository. ```bash !pip install git+https://github.com/google/flax.git ``` -------------------------------- ### Run MNIST Example Source: https://github.com/google/flax/blob/main/examples/mnist/README.md Basic command to run the MNIST example. Specify the work directory and configuration file. ```shell python main.py --workdir=/tmp/mnist --config=configs/default.py ``` -------------------------------- ### Launch MNIST Example on GCE Source: https://github.com/google/flax/blob/main/examples/cloud/README.md Launches the MNIST training example on a Google Compute Engine VM. Ensure you have set the PROJECT and GCS_BUCKET environment variables. ```shell python examples/cloud/launch_gce.py \ --project=$PROJECT \ --zone=us-west1-a \ --machine_type=n2-standard-2 \ --gcs_workdir_base=gs://$GCS_BUCKET/workdir_base \ --repo=${REPO:-https://github.com/google/flax} \ --branch=${BRANCH:-main} \ --example=mnist \ --args='--config=configs/default.py' \ --name=default ``` -------------------------------- ### Start TensorBoard for Training Monitoring Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb This code snippet loads the TensorBoard extension and starts a TensorBoard server. It's useful for live monitoring of training progress, especially when using Colab. ```python # Start TensorBoard # Get a live update during training - use the "refresh" button! # (In Jupyter[lab] start "tensorboard" in the local directory instead.) if 'google.colab' in str(get_ipython()): %load_ext tensorboard %tensorboard --logdir=. ``` -------------------------------- ### Install JAX on TPU VM Source: https://github.com/google/flax/blob/main/examples/lm1b/README.md Installs JAX with TPU support on the connected TPU VM, specifying a version compatible with TPU releases. ```bash pip install "jax[tpu]>=0.2.16" \ -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` -------------------------------- ### Start training on TPU Source: https://github.com/google/flax/blob/main/examples/imagenet/README.md Execute the training script with the appropriate configuration and backend target. ```shell export TFDS_DATA_DIR=gs://$GCS_TFDS_BUCKET/datasets python3 main.py --workdir=$HOME/logs/imagenet_tpu --config=configs/tpu.py \ --jax_backend_target="grpc://192.168.0.2:8470" ``` -------------------------------- ### Memory-Efficient Partial Initialization Setup Source: https://github.com/google/flax/blob/main/docs_nnx/guides/surgery.md Initializes the state for memory-efficient partial initialization. ```python # Some pretrained model state old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) ``` -------------------------------- ### Flax NNX Training Step Execution Example Source: https://github.com/google/flax/blob/main/docs_nnx/migrating/haiku_to_flax.rst Example of how to execute a Flax NNX training step with a model instance, inputs, and labels. The model is updated in-place. ```python sample_x = jnp.ones((1, 784)) train_step(model, sample_x, jnp.ones((1,), dtype=jnp.int32)) ``` -------------------------------- ### Launch ImageNet Example on GCE Source: https://github.com/google/flax/blob/main/examples/cloud/README.md Launches the ImageNet training example on a Google Compute Engine VM with specified hardware accelerators. Ensure PROJECT, GCS_BUCKET, and GCS_TFDS_BUCKET environment variables are set. ```shell python examples/cloud/launch_gce.py \ --project=$PROJECT \ --zone=us-west1-a \ --machine_type=n1-standard-96 \ --accelerator_type=nvidia-tesla-v100 --accelerator_count=8 \ --gcs_workdir_base=gs://$GCS_BUCKET/workdir_base \ --tfds_data_dir=gs://$GCS_TFDS_BUCKET/datasets \ --repo=${REPO:-https://github.com/google/flax} \ --branch=${BRANCH:-main} \ --example=imagenet \ --args='--config=configs/v100_x8_mixed_precision.py' \ --name=v100_x8_mixed_precision ``` -------------------------------- ### Install Jax with CUDA and Flax Dependencies Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Commands to install Jax with CUDA support, Flax, and example-specific requirements for GPU training. Includes a check for GPU availability. ```bash pip install jax[cuda13] # Check whether GPUs are available: # python3 -c "import jax; print(jax.devices())" git clone --depth=1 --branch=main https://github.com/google/flax cd flax pip install -e . cd examples/gemma pip install -r requirements.txt ``` -------------------------------- ### Configure the build Source: https://github.com/google/flax/blob/main/flaxlib_src/README.md Set up the build environment by installing dependencies via meson wrap and initializing the build directory. ```shell mkdir -p subprojects meson wrap install robin-map meson wrap install nanobind meson setup builddir ``` -------------------------------- ### Setup dependencies and checkpoint directory Source: https://github.com/google/flax/blob/main/docs_nnx/guides/checkpointing.md Imports necessary libraries and initializes a temporary directory for checkpoint storage. ```python from flax import nnx import orbax.checkpoint as ocp import jax from jax import numpy as jnp import numpy as np ckpt_dir = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/') ``` -------------------------------- ### Define and Train a CNN with Flax NNX Source: https://context7.com/google/flax/llms.txt Full example demonstrating CNN definition, initialization, and optimizer setup. ```python from flax import nnx import jax import jax.numpy as jnp import optax # Define a CNN model class CNN(nnx.Module): def __init__(self, rngs: nnx.Rngs): self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) self.linear1 = nnx.Linear(3136, 256, rngs=rngs) self.linear2 = nnx.Linear(256, 10, rngs=rngs) self.dropout = nnx.Dropout(0.5, rngs=rngs) def __call__(self, x): x = nnx.relu(self.conv1(x)) x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nnx.relu(self.conv2(x)) x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape(x.shape[0], -1) # Flatten x = nnx.relu(self.linear1(x)) x = self.dropout(x) x = self.linear2(x) return x # Initialize model and optimizer model = CNN(rngs=nnx.Rngs(0)) optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param) ``` -------------------------------- ### Initialize JAX Environment for Multi-Device Simulation Source: https://github.com/google/flax/blob/main/docs_nnx/guides/flax_gspmd.ipynb Sets up a multi-device environment using fake CPU devices for testing purposes. ```python from functools import partial import jax from jax import numpy as jnp from jax.sharding import PartitionSpec as P, NamedSharding, AxisType import optax import flax from flax import nnx # Ignore this if you are already running on a TPU or GPU if not jax._src.xla_bridge.backends_are_initialized(): jax.config.update('jax_num_cpu_devices', 8) print(f'You have 8 “fake” JAX devices now: {jax.devices()}') ``` -------------------------------- ### Initialize Model with Partitioned Weights Source: https://github.com/google/flax/blob/main/docs_nnx/flip/2434-general-metadata.md Demonstrates initializing a model with partitioned weights and inspecting the resulting variable structure, including the Partitioned metadata. ```python variables = partitioned_dense.init(rng, jnp.ones((4,))) jax.tree.map(np.shape, variables) # => {"params": {"kernel": Partitioned(value=(4, 8), names=(None, "data")), bias: (8,)}} ``` -------------------------------- ### Setup imports and environment Source: https://github.com/google/flax/blob/main/docs_nnx/guides/bridge_guide.ipynb Initial configuration and imports required for using Flax NNX and Linen together. ```python import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' from flax import nnx from flax import linen as nn from flax.nnx import bridge import jax from jax import numpy as jnp from jax.experimental import mesh_utils from typing import * ``` -------------------------------- ### Install GCSFuse Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Configures the GCSFuse repository and installs the necessary packages. ```bash export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` echo "deb [signed-by=/usr/share/keyrings/cloud.google.asc] https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | tee /etc/apt/sources.list.d/gcsfuse.list curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | tee /usr/share/keyrings/cloud.google.asc apt-get update apt-get install -y fuse gcsfuse --no-install-recommends ``` -------------------------------- ### Install SST-2 dependencies Source: https://github.com/google/flax/blob/main/examples/sst2/sst2.ipynb Installs required packages from the requirements.txt file. ```bash # Install SST-2 dependencies. !pip install -q -r requirements.txt ``` -------------------------------- ### Configure multi-host TPU environment Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Set environment variables and create a startup script for multi-host TPU training. ```bash export ZONE=us-east5-a export ACCELERATOR_TYPE=v5p-32 export RUNTIME_VERSION=v2-alpha-tpuv5 export TPU_NAME=flax-gemma-lm1b-${ACCELERATOR_TYPE} export GCS_OUTPUT_BUCKET=flax-gemma-example-training export GCS_DATA_BUCKET=flax-lm1b-arrayrecords cat << EOF > /tmp/example_startup.sh #!/bin/bash set -xeu python -m pip install uv uv venv --python 3.12 /tmp/venv source /tmp/venv/bin/activate uv pip install pip echo "source /tmp/venv/bin/activate" > /root/.bashrc # Install JAX, FLAX and other deps python -m pip install jax[tpu] python -m pip install \ "absl-py~=2.2" \ "clu==0.0.12" \ "mlcroissant~=1.0" \ "numpy~=2.3" \ "optax~=0.2" \ "sentencepiece~=0.2" \ "jaxtyping~=0.3" \ "tensorflow-cpu~=2.20" \ "tensorboard~=2.20" \ "tensorflow-datasets~=4.9" \ "grain~=0.2" \ "orbax-checkpoint[gcp]~=0.11" \ "google-cloud-storage" cd /root git clone --depth=1 --branch=main https://github.com/google/flax cd flax python -m pip install -e . # Install gcsfuse ``` -------------------------------- ### Install CLU and Flax Source: https://github.com/google/flax/blob/main/examples/seq2seq/seq2seq.ipynb Installs the required CLU and Flax libraries in the environment. ```python # Install CLU & Flax. !pip install -q clu flax ``` -------------------------------- ### Install Flax dependencies Source: https://github.com/google/flax/blob/main/docs_nnx/nnx_basics.md Use pip to install the latest version of Flax. ```bash # ! pip install -U flax ``` -------------------------------- ### Install Dependencies Source: https://github.com/google/flax/blob/main/examples/mnist/mnist.ipynb Installs the ml-collections library and the latest version of Flax from the GitHub repository. ```python # Install ml-collections & latest Flax version from Github. !pip install -q ml-collections git+https://github.com/google/flax ``` -------------------------------- ### Run training with default configuration Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/README.md Execute the training script using the default configuration file. ```shell python main.py --workdir=./ogbg_molpcba --config=configs/default.py ``` -------------------------------- ### Prepare Output Directories Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Creates output directories on the TPU VM using worker 0. ```bash export out_name=gemma3-1b_lm1b_run-$ACCELERATOR_TYPE-$(date -u +%Y%m%d-%H%M) export out_dir=/root/logs/$out_name export chpt_bucket=gs://$GCS_OUTPUT_BUCKET/$out_name/checkpoint gcloud compute tpus tpu-vm ssh --zone $ZONE $TPU_NAME --worker=0 --command="export out_dir=$out_dir && mkdir -p $out_dir" gcloud compute tpus tpu-vm ssh --zone $ZONE $TPU_NAME --worker=0 --command="ls /root/logs" ``` -------------------------------- ### Install dependencies Source: https://github.com/google/flax/blob/main/examples/ogbg_molpcba/ogbg_molpcba.ipynb Installs required libraries including clu, ml-collections, Flax, tensorflow_datasets, and jraph. ```python # Install clu, ml-collections, latest Flax version, and tensorflow_datasets. !pip install -U -q clu ml-collections git+https://github.com/google/flax tfds_nightly jraph ``` -------------------------------- ### Create a virtual environment Source: https://github.com/google/flax/blob/main/docs_nnx/contributing.md Initializes a virtual environment to isolate project dependencies. ```bash python3 -m virtualenv env . env/bin/activate ``` -------------------------------- ### Create TPU VM Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Provisions a new TPU VM with a specified startup script. ```bash gcloud compute tpus tpu-vm create $TPU_NAME --spot \ --zone $ZONE \ --accelerator-type=$ACCELERATOR_TYPE \ --version=$RUNTIME_VERSION \ --metadata-from-file=startup-script=/tmp/example_startup.sh ``` -------------------------------- ### Install Flax Source: https://github.com/google/flax/blob/main/docs_nnx/mnist_tutorial.md Use pip to install Flax and JAX with CUDA support in a Python environment. ```python # !pip install -U "jax[cuda12]" # !pip install -U flax ``` -------------------------------- ### Initialize Imports and Environment Source: https://github.com/google/flax/blob/main/docs_nnx/guides/surgery.ipynb Standard imports and random key initialization required for Flax NNX examples. ```python from typing import * from pprint import pprint import functools import jax from jax import lax, numpy as jnp, tree_util as jtu from jax.sharding import PartitionSpec, Mesh, NamedSharding from jax.experimental import mesh_utils import flax from flax import nnx import flax.traverse_util import numpy as np import orbax.checkpoint as orbax key = jax.random.key(0) ``` -------------------------------- ### Install gcsfuse on TPU VM Source: https://github.com/google/flax/blob/main/examples/gemma/README.md Configure and install gcsfuse to mount Google Cloud Storage buckets. ```bash # Install gcsfuse export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` echo "deb [signed-by=/usr/share/keyrings/cloud.google.asc] https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | tee /etc/apt/sources.list.d/gcsfuse.list curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | tee /usr/share/keyrings/cloud.google.asc apt-get update apt-get install -y fuse gcsfuse --no-install-recommends ``` -------------------------------- ### Run PPO with custom configuration Source: https://github.com/google/flax/blob/main/examples/ppo/README.md Execute the PPO training script with a specified configuration file and custom working directory. ```bash python ppo_main.py --config=configs/default.py --workdir=/my_fav_directory ``` ```bash python ppo_main.py --config=configs/default.py --config.game=Seaquest --config.total_frames=20000000 --config.decaying_lr_and_clip_param=False --workdir=/tmp/seaquest ``` -------------------------------- ### Install Flax Source: https://github.com/google/flax/blob/main/README.md Use this command to install the Flax library from PyPI. Ensure you have Python 3.8 or later. ```bash pip install flax ``` -------------------------------- ### Initialize Toy Model and Loss Function Source: https://github.com/google/flax/blob/main/docs_nnx/guides/optimization_cookbook.md Sets up a basic sequential model and a mean squared error loss function for training demonstrations. ```python import jax from flax import nnx jax.config.update('jax_num_cpu_devices', 8) import jax.numpy as jnp import functools as ft import optax import matplotlib.pyplot as plt lecun_normal = jax.nn.initializers.lecun_normal() rngs = nnx.Rngs(0) def make_model(rngs, init=lecun_normal): return nnx.Sequential( nnx.Linear(2,8, rngs=rngs, kernel_init=init), nnx.Linear(8,8, rngs=rngs, kernel_init=init)) def loss_fn(model, x, y): return jnp.mean((model(x) - y) ** 2) ``` -------------------------------- ### Install Flax with All Dependencies Source: https://github.com/google/flax/blob/main/README.md Install Flax along with a comprehensive set of additional dependencies, such as matplotlib, using this command. ```bash pip install "flax[all]" ``` -------------------------------- ### Naive Partial Initialization Example Source: https://github.com/google/flax/blob/main/docs_nnx/guides/surgery.ipynb This snippet shows a naive approach to partial initialization by initializing the entire model and then swapping pre-trained parameters. It highlights potential memory overhead due to intermediate parameter creation. ```python # Some pretrained model state old_state = nnx.state(TwoLayerMLP(4, rngs=nnx.Rngs(0))) simple_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(42))) print(f'Number of jax arrays in memory at start: {len(jax.live_arrays())}') # In this line, extra kernel and bias is created inside the new LoRALinear! # They are wasted, because you are going to use the kernel and bias in `old_state` anyway. simple_model.linear1 = nnx.LoRALinear(4, 4, lora_rank=3, rngs=nnx.Rngs(42)) print(f'Number of jax arrays in memory midway: {len(jax.live_arrays())}' ' (4 new created in LoRALinear - kernel, bias, lora_a & lora_b)') nnx.update(simple_model, old_state) print(f'Number of jax arrays in memory at end: {len(jax.live_arrays())}' ' (2 discarded - only lora_a & lora_b are used in model)') ``` -------------------------------- ### Bidirectional RNN Usage Example Source: https://github.com/google/flax/blob/main/docs_nnx/flip/2396-rnn.md Example of how to instantiate and use the Bidirectional RNN combinator with different RNN cell types. ```python forward_rnn = nn.RNN(nn.LSTMCell(), cell_size=32) backward_rnn = nn.RNN(nn.GRUCell(), cell_size=32) # Bidirectional combinator. bi_rnn = nn.Bidirectional(forward_rnn, backward_rnn) # Encodes a batch of input sequences in both directions. carry, outputs = bi_rnn(inputs, seq_lengths) ``` -------------------------------- ### Haiku Training Step Execution Example Source: https://github.com/google/flax/blob/main/docs_nnx/migrating/haiku_to_flax.rst Example of how to execute a Haiku training step with a PRNG key, parameters, inputs, and labels. ```python train_step(jax.random.key(0), params, sample_x, jnp.ones((1,), dtype=jnp.int32)) ``` -------------------------------- ### Install Jupytext Source: https://github.com/google/flax/blob/main/docs_nnx/contributing.md Install a specific version of jupytext required for syncing notebooks. Ensure the version matches the one specified in .pre-commit-config.yaml. ```bash pip install jupytext==1.13.8 ``` -------------------------------- ### Start Training Job Source: https://github.com/google/flax/blob/main/examples/wmt/README.md Executes the training script with specified work directory and batch size. ```bash python3 main.py --workdir=$HOME/logs/wmt_256 \ --config.per_device_batch_size=32 \ --jax_backend_target="grpc://192.168.0.2:8470" ```