Variational API quickstart

The variational inference (VI) API is focused on approximating posterior distributions for Bayesian models. Common use cases to which this module can be applied include:

  • Sampling from model posterior and computing arbitrary expressions

  • Conduct Monte Carlo approximation of expectation, variance, and other statistics

  • Remove symbolic dependence on PyMC3 random nodes and evaluate expressions (using eval)

  • Provide a bridge to arbitrary Theano code

Sounds good, doesn’t it?

The module provides an interface to a variety of inference methods, so you are free to choose what is most appropriate for the problem.

[1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
import theano

np.random.seed(42)
pm.set_tt_rng(42)

Basic setup

We do not need complex models to play with the VI API; let’s begin with a simple mixture model:

[2]:
w = pm.floatX([0.2, 0.8])
mu = pm.floatX([-0.3, 0.5])
sd = pm.floatX([0.1, 0.1])

with pm.Model() as model:
    x = pm.NormalMixture("x", w=w, mu=mu, sigma=sd, dtype=theano.config.floatX)
    x2 = x ** 2
    sin_x = pm.math.sin(x)

We can’t compute analytical expectations for this model. However, we can obtain an approximation using Markov chain Monte Carlo methods; let’s use NUTS first.

To allow samples of the expressions to be saved, we need to wrap them in Deterministic objects:

[3]:
with model:
    pm.Deterministic("x2", x2)
    pm.Deterministic("sin_x", sin_x)
[4]:
with model:
    trace = pm.sample(50000)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [x]
100.00% [204000/204000 03:06<00:00 Sampling 4 chains, 0 divergences]
Sampling 4 chains for 1_000 tune and 50_000 draw iterations (4_000 + 200_000 draws total) took 186 seconds.
The estimated number of effective samples is smaller than 200 for some parameters.
[5]:
pm.traceplot(trace);
/dependencies/arviz/arviz/data/io_pymc3.py:89: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  FutureWarning,
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_7_1.png

Above are traces for \(x^2\) and \(sin(x)\). We can see there is clear multi-modality in this model. One drawback, is that you need to know in advance what exactly you want to see in trace and wrap it with Deterministic.

The VI API takes an alternate approach: You obtain inference from model, then calculate expressions based on this model afterwards.

Let’s use the same model:

[6]:
with pm.Model() as model:

    x = pm.NormalMixture("x", w=w, mu=mu, sigma=sd, dtype=theano.config.floatX)
    x2 = x ** 2
    sin_x = pm.math.sin(x)

Here we will use automatic differentiation variational inference (ADVI).

[7]:
with model:
    mean_field = pm.fit(method="advi")
100.00% [10000/10000 00:01<00:00 Average Loss = 2.2413]
Finished [100%]: Average Loss = 2.2687
[8]:
pm.plot_posterior(mean_field.sample(1000), color="LightSeaGreen");
/dependencies/arviz/arviz/data/io_pymc3.py:89: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  FutureWarning,
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_12_1.png

Notice that ADVI has failed to approximate the multimodal distribution, since it uses a Gaussian distribution that has a single mode.

Checking convergence

[9]:
help(pm.callbacks.CheckParametersConvergence)
Help on class CheckParametersConvergence in module pymc3.variational.callbacks:

class CheckParametersConvergence(Callback)
 |  CheckParametersConvergence(every=100, tolerance=0.001, diff='relative', ord=inf)
 |
 |  Convergence stopping check
 |
 |  Parameters
 |  ----------
 |  every: int
 |      check frequency
 |  tolerance: float
 |      if diff norm < tolerance: break
 |  diff: str
 |      difference type one of {'absolute', 'relative'}
 |  ord: {non-zero int, inf, -inf, 'fro', 'nuc'}, optional
 |      see more info in :func:`numpy.linalg.norm`
 |
 |  Examples
 |  --------
 |  >>> with model:
 |  ...     approx = pm.fit(
 |  ...         n=10000, callbacks=[
 |  ...             CheckParametersConvergence(
 |  ...                 every=50, diff='absolute',
 |  ...                 tolerance=1e-4)
 |  ...         ]
 |  ...     )
 |
 |  Method resolution order:
 |      CheckParametersConvergence
 |      Callback
 |      builtins.object
 |
 |  Methods defined here:
 |
 |  __call__(self, approx, _, i)
 |      Call self as a function.
 |
 |  __init__(self, every=100, tolerance=0.001, diff='relative', ord=inf)
 |      Initialize self.  See help(type(self)) for accurate signature.
 |
 |  ----------------------------------------------------------------------
 |  Static methods defined here:
 |
 |  flatten_shared(shared_list)
 |
 |  ----------------------------------------------------------------------
 |  Data descriptors inherited from Callback:
 |
 |  __dict__
 |      dictionary for instance variables (if defined)
 |
 |  __weakref__
 |      list of weak references to the object (if defined)

Let’s use the default arguments for CheckParametersConvergence as they seem to be reasonable.

[10]:
from pymc3.variational.callbacks import CheckParametersConvergence

with model:
    mean_field = pm.fit(method="advi", callbacks=[CheckParametersConvergence()])
100.00% [10000/10000 00:01<00:00 Average Loss = 2.2559]
Finished [100%]: Average Loss = 2.2763

We can access inference history via .hist attribute.

[11]:
plt.plot(mean_field.hist);
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_18_0.png

This is not a good convergence plot, despite the fact that we ran many iterations. The reason is that the mean of the ADVI approximation is close to zero, and therefore taking the relative difference (the default method) is unstable for checking convergence.

[12]:
with model:
    mean_field = pm.fit(
        method="advi", callbacks=[pm.callbacks.CheckParametersConvergence(diff="absolute")]
    )
42.63% [4263/10000 00:00<00:00 Average Loss = 3.2279]
Convergence achieved at 4700
Interrupted at 4,699 [46%]: Average Loss = 4.7996
[13]:
plt.plot(mean_field.hist);
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_21_0.png

That’s much better! We’ve reached convergence after less than 5000 iterations.

Tracking parameters

Another usefull callback allows users to track parameters. It allows for the tracking of arbitrary statistics during inference, though it can be memory-hungry. Using the fit function, we do not have direct access to the approximation before inference. However, tracking parameters requires access to the approximation. We can get around this constraint by using the object-oriented (OO) API for inference.

[14]:
with model:
    advi = pm.ADVI()
[15]:
advi.approx
[15]:
<pymc3.variational.approximations.MeanField at 0x7f7dcca2b090>

Different approximations have different hyperparameters. In mean-field ADVI, we have \(\rho\) and \(\mu\) (inspired by Bayes by BackProp).

[16]:
advi.approx.shared_params
[16]:
{'mu': mu, 'rho': rho}

There are convenient shortcuts to relevant statistics associated with the approximation. This can be useful, for example, when specifying a mass matrix for NUTS sampling:

[17]:
advi.approx.mean.eval(), advi.approx.std.eval()
[17]:
(array([0.34]), array([0.69314718]))

We can roll these statistics into the Tracker callback.

[18]:
tracker = pm.callbacks.Tracker(
    mean=advi.approx.mean.eval,  # callable that returns mean
    std=advi.approx.std.eval,  # callable that returns std
)

Now, calling advi.fit will record the mean and standard deviation of the approximation as it runs.

[19]:
approx = advi.fit(20000, callbacks=[tracker])
100.00% [20000/20000 00:05<00:00 Average Loss = 1.9568]
Finished [100%]: Average Loss = 1.9589

We can now plot both the evidence lower bound and parameter traces:

[20]:
fig = plt.figure(figsize=(16, 9))
mu_ax = fig.add_subplot(221)
std_ax = fig.add_subplot(222)
hist_ax = fig.add_subplot(212)
mu_ax.plot(tracker["mean"])
mu_ax.set_title("Mean track")
std_ax.plot(tracker["std"])
std_ax.set_title("Std track")
hist_ax.plot(advi.hist)
hist_ax.set_title("Negative ELBO track");
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_36_0.png

Notice that there are convergence issues with the mean, and that lack of convergence does not seem to change the ELBO trajectory significantly. As we are using the OO API, we can run the approximation longer until convergence is achieved.

[21]:
advi.refine(100000)
100.00% [100000/100000 00:24<00:00 Average Loss = 1.8638]
Finished [100%]: Average Loss = 1.8422

Let’s take a look:

[22]:
fig = plt.figure(figsize=(16, 9))
mu_ax = fig.add_subplot(221)
std_ax = fig.add_subplot(222)
hist_ax = fig.add_subplot(212)
mu_ax.plot(tracker["mean"])
mu_ax.set_title("Mean track")
std_ax.plot(tracker["std"])
std_ax.set_title("Std track")
hist_ax.plot(advi.hist)
hist_ax.set_title("Negative ELBO track");
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_40_0.png

We still see evidence for lack of convergence, as the mean has devolved into a random walk. This could be the result of choosing a poor algorithm for inference. At any rate, it is unstable and can produce very different results even using different random seeds.

Let’s compare results with the NUTS output:

[23]:
import seaborn as sns

ax = sns.kdeplot(trace["x"], label="NUTS")
sns.kdeplot(approx.sample(10000)["x"], label="ADVI");
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_42_0.png

Again, we see that ADVI is not able to cope with multimodality; we can instead use SVGD, which generates an approximation based on a large number of particles.

[24]:
with model:
    svgd_approx = pm.fit(
        300,
        method="svgd",
        inf_kwargs=dict(n_particles=1000),
        obj_optimizer=pm.sgd(learning_rate=0.01),
    )
100.00% [300/300 02:05<00:00]
[25]:
ax = sns.kdeplot(trace["x"], label="NUTS")
sns.kdeplot(approx.sample(10000)["x"], label="ADVI")
sns.kdeplot(svgd_approx.sample(2000)["x"], label="SVGD");
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_45_0.png

That did the trick, as we now have a multimodal approximation using SVGD.

With this, it is possible to calculate arbitrary functions of the parameters with this variational approximation. For example we can calculate \(x^2\) and \(sin(x)\), as with the NUTS model.

[26]:
# recall x ~ NormalMixture
a = x ** 2
b = pm.math.sin(x)

To evaluate these expressions with the approximation, we need approx.sample_node.

[27]:
help(svgd_approx.sample_node)
Help on method sample_node in module pymc3.variational.opvi:

sample_node(node, size=None, deterministic=False, more_replacements=None) method of pymc3.variational.approximations.Empirical instance
    Samples given node or nodes over shared posterior

    Parameters
    ----------
    node: Theano Variables (or Theano expressions)
    size: None or scalar
        number of samples
    more_replacements: `dict`
        add custom replacements to graph, e.g. change input source
    deterministic: bool
        whether to use zeros as initial distribution
        if True - zero initial point will produce constant latent variables

    Returns
    -------
    sampled node(s) with replacements

[28]:
a_sample = svgd_approx.sample_node(a)
a_sample.eval()
[28]:
array(0.20617133)
[29]:
a_sample.eval()
[29]:
array(0.23059109)
[30]:
a_sample.eval()
[30]:
array(0.01689826)

Every call yields a different value from the same theano node. This is because it is stochastic.

By applying replacements, we are now free of the dependence on the PyMC3 model; instead, we now depend on the approximation. Changing it will change the distribution for stochastic nodes:

[31]:
sns.kdeplot(np.array([a_sample.eval() for _ in range(2000)]))
plt.title("$x^2$ distribution");
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_54_0.png

There is a more convinient way to get lots of samples at once: sample_node

[32]:
a_samples = svgd_approx.sample_node(a, size=1000)
[33]:
sns.kdeplot(a_samples.eval())
plt.title("$x^2$ distribution");
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_57_0.png

The sample_node function includes an additional dimension, so taking expectations or calculating variance is specified by axis=0.

[34]:
a_samples.var(0).eval()  # variance
[34]:
array(0.0963961)
[35]:
a_samples.mean(0).eval()  # mean
[35]:
array(0.24696937)

A symbolic sample size can also be specified:

[36]:
i = theano.tensor.iscalar("i")
i.tag.test_value = 1
a_samples_i = svgd_approx.sample_node(a, size=i)
[37]:
a_samples_i.eval({i: 100}).shape
[37]:
(100,)
[38]:
a_samples_i.eval({i: 10000}).shape
[38]:
(10000,)

Unfortunately the size must be a scalar value.

Converting a Trace to an Approximation

We can convert a MCMC trace into an Approximation. It will have the same API as approximations above with same sample_node methods:

[39]:
trace_approx = pm.Empirical(trace, model=model)
trace_approx
[39]:
<pymc3.variational.approximations.Empirical at 0x7f7e3a00af10>

We can then draw samples from the Emipirical object:

[40]:
pm.plot_posterior(trace_approx.sample(10000));
/dependencies/arviz/arviz/data/io_pymc3.py:89: FutureWarning: Using `from_pymc3` without the model will be deprecated in a future release. Not using the model will return less accurate and less useful results. Make sure you use the model argument or call from_pymc3 within a model context.
  FutureWarning,
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_69_1.png

Multilabel logistic regression

Let’s illustrate the use of Tracker with the famous Iris dataset. We’ll attempy multi-label classification and compute the expected accuracy score as a diagnostic.

[41]:
import pandas as pd
import theano.tensor as tt

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

X, y = load_iris(True)
X_train, X_test, y_train, y_test = train_test_split(X, y)
/env/miniconda3/lib/python3.7/site-packages/sklearn/utils/validation.py:71: FutureWarning: Pass return_X_y=True as keyword args. From version 0.25 passing these as positional arguments will result in an error
  FutureWarning)

image1

A relatively simple model will be sufficient here because the classes are roughly linearly separable; we are going to fit multinomial logistic regression.

[42]:
Xt = theano.shared(X_train)
yt = theano.shared(y_train)

with pm.Model() as iris_model:

    # Coefficients for features
    β = pm.Normal("β", 0, sigma=1e2, shape=(4, 3))
    # Transoform to unit interval
    a = pm.Flat("a", shape=(3,))
    p = tt.nnet.softmax(Xt.dot(β) + a)

    observed = pm.Categorical("obs", p=p, observed=yt)
/env/miniconda3/lib/python3.7/site-packages/theano/tensor/subtensor.py:2197: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  rval = inputs[0].__getitem__(inputs[1:])
/env/miniconda3/lib/python3.7/site-packages/theano/tensor/subtensor.py:2197: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  rval = inputs[0].__getitem__(inputs[1:])

Applying replacements in practice

PyMC3 models have symbolic inputs for latent variables. To evaluate an espression that requires knowledge of latent variables, one needs to provide fixed values. We can use values approximated by VI for this purpose. The function sample_node removes the symbolic dependenices.

sample_node will use the whole distribution at each step, so we will use it here. We can apply more replacements in single function call using the more_replacements keyword argument in both replacement functions.

HINT: You can use more_replacements argument when calling fit too: * pm.fit(more_replacements={full_data: minibatch_data}) * inference.fit(more_replacements={full_data: minibatch_data})

[43]:
with iris_model:

    # We'll use SVGD
    inference = pm.SVGD(n_particles=500, jitter=1)

    # Local reference to approximation
    approx = inference.approx

    # Here we need `more_replacements` to change train_set to test_set
    test_probs = approx.sample_node(p, more_replacements={Xt: X_test}, size=100)

    # For train set no more replacements needed
    train_probs = approx.sample_node(p)

By applying the code above, we now have 100 sampled probabilities (default number for sample_node is None) for each observation.

Next we create symbolic expressions for sampled accuracy scores:

[44]:
test_ok = tt.eq(test_probs.argmax(-1), y_test)
train_ok = tt.eq(train_probs.argmax(-1), y_train)
test_accuracy = test_ok.mean(-1)
train_accuracy = train_ok.mean(-1)

Tracker expects callables so we can pass .eval method of theano node that is function itself.

Calls to this function are cached so they can be reused.

[45]:
eval_tracker = pm.callbacks.Tracker(
    test_accuracy=test_accuracy.eval, train_accuracy=train_accuracy.eval
)
[46]:
inference.fit(100, callbacks=[eval_tracker]);
100.00% [100/100 00:38<00:00]
/env/miniconda3/lib/python3.7/site-packages/theano/tensor/subtensor.py:2197: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  rval = inputs[0].__getitem__(inputs[1:])
/env/miniconda3/lib/python3.7/site-packages/theano/tensor/subtensor.py:2197: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  rval = inputs[0].__getitem__(inputs[1:])
[47]:
_, ax = plt.subplots(1, 1)
df = pd.DataFrame(eval_tracker["test_accuracy"]).T.melt()
sns.lineplot(x="variable", y="value", data=df, color="red", ax=ax)
ax.plot(eval_tracker["train_accuracy"], color="blue")
ax.set_xlabel("epoch")
plt.legend(["test_accuracy", "train_accuracy"])
plt.title("Training Progress");
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_83_0.png

Training does not seem to be working here. Let’s use a different optimizer and boost the learning rate.

[48]:
inference.fit(400, obj_optimizer=pm.adamax(learning_rate=0.1), callbacks=[eval_tracker]);
100.00% [400/400 02:33<00:00]
/env/miniconda3/lib/python3.7/site-packages/theano/tensor/subtensor.py:2197: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  rval = inputs[0].__getitem__(inputs[1:])
[49]:
_, ax = plt.subplots(1, 1)
df = pd.DataFrame(np.asarray(eval_tracker["test_accuracy"])).T.melt()
sns.lineplot(x="variable", y="value", data=df, color="red", ax=ax)
ax.plot(eval_tracker["train_accuracy"], color="blue")
ax.set_xlabel("epoch")
plt.legend(["test_accuracy", "train_accuracy"])
plt.title("Training Progress");
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_86_0.png

This is much better!

So, Tracker allows us to monitor our approximation and choose good training schedule.

Minibatches

When dealing with large datasets, using minibatch training can drastically speed up and improve approximation performance. Large datasets impose a hefty cost on the computation of gradients.

There is a nice API in pymc3 to handle these cases, which is avaliable through the pm.Minibatch class. The minibatch is just a highly specialized Theano tensor:

[50]:
issubclass(pm.Minibatch, theano.tensor.TensorVariable)
[50]:
True

To demonstrate, let’s simulate a large quantity of data:

[51]:
# Raw values
data = np.random.rand(40000, 100)
# Scaled values
data *= np.random.randint(1, 10, size=(100,))
# Shifted values
data += np.random.rand(100) * 10

For comparison, let’s fit a model without minibatch processing:

[52]:
with pm.Model() as model:
    mu = pm.Flat("mu", shape=(100,))
    sd = pm.HalfNormal("sd", shape=(100,))
    lik = pm.Normal("lik", mu, sd, observed=data)

Just for fun, let’s create a custom special purpose callback to halt slow optimization. Here we define a callback that causes a hard stop when approximation runs too slowly:

[53]:
def stop_after_10(approx, loss_history, i):
    if (i > 0) and (i % 10) == 0:
        raise StopIteration("I was slow, sorry")
[54]:
with model:
    advifit = pm.fit(callbacks=[stop_after_10])
0.09% [9/10000 00:01<25:51 Average Loss = 7.7692e+08]
I was slow, sorry
Interrupted at 9 [0%]: Average Loss = 5.6736e+08

Inference is too slow, taking several seconds per iteration; fitting the approximation would have taken hours!

Now let’s use minibatches. At every iteration, we will draw 500 random values:

Remember to set total_size in observed

total_size is an important parameter that allows pymc3 to infer the right way of rescaling densities. If it is not set, you are likely to get completely wrong results. For more information please refer to the comprehensive documentation of pm.Minibatch.

[55]:
X = pm.Minibatch(data, batch_size=500)

with pm.Model() as model:

    mu = pm.Flat("mu", shape=(100,))
    sd = pm.HalfNormal("sd", shape=(100,))
    likelihood = pm.Normal("likelihood", mu, sd, observed=X, total_size=data.shape)
/dependencies/pymc3/pymc3/data.py:307: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  self.shared = theano.shared(data[in_memory_slc])
[56]:
with model:
    advifit = pm.fit()
100.00% [10000/10000 00:16<00:00 Average Loss = 1.546e+05]
Finished [100%]: Average Loss = 1.5452e+05
[57]:
plt.plot(advifit.hist);
../../../_images/pymc-examples_examples_variational_inference_variational_api_quickstart_100_0.png

Minibatch inference is dramatically faster. Multidimensional minibatches may be needed for some corner cases where you do matrix factorization or model is very wide.

Here is the docstring for Minibatch to illustrate how it can be customized.

[58]:
print(pm.Minibatch.__doc__)
Multidimensional minibatch that is pure TensorVariable

    Parameters
    ----------
    data: np.ndarray
        initial data
    batch_size: ``int`` or ``List[int|tuple(size, random_seed)]``
        batch size for inference, random seed is needed
        for child random generators
    dtype: ``str``
        cast data to specific type
    broadcastable: tuple[bool]
        change broadcastable pattern that defaults to ``(False, ) * ndim``
    name: ``str``
        name for tensor, defaults to "Minibatch"
    random_seed: ``int``
        random seed that is used by default
    update_shared_f: ``callable``
        returns :class:`ndarray` that will be carefully
        stored to underlying shared variable
        you can use it to change source of
        minibatches programmatically
    in_memory_size: ``int`` or ``List[int|slice|Ellipsis]``
        data size for storing in ``theano.shared``

    Attributes
    ----------
    shared: shared tensor
        Used for storing data
    minibatch: minibatch tensor
        Used for training

    Notes
    -----
    Below is a common use case of Minibatch with variational inference.
    Importantly, we need to make PyMC3 "aware" that a minibatch is being used in inference.
    Otherwise, we will get the wrong :math:`logp` for the model.
    the density of the model ``logp`` that is affected by Minibatch. See more in the examples below.
    To do so, we need to pass the ``total_size`` parameter to the observed node, which correctly scales
    the density of the model ``logp`` that is affected by Minibatch. See more in the examples below.

    Examples
    --------
    Consider we have `data` as follows:

    >>> data = np.random.rand(100, 100)

    if we want a 1d slice of size 10 we do

    >>> x = Minibatch(data, batch_size=10)

    Note that your data is cast to ``floatX`` if it is not integer type
    But you still can add the ``dtype`` kwarg for :class:`Minibatch`
    if you need more control.

    If we want 10 sampled rows and columns
    ``[(size, seed), (size, seed)]`` we can use

    >>> x = Minibatch(data, batch_size=[(10, 42), (10, 42)], dtype='int32')
    >>> assert str(x.dtype) == 'int32'


    Or, more simply, we can use the default random seed = 42
    ``[size, size]``

    >>> x = Minibatch(data, batch_size=[10, 10])


    In the above, `x` is a regular :class:`TensorVariable` that supports any math operations:


    >>> assert x.eval().shape == (10, 10)


    You can pass the Minibatch `x` to your desired model:

    >>> with pm.Model() as model:
    ...     mu = pm.Flat('mu')
    ...     sd = pm.HalfNormal('sd')
    ...     lik = pm.Normal('lik', mu, sd, observed=x, total_size=(100, 100))


    Then you can perform regular Variational Inference out of the box


    >>> with model:
    ...     approx = pm.fit()


    Important note: :class:``Minibatch`` has ``shared``, and ``minibatch`` attributes
    you can call later:

    >>> x.set_value(np.random.laplace(size=(100, 100)))

    and minibatches will be then from new storage
    it directly affects ``x.shared``.
    A less convenient convenient, but more explicit, way to achieve the same
    thing:

    >>> x.shared.set_value(pm.floatX(np.random.laplace(size=(100, 100))))

    The programmatic way to change storage is as follows
    I import ``partial`` for simplicity
    >>> from functools import partial
    >>> datagen = partial(np.random.laplace, size=(100, 100))
    >>> x = Minibatch(datagen(), batch_size=10, update_shared_f=datagen)
    >>> x.update_shared()

    To be more concrete about how we create a minibatch, here is a demo:
    1. create a shared variable

        >>> shared = theano.shared(data)

    2. take a random slice of size 10:

        >>> ridx = pm.tt_rng().uniform(size=(10,), low=0, high=data.shape[0]-1e-10).astype('int64')

    3) take the resulting slice:

        >>> minibatch = shared[ridx]

    That's done. Now you can use this minibatch somewhere else.
    You can see that the implementation does not require a fixed shape
    for the shared variable. Feel free to use that if needed.
    *FIXME: What is "that" which we can use here?  A fixed shape?  Should this say
    "but feel free to put a fixed shape on the shared variable, if appropriate?"*

    Suppose you need to make some replacements in the graph, e.g. change the minibatch to testdata

    >>> node = x ** 2  # arbitrary expressions on minibatch `x`
    >>> testdata = pm.floatX(np.random.laplace(size=(1000, 10)))

    Then you should create a `dict` with replacements:

    >>> replacements = {x: testdata}
    >>> rnode = theano.clone(node, replacements)
    >>> assert (testdata ** 2 == rnode.eval()).all()

    *FIXME: In the following, what is the **reason** to replace the Minibatch variable with
    its shared variable?  And in the following, the `rnode` is a **new** node, not a modification
    of a previously existing node, correct?*
    To replace a minibatch with its shared variable you should do
    the same things. The Minibatch variable is accessible through the `minibatch` attribute.
    For example

    >>> replacements = {x.minibatch: x.shared}
    >>> rnode = theano.clone(node, replacements)

    For more complex slices some more code is needed that can seem not so clear

    >>> moredata = np.random.rand(10, 20, 30, 40, 50)

    The default ``total_size`` that can be passed to ``PyMC3`` random node
    is then ``(10, 20, 30, 40, 50)`` but can be less verbose in some cases

    1. Advanced indexing, ``total_size = (10, Ellipsis, 50)``

        >>> x = Minibatch(moredata, [2, Ellipsis, 10])

        We take the slice only for the first and last dimension

        >>> assert x.eval().shape == (2, 20, 30, 40, 10)

    2. Skipping a particular dimension, ``total_size = (10, None, 30)``:

        >>> x = Minibatch(moredata, [2, None, 20])
        >>> assert x.eval().shape == (2, 20, 20, 40, 50)

    3. Mixing both of these together, ``total_size = (10, None, 30, Ellipsis, 50)``:

        >>> x = Minibatch(moredata, [2, None, 20, Ellipsis, 10])
        >>> assert x.eval().shape == (2, 20, 20, 40, 10)

[59]:
%load_ext watermark
%watermark -n -u -v -iv -w
pandas  1.0.4
pymc3   3.9.0
theano  1.0.4
numpy   1.18.5
seaborn 0.10.1
last updated: Mon Jun 15 2020

CPython 3.7.7
IPython 7.15.0
watermark 2.0.2