Skip to content

Commit

Permalink
Add visit-like pattern for RecordComponent (#1544)
Browse files Browse the repository at this point in the history
* Fixes for the datatype macros

* Introduce std::variant for dataset, attribute and non-vector types

Also use them to erase some repetition

* Add switchDatasetType

Also refactor the switchType functions to use the datatype macros

* Main implementation: Add variant-based loadChunk API

* Testing, examples
  • Loading branch information
franzpoeschel authored Dec 22, 2023
1 parent 2e89f87 commit b131cd2
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 219 deletions.
17 changes: 10 additions & 7 deletions examples/10_streaming_read.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using namespace openPMD;
int main()
{
#if openPMD_HAVE_ADIOS2
using position_t = double;
auto backends = openPMD::getFileExtensions();
if (std::find(backends.begin(), backends.end(), "sst") == backends.end())
{
Expand Down Expand Up @@ -40,15 +39,15 @@ int main()
std::cout << "Current iteration: " << iteration.iterationIndex
<< std::endl;
Record electronPositions = iteration.particles["e"]["position"];
std::array<std::shared_ptr<position_t>, 3> loadedChunks;
std::array<RecordComponent::shared_ptr_dataset_types, 3> loadedChunks;
std::array<Extent, 3> extents;
std::array<std::string, 3> const dimensions{{"x", "y", "z"}};

for (size_t i = 0; i < 3; ++i)
{
std::string const &dim = dimensions[i];
RecordComponent rc = electronPositions[dim];
loadedChunks[i] = rc.loadChunk<position_t>(
loadedChunks[i] = rc.loadChunkVariant(
Offset(rc.getDimensionality(), 0), rc.getExtent());
extents[i] = rc.getExtent();
}
Expand All @@ -64,10 +63,14 @@ int main()
Extent const &extent = extents[i];
std::cout << "\ndim: " << dim << "\n" << std::endl;
auto chunk = loadedChunks[i];
for (size_t j = 0; j < extent[0]; ++j)
{
std::cout << chunk.get()[j] << ", ";
}
std::visit(
[&extent](auto &shared_ptr) {
for (size_t j = 0; j < extent[0]; ++j)
{
std::cout << shared_ptr.get()[j] << ", ";
}
},
chunk);
std::cout << "\n----------\n" << std::endl;
}
}
Expand Down
23 changes: 14 additions & 9 deletions examples/2_read_serial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,25 @@ int main()

Offset chunk_offset = {1, 1, 1};
Extent chunk_extent = {2, 2, 1};
auto chunk_data = E_x.loadChunk<double>(chunk_offset, chunk_extent);
// Loading without explicit datatype here
auto chunk_data = E_x.loadChunkVariant(chunk_offset, chunk_extent);
cout << "Queued the loading of a single chunk from disk, "
"ready to execute\n";
series.flush();
cout << "Chunk has been read from disk\n"
<< "Read chunk contains:\n";
for (size_t row = 0; row < chunk_extent[0]; ++row)
{
for (size_t col = 0; col < chunk_extent[1]; ++col)
cout << "\t" << '(' << row + chunk_offset[0] << '|'
<< col + chunk_offset[1] << '|' << 1 << ")\t"
<< chunk_data.get()[row * chunk_extent[1] + col];
cout << '\n';
}
std::visit(
[&chunk_offset, &chunk_extent](auto &shared_ptr) {
for (size_t row = 0; row < chunk_extent[0]; ++row)
{
for (size_t col = 0; col < chunk_extent[1]; ++col)
cout << "\t" << '(' << row + chunk_offset[0] << '|'
<< col + chunk_offset[1] << '|' << 1 << ")\t"
<< shared_ptr.get()[row * chunk_extent[1] + col];
cout << '\n';
}
},
chunk_data);

auto all_data = E_x.loadChunk<double>();

Expand Down
18 changes: 15 additions & 3 deletions examples/4_read_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ int main(int argc, char *argv[])
Offset chunk_offset = {static_cast<long unsigned int>(mpi_rank) + 1, 1, 1};
Extent chunk_extent = {2, 2, 1};

auto chunk_data = E_x.loadChunk<double>(chunk_offset, chunk_extent);
// If you know the datatype, use `loadChunk<double>(...)` instead.
auto chunk_data = E_x.loadChunkVariant(chunk_offset, chunk_extent);

if (0 == mpi_rank)
cout << "Queued the loading of a single chunk per MPI rank from "
Expand All @@ -72,9 +73,20 @@ int main(int argc, char *argv[])
for (size_t row = 0; row < chunk_extent[0]; ++row)
{
for (size_t col = 0; col < chunk_extent[1]; ++col)
{
cout << "\t" << '(' << row + chunk_offset[0] << '|'
<< col + chunk_offset[1] << '|' << 1 << ")\t"
<< chunk_data.get()[row * chunk_extent[1] + col];
<< col + chunk_offset[1] << '|' << 1 << ")\t";
/*
* For hot loops, the std::visit(...) call should be moved
* further up.
*/
std::visit(
[row, col, &chunk_extent](auto &shared_ptr) {
cout << shared_ptr
.get()[row * chunk_extent[1] + col];
},
chunk_data);
}
cout << std::endl;
}
}
Expand Down
52 changes: 52 additions & 0 deletions include/openPMD/Datatype.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include "openPMD/auxiliary/TypeTraits.hpp"
#include "openPMD/auxiliary/UniquePtr.hpp"

// comment to prevent clang-format from moving this #include up
// datatype macros may be included and un-included in other headers
#include "openPMD/DatatypeMacros.hpp"

#include <array>
#include <climits>
#include <complex>
Expand All @@ -35,6 +39,7 @@
#include <tuple>
#include <type_traits>
#include <utility> // std::declval
#include <variant>
#include <vector>

namespace openPMD
Expand Down Expand Up @@ -94,6 +99,33 @@ enum class Datatype : int
*/
std::vector<Datatype> openPMD_Datatypes();

namespace detail
{
struct bottom
{};

// std::variant, but ignore first template parameter
// little trick to avoid trailing commas in the macro expansions below
template <typename Arg, typename... Args>
using variant_tail_t = std::variant<Args...>;
} // namespace detail

#define OPENPMD_ENUMERATE_TYPES(type) , type

using dataset_types =
detail::variant_tail_t<detail::bottom OPENPMD_FOREACH_DATASET_DATATYPE(
OPENPMD_ENUMERATE_TYPES)>;

using non_vector_types =
detail::variant_tail_t<detail::bottom OPENPMD_FOREACH_NONVECTOR_DATATYPE(
OPENPMD_ENUMERATE_TYPES)>;

using attribute_types =
detail::variant_tail_t<detail::bottom OPENPMD_FOREACH_DATATYPE(
OPENPMD_ENUMERATE_TYPES)>;

#undef OPENPMD_ENUMERATE_TYPES

/** @brief Fundamental equivalence check for two given types T and U.
*
* This checks whether the fundamental datatype (i.e. that of a single value
Expand Down Expand Up @@ -782,6 +814,25 @@ template <typename Action, typename... Args>
constexpr auto switchNonVectorType(Datatype dt, Args &&...args)
-> decltype(Action::template call<char>(std::forward<Args>(args)...));

/**
* Generalizes switching over an openPMD datatype.
*
* Will call the function template found at Action::call< T >(), instantiating T
* with the C++ internal datatype corresponding to the openPMD datatype.
* Specializes only on those types that can occur in a dataset.
*
* @tparam ReturnType The function template's return type.
* @tparam Action The struct containing the function template.
* @tparam Args The function template's argument types.
* @param dt The openPMD datatype.
* @param args The function template's arguments.
* @return Passes on the result of invoking the function template with the given
* arguments and with the template parameter specified by dt.
*/
template <typename Action, typename... Args>
constexpr auto switchDatasetType(Datatype dt, Args &&...args)
-> decltype(Action::template call<char>(std::forward<Args>(args)...));

} // namespace openPMD

#if !defined(_MSC_VER)
Expand Down Expand Up @@ -811,3 +862,4 @@ inline bool operator!=(openPMD::Datatype d, openPMD::Datatype e)
#endif

#include "openPMD/Datatype.tpp"
#include "openPMD/UndefDatatypeMacros.hpp"
Loading

0 comments on commit b131cd2

Please sign in to comment.