diff --git a/pytorch3d/csrc/iou_box3d/iou_box3d.cu b/pytorch3d/csrc/iou_box3d/iou_box3d.cu index 524a5bf4b..a315550f6 100644 --- a/pytorch3d/csrc/iou_box3d/iou_box3d.cu +++ b/pytorch3d/csrc/iou_box3d/iou_box3d.cu @@ -12,8 +12,6 @@ #include #include #include -#include -#include #include "iou_box3d/iou_utils.cuh" // Parallelize over N*M computations which can each be done diff --git a/pytorch3d/csrc/iou_box3d/iou_utils.cuh b/pytorch3d/csrc/iou_box3d/iou_utils.cuh index 4caf0099c..5ad5b165d 100644 --- a/pytorch3d/csrc/iou_box3d/iou_utils.cuh +++ b/pytorch3d/csrc/iou_box3d/iou_utils.cuh @@ -8,7 +8,6 @@ #include #include -#include #include #include "utils/float_math.cuh" diff --git a/pytorch3d/csrc/marching_cubes/marching_cubes.cu b/pytorch3d/csrc/marching_cubes/marching_cubes.cu index 527bced5d..8c8fe3925 100644 --- a/pytorch3d/csrc/marching_cubes/marching_cubes.cu +++ b/pytorch3d/csrc/marching_cubes/marching_cubes.cu @@ -9,8 +9,6 @@ #include #include #include -#include -#include #include #include "marching_cubes/tables.h" @@ -40,20 +38,6 @@ through" each cube in the grid. // EPS: Used to indicate if two float values are close __constant__ const float EPSILON = 1e-5; -// Thrust wrapper for exclusive scan -// -// Args: -// output: pointer to on-device output array -// input: pointer to on-device input array, where scan is performed -// numElements: number of elements for the input array -// -void ThrustScanWrapper(int* output, int* input, int numElements) { - thrust::exclusive_scan( - thrust::device_ptr(input), - thrust::device_ptr(input + numElements), - thrust::device_ptr(output)); -} - // Linearly interpolate the position where an isosurface cuts an edge // between two vertices, based on their scalar values // @@ -455,19 +439,24 @@ std::tuple MarchingCubesCuda( grid.x = 65535; } + using at::indexing::None; + using at::indexing::Slice; + auto d_voxelVerts = - at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt)) + at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt)) .to(vol.device()); + auto d_voxelVerts_ = d_voxelVerts.index({Slice(1, None)}); auto d_voxelOccupied = - at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt)) + at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt)) .to(vol.device()); + auto d_voxelOccupied_ = d_voxelOccupied.index({Slice(1, None)}); // Execute "ClassifyVoxelKernel" kernel to precompute // two arrays - d_voxelOccupied and d_voxelVertices to global memory, // which stores the occupancy state and number of voxel vertices per voxel. ClassifyVoxelKernel<<>>( - d_voxelVerts.packed_accessor32(), - d_voxelOccupied.packed_accessor32(), + d_voxelVerts_.packed_accessor32(), + d_voxelOccupied_.packed_accessor32(), vol.packed_accessor32(), isolevel); AT_CUDA_CHECK(cudaGetLastError()); @@ -477,18 +466,12 @@ std::tuple MarchingCubesCuda( // count for voxels in the grid and compute the number of active voxels. // If the number of active voxels is 0, return zero tensor for verts and // faces. - auto d_voxelOccupiedScan = - at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt)) - .to(vol.device()); - ThrustScanWrapper( - d_voxelOccupiedScan.data_ptr(), - d_voxelOccupied.data_ptr(), - numVoxels); + + auto d_voxelOccupiedScan = at::cumsum(d_voxelOccupied, 0); + auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)}); // number of active voxels - int lastElement = d_voxelVerts[numVoxels - 1].cpu().item(); - int lastScan = d_voxelOccupiedScan[numVoxels - 1].cpu().item(); - int activeVoxels = lastElement + lastScan; + int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item(); const int device_id = vol.device().index(); auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id); @@ -509,22 +492,17 @@ std::tuple MarchingCubesCuda( CompactVoxelsKernel<<>>( d_compVoxelArray.packed_accessor32(), d_voxelOccupied.packed_accessor32(), - d_voxelOccupiedScan.packed_accessor32(), + d_voxelOccupiedScan_.packed_accessor32(), numVoxels); AT_CUDA_CHECK(cudaGetLastError()); cudaDeviceSynchronize(); // Scan d_voxelVerts array to generate offsets of vertices for each voxel - auto d_voxelVertsScan = at::zeros({numVoxels}, opt); - ThrustScanWrapper( - d_voxelVertsScan.data_ptr(), - d_voxelVerts.data_ptr(), - numVoxels); + auto d_voxelVertsScan = at::cumsum(d_voxelVerts, 0); + auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)}); // total number of vertices - lastElement = d_voxelVerts[numVoxels - 1].cpu().item(); - lastScan = d_voxelVertsScan[numVoxels - 1].cpu().item(); - int totalVerts = lastElement + lastScan; + int totalVerts = d_voxelVertsScan[numVoxels].cpu().item(); // Execute "GenerateFacesKernel" kernel // This runs only on the occupied voxels. @@ -544,7 +522,7 @@ std::tuple MarchingCubesCuda( faces.packed_accessor(), ids.packed_accessor(), d_compVoxelArray.packed_accessor32(), - d_voxelVertsScan.packed_accessor32(), + d_voxelVertsScan_.packed_accessor32(), activeVoxels, vol.packed_accessor32(), faceTable.packed_accessor32(),