### Install Libraries Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/nnx_example.ipynb Installs necessary libraries including numpyro, arviz, flax, and matplotlib. Use this at the beginning of your environment setup. ```bash !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro arviz flax matplotlib ``` -------------------------------- ### Install Numpyro and Funsor Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/discrete_imputation.ipynb Installs the necessary libraries, numpyro and funsor, from GitHub. This is a prerequisite for running the subsequent code examples. ```python !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro funsor ``` -------------------------------- ### Import Libraries and Setup Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/effect_handlers.ipynb Imports necessary libraries like JAX, NumPy, ArviZ, and NumPyro, and sets up plotting and JAX device configuration. This is boilerplate for most NumPyro examples. ```python import arviz as az import matplotlib.pyplot as plt import numpy as np import pandas as pd import jax import jax.numpy as jnp import numpyro from numpyro import handlers import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS, Predictive from numpyro.infer.reparam import LocScaleReparam numpyro.set_host_device_count(4) plt.style.use("bmh") plt.rcParams["figure.figsize"] = [10, 6] plt.rcParams["figure.dpi"] = 100 plt.rcParams["figure.facecolor"] = "white" rng_key = jax.random.key(42) %load_ext autoreload %autoreload 2 %config InlineBackend.figure_format = "retina" ``` -------------------------------- ### Install Dependencies Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/other_samplers.ipynb Installs NumPyro from a Git repository, ArviZ, Blackjax, and FlowMC. This is a prerequisite for running the integration examples. ```shell !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro arviz blackjax flowMC ``` -------------------------------- ### Initialize MAP Estimator with AutoDelta Guide Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/gmm.ipynb Initializes the AutoDelta guide for the GMM model, selecting the best among multiple random initializations to mitigate local modes. Uses `init_to_value` for plausible starting points. ```python elbo = TraceEnum_ELBO() def initialize(seed): global global_guide init_values = { "weights": jnp.ones(K) / K, "scale": jnp.sqrt(data.var() / 2), "locs": data[ random.categorical( random.key(seed), jnp.ones(len(data)) / len(data), shape=(K,) ) ], } global_model = handlers.block( handlers.seed(model, random.key(0)), hide_fn=lambda site: ( site["name"] not in ["weights", "scale", "locs", "components"] ), ) global_guide = AutoDelta( global_model, init_loc_fn=init_to_value(values=init_values) ) handlers.seed(global_guide, random.key(0))(data) # warm up the guide return elbo.loss(random.key(0), {}, model, global_guide, data) # Choose the best among 100 random initializations. loss, seed = min((initialize(seed), seed) for seed in range(100)) initialize(seed) # initialize the global_guide print(f"seed = {seed}, initial_loss = {loss}") ``` -------------------------------- ### Install Numpyro from Git Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/bad_posterior_geometry.ipynb Installs the Numpyro library directly from its GitHub repository. This is useful for getting the latest development version. Ensure you have the necessary build tools installed. ```python !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro ``` -------------------------------- ### Install and Install Pre-Commit Hooks Source: https://github.com/pyro-ppl/numpyro/blob/master/CONTRIBUTING.md Installs the pre-commit package and its hooks to automatically format code before each commit. Hooks can be skipped with '--no-verify'. ```sh pip install pre-commit pre-commit install ``` -------------------------------- ### Install Dependencies Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/README.md Installs the necessary Python packages required for building the documentation. ```bash pip install -r requirements.txt ``` -------------------------------- ### Run SVI with Full Guide and Visualize Convergence Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/gmm.ipynb Sets up and runs SVI using the full guide and TraceEnum_ELBO. Visualizes the SVI loss convergence and gradient norms during optimization. Requires NumPyro, Optax, and Matplotlib. ```python from numpyro.infer import SVI, TraceEnum_ELBO from jax import random import matplotlib.pyplot as plt import optax # Assuming 'model', 'trained_global_guide', 'data', 'smoke_test', 'hook_optax' are defined elsewhere # For demonstration, let's mock them if they are not provided in the context # model = lambda data: None # Replace with actual model # trained_global_guide = lambda data: None # Replace with actual trained guide # data = jnp.array([...]) # Replace with actual data # smoke_test = False # Set appropriately # hook_optax = lambda optimizer: (optimizer, {}) # Replace with actual hook if needed optim, gradient_norms = hook_optax(optax.adam(learning_rate=0.2, b1=0.8, b2=0.99)) elbo = TraceEnum_ELBO() full_svi = SVI(model, full_guide, optim, loss=elbo) full_svi_result = full_svi.run(random.key(0), 200 if not smoke_test else 2, data) plt.figure(figsize=(10, 3), dpi=100).set_facecolor("white") plt.plot(full_svi_result.losses) plt.xlabel("iters") plt.ylabel("loss") plt.yscale("log") plt.title("Convergence of SVI") plt.show() plt.figure(figsize=(10, 4), dpi=100).set_facecolor("white") for name, grad_norms in gradient_norms.items(): plt.plot(grad_norms, label=name) plt.xlabel("iters") plt.ylabel("gradient norm") plt.yscale("log") plt.legend(loc="best") plt.title("Gradient norms during SVI") plt.show() ``` -------------------------------- ### Install Numpyro and UC Irvine ML Repository Package Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/variationally_inferred_parameterization.ipynb Installs the necessary libraries for the tutorial. Use this to set up your environment. ```python %pip -qq install numpyro %pip -qq install ucimlrepo ``` -------------------------------- ### Substitute Trained Parameters and Replay Globals Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/gmm.ipynb This snippet shows how to substitute trained parameters into a guide and then replay these globals into the model. This is a setup step for using the `infer_discrete` handler. ```python trained_global_guide = handlers.substitute( global_guide, global_svi_result.params ) # substitute trained params guide_trace = handlers.trace(trained_global_guide).get_trace(data) # record the globals trained_model = handlers.replay(model, trace=guide_trace) # replay the globals ``` -------------------------------- ### Install GPJax from GitHub Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/gpjax_example.ipynb Installs the GPJax library directly from its GitHub repository. Ensure you have pip and Git installed. ```python !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyx ``` -------------------------------- ### Install NumPyro and Optax Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/tbip.ipynb Installs the NumPyro library from its GitHub repository and the optax optimization library. This is a prerequisite for running the TBIP model. ```python %%capture %pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro) %pip install optax ``` -------------------------------- ### Install Dependencies Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/hierarchical_forecasting.ipynb Installs the required libraries, including a specific version of numpyro from its GitHub repository. ```bash !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro arviz matplotlib optax ``` -------------------------------- ### Install NumPyro from Source Source: https://github.com/pyro-ppl/numpyro/blob/master/CONTRIBUTING.md Installs NumPyro and its development dependencies from a local clone of the repository. Ensure JAX/JAXlib are installed first for CUDA support. ```sh git clone https://github.com/pyro-ppl/numpyro.git # install jax/jaxlib first for CUDA support pip install -e '.[dev,test,doc,examples]' ``` -------------------------------- ### Install NumPyro with Latest CPU JAX Source: https://github.com/pyro-ppl/numpyro/blob/master/README.md Use this command for a standard CPU installation of NumPyro with the latest JAX version. ```bash pip install numpyro ``` -------------------------------- ### Import Libraries and Setup Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/truncated_distributions.ipynb Imports necessary libraries including JAX, NumPyro, and distribution modules. Sets up random number generators and MCMC keyword arguments. ```python import matplotlib.pyplot as plt import numpy as np from scipy.stats import poisson as sp_poisson import jax from jax import lax, random import jax.numpy as jnp from jax.scipy.special import ndtri from jax.scipy.stats import norm, poisson import numpyro import numpyro.distributions as dist from numpyro.distributions import ( Distribution, FoldedDistribution, SoftLaplace, StudentT, TruncatedDistribution, TruncatedNormal, constraints, ) from numpyro.distributions.util import promote_shapes from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs, Predictive numpyro.enable_x64() RNG = random.key(0) PRIOR_RNG, MCMC_RNG, PRED_RNG = random.split(RNG, 3) MCMC_KWARGS = dict( num_warmup=2000, num_samples=2000, num_chains=4, chain_method="sequential", ) ``` -------------------------------- ### Import Libraries and Setup Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/gmm.ipynb Imports necessary libraries for NumPyro, plotting, and numerical operations. Sets up the environment for the tutorial, including checking the NumPyro version. ```python from collections import defaultdict import os import matplotlib.pyplot as plt import scipy.stats from jax import pure_callback, random import jax.numpy as jnp import optax import numpyro from numpyro import handlers from numpyro.contrib.funsor import config_enumerate, infer_discrete import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.infer import SVI, TraceEnum_ELBO, init_to_value from numpyro.infer.autoguide import AutoDelta %matplotlib inline smoke_test = "CI" in os.environ assert numpyro.__version__.startswith("0.21.0") ``` -------------------------------- ### Install Dependencies Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/bayesian_cuped.ipynb Installs necessary libraries including numpyro, arviz, matplotlib, polars, and seaborn. Use this snippet to set up your environment. ```python !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro arviz matplotlib polars seaborn ``` -------------------------------- ### Install NumPyro with Compatible CPU JAX Version Source: https://github.com/pyro-ppl/numpyro/blob/master/README.md Install NumPyro with a specific, known-compatible CPU version of JAX if compatibility issues arise with the default installation. ```bash pip install 'numpyro[cpu]' ``` -------------------------------- ### Install Dependencies Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/hsgp_example.ipynb Installs numpyro from a git repository and arviz for data analysis. ```python !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro arviz ``` -------------------------------- ### Install NumPyro from Git Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/lotka_volterra_multiple.ipynb Installs the latest version of NumPyro directly from its GitHub repository. Use this for the most up-to-date features or if a specific version is required. ```python #!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro ``` -------------------------------- ### Install NumPyro from Git Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/circulant_gp.ipynb Installs the NumPyro library directly from its GitHub repository, along with matplotlib for plotting. ```python !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro matplotlib ``` -------------------------------- ### Import Libraries and Setup Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/consensus_mc.ipynb Imports necessary libraries like JAX, NumPyro, ArviZ, Matplotlib, and tqdm. Sets up plotting styles and configures NumPyro for multi-device execution. Asserts the NumPyro version. ```python import arviz as az import matplotlib.pyplot as plt from tqdm.notebook import tqdm from jax import random import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.handlers import scale from numpyro.infer import MCMC, NUTS from numpyro.infer.hmc_util import consensus, parametric_draws from numpyro.infer.util import Predictive plt.style.use("bmh") plt.rcParams["figure.figsize"] = [10, 6] plt.rcParams["figure.dpi"] = 100 plt.rcParams["figure.facecolor"] = "white" numpyro.set_host_device_count(n=4) rng_key = random.key(seed=42) assert numpyro.__version__.startswith("0.21.0") %load_ext autoreload %autoreload 2 %config InlineBackend.figure_format = "retina" ``` -------------------------------- ### Time and Covariate Setup Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/hierarchical_forecasting.ipynb Initializes time arrays and covariate arrays for training and testing periods. Includes assertions to verify the shapes and consistency of the data splits. ```python n_stations = y_train.shape[-2] time = jnp.array(range(T0, T2)) time_train = jnp.array(range(T0, T1)) t_max_train = time_train.size time_test = jnp.array(range(T1, T2)) t_max_test = time_test.size covariates = jnp.zeros_like(y) covariates_train = jnp.zeros_like(y_train) covariates_test = jnp.zeros_like(y_test) assert time_train.size + time_test.size == time.size assert y_train.shape == (n_stations, n_stations, t_max_train) assert y_test.shape == (n_stations, n_stations, t_max_test) assert covariates.shape == y.shape assert covariates_train.shape == y_train.shape assert covariates_test.shape == y_test.shape ``` -------------------------------- ### Create SVI Object with AutoNormal Guide Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/tbip.ipynb This snippet demonstrates how to create an SVI object using an automatically generated guide with AutoNormal. It initializes parameters and verifies that the autoguide produces results identical to a manually defined guide. ```python from numpyro.infer.autoguide import AutoNormal def create_svi_object(guide): SVI( model=tbip.model, guide=guide, optim=adam(exponential_decay(learning_rate, num_steps, decay_rate)), loss=TraceMeanField_ELBO(), ) Y_batch, I_batch, D_batch = tbip.get_batch(random.key(1), counts, author_indices) svi_state = svi_batch.init( random.key(0), Y_batch=Y_batch, d_batch=D_batch, i_batch=I_batch ) return svi_state # This state uses the guide defined manually above svi_state_manualguide = create_svi_object(guide=tbip.guide) # Now let's create this object but using AutoNormal guide. We just need to ensure that # parameters are initialized as above. autoguide = AutoNormal( model=tbip.model, init_loc_fn={"beta": initial_objective_topic_loc, "theta": initial_document_loc}, ) svi_state_autoguide = create_svi_object(guide=autoguide) # Assert that the keys in the optimizer states are identical assert svi_state_manualguide[0][1][0].keys() == svi_state_autoguide[0][1][0].keys() # Assert that the values in the optimizer states are identical for key in svi_state_manualguide[0][1][0].keys(): assert jnp.all( svi_state_manualguide[0][1][0][key] == svi_state_autoguide[0][1][0][key] ) ``` -------------------------------- ### Initialize and Run SVI Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/hierarchical_forecasting.ipynb Initializes the SVI object with the model, guide, optimizer, and loss function, then runs the inference process for a specified number of steps. The ELBO loss is plotted to monitor convergence. ```python guide = AutoNormal(model) optimizer = optimizer svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) num_steps = 8_000 rng_key, rng_subkey = random.split(key=rng_key) svi_result = svi.run(rng_subkey, num_steps, covariates_train, y_train) fig, ax = plt.subplots(figsize=(9, 6)) ax.plot(svi_result.losses) ax.set_yscale("log") ax.set_title("ELBO loss", fontsize=18, fontweight="bold"); ``` -------------------------------- ### MCMC Inference Setup and Run Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/censoring.ipynb Sets up and runs Markov Chain Monte Carlo (MCMC) inference using the NUTS kernel for the censored gamma model. Requires NumPyro, JAX, and ArviZ. ```python censored_gamma_kernel = NUTS(censored_gamma_model) censored_gamma_mcmc = MCMC( censored_gamma_kernel, num_warmup=inference_params.num_warmup, num_samples=inference_params.num_samples, num_chains=inference_params.num_chains, ) rng_key, rng_subkey = random.split(rng_key) censored_gamma_mcmc.run( rng_key, y=censored_gamma_samples, lower=censored_gamma_data_params.lower, upper=censored_gamma_data_params.upper, truncation_label=truncation_label, ) ``` -------------------------------- ### Define Full Guide for GMM with Enumeration Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/gmm.ipynb Defines a guide that enumerates local variables for predicting class membership. Uses handlers.block to keep learned global parameters fixed. Requires NumPyro, JAX, and constraints. ```python from numpyro import handlers, sample, param from numpyro.distributions import constraints, dist from jax import numpy as jnp from numpyro.infer.enum import config_enumerate K = 2 # Assuming K=2 components for simplicity in this example @config_enumerate def full_guide(data): # Global variables. with handlers.block( hide=["weights_auto_loc", "locs_auto_loc", "scale_auto_loc"] ): # Keep our learned values of global parameters. trained_global_guide(data) # Local variables. with numpyro.plate("data", len(data)): assignment_probs = numpyro.param( "assignment_probs", jnp.ones((len(data), K)) / K, constraint=constraints.simplex, ) numpyro.sample("assignment", dist.Categorical(assignment_probs)) ``` -------------------------------- ### Import Libraries and Setup Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/censoring.ipynb Imports essential libraries for Bayesian modeling, plotting, and data handling. It also configures Matplotlib for better display in notebooks and sets the number of devices for NumPyro. ```python import os import arviz as az from IPython.display import set_matplotlib_formats from jaxlib.xla_extension import ArrayImpl import matplotlib.pyplot as plt import preliz as pz from pydantic import BaseModel, Field from jax import random import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.handlers import mask from numpyro.infer import MCMC, NUTS, Predictive plt.style.use("bmh") if "NUMPYRO_SPHINXBUILD" in os.environ: set_matplotlib_formats("svg") plt.rcParams["figure.figsize"] = [8, 6] numpyro.set_host_device_count(n=4) rng_key = random.key(seed=0) assert numpyro.__version__.startswith("0.21.0") %load_ext autoreload %autoreload 2 %config InlineBackend.figure_format = "retina" ``` -------------------------------- ### Install NumPyro and Dependencies Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/censoring.ipynb Installs NumPyro from its git repository along with other necessary libraries like ArviZ, Matplotlib, PreliZ, and Pydantic. This is typically run in a notebook environment. ```python !pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro arviz matplotlib preliz pydantic ``` -------------------------------- ### Import Libraries and Setup Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/bayesian_cuped.ipynb Imports essential libraries for data manipulation, plotting, and Bayesian inference with NumPyro. Sets up plotting styles and backend. Ensures the correct numpyro version is used. ```python import arviz as az import matplotlib.pyplot as plt import numpy as np import polars as pl import seaborn as sns from jax import Array, random import jax.numpy as jnp import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS numpyro.set_host_device_count(n=4) rng_key = random.key(seed=1) plt.style.use("bmh") plt.rcParams["figure.figsize"] = [10, 6] plt.rcParams["figure.dpi"] = 100 plt.rcParams["figure.facecolor"] = "white" assert numpyro.__version__.startswith("0.21.0") %load_ext autoreload %autoreload 2 %config InlineBackend.figure_format = "retina" ``` -------------------------------- ### Centered Parameterization MCMC Setup and Run Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/effect_handlers.ipynb Sets up and runs MCMC for the centered parameterization of the Eight Schools model. This parameterization can lead to divergences due to the funnel geometry. ```python kernel_centered = NUTS( eight_schools, target_accept_prob=0.9, ) mcmc_centered = MCMC( kernel_centered, num_warmup=1_000, num_samples=1_000, num_chains=4, ) rng_key, rng_subkey = jax.random.split(rng_key) mcmc_centered.run(rng_subkey, J, sigma_schools, y=y_schools) mcmc_centered.print_summary(exclude_deterministic=False) ``` -------------------------------- ### Import Required Libraries Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/hsgp_nd_example.ipynb Imports essential libraries for the example, including JAX, NumPyro, ArviZ, Matplotlib, and NumPy. It also imports specific modules for HSGP, inference, and optimization. ```python from typing import Sequence import arviz as az import matplotlib.pyplot as plt import numpy as np from numpy.typing import NDArray import jax from jax import random import jax.numpy as jnp from optax import linear_onecycle_schedule import numpyro from numpyro import distributions as dist from numpyro.contrib.hsgp.approximation import hsgp_squared_exponential from numpyro.infer import Predictive from numpyro.infer.autoguide import AutoNormal from numpyro.infer.elbo import Trace_ELBO from numpyro.infer.hmc import NUTS from numpyro.infer.initialization import init_to_median, init_to_uniform from numpyro.infer.mcmc import MCMC from numpyro.infer.svi import SVI from numpyro.optim import Adam ``` -------------------------------- ### Execute Sharded NUTS Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/consensus_mc.ipynb Example of running the sharded NUTS algorithm with a specified number of shards. This demonstrates the practical application of the `run_sharded_mcmc` function. ```python %%time # Run sharded NUTS (embarrassingly parallel). num_shards = 20 rng_key, shard_key = random.split(rng_key) subposteriors = run_sharded_mcmc( shard_key, x, y, num_shards=num_shards, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, ) ``` -------------------------------- ### Run SVI and Plot Loss Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/nnx_example.ipynb This snippet shows how to run SVI with a defined model, guide, and optimizer, and then plots the ELBO loss over training iterations. Ensure `numpyro`, `matplotlib.pyplot`, and `jax.random` are imported. ```python conditioned_model = condition(model, data={"likelihood": y_train}) guide = AutoNormal(model=conditioned_model) optimizer = numpyro.optim.Adam(step_size=0.005) svi = SVI(conditioned_model, guide, optimizer, loss=TraceGraph_ELBO()) n_samples = 8_000 rng_key, rng_subkey = random.split(key=rng_key) svi_result = svi.run(rng_subkey, n_samples, x_train) fig, ax = plt.subplots() ax.plot(svi_result.losses) ax.set_title("ELBO loss", fontsize=18, fontweight="bold"); ``` -------------------------------- ### Build Documentation Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/README.md Run this command from the project's toplevel directory to build the documentation. ```bash make docs ``` -------------------------------- ### Import Libraries and Setup Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/hierarchical_forecasting.ipynb Imports necessary libraries like JAX, NumPyro, ArviZ, Matplotlib, and Optax. Sets up plotting styles and device configuration. Asserts the NumPyro version and enables autoreload for the notebook. ```python import arviz as az import matplotlib.pyplot as plt import numpy as np import jax from jax import Array, random import jax.numpy as jnp import optax import numpyro import numpyro.distributions as dist from numpyro.examples.datasets import load_bart_od from numpyro.infer import SVI, Predictive, Trace_ELBO from numpyro.infer.autoguide import AutoNormal from numpyro.infer.reparam import LocScaleReparam plt.style.use("bmh") plt.rcParams["figure.figsize"] = [12, 7] plt.rcParams["figure.dpi"] = 100 plt.rcParams["figure.facecolor"] = "white" numpyro.set_host_device_count(n=4) rng_key = random.key(seed=42) assert numpyro.__version__.startswith("0.21.0") %load_ext autoreload %autoreload 2 %config InlineBackend.figure_format = "retina" ``` -------------------------------- ### Get and Initialize Data Batch Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/tbip.ipynb Retrieves the initial data batch required for training and parameter initialization. This step ensures data dimensions are consistent with the model's configuration. ```python # Get initial batch. This informs the dimension of arrays and ensures they are # consistent with dimensions (N, D, K, V) defined above. Y_batch, I_batch, D_batch = tbip.get_batch(random.key(1), counts, author_indices) # Initialize the parameters using initial batch svi_state = svi_batch.init( random.key(0), Y_batch=Y_batch, d_batch=D_batch, i_batch=I_batch ) ``` -------------------------------- ### AutoNormal and AutoDiagonalNormal Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/autoguide.rst AutoNormal and AutoDiagonalNormal are basic mean-field guides. They automatically handle non-euclidean latent spaces using bijective transformations. These are good starting points for variational inference. ```APIDOC ## AutoNormal and AutoDiagonalNormal ### Description These are basic mean-field guides that automatically handle non-euclidean latent spaces by using appropriate bijective transformations. They are recommended as a starting point for variational inference. ### Usage These guides can be directly instantiated and used within NumPyro's inference algorithms. ``` -------------------------------- ### Install NumPyro with CUDA 13.x.y Support Source: https://github.com/pyro-ppl/numpyro/blob/master/README.md Install NumPyro with GPU support for CUDA 13.x.y. Ensure CUDA is installed on your system prior to running this command. ```bash pip install 'numpyro[cuda13]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` -------------------------------- ### Initialize TBIP Model and SVI Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/tbip.ipynb Sets up the TBIP model and configures the SVI for training. Uses Adam optimizer with exponential decay and TraceMeanField_ELBO loss. The update function is compiled with JIT for performance. ```python from jax import jit from optax import adam, exponential_decay from numpyro.infer import SVI, TraceMeanField_ELBO num_steps = 50000 batch_size = 512 # Large batches are recommended learning_rate = 0.01 decay_rate = 0.01 tbip = TBIP( N=num_authors, D=num_documents, K=num_topics, V=num_words, batch_size=batch_size, init_mu_theta=initial_document_loc, init_mu_beta=initial_objective_topic_loc, ) svi_batch = SVI( model=tbip.model, guide=tbip.guide, optim=adam(exponential_decay(learning_rate, num_steps, decay_rate)), loss=TraceMeanField_ELBO(), ) # Compile update function for faster training svi_batch_update = jit(svi_batch.update) ``` -------------------------------- ### Install NumPyro with CUDA 12.x.y Support Source: https://github.com/pyro-ppl/numpyro/blob/master/README.md Install NumPyro with GPU support for CUDA 12.x.y. Ensure CUDA is installed on your system prior to running this command. ```bash pip install 'numpyro[cuda12]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` -------------------------------- ### AutoGuideList Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/autoguide.rst AutoGuideList allows for the combination of multiple automatic guides into a single, composite guide. ```APIDOC ## AutoGuideList ### Description This utility allows for the composition of multiple automatic guides into a single, unified guide. ### Usage Use AutoGuideList when you want to combine the strengths of different automatic guides for a more sophisticated posterior approximation. ``` -------------------------------- ### Install NumPyro with Conda Source: https://github.com/pyro-ppl/numpyro/blob/master/README.md Install NumPyro using the conda package manager from the conda-forge channel. ```bash conda install -c conda-forge numpyro ``` -------------------------------- ### Run MCMC with NUTS sampler Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/contrib.rst This snippet shows how to initialize and run an MCMC simulation using the NUTS sampler. It includes setting up the sampler, defining MCMC parameters, and running the simulation with provided data and hyperparameters. The summary of the MCMC run is then printed. ```python sampler = NUTS(model) mcmc = MCMC(sampler=sampler, num_warmup=500, num_samples=1_000, num_chains=2) rng_key, rng_subkey = random.split(rng_key) ell = 1.3 m = 20 non_centered = True mcmc.run(rng_subkey, x, ell, m, non_centered, y_obs) mcmc.print_summary() ``` -------------------------------- ### AutoSemiDAIS Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/autoguide.rst AutoSemiDAIS is a semi-differentiable adaptive importance sampling guide, offering a balance between AutoDAIS and simpler guides. ```APIDOC ## AutoSemiDAIS .. autoclass:: numpyro.infer.autoguide.AutoSemiDAIS :members: :undoc-members: :show-inheritance: :member-order: bysource ``` -------------------------------- ### Import Libraries and Initialize Keys Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/discrete_imputation.ipynb Imports essential libraries including JAX for numerical operations, Numpyro for Bayesian modeling, and Graphviz for visualization. It also initializes random number generation keys. ```python from math import inf from graphviz import Digraph from jax import numpy as jnp, random from jax.scipy.special import expit import numpyro from numpyro import distributions as dist, sample from numpyro.infer.hmc import NUTS from numpyro.infer.mcmc import MCMC simkeys = random.split(random.key(0), 10) nsim = 5000 mcmc_key = random.key(1) ``` -------------------------------- ### Configure and Run FlowMC Sampler Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/other_samplers.ipynb Sets the sampling parameters, initializes the FlowMC sampler with the defined local sampler and normalizing flow model, and performs the sampling process. ```python sampler_params = { "n_loop_training": 7, "n_loop_production": 7, "n_local_steps": 150, "n_global_steps": 100, "learning_rate": 0.001, "momentum": 0.9, "num_epochs": 30, "batch_size": 10_000, "use_global": True, } rng_key, rng_subkey = random.split(rng_key) nf_sampler = Sampler( n_dim=n_dim, rng_key=rng_subkey, data=data, local_sampler=mala_sampler, nf_model=nf_model, **sampler_params, ) nf_sampler.sample(initial_position_array, data) rng_key, subkey = jax.random.split(rng_key) nf_samples = nf_sampler.sample_flow(subkey, 5_000) ``` -------------------------------- ### Sample Grid and Data Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/hsgp_nd_example.ipynb Generates a 2D grid and corresponding data points for use in Gaussian Process modeling. Requires JAX for random key generation and a custom sampling function. ```python seed = 0 key = jax.random.key(seed) X_grid, y_grid, X, y = sample_grid_and_data( N_grid, N, L, amplitude, lengthscale, noise, key, D ) ``` -------------------------------- ### AutoDelta Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/autoguide.rst AutoDelta is an automatic guide that uses a delta function at the posterior mean. It's a simple yet effective guide for many models. ```APIDOC ## AutoDelta .. autoclass:: numpyro.infer.autoguide.AutoDelta :members: :undoc-members: :show-inheritance: :member-order: bysource ``` -------------------------------- ### AutoMultivariateNormal and AutoLowRankMultivariateNormal Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/autoguide.rst These guides construct Normal variational distributions and can capture correlations in the posterior, offering more flexibility than basic mean-field guides. They may be challenging to fit in high-dimensional settings. ```APIDOC ## AutoMultivariateNormal and AutoLowRankMultivariateNormal ### Description These guides offer more flexibility by capturing correlations in the posterior through Normal variational distributions. They may present fitting challenges in high-dimensional scenarios. ### Usage Suitable for models where posterior correlations are important, these guides can be integrated into NumPyro's inference framework. ``` -------------------------------- ### TBIP Guide for Variational Inference Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/tbip.ipynb Defines the variational family for the TBIP model. It provides parameterized distributions for each latent variable sampled in the model, ensuring the guide has the same call signature as the model. ```python def guide(self, Y_batch, d_batch, i_batch): # This defines variational family. Notice that each of the latent variables # defined in the sample statements in the model above has a corresponding # sample statement in the guide. The guide is responsible for providing # variational parameters for each of these latent variables. # Also notice it is required that model and the guide have the same call. mu_x = param( "mu_x", init_value=-1 + 2 * random.uniform(random.key(1), (self.N,)) ) sigma_x = param( "sigma_y", init_value=jnp.ones([self.N]), constraint=constraints.positive ) mu_eta = param( "mu_eta", init_value=random.normal(random.key(2), (self.K, self.V)) ) sigma_eta = param( "sigma_eta", init_value=jnp.ones([self.K, self.V]), constraint=constraints.positive, ) mu_theta = param("mu_theta", init_value=self.init_mu_theta) sigma_theta = param( "sigma_theta", init_value=jnp.ones([self.D, self.K]), constraint=constraints.positive, ) mu_beta = param("mu_beta", init_value=self.init_mu_beta) sigma_beta = param( "sigma_beta", init_value=jnp.ones([self.K, self.V]), constraint=constraints.positive, ) with plate("i", self.N): sample("x", dist.Normal(mu_x, sigma_x)) with plate("k", size=self.K, dim=-2): with plate("k_v", size=self.V, dim=-1): sample("beta", dist.LogNormal(mu_beta, sigma_beta)) sample("eta", dist.Normal(mu_eta, sigma_eta)) with plate("d", size=self.D, subsample_size=self.batch_size, dim=-2): with plate("d_k", size=self.K, dim=-1): sample("theta", dist.LogNormal(mu_theta[d_batch], sigma_theta[d_batch])) ``` -------------------------------- ### Initialize FlowMC Sampler Parameters Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/other_samplers.ipynb Sets up the initial parameters for the FlowMC sampler, including dimensions, number of chains, and initial positions. ```python n_dim = 3 # number of parameters n_chains = 20 # number of chains ``` ```python data = {"x": x, "y": y} rng_key, subkey = random.split(rng_key) initial_position_array = jax.random.normal(subkey, shape=(n_chains, n_dim)) ``` -------------------------------- ### AutoLaplaceApproximation Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/autoguide.rst This guide can be used to compute a Laplace approximation of the posterior distribution. ```APIDOC ## AutoLaplaceApproximation ### Description This guide facilitates the computation of a Laplace approximation for the posterior distribution. ### Usage Suitable for obtaining a Gaussian approximation of the posterior, which can be computationally efficient. ``` -------------------------------- ### Complete Case Analysis MCMC Execution Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/discrete_imputation.ipynb Sets up and runs the MCMC kernel for the complete case analysis model and prints the summary. ```python cckernel = NUTS(ccmodel) ccmcmc = MCMC(cckernel, num_warmup=250, num_samples=750) ccmcmc.run(mcmc_key, Acc, Bcc, Ycc) ccmcmc.print_summary() ``` -------------------------------- ### prng_key Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/primitives.rst Gets the current pseudo-random number generator key. Used for reproducible randomness. ```APIDOC ## prng_key ### Description Gets the current pseudo-random number generator key. Used for reproducible randomness. ### Function Signature `prng_key(name: str)` ### Parameters * **name** (str) - The name of the PRNG key site. ``` -------------------------------- ### Run NumPyro Development Checks Source: https://github.com/pyro-ppl/numpyro/blob/master/CONTRIBUTING.md Executes linting, code formatting, and unit tests. 'make test' combines linting and unit tests. ```sh make lint ``` ```sh make format ``` ```sh make test ``` ```sh make doctest ``` -------------------------------- ### Define 8 Schools Data Source: https://github.com/pyro-ppl/numpyro/blob/master/README.md Sets up the J, y, and sigma variables for the 8 Schools example using NumPy. ```python import numpy as np J = 8 y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]) sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]) ``` -------------------------------- ### Render and Display a NumPyro Model Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/model_rendering.ipynb Renders the graphical model defined by the `model` function and displays it. Requires `graphviz` to be installed. ```python data = jnp.ones(10) numpyro.render_model(model, model_args=(data,)) ``` -------------------------------- ### MCMC Sampling and Summary Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/bayesian_imputation.ipynb Initializes and runs an MCMC sampler with the NUTS kernel to obtain posterior samples for the model parameters and imputed values. Prints a summary of the sampling results. ```python mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000) mcmc.run(random.key(0), **data, survived=survived) mcmc.print_summary() ``` -------------------------------- ### Get Samples from MCMC Run Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/bayesian_regression.ipynb Retrieves the posterior samples from a completed MCMC run. These samples are used for further analysis and visualization. ```python samples_2 = mcmc.get_samples() ``` -------------------------------- ### Get Truncation Label Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/censoring.ipynb Determines the censoring labels for the data based on lower and upper bounds. This is a prerequisite for the censored model. ```python truncation_label = get_truncation_label( y=censored_poisson_samples, lower=censored_poisson_data_params.lower, upper=censored_poisson_data_params.upper, ) ``` -------------------------------- ### Prepare Data and Render MACE Model Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/model_rendering.ipynb Prepares sample data for positions and annotations, adjusts them for 0-based indexing, and then renders the MACE model. ```python positions = np.array([1, 1, 1, 2, 3, 4, 5]) # fmt: off annotations = np.array([ [1, 3, 1, 2, 2, 2, 1, 3, 2, 2, 4, 2, 1, 2, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 3, 1, 2, 2, 4, 2, 2, 3, 1, 1, 1, 2, 1, 2], [1, 3, 1, 2, 2, 2, 2, 3, 2, 3, 4, 2, 1, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 3, 1, 1, 1, 1, 3, 1, 2, 2, 3, 2, 3, 3, 1, 1, 2, 3, 2, 2], [1, 3, 2, 2, 2, 2, 2, 3, 2, 2, 4, 2, 1, 2, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2, 1, 3, 1, 2, 2, 3, 1, 2, 3, 1, 1, 1, 2, 1, 2], [1, 4, 2, 3, 3, 3, 2, 3, 2, 2, 4, 3, 1, 3, 1, 2, 1, 1, 2, 1, 2, 2, 3, 2, 1, 1, 2, 1, 1, 1, 1, 3, 1, 2, 3, 4, 2, 3, 3, 1, 1, 2, 2, 1, 2], [1, 3, 1, 1, 2, 3, 1, 4, 2, 2, 4, 3, 1, 2, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 1, 2, 2, 3, 2, 2, 4, 1, 1, 1, 2, 1, 2], [1, 3, 2, 2, 2, 2, 1, 3, 2, 2, 4, 4, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 3, 1, 2, 3, 4, 3, 3, 3, 1, 1, 1, 2, 1, 2], [1, 4, 2, 1, 2, 2, 1, 3, 3, 3, 4, 3, 1, 2, 1, 1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 3, 1, 2, 2, 3, 2, 3, 2, 1, 1, 1, 2, 1, 2], ]).T # fmt: on # we subtract 1 because the first index starts with 0 in Python positions -= 1 annotations -= 1 mace_graph = numpyro.render_model(mace, model_args=(positions, annotations)) ``` -------------------------------- ### Load Collections Email Campaign Data Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/effect_handlers.ipynb Loads the dataset for the email campaign example from a GitHub URL using pandas. ```python # Load data data_url = ( "https://raw.githubusercontent.com/matheusfacure/" "python-causality-handbook/master/causal-inference-for-the-brave-and-true" "/data/collections_email.csv" ) df = pd.read_csv(data_url) df.head() ``` -------------------------------- ### Define Tiny Dataset Source: https://github.com/pyro-ppl/numpyro/blob/master/notebooks/source/gmm.ipynb Defines a small, one-dimensional dataset with five points used for the Gaussian Mixture Model example. ```python data = jnp.array([0.0, 1.0, 10.0, 11.0, 12.0]) ``` -------------------------------- ### AutoSurrogateLikelihoodDAIS Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/autoguide.rst AutoSurrogateLikelihoodDAIS is an automatic guide that uses a surrogate likelihood with differentiable adaptive importance sampling, suitable for models with complex likelihoods. ```APIDOC ## AutoSurrogateLikelihoodDAIS .. autoclass:: numpyro.infer.autoguide.AutoSurrogateLikelihoodDAIS :members: :undoc-members: :show-inheritance: :member-order: bysource ``` -------------------------------- ### Build HTML Pages Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/README.md Generates the HTML version of the documentation after the reStructuredText files have been created. ```bash make html ``` -------------------------------- ### AutoDAIS Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/autoguide.rst AutoDAIS (Differentiable Adaptive Importance Sampling) is an automatic guide that leverages differentiable importance sampling for efficient posterior approximation. ```APIDOC ## AutoDAIS .. autoclass:: numpyro.infer.autoguide.AutoDAIS :members: :undoc-members: :show-inheritance: :member-order: bysource ``` -------------------------------- ### AutoSurrogateLikelihoodDAIS Source: https://github.com/pyro-ppl/numpyro/blob/master/docs/source/autoguide.rst This guide is a powerful variational inference algorithm leveraging HMC and supporting data subsampling, making it efficient for large datasets. ```APIDOC ## AutoSurrogateLikelihoodDAIS ### Description This algorithm combines HMC with support for data subsampling, offering an efficient approach for variational inference, especially with large datasets. ### Usage This guide is beneficial when dealing with large datasets and requiring HMC-based inference with data subsampling capabilities. ```