-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dtype parameter to External #274
Conversation
@AntonioMirarchi do you see any issues with this? It should be retrocompatible. |
From a users point of view maybe we should consider to use "single" and "double" precision as init args and then a dict to convert them. I use the external module from a yaml file in torchmd or openMM and I cannot specify a torch dtype. Furthermore, usually platform as torchmd and openMM have their own control over the precision so I don't know if maybe some assertion is needed in order to have compatibility between model and platform (but of course this is out of this PR). |
With OpenMM one should not use External. TorchForce and TorchMD_Net get along well. I made it so that dtype can be passed as a string. So in a yaml you can put: External:
dtype: float32 Or any valid torch.dtype (double, float64, half,...) |
Niceee, I will try to run something to be sure but seems fine! |
Allow External to take kwargs for load_model
LGTM, I can run sims with it using torchmd |
This allows to use External with any dtype.
In addition to being useful for TorchMD integration, I normally use External to put models into CUDA graphs, as it handles all the nuisances transparently.