diff --git a/.gitmodules b/.gitmodules index 7207ef9b6..1cfc7c11d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -11,3 +11,6 @@ [submodule "third_party/libnop"] path = third_party/libnop url = https://github.com/google/libnop.git +[submodule "third_party/libfabric"] + path = third_party/libfabric + url = https://github.com/ofiwg/libfabric diff --git a/tensorpipe/CMakeLists.txt b/tensorpipe/CMakeLists.txt index 5c3606471..acbdc6188 100644 --- a/tensorpipe/CMakeLists.txt +++ b/tensorpipe/CMakeLists.txt @@ -164,6 +164,27 @@ if(TP_ENABLE_IBV) set(TENSORPIPE_HAS_IBV_TRANSPORT 1) endif() +### EFA + +tp_conditional_backend( + TP_ENABLE_EFA "Enable EFA transport" "LINUX") +if(TP_ENABLE_EFA) + list(APPEND TP_SRCS + transport/efa/connection_impl.cc + transport/efa/context_impl.cc + transport/efa/error.cc + transport/efa/factory.cc + transport/efa/listener_impl.cc + transport/efa/reactor.cc + transport/efa/sockaddr.cc + transport/efa/utility.cc) + list(APPEND TP_PUBLIC_HDRS + transport/efa/error.h + transport/efa/factory.h + transport/efa/utility.h) + set(TENSORPIPE_HAS_EFA_TRANSPORT 1) + list(APPEND TP_INCLUDE_DIRS $) +endif() ## MAC OS specific library deps diff --git a/tensorpipe/benchmark/transport_registry.cc b/tensorpipe/benchmark/transport_registry.cc index b18778a19..d033a5755 100644 --- a/tensorpipe/benchmark/transport_registry.cc +++ b/tensorpipe/benchmark/transport_registry.cc @@ -42,6 +42,16 @@ std::shared_ptr makeUvContext() { TP_REGISTER_CREATOR(TensorpipeTransportRegistry, uv, makeUvContext); +// EFA + +#if TENSORPIPE_HAS_EFA_TRANSPORT +std::shared_ptr makeEfaContext() { + return tensorpipe::transport::efa::create(); +} + +TP_REGISTER_CREATOR(TensorpipeTransportRegistry, efa, makeEfaContext); +#endif // TENSORPIPE_HAS_EFA_TRANSPORT + void validateTransportContext( std::shared_ptr context) { if (!context) { diff --git a/tensorpipe/common/efa.h b/tensorpipe/common/efa.h new file mode 100644 index 000000000..215dd0180 --- /dev/null +++ b/tensorpipe/common/efa.h @@ -0,0 +1,259 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace tensorpipe { + +static const int FABRIC_VERSION = FI_VERSION(1, 10); + +#define TP_CHECK_EFA_RET(ret, msg) \ + do { \ + if (ret < 0) { \ + TP_THROW_ASSERT() << msg << ". Return Code: " << ret; \ + } \ + } while (false) + +struct FabricDeleter { + void operator()(fi_info* info) { + if (info) + efaLib->fi_freeinfo_op(info); + } + void operator()(fid* fid) { + if (fid) + fi_close(fid); + } + void operator()(fid_domain* fid) { + if (fid) + fi_close((fid_t)fid); + } + void operator()(fid_fabric* fid) { + if (fid) + fi_close((fid_t)fid); + } + void operator()(fid_cq* fid) { + if (fid) + fi_close((fid_t)fid); + } + void operator()(fid_av* fid) { + if (fid) + fi_close((fid_t)fid); + } + void operator()(fid_ep* fid) { + if (fid) + fi_close((fid_t)fid); + } + void operator()(fid_eq* fid) { + if (fid) + fi_close((fid_t)fid); + } + + EfaLib* efaLib; +}; + +template +using UniqueFabricPtr = std::unique_ptr; + +struct EfaAddress { + // endpoint name + char name[64] = {}; + // length of endpoint name + size_t len = sizeof(name); + + std::string debugStr() const { + std::stringstream ss; + ss << "["; + for (size_t i = 0; i < len; i++) { + ss << std::to_string(name[i]) << ","; + } + ss << "]"; + return ss.str(); + } + + std::string str() const { + return std::string(name, len); + } + + void copyFrom(void* epName, const size_t epNameLen) { + len = epNameLen; + memcpy(name, epName, sizeof(name)); + } + + void copyTo(char* epName, size_t* epNameLen) { + *(epNameLen) = len; + memcpy(epName, name, sizeof(name)); + } +}; + +inline EfaLib::device* getEfaDevices(EfaLib& efaLib) { + EfaLib::device* hints = efaLib.fi_dupinfo_op((const fi_info*)NULL); + hints->mode = FI_CONTEXT; + hints->ep_attr->type = FI_EP_RDM; // Reliable Datagram + hints->caps = FI_TAGGED | FI_MSG | FI_REMOTE_COMM | FI_DIRECTED_RECV | + FI_LOCAL_COMM | FI_SOURCE; + hints->tx_attr->msg_order = FI_ORDER_SAS; + hints->rx_attr->msg_order = FI_ORDER_SAS; + hints->domain_attr->control_progress = FI_PROGRESS_AUTO; + hints->domain_attr->data_progress = FI_PROGRESS_AUTO; + hints->domain_attr->caps = + FI_LOCAL_COMM | FI_REMOTE_COMM; // Enable local loopback + hints->domain_attr->av_type = FI_AV_TABLE; + hints->fabric_attr->prov_name = strdup("efa"); + // info. + struct fi_info* info_; + int ret = + efaLib.fi_getinfo_op(FABRIC_VERSION, nullptr, nullptr, 0, hints, &info_); + return info_; +} + +using EfaFabric = UniqueFabricPtr; +inline EfaFabric createEfaFabric(EfaLib& efaLib, EfaLib::device* info) { + struct fid_fabric* fabric_; + int ret = efaLib.fi_fabric_op(info->fabric_attr, &fabric_, nullptr); + TP_CHECK_EFA_RET(ret, "Couldn't open a fabric provider"); + return EfaFabric(fabric_, FabricDeleter{&efaLib}); +} + +using EfaDomain = UniqueFabricPtr; +inline EfaDomain createEfaDomain( + EfaLib& efaLib, + EfaFabric& fabric, + EfaLib::device* info) { + struct fid_domain* domain_; + int ret = fi_domain(fabric.get(), info, &domain_, nullptr); + TP_CHECK_EFA_RET(ret, "Couldn't open a fabric access domain"); + return EfaDomain(domain_, FabricDeleter{&efaLib}); +} + +using EfaEndpoint = UniqueFabricPtr; +inline EfaEndpoint createEfaEndpoint( + EfaLib& efaLib, + EfaDomain& domain, + EfaLib::device* info) { + struct fid_ep* ep_; + int ret = fi_endpoint(domain.get(), info, &ep_, nullptr); + TP_CHECK_EFA_RET(ret, "Couldn't allocate endpoint"); + return EfaEndpoint(ep_, FabricDeleter{&efaLib}); +} + +using EfaCompletionQueue = UniqueFabricPtr; +inline EfaCompletionQueue createEfaCompletionQueue( + EfaLib& efaLib, + EfaDomain& domain, + EfaLib::device* info) { + struct fid_cq* cq_; + struct fi_cq_attr cq_attr = {}; + cq_attr.format = FI_CQ_FORMAT_TAGGED; + cq_attr.size = info->rx_attr->size; + int ret = fi_cq_open(domain.get(), &cq_attr, &cq_, nullptr); + TP_CHECK_EFA_RET(ret, "Couldn't open CQ"); + return EfaCompletionQueue(cq_, FabricDeleter{&efaLib}); +} + +using EfaAdressVector = UniqueFabricPtr; +inline EfaAdressVector createEfaAdressVector( + EfaLib& efaLib, + EfaDomain& domain) { + struct fi_av_attr av_attr = {}; + struct fid_av* av_; + int ret = fi_av_open(domain.get(), &av_attr, &av_, nullptr); + TP_CHECK_EFA_RET(ret, "Couldn't open AV"); + return EfaAdressVector(av_, FabricDeleter{&efaLib}); +} + +inline EfaAddress enableEndpoint( + EfaLib& efaLib, + EfaEndpoint& ep, + EfaAdressVector& av, + EfaCompletionQueue& cq) { + // fi_ep_bind: bind CQ and AV to the endpoint + int ret; + ret = fi_ep_bind(ep.get(), (fid_t)cq.get(), FI_RECV | FI_TRANSMIT); + TP_CHECK_EFA_RET(ret, "Couldn't bind EP-CQ"); + ret = fi_ep_bind(ep.get(), (fid_t)av.get(), 0); + TP_CHECK_EFA_RET(ret, "Couldn't bind EP-AV"); + + // fi_enable: enable endpoint for communication + ret = fi_enable(ep.get()); + TP_CHECK_EFA_RET(ret, "Couldn't enable endpoint"); + + // fi_getname: get endpoint name + EfaAddress addr; + ret = fi_getname((fid_t)ep.get(), addr.name, &addr.len); + TP_CHECK_EFA_RET(ret, "Call to fi_getname() failed"); + return addr; +} + +class EfaDeviceList { + private: + EfaDeviceList(EfaLib& efaLib, EfaLib::device* ptr, int size) + : deviceList_(ptr, Deleter{&efaLib}), size_(size) {} + + public: + EfaDeviceList() = default; + + static std::tuple create(EfaLib& efaLib) { + int size; + EfaLib::device* ptr = getEfaDevices(efaLib); + EfaLib::device* firstDevice = ptr; + if (ptr == nullptr) { + return std::make_tuple( + TP_CREATE_ERROR(SystemError, "fi_getinfo", -1), EfaDeviceList()); + } + size = 1; + while (ptr->next != nullptr) { + ptr = ptr->next; + size++; + }; + return std::make_tuple( + Error::kSuccess, EfaDeviceList(efaLib, firstDevice, size)); + } + + int size() { + return size_; + } + + EfaLib::device& operator[](int index) { + EfaLib::device* ptr = deviceList_.get(); + for (int j = 0; j < index; j++) { + ptr = ptr->next; + } + return *ptr; + } + + void reset() { + deviceList_.reset(); + } + + private: + struct Deleter { + void operator()(EfaLib::device* ptr) { + efaLib->fi_freeinfo_op(ptr); + } + + EfaLib* efaLib; + }; + std::unique_ptr deviceList_; + int size_; +}; + +} // namespace tensorpipe \ No newline at end of file diff --git a/tensorpipe/common/efa_lib.h b/tensorpipe/common/efa_lib.h new file mode 100644 index 000000000..8f4281a7a --- /dev/null +++ b/tensorpipe/common/efa_lib.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorpipe { + +#define TP_FORALL_FABRIC_SYMBOLS(_) \ + _(fi_freeinfo) \ + _(fi_dupinfo) \ + _(fi_fabric) \ + _(fi_strerror) \ + _(fi_getinfo) + +// Wrapper for libfabric. + +class EfaLib { + public: + using device = struct fi_info; + + private: + explicit EfaLib(DynamicLibraryHandle dlhandle) + : dlhandle_(std::move(dlhandle)) {} + + DynamicLibraryHandle dlhandle_; + +#define TP_DECLARE_FIELD(function_name) \ + decltype(&function_name) function_name##_ptr_ = nullptr; + TP_FORALL_FABRIC_SYMBOLS(TP_DECLARE_FIELD) +#undef TP_DECLARE_FIELD + + public: + EfaLib() = default; + +#define TP_FORWARD_CALL(function_name) \ + template \ + auto function_name##_op(Args&&... args) { \ + return (*function_name##_ptr_)(std::forward(args)...); \ + } + TP_FORALL_FABRIC_SYMBOLS(TP_FORWARD_CALL) +#undef TP_FORWARD_CALL + + static std::tuple create() { + Error error; + DynamicLibraryHandle dlhandle; + // To keep things "neat" and contained, we open in "local" mode (as opposed + // to global) so that the ibverbs symbols can only be resolved through this + // handle and are not exposed (a.k.a., "leaded") to other shared objects. + std::tie(error, dlhandle) = + DynamicLibraryHandle::create("libfabric.so", RTLD_LOCAL | RTLD_LAZY); + if (error) { + TP_LOG_WARNING() << "Load so fail"; + return std::make_tuple(std::move(error), EfaLib()); + } + // Log at level 9 as we can't know whether this will be used in a transport + // or channel, thus err on the side of this being as low-level as possible + // because we don't expect this to be of interest that often. + TP_VLOG(9) << [&]() -> std::string { + std::string filename; + std::tie(error, filename) = dlhandle.getFilename(); + if (error) { + return "Couldn't determine location of shared library libfabric.so: " + + error.what(); + } + return "Found shared library libfabric.so at " + filename; + }(); + EfaLib lib(std::move(dlhandle)); +#define TP_LOAD_SYMBOL(function_name) \ + { \ + void* ptr; \ + std::tie(error, ptr) = lib.dlhandle_.loadSymbol(#function_name); \ + if (error) { \ + return std::make_tuple(std::move(error), EfaLib()); \ + } \ + TP_THROW_ASSERT_IF(ptr == nullptr); \ + lib.function_name##_ptr_ = \ + reinterpret_cast(ptr); \ + } + TP_FORALL_FABRIC_SYMBOLS(TP_LOAD_SYMBOL) +#undef TP_LOAD_SYMBOL + return std::make_tuple(Error::kSuccess, std::move(lib)); + } +}; + +} // namespace tensorpipe diff --git a/tensorpipe/common/efa_read_write_ops.h b/tensorpipe/common/efa_read_write_ops.h new file mode 100644 index 000000000..72db31cd9 --- /dev/null +++ b/tensorpipe/common/efa_read_write_ops.h @@ -0,0 +1,265 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tensorpipe { + +// The read operation captures all state associated with reading a +// fixed length chunk of data from the underlying connection. All +// reads are required to include a word-sized header containing the +// number of bytes in the operation. This makes it possible for the +// read side of the connection to either 1) not know how many bytes +// to expected, and dynamically allocate, or 2) know how many bytes +// to expect, and preallocate the destination memory. +class EFAReadOperation { + public: + enum Mode { + WAIT_TO_POST, + WAIT_TO_COMPLETE, + COMPLETE, + }; + + public: + using read_callback_fn = + std::function; + + explicit inline EFAReadOperation(void* opContext, read_callback_fn fn); + + inline EFAReadOperation( + void* ptr, + size_t length, + void* opContext, + read_callback_fn fn); + + // Called when a buffer is needed to read data from stream. + inline void allocFromLoop(); + + // Called when data has been read from stream. + // inline void readFromLoop(); + + // Returns if this read operation is complete. + inline bool completed() const; + inline bool posted() const; + + inline void setCompleted(); + + inline void setWaitToCompleted(); + + inline size_t getReadLength(); + + inline size_t* getLengthPtr(); + inline char* getBufferPtr(); + + // Get op context + inline void* getOpContext(); + + // Invoke user callback. + inline void callbackFromLoop(const Error& error); + + private: + Mode mode_{WAIT_TO_POST}; + char* ptr_{nullptr}; + void* opContext_{nullptr}; + + // Number of bytes as specified by the user (if applicable). + optional givenLength_; + + // Number of bytes to expect as read from the connection. + size_t readLength_{0}; + + // Number of bytes read from the connection. + // This is reset to 0 when we advance from READ_LENGTH to READ_PAYLOAD. + size_t bytesRead_{0}; + + // Holds temporary allocation if no length was specified. + std::unique_ptr buffer_{nullptr}; + + // User callback. + read_callback_fn fn_; +}; + +EFAReadOperation::EFAReadOperation(void* opContext, read_callback_fn fn) + : opContext_(opContext), fn_(std::move(fn)) {} + +EFAReadOperation::EFAReadOperation( + void* ptr, + size_t length, + void* opContext, + read_callback_fn fn) + : ptr_(static_cast(ptr)), + givenLength_(length), + opContext_(opContext), + fn_(std::move(fn)) {} + +void EFAReadOperation::allocFromLoop() { + if (givenLength_.has_value()) { + TP_DCHECK(ptr_ != nullptr || givenLength_.value() == 0); + TP_DCHECK_EQ(readLength_, givenLength_.value()); + } else { + TP_DCHECK(ptr_ == nullptr); + buffer_ = std::make_unique(readLength_); + ptr_ = buffer_.get(); + } +} + +inline size_t* EFAReadOperation::getLengthPtr() { + return &readLength_; +}; +inline char* EFAReadOperation::getBufferPtr() { + return ptr_; +}; + +inline size_t EFAReadOperation::getReadLength() { + return readLength_; +}; + +bool EFAReadOperation::completed() const { + return mode_ == COMPLETE; +} + +bool EFAReadOperation::posted() const { + return !(mode_ == WAIT_TO_POST); +} + +void EFAReadOperation::setCompleted() { + mode_ = COMPLETE; +} + +void EFAReadOperation::setWaitToCompleted() { + mode_ = WAIT_TO_COMPLETE; +} + +void EFAReadOperation::callbackFromLoop(const Error& error) { + fn_(error, ptr_, readLength_); +} + +void* EFAReadOperation::getOpContext() { + return opContext_; +} + +// The write operation captures all state associated with writing a +// fixed length chunk of data from the underlying connection. The +// write includes a word-sized header containing the length of the +// write. This header is a member field on this class and therefore +// the instance must be kept alive and the reference to the instance +// must remain valid until the write callback has been called. +class EFAWriteOperation { + public: + enum Mode { + WAIT_TO_POST, + WAIT_TO_COMPLETE, + COMPLETE, + }; + + using write_callback_fn = std::function; + + inline EFAWriteOperation( + const void* ptr, + size_t length, + void* opContext, + write_callback_fn fn); + + struct Buf { + char* base; + size_t len; + }; + + inline std::tuple getBufs(); + + // Invoke user callback. + inline void callbackFromLoop(const Error& error); + // set mode to WAIT_TO_COMPLETE + inline void setWaitComplete(); + + inline bool posted(); + + // Returns if this write operation is complete. + inline bool completed() const; + // set mode to complete + inline void setCompleted(); + // get length + inline size_t getLength() const; + // get op context + inline void* getOpContext(); + + private: + Mode mode_{WAIT_TO_POST}; + const char* ptr_; + const size_t length_; + fi_addr_t peerAddr_; + void* opContext_{nullptr}; + + // Buffers (structs with pointers and lengths) to write to stream. + std::array bufs_; + + // User callback. + write_callback_fn fn_; +}; + +EFAWriteOperation::EFAWriteOperation( + const void* ptr, + size_t length, + void* opContext, + write_callback_fn fn) + : ptr_(static_cast(ptr)), + length_(length), + opContext_(opContext), + fn_(std::move(fn)) { + bufs_[0].base = const_cast(reinterpret_cast(&length_)); + bufs_[0].len = sizeof(length_); + bufs_[1].base = const_cast(ptr_); + bufs_[1].len = length_; +} + +std::tuple EFAWriteOperation::getBufs() { + size_t numBuffers = length_ == 0 ? 1 : 2; + return std::make_tuple(bufs_.data(), numBuffers); +} + +void EFAWriteOperation::callbackFromLoop(const Error& error) { + fn_(error); +} + +bool EFAWriteOperation::posted() { + return !(mode_ == WAIT_TO_POST); +} + +size_t EFAWriteOperation::getLength() const { + return length_; +} + +void EFAWriteOperation::setWaitComplete() { + mode_ = WAIT_TO_COMPLETE; +} + +void EFAWriteOperation::setCompleted() { + mode_ = COMPLETE; +} + +bool EFAWriteOperation::completed() const { + return mode_ == COMPLETE; +} + +void* EFAWriteOperation::getOpContext() { + return opContext_; +} + +} // namespace tensorpipe diff --git a/tensorpipe/common/epoll_loop.cc b/tensorpipe/common/epoll_loop.cc index b1c5e1df6..5fc07f912 100644 --- a/tensorpipe/common/epoll_loop.cc +++ b/tensorpipe/common/epoll_loop.cc @@ -129,7 +129,7 @@ bool EpollLoop::hasRegisteredHandlers() { } void EpollLoop::loop() { - setThreadName("TP_IBV_loop"); + setThreadName("TP_epoll_loop"); // Stop when another thread has asked the loop the close and when all // handlers have been unregistered except for the wakeup eventfd one. diff --git a/tensorpipe/config.h.in b/tensorpipe/config.h.in index ff5fa16a0..e26726f51 100644 --- a/tensorpipe/config.h.in +++ b/tensorpipe/config.h.in @@ -10,5 +10,6 @@ #cmakedefine01 TENSORPIPE_HAS_SHM_TRANSPORT #cmakedefine01 TENSORPIPE_HAS_IBV_TRANSPORT +#cmakedefine01 TENSORPIPE_HAS_EFA_TRANSPORT #cmakedefine01 TENSORPIPE_HAS_CMA_CHANNEL diff --git a/tensorpipe/tensorpipe.h b/tensorpipe/tensorpipe.h index 15b54f97c..d720806aa 100644 --- a/tensorpipe/tensorpipe.h +++ b/tensorpipe/tensorpipe.h @@ -41,6 +41,12 @@ #include #endif // TENSORPIPE_HAS_IBV_TRANSPORT +#if TENSORPIPE_HAS_EFA_TRANSPORT +#include +#include +#include +#endif // TENSORPIPE_HAS_EFA_TRANSPORT + // Channels #include diff --git a/tensorpipe/transport/efa/connection_impl.cc b/tensorpipe/transport/efa/connection_impl.cc new file mode 100644 index 000000000..834e0fb03 --- /dev/null +++ b/tensorpipe/transport/efa/connection_impl.cc @@ -0,0 +1,349 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +ConnectionImpl::ConnectionImpl( + ConstructorToken token, + std::shared_ptr context, + std::string id, + Socket socket) + : ConnectionImplBoilerplate( + token, + std::move(context), + std::move(id)), + socket_(std::move(socket)) {} + +ConnectionImpl::ConnectionImpl( + ConstructorToken token, + std::shared_ptr context, + std::string id, + std::string addr) + : ConnectionImplBoilerplate( + token, + std::move(context), + std::move(id)), + sockaddr_(Sockaddr::createInetSockAddr(addr)) {} + +void ConnectionImpl::initImplFromLoop() { + context_->enroll(*this); + Error error; + // The connection either got a socket or an address, but not both. + + TP_DCHECK(socket_.hasValue() ^ sockaddr_.has_value()); + if (!socket_.hasValue()) { + std::tie(error, socket_) = + Socket::createForFamily(sockaddr_->addr()->sa_family); + if (error) { + setError(std::move(error)); + return; + } + error = socket_.reuseAddr(true); + if (error) { + setError(std::move(error)); + return; + } + error = socket_.connect(sockaddr_.value()); + if (error) { + setError(std::move(error)); + return; + } + } + // Ensure underlying control socket is non-blocking such that it + // works well with event driven I/O. + error = socket_.block(false); + if (error) { + setError(std::move(error)); + return; + } + + // We're sending address first, so wait for writability. + state_ = SEND_ADDR; + context_->registerDescriptor(socket_.fd(), EPOLLOUT, shared_from_this()); +} + +void ConnectionImpl::readImplFromLoop(read_callback_fn fn) { + readOperations_.emplace_back(this, std::move(fn)); + + processReadOperationsFromLoop(); +} + +void ConnectionImpl::readImplFromLoop( + void* ptr, + size_t length, + read_callback_fn fn) { + readOperations_.emplace_back(ptr, length, this, std::move(fn)); + + // If the inbox already contains some data, we may be able to process this + // operation right away. + processReadOperationsFromLoop(); +} + +void ConnectionImpl::writeImplFromLoop( + const void* ptr, + size_t length, + write_callback_fn fn) { + writeOperations_.emplace_back(ptr, length, this, std::move(fn)); + + // If the outbox has some free space, we may be able to process this operation + // right away. + processWriteOperationsFromLoop(); +} + +void ConnectionImpl::handleEventsFromLoop(int events) { + TP_DCHECK(context_->inLoop()); + TP_VLOG(9) << "Connection " << id_ << " is handling an event on its socket (" + << EpollLoop::formatEpollEvents(events) << ")"; + + // Handle only one of the events in the mask. Events on the control + // file descriptor are rare enough for the cost of having epoll call + // into this function multiple times to not matter. The benefit is + // that every handler can close and unregister the control file + // descriptor from the event loop, without worrying about the next + // handler trying to do so as well. + // In some cases the socket could be in a state where it's both in an error + // state and readable/writable. If we checked for EPOLLIN or EPOLLOUT first + // and then returned after handling them, we would keep doing so forever and + // never reach the error handling. So we should keep the error check first. + if (events & EPOLLERR) { + int error; + socklen_t errorlen = sizeof(error); + int rv = getsockopt( + socket_.fd(), + SOL_SOCKET, + SO_ERROR, + reinterpret_cast(&error), + &errorlen); + if (rv == -1) { + setError(TP_CREATE_ERROR(SystemError, "getsockopt", rv)); + } else { + setError(TP_CREATE_ERROR(SystemError, "async error on socket", error)); + } + return; + } + if (events & EPOLLIN) { + handleEventInFromLoop(); + return; + } + if (events & EPOLLOUT) { + handleEventOutFromLoop(); + return; + } + // Check for hangup last, as there could be cases where we get EPOLLHUP but + // there's still data to be read from the socket, so we want to deal with that + // before dealing with the hangup. + if (events & EPOLLHUP) { + setError(TP_CREATE_ERROR(EOFError)); + return; + } +} + +void ConnectionImpl::handleEventInFromLoop() { + TP_DCHECK(context_->inLoop()); + if (state_ == RECV_ADDR) { + struct EfaAddress addr; + // auto x = &addr.name; + auto err = socket_.read(addr.name, sizeof(addr.name)); + // Crossing our fingers that the exchange information is small enough that + // it can be read in a single chunk. + if (err != sizeof(addr.name)) { + setError(TP_CREATE_ERROR(ShortReadError, sizeof(addr.name), err)); + return; + } + + peerAddr_ = context_->getReactor().addPeerAddr(addr); + + // The connection is usable now. + state_ = ESTABLISHED; + // Trigger read operations in case a pair of local read() and remote + // write() happened before connection is established. Otherwise read() + // callback would lose if it's the only read() request. + processReadOperationsFromLoop(); + processWriteOperationsFromLoop(); + return; + } + + if (state_ == ESTABLISHED) { + // We don't expect to read anything on this socket once the + // connection has been established. If we do, assume it's a + // zero-byte read indicating EOF. + setError(TP_CREATE_ERROR(EOFError)); + return; + } + + TP_THROW_ASSERT() << "EPOLLIN event not handled in state " << state_; +} + +void ConnectionImpl::handleEventOutFromLoop() { + TP_DCHECK(context_->inLoop()); + if (state_ == SEND_ADDR) { + EfaAddress addr = context_->getReactor().getEfaAddress(); + auto err = + socket_.write(reinterpret_cast(addr.name), sizeof(addr.name)); + // Crossing our fingers that the exchange information is small enough that + // it can be written in a single chunk. + if (err != sizeof(addr.name)) { + setError(TP_CREATE_ERROR(ShortWriteError, sizeof(addr.name), err)); + return; + } + + // Sent our address. Wait for address from peer. + state_ = RECV_ADDR; + context_->registerDescriptor(socket_.fd(), EPOLLIN, shared_from_this()); + return; + } + + TP_THROW_ASSERT() << "EPOLLOUT event not handled in state " << state_; +} + +void ConnectionImpl::processReadOperationsFromLoop() { + TP_DCHECK(context_->inLoop()); + + // Process all read read operations that we can immediately serve, only + // when connection is established. + if (state_ != ESTABLISHED) { + return; + } + + for (int i = 0; i < readOperations_.size(); i++) { + EFAReadOperation& readOperation = readOperations_[i]; + if (!readOperation.posted()) { + // context_->getReactor().; + context_->getReactor().postRecv( + readOperation.getLengthPtr(), + sizeof(size_t), + kLength | recvIdx_, + peerAddr_, + 0, + &readOperation); + readOperation.setWaitToCompleted(); + recvIdx_++; + } else { + // if the operation is posted, all operations back should be posted + // we can skip more checks + // break; + } + } +} + +void ConnectionImpl::onWriteCompleted() { + while (!writeOperations_.empty()) { + EFAWriteOperation& writeOperation = writeOperations_.front(); + if (writeOperation.completed()) { + writeOperation.callbackFromLoop(Error::kSuccess); + writeOperations_.pop_front(); + } else { + break; + } + } +} + +void ConnectionImpl::onReadCompleted() { + while (!readOperations_.empty()) { + EFAReadOperation& readOperation = readOperations_.front(); + if (readOperation.completed()) { + readOperation.callbackFromLoop(Error::kSuccess); + readOperations_.pop_front(); + } else { + break; + } + } +} + +void ConnectionImpl::processWriteOperationsFromLoop() { + TP_DCHECK(context_->inLoop()); + + if (state_ != ESTABLISHED) { + return; + } + + for (int i = 0; i < writeOperations_.size(); i++) { + EFAWriteOperation& writeOperation = writeOperations_[i]; + if (!writeOperation.posted()) { + EFAWriteOperation::Buf* bufArray; + size_t size; + std::tie(bufArray, size) = writeOperation.getBufs(); + context_->getReactor().postSend( + bufArray[0].base, + bufArray[0].len, + kLength | sendIdx_, + peerAddr_, + &writeOperation); + if (size > 1) { + context_->getReactor().postSend( + bufArray[1].base, + bufArray[1].len, + kPayload | sendIdx_, + peerAddr_, + &writeOperation); + } + writeOperation.setWaitComplete(); + sendIdx_++; + } else { + // if the operation is posted, all operations back should be posted + // we can skip more checks + // break; + } + } +} + +void ConnectionImpl::handleErrorImpl() { + for (auto& readOperation : readOperations_) { + readOperation.callbackFromLoop(error_); + } + readOperations_.clear(); + + for (auto& writeOperation : writeOperations_) { + writeOperation.callbackFromLoop(error_); + } + writeOperations_.clear(); + + cleanup(); + + if (socket_.hasValue()) { + if (state_ > INITIALIZING) { + context_->unregisterDescriptor(socket_.fd()); + } + socket_.reset(); + } + + context_->unenroll(*this); +} + +void ConnectionImpl::cleanup() { + TP_DCHECK(context_->inLoop()); + TP_VLOG(8) << "Connection " << id_ << " is cleaning up"; + context_->getReactor().removePeerAddr(peerAddr_); +} + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/connection_impl.h b/tensorpipe/transport/efa/connection_impl.h new file mode 100644 index 000000000..57ae9aff8 --- /dev/null +++ b/tensorpipe/transport/efa/connection_impl.h @@ -0,0 +1,131 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +class ContextImpl; +class ListenerImpl; + +class ConnectionImpl final : public ConnectionImplBoilerplate< + ContextImpl, + ListenerImpl, + ConnectionImpl>, + public EpollLoop::EventHandler { + enum State { + INITIALIZING = 1, + SEND_ADDR, + RECV_ADDR, + ESTABLISHED, + }; + + public: + // Create a connection that is already connected (e.g. from a listener). + ConnectionImpl( + ConstructorToken token, + std::shared_ptr context, + std::string id, + Socket socket); + + // Create a connection that connects to the specified address. + ConnectionImpl( + ConstructorToken token, + std::shared_ptr context, + std::string id, + std::string addr); + + // Implementation of EventHandler. + void handleEventsFromLoop(int events) override; + + void onWriteCompleted(); + void onReadCompleted(); + + protected: + // Implement the entry points called by ConnectionImplBoilerplate. + void initImplFromLoop() override; + void readImplFromLoop(read_callback_fn fn) override; + void readImplFromLoop(void* ptr, size_t length, read_callback_fn fn) override; + void writeImplFromLoop(const void* ptr, size_t length, write_callback_fn fn) + override; + void handleErrorImpl() override; + + private: + // Handle events of type EPOLLIN on the UNIX domain socket. + // + // The only data that is expected on that socket is the address and other + // setup information for the other side's queue pair and inbox. + void handleEventInFromLoop(); + + // Handle events of type EPOLLOUT on the UNIX domain socket. + // + // Once the socket is writable we send the address and other setup information + // for this side's queue pair and inbox. + void handleEventOutFromLoop(); + + State state_{INITIALIZING}; + Socket socket_; + optional sockaddr_; + + fi_addr_t peerAddr_; + + uint32_t sendIdx_ = 0; + uint32_t recvIdx_ = 0; + + // Pending read operations. + std::deque readOperations_; + + // Pending write operations. + std::deque writeOperations_; + + // Process pending read operations if in an operational state. + // + // This may be triggered by the other side of the connection (by pushing this + // side's inbox token to the reactor) when it has written some new data to its + // outbox (which is this side's inbox). It is also called by this connection + // when it moves into an established state or when a new read operation is + // queued, in case data was already available before this connection was ready + // to consume it. + void processReadOperationsFromLoop(); + + // Process pending write operations if in an operational state. + // + // This may be triggered by the other side of the connection (by pushing this + // side's outbox token to the reactor) when it has read some data from its + // inbox (which is this side's outbox). This is important when some of this + // side's writes couldn't complete because the outbox was full, and thus they + // needed to wait for some of its data to be read. This method is also called + // by this connection when it moves into an established state, in case some + // writes were queued before the connection was ready to process them, or when + // a new write operation is queued. + void processWriteOperationsFromLoop(); + + void cleanup(); +}; + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/constants.h b/tensorpipe/transport/efa/constants.h new file mode 100644 index 000000000..920b40efd --- /dev/null +++ b/tensorpipe/transport/efa/constants.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace { + +// We should probably allow these to be user-configured. But, for now, we'll set +// them to the lowest value they can have, the rationale being that this way +// they will always be valid. +constexpr uint8_t kPortNum = 1; +constexpr uint8_t kGlobalIdentifierIndex = 0; + +// FIXME Instead of hardcoding the next three values, we could use +// efa_query_device to obtain max_cqe, max_qp_wr and max_srq_wr and deduce from +// them the maximum allowed values for these parameters. + +// How many simultaneous receive requests to keep queued on the shared receive +// queue. Incoming RDMA writes and sends will consume one such request. The +// reactor loop will fill the SRQ back up to this value once some requests +// complete. So this number should just be large enough to accommodate all the +// requests that could finish between two reactor loop iterations. And, even if +// this number ends up being too low, the excess incoming requests will just +// retry, causing a performance penalty but not a failure. +constexpr uint32_t kNumPendingRecvReqs = 1024; + +// How many RDMA write requests can be pending at the same time across all +// connections. We need to put a limit on them because they all use the same +// global completion queue which has a fixed capacity and if it overruns it will +// enter an unrecoverable error state. This value is also set as the capacity of +// the send queue of each queue pair. +constexpr uint32_t kNumPendingWriteReqs = 1024; + +// How many elements the completion queue should be able to hold. These elements +// will be either the completed receive requests of the SRQ, or the completed +// send requests from a connection's queue pair. We can bound the former value +// but not the latter, so we try to add some margin. +constexpr int kCompletionQueueSize = kNumPendingRecvReqs + kNumPendingWriteReqs; + +// How many work completions to poll from the completion queue at each reactor +// iteration. +constexpr int kNumPolledWorkCompletions = 64; + +} // namespace diff --git a/tensorpipe/transport/efa/context_impl.cc b/tensorpipe/transport/efa/context_impl.cc new file mode 100644 index 000000000..3e14c873f --- /dev/null +++ b/tensorpipe/transport/efa/context_impl.cc @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +namespace { + +// Prepend descriptor with transport name so it's easy to +// disambiguate descriptors when debugging. +const std::string kDomainDescriptorPrefix{"efa:"}; + +std::string generateDomainDescriptor() { + // It would be very cool if we could somehow obtain an "identifier" for the + // InfiniBand subnet that our device belongs to, but nothing of that sort + // seems to be available. So instead we say that if the user is trying to + // connect two processes which both have access to an InfiniBand device then + // they must know what they are doing and probably must have set up things + // properly. + return kDomainDescriptorPrefix + "*"; +} + +} // namespace + +std::shared_ptr ContextImpl::create() { + Error error; + EfaLib efaLib; + std::tie(error, efaLib) = EfaLib::create(); + if (error) { + TP_VLOG(7) + << "efa transport is not viable because libfabric couldn't be loaded: " + << error.what(); + return nullptr; + } + + EfaDeviceList deviceList; + std::tie(error, deviceList) = EfaDeviceList::create(efaLib); + if (error) { + TP_VLOG(7) << "EFA transport is not viable because it couldn't find any" + << "EFA devices"; + return nullptr; + } + TP_THROW_ASSERT_IF(error) + << "Couldn't get list of EFA devices: " << error.what(); + + return std::make_shared( + std::move(efaLib), std::move(deviceList)); +} + +ContextImpl::ContextImpl(EfaLib efaLib, EfaDeviceList deviceList) + : ContextImplBoilerplate( + generateDomainDescriptor()), + reactor_(std::move(efaLib), std::move(deviceList)) {} + +void ContextImpl::handleErrorImpl() { + loop_.close(); + reactor_.close(); +} + +void ContextImpl::joinImpl() { + loop_.join(); + reactor_.join(); +} + +bool ContextImpl::inLoop() const { + return reactor_.inLoop(); +}; + +void ContextImpl::deferToLoop(std::function fn) { + reactor_.deferToLoop(std::move(fn)); +}; + +void ContextImpl::registerDescriptor( + int fd, + int events, + std::shared_ptr h) { + loop_.registerDescriptor(fd, events, std::move(h)); +} + +void ContextImpl::unregisterDescriptor(int fd) { + loop_.unregisterDescriptor(fd); +} + +Reactor& ContextImpl::getReactor() { + return reactor_; +} + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/context_impl.h b/tensorpipe/transport/efa/context_impl.h new file mode 100644 index 000000000..813ce0c6b --- /dev/null +++ b/tensorpipe/transport/efa/context_impl.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +class ConnectionImpl; +class ListenerImpl; + +class ContextImpl final + : public ContextImplBoilerplate { + public: + static std::shared_ptr create(); + + ContextImpl(EfaLib efaLib, EfaDeviceList deviceList); + ContextImpl(); + + // Implement the DeferredExecutor interface. + bool inLoop() const override; + void deferToLoop(std::function fn) override; + + void registerDescriptor( + int fd, + int events, + std::shared_ptr h); + + void unregisterDescriptor(int fd); + + Reactor& getReactor(); + + protected: + // Implement the entry points called by ContextImplBoilerplate. + void handleErrorImpl() override; + void joinImpl() override; + + private: + Reactor reactor_; + EpollLoop loop_{this->reactor_}; +}; + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/error.cc b/tensorpipe/transport/efa/error.cc new file mode 100644 index 000000000..17c8f01ba --- /dev/null +++ b/tensorpipe/transport/efa/error.cc @@ -0,0 +1,37 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +// #include + +namespace tensorpipe { +namespace transport { +namespace efa { + +std::string EfaError::what() const { + return error_; +} + +std::string GetaddrinfoError::what() const { + std::ostringstream ss; + ss << "getaddrinfo: " << gai_strerror(error_); + return ss.str(); +} + +std::string NoAddrFoundError::what() const { + return "no address found"; +} + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/error.h b/tensorpipe/transport/efa/error.h new file mode 100644 index 000000000..a37dc632e --- /dev/null +++ b/tensorpipe/transport/efa/error.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +class EfaError final : public BaseError { + public: + explicit EfaError(std::string error) : error_(error) {} + + std::string what() const override; + + private: + std::string error_; +}; + +class GetaddrinfoError final : public BaseError { + public: + explicit GetaddrinfoError(int error) : error_(error) {} + + std::string what() const override; + + private: + int error_; +}; + +class NoAddrFoundError final : public BaseError { + public: + NoAddrFoundError() {} + + std::string what() const override; +}; + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/factory.cc b/tensorpipe/transport/efa/factory.cc new file mode 100644 index 000000000..c37363ed2 --- /dev/null +++ b/tensorpipe/transport/efa/factory.cc @@ -0,0 +1,27 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +std::shared_ptr create() { + return std::make_shared< + ContextBoilerplate>(); +} + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/factory.h b/tensorpipe/transport/efa/factory.h new file mode 100644 index 000000000..76611507e --- /dev/null +++ b/tensorpipe/transport/efa/factory.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +std::shared_ptr create(); + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/listener_impl.cc b/tensorpipe/transport/efa/listener_impl.cc new file mode 100644 index 000000000..d01aff347 --- /dev/null +++ b/tensorpipe/transport/efa/listener_impl.cc @@ -0,0 +1,158 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +ListenerImpl::ListenerImpl( + ConstructorToken token, + std::shared_ptr context, + std::string id, + std::string addr) + : ListenerImplBoilerplate( + token, + std::move(context), + std::move(id)), + sockaddr_(Sockaddr::createInetSockAddr(addr)) {} + +void ListenerImpl::initImplFromLoop() { + context_->enroll(*this); + + Error error; + TP_DCHECK(!socket_.hasValue()); + std::tie(error, socket_) = + Socket::createForFamily(sockaddr_.addr()->sa_family); + if (error) { + setError(std::move(error)); + return; + } + error = socket_.reuseAddr(true); + if (error) { + setError(std::move(error)); + return; + } + error = socket_.bind(sockaddr_); + if (error) { + setError(std::move(error)); + return; + } + error = socket_.block(false); + if (error) { + setError(std::move(error)); + return; + } + error = socket_.listen(128); + if (error) { + setError(std::move(error)); + return; + } + + struct sockaddr_storage addr; + socklen_t addrlen; + std::tie(error, addr, addrlen) = socket_.getSockName(); + if (error) { + setError(std::move(error)); + return; + } + sockaddr_ = Sockaddr(reinterpret_cast(&addr), addrlen); +} + +void ListenerImpl::handleErrorImpl() { + if (!fns_.empty()) { + context_->unregisterDescriptor(socket_.fd()); + } + socket_.reset(); + for (auto& fn : fns_) { + fn(error_, std::shared_ptr()); + } + fns_.clear(); + + context_->unenroll(*this); +} + +void ListenerImpl::acceptImplFromLoop(accept_callback_fn fn) { + fns_.push_back(std::move(fn)); + + // Only register if we go from 0 to 1 pending callbacks. In other cases we + // already had a pending callback and thus we were already registered. + if (fns_.size() == 1) { + // Register with loop for readability events. + context_->registerDescriptor(socket_.fd(), EPOLLIN, shared_from_this()); + } +} + +std::string ListenerImpl::addrImplFromLoop() const { + return sockaddr_.str(); +} + +void ListenerImpl::handleEventsFromLoop(int events) { + TP_DCHECK(context_->inLoop()); + TP_VLOG(9) << "Listener " << id_ << " is handling an event on its socket (" + << EpollLoop::formatEpollEvents(events) << ")"; + + if (events & EPOLLERR) { + int error; + socklen_t errorlen = sizeof(error); + int rv = getsockopt( + socket_.fd(), + SOL_SOCKET, + SO_ERROR, + reinterpret_cast(&error), + &errorlen); + if (rv == -1) { + setError(TP_CREATE_ERROR(SystemError, "getsockopt", rv)); + } else { + setError(TP_CREATE_ERROR(SystemError, "async error on socket", error)); + } + return; + } + if (events & EPOLLHUP) { + setError(TP_CREATE_ERROR(EOFError)); + return; + } + TP_ARG_CHECK_EQ(events, EPOLLIN); + + Error error; + Socket socket; + std::tie(error, socket) = socket_.accept(); + if (error) { + setError(std::move(error)); + return; + } + + TP_DCHECK(!fns_.empty()) + << "when the callback is disarmed the listener's descriptor is supposed " + << "to be unregistered"; + auto fn = std::move(fns_.front()); + fns_.pop_front(); + if (fns_.empty()) { + context_->unregisterDescriptor(socket_.fd()); + } + fn(Error::kSuccess, createAndInitConnection(std::move(socket))); +} + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/listener_impl.h b/tensorpipe/transport/efa/listener_impl.h new file mode 100644 index 000000000..e4c09ef8b --- /dev/null +++ b/tensorpipe/transport/efa/listener_impl.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +class ConnectionImpl; +class ContextImpl; + +class ListenerImpl final + : public ListenerImplBoilerplate, + public EpollLoop::EventHandler { + public: + // Create a listener that listens on the specified address. + ListenerImpl( + ConstructorToken token, + std::shared_ptr context, + std::string id, + std::string addr); + + // Implementation of EventHandler. + void handleEventsFromLoop(int events) override; + + protected: + // Implement the entry points called by ListenerImplBoilerplate. + void initImplFromLoop() override; + void acceptImplFromLoop(accept_callback_fn fn) override; + std::string addrImplFromLoop() const override; + void handleErrorImpl() override; + + private: + Socket socket_; + Sockaddr sockaddr_; + std::deque fns_; +}; + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/mode.md b/tensorpipe/transport/efa/mode.md new file mode 100644 index 000000000..3857f2bf0 --- /dev/null +++ b/tensorpipe/transport/efa/mode.md @@ -0,0 +1,27 @@ +# EFA + +The EFA communication model can be considered as a simplified ibverbs. The send/recv operation is async by the event queue and completion queue. And the operation itself doesn't need memory registration like ibverbs and also doesn't act like stream operation in socket. The complexity at the memory part is handled by the underlying provider(libfabric+efa). The overall implementation can be considered as Reactor from ibverbs + StreamOperation from uv. + +EFA supports the send-after-send order guarantees for data operation, which means the message order is preserved. However, the completion order is not guaranteed when reading events from completion queue. +For example, sender posts S1, S2, S3 three send operations; receiver posts R1, R2, R3 three recv operations. These operations are exactly matched due to send-after-send guarantee. But when reading from the completion queue at receiver side, the completion order might be R2, R1, R3 or other. Same for the sender side. + +This brings complexity in the busy polling thread when dealing with completion events, since the callback of write operations should be executed in order. To address this issue, the pointer of the `EFAWriteOperation` is passed as operation context when post send event. And in the completion stage, it will set the mode of `EFAWriteOperation` to completed. A seperate function is executed later by iterating the `writeOperations_` deque from front, and execute callback if the operation is done. + +For the receiver part, it's more complex. For example there are two incoming writeOperations. It will become 4 send operation at sender side, SEND_SIZE1, SEND_PAYLOAD1, SEND_SIZE2, SEND_PAYLOAD2. At the receiver side, the expected behavior is +1. Post receive event of single 64bits size, such as RECV_SIZE1 +2. Poll from cq when RECV_SIZE1 is done, and post RECV_PAYLOAD1 with the size in RECV_SIZE1 +3. Did the same thing for the second operation + +However when the four send operation issued concurrently, the first completion event at receiver side might be RECV_SIZE2. If we follow the process above, the recv order will be messed up. To address this problem, the implementation used tag matching. That each operation will have a index decided at the sender side. Two indicator, kLength=1ULL<<32, kPayload=1ULL<<33, are used to indicate the type of the message. The message tag is a 64bit integer, that the high 32 bits are indicators (kLength or kPayload), and the low 32 bits are operation ids. + +Send side: + +1. `post_send(buffer=&length, size=sizeof(int64_t), tag=kLength | msg_id, ...)` +2. `post_send(buffer=buffer, size=length, tag = kPayload | msg_id, ...)` +3. `msg_id++` msg_id is uint32_t + +Receiver side: +At receiver size, we first recv the message with high 32 bits equaling kLength and decode the index from low 32 bits. And then post a recv event with tag `kPayload | msg_id` +1. `post_recv(buffer=&length, size=sizeof(int64_t), tag=kLength, ignore=0xffffffff, ...)` ignore=0xfffffff means ignore lower 32 bits when matching tag +2. decode message id from the incoming message tag (take lower 32bits) +3. `post_recv(buffer=buffer, size=length, tag=kPayload | msg_id, ignore=0, ...)` ignore=0 means the tag should be exactly the same to match \ No newline at end of file diff --git a/tensorpipe/transport/efa/reactor.cc b/tensorpipe/transport/efa/reactor.cc new file mode 100644 index 000000000..cefcb8dcb --- /dev/null +++ b/tensorpipe/transport/efa/reactor.cc @@ -0,0 +1,219 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +Reactor::Reactor(EfaLib efaLib, EfaDeviceList efaDeviceList) { + efaLib_ = std::move(efaLib); + // AWS p4d instances may have multiple EFAs. Only use device 0 for now + EfaLib::device* device = &efaDeviceList[0]; + fabric_ = createEfaFabric(efaLib, device); + domain_ = createEfaDomain(efaLib, fabric_, device); + ep_ = createEfaEndpoint(efaLib, domain_, device); + av_ = createEfaAdressVector(efaLib, domain_); + cq_ = createEfaCompletionQueue(efaLib, domain_, device); + addr_ = enableEndpoint(efaLib, ep_, av_, cq_); + startThread("TP_efa_reactor"); +} + +void Reactor::postSend( + void* buffer, + size_t size, + uint64_t tag, + fi_addr_t peerAddr, + void* context) { + pendingSends_.emplace_back(EfaEvent((new fi_msg_tagged{ + /* msg_iov */ new iovec{.iov_base = buffer, .iov_len = size}, + /* desc */ 0, + /* iov_count */ 1, + /* peer addr */ peerAddr, + /* tag */ tag, + /* ignore */ 0, + /* context */ context, + /* data */ 0}))); + postPendingRecvs(); +} + +void Reactor::postRecv( + void* buffer, + size_t size, + uint64_t tag, + fi_addr_t peerAddr, + uint64_t ignore, + void* context) { + pendingRecvs_.emplace_back(EfaEvent(new fi_msg_tagged{ + /* msg_iov */ new iovec{.iov_base = buffer, .iov_len = size}, + /* desc */ 0, + /* iov_count */ 1, + /* peer addr */ peerAddr, + /* tag */ tag, + /* ignore */ ignore, + /* context */ context, + /* data */ 0})); + postPendingRecvs(); +} + +int Reactor::postPendingSends() { + while (!pendingSends_.empty()) { + fi_msg_tagged* sevent = pendingSends_.front().get(); + int ret = fi_tsendmsg(ep_.get(), sevent, 0); + if (ret == 0) { + // Send successfully, pop out events + pendingSends_.pop_front(); + } else if (ret == -FI_EAGAIN) { + return pendingSends_.size(); + } else if (ret < 0) { + // Unknown failure, raise exception + TP_CHECK_EFA_RET(ret, "Unable to do fi_tsend message"); + } + } + + return 0; +} + +fi_addr_t Reactor::addPeerAddr(EfaAddress& addr) { + fi_addr_t peerAddr; + int ret = fi_av_insert(av_.get(), addr.name, 1, &peerAddr, 0, nullptr); + TP_THROW_ASSERT_IF(ret != 1) << "Unable to add address to endpoint"; + TP_CHECK_EFA_RET(ret, "Unable to add address to endpoint"); + efaAddrSet_.emplace(peerAddr); + return peerAddr; +} + +void Reactor::removePeerAddr(fi_addr_t faddr) { + int ret = fi_av_remove(av_.get(), &faddr, 1, 0); + TP_CHECK_EFA_RET(ret, "Unable to remove address from endpoint"); + efaAddrSet_.erase(faddr); +}; + +int Reactor::postPendingRecvs() { + while (!pendingRecvs_.empty()) { + fi_msg_tagged* revent = pendingRecvs_.front().get(); + int ret = fi_trecvmsg(ep_.get(), revent, 0); + if (ret == 0) { + // Send successfully, pop out events + pendingRecvs_.pop_front(); + } else if (ret == -FI_EAGAIN) { + return pendingRecvs_.size(); + } else if (ret < 0) { + // Unknown failure, raise exception + TP_CHECK_EFA_RET(ret, "Unable to do fi_trecv message"); + } + } + return 0; +} + +void Reactor::setId(std::string id) { + id_ = std::move(id); +} + +void Reactor::close() { + if (!closed_.exchange(true)) { + stopBusyPolling(); + } +} + +void Reactor::join() { + close(); + + if (!joined_.exchange(true)) { + joinThread(); + } +} + +Reactor::~Reactor() { + join(); +} + +bool Reactor::pollOnce() { + std::array cqEntries; + std::array srcAddrs; + + postPendingSends(); + postPendingRecvs(); + int rv = fi_cq_readfrom( + cq_.get(), cqEntries.data(), cqEntries.size(), srcAddrs.data()); + if (rv == 0 || rv == -FI_EAGAIN) { + return false; + } else { + TP_CHECK_EFA_RET(rv, "Completion queue poll error."); + } + + int numRecvs = 0; + int numWrites = 0; + int numAcks = 0; + for (int cqIdx = 0; cqIdx < rv; cqIdx++) { + struct fi_cq_tagged_entry& cq = cqEntries[cqIdx]; + fi_addr_t& srcAddr = srcAddrs[cqIdx]; + uint32_t msgIdx = static_cast(cq.tag); + if (cq.flags & FI_SEND) { + // Send event + if (cq.tag & kLength) { + // Send size finished, check whether it's zero sized message + auto* operationPtr = static_cast(cq.op_context); + if (operationPtr->getLength() == 0) { + operationPtr->setCompleted(); + reinterpret_cast(operationPtr->getOpContext()) + ->onWriteCompleted(); + } + } else if (cq.tag & kPayload) { + auto* operationPtr = static_cast(cq.op_context); + operationPtr->setCompleted(); + reinterpret_cast(operationPtr->getOpContext()) + ->onWriteCompleted(); + } + } else if (cq.flags & FI_RECV) { + // Receive event + if (cq.tag & kLength) { + // Received length information + auto* operationPtr = static_cast(cq.op_context); + if (operationPtr->getReadLength() == 0) { + operationPtr->setCompleted(); + reinterpret_cast(operationPtr->getOpContext()) + ->onReadCompleted(); + } else { + // operation_ptr->mode_ = EFAReadOperation::Mode::READ_PAYLOAD; + operationPtr->allocFromLoop(); + postRecv( + operationPtr->getBufferPtr(), + operationPtr->getReadLength(), + kPayload | msgIdx, + srcAddr, + 0, // Exact match of tag + operationPtr); + operationPtr->setWaitToCompleted(); + } + } else if (cq.tag & kPayload) { + // Received payload + auto* operationPtr = static_cast(cq.op_context); + operationPtr->setCompleted(); + reinterpret_cast(operationPtr->getOpContext()) + ->onReadCompleted(); + } + } + } + + return true; +} + +bool Reactor::readyToClose() { + return efaAddrSet_.size() == 0; +} + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/reactor.h b/tensorpipe/transport/efa/reactor.h new file mode 100644 index 000000000..a3b96b361 --- /dev/null +++ b/tensorpipe/transport/efa/reactor.h @@ -0,0 +1,133 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +enum EfaTag : uint64_t { + kLength = 1ULL << 32, + kPayload = 1ULL << 33, +}; + +// Reactor loop. +// +// Companion class to the event loop in `loop.h` that executes +// functions on triggers. The triggers are posted to a shared memory +// ring buffer, so this can be done by other processes on the same +// machine. It uses extra data in the ring buffer header to store a +// mutex and condition variable to avoid a busy loop. +// +class Reactor final : public BusyPollingLoop { + public: + Reactor(EfaLib efaLib, EfaDeviceList efaDeviceList); + + const EfaLib& getefaLib() { + return efaLib_; + } + + EfaDomain& getefaDomain() { + return domain_; + } + + EfaCompletionQueue& getefaCq() { + return cq_; + } + + const EfaAddress& getEfaAddress() { + return addr_; + } + + void postSend( + void* buffer, + size_t size, + uint64_t tag, + fi_addr_t peerAddr, + void* context); + + void postRecv( + void* buffer, + size_t size, + uint64_t tag, + fi_addr_t peerAddr, + uint64_t ignore, + void* context); + + fi_addr_t addPeerAddr(EfaAddress& addr); + + void removePeerAddr(fi_addr_t faddr); + + void setId(std::string id); + + void close(); + + void join(); + + ~Reactor(); + + protected: + bool pollOnce() override; + + bool readyToClose() override; + + class EfaEventDeleter { + public: + void operator()(fi_msg_tagged* msg) { + delete msg->msg_iov; + } + }; + using EfaEvent = std::unique_ptr; + + private: + EfaLib efaLib_; + EfaFabric fabric_; + EfaDomain domain_; + EfaEndpoint ep_; + EfaCompletionQueue cq_; + EfaAdressVector av_; + EfaAddress addr_; + + int postPendingRecvs(); + int postPendingSends(); + + std::atomic closed_{false}; + std::atomic joined_{false}; + + // An identifier for the context, composed of the identifier for the context, + // combined with the transport's name. It will only be used for logging and + // debugging purposes. + std::string id_{"N/A"}; + + // The registered connections for each queue pair. + std::unordered_set efaAddrSet_; + + std::deque pendingSends_; + std::deque pendingRecvs_; +}; + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/sockaddr.cc b/tensorpipe/transport/efa/sockaddr.cc new file mode 100644 index 000000000..89b02e01d --- /dev/null +++ b/tensorpipe/transport/efa/sockaddr.cc @@ -0,0 +1,142 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include + +#include +#include + +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +Sockaddr Sockaddr::createInetSockAddr(const std::string& str) { + int port = 0; + std::string addrStr; + std::string portStr; + + // If the input string is an IPv6 address with port, the address + // itself must be wrapped with brackets. + if (addrStr.empty()) { + auto start = str.find("["); + auto stop = str.find("]"); + if (start < stop && start != std::string::npos && + stop != std::string::npos) { + addrStr = str.substr(start + 1, stop - (start + 1)); + if (stop + 1 < str.size() && str[stop + 1] == ':') { + portStr = str.substr(stop + 2); + } + } + } + + // If the input string is an IPv4 address with port, we expect + // at least a single period and a single colon in the string. + if (addrStr.empty()) { + auto period = str.find("."); + auto colon = str.find(":"); + if (period != std::string::npos && colon != std::string::npos) { + addrStr = str.substr(0, colon); + portStr = str.substr(colon + 1); + } + } + + // Fallback to using entire input string as address without port. + if (addrStr.empty()) { + addrStr = str; + } + + // Parse port number if specified. + if (!portStr.empty()) { + port = std::stoi(portStr); + if (port < 0 || port > std::numeric_limits::max()) { + TP_THROW_EINVAL() << str; + } + } + + // Try to convert an IPv4 address. + { + struct sockaddr_in addr; + std::memset(&addr, 0, sizeof(addr)); + auto rv = inet_pton(AF_INET, addrStr.c_str(), &addr.sin_addr); + TP_THROW_SYSTEM_IF(rv < 0, errno); + if (rv == 1) { + addr.sin_family = AF_INET; + addr.sin_port = ntohs(port); + return Sockaddr(reinterpret_cast(&addr), sizeof(addr)); + } + } + + // Try to convert an IPv6 address. + { + struct sockaddr_in6 addr; + std::memset(&addr, 0, sizeof(addr)); + + auto interfacePos = addrStr.find('%'); + if (interfacePos != std::string::npos) { + addr.sin6_scope_id = + if_nametoindex(addrStr.substr(interfacePos + 1).c_str()); + addrStr = addrStr.substr(0, interfacePos); + } + + auto rv = inet_pton(AF_INET6, addrStr.c_str(), &addr.sin6_addr); + TP_THROW_SYSTEM_IF(rv < 0, errno); + if (rv == 1) { + addr.sin6_family = AF_INET6; + addr.sin6_port = ntohs(port); + return Sockaddr(reinterpret_cast(&addr), sizeof(addr)); + } + } + + // Invalid address. + TP_THROW_EINVAL() << str; + + // Return bogus to silence "return from non-void function" warning. + // Note: we don't reach this point per the throw above. + return Sockaddr(nullptr, 0); +} + +std::string Sockaddr::str() const { + std::ostringstream oss; + + if (addr_.ss_family == AF_INET) { + std::array buf; + auto in = reinterpret_cast(&addr_); + auto rv = inet_ntop(AF_INET, &in->sin_addr, buf.data(), buf.size()); + TP_THROW_SYSTEM_IF(rv == nullptr, errno); + oss << buf.data() << ":" << htons(in->sin_port); + } else if (addr_.ss_family == AF_INET6) { + std::array buf; + auto in6 = reinterpret_cast(&addr_); + auto rv = inet_ntop(AF_INET6, &in6->sin6_addr, buf.data(), buf.size()); + TP_THROW_SYSTEM_IF(rv == nullptr, errno); + oss << "[" << buf.data(); + if (in6->sin6_scope_id > 0) { + std::array scopeBuf; + rv = if_indextoname(in6->sin6_scope_id, scopeBuf.data()); + TP_THROW_SYSTEM_IF(rv == nullptr, errno); + oss << "%" << scopeBuf.data(); + } + oss << "]:" << htons(in6->sin6_port); + + } else { + TP_THROW_EINVAL() << "invalid address family: " << addr_.ss_family; + } + + return oss.str(); +} + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/sockaddr.h b/tensorpipe/transport/efa/sockaddr.h new file mode 100644 index 000000000..cb6bbfd07 --- /dev/null +++ b/tensorpipe/transport/efa/sockaddr.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +class Sockaddr final : public tensorpipe::Sockaddr { + public: + static Sockaddr createInetSockAddr(const std::string& str); + + Sockaddr(const struct sockaddr* addr, socklen_t addrlen) { + TP_ARG_CHECK(addr != nullptr); + TP_ARG_CHECK_LE(addrlen, sizeof(addr_)); + // Ensure the sockaddr_storage is zeroed, because we don't always + // write to all fields in the `sockaddr_[in|in6]` structures. + std::memset(&addr_, 0, sizeof(addr_)); + std::memcpy(&addr_, addr, addrlen); + addrlen_ = addrlen; + } + + inline const struct sockaddr* addr() const override { + return reinterpret_cast(&addr_); + } + + inline struct sockaddr* addr() { + return reinterpret_cast(&addr_); + } + + inline socklen_t addrlen() const override { + return addrlen_; + } + + std::string str() const; + + private: + struct sockaddr_storage addr_; + socklen_t addrlen_; +}; + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/utility.cc b/tensorpipe/transport/efa/utility.cc new file mode 100644 index 000000000..8df5572e5 --- /dev/null +++ b/tensorpipe/transport/efa/utility.cc @@ -0,0 +1,178 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +namespace { + +struct InterfaceAddressesDeleter { + void operator()(struct ifaddrs* ptr) { + ::freeifaddrs(ptr); + } +}; + +using InterfaceAddresses = + std::unique_ptr; + +std::tuple createInterfaceAddresses() { + struct ifaddrs* ifaddrs; + auto rv = ::getifaddrs(&ifaddrs); + if (rv < 0) { + return std::make_tuple( + TP_CREATE_ERROR(SystemError, "getifaddrs", errno), + InterfaceAddresses()); + } + return std::make_tuple(Error::kSuccess, InterfaceAddresses(ifaddrs)); +} + +std::tuple getHostname() { + std::array hostname; + auto rv = ::gethostname(hostname.data(), hostname.size()); + if (rv < 0) { + return std::make_tuple( + TP_CREATE_ERROR(SystemError, "gethostname", errno), std::string()); + } + return std::make_tuple(Error::kSuccess, std::string(hostname.data())); +} + +struct AddressInfoDeleter { + void operator()(struct addrinfo* ptr) { + ::freeaddrinfo(ptr); + } +}; + +using AddressInfo = std::unique_ptr; + +std::tuple createAddressInfo(std::string host) { + struct addrinfo hints; + std::memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + struct addrinfo* result; + auto rv = ::getaddrinfo(host.c_str(), nullptr, &hints, &result); + if (rv != 0) { + return std::make_tuple( + TP_CREATE_ERROR(GetaddrinfoError, rv), AddressInfo()); + } + return std::make_tuple(Error::kSuccess, AddressInfo(result)); +} + +} // namespace + +std::tuple lookupAddrForIface(std::string iface) { + Error error; + InterfaceAddresses addresses; + std::tie(error, addresses) = createInterfaceAddresses(); + if (error) { + return std::make_tuple(std::move(error), std::string()); + } + + struct ifaddrs* ifa; + for (ifa = addresses.get(); ifa != nullptr; ifa = ifa->ifa_next) { + // Skip entry if ifa_addr is NULL (see getifaddrs(3)) + if (ifa->ifa_addr == nullptr) { + continue; + } + + if (iface != ifa->ifa_name) { + continue; + } + + switch (ifa->ifa_addr->sa_family) { + case AF_INET: + return std::make_tuple( + Error::kSuccess, + Sockaddr(ifa->ifa_addr, sizeof(struct sockaddr_in)).str()); + case AF_INET6: + return std::make_tuple( + Error::kSuccess, + Sockaddr(ifa->ifa_addr, sizeof(struct sockaddr_in6)).str()); + } + } + + return std::make_tuple(TP_CREATE_ERROR(NoAddrFoundError), std::string()); +} + +std::tuple lookupAddrForHostname() { + Error error; + std::string hostname; + std::tie(error, hostname) = getHostname(); + if (error) { + return std::make_tuple(std::move(error), std::string()); + } + + AddressInfo info; + std::tie(error, info) = createAddressInfo(std::move(hostname)); + if (error) { + return std::make_tuple(std::move(error), std::string()); + } + + Error firstError; + for (struct addrinfo* rp = info.get(); rp != nullptr; rp = rp->ai_next) { + TP_DCHECK(rp->ai_family == AF_INET || rp->ai_family == AF_INET6); + TP_DCHECK_EQ(rp->ai_socktype, SOCK_STREAM); + TP_DCHECK_EQ(rp->ai_protocol, IPPROTO_TCP); + + Sockaddr addr = Sockaddr(rp->ai_addr, rp->ai_addrlen); + + Socket socket; + std::tie(error, socket) = Socket::createForFamily(rp->ai_family); + + if (!error) { + error = socket.bind(addr); + } + + if (error) { + // Record the first binding error we encounter and return that in the end + // if no working address is found, in order to help with debugging. + if (!firstError) { + firstError = error; + } + continue; + } + + return std::make_tuple(Error::kSuccess, addr.str()); + } + + if (firstError) { + return std::make_tuple(std::move(firstError), std::string()); + } else { + return std::make_tuple(TP_CREATE_ERROR(NoAddrFoundError), std::string()); + } +} + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/tensorpipe/transport/efa/utility.h b/tensorpipe/transport/efa/utility.h new file mode 100644 index 000000000..e4ec1a4de --- /dev/null +++ b/tensorpipe/transport/efa/utility.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include + +namespace tensorpipe { +namespace transport { +namespace efa { + +std::tuple lookupAddrForIface(std::string iface); + +std::tuple lookupAddrForHostname(); + +} // namespace efa +} // namespace transport +} // namespace tensorpipe diff --git a/third_party/libfabric b/third_party/libfabric new file mode 160000 index 000000000..4c47e0b0c --- /dev/null +++ b/third_party/libfabric @@ -0,0 +1 @@ +Subproject commit 4c47e0b0cf92bec1fc9003ac046612cd05490992