Skip to content

Commit

Permalink
#2281: Working Rabenseifner with updated StateHolder
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Sep 12, 2024
1 parent 2eb3b11 commit 9cba850
Show file tree
Hide file tree
Showing 10 changed files with 451 additions and 362 deletions.
12 changes: 9 additions & 3 deletions src/vt/collective/reduce/allreduce/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
//@HEADER
*/

#include "vt/configs/debug/debug_printconst.h"
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H

Expand Down Expand Up @@ -109,7 +110,7 @@ struct DataHelper {
dest = DataHan::toVec(std::forward<Args>(data)...);
}

static MsgPtr<RabenseifnerMsg<Scalar, DataT>> createMessage(
static auto createMessage(
const std::vector<Scalar>& payload, size_t begin, size_t count, size_t id,
int32_t step = 0) {
return vt::makeMessage<RabenseifnerMsg<Scalar, DataT>>(
Expand Down Expand Up @@ -157,7 +158,7 @@ struct DataHelper<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>> {
dest = {std::forward<Args>(data)...};
}

static MsgPtr<RabenseifnerMsg<Scalar, DataT>> createMessage(
static auto createMessage(
const DataT& payload, size_t begin, size_t count, size_t id,
int32_t step = 0) {
return vt::makeMessage<RabenseifnerMsg<Scalar, DataT>>(
Expand Down Expand Up @@ -186,7 +187,7 @@ struct DataHelper<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>> {
template <template <typename Arg> class Op, typename... Args>
static void reduce(
DataT& dest, Args&&... val) {
auto view_val = {std::forward<Args>(val)...};
auto view_val = DataT{std::forward<Args>(val)...};
Kokkos::parallel_for(
"Rabenseifner::reduce", view_val.extent(0), KOKKOS_LAMBDA(const int i) {
Op<Scalar>()(dest(i), view_val(i));
Expand All @@ -205,6 +206,7 @@ struct StateBase {
virtual ~StateBase() = default;
size_t size_ = {};

uint32_t local_col_wait_count_ = 0;
bool finished_adjustment_part_ = false;

int32_t mask_ = 1;
Expand Down Expand Up @@ -255,6 +257,8 @@ struct State : StateBase {
MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> right_adjust_message_ = nullptr;
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> scatter_messages_ = {};
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> gather_messages_ = {};

vt::pipe::callback::cbunion::CallbackTyped<DataT> final_handler_ = {};
};

#if MAGISTRATE_KOKKOS_ENABLED
Expand All @@ -268,6 +272,8 @@ struct State<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>> : StateBase {
MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> right_adjust_message_ = nullptr;
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> scatter_messages_ = {};
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> gather_messages_ = {};

vt::pipe::callback::cbunion::CallbackTyped<DataT> final_handler_ = {};
};
#endif //MAGISTRATE_KOKKOS_ENABLED

Expand Down
76 changes: 43 additions & 33 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,34 +76,27 @@ struct ObjgroupAllreduceT {};
* \tparam Op Reduction operation (e.g., sum, max, min).
* \tparam finalHandler Callback handler for the final result.
*/
template <
typename DataT, template <typename Arg> class Op, auto f
>
template <template <typename Arg> class Op>
struct Rabenseifner {
using Data = DataT;
using DataType = DataHandler<DataT>;
using Scalar = typename DataType::Scalar;
using ReduceOp = Op<Scalar>;
using DataHelperT = DataHelper<Scalar, DataT>;
using StateT = State<Scalar, DataT>;

using Trait = ObjFuncTraits<decltype(f)>;
using CallbackType =
typename Trait::template WrapType<pipe::PipeManagerTL::CallbackRetType>;

template <typename ...Args>
Rabenseifner(detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems, Args&&... args);

template <typename ...Args>
Rabenseifner(detail::StrongGroup group, Args&&... args);

template <typename ...Args>
Rabenseifner(detail::StrongObjGroup objgroup, size_t id, Args&&... args);

void setFinalHandler(const CallbackType& fin) {
final_handler_ = fin;
}
template <typename ...Args>
// using Data = DataT;
// using DataType = DataHandler<DataT>;
// using Scalar = typename DataType::Scalar;
// using ReduceOp = Op<Scalar>;
// using DataHelperT = DataHelper<Scalar, DataT>;
// using StateT = State<Scalar, DataT>;

// using Trait = ObjFuncTraits<decltype(f)>;
// using CallbackType =
// typename Trait::template WrapType<pipe::PipeManagerTL::CallbackRetType>;

Rabenseifner(detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems);
Rabenseifner(detail::StrongGroup group);
Rabenseifner(detail::StrongObjGroup objgroup);

template <typename DataT, typename CallbackType>
void setFinalHandler(const CallbackType& fin, size_t id);

template <typename DataT, typename... Args>
void localReduce(size_t id, Args&&... args);
/**
* \brief Initialize the allreduce algorithm.
Expand All @@ -112,22 +105,25 @@ struct Rabenseifner {
*
* \param args Additional arguments for initializing the data value.
*/
template <typename ...Args>
template <typename DataT, typename ...Args>
void initialize(size_t id, Args&&... args);

template <typename DataT>
void initializeState(size_t id);
size_t generateNewId() { return id_++; }

/**
* \brief Execute the final handler callback with the reduced result.
*/
template <typename DataT>
void executeFinalHan(size_t id);

/**
* \brief Perform the allreduce operation.
*
* This function starts the allreduce operation, adjusting for non-power-of-two process counts if necessary.
*/
template <typename DataT>
void allreduce(size_t id);

/**
Expand All @@ -136,6 +132,7 @@ struct Rabenseifner {
* This function performs additional steps to handle non-power-of-two process counts, ensuring that the
* main scatter-reduce and gather-allgather phases can proceed with a power-of-two number of processes.
*/
template <typename DataT>
void adjustForPowerOfTwo(size_t id);

/**
Expand All @@ -145,6 +142,7 @@ struct Rabenseifner {
*
* \param msg Message containing the data from the partner process.
*/
template <typename DataT, typename Scalar = typename DataHandler<DataT>::Scalar>
void adjustForPowerOfTwoRightHalf(RabenseifnerMsg<Scalar, DataT>* msg);

/**
Expand All @@ -154,6 +152,7 @@ struct Rabenseifner {
*
* \param msg Message containing the data from the partner process.
*/
template <typename DataT, typename Scalar = typename DataHandler<DataT>::Scalar>
void adjustForPowerOfTwoLeftHalf(RabenseifnerMsg<Scalar, DataT>* msg);

/**
Expand All @@ -163,41 +162,47 @@ struct Rabenseifner {
*
* \param msg Message containing the data from the partner process.
*/
template <typename DataT, typename Scalar = typename DataHandler<DataT>::Scalar>
void adjustForPowerOfTwoFinalPart(RabenseifnerMsg<Scalar, DataT>* msg);

/**
* \brief Check if all scatter messages have been received.
*
* \return True if all scatter messages have been received, false otherwise.
*/
template <typename DataT>
bool scatterAllMessagesReceived(size_t id);

/**
* \brief Check if the scatter phase is complete.
*
* \return True if the scatter phase is complete, false otherwise.
*/
template <typename DataT>
bool scatterIsDone(size_t id);

/**
* \brief Check if the scatter phase is ready to proceed.
*
* \return True if the scatter phase is ready to proceed, false otherwise.
*/
template <typename DataT>
bool scatterIsReady(size_t id);

/**
* \brief Try to reduce the received scatter messages.
*
* \param step The current step in the scatter phase.
*/
template <typename DataT>
void scatterTryReduce(size_t id, int32_t step);

/**
* \brief Perform the scatter-reduce iteration.
*
* This function sends data to the appropriate partner process and proceeds to the next step in the scatter phase.
*/
template <typename DataT>
void scatterReduceIter(size_t id);

/**
Expand All @@ -207,41 +212,47 @@ struct Rabenseifner {
*
* \param msg Message containing the data from the partner process.
*/
template <typename DataT, typename Scalar>
void scatterReduceIterHandler(RabenseifnerMsg<Scalar, DataT>* msg);

/**
* \brief Check if all gather messages have been received.
*
* \return True if all gather messages have been received, false otherwise.
*/
template <typename DataT>
bool gatherAllMessagesReceived(size_t id);

/**
* \brief Check if the gather phase is complete.
*
* \return True if the gather phase is complete, false otherwise.
*/
template <typename DataT>
bool gatherIsDone(size_t id);

/**
* \brief Check if the gather phase is ready to proceed.
*
* \return True if the gather phase is ready to proceed, false otherwise.
*/
template <typename DataT>
bool gatherIsReady(size_t id);

/**
* \brief Try to reduce the received gather messages.
*
* \param step The current step in the gather phase.
*/
template <typename DataT>
void gatherTryReduce(size_t id, int32_t step);

/**
* \brief Perform the gather iteration.
*
* This function sends data to the appropriate partner process and proceeds to the next step in the gather phase.
*/
template <typename DataT>
void gatherIter(size_t id);

/**
Expand All @@ -251,20 +262,23 @@ struct Rabenseifner {
*
* \param msg Message containing the data from the partner process.
*/
template <typename DataT, typename Scalar>
void gatherIterHandler(RabenseifnerMsg<Scalar, DataT>* msg);

/**
* \brief Perform the final part of the allreduce operation.
*
* This function completes the allreduce operation, handling any remaining steps and invoking the final handler.
*/
template <typename DataT>
void finalPart(size_t id);

/**
* \brief Send the result to excluded nodes.
*
* This function handles the final step for non-power-of-two process counts, sending the reduced result to excluded nodes.
*/
template <typename DataT>
void sendToExcludedNodes(size_t id);

/**
Expand All @@ -274,20 +288,17 @@ struct Rabenseifner {
*
* \param msg Message containing the final result.
*/
template <typename DataT, typename Scalar>
void sendToExcludedNodesHandler(RabenseifnerMsg<Scalar, DataT>* msg);

vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};

CallbackType final_handler_ = {};

VirtualProxyType collection_proxy_ = u64empty;
ObjGroupProxyType objgroup_proxy_ = u64empty;

uint32_t local_col_wait_count_ = {};
size_t local_num_elems_ = {};

size_t id_ = 0;
// std::unordered_map<size_t, StateT> states_ = {};

/// Sorted list of Nodes that take part in allreduce
std::vector<NodeType> nodes_ = {};
Expand All @@ -314,7 +325,6 @@ struct Rabenseifner {

static inline const std::string name_ = "Rabenseifner";
static inline constexpr ReducerType type_ = ReducerType::Rabenseifner;
static constexpr bool KokkosPaylod = ShouldUseView_v<Scalar, DataT>;
};

} // namespace vt::collective::reduce::allreduce
Expand Down
Loading

0 comments on commit 9cba850

Please sign in to comment.