Skip to content

Commit

Permalink
Merge pull request #44 from masadcv/fix-multiple-cuda
Browse files Browse the repository at this point in the history
Fix cuda target on multi-gpu host
  • Loading branch information
masadcv authored Mar 2, 2023
2 parents 7101eb3 + 4d42f9a commit 7558018
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
9 changes: 9 additions & 0 deletions FastGeodis/fastgeodis_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_runtime.h>

Expand Down Expand Up @@ -283,6 +284,10 @@ torch::Tensor generalised_geodesic2d_cuda(
const int &iterations
)
{
int device = image.get_device();
// std::cout << "Running with CUDA Device: " << device << std::endl;
c10::cuda::CUDAGuard device_guard(device);

torch::Tensor image_local = image.clone();
torch::Tensor distance = v * mask.clone();

Expand Down Expand Up @@ -589,6 +594,10 @@ torch::Tensor generalised_geodesic3d_cuda(
const int &iterations
)
{
int device = image.get_device();
// std::cout << "Running with CUDA Device: " << device << std::endl;
c10::cuda::CUDAGuard device_guard(device);

// square spacing with transform
std::transform(spacing.begin(), spacing.end(), spacing.begin(), spacing.begin(), std::multiplies<float>());

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_extensions():

setup(
name="FastGeodis",
version="1.0.1",
version="1.0.2",
description="Fast Implementation of Generalised Geodesic Distance Transform for CPU (OpenMP) and GPU (CUDA)",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 7558018

Please sign in to comment.