Skip to content

Commit

Permalink
Merge pull request #329 from RaulPPelaez/update_env
Browse files Browse the repository at this point in the history
Update environment file
  • Loading branch information
RaulPPelaez authored Jun 10, 2024
2 parents e224dc5 + e165085 commit 440f985
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
6 changes: 2 additions & 4 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,9 @@ It is recommended to install the same version as the one used by torch.

.. code-block:: shell
conda install -c conda-forge cuda-nvcc cuda-libraries-dev cuda-version "gxx<12" pytorch=*=*cuda*
conda install -c conda-forge cuda-nvcc cuda-libraries-dev cuda-version gxx pytorch=*=*cuda*
.. warning:: gxx<12 is required due to a `bug in GCC+CUDA12 <https://github.com/pybind/pybind11/issues/4606>`_ that prevents pybind11 from compiling correctly

* CUDA<12

Expand Down
8 changes: 4 additions & 4 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ dependencies:
- matplotlib-base
- nnpops
- pip
- pytorch<2.2
- pytorch_geometric<2.5
- lightning<2.2
- pytorch
- pytorch_geometric
- lightning
- pydantic
- torchmetrics
- tqdm
# Dev tools
- flake8
- pytest
- psutil
- gxx<12
- gxx
4 changes: 4 additions & 0 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,10 @@ def test_per_batch_box(device, strategy, n_batches, use_forward):
@pytest.mark.parametrize("loop", [True, False])
@pytest.mark.parametrize("include_transpose", [True, False])
def test_torch_compile(device, dtype, loop, include_transpose):
import sys

if sys.version_info >= (3, 12):
pytest.skip("Not available in this version")
if torch.__version__ < "2.0.0":
pytest.skip("Not available in this version")
if device == "cuda" and not torch.cuda.is_available():
Expand Down
4 changes: 2 additions & 2 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, dtype=torch.float64):
super(FloatCastDatasetWrapper, self).__init__()
self._dtype = dtype

def forward(self, data):
def __call__(self, data):
for key, value in data:
if torch.is_tensor(value) and torch.is_floating_point(value):
setattr(data, key, value.to(self._dtype))
Expand All @@ -41,7 +41,7 @@ def __init__(self, atomref):
super(EnergyRefRemover, self).__init__()
self._atomref = atomref

def forward(self, data):
def __call__(self, data):
self._atomref = self._atomref.to(data.z.device).type(data.y.dtype)
if "y" in data:
data.y.index_add_(0, data.batch, -self._atomref[data.z])
Expand Down

0 comments on commit 440f985

Please sign in to comment.