Skip to content

Commit

Permalink
Working attempt at addressing #1
Browse files Browse the repository at this point in the history
  • Loading branch information
tvercaut committed Jul 31, 2023
1 parent 2904637 commit ef475a9
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 96 deletions.
41 changes: 31 additions & 10 deletions bindings/python/py-distance-transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace py = pybind11;
template <typename T, int N>
py::array_t<T> DistanceTransformImpl(
py::array_t<T, py::array::c_style | py::array::forcecast> maskarray,
bool computeSquareDistance) {
bool computeSquareDistance, const std::vector<T>& alphas) {
// Get input shape
typename dope::DopeVector<T, N>::IndexD masksize, pymasksize;
std::copy_n(maskarray.shape(), N, pymasksize.begin());
Expand All @@ -31,35 +31,55 @@ py::array_t<T> DistanceTransformImpl(
dope::Grid<T, N> dopefield(masksize);

dt::DistanceTransform::distanceTransformL2(dopemask, dopefield,
computeSquareDistance);
computeSquareDistance, alphas);

return py::array_t<T>(pymasksize, dopefield.data());
}

template <typename T>
py::array_t<T> DistanceTransform(
py::array_t<T, py::array::c_style | py::array::forcecast> maskarray,
bool computeSquareDistance = false) {
bool computeSquareDistance = false,
std::optional<py::array_t<T, py::array::c_style | py::array::forcecast> >
optalphas = py::none()) {
// std::cout<<"Got input with dtype="<<maskarray.dtype()<<",
// ndim="<<maskarray.ndim()<<std::endl;

std::vector<T> alphas(maskarray.ndim());
if (optalphas.has_value()) {
// Essentially doing alphas = optalphas.value();
if (optalphas.value().size() != maskarray.ndim()) {
throw std::out_of_range("Alpha vector size is not equal to dimension.");
}
std::copy_n(optalphas.value().data(), maskarray.ndim(), alphas.begin());
} else {
alphas = std::vector<T>(maskarray.ndim(), 1.0);
}

switch (maskarray.ndim()) {
case 1: {
return DistanceTransformImpl<T, 1>(maskarray, computeSquareDistance);
return DistanceTransformImpl<T, 1>(maskarray, computeSquareDistance,
alphas);
}
case 2: {
return DistanceTransformImpl<T, 2>(maskarray, computeSquareDistance);
return DistanceTransformImpl<T, 2>(maskarray, computeSquareDistance,
alphas);
}
case 3: {
return DistanceTransformImpl<T, 3>(maskarray, computeSquareDistance);
return DistanceTransformImpl<T, 3>(maskarray, computeSquareDistance,
alphas);
}
case 4: {
return DistanceTransformImpl<T, 4>(maskarray, computeSquareDistance);
return DistanceTransformImpl<T, 4>(maskarray, computeSquareDistance,
alphas);
}
case 5: {
return DistanceTransformImpl<T, 5>(maskarray, computeSquareDistance);
return DistanceTransformImpl<T, 5>(maskarray, computeSquareDistance,
alphas);
}
case 6: {
return DistanceTransformImpl<T, 6>(maskarray, computeSquareDistance);
return DistanceTransformImpl<T, 6>(maskarray, computeSquareDistance,
alphas);
}
default: {
throw std::out_of_range("Dimension " + std::to_string(maskarray.ndim()) +
Expand All @@ -82,5 +102,6 @@ PYBIND11_MODULE(py_distance_transform, m) {
Compute the distance transform
https://github.com/tvercaut/distance_transform
)pbdoc",
py::arg("maskarray"), py::arg("computeSquareDistance") = false);
py::arg("maskarray"), py::arg("computeSquareDistance") = false,
py::arg("alphas") = py::none());
}
24 changes: 16 additions & 8 deletions example/distance_transform_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ int runexample() {

std::chrono::steady_clock::time_point start =
std::chrono::steady_clock::now();
dt::DistanceTransform::distanceTransformL2(f, f, indices, false, 1);
dt::DistanceTransform::distanceTransformL2(f, f, indices, false,
std::vector<float>(2, 1.0), 1);
std::cout << std::endl
<< "2D distance function computed in: "
<< std::chrono::duration_cast<std::chrono::nanoseconds>(
Expand Down Expand Up @@ -137,7 +138,8 @@ int runexample() {
}

start = std::chrono::steady_clock::now();
dt::DistanceTransform::distanceTransformL2(fWin, fWin, indicesWin, true, 1);
dt::DistanceTransform::distanceTransformL2(fWin, fWin, indicesWin, true,
std::vector<float>(2, 1.0), 1);
std::cout << std::endl
<< "2D distance function computed on the window in: "
<< std::chrono::duration_cast<std::chrono::nanoseconds>(
Expand Down Expand Up @@ -168,7 +170,8 @@ int runexample() {
f2D[i][j] = std::numeric_limits<float>::max();
f2D[0][0] = 0.0f;
start = std::chrono::steady_clock::now();
dt::DistanceTransform::distanceTransformL2(f2D, f2D, false, 1);
dt::DistanceTransform::distanceTransformL2(f2D, f2D, false,
std::vector<float>(2, 1.0), 1);
std::cout << std::endl
<< size[0] << 'x' << size[1] << " distance function computed in: "
<< std::chrono::duration_cast<std::chrono::milliseconds>(
Expand All @@ -185,7 +188,8 @@ int runexample() {
f3D[i][j][k] = std::numeric_limits<float>::max();
f3D[0][0][0] = 0.0f;
start = std::chrono::steady_clock::now();
dt::DistanceTransform::distanceTransformL2(f3D, f3D, false, 1);
dt::DistanceTransform::distanceTransformL2(f3D, f3D, false,
std::vector<float>(3, 1.0), 1);
std::cout << std::endl
<< size3D[0] << 'x' << size3D[1] << 'x' << size3D[2]
<< " distance function computed in: "
Expand All @@ -202,7 +206,8 @@ int runexample() {
f3D[0][0][0] = 0.0f;
start = std::chrono::steady_clock::now();
dt::DistanceTransform::distanceTransformL2(
f3D, f3D, false, std::thread::hardware_concurrency());
f3D, f3D, false, std::vector<float>(3, 1.0),
std::thread::hardware_concurrency());
std::cout << std::endl
<< size3D[0] << 'x' << size3D[1] << 'x' << size3D[2]
<< " distance function (concurrently) computed in: "
Expand All @@ -220,7 +225,8 @@ int runexample() {
f2DBig[i][j] = std::numeric_limits<float>::max();
f2DBig[0][0] = 0.0f;
start = std::chrono::steady_clock::now();
dt::DistanceTransform::distanceTransformL2(f2DBig, f2DBig, false, 1);
dt::DistanceTransform::distanceTransformL2(f2DBig, f2DBig, false,
std::vector<float>(2, 1.0), 1);
std::cout << std::endl
<< size[0] << 'x' << size[1] << " distance function computed in: "
<< std::chrono::duration_cast<std::chrono::milliseconds>(
Expand All @@ -235,7 +241,8 @@ int runexample() {
f2DBig[0][0] = 0.0f;
start = std::chrono::steady_clock::now();
dt::DistanceTransform::distanceTransformL2(
f2DBig, f2DBig, false, std::thread::hardware_concurrency());
f2DBig, f2DBig, false, std::vector<float>(2, 1.0),
std::thread::hardware_concurrency());
std::cout << std::endl
<< size[0] << 'x' << size[1]
<< " distance function (concurrently) computed in: "
Expand All @@ -257,7 +264,8 @@ int runexample() {
f6D[i][j][k][l][m][n] = std::numeric_limits<float>::max();
f6D[0][0][0][0][0][0] = 0.0f;
start = std::chrono::steady_clock::now();
dt::DistanceTransform::distanceTransformL2(f6D, f6D, false, 1);
dt::DistanceTransform::distanceTransformL2(f6D, f6D, false,
std::vector<float>(6, 1.0), 1);
std::cout << std::endl
<< size6D[0] << 'x' << size6D[1] << 'x' << size6D[2] << 'x'
<< size6D[3] << 'x' << size6D[4] << 'x' << size6D[5]
Expand Down
10 changes: 6 additions & 4 deletions example/py-distance-transform-minimal-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@


def distance_transform_example():
mask = np.ones((51,100), dtype=float)*1e2
mask[10,23] = 0
mask[35,84] = 0
mask = np.ones((51, 100), dtype=float) * 1e3
mask[10, 23] = 0
mask[35, 84] = 0

distance_map = dt.distance_transform(mask,True)
spacings = [2, 1.0]

distance_map = dt.distance_transform(mask, False, spacings)

print("Max dist:", np.max(distance_map[:]))
print(distance_map.dtype)
Expand Down
33 changes: 25 additions & 8 deletions include/distance_transform/distance_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define INCLUDE_DISTANCE_TRANSFORM_DISTANCE_TRANSFORM_H_

#include <thread>
#include <vector>

#include "dope_vector/Grid.h"

Expand All @@ -32,6 +33,8 @@ class DistanceTransform {
* @param D The resulting distance field of f.
* @param squared Compute squared distances (L2)^2 - avoiding to
* compute square roots - (true) or keep them normal (false - default).
* @param alphas Weighting factor for anisotropic distances (square
* of pixel/voxel spacing)
* @param nThreads The number of threads for parallel computation. If
* <= 1, the computation will be sequential.
* @note Arrays f and D can also be the same.
Expand All @@ -40,6 +43,7 @@ class DistanceTransform {
inline static void distanceTransformL2(
const dope::DopeVector<Scalar, DIM> &f, dope::DopeVector<Scalar, DIM> &D,
const bool squared = false,
std::vector<Scalar> alphas = std::vector<Scalar>(DIM, 1.0),
const std::size_t nThreads = std::thread::hardware_concurrency());

/**
Expand All @@ -49,6 +53,8 @@ class DistanceTransform {
* @param D The resulting distance field of f.
* @param squared Compute squared distances (L2)^2 - avoiding to
* compute square roots - (true) or keep them normal (false - default).
* @param alphas Weighting factor for anisotropic distances (square
* of pixel/voxel spacing)
* @param nThreads The number of threads for parallel computation.
* Actually NOT used, since it's not easy to run a single row computation in
* parallel.
Expand All @@ -58,6 +64,7 @@ class DistanceTransform {
inline static void distanceTransformL2(
const dope::DopeVector<Scalar, 1> &f, dope::DopeVector<Scalar, 1> &D,
const bool squared = false,
std::vector<Scalar> alphas = std::vector<Scalar>(1, 1.0),
const std::size_t nThreads = std::thread::hardware_concurrency());

/**
Expand All @@ -70,6 +77,8 @@ class DistanceTransform {
* local minimum for each sample.
* @param squared Compute squared distances (L2)^2 - avoiding to
* compute square roots - (true) or keep them normal (false - default).
* @param alphas Weighting factor for anisotropic distances (square
* of pixel/voxel spacing)
* @param nThreads The number of threads for parallel computation. If
* <= 1, the computation will be sequential.
* @note Arrays f and D can also be the same. I should be first
Expand All @@ -79,6 +88,7 @@ class DistanceTransform {
inline static void distanceTransformL2(
const dope::DopeVector<Scalar, DIM> &f, dope::DopeVector<Scalar, DIM> &D,
dope::DopeVector<dope::SizeType, DIM> &I, const bool squared = false,
std::vector<Scalar> alphas = std::vector<Scalar>(DIM, 1.0),
const std::size_t nThreads = std::thread::hardware_concurrency());

/**
Expand All @@ -91,6 +101,8 @@ class DistanceTransform {
* local minimum for each sample.
* @param squared Compute squared distances (L2)^2 - avoiding to
* compute square roots - (true) or keep them normal (false - default).
* @param alphas Weighting factor for anisotropic distances (square
* of pixel/voxel spacing)
* @param nThreads The number of threads for parallel computation.
* Actually NOT used, since it's not easy to run a single row computation in
* parallel.
Expand All @@ -100,6 +112,7 @@ class DistanceTransform {
inline static void distanceTransformL2(
const dope::DopeVector<Scalar, 1> &f, dope::DopeVector<Scalar, 1> &D,
dope::DopeVector<dope::SizeType, 1> &I, const bool squared = false,
std::vector<Scalar> alphas = std::vector<Scalar>(1, 1.0),
const std::size_t nThreads = std::thread::hardware_concurrency());

/**
Expand All @@ -126,12 +139,13 @@ class DistanceTransform {
* window, in multi-threading).
* @param D The resulting distance field of f (a window, in
* multi-threading).
* @param d The dimension where to slice.
* @param order The order in which to permute the slices.
* @param alpha A multiplier stretching each parabola (L2 distance)
* vertically.
*/
template <typename Scalar, dope::SizeType DIM>
inline static void distanceL2Helper(const dope::DopeVector<Scalar, DIM> &f,
dope::DopeVector<Scalar, DIM> &D);
dope::DopeVector<Scalar, DIM> &D,
const Scalar alpha);

/**
* @brief The actual distance field computation is done by recursive calls
Expand All @@ -141,7 +155,8 @@ class DistanceTransform {
*/
template <typename Scalar, dope::SizeType DIM>
inline static void distanceL2(const dope::DopeVector<Scalar, DIM> &f,
dope::DopeVector<Scalar, DIM> &D);
dope::DopeVector<Scalar, DIM> &D,
const Scalar alpha);

/**
* @brief The actual distance field computation as in the "Distance
Expand All @@ -152,7 +167,8 @@ class DistanceTransform {
*/
template <typename Scalar>
inline static void distanceL2(const dope::DopeVector<Scalar, 1> &f,
dope::DopeVector<Scalar, 1> &D);
dope::DopeVector<Scalar, 1> &D,
const Scalar alpha);

/**
* @brief The loop iteration process that can be executed sequentially and
Expand All @@ -170,7 +186,7 @@ class DistanceTransform {
inline static void distanceL2Helper(
const dope::DopeVector<Scalar, DIM> &f, dope::DopeVector<Scalar, DIM> &D,
const dope::DopeVector<dope::SizeType, DIM> &Ipre,
dope::DopeVector<dope::SizeType, DIM> &Ipost);
dope::DopeVector<dope::SizeType, DIM> &Ipost, const Scalar alpha);

/**
* @brief The actual distance field computation is done by recursive calls
Expand All @@ -184,7 +200,7 @@ class DistanceTransform {
inline static void distanceL2(
const dope::DopeVector<Scalar, DIM> &f, dope::DopeVector<Scalar, DIM> &D,
const dope::DopeVector<dope::SizeType, DIM> &Ipre,
dope::DopeVector<dope::SizeType, DIM> &Ipost);
dope::DopeVector<dope::SizeType, DIM> &Ipost, const Scalar alpha);

/**
* @brief The actual distance field computation as in the "Distance
Expand All @@ -199,7 +215,8 @@ class DistanceTransform {
inline static void distanceL2(const dope::DopeVector<Scalar, 1> &f,
dope::DopeVector<Scalar, 1> &D,
const dope::DopeVector<dope::SizeType, 1> &Ipre,
dope::DopeVector<dope::SizeType, 1> &Ipost);
dope::DopeVector<dope::SizeType, 1> &Ipost,
const Scalar alpha);

public:
/**
Expand Down
Loading

0 comments on commit ef475a9

Please sign in to comment.