Skip to content

Commit

Permalink
fixes for Windows compilation of CUDA extensions. Replacing 'and' wit…
Browse files Browse the repository at this point in the history
…h && and 'or' with || and fixing the call to std::max which can sometimes have long long vs long type
  • Loading branch information
stefdoerr committed Jan 21, 2025
1 parent afd08df commit 7a99e4b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion torchmdnet/extensions/neighbors/neighbors_cuda_brute.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ forward_brute(const Tensor& positions, const Tensor& batch, const Tensor& in_box
const CUDAStreamGuard guard(stream);
const uint64_t num_all_pairs = num_atoms * (num_atoms - 1UL) / 2UL;
const uint64_t num_threads = 128;
const uint64_t num_blocks = std::max((num_all_pairs + num_threads - 1UL) / num_threads, 1UL);
const uint64_t num_blocks = std::max(static_cast<uint64_t>((num_all_pairs + num_threads - 1UL) / num_threads), static_cast<uint64_t>(1UL));
AT_DISPATCH_FLOATING_TYPES(positions.scalar_type(), "get_neighbor_pairs_forward", [&]() {
PairListAccessor<scalar_t> list_accessor(list);
auto box = triclinic::get_box_accessor<scalar_t>(box_vectors, use_periodic);
Expand Down
6 changes: 3 additions & 3 deletions torchmdnet/extensions/neighbors/neighbors_cuda_cell.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ template <class scalar_t> struct CellListAccessor {
template <class scalar_t>
__device__ void addNeighborPair(PairListAccessor<scalar_t>& list, const int i, const int j,
scalar_t distance2, scalar3<scalar_t> delta) {
const bool requires_transpose = list.include_transpose and (j != i);
const bool requires_transpose = list.include_transpose && (j != i);
const int ni = max(i, j);
const int nj = min(i, j);
const scalar_t delta_sign = (ni == i) ? scalar_t(1.0) : scalar_t(-1.0);
Expand All @@ -292,8 +292,8 @@ __device__ void addNeighborsForCell(const Particle<scalar_t>& i_atom, int j_cell
const auto last_particle = cl.cell_end[j_cell];
for (int cur_j = first_particle; cur_j < last_particle; cur_j++) {
const auto j_batch = cl.sorted_batch[cur_j];
if ((j_batch == i_atom.batch) and
((cur_j < i_atom.index) || (list.loop and cur_j == i_atom.index))) {
if ((j_batch == i_atom.batch) &&
((cur_j < i_atom.index) || (list.loop && cur_j == i_atom.index))) {
const auto position_j = fetchPosition(cl.sorted_positions, cur_j);
const auto delta = rect::compute_distance<scalar_t>(i_atom.position, position_j,
list.use_periodic, box_size);
Expand Down
2 changes: 1 addition & 1 deletion torchmdnet/extensions/neighbors/neighbors_cuda_shared.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ __global__ void forward_kernel_shared(uint32_t num_atoms, const Accessor<scalar_
if (!active)
break; // An out of bounds thread must be masked
const int cur_j = tile * blockDim.x + counter;
const bool testPair = cur_j < num_atoms and (cur_j < id or (list.loop and cur_j == id));
const bool testPair = cur_j < num_atoms && (cur_j < id || (list.loop && cur_j == id));
if (testPair) {
const auto batch_j = sh_batch[counter];
if (batch_i == batch_j) {
Expand Down

0 comments on commit 7a99e4b

Please sign in to comment.