Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Simple protocol plans using scratch buffer #371

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
24 changes: 11 additions & 13 deletions src/executor/execution_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ namespace mscclpp {

template <typename PacketType>
void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, void* src, void* dst, void* scratch,
size_t scratchSize, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag) {
DataType dataType, DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag) {
switch (dataType) {
case DataType::INT32:
executionKernel<int32_t, PacketType><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, scratchSize, plan, flag
rank, (int32_t*)src, (int32_t*)dst, (int32_t*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -23,7 +23,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::UINT32:
executionKernel<uint32_t><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, scratchSize, plan, flag
rank, (uint32_t*)src, (uint32_t*)dst, (uint32_t*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -33,7 +33,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::FLOAT16:
executionKernel<half><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (half*)src, (half*)dst, (half*)scratch, scratchSize, plan, flag
rank, (half*)src, (half*)dst, (half*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -43,7 +43,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::FLOAT32:
executionKernel<float><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (float*)src, (float*)dst, (float*)scratch, scratchSize, plan, flag
rank, (float*)src, (float*)dst, (float*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -53,7 +53,7 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
break;
case DataType::BFLOAT16:
executionKernel<__bfloat16><<<nthreadblocks, nthreads, sharedMemSize, stream>>>(
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, scratchSize, plan, flag
rank, (__bfloat16*)src, (__bfloat16*)dst, (__bfloat16*)scratch, plan, flag
#if defined(ENABLE_NPKIT)
,
NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp());
Expand All @@ -65,12 +65,10 @@ void ExecutionKernel::launchKernel(int rank, int nthreadblocks, int nthreads, vo
}

template void ExecutionKernel::launchKernel<LL16Packet>(int rank, int nthreadblocks, int nthreads, void* src, void* dst,
void* scratch, size_t scratchSize, DataType dataType,
DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag);
void* scratch, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
template void ExecutionKernel::launchKernel<LL8Packet>(int rank, int nthreadblocks, int nthreads, void* src, void* dst,
void* scratch, size_t scratchSize, DataType dataType,
DeviceExecutionPlan* plan, size_t sharedMemSize,
cudaStream_t stream, uint32_t flag);
void* scratch, DataType dataType, DeviceExecutionPlan* plan,
size_t sharedMemSize, cudaStream_t stream, uint32_t flag);
} // namespace mscclpp
#endif
65 changes: 49 additions & 16 deletions src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ auto getOpType = [](const std::string& str) {
return mscclpp::OperationType::WAIT;
} else if (str == "flush") {
return mscclpp::OperationType::FLUSH;
} else if (str == "re") {
} else if (str == "reduce") {
return mscclpp::OperationType::REDUCE;
} else if (str == "rs") {
return mscclpp::OperationType::REDUCE_SEND;
Expand Down Expand Up @@ -100,7 +100,7 @@ std::vector<ChannelInfo> ExecutionPlan::Impl::getChannelInfos(int rank, BufferTy
}

std::vector<ChannelInfo> ExecutionPlan::Impl::getChannelInfosByDstRank(int rank, BufferType bufferType) const {
auto pred = [rank, bufferType](const ChannelInfo& info) { return info.dstBufferType == bufferType; };
auto pred = [bufferType](const ChannelInfo& info) { return info.dstBufferType == bufferType; };
return filter(this->channelInfosByDstRank.at(rank), pred);
}

Expand Down Expand Up @@ -148,7 +148,8 @@ std::vector<BufferType> ExecutionPlan::Impl::getConnectedBufferTypes(int rank) c
}
return std::vector<BufferType>(bufferTypes.begin(), bufferTypes.end());
}
size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, size_t outputSize) const {

void ExecutionPlan::Impl::calcScratchBufferSizeAndOffset(int rank, size_t inputSize, size_t outputSize, int flag) {
size_t sizePerRank;
if (this->inputChunks.at(rank) != 0)
sizePerRank = inputSize / this->inputChunks.at(rank);
Expand All @@ -157,11 +158,18 @@ size_t ExecutionPlan::Impl::getScratchBufferSize(int rank, size_t inputSize, siz
else
throw mscclpp::Error("Output or Input chunks must be greater than 0", mscclpp::ErrorCode::ExecutorError);

this->scratchBufferSize = sizePerRank * this->scratchChunks.at(rank);
this->scratchBufferOffset = (this->isUsingDoubleScratchBuffer && (flag % 2) == 0) ? this->scratchBufferSize : 0;
if (this->isUsingPacket) {
return sizePerRank * this->scratchChunks.at(rank) * 2 /* data + flag*/ * 2 /*double buffer*/;
this->scratchBufferSize *= 2; /* data + flag */
}
if (this->isUsingDoubleScratchBuffer) {
this->scratchBufferSize *= 2; /* double buffer */
}
return sizePerRank * this->scratchChunks.at(rank);
}

size_t ExecutionPlan::Impl::getScratchBufferSize() const { return this->scratchBufferSize; }

std::vector<Operation> ExecutionPlan::Impl::getOperations(int rank, int threadblock) const {
return this->operations.at(rank)[threadblock];
}
Expand All @@ -170,8 +178,9 @@ int ExecutionPlan::Impl::getThreadblockCount(int rank) const { return this->oper

int ExecutionPlan::Impl::getNThreadsPerBlock() const { return this->nThreadsPerBlock; }

void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
size_t constDstOffset) {
void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize, size_t constSrcOffset,
size_t constDstOffset, int selfRank, size_t inputBufferSize,
size_t outputBufferSize, int flag) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
Expand All @@ -182,6 +191,7 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,
this->isUsingPacket = true;
}
this->nThreadsPerBlock = obj.value("num_threads_per_block", 1024);
this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"];
const auto& gpus = obj["gpus"];

for (const auto& gpu : gpus) {
Expand All @@ -195,11 +205,13 @@ void ExecutionPlan::Impl::loadExecutionPlan(size_t inputSize, size_t outputSize,

this->inputSize = inputSize;
this->outputSize = outputSize;
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
this->calcScratchBufferSizeAndOffset(selfRank, inputBufferSize, outputBufferSize, flag);
this->setupOperations(gpus, constSrcOffset, constDstOffset);
}

void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t contsSrcOffset,
size_t constDstOffset) {
void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t outputSize, size_t constSrcOffset,
size_t constDstOffset, int selfRank, size_t inputBufferSize,
size_t outputBufferSize, int flag) {
std::ifstream file(this->planPath);
json obj = json::parse(file);
if (this->name != obj["name"]) {
Expand All @@ -209,6 +221,7 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output
if (protocol == "LL") {
this->isUsingPacket = true;
}
this->isUsingDoubleScratchBuffer = obj["use_double_scratch_buffer"];
const auto& gpus = obj["gpus"];

for (const auto& gpu : gpus) {
Expand All @@ -221,7 +234,8 @@ void ExecutionPlan::Impl::lightLoadExecutionPlan(size_t inputSize, size_t output

this->inputSize = inputSize;
this->outputSize = outputSize;
this->setupOperations(gpus, contsSrcOffset, constDstOffset);
this->calcScratchBufferSizeAndOffset(selfRank, inputBufferSize, outputBufferSize, flag);
this->setupOperations(gpus, constSrcOffset, constDstOffset);
}

// Construct the channel info. Step 1. Flatten SM and PROXY channels into separate vectors.
Expand Down Expand Up @@ -291,7 +305,16 @@ void ExecutionPlan::Impl::setupChannels(const json& gpus) {
}
}

void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffset, size_t constDstOffset) {
void ExecutionPlan::Impl::checkChannelsPerOperation(int channels) {
if (channels > MAX_CHANNEL_PER_OPERATION) {
throw Error("Executor plan has " + std::to_string(channels) +
" channels per operation, exceeding executor support (" +
std::to_string(MAX_CHANNEL_PER_OPERATION) + ")",
ErrorCode::ExecutorError);
}
}

void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffset, size_t constDstOffset) {
// setup threadblocks and operations
for (const auto& gpu : gpus) {
int rank = gpu["id"];
Expand All @@ -318,6 +341,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
}
if (op.contains("i_cids")) {
operation.nInputs = op["i_cids"].size();
checkChannelsPerOperation(operation.nInputs);
for (int i = 0; i < operation.nInputs; i++) {
BufferType srcBufferType = convertToBufferType(op["i_buff"]["src"]);
BufferType dstBufferType = convertToBufferType(op["i_buff"]["dst"]);
Expand All @@ -326,42 +350,45 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["i_cids"][i]["id"]];
operation.inputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["i_cids"][i]["off"]) +
(srcBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
(srcBufferType != BufferType::SCRATCH ? constSrcOffset : this->scratchBufferOffset);
chunkIndexes.push_back((uint32_t)op["i_cids"][i]["off"]);
}
}
// will have either srcs or i_cids
if (op.contains("srcs")) {
operation.nInputs = op["srcs"].size();
checkChannelsPerOperation(operation.nInputs);
operation.inputBufferType = convertToBufferType(op["srcs"][0]["buff"]);
for (int i = 0; i < operation.nInputs; i++) {
operation.inputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcs"][i]["off"]) +
(operation.inputBufferType != BufferType::SCRATCH ? contsSrcOffset : 0);
(operation.inputBufferType != BufferType::SCRATCH ? constSrcOffset : this->scratchBufferOffset);
chunkIndexes.push_back((uint32_t)op["srcs"][i]["off"]);
}
}
if (op.contains("o_cids")) {
operation.nOutputs = op["o_cids"].size();
checkChannelsPerOperation(operation.nOutputs);
for (int i = 0; i < operation.nOutputs; i++) {
BufferType srcBufferType = convertToBufferType(op["o_buff"]["src"]);
BufferType dstBufferType = convertToBufferType(op["o_buff"]["dst"]);
operation.outputChannelIndexes[i] =
channelIndexes[{srcBufferType, dstBufferType, operation.channelType}][op["o_cids"][i]["id"]];
operation.outputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["o_cids"][i]["off"]) +
(dstBufferType != BufferType::SCRATCH ? constDstOffset : 0);
(dstBufferType != BufferType::SCRATCH ? constDstOffset : this->scratchBufferOffset);
chunkIndexes.push_back((uint32_t)op["o_cids"][i]["off"]);
}
}
// will have either dsts or o_cids
if (op.contains("dsts")) {
operation.nOutputs = op["dsts"].size();
checkChannelsPerOperation(operation.nOutputs);
operation.outputBufferType = convertToBufferType(op["dsts"][0]["buff"]);
for (int i = 0; i < operation.nOutputs; i++) {
operation.outputOffsets[i] =
this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dsts"][i]["off"]) +
(operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : 0);
(operation.outputBufferType != BufferType::SCRATCH ? constDstOffset : this->scratchBufferOffset);
chunkIndexes.push_back((uint32_t)op["dsts"][i]["off"]);
}
}
Expand All @@ -370,13 +397,19 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t contsSrcOffse
}
if (op.contains("srcoff")) {
operation.srcOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["srcoff"]);
if (operation.srcBufferType == BufferType::SCRATCH) {
operation.srcOffset += this->scratchBufferOffset;
}
chunkIndexes.push_back((uint32_t)op["srcoff"]);
}
if (op.contains("dstbuff")) {
operation.dstBufferType = convertToBufferType(op["dstbuff"]);
}
if (op.contains("dstoff")) {
operation.dstOffset = this->getOffset(rank, this->inputSize, this->outputSize, (uint32_t)op["dstoff"]);
if (operation.dstBufferType == BufferType::SCRATCH) {
operation.dstOffset += this->scratchBufferOffset;
}
chunkIndexes.push_back((uint32_t)op["dstoff"]);
}
if (op.contains("cnt")) {
Expand Down
Loading
Loading