diff --git a/modelforge/tests/test_models.py b/modelforge/tests/test_models.py index f64cce95..41e406af 100644 --- a/modelforge/tests/test_models.py +++ b/modelforge/tests/test_models.py @@ -68,6 +68,7 @@ def test_energy_scaling_and_offset(): def test_forward_pass(inference_model, batch_QM9_ANI2x): # this test sends a single batch from different datasets through the model + batch = batch_QM9_ANI2x nnp_input = batch.nnp_input nr_of_mols = nnp_input.atomic_subsystem_indices.unique().shape[0] @@ -82,6 +83,7 @@ def test_calculate_energies_and_forces(inference_model, batch_QM9_ANI2x): """ Test the calculation of energies and forces for a molecule. """ + batch = batch_QM9_ANI2x import torch nnp_input = batch.nnp_input @@ -319,10 +321,11 @@ def test_pairlist_on_dataset(initialized_dataset): assert shapePairlist[0] == 2 -def test_casting(batch, inference_model): +def test_casting(batch_QM9_ANI2x, inference_model): # test dtype casting import torch + batch = batch_QM9_ANI2x batch_ = batch.to(dtype=torch.float64) assert batch_.nnp_input.positions.dtype == torch.float64 batch_ = batch_.to(dtype=torch.float32) @@ -354,6 +357,7 @@ def test_equivariant_energies_and_forces( Test the calculation of energies and forces for a molecule. NOTE: test will be adapted once we have a trained model. """ + batch = batch_QM9_ANI2x import torch from dataclasses import replace