Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speed up scipy optimize #25

Open
jacobpennington opened this issue Aug 10, 2022 · 2 comments
Open

speed up scipy optimize #25

jacobpennington opened this issue Aug 10, 2022 · 2 comments
Labels
enhancement New feature or request

Comments

@jacobpennington
Copy link
Owner

jacobpennington commented Aug 10, 2022

Try out the suggestion here re: using JAX library to compute cost function gradient and providing that information to scipy.
(for their specific example, quoting a ~5000x speedup)
https://stackoverflow.com/questions/68507176/faster-scipy-optimizations

A little more involved than I first thought, since np.<whatever> operations have to be replaced with jnp.<whatever>. But aside from a few caveats like not using in-place operations most Layer implementations would be otherwise identical, so this could be added in as a backend (and would be much simpler than TF, just define evaluate_jax and still use scipy, but with a hook to use the gradient version). Without configuring GPU usage this would still be slower than TF, but may be a good intermediate option that's still much faster than vanilla scipy/numpy and easier for new users to implement.

Separately, try adding numba to the standard scipy evaluates (http://numba.pydata.org/). It looks like it's supposed to work with standard numpy unlike JAX, so may be simple to integrate improvements.

@jacobpennington jacobpennington added the enhancement New feature or request label Aug 10, 2022
@jacobpennington
Copy link
Owner Author

Follow-up note: looks like there's even a jax.scipy.optimize.minimize function, so I'm hopeful this would slot in without too many headaches.

@jacobpennington
Copy link
Owner Author

Notes on Numba: doesn't seem like it will be worth it for most of the core Layers. Numpy functions that are already heavily optimized (like np.dot and np.convolve) already perform about as well or better. We might see a big benefit for slow Layers with lots of for loops and plain python (like STP), but still need to test that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant