Converting NumPyro objects to DataTree#
DataTree is the data format ArviZ relies on.
This page covers multiple ways to generate a DataTree from NumPyro MCMC and SVI objects.
See also
Conversion from Python, numpy or pandas objects
DataTree for Exploratory Analysis of Bayesian Models for an overview of
InferenceDataand its role within ArviZ.
We will start by importing the required packages and defining the model. The famous 8 school model.
import arviz_base as az
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, autoguide, Predictive
from jax import random
import jax.numpy as jnp
J = 8
y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools_model(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
with numpyro.plate("J", J):
eta = numpyro.sample("eta", dist.Normal(0, 1))
theta = numpyro.deterministic("theta", mu + tau * eta)
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
def eight_schools_custom_guide(J, sigma, y=None):
# Variational parameters for mu
mu_loc = numpyro.param("mu_loc", 0.0)
mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive)
mu = numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))
# Variational parameters for tau (positive support)
tau_loc = numpyro.param("tau_loc", 1.0)
tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive)
tau = numpyro.sample("tau", dist.LogNormal(jnp.log(tau_loc), tau_scale))
# Variational parameters for eta
eta_loc = numpyro.param("eta_loc", jnp.zeros(J))
eta_scale = numpyro.param("eta_scale", jnp.ones(J), constraint=dist.constraints.positive)
with numpyro.plate("J", J):
eta = numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))
# Deterministic transform
numpyro.deterministic("theta", mu + tau * eta)
Convert from MCMC#
This first example shows conversion from MCMC
# fit with MCMC
nuts = NUTS(eight_schools_model)
mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)
# sample the posterior predictive
predictive = Predictive(eight_schools_model, mcmc.get_samples())
samples_predictive = predictive(random.PRNGKey(1), J=J, sigma=sigma)
# Convert to MCMC
idata_mcmc = az.from_numpyro(mcmc, posterior_predictive=samples_predictive)
idata_mcmc
/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_86443/3262796440.py:3: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
sample: 100%|██████████| 2000/2000 [00:00<00:00, 3284.68it/s, 7 steps of size 3.56e-01. acc. prob=0.89]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 8009.51it/s, 7 steps of size 4.66e-01. acc. prob=0.88]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7861.79it/s, 7 steps of size 5.04e-01. acc. prob=0.84]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7739.55it/s, 7 steps of size 3.94e-01. acc. prob=0.90]
<xarray.DatasetView> Size: 0B
Dimensions: ()
Data variables:
*empty*Convert from SVI with Autoguide#
eight_schools_guide = autoguide.AutoNormal(eight_schools_model, init_loc_fn=numpyro.infer.init_to_median(num_samples=100))
svi = SVI(
eight_schools_model,
guide=eight_schools_guide,
optim=numpyro.optim.Adam(0.01),
loss = Trace_ELBO()
)
svi_result = svi.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)
# sample the posterior predictive
predictive_svi = Predictive(eight_schools_model, guide=eight_schools_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi = predictive_svi(random.PRNGKey(1), J=J, sigma=sigma)
idata_svi = az.from_numpyro_svi(
svi,
svi_result=svi_result,
model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
num_samples = 4000, # number of samples to draw in the posterior
posterior_predictive=samples_predictive_svi
)
idata_svi
100%|██████████| 10000/10000 [00:00<00:00, 10095.92it/s, init loss: 53.6608, avg. loss [9501-10000]: 31.6204]
<xarray.DatasetView> Size: 0B
Dimensions: ()
Data variables:
*empty*Converting from SVI with a custom guide function#
svi_custom_guide = SVI(
eight_schools_model,
guide=eight_schools_custom_guide,
optim=numpyro.optim.Adam(0.01),
loss = Trace_ELBO()
)
svi_custom_guide_result = svi_custom_guide.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)
# sample the posterior predictive
predictive_svi_custom = Predictive(eight_schools_model, guide=eight_schools_custom_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi_custom = predictive_svi_custom(random.PRNGKey(1), J=J, sigma=sigma)
idata_svi_custom_guide = az.from_numpyro_svi(
svi_custom_guide,
svi_result=svi_custom_guide_result,
model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
num_samples = 4000, # number of samples to draw in the posterior
posterior_predictive=samples_predictive_svi_custom
)
idata_svi_custom_guide
100%|██████████| 10000/10000 [00:00<00:00, 10246.71it/s, init loss: 34.9525, avg. loss [9501-10000]: 31.6279]
<xarray.DatasetView> Size: 0B
Dimensions: ()
Data variables:
*empty*Automatically Labelling Event Dims#
NumPyro batch dims are automatically labelled according to their corresponding plate names. In order to label event dims, we add infer={"event_dims": dim_labels} to the numpyro.sample statement as shown below:
def eight_schools_model_zsn(J, sigma, y=None):
mu = numpyro.sample("mu", dist.Normal(0, 5))
tau = numpyro.sample("tau", dist.HalfCauchy(5))
eta = numpyro.sample(
"eta",
dist.ZeroSumNormal(tau, event_shape=(J,)),
# note: this allows arviz to infer the event dimension labels
infer={"event_dims":["J"]}
)
with numpyro.plate("J", J):
theta = numpyro.deterministic("theta", mu + eta)
return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
# fit with MCMC
nuts = NUTS(eight_schools_model_zsn)
mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc2.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)
# sample the posterior predictive
predictive2 = Predictive(eight_schools_model, mcmc2.get_samples())
samples_predictive2 = predictive2(random.PRNGKey(1), J=J, sigma=sigma)
# Convert to MCMC
idata_mcmc2 = az.from_numpyro(mcmc2, posterior_predictive=samples_predictive2)
/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_86443/306760900.py:17: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
sample: 100%|██████████| 2000/2000 [00:00<00:00, 3119.59it/s, 15 steps of size 2.12e-01. acc. prob=0.82]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7054.28it/s, 15 steps of size 2.12e-01. acc. prob=0.86]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7063.72it/s, 15 steps of size 2.82e-01. acc. prob=0.83]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6924.13it/s, 31 steps of size 1.85e-01. acc. prob=0.91]
Notice that eta is labelled appropriately with J
idata_mcmc2
<xarray.DatasetView> Size: 0B
Dimensions: ()
Data variables:
*empty*%load_ext watermark
%watermark -n -u -v -iv -w
The watermark extension is already loaded. To reload it, use:
%reload_ext watermark
Last updated: Wed Oct 29 2025
Python implementation: CPython
Python version : 3.12.10
IPython version : 9.4.0
arviz_base: 0.7.0.dev0
numpyro : 0.19.0
jax : 0.6.2
numpy : 2.3.2
Watermark: 2.5.0