diff --git a/docs/source/installation.rst b/docs/source/installation.rst index ad21b2d3..01a10a48 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -54,11 +54,9 @@ It is recommended to install the same version as the one used by torch. .. code-block:: shell - conda install -c conda-forge cuda-nvcc cuda-libraries-dev cuda-version "gxx<12" pytorch=*=*cuda* + conda install -c conda-forge cuda-nvcc cuda-libraries-dev cuda-version gxx pytorch=*=*cuda* + - -.. warning:: gxx<12 is required due to a `bug in GCC+CUDA12 `_ that prevents pybind11 from compiling correctly - * CUDA<12 diff --git a/environment.yml b/environment.yml index 697fe808..5116787f 100644 --- a/environment.yml +++ b/environment.yml @@ -6,9 +6,9 @@ dependencies: - matplotlib-base - nnpops - pip - - pytorch<2.2 - - pytorch_geometric<2.5 - - lightning<2.2 + - pytorch + - pytorch_geometric + - lightning - pydantic - torchmetrics - tqdm @@ -16,4 +16,4 @@ dependencies: - flake8 - pytest - psutil - - gxx<12 + - gxx diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index dc2b113e..c40e9e8a 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -669,6 +669,10 @@ def test_per_batch_box(device, strategy, n_batches, use_forward): @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) def test_torch_compile(device, dtype, loop, include_transpose): + import sys + + if sys.version_info >= (3, 12): + pytest.skip("Not available in this version") if torch.__version__ < "2.0.0": pytest.skip("Not available in this version") if device == "cuda" and not torch.cuda.is_available(): diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 108a1915..8fee82af 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -25,7 +25,7 @@ def __init__(self, dtype=torch.float64): super(FloatCastDatasetWrapper, self).__init__() self._dtype = dtype - def forward(self, data): + def __call__(self, data): for key, value in data: if torch.is_tensor(value) and torch.is_floating_point(value): setattr(data, key, value.to(self._dtype)) @@ -41,7 +41,7 @@ def __init__(self, atomref): super(EnergyRefRemover, self).__init__() self._atomref = atomref - def forward(self, data): + def __call__(self, data): self._atomref = self._atomref.to(data.z.device).type(data.y.dtype) if "y" in data: data.y.index_add_(0, data.batch, -self._atomref[data.z])