Major JAX efficiency upgrade by building from source
I've found that one can get major performance improvements (compiled code executes in about half the time) in JAX code if one builds JAX from source. This is very easy to do (much easier than building TensorFlow from source), and actually simple enough that I'm pretty confident that building JAX from source in our setup.py
wouldn't cause problems. However, it does take a while to build (20 mins+ on my laptop), and the compiled JAX code is probably not the bottleneck for most solves (especially larger ones), so I'm not sure if it's a good idea to build JAX from source ourselves in the setup.py or just include some documentation saying that it's probably a good idea to do yourself, and instructions for how to do it. I would probably learn towards the second (default install the PyPI JAX, include instructions to build from source) myself, but just building JAX when we install seems viable too. Anybody else have thoughts on this?