diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 1e9e6abd8..76332c60f 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -542,7 +542,8 @@ class Communicator { /// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB. /// @return std::shared_ptr A shared pointer to the connection. std::shared_ptr connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024, - int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64); + int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64, + int ibMaxNumSgesPerWr = 1); /// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called. /// diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index a65a443a6..ce0fc606e 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -141,7 +141,7 @@ void register_core(nb::module_& m) { .def("recv_memory_on_setup", &Communicator::recvMemoryOnSetup, nb::arg("remoteRank"), nb::arg("tag")) .def("connect_on_setup", &Communicator::connectOnSetup, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("transport"), nb::arg("ibMaxCqSize") = 1024, nb::arg("ibMaxCqPollNum") = 1, - nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64) + nb::arg("ibMaxSendWr") = 8192, nb::arg("ibMaxWrPerSend") = 64, nb::arg("ibMaxNumSgesPerWr") = 1) .def("setup", &Communicator::setup); } diff --git a/src/communicator.cc b/src/communicator.cc index cc0323556..891e9a917 100644 --- a/src/communicator.cc +++ b/src/communicator.cc @@ -94,11 +94,9 @@ MSCCLPP_API_CPP NonblockingFuture Communicator::recvMemoryOnSe return NonblockingFuture(memoryReceiver->memoryPromise_.get_future()); } -MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int remoteRank, int tag, Transport transport, - int ibMaxCqSize /*=1024*/, - int ibMaxCqPollNum /*=1*/, - int ibMaxSendWr /*=8192*/, - int ibMaxWrPerSend /*=64*/) { +MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup( + int remoteRank, int tag, Transport transport, int ibMaxCqSize /*=1024*/, int ibMaxCqPollNum /*=1*/, + int ibMaxSendWr /*=8192*/, int ibMaxWrPerSend /*=64*/, int ibMaxNumSgesPerWr /*=1*/) { std::shared_ptr conn; if (transport == Transport::CudaIpc) { // sanity check: make sure the IPC connection is being made within a node @@ -116,7 +114,7 @@ MSCCLPP_API_CPP std::shared_ptr Communicator::connectOnSetup(int rem pimpl->rankToHash_[remoteRank]); } else if (AllIBTransports.has(transport)) { auto ibConn = std::make_shared(remoteRank, tag, transport, ibMaxCqSize, ibMaxCqPollNum, ibMaxSendWr, - ibMaxWrPerSend, *pimpl); + ibMaxWrPerSend, ibMaxNumSgesPerWr, *pimpl); conn = ibConn; INFO(MSCCLPP_NET, "IB connection between rank %d(%lx) via %s and remoteRank %d(%lx) created", pimpl->bootstrap_->getRank(), pimpl->rankToHash_[pimpl->bootstrap_->getRank()], diff --git a/src/connection.cc b/src/connection.cc index 112e11783..d8a055b52 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -83,13 +83,13 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) { // IBConnection IBConnection::IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr, - int maxWrPerSend, Communicator::Impl& commImpl) + int maxWrPerSend, int maxNumSgesPerWr, Communicator::Impl& commImpl) : ConnectionBase(remoteRank, tag), transport_(transport), remoteTransport_(Transport::Unknown), numSignaledSends(0), dummyAtomicSource_(std::make_unique(0)) { - qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend); + qp = commImpl.getIbContext(transport)->createQp(maxCqSize, maxCqPollNum, maxSendWr, 0, maxWrPerSend, maxNumSgesPerWr); dummyAtomicSourceMem_ = RegisteredMemory(std::make_shared( dummyAtomicSource_.get(), sizeof(uint64_t), commImpl.bootstrap_->getRank(), transport, commImpl)); validateTransport(dummyAtomicSourceMem_, transport); diff --git a/src/ib.cc b/src/ib.cc index 7a93a650e..cdcfdd04b 100644 --- a/src/ib.cc +++ b/src/ib.cc @@ -16,19 +16,21 @@ #include "api.h" #include "debug.h" +static ibv_qp_attr createQpAttr() { + ibv_qp_attr qpAttr; + std::memset(&qpAttr, 0, sizeof(qpAttr)); + return qpAttr; +} + namespace mscclpp { -IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { - if (size == 0) { - throw std::invalid_argument("invalid size: " + std::to_string(size)); - } +IbMr::IbMr(ibv_pd* pd, void* buff, size_t alignedSize) : buff(buff) { static __thread uintptr_t pageSize = 0; if (pageSize == 0) { pageSize = sysconf(_SC_PAGESIZE); } uintptr_t addr = reinterpret_cast(buff) & -pageSize; - std::size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; - this->mr = ibv_reg_mr(pd, reinterpret_cast(addr), pages * pageSize, + this->mr = ibv_reg_mr(pd, reinterpret_cast(addr), alignedSize, IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC); if (this->mr == nullptr) { @@ -36,7 +38,7 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) { err << "ibv_reg_mr failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); } - this->size = pages * pageSize; + this->size = alignedSize; } IbMr::~IbMr() { ibv_dereg_mr(this->mr); } @@ -53,8 +55,8 @@ const void* IbMr::getBuff() const { return this->buff; } uint32_t IbMr::getLkey() const { return this->mr->lkey; } IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, - int maxWrPerSend) - : maxCqPollNum(maxCqPollNum), maxWrPerSend(maxWrPerSend) { + int maxWrPerSend, int maxNumSgesPerWr) + : maxCqPollNum_(maxCqPollNum), maxWrPerSend_(maxWrPerSend), maxNumSgesPerWr_(maxNumSgesPerWr) { this->cq = ibv_create_cq(ctx, maxCqSize, nullptr, nullptr, 0); if (this->cq == nullptr) { std::stringstream err; @@ -70,8 +72,8 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN qpInitAttr.qp_type = IBV_QPT_RC; qpInitAttr.cap.max_send_wr = maxSendWr; qpInitAttr.cap.max_recv_wr = maxRecvWr; - qpInitAttr.cap.max_send_sge = 1; - qpInitAttr.cap.max_recv_sge = 1; + qpInitAttr.cap.max_send_sge = maxNumSgesPerWr; + qpInitAttr.cap.max_recv_sge = maxNumSgesPerWr; qpInitAttr.cap.max_inline_data = 0; struct ibv_qp* _qp = ibv_create_qp(pd, &qpInitAttr); @@ -105,8 +107,7 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN this->info.iid = gid.global.interface_id; } - struct ibv_qp_attr qpAttr; - memset(&qpAttr, 0, sizeof(qpAttr)); + ibv_qp_attr qpAttr = createQpAttr(); qpAttr.qp_state = IBV_QPS_INIT; qpAttr.pkey_index = 0; qpAttr.port_num = port; @@ -117,10 +118,11 @@ IbQp::IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollN throw mscclpp::IbError(err.str(), errno); } this->qp = _qp; - this->wrn = 0; - this->wrs = std::make_unique(maxWrPerSend); - this->sges = std::make_unique(maxWrPerSend); - this->wcs = std::make_unique(maxCqPollNum); + this->wrs = std::make_unique(maxWrPerSend_); + this->sges = std::make_unique((size_t)maxWrPerSend_ * maxNumSgesPerWr_); + this->wcs = std::make_unique(maxCqPollNum_); + numStagedWrs_ = 0; + numStagedSges_ = 0; } IbQp::~IbQp() { @@ -129,30 +131,29 @@ IbQp::~IbQp() { } void IbQp::rtr(const IbQpInfo& info) { - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_RTR; - qp_attr.path_mtu = static_cast(info.mtu); - qp_attr.dest_qp_num = info.qpn; - qp_attr.rq_psn = 0; - qp_attr.max_dest_rd_atomic = 1; - qp_attr.min_rnr_timer = 0x12; + ibv_qp_attr qpAttr = createQpAttr(); + qpAttr.qp_state = IBV_QPS_RTR; + qpAttr.path_mtu = static_cast(info.mtu); + qpAttr.dest_qp_num = info.qpn; + qpAttr.rq_psn = 0; + qpAttr.max_dest_rd_atomic = 1; + qpAttr.min_rnr_timer = 0x12; if (info.linkLayer == IBV_LINK_LAYER_ETHERNET || info.is_grh) { - qp_attr.ah_attr.is_global = 1; - qp_attr.ah_attr.grh.dgid.global.subnet_prefix = info.spn; - qp_attr.ah_attr.grh.dgid.global.interface_id = info.iid; - qp_attr.ah_attr.grh.flow_label = 0; - qp_attr.ah_attr.grh.sgid_index = 0; - qp_attr.ah_attr.grh.hop_limit = 255; - qp_attr.ah_attr.grh.traffic_class = 0; + qpAttr.ah_attr.is_global = 1; + qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info.spn; + qpAttr.ah_attr.grh.dgid.global.interface_id = info.iid; + qpAttr.ah_attr.grh.flow_label = 0; + qpAttr.ah_attr.grh.sgid_index = 0; + qpAttr.ah_attr.grh.hop_limit = 255; + qpAttr.ah_attr.grh.traffic_class = 0; } else { - qp_attr.ah_attr.is_global = 0; + qpAttr.ah_attr.is_global = 0; } - qp_attr.ah_attr.dlid = info.lid; - qp_attr.ah_attr.sl = 0; - qp_attr.ah_attr.src_path_bits = 0; - qp_attr.ah_attr.port_num = info.port; - int ret = ibv_modify_qp(this->qp, &qp_attr, + qpAttr.ah_attr.dlid = info.lid; + qpAttr.ah_attr.sl = 0; + qpAttr.ah_attr.src_path_bits = 0; + qpAttr.ah_attr.port_num = info.port; + int ret = ibv_modify_qp(this->qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER); if (ret != 0) { @@ -163,16 +164,15 @@ void IbQp::rtr(const IbQpInfo& info) { } void IbQp::rts() { - struct ibv_qp_attr qp_attr; - std::memset(&qp_attr, 0, sizeof(struct ibv_qp_attr)); - qp_attr.qp_state = IBV_QPS_RTS; - qp_attr.timeout = 18; - qp_attr.retry_cnt = 7; - qp_attr.rnr_retry = 7; - qp_attr.sq_psn = 0; - qp_attr.max_rd_atomic = 1; + ibv_qp_attr qpAttr = createQpAttr(); + qpAttr.qp_state = IBV_QPS_RTS; + qpAttr.timeout = 18; + qpAttr.retry_cnt = 7; + qpAttr.rnr_retry = 7; + qpAttr.sq_psn = 0; + qpAttr.max_rd_atomic = 1; int ret = ibv_modify_qp( - this->qp, &qp_attr, + this->qp, &qpAttr, IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC); if (ret != 0) { std::stringstream err; @@ -181,29 +181,34 @@ void IbQp::rts() { } } -IbQp::WrInfo IbQp::getNewWrInfo() { - if (this->wrn >= this->maxWrPerSend) { +IbQp::WrInfo IbQp::getNewWrInfo(int numSges) { + if (numStagedWrs_ >= maxWrPerSend_) { + std::stringstream err; + err << "too many outstanding work requests. limit is " << maxWrPerSend_; + throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); + } + if (numSges > maxNumSgesPerWr_) { std::stringstream err; - err << "too many outstanding work requests. limit is " << this->maxWrPerSend; + err << "too many sges per work request. limit is " << maxNumSgesPerWr_; throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); } - int wrn = this->wrn; - ibv_send_wr* wr_ = &this->wrs[wrn]; - ibv_sge* sge_ = &this->sges[wrn]; + ibv_send_wr* wr_ = &this->wrs[numStagedWrs_]; + ibv_sge* sge_ = &this->sges[numStagedSges_]; wr_->sg_list = sge_; - wr_->num_sge = 1; + wr_->num_sge = numSges; wr_->next = nullptr; - if (wrn > 0) { - this->wrs[wrn - 1].next = wr_; + if (numStagedWrs_ > 0) { + this->wrs[numStagedWrs_ - 1].next = wr_; } - this->wrn++; + numStagedWrs_++; + numStagedSges_ += numSges; return IbQp::WrInfo{wr_, sge_}; } -void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled) { - auto wrInfo = this->getNewWrInfo(); +void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled) { + auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_RDMA_WRITE; wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; @@ -214,8 +219,8 @@ void IbQp::stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64 wrInfo.sge->lkey = mr->getLkey(); } -void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal) { - auto wrInfo = this->getNewWrInfo(); +void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint32_t dstOffset, uint64_t addVal) { + auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_ATOMIC_FETCH_AND_ADD; wrInfo.wr->send_flags = 0; // atomic op cannot be signaled @@ -227,9 +232,9 @@ void IbQp::stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, u wrInfo.sge->lkey = mr->getLkey(); } -void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled, unsigned int immData) { - auto wrInfo = this->getNewWrInfo(); +void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled, unsigned int immData) { + auto wrInfo = this->getNewWrInfo(1); wrInfo.wr->wr_id = wrId; wrInfo.wr->opcode = IBV_WR_RDMA_WRITE_WITH_IMM; wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; @@ -241,8 +246,31 @@ void IbQp::stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, wrInfo.sge->lkey = mr->getLkey(); } +void IbQp::stageSendGather(const std::vector& srcMrList, const IbMrInfo& dstMrInfo, + const std::vector& srcSizeList, uint64_t wrId, + const std::vector& srcOffsetList, uint32_t dstOffset, bool signaled) { + size_t numSrcs = srcMrList.size(); + if (numSrcs != srcSizeList.size() || numSrcs != srcOffsetList.size()) { + std::stringstream err; + err << "invalid srcs: srcMrList.size()=" << numSrcs << ", srcSizeList.size()=" << srcSizeList.size() + << ", srcOffsetList.size()=" << srcOffsetList.size(); + throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); + } + auto wrInfo = this->getNewWrInfo(numSrcs); + wrInfo.wr->wr_id = wrId; + wrInfo.wr->opcode = IBV_WR_RDMA_WRITE; + wrInfo.wr->send_flags = signaled ? IBV_SEND_SIGNALED : 0; + wrInfo.wr->wr.rdma.remote_addr = (uint64_t)(dstMrInfo.addr) + dstOffset; + wrInfo.wr->wr.rdma.rkey = dstMrInfo.rkey; + for (size_t i = 0; i < numSrcs; ++i) { + wrInfo.sge[i].addr = (uint64_t)(srcMrList[i]->getBuff()) + srcOffsetList[i]; + wrInfo.sge[i].length = srcSizeList[i]; + wrInfo.sge[i].lkey = srcMrList[i]->getLkey(); + } +} + void IbQp::postSend() { - if (this->wrn == 0) { + if (numStagedWrs_ == 0) { return; } struct ibv_send_wr* bad_wr; @@ -252,7 +280,8 @@ void IbQp::postSend() { err << "ibv_post_send failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); } - this->wrn = 0; + numStagedWrs_ = 0; + numStagedSges_ = 0; } void IbQp::postRecv(uint64_t wrId) { @@ -269,7 +298,7 @@ void IbQp::postRecv(uint64_t wrId) { } } -int IbQp::pollCq() { return ibv_poll_cq(this->cq, this->maxCqPollNum, this->wcs.get()); } +int IbQp::pollCq() { return ibv_poll_cq(this->cq, maxCqPollNum_, this->wcs.get()); } IbQpInfo& IbQp::getInfo() { return this->info; } @@ -296,6 +325,13 @@ IbCtx::IbCtx(const std::string& devName) : devName(devName) { err << "ibv_alloc_pd failed (errno " << errno << ")"; throw mscclpp::IbError(err.str(), errno); } + // TODO: do not use new + this->devAttr = new ibv_device_attr; + if (ibv_query_device(this->ctx, this->devAttr) != 0) { + std::stringstream err; + err << "ibv_query_device failed (errno " << errno << ")"; + throw mscclpp::IbError(err.str(), errno); + } } IbCtx::~IbCtx() { @@ -307,6 +343,8 @@ IbCtx::~IbCtx() { if (this->ctx != nullptr) { ibv_close_device(this->ctx); } + // TODO: do not use delete + delete this->devAttr; } bool IbCtx::isPortUsable(int port) const { @@ -321,13 +359,7 @@ bool IbCtx::isPortUsable(int port) const { } int IbCtx::getAnyActivePort() const { - struct ibv_device_attr devAttr; - if (ibv_query_device(this->ctx, &devAttr) != 0) { - std::stringstream err; - err << "ibv_query_device failed (errno " << errno << ")"; - throw mscclpp::IbError(err.str(), errno); - } - for (uint8_t port = 1; port <= devAttr.phys_port_cnt; ++port) { + for (uint8_t port = 1; port <= this->devAttr->phys_port_cnt; ++port) { if (this->isPortUsable(port)) { return port; } @@ -335,22 +367,64 @@ int IbCtx::getAnyActivePort() const { return -1; } +void IbCtx::validateQpConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port) const { + if (!this->isPortUsable(port)) { + throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InvalidUsage); + } + if (maxCqSize > this->devAttr->max_cqe || maxCqSize < 1) { + throw mscclpp::Error("invalid maxCqSize: " + std::to_string(maxCqSize), ErrorCode::InvalidUsage); + } + if (maxCqPollNum > maxCqSize || maxCqPollNum < 1) { + throw mscclpp::Error("invalid maxCqPollNum: " + std::to_string(maxCqPollNum), ErrorCode::InvalidUsage); + } + if (maxSendWr > this->devAttr->max_qp_wr) { + throw mscclpp::Error("invalid maxSendWr: " + std::to_string(maxSendWr), ErrorCode::InvalidUsage); + } + if (maxRecvWr > this->devAttr->max_qp_wr) { + throw mscclpp::Error("invalid maxRecvWr: " + std::to_string(maxRecvWr), ErrorCode::InvalidUsage); + } + if (maxWrPerSend > maxSendWr || maxWrPerSend < 1) { + throw mscclpp::Error("invalid maxWrPerSend: " + std::to_string(maxWrPerSend), ErrorCode::InvalidUsage); + } + if (maxNumSgesPerWr > this->devAttr->max_sge || maxNumSgesPerWr < 1) { + throw mscclpp::Error("invalid maxNumSgesPerWr: " + std::to_string(maxNumSgesPerWr), ErrorCode::InvalidUsage); + } +} + IbQp* IbCtx::createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, - int port /*=-1*/) { + int maxNumSgesPerWr, int port /*=-1*/) { if (port == -1) { port = this->getAnyActivePort(); if (port == -1) { throw mscclpp::Error("No active port found", ErrorCode::InternalError); } - } else if (!this->isPortUsable(port)) { - throw mscclpp::Error("invalid IB port: " + std::to_string(port), ErrorCode::InternalError); } - qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend)); + this->validateQpConfig(maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, maxNumSgesPerWr, port); + qps.emplace_back(new IbQp(this->ctx, this->pd, port, maxCqSize, maxCqPollNum, maxSendWr, maxRecvWr, maxWrPerSend, + maxNumSgesPerWr)); return qps.back().get(); } -const IbMr* IbCtx::registerMr(void* buff, std::size_t size) { - mrs.emplace_back(new IbMr(this->pd, buff, size)); +const IbMr* IbCtx::registerMr(void* buff, uint32_t size) { + if (size == 0) { + throw mscclpp::Error("invalid size: " + std::to_string(size), ErrorCode::InvalidUsage); + } + static __thread uintptr_t pageSize = 0; + if (pageSize == 0) { + pageSize = sysconf(_SC_PAGESIZE); + } + uintptr_t addr = reinterpret_cast(buff) & -pageSize; + size_t pages = (size + (reinterpret_cast(buff) - addr) + pageSize - 1) / pageSize; + + size_t alignedSize = pages * pageSize; + if (alignedSize > this->devAttr->max_mr_size) { + std::stringstream err; + err << "invalid MR size: " << alignedSize << " max " << this->devAttr->max_mr_size; + throw mscclpp::Error(err.str(), ErrorCode::InvalidUsage); + } + + mrs.emplace_back(new IbMr(this->pd, buff, alignedSize)); return mrs.back().get(); } diff --git a/src/include/connection.hpp b/src/include/connection.hpp index 0475691c9..106b86d99 100644 --- a/src/include/connection.hpp +++ b/src/include/connection.hpp @@ -57,7 +57,7 @@ class IBConnection : public ConnectionBase { public: IBConnection(int remoteRank, int tag, Transport transport, int maxCqSize, int maxCqPollNum, int maxSendWr, - int maxWrPerSend, Communicator::Impl& commImpl); + int maxWrPerSend, int maxNumSgesPerWr, Communicator::Impl& commImpl); Transport transport() override; diff --git a/src/include/ib.hpp b/src/include/ib.hpp index 1bec30b85..440c208ce 100644 --- a/src/include/ib.hpp +++ b/src/include/ib.hpp @@ -7,9 +7,11 @@ #include #include #include +#include // Forward declarations of IB structures struct ibv_context; +struct ibv_device_attr; struct ibv_pd; struct ibv_mr; struct ibv_qp; @@ -34,11 +36,11 @@ class IbMr { uint32_t getLkey() const; private: - IbMr(ibv_pd* pd, void* buff, std::size_t size); + IbMr(ibv_pd* pd, void* buff, size_t alignedSize); ibv_mr* mr; void* buff; - std::size_t size; + size_t size; friend class IbCtx; }; @@ -61,11 +63,14 @@ class IbQp { void rtr(const IbQpInfo& info); void rts(); - void stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled); - void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint64_t dstOffset, uint64_t addVal); - void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint64_t srcOffset, - uint64_t dstOffset, bool signaled, unsigned int immData); + void stageSend(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled); + void stageAtomicAdd(const IbMr* mr, const IbMrInfo& info, uint64_t wrId, uint32_t dstOffset, uint64_t addVal); + void stageSendWithImm(const IbMr* mr, const IbMrInfo& info, uint32_t size, uint64_t wrId, uint32_t srcOffset, + uint32_t dstOffset, bool signaled, unsigned int immData); + void stageSendGather(const std::vector& srcMrList, const IbMrInfo& dstMrInfo, + const std::vector& srcSizeList, uint64_t wrId, + const std::vector& srcOffsetList, uint32_t dstOffset, bool signaled); void postSend(); void postRecv(uint64_t wrId); int pollCq(); @@ -80,8 +85,8 @@ class IbQp { }; IbQp(ibv_context* ctx, ibv_pd* pd, int port, int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, - int maxWrPerSend); - WrInfo getNewWrInfo(); + int maxWrPerSend, int maxNumSgesPerWr); + WrInfo getNewWrInfo(int numSges); IbQpInfo info; @@ -90,10 +95,12 @@ class IbQp { std::unique_ptr wcs; std::unique_ptr wrs; std::unique_ptr sges; - int wrn; + int numStagedWrs_; + int numStagedSges_; - const int maxCqPollNum; - const int maxWrPerSend; + const int maxCqPollNum_; + const int maxWrPerSend_; + const int maxNumSgesPerWr_; friend class IbCtx; }; @@ -103,18 +110,22 @@ class IbCtx { IbCtx(const std::string& devName); ~IbCtx(); - IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int port = -1); - const IbMr* registerMr(void* buff, std::size_t size); + IbQp* createQp(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, int maxNumSgesPerWr, + int port = -1); + const IbMr* registerMr(void* buff, uint32_t size); const std::string& getDevName() const; private: bool isPortUsable(int port) const; int getAnyActivePort() const; + void validateQpConfig(int maxCqSize, int maxCqPollNum, int maxSendWr, int maxRecvWr, int maxWrPerSend, + int maxNumSgesPerWr, int port) const; const std::string devName; ibv_context* ctx; ibv_pd* pd; + ibv_device_attr* devAttr; std::list> qps; std::list> mrs; }; diff --git a/test/mp_unit/ib_tests.cu b/test/mp_unit/ib_tests.cu index 7ab892b51..d4c5539a8 100644 --- a/test/mp_unit/ib_tests.cu +++ b/test/mp_unit/ib_tests.cu @@ -36,41 +36,59 @@ void IbPeerToPeerTest::SetUp() { bootstrap->initialize(id); ibCtx = std::make_shared(ibDevName); - qp = ibCtx->createQp(1024, 1, 8192, 0, 64); + qp = ibCtx->createQp(1024, 1, 8192, 0, 64, 4); - qpInfo[gEnv->rank] = qp->getInfo(); - bootstrap->allGather(qpInfo.data(), sizeof(mscclpp::IbQpInfo)); + int remoteRank = (gEnv->rank == 0) ? 1 : 0; + + mscclpp::IbQpInfo localQpInfo = qp->getInfo(); + bootstrap->send(&localQpInfo, sizeof(mscclpp::IbQpInfo), remoteRank, /*tag=*/0); + bootstrap->recv(&remoteQpInfo, sizeof(mscclpp::IbQpInfo), remoteRank, /*tag=*/0); } -void IbPeerToPeerTest::registerBufferAndConnect(void* buf, size_t size) { - bufSize = size; - mr = ibCtx->registerMr(buf, size); - mrInfo[gEnv->rank] = mr->getInfo(); - bootstrap->allGather(mrInfo.data(), sizeof(mscclpp::IbMrInfo)); - - for (int i = 0; i < bootstrap->getNranks(); ++i) { - if (i == gEnv->rank) continue; - qp->rtr(qpInfo[i]); - qp->rts(); - break; +void IbPeerToPeerTest::registerBuffersAndConnect(const std::vector& bufList, + const std::vector& sizeList) { + size_t numMrs = bufList.size(); + if (numMrs != sizeList.size()) { + throw std::runtime_error("bufList.size() != sizeList.size()"); + } + + // Assume the remote side registers the same number of MRs + std::vector localMrInfo; + for (size_t i = 0; i < numMrs; ++i) { + const mscclpp::IbMr* mr = ibCtx->registerMr(bufList[i], sizeList[i]); + localMrList.push_back(mr); + localMrInfo.emplace_back(mr->getInfo()); } + + int remoteRank = (gEnv->rank == 0) ? 1 : 0; + + // Send the number of MRs and the MR info to the remote side + bootstrap->send(&numMrs, sizeof(numMrs), remoteRank, /*tag=*/0); + bootstrap->send(localMrInfo.data(), sizeof(mscclpp::IbMrInfo) * numMrs, remoteRank, /*tag=*/1); + + // Receive the number of MRs and the MR info from the remote side + size_t numRemoteMrs; + bootstrap->recv(&numRemoteMrs, sizeof(numRemoteMrs), remoteRank, /*tag=*/0); + remoteMrInfoList.resize(numRemoteMrs); + bootstrap->recv(remoteMrInfoList.data(), sizeof(mscclpp::IbMrInfo) * numRemoteMrs, remoteRank, /*tag=*/1); + + qp->rtr(remoteQpInfo); + qp->rts(); + bootstrap->barrier(); } void IbPeerToPeerTest::stageSend(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled) { - const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1]; - qp->stageSend(mr, remoteMrInfo, size, wrId, srcOffset, dstOffset, signaled); + qp->stageSend(localMrList[0], remoteMrInfoList[0], size, wrId, srcOffset, dstOffset, signaled); } void IbPeerToPeerTest::stageAtomicAdd(uint64_t wrId, uint64_t dstOffset, uint64_t addVal) { - const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1]; - qp->stageAtomicAdd(mr, remoteMrInfo, wrId, dstOffset, addVal); + qp->stageAtomicAdd(localMrList[0], remoteMrInfoList[0], wrId, dstOffset, addVal); } void IbPeerToPeerTest::stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData) { - const mscclpp::IbMrInfo& remoteMrInfo = mrInfo[(gEnv->rank == 1) ? 0 : 1]; - qp->stageSendWithImm(mr, remoteMrInfo, size, wrId, srcOffset, dstOffset, signaled, immData); + qp->stageSendWithImm(localMrList[0], remoteMrInfoList[0], size, wrId, srcOffset, dstOffset, signaled, immData); } TEST_F(IbPeerToPeerTest, SimpleSendRecv) { @@ -85,7 +103,7 @@ TEST_F(IbPeerToPeerTest, SimpleSendRecv) { const int nelem = 1; auto data = mscclpp::allocUniqueCuda(nelem); - registerBufferAndConnect(data.get(), sizeof(int) * nelem); + registerBuffersAndConnect({data.get()}, {sizeof(int) * nelem}); if (gEnv->rank == 1) { mscclpp::Timer timer; @@ -194,7 +212,7 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) { const uint64_t nelem = 65536 + 1; auto data = mscclpp::allocUniqueCuda(nelem); - registerBufferAndConnect(data.get(), sizeof(uint64_t) * nelem); + registerBuffersAndConnect({data.get()}, {sizeof(uint64_t) * nelem}); uint64_t res = 0; uint64_t iter = 0; @@ -288,3 +306,106 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) { EXPECT_EQ(res, 0); } + +TEST_F(IbPeerToPeerTest, SendGather) { + if (gEnv->rank >= 2) { + // This test needs only two ranks + return; + } + + mscclpp::Timer timeout(3); + + const int numDataSrcs = 4; + const int nelemPerMr = 1024; + + // Gather send from rank 0 to 1 + if (gEnv->rank == 0) { + std::vector> dataList; + for (int i = 0; i < numDataSrcs; ++i) { + auto data = mscclpp::allocUniqueCuda(nelemPerMr); + // Fill in data for correctness check + std::vector hostData(nelemPerMr, i + 1); + mscclpp::memcpyCuda(data.get(), hostData.data(), nelemPerMr); + dataList.emplace_back(std::move(data)); + } + + std::vector dataRefList; + for (int i = 0; i < numDataSrcs; ++i) { + dataRefList.emplace_back(dataList[i].get()); + } + + // For sending a completion signal to the remote side + uint64_t outboundSema = 1; + + dataRefList.push_back(&outboundSema); + + std::vector sizeList(numDataSrcs, sizeof(int) * nelemPerMr); + sizeList.push_back(sizeof(outboundSema)); + + registerBuffersAndConnect(dataRefList, sizeList); + + auto& remoteDataMrInfo = remoteMrInfoList[0]; + auto& remoteSemaMrInfo = remoteMrInfoList[1]; + auto& localSemaMr = localMrList[numDataSrcs]; + + std::vector gatherLocalMrList; + for (int i = 0; i < numDataSrcs; ++i) { + gatherLocalMrList.emplace_back(localMrList[i]); + } + std::vector gatherSizeList(numDataSrcs, sizeof(int) * nelemPerMr); + std::vector gatherOffsetList(numDataSrcs, 0); + + qp->stageSendGather(gatherLocalMrList, remoteDataMrInfo, gatherSizeList, /*wrId=*/0, gatherOffsetList, + /*dstOffset=*/0, /*signaled=*/true); + qp->postSend(); + + qp->stageAtomicAdd(localSemaMr, remoteSemaMrInfo, /*wrId=*/0, /*dstOffset=*/0, /*addVal=*/1); + qp->postSend(); + + // Wait for send completion + bool waiting = true; + int spin = 0; + while (waiting) { + int wcNum = qp->pollCq(); + ASSERT_GE(wcNum, 0); + for (int i = 0; i < wcNum; ++i) { + const ibv_wc* wc = qp->getWc(i); + EXPECT_EQ(wc->status, IBV_WC_SUCCESS); + waiting = false; + break; + } + if (spin++ > 1000000) { + FAIL() << "Polling is stuck."; + } + } + } else { + // Data array to receive + auto data = mscclpp::allocUniqueCuda(nelemPerMr * numDataSrcs); + + // For receiving a completion signal from the remote side + uint64_t inboundSema = 0; + + registerBuffersAndConnect({data.get(), &inboundSema}, + {sizeof(int) * nelemPerMr * numDataSrcs, sizeof(inboundSema)}); + + // Wait for a signal from the remote side + volatile uint64_t* ptrInboundSema = &inboundSema; + int spin = 0; + while (*ptrInboundSema == 0) { + if (spin++ > 1000000) { + FAIL() << "Polling is stuck."; + } + } + + // Correctness check + std::vector hostData(nelemPerMr * numDataSrcs); + mscclpp::memcpyCuda(hostData.data(), data.get(), nelemPerMr * numDataSrcs); + for (int i = 0; i < numDataSrcs; ++i) { + for (int j = 0; j < nelemPerMr; ++j) { + EXPECT_EQ(hostData[i * nelemPerMr + j], i + 1); + } + } + } + + bootstrap->barrier(); +} diff --git a/test/mp_unit/mp_unit_tests.hpp b/test/mp_unit/mp_unit_tests.hpp index 393255638..0ec62e79f 100644 --- a/test/mp_unit/mp_unit_tests.hpp +++ b/test/mp_unit/mp_unit_tests.hpp @@ -67,7 +67,7 @@ class IbPeerToPeerTest : public IbTestBase { protected: void SetUp() override; - void registerBufferAndConnect(void* buf, size_t size); + void registerBuffersAndConnect(const std::vector& bufList, const std::vector& sizeList); void stageSend(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled); @@ -76,14 +76,16 @@ class IbPeerToPeerTest : public IbTestBase { void stageSendWithImm(uint32_t size, uint64_t wrId, uint64_t srcOffset, uint64_t dstOffset, bool signaled, unsigned int immData); + void stageSendGather(const std::vector& sizeList, uint64_t wrId, const std::vector& srcOffsetList, + uint32_t dstOffset, bool signaled); + std::shared_ptr bootstrap; std::shared_ptr ibCtx; mscclpp::IbQp* qp; - const mscclpp::IbMr* mr; - size_t bufSize; + std::vector localMrList; - std::array qpInfo; - std::array mrInfo; + mscclpp::IbQpInfo remoteQpInfo; + std::vector remoteMrInfoList; }; class CommunicatorTestBase : public MultiProcessTest {