Skip to content

Enforce static incidence matrix in *TimeStateTransitionModel

Despite type annotations, it is possible for a user to pass a JAX Array into the constructor for DiscreteTimeStateTransitionModel. Unfortunately, the computation of source and destination states for each transition arrow then depend on the values of this array. If the values are dynamic (as with jax.Array), then we get a JAX tracing error.

See here

A solution would be to explicitly coerce the incidence matrix to a numpy.NDArray as a defensive procedure.