From ad13693fe8d2ea10df68b871010a31f72fbb47bd Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Tue, 5 Sep 2023 14:41:08 +0000 Subject: [PATCH 1/5] IB gather WIP --- include/mscclpp/core.hpp | 3 +- src/communicator.cc | 5 +- src/connection.cc | 4 +- src/ib.cc | 121 +++++++++++++++++++++++++++---------- src/include/connection.hpp | 2 +- src/include/ib.hpp | 17 ++++-- test/mp_unit/ib_tests.cu | 2 +- 7 files changed, 109 insertions(+), 45 deletions(-) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 1e9e6abd8..67ed523a3 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -542,7 +542,8 @@ class Communicator { /// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB. /// @return std::shared_ptr A shared pointer to the connection. std::shared_ptr connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024, - int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64); + int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64, + int ibMaxNumSgesPerWr = 16); /// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called. /// diff --git a/src/communicator.cc b/src/communicator.cc index cc0323556..d5b49fa16 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -98,7 +98,8 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int rem int ibMaxCqSize /*=1024*/, int ibMaxCqPollNum /*=1*/, int ibMaxSendWr /*=8192*/, - int ibMaxWrPerSend /*=64*/) { + int ibMaxWrPerSend /*=64*/, + int ibMaxNumSgesPerWr /*=16*/) { std::shared_ptr conn; if (transport == Transport::CudaIpc) { // sanity check: make sure the IPC connection is being made within a node @@ -116,7 +117,7 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int rem pimpl->rankToHash_[remoteRank]); } else if (AllIBTransports.has(transport)) { auto ibConn = std::make_shared(remoteRank, tag, transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr, - ibMaxWrPerSend, *pimpl); + ibMaxWrPerSend, ibMaxNumSgesPerWr, *pimpl); conn = ibConn; INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], diff --git a/src/connection.cc b/src/connection.cc index 112e11783..d8a055b52 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -83,13 +83,13 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) { // IBConnection IBConnection::IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr, - int maxWrPerSend, Communicator::Impl& commImpl) + int maxWrPerSend, int maxNumSgesPerWr, Communicator::Impl& commImpl) : ConnectionBase(remoteRank, tag), transport_(transport), remoteTransport_(Transport::Unknown), numSignaledSends(0), dummyAtomicSource_(std::make_unique(0)) { - qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend); + qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend, maxNumSgesPerWr); dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared( dummyAtomicSource_.get(), sizeof(uint64_t), commImpl.bootstrap_->getRank(), transport, commImpl)); validateTransport(dummyAtomicSourceMem_, transport); diff --git a/src/ib.cc b/src/ib.cc index 7a93a650e..0dcb33fdc 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -16,6 +16,16 @@ #include "api.h" #include "debug.h" +static ibv_device_attr getDeviceAttr(ibv_context *ctx) { + ibv_device_attr devAttr; + if (ibv_query_device(ctx, &devAttr) != 0) { + std::stringstream err; + err << "ibv_query_device failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } + return devAttr; +} + namespace mscclpp { IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { @@ -53,8 +63,8 @@ const void* IbMr::getBuff() const { return this->buff; } uint32_t IbMr::getLkey() const { return this->mr->lkey; } IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, - int maxWrPerSend) - : maxCqPollNum(maxCqPollNum), maxWrPerSend(maxWrPerSend) { + int maxWrPerSend, int maxNumSgesPerWr) + : maxCqPollNum_(maxCqPollNum), maxWrPerSend_(maxWrPerSend), maxNumSgesPerWr_(maxNumSgesPerWr) { this->cq = ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0); if (this->cq == nullptr) { std::stringstream err; @@ -117,10 +127,11 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN throw mscclpp::IbError(err.str(), errno); } this->qp = _qp; - this->wrn = 0; - this->wrs = std::make_unique(maxWrPerSend); - this->sges = std::make_unique(maxWrPerSend); - this->wcs = std::make_unique(maxCqPollNum); + this->wrs = std::make_unique(maxWrPerSend_); + this->sges = std::make_unique(maxWrPerSend_ * maxNumSgesPerWr_); + this->wcs = std::make_unique(maxCqPollNum_); + numStagedWrs_ = 0; + numStagedSges_ = 0; } IbQp::~IbQp() { @@ -181,29 +192,34 @@ void IbQp::rts() { } } -IbQp::WrInfo IbQp::getNewWrInfo() { - if (this->wrn >= this->maxWrPerSend) { +IbQp::WrInfo IbQp::getNewWrInfo(int numSges) { + if (numStagedWrs_ >= maxWrPerSend_) { std::stringstream err; - err << "too many outstanding work requests. limit is " << this->maxWrPerSend; + err << "too many outstanding work requests. limit is " << maxWrPerSend_; + throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); + } + if (numSges > maxNumSgesPerWr_) { + std::stringstream err; + err << "too many sges per work request. limit is " << maxNumSgesPerWr_; throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); } - int wrn = this->wrn; - ibv_send_wr* wr_ = &this->wrs[wrn]; - ibv_sge* sge_ = &this->sges[wrn]; + ibv_send_wr* wr_ = &this->wrs[numStagedWrs_]; + ibv_sge* sge_ = &this->sges[numStagedSges_]; wr_->sg_list = sge_; - wr_->num_sge = 1; + wr_->num_sge = numSges; wr_->next = nullptr; - if (wrn > 0) { - this->wrs[wrn - 1].next = wr_; + if (numStagedWrs_ > 0) { + this->wrs[numStagedWrs_ - 1].next = wr_; } - this->wrn++; + numStagedWrs_++; + numStagedSges_ += numSges; return IbQp::WrInfo{wr_, sge_}; } void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) { - auto wrInfo = this->getNewWrInfo(); + auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_RDMA_WRITE; wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; @@ -215,7 +231,7 @@ void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64 } void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal) { - auto wrInfo = this->getNewWrInfo(); + auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_ATOMIC_FETCH_AND_ADD; wrInfo.wr->send_flags = 0; // atomic op cannot be signaled @@ -229,7 +245,7 @@ void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, u void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) { - auto wrInfo = this->getNewWrInfo(); + auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; @@ -241,8 +257,28 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, wrInfo.sge->lkey = mr->getLkey(); } +void IbQp::stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, const std::vector& srcSizes, + uint64_t wrId, const std::vector& srcOffsets, uint64_t dstOffset, bool signaled) { + size_t numSrcs = srcMrs.size(); + if (numSrcs != srcSizes.size() || numSrcs != srcOffsets.size()) { + std::stringstream err; + err << "invalid srcs: srcMrs.size()=" << numSrcs << ", srcSizes.size()=" << srcSizes.size() + << ", srcOffsets.size()=" << srcOffsets.size(); + throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); + } + auto wrInfo = this->getNewWrInfo(numSrcs); + wrInfo.wr->wr_id = wrId; + wrInfo.wr->opcode = IBV_WR_RDMA_READ; + wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; + wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstInfo.addr) + dstOffset; + wrInfo.wr->wr.rdma.rkey = dstInfo.rkey; + // wrInfo.sge->addr = (uint64_t)(mr->getBuff()) + srcOffset; + // wrInfo.sge->length = size; + // wrInfo.sge->lkey = mr->getLkey(); +} + void IbQp::postSend() { - if (this->wrn == 0) { + if (numStagedWrs_ == 0) { return; } struct ibv_send_wr* bad_wr; @@ -252,7 +288,8 @@ void IbQp::postSend() { err << "ibv_post_send failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); } - this->wrn = 0; + numStagedWrs_ = 0; + numStagedSges_ = 0; } void IbQp::postRecv(uint64_t wrId) { @@ -269,7 +306,7 @@ void IbQp::postRecv(uint64_t wrId) { } } -int IbQp::pollCq() { return ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get()); } +int IbQp::pollCq() { return ibv_poll_cq(this->cq, maxCqPollNum_, this->wcs.get()); } IbQpInfo& IbQp::getInfo() { return this->info; } @@ -321,12 +358,7 @@ bool IbCtx::isPortUsable(int port) const { } int IbCtx::getAnyActivePort() const { - struct ibv_device_attr devAttr; - if (ibv_query_device(this->ctx, &devAttr) != 0) { - std::stringstream err; - err << "ibv_query_device failed (errno " << errno << ")"; - throw mscclpp::IbError(err.str(), errno); - } + ibv_device_attr devAttr = getDeviceAttr(this->ctx); for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) { if (this->isPortUsable(port)) { return port; @@ -335,17 +367,42 @@ int IbCtx::getAnyActivePort() const { return -1; } -IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, +void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, + int port) const { + if (!this->isPortUsable(port)) { + throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage); + } + ibv_device_attr devAttr = getDeviceAttr(this->ctx); + if (maxCqSize > devAttr.max_cqe || maxCqSize < 1) { + throw mscclpp::Error("invalid maxCqSize: " + std::to_string(maxCqSize), ErrorCode::InvalidUsage); + } + if (maxCqPollNum > maxCqSize || maxCqPollNum < 1) { + throw mscclpp::Error("invalid maxCqPollNum: " + std::to_string(maxCqPollNum), ErrorCode::InvalidUsage); + } + if (maxSendWr > devAttr.max_qp_wr || maxSendWr < 1) { + throw mscclpp::Error("invalid maxSendWr: " + std::to_string(maxSendWr), ErrorCode::InvalidUsage); + } + if (maxRecvWr > devAttr.max_qp_wr || maxRecvWr < 1) { + throw mscclpp::Error("invalid maxRecvWr: " + std::to_string(maxRecvWr), ErrorCode::InvalidUsage); + } + if (maxWrPerSend > maxSendWr || maxWrPerSend < 1) { + throw mscclpp::Error("invalid maxWrPerSend: " + std::to_string(maxWrPerSend), ErrorCode::InvalidUsage); + } + if (maxNumSgesPerWr > devAttr.max_sge || maxNumSgesPerWr < 1) { + throw mscclpp::Error("invalid maxNumSgesPerWr: " + std::to_string(maxNumSgesPerWr), ErrorCode::InvalidUsage); + } +} + +IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port /*=-1*/) { if (port == -1) { port = this->getAnyActivePort(); if (port == -1) { throw mscclpp::Error("No active port found", ErrorCode::InternalError); } - } else if (!this->isPortUsable(port)) { - throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError); } - qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend)); + validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); + qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr)); return qps.back().get(); } diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 0475691c9..106b86d99 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -57,7 +57,7 @@ class IBConnection : public ConnectionBase { public: IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr, - int maxWrPerSend, Communicator::Impl& commImpl); + int maxWrPerSend, int maxNumSgesPerWr, Communicator::Impl& commImpl); Transport transport() override; diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 1bec30b85..0126ef898 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -66,6 +66,8 @@ class IbQp { void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal); void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); + void stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, const std::vector& srcSizes, + uint64_t wrId, const std::vector& srcOffsets, uint64_t dstOffset, bool signaled); void postSend(); void postRecv(uint64_t wrId); int pollCq(); @@ -80,8 +82,8 @@ class IbQp { }; IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, - int maxWrPerSend); - WrInfo getNewWrInfo(); + int maxWrPerSend, int maxNumSgesPerWr); + WrInfo getNewWrInfo(int numSges); IbQpInfo info; @@ -90,10 +92,12 @@ class IbQp { std::unique_ptr wcs; std::unique_ptr wrs; std::unique_ptr sges; - int wrn; + int numStagedWrs_; + int numStagedSges_; - const int maxCqPollNum; - const int maxWrPerSend; + const int maxCqPollNum_; + const int maxWrPerSend_; + const int maxNumSgesPerWr_; friend class IbCtx; }; @@ -103,7 +107,7 @@ class IbCtx { IbCtx(const std::string& devName); ~IbCtx(); - IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int port = -1); + IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port = -1); const IbMr* registerMr(void* buff, std::size_t size); const std::string& getDevName() const; @@ -111,6 +115,7 @@ class IbCtx { private: bool isPortUsable(int port) const; int getAnyActivePort() const; + void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port) const; const std::string devName; ibv_context* ctx; diff --git a/test/mp_unit/ib_tests.cu b/test/mp_unit/ib_tests.cu index 7ab892b51..a38f992e6 100644 --- a/test/mp_unit/ib_tests.cu +++ b/test/mp_unit/ib_tests.cu @@ -36,7 +36,7 @@ void IbPeerToPeerTest::SetUp() { bootstrap->initialize(id); ibCtx = std::make_shared(ibDevName); - qp = ibCtx->createQp(1024, 1, 8192, 0, 64); + qp = ibCtx->createQp(1024, 1, 8192, 0, 64, 1); qpInfo[gEnv->rank] = qp->getInfo(); bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo)); From 89cad567210b7b447b6ee9ba905b6dd14b4a1b40 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 6 Sep 2023 02:33:21 +0000 Subject: [PATCH 2/5] updates --- src/communicator.cc | 9 ++--- src/ib.cc | 95 ++++++++++++++++++++++++--------------------- src/include/ib.hpp | 7 +++- 3 files changed, 59 insertions(+), 52 deletions(-) diff --git a/src/communicator.cc b/src/communicator.cc index d5b49fa16..b87388b78 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -94,12 +94,9 @@ MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSe return NonblockingFuture(memoryReceiver->memoryPromise_.get_future()); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int remoteRank, int tag, Transport transport, - int ibMaxCqSize /*=1024*/, - int ibMaxCqPollNum /*=1*/, - int ibMaxSendWr /*=8192*/, - int ibMaxWrPerSend /*=64*/, - int ibMaxNumSgesPerWr /*=16*/) { +MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup( + int remoteRank, int tag, Transport transport, int ibMaxCqSize /*=1024*/, int ibMaxCqPollNum /*=1*/, + int ibMaxSendWr /*=8192*/, int ibMaxWrPerSend /*=64*/, int ibMaxNumSgesPerWr /*=16*/) { std::shared_ptr conn; if (transport == Transport::CudaIpc) { // sanity check: make sure the IPC connection is being made within a node diff --git a/src/ib.cc b/src/ib.cc index 0dcb33fdc..63eb1040d 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -16,7 +16,7 @@ #include "api.h" #include "debug.h" -static ibv_device_attr getDeviceAttr(ibv_context *ctx) { +static ibv_device_attr getDeviceAttr(ibv_context* ctx) { ibv_device_attr devAttr; if (ibv_query_device(ctx, &devAttr) != 0) { std::stringstream err; @@ -26,6 +26,12 @@ static ibv_device_attr getDeviceAttr(ibv_context *ctx) { return devAttr; } +static ibv_qp_attr createQpAttr() { + ibv_qp_attr qpAttr; + std::memset(&qpAttr, 0, sizeof(qpAttr)); + return qpAttr; +} + namespace mscclpp { IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { @@ -115,8 +121,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN this->info.iid = gid.global.interface_id; } - struct ibv_qp_attr qpAttr; - memset(&qpAttr, 0, sizeof(qpAttr)); + ibv_qp_attr qpAttr = createQpAttr(); qpAttr.qp_state = IBV_QPS_INIT; qpAttr.pkey_index = 0; qpAttr.port_num = port; @@ -140,30 +145,29 @@ IbQp::~IbQp() { } void IbQp::rtr(const IbQpInfo& info) { - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_RTR; - qp_attr.path_mtu = static_cast(info.mtu); - qp_attr.dest_qp_num = info.qpn; - qp_attr.rq_psn = 0; - qp_attr.max_dest_rd_atomic = 1; - qp_attr.min_rnr_timer = 0x12; + ibv_qp_attr qpAttr = createQpAttr(); + qpAttr.qp_state = IBV_QPS_RTR; + qpAttr.path_mtu = static_cast(info.mtu); + qpAttr.dest_qp_num = info.qpn; + qpAttr.rq_psn = 0; + qpAttr.max_dest_rd_atomic = 1; + qpAttr.min_rnr_timer = 0x12; if (info.linkLayer == IBV_LINK_LAYER_ETHERNET || info.is_grh) { - qp_attr.ah_attr.is_global = 1; - qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info.spn; - qp_attr.ah_attr.grh.dgid.global.interface_id = info.iid; - qp_attr.ah_attr.grh.flow_label = 0; - qp_attr.ah_attr.grh.sgid_index = 0; - qp_attr.ah_attr.grh.hop_limit = 255; - qp_attr.ah_attr.grh.traffic_class = 0; + qpAttr.ah_attr.is_global = 1; + qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info.spn; + qpAttr.ah_attr.grh.dgid.global.interface_id = info.iid; + qpAttr.ah_attr.grh.flow_label = 0; + qpAttr.ah_attr.grh.sgid_index = 0; + qpAttr.ah_attr.grh.hop_limit = 255; + qpAttr.ah_attr.grh.traffic_class = 0; } else { - qp_attr.ah_attr.is_global = 0; + qpAttr.ah_attr.is_global = 0; } - qp_attr.ah_attr.dlid = info.lid; - qp_attr.ah_attr.sl = 0; - qp_attr.ah_attr.src_path_bits = 0; - qp_attr.ah_attr.port_num = info.port; - int ret = ibv_modify_qp(this->qp, &qp_attr, + qpAttr.ah_attr.dlid = info.lid; + qpAttr.ah_attr.sl = 0; + qpAttr.ah_attr.src_path_bits = 0; + qpAttr.ah_attr.port_num = info.port; + int ret = ibv_modify_qp(this->qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); if (ret != 0) { @@ -174,16 +178,15 @@ void IbQp::rtr(const IbQpInfo& info) { } void IbQp::rts() { - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_RTS; - qp_attr.timeout = 18; - qp_attr.retry_cnt = 7; - qp_attr.rnr_retry = 7; - qp_attr.sq_psn = 0; - qp_attr.max_rd_atomic = 1; + ibv_qp_attr qpAttr = createQpAttr(); + qpAttr.qp_state = IBV_QPS_RTS; + qpAttr.timeout = 18; + qpAttr.retry_cnt = 7; + qpAttr.rnr_retry = 7; + qpAttr.sq_psn = 0; + qpAttr.max_rd_atomic = 1; int ret = ibv_modify_qp( - this->qp, &qp_attr, + this->qp, &qpAttr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); if (ret != 0) { std::stringstream err; @@ -257,8 +260,9 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, wrInfo.sge->lkey = mr->getLkey(); } -void IbQp::stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, const std::vector& srcSizes, - uint64_t wrId, const std::vector& srcOffsets, uint64_t dstOffset, bool signaled) { +void IbQp::stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, + const std::vector& srcSizes, uint64_t wrId, + const std::vector& srcOffsets, uint64_t dstOffset, bool signaled) { size_t numSrcs = srcMrs.size(); if (numSrcs != srcSizes.size() || numSrcs != srcOffsets.size()) { std::stringstream err; @@ -272,9 +276,11 @@ void IbQp::stageSendGather(const std::vector& srcMrs, const IbMrInfo& dst wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstInfo.addr) + dstOffset; wrInfo.wr->wr.rdma.rkey = dstInfo.rkey; - // wrInfo.sge->addr = (uint64_t)(mr->getBuff()) + srcOffset; - // wrInfo.sge->length = size; - // wrInfo.sge->lkey = mr->getLkey(); + for (size_t i = 0; i < numSrcs; ++i) { + wrInfo.sge[i].addr = (uint64_t)(srcMrs[i]->getBuff()) + srcOffsets[i]; + wrInfo.sge[i].length = srcSizes[i]; + wrInfo.sge[i].lkey = srcMrs[i]->getLkey(); + } } void IbQp::postSend() { @@ -367,8 +373,8 @@ int IbCtx::getAnyActivePort() const { return -1; } -void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, - int port) const { +void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port) const { if (!this->isPortUsable(port)) { throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage); } @@ -393,16 +399,17 @@ void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int m } } -IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, - int port /*=-1*/) { +IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port /*=-1*/) { if (port == -1) { port = this->getAnyActivePort(); if (port == -1) { throw mscclpp::Error("No active port found", ErrorCode::InternalError); } } - validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); - qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr)); + this->validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); + qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, + maxNumSgesPerWr)); return qps.back().get(); } diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 0126ef898..cb909111b 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -7,6 +7,7 @@ #include #include #include +#include // Forward declarations of IB structures struct ibv_context; @@ -107,7 +108,8 @@ class IbCtx { IbCtx(const std::string& devName); ~IbCtx(); - IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port = -1); + IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, + int port = -1); const IbMr* registerMr(void* buff, std::size_t size); const std::string& getDevName() const; @@ -115,7 +117,8 @@ class IbCtx { private: bool isPortUsable(int port) const; int getAnyActivePort() const; - void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port) const; + void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port) const; const std::string devName; ibv_context* ctx; From 8232ec731ffa1ea814a428017ea729e7643b6093 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 6 Sep 2023 12:15:12 +0000 Subject: [PATCH 3/5] Working --- include/mscclpp/core.hpp | 2 +- python/mscclpp/core_py.cpp | 2 +- src/communicator.cc | 2 +- src/ib.cc | 104 +++++++++++---------- src/include/ib.hpp | 27 +++--- test/mp_unit/ib_tests.cu | 163 ++++++++++++++++++++++++++++----- test/mp_unit/mp_unit_tests.hpp | 12 ++- 7 files changed, 224 insertions(+), 88 deletions(-) diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 67ed523a3..76332c60f 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -543,7 +543,7 @@ class Communicator { /// @return std::shared_ptr A shared pointer to the connection. std::shared_ptr connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024, int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64, - int ibMaxNumSgesPerWr = 16); + int ibMaxNumSgesPerWr = 1); /// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called. /// diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index a65a443a6..ce0fc606e 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -141,7 +141,7 @@ void register_core(nb::module_& m) { .def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag")) .def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1, - nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64) + nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64, nb::arg("ibMaxNumSgesPerWr") = 1) .def("setup", &Communicator::setup); } diff --git a/src/communicator.cc b/src/communicator.cc index b87388b78..891e9a917 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -96,7 +96,7 @@ MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSe MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup( int remoteRank, int tag, Transport transport, int ibMaxCqSize /*=1024*/, int ibMaxCqPollNum /*=1*/, - int ibMaxSendWr /*=8192*/, int ibMaxWrPerSend /*=64*/, int ibMaxNumSgesPerWr /*=16*/) { + int ibMaxSendWr /*=8192*/, int ibMaxWrPerSend /*=64*/, int ibMaxNumSgesPerWr /*=1*/) { std::shared_ptr conn; if (transport == Transport::CudaIpc) { // sanity check: make sure the IPC connection is being made within a node diff --git a/src/ib.cc b/src/ib.cc index 63eb1040d..50fddc965 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -16,16 +16,6 @@ #include "api.h" #include "debug.h" -static ibv_device_attr getDeviceAttr(ibv_context* ctx) { - ibv_device_attr devAttr; - if (ibv_query_device(ctx, &devAttr) != 0) { - std::stringstream err; - err << "ibv_query_device failed (errno " << errno << ")"; - throw mscclpp::IbError(err.str(), errno); - } - return devAttr; -} - static ibv_qp_attr createQpAttr() { ibv_qp_attr qpAttr; std::memset(&qpAttr, 0, sizeof(qpAttr)); @@ -34,17 +24,13 @@ static ibv_qp_attr createQpAttr() { namespace mscclpp { -IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { - if (size == 0) { - throw std::invalid_argument("invalid size: " + std::to_string(size)); - } +IbMr::IbMr(ibv_pd* pd, void* buff, size_t alignedSize) : buff(buff) { static __thread uintptr_t pageSize = 0; if (pageSize == 0) { pageSize = sysconf(_SC_PAGESIZE); } uintptr_t addr = reinterpret_cast(buff) & -pageSize; - std::size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; - this->mr = ibv_reg_mr(pd, reinterpret_cast(addr), pages * pageSize, + this->mr = ibv_reg_mr(pd, reinterpret_cast(addr), alignedSize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC); if (this->mr == nullptr) { @@ -52,7 +38,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { err << "ibv_reg_mr failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); } - this->size = pages * pageSize; + this->size = alignedSize; } IbMr::~IbMr() { ibv_dereg_mr(this->mr); } @@ -220,8 +206,8 @@ IbQp::WrInfo IbQp::getNewWrInfo(int numSges) { return IbQp::WrInfo{wr_, sge_}; } -void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled) { +void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled) { auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_RDMA_WRITE; @@ -233,7 +219,7 @@ void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64 wrInfo.sge->lkey = mr->getLkey(); } -void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal) { +void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint32_t dstOffset, uint64_t addVal) { auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_ATOMIC_FETCH_AND_ADD; @@ -246,8 +232,8 @@ void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, u wrInfo.sge->lkey = mr->getLkey(); } -void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled, unsigned int immData) { +void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled, unsigned int immData) { auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; @@ -260,26 +246,26 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, wrInfo.sge->lkey = mr->getLkey(); } -void IbQp::stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, - const std::vector& srcSizes, uint64_t wrId, - const std::vector& srcOffsets, uint64_t dstOffset, bool signaled) { - size_t numSrcs = srcMrs.size(); - if (numSrcs != srcSizes.size() || numSrcs != srcOffsets.size()) { +void IbQp::stageSendGather(const std::vector& srcMrList, const IbMrInfo& dstMrInfo, + const std::vector& srcSizeList, uint64_t wrId, + const std::vector& srcOffsetList, uint32_t dstOffset, bool signaled) { + size_t numSrcs = srcMrList.size(); + if (numSrcs != srcSizeList.size() || numSrcs != srcOffsetList.size()) { std::stringstream err; - err << "invalid srcs: srcMrs.size()=" << numSrcs << ", srcSizes.size()=" << srcSizes.size() - << ", srcOffsets.size()=" << srcOffsets.size(); + err << "invalid srcs: srcMrList.size()=" << numSrcs << ", srcSizeList.size()=" << srcSizeList.size() + << ", srcOffsetList.size()=" << srcOffsetList.size(); throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); } auto wrInfo = this->getNewWrInfo(numSrcs); wrInfo.wr->wr_id = wrId; - wrInfo.wr->opcode = IBV_WR_RDMA_READ; + wrInfo.wr->opcode = IBV_WR_RDMA_WRITE; wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; - wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstInfo.addr) + dstOffset; - wrInfo.wr->wr.rdma.rkey = dstInfo.rkey; + wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstMrInfo.addr) + dstOffset; + wrInfo.wr->wr.rdma.rkey = dstMrInfo.rkey; for (size_t i = 0; i < numSrcs; ++i) { - wrInfo.sge[i].addr = (uint64_t)(srcMrs[i]->getBuff()) + srcOffsets[i]; - wrInfo.sge[i].length = srcSizes[i]; - wrInfo.sge[i].lkey = srcMrs[i]->getLkey(); + wrInfo.sge[i].addr = (uint64_t)(srcMrList[i]->getBuff()) + srcOffsetList[i]; + wrInfo.sge[i].length = srcSizeList[i]; + wrInfo.sge[i].lkey = srcMrList[i]->getLkey(); } } @@ -339,6 +325,13 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) { err << "ibv_alloc_pd failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); } + // TODO: do not use new + this->devAttr = new ibv_device_attr; + if (ibv_query_device(this->ctx, this->devAttr) != 0) { + std::stringstream err; + err << "ibv_query_device failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } } IbCtx::~IbCtx() { @@ -350,6 +343,8 @@ IbCtx::~IbCtx() { if (this->ctx != nullptr) { ibv_close_device(this->ctx); } + // TODO: do not use delete + delete this->devAttr; } bool IbCtx::isPortUsable(int port) const { @@ -364,8 +359,7 @@ bool IbCtx::isPortUsable(int port) const { } int IbCtx::getAnyActivePort() const { - ibv_device_attr devAttr = getDeviceAttr(this->ctx); - for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) { + for (uint8_t port = 1; port <= this->devAttr->phys_port_cnt; ++port) { if (this->isPortUsable(port)) { return port; } @@ -373,28 +367,27 @@ int IbCtx::getAnyActivePort() const { return -1; } -void IbCtx::validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, - int maxNumSgesPerWr, int port) const { +void IbCtx::validateQpConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port) const { if (!this->isPortUsable(port)) { throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage); } - ibv_device_attr devAttr = getDeviceAttr(this->ctx); - if (maxCqSize > devAttr.max_cqe || maxCqSize < 1) { + if (maxCqSize > this->devAttr->max_cqe || maxCqSize < 1) { throw mscclpp::Error("invalid maxCqSize: " + std::to_string(maxCqSize), ErrorCode::InvalidUsage); } if (maxCqPollNum > maxCqSize || maxCqPollNum < 1) { throw mscclpp::Error("invalid maxCqPollNum: " + std::to_string(maxCqPollNum), ErrorCode::InvalidUsage); } - if (maxSendWr > devAttr.max_qp_wr || maxSendWr < 1) { + if (maxSendWr > this->devAttr->max_qp_wr) { throw mscclpp::Error("invalid maxSendWr: " + std::to_string(maxSendWr), ErrorCode::InvalidUsage); } - if (maxRecvWr > devAttr.max_qp_wr || maxRecvWr < 1) { + if (maxRecvWr > this->devAttr->max_qp_wr) { throw mscclpp::Error("invalid maxRecvWr: " + std::to_string(maxRecvWr), ErrorCode::InvalidUsage); } if (maxWrPerSend > maxSendWr || maxWrPerSend < 1) { throw mscclpp::Error("invalid maxWrPerSend: " + std::to_string(maxWrPerSend), ErrorCode::InvalidUsage); } - if (maxNumSgesPerWr > devAttr.max_sge || maxNumSgesPerWr < 1) { + if (maxNumSgesPerWr > this->devAttr->max_sge || maxNumSgesPerWr < 1) { throw mscclpp::Error("invalid maxNumSgesPerWr: " + std::to_string(maxNumSgesPerWr), ErrorCode::InvalidUsage); } } @@ -407,14 +400,31 @@ IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRec throw mscclpp::Error("No active port found", ErrorCode::InternalError); } } - this->validateConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); + this->validateQpConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr)); return qps.back().get(); } -const IbMr* IbCtx::registerMr(void* buff, std::size_t size) { - mrs.emplace_back(new IbMr(this->pd, buff, size)); +const IbMr* IbCtx::registerMr(void* buff, uint32_t size) { + if (size == 0) { + throw mscclpp::Error("invalid size: " + std::to_string(size), ErrorCode::InvalidUsage); + } + static __thread uintptr_t pageSize = 0; + if (pageSize == 0) { + pageSize = sysconf(_SC_PAGESIZE); + } + uintptr_t addr = reinterpret_cast(buff) & -pageSize; + size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; + + size_t alignedSize = pages * pageSize; + if (alignedSize > this->devAttr->max_mr_size) { + std::stringstream err; + err << "invalid MR size: " << alignedSize << " max " << this->devAttr->max_mr_size; + throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); + } + + mrs.emplace_back(new IbMr(this->pd, buff, alignedSize)); return mrs.back().get(); } diff --git a/src/include/ib.hpp b/src/include/ib.hpp index cb909111b..440c208ce 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -11,6 +11,7 @@ // Forward declarations of IB structures struct ibv_context; +struct ibv_device_attr; struct ibv_pd; struct ibv_mr; struct ibv_qp; @@ -35,11 +36,11 @@ class IbMr { uint32_t getLkey() const; private: - IbMr(ibv_pd* pd, void* buff, std::size_t size); + IbMr(ibv_pd* pd, void* buff, size_t alignedSize); ibv_mr* mr; void* buff; - std::size_t size; + size_t size; friend class IbCtx; }; @@ -62,13 +63,14 @@ class IbQp { void rtr(const IbQpInfo& info); void rts(); - void stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled); - void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal); - void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled, unsigned int immData); - void stageSendGather(const std::vector& srcMrs, const IbMrInfo& dstInfo, const std::vector& srcSizes, - uint64_t wrId, const std::vector& srcOffsets, uint64_t dstOffset, bool signaled); + void stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled); + void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint32_t dstOffset, uint64_t addVal); + void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled, unsigned int immData); + void stageSendGather(const std::vector& srcMrList, const IbMrInfo& dstMrInfo, + const std::vector& srcSizeList, uint64_t wrId, + const std::vector& srcOffsetList, uint32_t dstOffset, bool signaled); void postSend(); void postRecv(uint64_t wrId); int pollCq(); @@ -110,19 +112,20 @@ class IbCtx { IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, int port = -1); - const IbMr* registerMr(void* buff, std::size_t size); + const IbMr* registerMr(void* buff, uint32_t size); const std::string& getDevName() const; private: bool isPortUsable(int port) const; int getAnyActivePort() const; - void validateConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, - int maxNumSgesPerWr, int port) const; + void validateQpConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port) const; const std::string devName; ibv_context* ctx; ibv_pd* pd; + ibv_device_attr* devAttr; std::list> qps; std::list> mrs; }; diff --git a/test/mp_unit/ib_tests.cu b/test/mp_unit/ib_tests.cu index a38f992e6..2a7d38412 100644 --- a/test/mp_unit/ib_tests.cu +++ b/test/mp_unit/ib_tests.cu @@ -38,39 +38,57 @@ void IbPeerToPeerTest::SetUp() { ibCtx = std::make_shared(ibDevName); qp = ibCtx->createQp(1024, 1, 8192, 0, 64, 1); - qpInfo[gEnv->rank] = qp->getInfo(); - bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo)); + int remoteRank = (gEnv->rank == 0) ? 1 : 0; + + mscclpp::IbQpInfo localQpInfo = qp->getInfo(); + bootstrap->send(&localQpInfo, sizeof(mscclpp::IbQpInfo), remoteRank, /*tag=*/0); + bootstrap->recv(&remoteQpInfo, sizeof(mscclpp::IbQpInfo), remoteRank, /*tag=*/0); } -void IbPeerToPeerTest::registerBufferAndConnect(void* buf, size_t size) { - bufSize = size; - mr = ibCtx->registerMr(buf, size); - mrInfo[gEnv->rank] = mr->getInfo(); - bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo)); - - for (int i = 0; i < bootstrap->getNranks(); ++i) { - if (i == gEnv->rank) continue; - qp->rtr(qpInfo[i]); - qp->rts(); - break; +void IbPeerToPeerTest::registerBuffersAndConnect(const std::vector& bufList, + const std::vector& sizeList) { + size_t numMrs = bufList.size(); + if (numMrs != sizeList.size()) { + throw std::runtime_error("bufList.size() != sizeList.size()"); + } + + // Assume the remote side registers the same number of MRs + std::vector localMrInfo; + for (size_t i = 0; i < numMrs; ++i) { + const mscclpp::IbMr* mr = ibCtx->registerMr(bufList[i], sizeList[i]); + localMrList.push_back(mr); + localMrInfo.emplace_back(mr->getInfo()); } + + int remoteRank = (gEnv->rank == 0) ? 1 : 0; + + // Send the number of MRs and the MR info to the remote side + bootstrap->send(&numMrs, sizeof(numMrs), remoteRank, /*tag=*/0); + bootstrap->send(localMrInfo.data(), sizeof(mscclpp::IbMrInfo) * numMrs, remoteRank, /*tag=*/1); + + // Receive the number of MRs and the MR info from the remote side + size_t numRemoteMrs; + bootstrap->recv(&numRemoteMrs, sizeof(numRemoteMrs), remoteRank, /*tag=*/0); + remoteMrInfoList.resize(numRemoteMrs); + bootstrap->recv(remoteMrInfoList.data(), sizeof(mscclpp::IbMrInfo) * numRemoteMrs, remoteRank, /*tag=*/1); + + qp->rtr(remoteQpInfo); + qp->rts(); + bootstrap->barrier(); } void IbPeerToPeerTest::stageSend(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) { - const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1]; - qp->stageSend(mr, remoteMrInfo, size, wrId, srcOffset, dstOffset, signaled); + qp->stageSend(localMrList[0], remoteMrInfoList[0], size, wrId, srcOffset, dstOffset, signaled); } void IbPeerToPeerTest::stageAtomicAdd(uint64_t wrId, uint64_t dstOffset, uint64_t addVal) { - const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1]; - qp->stageAtomicAdd(mr, remoteMrInfo, wrId, dstOffset, addVal); + qp->stageAtomicAdd(localMrList[0], remoteMrInfoList[0], wrId, dstOffset, addVal); } void IbPeerToPeerTest::stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) { - const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1]; - qp->stageSendWithImm(mr, remoteMrInfo, size, wrId, srcOffset, dstOffset, signaled, immData); + qp->stageSendWithImm(localMrList[0], remoteMrInfoList[0], size, wrId, srcOffset, dstOffset, signaled, immData); } TEST_F(IbPeerToPeerTest, SimpleSendRecv) { @@ -85,7 +103,7 @@ TEST_F(IbPeerToPeerTest, SimpleSendRecv) { const int nelem = 1; auto data = mscclpp::allocUniqueCuda(nelem); - registerBufferAndConnect(data.get(), sizeof(int) * nelem); + registerBuffersAndConnect({data.get()}, {sizeof(int) * nelem}); if (gEnv->rank == 1) { mscclpp::Timer timer; @@ -194,7 +212,7 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) { const uint64_t nelem = 65536 + 1; auto data = mscclpp::allocUniqueCuda(nelem); - registerBufferAndConnect(data.get(), sizeof(uint64_t) * nelem); + registerBuffersAndConnect({data.get()}, {sizeof(uint64_t) * nelem}); uint64_t res = 0; uint64_t iter = 0; @@ -288,3 +306,106 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) { EXPECT_EQ(res, 0); } + +TEST_F(IbPeerToPeerTest, SendGather) { + if (gEnv->rank >= 2) { + // This test needs only two ranks + return; + } + + mscclpp::Timer timeout(3); + + const int numDataSrcs = 1; + const int nelemPerMr = 1024; + + // Gather send from rank 0 to 1 + if (gEnv->rank == 0) { + std::vector> dataList; + for (int i = 0; i < numDataSrcs; ++i) { + auto data = mscclpp::allocUniqueCuda(nelemPerMr); + // Fill in data for correctness check + std::vector hostData(nelemPerMr, i + 1); + mscclpp::memcpyCuda(data.get(), hostData.data(), nelemPerMr); + dataList.emplace_back(std::move(data)); + } + + std::vector dataRefList; + for (int i = 0; i < numDataSrcs; ++i) { + dataRefList.emplace_back(dataList[i].get()); + } + + // For sending a completion signal to the remote side + uint64_t outboundSema = 1; + + dataRefList.push_back(&outboundSema); + + std::vector sizeList(numDataSrcs, sizeof(int) * nelemPerMr); + sizeList.push_back(sizeof(outboundSema)); + + registerBuffersAndConnect(dataRefList, sizeList); + + auto& remoteDataMrInfo = remoteMrInfoList[0]; + auto& remoteSemaMrInfo = remoteMrInfoList[1]; + auto& localSemaMr = localMrList[numDataSrcs]; + + std::vector gatherLocalMrList; + for (int i = 0; i < numDataSrcs; ++i) { + gatherLocalMrList.emplace_back(localMrList[i]); + } + std::vector gatherSizeList(numDataSrcs, sizeof(int) * nelemPerMr); + std::vector gatherOffsetList(numDataSrcs, 0); + + qp->stageSendGather(gatherLocalMrList, remoteDataMrInfo, gatherSizeList, /*wrId=*/0, gatherOffsetList, + /*dstOffset=*/0, /*signaled=*/true); + qp->postSend(); + + qp->stageAtomicAdd(localSemaMr, remoteSemaMrInfo, /*wrId=*/0, /*dstOffset=*/0, /*addVal=*/1); + qp->postSend(); + + // Wait for send completion + bool waiting = true; + int spin = 0; + while (waiting) { + int wcNum = qp->pollCq(); + ASSERT_GE(wcNum, 0); + for (int i = 0; i < wcNum; ++i) { + const ibv_wc* wc = qp->getWc(i); + EXPECT_EQ(wc->status, IBV_WC_SUCCESS); + waiting = false; + break; + } + if (spin++ > 1000000) { + FAIL() << "Polling is stuck."; + } + } + } else { + // Data array to receive + auto data = mscclpp::allocUniqueCuda(nelemPerMr * numDataSrcs); + + // For receiving a completion signal from the remote side + uint64_t inboundSema = 0; + + registerBuffersAndConnect({data.get(), &inboundSema}, + {sizeof(int) * nelemPerMr * numDataSrcs, sizeof(inboundSema)}); + + // Wait for a signal from the remote side + volatile uint64_t* ptrInboundSema = &inboundSema; + int spin = 0; + while (*ptrInboundSema == 0) { + if (spin++ > 1000000) { + FAIL() << "Polling is stuck."; + } + } + + // Correctness check + std::vector hostData(nelemPerMr * numDataSrcs); + mscclpp::memcpyCuda(hostData.data(), data.get(), nelemPerMr * numDataSrcs); + for (int i = 0; i < numDataSrcs; ++i) { + for (int j = 0; j < nelemPerMr; ++j) { + EXPECT_EQ(hostData[i * nelemPerMr + j], i + 1); + } + } + } + + bootstrap->barrier(); +} diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index 393255638..0ec62e79f 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -67,7 +67,7 @@ class IbPeerToPeerTest : public IbTestBase { protected: void SetUp() override; - void registerBufferAndConnect(void* buf, size_t size); + void registerBuffersAndConnect(const std::vector& bufList, const std::vector& sizeList); void stageSend(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled); @@ -76,14 +76,16 @@ class IbPeerToPeerTest : public IbTestBase { void stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); + void stageSendGather(const std::vector& sizeList, uint64_t wrId, const std::vector& srcOffsetList, + uint32_t dstOffset, bool signaled); + std::shared_ptr bootstrap; std::shared_ptr ibCtx; mscclpp::IbQp* qp; - const mscclpp::IbMr* mr; - size_t bufSize; + std::vector localMrList; - std::array qpInfo; - std::array mrInfo; + mscclpp::IbQpInfo remoteQpInfo; + std::vector remoteMrInfoList; }; class CommunicatorTestBase : public MultiProcessTest { From e48f942068b535f9bc8e06cf2eecc0f0fd100344 Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 6 Sep 2023 12:30:04 +0000 Subject: [PATCH 4/5] Gather size 4 --- src/ib.cc | 4 ++-- test/mp_unit/ib_tests.cu | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ib.cc b/src/ib.cc index 50fddc965..9fd16082b 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -72,8 +72,8 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN qpInitAttr.qp_type = IBV_QPT_RC; qpInitAttr.cap.max_send_wr = maxSendWr; qpInitAttr.cap.max_recv_wr = maxRecvWr; - qpInitAttr.cap.max_send_sge = 1; - qpInitAttr.cap.max_recv_sge = 1; + qpInitAttr.cap.max_send_sge = maxNumSgesPerWr; + qpInitAttr.cap.max_recv_sge = maxNumSgesPerWr; qpInitAttr.cap.max_inline_data = 0; struct ibv_qp* _qp = ibv_create_qp(pd, &qpInitAttr); diff --git a/test/mp_unit/ib_tests.cu b/test/mp_unit/ib_tests.cu index 2a7d38412..d4c5539a8 100644 --- a/test/mp_unit/ib_tests.cu +++ b/test/mp_unit/ib_tests.cu @@ -36,7 +36,7 @@ void IbPeerToPeerTest::SetUp() { bootstrap->initialize(id); ibCtx = std::make_shared(ibDevName); - qp = ibCtx->createQp(1024, 1, 8192, 0, 64, 1); + qp = ibCtx->createQp(1024, 1, 8192, 0, 64, 4); int remoteRank = (gEnv->rank == 0) ? 1 : 0; @@ -315,7 +315,7 @@ TEST_F(IbPeerToPeerTest, SendGather) { mscclpp::Timer timeout(3); - const int numDataSrcs = 1; + const int numDataSrcs = 4; const int nelemPerMr = 1024; // Gather send from rank 0 to 1 From d38df98326a37edbdb28ea335a8343998c24840e Mon Sep 17 00:00:00 2001 From: Changho Hwang Date: Wed, 6 Sep 2023 17:03:04 +0000 Subject: [PATCH 5/5] resolve a code scanning issue --- src/ib.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ib.cc b/src/ib.cc index 9fd16082b..cdcfdd04b 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -119,7 +119,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN } this->qp = _qp; this->wrs = std::make_unique(maxWrPerSend_); - this->sges = std::make_unique(maxWrPerSend_ * maxNumSgesPerWr_); + this->sges = std::make_unique((size_t)maxWrPerSend_ * maxNumSgesPerWr_); this->wcs = std::make_unique(maxCqPollNum_); numStagedWrs_ = 0; numStagedSges_ = 0;