### Build and Serve Diffrax Documentation Source: https://github.com/patrick-kidger/diffrax/blob/main/CONTRIBUTING.md Installs the documentation dependencies and starts a local server to preview the documentation changes. ```bash pip install -e '.[docs]' mkdocs serve ``` -------------------------------- ### Adaptive Stepping for SDEs with HalfSolver and PIDController Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/usage/getting-started.md This code illustrates how to implement adaptive stepping for SDEs in Diffrax by wrapping a base solver (e.g., Euler) with HalfSolver for error estimation and using a PIDController for step size management. This approach is crucial for efficiently solving SDEs with varying step size requirements. ```Python solver = ... # Euler, Heun etc. as usual solver = HalfSolver(solver) # Computes error estimates using half-steps stepsize_controller = PIDController(pcoeff=0.1, icoeff=0.3, rtol=..., atol=...) ``` -------------------------------- ### Solve ODE with Diffrax Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/usage/getting-started.md Solves the ODE dy/dt = -y with y(0) = 1 over the interval [0, 3] using the Dopri5 solver and adaptive step sizing. The solution is saved at specific time points. ```Python from diffrax import diffeqsolve, Dopri5, ODETerm, SaveAt, PIDController vector_field = lambda t, y, args: -y term = ODETerm(vector_field) solver = Dopri5() saveat = SaveAt(ts=[0., 1., 2., 3.]) stepsize_controller = PIDController(rtol=1e-5, atol=1e-5) sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat, stepsize_controller=stepsize_controller) print(sol.ts) # DeviceArray([0. , 1. , 2. , 3. ]) print(sol.ys) # DeviceArray([1. , 0.368, 0.135, 0.0498]) ``` -------------------------------- ### Main Function for Kalman Filter Simulation and Optimization Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/kalman_filter.ipynb The main entry point for the Kalman Filter example. It sets up a simulated linear system, initializes a Kalman Filter, and optionally performs gradient-based optimization to tune the Q and R matrices. It also includes plotting functionality to visualize the results. ```Python def main( # evaluate at these timepoints ts=jnp.arange(0, 5.0, 0.01), # system that generates data sys_true=harmonic_oscillator(0.3), # initial state of our data generating system sys_true_x0=jnp.array([1.0, 0.0]), # standard deviation of measurement noise sys_true_std_measurement_noise=1.0, # our model for system `true`, it's not perfect sys_model=harmonic_oscillator(0.7), # initial state guess, it's not perfect sys_model_x0=jnp.array([0.0, 0.0]), # weighs how much we trust our model of the system Q_root=jnp.diag(jnp.ones((2,))) * 0.1, # weighs how much we trust in the measurements of the system R_root=jnp.diag(jnp.ones((1,))), # weighs how much we trust our initial guess P0=jnp.diag(jnp.ones((2,))) * 10.0, plot=True, n_gradient_steps=0, print_every=10, ): xs, ys = simulate_lti_system( sys_true, sys_true_x0, ts, std_measurement_noise=sys_true_std_measurement_noise ) kmf = KalmanFilter(sys_model, sys_model_x0, P0, Q_root, R_root) initial_Q = kmf.Q_root.T @ kmf.Q_root initial_R = kmf.R_root.T @ kmf.R_root print(f"Initial Q: \n{initial_Q}\n Initial R: \n{initial_R}") # gradients should only be able to change Q/R parameters # *not* the model (well at least not in this example :) filter_spec = jtu.tree_map(lambda arr: False, kmf) filter_spec = eqx.tree_at( lambda tree: (tree.Q_root, tree.R_root), filter_spec, replace=(True, True) ) opt = optax.adam(1e-2) opt_state = opt.init(kmf) @eqx.filter_value_and_grad def loss_fn(dynamic_kmf, static_kmf, ts, ys, xs): kmf = eqx.combine(dynamic_kmf, static_kmf) xhats = kmf(ts, ys) return jnp.mean((xs - xhats) ** 2) @eqx.filter_jit def make_step(kmf, opt_state, ts, ys, xs): dynamic_kmf, static_kmf = eqx.partition(kmf, filter_spec) value, grads = loss_fn(dynamic_kmf, static_kmf, ts, ys, xs) updates, opt_state = opt.update(grads, opt_state) kmf = eqx.apply_updates(kmf, updates) return value, kmf, opt_state for step in range(n_gradient_steps): value, kmf, opt_state = make_step(kmf, opt_state, ts, ys, xs) if step % print_every == 0: print("Current MSE: ", value) final_Q = kmf.Q_root.T @ kmf.Q_root final_R = kmf.R_root.T @ kmf.R_root print(f"Final Q: \n{final_Q}\n Final R: \n{final_R}") if plot: xhats = kmf(ts, ys) plt.plot(ts, xs[:, 0], label="true position", color="orange") plt.plot( ts, xhats[:, 0], label="estimated position", color="orange", linestyle="dashed", ) plt.plot(ts, xs[:, 1], label="true velocity", color="blue") plt.plot( ts, xhats[:, 1], label="estimated velocity", color="blue", linestyle="dashed", ) plt.xlabel("time") plt.ylabel("position / velocity") plt.grid() plt.legend() plt.title("Kalman-Filter optimization w.r.t Q/R") ``` -------------------------------- ### Running Diffrax Main Function with Different Inputs Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/continuous_normalising_flow.ipynb This snippet demonstrates how to execute the main function of the Diffrax project with various input parameters. It shows examples of specifying the input image path, the number of blocks, and the width size, allowing users to reproduce specific visualizations. ```Python main(in_path="../imgs/cat.png") main(in_path="../imgs/butterfly.png", num_blocks=3) main(in_path="../imgs/target.png", width_size=128) ``` -------------------------------- ### Demonstration of Diffrax Main Function Execution Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/continuous_normalising_flow.ipynb This is a specific example demonstrating the execution of the main function with a particular input path, '../imgs/cat.png'. This command is used to reproduce one of the visualizations shown at the beginning of the documentation. ```Python main(in_path="../imgs/cat.png") ``` -------------------------------- ### Import Libraries for Symbolic Regression Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/symbolic_regression.ipynb Imports necessary libraries for the symbolic regression example, including Equinox for neural networks, JAX for automatic differentiation and numerical computation, Optax for optimization, PySR for symbolic regression, SymPy for symbolic mathematics, and SymPy2Jax for integrating SymPy expressions with JAX. ```Python import tempfile import equinox as eqx # https://github.com/patrick-kidger/equinox import jax import jax.numpy as jnp import optax # https://github.com/deepmind/optax import pysr # https://github.com/MilesCranmer/PySR import sympy import sympy2jax # https://github.com/google/sympy2jax ``` -------------------------------- ### Solve CDE with QuadraticPath and Dopri5 Solver in Diffrax Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/usage/getting-started.md This snippet demonstrates solving a Controlled Differential Equation (CDE) in Diffrax. It defines a custom control signal using a QuadraticPath class inheriting from AbstractPath and solves the CDE using the Dopri5 solver. The output shows the time points and corresponding solution values. ```Python from diffrax import AbstractPath, ControlTerm, diffeqsolve, Dopri5 class QuadraticPath(AbstractPath): @property def t0(self): return 0 @property def t1(self): return 3 def evaluate(self, t0, t1=None, left=True): del left if t1 is not None: return self.evaluate(t1) - self.evaluate(t0) return t0 ** 2 vector_field = lambda t, y, args: -y control = QuadraticPath() term = ControlTerm(vector_field, control).to_ode() solver = Dopri5() sol = diffeqsolve(term, solver, t0=0, t1=3, dt0=0.05, y0=1) print(sol.ts) # DeviceArray([3.]) print(sol.ys) # DeviceArray([0.00012341]) ``` -------------------------------- ### Install Diffrax Source: https://github.com/patrick-kidger/diffrax/blob/main/README.md Installs the Diffrax library using pip. Requires Python 3.10 or higher. ```shell pip install diffrax ``` -------------------------------- ### Install Diffrax Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/index.md Installs the Diffrax library using pip. Requires Python 3.10 or higher. ```Shell pip install diffrax ``` -------------------------------- ### Solve Itô SDE with Euler Solver in Diffrax Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/usage/getting-started.md This snippet demonstrates solving an Itô Stochastic Differential Equation (SDE) using the Euler-Maruyama method in Diffrax. It defines the drift and diffusion terms, a Brownian motion process, and bundles them using MultiTerm. The solution is then evaluated at a specific time point. ```Python import jax.random as jr from diffrax import diffeqsolve, ControlTerm, Euler, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree t0, t1 = 0, 3 drift = lambda t, y, args: -y diffusion = lambda t, y, args: 0.1 * t brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=jr.PRNGKey(0)) terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion)) solver = Euler() saveat = SaveAt(dense=True) sol = diffeqsolve(terms, solver, t0, t1, dt0=0.05, y0=1.0, saveat=saveat) print(sol.evaluate(1.1)) # DeviceArray(0.89436394) ``` -------------------------------- ### Install Diffrax in Development Mode Source: https://github.com/patrick-kidger/diffrax/blob/main/CONTRIBUTING.md Clones the Diffrax repository and installs it in editable mode for development. Also installs pre-commit hooks for code quality. ```bash git clone https://github.com/your-username-here/diffrax.git cd diffrax pip install -e . pip install pre-commit pre-commit install ``` -------------------------------- ### Import Libraries for Diffrax Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/stiff_ode.ipynb Imports necessary libraries including Diffrax, Equinox, and JAX for numerical computations and ODE solving. Equinox is used for defining the ODE system as a PyTree. ```Python import time import diffrax import equinox as eqx import jax import jax.numpy as jnp ``` -------------------------------- ### Import Libraries for Neural SDE Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_sde.ipynb Imports necessary libraries including Diffrax, Equinox, JAX, Matplotlib, and Optax for building and training the Neural SDE model. ```Python import diffrax import equinox as eqx import jax import jax.nn as jnn import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt import optax ``` -------------------------------- ### Import Libraries for Latent ODE Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/latent_ode.ipynb Imports necessary libraries including diffrax, equinox, JAX, and plotting tools for the Latent ODE example. ```Python import time import diffrax import equinox as eqx import jax import jax.nn as jnn import jax.numpy as jnp import jax.random as jr import matplotlib import matplotlib.pyplot as plt import numpy as np import optax matplotlib.rcParams.update({"font.size": 30}) ``` -------------------------------- ### Diffrax EulerHeun for Cheap Stratonovich SDEs Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/usage/how-to-choose-a-solver.md For Stratonovich SDEs where cheap, low-accuracy solves are desired, `diffrax.EulerHeun` is a suitable choice. It provides a balance between computational cost and accuracy. ```Python import diffrax solver = diffrax.EulerHeun() ``` -------------------------------- ### Import Libraries for Diffrax PDE Solver Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/nonlinear_heat_pde.ipynb Imports necessary libraries including Diffrax, Equinox, JAX, and Matplotlib for solving the nonlinear heat PDE. It also configures JAX for 64-bit precision. ```Python from collections.abc import Callable import diffrax import equinox as eqx import jax import jax.lax as lax import jax.numpy as jnp import matplotlib.pyplot as plt from jaxtyping import Array, Float jax.config.update("jax_enable_x64", True) ``` -------------------------------- ### NeuralSDE Class Definition Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_sde.ipynb Starts the definition of the NeuralSDE class, which will encapsulate the neural network representing the SDE. ```Python class NeuralSDE(eqx.Module): ``` -------------------------------- ### Solve the Stiff ODE with Diffrax Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/stiff_ode.ipynb A JIT-compiled function that sets up and solves the stiff ODE using Diffrax. It configures the ODE term, solver (Kvaerno5), initial conditions, time span, and adaptive step size controller (PIDController) with specified tolerances. ```Python @jax.jit def main(k1, k2, k3): robertson = Robertson(k1, k2, k3) terms = diffrax.ODETerm(robertson) t0 = 0.0 t1 = 100.0 y0 = jnp.array([1.0, 0.0, 0.0]) dt0 = 0.0002 solver = diffrax.Kvaerno5() saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2])) stepsize_controller = diffrax.PIDController(rtol=1e-8, atol=1e-8) sol = diffrax.diffeqsolve( terms, solver, t0, t1, dt0, y0, saveat=saveat, stepsize_controller=stepsize_controller, ) return sol ``` -------------------------------- ### Main Execution Function Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/symbolic_regression.ipynb A placeholder for the main execution function of the script, typically used to start the program's workflow. ```python main() ``` -------------------------------- ### Setup for SRK Demonstration in Diffrax Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/devdocs/srk_example.ipynb Sets up the environment for demonstrating Stochastic Runge-Kutta (SRK) methods using Diffrax. It configures JAX for CUDA, imports necessary components from Diffrax and JAX, and sets display options for numerical precision and warnings. ```Python %env JAX_PLATFORM_NAME=cuda from test.helpers import ( get_mlp_sde, get_time_sde, simple_sde_order, ) from warnings import simplefilter import diffrax import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt from diffrax import ( diffeqsolve, GeneralShARK, ShARK, SlowRK, SpaceTimeLevyArea, SPaRK, SRA1, ) from jax import config simplefilter("ignore", category=FutureWarning) config.update("jax_enable_x64", True) jnp.set_printoptions(precision=4, suppress=True) ``` -------------------------------- ### Enable 64-bit Precision Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/stiff_ode.ipynb Enables 64-bit floating-point precision in JAX, which is crucial for achieving high accuracy (e.g., 1e-8 tolerances) when solving stiff ODEs. ```Python jax.config.update("jax_enable_x64", True) ``` -------------------------------- ### Import Libraries for Neural ODE Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_ode.ipynb Imports necessary libraries including Diffrax for ODE solving, Equinox for neural networks, JAX for numerical computation, and Optax for optimization. These are foundational for building and training the Neural ODE model. ```Python import time import diffrax import equinox as eqx import jax import jax.nn as jnn import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt import optax ``` -------------------------------- ### Main Function for Diffrax Project Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/continuous_normalising_flow.ipynb The main entry point for the Diffrax project. It initializes random keys, loads data, sets up the CNF model and optimizer, and defines the training loop. It handles file path management and configuration parameters. ```Python def main( in_path, out_path=None, batch_size=500, virtual_batches=2, lr=1e-3, weight_decay=1e-5, steps=10000, exact_logp=True, num_blocks=2, width_size=64, depth=3, print_every=100, seed=5678, ): if out_path is None: out_path = here / pathlib.Path(in_path).name else: out_path = pathlib.Path(out_path) key = jr.PRNGKey(seed) model_key, loader_key, loss_key, sample_key = jr.split(key, 4) dataset, weights, mean, std, img, width, height = get_data(in_path) dataset_size, data_size = dataset.shape dataloader = DataLoader((dataset, weights), batch_size, key=loader_key) model = CNF( data_size=data_size, exact_logp=exact_logp, num_blocks=num_blocks, width_size=width_size, depth=depth, key=model_key, ) optim = optax.adamw(lr, weight_decay=weight_decay) opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array)) # ... (loss and make_step functions defined here) ... # Training loop would typically follow here, calling make_step ``` -------------------------------- ### Main Training Entry Point Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_sde.ipynb The main function to run the GAN training process. It initializes the PRNG key, sets up the generator and discriminator models, defines optimizers, and iterates through the dataloader to perform training steps. It also includes periodic evaluation of the loss. ```Python def main( initial_noise_size=5, noise_size=3, hidden_size=16, width_size=16, depth=1, generator_lr=2e-5, discriminator_lr=1e-4, batch_size=1024, steps=10000, steps_per_print=200, dataset_size=8192, seed=5678, ): key = jr.PRNGKey(seed) ( data_key, generator_key, discriminator_key, dataloader_key, train_key, evaluate_key, sample_key, ) = jr.split(key, 7) data_key = jr.split(data_key, dataset_size) ts, ys = get_data(data_key) _, _, data_size = ys.shape generator = NeuralSDE( data_size, initial_noise_size, noise_size, hidden_size, width_size, depth, key=generator_key, ) discriminator = NeuralCDE( data_size, hidden_size, width_size, depth, key=discriminator_key ) g_optim = optax.rmsprop(generator_lr) d_optim = optax.rmsprop(-discriminator_lr) g_opt_state = g_optim.init(eqx.filter(generator, eqx.is_inexact_array)) d_opt_state = d_optim.init(eqx.filter(discriminator, eqx.is_inexact_array)) infinite_dataloader = dataloader( (ts, ys), batch_size, loop=True, key=dataloader_key ) for step, (ts_i, ys_i) in zip(range(steps), infinite_dataloader): step = jnp.asarray(step) generator, discriminator, g_opt_state, d_opt_state = make_step( generator, discriminator, g_opt_state, d_opt_state, g_optim, d_optim, ts_i, ys_i, key, step, ) if (step % steps_per_print) == 0 or step == steps - 1: total_score = 0 num_batches = 0 for ts_i, ys_i in dataloader( (ts, ys), batch_size, loop=False, key=evaluate_key ): score = loss(generator, discriminator, ts_i, ys_i, sample_key) total_score += score.item() num_batches += ``` -------------------------------- ### Run Diffrax Tests Source: https://github.com/patrick-kidger/diffrax/blob/main/CONTRIBUTING.md Installs testing requirements and runs all tests using pytest to verify code changes. ```bash pip install -r test/requirements.txt pytest ``` -------------------------------- ### Set Tolerances with PIDController Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/nonlinear_heat_pde.ipynb Configures a PID controller with specified relative and absolute tolerances for adaptive step-size control in differential equation solving. ```Python rtol = 1e-10 atol = 1e-10 stepsize_controller = diffrax.PIDController( pcoeff=0.3, icoeff=0.4, rtol=rtol, atol=atol, dtmax=0.001 ) ``` -------------------------------- ### Execute Neural ODE Main Function Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_ode.ipynb This snippet shows how to call the main function that orchestrates the training and plotting of the Neural ODE model. ```Python ts, ys, model = main() ``` -------------------------------- ### Execute and Time the ODE Solver Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/stiff_ode.ipynb Executes the `main` function twice: once for JIT compilation and a second time to measure the execution performance. It then prints the results (time and state values) and the number of steps taken along with the execution time. ```Python main(0.04, 3e7, 1e4) start = time.time() sol = main(0.04, 3e7, 1e4) end = time.time() print("Results:") for ti, yi in zip(sol.ts, sol.ys): print(f"t={ti.item()}, y={yi.tolist()}") print(f"Took {sol.stats['num_steps']} steps in {end - start} seconds.") ``` -------------------------------- ### Execute Main Function Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/latent_ode.ipynb This snippet simply calls the main function, presumably to run the Latent ODE model and generate the data for visualization. ```python main() ``` -------------------------------- ### Main Function for Training and Visualization Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/latent_ode.ipynb The main entry point for the project. It initializes the model, optimizer, and data loader, then iterates through training steps. In each step, it computes the loss, updates the model, and periodically visualizes the model's predictions by sampling over a longer time interval. ```Python def main( dataset_size=10000, batch_size=256, lr=1e-2, steps=250, save_every=50, hidden_size=16, latent_size=16, width_size=16, depth=2, seed=5678, ): key = jr.PRNGKey(seed) data_key, model_key, loader_key, train_key, sample_key = jr.split(key, 5) ts, ys = get_data(dataset_size, key=data_key) model = LatentODE( data_size=ys.shape[-1], hidden_size=hidden_size, latent_size=latent_size, width_size=width_size, depth=depth, key=model_key, ) @eqx.filter_value_and_grad def loss(model, ts_i, ys_i, key_i): batch_size, _ = ts_i.shape key_i = jr.split(key_i, batch_size) loss = jax.vmap(model.train)(ts_i, ys_i, key=key_i) return jnp.mean(loss) @eqx.filter_jit def make_step(model, opt_state, ts_i, ys_i, key_i): value, grads = loss(model, ts_i, ys_i, key_i) key_i = jr.split(key_i, 1)[0] updates, opt_state = optim.update(grads, opt_state) model = eqx.apply_updates(model, updates) return value, model, opt_state, key_i optim = optax.adam(lr) opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array)) # Plot results num_plots = 1 + (steps - 1) // save_every if ((steps - 1) % save_every) != 0: num_plots += 1 fig, axs = plt.subplots(1, num_plots, figsize=(num_plots * 8, 8)) axs[0].set_ylabel("x") axs = iter(axs) for step, (ts_i, ys_i) in zip( range(steps), dataloader((ts, ys), batch_size, key=loader_key) ): start = time.time() value, model, opt_state, train_key = make_step( model, opt_state, ts_i, ys_i, train_key ) end = time.time() print(f"Step: {step}, Loss: {value}, Computation time: {end - start}") if (step % save_every) == 0 or step == steps - 1: ax = next(axs) # Sample over a longer time interval than we trained on. The model will be # sufficiently good that it will correctly extrapolate! sample_t = jnp.linspace(0, 12, 300) sample_y = model.sample(sample_t, key=sample_key) sample_t = np.asarray(sample_t) sample_y = np.asarray(sample_y) ax.plot(sample_t, sample_y[:, 0]) ax.plot(sample_t, sample_y[:, 1]) ``` -------------------------------- ### Import Libraries for Kalman Filter Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/kalman_filter.ipynb Imports necessary libraries for implementing and optimizing a Kalman Filter. This includes Diffrax for differential equations, Equinox for building the filter and LTI systems, JAX for numerical operations, and Optax for optimization. ```Python from types import SimpleNamespace import diffrax as dfx import equinox as eqx # https://github.com/patrick-kidger/equinox import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu import matplotlib.pyplot as plt import optax # https://github.com/deepmind/optax ``` -------------------------------- ### Create Harmonic Oscillator LTI System Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/kalman_filter.ipynb A function that generates an LTISystem representing a harmonic oscillator. It takes optional damping and time scaling parameters to configure the system's dynamics. This serves as a concrete example of an LTI system. ```Python def harmonic_oscillator(damping: float = 0.0, time_scaling: float = 1.0) -> LTISystem: A = jnp.array([[0.0, time_scaling], [-time_scaling, -2 * damping]]) B = jnp.array([[0.0], [1.0]]) C = jnp.array([[0.0, 1.0]]) return LTISystem(A, B, C) ``` -------------------------------- ### Generate Toy Dataset (Oscillators) Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_ode.ipynb Provides a function `get_data` to generate a toy dataset of nonlinear oscillators. It defines a simple ODE system and uses `diffrax.diffeqsolve` to sample trajectories for a given number of data points. ```Python def _get_data(ts, *, key): y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1) def f(t, y, args): x = y / (1 + y) return jnp.stack([x[1], -x[0]], axis=-1) solver = diffrax.Tsit5() dt0 = 0.1 saveat = diffrax.SaveAt(ts=ts) sol = diffrax.diffeqsolve( diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat ) ys = sol.ys return ys def get_data(dataset_size, *, key): ts = jnp.linspace(0, 10, 100) key = jr.split(key, dataset_size) ys = jax.vmap(lambda key: _get_data(ts, key=key))(key) return ts, ys ``` -------------------------------- ### Diffrax Tsit5 for Non-Stiff ODEs Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/usage/how-to-choose-a-solver.md The `diffrax.Tsit5` solver is recommended as a good general-purpose option for non-stiff ordinary differential equations. It is generally considered more efficient than `diffrax.Dopri5`. ```Python import diffrax solver = diffrax.Tsit5() ``` -------------------------------- ### Execute Main Training Function Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_cde.ipynb This snippet is the entry point for running the entire Neural CDE training and evaluation process. Calling this function initiates the data generation, model training, and result plotting as defined in the `main` function. ```python main() ``` -------------------------------- ### Run Kalman Filter Simulation (With Optimization) Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/kalman_filter.ipynb Executes the main Kalman Filter simulation function with `n_gradient_steps` set to 100. This will run the filter and perform 100 gradient descent steps to optimize the Q and R matrices. ```Python main(n_gradient_steps=100) ``` -------------------------------- ### Main Training Function Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_ode.ipynb The `main` function orchestrates the training process. It initializes the model, optimizer, and data loader. It then iterates through training steps, calculating loss, gradients, and updating model parameters. A strategy is employed to initially train on shorter time series segments. ```Python def main( dataset_size=256, batch_size=32, lr=3e-3, steps_strategy=(500, 500), length_strategy=(0.1, 1), width_size=64, depth=2, seed=5678, plot=True, print_every=100, ): key = jr.PRNGKey(seed) data_key, model_key, loader_key = jr.split(key, 3) ts, ys = get_data(dataset_size, key=data_key) _, length_size, data_size = ys.shape model = NeuralODE(data_size, width_size, depth, key=model_key) optim = optax.adabelief(lr) # Training loop like normal. # # Only thing to notice is that up until step 500 we train on only the first 10% of # each time series. This is a standard trick to avoid getting caught in a local # minimum. @eqx.filter_value_and_grad def grad_loss(model, ti, yi): y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0]) return jnp.mean((yi - y_pred) ** 2) @eqx.filter_jit def make_step(ti, yi, model, opt_state): loss, grads = grad_loss(model, ti, yi) updates, opt_state = optim.update(grads, opt_state) model = eqx.apply_updates(model, updates) return loss, model, opt_state for steps, length in zip(steps_strategy, length_strategy): opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array)) _ts = ts[: int(length_size * length)] _ys = ys[:, : int(length_size * length)] for step, (yi,) in zip( ``` -------------------------------- ### Execute Main Function Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_sde.ipynb This snippet simply calls the main function to execute the plotting and analysis logic defined elsewhere in the script. ```python main() ``` -------------------------------- ### Main Function for Symbolic Regression Workflow Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/symbolic_regression.ipynb The main function orchestrates the symbolic regression process. It includes parameters for dataset size, PySR configuration, and fine-tuning. It first trains a neural ODE (by running another notebook), then uses PySR to find symbolic expressions for the learned vector field, and finally fine-tunes these expressions by plugging them back into the dataset and applying gradient descent. ```Python def main( symbolic_dataset_size=2000, symbolic_num_populations=100, symbolic_population_size=20, symbolic_migration_steps=4, symbolic_mutation_steps=30, symbolic_descent_steps=50, pareto_coefficient=2, fine_tuning_steps=500, fine_tuning_lr=3e-3, quantise_to=0.01, ): # # First obtain a neural approximation to the dynamics. # We begin by running the previous example. # # Runs the Neural ODE example. # This defines the variables `ts`, `ys`, `model`. print("Training neural differential equation.") %run neural_ode.ipynb # # Now symbolically regress across the learnt vector field, to obtain a Pareto # frontier of symbolic equations, that trades loss against complexity of the # equation. Select the "best" from this frontier. # print("Symbolically regressing across the vector field.") vector_field = model.func.mlp # noqa: F821 dataset_size, length_size, data_size = ys.shape # noqa: F821 in_ = ys.reshape(dataset_size * length_size, data_size) # noqa: F821 in_ = in_[:symbolic_dataset_size] out = jax.vmap(vector_field)(in_) with tempfile.TemporaryDirectory() as tempdir: symbolic_regressor = pysr.PySRRegressor( niterations=symbolic_migration_steps, ncycles_per_iteration=symbolic_mutation_steps, populations=symbolic_num_populations, population_size=symbolic_population_size, optimizer_iterations=symbolic_descent_steps, optimizer_nrestarts=1, procs=1, model_selection="score", progress=False, tempdir=tempdir, temp_equation_file=True, ) symbolic_regressor.fit(in_, out) best_expressions = [b.sympy_format for b in symbolic_regressor.get_best()] # # Now the constants in this expression have been optimised for regressing across # the neural vector field. This was good enough to obtain the symbolic expression, # but won't quite be perfect -- some of the constants will be slightly off. # # To fix this we now plug our symbolic function back into the original dataset # and apply gradient descent. # print("\nOptimising symbolic expression.") symbolic_fn = Stack([sympy2jax.SymbolicModule(expr) for expr in best_expressions]) ``` -------------------------------- ### Diffrax GeneralShARK for Non-Commutative Stratonovich SDEs Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/usage/how-to-choose-a-solver.md For Stratonovich SDEs with non-commutative noise, `diffrax.GeneralShARK` is the most efficient choice. `diffrax.Heun` serves as a good, cheaper alternative. ```Python import diffrax solver = diffrax.GeneralShARK() ``` ```Python import diffrax solver = diffrax.Heun() ``` -------------------------------- ### Solve ODE with Tsit5 Solver Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/nonlinear_heat_pde.ipynb Solves a differential equation using the Tsit5 solver, incorporating a PID step-size controller for adaptive integration. It saves the solution at specified points. ```Python solver = diffrax.Tsit5() sol = diffrax.diffeqsolve( term, solver, t0, t_final, δt, y0, saveat=saveat, stepsize_controller=stepsize_controller, max_steps=None, ) ``` -------------------------------- ### Define the Robertson ODE System Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/stiff_ode.ipynb Defines the Robertson problem as a class inheriting from `eqx.Module`. The `__call__` method implements the system of ODEs, taking time `t`, state `y`, and arguments `args` to return the derivatives. ```Python class Robertson(eqx.Module): k1: float k2: float k3: float def __call__(self, t, y, args): f0 = -self.k1 * y[0] + self.k3 * y[1] * y[2] f1 = self.k1 * y[0] - self.k2 * y[1] ** 2 - self.k3 * y[1] * y[2] f2 = self.k2 * y[1] ** 2 return jnp.stack([f0, f1, f2]) ``` -------------------------------- ### Import Libraries for Neural CDE Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_cde.ipynb Imports necessary libraries including diffrax, equinox, jax, jax.nn, jax.numpy, jax.random, jax.scipy, matplotlib, and optax for building and training the Neural CDE model. ```Python import math import time import diffrax import equinox as eqx import jax import jax.nn as jnn import jax.numpy as jnp import jax.random as jr import jax.scipy as jsp import matplotlib import matplotlib.pyplot as plt import optax ``` -------------------------------- ### Set up Diffrax ODE Solver for Nonlinear Heat PDE Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/nonlinear_heat_pde.ipynb Initializes the Diffrax ODE solver by defining the `ODETerm` with the `vector_field`, setting the initial condition `ic` for the spatial discretization, and specifying the temporal discretization parameters including the save-at points. ```Python term = diffrax.ODETerm(vector_field) ic = lambda x: x**2 # Spatial discretisation x0 = -1 x_final = 1 n = 50 y0 = SpatialDiscretisation.discretise_fn(x0, x_final, n, ic) # Temporal discretisation t0 = 0 t_final = 1 δt = 0.0001 saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t_final, 50)) ``` -------------------------------- ### LipSwish Activation Function Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_sde.ipynb Defines the LipSwish activation function, a variant of the Swish function, used in the neural networks for the SDE-GAN. ```Python def lipswish(x): return 0.909 * jnn.silu(x) ``` -------------------------------- ### Diffrax SDE Term Structure Example Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/api/solvers/sde_solvers.md Demonstrates how to define terms for solving SDEs using Diffrax. It shows the use of `ODETerm` for the drift and `ControlTerm` with `UnsafeBrownianPath` for the diffusion, combined within a `MultiTerm`. ```Python drift = lambda t, y, args: -y diffusion = lambda t, y, args: y[..., None] bm = UnsafeBrownianPath(shape=(1,), key=...) terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, bm)) diffeqsolve(terms, solver=Euler(), ...) ``` -------------------------------- ### Create a Batched Dataloader Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_sde.ipynb Implements a dataloader that yields batches of data from given arrays. It supports looping through the dataset and shuffling the data using permutations. The batch size and looping behavior are configurable. ```Python def dataloader(arrays, batch_size, loop, *, key): dataset_size = arrays[0].shape[0] assert all(array.shape[0] == dataset_size for array in arrays) indices = jnp.arange(dataset_size) while True: perm = jr.permutation(key, indices) key = jr.split(key, 1)[0] start = 0 end = batch_size while end < dataset_size: batch_perm = perm[start:end] yield tuple(array[batch_perm] for array in arrays) start = end end = start + batch_size if not loop: break ``` -------------------------------- ### Create Dataloader Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/latent_ode.ipynb Creates a dataloader that yields batches of data from given arrays. It shuffles the data at each epoch and provides batches of a specified size. This is useful for training machine learning models. ```Python def dataloader(arrays, batch_size, *, key): dataset_size = arrays[0].shape[0] assert all(array.shape[0] == dataset_size for array in arrays) indices = jnp.arange(dataset_size) while True: perm = jr.permutation(key, indices) (key,) = jr.split(key, 1) start = 0 end = batch_size while start < dataset_size: batch_perm = perm[start:end] yield tuple(array[batch_perm] for array in arrays) start = end end = start + batch_size ``` -------------------------------- ### Diffrax KenCarp4 for Split Stiffness ODEs Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/usage/how-to-choose-a-solver.md IMEX (Implicit-Explicit) methods are suitable for 'split stiffness' problems where one term is stiff and another is non-stiff. `diffrax.KenCarp4` is recommended for such cases, typically used with `diffrax.PIDController`. ```Python import diffrax solver = diffrax.KenCarp4() stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5) ``` -------------------------------- ### Diffrax SPaRK for Adaptive Stratonovich SDEs (Non-Commutative) Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/usage/how-to-choose-a-solver.md If an embedded method for adaptive step size control is needed for Stratonovich SDEs with non-commutative noise, `diffrax.SPaRK` is the recommended choice. ```Python import diffrax solver = diffrax.SPaRK() stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5) ``` -------------------------------- ### Create Data Loader Source: https://github.com/patrick-kidger/diffrax/blob/main/docs/examples/neural_ode.ipynb Implements a `dataloader` function that yields batches of data from the provided arrays. It uses JAX's random permutation to shuffle the data at each epoch, ensuring variety in training batches. ```Python def dataloader(arrays, batch_size, *, key): dataset_size = arrays[0].shape[0] assert all(array.shape[0] == dataset_size for array in arrays) indices = jnp.arange(dataset_size) while True: perm = jr.permutation(key, indices) (key,) = jr.split(key, 1) start = 0 end = batch_size while end < dataset_size: batch_perm = perm[start:end] yield tuple(array[batch_perm] for array in arrays) start = end end = start + batch_size ```