### Setup Brax and Dependencies Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb Installs Brax and necessary libraries. It's recommended to use a GPU runtime for better performance. This code block also handles TPU setup if available. ```python #@markdown ## ⚠️ PLEASE NOTE: #@markdown This colab runs best using a GPU runtime. From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown. import functools import jax import os from datetime import datetime from jax import numpy as jp import matplotlib.pyplot as plt from IPython.display import HTML, clear_output try: import brax except (ImportError, ModuleNotFoundError): !pip install git+https://github.com/google/brax.git@main clear_output() import brax import flax from brax import envs from brax.io import model from brax.io import json from brax.io import html from brax.training.agents.ppo import train as ppo from brax.training.agents.sac import train as sac if 'COLAB_TPU_ADDR' in os.environ: from jax.tools import colab_tpu colab_tpu.setup_tpu() ``` -------------------------------- ### Install Brax from Source Source: https://github.com/google/brax/blob/main/README.md Clone the repository, navigate to the directory, set up a virtual environment, and install Brax in editable mode. ```bash python3 -m venv env source env/bin/activate pip install --upgrade pip pip install -e . ``` -------------------------------- ### Install MuJoCo, MJX, and Brax Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb Installs the necessary libraries for MuJoCo, MJX, and Brax. Run this cell first to set up your environment. ```python !pip install mujoco !pip install mujoco_mjx !pip install brax ``` -------------------------------- ### Install Brax from PyPI Source: https://github.com/google/brax/blob/main/README.md Use these commands to create a virtual environment, activate it, upgrade pip, and install Brax from PyPI. ```bash python3 -m venv env source env/bin/activate pip install --upgrade pip pip install brax ``` -------------------------------- ### Setup and Imports for Brax Source: https://github.com/google/brax/blob/main/notebooks/basics.ipynb Installs Brax if not found and imports necessary JAX and Matplotlib libraries for physics simulation and visualization. ```python #@title Colab setup and imports import jax from jax import numpy as jp from matplotlib.lines import Line2D from matplotlib.patches import Circle import matplotlib.pyplot as plt try: import brax except (ImportError, ModuleNotFoundError): from IPython.display import clear_output !pip install git+https://github.com/google/brax.git@main clear_output() import brax ``` -------------------------------- ### Install Brax Source: https://context7.com/google/brax/llms.txt Install Brax using pip from PyPI or conda-forge. Alternatively, clone the repository and install from source. ```bash # From PyPI pip install brax # From conda-forge conda install -c conda-forge brax # From source git clone https://github.com/google/brax && cd brax pip install -e . ``` -------------------------------- ### Train a Model with Brax Source: https://github.com/google/brax/blob/main/README.md Execute this command to start training a model using Brax. Ensure you have the necessary GPU support installed if training on an NVidia GPU. ```bash learn ``` -------------------------------- ### ES Training Setup Source: https://context7.com/google/brax/llms.txt Configures and executes Evolution Strategies (ES) training. This method optimizes a policy population using estimated gradients from perturbed rollouts and supports canonical and OpenAI ES variants. Environment setup and training parameters are specified. ```python from brax import envs from brax.training.agents.es import train as es env = envs.get_environment('swimmer', backend='spring') make_inference_fn, params, metrics = es.train( environment=env, num_timesteps=5_000_000, episode_length=1000, action_repeat=1, num_envs=64, learning_rate=0.01, l2coeff=0.005, noise_std=0.01, normalize_observations=True, center_fitness=True, ) ``` -------------------------------- ### Check MuJoCo Installation and GPU Setup Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb Verifies the MuJoCo installation and configures the environment for GPU rendering using EGL. This is crucial for performance and requires a GPU runtime. ```python #@title Check if MuJoCo installation was successful from google.colab import files import distutils.util import os import subprocess if subprocess.run('nvidia-smi').returncode: raise RuntimeError( 'Cannot communicate with GPU. ' 'Make sure you are using a GPU Colab runtime. ' 'Go to the Runtime menu and select Choose runtime type.') # Add an ICD config so that glvnd can pick up the Nvidia EGL driver. # This is usually installed as part of an Nvidia driver package, but the Colab # kernel doesn't install its driver via APT, and as a result the ICD is missing. # (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md) NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json' if not os.path.exists(NVIDIA_ICD_CONFIG_PATH): with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f: f.write("""{ "file_format_version" : "1.0.0", "ICD" : { "library_path" : "libEGL_nvidia.so.0" } } """) # Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs xla_flags = os.environ.get('XLA_FLAGS', '') xla_flags += ' --xla_gpu_triton_gemm_any=True' os.environ['XLA_FLAGS'] = xla_flags # Configure MuJoCo to use the EGL rendering backend (requires GPU) print('Setting environment variable to use GPU rendering:') %env MUJOCO_GL=egl try: print('Checking that the installation succeeded:') import mujoco mujoco.MjModel.from_xml_string('') except Exception as e: raise e from RuntimeError( 'Something went wrong during installation. Check the shell output above ' 'for more information.\n' 'If using a hosted Colab runtime, make sure you enable GPU acceleration ' 'by going to the Runtime menu and selecting "Choose runtime type".') print('Installation successful.') ``` -------------------------------- ### PPO Training Setup and Execution Source: https://context7.com/google/brax/llms.txt Sets up and trains a Proximal Policy Optimization (PPO) agent. Requires environment initialization and defines a progress function for monitoring training. The trained policy can then be used for inference. ```python import functools import jax from brax import envs from brax.training.agents.ppo import train as ppo env = envs.get_environment('ant', backend='mjx') def progress(step, metrics): print(f'step={step:,} reward={metrics["eval/episode_reward"]:.2f}') make_inference_fn, params, metrics = ppo.train( environment=env, num_timesteps=50_000_000, num_envs=2048, episode_length=1000, learning_rate=3e-4, entropy_cost=1e-2, discounting=0.97, unroll_length=5, batch_size=512, num_minibatches=32, num_updates_per_batch=4, normalize_observations=True, reward_scaling=10.0, clipping_epsilon=0.3, gae_lambda=0.95, action_repeat=1, progress_fn=progress, # Optional: checkpoint_logdir='/tmp/ppo_ant' ) # Use the trained policy for inference inference_fn = make_inference_fn(params, deterministic=True) rng = jax.random.PRNGKey(0) obs = jax.numpy.zeros(env.observation_size) action, _ = inference_fn(obs, rng) print('action:', action) ``` -------------------------------- ### Install Brax and Dependencies Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb Installs the Brax library and its dependencies, including upgrading OpenAI Gym. This is typically run at the beginning of a Colab notebook. ```python !pip install --upgrade gym --quiet ``` -------------------------------- ### SAC Training Setup and Execution Source: https://context7.com/google/brax/llms.txt Configures and initiates training for a Soft Actor-Critic (SAC) agent. This off-policy algorithm uses a replay buffer and is suitable for continuous-action tasks. It requires environment setup and a progress callback. ```python import jax from brax import envs from brax.training.agents.sac import train as sac env = envs.get_environment('halfcheetah', backend='generalized') def progress(step, metrics): print(f'step={step:,} reward={metrics["eval/episode_reward"]:.2f}') make_inference_fn, params, metrics = sac.train( environment=env, num_timesteps=1_000_000, episode_length=1000, num_envs=128, num_eval_envs=128, learning_rate=6e-4, discounting=0.99, batch_size=256, tau=0.005, reward_scaling=30.0, normalize_observations=False, min_replay_size=8192, grad_updates_per_step=1, deterministic_eval=True, seed=0, progress_fn=progress, ) inference_fn = make_inference_fn(params, deterministic=True) obs = jax.numpy.zeros(env.observation_size) action, _ = inference_fn(obs, jax.random.PRNGKey(1)) ``` -------------------------------- ### Environment Creation and Usage Source: https://context7.com/google/brax/llms.txt Demonstrates how to get and create Brax environments, including options for wrappers, batching, and backends. It also shows basic usage like resetting and stepping the environment, and accessing environment properties. ```APIDOC ## Environment API (`brax.envs`) Brax ships 12 built-in environments. All environments implement the `Env` interface with `reset(rng)` → `State` and `step(state, action)` → `State`. The `create` factory adds episode management, batching, and auto-reset wrappers. ```python import jax from brax import envs # Available envs: 'ant', 'humanoid', 'hopper', 'halfcheetah', 'walker2d', # 'swimmer', 'reacher', 'pusher', 'inverted_pendulum', # 'inverted_double_pendulum', 'humanoidstandup', 'fast' # Simple creation (no wrappers) env = envs.get_environment('ant', backend='mjx') # Full creation with wrappers (for training) env = envs.create( env_name='ant', episode_length=1000, action_repeat=1, auto_reset=True, batch_size=None, # set to int for batched env backend='generalized', # 'mjx' | 'generalized' | 'positional' | 'spring' ) rng = jax.random.PRNGKey(0) state = jax.jit(env.reset)(rng) print(state.obs.shape) # (observation_size,) print(state.reward) # scalar print(state.done) # scalar bool print(state.metrics) # dict of per-step diagnostic scalars print(env.observation_size) print(env.action_size) # Step the environment action = jax.random.uniform(rng, (env.action_size,), minval=-1, maxval=1) state = jax.jit(env.step)(state, action) # Register a custom environment from brax.envs.base import PipelineEnv, State class MyEnv(PipelineEnv): def reset(self, rng): ... def step(self, state, action): ... envs.register_environment('my_env', MyEnv) my_env = envs.get_environment('my_env') ``` ``` -------------------------------- ### Domain Randomization Setup in Brax Source: https://context7.com/google/brax/llms.txt Demonstrates how to set up domain randomization using `DomainRandomizationVmapWrapper` or `training.wrap` with a custom randomization function. This enables training policies that are robust to variations in physical parameters. ```python import jax import jax.numpy as jp from brax import envs from brax.envs.wrappers.training import DomainRandomizationVmapWrapper, wrap base_env = envs.get_environment('ant', backend='generalized') def randomize(sys, rng): """Return (batched_sys, in_axes) for 64-env randomization.""" n = 64 rngs = jax.random.split(rng, n) # Randomize link masses ±20% mass_scales = jax.vmap( lambda r: jax.random.uniform(r, shape=(sys.num_links(),), minval=0.8, maxval=1.2) )(rngs) sys_v = jax.vmap(lambda s: sys.tree_replace({'link.inertia.mass': sys.link.inertia.mass * s}))(mass_scales) # in_axes: 0 where batched, None where not batched in_axes = jax.tree_util.tree_map(lambda _: 0, sys_v) return sys_v, in_axes rand_env = DomainRandomizationVmapWrapper(base_env, randomize) # Or use training.wrap with the randomization function: from brax.training.agents.ppo import train as ppo make_inference_fn, params, _ = ppo.train( environment=base_env, num_timesteps=10_000_000, num_envs=64, episode_length=1000, randomization_fn=randomize, ) ``` -------------------------------- ### Initialize Environment for Visualization Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb Initializes the environment for policy visualization. It gets the environment instance and creates JIT-compiled versions of the reset and step functions for efficiency. ```python eval_env = envs.get_environment(env_name) jit_reset = jax.jit(eval_env.reset) jit_step = jax.jit(eval_env.step) ``` -------------------------------- ### Install Brax from Conda/Mamba Source: https://github.com/google/brax/blob/main/README.md Install Brax using Conda or Mamba by specifying the conda-forge channel. ```bash conda install -c conda-forge brax # s/conda/mamba for mamba ``` -------------------------------- ### Execute Training and Get Inference Function Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb Initiates the training process using the defined training function and environment, and obtains a function for making inferences. ```python make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress) ``` -------------------------------- ### Get Barkour Environment Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb Retrieves an instance of the Barkour environment using its registered name. This is the standard way to initialize an environment for use. ```python env_name = 'barkour' env = envs.get_environment(env_name) ``` -------------------------------- ### Load Brax System from MJCF Source: https://github.com/google/brax/blob/main/notebooks/basics.ipynb Loads a Brax physics system defined in MJCF format. This example creates a simple scene with a sphere and a plane. ```python from brax.io import mjcf ball = mjcf.loads( """ " ) ``` -------------------------------- ### PipelineEnv - Building Custom Environments Source: https://context7.com/google/brax/llms.txt Provides a base class `PipelineEnv` for creating custom physics-based environments. It outlines the structure for implementing `reset` and `step` methods, and includes an example of a `BallBalance` environment. ```APIDOC ## `PipelineEnv` — Building Custom Environments `PipelineEnv` is the base class for all built-in environments. Subclass it to build custom physics-based environments with selectable backends. ```python from brax.envs.base import PipelineEnv, State from brax.io import mjcf import jax import jax.numpy as jp class BallBalance(PipelineEnv): def __init__(self, backend='mjx', **kwargs): xml = """ """ sys = mjcf.loads(xml) super().__init__(sys, backend=backend, n_frames=4, **kwargs) def reset(self, rng: jax.Array) -> State: q = self.sys.init_q + jax.random.uniform(rng, (self.sys.q_size(),), minval=-0.01, maxval=0.01) qd = jp.zeros(self.sys.qd_size()) pipeline_state = self.pipeline_init(q, qd) obs = self._get_obs(pipeline_state) return State(pipeline_state=pipeline_state, obs=obs, reward=jp.float32(0), done=jp.float32(0)) def step(self, state: State, action: jax.Array) -> State: pipeline_state = self.pipeline_step(state.pipeline_state, action) obs = self._get_obs(pipeline_state) reward = -jp.sum(jp.square(pipeline_state.x.pos[0, :2])) # penalize XY offset done = jp.float32(jp.abs(pipeline_state.x.pos[0, 2]) < 0.05) # fell over return state.replace(pipeline_state=pipeline_state, obs=obs, reward=reward, done=done) def _get_obs(self, ps): return jp.concatenate([ps.q, ps.qd]) env = BallBalance(backend='mjx') state = jax.jit(env.reset)(jax.random.PRNGKey(42)) print(state.obs.shape, env.observation_size, env.action_size) ``` ``` -------------------------------- ### Import Brax and Helper Modules Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb Imports core Brax libraries, environment wrappers, training agents (PPO), and common Python modules for data manipulation, visualization, and system operations. It includes a fallback to install Brax from source if it's not found. ```python #@title Import Brax and some helper modules from IPython.display import clear_output import collections from datetime import datetime import functools import math import os import time from typing import Any, Callable, Dict, Optional, Sequence try: import brax except (ImportError, ModuleNotFoundError): !pip install git+https://github.com/google/brax.git@main clear_output() import brax from brax import envs from brax.envs.wrappers import gym as gym_wrapper from brax.envs.wrappers import torch as torch_wrapper from brax.io import metrics from brax.training.agents.ppo import train as ppo import gym import matplotlib.pyplot as plt import numpy as np import torch from torch import nn from torch import optim import torch.nn.functional as F ``` -------------------------------- ### ARS Training Setup Source: https://context7.com/google/brax/llms.txt Initializes and runs Augmented Random Search (ARS) training. This derivative-free method is effective for locomotion tasks and does not require neural network gradients. It involves setting up the environment and training parameters. ```python from brax import envs from brax.training.agents.ars import train as ars env = envs.get_environment('hopper', backend='positional') make_inference_fn, params, metrics = ars.train( environment=env, num_timesteps=10_000_000, episode_length=1000, action_repeat=1, num_envs=32, learning_rate=0.015, normalize_observations=True, number_of_directions=60, step_size=0.03, ) ``` -------------------------------- ### Brax Trainer Configuration and Execution Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb This code configures and runs the built-in Brax trainer for an 'ant' environment using the 'spring' backend. It's designed for efficient training and includes parameters for environment setup, training duration, and optimization. Note the use of `progress_fn` for monitoring. ```python train_sps = [] def progress(_, metrics): if 'training/sps' in metrics: train_sps.append(metrics['training/sps']) ppo.train( environment=envs.create(env_name='ant', backend='spring'), num_timesteps = 30_000_000, num_evals = 10, reward_scaling = .1, episode_length = 1000, normalize_observations = True, action_repeat = 1, unroll_length = 5, num_minibatches = 32, num_updates_per_batch = 4, discounting = 0.97, learning_rate = 3e-4, entropy_cost = 1e-2, num_envs = 2048, batch_size = 1024, progress_fn = progress) print(f'train steps/sec: {np.mean(train_sps[1:])}') ``` -------------------------------- ### Defining a Custom Brax Environment with `PipelineEnv` Source: https://context7.com/google/brax/llms.txt Implement a custom physics environment by subclassing `PipelineEnv`. Define the environment's XML, reset logic, step logic, and observation calculation. This example shows a 'BallBalance' environment. ```python from brax.envs.base import PipelineEnv, State from brax.io import mjcf import jax import jax.numpy as jp class BallBalance(PipelineEnv): def __init__(self, backend='mjx', **kwargs): xml = """ """ sys = mjcf.loads(xml) super().__init__(sys, backend=backend, n_frames=4, **kwargs) def reset(self, rng: jax.Array) -> State: q = self.sys.init_q + jax.random.uniform(rng, (self.sys.q_size(),), minval=-0.01, maxval=0.01) qd = jp.zeros(self.sys.qd_size()) pipeline_state = self.pipeline_init(q, qd) obs = self._get_obs(pipeline_state) return State(pipeline_state=pipeline_state, obs=obs, reward=jp.float32(0), done=jp.float32(0)) def step(self, state: State, action: jax.Array) -> State: pipeline_state = self.pipeline_step(state.pipeline_state, action) obs = self._get_obs(pipeline_state) reward = -jp.sum(jp.square(pipeline_state.x.pos[0, :2])) # penalize XY offset done = jp.float32(jp.abs(pipeline_state.x.pos[0, 2]) < 0.05) # fell over return state.replace(pipeline_state=pipeline_state, obs=obs, reward=reward, done=done) def _get_obs(self, ps): return jp.concatenate([ps.q, ps.qd]) env = BallBalance(backend='mjx') state = jax.jit(env.reset)(jax.random.PRNGKey(42)) print(state.obs.shape, env.observation_size, env.action_size) ``` -------------------------------- ### Create and Train Environment Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb Creates an environment, JIT compiles reset, step, and inference functions, and runs a training loop for 1000 steps. Displays the rendered rollout. ```python env = envs.create(env_name=env_name, backend=backend) jit_env_reset = jax.jit(env.reset) jit_env_step = jax.jit(env.step) jit_inference_fn = jax.jit(inference_fn) rollout = [] rng = jax.random.PRNGKey(seed=1) state = jit_env_reset(rng=rng) for _ in range(1000): rollout.append(state.pipeline_state) act_rng, rng = jax.random.split(rng) act, _ = jit_inference_fn(state.obs, act_rng) state = jit_env_step(state, act) HTML(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout)) ``` -------------------------------- ### Training with PPO Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb This snippet initiates the training process for a selected environment using the Proximal Policy Optimization (PPO) algorithm. It requires the environment and backend to be set up prior to execution. ```python #@title Training ``` -------------------------------- ### Creating Brax Environments Source: https://context7.com/google/brax/llms.txt Create Brax environments using `envs.get_environment` for simple creation or `envs.create` for full creation with wrappers like episode management and auto-reset. Specify the backend ('mjx', 'generalized', etc.). ```python import jax from brax import envs # Simple creation (no wrappers) env = envs.get_environment('ant', backend='mjx') # Full creation with wrappers (for training) env = envs.create( env_name='ant', episode_length=1000, action_repeat=1, auto_reset=True, batch_size=None, # set to int for batched env backend='generalized', # 'mjx' | 'generalized' | 'positional' | 'spring' ) ``` -------------------------------- ### Running a Brax Simulation with JIT Source: https://context7.com/google/brax/llms.txt Initialize and run a Brax simulation for a specified number of steps using JIT-compiled functions for efficiency. ```python state = init_jit(sys, q, qd) for _ in range(100): state = step_jit(sys, state, act) ``` -------------------------------- ### Initializing and Stepping Brax Environments Source: https://context7.com/google/brax/llms.txt Initialize an environment state using a PRNG key and step the environment with a given action. Access observation shape, reward, done status, and metrics from the state. Environment observation and action sizes are also accessible. ```python rng = jax.random.PRNGKey(0) state = jax.jit(env.reset)(rng) print(state.obs.shape) # (observation_size,) print(state.reward) # scalar print(state.done) # scalar bool print(state.metrics) # dict of per-step diagnostic scalars print(env.observation_size) print(env.action_size) # Step the environment action = jax.random.uniform(rng, (env.action_size,), minval=-1, maxval=1) state = jax.jit(env.step)(state, action) ``` -------------------------------- ### Get Logits and Action Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb Normalizes observations and passes them through the policy network to obtain action logits and sample an action. This method is exported for JIT compilation. ```python @torch.jit.export def get_logits_action(self, observation): observation = self.normalize(observation) logits = self.policy(observation) loc, scale = self.dist_create(logits) action = self.dist_sample_no_postprocess(loc, scale) return logits, action ``` -------------------------------- ### Load Environment and Backend Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb Selects an environment and a physics engine backend for training. The backend affects the trade-off between physical realism and training speed. ```python #@title Load Env { run: "auto" } env_name = 'ant' # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d'] backend = 'positional' # @param ['generalized', 'positional', 'spring'] env = envs.get_environment(env_name=env_name, backend=backend) state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0)) HTML(html.render(env.sys, [state.pipeline_state])) ``` -------------------------------- ### Reset Quadruped Environment State in Brax Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb Initializes the environment state, including pipeline state, observations, and reward information. Ensures a clean start for simulations. ```python def reset(self, rng: jax.Array) -> State: # pytype: disable=signature-mismatch rng, key = jax.random.split(rng) pipeline_state = self.pipeline_init(self._init_q, jp.zeros(self._nv)) state_info = { 'rng': rng, 'last_act': jp.zeros(12), 'last_vel': jp.zeros(12), 'command': self.sample_command(key), 'last_contact': jp.zeros(4, dtype=bool), 'feet_air_time': jp.zeros(4), 'rewards': {k: 0.0 for k in self.reward_config.rewards.scales.keys()}, 'kick': jp.array([0.0, 0.0]), 'step': 0, } obs_history = jp.zeros(15 * 31) # store 15 steps of history obs = self._get_obs(pipeline_state, state_info, obs_history) reward, done = jp.zeros(2) metrics = {'total_dist': 0.0} for k in state_info['rewards']: metrics[k] = state_info['rewards'][k] state = State(pipeline_state, obs, reward, done, metrics, state_info) # pytype: disable=wrong-arg-types return state ``` -------------------------------- ### Import Plotting and Graphics Packages Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb Imports necessary libraries for plotting and creating graphics, including mediapy for media handling and matplotlib for plotting. Ensures ffmpeg is installed. ```python #@title Import packages for plotting and creating graphics import time import itertools import numpy as np from typing import Callable, NamedTuple, Optional, Union, List # Graphics and plotting. print('Installing mediapy:') !command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg) !pip install -q mediapy import mediapy as media import matplotlib.pyplot as plt # More legible printing from numpy. np.set_printoptions(precision=3, suppress=True, linewidth=100) ``` -------------------------------- ### Physics Pipelines - init and step Source: https://context7.com/google/brax/llms.txt All four Brax physics pipelines (`generalized`, `positional`, `spring`, `mjx`) expose `init` and `step` functions, which are pure JAX functions, JIT-compilable, and differentiable. ```APIDOC ## Physics Pipelines — `init` and `step` All four pipelines (`generalized`, `positional`, `spring`, `mjx`) expose the same two functions: `init` and `step`. They are pure JAX functions and are fully JIT-compilable and differentiable. ```python import jax import jax.numpy as jp from brax.io import mjcf from brax.generalized import pipeline as gen_pipeline # or: # from brax.positional import pipeline as pos_pipeline # from brax.spring import pipeline as spr_pipeline # from brax.mjx import pipeline as mjx_pipeline sys = mjcf.load('path/to/ant.xml') # Initialize state from q (positions) and qd (velocities) q = sys.init_q # default initial positions qd = jp.zeros(sys.qd_size()) # zero velocities state = gen_pipeline.init(sys, q, qd) # Single physics step act = jp.zeros(sys.act_size()) # zero action state = gen_pipeline.step(sys, state, act) ``` ``` -------------------------------- ### Initialize and Step Physics Pipelines Source: https://context7.com/google/brax/llms.txt Initialize the physics state using `pipeline.init` and advance the simulation by one step using `pipeline.step`. These functions are pure JAX, JIT-compilable, and differentiable. ```python import jax import jax.numpy as jp from brax.io import mjcf from brax.generalized import pipeline as gen_pipeline # or: # from brax.positional import pipeline as pos_pipeline # from brax.spring import pipeline as spr_pipeline # from brax.mjx import pipeline as mjx_pipeline sys = mjcf.load('path/to/ant.xml') # Initialize state from q (positions) and qd (velocities) q = sys.init_q # default initial positions qd = jp.zeros(sys.qd_size()) # zero velocities state = gen_pipeline.init(sys, q, qd) # Single physics step act = jp.zeros(sys.act_size()) # zero action state = gen_pipeline.step(sys, state, act) ``` -------------------------------- ### Initialize and Wrap Environment for PyTorch Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb Sets up a vectorized Brax environment and wraps it with TorchWrapper for PyTorch compatibility. Ensures the environment is ready for agent interaction. ```python env = gym_wrapper.VectorGymWrapper(env) # automatically convert between jax ndarrays and torch tensors: env = torch_wrapper.TorchWrapper(env, device=device) # env warmup env.reset() action = torch.zeros(env.action_space.shape).to(device) env.step(action) ``` -------------------------------- ### Import Core Brax, MJX, and MuJoCo Packages Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb Imports the main components of Brax, MJX, and MuJoCo, including environment definitions, training utilities, and IO modules. This is essential for setting up and running simulations. ```python #@title Import MuJoCo, MJX, and Brax from datetime import datetime import functools from IPython.display import HTML import jax from jax import numpy as jp import numpy as np from typing import Any, Dict, Sequence, Tuple, Union from brax import base from brax import envs from brax import math from brax.base import Base, Motion, Transform from brax.envs.base import Env, PipelineEnv, State from brax.mjx.base import State as MjxState from brax.training.agents.ppo import train as ppo from brax.training.agents.ppo import networks as ppo_networks from brax.io import html, mjcf, model from etils import epath from flax import struct from matplotlib import pyplot as plt import mediapy as media from ml_collections import config_dict import mujoco from mujoco import mjx ``` -------------------------------- ### Define Training Hyperparameters with functools.partial Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb Defines training configurations for different environments using functools.partial. These configurations specify hyperparameters for PPO or SAC training. ```python train_fn = { 'inverted_pendulum': functools.partial(ppo.train, num_timesteps=2_000_000, num_evals=20, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=2048, batch_size=1024, seed=1), 'inverted_double_pendulum': functools.partial(ppo.train, num_timesteps=20_000_000, num_evals=20, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=2048, batch_size=1024, seed=1), 'ant': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096, batch_size=2048, seed=1), 'humanoid': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=10, reward_scaling=0.1, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=10, num_minibatches=32, num_updates_per_batch=8, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=2048, batch_size=1024, seed=1), 'reacher': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=20, reward_scaling=5, episode_length=1000, normalize_observations=True, action_repeat=4, unroll_length=50, num_minibatches=32, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4, entropy_cost=1e-3, num_envs=2048, batch_size=256, max_devices_per_host=8, seed=1), 'humanoidstandup': functools.partial(ppo.train, num_timesteps=100_000_000, num_evals=20, reward_scaling=0.1, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=15, num_minibatches=32, num_updates_per_batch=8, discounting=0.97, learning_rate=6e-4, entropy_cost=1e-2, num_envs=2048, batch_size=1024, seed=1), 'hopper': functools.partial(sac.train, num_timesteps=6_553_600, num_evals=20, reward_scaling=30, episode_length=1000, normalize_observations=True, action_repeat=1, discounting=0.997, learning_rate=6e-4, num_envs=128, batch_size=512, grad_updates_per_step=64, max_devices_per_host=1, max_replay_size=1048576, min_replay_size=8192, seed=1), 'walker2d': functools.partial(sac.train, num_timesteps=7_864_320, num_evals=20, reward_scaling=5, episode_length=1000, normalize_observations=True, action_repeat=1, discounting=0.997, learning_rate=6e-4, num_envs=128, batch_size=128, grad_updates_per_step=32, max_devices_per_host=1, max_replay_size=1048576, min_replay_size=8192, seed=1), 'halfcheetah': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=20, reward_scaling=1, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=20, num_minibatches=32, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4, entropy_cost=0.001, num_envs=2048, batch_size=512, seed=3), 'pusher': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=20, reward_scaling=5, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=30, num_minibatches=16, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4,entropy_cost=1e-2, num_envs=2048, batch_size=512, seed=3), }[env_name] ``` -------------------------------- ### Brax Environment Wrappers Source: https://context7.com/google/brax/llms.txt Demonstrates the usage of AutoResetWrapper, EvalWrapper, and DomainRandomizationVmapWrapper for environment customization. The training.wrap function is a convenience utility for applying these wrappers. ```python ar_env = training.AutoResetWrapper(ep_env) ``` ```python eval_env = training.EvalWrapper(base_env) ``` ```python def rand_fn(sys, rng): # Scale link masses randomly per environment in the batch batch = 8 rngs = jax.random.split(rng, batch) scales = jax.vmap(lambda r: jax.random.uniform(r, minval=0.9, maxval=1.1))(rngs) in_axes = jax.tree_util.tree_map(lambda _: None, sys) new_mass = sys.link.inertia.mass * scales[:, None] sys_v = sys.tree_replace({'link.inertia.mass': new_mass}) return sys_v, in_axes dr_env = training.DomainRandomizationVmapWrapper(base_env, rand_fn) ``` ```python wrapped = training.wrap( base_env, episode_length=1000, action_repeat=1, randomization_fn=None, # or pass rand_fn ) ``` -------------------------------- ### Initialize PyTorch Agent and Optimizer Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb Defines the network architecture for the agent's policy and value functions, initializes the agent, scripts it for optimization, and sets up the Adam optimizer. ```python policy_layers = [ env.observation_space.shape[-1], 64, 64, env.action_space.shape[-1] * 2 ] value_layers = [env.observation_space.shape[-1], 64, 64, 1] agent = Agent(policy_layers, value_layers, entropy_cost, discounting, reward_scaling, device) agent = torch.jit.script(agent.to(device)) optimizer = optim.Adam(agent.parameters(), lr=learning_rate) ``` -------------------------------- ### Training Wrappers Source: https://context7.com/google/brax/llms.txt Introduces training wrappers from `brax.envs.wrappers.training`, such as `VmapWrapper` for vectorization and `EpisodeWrapper` for episode management, which can be composed over existing environments. ```APIDOC ## Training Wrappers (`brax.envs.wrappers.training`) Training wrappers compose over any `Env` to add episode truncation, vectorization, auto-reset, domain randomization, and evaluation metrics. ```python from brax import envs from brax.envs.wrappers import training import jax import jax.numpy as jp base_env = envs.get_environment('humanoid', backend='mjx') # VmapWrapper: vectorize across a batch of environments batched_env = training.VmapWrapper(base_env) rngs = jax.random.split(jax.random.PRNGKey(0), 64) states = jax.jit(batched_env.reset)(rngs) print(states.obs.shape) # (64, observation_size) # EpisodeWrapper: add step counter + truncation signal ep_env = training.EpisodeWrapper(base_env, episode_length=1000, action_repeat=2) ``` ``` -------------------------------- ### Save and Reload Policy Parameters Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb Demonstrates how to save trained policy parameters to a file and then load them back. This is useful for resuming training or deploying a trained policy. ```python # Save and reload params. model_path = '/tmp/mjx_brax_quadruped_policy' model.save_params(model_path, params) params = model.load_params(model_path) inference_fn = make_inference_fn(params) jit_inference_fn = jax.jit(inference_fn) ``` -------------------------------- ### Visualize Policy with Joystick Commands Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb Visualizes the trained policy by controlling the quadruped with joystick-like commands (x_vel, y_vel, ang_vel). It initializes the state, rolls out a trajectory, and renders the video. ```python # @markdown Commands **only used for Barkour Env**: x_vel = 1.0 #@param {type: "number"} y_vel = 0.0 #@param {type: "number"} ang_vel = -0.5 #@param {type: "number"} the_command = jp.array([x_vel, y_vel, ang_vel]) # initialize the state rng = jax.random.PRNGKey(0) state = jit_reset(rng) state.info['command'] = the_command rollout = [state.pipeline_state] # grab a trajectory n_steps = 500 render_every = 2 for i in range(n_steps): act_rng, rng = jax.random.split(rng) ctrl, _ = jit_inference_fn(state.obs, act_rng) state = jit_step(state, ctrl) rollout.append(state.pipeline_state) media.show_video( eval_env.render(rollout[::render_every], camera='track'), fps=1.0 / eval_env.dt / render_every) ``` -------------------------------- ### Initialize Brax Viewer Source: https://github.com/google/brax/blob/main/brax/visualizer/index.html Initializes the Brax viewer with the specified DOM element and system data. The system data is expected to be a base64 encoded, gzipped JSON string. ```javascript var system = "{{ system_json_b64 }}"; // decode base64 (convert ascii to binary) system = atob(system); // convert binary string to character-number array system = system.split('').map(function(x){return x.charCodeAt(0);}); // decompress system = pako.inflate(system); // convert gunzipped byteArray back to ascii string system = new TextDecoder("utf-8").decode(system); // and load json system = JSON.parse(system); import {Viewer} from 'viewer'; const domElement = document.getElementById("brax-viewer"); var viewer = new Viewer(domElement, system); ``` -------------------------------- ### Adding Episode Management with `EpisodeWrapper` Source: https://context7.com/google/brax/llms.txt Wrap an environment with `EpisodeWrapper` to automatically track steps within an episode, manage action repeats, and signal episode truncation. This is essential for training loops. ```python from brax import envs from brax.envs.wrappers import training import jax import jax.numpy as jp base_env = envs.get_environment('humanoid', backend='mjx') # EpisodeWrapper: add step counter + truncation signal ep_env = training.EpisodeWrapper(base_env, episode_length=1000, action_repeat=2) ```