From 8e6ea002ee11ab45238eb8bfc8fbccf93b4cab42 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 10 Jun 2024 16:52:31 +0200 Subject: [PATCH 1/4] Update env file --- environment.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/environment.yml b/environment.yml index 697fe8084..5116787f8 100644 --- a/environment.yml +++ b/environment.yml @@ -6,9 +6,9 @@ dependencies: - matplotlib-base - nnpops - pip - - pytorch<2.2 - - pytorch_geometric<2.5 - - lightning<2.2 + - pytorch + - pytorch_geometric + - lightning - pydantic - torchmetrics - tqdm @@ -16,4 +16,4 @@ dependencies: - flake8 - pytest - psutil - - gxx<12 + - gxx From dbc1681261a241fe123d47191b41673bed5ba11f Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 10 Jun 2024 17:01:36 +0200 Subject: [PATCH 2/4] Disable a test that uses torch.compile in python 3.12 --- tests/test_neighbors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index dc2b113ee..c40e9e8a3 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -669,6 +669,10 @@ def test_per_batch_box(device, strategy, n_batches, use_forward): @pytest.mark.parametrize("loop", [True, False]) @pytest.mark.parametrize("include_transpose", [True, False]) def test_torch_compile(device, dtype, loop, include_transpose): + import sys + + if sys.version_info >= (3, 12): + pytest.skip("Not available in this version") if torch.__version__ < "2.0.0": pytest.skip("Not available in this version") if device == "cuda" and not torch.cuda.is_available(): From 84d6e6a489567297933a495acd3680da75a82917 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 10 Jun 2024 17:01:53 +0200 Subject: [PATCH 3/4] Use __call__ instead of forward for compatibility with previous versions of geometric --- torchmdnet/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 108a1915e..8fee82afa 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -25,7 +25,7 @@ def __init__(self, dtype=torch.float64): super(FloatCastDatasetWrapper, self).__init__() self._dtype = dtype - def forward(self, data): + def __call__(self, data): for key, value in data: if torch.is_tensor(value) and torch.is_floating_point(value): setattr(data, key, value.to(self._dtype)) @@ -41,7 +41,7 @@ def __init__(self, atomref): super(EnergyRefRemover, self).__init__() self._atomref = atomref - def forward(self, data): + def __call__(self, data): self._atomref = self._atomref.to(data.z.device).type(data.y.dtype) if "y" in data: data.y.index_add_(0, data.batch, -self._atomref[data.z]) From e16508535abd9aaca6f49ca99609dd95e15cba07 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 10 Jun 2024 17:10:36 +0200 Subject: [PATCH 4/4] Remove old mention of a bug in mamba --- docs/source/installation.rst | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/source/installation.rst b/docs/source/installation.rst index ad21b2d3b..01a10a480 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -54,11 +54,9 @@ It is recommended to install the same version as the one used by torch. .. code-block:: shell - conda install -c conda-forge cuda-nvcc cuda-libraries-dev cuda-version "gxx<12" pytorch=*=*cuda* + conda install -c conda-forge cuda-nvcc cuda-libraries-dev cuda-version gxx pytorch=*=*cuda* + - -.. warning:: gxx<12 is required due to a `bug in GCC+CUDA12 `_ that prevents pybind11 from compiling correctly - * CUDA<12