Storage backends for traces
The NDArray (pymc3.backends.NDArray) backend holds the entire trace in memory.
Selecting values from a backend¶
After a backend is finished sampling, it returns a MultiTrace object. Values can be accessed in a few ways. The easiest way is to index the backend object with a variable or variable name.
>>> trace['x'] # or trace.x or trace[x]
The call will return the sampling values of x, with the values for all chains concatenated. (For a single call to sample, the number of chains will correspond to the cores argument.)
To discard the first N values of each chain, slicing syntax can be used.
>>> trace['x', 1000:]
The get_values method offers more control over which values are returned. The call below will discard the first 1000 iterations from each chain and keep the values for each chain as separate arrays.
>>> trace.get_values('x', burn=1000, combine=False)
The chains parameter of get_values can be used to limit the chains that are retrieved.
>>> trace.get_values('x', burn=1000, chains=[0, 2])
MultiTrace objects also support slicing. For example, the following call would return a new trace object without the first 1000 sampling iterations for all traces and variables.
>>> sliced_trace = trace[1000:]
The backend for the new trace is always NDArray, regardless of the type of original trace.
Loading a saved backend¶
Saved backends can be loaded using arviz.from_netcdf
NumPy array trace backend
Store sampling values in memory as a NumPy array.
- class pymc3.backends.ndarray.NDArray(name=None, model=None, vars=None, test_point=None)¶
NDArray trace object
- name: str
Name of backend. This has no meaning for the NDArray backend.
- model: Model
If None, the model is taken from the with context.
- vars: list of variables
Sampling values will be stored for these variables. If None, model.unobserved_RVs is used.
Close the database backend.
This is called after sampling has finished.
- get_values(varname: str, burn=0, thin=1) numpy.ndarray ¶
Get values from trace.
- varname: str
- burn: int
- thin: int
- A NumPy array
- point(idx) Dict[str, Any] ¶
Return dictionary of point values at idx for current chain with variable names as keys.
- record(point, sampler_stats=None) None ¶
Record results of a sampling iteration.
- point: dict
Values mapped to variable names
- setup(draws, chain, sampler_vars=None) None ¶
Perform chain-specific setup.
- draws: int
Expected number of draws
- chain: int
- sampler_vars: list of dicts
Names and dtypes of the variables that are exported by the samplers.
- pymc3.backends.ndarray.load_trace(directory: str, model=None) pymc3.backends.base.MultiTrace ¶
Loads a multitrace that has been written to file.
A the model used for the trace must be passed in, or the command must be run in a model context.
- directory: str
Path to a pymc3 serialized trace
- model: pm.Model (optional)
Model used to create the trace. Can also be inferred from context
- pm.Multitrace that was saved in the directory
- pymc3.backends.ndarray.point_list_to_multitrace(point_list: List[Dict[str, numpy.ndarray]], model: Optional[pymc3.model.Model] = None) pymc3.backends.base.MultiTrace ¶
transform point list into MultiTrace
- pymc3.backends.ndarray.save_trace(trace: pymc3.backends.base.MultiTrace, directory: Optional[str] = None, overwrite=False) str ¶
Save multitrace to file.
TODO: Also save warnings.
This is a custom data format for PyMC3 traces. Each chain goes inside a directory, and each directory contains a metadata json file, and a numpy compressed file. See https://docs.scipy.org/doc/numpy/neps/npy-format.html for more information about this format.
- trace: pm.MultiTrace
trace to save to disk
- directory: str (optional)
path to a directory to save the trace
- overwrite: bool (default False)
whether to overwrite an existing directory.
- str, path to the directory where the trace was saved
Functions for converting traces into a table-like format
- pymc3.backends.tracetab.trace_to_dataframe(trace, chains=None, varnames=None, include_transformed=False)¶
Convert trace to pandas DataFrame.
- trace: NDarray trace
- chains: int or list of ints
Chains to include. If None, all chains are used. A single chain value can also be given.
- varnames: list of variable names
Variables to be included in the DataFrame, if None all variable are included.
- include_transformed: boolean
If true transformed variables will be included in the resulting DataFrame.