Skip to content

Commit

Permalink
Rework and simplify the constructors of NonUniformPointSampling (#694)
Browse files Browse the repository at this point in the history
  • Loading branch information
tpadioleau authored Dec 5, 2024
1 parent c02adf7 commit f5c0a2c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 31 deletions.
35 changes: 8 additions & 27 deletions include/ddc/non_uniform_point_sampling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,48 +66,29 @@ class NonUniformPointSampling : detail::NonUniformPointSamplingBase
Impl() = default;

/// @brief Construct a `NonUniformPointSampling` using a brace-list, i.e. `NonUniformPointSampling mesh({0., 1.})`
Impl(std::initializer_list<continuous_element_type> points)
Impl(std::initializer_list<continuous_element_type> const points)
: Impl(points.begin(), points.end())
{
if (!std::is_sorted(points.begin(), points.end())) {
throw std::runtime_error("Input points must be sorted");
}
std::vector<continuous_element_type> host_points(points.begin(), points.end());
Kokkos::View<continuous_element_type*, Kokkos::HostSpace> const
host(host_points.data(), host_points.size());
Kokkos::resize(m_points, host.extent(0));
Kokkos::deep_copy(m_points, host);
}

/// @brief Construct a `NonUniformPointSampling` using a C++20 "common range".
template <class InputRange>
explicit Impl(InputRange const& points)
explicit Impl(InputRange const& points) : Impl(points.begin(), points.end())
{
if (!std::is_sorted(points.begin(), points.end())) {
throw std::runtime_error("Input points must be sorted");
}
if constexpr (Kokkos::is_view_v<InputRange>) {
Kokkos::deep_copy(m_points, points);
} else {
std::vector<continuous_element_type> host_points(points.begin(), points.end());
Kokkos::View<continuous_element_type*, Kokkos::HostSpace> const
host(host_points.data(), host_points.size());
Kokkos::resize(m_points, host.extent(0));
Kokkos::deep_copy(m_points, host);
}
}

/// @brief Construct a `NonUniformPointSampling` using a pair of iterators.
template <class InputIt>
Impl(InputIt points_begin, InputIt points_end)
Impl(InputIt const points_begin, InputIt const points_end)
{
using view_type = Kokkos::View<continuous_element_type*, MemorySpace>;
if (!std::is_sorted(points_begin, points_end)) {
throw std::runtime_error("Input points must be sorted");
}
// Make a contiguous copy of [points_begin, points_end[
std::vector<continuous_element_type> host_points(points_begin, points_end);
Kokkos::View<continuous_element_type*, Kokkos::HostSpace> const
host(host_points.data(), host_points.size());
Kokkos::resize(m_points, host.extent(0));
Kokkos::deep_copy(m_points, host);
m_points = view_type("NonUniformPointSampling::points", host_points.size());
Kokkos::deep_copy(m_points, view_type(host_points.data(), host_points.size()));
}

template <class OriginMemorySpace>
Expand Down
17 changes: 13 additions & 4 deletions tests/non_uniform_point_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT

#include <array>
#include <list>
#include <sstream>
#include <string>
#include <utility>
Expand Down Expand Up @@ -39,10 +40,11 @@ struct DDimY : ddc::NonUniformPointSampling<DimY>
{
};

std::array<double, 4> const array_points_x VALUES_X;
std::vector<double> const vector_points_x VALUES_X;
std::array<ddc::Coordinate<DimX>, 4> const array_points_x VALUES_X;
std::list<ddc::Coordinate<DimX>> const list_points_x VALUES_X;
std::vector<ddc::Coordinate<DimX>> const vector_points_x VALUES_X;

std::vector<double> const vector_points_y VALUES_Y;
std::vector<ddc::Coordinate<DimY>> const vector_points_y VALUES_Y;

ddc::DiscreteElement<DDimX> constexpr point_ix(2);
ddc::Coordinate<DimX> constexpr point_rx(0.3);
Expand All @@ -55,7 +57,7 @@ ddc::Coordinate<DimX, DimY> constexpr point_rxy(0.3, 0.2);

} // namespace DDC_HIP_5_7_ANONYMOUS_NAMESPACE_WORKAROUND(NON_UNIFORM_POINT_SAMPLING_CPP)

TEST(NonUniformPointSamplingTest, ListConstructor)
TEST(NonUniformPointSamplingTest, InitializerListConstructor)
{
DDimX::Impl<DDimX, Kokkos::HostSpace> const ddim_x(VALUES_X);
EXPECT_EQ(ddim_x.size(), 4);
Expand All @@ -76,6 +78,13 @@ TEST(NonUniformPointSamplingTest, VectorConstructor)
EXPECT_EQ(ddim_x.coordinate(point_ix), point_rx);
}

TEST(NonUniformPointSamplingTest, ListConstructor)
{
DDimX::Impl<DDimX, Kokkos::HostSpace> const ddim_x(list_points_x);
EXPECT_EQ(ddim_x.size(), list_points_x.size());
EXPECT_EQ(ddim_x.coordinate(point_ix), point_rx);
}

TEST(NonUniformPointSamplingTest, NotSortedVectorConstructor)
{
std::vector unordered_vector_points_x = vector_points_x;
Expand Down

0 comments on commit f5c0a2c

Please sign in to comment.