Skip to content

Add TAO-pytorch interface

Hong Zhang requested to merge denera/tao-pytorch-optimizer-bindings-rebased into main

Copied from the old MR:

Provides a TAOtorch() object in Python that makes TAO solvers look like PyTorch Optimizer objects and work for standard PyTorch training workflows.

The interface is tested on a canonical MNIST classification problem.

Limitations:

  • PyTorch has to run in whatever precision PETSc has been compiled with.
  • All globalization is disabled, and TAOtorch implements its own adaptive learning rate scale derived from the AMSgrad with options to disable or set a user-defined scale.
Edited by Hong Zhang

Merge request reports