Skip to content

Commit

Permalink
Update neighbor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez committed Jan 17, 2024
1 parent 1151cc1 commit 5ff701d
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,9 @@ def test_cuda_graph_compatible_backward(

@pytest.mark.parametrize(("device", "strategy"), [("cpu", "brute"), ("cuda", "brute"), ("cuda", "shared")])
@pytest.mark.parametrize("n_batches", [1, 128])
@pytest.mark.parametrize("use_forward", [True, False])
def test_per_batch_box(
device, strategy, n_batches,
device, strategy, n_batches, use_forward
):
dtype = torch.float32
cutoff = 1.0
Expand Down Expand Up @@ -619,12 +620,12 @@ def test_per_batch_box(
cutoff_upper=cutoff,
max_num_pairs=max_num_pairs,
strategy=strategy,
box=box,
box=box if not use_forward else None,
return_vecs=True,
include_transpose=include_transpose,
)
batch.to(device)
neighbors, distances, distance_vecs = nl(pos, batch)
neighbors, distances, distance_vecs = nl(pos, batch, box=box if use_forward else None)
neighbors = neighbors.cpu().detach().numpy()
distance_vecs = distance_vecs.cpu().detach().numpy()
distances = distances.cpu().detach().numpy()
Expand Down

0 comments on commit 5ff701d

Please sign in to comment.