From 7a99e4b11a86d0e7932654d4b5aa03f2782cae5b Mon Sep 17 00:00:00 2001 From: Stefan Doerr Date: Tue, 21 Jan 2025 15:02:41 +0200 Subject: [PATCH] fixes for Windows compilation of CUDA extensions. Replacing 'and' with && and 'or' with || and fixing the call to std::max which can sometimes have long long vs long type --- torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh | 2 +- torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh | 6 +++--- torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh index 7df3317d4..e6c89c8bb 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh @@ -99,7 +99,7 @@ forward_brute(const Tensor& positions, const Tensor& batch, const Tensor& in_box const CUDAStreamGuard guard(stream); const uint64_t num_all_pairs = num_atoms * (num_atoms - 1UL) / 2UL; const uint64_t num_threads = 128; - const uint64_t num_blocks = std::max((num_all_pairs + num_threads - 1UL) / num_threads, 1UL); + const uint64_t num_blocks = std::max(static_cast((num_all_pairs + num_threads - 1UL) / num_threads), static_cast(1UL)); AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() { PairListAccessor list_accessor(list); auto box = triclinic::get_box_accessor(box_vectors, use_periodic); diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh index 193a79898..7144a87ed 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh @@ -266,7 +266,7 @@ template struct CellListAccessor { template __device__ void addNeighborPair(PairListAccessor& list, const int i, const int j, scalar_t distance2, scalar3 delta) { - const bool requires_transpose = list.include_transpose and (j != i); + const bool requires_transpose = list.include_transpose && (j != i); const int ni = max(i, j); const int nj = min(i, j); const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0); @@ -292,8 +292,8 @@ __device__ void addNeighborsForCell(const Particle& i_atom, int j_cell const auto last_particle = cl.cell_end[j_cell]; for (int cur_j = first_particle; cur_j < last_particle; cur_j++) { const auto j_batch = cl.sorted_batch[cur_j]; - if ((j_batch == i_atom.batch) and - ((cur_j < i_atom.index) || (list.loop and cur_j == i_atom.index))) { + if ((j_batch == i_atom.batch) && + ((cur_j < i_atom.index) || (list.loop && cur_j == i_atom.index))) { const auto position_j = fetchPosition(cl.sorted_positions, cur_j); const auto delta = rect::compute_distance(i_atom.position, position_j, list.use_periodic, box_size); diff --git a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh index 9c4523f50..e7cdfc257 100644 --- a/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh +++ b/torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh @@ -50,7 +50,7 @@ __global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor