Skip to content

Commit

Permalink
Force static_shapes in cuda graph mode
Browse files Browse the repository at this point in the history
Allow External to take kwargs for load_model
  • Loading branch information
RaulPPelaez committed Feb 19, 2024
1 parent 1713911 commit 39b2399
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion torchmdnet/calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class External:
Number of steps to run as warmup before recording the CUDA graph. Default: 12
dtype : torch.dtype or str, optional
Cast the input to this dtype if defined. If passed as a string it should be a valid torch dtype. Default: torch.float32
kwargs : dict, optional
Extra arguments to pass to the model when loading it.
"""

def __init__(
Expand All @@ -55,9 +57,19 @@ def __init__(
use_cuda_graph=False,
cuda_graph_warmup_steps=12,
dtype=torch.float32,
**kwargs,
):
if isinstance(netfile, str):
self.model = load_model(netfile, device=device, derivative=True)
extra_args = kwargs
if use_cuda_graph:
extra_args["static_shapes"] = True
extra_args["check_errors"] = False
self.model = load_model(
netfile,
device=device,
derivative=True,
**extra_args,
)
elif isinstance(netfile, torch.nn.Module):
self.model = netfile
else:
Expand Down

0 comments on commit 39b2399

Please sign in to comment.