Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dtype parameter to External #274

Merged
merged 11 commits into from
Feb 19, 2024
Merged
70 changes: 33 additions & 37 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,45 @@
from utils import create_example_batch


def test_compare_forward():
checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt")
z, pos, _ = create_example_batch(multiple_batches=False)
calc = External(checkpoint, z.unsqueeze(0))
model = load_model(checkpoint, derivative=True)

e_calc, f_calc = calc.calculate(pos, None)
e_pred, f_pred = model(z, pos)

assert_allclose(e_calc, e_pred)
assert_allclose(f_calc, f_pred.unsqueeze(0))

@pytest.mark.parametrize("box", [None, torch.eye(3)])
def test_compare_forward_cuda_graph(box):
if not torch.cuda.is_available():
@pytest.mark.parametrize("use_cuda_graphs", [True, False])
def test_compare_forward(box, use_cuda_graphs):
if use_cuda_graphs and not torch.cuda.is_available():
pytest.skip("CUDA not available")
checkpoint = join(dirname(dirname(__file__)), "tests", "example.ckpt")
args = {"model": "tensornet",
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32 }
model = create_model(args).to(device="cuda")
args = {
"model": "tensornet",
"embedding_dimension": 128,
"num_layers": 2,
"num_rbf": 32,
"rbf_type": "expnorm",
"trainable_rbf": False,
"activation": "silu",
"cutoff_lower": 0.0,
"cutoff_upper": 5.0,
"max_z": 100,
"max_num_neighbors": 128,
"equivariance_invariance_group": "O(3)",
"prior_model": None,
"atom_filter": -1,
"derivative": True,
"output_model": "Scalar",
"reduce_op": "sum",
"precision": 32,
}
device = "cpu" if not use_cuda_graphs else "cuda"
model = create_model(args).to(device=device)
z, pos, _ = create_example_batch(multiple_batches=False)
z = z.to("cuda")
pos = pos.to("cuda")
calc = External(checkpoint, z.unsqueeze(0), use_cuda_graph=False, device="cuda")
calc_graph = External(checkpoint, z.unsqueeze(0), use_cuda_graph=True, device="cuda")
z = z.to(device)
pos = pos.to(device)
calc = External(checkpoint, z.unsqueeze(0), use_cuda_graph=False, device=device)
calc_graph = External(
checkpoint, z.unsqueeze(0), use_cuda_graph=use_cuda_graphs, device=device
)
calc.model = model
calc_graph.model = model
if box is not None:
box = (box * 2 * args["cutoff_upper"]).unsqueeze(0)
for _ in range(10):
e_calc, f_calc = calc.calculate(pos, box)
e_pred, f_pred = calc_graph.calculate(pos, box)
Expand Down
43 changes: 38 additions & 5 deletions torchmdnet/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from torchmdnet.models.model import load_model
import warnings

# dict of preset transforms
tranforms = {
Expand Down Expand Up @@ -42,6 +43,10 @@ class External:
Whether to use CUDA graphs to speed up the calculation. Default: False
cuda_graph_warmup_steps : int, optional
Number of steps to run as warmup before recording the CUDA graph. Default: 12
dtype : torch.dtype or str, optional
Cast the input to this dtype if defined. If passed as a string it should be a valid torch dtype. Default: torch.float32
kwargs : dict, optional
Extra arguments to pass to the model when loading it.
"""

def __init__(
Expand All @@ -52,10 +57,28 @@ def __init__(
output_transform=None,
use_cuda_graph=False,
cuda_graph_warmup_steps=12,
dtype=torch.float32,
**kwargs,
):
if isinstance(netfile, str):
self.model = load_model(netfile, device=device, derivative=True)
extra_args = kwargs
if use_cuda_graph:
warnings.warn(
"CUDA graphs are enabled, setting static_shapes=True and check_errors=False"
)
extra_args["static_shapes"] = True
extra_args["check_errors"] = False
self.model = load_model(
netfile,
device=device,
derivative=True,
**extra_args,
)
elif isinstance(netfile, torch.nn.Module):
if kwargs:
warnings.warn(
"Warning: extra arguments are being ignored when passing a torch.nn.Module"
)
self.model = netfile
else:
raise ValueError(
Expand Down Expand Up @@ -87,6 +110,12 @@ def __init__(
self.forces = None
self.box = None
self.pos = None
if isinstance(dtype, str):
try:
dtype = getattr(torch, dtype)
except AttributeError:
raise ValueError(f"Unknown torch dtype {dtype}")
self.dtype = dtype

def _init_cuda_graph(self):
stream = torch.cuda.Stream()
Expand All @@ -101,7 +130,7 @@ def _init_cuda_graph(self):
self.embeddings, self.pos, self.batch, self.box
)

def calculate(self, pos, box = None):
def calculate(self, pos, box=None):
"""Calculate the energy and forces of the system.

Parameters
Expand All @@ -118,7 +147,9 @@ def calculate(self, pos, box = None):
forces : torch.Tensor
Forces on the atoms in the system.
"""
pos = pos.to(self.device).type(torch.float32).reshape(-1, 3)
pos = pos.to(self.device).to(self.dtype).reshape(-1, 3)
if box is not None:
box = box.to(self.device).to(self.dtype)
if self.use_cuda_graph:
if self.pos is None:
self.pos = (
Expand All @@ -128,10 +159,12 @@ def calculate(self, pos, box = None):
.requires_grad_(pos.requires_grad)
)
if self.box is None and box is not None:
self.box = box.clone().to(self.device).detach()
self.box = box.clone().to(self.device).to(self.dtype).detach()
if self.cuda_graph is None:
self._init_cuda_graph()
assert self.cuda_graph is not None, "CUDA graph is not initialized. This should not had happened."
assert (
self.cuda_graph is not None
), "CUDA graph is not initialized. This should not had happened."
with torch.no_grad():
self.pos.copy_(pos)
if box is not None:
Expand Down
Loading