Skip to content

torch adjoint

Emily Jakobs requested to merge emily/torch-adjoint into main

This MR makes a pyadjoint_utils.torch_adjoint module, intended to be a pytorch version of what our jax_adjoint module is for JAX, by introducing a function overload_torch() that does what overload_jax() does for JAX functions, but for pytorch functions. It also has an overloaded torch.Tensor object.

More to come as I write this module over the next few weeks/months (I'm mostly writing this to learn the internals of torch better than I do now for other reasons, so it might not be very fast). Feel free to contribute if you have any ideas or just want this done faster than I'm personally working on it.

CC @mikemccabe since I know this is something you've asked about before

Okay, here is how everything works: I have created the directory pyadjoint_utils/numpy_backend, which contains a base class, NumpyBackend (in base.py) that defines a numpy backend as a class that implements the numpy API plus the following functions:

  1. jit(self, f, *args, **kwargs) (implements an equivalent of the jax.jit API)
  2. vmap(self, f, in_axes=0, out_axes=0) (implements the jax.vmap API)
  3. init_rng(self, seed=None) (seeds the RNG)
  4. rand(self, *args, **kwargs) (implements the numpy.random.rand API)
  5. script_if_tracing(self, f) (implements torch.jit.script_if_tracing, which is a no-op in JAX and plain numpy)
  6. vjp(self, f, *inputs) (implements jax.vjp-like API)
  7. array(self, x) (like jpn.array)
  8. concatenate(self, tensors, dim=0) (implements jnp.concatenate)
  9. einsum(self, *args, optimize=None, **kwargs) (paper over the fact that jax's einsum has an optimize kwarg that we want to be True but torch does not have that)
  10. index_update(self, x, idx, y) implement x[idx] = y and return the updated x (because JAX tensors are immutable)
  11. tensordot(self, *args, axes=0, **kwargs) (paper over the fact that torch names axes dims)

All calls to jnp.xxx, jax.xxx, or torch.xxx in the relevant parts of the crikit codebase (such as crikit.invariants and crikit.cr.CR are thus replaced with backend.xxx so that they can be called with any backend without changing the source code. Additionally, in the jax_utils.py and torch_utils.py files, we have the JAXFunctionJITTracer and TorchFunctionJITTracer, which both implement trace_and_overload(self, f, args, jit=True, argnums=None, pointwise=None, out_pointwise=None, strict=False, check=False), which optionally trace and compile (iff jit==True) a function given inputs args, and overload that function (not optional). We replace calls to overload_jax and overload_torch with calls to compiler.trace_and_overload for a given instance (named compiler) of a JAXFunctionJITTracer or TorchFunctionJITTracer.

The user can now create a crikit.cr.CR with either a jax or a torch inner function, simply by passing backend='jax' or backend='torch' to CR.__init__().

The rest of the changes are mostly just adding torch versions of the various JAX classes (e.g. TorchTensor is the torch version of JAXArrays) and functions (e.g. overload_torch to go along with overload_jax).

I also did change the API for overload_jax and overload_torch to replace the nojit=False param with jit=True (@jedbrown thinks that jit is a better API than nojit and I agree)

However, I did also fix an issue in crikit.covering.get_map(), which had the following erroneous line of code:

if source == target:
        return IdentityPointMap(s, t)

which has been replaced with

if source == target:
        return IdentityPointMap(source, target)

Additionally, we now have limited support for ONNX CRs (see test_2d_torch_saved_cr``tests/crikit/test_invariant_cr.py, with the constraint that we can't vmap or differentiate them so until we can figure that out, we will have trouble using those trained ONNX CRs in actual simulations (since we can't get the Jacobian or use it on a different mesh size than it was saved with -- in order to trace the model and convert it to ONNX we have to fix the input size). But the framework is there, so it's likely just a matter of getting functorch.jacobian to behave as we want (and to write a traceable pointwise Jacobian implementation for efficiency). Also, if we want to export the entire CR and not just the inner function, we'll have to resolve this issue I opened with PyTorch

Edited by Emily Jakobs

Merge request reports