Skip to content

Commit

Permalink
Fix bug with custom kernels not working on non-default device. (#87)
Browse files Browse the repository at this point in the history
* add thread lock to custom kernels, remove unneeded import

* update test case dtypes
  • Loading branch information
lubbersnick authored Aug 14, 2024
1 parent d565b18 commit cf16e9f
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 12 deletions.
35 changes: 32 additions & 3 deletions hippynn/custom_kernels/autograd_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@
"""
import torch.autograd

import threading
from contextlib import contextmanager

_DEVICE_CONTEXT_LOCK = threading.Lock()
_DEVICE_TIMEOUT = 10 # if custom kernels have locked for 10S, throw an error


@contextmanager
def _lock_device(tensor):
"""
This function locks the torch.cuda.device, which affects how
triton and cupy try to launch their kernels.
:param tensor:
:return:
"""
acquired = _DEVICE_CONTEXT_LOCK.acquire(timeout=_DEVICE_TIMEOUT)

if not acquired:
raise TimeoutError(f"Custom kernel device-lock appears deadlocked. (exceeded timeout {_DEVICE_CONTEXT_LOCK})")
try:
# Developer note: device_of is safe to CPU tensors, but torch.cuda.device is not!
with torch.cuda.device_of(tensor):
yield
finally:
_DEVICE_CONTEXT_LOCK.release()


def wrap_envops(envsum_impl, sensesum_impl, featsum_impl):
"""
Expand All @@ -22,7 +48,8 @@ def forward(ctx, sense, feat, pfirst, psecond):
if n_pair != 0 or psecond.shape[0] != 0:
raise ValueError("Inconsistent shapes for envsum.")
return torch.zeros((n_atom, n_nu, n_feat), dtype=feat.dtype, device=feat.device)
env = envsum_impl(sense, feat, pfirst, psecond)
with _lock_device(feat):
env = envsum_impl(sense, feat, pfirst, psecond)
return env

@staticmethod
Expand Down Expand Up @@ -52,7 +79,8 @@ def forward(ctx, env, feat, pfirst, psecond):
if psecond.shape[0] != 0 or n_atom0 != n_atom1 or n_feat0 != n_feat1:
raise ValueError("Inconsistent shapes for sensesum")
return torch.zeros((0, n_nu), dtype=feat.dtype, device=feat.device)
sense = sensesum_impl(env, feat, pfirst, psecond)
with _lock_device(feat):
sense = sensesum_impl(env, feat, pfirst, psecond)
return sense

@staticmethod
Expand All @@ -75,7 +103,8 @@ def forward(ctx, env, sense, pfirst, psecond):
if psecond.shape[0] != 0 or n_nu0 != n_nu1:
raise ValueError("Inconsistent shapes for featsum")
return torch.zeros((n_atom, n_feat), dtype=env.dtype, device=env.device)
feat = featsum_impl(env, sense, pfirst, psecond)
with _lock_device(env):
feat = featsum_impl(env, sense, pfirst, psecond)
return feat

@staticmethod
Expand Down
3 changes: 0 additions & 3 deletions hippynn/layers/pairs/periodic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import torch

from scipy.spatial import KDTree

from .open import _PairIndexer, PairMemory
from torch.profiler import profile, record_function, ProfilerActivity

# Deprecated?
class StaticImagePeriodicPairIndexer(_PairIndexer):
Expand Down
6 changes: 3 additions & 3 deletions tests/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Below are some test cases for the optimizer.

from test_configs import c2h6_config as C2H6, qm9b1_config as QM9b1
from .test_configs import c2h6_config as C2H6, qm9b1_config as QM9b1

# C2H6 contain 15 ethane conformers with C-C bond elongated.
# 'E'/'F' is the B97-3c energy/forces calculated by ORCA4
Expand All @@ -13,8 +13,8 @@
# To test the numerical stability of the optimizer,
# I add meaningless 0 paddings to C2H6 to form this C2H6_padded
C2H6_padded = {
'Z': torch.cat((C2H6['Z'].clone(), torch.zeros(C2H6['Z'].shape[0], 2)), dim=1),
'Z': torch.cat((C2H6['Z'].clone(), torch.zeros(C2H6['Z'].shape[0], 2,dtype=torch.int64)), dim=1),
'R': torch.cat((C2H6['R'].clone(), torch.zeros(C2H6['R'].shape[0], 2, 3)), dim=1),
'E': C2H6['E'].clone(),
'F': torch.cat((C2H6['F'].clone(), torch.zeros(C2H6['F'].shape[0], 2, 3)), dim=1),
}
}
6 changes: 3 additions & 3 deletions tests/optimizer/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
[6, 6, 1, 1, 1, 1, 1, 1],
[6, 6, 1, 1, 1, 1, 1, 1],
],
dtype=torch.int32,
dtype=torch.int64,
),
"R": torch.tensor(
[
Expand Down Expand Up @@ -358,8 +358,8 @@
[6, 6, 6, 6, 6, 6, 6, 6, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[6, 6, 7, 6, 6, 6, 8, 6, 8, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[6, 6, 6, 8, 6, 8, 6, 6, 8, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
),
],
dtype=torch.int64),
"R": torch.tensor(
[
[
Expand Down

0 comments on commit cf16e9f

Please sign in to comment.