A JointDistribution wishlist
We commonly make use of TensorFlow Probability's JointDistributionCoroutineAutoBatched to embed state transition models within larger Bayesian hierarchical DAGs, e.g. the IDDInf gemlib end-to-end example tutorial. JDCoroutineAB make life easy for expressing probability models, and have a couple of very useful features:
- Each stochastic node in the Bayesian DAG is marked with a Python
yieldstatement, drawing the reader's eye towards the structure of the DAG. - The model object provides a
samplemethod which optionally takes "pinned" values of any known parameters, as well as alog_probmethod for evaluating the log probability density function. - The
AutoBatchedversion treats the output of thesamplemethod as a single sample from the joint probability model, such that thelog_probfunction maps the entire sample to a scalar log-probability consistent with the Giry monad.
However, it lacks a couple of features that are important for Bayesians:
- Sampling over batches of parameters, as is required for estimating predictive distributions, is currently not possible due to seed handling within the JDCoroutineAB logic (see here). Instead, the user must manually
vmapover (parameter) samples from a posterior distribution. - It is currently not possible to "trace" values of interest from within a probability model, such as deterministic quantities that are computed between instantiations of stochastic quantities. The best we can do is to use a
tfd.Deterministic"distribution", and then feed dummy values to the JD.log_probmethod with a large value for thetolargument. - We'd like to be able to directly parameterise the JDCoroutingAutoBatched object. In fact, it would be better if the decorator emitted a class, rather than an object of type JDCoroutingAutoBatched.
Given that the base JointDistributionCoroutine is a highly sophisticated class that tries to handle batching for inference purposes, it may be better to roll our own simplified class based on its principles, but that sacrifices the TFP batching semantics for our own more domain-specialised semantics.
Edited by Chris Jewell