Converting NumPyro objects to DataTree#

DataTree is the data format ArviZ relies on.

This page covers multiple ways to generate a DataTree from NumPyro MCMC and SVI objects.

See also

We will start by importing the required packages and defining the model. The famous 8 school model.

import arviz_base as az
import numpy as np

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, autoguide, Predictive
from jax import random
import jax.numpy as jnp
J = 8
y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools_model(J, sigma, y=None):
    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    with numpyro.plate("J", J):
        eta = numpyro.sample("eta", dist.Normal(0, 1))
        theta = numpyro.deterministic("theta", mu + tau * eta)
        return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
    

def eight_schools_custom_guide(J, sigma, y=None):

    # Variational parameters for mu
    mu_loc = numpyro.param("mu_loc", 0.0)
    mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive)
    mu = numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))

    # Variational parameters for tau (positive support)
    tau_loc = numpyro.param("tau_loc", 1.0)
    tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive)
    tau = numpyro.sample("tau", dist.LogNormal(jnp.log(tau_loc), tau_scale))

    # Variational parameters for eta
    eta_loc = numpyro.param("eta_loc", jnp.zeros(J))
    eta_scale = numpyro.param("eta_scale", jnp.ones(J), constraint=dist.constraints.positive)
    with numpyro.plate("J", J):
        eta = numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))

        # Deterministic transform
        numpyro.deterministic("theta", mu + tau * eta)

Convert from MCMC#

This first example shows conversion from MCMC

# fit with MCMC
nuts = NUTS(eight_schools_model)
mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)

# sample the posterior predictive
predictive = Predictive(eight_schools_model, mcmc.get_samples())
samples_predictive = predictive(random.PRNGKey(1), J=J, sigma=sigma)

# Convert to MCMC
idata_mcmc = az.from_numpyro(mcmc, posterior_predictive=samples_predictive)
idata_mcmc
/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_86443/3262796440.py:3: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
sample: 100%|██████████| 2000/2000 [00:00<00:00, 3284.68it/s, 7 steps of size 3.56e-01. acc. prob=0.89] 
sample: 100%|██████████| 2000/2000 [00:00<00:00, 8009.51it/s, 7 steps of size 4.66e-01. acc. prob=0.88]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7861.79it/s, 7 steps of size 5.04e-01. acc. prob=0.84]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7739.55it/s, 7 steps of size 3.94e-01. acc. prob=0.90]
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*

Convert from SVI with Autoguide#

eight_schools_guide = autoguide.AutoNormal(eight_schools_model, init_loc_fn=numpyro.infer.init_to_median(num_samples=100))
svi = SVI(
    eight_schools_model, 
    guide=eight_schools_guide,
    optim=numpyro.optim.Adam(0.01),
    loss = Trace_ELBO()
)
svi_result = svi.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)

# sample the posterior predictive
predictive_svi = Predictive(eight_schools_model, guide=eight_schools_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi = predictive_svi(random.PRNGKey(1), J=J, sigma=sigma)


idata_svi = az.from_numpyro_svi(
    svi,
    svi_result=svi_result,
    model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
    num_samples = 4000, # number of samples to draw in the posterior
    posterior_predictive=samples_predictive_svi
)
idata_svi
100%|██████████| 10000/10000 [00:00<00:00, 10095.92it/s, init loss: 53.6608, avg. loss [9501-10000]: 31.6204]
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*

Converting from SVI with a custom guide function#

svi_custom_guide = SVI(
    eight_schools_model, 
    guide=eight_schools_custom_guide,
    optim=numpyro.optim.Adam(0.01),
    loss = Trace_ELBO()
)
svi_custom_guide_result = svi_custom_guide.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)

# sample the posterior predictive
predictive_svi_custom = Predictive(eight_schools_model, guide=eight_schools_custom_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi_custom = predictive_svi_custom(random.PRNGKey(1), J=J, sigma=sigma)

idata_svi_custom_guide = az.from_numpyro_svi(
    svi_custom_guide,
    svi_result=svi_custom_guide_result,
    model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
    num_samples = 4000, # number of samples to draw in the posterior
    posterior_predictive=samples_predictive_svi_custom
)
idata_svi_custom_guide
100%|██████████| 10000/10000 [00:00<00:00, 10246.71it/s, init loss: 34.9525, avg. loss [9501-10000]: 31.6279]
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*

Automatically Labelling Event Dims#

NumPyro batch dims are automatically labelled according to their corresponding plate names. In order to label event dims, we add infer={"event_dims": dim_labels} to the numpyro.sample statement as shown below:

def eight_schools_model_zsn(J, sigma, y=None):
    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    eta = numpyro.sample(
        "eta", 
        dist.ZeroSumNormal(tau, event_shape=(J,)),
        # note: this allows arviz to infer the event dimension labels
        infer={"event_dims":["J"]}
    )
    with numpyro.plate("J", J):
        theta = numpyro.deterministic("theta", mu + eta)
        return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)


# fit with MCMC
nuts = NUTS(eight_schools_model_zsn)
mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc2.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)


# sample the posterior predictive
predictive2 = Predictive(eight_schools_model, mcmc2.get_samples())
samples_predictive2 = predictive2(random.PRNGKey(1), J=J, sigma=sigma)

# Convert to MCMC
idata_mcmc2 = az.from_numpyro(mcmc2, posterior_predictive=samples_predictive2)
/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_86443/306760900.py:17: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
sample: 100%|██████████| 2000/2000 [00:00<00:00, 3119.59it/s, 15 steps of size 2.12e-01. acc. prob=0.82]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7054.28it/s, 15 steps of size 2.12e-01. acc. prob=0.86]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7063.72it/s, 15 steps of size 2.82e-01. acc. prob=0.83]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6924.13it/s, 31 steps of size 1.85e-01. acc. prob=0.91]

Notice that eta is labelled appropriately with J

idata_mcmc2
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*
%load_ext watermark
%watermark -n -u -v -iv -w
The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
Last updated: Wed Oct 29 2025

Python implementation: CPython
Python version       : 3.12.10
IPython version      : 9.4.0

arviz_base: 0.7.0.dev0
numpyro   : 0.19.0
jax       : 0.6.2
numpy     : 2.3.2

Watermark: 2.5.0