JAX pointwise jacobian compile time
In pyadjoint_utils/jax_adjoint/jax.py
, there are a bunch of functions named _compiled_stack_XXX
, such as
@partial(jax.jit, static_argnums=(1,))
def _compiled_stack_j(point_jacs, j):
return jnp.stack(tuple(p[j] for p in point_jacs))
This results in very long compile times for large meshes because XLA unrolls the python-mode loop in the generator (p[j] for p in point_jacs
). We should replace these functions with ones that use JAX's loop primitives to improve compile times.
The reason this isn't straightforward is that you can't index a tuple or a list with something that is also a loop index in a JAX compiled loop (unless you unroll in python mode like we're doing currently), because that has type Traced<ShapedArray(int64[]):JaxprTrace(level=1/1)>
, and does not support the __index__
method.
We might choose to pad the point_jacs
tuple of arrays so that we can call jnp.asarray()
on it (right now we can't because it contains arrays of different sizes), but that would also be awkward and probably require maintaining another large array of indices to unpack it into the correctly-shaped result.