Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-stream CUDA IPC #326

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/mscclpp/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
#ifndef MSCCLPP_DEVICE_HPP_
#define MSCCLPP_DEVICE_HPP_

#if defined(__HIP_PLATFORM_AMD__)
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
#include <hip/hip_runtime.h>
#endif // defined(__HIP_PLATFORM_AMD__)
#endif // defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)

#if (defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__))

#define MSCCLPP_DEVICE_COMPILE
#define MSCCLPP_DEVICE_INLINE __forceinline__ __device__
#define MSCCLPP_HOST_DEVICE_INLINE __forceinline__ __host__ __device__
#if defined(__HIP_PLATFORM_AMD__)
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
#define MSCCLPP_DEVICE_HIP
#else // !(defined(__HIP_PLATFORM_AMD__)
#else // !(defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1))
#define MSCCLPP_DEVICE_CUDA
#endif // !(defined(__HIP_PLATFORM_AMD__))
#endif // !(defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1))

#else // !(defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__))

Expand Down
2 changes: 1 addition & 1 deletion include/mscclpp/gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#ifndef MSCCLPP_GPU_HPP_
#define MSCCLPP_GPU_HPP_

#if defined(__HIP_PLATFORM_AMD__)
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)

#include <hip/hip_runtime.h>

Expand Down
10 changes: 7 additions & 3 deletions src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(Endpoint localEndpo
if (remoteEndpoint.transport() != Transport::CudaIpc) {
throw mscclpp::Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
}
if (!(pimpl_->ipcStream_)) {
pimpl_->ipcStream_ = std::make_shared<CudaStreamWithFlags>(cudaStreamNonBlocking);
#if defined(__HIP_PLATFORM_AMD__)
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaStreamWithFlags>(cudaStreamNonBlocking));
#else
if (pimpl_->ipcStreams_.empty()) {
pimpl_->ipcStreams_.emplace_back(std::make_shared<CudaStreamWithFlags>(cudaStreamNonBlocking));
}
conn = std::make_shared<CudaIpcConnection>(localEndpoint, remoteEndpoint, cudaStream_t(*(pimpl_->ipcStream_)));
#endif
conn = std::make_shared<CudaIpcConnection>(localEndpoint, remoteEndpoint, *(pimpl_->ipcStreams_.back()));
} else if (AllIBTransports.has(localEndpoint.transport())) {
if (!AllIBTransports.has(remoteEndpoint.transport())) {
throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
Expand Down
1 change: 1 addition & 0 deletions src/include/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace mscclpp {
struct Context::Impl {
std::vector<std::shared_ptr<Connection>> connections_;
std::unordered_map<Transport, std::unique_ptr<IbCtx>> ibContexts_;
std::vector<std::shared_ptr<CudaStreamWithFlags>> ipcStreams_;
std::shared_ptr<CudaStreamWithFlags> ipcStream_;
CUmemGenericAllocationHandle mcHandle_;

Expand Down
40 changes: 23 additions & 17 deletions test/allgather_test_cpp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ static int nranksPerNode = 8;
} \
} while (false)

#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif

// Measure current time in second.
static double getTime(void) {
struct timespec tspec;
Expand All @@ -47,14 +53,14 @@ __device__ void allgather0(DeviceHandle<mscclpp::ProxyChannel> proxyChan, int ra

// this thread's role is a sender role
// put your data asynchronously
if ((threadIdx.x % 32) == 0) proxyChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.putWithSignal(rank * nelemsPerGPU * sizeof(int), nelemsPerGPU * sizeof(int));
// make sure everyone is put their data before some thread randomly blocks everyone else in signal
__syncthreads();
// push with flag and sync to make sure the data is received
if ((threadIdx.x % 32) == 0) proxyChan.flush();
if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.flush();

// this thread's role is a receiver role. wait on the semaphore to make sure the data is ready
if ((threadIdx.x % 32) == 0) proxyChan.wait();
if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.wait();
}

__device__ void localAllGather(DeviceHandle<mscclpp::ProxyChannel> proxyChan, int rank, int nranksPerNode,
Expand All @@ -68,17 +74,17 @@ __device__ void localAllGather(DeviceHandle<mscclpp::ProxyChannel> proxyChan, in
for (int i = 1; i < nranksPerNode; i++) {
if ((remoteRank % nranksPerNode) == ((rank + i) % nranksPerNode)) {
// put your data to GPU (rank+i) % nranksPerNode and signal in one call
if ((threadIdx.x % 32) == 0) proxyChan.putWithSignal(offset, size);
if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.putWithSignal(offset, size);
}
// wait for the data from GPU (rank-i) % nranksPerNode to arrive
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) {
if ((threadIdx.x % 32) == 0) proxyChan.wait();
if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.wait();
}
#if defined(__HIP_PLATFORM_AMD__)
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
// NOTE: we actually need a group barrier here for better performance, but __syncthreads() is still correct.
__syncthreads();
#else
asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * 32) : "memory");
asm volatile("bar.sync %0, %1;" ::"r"(11), "r"((nranksPerNode - 1) * WARP_SIZE) : "memory");
#endif
}
}
Expand All @@ -88,7 +94,7 @@ __device__ void allgather1(DeviceHandle<mscclpp::ProxyChannel> proxyChan, int ra
localAllGather(proxyChan, rank, nranksPerNode, remoteRank, rank * nelemsPerGPU * sizeof(int),
nelemsPerGPU * sizeof(int));
if (remoteRank / nranksPerNode == rank / nranksPerNode)
if ((threadIdx.x % 32) == 0) proxyChan.flush();
if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.flush();
}

__device__ void allgather2(DeviceHandle<mscclpp::ProxyChannel> proxyChan, int rank, int world_size, int nranksPerNode,
Expand All @@ -114,10 +120,10 @@ __device__ void allgather2(DeviceHandle<mscclpp::ProxyChannel> proxyChan, int ra
// cross-node exchange
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
// opposite side
if ((threadIdx.x % 32) == 0)
if ((threadIdx.x % WARP_SIZE) == 0)
proxyChan.putWithSignal(rank * nelemsPerGPU * sizeof(int),
(nelemsPerGPU * (pipelineSize - 1)) / pipelineSize * sizeof(int));
if ((threadIdx.x % 32) == 0) proxyChan.wait();
if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.wait();
}

__syncthreads();
Expand All @@ -133,10 +139,10 @@ __device__ void allgather2(DeviceHandle<mscclpp::ProxyChannel> proxyChan, int ra
// cross-node exchange
if (remoteRank % nranksPerNode == rank % nranksPerNode) {
// opposite side
if ((threadIdx.x % 32) == 0)
if ((threadIdx.x % WARP_SIZE) == 0)
proxyChan.putWithSignal((rank * nelemsPerGPU + (nelemsPerGPU * (pipelineSize - 1)) / pipelineSize) * sizeof(int),
nelemsPerGPU / pipelineSize * sizeof(int));
if ((threadIdx.x % 32) == 0) proxyChan.wait();
if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.wait();
}

__syncthreads();
Expand All @@ -150,13 +156,13 @@ __device__ void allgather2(DeviceHandle<mscclpp::ProxyChannel> proxyChan, int ra
}

if (remoteRank / nranksPerNode == rank / nranksPerNode || remoteRank % nranksPerNode == rank % nranksPerNode) {
if ((threadIdx.x % 32) == 0) proxyChan.flush();
if ((threadIdx.x % WARP_SIZE) == 0) proxyChan.flush();
}
}

__global__ void kernel(int rank, int world_size, int nranksPerNode, size_t nelemsPerGPU, int kernel) {
// find the mapping between remoteRank and proxyChans
int warpId = threadIdx.x / 32;
int warpId = threadIdx.x / WARP_SIZE;
int remoteRank = (warpId < rank) ? warpId : warpId + 1;
// Each warp is responsible for one of the remote ranks
DeviceHandle<mscclpp::ProxyChannel> proxyChan = constProxyChans[warpId];
Expand Down Expand Up @@ -410,7 +416,7 @@ int main(int argc, const char* argv[]) {
cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaDeviceSynchronize());
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
kernel<<<1, WARP_SIZE * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
CUDACHECK(cudaDeviceSynchronize());
CUDACHECK(cudaMemcpy(data_h, data_d, dataSize, cudaMemcpyDeviceToHost));

Expand All @@ -432,7 +438,7 @@ int main(int argc, const char* argv[]) {
CUDACHECK(cudaStreamSynchronize(stream));
bootstrap->allGather(tmp, sizeof(int));
for (int i = 0; i < iterwithoutcudagraph; ++i) {
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
kernel<<<1, WARP_SIZE * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
}
CUDACHECK(cudaStreamSynchronize(stream));
bootstrap->allGather(tmp, sizeof(int));
Expand All @@ -444,7 +450,7 @@ int main(int argc, const char* argv[]) {
cudaGraphExec_t instance;
CUDACHECK(cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal));
for (int i = 0; i < cudagraphiter; ++i) {
kernel<<<1, 32 * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
kernel<<<1, WARP_SIZE * (world_size - 1), 0, stream>>>(rank, world_size, nranksPerNode, nelemsPerGPU, kernelNum);
}
CUDACHECK(cudaStreamEndCapture(stream, &graph));
CUDACHECK(cudaGraphInstantiate(&instance, graph, NULL, NULL, 0));
Expand Down
Loading