Skip to content

A JointDistribution wishlist

We commonly make use of TensorFlow Probability's JointDistributionCoroutineAutoBatched to embed state transition models within larger Bayesian hierarchical DAGs, e.g. the IDDInf gemlib end-to-end example tutorial. JDCoroutineAB make life easy for expressing probability models, and have a couple of very useful features:

  1. Each stochastic node in the Bayesian DAG is marked with a Python yield statement, drawing the reader's eye towards the structure of the DAG.
  2. The model object provides a sample method which optionally takes "pinned" values of any known parameters, as well as a log_prob method for evaluating the log probability density function.
  3. The AutoBatched version treats the output of the sample method as a single sample from the joint probability model, such that the log_prob function maps the entire sample to a scalar log-probability consistent with the Giry monad.

However, it lacks a couple of features that are important for Bayesians:

  1. Sampling over batches of parameters, as is required for estimating predictive distributions, is currently not possible due to seed handling within the JDCoroutineAB logic (see here). Instead, the user must manually vmap over (parameter) samples from a posterior distribution.
  2. It is currently not possible to "trace" values of interest from within a probability model, such as deterministic quantities that are computed between instantiations of stochastic quantities. The best we can do is to use a tfd.Deterministic "distribution", and then feed dummy values to the JD.log_prob method with a large value for the tol argument.
  3. We'd like to be able to directly parameterise the JDCoroutingAutoBatched object. In fact, it would be better if the decorator emitted a class, rather than an object of type JDCoroutingAutoBatched.

Given that the base JointDistributionCoroutine is a highly sophisticated class that tries to handle batching for inference purposes, it may be better to roll our own simplified class based on its principles, but that sacrifices the TFP batching semantics for our own more domain-specialised semantics.

Edited by Chris Jewell