Skip to content

Commit

Permalink
Merge pull request #13 from masadcv/cuda-multichannel-support
Browse files Browse the repository at this point in the history
Add multi-channel support for CUDA distance transforms
  • Loading branch information
masadcv authored Jul 16, 2022
2 parents ac36e64 + b6f3b3d commit ab7a599
Show file tree
Hide file tree
Showing 7 changed files with 429 additions and 355 deletions.
10 changes: 5 additions & 5 deletions FastGeodis/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ void check_spatial_shape_match(const torch::Tensor &in1, const torch::Tensor &in
{
if (in1.dim() != in2.dim())
{
throw std::invalid_argument("dimensions of input tensors do not match "
+ std::to_string(in1.dim() - 2) + " vs " + std::to_string(in2.dim() - 2));
throw std::invalid_argument("dimensions of input tensors do not match " + \
std::to_string(in1.dim() - 2) + " vs " + std::to_string(in2.dim() - 2));
}
for(int i=0; i < dims; i++)
for (int i = 0; i < dims; i++)
{
if(in1.size(2+i) != in2.size(2+i))
if (in1.size(2 + i) != in2.size(2 + i))
{
std::cout << "Tensor1 ";
print_shape(in1);
Expand Down Expand Up @@ -116,5 +116,5 @@ void check_input_dimensions(const torch::Tensor &image, const torch::Tensor &mas
check_single_batch(mask);

// check spatial shapes match
check_spatial_shape_match(image, mask, num_dims-2);
check_spatial_shape_match(image, mask, num_dims - 2);
}
74 changes: 35 additions & 39 deletions FastGeodis/fastgeodis.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,57 +36,53 @@

#ifdef WITH_CUDA
torch::Tensor generalised_geodesic2d_cuda(
torch::Tensor &image,
const torch::Tensor &mask,
const float &v,
const float &l_grad,
const float &l_eucl,
torch::Tensor &image,
const torch::Tensor &mask,
const float &v,
const float &l_grad,
const float &l_eucl,
const int &iterations);

torch::Tensor generalised_geodesic3d_cuda(
torch::Tensor &image,
const torch::Tensor &mask,
const std::vector<float> &spacing,
const float &v,
const float &l_grad,
const float &l_eucl,
torch::Tensor &image,
const torch::Tensor &mask,
const std::vector<float> &spacing,
const float &v,
const float &l_grad,
const float &l_eucl,
const int &iterations);
#endif

torch::Tensor generalised_geodesic2d_cpu(
torch::Tensor &image,
const torch::Tensor &mask,
const float &v,
const float &l_grad,
const float &l_eucl,
torch::Tensor &image,
const torch::Tensor &mask,
const float &v,
const float &l_grad,
const float &l_eucl,
const int &iterations);

torch::Tensor generalised_geodesic3d_cpu(
torch::Tensor &image,
const torch::Tensor &mask,
const std::vector<float> &spacing,
const float &v,
const float &l_grad,
const float &l_eucl,
torch::Tensor &image,
const torch::Tensor &mask,
const std::vector<float> &spacing,
const float &v,
const float &l_grad,
const float &l_eucl,
const int &iterations);

torch::Tensor generalised_geodesic2d(
torch::Tensor &image,
const torch::Tensor &mask,
const float &v,
const float &l_grad,
const float &l_eucl,
const int &iterations
);
torch::Tensor &image,
const torch::Tensor &mask,
const float &v,
const float &l_grad,
const float &l_eucl,
const int &iterations);

torch::Tensor generalised_geodesic3d(
torch::Tensor &image,
const torch::Tensor &mask,
const std::vector<float> &spacing,
const float &v,
const float &l_grad,
const float &l_eucl,
const int &iterations
);


torch::Tensor &image,
const torch::Tensor &mask,
const std::vector<float> &spacing,
const float &v,
const float &l_grad,
const float &l_eucl,
const int &iterations);
Loading

0 comments on commit ab7a599

Please sign in to comment.