torch adjoint
This MR makes a pyadjoint_utils.torch_adjoint
module, intended to be a pytorch version of what our jax_adjoint
module is for JAX, by introducing a function overload_torch()
that does what overload_jax()
does for JAX functions, but for pytorch functions. It also has an overloaded torch.Tensor
object.
More to come as I write this module over the next few weeks/months (I'm mostly writing this to learn the internals of torch better than I do now for other reasons, so it might not be very fast). Feel free to contribute if you have any ideas or just want this done faster than I'm personally working on it.
CC @mikemccabe since I know this is something you've asked about before
Okay, here is how everything works: I have created the directory pyadjoint_utils/numpy_backend
, which contains a base class, NumpyBackend
(in base.py
) that defines a numpy backend as a class that implements the numpy API plus the following functions:
-
jit(self, f, *args, **kwargs)
(implements an equivalent of thejax.jit
API) -
vmap(self, f, in_axes=0, out_axes=0)
(implements thejax.vmap
API) -
init_rng(self, seed=None)
(seeds the RNG) -
rand(self, *args, **kwargs)
(implements thenumpy.random.rand
API) -
script_if_tracing(self, f)
(implementstorch.jit.script_if_tracing
, which is a no-op in JAX and plain numpy) -
vjp(self, f, *inputs)
(implementsjax.vjp
-like API) -
array(self, x)
(likejpn.array
) -
concatenate(self, tensors, dim=0)
(implementsjnp.concatenate
) -
einsum(self, *args, optimize=None, **kwargs)
(paper over the fact that jax's einsum has anoptimize
kwarg that we want to beTrue
but torch does not have that) -
index_update(self, x, idx, y)
implementx[idx] = y
and return the updatedx
(because JAX tensors are immutable) -
tensordot(self, *args, axes=0, **kwargs)
(paper over the fact that torch namesaxes
dims
)
All calls to jnp.xxx
, jax.xxx
, or torch.xxx
in the relevant parts of the crikit
codebase (such as crikit.invariants
and crikit.cr.CR
are thus replaced with backend.xxx
so that they can be called with any backend without changing the source code. Additionally, in the jax_utils.py
and torch_utils.py
files, we have the JAXFunctionJITTracer
and TorchFunctionJITTracer
, which both implement trace_and_overload(self, f, args, jit=True, argnums=None, pointwise=None, out_pointwise=None, strict=False, check=False)
, which optionally trace and compile (iff jit==True
) a function given inputs args
, and overload that function (not optional). We replace calls to overload_jax
and overload_torch
with calls to compiler.trace_and_overload
for a given instance (named compiler
) of a JAXFunctionJITTracer
or TorchFunctionJITTracer
.
The user can now create a crikit.cr.CR
with either a jax or a torch inner function, simply by passing backend='jax'
or backend='torch'
to CR.__init__()
.
The rest of the changes are mostly just adding torch versions of the various JAX classes (e.g. TorchTensor
is the torch version of JAXArrays
) and functions (e.g. overload_torch
to go along with overload_jax
).
I also did change the API for overload_jax
and overload_torch
to replace the nojit=False
param with jit=True
(@jedbrown thinks that jit
is a better API than nojit
and I agree)
However, I did also fix an issue in crikit.covering.get_map()
, which had the following erroneous line of code:
if source == target:
return IdentityPointMap(s, t)
which has been replaced with
if source == target:
return IdentityPointMap(source, target)
Additionally, we now have limited support for ONNX CRs (see test_2d_torch_saved_cr``tests/crikit/test_invariant_cr.py
, with the constraint that we can't vmap
or differentiate them so until we can figure that out, we will have trouble using those trained ONNX CRs in actual simulations (since we can't get the Jacobian or use it on a different mesh size than it was saved with -- in order to trace the model and convert it to ONNX we have to fix the input size). But the framework is there, so it's likely just a matter of getting functorch.jacobian
to behave as we want (and to write a traceable pointwise Jacobian implementation for efficiency). Also, if we want to export the entire CR and not just the inner function, we'll have to resolve this issue I opened with PyTorch