arviz_base.from_numpyro

Contents

arviz_base.from_numpyro#

arviz_base.from_numpyro(posterior=None, *, prior=None, posterior_predictive=None, predictions=None, constant_data=None, predictions_constant_data=None, log_likelihood=None, index_origin=None, coords=None, dims=None, pred_dims=None, extra_event_dims=None, num_chains=1)[source]#

Convert NumPyro data into a DataTree object.

For a usage example read Converting NumPyro objects to DataTree

If no dims are provided, this will infer batch dim names from NumPyro model plates. For event dim names, such as with the ZeroSumNormal, infer={“event_dims”:dim_names} can be provided in numpyro.sample, i.e.:

# equivalent to dims entry, {"gamma": ["groups"]}
gamma = numpyro.sample(
    "gamma",
    dist.ZeroSumNormal(1, event_shape=(n_groups,)),
    infer={"event_dims":["groups"]}
)

There is also an additional extra_event_dims input to cover any edge cases, for instance deterministic sites with event dims (which dont have an infer argument to provide metadata).

Parameters:
posteriornumpyro.mcmc.MCMC

Fitted MCMC object from NumPyro

priordict, optional

Prior samples from a NumPyro model

posterior_predictivedict, optional

Posterior predictive samples for the posterior

predictionsdict, optional

Out of sample predictions

constant_datadict, optional

Dictionary containing constant data variables mapped to their values.

predictions_constant_datadict, optional

Constant data used for out-of-sample predictions.

index_originint, optional
coordsdict, optional

Map of dimensions to coordinates

dimsdict of {strlist of str}, optional

Map variable names to their coordinates. Will be inferred if they are not provided.

pred_dimsdict, optional

Dims for predictions data. Map variable names to their coordinates. Default behavior is to infer dims if this is not provided

extra_event_dimsdict, optional

Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to their coordinates.

num_chainsint, default 1

Number of chains used for sampling. Ignored if posterior is present.

Returns:
xarray.DataTree