Skip to content

Commit

Permalink
Fix & improve perf for ROCm (#232)
Browse files Browse the repository at this point in the history
Co-authored-by: Binyang Li <[email protected]>
  • Loading branch information
chhwang and Binyang2014 authored Dec 18, 2023
1 parent 5a9998b commit 5ff8bc5
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 15 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,15 @@ include(${PROJECT_SOURCE_DIR}/cmake/AddFormatTargets.cmake)
# Find ibverbs and libnuma
find_package(IBVerbs REQUIRED)
find_package(NUMA REQUIRED)
find_package(Threads REQUIRED)

add_library(mscclpp_obj OBJECT)
target_include_directories(mscclpp_obj
PRIVATE
${GPU_INCLUDE_DIRS}
${IBVERBS_INCLUDE_DIRS}
${NUMA_INCLUDE_DIRS})
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES})
target_link_libraries(mscclpp_obj PRIVATE ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads)
set_target_properties(mscclpp_obj PROPERTIES LINKER_LANGUAGE CXX POSITION_INDEPENDENT_CODE 1 VERSION ${MSCCLPP_VERSION} SOVERSION ${MSCCLPP_SOVERSION})
if(USE_CUDA)
target_compile_definitions(mscclpp_obj PRIVATE USE_CUDA)
Expand Down
10 changes: 5 additions & 5 deletions include/mscclpp/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
#ifndef MSCCLPP_DEVICE_HPP_
#define MSCCLPP_DEVICE_HPP_

#if defined(__HIP_PLATFORM_AMD__)
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
#include <hip/hip_runtime.h>
#endif // defined(__HIP_PLATFORM_AMD__)
#endif // defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)

#if (defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__))

#define MSCCLPP_DEVICE_COMPILE
#define MSCCLPP_DEVICE_INLINE __forceinline__ __device__
#define MSCCLPP_HOST_DEVICE_INLINE __forceinline__ __host__ __device__
#if defined(__HIP_PLATFORM_AMD__)
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
#define MSCCLPP_DEVICE_HIP
#else // !defined(__HIP_PLATFORM_AMD__)
#else // !(defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1))
#define MSCCLPP_DEVICE_CUDA
#endif // !defined(__HIP_PLATFORM_AMD__)
#endif // !(defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1))

#else // !(defined(__NVCC__) || defined(__HIP_PLATFORM_AMD__))

Expand Down
6 changes: 3 additions & 3 deletions include/mscclpp/fifo_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ struct FifoDeviceHandle {
#if defined(MSCCLPP_DEVICE_CUDA)
asm volatile("st.global.relaxed.sys.v2.u64 [%0], {%1,%2};" ::"l"(triggerPtr), "l"(trigger.fst), "l"(trigger.snd));
#else // !defined(MSCCLPP_DEVICE_CUDA)
// TODO: both atomic and clang built-ins are buggy here
triggerPtr->fst = trigger.fst;
triggerPtr->snd = trigger.snd;
// store snd no later than fst.
atomicStore(&(triggerPtr->snd), trigger.snd, memoryOrderRelaxed);
atomicStore(&(triggerPtr->fst), trigger.fst, memoryOrderRelaxed);
#endif // !defined(MSCCLPP_DEVICE_CUDA)

return curFifoHead;
Expand Down
2 changes: 1 addition & 1 deletion include/mscclpp/gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#ifndef MSCCLPP_GPU_HPP_
#define MSCCLPP_GPU_HPP_

#if defined(__HIP_PLATFORM_AMD__)
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)

#include <hip/hip_runtime.h>

Expand Down
2 changes: 0 additions & 2 deletions include/mscclpp/packet_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ union alignas(16) LLPacket {
#else // !defined(MSCCLPP_DEVICE_CUDA)
uint4 reg = make_uint4(val1, flag, val2, flag);
ulonglong2* p = reinterpret_cast<ulonglong2*>(&reg);
// TODO: clang built-ins are buggy here
atomicStore(&(raw_.x), p->x, memoryOrderRelaxed);
atomicStore(&(raw_.y), p->y, memoryOrderRelaxed);
#endif
Expand All @@ -65,7 +64,6 @@ union alignas(16) LLPacket {
return (flag1 != flag) || (flag2 != flag);
#else // !defined(MSCCLPP_DEVICE_CUDA)
ulonglong2 reg;
// TODO: clang built-ins are buggy here
reg.x = atomicLoad(&(raw_.x), memoryOrderRelaxed);
reg.y = atomicLoad(&(raw_.y), memoryOrderRelaxed);
uint4* ptr = reinterpret_cast<uint4*>(&reg);
Expand Down
2 changes: 1 addition & 1 deletion include/mscclpp/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct Timer {

~Timer();

/// Returns the elapsed time in milliseconds.
/// Returns the elapsed time in microseconds.
int64_t elapsed() const;

void set(int timeout);
Expand Down
2 changes: 1 addition & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

find_package(MPI)

set(TEST_LIBS_COMMON mscclpp ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES})
set(TEST_LIBS_COMMON mscclpp ${GPU_LIBRARIES} ${NUMA_LIBRARIES} ${IBVERBS_LIBRARIES} Threads::Threads)
set(TEST_LIBS_GTEST GTest::gtest_main GTest::gmock_main)
set(TEST_INC_COMMON PRIVATE ${PROJECT_SOURCE_DIR}/include ${GPU_INCLUDE_DIRS})
set(TEST_INC_INTERNAL PRIVATE ${PROJECT_SOURCE_DIR}/src/include)
Expand Down
2 changes: 1 addition & 1 deletion test/allgather_test_cpp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ __device__ void localAllGather(DeviceHandle<mscclpp::SimpleProxyChannel> proxyCh
if ((remoteRank % nranksPerNode) == ((rank - i + nranksPerNode) % nranksPerNode)) {
if ((threadIdx.x % 32) == 0) proxyChan.wait();
}
#if defined(__HIP_PLATFORM_AMD__)
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
// NOTE: we actually need a group barrier here for better performance, but __syncthreads() is still correct.
__syncthreads();
#else
Expand Down
101 changes: 101 additions & 0 deletions test/mscclpp-test/allreduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,68 @@ __device__ void localReduceScatterSm2(int* buff, int rank, int nRanksPerNode, si
}
}

__device__ void localReduceScatterSm3(int* buff, int rank, int nRanksPerNode, size_t chunkSize, size_t nelems,
int nBlocks) {
if (nRanksPerNode == 1) return;
if ((int)blockIdx.x >= nBlocks) return;
const int nPeer = nRanksPerNode - 1;
DeviceHandle<mscclpp::SmChannel>* smChans = constSmOutOfPlaceGetChans;

const size_t localRankIndexInNode = rank % nRanksPerNode;
const size_t indexOffset = localRankIndexInNode * chunkSize;
const size_t indexOffset4 = indexOffset / 4;

int4* buff4 = (int4*)buff;

const int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < nPeer) {
smChans[tid].signal();
}
const int waitStart = nBlocks * blockDim.x - nPeer;
if (tid >= waitStart && tid < (int)(nBlocks * blockDim.x)) {
smChans[tid - waitStart].wait();
}
reduceScatterDeviceSyncer.sync(nBlocks);

const size_t nInt4 = nelems / 4;

size_t base = 0;
const size_t unitNInt4 = blockDim.x * nBlocks;
for (; base + unitNInt4 < nInt4; base += unitNInt4) {
for (int index = 0; index < nPeer; ++index) {
int4 val;
int peerIdx = (index + localRankIndexInNode) % nPeer;
for (size_t idx = base + threadIdx.x + blockIdx.x * blockDim.x; idx < base + unitNInt4;
idx += blockDim.x * nBlocks) {
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
buff4[indexOffset4 + idx].w += val.w;
buff4[indexOffset4 + idx].x += val.x;
buff4[indexOffset4 + idx].y += val.y;
buff4[indexOffset4 + idx].z += val.z;
}
}
}
for (int index = 0; index < nPeer; ++index) {
int4 val;
int peerIdx = (index + localRankIndexInNode) % nPeer;
for (size_t idx = base + threadIdx.x + blockIdx.x * blockDim.x; idx < nInt4; idx += blockDim.x * nBlocks) {
val = smChans[peerIdx].read<int4>(indexOffset4 + idx);
buff4[indexOffset4 + idx].w += val.w;
buff4[indexOffset4 + idx].x += val.x;
buff4[indexOffset4 + idx].y += val.y;
buff4[indexOffset4 + idx].z += val.z;
}
}

const size_t nLastInts = nelems % 4;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
for (size_t idx = threadIdx.x + blockIdx.x * blockDim.x; idx < nLastInts; idx += blockDim.x * nBlocks) {
int val = smChans[(localRankIndexInNode + peerIdx) % nPeer].read<int>(indexOffset + nInt4 * 4 + idx);
buff[indexOffset + nInt4 * 4 + idx] += val;
}
}
}

__device__ void reduceScatterSm(int* buff, int* scratch, int rank, int nRanksPerNode, int worldSize,
size_t nelems // must be divisible by 3
) {
Expand Down Expand Up @@ -520,6 +582,39 @@ __device__ void localRingAllGatherSm(int rank, int nRanksPerNode, uint64_t size,
}
}

__device__ void localRingAllGatherSm2(size_t rank, size_t nRanksPerNode, size_t size, size_t nBlocks) {
if (nRanksPerNode == 1) return;
if (blockIdx.x >= nBlocks) return;

size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
const size_t nPeer = nRanksPerNode - 1;

if (tid < nPeer) {
constSmInPlaceChans[tid].signal();
}
size_t waitStart = nBlocks * blockDim.x - nPeer;
if (tid >= waitStart && tid < nBlocks * blockDim.x) {
constSmInPlaceChans[tid - waitStart].wait();
}
allGatherDeviceSyncer.sync(nBlocks);
const size_t unitSize = 16 * blockDim.x * nBlocks;
size_t base = 0;
for (; base + unitSize < size; base += unitSize) {
for (size_t i = 0; i < nPeer; ++i) {
size_t peerIdx = (i + rank) % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
size_t offset = size * remoteRankLocalIndex + base;
constSmInPlaceChans[peerIdx].get(offset, unitSize, tid, blockDim.x * nBlocks);
}
}
for (size_t i = 0; i < nPeer; ++i) {
size_t peerIdx = (i + rank) % nPeer;
const size_t remoteRankLocalIndex = (peerIdx < rank ? peerIdx : peerIdx + 1);
size_t offset = size * remoteRankLocalIndex + base;
constSmInPlaceChans[peerIdx].get(offset, size - base, tid, blockDim.x * nBlocks);
}
}

// This is an allgather4 equivalent
__device__ void allGatherSm(int rank, int worldSize, int nRanksPerNode, size_t nelemsPerGPU) {
// this allgather is a pipelined and hierarchical one and only works for two nodes
Expand Down Expand Up @@ -861,9 +956,15 @@ __global__ void allreduce4(int* buff, int* scratch, int rank, int nRanksPerNode,
}

__global__ void allreduce5(int* buff, int rank, int nRanksPerNode, int worldSize, size_t nelems) {
#if defined(__HIP_PLATFORM_AMD__) && (__HIP_PLATFORM_AMD__ == 1)
localReduceScatterSm3(buff, rank, nRanksPerNode, nelems / worldSize, nelems / worldSize, gridDim.x);
deviceSyncer.sync(gridDim.x);
localRingAllGatherSm2(rank, nRanksPerNode, nelems / worldSize * sizeof(int), gridDim.x);
#else
localReduceScatterSm2(buff, rank, nRanksPerNode, nelems / worldSize, nelems / worldSize, gridDim.x);
deviceSyncer.sync(gridDim.x);
localRingAllGatherSm(rank, nRanksPerNode, nelems / worldSize * sizeof(int), gridDim.x);
#endif
}

__global__ void allreduce6(int* buff, int* scratch, void* resultBuff, int rank, int nRanksPerNode, int worldSize,
Expand Down

0 comments on commit 5ff8bc5

Please sign in to comment.