Skip to content

Commit

Permalink
#2281: Code refactor and minor bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Sep 30, 2024
1 parent d0cda33 commit effd316
Show file tree
Hide file tree
Showing 15 changed files with 279 additions and 346 deletions.
49 changes: 24 additions & 25 deletions src/vt/collective/reduce/allreduce/allreduce_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,23 @@

namespace vt::collective::reduce::allreduce {

void AllreduceHolder::createAllreducers(detail::StrongGroup strong_group) {
addRabensifnerAllreducer(strong_group);
addRecursiveDoublingAllreducer(strong_group);
template <typename MapT>
inline static void removeImpl(MapT& map, uint64_t key){
auto it = map.find(key);

if (it != map.end()) {
auto& [rabenseifner, recursive_doubling] = map.at(key);

if(rabenseifner) {
delete rabenseifner;
}

if(recursive_doubling) {
delete recursive_doubling;
}

map.erase(key);
}
}

Rabenseifner* AllreduceHolder::addRabensifnerAllreducer(
Expand All @@ -60,7 +74,7 @@ Rabenseifner* AllreduceHolder::addRabensifnerAllreducer(
col_reducers_[coll_proxy].first = obj_proxy;

vt_debug_print(
verbose, allreduce, "Adding new Rabenseifner reducer for collection={:x}",
verbose, allreduce, "Adding new Rabenseifner reducer for collection={:x}\n",
coll_proxy
);

Expand All @@ -79,7 +93,7 @@ AllreduceHolder::addRecursiveDoublingAllreducer(

vt_debug_print(
verbose, allreduce,
"Adding new RecursiveDoubling reducer for collection={:x}", coll_proxy
"Adding new RecursiveDoubling reducer for collection={:x}\n", coll_proxy
);

return obj_proxy;
Expand All @@ -96,7 +110,7 @@ AllreduceHolder::addRabensifnerAllreducer(detail::StrongGroup strong_group) {

vt_debug_print(
verbose, allreduce,
"Adding new Rabenseifner reducer for group={:x}", group
"Adding new Rabenseifner reducer for group={:x}\n", group
);

return obj_proxy;
Expand All @@ -112,7 +126,7 @@ AllreduceHolder::addRecursiveDoublingAllreducer(

vt_debug_print(
verbose, allreduce,
"Adding new Rabenseifner reducer for group={:x}", group
"Adding new RecursiveDoubling reducer for group={:x}\n", group
);

group_reducers_[group].second = obj_proxy;
Expand Down Expand Up @@ -157,32 +171,17 @@ AllreduceHolder::addRecursiveDoublingAllreducer(

void AllreduceHolder::remove(detail::StrongVrtProxy strong_proxy) {
auto const key = strong_proxy.get();

auto it = col_reducers_.find(key);

if (it != col_reducers_.end()) {
col_reducers_.erase(key);
}
removeImpl(col_reducers_, key);
}

void AllreduceHolder::remove(detail::StrongGroup strong_group) {
auto const key = strong_group.get();

auto it = group_reducers_.find(key);

if (it != group_reducers_.end()) {
group_reducers_.erase(key);
}
removeImpl(group_reducers_, key);
}

void AllreduceHolder::remove(detail::StrongObjGroup strong_objgroup) {
auto const key = strong_objgroup.get();

auto it = objgroup_reducers_.find(key);

if (it != objgroup_reducers_.end()) {
objgroup_reducers_.erase(key);
}
removeImpl(objgroup_reducers_, key);
}

} // namespace vt::collective::reduce::allreduce
147 changes: 22 additions & 125 deletions src/vt/collective/reduce/allreduce/allreduce_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
#include "vt/configs/types/types_sentinels.h"
#include "vt/objgroup/proxy/proxy_objgroup.h"

#include <type_traits>
#include <unordered_map>

namespace vt::collective::reduce::allreduce {
Expand All @@ -61,143 +60,43 @@ struct AllreduceHolder {
using RabenseifnerProxy = ObjGroupProxyType;
using RecursiveDoublingProxy = ObjGroupProxyType;

static void createAllreducers(detail::StrongGroup strong_group);

template <typename ReducerT>
static auto getAllreducer(
detail::StrongVrtProxy strong_proxy) {
auto const coll_proxy = strong_proxy.get();

auto it = col_reducers_.find(coll_proxy);
if(it == col_reducers_.end()){
col_reducers_[coll_proxy] = {nullptr, nullptr};
}

if constexpr(std::is_same_v<ReducerT, RabenseifnerT>){
return col_reducers_.at(coll_proxy).first;
}else {
return col_reducers_.at(coll_proxy).second;
}
}
static decltype(auto) getAllreducer(detail::StrongVrtProxy strong_proxy);

template <typename ReducerT>
static auto getAllreducer(
detail::StrongGroup strong_group) {
auto const group = strong_group.get();

auto it = group_reducers_.find(group);
if(it == group_reducers_.end()){
group_reducers_[group] = {nullptr, nullptr};
}

if constexpr(std::is_same_v<ReducerT, RabenseifnerT>){
return group_reducers_.at(group).first;
}else {
return group_reducers_.at(group).second;
}
}
static decltype(auto) getAllreducer(detail::StrongGroup strong_group);

template <typename ReducerT>
static auto getAllreducer(
detail::StrongObjGroup strong_objgroup) {
auto const objgroup = strong_objgroup.get();

auto it = objgroup_reducers_.find(objgroup);
if(it == objgroup_reducers_.end()){
objgroup_reducers_[objgroup] = {nullptr, nullptr};
}

if constexpr(std::is_same_v<ReducerT, RabenseifnerT>){
return objgroup_reducers_.at(objgroup).first;
}else {
return objgroup_reducers_.at(objgroup).second;
}
}
static decltype(auto) getAllreducer(detail::StrongObjGroup strong_objgroup);

template <typename ReducerT>
static auto getOrCreateAllreducer(
static decltype(auto) getOrCreateAllreducer(
detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group,
size_t num_elems) {
auto const coll_proxy = strong_proxy.get();

if (col_reducers_.find(coll_proxy) == col_reducers_.end()) {
col_reducers_[coll_proxy] = {nullptr, nullptr};
}

if constexpr (std::is_same_v<ReducerT, RabenseifnerT>) {
auto reducer = col_reducers_.at(coll_proxy).first;
if (reducer == nullptr) {
return addRabensifnerAllreducer(strong_proxy, strong_group, num_elems);
} else {
return reducer;
}
} else {
auto reducer = col_reducers_.at(coll_proxy).second;
if (reducer == nullptr) {
return addRecursiveDoublingAllreducer(
strong_proxy, strong_group, num_elems);
} else {
return reducer;
}
}
}
size_t num_elems);

template <typename ReducerT>
static auto getOrCreateAllreducer(detail::StrongGroup strong_group) {
auto const group = strong_group.get();

if (auto it = group_reducers_.find(group); it == group_reducers_.end()) {
group_reducers_[group] = {nullptr, nullptr};
}

if constexpr (std::is_same_v<ReducerT, RabenseifnerT>) {
auto reducer = group_reducers_.at(group).first;
if (reducer == nullptr) {
return addRabensifnerAllreducer(strong_group);
} else {
return reducer;
}
} else {
auto reducer = group_reducers_.at(group).second;
if (reducer == nullptr) {
return addRecursiveDoublingAllreducer(strong_group);
} else {
return reducer;
}
}
}
static decltype(auto) getOrCreateAllreducer(detail::StrongGroup strong_group);

template <typename ReducerT>
static auto getOrCreateAllreducer(detail::StrongObjGroup strong_objgroup) {
auto const objgroup = strong_objgroup.get();

if (auto it = objgroup_reducers_.find(objgroup); it == objgroup_reducers_.end()) {
objgroup_reducers_[objgroup] = {nullptr, nullptr};
}

if constexpr (std::is_same_v<ReducerT, RabenseifnerT>) {
auto reducer = objgroup_reducers_.at(objgroup).first;
if (reducer == nullptr) {
return addRabensifnerAllreducer(strong_objgroup);
} else {
return reducer;
}
} else {
auto reducer = objgroup_reducers_.at(objgroup).second;
if (reducer == nullptr) {
return addRecursiveDoublingAllreducer(strong_objgroup);
} else {
return reducer;
}
}
}
static decltype(auto)
getOrCreateAllreducer(detail::StrongObjGroup strong_objgroup);

static void remove(detail::StrongVrtProxy strong_proxy);
static void remove(detail::StrongGroup strong_group);
static void remove(detail::StrongObjGroup strong_group);

private:
template <typename ReducerT, typename MapT>
static decltype(auto) getAllreducerImpl(MapT& map, uint64_t id);

template <typename ReducerT, typename MapT, typename... Args>
static decltype(auto) getOrCreateAllreducerImpl(MapT& map, uint64_t id, Args&&... args);

static Rabenseifner* addRabensifnerAllreducer(
detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group,
size_t num_elems);

static RecursiveDoubling*
addRecursiveDoublingAllreducer(
static RecursiveDoubling* addRecursiveDoublingAllreducer(
detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group,
size_t num_elems);

Expand All @@ -211,10 +110,6 @@ struct AllreduceHolder {
static RecursiveDoubling*
addRecursiveDoublingAllreducer(detail::StrongObjGroup strong_group);

static void remove(detail::StrongVrtProxy strong_proxy);
static void remove(detail::StrongGroup strong_group);
static void remove(detail::StrongObjGroup strong_group);

static inline std::unordered_map<
VirtualProxyType, std::pair<Rabenseifner*, RecursiveDoubling*>>
col_reducers_ = {};
Expand Down Expand Up @@ -242,4 +137,6 @@ static inline auto* getAllreducer(ComponentInfo type) {

} // namespace vt::collective::reduce::allreduce

#include "vt/collective/reduce/allreduce/allreduce_holder.impl.h"

#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_HOLDER_H*/
Loading

0 comments on commit effd316

Please sign in to comment.