Skip to content
This repository has been archived by the owner on Apr 28, 2023. It is now read-only.

Tuner timeout #486

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tc/core/compiler-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
options);
return std::unique_ptr<typename Backend::ExecutorType>(
new typename Backend::ExecutorType(
inputsInfo, outputsInfo, halideComponents, compilationResult));
inputsInfo,
outputsInfo,
halideComponents,
compilationResult,
options));
}
} // namespace tc
2 changes: 2 additions & 0 deletions tc/core/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,7 @@ constexpr auto kWriteIdName = "write";
constexpr auto kSyncIdPrefix = "_sync_";
constexpr auto kWarpSyncIdPrefix = "_warpSync_";

constexpr auto kTimeoutCheckPrefix = "_timeoutCheck_";

} // namespace polyhedral
} // namespace tc
3 changes: 2 additions & 1 deletion tc/core/cpu/cpu_tc_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ CpuTcExecutor::CpuTcExecutor(
const std::vector<TensorInfo>& inputsInfo,
const std::vector<TensorInfo>& outputsInfo,
const tc2halide::HalideComponents& halideComponents,
const typename CpuBackend::CompilationResultType& compilationResult)
const typename CpuBackend::CompilationResultType& compilationResult,
const typename CpuBackend::MappingOptionsType& options)
: TcExecutor<CpuBackend>(
inputsInfo,
outputsInfo,
Expand Down
3 changes: 2 additions & 1 deletion tc/core/cpu/cpu_tc_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class CpuTcExecutor : public TcExecutor<CpuBackend> {
const std::vector<TensorInfo>& inputsInfo,
const std::vector<TensorInfo>& outputsInfo,
const tc2halide::HalideComponents& halideComponents,
const typename CpuBackend::CompilationResultType& compilationResult);
const typename CpuBackend::CompilationResultType& compilationResult,
const typename CpuBackend::MappingOptionsType& options);

/// This is the "low-latency" mode in which we just propagate raw pointers to
/// data in the address space where kernel is executed.
Expand Down
8 changes: 7 additions & 1 deletion tc/core/cuda/cuda_mapping_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ CudaMappingOptions& CudaMappingOptions::useReadOnlyCache(bool b) {
return *this;
}

CudaMappingOptions& CudaMappingOptions::timeout(uint32_t ms) {
ownedProto_.set_timeout(ms);
return *this;
}

CudaMappingOptions& CudaMappingOptions::mapToThreads(
const std::string& commaSeparatedSizes) {
auto sizes = parseCommaSeparatedIntegers<uint64_t>(commaSeparatedSizes);
Expand Down Expand Up @@ -318,7 +323,8 @@ CudaMappingOptions CudaMappingOptions::makeUnmappedMappingOptions() {
.useSharedMemory(false)
.usePrivateMemory(false)
.unrollCopyShared(false)
.useReadOnlyCache(false);
.useReadOnlyCache(false)
.timeout(FLAGS_timeout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a good justification to adding debug type information to the protos, I would kill this with fire.

return mo;
}

Expand Down
5 changes: 5 additions & 0 deletions tc/core/cuda/cuda_mapping_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ class CudaMappingOptions {
CudaMappingOptions& useReadOnlyCache(bool b);
///@}

/// Change kernel timeout
///@{
CudaMappingOptions& timeout(uint32_t ms);
///@}

/// Static constructors for predefined strategies.
///@{
static CudaMappingOptions makeNaiveMappingOptions();
Expand Down
3 changes: 3 additions & 0 deletions tc/core/cuda/cuda_mapping_options_cpp_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ CudaMappingOptionsCppPrinter& operator<<(
prn.printValueOption(
"maxSharedMemory", cudaOptions.proto().max_shared_memory());
}
if (cudaOptions.proto().has_timeout()) {
prn.printValueOption("timeout", cudaOptions.proto().timeout());
}
prn.endStmt();
return prn;
}
Expand Down
27 changes: 25 additions & 2 deletions tc/core/cuda/cuda_rtc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
namespace tc {
std::mutex nvrtc_mutex;

CudaRTCFunction::CudaRTCFunction() {}
CudaRTCFunction::CudaRTCFunction() {
TC_CUDA_RUNTIMEAPI_ENFORCE(
cudaMalloc((void**)&startTimeDev, sizeof(unsigned long long)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using the clock instructions should make this unnecessary

}

CudaRTCFunction::~CudaRTCFunction() {
if (!cleared_) {
Expand All @@ -43,6 +46,7 @@ void CudaRTCFunction::clear() {
WithCudaDevice(kvp.first);
TC_CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(kvp.second));
}
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaFree((void*)startTimeDev));
cleared_ = true;
}
}
Expand Down Expand Up @@ -136,7 +140,9 @@ Duration CudaRTCFunction::Launch(
std::vector<long> params,
std::vector<void*> outputs,
std::vector<const void*> inputs,
uint32_t timeout,
bool profile) const {
uint64_t timeoutInNs = timeout * 1000 * 1000;
int dev;
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaGetDevice(&dev));
if (perGpuModule_.count(dev) == 0) {
Expand All @@ -152,11 +158,19 @@ Duration CudaRTCFunction::Launch(

constexpr size_t kNumMaxParameters = 100;
std::array<void*, kNumMaxParameters> args_voidp{0};
CHECK_GE(kNumMaxParameters, params.size() + outputs.size() + inputs.size());
CHECK_GE(
kNumMaxParameters,
params.size() + outputs.size() + inputs.size() + (timeout != 0));
int ind = 0;
for (auto& p : params) {
args_voidp[ind++] = &p;
}
if (timeout != 0) {
args_voidp[ind++] =
const_cast<void*>(static_cast<const void*>(&startTimeDev));
args_voidp[ind++] =
const_cast<void*>(static_cast<const void*>(&timeoutInNs));
}
for (auto& o : outputs) {
args_voidp[ind++] = &o;
}
Expand All @@ -171,6 +185,15 @@ Duration CudaRTCFunction::Launch(
unsigned int bx = block[0];
unsigned int by = block[1];
unsigned int bz = block[2];
if (timeout != 0) {
unsigned long long startTime = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

significant, unnecessary complexity once you use clock

TC_CUDA_RUNTIMEAPI_ENFORCE(cudaMemcpy(
(void*)startTimeDev,
(void*)&startTime,
sizeof(unsigned long long),
cudaMemcpyHostToDevice));
}

auto launch = [&]() {
TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
perGpuKernel_.at(dev),
Expand Down
2 changes: 2 additions & 0 deletions tc/core/cuda/cuda_rtc.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ class CudaRTCFunction {
std::vector<long> params,
std::vector<void*> outputs,
std::vector<const void*> inputs,
uint32_t timeout,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely not, we JIT compile kernels, let's hardcode the timeouts instead of increasing the complexity of the flow

bool profile = false) const;

void clear();

private:
mutable std::unordered_map<size_t, CUmodule> perGpuModule_;
mutable std::unordered_map<size_t, CUfunction> perGpuKernel_;
unsigned long long* startTimeDev;
std::string specializedName;
std::vector<char> nvrtc_ptx;
bool cleared_;
Expand Down
10 changes: 7 additions & 3 deletions tc/core/cuda/cuda_tc_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ CudaTcExecutor::CudaTcExecutor(
const std::vector<TensorInfo>& inputsInfo,
const std::vector<TensorInfo>& outputsInfo,
const tc2halide::HalideComponents& halideComponents,
const typename CudaBackend::CompilationResultType& compilationResult)
const typename CudaBackend::CompilationResultType& compilationResult,
const typename CudaBackend::MappingOptionsType& options)
: TcExecutor<CudaBackend>(
inputsInfo,
outputsInfo,
halideComponents,
compilationResult) {
compilationResult),
timeout_(options.proto().timeout()) {
auto t0 = std::chrono::high_resolution_clock::now();
// force unloading in case we JIT with the same name/input/outputs with
// different options.
Expand Down Expand Up @@ -121,7 +123,8 @@ void CudaTcExecutor::uncheckedRun(
info.stream,
parameters_,
outputs,
inputs);
inputs,
timeout_);
}

ProfilingInfo CudaTcExecutor::profileUnchecked(
Expand All @@ -140,6 +143,7 @@ ProfilingInfo CudaTcExecutor::profileUnchecked(
parameters_,
outputs,
inputs,
timeout_,
true));
// The CPU overhead is the total time minus the (synchronized) kernel runtime
Duration cpuOverhead(Duration::since(start));
Expand Down
5 changes: 4 additions & 1 deletion tc/core/cuda/cuda_tc_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class CudaTcExecutor : public TcExecutor<CudaBackend> {
const std::vector<TensorInfo>& inputsInfo,
const std::vector<TensorInfo>& outputsInfo,
const tc2halide::HalideComponents& halideComponents,
const typename CudaBackend::CompilationResultType& compilationResult);
const typename CudaBackend::CompilationResultType& compilationResult,
const typename CudaBackend::MappingOptionsType& options);

/// This is the "low-latency" mode in which we just propagate raw pointers to
/// data in the address space where kernel is executed.
Expand Down Expand Up @@ -63,5 +64,7 @@ class CudaTcExecutor : public TcExecutor<CudaBackend> {
// GPU-specific results of compilation
Grid grid_;
Block block_;

uint32_t timeout_;
};
} // namespace tc
6 changes: 6 additions & 0 deletions tc/core/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ DEFINE_bool(
DEFINE_bool(dump_cuda, false, "Print the generated source");
DEFINE_bool(dump_ptx, false, "Dump the generated PTX");

DEFINE_uint32(
timeout_check_frequency,
100,
"The minimum number of loop iterations between two timeout checks");
DEFINE_uint32(timeout, 0, "The cuda kernel timeout in ms");

// CPU codegen options
DEFINE_bool(llvm_dump_before_opt, false, "Print IR before optimization");
DEFINE_bool(llvm_dump_after_opt, false, "Print IR after optimization");
Expand Down
4 changes: 4 additions & 0 deletions tc/core/flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ DECLARE_bool(debug_tuner);
DECLARE_bool(dump_cuda);
DECLARE_bool(dump_ptx);

// Cuda timeout
DECLARE_uint32(timeout_check_frequency);
DECLARE_uint32(timeout);

// llvm codegen
DECLARE_bool(llvm_dump_before_opt);
DECLARE_bool(llvm_dump_after_opt);
Expand Down
7 changes: 7 additions & 0 deletions tc/core/libraries.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ constexpr auto defines = R"C(
#define inf __longlong_as_double(0x7ff0000000000000LL)
)C";

constexpr auto timestampFunction = R"C(
__device__ unsigned long long __timestamp() {
unsigned long long startns;
asm volatile("mov.u64 %0,%%globaltimer;" : "=l"(startns));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you mention NVIDIA explicitly states we should not use globaltime, why not use something based on clock64, see an example here. We can get the frequency from the device property and turn that into time. This should also be significantly lower overhead than fectching data from global memory.

return startns;
})C";

constexpr auto warpSyncFunctions = R"C(
// Before CUDA 9, syncwarp is a noop since warps are always synchronized.
#if __CUDACC_VER_MAJOR__ < 9
Expand Down
45 changes: 39 additions & 6 deletions tc/core/polyhedral/cuda/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,24 @@ struct AstPrinter {
bool inReduction_ = false;
};

vector<string> emitParams(const Scop& scop) {
vector<string> emitParams(const MappedScop& mappedScop) {
vector<string> res;
const auto& scop = mappedScop.scop();
res.reserve(scop.halide.params.size());
// Halide params. One of these two vectors will be empty.
for (auto p : scop.halide.params) {
stringstream ss;
ss << p.type() << " " << p.name();
res.push_back(ss.str());
}
if (mappedScop.useTimeout != 0) {
stringstream ssStartTime;
ssStartTime << "unsigned long long* startTime";
res.push_back(ssStartTime.str());
stringstream ssTimeout;
ssTimeout << "unsigned long long timeout";
res.push_back(ssTimeout.str());
}
return res;
}

Expand Down Expand Up @@ -136,9 +145,10 @@ vector<string> emitTypedTensorNames(const vector<Halide::ImageParam>& tensors) {
return res;
}

void emitArgs(stringstream& ss, const Scop& scop) {
void emitArgs(stringstream& ss, const MappedScop& mappedScop) {
// Order is: params, outs, ins
auto sigVec = emitParams(scop);
const auto& scop = mappedScop.scop();
auto sigVec = emitParams(mappedScop);
sigVec = sigVec + emitTypedTensorNames(scop.halide.outputs);
sigVec = sigVec + emitTypedTensorNames(scop.halide.inputs);
for (auto& s : sigVec) {
Expand All @@ -152,10 +162,10 @@ void emitArgs(stringstream& ss, const Scop& scop) {
void emitKernelSignature(
stringstream& ss,
const std::string& specializedName,
const Scop& scop) {
const MappedScop& mappedScop) {
CHECK_NE(specializedName, "") << "name not provided";
ss << "__global__ void " << specializedName << "(";
emitArgs(ss, scop);
emitArgs(ss, mappedScop);
ss << ") {" << endl;
}

Expand Down Expand Up @@ -452,6 +462,10 @@ void AstPrinter::emitStmt(isl::ast_node_user node) {
} else if (
stmtId.get_name() == kReadIdName || stmtId.get_name() == kWriteIdName) {
emitCopyStmt(statementContext);
} else if (context_.scop().isTimeoutCheckId(stmtId)) {
context_.ss << "if(__timestamp() - startns > timeout) {\n";
context_.ss << ws.tab() << ws.tab() << "return;\n";
context_.ss << ws.tab() << "}" << std::endl;
} else { // regular statement
auto mappedStmtId = statementContext.statementId();
CHECK_EQ(stmtId, mappedStmtId)
Expand Down Expand Up @@ -668,6 +682,22 @@ void emitThreadIdInit(stringstream& ss, const MappedScop& scop) {
ss << "int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z;\n";
}

void emitTimestampInit(stringstream& ss) {
WS ws;
ss << ws.tab();
ss << "unsigned long long startns = __timestamp();\n";
ss << ws.tab();
ss << "unsigned long long old_startns = startns;\n";
ss << ws.tab();
ss << "old_startns = atomicCAS(startTime, 0, startns);\n";
ss << ws.tab();
ss << "if(old_startns < startns && startns - old_startns > timeout && old_startns != 0) {\n";
ss << ws.tab() << ws.tab();
ss << "return;\n";
ss << ws.tab();
ss << "}\n";
}

void emitTmpDecl(stringstream& ss, const Scop& scop) {
for (const auto& kvp : scop.treeSyncUpdateMap) {
WS ws;
Expand Down Expand Up @@ -752,12 +782,15 @@ string emitCudaKernel(
}

stringstream ss;
emitKernelSignature(ss, specializedName, scop);
emitKernelSignature(ss, specializedName, mscop);
emitThreadIdInit(ss, mscop);
emitTensorViews(ss, scop.halide.outputs, paramValues);
emitTensorViews(ss, scop.halide.inputs, paramValues);
emitTmpDecl(ss, scop);
emitPromotedArrayViewsHalide(ss, scop);
if (mscop.useTimeout) {
emitTimestampInit(ss);
}
NodeInfoMapType nodeInfoMap;
auto collect = [&nodeInfoMap](
isl::ast_node n, isl::ast_build b) -> isl::ast_node {
Expand Down
Loading