Num Pyro Ai en Stable
Num Pyro Ai en Stable
Num Pyro Ai en Stable
Uber AI Labs
2 API Reference 11
i
21 Example: Hidden Markov Model 273
Index 293
ii
CHAPTER 1
Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
Docs and Examples | Forum
NumPyro is a small probabilistic programming library that provides a NumPy backend for Pyro. We rely on JAX for
automatic differentiation and JIT compilation to GPU / CPU. This is an alpha release under active development, so
beware of brittleness, bugs, and changes to the API as the design evolves.
NumPyro is designed to be lightweight and focuses on providing a flexible substrate that users can build on:
• Pyro Primitives: NumPyro programs can contain regular Python and NumPy code, in addition to Pyro prim-
itives like sample and param. The model code should look very similar to Pyro except for some minor
differences between PyTorch and Numpy’s API. See the example below.
• Inference algorithms: NumPyro currently supports Hamiltonian Monte Carlo, including an implementation of
the No U-Turn Sampler. One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by
JIT compiling the verlet integrator that includes multiple gradient computations. With JAX, we can compose
jit and grad to compile the entire integration step into an XLA optimized kernel. We also eliminate Python
overhead by JIT compiling the entire tree building stage in NUTS (this is possible using Iterative NUTS).
There is also a basic Variational Inference implementation for reparameterized distributions together with many
flexible (auto)guides for Automatic Differentiation Variational Inference (ADVI).
• Distributions: The numpyro.distributions module provides distribution classes, constraints and bijective trans-
forms. The distribution classes wrap over samplers implemented to work with JAX’s functional pseudo-random
number generator. The design of the distributions module largely follows from PyTorch. A major subset of
the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result,
Pyro and PyTorch users can rely on the same API and batching semantics as in torch.distributions.
In addition to distributions, constraints and transforms are very useful when operating on distribution
classes with bounded support.
1
NumPyro Documentation
• Effect handlers: Like Pyro, primitives like sample and param can be provided nonstandard interpretations
using effect-handlers from the numpyro.handlers module, and these can be easily extended to implement custom
inference algorithms and inference utilities.
Let us explore NumPyro using a simple example. We will use the eight schools example from Gelman et al., Bayesian
Data Analysis: Sec. 5.5, 2003, which studies the effect of coaching on SAT performance in eight schools.
The data is given by:
>>> J = 8
>>> sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
, where y are the treatment effects and sigma the standard error. We build a hierarchical model for the study where we
assume that the group-level parameters theta for each school are sampled from a Normal distribution with unknown
mean mu and standard deviation tau, while the observed data are in turn generated from a Normal distribution with
mean and standard deviation given by theta (true effect) and sigma, respectively. This allows us to estimate the
population-level parameters mu and tau by pooling from all the observations, while still allowing for individual
variation amongst the schools using the group-level theta parameters.
Let us infer the values of the unknown parameters in our model by running MCMC using the No-U-Turn Sampler
(NUTS). Note the usage of the extra_fields argument in MCMC.run. By default, we only collect samples
from the target (posterior) distribution when we run inference using MCMC. However, collecting additional fields like
potential energy or the acceptance probability of a sample can be easily achieved by using the extra_fields
argument. For a list of possible fields that can be collected, see the HMCState object. In this example, we will
additionally collect the potential_energy for each sample.
We can print the summary of the MCMC run, and examine if we observed any divergences during inference. Addi-
tionally, since we collected the potential energy for each of the samples, we can easily compute the expected log joint
density.
>>> mcmc.print_summary()
Number of divergences: 19
>>> pe = mcmc.get_extra_fields()['potential_energy']
The values above 1 for the split Gelman Rubin diagnostic (r_hat) indicates that the chain has not fully converged.
The low value for the effective sample size (n_eff), particularly for tau, and the number of divergent transitions
looks problematic. Fortunately, this is a common pathology that can be rectified by using a non-centered paramateri-
zation for tau in our model. This is straightforward to do in NumPyro by using a TransformedDistribution instance
together with a reparameterization effect handler. Let us rewrite the same model but instead of sampling theta from
a Normal(mu, tau), we will instead sample it from a base Normal(0, 1) distribution that is transformed using
an AffineTransform. Note that by doing so, NumPyro runs HMC by generating samples theta_base for the base
Normal(0, 1) distribution instead. We see that the resulting chain does not suffer from the same pathology — the
Gelman Rubin diagnostic is 1 for all the parameters and the effective sample size looks quite good!
... 'theta',
... dist.transforms.AffineTransform(mu,
˓→tau)))
>>> mcmc.print_summary(exclude_deterministic=False)
Number of divergences: 0
>>> pe = mcmc.get_extra_fields()['potential_energy']
Note that for the class of distributions with loc,scale paramaters such as Normal, Cauchy, StudentT, we also
provide a LocScaleReparam reparameterizer to achieve the same purpose. The corresponding code will be
Now, let us assume that we have a new school for which we have not observed any test scores, but we would like to
generate predictions. NumPyro provides a Predictive class for such a purpose. Note that in the absence of any observed
data, we simply use the population-level parameters to generate predictions. The Predictive utility conditions the
unobserved mu and tau sites to values drawn from the posterior distribution from our last MCMC run, and runs the
model forward to generate predictions.
>>> print(np.mean(samples_predictive['obs']))
3.9886456
For some more examples on specifying models and doing inference in NumPyro:
• Bayesian Regression in NumPyro - Start here to get acquainted with writing a simple model in NumPyro,
MCMC inference API, effect handlers and writing custom inference utilities.
• Time Series Forecasting - Illustrates how to convert for loops in the model to JAX’s lax.scan primitive for
fast inference.
• Baseball example - Using NUTS for a simple hierarchical model. Compare this with the baseball example in
Pyro.
• Hidden Markov Model in NumPyro as compared to Stan.
• Variational Autoencoder - As a simple example that uses Variational Inference with neural networks. Pyro
implementation for comparison.
• Gaussian Process - Provides a simple example to use NUTS to sample from the posterior over the hyper-
parameters of a Gaussian Process.
• Statistical Rethinking with NumPyro - Notebooks containing translation of the code in Richard McElreath’s
Statistical Rethinking book second version, to NumPyro.
• Other model examples can be found in the examples folder.
Pyro users will note that the API for model specification and inference is largely the same as Pyro, including the
distributions API, by design. However, there are some important core differences (reflected in the internals) that users
should be aware of. e.g. in NumPyro, there is no global parameter store or random state, to make it possible for us
to leverage JAX’s JIT compilation. Also, users may need to write their models in a more functional style that works
better with JAX. Refer to FAQs for a list of differences.
1.4 Installation
Limited Windows Support: Note that NumPyro is untested on Windows, and might require building
jaxlib from source. See this JAX issue for more details. Alternatively, you can install Windows Subsystem
for Linux and use NumPyro on it as on a Linux system. See also CUDA on Windows Subsystem for Linux
if you want to use GPUs on Windows.
To install NumPyro with a CPU version of JAX, you can use pip:
To use NumPyro on the GPU, you will need to first install jax and jaxlib with CUDA support.
To run NumPyro on Cloud TPUs, you can use pip to install NumPyro as above and setup the TPU backend as detailed
here.
Default Platform: In contrast to JAX, which uses GPU as the default platform, we use CPU as the
default platform. You can use set_platform utility to switch to other platforms such as GPU or TPU at the
beginning of your program.
You can also install NumPyro from source:
1.4. Installation 7
NumPyro Documentation
‘‘‘
, or as a higher order function:
‘‘‘python
def fn():
y = numpyro.sample('y', dist.Bernoulli(x))
return y
print(handlers.seed(fn, rng_seed=0)())
‘‘‘
2. Can I use the same Pyro model for doing inference in NumPyro?
As you may have noticed from the examples, NumPyro supports all Pyro primitives like sample, param, plate
and module, and effect handlers. Additionally, we have ensured that the distributions API is based on torch.
distributions, and the inference classes like SVI and MCMC have the same interface. This along with the
similarity in the API for NumPy and PyTorch operations ensures that models containing Pyro primitive statements can
be used with either backend with some minor changes. Example of some differences along with the changes needed,
are noted below:
• Any torch operation in your model will need to be written in terms of the corresponding jax.numpy opera-
tion. Additionally, not all torch operations have a numpy counterpart (and vice-versa), and sometimes there
are minor differences in the API.
• pyro.sample statements outside an inference context will need to be wrapped in a seed handler, as men-
tioned above.
• There is no global parameter store, and as such using numpyro.param outside an inference context will have
no effect. To retrieve the optimized parameter values from SVI, use the SVI.get_params method. Note that you
can still use param statements inside a model and NumPyro will use the substitute effect handler internally to
substitute values from the optimizer when running the model in SVI.
• PyTorch neural network modules will need to rewritten as stax neural networks. See the VAE example for
differences in syntax between the two backends.
• JAX works best with functional code, particularly if we would like to leverage JIT compilation, which NumPyro
does internally for many inference subroutines. As such, if your model has side-effects that are not visible to
the JAX tracer, it may need to rewritten in a more functional style.
For most small models, changes required to run inference in NumPyro should be minor. Additionally, we are working
on pyro-api which allows you to write the same code and dispatch it to multiple backends, including NumPyro. This
will necessarily be more restrictive, but has the advantage of being backend agnostic. See the documentation for an
example, and let us know your feedback.
3. How can I contribute to the project?
Thanks for your interest in the project! You can take a look at beginner friendly issues that are marked with the good
first issue tag on Github. Also, please feel to reach out to us on the forum.
In the near term, we plan to work on the following. Please open new issues for feature requests and enhancements:
@article{phan2019composable,
year={2019}
as well as
@article{bingham2018pyro,
author = {Bingham, Eli and Chen, Jonathan P. and Jankowiak, Martin and Obermeyer,
˓→ Fritz and
Pradhan, Neeraj and Karaletsos, Theofanis and Singh, Rohit and Szerlip,
˓→ Paul and
year = {2018}
API Reference
2.1 Modeling
param
11
NumPyro Documentation
sample
Note: By design, sample primitive is meant to be used inside a NumPyro model. Then seed handler is used
to inject a random state to fn. In those situations, rng_key keyword will take no effect.
Parameters
• name (str) – name of the sample site.
• fn – a stochastic function that returns a sample.
• obs (numpy.ndarray) – observed value
• rng_key (jax.random.PRNGKey) – an optional random key for fn.
• sample_shape – Shape of samples to be drawn.
• infer (dict) – an optional dictionary containing additional information for inference
algorithms. For example, if fn is a discrete distribution, setting infer={‘enumerate’: ‘paral-
lel’} to tell MCMC marginalize this discrete latent site.
Returns sample from the stochastic fn.
plate
Parameters
• name (str) – Name of the plate.
• size (int) – Size of the plate.
• subsample_size (int) – Optional argument denoting the size of the mini-batch. This
can be used to apply a scaling factor by inference algorithms. e.g. when computing ELBO
using a mini-batch.
• dim (int) – Optional argument to specify which dimension in the tensor is used as the
plate dim. If None (default), the leftmost available dim is allocated.
plate_stack
Parameters
• prefix (str) – Name prefix for plates.
• sizes (iterable) – An iterable of plate sizes.
• rightmost_dim (int) – The rightmost dim, counting from the right.
subsample
subsample(data, event_dim)
EXPERIMENTAL Subsampling statement to subsample data based on enclosing plate s.
This is typically called on arguments to model() when subsampling is performed automatically by plate s
by passing subsample_size kwarg. For example the following are equivalent:
# Version 1. using indexing
def model(data):
with numpyro.plate("data", len(data), subsample_size=10, dim=-data.dim()) as
˓→ind:
data = data[ind]
# ...
Parameters
• data (numpy.ndarray) – A tensor of batched data.
• event_dim (int) – The event dimension of the data tensor. Dimensions to the left are
considered batch dimensions.
Returns A subsampled version of data
Return type ndarray
deterministic
deterministic(name, value)
Used to designate deterministic sites in the model. Note that most effect handlers will not operate on determin-
istic sites (except trace()), so deterministic sites should be side-effect free. The use case for deterministic
nodes is to record any values in the model execution trace.
Parameters
• name (str) – name of the deterministic site.
• value (numpy.ndarray) – deterministic value to record in the trace.
2.1. Modeling 13
NumPyro Documentation
prng_key
prng_key()
A statement to draw a pseudo-random number generator key PRNGKey() under seed handler.
Returns a PRNG key of shape (2,) and dtype unit32.
factor
factor(name, log_factor)
Factor statement to add arbitrary log probability factor to a probabilistic model.
Parameters
• name (str) – Name of the trivial sample.
• log_factor (numpy.ndarray) – A possibly batched log probability factor.
module
flax_module
haiku_module
random_flax_module
Note: Parameters of a Flax module are stored in a nested dict. For example, the module B defined as follows:
class A(nn.Module):
def apply(self, x):
return nn.Dense(x, 1, bias=False, name='dense')
class B(nn.Module):
def apply(self, x):
return A(x, name='inner')
has parameters {‘inner’: {‘dense’: {‘kernel’: param_value}}}. In the argument prior, to specify kernel parame-
ter, we join the path to it using dots: prior={“inner.dense.kernel”: param_prior}.
Parameters
• name (str) – name of NumPyro module
• flax.nn.Module – the module to be registered with NumPyro
• prior – a NumPyro distribution or a Python dict with parameter names as keys and re-
spective distributions as values. For example:
net = random_flax_module("net",
flax.nn.Dense.partial(features=1),
prior={"bias": dist.Cauchy(), "kernel":
˓→dist.Normal()},
input_shape=(4,))
2.1. Modeling 15
NumPyro Documentation
Example
>>>
>>> class Net(nn.Module):
... def apply(self, x, n_units):
... x = nn.Dense(x[..., None], features=n_units)
... x = nn.relu(x)
... x = nn.Dense(x, features=n_units)
... x = nn.relu(x)
... mean = nn.Dense(x, features=1)
... rho = nn.Dense(x, features=1)
... return mean.squeeze(), rho.squeeze()
>>>
>>> def generate_data(n_samples):
... x = np.random.normal(size=n_samples)
... y = np.cos(x * 3) + np.random.normal(size=n_samples) * np.abs(x) / 2
... return x, y
>>>
>>> def model(x, y=None, batch_size=None):
... module = Net.partial(n_units=32)
... net = random_flax_module("nn", module, dist.Normal(0, 0.1), input_
˓→shape=())
random_haiku_module
net = random_haiku_module("net",
haiku.transform(lambda x: hk.
˓→Linear(1)(x)),
prior={"linear.b": dist.Cauchy(),
˓→"linear.w": dist.Normal()},
input_shape=(4,))
scan
2.1. Modeling 17
NumPyro Documentation
Warning: This is an experimental utility function that allows users to use JAX control flow with NumPyro’s
effect handlers. Currently, sample and deterministic sites within the scan body f are supported. If you notice
that any effect handlers or distributions are unsupported, please file an issue.
Note: It is ambiguous to align scan dimension inside a plate context. So the following pattern won’t be
supported
with numpyro.plate('N', 10):
last, ys = scan(f, init, xs)
All plate statements should be put inside f. For example, the corresponding working code is
def g(*args, **kwargs):
with numpyro.plate('N', 10):
return f(*arg, **kwargs)
Note: We can scan over discrete latent variables in f. The joint density is evaluated using parallel-scan (refer-
ence [1]) over time dimension, which reduces parallel complexity to O(log(length)).
A trace of scan with discrete latent variables will contain the following sites:
• init sites: those sites belong to the first history traces of f. Sites at the i-th trace will have name prefixed
with ‘_PREV_’ * (2 * history - 1 - i).
• scanned sites: those sites collect the values of the remaining scan loop over f. An addition time dimen-
sion _time_foo will be added to those sites, where foo is the name of the first site appeared in f.
Not all transition functions f are supported. All of the restrictions from Pyro’s enumeration tutorial [2] still
apply here. In addition, there should not have any site outside of scan depend on the first output of scan (the last
carry value).
** References **
1. Temporal Parallelization of Bayesian Smoothers, Simo Sarkka, Angel F. Garcia-Fernandez (https://arxiv.
org/abs/1905.13002)
2. Inference with Discrete Latent Variables (http://pyro.ai/examples/enumeration.html#
Dependencies-among-plates)
Parameters
• f (callable) – a function to be scanned.
• init – the initial carrying state
• xs – the values over which we scan along the leading axis. This can be any JAX pytree (e.g.
list/dict of arrays).
• length – optional value specifying the length of xs but can be used when xs is an empty
pytree (e.g. None)
• reverse (bool) – optional boolean specifying whether to run the scan iteration forward
(the default) or in reverse
• history (int) – The number of previous contexts visible from the current context. De-
faults to 1. If zero, this is similar to numpyro.plate.
Returns output of scan, quoted from jax.lax.scan() docs: “pair of type (c, [b]) where the
first element represents the final loop carry value and the second element represents the stacked
outputs of the second output of f when scanned over the leading axis of the inputs”.
This provides a small set of effect handlers in NumPyro that are modeled after Pyro’s poutine module. For a tutorial on
effect handlers more generally, readers are encouraged to read Poutine: A Guide to Programming with Effect Handlers
in Pyro. These simple effect handlers can be composed together or new ones added to enable implementation of custom
inference utilities and algorithms.
Example
As an example, we are using seed, trace and substitute handlers to define the log_likelihood function below.
We first create a logistic regression model and sample from the posterior distribution over the regression parameters
using MCMC(). The log_likelihood function uses effect handlers to run the model by substituting sample sites with
values from the posterior distribution and computes the log density for a single data point. The log_predictive_density
function computes the log likelihood for each draw from the joint posterior and aggregates the results for all the data
points, but does so by using JAX’s auto-vectorize transform called vmap so that we do not need to loop over all the
data points.
>>> N, D = 3000, 3
>>> def logistic_regression(data, labels):
... coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(D), jnp.ones(D)))
... intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
... logits = jnp.sum(coefs * data + intercept, axis=-1)
... return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
>>> mcmc.print_summary()
2.1. Modeling 19
NumPyro Documentation
block
process_message(msg)
collapse
condition
process_message(msg)
do
2.1. Modeling 21
NumPyro Documentation
This is equivalent to replacing z = numpyro.sample(“z”, . . . ) with z = 1. and introducing a fresh sample site
numpyro.sample(“z”, . . . ) whose value is not used elsewhere.
References:
1. Single World Intervention Graphs: A Primer, Thomas Richardson, James Robins
Parameters
• fn – a stochastic function (callable containing Pyro primitive calls)
• data – a dict mapping sample site names to interventions
Example:
process_message(msg)
infer_config
lift
lift makes param statements behave like sample statements using the distributions in prior. In
this example, site s will now behave as if it was replaced with s = numpyro.sample("s", dist.
Exponential(0.3)).
Parameters
• fn – function whose parameters will be lifted to random values
• prior – prior function in the form of a Distribution or a dict of Distributions
process_message(msg)
mask
reparam
Parameters config (dict or callable) – Configuration, either a dict mapping site name to
Reparam , or a function mapping site to Reparam or None.
process_message(msg)
2.1. Modeling 23
NumPyro Documentation
replay
process_message(msg)
scale
scope
.. doctest::
Parameters
• fn – Python callable with NumPyro primitives.
• prefix (str) – a string to prepend to sample names
• divider (str) – a string to join the prefix and sample name; default to ‘/’
process_message(msg)
seed
Note: Unlike in Pyro, numpyro.sample primitive cannot be used without wrapping it in seed handler since
there is no global random state. As such, users need to use seed as a contextmanager to generate samples from
distributions or as a decorator for their model callable (See below).
Example:
2.1. Modeling 25
NumPyro Documentation
process_message(msg)
substitute
process_message(msg)
trace
class trace(fn=None)
Bases: numpyro.primitives.Messenger
Returns a handler that records the inputs and outputs at primitive calls inside fn.
Example
'is_observed': False,
'kwargs': {'rng_key': DeviceArray([0, 0], dtype=uint32)},
'name': 'a',
'type': 'sample',
'value': DeviceArray(-0.20584235, dtype=float32)})])
postprocess_message(msg)
get_trace(*args, **kwargs)
Run the wrapped callable and return the recorded trace.
Parameters
• *args – arguments to the callable.
• **kwargs – keyword arguments to the callable.
Returns OrderedDict containing the execution trace.
2.2 Distributions
Distribution
2.2. Distributions 27
NumPyro Documentation
arg_constraints = {}
support = None
has_enumerate_support = False
is_discrete = False
reparametrized_params = []
tree_flatten()
classmethod tree_unflatten(aux_data, params)
static set_default_validate_args(value)
batch_shape
Returns the shape over which the distribution parameters are batched.
Returns batch shape of the distribution.
Return type tuple
event_shape
Returns the shape of a single sample from the distribution without batching.
Returns event shape of the distribution.
Return type tuple
event_dim
Returns Number of dimensions of individual events.
Return type int
has_rsample
rsample(key, sample_shape=())
shape(sample_shape=())
The tensor shape of samples from this distribution.
Samples are of shape:
Parameters sample_shape (tuple) – the size of the iid batch to be drawn from the distri-
bution.
Returns shape of samples.
Return type tuple
sample(key, sample_shape=())
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape.
Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned
sample will be filled with iid draws from the distribution instance.
Parameters
• key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
• sample_shape (tuple) – the sample shape for the distribution.
Returns an array of shape sample_shape + batch_shape + event_shape
Return type numpy.ndarray
sample_with_intermediates(key, sample_shape=())
Same as sample except that any intermediate computations are returned (useful for TransformedDistri-
bution).
Parameters
• key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
• sample_shape (tuple) – the sample shape for the distribution.
Returns an array of shape sample_shape + batch_shape + event_shape
Return type numpy.ndarray
log_prob(value)
Evaluates the log probability density for a batch of samples given by value.
Parameters value – A batch of samples from the distribution.
Returns an array with shape value.shape[:-self.event_shape]
Return type numpy.ndarray
mean
Mean of the distribution.
variance
Variance of the distribution.
to_event(reinterpreted_batch_ndims=None)
Interpret the rightmost reinterpreted_batch_ndims batch dimensions as dependent event dimensions.
Parameters reinterpreted_batch_ndims – Number of rightmost batch dims to inter-
pret as event dims.
Returns An instance of Independent distribution.
Return type numpyro.distributions.distribution.Independent
enumerate_support(expand=True)
Returns an array with shape len(support) x batch_shape containing all values in the support.
expand(batch_shape)
Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.
Parameters batch_shape (tuple) – batch shape to expand to.
Returns an instance of ExpandedDistribution.
Return type ExpandedDistribution
2.2. Distributions 29
NumPyro Documentation
expand_by(sample_shape)
Expands a distribution by adding sample_shape to the left side of its batch_shape. To expand
internal dims of self.batch_shape from 1 to something larger, use expand() instead.
Parameters sample_shape (tuple) – The size of the iid batch to be drawn from the distri-
bution.
Returns An expanded version of this distribution.
Return type ExpandedDistribution
mask(mask)
Masks a distribution by a boolean or boolean-valued array that is broadcastable to the distributions
Distribution.batch_shape .
Parameters mask (bool or jnp.ndarray) – A boolean or boolean valued array (True
includes a site, False excludes a site).
Returns A masked copy of this distribution.
Return type MaskedDistribution
ExpandedDistribution
ImproperUniform
Note: sample method is not implemented for this distribution. In autoguide and mcmc, initial parameters for
improper sites are derived from init_to_uniform or init_to_value strategies.
Usage:
...
... # real matrix with shape (3, 4)
... y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4)))
...
... # a shape-(6, 8) batch of length-5 vectors greater than 3
... z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8),
˓→event_shape=(5,)))
If you want to set improper prior over all values greater than a, where a is another random variable, you might
use
2.2. Distributions 31
NumPyro Documentation
Parameters
• support (Constraint) – the support of this distribution.
• batch_shape (tuple) – batch shape of this distribution. It is usually safe to set
batch_shape=().
• event_shape (tuple) – event shape of this distribution.
arg_constraints = {}
log_prob(*args, **kwargs)
tree_flatten()
Independent
Parameters
• base_distribution (numpyro.distribution.Distribution) – a distribu-
tion instance.
• reinterpreted_batch_ndims (int) – the number of batch dims to reinterpret as
event dims.
arg_constraints = {}
support
has_enumerate_support
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two
instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
is_discrete
bool(x) -> bool
Returns True when the argument x is true, False otherwise. The builtins True and False are the only two
instances of the class bool. The class bool is a subclass of the class int, and cannot be subclassed.
reparameterized_params
mean
Mean of the distribution.
variance
Variance of the distribution.
has_rsample
rsample(key, sample_shape=())
sample(key, sample_shape=())
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape.
Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned
sample will be filled with iid draws from the distribution instance.
Parameters
• key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
• sample_shape (tuple) – the sample shape for the distribution.
Returns an array of shape sample_shape + batch_shape + event_shape
Return type numpy.ndarray
log_prob(value)
Evaluates the log probability density for a batch of samples given by value.
Parameters value – A batch of samples from the distribution.
Returns an array with shape value.shape[:-self.event_shape]
Return type numpy.ndarray
expand(batch_shape)
Returns a new ExpandedDistribution instance with batch dimensions expanded to batch_shape.
Parameters batch_shape (tuple) – batch shape to expand to.
Returns an instance of ExpandedDistribution.
Return type ExpandedDistribution
tree_flatten()
classmethod tree_unflatten(aux_data, params)
MaskedDistribution
2.2. Distributions 33
NumPyro Documentation
TransformedDistribution
2.2. Distributions 35
NumPyro Documentation
Delta
Unit
Beta
Cauchy
2.2. Distributions 37
NumPyro Documentation
sample(key, sample_shape=())
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape.
Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned
sample will be filled with iid draws from the distribution instance.
Parameters
• key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
• sample_shape (tuple) – the sample shape for the distribution.
Returns an array of shape sample_shape + batch_shape + event_shape
Return type numpy.ndarray
log_prob(*args, **kwargs)
mean
Mean of the distribution.
variance
Variance of the distribution.
Chi2
Dirichlet
variance
Variance of the distribution.
Exponential
Gamma
2.2. Distributions 39
NumPyro Documentation
mean
Mean of the distribution.
variance
Variance of the distribution.
Gumbel
GaussianRandomWalk
HalfCauchy
HalfNormal
2.2. Distributions 41
NumPyro Documentation
Parameters
• key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
• sample_shape (tuple) – the sample shape for the distribution.
Returns an array of shape sample_shape + batch_shape + event_shape
Return type numpy.ndarray
log_prob(*args, **kwargs)
mean
Mean of the distribution.
variance
Variance of the distribution.
InverseGamma
Note: We keep the same notation rate as in Pyro but it plays the role of scale parameter of InverseGamma in
literatures (e.g. wikipedia: https://en.wikipedia.org/wiki/Inverse-gamma_distribution)
Laplace
LKJ
LKJCholesky
2.2. Distributions 43
NumPyro Documentation
factor propotional to det(𝑀 )𝜂−1 . Because of that, when concentration == 1, we have a uniform distri-
bution over Cholesky factors of correlation matrices.
When concentration > 1, the distribution favors samples with large diagonal entries (hence large deter-
minent). This is useful when we know a priori that the underlying variables are not correlated.
When concentration < 1, the distribution favors samples with small diagonal entries (hence small deter-
minent). This is useful when we know a priori that some underlying variables are correlated.
Parameters
• dimension (int) – dimension of the matrices
• concentration (ndarray) – concentration/shape parameter of the distribution (often
referred to as eta)
• sample_method (str) – Either “cvine” or “onion”. Both methods are proposed in [1]
and offer the same distribution over correlation matrices. But they are different in how to
generate samples. Defaults to “onion”.
References
[1] Generating random correlation matrices based on vines and extended onion method, Daniel Lewandowski,
Dorota Kurowicka, Harry Joe
arg_constraints = {'concentration': <numpyro.distributions.constraints._GreaterThan obj
reparametrized_params = ['concentration']
support = <numpyro.distributions.constraints._CorrCholesky object>
sample(key, sample_shape=())
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape.
Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned
sample will be filled with iid draws from the distribution instance.
Parameters
• key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
• sample_shape (tuple) – the sample shape for the distribution.
Returns an array of shape sample_shape + batch_shape + event_shape
Return type numpy.ndarray
log_prob(*args, **kwargs)
tree_flatten()
classmethod tree_unflatten(aux_data, params)
LogNormal
variance
Variance of the distribution.
tree_flatten()
Logistic
MultivariateNormal
2.2. Distributions 45
NumPyro Documentation
LowRankMultivariateNormal
Normal
Pareto
StudentT
2.2. Distributions 47
NumPyro Documentation
Parameters
• key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
• sample_shape (tuple) – the sample shape for the distribution.
Returns an array of shape sample_shape + batch_shape + event_shape
Return type numpy.ndarray
log_prob(*args, **kwargs)
mean
Mean of the distribution.
variance
Variance of the distribution.
TruncatedCauchy
TruncatedNormal
TruncatedPolyaGamma
Uniform
Bernoulli
2.2. Distributions 49
NumPyro Documentation
BernoulliLogits
BernoulliProbs
BetaBinomial
2.2. Distributions 51
NumPyro Documentation
variance
Variance of the distribution.
support
Binomial
BinomialLogits
BinomialProbs
sample(key, sample_shape=())
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape.
Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned
sample will be filled with iid draws from the distribution instance.
Parameters
• key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
• sample_shape (tuple) – the sample shape for the distribution.
Returns an array of shape sample_shape + batch_shape + event_shape
Return type numpy.ndarray
log_prob(*args, **kwargs)
logits
mean
Mean of the distribution.
variance
Variance of the distribution.
support
enumerate_support(expand=True)
Returns an array with shape len(support) x batch_shape containing all values in the support.
Categorical
CategoricalLogits
2.2. Distributions 53
NumPyro Documentation
mean
Mean of the distribution.
variance
Variance of the distribution.
support
enumerate_support(expand=True)
Returns an array with shape len(support) x batch_shape containing all values in the support.
CategoricalProbs
DirichletMultinomial
GammaPoisson
2.2. Distributions 55
NumPyro Documentation
Geometric
GeometricLogits
GeometricProbs
sample(key, sample_shape=())
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape.
Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned
sample will be filled with iid draws from the distribution instance.
Parameters
• key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
• sample_shape (tuple) – the sample shape for the distribution.
Returns an array of shape sample_shape + batch_shape + event_shape
Return type numpy.ndarray
log_prob(*args, **kwargs)
logits
mean
Mean of the distribution.
variance
Variance of the distribution.
Multinomial
MultinomialLogits
2.2. Distributions 57
NumPyro Documentation
MultinomialProbs
OrderedLogistic
Parameters
• predictor (numpy.ndarray) – prediction in real domain; typically this is output of a
linear model.
• cutpoints (numpy.ndarray) – positions in real domain to separate categories.
Poisson
PRNGIdentity
class PRNGIdentity
Bases: numpyro.distributions.distribution.Distribution
Distribution over PRNGKey(). This can be used to draw a batch of PRNGKey() using the seed handler.
Only sample method is supported.
is_discrete = True
sample(key, sample_shape=())
Returns a sample from the distribution having shape given by sample_shape + batch_shape + event_shape.
Note that when sample_shape is non-empty, leading dimensions (of size sample_shape) of the returned
sample will be filled with iid draws from the distribution instance.
Parameters
• key (jax.random.PRNGKey) – the rng_key key to be used for the distribution.
• sample_shape (tuple) – the sample shape for the distribution.
Returns an array of shape sample_shape + batch_shape + event_shape
Return type numpy.ndarray
ZeroInflatedPoisson
2.2. Distributions 59
NumPyro Documentation
VonMises
Thin wrappers around TensorFlow Probability (TFP) distributions. For details on the TFP distribution interface, see
its Distribution docs.
BijectorConstraint
class BijectorConstraint(bijector)
A constraint which is codomain of a TensorFlow bijector.
Parameters bijector (Bijector) – a TensorFlow bijector
BijectorTransform
class BijectorTransform(bijector)
A wrapper for TensorFlow bijectors to make them compatible with NumPyro’s transforms.
Parameters bijector (Bijector) – a TensorFlow bijector
TFPDistributionMixin
Autoregressive
BatchReshape
Bates
Bernoulli
Beta
2.2. Distributions 61
NumPyro Documentation
BetaBinomial
Binomial
Blockwise
Categorical
Cauchy
Chi
Chi2
CholeskyLKJ
ContinuousBernoulli
DeterminantalPointProcess
Deterministic
Dirichlet
DirichletMultinomial
DoublesidedMaxwell
Empirical
2.2. Distributions 63
NumPyro Documentation
ExpGamma
ExpInverseGamma
ExpRelaxedOneHotCategorical
Exponential
ExponentiallyModifiedGaussian
FiniteDiscrete
Gamma
GammaGamma
GaussianProcess
GaussianProcessRegressionModel
GeneralizedExtremeValue
GeneralizedNormal
GeneralizedPareto
2.2. Distributions 65
NumPyro Documentation
Geometric
Gumbel
HalfCauchy
HalfNormal
HalfStudentT
HiddenMarkovModel
Horseshoe
Independent
InverseGamma
InverseGaussian
JohnsonSU
JointDistribution
JointDistributionCoroutine
JointDistributionCoroutineAutoBatched
2.2. Distributions 67
NumPyro Documentation
JointDistributionNamed
JointDistributionNamedAutoBatched
JointDistributionSequential
JointDistributionSequentialAutoBatched
Kumaraswamy
LKJ
Laplace
LinearGaussianStateSpaceModel
LogLogistic
LogNormal
Logistic
LogitNormal
MixtureSameFamily
Moyal
2.2. Distributions 69
NumPyro Documentation
Multinomial
MultivariateNormalDiag
MultivariateNormalDiagPlusLowRank
MultivariateNormalFullCovariance
MultivariateNormalLinearOperator
MultivariateNormalTriL
MultivariateStudentTLinearOperator
NegativeBinomial
Normal
OneHotCategorical
OrderedLogistic
PERT
Pareto
2.2. Distributions 71
NumPyro Documentation
PlackettLuce
Poisson
PoissonLogNormalQuadratureCompound
PowerSpherical
ProbitBernoulli
QuantizedDistribution
RelaxedBernoulli
RelaxedOneHotCategorical
Sample
SinhArcsinh
Skellam
SphericalUniform
StoppingRatioLogistic
StudentT
2.2. Distributions 73
NumPyro Documentation
StudentTProcess
TransformedDistribution
Triangular
TruncatedCauchy
TruncatedNormal
Uniform
VariationalGaussianProcess
with TFPDistributionMixin.
VectorDeterministic
VectorExponentialDiag
VonMises
VonMisesFisher
Weibull
WishartLinearOperator
WishartTriL
2.2. Distributions 75
NumPyro Documentation
2.2.6 Constraints
Constraint
class Constraint
Bases: object
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be
optimized.
event_dim = 0
check(value)
Returns a byte tensor of sample_shape + batch_shape indicating whether each event in value satisfies this
constraint.
feasible_like(prototype)
Get a feasible value which has the same shape as dtype as prototype.
boolean
corr_cholesky
corr_matrix
dependent
greater_than
greater_than(lower_bound)
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be
optimized.
integer_interval
integer_interval(lower_bound, upper_bound)
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be
optimized.
integer_greater_than
integer_greater_than(lower_bound)
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be
optimized.
interval
interval(lower_bound, upper_bound)
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be
optimized.
less_than
less_than(upper_bound)
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be
optimized.
lower_cholesky
multinomial
multinomial(upper_bound)
Abstract base class for constraints.
A constraint object represents a region over which a variable is valid, e.g. within which a variable can be
optimized.
nonnegative_integer
ordered_vector
positive
2.2. Distributions 77
NumPyro Documentation
positive_definite
positive_integer
real
real_vector
simplex
unit_interval
2.2.7 Transforms
biject_to
biject_to(constraint)
Transform
class Transform
Bases: object
domain = <numpyro.distributions.constraints._Real object>
codomain = <numpyro.distributions.constraints._Real object>
event_dim
inv
log_abs_det_jacobian(x, y, intermediates=None)
call_with_intermediates(x)
forward_shape(shape)
Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.
inverse_shape(shape)
Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.
AbsTransform
class AbsTransform
Bases: numpyro.distributions.transforms.Transform
domain = <numpyro.distributions.constraints._Real object>
codomain = <numpyro.distributions.constraints._GreaterThan object>
AffineTransform
Note: When scale is a JAX tracer, we always assume that scale > 0 when calculating codomain.
codomain
log_abs_det_jacobian(x, y, intermediates=None)
forward_shape(shape)
Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.
inverse_shape(shape)
Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.
ComposeTransform
class ComposeTransform(parts)
Bases: numpyro.distributions.transforms.Transform
domain
codomain
log_abs_det_jacobian(x, y, intermediates=None)
call_with_intermediates(x)
forward_shape(shape)
Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.
inverse_shape(shape)
Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.
CorrCholeskyTransform
class CorrCholeskyTransform
Bases: numpyro.distributions.transforms.Transform
Transforms a uncontrained real vector 𝑥 with length 𝐷 * (𝐷 − 1)/2 into the Cholesky factor of a D-dimension
correlation matrix. This Cholesky factor is a lower triangular matrix with positive diagonals and unit Euclidean
norm for each row. The transform is processed as follows:
2.2. Distributions 79
NumPyro Documentation
1. First we convert 𝑥 into a lower triangular matrix with the following order:
⎡ ⎤
1 0 0 0
⎢𝑥0 1 0 0⎥
⎢ ⎥
⎣𝑥1 𝑥2 1 0⎦
𝑥3 𝑥4 𝑥5 1
2. For each row 𝑋𝑖 of the lower triangular part, we apply a signed version of class
StickBreakingTransform to transform 𝑋𝑖 into a unit Euclidean length vector using the fol-
lowing steps:
a. Scales into the interval (−1, 1) domain: 𝑟𝑖 = tanh(𝑋𝑖 ).
b. Transforms into an unsigned domain: 𝑧𝑖 = 𝑟𝑖2 .
c. Applies 𝑠𝑖 = 𝑆𝑡𝑖𝑐𝑘𝐵𝑟𝑒𝑎𝑘𝑖𝑛𝑔𝑇 𝑟𝑎𝑛𝑠𝑓 𝑜𝑟𝑚(𝑧𝑖 ).
√
d. Transforms back into signed domain: 𝑦𝑖 = (𝑠𝑖𝑔𝑛(𝑟𝑖 ), 1) * 𝑠𝑖 .
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._CorrCholesky object>
log_abs_det_jacobian(x, y, intermediates=None)
forward_shape(shape)
Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.
inverse_shape(shape)
Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.
ExpTransform
IdentityTransform
class IdentityTransform
Bases: numpyro.distributions.transforms.Transform
log_abs_det_jacobian(x, y, intermediates=None)
InvCholeskyTransform
LowerCholeskyAffine
LowerCholeskyTransform
class LowerCholeskyTransform
Bases: numpyro.distributions.transforms.Transform
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._LowerCholesky object>
log_abs_det_jacobian(x, y, intermediates=None)
forward_shape(shape)
Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.
inverse_shape(shape)
Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.
OrderedTransform
class OrderedTransform
Bases: numpyro.distributions.transforms.Transform
Transform a real vector to an ordered vector.
References:
1. Stan Reference Manual v2.20, section 10.6, Stan Development Team
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._OrderedVector object>
log_abs_det_jacobian(x, y, intermediates=None)
2.2. Distributions 81
NumPyro Documentation
PermuteTransform
class PermuteTransform(permutation)
Bases: numpyro.distributions.transforms.Transform
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._IndependentConstraint object>
log_abs_det_jacobian(x, y, intermediates=None)
PowerTransform
class PowerTransform(exponent)
Bases: numpyro.distributions.transforms.Transform
domain = <numpyro.distributions.constraints._GreaterThan object>
codomain = <numpyro.distributions.constraints._GreaterThan object>
log_abs_det_jacobian(x, y, intermediates=None)
forward_shape(shape)
Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.
inverse_shape(shape)
Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.
SigmoidTransform
class SigmoidTransform
Bases: numpyro.distributions.transforms.Transform
codomain = <numpyro.distributions.constraints._Interval object>
log_abs_det_jacobian(x, y, intermediates=None)
StickBreakingTransform
class StickBreakingTransform
Bases: numpyro.distributions.transforms.Transform
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._Simplex object>
log_abs_det_jacobian(x, y, intermediates=None)
forward_shape(shape)
Infers the shape of the forward computation, given the input shape. Defaults to preserving shape.
inverse_shape(shape)
Infers the shapes of the inverse computation, given the output shape. Defaults to preserving shape.
2.2.8 Flows
InverseAutoregressiveTransform
BlockNeuralAutoregressiveTransform
class BlockNeuralAutoregressiveTransform(bn_arn)
Bases: numpyro.distributions.transforms.Transform
An implementation of Block Neural Autoregressive flow.
References
1. Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz
domain = <numpyro.distributions.constraints._IndependentConstraint object>
codomain = <numpyro.distributions.constraints._IndependentConstraint object>
call_with_intermediates(x)
log_abs_det_jacobian(x, y, intermediates=None)
Calculates the elementwise determinant of the log jacobian.
Parameters
• x (numpy.ndarray) – the input to the transform
• y (numpy.ndarray) – the output of the transform
2.2. Distributions 83
NumPyro Documentation
2.3 Inference
Note: Setting progress_bar=False will improve the speed for many cases.
Parameters
• sampler (MCMCKernel) – an instance of MCMCKernel that determines the sampler for
running MCMC. Currently, only HMC and NUTS are available.
• num_warmup (int) – Number of warmup steps.
• num_samples (int) – Number of samples to generate from the Markov chain.
• thinning (int) – Positive integer that controls the fraction of post-warmup samples that
are retained. For example if thinning is 2 then every other sample is retained. Defaults to 1,
i.e. no thinning.
• num_chains (int) – Number of Number of MCMC chains to run. By default, chains
will be run in parallel using jax.pmap(), failing which, chains will be run in sequence.
• postprocess_fn – Post-processing callable - used to convert a collection of uncon-
strained sample values returned from the sampler to constrained values that lie within the
support of the sample sites. Additionally, this is used to return values at deterministic sites
in the model.
• chain_method (str) – One of ‘parallel’ (default), ‘sequential’, ‘vectorized’. The
method ‘parallel’ is used to execute the drawing process in parallel on XLA devices
(CPUs/GPUs/TPUs), If there are not enough devices for ‘parallel’, we fall back to ‘sequen-
tial’ method to draw chains sequentially. ‘vectorized’ method is an experimental feature
which vectorizes the drawing method, hence allowing us to collect samples in parallel on a
single device.
• progress_bar (bool) – Whether to enable progress bar updates. Defaults to True.
• jit_model_args (bool) – If set to True, this will compile the potential energy compu-
tation as a function of model arguments. As such, calling MCMC.run again on a same sized
but different dataset will not result in additional compilation cost.
post_warmup_state
The state before the sampling phase. If this attribute is not None, run() will skip the warmup phase and
start with the state specified in this attribute.
Note: This attribute can be used to sequentially draw MCMC samples. For example,
last_state
The final MCMC state at the end of the sampling phase.
warmup(rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs)
Run the MCMC warmup adaptation phase. After this call, self.warmup_state will be set and the run()
method will skip the warmup adaptation phase. To run warmup again for the new data, it is required to run
warmup() again.
Parameters
• rng_key (random.PRNGKey) – Random number generator key to be used for the
sampling.
• args – Arguments to be provided to the numpyro.infer.mcmc.MCMCKernel.
init() method. These are typically the arguments needed by the model.
• extra_fields (tuple or list) – Extra fields (aside from
default_fields()) from the state object (e.g. numpyro.infer.hmc.
HMCState for HMC) to collect during the MCMC run.
• collect_warmup (bool) – Whether to collect samples from the warmup phase. De-
faults to False.
• init_params – Initial parameters to begin sampling. The type must be consistent with
the input type to potential_fn.
• kwargs – Keyword arguments to be provided to the numpyro.infer.mcmc.
MCMCKernel.init() method. These are typically the keyword arguments needed by
the model.
run(rng_key, *args, extra_fields=(), init_params=None, **kwargs)
Run the MCMC samplers and collect samples.
Parameters
• rng_key (random.PRNGKey) – Random number generator key to be used for the
sampling. For multi-chains, a batch of num_chains keys can be supplied. If rng_key
does not have batch_size, it will be split in to a batch of num_chains keys.
• args – Arguments to be provided to the numpyro.infer.mcmc.MCMCKernel.
init() method. These are typically the arguments needed by the model.
• extra_fields (tuple or list) – Extra fields (aside from z, diverging) from
numpyro.infer.mcmc.HMCState to collect during the MCMC run.
• init_params – Initial parameters to begin sampling. The type must be consistent with
the input type to potential_fn.
• kwargs – Keyword arguments to be provided to the numpyro.infer.mcmc.
MCMCKernel.init() method. These are typically the keyword arguments needed by
the model.
2.3. Inference 85
NumPyro Documentation
Note: jax allows python code to continue even when the compiled code has not finished yet. This can
cause troubles when trying to profile the code for speed. See https://jax.readthedocs.io/en/latest/async_
dispatch.html and https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.
get_samples(group_by_chain=False)
Get samples from the MCMC run.
Parameters group_by_chain (bool) – Whether to preserve the chain dimension. If True,
all samples will have num_chains as the size of their leading dimension.
Returns Samples having the same data type as init_params. The data type is a dict keyed on site
names if a model containing Pyro primitives is used, but can be any jaxlib.pytree(),
more generally (e.g. when defining a potential_fn for HMC that takes list args).
get_extra_fields(group_by_chain=False)
Get extra fields from the MCMC run.
Parameters group_by_chain (bool) – Whether to preserve the chain dimension. If True,
all samples will have num_chains as the size of their leading dimension.
Returns Extra fields keyed by field names which are specified in the extra_fields keyword of
run().
print_summary(prob=0.9, exclude_deterministic=True)
Print the statistics of posterior samples collected during running this MCMC instance.
Parameters
• prob (float) – the probability mass of samples within the credible interval.
• exclude_deterministic (bool) – whether or not print out the statistics at deter-
ministic sites.
MCMC Kernels
class MCMCKernel
Bases: abc.ABC
Defines the interface for the Markov transition kernel that is used for MCMC inference.
Example:
postprocess_fn(model_args, model_kwargs)
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s
support, in addition to returning deterministic sites in the model.
Parameters
• model_args – Arguments to the model.
• model_kwargs – Keyword arguments to the model.
init(rng_key, num_warmup, init_params, model_args, model_kwargs)
Initialize the MCMCKernel and return an initial state to begin sampling from.
Parameters
• rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
• num_warmup (int) – Number of warmup steps. This can be useful when doing adapta-
tion during warmup.
• init_params (tuple) – Initial parameters to begin sampling. The type must be con-
sistent with the input type to potential_fn.
• model_args – Arguments provided to the model.
• model_kwargs – Keyword arguments provided to the model.
Returns The initial state representing the state of the kernel. This can be any class that is regis-
tered as a pytree.
sample(state, model_args, model_kwargs)
Given the current state, return the next state using the given transition kernel.
Parameters
• state – A pytree class representing the state for the kernel. For HMC, this is given by
HMCState. In general, this could be any class that supports getattr.
• model_args – Arguments provided to the model.
2.3. Inference 87
NumPyro Documentation
Parameters
• model – Python callable containing Pyro primitives. If model is provided, potential_fn
will be inferred using the model.
• potential_fn – Python callable that computes the potential energy given input param-
eters. The input parameters to potential_fn can be any python collection type, provided that
init_params argument to init() has the same type.
• kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix
and momentum. If not provided, the default is euclidean kinetic energy.
• step_size (float) – Determines the size of a single step taken by the verlet integrator
while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set
to 1.
• adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-
up phase using Dual Averaging scheme.
• adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during
warm-up phase using Welford scheme.
• dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default when
dense_mass=False)
• target_accept_prob (float) – Target acceptance probability for step size adaptation
using Dual Averaging. Increasing this value will lead to a smaller step size, hence the
sampling will be slower but more robust. Default to 0.8.
• trajectory_length (float) – Length of a MCMC trajectory for HMC. Default value
is 2𝜋.
• init_strategy (callable) – a per-site initialization function. See Initialization
Strategies section for available functions.
model
sample_field
The attribute of the state object passed to sample() that denotes the MCMC sample. This is used by
postprocess_fn() and for reporting results in MCMC.print_summary().
default_fields
The attributes of the state object to be collected by default during the MCMC run (when MCMC.run() is
called).
get_diagnostics_str(state)
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
init(rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={})
Initialize the MCMCKernel and return an initial state to begin sampling from.
Parameters
• rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
• num_warmup (int) – Number of warmup steps. This can be useful when doing adapta-
tion during warmup.
• init_params (tuple) – Initial parameters to begin sampling. The type must be con-
sistent with the input type to potential_fn.
• model_args – Arguments provided to the model.
• model_kwargs – Keyword arguments provided to the model.
Returns
The initial state representing the state of the kernel. This can be any class that is registered
as a pytree.
postprocess_fn(args, kwargs)
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s
support, in addition to returning deterministic sites in the model.
Parameters
• model_args – Arguments to the model.
• model_kwargs – Keyword arguments to the model.
sample(state, model_args, model_kwargs)
Run HMC from the given HMCState and return the resulting HMCState.
Parameters
• state (HMCState) – Represents the current state.
• model_args – Arguments provided to the model.
• model_kwargs – Keyword arguments provided to the model.
2.3. Inference 89
NumPyro Documentation
Parameters
• model – Python callable containing Pyro primitives. If model is provided, potential_fn
will be inferred using the model.
• potential_fn – Python callable that computes the potential energy given input param-
eters. The input parameters to potential_fn can be any python collection type, provided that
init_params argument to init_kernel has the same type.
• kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix
and momentum. If not provided, the default is euclidean kinetic energy.
• step_size (float) – Determines the size of a single step taken by the verlet integrator
while computing the trajectory using Hamiltonian dynamics. If not specified, it will be set
to 1.
• adapt_step_size (bool) – A flag to decide if we want to adapt step_size during warm-
up phase using Dual Averaging scheme.
• adapt_mass_matrix (bool) – A flag to decide if we want to adapt mass matrix during
warm-up phase using Welford scheme.
• dense_mass (bool) – A flag to decide if mass matrix is dense or diagonal (default when
dense_mass=False)
• target_accept_prob (float) – Target acceptance probability for step size adaptation
using Dual Averaging. Increasing this value will lead to a smaller step size, hence the
sampling will be slower but more robust. Default to 0.8.
• trajectory_length (float) – Length of a MCMC trajectory for HMC. This arg has
no effect in NUTS sampler.
• max_tree_depth (int) – Max depth of the binary tree created during the doubling
scheme of NUTS sampler. Defaults to 10.
• init_strategy (callable) – a per-site initialization function. See Initialization
Strategies section for available functions.
• find_heuristic_step_size (bool) – whether to a heuristic function to adjust the
step size at the beginning of each adaptation window. Defaults to False.
• forward_mode_differentiation (bool) – whether to use forward-mode differen-
tiation or reverse-mode differentiation. By default, we use reverse mode but the forward
mode can be useful in some cases to improve the performance. In addition, some control
sample_field = 'z'
model
get_diagnostics_str(state)
Given the current state, returns the diagnostics string to be added to progress bar for diagnostics purpose.
postprocess_fn(args, kwargs)
Get a function that transforms unconstrained values at sample sites to values constrained to the site’s
support, in addition to returning deterministic sites in the model.
2.3. Inference 91
NumPyro Documentation
Parameters
• model_args – Arguments to the model.
• model_kwargs – Keyword arguments to the model.
init(rng_key, num_warmup, init_params, model_args, model_kwargs)
Initialize the MCMCKernel and return an initial state to begin sampling from.
Parameters
• rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
• num_warmup (int) – Number of warmup steps. This can be useful when doing adapta-
tion during warmup.
• init_params (tuple) – Initial parameters to begin sampling. The type must be con-
sistent with the input type to potential_fn.
• model_args – Arguments provided to the model.
• model_kwargs – Keyword arguments provided to the model.
Returns
The initial state representing the state of the kernel. This can be any class that is registered
as a pytree.
sample(state, model_args, model_kwargs)
Given the current state, return the next state using the given transition kernel.
Parameters
• state – A pytree class representing the state for the kernel. For HMC, this is given by
HMCState. In general, this could be any class that supports getattr.
• model_args – Arguments provided to the model.
• model_kwargs – Keyword arguments provided to the model.
Returns Next state.
class DiscreteHMCGibbs(inner_kernel, *, random_walk=False, modified=False)
Bases: numpyro.infer.hmc_gibbs.HMCGibbs
[EXPERIMENTAL INTERFACE]
A subclass of HMCGibbs which performs Metropolis updates for discrete latent sites.
Note: This class supports enumeration of discrete latent variables. To marginalize out a discrete latent site, we
can specify infer={‘enumerate’: ‘parallel’} keyword in its corresponding sample() statement.
Parameters
• inner_kernel – One of HMC or NUTS.
• discrete_sites (list) – a list of site names for the discrete latent variables that are
covered by the Gibbs sampler.
• random_walk (bool) – If False, Gibbs sampling will be used to draw a sample from the
conditional p(gibbs_site | remaining sites). Otherwise, a sample will be drawn uniformly
from the domain of gibbs_site.
• modified (bool) – whether to use a modified proposal, as suggested in reference [1],
which always proposes a new state for the current Gibbs site. The modified scheme appears
in the literature under the name “modified Gibbs sampler” or “Metropolised Gibbs sampler”.
References:
1. Peskun’s theorem and a modified discrete-state Gibbs sampler, Liu, J. S. (1996)
Example
2.3. Inference 93
NumPyro Documentation
Note: New subsample indices are proposed randomly with replacement at each MCMC step.
References:
1. Hamiltonian Monte Carlo with energy conserving subsampling, Dang, K. D., Quiroz, M., Kohn, R., Minh-
Ngoc, T., & Villani, M. (2019)
2. Speeding Up MCMC by Efficient Data Subsampling, Quiroz, M., Kohn, R., Villani, M., & Tran, M. N.
(2018)
3. The Block Pseudo-Margional Sampler, Tran, M.-N., Kohn, R., Quiroz, M. Villani, M. (2017)
Parameters
• inner_kernel – One of HMC or NUTS.
• num_blocks (int) – Number of blocks to partition subsample into.
Example
Parameters
• rng_key (random.PRNGKey) – Random number generator key to initialize the kernel.
• num_warmup (int) – Number of warmup steps. This can be useful when doing adapta-
tion during warmup.
• init_params (tuple) – Initial parameters to begin sampling. The type must be con-
sistent with the input type to potential_fn.
• model_args – Arguments provided to the model.
• model_kwargs – Keyword arguments provided to the model.
Returns
The initial state representing the state of the kernel. This can be any class that is registered
as a pytree.
sample(state, model_args, model_kwargs)
Given the current state, return the next state using the given transition kernel.
Parameters
• state – A pytree class representing the state for the kernel. For HMC, this is given by
HMCState. In general, this could be any class that supports getattr.
• model_args – Arguments provided to the model.
• model_kwargs – Keyword arguments provided to the model.
Returns Next state.
class SA(model=None, potential_fn=None, adapt_state_size=None, dense_mass=True,
init_strategy=<function init_to_uniform>)
Bases: numpyro.infer.mcmc.MCMCKernel
Sample Adaptive MCMC, a gradient-free sampler.
This is a very fast (in term of n_eff / s) sampler but requires many warmup (burn-in) steps. In each MCMC step,
we only need to evaluate potential function at one point.
Note that unlike in reference [1], we return a randomly selected (i.e. thinned) subset of approximate posterior
samples of size num_chains x num_samples instead of num_chains x num_samples x adapt_state_size.
Note: We recommend to use this kernel with progress_bar=False in MCMC to reduce JAX’s dispatch overhead.
References:
1. Sample Adaptive MCMC (https://papers.nips.cc/paper/9107-sample-adaptive-mcmc), Michael Zhu
Parameters
• model – Python callable containing Pyro primitives. If model is provided, potential_fn
will be inferred using the model.
• potential_fn – Python callable that computes the potential energy given input param-
eters. The input parameters to potential_fn can be any python collection type, provided that
init_params argument to init() has the same type.
• adapt_state_size (int) – The number of points to generate proposal distribution.
Defaults to 2 times latent size.
2.3. Inference 95
NumPyro Documentation
References:
1. MCMC Using Hamiltonian Dynamics, Radford M. Neal
2. The No-U-turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo, Matthew D. Hoff-
man, and Andrew Gelman.
3. A Conceptual Introduction to Hamiltonian Monte Carlo‘, Michael Betancourt
Parameters
• potential_fn – Python callable that computes the potential energy given input param-
eters. The input parameters to potential_fn can be any python collection type, provided that
init_params argument to init_kernel has the same type.
• potential_fn_gen – Python callable that when provided with model arguments / key-
word arguments returns potential_fn. This may be provided to do inference on the same
model with changing data. If the data shape remains the same, we can compile sam-
ple_kernel once, and use the same for multiple inference runs.
• kinetic_fn – Python callable that returns the kinetic energy given inverse mass matrix
and momentum. If not provided, the default is euclidean kinetic energy.
• algo (str) – Whether to run HMC with fixed number of steps or NUTS with adaptive path
length. Default is NUTS.
Returns a tuple of callables (init_kernel, sample_kernel), the first one to initialize the sampler, and
the second one to generate samples given an existing one.
Warning: Instead of using this interface directly, we would highly recommend you to use the higher level
numpyro.infer.MCMC API instead.
Example
>>>
>>> def model(data, labels):
... coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(3), jnp.ones(3)))
... intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data +
˓→intercept).sum(-1)), obs=labels)
>>>
>>> model_info = initialize_model(random.PRNGKey(0), model, model_args=(data,
˓→labels,))
2.3. Inference 97
NumPyro Documentation
2.3. Inference 99
NumPyro Documentation
• mean_accept_prob - Mean acceptance probability until current iteration during warmup or sampling (for
diagnostics).
• diverging - A boolean value to indicate whether the new sample potential energy is diverging from the
current one.
• adapt_state - A SAAdaptState namedtuple which contains adaptation information:
– zs - Step size to be used by the integrator in the next iteration.
– pes - Potential energies of zs.
– loc - Mean of those zs.
– inv_mass_matrix_sqrt - If using dense mass matrix, this is Cholesky of the covariance of zs. Other-
wise, this is standard deviation of those zs.
• rng_key - random number generator seed used for the iteration.
TensorFlow Kernels
Thin wrappers around TensorFlow Probability (TFP) distributions. For details on the TFP distribution interface, see
its TransitionKernel docs.
TFPKernel
Note: By default, uncalibrated kernels will be inner kernels of the MetropolisHastings kernel.
Note: For ReplicaExchangeMC, TFP requires that the shape of step_size of the inner kernel must be
[len(inverse_temperatures), 1] or [len(inverse_temperatures), latent_size].
Parameters
• model – Python callable containing Pyro primitives. If model is provided, potential_fn
will be inferred using the model.
• potential_fn – Python callable that computes the target potential energy given input
parameters. The input parameters to potential_fn can be any python collection type, pro-
vided that init_params argument to init() has the same type.
• init_strategy (callable) – a per-site initialization function. See Initialization
Strategies section for available functions.
• kernel_kwargs – other arguments to be passed to TFP kernel constructor.
HamiltonianMonteCarlo
MetropolisAdjustedLangevinAlgorithm
NoUTurnSampler
RandomWalkMetropolis
ReplicaExchangeMC
SliceSampler
UncalibratedHamiltonianMonteCarlo
UncalibratedLangevin
UncalibratedRandomWalk
MCMC Utilities
Returns a namedtupe ModelInfo which contains the fields (param_info, potential_fn, postpro-
cess_fn, model_trace), where param_info is a namedtuple ParamInfo containing values from
the prior used to initiate MCMC, their corresponding potential energy, and their gradients; post-
process_fn is a callable that uses inverse transforms to convert unconstrained HMC samples to
constrained values that lie within the site’s support, in addition to returning values at determin-
istic sites in the model.
fori_collect(lower, upper, body_fun, init_val, transform=<function identity>, progbar=True, re-
turn_last_val=False, collection_size=None, thinning=1, **progbar_opts)
This looping construct works like fori_loop() but with the additional effect of collecting values from the
loop body. In addition, this allows for post-processing of these samples via transform, and progress bar updates.
Note that, progbar=False will be faster, especially when collecting a lot of samples. Refer to example usage in
hmc().
Parameters
• lower (int) – the index to start the collective work. In other words, we will skip collecting
the first lower values.
• upper (int) – number of times to run the loop body.
• body_fun – a callable that takes a collection of np.ndarray and returns a collection with
the same shape and dtype.
• init_val – initial value to pass as argument to body_fun. Can be any Python collection
type containing np.ndarray objects.
• transform – a callable to post-process the values returned by body_fn.
• progbar – whether to post progress bar updates.
• return_last_val (bool) – If True, the last value is also returned. This has the same
type as init_val.
• thinning – Positive integer that controls the thinning ratio for retained values. Defaults
to 1, i.e. no thinning.
• collection_size (int) – Size of the returned collection. If not specified, the size will
be (upper - lower) // thinning. If the size is larger than (upper - lower)
// thinning, only the top (upper - lower) // thinning entries will be non-
zero.
• **progbar_opts – optional additional progress bar arguments. A diagnostics_fn can be
supplied which when passed the current value from body_fun returns a string that is used
to update the progress bar postfix. Also a progbar_desc keyword argument can be supplied
which is used to label the progress bar.
Returns collection with the same type as init_val with values collected along the leading axis of
np.ndarray objects.
consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None)
Merges subposteriors following consensus Monte Carlo algorithm.
References:
1. Bayes and big data: The consensus Monte Carlo algorithm, Steven L. Scott, Alexander W. Blocker, Fer-
nando V. Bonassi, Hugh A. Chipman, Edward I. George, Robert E. McCulloch
Parameters
• subposteriors (list) – a list in which each element is a collection of samples.
• num_draws (int) – number of draws from the merged posterior.
parametric(subposteriors, diagonal=False)
Merges subposteriors following (embarrassingly parallel) parametric Monte Carlo algorithm.
References:
1. Asymptotically Exact, Embarrassingly Parallel MCMC, Willie Neiswanger, Chong Wang, Eric Xing
Parameters
• subposteriors (list) – a list in which each element is a collection of samples.
• diagonal (bool) – whether to compute weights using variance or covariance, defaults to
False (using covariance).
Returns the estimated mean and variance/covariance parameters of the joined posterior
Parameters
• subposteriors (list) – a list in which each element is a collection of samples.
• num_draws (int) – number of draws from the merged posterior.
• diagonal (bool) – whether to compute weights using variance or covariance, defaults to
False (using covariance).
• rng_key (jax.random.PRNGKey) – source of the randomness, defaults to
jax.random.PRNGKey(0).
Returns a collection of num_draws samples with the same data structure as each subposterior.
... constraint=constraints.positive)
... numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))
Parameters
• model – Python callable with Pyro primitives for the model.
• guide – Python callable with Pyro primitives for the guide (recognition network).
• optim – an instance of _NumpyroOptim.
• loss – ELBO loss, i.e. negative Evidence Lower Bound, to minimize.
• static_kwargs – static arguments for the model / guide, i.e. arguments that remain
constant during fitting.
Returns tuple of (init_fn, update_fn, evaluate).
Note: For a complex training process (e.g. the one requires early stopping, epoch training, varying
args/kwargs,. . . ), we recommend to use the more flexible methods init(), update(), evaluate()
to customize your training procedure.
Parameters
• rng_key (jax.random.PRNGKey) – random number generator seed.
• num_steps (int) – the number of optimization steps.
• args – arguments to the model / guide
• progress_bar (bool) – Whether to enable progress bar updates. Defaults to True.
• kwargs – keyword arguments to the model / guide
Returns a namedtuple with fields params and losses where params holds the optimized values
at numpyro.param sites, and losses is the collected loss during the process.
Return type SVIRunResult
ELBO
class ELBO(num_particles=1)
Bases: numpyro.infer.elbo.Trace_ELBO
Trace_ELBO
class Trace_ELBO(num_particles=1)
Bases: object
A trace implementation of ELBO-based SVI. The estimator is constructed along the lines of references [1] and
[2]. There are no restrictions on the dependency structure of the model or the guide.
This is the most basic implementation of the Evidence Lower Bound, which is the fundamental objective in Vari-
ational Inference. This implementation has various limitations (for example it only supports random variables
with reparameterized samplers) but can be used as a template to build more sophisticated loss objectives.
For more details, refer to http://pyro.ai/examples/svi_part_i.html.
References:
1. Automated Variational Inference in Probabilistic Programming, David Wingate, Theo Weber
2. Black Box Variational Inference, Rajesh Ranganath, Sean Gerrish, David M. Blei
Parameters num_particles – The number of particles/samples used to form the ELBO (gradi-
ent) estimators.
TraceMeanField_ELBO
class TraceMeanField_ELBO(num_particles=1)
Bases: numpyro.infer.elbo.Trace_ELBO
A trace implementation of ELBO-based SVI. This is currently the only ELBO estimator in NumPyro that uses
analytic KL divergences when those are available.
Warning: This estimator may give incorrect results if the mean-field condition is not satisfied. The mean
field condition is a sufficient but not necessary condition for this estimator to be correct. The precise con-
dition is that for every latent variable z in the guide, its parents in the model must not include any latent
variables that are descendants of z in the guide. Here ‘parents in the model’ and ‘descendants in the guide’
is with respect to the corresponding (statistical) dependency structure. For example, this condition is always
satisfied if the model and guide have identical dependency structures.
RenyiELBO
Note: Setting 𝛼 < 1 gives a better bound than the usual ELBO.
Parameters
• alpha (float) – The order of 𝛼-divergence. Here 𝛼 ̸= 1. Default is 0.
• num_particles – The number of particles/samples used to form the objective (gradient)
estimator. Default is 2.
References:
1. Renyi Divergence Variational Inference, Yingzhen Li, Richard E. Turner
2. Importance Weighted Autoencoders, Yuri Burda, Roger Grosse, Ruslan Salakhutdinov
loss(rng_key, param_map, model, guide, *args, **kwargs)
Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles.
Parameters
• rng_key (jax.random.PRNGKey) – random number generator seed.
• param_map (dict) – dictionary of current parameter values keyed by site name.
• model – Python callable with NumPyro primitives for the model.
• guide – Python callable with NumPyro primitives for the guide.
• args – arguments to the model / guide (these can possibly vary during the course of
fitting).
• kwargs – keyword arguments to the model / guide (these can possibly vary during the
course of fitting).
Returns negative of the Renyi Evidence Lower Bound (ELBO) to be minimized.
AutoContinuous
Parameters
• model (callable) – A NumPyro model.
• prefix (str) – a prefix that will be prefixed to all param internal sites.
• init_loc_fn (callable) – A per-site initialization function. See Initialization Strate-
gies section for available functions.
get_base_dist()
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution. This should not depend on the model’s *args, **kwargs.
get_transform(params)
Returns the transformation learned by the guide to generate samples from the unconstrained (approximate)
posterior.
Parameters params (dict) – Current parameters of model and autoguide. The parameters
can be obtained using get_params() method from SVI.
Returns the transform of posterior distribution
Return type Transform
get_posterior(params)
Returns the posterior distribution.
Parameters params (dict) – Current parameters of model and autoguide. The parameters
can be obtained using get_params() method from SVI.
sample_posterior(rng_key, params, sample_shape=())
Get samples from the learned posterior.
Parameters
• rng_key (jax.random.PRNGKey) – random key to be used draw samples.
• params (dict) – Current parameters of model and autoguide. The parameters can be
obtained using get_params() method from SVI.
Parameters
• params (dict) – A dict containing parameter values. The parameters can be obtained
using get_params() method from SVI.
• quantiles (list) – A list of requested quantiles between 0 and 1.
Returns A dict mapping sample site name to a list of quantile values.
Return type dict
AutoBNAFNormal
References
1. Block Neural Autoregressive Flow, Nicola De Cao, Ivan Titov, Wilker Aziz
Parameters
• model (callable) – a generative model.
• prefix (str) – a prefix that will be prefixed to all param internal sites.
• init_loc_fn (callable) – A per-site initialization function.
• num_flows (int) – the number of flows to be used, defaults to 3.
get_base_dist()
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution. This should not depend on the model’s *args, **kwargs.
AutoDiagonalNormal
get_base_dist()
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution. This should not depend on the model’s *args, **kwargs.
get_transform(params)
Returns the transformation learned by the guide to generate samples from the unconstrained (approximate)
posterior.
Parameters params (dict) – Current parameters of model and autoguide. The parameters
can be obtained using get_params() method from SVI.
Returns the transform of posterior distribution
Return type Transform
get_posterior(params)
Returns a diagonal Normal posterior distribution.
median(params)
Returns the posterior median value of each latent variable.
Parameters params (dict) – A dict containing parameter values. The parameters can be
obtained using get_params() method from SVI.
Returns A dict mapping sample site name to median tensor.
Return type dict
quantiles(params, quantiles)
Returns posterior quantiles each latent variable. Example:
Parameters
• params (dict) – A dict containing parameter values. The parameters can be obtained
using get_params() method from SVI.
• quantiles (list) – A list of requested quantiles between 0 and 1.
AutoMultivariateNormal
get_base_dist()
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution. This should not depend on the model’s *args, **kwargs.
get_transform(params)
Returns the transformation learned by the guide to generate samples from the unconstrained (approximate)
posterior.
Parameters params (dict) – Current parameters of model and autoguide. The parameters
can be obtained using get_params() method from SVI.
Returns the transform of posterior distribution
Return type Transform
get_posterior(params)
Returns a multivariate Normal posterior distribution.
median(params)
Returns the posterior median value of each latent variable.
Parameters params (dict) – A dict containing parameter values. The parameters can be
obtained using get_params() method from SVI.
Returns A dict mapping sample site name to median tensor.
Return type dict
quantiles(params, quantiles)
Returns posterior quantiles each latent variable. Example:
Parameters
• params (dict) – A dict containing parameter values. The parameters can be obtained
using get_params() method from SVI.
• quantiles (list) – A list of requested quantiles between 0 and 1.
Returns A dict mapping sample site name to a list of quantile values.
Return type dict
AutoIAFNormal
Parameters
• model (callable) – a generative model.
• prefix (str) – a prefix that will be prefixed to all param internal sites.
• init_loc_fn (callable) – A per-site initialization function.
• num_flows (int) – the number of flows to be used, defaults to 3.
• hidden_dims (list) – the dimensionality of the hidden units per layer. Defaults to
[latent_dim, latent_dim].
• skip_connections (bool) – whether to add skip connections from the input to the
output of each flow. Defaults to False.
• nonlinearity (callable) – the nonlinearity to use in the feedforward network. De-
faults to jax.experimental.stax.Elu().
get_base_dist()
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution. This should not depend on the model’s *args, **kwargs.
AutoLaplaceApproximation
get_base_dist()
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution. This should not depend on the model’s *args, **kwargs.
get_transform(params)
Returns the transformation learned by the guide to generate samples from the unconstrained (approximate)
posterior.
Parameters params (dict) – Current parameters of model and autoguide. The parameters
can be obtained using get_params() method from SVI.
Returns the transform of posterior distribution
Return type Transform
get_posterior(params)
Returns a multivariate Normal posterior distribution.
sample_posterior(rng_key, params, sample_shape=())
Get samples from the learned posterior.
Parameters
• rng_key (jax.random.PRNGKey) – random key to be used draw samples.
• params (dict) – Current parameters of model and autoguide. The parameters can be
obtained using get_params() method from SVI.
• sample_shape (tuple) – batch shape of each latent sample, defaults to ().
Returns a dict containing samples drawn the this guide.
Return type dict
median(params)
Returns the posterior median value of each latent variable.
Parameters params (dict) – A dict containing parameter values. The parameters can be
obtained using get_params() method from SVI.
Returns A dict mapping sample site name to median tensor.
Return type dict
quantiles(params, quantiles)
Returns posterior quantiles each latent variable. Example:
Parameters
• params (dict) – A dict containing parameter values. The parameters can be obtained
using get_params() method from SVI.
• quantiles (list) – A list of requested quantiles between 0 and 1.
Returns A dict mapping sample site name to a list of quantile values.
Return type dict
AutoLowRankMultivariateNormal
get_base_dist()
Returns the base distribution of the posterior when reparameterized as a
TransformedDistribution. This should not depend on the model’s *args, **kwargs.
get_transform(params)
Returns the transformation learned by the guide to generate samples from the unconstrained (approximate)
posterior.
Parameters params (dict) – Current parameters of model and autoguide. The parameters
can be obtained using get_params() method from SVI.
Returns the transform of posterior distribution
Return type Transform
get_posterior(params)
Returns a lowrank multivariate Normal posterior distribution.
median(params)
Returns the posterior median value of each latent variable.
Parameters params (dict) – A dict containing parameter values. The parameters can be
obtained using get_params() method from SVI.
Returns A dict mapping sample site name to median tensor.
Return type dict
quantiles(params, quantiles)
Returns posterior quantiles each latent variable. Example:
Parameters
• params (dict) – A dict containing parameter values. The parameters can be obtained
using get_params() method from SVI.
• quantiles (list) – A list of requested quantiles between 0 and 1.
Returns A dict mapping sample site name to a list of quantile values.
Return type dict
AutoNormal
This should be equivalent to :class: AutoDiagonalNormal , but with more convenient site names and with better
support for mean field ELBO.
Usage:
guide = AutoNormal(model)
svi = SVI(model, guide, ...)
Parameters
• model (callable) – A NumPyro model.
• prefix (str) – a prefix that will be prefixed to all param internal sites.
• init_loc_fn (callable) – A per-site initialization function. See Initialization Strate-
gies section for available functions.
• init_scale (float) – Initial scale for the standard deviation of each (unconstrained
transformed) latent variable.
• create_plates (callable) – An optional function inputing the same *args,
**kwargs as model() and returning a numpyro.plate or iterable of plates. Plates
not returned will be created automatically as usual. This is useful for data subsampling.
AutoDelta
Usage:
guide = AutoDelta(model)
svi = SVI(model, guide, ...)
Parameters
• model (callable) – A NumPyro model.
• prefix (str) – a prefix that will be prefixed to all param internal sites.
• init_loc_fn (callable) – A per-site initialization function. See Initialization Strate-
gies section for available functions.
• create_plates (callable) – An optional function inputing the same *args,
**kwargs as model() and returning a numpyro.plate or iterable of plates. Plates
not returned will be created automatically as usual. This is useful for data subsampling.
2.3.4 Reparameterizers
Loc-Scale Decentering
Parameters
• centered (float) – optional centered parameter. If None (default) learn a per-site per-
element centering parameter in [0,1]. If 0, fully decenter the distribution; if 1, preserve
the centered distribution unchanged.
Neural Transport
This reparameterization works only for latent variables, not likelihoods. Note that all sites must share a single
common NeuTraReparam instance, and that the model must have static structure.
[1] Hoffman, M. et al. (2019) “NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Trans-
port” https://arxiv.org/abs/1903.03704
Parameters
• guide (AutoContinuous) – A guide.
• params – trained parameters of the guide.
reparam(fn=None)
__call__(name, fn, obs)
Parameters
• name (str) – A sample site name.
• fn (Distribution) – A distribution.
• obs (numpy.ndarray) – Observed value or None.
Returns A pair (new_fn, value).
transform_sample(latent)
Given latent samples from the warped posterior (with possible batch dimensions), return a dict of samples
from the latent sites in the model.
Parameters latent – sample from the warped posterior (possibly batched).
Returns a dict of samples keyed by latent sites in the model.
Return type dict
Transformed Distributions
class TransformReparam
Bases: numpyro.infer.reparam.Reparam
Reparameterizer for TransformedDistribution latent variables.
This is useful for transformed distributions with complex, geometry-changing transforms, where the posterior
has simple shape in the space of base_dist.
This reparameterization works only for latent variables, not likelihoods.
__call__(name, fn, obs)
Parameters
• name (str) – A sample site name.
• fn (Distribution) – A distribution.
• obs (numpy.ndarray) – Observed value or None.
Returns A pair (new_fn, value).
Effect handlers
Parameters
• x – An object.
• output (funsor.domains.Domain) – An optional output hint to uniquely convert a
data to a Funsor (e.g. when x is a string).
• dim_to_name (OrderedDict) – An optional mapping from negative batch dimensions
to name strings.
• dim_type (int) – Either 0, 1, or 2. This optional argument indicates a dimension should
be treated as ‘local’, ‘global’, or ‘visible’, which can be used to interact with the global
DimStack.
Returns A Funsor equivalent to x.
Return type funsor.terms.Funsor
class trace(fn=None)
Bases: numpyro.handlers.trace
This version of trace handler records information necessary to do packing after execution.
Each sample site is annotated with a “dim_to_name” dictionary, which can be passed directly to
to_funsor().
postprocess_message(msg)
Inference Utilities
config_enumerate(fn, default=’parallel’)
Configures enumeration for all relevant sites in a NumPyro model.
When configuring for exhaustive enumeration of discrete variables, this configures all sample sites whose dis-
tribution satisfies .has_enumerate_support == True.
This can be used as either a function:
model = config_enumerate(model)
or as a decorator:
@config_enumerate
def model(*args, **kwargs):
...
Parameters
• fn (callable) – Python callable with NumPyro primitives.
• default (str) – Which enumerate strategy to use, one of “sequential”, “parallel”, or
None. Defaults to “parallel”.
Parameters
• model – Python callable containing NumPyro primitives. Typically, the model has been
enumerated by using enum handler:
enum_model = numpyro.contrib.funsor.enum(model)
with plate_to_enum_plate():
model_trace = numpyro.contrib.funsor.trace(enum_model).get_trace(
*model_args, **model_kwargs)
2.3.6 Optimizers
Optimizer classes defined here are light wrappers over the corresponding optimizers sourced from jax.
experimental.optimizers with an interface that is better suited for working with NumPyro inference algo-
rithms.
Adam
Adagrad
ClippedAdam
Minimize
Momentum
RMSProp
RMSPropMomentum
SGD
SM3
Returns a pair of the output of objective function and the new optimizer state.
get_params(state: Tuple[int, _OptState]) → _Params
Get current parameter values.
Parameters state – current optimizer state.
Returns collection with current value for parameters.
init(params: _Params) → Tuple[int, _OptState]
Initialize the optimizer with parameters designated to be optimized.
Parameters params – a collection of numpy arrays.
Returns initial optimizer state.
update(g: _Params, state: Tuple[int, _OptState]) → Tuple[int, _OptState]
Gradient update for the optimizer.
Parameters
• g – gradient information for parameters.
• state – current optimizer state.
Returns new optimizer state after the update.
2.3.7 Diagnostics
This provides a small set of utilities in NumPyro that are used to diagnose posterior samples.
Autocorrelation
autocorrelation(x, axis=0)
Computes the autocorrelation of samples at dimension axis.
Parameters
• x (numpy.ndarray) – the input array.
• axis (int) – the dimension to calculate autocorrelation.
Returns autocorrelation of x.
Return type numpy.ndarray
Autocovariance
autocovariance(x, axis=0)
Computes the autocovariance of samples at dimension axis.
Parameters
• x (numpy.ndarray) – the input array.
• axis (int) – the dimension to calculate autocovariance.
Returns autocovariance of x.
Return type numpy.ndarray
effective_sample_size(x)
Computes effective sample size of input x, where the first dimension of x is chain dimension and the second
dimension of x is draw dimension.
References:
1. Introduction to Markov Chain Monte Carlo, Charles J. Geyer
2. Stan Reference Manual version 2.18, Stan Development Team
Gelman Rubin
gelman_rubin(x)
Computes R-hat over chains of samples x, where the first dimension of x is chain dimension and the second
dimension of x is draw dimension. It is required that x.shape[0] >= 2 and x.shape[1] >= 2.
Parameters x (numpy.ndarray) – the input array.
Returns R-hat of x.
Return type numpy.ndarray
split_gelman_rubin(x)
Computes split R-hat over chains of samples x, where the first dimension of x is chain dimension and the second
dimension of x is draw dimension. It is required that x.shape[1] >= 4.
Parameters x (numpy.ndarray) – the input array.
Returns split R-hat of x.
Return type numpy.ndarray
HPDI
Summary
enable_validation
enable_validation(is_validate=True)
Enable or disable validation checks in NumPyro. Validation checks provide useful warnings and errors, e.g.
NaN checks, validating distribution arguments and support values, etc. which is useful for debugging.
Note: This utility does not take effect under JAX’s JIT compilation or vectorized transformation jax.
vmap().
validation_enabled
validation_enabled(is_validate=True)
Context manager that is useful when temporarily enabling/disabling validation checks.
Parameters is_validate (bool) – whether to enable validation checks.
enable_x64
enable_x64(use_x64=True)
Changes the default array type to use 64 bit precision as in NumPy.
Parameters use_x64 (bool) – when True, JAX arrays will use 64 bits by default; else 32 bits.
set_platform
set_platform(platform=None)
Changes platform to CPU, GPU, or TPU. This utility only takes effect at the beginning of your program.
Parameters platform (str) – either ‘cpu’, ‘gpu’, or ‘tpu’.
set_host_device_count
set_host_device_count(n)
By default, XLA considers all CPU cores as one device. This utility tells XLA that there are n host (CPU)
devices available to use. As a consequence, this allows parallel mapping in JAX jax.pmap() to work in CPU
platform.
Note: This utility only takes effect at the beginning of your program. Under the hood, this sets the environ-
ment variable XLA_FLAGS=–xla_force_host_platform_device_count=[num_devices], where [num_device] is
the desired number of CPU devices n.
Warning: Our understanding of the side effects of using the xla_force_host_platform_device_count flag in
XLA is incomplete. If you observe some strange phenomenon when using this utility, please let us know
through our issue or forum page. More information is available in this JAX issue.
Predictive
Warning: The interface for the Predictive class is experimental, and might change in the future.
Parameters
• model – Python callable containing Pyro primitives.
• posterior_samples (dict) – dictionary of samples from the posterior.
• guide (callable) – optional guide to get posterior samples of sites not present in pos-
terior_samples.
• params (dict) – dictionary of values for param sites of model/guide.
• num_samples (int) – number of samples
• return_sites (list) – sites to return; by default only sample sites not present in pos-
terior_samples are returned.
• parallel (bool) – whether to predict in parallel using JAX vectorized map jax.
vmap(). Defaults to False.
• batch_ndims – the number of batch dimensions in posterior samples. Some usages:
– set batch_ndims=0 to get prediction for 1 single sample
– set batch_ndims=1 to get prediction for posterior_samples with shapes (num_samples x
...)
– set batch_ndims=2 to get prediction for posterior_samples with shapes (num_chains x
N x . . . ). Note that if num_samples argument is not None, its value should be equal to
num_chains x N.
Returns dict of samples from the predictive distribution.
log_density
transform_fn
constrain_fn
potential_energy
log_likelihood
find_valid_initial_params
Initialization Strategies
init_to_feasible
init_to_feasible(site=None)
Initialize to an arbitrary feasible point, ignoring distribution parameters.
init_to_median
init_to_median(site=None, num_samples=15)
Initialize to the prior median. For priors with no .sample method implemented, we defer to the
init_to_uniform() strategy.
Parameters num_samples (int) – number of prior points to calculate median.
init_to_sample
init_to_sample(site=None)
Initialize to a prior sample. For priors with no .sample method implemented, we defer to the
init_to_uniform() strategy.
init_to_uniform
init_to_uniform(site=None, radius=2)
Initialize to a random point in the area (-radius, radius) of unconstrained domain.
Parameters radius (float) – specifies the range to draw an initial point in the unconstrained
domain.
init_to_value
init_to_value(site=None, values={})
Initialize to the value specified in values. We defer to init_to_uniform() strategy for sites which do not
appear in values.
Parameters values (dict) – dictionary of initial values keyed by site name.
Tensor Indexing
vindex(tensor, args)
Vectorized advanced indexing with broadcasting semantics.
See also the convenience wrapper Vindex.
This is useful for writing indexing code that is compatible with batching and enumeration, especially for select-
ing mixture components with discrete random variables.
For example suppose x is a parameter with len(x.shape) == 3 and we wish to generalize the expression
x[i, :, j] from integer i,j to tensors i,j with batch dims and enum dims (but no event dims). Then we
can write the generalize version using Vindex
xij = Vindex(x)[i, :, j]
To handle the case when x may also contain batch dimensions (e.g. if x was sampled in a plated context as when
using vectorized particles), vindex() uses the special convention that Ellipsis denotes batch dimensions
(hence ... can appear only on the left, never in the middle or in the right). Suppose x has event dim 3. Then
we can write:
old_batch_shape = x.shape[:-3]
old_event_shape = x.shape[-3:]
Note that this special handling of Ellipsis differs from the NEP [1].
Formally, this function assumes:
1. Each arg is either Ellipsis, slice(None), an integer, or a batched integer tensor (i.e. with empty
event shape). This function does not support Nontrivial slices or boolean tensor masks. Ellipsis can
only appear on the left as args[0].
2. If args[0] is not Ellipsis then tensor is not batched, and its event dim is equal to
len(args).
3. If args[0] is Ellipsis then tensor is batched and its event dim is equal to len(args[1:]).
Dims of tensor to the left of the event dims are considered batch dims and will be broadcasted with dims
of tensor args.
Note that if none of the args is a tensor with len(shape) > 0, then this function behaves like standard
indexing:
References
[1] https://www.numpy.org/neps/nep-0021-advanced-indexing.html introduces vindex as a helper for
vectorized indexing. This implementation is similar to the proposed notation x.vindex[] except for
slightly different handling of Ellipsis.
Parameters
• tensor (jnp.ndarray) – A tensor to be indexed.
• args (tuple) – An index, as args to __getitem__.
Returns A nonstandard interpetation of tensor[args].
Return type jnp.ndarray
class Vindex(tensor)
Bases: object
Convenience wrapper around vindex().
The following are equivalent:
Vindex(x)[..., i, j, :]
vindex(x, (Ellipsis, i, j, slice(None)))
In this tutorial, we will explore how to do bayesian regression in NumPyro, using a simple example adapted from
Statistical Rethinking [1]. In particular, we would like to explore the following:
• Write a simple model using the sample NumPyro primitive.
• Run inference using MCMC in NumPyro, in particular, using the No U-Turn Sampler (NUTS) to get a posterior
distribution over our regression parameters of interest.
• Learn about inference utilities such as Predictive and log_likelihood.
• Learn how we can use effect-handlers in NumPyro to generate execution traces from the model, condition on
sample statements, seed models with RNG seeds, etc., and use this to implement various utilities that will be
useful for MCMC. e.g. computing model log likelihood, generating empirical distribution over the posterior
predictive, etc.
1. Dataset
2. Regression Model to Predict Divorce Rate
• Model-1: Predictor-Marriage Rate
• Posterior Distribution over the Regression Parameters
• Posterior Predictive Distribution
• Predictive Utility With Effect Handlers
• Model Predictive Density
• Model-2: Predictor-Median Age of Marriage
• Model-3: Predictor-Marriage Rate and Median Age of Marriage
• Divorce Rate Residuals by State
139
NumPyro Documentation
[ ]: %reset -s -f
[2]: import os
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS
plt.style.use('bmh')
if "NUMPYRO_SPHINXBUILD" in os.environ:
set_matplotlib_formats('svg')
assert numpyro.__version__.startswith('0.5.0')
3.2 Dataset
For this example, we will use the WaffleDivorce dataset from Chapter 05, Statistical Rethinking [1]. The dataset
contains divorce rates in each of the 50 states in the USA, along with predictors such as population, median age of
marriage, whether it is a Southern state and, curiously, number of Waffle Houses.
Population1860 PropSlaves1860
0 964201 0.450000
1 0 0.000000
2 0 0.000000
3 435450 0.260000
4 379994 0.000000
5 34277 0.000000
6 460147 0.000000
7 112216 0.016000
8 75080 0.000000
9 140424 0.440000
10 1057286 0.440000
11 0 0.000000
12 0 0.000000
13 1711951 0.000000
14 1350428 0.000000
15 674913 0.000000
16 107206 0.000019
17 1155684 0.000000
18 708002 0.470000
19 628279 0.000000
20 687049 0.130000
21 1231066 0.000000
22 749113 0.000000
(continues on next page)
Let us plot the pair-wise relationship amongst the main variables in the dataset, using seaborn.pairplot.
From the plots above, we can clearly observe that there is a relationship between divorce rates and marriage rates in a
state (as might be expected), and also between divorce rates and median age of marriage.
There is also a weak relationship between number of Waffle Houses and divorce rates, which is not obvious from the
plot above, but will be clearer if we regress Divorce against WaffleHouse and plot the results.
This is an example of a spurious association. We do not expect the number of Waffle Houses in a state to affect the
divorce rate, but it is likely correlated with other factors that have an effect on the divorce rate. We will not delve into
this spurious association in this tutorial, but the interested reader is encouraged to read Chapters 5 and 6 of [1] which
explores the problem of causal association in the presence of multiple predictors.
For simplicity, we will primarily focus on marriage rate and the median age of marriage as our predictors for divorce
rate throughout the remaining tutorial.
Let us now write a regressionn model in NumPyro to predict the divorce rate as a linear function of marriage rate and
median age of marriage in each of the states.
First, note that our predictor variables have somewhat different scales. It is a good practice to standardize our predictors
and response variables to mean 0 and standard deviation 1, which should result in faster inference.
dset['AgeScaled'] = dset.MedianAgeMarriage.pipe(standardize)
dset['MarriageScaled'] = dset.Marriage.pipe(standardize)
dset['DivorceScaled'] = dset.Divorce.pipe(standardize)
We write the NumPyro model as follows. While the code should largely be self-explanatory, take note of the following:
• In NumPyro, model code is any Python callable which can optionally accept additional arguments and keywords.
For HMC which we will be using for this tutorial, these arguments and keywords remain static during inference,
but we can reuse the same model to generate predictions on new data.
• In addition to regular Python statements, the model code also contains primitives like sample. These primitives
can be interpreted with various side-effects using effect handlers. For more on effect handlers, refer to [3],
[4]. For now, just remember that a sample statement makes this a stochastic function that samples some
latent parameters from a prior distribution. Our goal is to infer the posterior distribution of these parameters
conditioned on observed data.
• The reason why we have kept our predictors as optional keyword arguments is to be able to reuse the same
model as we vary the set of predictors. Likewise, the reason why the response variable is optional is that we
would like to reuse this model to sample from the posterior predictive distribution. See the section on plotting
the posterior predictive distribution, as an example.
[7]: def model(marriage=None, age=None, divorce=None):
a = numpyro.sample('a', dist.Normal(0., 0.2))
M, A = 0., 0.
if marriage is not None:
bM = numpyro.sample('bM', dist.Normal(0., 0.5))
M = bM * marriage
if age is not None:
bA = numpyro.sample('bA', dist.Normal(0., 0.5))
A = bA * age
sigma = numpyro.sample('sigma', dist.Exponential(1.))
mu = a + M + A
numpyro.sample('obs', dist.Normal(mu, sigma), obs=divorce)
We first try to model the divorce rate as depending on a single variable, marriage rate. As mentioned above, we can
use the same model code as earlier, but only pass values for marriage and divorce keyword arguments. We will
use the No U-Turn Sampler (see [5] for more details on the NUTS algorithm) to run inference on this simple model.
The Hamiltonian Monte Carlo (or, the NUTS) implementation in NumPyro takes in a potential energy function. This is
the negative log joint density for the model. Therefore, for our model description above, we need to construct a function
which given the parameter values returns the potential energy (or negative log joint density). Additionally, the verlet
integrator in HMC (or, NUTS) returns sample values simulated using Hamiltonian dynamics in the unconstrained
space. As such, continuous variables with bounded support need to be transformed into unconstrained space using
bijective transforms. We also need to transform these samples back to their constrained support before returning these
values to the user. Thankfully, this is handled on the backend for us, within a convenience class for doing MCMC
inference that has the following methods:
• run(...): runs warmup, adapts steps size and mass matrix, and does sampling using the sample from the
warmup phase.
• print_summary(): print diagnostic information like quantiles, effective sample size, and the Gelman-Rubin
diagnostic.
• get_samples(): gets samples from the posterior distribution.
Note the following:
• JAX uses functional PRNGs. Unlike other languages / frameworks which maintain a global random state, in
JAX, every call to a sampler requires an explicit PRNGKey. We will split our initial random seed for subsequent
operations, so that we do not accidentally reuse the same seed.
• We run inference with the NUTS sampler. To run vanilla HMC, we can instead use the HMC class.
[8]: # Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key_, marriage=dset.MarriageScaled.values, divorce=dset.DivorceScaled.
˓→values)
Number of divergences: 0
We notice that the progress bar gives us online statistics on the acceptance probability, step size and number of steps
taken per sample while running NUTS. In particular, during warmup, we adapt the step size and mass matrix to achieve
a certain target acceptance probability which is 0.8, by default. We were able to successfully adapt our step size to
achieve this target in the warmup phase.
During warmup, the aim is to adapt hyper-parameters such as step size and mass matrix (the HMC algorithm is very
sensitive to these hyper-parameters), and to reach the typical set (see [6] for more details). If there are any issues in the
model specification, the first signal to notice would be low acceptance probabilities or very high number of steps. We
use the sample from the end of the warmup phase to seed the MCMC chain (denoted by the second sample progress
bar) from which we generate the desired number of samples from our target distribution.
At the end of inference, NumPyro prints the mean, std and 90% CI values for each of the latent parameters. Note that
since we standardized our predictors and response variable, we would expect the intercept to have mean 0, as can be
seen here. It also prints other convergence diagnostics on the latent parameters in the model, including effective sample
size and the gelman rubin diagnostic (𝑅). ˆ The value for these diagnostics indicates that the chain has converged to the
target distribution. In our case, the “target distribution” is the posterior distribution over the latent parameters that we
are interested in. Note that this is often worth verifying with multiple chains for more complicated models. In the end,
samples_1 is a collection (in our case, a dict since init_samples was a dict) containing samples from the
posterior distribution for each of the latent parameters in the model.
To look at our regression fit, let us plot the regression line using our posterior estimates for the regression parameters,
along with the 90% Credibility Interval (CI). Note that the hpdi function in NumPyro’s diagnostics module can be
used to compute CI. In the functions below, note that the collected samples from the posterior are all along the leading
axis.
[9]: def plot_regression(x, y_mean, y_hpdi):
# Sort values for plotting by x axis
idx = jnp.argsort(x)
marriage = x[idx]
mean = y_mean[idx]
hpdi = y_hpdi[:, idx]
divorce = dset.DivorceScaled.values[idx]
# Plot
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
ax.plot(marriage, mean)
ax.plot(marriage, divorce, 'o')
ax.fill_between(marriage, hpdi[0], hpdi[1], alpha=0.3, interpolate=True)
return ax
We can see from the plot, that the CI broadens towards the tails where the data is relatively sparse, as can be expected.
Let us now look at the posterior predictive distribution to see how our predictive distribution looks with respect to
the observed divorce rates. To get samples from the posterior predictive distribution, we need to run the model by
substituting the latent parameters with samples from the posterior. NumPyro provides a handy Predictive utility for
this purpose. Note that by default we generate a single prediction for each sample from the joint posterior distribution,
but this can be controlled using the num_samples argument.
To remove the magic behind Predictive, let us see how we can combine effect handlers with the vmap JAX
primitive to implement our own simplified predictive utility function that can do vectorized predictions.
Note the use of the condition, seed and trace effect handlers in the predict function.
• The seed effect-handler is used to wrap a stochastic function with an initial PRNGKey seed. When a sample
statement inside the model is called, it uses the existing seed to sample from a distribution but this effect-handler
also splits the existing key to ensure that future sample calls in the model use the newly split key instead. This
is to prevent us from having to explicitly pass in a PRNGKey to each sample statement in the model.
• The condition effect handler conditions the latent sample sites to certain values. In our case, we are condi-
tioning on values from the posterior distribution returned by MCMC.
• The trace effect handler runs the model and records the execution trace within an OrderedDict. This trace
object contains execution metadata that is useful for computing quantities such as the log joint density.
It should be clear now that the predict function simply runs the model by substituting the latent parameters with
samples from the posterior (generated by the mcmc function) to generate predictions. Note the use of JAX’s auto-
vectorization transform called vmap to vectorize predictions. Note that if we didn’t use vmap, we would have to
use a native for loop which for each sample which is much slower. Each draw from the posterior can be used to get
predictions over all the 50 states. When we vectorize this over all the samples from the posterior using vmap, we
will get a predictions_1 array of shape (num_samples, 50). We can then compute the mean and 90% CI
of these samples to plot the posterior predictive distribution. We note that our mean predictions match those obtained
from the Predictive utility class.
[12]: # Using the same key as we used for Predictive - note that the results are identical.
We have used the same plot_regression function as earlier. We notice that our CI for the predictive distribution
is much broader as compared to the last plot due to the additional noise introduced by the sigma parameter. Most
data points lie well within the 90% CI, which indicates a good fit.
Likewise, making use of effect-handlers and vmap, we can also compute the log likelihood for this model given the
dataset, and the log posterior predictive density [6] which is given by
𝑛 ∫︁ 𝑛
𝑝(𝜃𝑠 )
∏︁ ∑︁ ∑︀
𝑠
𝑙𝑜𝑔 𝑝(𝑦𝑖 |𝜃)𝑝𝑝𝑜𝑠𝑡 (𝜃)𝑑𝜃 ≈ 𝑙𝑜𝑔
𝑖=1 𝑖=1
𝑆
𝑛
∑︁ ∑︁
= (𝑙𝑜𝑔 𝑝(𝜃𝑠 ) − 𝑙𝑜𝑔(𝑆))
𝑖=1 𝑠
.
Here, 𝑖 indexes the observed data points 𝑦 and 𝑠 indexes the posterior samples over the latent parameters 𝜃. If the
posterior predictive density for a model has a comparatively high value, it indicates that the observed data-points have
higher probability under the given model.
Note that NumPyro provides the log_likelihood utility function that can be used directly for computing log
likelihood as in the first function for any general model. In this tutorial, we would like to emphasize that there is
nothing magical about such utility functions, and you can roll out your own inference utilities using NumPyro’s effect
handling stack.
divorce=dset.
˓→DivorceScaled.values)))
We will now model the divorce rate as a function of the median age of marriage. The computations are mostly a
reproduction of what we did for Model 1. Notice the following:
• Divorce rate is inversely related to the age of marriage. Hence states where the median age of marriage is low
will likely have a higher divorce rate.
• We get a higher log likelihood as compared to Model 2, indicating that median age of marriage is likely a much
better predictor of divorce rate.
Number of divergences: 0
Finally, we will also model divorce rate as depending on both marriage rate as well as the median age of marriage. Note
that the model’s posterior predictive density is similar to Model 2 which likely indicates that the marginal information
from marriage rate in predicting divorce rate is low when the median age of marriage is already known.
[20]: rng_key, rng_key_ = random.split(rng_key)
mcmc.run(rng_key_, marriage=dset.MarriageScaled.values,
age=dset.AgeScaled.values, divorce=dset.DivorceScaled.values)
mcmc.print_summary()
samples_3 = mcmc.get_samples()
sample: 100%| | 3000/3000 [00:07<00:00, 389.02it/s, 7 steps of
˓→size 5.15e-01. acc. prob=0.92]
Number of divergences: 0
The regression plots above shows that the observed divorce rates for many states differs considerably from the mean re-
gression line. To dig deeper into how the last model (Model 3) under-predicts or over-predicts for each of the states, we
will plot the posterior predictive and residuals (Observed divorce rate - Predicted divorce rate)
for each of the states.
[22]: # Predictions for Model 3.
rng_key, rng_key_ = random.split(rng_key)
predictions_3 = Predictive(model, samples_3)(rng_key_,
marriage=dset.MarriageScaled.values,
age=dset.AgeScaled.values)['obs']
y = jnp.arange(50)
# Plot residuals
residuals_3 = dset.DivorceScaled.values - predictions_3
residuals_mean = jnp.mean(residuals_3, axis=0)
(continues on next page)
ax[1].plot(jnp.zeros(50), y, '--')
ax[1].errorbar(residuals_mean[idx], y, xerr=err[idx],
marker='o', ms=5, mew=4, ls='none', alpha=0.8)
ax[1].set(xlabel='Residuals', ylabel='State', title='Residuals with 90% CI')
ax[1].set_yticks(y)
ax[1].set_yticklabels(dset.Loc.values[idx], fontsize=10);
The plot on the left shows the mean predictions with 90% CI for each of the states using Model 3. The gray markers
indicate the actual observed divorce rates. The right plot shows the residuals for each of the states, and both these plots
are sorted by the residuals, i.e. at the bottom, we are looking at states where the model predictions are higher than the
observed rates, whereas at the top, the reverse is true.
Overall, the model fit seems good because most observed data points like within a 90% CI around the mean predictions.
However, notice how the model over-predicts by a large margin for states like Idaho (bottom left), and on the other
end under-predicts for states like Maine (top right). This is likely indicative of other factors that we are missing out in
our model that affect divorce rate across different states. Even ignoring other socio-political variables, one such factor
that we have not yet modeled is the measurement noise given by Divorce SE in the dataset. We will explore this in
the next section.
Note that in our previous models, each data point influences the regression line equally. Is this well justified? We
will build on the previous model to incorporate measurement error given by Divorce SE variable in the dataset.
Incorporating measurement noise will be useful in ensuring that observations that have higher confidence (i.e. lower
measurement noise) have a greater impact on the regression line. On the other hand, this will also help us better model
outliers with high measurement errors. For more details on modeling errors due to measurement noise, refer to Chapter
14 of [1].
To do this, we will reuse Model 3, with the only change that the final observed value has a measurement error given
by divorce_sd (notice that this has to be standardized since the divorce variable itself has been standardized to
mean 0 and std 1).
[24]: # Standardize
dset['DivorceScaledSD'] = dset['Divorce SE'] / jnp.std(dset.Divorce.values)
Number of divergences: 0
Notice that our values for the regression coefficients is very similar to Model 3. However, introducing measurement
noise allows us to more closely match our predictive distribution to the observed values. We can see this if we plot the
residuals as earlier.
[27]: sd = dset.DivorceScaledSD.values
residuals_4 = dset.DivorceScaled.values - predictions_4
residuals_mean = jnp.mean(residuals_4, axis=0)
residuals_hpdi = hpdi(residuals_4, 0.9)
err = residuals_hpdi[1] - residuals_mean
idx = jnp.argsort(residuals_mean)
y = jnp.arange(50)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 16))
# Plot Residuals
ax.plot(jnp.zeros(50), y, '--')
ax.errorbar(residuals_mean[idx], y, xerr=err[idx],
marker='o', ms=5, mew=4, ls='none', alpha=0.8)
# Plot SD
ax.errorbar(residuals_mean[idx], y, xerr=sd[idx],
ls='none', color='orange', alpha=0.9)
The plot above shows the residuals for each of the states, along with the measurement noise given by inner error bar.
The gray dots are the mean residuals from our earlier Model 3. Notice how having an additional degree of freedom
to model the measurement noise has shrunk the residuals. In particular, for Idaho and Maine, our predictions are now
much closer to the observed values after incorporating measurement noise in the model.
To better see how measurement noise affects the movement of the regression line, let us plot the residuals with respect
to the measurement noise.
The plot above shows what has happend in more detail - the regression line itself has moved to ensure a better fit for
observations with low measurement noise (left of the plot) where the residuals have shrunk very close to 0. That is to
say that data points with low measurement error have a concomitantly higher contribution in determining the regression
line. On the other hand, for states with high measurement error (right of the plot), incorporating measurement noise
allows us to move our posterior distribution mass closer to the observations resulting in a shrinkage of residuals as
well.
3.5 References
1. McElreath, R. (2016). Statistical Rethinking: A Bayesian Course with Examples in R and Stan CRC Press.
2. Stan Development Team. Stan User’s Guide
3. Goodman, N.D., and StuhlMueller, A. (2014). The Design and Implementation of Probabilistic Programming
Languages
4. Pyro Development Team. Poutine: A Guide to Programming with Effect Handlers in Pyro
5. Hoffman, M.D., Gelman, A. (2011). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian
Monte Carlo.
6. Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo.
7. JAX Development Team (2018). Composable transformations of Python+NumPy programs: differentiate, vec-
torize, JIT to GPU/TPU, and more
8. Gelman, A., Hwang, J., and Vehtari A. Understanding predictive information criteria for Bayesian models
Pulmonary fibrosis is a disorder with no known cause and no known cure, created by scarring of the lungs. In this
competition, we were asked to predict a patient’s severity of decline in lung function. Lung function is assessed based
on output from a spirometer, which measures the forced vital capacity (FVC), i.e. the volume of air exhaled.
In medical applications, it is useful to evaluate a model’s confidence in its decisions. Accordingly, the metric used to
rank the teams was designed to reflect both the accuracy and certainty of each prediction. It’s a modified version
of the Laplace Log Likelihood (more details on that later).
Let’s explore the data and see what’s that all about:
163
NumPyro Documentation
In the dataset, we were provided with a baseline chest CT scan and associated clinical information for a set of patients.
A patient has an image acquired at time Week = 0 and has numerous follow up visits over the course of approximately
1-2 years, at which time their FVC is measured. For this tutorial, I will use only the Patient ID, the weeks and the FVC
measurements, discarding all the rest. Using only these columns enabled our team to achieve a competitive score,
which shows the power of Bayesian hierarchical linear regression models especially when gauging uncertainty is an
important part of the problem.
Since this is real medical data, the relative timing of FVC measurements varies widely, as shown in the 3 sample
patients below:
On average, each of the 176 provided patients made 9 visits, when FVC was measured. The visits happened in specific
weeks in the [-12, 133] interval. The decline in lung capacity is very clear. We see, though, they are very different
from patient to patient.
We were are asked to predict every patient’s FVC measurement for every possible week in the [-12, 133] interval,
and the confidence for each prediction. In other words: we were asked fill a matrix like the one below, and provide a
confidence score for each prediction:
The task was perfect to apply Bayesian inference. However, the vast majority of solutions shared by Kaggle community
used discriminative machine learning models, disconsidering the fact that most discriminative methods are very poor at
providing realistic uncertainty estimates. Because they are typically trained in a manner that optimizes the parameters
to minimize some loss criterion (e.g. the predictive error), they do not, in general, encode any uncertainty in either
their parameters or the subsequent predictions. Though many methods can produce uncertainty estimates either as a
by-product or from a post-processing step, these are typically heuristic based, rather than stemming naturally from a
statistically principled estimate of the target uncertainty distribution [2].
The simplest possible linear regression, not hierarchical, would assume all FVC decline curves have the same 𝛼 and
𝛽. That’s the pooled model. In the other extreme, we could assume a model where each patient has a personalized
FVC decline curve, and these curves are completely unrelated. That’s the unpooled model, where each patient has
completely separate regressions.
Here, I’ll use the middle ground: Partial pooling. Specifically, I’ll assume that while 𝛼’s and 𝛽’s are different for
each patient as in the unpooled case, the coefficients all share similarity. We can model this by assuming that each
individual coefficient comes from a common group distribution. The image below represents this model graphically:
4.2. 2. Modelling: Bayesian Hierarchical Linear Regression with Partial Pooling 165
NumPyro Documentation
where t is the time in weeks. Those are very uninformative priors, but that’s ok: our model will converge!
Implementing this model in NumPyro is pretty straightforward:
unique_patient_IDs = np.unique(PatientID)
n_patients = len(unique_patient_IDs)
𝜎 = numpyro.sample("𝜎", dist.HalfNormal(100.))
FVC_est = 𝛼[PatientID] + 𝛽[PatientID] * Weeks
A great achievement of Probabilistic Programming Languages such as NumPyro is to decouple model specification
and inference. After specifying my generative model, with priors, condition statements and data likelihood, I can leave
the hard work to NumPyro’s inference engine.
Calling it requires just a few lines. Before we do it, let’s add a numerical Patient ID for each patient code. That can be
easily done with scikit-learn’s LabelEncoder:
[6]: from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
train['PatientID'] = le.fit_transform(train['Patient'].values)
FVC_obs = train['FVC'].values
Weeks = train['Weeks'].values
PatientID = train['PatientID'].values
posterior_samples = mcmc.get_samples()
sample: 100%| | 4000/4000 [00:20<00:00, 195.69it/s, 63 steps of
˓→size 1.06e-01. acc. prob=0.89]
First, let’s inspect the parameters learned. To do that, I will use ArviZ, which perfectly integrates with NumPyro:
data = az.from_numpyro(mcmc)
az.plot_trace(data, compact=True);
Looks like our model learned personalized alphas and betas for each patient!
Now, let’s visually inspect FVC decline curves predicted by our model. We will completely fill in the FVC table,
predicting all missing values. The first step is to create a table to fill:
[9]: pred_template = []
for i in range(train['Patient'].nunique()):
df = pd.DataFrame(columns=['PatientID', 'Weeks'])
df['Weeks'] = np.arange(-12, 134)
df['PatientID'] = i
pred_template.append(df)
pred_template = pd.concat(pred_template, ignore_index=True)
Predicting the missing values in the FVC table and confidence (sigma) for each value becomes really easy:
Let’s now put the predictions together with the true values, to visualize them:
FVC_sup FVC_true
0 2378.633545 NaN
1 2366.977051 NaN
2 2366.947021 NaN
3 2361.241455 NaN
4 2359.558594 NaN
4.4.3 4.3. Computing the modified Laplace Log Likelihood and RMSE
As mentioned earlier, the competition was evaluated on a modified version of the Laplace Log Likelihood. In medical
applications, it is useful to evaluate a model’s confidence in its decisions. Accordingly, the metric is designed to reflect
both the accuracy and certainty of each prediction.
For each true FVC measurement, we predicted both an FVC and a confidence measure (standard deviation 𝜎). The
metric was computed as:
averaging the metric across all (Patient, Week) pairs. Note that metric values will be negative and higher is better.
Next, we calculate the metric and RMSE:
[13]: y = df.dropna()
rmse = ((y['FVC_pred'] - y['FVC_true']) ** 2).mean() ** (1/2)
print(f'RMSE: {rmse:.1f} ml')
sigma_c = y['sigma'].values
sigma_c[sigma_c < 70] = 70
delta = (y['FVC_pred'] - y['FVC_true']).abs()
delta[delta > 1000] = 1000
lll = - np.sqrt(2) * delta / sigma_c - np.log(np.sqrt(2) * sigma_c)
print(f'Laplace Log Likelihood: {lll.mean():.4f}')
RMSE: 122.1 ml
Laplace Log Likelihood: -6.1376
What do these numbers mean? It means if you adopted this approach, you would outperform most of the public
solutions in the competition. Curiously, the vast majority of public solutions adopt a standard deterministic Neural
Network, modelling uncertainty through a quantile loss. Most of the people still adopt a frequentist approach.
Uncertainty for single predictions becomes more and more important in machine learning and is often a requirement.
Especially when the consequenses of a wrong prediction are high, we need to know what the probability distribution
of an individual prediction is. For perspective, Kaggle just launched a new competition sponsored by Lyft, to build
motion prediction models for self-driving vehicles. “We ask that you predict a few trajectories for every agent and
provide a confidence score for each of them.”
Finally, I hope the great work done by Pyro/NumPyro developers help democratize Bayesian methods, empowering
an ever growing community of researchers and practitioners to create models that can not only generate predictions,
but also assess uncertainty in their predictions.
4.5 References
1. Ghahramani, Z. Probabilistic machine learning and artificial intelligence. Nature 521, 452–459 (2015). https:
//doi.org/10.1038/nature14541
2. Rainforth, Thomas William Gamlen. Automating Inference, Learning, and Design Using Probabilistic Pro-
gramming. University of Oxford, 2017.
173
NumPyro Documentation
import argparse
import os
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import BASEBALL, load_dataset
from numpyro.infer import HMC, MCMC, NUTS, SA, Predictive, log_likelihood
175
NumPyro Documentation
def main(args):
_, fetch_train = load_dataset(BASEBALL, split='train', shuffle=False)
train, player_names = fetch_train()
_, fetch_test = load_dataset(BASEBALL, split='test', shuffle=False)
test, _ = fetch_test()
at_bats, hits = train[:, 0], train[:, 1]
season_at_bats, season_hits = test[:, 0], test[:, 1]
for i, model in enumerate((fully_pooled,
not_pooled,
partially_pooled,
partially_pooled_with_logit,
)):
rng_key, rng_key_predict = random.split(random.PRNGKey(i + 1))
zs = run_inference(model, at_bats, hits, rng_key, args)
predict(model, at_bats, hits, zs, rng_key_predict, player_names)
predict(model, season_at_bats, season_hits, zs, rng_key_predict, player_names,
˓→ train=False)
if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Baseball batting average using MCMC
˓→")
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
import argparse
import inspect
import os
import time
import numpyro
from numpyro import optim
import numpyro.distributions as dist
from numpyro.examples.datasets import MNIST, load_dataset
from numpyro.infer import SVI, Trace_ELBO
RESULTS_DIR = os.path.abspath(os.path.join(os.path.dirname(inspect.getfile(lambda:
˓→None)),
'.results'))
os.makedirs(RESULTS_DIR, exist_ok=True)
177
NumPyro Documentation
@jit
def binarize(rng_key, batch):
return random.bernoulli(rng_key, batch).astype(batch.dtype)
def main(args):
encoder_nn = encoder(args.hidden_dim, args.z_dim)
decoder_nn = decoder(args.hidden_dim, 28 * 28)
adam = optim.Adam(args.learning_rate)
svi = SVI(model, guide, adam, Trace_ELBO(), hidden_dim=args.hidden_dim, z_
˓→dim=args.z_dim)
rng_key = PRNGKey(0)
train_init, train_fetch = load_dataset(MNIST, batch_size=args.batch_size, split=
˓→'train')
@jit
def epoch_train(svi_state, rng_key):
def body_fn(i, val):
loss_sum, svi_state = val
rng_key_binarize = random.fold_in(rng_key, i)
batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
svi_state, loss = svi.update(svi_state, batch)
loss_sum += loss
return loss_sum, svi_state
@jit
def eval_test(svi_state, rng_key):
def body_fun(i, loss_sum):
rng_key_binarize = random.fold_in(rng_key, i)
batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
# FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
loss = svi.evaluate(svi_state, batch) / len(batch)
loss_sum += loss
return loss_sum
z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28])
plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)),
˓→img_loc, cmap='gray')
for i in range(args.num_epochs):
rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(rng_
˓→key, 4)
t_start = time.time()
num_train, train_idx = train_init()
_, svi_state = epoch_train(svi_state, rng_key_train)
rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3)
num_test, test_idx = test_init()
test_loss = eval_test(svi_state, rng_key_test)
reconstruct_img(i, rng_key_reconstruct)
print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_
˓→start))
if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-epochs', default=15, type=int, help='number of
˓→training epochs')
args = parser.parse_args()
main(args)
179
NumPyro Documentation
This example, which is adapted from [1], illustrates how to leverage non-centered parameterization using the
reparam handler. We will examine the difference between two types of parameterizations on the 10-dimensional
Neal’s funnel distribution. As we will see, HMC gets trouble at the neck of the funnel if centered parameterization is
used. On the contrary, the problem can be solved by using non-centered parameterization.
Using non-centered parameterization through LocScaleReparam or TransformReparam in NumPyro has the
same effect as the automatic reparameterisation technique introduced in [2].
References:
1. Stan User’s Guide, https://mc-stan.org/docs/2_19/stan-users-guide/reparameterization-section.html
2. Maria I. Gorinova, Dave Moore, Matthew D. Hoffman (2019), “Automatic Reparameterisation of Probabilistic
Programs”, (https://arxiv.org/abs/1906.03028)
181
NumPyro Documentation
import argparse
import os
(continues on next page)
import numpyro
import numpyro.distributions as dist
from numpyro.handlers import reparam
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam
def model(dim=10):
y = numpyro.sample('y', dist.Normal(0, 3))
numpyro.sample('x', dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))
def main(args):
rng_key = random.PRNGKey(0)
random.PRNGKey(1))
# make plots
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(8, 8), constrained_
˓→layout=True)
183
NumPyro Documentation
plt.savefig('funnel_plot.pdf')
if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Non-centered reparameterization
˓→example")
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
Generative model:
𝜎 ∼ Exponential(50) (8.1)
𝜈 ∼ Exponential(.1) (8.2)
−2
𝑠𝑖 ∼ Normal(𝑠𝑖−1 , 𝜎 ) (8.3)
𝑟𝑖 ∼ StudentT(𝜈, 0, exp(𝑠𝑖 )) (8.4)
This example is from PyMC3 [1], which itself is adapted from the original experiment from [2]. A discussion about
translating this in Pyro appears in [3].
We take this example to illustrate how to use the functional interface hmc. However, we recommend readers to use
MCMC class as in other examples because it is more stable and has more features supported.
References:
1. Stochastic Volatility Model, https://docs.pymc.io/notebooks/stochastic_volatility.html
2. The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo, https://arxiv.org/pdf/
1111.4246.pdf
3. Pyro forum discussion, https://forum.pyro.ai/t/problems-transforming-a-pymc3-model-to-pyro-mcmc/208/14
185
NumPyro Documentation
import argparse
import os
import matplotlib
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import SP500, load_dataset
from numpyro.infer.hmc import hmc
from numpyro.infer.util import initialize_model
from numpyro.util import fori_collect
(continues on next page)
def model(returns):
step_size = numpyro.sample('sigma', dist.Exponential(50.))
s = numpyro.sample('s', dist.GaussianRandomWalk(scale=step_size, num_steps=jnp.
˓→shape(returns)[0]))
nu = numpyro.sample('nu', dist.Exponential(.1))
return numpyro.sample('r', dist.StudentT(df=nu, loc=0., scale=jnp.exp(s)),
obs=returns)
def main(args):
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed))
model_info = initialize_model(init_rng_key, model, model_args=(returns,))
init_kernel, sample_kernel = hmc(model_info.potential_fn, algo='NUTS')
hmc_state = init_kernel(model_info.param_info, args.num_warmup, rng_key=sample_
˓→rng_key)
print_results(hmc_states, dates)
187
NumPyro Documentation
plt.savefig("stochastic_volatility_plot.pdf")
if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Stochastic Volatility Model")
parser.add_argument('-n', '--num-samples', nargs='?', default=600, type=int)
parser.add_argument('--num-warmup', nargs='?', default=600, type=int)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".
˓→')
args = parser.parse_args()
numpyro.set_platform(args.device)
main(args)
In this example, we run MCMC for various crowdsourced annotation models in [1].
All models have discrete latent variables. Under the hood, we enumerate over (marginalize out) those discrete latent
sites in inference. Those models have different complexity so they are great refererences for those who are new to
Pyro/NumPyro enumeration mechanism. We recommend readers compare the implementations with the correspond-
ing plate diagrams in [1] to see how concise a Pyro/NumPyro program is.
The interested readers can also refer to [3] for more explanation about enumeration.
The data is taken from Table 1 of reference [2].
Currently, this example does not include postprocessing steps to deal with “Label Switching” issue (mentioned in
section 6.2 of [1]).
References:
1. Paun, S., Carpenter, B., Chamberlain, J., Hovy, D., Kruschwitz, U., and Poesio, M. (2018). “Comparing
bayesian models of annotation” (https://www.aclweb.org/anthology/Q18-1040/)
2. Dawid, A. P., and Skene, A. M. (1979). “Maximum likelihood estimation of observer error-rates using the EM
algorithm”
3. “Inference with Discrete Latent Variables” (http://pyro.ai/examples/enumeration.html)
import argparse
import os
import numpy as np
import numpyro
from numpyro import handlers
from numpyro.contrib.indexing import Vindex
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
(continues on next page)
189
NumPyro Documentation
def get_data():
"""
:return: a tuple of annotator indices and class indices. The first term has shape
`num_positions` whose entries take values from `0` to `num_annotators - 1`.
The second term has shape `num_items x num_positions` whose entries take
˓→values
def multinomial(annotations):
"""
This model corresponds to the plate diagram in Figure 1 of reference [1].
"""
num_classes = int(np.max(annotations)) + 1
num_items, num_positions = annotations.shape
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
# here we use Vindex to allow broadcasting for the second index `c`
# ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.
˓→vindex
191
NumPyro Documentation
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
def item_difficulty(annotations):
"""
This model corresponds to the plate diagram in Figure 5 of reference [1].
"""
num_classes = int(np.max(annotations)) + 1
num_items, num_positions = annotations.shape
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
NAME_TO_MODEL = {
"mn": multinomial,
"ds": dawid_skene,
"mace": mace,
"hds": hierarchical_dawid_skene,
"id": item_difficulty,
"lre": logistic_random_effects,
}
def main(args):
annotators, annotations = get_data()
model = NAME_TO_MODEL[args.model]
data = (annotations,) if model in [multinomial, item_difficulty] else (annotators,
˓→ annotations)
mcmc = MCMC(
NUTS(model),
args.num_warmup,
args.num_samples,
num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
)
(continues on next page)
193
NumPyro Documentation
if __name__ == "__main__":
assert numpyro.__version__.startswith("0.5.0")
parser = argparse.ArgumentParser(description="Bayesian Models of Annotation")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument(
"--model",
nargs="?",
default="ds",
help='one of "mn" (multinomial), "ds" (dawid_skene), "mace",'
' "hds" (hierarchical_dawid_skene),'
' "id" (item_difficulty), "lre" (logistic_random_effects)',
)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".
˓→')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
This example is ported from [1], which shows how to marginalize out discrete model variables in Pyro.
This combines MCMC with a variable elimination algorithm, where we use enumeration to exactly marginalize out
some variables from the joint density.
To marginalize out discrete variables x:
1. Verify that the variable dependency structure in your model admits tractable inference, i.e. the dependency
graph among enumerated variables should have narrow treewidth.
2. Ensure your model can handle broadcasting of the sample values of those variables.
Note that difference from [1], which uses Python loop, here we use scan() to reduce compilation times (only one
step needs to be compiled) of the model. Under the hood, scan stacks all the priors’ parameters and values into an
additional time dimension. This allows us computing the joint density in parallel. In addition, the stacked form allows
us to use the parallel-scan algorithm in [2], which reduces parallel complexity from O(length) to O(log(length)).
Data are taken from [3]. However, the original source of the data seems to be the Institut fuer Algorithmen und
Kognitive Systeme at Universitaet Karlsruhe.
References:
1. Pyro’s Hidden Markov Model example, (https://pyro.ai/examples/hmm.html)
2. Temporal Parallelization of Bayesian Smoothers, Simo Sarkka, Angel F. Garcia-Fernandez (https://arxiv.org/
abs/1905.13002)
3. Modeling Temporal Dependencies in High-Dimensional Sequences: Application to Polyphonic Music Genera-
tion and Transcription, Boulanger-Lewandowski, N., Bengio, Y. and Vincent, P.
4. Tensor Variable Elimination for Plated Factor Graphs, Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin
Chiu, Neeraj Pradhan, Alexander Rush, Noah Goodman (https://arxiv.org/abs/1902.03210)
import argparse
import logging
import os
import time
(continues on next page)
195
NumPyro Documentation
import numpyro
from numpyro.contrib.control_flow import scan
from numpyro.contrib.indexing import Vindex
import numpyro.distributions as dist
from numpyro.examples.datasets import JSB_CHORALES, load_dataset
from numpyro.handlers import mask
from numpyro.infer import HMC, MCMC, NUTS
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
probs_y = numpyro.sample("probs_y",
dist.Beta(0.1, 0.9)
.expand([args.hidden_dim, 2, data_dim])
.to_event(3))
197
NumPyro Documentation
Next let’s consider a second-order HMM model in which x[t+1] depends on both x[t] and x[t-1].
# _______>______
# _____>_____/______ \
# / / \ \
# x[t-1] --> x[t] --> x[t+1] --> x[t+2]
# | | | |
# V V V V
# y[t-1] y[t] y[t+1] y[t+2]
#
# Note that in this model (in contrast to the previous model) we treat
# the transition and emission probabilities as parameters (so they have no prior).
#
# Note that this is the "2HMM" model in reference [4].
def model_6(sequences, lengths, args, include_prior=False):
num_sequences, max_length, data_dim = sequences.shape
with mask(mask=include_prior):
# Explicitly parameterize the full tensor of transition probabilities, which
# has hidden_dim cubed entries.
probs_x = numpyro.sample("probs_x",
dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1)
.expand([args.hidden_dim, args.hidden_dim])
.to_event(2))
probs_y = numpyro.sample("probs_y",
dist.Beta(0.1, 0.9)
.expand([args.hidden_dim, data_dim])
.to_event(2))
Do inference
199
NumPyro Documentation
def main(args):
model = models[args.model]
logger.info('-' * 40)
logger.info('Training {} on {} sequences'.format(
model.__name__, len(sequences)))
# find all the notes that are present at least once in the training set
present_notes = ((sequences == 1).sum(0).sum(0) > 0)
# remove notes that are never played (we remove 37/88 notes with default args)
sequences = sequences[..., present_notes]
if args.truncate:
lengths = lengths.clip(0, args.truncate)
sequences = sequences[:, :args.truncate]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="HMC for HMMs")
parser.add_argument("-m", "--model", default="1", type=str,
help="one of: {}".format(", ".join(sorted(models.keys()))))
parser.add_argument('-n', '--num-samples', nargs='?', default=1000, type=int)
parser.add_argument("-d", "--hidden-dim", default=16, type=int)
parser.add_argument('-t', "--truncate", type=int)
parser.add_argument("--num-sequences", type=int)
parser.add_argument("--kernel", default='nuts', type=str)
parser.add_argument('--num-warmup', nargs='?', default=500, type=int)
parser.add_argument("--num-chains", nargs='?', default=1, type=int)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".
˓→')
args = parser.parse_args()
numpyro.set_platform(args.device)
(continues on next page)
main(args)
201
NumPyro Documentation
import argparse
import os
203
NumPyro Documentation
import numpyro
from numpyro import handlers
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist
from numpyro.examples.datasets import DIPPER_VOLE, load_dataset
from numpyro.infer import HMC, MCMC, NUTS
from numpyro.infer.reparam import LocScaleReparam
Our first and simplest CJS model variant only has two continuous (scalar) latent random variables: i) the survival prob-
ability phi; and ii) the recapture probability rho. These are treated as fixed effects with no temporal or individual/group
variation.
z = jnp.ones(N, dtype=jnp.int32)
# we use this mask to eliminate extraneous log probabilities
# that arise for a given individual before its first capture.
first_capture_mask = capture_history[:, 0].astype(bool)
# NB swapaxes: we move time dimension of `capture_history` to the front to scan
˓→over it
In our second model variant there is a time-varying survival probability phi_t for T-1 of the T time periods of the
capture data; each phi_t is treated as a fixed effect.
204 Chapter 11. Example: CJS Capture-Recapture Model for Ecological Data
NumPyro Documentation
z = jnp.ones(N, dtype=jnp.int32)
# we use this mask to eliminate extraneous log probabilities
# that arise for a given individual before its first capture.
first_capture_mask = capture_history[:, 0].astype(bool)
# NB swapaxes: we move time dimension of `capture_history` to the front to scan
˓→over it
In our third model variant there is a survival probability phi_t for T-1 of the T time periods of the capture data (just
like in model_2), but here each phi_t is treated as a random effect.
def model_3(capture_history, sex):
N, T = capture_history.shape
phi_mean = numpyro.sample("phi_mean", dist.Uniform(0.0, 1.0)) # mean survival
˓→probability
phi_logit_mean = logit(phi_mean)
# controls temporal variability of survival probability
phi_sigma = numpyro.sample("phi_sigma", dist.Uniform(0.0, 10.0))
rho = numpyro.sample("rho", dist.Uniform(0.0, 1.0)) # recapture probability
phi_t = expit(phi_logit_t)
with numpyro.plate("animals", N, dim=-1):
with handlers.mask(mask=first_capture_mask):
mu_z_t = first_capture_mask * phi_t * z + (1 - first_capture_mask)
# NumPyro exactly sums out the discrete states z_t.
z = numpyro.sample("z", dist.Bernoulli(dist.util.clamp_probs(mu_z_t)))
mu_y_t = rho * z
numpyro.sample("y", dist.Bernoulli(dist.util.clamp_probs(mu_y_t)),
˓→obs=y)
z = jnp.ones(N, dtype=jnp.int32)
# we use this mask to eliminate extraneous log probabilities
# that arise for a given individual before its first capture.
first_capture_mask = capture_history[:, 0].astype(bool)
# NB swapaxes: we move time dimension of `capture_history` to the front to scan
˓→over it
(continues on next page)
205
NumPyro Documentation
In our fourth model variant we include group-level fixed effects for sex (male, female).
z = jnp.ones(N, dtype=jnp.int32)
# we use this mask to eliminate extraneous log probabilities
# that arise for a given individual before its first capture.
first_capture_mask = capture_history[:, 0].astype(bool)
# NB swapaxes: we move time dimension of `capture_history` to the front to scan
˓→over it
In our final model variant we include both fixed group effects and fixed time effects for the survival probability phi:
logit(phi_t) = beta_group + gamma_t We need to take care that the model is not overparameterized; to do this we
effectively let a single scalar beta encode the difference in male and female survival probabilities.
206 Chapter 11. Example: CJS Capture-Recapture Model for Ecological Data
NumPyro Documentation
z = jnp.ones(N, dtype=jnp.int32)
# we use this mask to eliminate extraneous log probabilities
# that arise for a given individual before its first capture.
first_capture_mask = capture_history[:, 0].astype(bool)
# NB swapaxes: we move time dimension of `capture_history` to the front to scan
˓→over it
Do inference
models = {name[len('model_'):]: model
for name, model in globals().items()
if name.startswith('model_')}
def main(args):
# load data
if args.dataset == "dipper":
capture_history, sex = load_dataset(DIPPER_VOLE, split='dipper',
˓→shuffle=False)[1]()
N, T = capture_history.shape
print("Loaded {} capture history for {} individuals collected over {} time
˓→periods.".format(
(continues on next page)
207
NumPyro Documentation
model = models[args.model]
rng_key = random.PRNGKey(args.rng_seed)
run_inference(model, capture_history, sex, rng_key, args)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="CJS capture-recapture model for
˓→ecological data")
208 Chapter 11. Example: CJS Capture-Recapture Model for Ecological Data
CHAPTER 12
Missing data is a very widespread problem in practical applications, both in covariates (‘explanatory variables’) and
outcomes. When performing bayesian inference with MCMC, imputing discrete missing values is not possible using
Hamiltonian Monte Carlo techniques. One way around this problem is to create a new model that enumerates the
discrete variables and does inference over the new model, which, for a single discrete variable, is a mixture model.
(see e.g. Stan’s user guide on Latent Discrete Parameters) Enumerating the discrete latent sites requires some manual
math work that can get tedious for complex models. Inference by automatic enumeration of discrete variables is
implemented in numpyro and allows for a very convenient way of dealing with missing discrete data.
[1]: import numpyro
from jax import numpy as jnp, random, ops
from jax.scipy.special import expit
from numpyro import distributions as dist, sample
from numpyro.infer.mcmc import MCMC
from numpyro.infer.hmc import NUTS
from math import inf
from graphviz import Digraph
First we will simulate data with correlated binary covariates. The assumption is that we wish to estimate parameter
for some parametric model without bias (e.g. for inferring a causal effect). For several different missing data patterns
we will see how to impute the values to lead to unbiased models.
The basic data structure is as follows. Z is a latent variable that gives rise to the marginal dependence between A and
B, the observed covariates. We will consider different missing data mechanisms for variable A, where variable B and
Y are fully observed. The effects of A and B on Y are the effects of interest.
[2]: dot = Digraph()
dot.node('A')
dot.node('B')
dot.node('Z')
dot.node('Y')
(continues on next page)
209
NumPyro Documentation
210 Chapter 12. Bayesian Imputation for Missing Values in Discrete Covariates
NumPyro Documentation
[2]:
211
NumPyro Documentation
According to Rubin’s classic definitions there are 3 distinct of missing data mechanisms:
1. missing completely at random (MCAR)
2. missing at random, conditional on observed data (MAR)
3. missing not at random, even after conditioning on observed data (MNAR)
Missing data mechanisms 1. and 2. are ‘easy’ to handle as they depend on observed data only. Mechanism 3. (MNAR)
is trickier as it depends on data that is not observed, but may still be relevant to the outcome you are modeling (see
below for a concrete example).
First we will generate missing values in A, conditional on the value of Y (thus it is a MAR mechanism).
212 Chapter 12. Bayesian Imputation for Missing Values in Discrete Covariates
NumPyro Documentation
[4]:
This graph depicts the datagenerating mechanism, where Y is the only cause of missingness in A, denoted M. This
means that the missingness in M is random, conditional on Y.
As an example consider this simplified scenario:
• A represents a history of heart illness
• B represents the age of a patient
• Y represents whether or not the patient will visit the general practitioner
A general practitioner wants to find out why patients that are assigned to her clinic will visit the clinic or not. She
thinks that having a history of heart illness and age are potential causes of doctor visits. Data on patient ages are
available through their registration forms, but information on prior heart illness may be availalbe only after they have
visited the clinic. This makes the missingness in A (history of heart disease), dependent on the outcome (visiting the
clinic).
[5]: A_isobs = random.bernoulli(simkeys[4], expit(3*(Y - Y.mean())))
Aobs = jnp.where(A_isobs, A, -1)
A_obsidx = jnp.where(A_isobs)
Number of divergences: 0
214 Chapter 12. Bayesian Imputation for Missing Values in Discrete Covariates
NumPyro Documentation
# cancel out enumerated values that are not equal to observed values
log_prob = jnp.where(A_isobs & (Aimp != A), -inf, log_prob)
Number of divergences: 0
As we can see, when data are missing conditionally on Y, imputation leads to consistent estimation of the parameter
of interest (b_A and b_B).
When data are missing conditional on unobserved data, things get more tricky. Here we will generate missing values
in A, conditional on the value of A itself (missing not at random (MNAR), but missing at random conditional on A).
As an example consider patients who have cancer:
216 Chapter 12. Bayesian Imputation for Missing Values in Discrete Covariates
NumPyro Documentation
[10]:
Number of divergences: 0
Number of divergences: 0
Perhaps surprisingly, imputing missing values when the missingness mechanism depends on the variable itself will
actually lead to bias, while complete case analysis is unbiased! See e.g. Bias and efficiency of multiple imputation
compared with complete-case analysis for missing covariate values.
However, complete case analysis may be undesirable as well. E.g. due to leading to lower precision in estimating
the parameter from B to Y, or maybe when there is an expected difference interaction between the value of A and the
parameter from A to Y. To deal with this situation, an explicit model for the reason of missingness (/observation) is
required. We will add one below.
[14]: def impmissmodel(A, B, Y):
ntotal = A.shape[0]
A_isobs = A >= 0
218 Chapter 12. Bayesian Imputation for Missing Values in Discrete Covariates
NumPyro Documentation
# cancel out enumerated values that are not equal to observed values
log_prob = jnp.where(A_isobs & (Aimp != A), -inf, log_prob)
Number of divergences: 0
We can now estimate the parameters b_A and b_B without bias, while still utilizing all observations. Obviously,
modeling the missingness mechanism relies on assumptions that need either be substantiated with prior evidence, or
possibly analyzed through sensitivity analysis.
For more reading on missing data in bayesian inference, see:
• Presentation Bayesian Methods for missing data (pdf)
• Bayesian Approaches for Missing Not at Random Outcome Data: The Role of Identifying Restrictions
(doi:10.1214/17-STS630)
220 Chapter 12. Bayesian Imputation for Missing Values in Discrete Covariates
CHAPTER 13
In this tutorial, we will demonstrate how to build a model for time series forecasting in NumPyro. Specifically, we
will replicate the Seasonal, Global Trend (SGT) model from the Rlgt: Bayesian Exponential Smoothing Models with
Trend Modifications package. The time series data that we will use for this tutorial is the lynx dataset, which contains
annual numbers of lynx trappings from 1821 to 1934 in Canada.
[1]: import os
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro.diagnostics import autocorrelation, hpdi
from numpyro.infer import MCMC, NUTS, Predictive
if "NUMPYRO_SPHINXBUILD" in os.environ:
set_matplotlib_formats("svg")
numpyro.set_host_device_count(4)
assert numpyro.__version__.startswith("0.5.0")
13.1 Data
221
NumPyro Documentation
The time series has a length of 114 (a data point for each year), and by looking at the plot, we can observe seasonality
in this dataset, which is the recurrence of similar patterns at specific time periods. e.g. in this dataset, we observe a
cyclical pattern every 10 years, but there is also a less obvious but clear spike in the number of trappings every 40
years. Let us see if we can model this effect in NumPyro.
In this tutorial, we will use the first 80 values for training and the last 34 values for testing.
13.2 Model
The model we are going to use is called Seasonal, Global Trend, which when tested on 3003 time series of the M-3
competition, has been known to outperform other models originally participating in the competition:
A more detailed explanation for SGT model can be found in this vignette from the authors of the Rlgt package. Here
we summarize the core ideas of this model:
• Student’s t-distribution, which has heavier tails than normal distribution, is used for the likelihood.
• The expected value exp_val consists of a trending component and a seasonal component:
• The trend is governed by the map 𝑥 ↦→ 𝑥 + 𝑎𝑥𝑏 , where 𝑥 is level, 𝑎 is coef_trend, and 𝑏 is pow_trend.
Note that when 𝑏 ∼ 0, the trend is linear with 𝑎 is the slope, and when 𝑏 ∼ 1, the trend is exponential with 𝑎 is
the rate. So that function can cover a large family of trend.
• When time changes, level and s are updated to new values. Coefficients level_sm and s_sm are used to
make the transition smoothly.
• When powx is near 0, the error 𝜎𝑡 will be nearly constant while when powx is near 1, the error will be propo-
tional to the expected value.
• There are several varieties of SGT. In this tutorial, we use generalized seasonality and seasonal average method.
We are ready to specify the model using NumPyro primitives. In NumPyro, we use the primitive sample(name,
prior) to declare a latent random variable with a corresponding prior. These primitives can have custom inter-
pretations depending on the effect handlers that are used by NumPyro inference algorithms in the backend. e.g. we
can condition on specific values using the condition handler, or record values at these sample sites in the execu-
tion trace using the trace handler. Note that these details are not important for specifying the model, or running
inference, but curious readers are encouraged to read the tutorial on effect handlers in Pyro.
[4]: def sgt(y, seasonality, future=0):
# heuristically, standard derivation of Cauchy prior depends on
# the max value of data
cauchy_sd = jnp.max(y) / 150
moving_sum = (
moving_sum + y[t] - jnp.where(t >= seasonality, y[t - seasonality], 0.0)
(continues on next page)
N = y.shape[0]
level_init = y[0]
s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
moving_sum = level_init
with numpyro.handlers.condition(data={"y": y[1:]}):
_, ys = scan(
transition_fn, (level_init, s_init, moving_sum), jnp.arange(1, N + future)
)
if future > 0:
numpyro.deterministic("y_forecast", ys[-future:])
Note that level and s are updated recursively while we collect the expected value at each time step. NumPyro uses
JAX in the backend to JIT compile many critical parts of the NUTS algorithm, including the verlet integrator and the
tree building process. However, doing so using Python’s for loop in the model will result in a long compilation time
for the model, so we use scan - which is a wrapper of lax.scan with supports for NumPyro primitives and handlers. A
detailed explanation for using this utility can be found in NumPyro documentation. Here we use it to collect y values
while the triple (level, s, moving_sum) plays the role of carrying state.
Another note is that instead of declaring the observation site y in transition_fn
, we have used condition handler here. The reason is we also want to use this model for forecasting. In forecasting,
future values of y are non-observable, so obs=y[t] does not make sense when t >= len(y) (caution: index
out-of-bound errors do not get raised in JAX, e.g. jnp.arange(3)[10] == 2). Using condition, when the
length of scan is larger than the length of the conditioned/observed site, unobserved values will be sampled from the
distribution of that site.
13.3 Inference
First, we want to choose a good value for seasonality. Following the demo in Rlgt, we will set
seasonality=38. Indeed, this value can be guessed by looking at the plot of the training data, where the sec-
ond order seasonality effect has a periodicity around 40 years. Note that 38 is also one of the highest-autocorrelation
lags.
[ 0 67 57 38 68 1 29 58 37 56 28 10 19 39 66 78 47 77 9 79 48 76 30 18
20 11 46 59 69 27 55 36 2 8 40 49 17 21 75 12 65 45 31 26 7 54 35 41
50 3 22 60 70 16 44 13 6 25 74 53 42 32 23 43 51 4 15 14 34 24 5 52
73 64 33 71 72 61 63 62]
Now, let us run 4 MCMC chains (using the No-U-Turn Sampler algorithm) with 5000 warmup steps and 5000 sampling
steps per each chain. The returned value will be a collection of 20000 samples.
[6]: %%time
kernel = NUTS(sgt)
mcmc = MCMC(kernel, num_warmup=5000, num_samples=5000, num_chains=4)
mcmc.run(random.PRNGKey(0), y_train, seasonality=38)
mcmc.print_summary()
samples = mcmc.get_samples()
13.4 Forecasting
Given samples from mcmc, we want to do forecasting for the testing dataset y_test. NumPyro provides a conve-
nient utility Predictive to get predictive distribution. Let’s see how to use it to get forecasting values.
Notice that in the sgt model defined above, there is a keyword future which controls the execution of the model
- depending on whether future > 0 or future == 0. The following code predicts the last 34 values from the
original time-series.
Let’s get sMAPE, root mean square error of the prediction, and visualize the result with the mean prediction and the
90% highest posterior density interval (HPDI).
Finally, let’s plot the result to verify that we get the expected one.
As we can observe, the model has been able to learn both the first and second order seasonality effects, i.e. a cyclical
pattern with a periodicity of around 10, as well as spikes that can be seen once every 40 or so years. Moreover, we
not only have point estimates for the forecast but can also use the uncertainty estimates from the model to bound our
forecasts.
13.5 Acknowledgements
We would like to thank Slawek Smyl for many helpful resources and suggestions. Fast inference would not have been
possible without the support of JAX and the XLA teams, so we would like to thank them for providing such a great
open-source platform for us to build on, and for their responsiveness in dealing with our feature requests and bug
reports.
13.6 References
[1] Rlgt: Bayesian Exponential Smoothing Models with Trend Modifications, Slawek
Smyl, Christoph Bergmeir, Erwin Wibowo, To Wang Ng, Trustees of Columbia University
Ordinal Regression
Some data are discrete but instrinsically ordered, these are called **ordinal** data. One example is the likert scale
for questionairs (“this is an informative tutorial”: 1. strongly disagree / 2. disagree / 3. neither agree nor disagree /
4. agree / 5. strongly agree). Ordinal data is also ubiquitous in the medical world (e.g. the Glasgow Coma Scale for
measuring neurological disfunctioning).
This poses a challenge for statistical modeling as the data do not fit the most well known modelling approaches (e.g.
linear regression). Modeling the data as categorical is one possibility, but it disregards the inherent ordering in the
data, and may be less statistically efficient. There are multiple appoaches for modeling ordered data. Here we will
show how to use the OrderedLogistic distribution using cutpoints that are sampled from a Normal distribution with
as additional constrain that the cutpoints they are ordered. For a more in-depth discussion of Bayesian modeling of
ordinal data, see e.g. Michael Betancour’s blog
[1]: from jax import numpy as np, random
import numpyro
from numpyro import sample
from numpyro.distributions import (Categorical, ImproperUniform, Normal,
˓→OrderedLogistic,
229
NumPyro Documentation
for i in range(nclasses):
print(f"mean(X) for Y == {i}: {X[np.where(Y==i)].mean():.3f}")
value counts of Y:
1 19
2 16
0 15
Name: Y, dtype: int64
mean(X) for Y == 0: 0.042
mean(X) for Y == 1: 0.832
mean(X) for Y == 2: 1.448
We will model the outcomes Y as coming from an OrderedLogistic distribution, conditional on X. The
OrderedLogistic distribution in numpyro requires ordered cutpoints. We can use the ImproperUnifrom
distribution to introduce a parameter with an arbitrary support that is otherwise completely uninformative, and then
add an ordered_vector constraint.
[4]: def model1(X, Y, nclasses=3):
b_X_eta = sample('b_X_eta', Normal(0, 5))
c_y = sample('c_y', ImproperUniform(support=constraints.ordered_vector,
batch_shape=(),
event_shape=(nclasses-1,)))
with numpyro.plate('obs', X.shape[0]):
eta = X * b_X_eta
sample('Y', OrderedLogistic(eta, c_y), obs=Y)
mcmc_key = random.PRNGKey(1234)
kernel = NUTS(model1)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, X,Y, nclasses)
mcmc.print_summary()
sample: 100%| | 1000/1000 [00:07<00:00, 126.55it/s, 7 steps of
˓→size 4.34e-01. acc. prob=0.95]
Number of divergences: 0
The ImproperUniform distribution allows us to use parameters with constraints on their domain, without adding
any additional information e.g. about the location or scale of the prior distribution on that parameter.
If we want to incorporate such information, for instance that the values of the cut-points should not be too far from
zero, we can add an additional sample statement that uses another prior, coupled with an obs argument. In the
example below we first sample cutpoints c_y from the ImproperUniform distribution with constraints.
ordered_vector as before, and then sample a dummy parameter from a Normal distribution while conditioning
on c_y using obs=c_y. Effectively, we’ve created an improper / unnormalized prior that results from restricting the
support of a Normal distribution to the ordered domain
[5]: def model2(X, Y, nclasses=3):
b_X_eta = sample('b_X_eta', Normal(0, 5))
c_y = sample('c_y', ImproperUniform(support=constraints.ordered_vector,
batch_shape=(),
event_shape=(nclasses-1,)))
sample('c_y_smp', Normal(0,1), obs=c_y)
with numpyro.plate('obs', X.shape[0]):
eta = X * b_X_eta
sample('Y', OrderedLogistic(eta, c_y), obs=Y)
kernel = NUTS(model2)
mcmc = MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(mcmc_key, X,Y, nclasses)
mcmc.print_summary()
sample: 100%| | 1000/1000 [00:03<00:00, 315.02it/s, 7 steps of
˓→size 4.80e-01. acc. prob=0.94]
Number of divergences: 0
If having a proper prior for those cutpoints c_y is desirable (e.g. to sample from that prior and get prior predictive),
we can use TransformedDistribution with an OrderedTransform transform as follows.
[6]: def model3(X, Y, nclasses=3):
b_X_eta = sample('b_X_eta', Normal(0, 5))
c_y = sample("c_y", TransformedDistribution(Normal(0, 1).expand([nclasses
˓→- 1]),
transforms.
˓→OrderedTransform()))
kernel = NUTS(model3)
(continues on next page)
231
NumPyro Documentation
Number of divergences: 0
Bayesian Imputation
Real-world datasets often contain many missing values. In those situations, we have to either remove those miss-
ing data (also known as “complete case”) or replace them by some values. Though using complete case is pretty
straightforward, it is only applicable when the number of missing entries is so small that throwing away those entries
would not affect much the power of the analysis we are conducting on the data. The second strategy, also known as
imputation, is more applicable and will be our focus in this tutorial.
Probably the most popular way to perform imputation is to fill a missing value with the mean, median, or mode of its
corresponding feature. In that case, we implicitly assume that the feature containing missing values has no correlation
with the remaining features of our dataset. This is a pretty strong assumption and might not be true in general. In
addition, it does not encode any uncertainty that we might put on those values. Below, we will construct a Bayesian
setting to resolve those issues. In particular, given a model on the dataset, we will
• create a generative model for the feature with missing value
• and consider missing values as unobserved latent variables.
[1]: # first, we need some imports
import os
import numpyro
from numpyro import distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, Predictive
plt.style.use("seaborn")
if "NUMPYRO_SPHINXBUILD" in os.environ:
(continues on next page)
233
NumPyro Documentation
assert numpyro.__version__.startswith("0.5.0")
15.1 Dataset
The data is taken from the competition Titanic: Machine Learning from Disaster hosted on kaggle. It contains infor-
mation of passengers in the Titanic accident such as name, age, gender,. . . And our target is to predict if a person is
more likely to survive.
Look at the data info, we know that there are missing data at Age, Cabin, and Embarked columns. Although
Cabin is an important feature (because the position of a cabin in the ship can affect the chance of people in that cabin
to survive), we will skip it in this tutorial for simplicity. In the dataset, there are many categorical columns and two
numerical columns Age and Fare. Let’s first look at the distribution of those categorical columns:
3 491
1 216
2 184
Name: Pclass, dtype: int64
male 577
female 314
Name: Sex, dtype: int64
0 608
1 209
2 28
4 18
3 16
8 7
5 5
Name: SibSp, dtype: int64
0 678
1 118
2 80
5 5
3 5
4 4
6 1
Name: Parch, dtype: int64
S 644
C 168
Q 77
Name: Embarked, dtype: int64
First, we will merge rare groups in SibSp and Parch columns together. In addition, we’ll fill 2 missing entries in
Embarked by the mode S. Note that we can make a generative model for those missing entries in Embarked but
let’s skip doing so for simplicity.
Looking closer at the data, we can observe that each name contains a title. We know that age is correlated with the title
of the name: e.g. those with Mrs. would be older than those with Miss. (on average) so it might be good to create
that feature. The distribution of titles is:
We will make a new column Title, where rare titles are merged into one group Misc..
[6]: train_df["Title"] = (
train_df.Name.str.split(", ")
.str.get(1)
.str.split(" ")
.str.get(0)
.apply(lambda x: x if x in ["Mr.", "Miss.", "Mrs.", "Master."] else "Misc.")
)
Now, it is ready to turn the dataframe, which includes categorical values, into numpy arrays. We also perform stan-
dardization (a good practice for regression models) for Age column.
15.3 Modelling
and
def model1b():
x = numpyro.sample("x", dist.Normal(0, 1).expand([10].mask(False))
numpyro.sample("x_obs", dist.Normal(0, 1).expand([10]), obs=x)
Both approaches to model the partial observed data x are equivalent. For the model below, we will use the latter
method.
[8]: def model(age, pclass, title, sex, sibsp, parch, embarked, survived=None, bayesian_
˓→impute=True):
)
age = ops.index_update(age, age_nanidx, age_impute)
numpyro.sample("age", dist.Normal(age_mu, age_sigma), obs=age)
else:
# fill missing data by the mean of ages for each title
age_impute = age_mean_by_title[title][age_nanidx]
age = ops.index_update(age, age_nanidx, age_impute)
Note that in the model, the prior for age is dist.Normal(age_mu, age_sigma), where the values of age_mu
and age_sigma depend on title. Because there are missing values in age, we will encode those missing values
in the latent parameter age_impute. Then we can replace NaN entries in age with the vector age_impute.
15.4 Sampling
We will use MCMC with NUTS kernel to sample both regression coefficients and imputed values.
Number of divergences: 0
To double check that the assumption “age is correlated with title” is reasonable, let’s look at the infered age by title.
Recall that we performed standarization on age, so here we need to scale back to original domain.
The infered result confirms our assumption that Age is correlated with Title:
• those with Master. title has pretty small age (in other words, they are children in the ship) comparing to the
other groups,
• those with Mrs. title have larger age than those with Miss. title (in average).
We can also see that the result is similar to the actual statistical mean of Age given Title in our training dataset:
[11]: train_df.groupby("Title")["Age"].mean()
[11]: Title
Master. 4.574167
Misc. 42.384615
Miss. 21.773973
Mr. 32.368090
Mrs. 35.898148
Name: Age, dtype: float64
So far so good, we have many information about the regression coefficients together with imputed values and their
uncertainties. Let’s inspect those results a bit:
• The mean value -0.44 of b_Age implies that those with smaller ages have better chance to survive.
• The mean value (1.11, -1.07) of b_Sex implies that female passengers have higher chance to survive
than male passengers.
15.5 Prediction
In NumPyro, we can use Predictive utility for making predictions from posterior samples. Let’s check how well the
model performs on the training dataset. For simplicity, we will get a survived prediction for each posterior sample
and perform the majority rule on the predictions.
This is a pretty good result using a simple logistic regression model. Let’s see how the model performs if we don’t use
Bayesian imputation here.
Accuracy: 0.82042646
[13]: predict 0 1
actual
0 0.872495 0.204678
1 0.163934 0.736842
We can see that Bayesian imputation does a little bit better here.
Remark. When using posterior samples to perform prediction on the new data, we need to marginalize out
age_impute because those imputing values are specific to the training data:
posterior.pop("age_impute")
survived_pred = Predictive(model, posterior)(random.PRNGKey(3), **new_data)
15.6 References
1. McElreath, R. (2016). Statistical Rethinking: A Bayesian Course with Examples in R and Stan.
2. Kaggle competition: Titanic: Machine Learning from Disaster
In this example we show how to use NUTS to sample from the posterior over the hyperparameters of a gaussian
process.
245
NumPyro Documentation
import argparse
import os
import time
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import jax
from jax import vmap
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_feasible, init_to_median, init_to_
˓→sample, init_to_uniform, init_to_value (continues on next page)
# compute kernel
k = kernel(X, X, var, length, noise)
obs=Y)
# do GP prediction for a given set of hyperparameters. this makes use of the well-
˓→known
247
NumPyro Documentation
return X, Y, X_test
def main(args):
X, Y, X_test = get_data(N=args.num_data)
# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
samples = run_inference(model, args, rng_key, X, Y)
# do prediction
vmap_args = (random.split(rng_key_predict, samples['kernel_var'].shape[0]),
samples['kernel_var'], samples['kernel_length'], samples['kernel_
˓→noise'])
# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
plt.savefig("gp_plot.pdf")
if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
parser.add_argument("--num-chains", nargs='?', default=1, type=int)
parser.add_argument("--thinning", nargs='?', default=2, type=int)
parser.add_argument("--num-data", nargs='?', default=25, type=int)
parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".
˓→')
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
249
NumPyro Documentation
We demonstrate how to use NUTS to do inference on a simple (small) Bayesian neural network with two hidden
layers.
251
NumPyro Documentation
import argparse
import os
import time
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
(continues on next page)
# sample first layer (we put unit normal priors on all weights)
w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_
˓→H)))) # D_X D_H
z1 = nonlin(jnp.matmul(X, w1)) # N D_H <= first layer of activations
# observe data
numpyro.sample("Y", dist.Normal(z3, sigma_obs), obs=Y)
253
NumPyro Documentation
return X, Y, X_test
def main(args):
N, D_X, D_H = args.num_data, 3, args.num_hidden
X, Y, X_test = get_data(N=N, D_X=D_X)
# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
samples = run_inference(model, args, rng_key, X, Y, D_H)
predictions = predictions[..., 0]
# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
plt.savefig('bnn_plot.pdf')
if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Bayesian neural network example")
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=1000, type=int)
parser.add_argument("--num-chains", nargs='?', default=1, type=int)
parser.add_argument("--num-data", nargs='?', default=100, type=int)
parser.add_argument("--num-hidden", nargs='?', default=5, type=int)
parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".
˓→')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
255
NumPyro Documentation
We demonstrate how to do (fully Bayesian) sparse linear regression using the approach described in [1]. This approach
is particularly suitable for situations with many feature dimensions (large P) but not too many datapoints (small N). In
particular we consider a quadratic regressor of the form:
∑︁ ∑︁
𝑓 (𝑋) = constant + 𝜃𝑖 𝑋𝑖 + 𝜃𝑖𝑗 𝑋𝑖 𝑋𝑗 + observation noise
𝑖 𝑖<𝑗
References:
1. Raj Agrawal, Jonathan H. Huggins, Brian Trippe, Tamara Broderick (2019), “The Kernel Interaction Trick: Fast
Bayesian Discovery of Pairwise Interactions in High Dimensions”, (https://arxiv.org/abs/1905.06501)
import argparse
import itertools
import os
import time
import numpy as np
import jax
from jax import vmap
import jax.numpy as jnp
import jax.random as random
from jax.scipy.linalg import cho_factor, cho_solve, solve_triangular
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
257
NumPyro Documentation
# Most of the model code is concerned with constructing the sparsity inducing prior.
def model(X, Y, hypers):
S, P, N = hypers['expected_sparsity'], X.shape[1], X.shape[0]
# compute kernel
kX = kappa * X
k = kernel(kX, kX, eta1, eta2, hypers['c']) + sigma ** 2 * jnp.eye(N)
assert k.shape == (N, N)
obs=Y)
# Compute the mean and variance of coefficient theta_i (where i = dimension) for a
# MCMC sample of the kernel hyperparameters (eta1, xisq, ...).
# Compare to theorem 5.1 in reference [1].
def compute_singleton_mean_variance(X, Y, dimension, msq, lam, eta1, xisq, c, sigma):
P, N = X.shape[1], X.shape[0]
kX = kappa * X
kprobe = kappa * probe
# Compute the mean and variance of coefficient theta_ij for a MCMC sample of the
# kernel hyperparameters (eta1, xisq, ...). Compare to theorem 5.1 in reference [1].
def compute_pairwise_mean_variance(X, Y, dim1, dim2, msq, lam, eta1, xisq, c, sigma):
P, N = X.shape[1], X.shape[0]
kX = kappa * X
kprobe = kappa * probe
# Sample coefficients theta from the posterior for a given MCMC sample.
# The first P returned values are {theta_1, theta_2, ...., theta_P}, while
# the remaining values are {theta_ij} for i,j in the list `active_dims`,
# sorted so that i < j.
def sample_theta_space(X, Y, active_dims, msq, lam, eta1, xisq, c, sigma):
P, N, M = X.shape[1], X.shape[0], len(active_dims)
# the total number of coefficients we return
(continues on next page)
259
NumPyro Documentation
start1 += 2
start2 += 1
kX = kappa * X
kprobe = kappa * probe
return sample
X = np.random.randn(N, P)
# generate S coefficients with non-negligible magnitude
W = 0.5 + 2.5 * np.random.rand(S)
# generate data using the S coefficients and a single pairwise interaction
Y = np.sum(X[:, 0:S] * W, axis=-1) + X[:, 0] * X[:, 1] + sigma_obs * np.random.
˓→randn(N)
Y -= jnp.mean(Y)
Y_std = jnp.std(Y)
# Helper function for analyzing the posterior statistics for coefficient theta_i
def analyze_dimension(samples, X, Y, dimension, hypers):
vmap_args = (samples['msq'], samples['lambda'], samples['eta1'], samples['xisq'],
˓→samples['sigma'])
# Helper function for analyzing the posterior statistics for coefficient theta_ij
def analyze_pair_of_dimensions(samples, X, Y, dim1, dim2, hypers):
vmap_args = (samples['msq'], samples['lambda'], samples['eta1'], samples['xisq'],
˓→samples['sigma'])
261
NumPyro Documentation
def main(args):
X, Y, expected_thetas, expected_pairwise = get_data(N=args.num_data, P=args.num_
˓→dimensions,
S=args.active_dimensions)
# setup hyperparameters
hypers = {'expected_sparsity': max(1.0, args.num_dimensions / 10),
'alpha1': 3.0, 'beta1': 1.0,
'alpha2': 3.0, 'beta2': 1.0,
'alpha3': 1.0, 'c': 1.0}
# do inference
rng_key = random.PRNGKey(0)
samples = run_inference(model, args, rng_key, X, Y, hypers)
# compute the mean and square root variance of each coefficient theta_i
means, stds = vmap(lambda dim: analyze_dimension(samples, X, Y, dim, hypers))(jnp.
˓→arange(args.num_dimensions))
active_dimensions = []
# Compute the mean and square root variance of coefficients theta_ij for i,j
˓→active dimensions.
# Note that the resulting numbers are only meaningful for i != j.
if len(active_dimensions) > 0:
dim_pairs = jnp.array(list(itertools.product(active_dimensions, active_
˓→dimensions)))
if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="Gaussian Process example")
parser.add_argument("-n", "--num-samples", nargs="?", default=1000, type=int)
parser.add_argument("--num-warmup", nargs='?', default=500, type=int)
parser.add_argument("--num-chains", nargs='?', default=1, type=int)
parser.add_argument("--num-data", nargs='?', default=100, type=int)
parser.add_argument("--num-dimensions", nargs='?', default=20, type=int)
parser.add_argument("--active-dimensions", nargs='?', default=3, type=int)
parser.add_argument("--device", default='cpu', type=str, help='use "cpu" or "gpu".
˓→')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
263
NumPyro Documentation
You are managing a business and want to test if calling your customers will increase their chance of making a purchase.
You get 100,000 customer records and call roughly half of them and record if they make a purchase in the next three
months. You do the same for the half that did not get called. After three months, the data is in - did calling help?
This example answers this question by estimating a logistic regression model where the covariates are whether the
customer got called and their gender. We place a multivariate normal prior on the regression coefficients. We report
the 95% highest posterior density interval for the effect of making a call.
import argparse
import os
from typing import Tuple
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
num_calls = 51342
num_no_calls = 48658
265
NumPyro Documentation
logits = design_matrix.dot(beta)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(rng_key, design_matrix, outcome)
def main(args):
rng_key, _ = random.split(random.PRNGKey(3))
design_matrix, response = make_dataset(rng_key)
run_inference(design_matrix, response, rng_key,
args.num_warmup,
args.num_samples,
args.num_chains,
args.interval_size)
if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description='Testing whether ')
parser.add_argument('-n', '--num-samples', nargs='?', default=500, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1500, type=int)
parser.add_argument('--num-chains', nargs='?', default=1, type=int)
parser.add_argument('--interval-size', nargs='?', default=0.95, type=float)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".
˓→')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
267
NumPyro Documentation
The UCBadmit data is sourced from the study [1] of gender biased in graduate admissions at UC Berkeley in Fall
1973:
This example replicates the multilevel model m_glmm5 at [3], which is used to evaluate whether the data contain
evidence of gender biased in admissions accross departments. This is a form of Generalized Linear Mixed Models for
binomial regression problem, which models
• varying intercepts accross departments,
• varying slopes (or the effects of being male) accross departments,
• correlation between intercepts and slopes,
and uses non-centered parameterization (or whitening).
A more comprehensive explanation for binomial regression and non-centered parameterization can be found in Chapter
10 (Counting and Classification) and Chapter 13 (Adventures in Covariance) of [2].
References:
269
NumPyro Documentation
1. Bickel, P. J., Hammel, E. A., and O’Connell, J. W. (1975), “Sex Bias in Graduate Admissions: Data from
Berkeley”, Science, 187(4175), 398-404.
2. McElreath, R. (2018), “Statistical Rethinking: A Bayesian Course with Examples in R and Stan”, Chapman and
Hall/CRC.
3. https://github.com/rmcelreath/rethinking/tree/Experimental#multilevel-model-formulas
import argparse
import os
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import UCBADMIT, load_dataset
from numpyro.infer import MCMC, NUTS, Predictive
def main(args):
_, fetch_train = load_dataset(UCBADMIT, split='train', shuffle=False)
dept, male, applications, admit = fetch_train()
rng_key, rng_key_predict = random.split(random.PRNGKey(1))
zs = run_inference(dept, male, applications, admit, rng_key, args)
pred_probs = Predictive(glmm, zs)(rng_key_predict, dept, male, applications)[
˓→'probs']
# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
(continues on next page)
271
NumPyro Documentation
ax.legend()
plt.savefig("ucbadmit_plot.pdf")
if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description='UCBadmit gender discrimination
˓→using HMC')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
In this example, we will follow [1] to construct a semi-supervised Hidden Markov Model for a generative model with
observations are words and latent variables are categories. Instead of automatically marginalizing all discrete latent
variables (as in [2]), we will use the “forward algorithm” (which exploits the conditional independent of a Markov
model - see [3]) to iteratively do this marginalization.
The semi-supervised problem is chosen instead of an unsupervised one because it is hard to make the inference works
for an unsupervised model (see the discussion [4]). On the other hand, this example also illustrates the usage of JAX’s
lax.scan primitive. The primitive will greatly improve compiling for the model.
References:
1. https://mc-stan.org/docs/2_19/stan-users-guide/hmms-section.html
2. http://pyro.ai/examples/hmm.html
3. https://en.wikipedia.org/wiki/Forward_algorithm
4. https://discourse.pymc.io/t/how-to-marginalized-markov-chain-with-categorical/2230
273
NumPyro Documentation
import argparse
import os
import time
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
transition_prior = jnp.ones(num_categories)
emission_prior = jnp.repeat(0.1, num_words)
transition_prob = dist.Dirichlet(transition_prior).sample(key=rng_key_transition,
sample_shape=(num_
˓→categories,))
emission_prob = dist.Dirichlet(emission_prior).sample(key=rng_key_emission,
sample_shape=(num_
˓→categories,))
word = dist.Categorical(emission_prob[category]).sample(key=rng_key_emission)
categories.append(category)
words.append(word)
# Note: The following naive implementation will make it very slow to compile
# and do inference. So we use lax.scan instead.
#
# >>> log_prob = init_log_prob
# >>> for word in words:
# ... log_prob = forward_one_step(log_prob, word, transition_log_prob,
˓→emission_log_prob)
275
NumPyro Documentation
else:
log_prob, _ = lax.scan(scan_fn, init_log_prob, words)
return log_prob
obs=supervised_categories[1:])
numpyro.sample('supervised_words', dist.Categorical(emission_prob[supervised_
˓→categories]),
obs=supervised_words)
def main(args):
print('Simulating data...')
(transition_prior, emission_prior, transition_prob, emission_prob,
supervised_categories, supervised_words, unsupervised_words) = simulate_data(
random.PRNGKey(1),
num_categories=args.num_categories,
num_words=args.num_words,
num_supervised_data=args.num_supervised,
num_unsupervised_data=args.num_unsupervised,
)
print('Starting inference...')
rng_key = random.PRNGKey(2)
start = time.time()
kernel = NUTS(semi_supervised_hmm)
mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
mcmc.run(rng_key, transition_prior, emission_prior, supervised_categories,
supervised_words, unsupervised_words, args.unroll_loop)
samples = mcmc.get_samples()
print_results(samples, transition_prob, emission_prob)
print('\nMCMC elapsed time:', time.time() - start)
# make plots
fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)
x = np.linspace(0, 1, 101)
for i in range(transition_prob.shape[0]):
for j in range(transition_prob.shape[1]):
ax.plot(x, gaussian_kde(samples['transition_prob'][:, i, j])(x),
label="trans_prob[{}, {}], true value = {:.2f}"
.format(i, j, transition_prob[i, j]))
ax.set(xlabel="Probability", ylabel="Frequency",
title="Transition probability posterior")
ax.legend()
plt.savefig("hmm_plot.pdf")
if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description='Semi-supervised Hidden Markov Model
˓→')
277
NumPyro Documentation
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
This example replicates the great case study [1], which leverages the Lotka-Volterra equation [2] to describe the
dynamics of Canada lynx (predator) and snowshoe hare (prey) populations. We will use the dataset obtained from [3]
and run MCMC to get inferences about parameters of the differential equation governing the dynamics.
References:
1. Bob Carpenter (2018), “Predator-Prey Population Dynamics: the Lotka-Volterra model in Stan”.
2. https://en.wikipedia.org/wiki/Lotka-Volterra_equations
3. http://people.whitman.edu/~hundledr/courses/M250F03/M250.html
279
NumPyro Documentation
import argparse
import os
import matplotlib
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import LYNXHARE, load_dataset
from numpyro.infer import MCMC, NUTS, Predictive
def main(args):
_, fetch = load_dataset(LYNXHARE, shuffle=False)
year, data = fetch() # data is in hare -> lynx order
# predict populations
pop_pred = Predictive(model, mcmc.get_samples())(PRNGKey(2), data.shape[0])["y"]
mu, pi = jnp.mean(pop_pred, 0), jnp.percentile(pop_pred, (10, 90), 0)
plt.figure(figsize=(8, 6), constrained_layout=True)
plt.plot(year, data[:, 0], "ko", mfc="none", ms=4, label="true hare", alpha=0.67)
plt.plot(year, data[:, 1], "bx", label="true lynx")
plt.plot(year, mu[:, 0], "k-.", label="pred hare", lw=1, alpha=0.67)
(continues on next page)
281
NumPyro Documentation
plt.savefig("ode_plot.pdf")
if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description='Predator-Prey Model')
parser.add_argument('-n', '--num-samples', nargs='?', default=1000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
parser.add_argument("--num-chains", nargs='?', default=1, type=int)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".
˓→')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
This example illustrates how to use a trained AutoBNAFNormal autoguide to transform a posterior to a Gaussian-like
one. The transform will be used to get better mixing rate for NUTS sampler.
References:
1. Hoffman, M. et al. (2019), “NeuTra-lizing Bad Geometry in Hamiltonian Monte Carlo Using Neural Transport”,
(https://arxiv.org/abs/1903.03704)
283
NumPyro Documentation
import argparse
import os
(continues on next page)
import numpyro
from numpyro import optim
from numpyro.diagnostics import print_summary
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal
from numpyro.infer.reparam import NeuTraReparam
class DualMoonDistribution(dist.Distribution):
support = constraints.real_vector
def __init__(self):
super(DualMoonDistribution, self).__init__(event_shape=(2,))
def dual_moon_model():
numpyro.sample('x', DualMoonDistribution())
def main(args):
print("Start vanilla HMC...")
nuts_kernel = NUTS(dual_moon_model)
mcmc = MCMC(nuts_kernel, args.num_warmup, args.num_samples, num_chains=args.num_
˓→chains,
285
NumPyro Documentation
# make plots
guide_trans_samples = neutra.transform_sample(guide_base_samples)['x']
x1 = jnp.linspace(-3, 3, 100)
x2 = jnp.linspace(-3, 3, 100)
X1, X2 = jnp.meshgrid(x1, x2)
P = jnp.exp(DualMoonDistribution().log_prob(jnp.stack([X1, X2], axis=-1)))
ax1.plot(svi_result.losses[1000:])
ax1.set_title('Autoguide training loss\n(after 1000 steps)')
plt.savefig("neutra.pdf")
if __name__ == "__main__":
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="NeuTra HMC")
parser.add_argument('-n', '--num-samples', nargs='?', default=4000, type=int)
parser.add_argument('--num-warmup', nargs='?', default=1000, type=int)
parser.add_argument("--num-chains", nargs='?', default=1, type=int)
parser.add_argument('--hidden-factor', nargs='?', default=8, type=int)
parser.add_argument('--num-iters', nargs='?', default=10000, type=int)
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".
˓→')
args = parser.parse_args()
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
main(args)
287
NumPyro Documentation
• genindex
• modindex
• search
289
NumPyro Documentation
n
numpyro.contrib.funsor, 119
numpyro.contrib.indexing, 136
numpyro.contrib.tfp.distributions, 60
numpyro.contrib.tfp.mcmc, 100
numpyro.diagnostics, 129
numpyro.handlers, 19
numpyro.infer.autoguide, 109
numpyro.infer.reparam, 117
numpyro.infer.util, 132
numpyro.optim, 122
numpyro.primitives, 11
numpyro.util, 131
291
NumPyro Documentation
293
NumPyro Documentation
294 Index
NumPyro Documentation
ContinuousBernoulli (class in E
numpyro.contrib.tfp.distributions), 63 effective_sample_size() (in module
corr_cholesky (in module numpyro.diagnostics), 130
numpyro.distributions.constraints), 76 ELBO (class in numpyro.infer.elbo), 106
corr_matrix (in module Empirical (class in numpyro.contrib.tfp.distributions),
numpyro.distributions.constraints), 76 63
CorrCholeskyTransform (class in enable_validation() (in module
numpyro.distributions.transforms), 79 numpyro.distributions.distribution), 131
covariance_matrix (LowRankMultivariateNormal enable_x64() (in module numpyro.util), 132
attribute), 46 entropy() (LowRankMultivariateNormal method), 46
covariance_matrix (MultivariateNormal attribute), enum (class in numpyro.contrib.funsor.enum_messenger),
46 119
enumerate_support() (BernoulliLogits method),
D 50
default_fields (HMC attribute), 89 enumerate_support() (BernoulliProbs method),
default_fields (MCMCKernel attribute), 88 51
default_fields (SA attribute), 96 enumerate_support() (BetaBinomial method), 51
Delta (class in numpyro.distributions.distribution), 36 enumerate_support() (BinomialLogits method),
dependent (in module 52
numpyro.distributions.constraints), 76 enumerate_support() (BinomialProbs method), 53
DeterminantalPointProcess (class in enumerate_support() (CategoricalLogits method),
numpyro.contrib.tfp.distributions), 63 54
Deterministic (class in enumerate_support() (CategoricalProbs method),
numpyro.contrib.tfp.distributions), 63 54
deterministic() (in module numpyro.primitives), enumerate_support() (Distribution method), 29
13 enumerate_support() (ExpandedDistribution
Dirichlet (class in numpyro.contrib.tfp.distributions), method), 31
63 enumerate_support() (MaskedDistribution
Dirichlet (class in method), 34
numpyro.distributions.continuous), 38 eval_and_update() (Adagrad method), 123
DirichletMultinomial (class in eval_and_update() (Adam method), 122
numpyro.contrib.tfp.distributions), 63 eval_and_update() (ClippedAdam method), 124
DirichletMultinomial (class in eval_and_update() (Minimize method), 125
numpyro.distributions.conjugate), 54 eval_and_update() (Momentum method), 126
DiscreteHMCGibbs (class in eval_and_update() (RMSProp method), 126
numpyro.infer.hmc_gibbs), 92 eval_and_update() (RMSPropMomentum method),
Distribution (class in 127
numpyro.distributions.distribution), 27 eval_and_update() (SGD method), 128
do (class in numpyro.handlers), 21 eval_and_update() (SM3 method), 128
domain (AbsTransform attribute), 79 evaluate() (SVI method), 106
domain (BlockNeuralAutoregressiveTransform at- event_dim (Constraint attribute), 76
tribute), 83 event_dim (Distribution attribute), 28
domain (ComposeTransform attribute), 79 event_dim (Transform attribute), 78
domain (CorrCholeskyTransform attribute), 80 event_shape (Distribution attribute), 28
domain (InverseAutoregressiveTransform attribute), 83 expand() (Distribution method), 29
domain (LowerCholeskyAffine attribute), 81 expand() (Independent method), 33
domain (LowerCholeskyTransform attribute), 81 expand_by() (Distribution method), 29
domain (OrderedTransform attribute), 81 ExpandedDistribution (class in
domain (PermuteTransform attribute), 82 numpyro.distributions.distribution), 30
domain (PowerTransform attribute), 82 ExpGamma (class in numpyro.contrib.tfp.distributions),
domain (StickBreakingTransform attribute), 82 64
domain (Transform attribute), 78 ExpInverseGamma (class in
DoublesidedMaxwell (class in numpyro.contrib.tfp.distributions), 64
numpyro.contrib.tfp.distributions), 63
Index 295
NumPyro Documentation
296 Index
NumPyro Documentation
Index 297
NumPyro Documentation
298 Index
NumPyro Documentation
Index 299
NumPyro Documentation
300 Index
NumPyro Documentation
O precision_matrix (LowRankMultivariateNormal
OneHotCategorical (class in attribute), 46
numpyro.contrib.tfp.distributions), 71 precision_matrix (MultivariateNormal attribute),
ordered_vector (in module 46
numpyro.distributions.constraints), 77 Predictive (class in numpyro.infer.util), 132
OrderedLogistic (class in print_summary() (in module numpyro.diagnostics),
numpyro.contrib.tfp.distributions), 71 131
OrderedLogistic (class in print_summary() (MCMC method), 86
numpyro.distributions.discrete), 58 prng_key() (in module numpyro.primitives), 14
OrderedTransform (class in PRNGIdentity (class in
numpyro.distributions.transforms), 81 numpyro.distributions.discrete), 59
ProbitBernoulli (class in
P numpyro.contrib.tfp.distributions), 72
param() (in module numpyro.primitives), 11 probs (BernoulliLogits attribute), 50
parametric() (in module numpyro.infer.hmc_util), probs (BinomialLogits attribute), 52
104 probs (CategoricalLogits attribute), 53
parametric_draws() (in module probs (GeometricLogits attribute), 56
numpyro.infer.hmc_util), 104 probs (MultinomialLogits attribute), 57
Pareto (class in numpyro.contrib.tfp.distributions), 71 process_message() (block method), 20
Pareto (class in numpyro.distributions.continuous), 47 process_message() (collapse method), 21
PermuteTransform (class in process_message() (condition method), 21
numpyro.distributions.transforms), 82 process_message() (do method), 22
PERT (class in numpyro.contrib.tfp.distributions), 71 process_message() (enum method), 119
PlackettLuce (class in process_message() (infer_config method), 22, 120
numpyro.contrib.tfp.distributions), 72 process_message() (lift method), 23
plate (class in numpyro.contrib.funsor.enum_messenger), process_message() (mask method), 23
120 process_message() (plate method), 120
plate (class in numpyro.primitives), 12 process_message() (reparam method), 23
plate_stack() (in module numpyro.primitives), 13 process_message() (replay method), 24
plate_to_enum_plate() (in module process_message() (scale method), 24
numpyro.contrib.funsor.infer_util), 122 process_message() (scope method), 25
Poisson (class in numpyro.contrib.tfp.distributions), 72 process_message() (seed method), 26
Poisson (class in numpyro.distributions.discrete), 58 process_message() (substitute method), 26
PoissonLogNormalQuadratureCompound (class
in numpyro.contrib.tfp.distributions), 72
Q
positive (in module quantiles() (AutoContinuous method), 110
numpyro.distributions.constraints), 77 quantiles() (AutoDiagonalNormal method), 111
positive_definite (in module quantiles() (AutoLaplaceApproximation method),
numpyro.distributions.constraints), 78 114
positive_integer (in module quantiles() (AutoLowRankMultivariateNormal
numpyro.distributions.constraints), 78 method), 115
post_warmup_state (MCMC attribute), 84 quantiles() (AutoMultivariateNormal method), 112
postprocess_fn() (HMC method), 89 quantiles() (AutoNormal method), 116
postprocess_fn() (HMCGibbs method), 91 QuantizedDistribution (class in
postprocess_fn() (MCMCKernel method), 87 numpyro.contrib.tfp.distributions), 72
postprocess_fn() (SA method), 96
postprocess_message() (plate method), 120 R
postprocess_message() (trace method), 27, 121 random_flax_module() (in module
potential_energy() (in module numpyro.contrib.module), 15
numpyro.infer.util), 134 random_haiku_module() (in module
PowerSpherical (class in numpyro.contrib.module), 17
numpyro.contrib.tfp.distributions), 72 RandomWalkMetropolis (class in
PowerTransform (class in numpyro.contrib.tfp.mcmc), 101
numpyro.distributions.transforms), 82 real (in module numpyro.distributions.constraints), 78
Index 301
NumPyro Documentation
302 Index
NumPyro Documentation
Index 303
NumPyro Documentation
304 Index
NumPyro Documentation
Index 305