Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Context and Endpoint classes to enable non-Communicator use-cases #166

Merged
merged 17 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ install(TARGETS mscclpp_static

# Tests
if (BUILD_TESTS)
enable_testing() # Called here to allow ctest from the build directory
add_subdirectory(test)
endif()

Expand Down
218 changes: 162 additions & 56 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ class TcpBootstrap : public Bootstrap {
void barrier() override;

private:
/// Implementation class for @ref TcpBootstrap.
class Impl;
/// Pointer to the implementation class for @ref TcpBootstrap.
// The interal implementation.
struct Impl;

// Pointer to the internal implementation.
std::unique_ptr<Impl> pimpl_;
};

Expand Down Expand Up @@ -303,23 +304,15 @@ std::string getIBDeviceName(Transport ibTransport);
/// @return The InfiniBand transport associated with the specified device name.
Transport getIBTransportByDeviceName(const std::string& ibDeviceName);

class Communicator;
class Context;
class Connection;

/// Represents a block of memory that has been registered to a @ref Communicator.
/// Represents a block of memory that has been registered to a @ref Context.
class RegisteredMemory {
protected:
struct Impl;

public:
/// Default constructor.
RegisteredMemory() = default;

/// Constructor that takes a shared pointer to an implementation object.
///
/// @param pimpl A shared pointer to an implementation object.
RegisteredMemory(std::shared_ptr<Impl> pimpl);

/// Destructor.
~RegisteredMemory();

Expand All @@ -333,11 +326,6 @@ class RegisteredMemory {
/// @return The size of the memory block.
size_t size();

/// Get the rank of the process that owns the memory block.
///
/// @return The rank of the process that owns the memory block.
int rank();

/// Get the transport flags associated with the memory block.
///
/// @return The transport flags associated with the memory block.
Expand All @@ -354,14 +342,54 @@ class RegisteredMemory {
/// @return A deserialized RegisteredMemory object.
static RegisteredMemory deserialize(const std::vector<char>& data);

private:
// The interal implementation.
struct Impl;

// Internal constructor.
RegisteredMemory(std::shared_ptr<Impl> pimpl);

// Pointer to the internal implementation. A shared_ptr is used since RegisteredMemory is immutable.
std::shared_ptr<Impl> pimpl_;

friend class Context;
friend class Connection;
friend class IBConnection;
friend class Communicator;
};

/// Represents one end of a connection.
class Endpoint {
public:
/// Default constructor.
Endpoint() = default;

/// Get the transport used.
///
/// @return The transport used.
Transport transport();

/// Serialize the Endpoint object to a vector of characters.
///
/// @return A vector of characters representing the serialized Endpoint object.
std::vector<char> serialize();

/// Deserialize a Endpoint object from a vector of characters.
///
/// @param data A vector of characters representing a serialized Endpoint object.
/// @return A deserialized Endpoint object.
static Endpoint deserialize(const std::vector<char>& data);

private:
// A shared_ptr is used since RegisteredMemory is functionally immutable, although internally some state is populated
// lazily.
std::shared_ptr<Impl> pimpl;
// The interal implementation.
struct Impl;

// Internal constructor.
Endpoint(std::shared_ptr<Impl> pimpl);

// Pointer to the internal implementation. A shared_ptr is used since Endpoint is immutable.
std::shared_ptr<Impl> pimpl_;

friend class Context;
friend class Connection;
};

/// Represents a connection between two processes.
Expand All @@ -388,16 +416,6 @@ class Connection {
/// Flush any pending writes to the remote process.
virtual void flush(int64_t timeoutUsec = 3e7) = 0;

/// Get the rank of the remote process.
///
/// @return The rank of the remote process.
virtual int remoteRank() = 0;

/// Get the tag associated with the connection.
///
/// @return The tag associated with the connection.
virtual int tag() = 0;

/// Get the transport used by the local process.
///
/// @return The transport used by the local process.
Expand All @@ -409,11 +427,89 @@ class Connection {
virtual Transport remoteTransport() = 0;

protected:
/// Get the implementation object associated with a @ref RegisteredMemory object.
// Internal methods for getting implementation pointers.
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
};

/// Used to configure an endpoint.
struct EndpointConfig {
static const int DefaultMaxCqSize = 1024;
static const int DefaultMaxCqPollNum = 1;
static const int DefaultMaxSendWr = 8192;
static const int DefaultMaxWrPerSend = 64;

Transport transport;
int ibMaxCqSize = DefaultMaxCqSize;
int ibMaxCqPollNum = DefaultMaxCqPollNum;
int ibMaxSendWr = DefaultMaxSendWr;
int ibMaxWrPerSend = DefaultMaxWrPerSend;

/// Default constructor. Sets transport to Transport::Unknown.
EndpointConfig() : transport(Transport::Unknown) {}

/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param memory The @ref RegisteredMemory object.
/// @return A shared pointer to the implementation object.
static std::shared_ptr<RegisteredMemory::Impl> getRegisteredMemoryImpl(RegisteredMemory& memory);
/// @param transport The transport to use.
EndpointConfig(Transport transport) : transport(transport) {}
};

/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases
/// where the process group abstraction offered by @ref Communicator is not suitable, e.g., ephemeral client-server
/// connections. Correct use of this class requires external synchronization when finalizing connections with the
/// @ref connect() method.
///
/// As an example, a client-server scenario where the server will write to the client might proceed as follows:
/// 1. The client creates an endpoint with @ref createEndpoint() and sends it to the server.
/// 2. The server receives the client endpoint, creates its own endpoint with @ref createEndpoint(), sends it to the
/// client, and creates a connection with @ref connect().
/// 4. The client receives the server endpoint, creates a connection with @ref connect() and sends a
/// @ref RegisteredMemory to the server.
/// 5. The server receives the @ref RegisteredMemory and writes to it using the previously created connection.
/// The client waiting to create a connection before sending the @ref RegisteredMemory ensures that the server can not
/// write to the @ref RegisteredMemory before the connection is established.
///
/// While some transports may have more relaxed implementation behavior, this should not be relied upon.
class Context {
public:
/// Create a context.
Context();

/// Destroy the context.
~Context();

/// Register a region of GPU memory for use in this context.
///
/// @param ptr Base pointer to the memory.
/// @param size Size of the memory region in bytes.
/// @param transports Transport flags.
/// @return RegisteredMemory A handle to the buffer.
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);

/// Create an endpoint for establishing connections.
///
/// @param config The configuration for the endpoint.
/// @return The newly created endpoint.
Endpoint createEndpoint(EndpointConfig config);

/// Establish a connection between two endpoints. While this method immediately returns a connection object, the
/// connection is only safe to use after the corresponding connection on the remote endpoint has been established.
/// This method must be called on both endpoints to establish a connection.
///
/// @param localEndpoint The local endpoint.
/// @param remoteEndpoint The remote endpoint.
/// @return std::shared_ptr<Connection> A shared pointer to the connection.
std::shared_ptr<Connection> connect(Endpoint localEndpoint, Endpoint remoteEndpoint);

private:
// The interal implementation.
struct Impl;

// Pointer to the internal implementation.
std::unique_ptr<Impl> pimpl_;

friend class RegisteredMemory;
friend class Endpoint;
};

/// A base class for objects that can be set up during @ref Communicator::setup().
Expand Down Expand Up @@ -479,14 +575,12 @@ class NonblockingFuture {
/// 6. All done; use connections and registered memories to build channels.
///
class Communicator {
protected:
struct Impl;

public:
/// Initializes the communicator with a given bootstrap implementation.
///
/// @param bootstrap An implementation of the Bootstrap that the communicator will use.
Communicator(std::shared_ptr<Bootstrap> bootstrap);
/// @param context An optional context to use for the communicator. If not provided, a new context will be created.
Communicator(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<Context> context = nullptr);

/// Destroy the communicator.
~Communicator();
Expand All @@ -496,7 +590,12 @@ class Communicator {
/// @return std::shared_ptr<Bootstrap> The bootstrap held by this communicator.
std::shared_ptr<Bootstrap> bootstrap();

/// Register a region of GPU memory for use in this communicator.
/// Returns the context held by this communicator.
///
/// @return std::shared_ptr<Context> The context held by this communicator.
std::shared_ptr<Context> context();

/// Register a region of GPU memory for use in this communicator's context.
///
/// @param ptr Base pointer to the memory.
/// @param size Size of the memory region in bytes.
Expand Down Expand Up @@ -534,15 +633,22 @@ class Communicator {
///
/// @param remoteRank The rank of the remote process.
/// @param tag The tag of the connection for identifying it.
/// @param transport The type of transport to be used.
/// @param ibMaxCqSize The maximum number of completion queue entries for IB. Unused if transport is not IB.
/// @param ibMaxCqPollNum The maximum number of completion queue entries to poll for IB. Unused if transport is not
/// IB.
/// @param ibMaxSendWr The maximum number of outstanding send work requests for IB. Unused if transport is not IB.
/// @param ibMaxWrPerSend The maximum number of work requests per send for IB. Unused if transport is not IB.
/// @return std::shared_ptr<Connection> A shared pointer to the connection.
std::shared_ptr<Connection> connectOnSetup(int remoteRank, int tag, Transport transport, int ibMaxCqSize = 1024,
int ibMaxCqPollNum = 1, int ibMaxSendWr = 8192, int ibMaxWrPerSend = 64);
/// @param config The configuration for the local endpoint.
/// @return NonblockingFuture<NonblockingFuture<std::shared_ptr<Connection>>> A non-blocking future of shared pointer
/// to the connection.
NonblockingFuture<std::shared_ptr<Connection>> connectOnSetup(int remoteRank, int tag, EndpointConfig localConfig);

/// Get the remote rank a connection is connected to.
///
/// @param connection The connection to get the remote rank for.
/// @return The remote rank the connection is connected to.
int remoteRankOf(const Connection& connection);

/// Get the tag a connection was made with.
///
/// @param connection The connection to get the tag for.
/// @return The tag the connection was made with.
int tagOf(const Connection& connection);

/// Add a custom Setuppable object to a list of objects to be setup later, when @ref setup() is called.
///
Expand All @@ -556,12 +662,12 @@ class Communicator {
/// that have been registered after the (n-1)-th call.
void setup();

friend class RegisteredMemory::Impl;
friend class IBConnection;

private:
/// Unique pointer to the implementation of the Communicator class.
std::unique_ptr<Impl> pimpl;
// The interal implementation.
struct Impl;

// Pointer to the internal implementation.
std::unique_ptr<Impl> pimpl_;
};

/// A constant TransportFlags object representing no transports.
Expand Down
8 changes: 5 additions & 3 deletions python/examples/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,16 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
remote_memories.append(remote_mem)
comm.setup()

connections = [conn.get() for conn in connections]

# Create simple proxy channels
for i, conn in enumerate(connections):
proxy_channel = mscclpp.SimpleProxyChannel(
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(conn)),
proxy_service.proxy_channel(proxy_service.build_and_add_semaphore(comm, conn)),
proxy_service.add_memory(remote_memories[i].get()),
proxy_service.add_memory(reg_mem),
)
simple_proxy_channels.append(mscclpp.device_handle(proxy_channel))
simple_proxy_channels.append(proxy_channel.device_handle())
comm.setup()

# Create sm channels
Expand All @@ -66,7 +68,7 @@ def setup_connections(comm, rank, world_size, element_size, proxy_service):
for i, conn in enumerate(sm_semaphores):
sm_chan = mscclpp.SmChannel(sm_semaphores[i], remote_memories[i].get(), ptr)
sm_channels.append(sm_chan)
return simple_proxy_channels, [mscclpp.device_handle(sm_chan) for sm_chan in sm_channels]
return simple_proxy_channels, [sm_chan.device_handle() for sm_chan in sm_channels]


def run(rank, args):
Expand Down
Loading