Skip to content

Commit

Permalink
extract() <=> data buffer is owning (#610)
Browse files Browse the repository at this point in the history
* extract is now possible for owning data buffers

* update docu

* add ampersand and white space

* include specialization in internal namespace

* fix clang error
  • Loading branch information
mschimek authored Jan 8, 2024
1 parent 3d2617e commit c89aa70
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 49 deletions.
5 changes: 4 additions & 1 deletion include/kamping/data_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ class DataBuffer : private ParameterObjectBase {
/// @brief Indicates whether the buffer is allocated by KaMPIng.
static constexpr bool is_lib_allocated = allocation == BufferAllocation::lib_allocated;

static constexpr bool is_owning =
ownership == BufferOwnership::owning; ///< Indicates whether the buffer owns its underlying storage.

static constexpr bool is_modifiable =
modifiability == BufferModifiability::modifiable; ///< Indicates whether the underlying storage is modifiable.
static constexpr bool is_single_element =
Expand Down Expand Up @@ -502,7 +505,7 @@ class DataBuffer : private ParameterObjectBase {
/// state.
///
/// @return Moves the underlying container out of the DataBuffer.
template <bool enabled = allocation == BufferAllocation::lib_allocated, std::enable_if_t<enabled, bool> = true>
template <bool enabled = is_owning, std::enable_if_t<enabled, bool> = true>
MemberTypeWithConst extract() {
static_assert(
ownership == BufferOwnership::owning,
Expand Down
53 changes: 25 additions & 28 deletions include/kamping/result.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is part of KaMPIng.
//
// Copyright 2021-2023 The KaMPIng Authors
// Copyright 2021-2024 The KaMPIng Authors
//
// KaMPIng is free software : you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
Expand Down Expand Up @@ -36,6 +36,15 @@ inline constexpr bool has_extract_v = has_member_extract_v<T>;

/// @brief Use this type if one of the template parameters of MPIResult is not used for a specific wrapped \c MPI call.
struct ResultCategoryNotUsed {};

/// @brief Helper for implementing the extract_* functions in \ref MPIResult. Is \c true if the passed buffer type owns
/// its underlying storage and is an output buffer.
template <typename Buffer>
inline constexpr bool is_extractable = Buffer::is_owning&& Buffer::is_out_buffer;

/// @brief Specialization of helper for implementing the extract_* functions in \ref MPIResult. Is always \c false;
template <>
inline constexpr bool is_extractable<internal::ResultCategoryNotUsed> = false;
} // namespace internal

/// @brief MPIResult contains the result of a \c MPI call wrapped by KaMPIng.
Expand Down Expand Up @@ -129,7 +138,7 @@ class MPIResult {
///
/// This function is only available if the underlying status is owned by the
/// MPIResult object.
/// @tparam StatusType_ Template parameter helper only needed to remove this
/// @tparam StatusObject_ Template parameter helper only needed to remove this
/// function if StatusType does not possess a member function \c extract().
/// @return Returns the underlying status object.
template <
Expand All @@ -144,9 +153,9 @@ class MPIResult {
/// This function is only available if the underlying memory is owned by the
/// MPIResult object.
/// @tparam RecvBuf_ Template parameter helper only needed to remove this
/// function if RecvBuf does not possess a member function \c extract().
/// function if RecvBuf should not be extracted (it does not own its underlying memory or is not an out-buffer).
/// @return Returns the underlying storage containing the received elements.
template <typename RecvBuf_ = RecvBuf, std::enable_if_t<kamping::internal::has_extract_v<RecvBuf_>, bool> = true>
template <typename RecvBuf_ = RecvBuf, std::enable_if_t<internal::is_extractable<RecvBuf_>, bool> = true>
decltype(auto) extract_recv_buffer() {
return _recv_buffer.extract();
}
Expand All @@ -157,9 +166,7 @@ class MPIResult {
/// @tparam RecvCounts_ Template parameter helper only needed to remove this function if RecvCounts does not possess
/// a member function \c extract().
/// @return Returns the underlying storage containing the receive counts.
template <
typename RecvCounts_ = RecvCounts,
std::enable_if_t<kamping::internal::has_extract_v<RecvCounts_>, bool> = true>
template <typename RecvCounts_ = RecvCounts, std::enable_if_t<internal::is_extractable<RecvCounts_>, bool> = true>
decltype(auto) extract_recv_counts() {
return _recv_counts.extract();
}
Expand All @@ -170,9 +177,7 @@ class MPIResult {
/// @tparam RecvCount_ Template parameter helper only needed to remove this function if RecvCount does not
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the recv count.
template <
typename RecvCount_ = RecvCount,
std::enable_if_t<kamping::internal::has_extract_v<RecvCount_>, bool> = true>
template <typename RecvCount_ = RecvCount, std::enable_if_t<internal::is_extractable<RecvCount_>, bool> = true>
decltype(auto) extract_recv_count() {
return _recv_count.extract();
}
Expand All @@ -183,9 +188,7 @@ class MPIResult {
/// @tparam RecvDispls_ Template parameter helper only needed to remove this function if RecvDispls does not possess
/// a member function \c extract().
/// @return Returns the underlying storage containing the receive displacements.
template <
typename RecvDispls_ = RecvDispls,
std::enable_if_t<kamping::internal::has_extract_v<RecvDispls_>, bool> = true>
template <typename RecvDispls_ = RecvDispls, std::enable_if_t<internal::is_extractable<RecvDispls_>, bool> = true>
decltype(auto) extract_recv_displs() {
return _recv_displs.extract();
}
Expand All @@ -196,9 +199,7 @@ class MPIResult {
/// @tparam SendCounts_ Template parameter helper only needed to remove this function if SendCounts does not possess
/// a member function \c extract().
/// @return Returns the underlying storage containing the send counts.
template <
typename SendCounts_ = SendCounts,
std::enable_if_t<kamping::internal::has_extract_v<SendCounts_>, bool> = true>
template <typename SendCounts_ = SendCounts, std::enable_if_t<internal::is_extractable<SendCounts_>, bool> = true>
decltype(auto) extract_send_counts() {
return _send_counts.extract();
}
Expand All @@ -209,9 +210,7 @@ class MPIResult {
/// @tparam SendCount_ Template parameter helper only needed to remove this function if SendCount does not
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the send count.
template <
typename SendCount_ = SendCount,
std::enable_if_t<kamping::internal::has_extract_v<SendCount_>, bool> = true>
template <typename SendCount_ = SendCount, std::enable_if_t<internal::is_extractable<SendCount_>, bool> = true>
decltype(auto) extract_send_count() {
return _send_count.extract();
}
Expand All @@ -222,9 +221,7 @@ class MPIResult {
/// @tparam SendDispls_ Template parameter helper only needed to remove this function if SendDispls does not possess
/// a member function \c extract().
/// @return Returns the underlying storage containing the send displacements.
template <
typename SendDispls_ = SendDispls,
std::enable_if_t<kamping::internal::has_extract_v<SendDispls_>, bool> = true>
template <typename SendDispls_ = SendDispls, std::enable_if_t<internal::is_extractable<SendDispls_>, bool> = true>
decltype(auto) extract_send_displs() {
return _send_displs.extract();
}
Expand All @@ -236,8 +233,8 @@ class MPIResult {
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the send_recv_count.
template <
typename SendRecvCount_ = SendRecvCount,
std::enable_if_t<kamping::internal::has_extract_v<SendRecvCount_>, bool> = true>
typename SendRecvCount_ = SendRecvCount,
std::enable_if_t<internal::is_extractable<SendRecvCount_>, bool> = true>
decltype(auto) extract_send_recv_count() {
return _send_recv_count.extract();
}
Expand All @@ -248,7 +245,7 @@ class MPIResult {
/// @tparam SendType_ Template parameter helper only needed to remove this function if SendType does not
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the send_type.
template <typename SendType_ = SendType, std::enable_if_t<kamping::internal::has_extract_v<SendType_>, bool> = true>
template <typename SendType_ = SendType, std::enable_if_t<internal::is_extractable<SendType_>, bool> = true>
decltype(auto) extract_send_type() {
return _send_type.extract();
}
Expand All @@ -259,7 +256,7 @@ class MPIResult {
/// @tparam RecvType_ Template parameter helper only needed to remove this function if RecvType does not
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the send_type.
template <typename RecvType_ = RecvType, std::enable_if_t<kamping::internal::has_extract_v<RecvType_>, bool> = true>
template <typename RecvType_ = RecvType, std::enable_if_t<internal::is_extractable<RecvType_>, bool> = true>
decltype(auto) extract_recv_type() {
return _recv_type.extract();
}
Expand All @@ -270,8 +267,8 @@ class MPIResult {
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the send_type.
template <
typename SendRecvType_ = SendRecvType,
std::enable_if_t<kamping::internal::has_extract_v<SendRecvType_>, bool> = true>
typename SendRecvType_ = SendRecvType,
std::enable_if_t<internal::is_extractable<SendRecvType_>, bool> = true>
decltype(auto) extract_send_recv_type() {
return _send_recv_type.extract();
}
Expand Down
29 changes: 20 additions & 9 deletions tests/data_buffer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,15 +591,26 @@ TEST(DataBufferTest, has_extract) {
"Library allocated DataBuffers must have an extract() member function"
);
static_assert(
!has_extract_v<DataBuffer<
has_extract_v<DataBuffer<
int,
ParameterType::send_buf,
BufferModifiability::modifiable,
BufferOwnership::owning,
BufferType::in_buffer,
BufferResizePolicy::no_resize,
BufferAllocation::user_allocated>>,
"User allocated DataBuffers must not have an extract() member function"
"User allocated owning DataBuffers must have an extract() member function"
);
static_assert(
!has_extract_v<DataBuffer<
int,
ParameterType::send_buf,
BufferModifiability::modifiable,
BufferOwnership::referencing,
BufferType::in_buffer,
BufferResizePolicy::no_resize,
BufferAllocation::user_allocated>>,
"User allocated referencing DataBuffers must not have an extract() member function"
);
}

Expand Down Expand Up @@ -711,14 +722,14 @@ TEST(LibAllocatedContainerBasedBufferTest, prevent_usage_after_extraction) {
}

TEST(LibAllocatedContainerBasedBufferTest, prevent_usage_after_extraction_via_mpi_result) {
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_buf, BufferType::in_buffer> recv_buffer;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_counts, BufferType::in_buffer> recv_counts;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_displs, BufferType::in_buffer> recv_displs;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::send_counts, BufferType::in_buffer> send_counts;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::send_displs, BufferType::in_buffer> send_displs;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_buf, BufferType::out_buffer> recv_buffer;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_counts, BufferType::out_buffer> recv_counts;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_displs, BufferType::out_buffer> recv_displs;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::send_counts, BufferType::out_buffer> send_counts;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::send_displs, BufferType::out_buffer> send_displs;
// we use out_buffer here because extracting is only done from out buffers
LibAllocatedContainerBasedBuffer<int, ParameterType::recv_count, BufferType::in_buffer> recv_count;
LibAllocatedContainerBasedBuffer<int, ParameterType::send_count, BufferType::in_buffer> send_count;
LibAllocatedContainerBasedBuffer<int, ParameterType::recv_count, BufferType::out_buffer> recv_count;
LibAllocatedContainerBasedBuffer<int, ParameterType::send_count, BufferType::out_buffer> send_count;
LibAllocatedContainerBasedBuffer<int, ParameterType::send_recv_count, BufferType::out_buffer> send_recv_count;
LibAllocatedContainerBasedBuffer<MPI_Datatype, ParameterType::send_type, BufferType::out_buffer> send_type;
LibAllocatedContainerBasedBuffer<MPI_Datatype, ParameterType::recv_type, BufferType::out_buffer> recv_type;
Expand Down
12 changes: 6 additions & 6 deletions tests/named_parameters_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ TEST(ParameterFactoriesTest, make_data_buffer) {
"Owning buffers must hold their data directly."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}

{
Expand Down Expand Up @@ -1552,7 +1552,7 @@ TEST(ParameterFactoriesTest, make_data_buffer) {
"Owning buffers must hold their data directly."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}
{
// Constant, container, owning, user_allocated with initializer_list
Expand All @@ -1570,7 +1570,7 @@ TEST(ParameterFactoriesTest, make_data_buffer) {
"Owning buffers must hold their data directly."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}
}

Expand Down Expand Up @@ -1656,7 +1656,7 @@ TEST(ParameterFactoriesTest, make_data_buffer_boolean_value) {
"Owning buffers must hold their data directly."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}

{
Expand Down Expand Up @@ -1714,7 +1714,7 @@ TEST(ParameterFactoriesTest, make_data_buffer_boolean_value) {
"Initializer lists of type bool have to be converted to std::vector<kabool>."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}
{
// Constant, container, owning, user_allocated with initializer_list
Expand All @@ -1733,7 +1733,7 @@ TEST(ParameterFactoriesTest, make_data_buffer_boolean_value) {
"Initializer lists of type bool have to be converted to std::vector<kabool>."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}
}

Expand Down
10 changes: 5 additions & 5 deletions tests/result_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void test_recv_count_in_MPIResult() {
using namespace kamping;
using namespace kamping::internal;

LibAllocatedSingleElementBuffer<int, ParameterType::recv_count, BufferType::in_buffer> recv_count_wrapper{};
LibAllocatedSingleElementBuffer<int, ParameterType::recv_count, BufferType::out_buffer> recv_count_wrapper{};
recv_count_wrapper.underlying() = 42;
MPIResult mpi_result{
ResultCategoryNotUsed{},
Expand Down Expand Up @@ -415,7 +415,7 @@ KAMPING_MAKE_HAS_MEMBER(extract_send_recv_type)
TEST(MpiResultTest, removed_extract_functions) {
using namespace ::kamping;
using namespace ::kamping::internal;
constexpr BufferType btype = BufferType::in_buffer;
constexpr BufferType btype = BufferType::out_buffer;
{
// All of these should be extractable (used to make sure that the above macros work correctly)
StatusParam<StatusParamType::owning> status_sanity_check;
Expand Down Expand Up @@ -642,7 +642,7 @@ TEST(MpiResultTest, removed_extract_functions) {

TEST(MakeMpiResultTest, pass_random_order_buffer) {
{
constexpr BufferType btype = BufferType::in_buffer;
constexpr BufferType btype = BufferType::out_buffer;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_counts, btype> recv_counts;
LibAllocatedContainerBasedBuffer<std::vector<char>, ParameterType::recv_buf, btype> recv_buf;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_displs, btype> recv_displs;
Expand All @@ -663,7 +663,7 @@ TEST(MakeMpiResultTest, pass_random_order_buffer) {
ASSERT_EQ(result_status.tag(), 42);
}
{
constexpr BufferType btype = BufferType::in_buffer;
constexpr BufferType btype = BufferType::out_buffer;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_counts, btype> recv_counts;
LibAllocatedContainerBasedBuffer<std::vector<double>, ParameterType::recv_buf, btype> recv_buf;

Expand All @@ -686,7 +686,7 @@ TEST(MakeMpiResultTest, pass_send_recv_buf) {
}

TEST(MakeMpiResultTest, check_content) {
constexpr BufferType btype = BufferType::in_buffer;
constexpr BufferType btype = BufferType::out_buffer;

std::vector<int> recv_buf_data(20);
std::iota(recv_buf_data.begin(), recv_buf_data.end(), 0);
Expand Down

0 comments on commit c89aa70

Please sign in to comment.