diff --git a/FastGeodis/fastgeodis_cuda.cu b/FastGeodis/fastgeodis_cuda.cu index 671f3a6..efdb0c8 100644 --- a/FastGeodis/fastgeodis_cuda.cu +++ b/FastGeodis/fastgeodis_cuda.cu @@ -29,6 +29,7 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include +#include #include #include @@ -283,6 +284,10 @@ torch::Tensor generalised_geodesic2d_cuda( const int &iterations ) { + int device = image.get_device(); + // std::cout << "Running with CUDA Device: " << device << std::endl; + c10::cuda::CUDAGuard device_guard(device); + torch::Tensor image_local = image.clone(); torch::Tensor distance = v * mask.clone(); @@ -589,6 +594,10 @@ torch::Tensor generalised_geodesic3d_cuda( const int &iterations ) { + int device = image.get_device(); + // std::cout << "Running with CUDA Device: " << device << std::endl; + c10::cuda::CUDAGuard device_guard(device); + // square spacing with transform std::transform(spacing.begin(), spacing.end(), spacing.begin(), spacing.begin(), std::multiplies()); diff --git a/setup.py b/setup.py index 337c0f4..4b2515d 100755 --- a/setup.py +++ b/setup.py @@ -122,7 +122,7 @@ def get_extensions(): setup( name="FastGeodis", - version="1.0.1", + version="1.0.2", description="Fast Implementation of Generalised Geodesic Distance Transform for CPU (OpenMP) and GPU (CUDA)", long_description=long_description, long_description_content_type="text/markdown",