Skip to content

Commit

Permalink
Double buffering for NCCL APIs (#324)
Browse files Browse the repository at this point in the history
Using two scratch buffers in each peer to exchange data.

---------

Co-authored-by: Changho Hwang <[email protected]>
  • Loading branch information
caiomcbr and chhwang authored Jul 15, 2024
1 parent 5f9ee27 commit 7493e2f
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 212 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- name: Run cpplint
run: |
CPPSOURCES=$(find ./src ./include ./python ./test -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)')
CPPSOURCES=$(find ./src ./include ./python ./test ./apps -regextype posix-extended -regex '.*\.(c|cpp|h|hpp|cc|cxx|cu)')
clang-format -style=file --verbose --Werror --dry-run ${CPPSOURCES}
pylint:
Expand Down
264 changes: 137 additions & 127 deletions apps/nccl/include/nccl.h

Large diffs are not rendered by default.

32 changes: 19 additions & 13 deletions apps/nccl/src/allgather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

template <bool IsOutOfPlace>
__global__ void __launch_bounds__(1024, 1)
allgather6(void* sendbuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels, size_t channelOutOffset, size_t rank,
[[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) {
allgather6(void* sendbuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels, size_t channelOutOffset,
size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) {
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
const size_t lid = tid % WARP_SIZE;
const size_t wid = tid / WARP_SIZE;
Expand Down Expand Up @@ -53,8 +53,10 @@ __global__ void __launch_bounds__(1024, 1)
char* src = reinterpret_cast<char*>(smChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
const size_t offsetWithinRank = (wid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE);
smChans[peerIdx].copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE);
smChans[peerIdx].copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
smChans[peerIdx].copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
} else {
smChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE);
}
Expand All @@ -69,8 +71,10 @@ __global__ void __launch_bounds__(1024, 1)
char* src = reinterpret_cast<char*>(smChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE);
smChans[peerIdx].copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, WARP_SIZE);
smChans[peerIdx].copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
smChans[peerIdx].copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
} else {
smChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE);
}
Expand All @@ -89,8 +93,10 @@ __global__ void __launch_bounds__(1024, 1)
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(smChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
smChans[peerIdx].copy<16, true>(src + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE);
smChans[peerIdx].copy<16, true>(dst + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE);
smChans[peerIdx].copy<16, true>(src + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid,
WARP_SIZE);
smChans[peerIdx].copy<16, true>(dst + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid,
WARP_SIZE);
} else {
smChans[peerIdx].put<16, true>(offset + channelOutOffset, remainBytes, lid, WARP_SIZE);
}
Expand All @@ -100,11 +106,11 @@ __global__ void __launch_bounds__(1024, 1)

template <bool IsOutOfPlace, typename T>
cudaError_t allgather(T* buff, [[maybe_unused]] T* scratch, [[maybe_unused]] T* resultBuff,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels, size_t channelOutOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, cudaStream_t stream) {
allgather6<IsOutOfPlace><<<28, 1024, 0, stream>>>((void*)buff, smChannels, channelOutOffset, rank, worldSize, nRanksPerNode,
nelems * sizeof(T) / sizeof(int));
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels, size_t channelOutOffset, int rank,
int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream) {
allgather6<IsOutOfPlace><<<28, 1024, 0, stream>>>((void*)buff, smChannels, channelOutOffset, rank, worldSize,
nRanksPerNode, nelems * sizeof(T) / sizeof(int));
return cudaGetLastError();
}

#endif // ALLGATHER_HPP_
#endif // ALLGATHER_HPP_
48 changes: 24 additions & 24 deletions apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ __forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem) {
template <typename T>
__global__ void __launch_bounds__(32, 1)
allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
size_t channelDataOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems,
uint32_t flag) {
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, uint32_t flag) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
if (sizeof(T) == 2) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int);
Expand All @@ -142,11 +142,9 @@ __global__ void __launch_bounds__(32, 1)
const int localBlockIdx = blockIdx.x % nBlocksPerPeer;
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
const int peerIdx = blockIdx.x / nBlocksPerPeer;
// Double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : 4 * worldSize * nelems * sizeof(mscclpp::LL8Packet);
size_t srcOffset = channelDataOffset;
size_t scratchOffset = scratchBaseOffset + rank * nelems * sizeof(mscclpp::LL8Packet);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = channelScratchOffset + rank * nelems * sizeof(mscclpp::LL8Packet);
void* scratchBuff = (void*)((char*)scratch + channelScratchOffset);
uint32_t* src = (uint32_t*)((char*)buff);
uint32_t* dst = (uint32_t*)((char*)resultBuff);

Expand Down Expand Up @@ -178,7 +176,8 @@ __global__ void __launch_bounds__(32, 1)
template <typename T>
__global__ void __launch_bounds__(1024, 1)
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
size_t channelDataOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems, uint32_t flag) {
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, uint32_t flag) {
// This version of allreduce only works for single nodes
if (worldSize != nRanksPerNode) return;
nelems = nelems / (sizeof(int) / sizeof(T));
Expand All @@ -192,12 +191,9 @@ __global__ void __launch_bounds__(1024, 1)
const int peerIdx = blockIdx.x / nBlocksPerPeer;
const int remoteRank = peerIdx < rank ? peerIdx : peerIdx + 1;
const int tid = threadIdx.x + localBlockIdx * blockDim.x;
// double buffering
size_t scratchBaseOffset = (flag & 1) ? 0 : nPkts * sizeof(mscclpp::LL8Packet);
void* scratchBuff = (void*)((char*)scratch + scratchBaseOffset);
size_t scratchOffset = scratchBaseOffset + rank * nPktsPerRank * sizeof(mscclpp::LL8Packet);
size_t scratchResultOffset =
(flag & 1) ? 2 * nPkts * sizeof(mscclpp::LL8Packet) : 3 * nPkts * sizeof(mscclpp::LL8Packet);
void* scratchBuff = (void*)((char*)scratch + channelScratchOffset);
size_t scratchOffset = channelScratchOffset + rank * nPktsPerRank * sizeof(mscclpp::LL8Packet);
size_t scratchResultOffset = channelScratchOffset + 2 * nPkts * sizeof(mscclpp::LL8Packet);
size_t srcOffset = remoteRank * nelemsPerRank * sizeof(int) + channelDataOffset;
uint32_t* src = (uint32_t*)((char*)buff + rank * nelemsPerRank * sizeof(int));
uint32_t* dst = (uint32_t*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));
Expand Down Expand Up @@ -246,8 +242,8 @@ __global__ void __launch_bounds__(1024, 1)
template <typename T>
__global__ void __launch_bounds__(512, 1)
allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset, int rank,
int nRanksPerNode, int worldSize, size_t nelems) {
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
const int nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;
// assume (nelems * sizeof(T)) is divisible by (16 * worldSize)
Expand All @@ -257,7 +253,7 @@ __global__ void __launch_bounds__(512, 1)
auto smOutChans = smOutChannels + chanOffset;

int4* buff4 = reinterpret_cast<int4*>(buff);
int4* scratch4 = reinterpret_cast<int4*>(scratch);
int4* scratch4 = reinterpret_cast<int4*>((char*)scratch + channelScratchOffset);
int4* resultBuff4 = reinterpret_cast<int4*>(resultBuff);

// Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4`
Expand All @@ -278,6 +274,7 @@ __global__ void __launch_bounds__(512, 1)
const size_t chunkSizePerRank = nNeededBlocks * nInt4PerChunk;
const size_t blockOffset = nInt4PerChunk * blockIdx.x;
const size_t scratchChunkRankOffset = chunkSizePerRank * rank;
const size_t scratchBaseOffsetInt4 = channelScratchOffset / sizeof(int4);

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> channels[NRANKS_PER_NODE - 1];
__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> outChannels[NRANKS_PER_NODE - 1];
Expand All @@ -301,7 +298,7 @@ __global__ void __launch_bounds__(512, 1)
const int peerIdx = (i + blockIdx.x) % nPeer;
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock];
channels[peerIdx].write(scratchChunkRankOffset + blockOffset + idx, val);
channels[peerIdx].write(scratchBaseOffsetInt4 + scratchChunkRankOffset + blockOffset + idx, val);
}
}

Expand Down Expand Up @@ -338,7 +335,7 @@ __global__ void __launch_bounds__(512, 1)
const int peerIdx = (i + blockIdx.x) % nPeer;
const int remoteRank = (peerIdx < rank) ? peerIdx : peerIdx + 1;
int4 val = buff4[nInt4PerRank * remoteRank + idx + offsetOfThisBlock];
channels[peerIdx].write(scratchChunkRankOffset + blockOffset + idx, val);
channels[peerIdx].write(scratchBaseOffsetInt4 + scratchChunkRankOffset + blockOffset + idx, val);
}
}

Expand Down Expand Up @@ -367,29 +364,32 @@ __global__ void __launch_bounds__(512, 1)
template <typename T>
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
size_t channelOutOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems,
cudaStream_t stream) {
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, cudaStream_t stream) {
static uint32_t flag = 1;

if (sizeof(T) * nelems < worldSize * sizeof(int)) {
int nBlocks = 7;
int nThreadsPerBlock = 32;
allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
rank, nRanksPerNode, worldSize, nelems, flag++);
channelScratchOffset, rank, nRanksPerNode, worldSize,
nelems, flag++);
} else if (sizeof(T) * nelems <= (1 << 20)) {
int nBlocks = 28;
int nThreadsPerBlock = 1024;
if (nelems >= 8192) {
nBlocks = 56;
nThreadsPerBlock = (nelems <= 76800) ? 512 : 1024;
}
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset, rank,
nRanksPerNode, worldSize, nelems, flag++);
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
channelScratchOffset, rank, nRanksPerNode, worldSize, nelems,
flag++);
} else {
int nBlocks = 35;
int nThreadsPerBlock = 512;
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels,
channelOutOffset, rank, nRanksPerNode, worldSize, nelems);
channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
worldSize, nelems);
}

return cudaGetLastError();
Expand Down
4 changes: 2 additions & 2 deletions apps/nccl/src/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
#endif

constexpr int NRANKS_PER_NODE = 8;
constexpr int SCRATCH_SIZE = 1024 * 1024 * 70; // 35 thread-blocks * 8 ranks * 256KB = 70MB
constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB

#endif // NCCL_COMMON_HPP_
#endif // NCCL_COMMON_HPP_
Loading

0 comments on commit 7493e2f

Please sign in to comment.