Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renaming channels #436

Merged
merged 48 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
75a2ac5
Tackle build warnings
chhwang Dec 19, 2024
b8a1360
Add `gpuMemAlloc` function
chhwang Dec 19, 2024
1d1d07d
Merge branch 'main' into chhwang/malloc
chhwang Dec 20, 2024
78da4a4
Refine `gpu_utils.hpp`
chhwang Dec 20, 2024
fbd0f6c
Fixes
chhwang Jan 2, 2025
08815b9
Merge branch 'main' into chhwang/malloc
chhwang Jan 2, 2025
434898c
fix
chhwang Jan 2, 2025
35b4598
Fix a test
chhwang Jan 2, 2025
24d883e
Fix npkit
chhwang Jan 2, 2025
0862717
Merge branch 'main' into chhwang/malloc
chhwang Jan 3, 2025
a99522d
Python interface
chhwang Jan 3, 2025
fcf8392
Merge branch 'main' into chhwang/malloc
chhwang Jan 3, 2025
11c2bf9
lint
chhwang Jan 3, 2025
4f80f7c
Fix names & make uncached alloc available only on AMD
chhwang Jan 3, 2025
a5c3653
lint
chhwang Jan 3, 2025
1f72e94
SmChannel to MemoryChannel
chhwang Jan 3, 2025
afc9d20
Fix NVLS memory allocation
chhwang Jan 4, 2025
6a5ce05
Merge branch 'main' into chhwang/malloc
chhwang Jan 4, 2025
3abd219
lint
chhwang Jan 4, 2025
ec4452a
Merge branch 'main' into chhwang/malloc
chhwang Jan 5, 2025
b2a17cd
Merge branch 'main' into chhwang/rename-channels
chhwang Jan 5, 2025
ea0c4a3
pipeline fix
chhwang Jan 4, 2025
3f9c653
C++ class
chhwang Jan 6, 2025
35fa922
Add `gpuMemAlloc` function
chhwang Dec 19, 2024
6fb486b
Refine `gpu_utils.hpp`
chhwang Dec 20, 2024
a78e1a8
Fixes
chhwang Jan 2, 2025
7228d13
fix
chhwang Jan 2, 2025
6ac313d
Fix a test
chhwang Jan 2, 2025
18259c9
Fix npkit
chhwang Jan 2, 2025
d0e9f49
Python interface
chhwang Jan 3, 2025
1d425b9
lint
chhwang Jan 3, 2025
12a0ac4
Fix names & make uncached alloc available only on AMD
chhwang Jan 3, 2025
27e7d96
lint
chhwang Jan 3, 2025
e03ba29
Fix NVLS memory allocation
chhwang Jan 4, 2025
3885ee0
pipeline fix
chhwang Jan 4, 2025
ca06267
C++ class
chhwang Jan 6, 2025
5da9618
more changes
chhwang Jan 6, 2025
8be9d6e
Merge branch 'chhwang/malloc' into chhwang/rename-channels
chhwang Jan 6, 2025
3305264
lint
chhwang Jan 6, 2025
22edc9b
ProxyChannel to PortChannel
chhwang Jan 6, 2025
38b27c5
Fix
chhwang Jan 6, 2025
22c3a9b
Merge branch 'chhwang/malloc' into chhwang/rename-channels
chhwang Jan 6, 2025
6c0ce7a
Merge branch 'main' into chhwang/rename-channels
chhwang Jan 23, 2025
b995bc0
Merge branch 'main' into chhwang/rename-channels
chhwang Jan 23, 2025
204bbd4
more updates
chhwang Jan 23, 2025
66c63eb
fix lint
chhwang Jan 23, 2025
8ccb3a2
fixes
chhwang Jan 23, 2025
79c27f8
BC
chhwang Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ The following highlights key concepts of MSCCL++.
MSCCL++ provides peer-to-peer communication methods between GPUs. A peer-to-peer connection between two GPUs is called a *Channel*. Channels are constructed by MSCCL++ host-side interfaces and copied to GPUs during initialization. Channels provide *GPU-side interfaces*, which means that all communication methods are defined as a device function to be called from a GPU kernel code. For example, the `put()` method in the following example copies 1KB data from the local GPU to a remote GPU.

```cpp
// `ProxyChannel` will be explained in the following section.
__device__ mscclpp::DeviceHandle<mscclpp::ProxyChannel> channel;
// `PortChannel` will be explained in the following section.
__device__ mscclpp::DeviceHandle<mscclpp::PortChannel> channel;
__global__ void gpuKernel() {
...
// Only one thread is needed for this method.
Expand Down Expand Up @@ -79,15 +79,15 @@ __device__ void barrier() {

MSCCL++ provides consistent interfaces, i.e., the above interfaces are used regardless of the location of the remote GPU (either on the local node or on a remote node) or the underlying link (either NVLink/xGMI or InfiniBand).

### ProxyChannel and SmChannel
### PortChannel and MemoryChannel

MSCCL++ delivers two types of channels, **ProxyChannel** and **SmChannel**. `ProxyChannel` provides (R)DMA-based data copy and synchronization methods. When called, these methods send/receive a signal to/from a host-side proxy (hence the name `ProxyChannel`), which will trigger (R)DMA (such as `cudaMemcpy*` or `ibv_post_send`) or issue synchronization methods (such as `cudaStreamSynchronize` or `ibv_poll_cq`). Since the key functionalities are run by the proxy, `ProxyChannel` requires only a single GPU thread to call its methods. See all `ProxyChannel` methods from [here](./include/mscclpp/proxy_channel_device.hpp).
MSCCL++ delivers two types of channels, **PortChannel** and **MemoryChannel**. `PortChannel` provides port-mapping-based data copy and synchronization methods. When called, these methods send/receive a signal to/from a host-side proxy, which will trigger (R)DMA (such as `cudaMemcpy*` or `ibv_post_send`) or issue synchronization methods (such as `cudaStreamSynchronize` or `ibv_poll_cq`). Since the key functionalities are run by the proxy, `PortChannel` requires only a single GPU thread to call its methods. See all `PortChannel` methods from [here](./include/mscclpp/port_channel_device.hpp).

On the other hand, `SmChannel` provides memory-mapping-based copy and synchronization methods. When called, these methods will directly use GPU threads to read/write from/to the remote GPU's memory space. Comparing against `ProxyChannel`, `SmChannel` is especially performant for low-latency scenarios, while it may need many GPU threads to call copying methods at the same time to achieve high copying bandwidth. See all `SmChannel` methods from [here](./include/mscclpp/sm_channel_device.hpp).
On the other hand, `MemoryChannel` provides memory-mapping-based copy and synchronization methods. When called, these methods will directly use GPU threads to read/write from/to the remote GPU's memory space. Comparing against `PortChannel`, `MemoryChannel` is especially performant for low-latency scenarios, while it may need many GPU threads to call copying methods at the same time to achieve high copying bandwidth. See all `MemoryChannel` methods from [here](./include/mscclpp/memory_channel_device.hpp).

### Host-Side Communication Proxy

MSCCL++ provides a default implementation of a host-side proxy for ProxyChannels, which is a background host thread that busy polls triggers from GPUs and conducts functionalities accordingly. For example, the following is a typical host-side code for MSCCL++.
MSCCL++ provides a default implementation of a host-side proxy for PortChannels, which is a background host thread that busy polls triggers from GPUs and conducts functionalities accordingly. For example, the following is a typical host-side code for MSCCL++.

```cpp
// Bootstrap: initialize control-plane connections between all ranks
Expand Down
62 changes: 31 additions & 31 deletions apps/nccl/src/allgather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>
#include <mscclpp/memory_channel.hpp>
#include <mscclpp/memory_channel_device.hpp>

#include "common.hpp"

template <bool IsOutOfPlace>
__global__ void __launch_bounds__(1024, 1)
allgather6(void* sendbuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels, size_t channelOutOffset,
allgather6(void* sendbuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels, size_t channelOutOffset,
size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, size_t nelemsPerGPU) {
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
const size_t lid = tid % WARP_SIZE;
Expand All @@ -24,11 +24,11 @@ __global__ void __launch_bounds__(1024, 1)
const size_t nWarp = nThread / WARP_SIZE;
const size_t nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;
auto smChans = smChannels + chanOffset;
auto memChans = memoryChannels + chanOffset;

if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].wait();
memChans[threadIdx.x].relaxedSignal();
memChans[threadIdx.x].wait();
}
__syncthreads();

Expand All @@ -49,16 +49,16 @@ __global__ void __launch_bounds__(1024, 1)
const size_t peerIdx = wid % nPeer;
const size_t offset = bytesPerGPU * rank + (wid / nPeer) * unitBytesPerWarp;
if constexpr (IsOutOfPlace) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(smChans[peerIdx].src_);
char* dst = reinterpret_cast<char*>(memChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(memChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
const size_t offsetWithinRank = (wid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
smChans[peerIdx].copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
memChans[peerIdx].copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
memChans[peerIdx].copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
} else {
smChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE);
memChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE);
}
}

Expand All @@ -67,16 +67,16 @@ __global__ void __launch_bounds__(1024, 1)
const size_t peerIdx = gWid % nPeer;
const size_t offset = bytesPerGPU * rank + (gWid / nPeer) * unitBytesPerWarp;
if constexpr (IsOutOfPlace) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(smChans[peerIdx].src_);
char* dst = reinterpret_cast<char*>(memChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(memChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp;
smChans[peerIdx].copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
smChans[peerIdx].copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
memChans[peerIdx].copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
memChans[peerIdx].copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid,
WARP_SIZE);
} else {
smChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE);
memChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE);
}
}

Expand All @@ -90,30 +90,30 @@ __global__ void __launch_bounds__(1024, 1)
: unitBytesPerWarp;
if (remainBytes > 0) {
if constexpr (IsOutOfPlace) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(smChans[peerIdx].src_);
char* dst = reinterpret_cast<char*>(memChans[peerIdx].dst_);
char* src = reinterpret_cast<char*>(memChans[peerIdx].src_);
char* buff = reinterpret_cast<char*>(sendbuff);
smChans[peerIdx].copy<16, true>(src + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid,
WARP_SIZE);
smChans[peerIdx].copy<16, true>(dst + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid,
WARP_SIZE);
memChans[peerIdx].copy<16, true>(src + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid,
WARP_SIZE);
memChans[peerIdx].copy<16, true>(dst + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid,
WARP_SIZE);
} else {
smChans[peerIdx].put<16, true>(offset + channelOutOffset, remainBytes, lid, WARP_SIZE);
memChans[peerIdx].put<16, true>(offset + channelOutOffset, remainBytes, lid, WARP_SIZE);
}
}
}

deviceSyncer.sync(gridDim.x);

if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].wait();
memChans[threadIdx.x].relaxedSignal();
memChans[threadIdx.x].wait();
}
}

template <bool IsOutOfPlace, typename T>
cudaError_t allgather(T* buff, [[maybe_unused]] T* scratch, [[maybe_unused]] T* resultBuff,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels, size_t channelOutOffset, int rank,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels, size_t channelOutOffset, int rank,
int nRanksPerNode, int worldSize, size_t nelems, cudaStream_t stream) {
int nBlocks = 28;
if (nelems <= 4096) {
Expand All @@ -123,7 +123,7 @@ cudaError_t allgather(T* buff, [[maybe_unused]] T* scratch, [[maybe_unused]] T*
} else if (nelems >= 2097152) {
nBlocks = 35;
}
allgather6<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, smChannels, channelOutOffset, rank, worldSize,
allgather6<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, memoryChannels, channelOutOffset, rank, worldSize,
nRanksPerNode, nelems * sizeof(T) / sizeof(int));
return cudaGetLastError();
}
Expand Down
50 changes: 25 additions & 25 deletions apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
#include <mscclpp/core.hpp>
#include <mscclpp/gpu.hpp>
#include <mscclpp/gpu_data_types.hpp>
#include <mscclpp/memory_channel.hpp>
#include <mscclpp/memory_channel_device.hpp>
#include <mscclpp/packet_device.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>

#if defined(ENABLE_NPKIT)
#include <mscclpp/npkit/npkit.hpp>
Expand Down Expand Up @@ -196,7 +196,7 @@ __forceinline__ __device__ void vectorSum(T* dst, T* src, size_t nElem) {

template <typename T>
__global__ void __launch_bounds__(32, 1)
allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
allreduceAllToAll(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, uint32_t flag) {
// This version of allreduce only works for single nodes
Expand All @@ -213,10 +213,10 @@ __global__ void __launch_bounds__(32, 1)
uint32_t* src = (uint32_t*)((char*)buff);
uint32_t* dst = (uint32_t*)((char*)resultBuff);

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> channels[NRANKS_PER_NODE - 1];
__shared__ mscclpp::DeviceHandle<mscclpp::MemoryChannel> channels[NRANKS_PER_NODE - 1];
const int lid = tid % WARP_SIZE;
if (lid < nPeers) {
channels[lid] = smChannels[lid];
channels[lid] = memoryChannels[lid];
}
__syncwarp();

Expand All @@ -240,7 +240,7 @@ __global__ void __launch_bounds__(32, 1)

template <typename T>
__global__ void __launch_bounds__(1024, 1)
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
allreduce7(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
size_t channelDataOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, uint32_t flag
#if defined(ENABLE_NPKIT)
Expand Down Expand Up @@ -304,10 +304,10 @@ __global__ void __launch_bounds__(1024, 1)
uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int));

// Put channels into shared memory, read channel info from global memory is unexpectable slow.
__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> channels[NRANKS_PER_NODE - 1];
__shared__ mscclpp::DeviceHandle<mscclpp::MemoryChannel> channels[NRANKS_PER_NODE - 1];
const int lid = tid % WARP_SIZE;
if (lid < nPeers) {
channels[lid] = smChannels[lid];
channels[lid] = memoryChannels[lid];
}
__syncwarp();

Expand Down Expand Up @@ -361,16 +361,16 @@ __global__ void __launch_bounds__(1024, 1)

template <typename T>
__global__ void __launch_bounds__(512, 1)
allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
allreduce8(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryOutChannels, size_t channelOutDataOffset,
size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
const int nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;
// assume (nelems * sizeof(T)) is divisible by (16 * worldSize)
const size_t nInt4 = nelems * sizeof(T) / sizeof(int4);
const size_t nInt4PerRank = nInt4 / worldSize;
auto smChans = smChannels + chanOffset;
auto smOutChans = smOutChannels + chanOffset;
auto memoryChans = memoryChannels + chanOffset;
auto memoryOutChans = memoryOutChannels + chanOffset;

int4* buff4 = reinterpret_cast<int4*>(buff);
int4* scratch4 = reinterpret_cast<int4*>((char*)scratch + channelScratchOffset);
Expand All @@ -396,12 +396,12 @@ __global__ void __launch_bounds__(512, 1)
const size_t scratchChunkRankOffset = chunkSizePerRank * rank;
const size_t scratchBaseOffsetInt4 = channelScratchOffset / sizeof(int4);

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> channels[NRANKS_PER_NODE - 1];
__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> outChannels[NRANKS_PER_NODE - 1];
__shared__ mscclpp::DeviceHandle<mscclpp::MemoryChannel> channels[NRANKS_PER_NODE - 1];
__shared__ mscclpp::DeviceHandle<mscclpp::MemoryChannel> outChannels[NRANKS_PER_NODE - 1];
const int lid = threadIdx.x % WARP_SIZE;
if (lid < nPeer) {
channels[lid] = smChans[lid];
outChannels[lid] = smOutChans[lid];
channels[lid] = memoryChans[lid];
outChannels[lid] = memoryOutChans[lid];
}
__syncwarp();

Expand Down Expand Up @@ -496,18 +496,18 @@ __global__ void __launch_bounds__(512, 1)
}

template <typename T>
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryOutChannels, size_t channelInOffset,
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, cudaStream_t stream) {
static uint32_t flag = 1;

if (sizeof(T) * nelems < worldSize * sizeof(int)) {
int nBlocks = 7;
int nThreadsPerBlock = 32;
allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
channelScratchOffset, rank, nRanksPerNode, worldSize,
nelems, flag++);
allreduceAllToAll<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, memoryChannels,
channelInOffset, channelScratchOffset, rank,
nRanksPerNode, worldSize, nelems, flag++);
} else if (sizeof(T) * nelems <= (1 << 20)) {
int nBlocks = 28;
int nThreadsPerBlock = 1024;
Expand All @@ -518,17 +518,17 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
#if defined(ENABLE_NPKIT)
size_t NpkitSharedMemSize = NPKIT_SHM_NUM_EVENTS * sizeof(NpKitEvent);
allreduce7<<<nBlocks, nThreadsPerBlock, NpkitSharedMemSize, stream>>>(
buff, scratch, resultBuff, smChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode, worldSize,
nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
buff, scratch, resultBuff, memoryChannels, channelInOffset, channelScratchOffset, rank, nRanksPerNode,
worldSize, nelems, flag++, NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
#else
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, channelInOffset,
allreduce7<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, memoryChannels, channelInOffset,
channelScratchOffset, rank, nRanksPerNode, worldSize, nelems,
flag++);
#endif
} else {
int nBlocks = 35;
int nThreadsPerBlock = 512;
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels,
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, memoryChannels, memoryOutChannels,
channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
worldSize, nelems);
}
Expand Down
Loading
Loading