### Setup Brax and Dependencies
Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb
Installs Brax and necessary libraries. It's recommended to use a GPU runtime for better performance. This code block also handles TPU setup if available.
```python
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a GPU runtime. From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.
import functools
import jax
import os
from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt
from IPython.display import HTML, clear_output
try:
import brax
except (ImportError, ModuleNotFoundError):
!pip install git+https://github.com/google/brax.git@main
clear_output()
import brax
import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac
if 'COLAB_TPU_ADDR' in os.environ:
from jax.tools import colab_tpu
colab_tpu.setup_tpu()
```
--------------------------------
### Install Brax from Source
Source: https://github.com/google/brax/blob/main/README.md
Clone the repository, navigate to the directory, set up a virtual environment, and install Brax in editable mode.
```bash
python3 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install -e .
```
--------------------------------
### Install MuJoCo, MJX, and Brax
Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb
Installs the necessary libraries for MuJoCo, MJX, and Brax. Run this cell first to set up your environment.
```python
!pip install mujoco
!pip install mujoco_mjx
!pip install brax
```
--------------------------------
### Install Brax from PyPI
Source: https://github.com/google/brax/blob/main/README.md
Use these commands to create a virtual environment, activate it, upgrade pip, and install Brax from PyPI.
```bash
python3 -m venv env
source env/bin/activate
pip install --upgrade pip
pip install brax
```
--------------------------------
### Setup and Imports for Brax
Source: https://github.com/google/brax/blob/main/notebooks/basics.ipynb
Installs Brax if not found and imports necessary JAX and Matplotlib libraries for physics simulation and visualization.
```python
#@title Colab setup and imports
import jax
from jax import numpy as jp
from matplotlib.lines import Line2D
from matplotlib.patches import Circle
import matplotlib.pyplot as plt
try:
import brax
except (ImportError, ModuleNotFoundError):
from IPython.display import clear_output
!pip install git+https://github.com/google/brax.git@main
clear_output()
import brax
```
--------------------------------
### Install Brax
Source: https://context7.com/google/brax/llms.txt
Install Brax using pip from PyPI or conda-forge. Alternatively, clone the repository and install from source.
```bash
# From PyPI
pip install brax
# From conda-forge
conda install -c conda-forge brax
# From source
git clone https://github.com/google/brax && cd brax
pip install -e .
```
--------------------------------
### Train a Model with Brax
Source: https://github.com/google/brax/blob/main/README.md
Execute this command to start training a model using Brax. Ensure you have the necessary GPU support installed if training on an NVidia GPU.
```bash
learn
```
--------------------------------
### ES Training Setup
Source: https://context7.com/google/brax/llms.txt
Configures and executes Evolution Strategies (ES) training. This method optimizes a policy population using estimated gradients from perturbed rollouts and supports canonical and OpenAI ES variants. Environment setup and training parameters are specified.
```python
from brax import envs
from brax.training.agents.es import train as es
env = envs.get_environment('swimmer', backend='spring')
make_inference_fn, params, metrics = es.train(
environment=env,
num_timesteps=5_000_000,
episode_length=1000,
action_repeat=1,
num_envs=64,
learning_rate=0.01,
l2coeff=0.005,
noise_std=0.01,
normalize_observations=True,
center_fitness=True,
)
```
--------------------------------
### Check MuJoCo Installation and GPU Setup
Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb
Verifies the MuJoCo installation and configures the environment for GPU rendering using EGL. This is crucial for performance and requires a GPU runtime.
```python
#@title Check if MuJoCo installation was successful
from google.colab import files
import distutils.util
import os
import subprocess
if subprocess.run('nvidia-smi').returncode:
raise RuntimeError(
'Cannot communicate with GPU. '
'Make sure you are using a GPU Colab runtime. '
'Go to the Runtime menu and select Choose runtime type.')
# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
f.write("""{
"file_format_version" : "1.0.0",
"ICD" : {
"library_path" : "libEGL_nvidia.so.0"
}
}
""")
# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs
xla_flags = os.environ.get('XLA_FLAGS', '')
xla_flags += ' --xla_gpu_triton_gemm_any=True'
os.environ['XLA_FLAGS'] = xla_flags
# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl
try:
print('Checking that the installation succeeded:')
import mujoco
mujoco.MjModel.from_xml_string('')
except Exception as e:
raise e from RuntimeError(
'Something went wrong during installation. Check the shell output above '
'for more information.\n'
'If using a hosted Colab runtime, make sure you enable GPU acceleration '
'by going to the Runtime menu and selecting "Choose runtime type".')
print('Installation successful.')
```
--------------------------------
### PPO Training Setup and Execution
Source: https://context7.com/google/brax/llms.txt
Sets up and trains a Proximal Policy Optimization (PPO) agent. Requires environment initialization and defines a progress function for monitoring training. The trained policy can then be used for inference.
```python
import functools
import jax
from brax import envs
from brax.training.agents.ppo import train as ppo
env = envs.get_environment('ant', backend='mjx')
def progress(step, metrics):
print(f'step={step:,} reward={metrics["eval/episode_reward"]:.2f}')
make_inference_fn, params, metrics = ppo.train(
environment=env,
num_timesteps=50_000_000,
num_envs=2048,
episode_length=1000,
learning_rate=3e-4,
entropy_cost=1e-2,
discounting=0.97,
unroll_length=5,
batch_size=512,
num_minibatches=32,
num_updates_per_batch=4,
normalize_observations=True,
reward_scaling=10.0,
clipping_epsilon=0.3,
gae_lambda=0.95,
action_repeat=1,
progress_fn=progress,
# Optional: checkpoint_logdir='/tmp/ppo_ant'
)
# Use the trained policy for inference
inference_fn = make_inference_fn(params, deterministic=True)
rng = jax.random.PRNGKey(0)
obs = jax.numpy.zeros(env.observation_size)
action, _ = inference_fn(obs, rng)
print('action:', action)
```
--------------------------------
### Install Brax and Dependencies
Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb
Installs the Brax library and its dependencies, including upgrading OpenAI Gym. This is typically run at the beginning of a Colab notebook.
```python
!pip install --upgrade gym --quiet
```
--------------------------------
### SAC Training Setup and Execution
Source: https://context7.com/google/brax/llms.txt
Configures and initiates training for a Soft Actor-Critic (SAC) agent. This off-policy algorithm uses a replay buffer and is suitable for continuous-action tasks. It requires environment setup and a progress callback.
```python
import jax
from brax import envs
from brax.training.agents.sac import train as sac
env = envs.get_environment('halfcheetah', backend='generalized')
def progress(step, metrics):
print(f'step={step:,} reward={metrics["eval/episode_reward"]:.2f}')
make_inference_fn, params, metrics = sac.train(
environment=env,
num_timesteps=1_000_000,
episode_length=1000,
num_envs=128,
num_eval_envs=128,
learning_rate=6e-4,
discounting=0.99,
batch_size=256,
tau=0.005,
reward_scaling=30.0,
normalize_observations=False,
min_replay_size=8192,
grad_updates_per_step=1,
deterministic_eval=True,
seed=0,
progress_fn=progress,
)
inference_fn = make_inference_fn(params, deterministic=True)
obs = jax.numpy.zeros(env.observation_size)
action, _ = inference_fn(obs, jax.random.PRNGKey(1))
```
--------------------------------
### Environment Creation and Usage
Source: https://context7.com/google/brax/llms.txt
Demonstrates how to get and create Brax environments, including options for wrappers, batching, and backends. It also shows basic usage like resetting and stepping the environment, and accessing environment properties.
```APIDOC
## Environment API (`brax.envs`)
Brax ships 12 built-in environments. All environments implement the `Env` interface with `reset(rng)` → `State` and `step(state, action)` → `State`. The `create` factory adds episode management, batching, and auto-reset wrappers.
```python
import jax
from brax import envs
# Available envs: 'ant', 'humanoid', 'hopper', 'halfcheetah', 'walker2d',
# 'swimmer', 'reacher', 'pusher', 'inverted_pendulum',
# 'inverted_double_pendulum', 'humanoidstandup', 'fast'
# Simple creation (no wrappers)
env = envs.get_environment('ant', backend='mjx')
# Full creation with wrappers (for training)
env = envs.create(
env_name='ant',
episode_length=1000,
action_repeat=1,
auto_reset=True,
batch_size=None, # set to int for batched env
backend='generalized', # 'mjx' | 'generalized' | 'positional' | 'spring'
)
rng = jax.random.PRNGKey(0)
state = jax.jit(env.reset)(rng)
print(state.obs.shape) # (observation_size,)
print(state.reward) # scalar
print(state.done) # scalar bool
print(state.metrics) # dict of per-step diagnostic scalars
print(env.observation_size)
print(env.action_size)
# Step the environment
action = jax.random.uniform(rng, (env.action_size,), minval=-1, maxval=1)
state = jax.jit(env.step)(state, action)
# Register a custom environment
from brax.envs.base import PipelineEnv, State
class MyEnv(PipelineEnv):
def reset(self, rng):
...
def step(self, state, action):
...
envs.register_environment('my_env', MyEnv)
my_env = envs.get_environment('my_env')
```
```
--------------------------------
### Domain Randomization Setup in Brax
Source: https://context7.com/google/brax/llms.txt
Demonstrates how to set up domain randomization using `DomainRandomizationVmapWrapper` or `training.wrap` with a custom randomization function. This enables training policies that are robust to variations in physical parameters.
```python
import jax
import jax.numpy as jp
from brax import envs
from brax.envs.wrappers.training import DomainRandomizationVmapWrapper, wrap
base_env = envs.get_environment('ant', backend='generalized')
def randomize(sys, rng):
"""Return (batched_sys, in_axes) for 64-env randomization."""
n = 64
rngs = jax.random.split(rng, n)
# Randomize link masses ±20%
mass_scales = jax.vmap(
lambda r: jax.random.uniform(r, shape=(sys.num_links(),), minval=0.8, maxval=1.2)
)(rngs)
sys_v = jax.vmap(lambda s: sys.tree_replace({'link.inertia.mass': sys.link.inertia.mass * s}))(mass_scales)
# in_axes: 0 where batched, None where not batched
in_axes = jax.tree_util.tree_map(lambda _: 0, sys_v)
return sys_v, in_axes
rand_env = DomainRandomizationVmapWrapper(base_env, randomize)
# Or use training.wrap with the randomization function:
from brax.training.agents.ppo import train as ppo
make_inference_fn, params, _ = ppo.train(
environment=base_env,
num_timesteps=10_000_000,
num_envs=64,
episode_length=1000,
randomization_fn=randomize,
)
```
--------------------------------
### Initialize Environment for Visualization
Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb
Initializes the environment for policy visualization. It gets the environment instance and creates JIT-compiled versions of the reset and step functions for efficiency.
```python
eval_env = envs.get_environment(env_name)
jit_reset = jax.jit(eval_env.reset)
jit_step = jax.jit(eval_env.step)
```
--------------------------------
### Install Brax from Conda/Mamba
Source: https://github.com/google/brax/blob/main/README.md
Install Brax using Conda or Mamba by specifying the conda-forge channel.
```bash
conda install -c conda-forge brax # s/conda/mamba for mamba
```
--------------------------------
### Execute Training and Get Inference Function
Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb
Initiates the training process using the defined training function and environment, and obtains a function for making inferences.
```python
make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)
```
--------------------------------
### Get Barkour Environment
Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb
Retrieves an instance of the Barkour environment using its registered name. This is the standard way to initialize an environment for use.
```python
env_name = 'barkour'
env = envs.get_environment(env_name)
```
--------------------------------
### Load Brax System from MJCF
Source: https://github.com/google/brax/blob/main/notebooks/basics.ipynb
Loads a Brax physics system defined in MJCF format. This example creates a simple scene with a sphere and a plane.
```python
from brax.io import mjcf
ball = mjcf.loads(
""""
)
```
--------------------------------
### PipelineEnv - Building Custom Environments
Source: https://context7.com/google/brax/llms.txt
Provides a base class `PipelineEnv` for creating custom physics-based environments. It outlines the structure for implementing `reset` and `step` methods, and includes an example of a `BallBalance` environment.
```APIDOC
## `PipelineEnv` — Building Custom Environments
`PipelineEnv` is the base class for all built-in environments. Subclass it to build custom physics-based environments with selectable backends.
```python
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
import jax
import jax.numpy as jp
class BallBalance(PipelineEnv):
def __init__(self, backend='mjx', **kwargs):
xml = """
"""
sys = mjcf.loads(xml)
super().__init__(sys, backend=backend, n_frames=4, **kwargs)
def reset(self, rng: jax.Array) -> State:
q = self.sys.init_q + jax.random.uniform(rng, (self.sys.q_size(),), minval=-0.01, maxval=0.01)
qd = jp.zeros(self.sys.qd_size())
pipeline_state = self.pipeline_init(q, qd)
obs = self._get_obs(pipeline_state)
return State(pipeline_state=pipeline_state, obs=obs, reward=jp.float32(0), done=jp.float32(0))
def step(self, state: State, action: jax.Array) -> State:
pipeline_state = self.pipeline_step(state.pipeline_state, action)
obs = self._get_obs(pipeline_state)
reward = -jp.sum(jp.square(pipeline_state.x.pos[0, :2])) # penalize XY offset
done = jp.float32(jp.abs(pipeline_state.x.pos[0, 2]) < 0.05) # fell over
return state.replace(pipeline_state=pipeline_state, obs=obs, reward=reward, done=done)
def _get_obs(self, ps):
return jp.concatenate([ps.q, ps.qd])
env = BallBalance(backend='mjx')
state = jax.jit(env.reset)(jax.random.PRNGKey(42))
print(state.obs.shape, env.observation_size, env.action_size)
```
```
--------------------------------
### Import Brax and Helper Modules
Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb
Imports core Brax libraries, environment wrappers, training agents (PPO), and common Python modules for data manipulation, visualization, and system operations. It includes a fallback to install Brax from source if it's not found.
```python
#@title Import Brax and some helper modules
from IPython.display import clear_output
import collections
from datetime import datetime
import functools
import math
import os
import time
from typing import Any, Callable, Dict, Optional, Sequence
try:
import brax
except (ImportError, ModuleNotFoundError):
!pip install git+https://github.com/google/brax.git@main
clear_output()
import brax
from brax import envs
from brax.envs.wrappers import gym as gym_wrapper
from brax.envs.wrappers import torch as torch_wrapper
from brax.io import metrics
from brax.training.agents.ppo import train as ppo
import gym
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
```
--------------------------------
### ARS Training Setup
Source: https://context7.com/google/brax/llms.txt
Initializes and runs Augmented Random Search (ARS) training. This derivative-free method is effective for locomotion tasks and does not require neural network gradients. It involves setting up the environment and training parameters.
```python
from brax import envs
from brax.training.agents.ars import train as ars
env = envs.get_environment('hopper', backend='positional')
make_inference_fn, params, metrics = ars.train(
environment=env,
num_timesteps=10_000_000,
episode_length=1000,
action_repeat=1,
num_envs=32,
learning_rate=0.015,
normalize_observations=True,
number_of_directions=60,
step_size=0.03,
)
```
--------------------------------
### Brax Trainer Configuration and Execution
Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb
This code configures and runs the built-in Brax trainer for an 'ant' environment using the 'spring' backend. It's designed for efficient training and includes parameters for environment setup, training duration, and optimization. Note the use of `progress_fn` for monitoring.
```python
train_sps = []
def progress(_, metrics):
if 'training/sps' in metrics:
train_sps.append(metrics['training/sps'])
ppo.train(
environment=envs.create(env_name='ant', backend='spring'),
num_timesteps = 30_000_000, num_evals = 10, reward_scaling = .1,
episode_length = 1000, normalize_observations = True, action_repeat = 1,
unroll_length = 5, num_minibatches = 32, num_updates_per_batch = 4,
discounting = 0.97, learning_rate = 3e-4, entropy_cost = 1e-2,
num_envs = 2048, batch_size = 1024, progress_fn = progress)
print(f'train steps/sec: {np.mean(train_sps[1:])}')
```
--------------------------------
### Defining a Custom Brax Environment with `PipelineEnv`
Source: https://context7.com/google/brax/llms.txt
Implement a custom physics environment by subclassing `PipelineEnv`. Define the environment's XML, reset logic, step logic, and observation calculation. This example shows a 'BallBalance' environment.
```python
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
import jax
import jax.numpy as jp
class BallBalance(PipelineEnv):
def __init__(self, backend='mjx', **kwargs):
xml = """
"""
sys = mjcf.loads(xml)
super().__init__(sys, backend=backend, n_frames=4, **kwargs)
def reset(self, rng: jax.Array) -> State:
q = self.sys.init_q + jax.random.uniform(rng, (self.sys.q_size(),), minval=-0.01, maxval=0.01)
qd = jp.zeros(self.sys.qd_size())
pipeline_state = self.pipeline_init(q, qd)
obs = self._get_obs(pipeline_state)
return State(pipeline_state=pipeline_state, obs=obs, reward=jp.float32(0), done=jp.float32(0))
def step(self, state: State, action: jax.Array) -> State:
pipeline_state = self.pipeline_step(state.pipeline_state, action)
obs = self._get_obs(pipeline_state)
reward = -jp.sum(jp.square(pipeline_state.x.pos[0, :2])) # penalize XY offset
done = jp.float32(jp.abs(pipeline_state.x.pos[0, 2]) < 0.05) # fell over
return state.replace(pipeline_state=pipeline_state, obs=obs, reward=reward, done=done)
def _get_obs(self, ps):
return jp.concatenate([ps.q, ps.qd])
env = BallBalance(backend='mjx')
state = jax.jit(env.reset)(jax.random.PRNGKey(42))
print(state.obs.shape, env.observation_size, env.action_size)
```
--------------------------------
### Create and Train Environment
Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb
Creates an environment, JIT compiles reset, step, and inference functions, and runs a training loop for 1000 steps. Displays the rendered rollout.
```python
env = envs.create(env_name=env_name, backend=backend)
jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fn)
rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
for _ in range(1000):
rollout.append(state.pipeline_state)
act_rng, rng = jax.random.split(rng)
act, _ = jit_inference_fn(state.obs, act_rng)
state = jit_env_step(state, act)
HTML(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))
```
--------------------------------
### Training with PPO
Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb
This snippet initiates the training process for a selected environment using the Proximal Policy Optimization (PPO) algorithm. It requires the environment and backend to be set up prior to execution.
```python
#@title Training
```
--------------------------------
### Creating Brax Environments
Source: https://context7.com/google/brax/llms.txt
Create Brax environments using `envs.get_environment` for simple creation or `envs.create` for full creation with wrappers like episode management and auto-reset. Specify the backend ('mjx', 'generalized', etc.).
```python
import jax
from brax import envs
# Simple creation (no wrappers)
env = envs.get_environment('ant', backend='mjx')
# Full creation with wrappers (for training)
env = envs.create(
env_name='ant',
episode_length=1000,
action_repeat=1,
auto_reset=True,
batch_size=None, # set to int for batched env
backend='generalized', # 'mjx' | 'generalized' | 'positional' | 'spring'
)
```
--------------------------------
### Running a Brax Simulation with JIT
Source: https://context7.com/google/brax/llms.txt
Initialize and run a Brax simulation for a specified number of steps using JIT-compiled functions for efficiency.
```python
state = init_jit(sys, q, qd)
for _ in range(100):
state = step_jit(sys, state, act)
```
--------------------------------
### Initializing and Stepping Brax Environments
Source: https://context7.com/google/brax/llms.txt
Initialize an environment state using a PRNG key and step the environment with a given action. Access observation shape, reward, done status, and metrics from the state. Environment observation and action sizes are also accessible.
```python
rng = jax.random.PRNGKey(0)
state = jax.jit(env.reset)(rng)
print(state.obs.shape) # (observation_size,)
print(state.reward) # scalar
print(state.done) # scalar bool
print(state.metrics) # dict of per-step diagnostic scalars
print(env.observation_size)
print(env.action_size)
# Step the environment
action = jax.random.uniform(rng, (env.action_size,), minval=-1, maxval=1)
state = jax.jit(env.step)(state, action)
```
--------------------------------
### Get Logits and Action
Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb
Normalizes observations and passes them through the policy network to obtain action logits and sample an action. This method is exported for JIT compilation.
```python
@torch.jit.export
def get_logits_action(self, observation):
observation = self.normalize(observation)
logits = self.policy(observation)
loc, scale = self.dist_create(logits)
action = self.dist_sample_no_postprocess(loc, scale)
return logits, action
```
--------------------------------
### Load Environment and Backend
Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb
Selects an environment and a physics engine backend for training. The backend affects the trade-off between physical realism and training speed.
```python
#@title Load Env { run: "auto" }
env_name = 'ant' # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
backend = 'positional' # @param ['generalized', 'positional', 'spring']
env = envs.get_environment(env_name=env_name,
backend=backend)
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))
HTML(html.render(env.sys, [state.pipeline_state]))
```
--------------------------------
### Reset Quadruped Environment State in Brax
Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb
Initializes the environment state, including pipeline state, observations, and reward information. Ensures a clean start for simulations.
```python
def reset(self, rng: jax.Array) -> State: # pytype: disable=signature-mismatch
rng, key = jax.random.split(rng)
pipeline_state = self.pipeline_init(self._init_q, jp.zeros(self._nv))
state_info = {
'rng': rng,
'last_act': jp.zeros(12),
'last_vel': jp.zeros(12),
'command': self.sample_command(key),
'last_contact': jp.zeros(4, dtype=bool),
'feet_air_time': jp.zeros(4),
'rewards': {k: 0.0 for k in self.reward_config.rewards.scales.keys()},
'kick': jp.array([0.0, 0.0]),
'step': 0,
}
obs_history = jp.zeros(15 * 31) # store 15 steps of history
obs = self._get_obs(pipeline_state, state_info, obs_history)
reward, done = jp.zeros(2)
metrics = {'total_dist': 0.0}
for k in state_info['rewards']:
metrics[k] = state_info['rewards'][k]
state = State(pipeline_state, obs, reward, done, metrics, state_info) # pytype: disable=wrong-arg-types
return state
```
--------------------------------
### Import Plotting and Graphics Packages
Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb
Imports necessary libraries for plotting and creating graphics, including mediapy for media handling and matplotlib for plotting. Ensures ffmpeg is installed.
```python
#@title Import packages for plotting and creating graphics
import time
import itertools
import numpy as np
from typing import Callable, NamedTuple, Optional, Union, List
# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)
!pip install -q mediapy
import mediapy as media
import matplotlib.pyplot as plt
# More legible printing from numpy.
np.set_printoptions(precision=3, suppress=True, linewidth=100)
```
--------------------------------
### Physics Pipelines - init and step
Source: https://context7.com/google/brax/llms.txt
All four Brax physics pipelines (`generalized`, `positional`, `spring`, `mjx`) expose `init` and `step` functions, which are pure JAX functions, JIT-compilable, and differentiable.
```APIDOC
## Physics Pipelines — `init` and `step`
All four pipelines (`generalized`, `positional`, `spring`, `mjx`) expose the same two functions: `init` and `step`. They are pure JAX functions and are fully JIT-compilable and differentiable.
```python
import jax
import jax.numpy as jp
from brax.io import mjcf
from brax.generalized import pipeline as gen_pipeline # or:
# from brax.positional import pipeline as pos_pipeline
# from brax.spring import pipeline as spr_pipeline
# from brax.mjx import pipeline as mjx_pipeline
sys = mjcf.load('path/to/ant.xml')
# Initialize state from q (positions) and qd (velocities)
q = sys.init_q # default initial positions
qd = jp.zeros(sys.qd_size()) # zero velocities
state = gen_pipeline.init(sys, q, qd)
# Single physics step
act = jp.zeros(sys.act_size()) # zero action
state = gen_pipeline.step(sys, state, act)
```
```
--------------------------------
### Initialize and Step Physics Pipelines
Source: https://context7.com/google/brax/llms.txt
Initialize the physics state using `pipeline.init` and advance the simulation by one step using `pipeline.step`. These functions are pure JAX, JIT-compilable, and differentiable.
```python
import jax
import jax.numpy as jp
from brax.io import mjcf
from brax.generalized import pipeline as gen_pipeline # or:
# from brax.positional import pipeline as pos_pipeline
# from brax.spring import pipeline as spr_pipeline
# from brax.mjx import pipeline as mjx_pipeline
sys = mjcf.load('path/to/ant.xml')
# Initialize state from q (positions) and qd (velocities)
q = sys.init_q # default initial positions
qd = jp.zeros(sys.qd_size()) # zero velocities
state = gen_pipeline.init(sys, q, qd)
# Single physics step
act = jp.zeros(sys.act_size()) # zero action
state = gen_pipeline.step(sys, state, act)
```
--------------------------------
### Initialize and Wrap Environment for PyTorch
Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb
Sets up a vectorized Brax environment and wraps it with TorchWrapper for PyTorch compatibility. Ensures the environment is ready for agent interaction.
```python
env = gym_wrapper.VectorGymWrapper(env)
# automatically convert between jax ndarrays and torch tensors:
env = torch_wrapper.TorchWrapper(env, device=device)
# env warmup
env.reset()
action = torch.zeros(env.action_space.shape).to(device)
env.step(action)
```
--------------------------------
### Import Core Brax, MJX, and MuJoCo Packages
Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb
Imports the main components of Brax, MJX, and MuJoCo, including environment definitions, training utilities, and IO modules. This is essential for setting up and running simulations.
```python
#@title Import MuJoCo, MJX, and Brax
from datetime import datetime
import functools
from IPython.display import HTML
import jax
from jax import numpy as jp
import numpy as np
from typing import Any, Dict, Sequence, Tuple, Union
from brax import base
from brax import envs
from brax import math
from brax.base import Base, Motion, Transform
from brax.envs.base import Env, PipelineEnv, State
from brax.mjx.base import State as MjxState
from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from brax.io import html, mjcf, model
from etils import epath
from flax import struct
from matplotlib import pyplot as plt
import mediapy as media
from ml_collections import config_dict
import mujoco
from mujoco import mjx
```
--------------------------------
### Define Training Hyperparameters with functools.partial
Source: https://github.com/google/brax/blob/main/notebooks/training.ipynb
Defines training configurations for different environments using functools.partial. These configurations specify hyperparameters for PPO or SAC training.
```python
train_fn = {
'inverted_pendulum': functools.partial(ppo.train, num_timesteps=2_000_000, num_evals=20, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=2048, batch_size=1024, seed=1),
'inverted_double_pendulum': functools.partial(ppo.train, num_timesteps=20_000_000, num_evals=20, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=2048, batch_size=1024, seed=1),
'ant': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=10, reward_scaling=10, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=5, num_minibatches=32, num_updates_per_batch=4, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-2, num_envs=4096, batch_size=2048, seed=1),
'humanoid': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=10, reward_scaling=0.1, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=10, num_minibatches=32, num_updates_per_batch=8, discounting=0.97, learning_rate=3e-4, entropy_cost=1e-3, num_envs=2048, batch_size=1024, seed=1),
'reacher': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=20, reward_scaling=5, episode_length=1000, normalize_observations=True, action_repeat=4, unroll_length=50, num_minibatches=32, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4, entropy_cost=1e-3, num_envs=2048, batch_size=256, max_devices_per_host=8, seed=1),
'humanoidstandup': functools.partial(ppo.train, num_timesteps=100_000_000, num_evals=20, reward_scaling=0.1, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=15, num_minibatches=32, num_updates_per_batch=8, discounting=0.97, learning_rate=6e-4, entropy_cost=1e-2, num_envs=2048, batch_size=1024, seed=1),
'hopper': functools.partial(sac.train, num_timesteps=6_553_600, num_evals=20, reward_scaling=30, episode_length=1000, normalize_observations=True, action_repeat=1, discounting=0.997, learning_rate=6e-4, num_envs=128, batch_size=512, grad_updates_per_step=64, max_devices_per_host=1, max_replay_size=1048576, min_replay_size=8192, seed=1),
'walker2d': functools.partial(sac.train, num_timesteps=7_864_320, num_evals=20, reward_scaling=5, episode_length=1000, normalize_observations=True, action_repeat=1, discounting=0.997, learning_rate=6e-4, num_envs=128, batch_size=128, grad_updates_per_step=32, max_devices_per_host=1, max_replay_size=1048576, min_replay_size=8192, seed=1),
'halfcheetah': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=20, reward_scaling=1, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=20, num_minibatches=32, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4, entropy_cost=0.001, num_envs=2048, batch_size=512, seed=3),
'pusher': functools.partial(ppo.train, num_timesteps=50_000_000, num_evals=20, reward_scaling=5, episode_length=1000, normalize_observations=True, action_repeat=1, unroll_length=30, num_minibatches=16, num_updates_per_batch=8, discounting=0.95, learning_rate=3e-4,entropy_cost=1e-2, num_envs=2048, batch_size=512, seed=3),
}[env_name]
```
--------------------------------
### Brax Environment Wrappers
Source: https://context7.com/google/brax/llms.txt
Demonstrates the usage of AutoResetWrapper, EvalWrapper, and DomainRandomizationVmapWrapper for environment customization. The training.wrap function is a convenience utility for applying these wrappers.
```python
ar_env = training.AutoResetWrapper(ep_env)
```
```python
eval_env = training.EvalWrapper(base_env)
```
```python
def rand_fn(sys, rng):
# Scale link masses randomly per environment in the batch
batch = 8
rngs = jax.random.split(rng, batch)
scales = jax.vmap(lambda r: jax.random.uniform(r, minval=0.9, maxval=1.1))(rngs)
in_axes = jax.tree_util.tree_map(lambda _: None, sys)
new_mass = sys.link.inertia.mass * scales[:, None]
sys_v = sys.tree_replace({'link.inertia.mass': new_mass})
return sys_v, in_axes
dr_env = training.DomainRandomizationVmapWrapper(base_env, rand_fn)
```
```python
wrapped = training.wrap(
base_env,
episode_length=1000,
action_repeat=1,
randomization_fn=None, # or pass rand_fn
)
```
--------------------------------
### Initialize PyTorch Agent and Optimizer
Source: https://github.com/google/brax/blob/main/notebooks/training_torch.ipynb
Defines the network architecture for the agent's policy and value functions, initializes the agent, scripts it for optimization, and sets up the Adam optimizer.
```python
policy_layers = [
env.observation_space.shape[-1], 64, 64, env.action_space.shape[-1] * 2
]
value_layers = [env.observation_space.shape[-1], 64, 64, 1]
agent = Agent(policy_layers, value_layers, entropy_cost, discounting,
reward_scaling, device)
agent = torch.jit.script(agent.to(device))
optimizer = optim.Adam(agent.parameters(), lr=learning_rate)
```
--------------------------------
### Training Wrappers
Source: https://context7.com/google/brax/llms.txt
Introduces training wrappers from `brax.envs.wrappers.training`, such as `VmapWrapper` for vectorization and `EpisodeWrapper` for episode management, which can be composed over existing environments.
```APIDOC
## Training Wrappers (`brax.envs.wrappers.training`)
Training wrappers compose over any `Env` to add episode truncation, vectorization, auto-reset, domain randomization, and evaluation metrics.
```python
from brax import envs
from brax.envs.wrappers import training
import jax
import jax.numpy as jp
base_env = envs.get_environment('humanoid', backend='mjx')
# VmapWrapper: vectorize across a batch of environments
batched_env = training.VmapWrapper(base_env)
rngs = jax.random.split(jax.random.PRNGKey(0), 64)
states = jax.jit(batched_env.reset)(rngs)
print(states.obs.shape) # (64, observation_size)
# EpisodeWrapper: add step counter + truncation signal
ep_env = training.EpisodeWrapper(base_env, episode_length=1000, action_repeat=2)
```
```
--------------------------------
### Save and Reload Policy Parameters
Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb
Demonstrates how to save trained policy parameters to a file and then load them back. This is useful for resuming training or deploying a trained policy.
```python
# Save and reload params.
model_path = '/tmp/mjx_brax_quadruped_policy'
model.save_params(model_path, params)
params = model.load_params(model_path)
inference_fn = make_inference_fn(params)
jit_inference_fn = jax.jit(inference_fn)
```
--------------------------------
### Visualize Policy with Joystick Commands
Source: https://github.com/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb
Visualizes the trained policy by controlling the quadruped with joystick-like commands (x_vel, y_vel, ang_vel). It initializes the state, rolls out a trajectory, and renders the video.
```python
# @markdown Commands **only used for Barkour Env**:
x_vel = 1.0 #@param {type: "number"}
y_vel = 0.0 #@param {type: "number"}
ang_vel = -0.5 #@param {type: "number"}
the_command = jp.array([x_vel, y_vel, ang_vel])
# initialize the state
rng = jax.random.PRNGKey(0)
state = jit_reset(rng)
state.info['command'] = the_command
rollout = [state.pipeline_state]
# grab a trajectory
n_steps = 500
render_every = 2
for i in range(n_steps):
act_rng, rng = jax.random.split(rng)
ctrl, _ = jit_inference_fn(state.obs, act_rng)
state = jit_step(state, ctrl)
rollout.append(state.pipeline_state)
media.show_video(
eval_env.render(rollout[::render_every], camera='track'),
fps=1.0 / eval_env.dt / render_every)
```
--------------------------------
### Initialize Brax Viewer
Source: https://github.com/google/brax/blob/main/brax/visualizer/index.html
Initializes the Brax viewer with the specified DOM element and system data. The system data is expected to be a base64 encoded, gzipped JSON string.
```javascript
var system = "{{ system_json_b64 }}"; // decode base64 (convert ascii to binary)
system = atob(system); // convert binary string to character-number array
system = system.split('').map(function(x){return x.charCodeAt(0);}); // decompress
system = pako.inflate(system); // convert gunzipped byteArray back to ascii string
system = new TextDecoder("utf-8").decode(system); // and load json
system = JSON.parse(system);
import {Viewer} from 'viewer'; const domElement = document.getElementById("brax-viewer");
var viewer = new Viewer(domElement, system);
```
--------------------------------
### Adding Episode Management with `EpisodeWrapper`
Source: https://context7.com/google/brax/llms.txt
Wrap an environment with `EpisodeWrapper` to automatically track steps within an episode, manage action repeats, and signal episode truncation. This is essential for training loops.
```python
from brax import envs
from brax.envs.wrappers import training
import jax
import jax.numpy as jp
base_env = envs.get_environment('humanoid', backend='mjx')
# EpisodeWrapper: add step counter + truncation signal
ep_env = training.EpisodeWrapper(base_env, episode_length=1000, action_repeat=2)
```