Skip to content

Commit

Permalink
Add support for generalized domains in for_each and transform_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
tpadioleau committed Dec 9, 2024
1 parent 73c8437 commit 47b0348
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 27 deletions.
3 changes: 1 addition & 2 deletions include/ddc/chunk_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ class ChunkCommon
template <class QueryDDim>
KOKKOS_FUNCTION constexpr size_type stride() const
{
return m_allocation_mdspan.stride(
type_seq_rank_v<QueryDDim, detail::ToTypeSeq<SupportType>>);
return m_allocation_mdspan.stride(type_seq_rank_v<QueryDDim, to_type_seq_t<SupportType>>);
}

/** Provide access to the domain on which this chunk is defined
Expand Down
6 changes: 6 additions & 0 deletions include/ddc/discrete_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,12 @@ class DiscreteDomain<>
return *this;
}

KOKKOS_FUNCTION constexpr DiscreteElement<> operator()(
DiscreteVector<> const& dvect) const noexcept
{
return DiscreteElement<>();
}

#if defined(DDC_BUILD_DEPRECATED_CODE)
template <class... ODims>
[[deprecated("Use `restrict_with` instead")]] KOKKOS_FUNCTION constexpr DiscreteDomain restrict(
Expand Down
23 changes: 10 additions & 13 deletions include/ddc/for_each.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ namespace ddc {

namespace detail {

template <class RetType, class Element, std::size_t N, class Functor, class... Is>
template <class Support, class Element, std::size_t N, class Functor, class... Is>
void for_each_serial(
std::array<Element, N> const& begin,
std::array<Element, N> const& end,
Support const& support,
std::array<Element, N> const& size,
Functor const& f,
Is const&... is) noexcept
{
static constexpr std::size_t I = sizeof...(Is);
if constexpr (I == N) {
f(RetType(is...));
f(support(typename Support::discrete_vector_type(is...)));
} else {
for (Element ii = begin[I]; ii < end[I]; ++ii) {
for_each_serial<RetType>(begin, end, f, is..., ii);
for (Element ii = 0; ii < size[I]; ++ii) {
for_each_serial(support, size, f, is..., ii);
}
}
}
Expand All @@ -39,14 +39,11 @@ void for_each_serial(
* @param[in] domain the domain over which to iterate
* @param[in] f a functor taking an index as parameter
*/
template <class... DDims, class Functor>
void for_each(DiscreteDomain<DDims...> const& domain, Functor&& f) noexcept
template <class Support, class Functor>
void for_each(Support const& domain, Functor&& f) noexcept
{
DiscreteElement<DDims...> const ddc_begin = domain.front();
DiscreteElement<DDims...> const ddc_end = domain.front() + domain.extents();
std::array const begin = detail::array(ddc_begin);
std::array const end = detail::array(ddc_end);
detail::for_each_serial<DiscreteElement<DDims...>>(begin, end, std::forward<Functor>(f));
std::array const size = detail::array(domain.extents());
detail::for_each_serial(domain, size, std::forward<Functor>(f));
}

} // namespace ddc
6 changes: 6 additions & 0 deletions include/ddc/strided_discrete_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,12 @@ class StridedDiscreteDomain<>
return *this;
}

KOKKOS_FUNCTION constexpr DiscreteElement<> operator()(
DiscreteVector<> const& dvect) const noexcept
{
return DiscreteElement<>();
}

#if defined(DDC_BUILD_DEPRECATED_CODE)
template <class... ODims>
[[deprecated(
Expand Down
20 changes: 8 additions & 12 deletions include/ddc/transform_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "ddc/detail/macros.hpp"
#include "ddc/discrete_domain.hpp"
#include "ddc/discrete_element.hpp"
#include "ddc/strided_discrete_domain.hpp"

namespace ddc {

Expand All @@ -23,24 +24,19 @@ namespace detail {
* range. The return type must be acceptable as input to reduce
* @param[in] dcoords discrete elements from dimensions already in a loop
*/
template <
class... DDims,
class T,
class BinaryReductionOp,
class UnaryTransformOp,
class... DCoords>
template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp, class... DCoords>
T transform_reduce_serial(
DiscreteDomain<DDims...> const& domain,
Support const& domain,
[[maybe_unused]] T const neutral,
BinaryReductionOp const& reduce,
UnaryTransformOp const& transform,
DCoords const&... dcoords) noexcept
{
DDC_IF_NVCC_THEN_PUSH_AND_SUPPRESS(implicit_return_from_non_void_function)
if constexpr (sizeof...(DCoords) == sizeof...(DDims)) {
return transform(DiscreteElement<DDims...>(dcoords...));
if constexpr (sizeof...(DCoords) == Support::rank()) {
return transform(typename Support::discrete_element_type(dcoords...));
} else {
using CurrentDDim = type_seq_element_t<sizeof...(DCoords), detail::TypeSeq<DDims...>>;
using CurrentDDim = type_seq_element_t<sizeof...(DCoords), to_type_seq_t<Support>>;
T result = neutral;
for (DiscreteElement<CurrentDDim> const ii : select<CurrentDDim>(domain)) {
result = reduce(
Expand All @@ -62,9 +58,9 @@ T transform_reduce_serial(
* @param[in] transform a unary FunctionObject that will be applied to each element of the input
* range. The return type must be acceptable as input to reduce
*/
template <class... DDims, class T, class BinaryReductionOp, class UnaryTransformOp>
template <class Support, class T, class BinaryReductionOp, class UnaryTransformOp>
T transform_reduce(
DiscreteDomain<DDims...> const& domain,
Support const& domain,
T neutral,
BinaryReductionOp&& reduce,
UnaryTransformOp&& transform) noexcept
Expand Down

0 comments on commit 47b0348

Please sign in to comment.