Skip to content

Commit

Permalink
Merge pull request #26 from masadcv/add-fastmarch
Browse files Browse the repository at this point in the history
Add Fast Marching method for Geodesic and Euclidean distance transform
  • Loading branch information
masadcv authored Aug 15, 2022
2 parents 6893299 + f5d98c9 commit 0ce4f89
Show file tree
Hide file tree
Showing 29 changed files with 2,655 additions and 691 deletions.
218 changes: 212 additions & 6 deletions FastGeodis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def generalised_geodesic2d_toivanen(
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU or GPU depending on Tensor's device location
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
Expand Down Expand Up @@ -218,7 +218,7 @@ def generalised_geodesic3d_toivanen(
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU or GPU depending on Tensor's device location
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
Expand Down Expand Up @@ -254,7 +254,7 @@ def signed_generalised_geodesic2d_toivanen(
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU or GPU depending on Tensor's device location
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
Expand Down Expand Up @@ -291,7 +291,7 @@ def signed_generalised_geodesic3d_toivanen(
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU or GPU depending on Tensor's device location
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
Expand All @@ -308,6 +308,145 @@ def signed_generalised_geodesic3d_toivanen(
image, softmask, spacing, v, lamb, 1 - lamb, iter
)

def generalised_geodesic2d_fastmarch(
image: torch.Tensor,
softmask: torch.Tensor,
v: float,
lamb: float
):
r"""Computes Generalised Geodesic Distance using Fast Marching method from:
Sethian, James A.
"Fast marching methods."
SIAM review 41.2 (1999): 199-235.
For more details on generalised geodesic distance, check the following reference:
Criminisi, Antonio, Toby Sharp, and Andrew Blake.
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
softmask: softmask in range [0, 1] with seed information.
v: weighting factor for establishing relationship between unary and spatial distances.
lamb: weighting factor between 0.0 and 1.0. 0.0 returns euclidean distance, whereas 1.0 returns geodesic distance
Returns:
torch.Tensor with distance transform
"""
return FastGeodisCpp.generalised_geodesic2d_fastmarch(
image, softmask, v, lamb, 1 - lamb
)


def generalised_geodesic3d_fastmarch(
image: torch.Tensor,
softmask: torch.Tensor,
spacing: List,
v: float,
lamb: float
):
r"""Computes Generalised Geodesic Distance using Fast Marching method from:
TSethian, James A.
"Fast marching methods."
SIAM review 41.2 (1999): 199-235.
For more details on generalised geodesic distance, check the following reference:
Criminisi, Antonio, Toby Sharp, and Andrew Blake.
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
softmask: softmask in range [0, 1] with seed information.
spacing: spacing for 3D data
v: weighting factor for establishing relationship between unary and spatial distances.
lamb: weighting factor between 0.0 and 1.0. 0.0 returns euclidean distance, whereas 1.0 returns geodesic distance
Returns:
torch.Tensor with distance transform
"""
return FastGeodisCpp.generalised_geodesic3d_fastmarch(
image, softmask, spacing, v, lamb, 1 - lamb
)

def signed_generalised_geodesic2d_fastmarch(
image: torch.Tensor,
softmask: torch.Tensor,
v: float,
lamb: float
):
r"""Computes Signed Generalised Geodesic Distance using Fast Marching method from:
Sethian, James A.
"Fast marching methods."
SIAM review 41.2 (1999): 199-235.
For more details on generalised geodesic distance, check the following reference:
Criminisi, Antonio, Toby Sharp, and Andrew Blake.
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
softmask: softmask in range [0, 1] with seed information.
v: weighting factor for establishing relationship between unary and spatial distances.
lamb: weighting factor between 0.0 and 1.0. 0.0 returns euclidean distance, whereas 1.0 returns geodesic distance
Returns:
torch.Tensor with distance transform
"""
return FastGeodisCpp.signed_generalised_geodesic2d_fastmarch(
image, softmask, v, lamb, 1 - lamb
)


def signed_generalised_geodesic3d_fastmarch(
image: torch.Tensor,
softmask: torch.Tensor,
spacing: List,
v: float,
lamb: float
):
r"""Computes Signed Generalised Geodesic Distance using Fast Marching method from:
Sethian, James A.
"Fast marching methods."
SIAM review 41.2 (1999): 199-235.
For more details on generalised geodesic distance, check the following reference:
Criminisi, Antonio, Toby Sharp, and Andrew Blake.
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
softmask: softmask in range [0, 1] with seed information.
spacing: spacing for 3D data
v: weighting factor for establishing relationship between unary and spatial distances.
lamb: weighting factor between 0.0 and 1.0. 0.0 returns euclidean distance, whereas 1.0 returns geodesic distance
iter: number of passes of the iterative distance transform method
Returns:
torch.Tensor with distance transform
"""
return FastGeodisCpp.signed_generalised_geodesic3d_fastmarch(
image, softmask, spacing, v, lamb, 1 - lamb
)

def GSF2d(
image: torch.Tensor,
softmask: torch.Tensor,
Expand Down Expand Up @@ -389,7 +528,7 @@ def GSF2d_toivanen(
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU or GPU depending on Tensor's device location
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
Expand Down Expand Up @@ -425,7 +564,7 @@ def GSF3d_toivanen(
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU or GPU depending on Tensor's device location
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
Expand All @@ -439,3 +578,70 @@ def GSF3d_toivanen(
torch.Tensor with distance transform
"""
return FastGeodisCpp.GSF3d_toivanen(image, softmask, theta, spacing, v, lamb, iter)

def GSF2d_fastmarch(
image: torch.Tensor,
softmask: torch.Tensor,
theta: float,
v: float,
lamb: float
):
r"""Computes Geodesic Symmetric Filtering (GSF) using Fast Marching method from:
Sethian, James A.
"Fast marching methods."
SIAM review 41.2 (1999): 199-235.
For more details on GSF, check the following reference:
Criminisi, Antonio, Toby Sharp, and Andrew Blake.
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
softmask: softmask in range [0, 1] with seed information.
v: weighting factor for establishing relationship between unary and spatial distances.
lamb: weighting factor between 0.0 and 1.0. 0.0 returns euclidean distance, whereas 1.0 returns geodesic distance
Returns:
torch.Tensor with distance transform
"""
return FastGeodisCpp.GSF2d_fastmarch(image, softmask, theta, v, lamb)


def GSF3d_fastmarch(
image: torch.Tensor,
softmask: torch.Tensor,
theta: float,
spacing: List,
v: float,
lamb: float,
):
r"""Computes Geodesic Symmetric Filtering (GSF) using Fast Marching method from:
Sethian, James A.
"Fast marching methods."
SIAM review 41.2 (1999): 199-235.
For more details on GSF, check the following reference:
Criminisi, Antonio, Toby Sharp, and Andrew Blake.
"Geos: Geodesic image segmentation."
European Conference on Computer Vision, Berlin, Heidelberg, 2008.
The function expects input as torch.Tensor, which can be run on CPU only using Tensor's device location
Args:
image: input image, can be grayscale or multiple channels.
softmask: softmask in range [0, 1] with seed information.
spacing: spacing for 3D data
v: weighting factor for establishing relationship between unary and spatial distances.
lamb: weighting factor between 0.0 and 1.0. 0.0 returns euclidean distance, whereas 1.0 returns geodesic distance
Returns:
torch.Tensor with distance transform
"""
return FastGeodisCpp.GSF3d_fastmarch(image, softmask, theta, spacing, v, lamb)
81 changes: 81 additions & 0 deletions FastGeodis/fastgeodis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,36 @@ torch::Tensor generalised_geodesic3d_toivanen(torch::Tensor &image, const torch:
return generalised_geodesic3d_toivanen_cpu(image, mask, spacing, v, l_grad, l_eucl, iterations);
}

torch::Tensor generalised_geodesic2d_fastmarch(torch::Tensor &image, const torch::Tensor &mask, const float &v, const float &l_grad, const float &l_eucl)
{

// check input dimensions
check_input_dimensions(image, mask, 4);

// fastmarch method is only implementable on cpu
check_cpu(image);
check_cpu(mask);

return generalised_geodesic2d_fastmarch_cpu(image, mask, v, l_grad, l_eucl);
}

torch::Tensor generalised_geodesic3d_fastmarch(torch::Tensor &image, const torch::Tensor &mask, const std::vector<float> &spacing, const float &v, const float &l_grad, const float &l_eucl)
{
// check input dimensions
check_input_dimensions(image, mask, 5);

// fastmarch method is only implementable on cpu
check_cpu(image);
check_cpu(mask);

if (spacing.size() != 3)
{
throw std::invalid_argument(
"function only supports 3D spacing inputs, received " + std::to_string(spacing.size()));
}

return generalised_geodesic3d_fastmarch_cpu(image, mask, spacing, v, l_grad, l_eucl);
}

torch::Tensor getDs2d(torch::Tensor &image, const torch::Tensor &mask, const float &v, const float &l_grad, const float &l_eucl, const int &iterations)
{
Expand Down Expand Up @@ -260,6 +290,48 @@ torch::Tensor GSF3d_toivanen(torch::Tensor &image, const torch::Tensor &mask, co
return Dd_Md + De_Me;
}

torch::Tensor getDs2d_fastmarch(torch::Tensor &image, const torch::Tensor &mask, const float &v, const float &l_grad, const float &l_eucl)
{
torch::Tensor D_M = generalised_geodesic2d_fastmarch(image, mask, v, l_grad, l_eucl);
torch::Tensor D_Mb = generalised_geodesic2d_fastmarch(image, 1 - mask, v, l_grad, l_eucl);

return D_M - D_Mb;
}

torch::Tensor GSF2d_fastmarch(torch::Tensor &image, const torch::Tensor &mask, const float &theta, const float &v, const float &lambda)
{
torch::Tensor Ds_M = getDs2d_fastmarch(image, mask, v, lambda, 1 - lambda);

torch::Tensor Md = (Ds_M > theta).type_as(Ds_M);
torch::Tensor Me = (Ds_M > -theta).type_as(Ds_M);

torch::Tensor Dd_Md = -getDs2d_fastmarch(image, 1 - Md, v, lambda, 1 - lambda);
torch::Tensor De_Me = getDs2d_fastmarch(image, Me, v, lambda, 1 - lambda);

return Dd_Md + De_Me;
}

torch::Tensor getDs3d_fastmarch(torch::Tensor &image, const torch::Tensor &mask, const std::vector<float> &spacing, const float &v, const float &l_grad, const float &l_eucl)
{
torch::Tensor D_M = generalised_geodesic3d_fastmarch(image, mask, spacing, v, l_grad, l_eucl);
torch::Tensor D_Mb = generalised_geodesic3d_fastmarch(image, 1 - mask, spacing, v, l_grad, l_eucl);

return D_M - D_Mb;
}

torch::Tensor GSF3d_fastmarch(torch::Tensor &image, const torch::Tensor &mask, const float &theta, const std::vector<float> &spacing, const float &v, const float &lambda)
{
torch::Tensor Ds_M = getDs3d_fastmarch(image, mask, spacing, v, lambda, 1 - lambda);

torch::Tensor Md = (Ds_M > theta).type_as(Ds_M);
torch::Tensor Me = (Ds_M > -theta).type_as(Ds_M);

torch::Tensor Dd_Md = -getDs3d_fastmarch(image, 1 - Md, spacing, v, lambda, 1 - lambda);
torch::Tensor De_Me = getDs3d_fastmarch(image, Me, spacing, v, lambda, 1 - lambda);

return Dd_Md + De_Me;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("generalised_geodesic2d", &generalised_geodesic2d, "Generalised Geodesic distance 2d");
Expand All @@ -270,11 +342,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("GSF2d_toivanen", &GSF2d_toivanen, "Geodesic Symmetric Filtering 2d using Toivanen's method");
m.def("signed_generalised_geodesic2d_toivanen", &getDs2d_toivanen, "Signed Generalised Geodesic distance 2d using Toivanen's method");

m.def("generalised_geodesic2d_fastmarch", &generalised_geodesic2d_fastmarch, "Generalised Geodesic distance 2d using Fast Marching method");
m.def("GSF2d_fastmarch", &GSF2d_fastmarch, "Geodesic Symmetric Filtering 2d using Fast Marching method");
m.def("signed_generalised_geodesic2d_fastmarch", &getDs2d_fastmarch, "Signed Generalised Geodesic distance 2d using Fast Marching method");

m.def("generalised_geodesic3d", &generalised_geodesic3d, "Generalised Geodesic distance 3d");
m.def("GSF3d", &GSF3d, "Geodesic Symmetric Filtering 3d");
m.def("signed_generalised_geodesic3d", &getDs3d, "Signed Generalised Geodesic distance 3d");

m.def("generalised_geodesic3d_toivanen", &generalised_geodesic3d_toivanen, "Generalised Geodesic distance 3d using Toivanen's method");
m.def("GSF3d_toivanen", &GSF3d_toivanen, "Geodesic Symmetric Filtering 3d using Toivanen's method");
m.def("signed_generalised_geodesic3d_toivanen", &getDs3d_toivanen, "Signed Generalised Geodesic distance 3d using Toivanen's method");

m.def("generalised_geodesic3d_fastmarch", &generalised_geodesic3d_fastmarch, "Generalised Geodesic distance 3d using Fast Marching method");
m.def("GSF3d_fastmarch", &GSF3d_fastmarch, "Geodesic Symmetric Filtering 3d using Fast Marching method");
m.def("signed_generalised_geodesic3d_fastmarch", &getDs3d_fastmarch, "Signed Generalised Geodesic distance 3d using Fast Marching method");

}
Loading

0 comments on commit 0ce4f89

Please sign in to comment.