You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A little more involved than I first thought, since np.<whatever> operations have to be replaced with jnp.<whatever>. But aside from a few caveats like not using in-place operations most Layer implementations would be otherwise identical, so this could be added in as a backend (and would be much simpler than TF, just define evaluate_jax and still use scipy, but with a hook to use the gradient version). Without configuring GPU usage this would still be slower than TF, but may be a good intermediate option that's still much faster than vanilla scipy/numpy and easier for new users to implement.
Separately, try adding numba to the standard scipy evaluates (http://numba.pydata.org/). It looks like it's supposed to work with standard numpy unlike JAX, so may be simple to integrate improvements.
The text was updated successfully, but these errors were encountered:
Notes on Numba: doesn't seem like it will be worth it for most of the core Layers. Numpy functions that are already heavily optimized (like np.dot and np.convolve) already perform about as well or better. We might see a big benefit for slow Layers with lots of for loops and plain python (like STP), but still need to test that.
Try out the suggestion here re: using JAX library to compute cost function gradient and providing that information to scipy.
(for their specific example, quoting a ~5000x speedup)
https://stackoverflow.com/questions/68507176/faster-scipy-optimizations
A little more involved than I first thought, since
np.<whatever>
operations have to be replaced withjnp.<whatever>
. But aside from a few caveats like not using in-place operations most Layer implementations would be otherwise identical, so this could be added in as a backend (and would be much simpler than TF, just define evaluate_jax and still use scipy, but with a hook to use the gradient version). Without configuring GPU usage this would still be slower than TF, but may be a good intermediate option that's still much faster than vanilla scipy/numpy and easier for new users to implement.Separately, try adding
numba
to the standard scipy evaluates (http://numba.pydata.org/). It looks like it's supposed to work with standard numpy unlike JAX, so may be simple to integrate improvements.The text was updated successfully, but these errors were encountered: