Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extract() <=> data buffer is owning #610

Merged
merged 5 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/kamping/data_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ class DataBuffer : private ParameterObjectBase {
/// @brief Indicates whether the buffer is allocated by KaMPIng.
static constexpr bool is_lib_allocated = allocation == BufferAllocation::lib_allocated;

static constexpr bool is_owning =
ownership == BufferOwnership::owning; ///< Indicates whether the buffer owns its underlying storage.

static constexpr bool is_modifiable =
modifiability == BufferModifiability::modifiable; ///< Indicates whether the underlying storage is modifiable.
static constexpr bool is_single_element =
Expand Down Expand Up @@ -502,7 +505,7 @@ class DataBuffer : private ParameterObjectBase {
/// state.
///
/// @return Moves the underlying container out of the DataBuffer.
template <bool enabled = allocation == BufferAllocation::lib_allocated, std::enable_if_t<enabled, bool> = true>
template <bool enabled = is_owning, std::enable_if_t<enabled, bool> = true>
MemberTypeWithConst extract() {
static_assert(
ownership == BufferOwnership::owning,
Expand Down
53 changes: 25 additions & 28 deletions include/kamping/result.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// This file is part of KaMPIng.
//
// Copyright 2021-2023 The KaMPIng Authors
// Copyright 2021-2024 The KaMPIng Authors
//
// KaMPIng is free software : you can redistribute it and/or modify it under the terms of the GNU Lesser General Public
// License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
Expand Down Expand Up @@ -36,6 +36,15 @@ inline constexpr bool has_extract_v = has_member_extract_v<T>;

/// @brief Use this type if one of the template parameters of MPIResult is not used for a specific wrapped \c MPI call.
struct ResultCategoryNotUsed {};

/// @brief Helper for implementing the extract_* functions in \ref MPIResult. Is \c true if the passed buffer type owns
/// its underlying storage and is an output buffer.
template <typename Buffer>
inline constexpr bool is_extractable = Buffer::is_owning&& Buffer::is_out_buffer;

/// @brief Specialization of helper for implementing the extract_* functions in \ref MPIResult. Is always \c false;
template <>
inline constexpr bool is_extractable<internal::ResultCategoryNotUsed> = false;
} // namespace internal

/// @brief MPIResult contains the result of a \c MPI call wrapped by KaMPIng.
Expand Down Expand Up @@ -129,7 +138,7 @@ class MPIResult {
///
/// This function is only available if the underlying status is owned by the
/// MPIResult object.
/// @tparam StatusType_ Template parameter helper only needed to remove this
/// @tparam StatusObject_ Template parameter helper only needed to remove this
/// function if StatusType does not possess a member function \c extract().
/// @return Returns the underlying status object.
template <
Expand All @@ -144,9 +153,9 @@ class MPIResult {
/// This function is only available if the underlying memory is owned by the
/// MPIResult object.
/// @tparam RecvBuf_ Template parameter helper only needed to remove this
/// function if RecvBuf does not possess a member function \c extract().
/// function if RecvBuf should not be extracted (it does not own its underlying memory or is not an out-buffer).
/// @return Returns the underlying storage containing the received elements.
template <typename RecvBuf_ = RecvBuf, std::enable_if_t<kamping::internal::has_extract_v<RecvBuf_>, bool> = true>
template <typename RecvBuf_ = RecvBuf, std::enable_if_t<internal::is_extractable<RecvBuf_>, bool> = true>
decltype(auto) extract_recv_buffer() {
return _recv_buffer.extract();
}
Expand All @@ -157,9 +166,7 @@ class MPIResult {
/// @tparam RecvCounts_ Template parameter helper only needed to remove this function if RecvCounts does not possess
/// a member function \c extract().
/// @return Returns the underlying storage containing the receive counts.
template <
typename RecvCounts_ = RecvCounts,
std::enable_if_t<kamping::internal::has_extract_v<RecvCounts_>, bool> = true>
template <typename RecvCounts_ = RecvCounts, std::enable_if_t<internal::is_extractable<RecvCounts_>, bool> = true>
decltype(auto) extract_recv_counts() {
return _recv_counts.extract();
}
Expand All @@ -170,9 +177,7 @@ class MPIResult {
/// @tparam RecvCount_ Template parameter helper only needed to remove this function if RecvCount does not
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the recv count.
template <
typename RecvCount_ = RecvCount,
std::enable_if_t<kamping::internal::has_extract_v<RecvCount_>, bool> = true>
template <typename RecvCount_ = RecvCount, std::enable_if_t<internal::is_extractable<RecvCount_>, bool> = true>
decltype(auto) extract_recv_count() {
return _recv_count.extract();
}
Expand All @@ -183,9 +188,7 @@ class MPIResult {
/// @tparam RecvDispls_ Template parameter helper only needed to remove this function if RecvDispls does not possess
/// a member function \c extract().
/// @return Returns the underlying storage containing the receive displacements.
template <
typename RecvDispls_ = RecvDispls,
std::enable_if_t<kamping::internal::has_extract_v<RecvDispls_>, bool> = true>
template <typename RecvDispls_ = RecvDispls, std::enable_if_t<internal::is_extractable<RecvDispls_>, bool> = true>
decltype(auto) extract_recv_displs() {
return _recv_displs.extract();
}
Expand All @@ -196,9 +199,7 @@ class MPIResult {
/// @tparam SendCounts_ Template parameter helper only needed to remove this function if SendCounts does not possess
/// a member function \c extract().
/// @return Returns the underlying storage containing the send counts.
template <
typename SendCounts_ = SendCounts,
std::enable_if_t<kamping::internal::has_extract_v<SendCounts_>, bool> = true>
template <typename SendCounts_ = SendCounts, std::enable_if_t<internal::is_extractable<SendCounts_>, bool> = true>
decltype(auto) extract_send_counts() {
return _send_counts.extract();
}
Expand All @@ -209,9 +210,7 @@ class MPIResult {
/// @tparam SendCount_ Template parameter helper only needed to remove this function if SendCount does not
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the send count.
template <
typename SendCount_ = SendCount,
std::enable_if_t<kamping::internal::has_extract_v<SendCount_>, bool> = true>
template <typename SendCount_ = SendCount, std::enable_if_t<internal::is_extractable<SendCount_>, bool> = true>
decltype(auto) extract_send_count() {
return _send_count.extract();
}
Expand All @@ -222,9 +221,7 @@ class MPIResult {
/// @tparam SendDispls_ Template parameter helper only needed to remove this function if SendDispls does not possess
/// a member function \c extract().
/// @return Returns the underlying storage containing the send displacements.
template <
typename SendDispls_ = SendDispls,
std::enable_if_t<kamping::internal::has_extract_v<SendDispls_>, bool> = true>
template <typename SendDispls_ = SendDispls, std::enable_if_t<internal::is_extractable<SendDispls_>, bool> = true>
decltype(auto) extract_send_displs() {
return _send_displs.extract();
}
Expand All @@ -236,8 +233,8 @@ class MPIResult {
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the send_recv_count.
template <
typename SendRecvCount_ = SendRecvCount,
std::enable_if_t<kamping::internal::has_extract_v<SendRecvCount_>, bool> = true>
typename SendRecvCount_ = SendRecvCount,
std::enable_if_t<internal::is_extractable<SendRecvCount_>, bool> = true>
decltype(auto) extract_send_recv_count() {
return _send_recv_count.extract();
}
Expand All @@ -248,7 +245,7 @@ class MPIResult {
/// @tparam SendType_ Template parameter helper only needed to remove this function if SendType does not
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the send_type.
template <typename SendType_ = SendType, std::enable_if_t<kamping::internal::has_extract_v<SendType_>, bool> = true>
template <typename SendType_ = SendType, std::enable_if_t<internal::is_extractable<SendType_>, bool> = true>
decltype(auto) extract_send_type() {
return _send_type.extract();
}
Expand All @@ -259,7 +256,7 @@ class MPIResult {
/// @tparam RecvType_ Template parameter helper only needed to remove this function if RecvType does not
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the send_type.
template <typename RecvType_ = RecvType, std::enable_if_t<kamping::internal::has_extract_v<RecvType_>, bool> = true>
template <typename RecvType_ = RecvType, std::enable_if_t<internal::is_extractable<RecvType_>, bool> = true>
decltype(auto) extract_recv_type() {
return _recv_type.extract();
}
Expand All @@ -270,8 +267,8 @@ class MPIResult {
/// possess a member function \c extract().
/// @return Returns the underlying storage containing the send_type.
template <
typename SendRecvType_ = SendRecvType,
std::enable_if_t<kamping::internal::has_extract_v<SendRecvType_>, bool> = true>
typename SendRecvType_ = SendRecvType,
std::enable_if_t<internal::is_extractable<SendRecvType_>, bool> = true>
decltype(auto) extract_send_recv_type() {
return _send_recv_type.extract();
}
Expand Down
29 changes: 20 additions & 9 deletions tests/data_buffer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,15 +591,26 @@ TEST(DataBufferTest, has_extract) {
"Library allocated DataBuffers must have an extract() member function"
);
static_assert(
!has_extract_v<DataBuffer<
has_extract_v<DataBuffer<
int,
ParameterType::send_buf,
BufferModifiability::modifiable,
BufferOwnership::owning,
BufferType::in_buffer,
BufferResizePolicy::no_resize,
BufferAllocation::user_allocated>>,
"User allocated DataBuffers must not have an extract() member function"
"User allocated owning DataBuffers must have an extract() member function"
);
static_assert(
!has_extract_v<DataBuffer<
int,
ParameterType::send_buf,
BufferModifiability::modifiable,
BufferOwnership::referencing,
BufferType::in_buffer,
BufferResizePolicy::no_resize,
BufferAllocation::user_allocated>>,
"User allocated referencing DataBuffers must not have an extract() member function"
);
}

Expand Down Expand Up @@ -711,14 +722,14 @@ TEST(LibAllocatedContainerBasedBufferTest, prevent_usage_after_extraction) {
}

TEST(LibAllocatedContainerBasedBufferTest, prevent_usage_after_extraction_via_mpi_result) {
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_buf, BufferType::in_buffer> recv_buffer;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_counts, BufferType::in_buffer> recv_counts;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_displs, BufferType::in_buffer> recv_displs;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::send_counts, BufferType::in_buffer> send_counts;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::send_displs, BufferType::in_buffer> send_displs;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_buf, BufferType::out_buffer> recv_buffer;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_counts, BufferType::out_buffer> recv_counts;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_displs, BufferType::out_buffer> recv_displs;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::send_counts, BufferType::out_buffer> send_counts;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::send_displs, BufferType::out_buffer> send_displs;
// we use out_buffer here because extracting is only done from out buffers
LibAllocatedContainerBasedBuffer<int, ParameterType::recv_count, BufferType::in_buffer> recv_count;
LibAllocatedContainerBasedBuffer<int, ParameterType::send_count, BufferType::in_buffer> send_count;
LibAllocatedContainerBasedBuffer<int, ParameterType::recv_count, BufferType::out_buffer> recv_count;
LibAllocatedContainerBasedBuffer<int, ParameterType::send_count, BufferType::out_buffer> send_count;
LibAllocatedContainerBasedBuffer<int, ParameterType::send_recv_count, BufferType::out_buffer> send_recv_count;
LibAllocatedContainerBasedBuffer<MPI_Datatype, ParameterType::send_type, BufferType::out_buffer> send_type;
LibAllocatedContainerBasedBuffer<MPI_Datatype, ParameterType::recv_type, BufferType::out_buffer> recv_type;
Expand Down
12 changes: 6 additions & 6 deletions tests/named_parameters_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ TEST(ParameterFactoriesTest, make_data_buffer) {
"Owning buffers must hold their data directly."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}

{
Expand Down Expand Up @@ -1552,7 +1552,7 @@ TEST(ParameterFactoriesTest, make_data_buffer) {
"Owning buffers must hold their data directly."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}
{
// Constant, container, owning, user_allocated with initializer_list
Expand All @@ -1570,7 +1570,7 @@ TEST(ParameterFactoriesTest, make_data_buffer) {
"Owning buffers must hold their data directly."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}
}

Expand Down Expand Up @@ -1656,7 +1656,7 @@ TEST(ParameterFactoriesTest, make_data_buffer_boolean_value) {
"Owning buffers must hold their data directly."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}

{
Expand Down Expand Up @@ -1714,7 +1714,7 @@ TEST(ParameterFactoriesTest, make_data_buffer_boolean_value) {
"Initializer lists of type bool have to be converted to std::vector<kabool>."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}
{
// Constant, container, owning, user_allocated with initializer_list
Expand All @@ -1733,7 +1733,7 @@ TEST(ParameterFactoriesTest, make_data_buffer_boolean_value) {
"Initializer lists of type bool have to be converted to std::vector<kabool>."
);
// extract() as proxy for lib allocated DataBuffers
EXPECT_FALSE(has_extract_v<decltype(data_buf)>);
EXPECT_TRUE(has_extract_v<decltype(data_buf)>);
}
}

Expand Down
10 changes: 5 additions & 5 deletions tests/result_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void test_recv_count_in_MPIResult() {
using namespace kamping;
using namespace kamping::internal;

LibAllocatedSingleElementBuffer<int, ParameterType::recv_count, BufferType::in_buffer> recv_count_wrapper{};
LibAllocatedSingleElementBuffer<int, ParameterType::recv_count, BufferType::out_buffer> recv_count_wrapper{};
recv_count_wrapper.underlying() = 42;
MPIResult mpi_result{
ResultCategoryNotUsed{},
Expand Down Expand Up @@ -415,7 +415,7 @@ KAMPING_MAKE_HAS_MEMBER(extract_send_recv_type)
TEST(MpiResultTest, removed_extract_functions) {
using namespace ::kamping;
using namespace ::kamping::internal;
constexpr BufferType btype = BufferType::in_buffer;
constexpr BufferType btype = BufferType::out_buffer;
{
// All of these should be extractable (used to make sure that the above macros work correctly)
StatusParam<StatusParamType::owning> status_sanity_check;
Expand Down Expand Up @@ -642,7 +642,7 @@ TEST(MpiResultTest, removed_extract_functions) {

TEST(MakeMpiResultTest, pass_random_order_buffer) {
{
constexpr BufferType btype = BufferType::in_buffer;
constexpr BufferType btype = BufferType::out_buffer;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_counts, btype> recv_counts;
LibAllocatedContainerBasedBuffer<std::vector<char>, ParameterType::recv_buf, btype> recv_buf;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_displs, btype> recv_displs;
Expand All @@ -663,7 +663,7 @@ TEST(MakeMpiResultTest, pass_random_order_buffer) {
ASSERT_EQ(result_status.tag(), 42);
}
{
constexpr BufferType btype = BufferType::in_buffer;
constexpr BufferType btype = BufferType::out_buffer;
LibAllocatedContainerBasedBuffer<std::vector<int>, ParameterType::recv_counts, btype> recv_counts;
LibAllocatedContainerBasedBuffer<std::vector<double>, ParameterType::recv_buf, btype> recv_buf;

Expand All @@ -686,7 +686,7 @@ TEST(MakeMpiResultTest, pass_send_recv_buf) {
}

TEST(MakeMpiResultTest, check_content) {
constexpr BufferType btype = BufferType::in_buffer;
constexpr BufferType btype = BufferType::out_buffer;

std::vector<int> recv_buf_data(20);
std::iota(recv_buf_data.begin(), recv_buf_data.end(), 0);
Expand Down