-
Notifications
You must be signed in to change notification settings - Fork 212
Tuner timeout #486
base: master
Are you sure you want to change the base?
Tuner timeout #486
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,10 @@ | |
namespace tc { | ||
std::mutex nvrtc_mutex; | ||
|
||
CudaRTCFunction::CudaRTCFunction() {} | ||
CudaRTCFunction::CudaRTCFunction() { | ||
TC_CUDA_RUNTIMEAPI_ENFORCE( | ||
cudaMalloc((void**)&startTimeDev, sizeof(unsigned long long))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using the clock instructions should make this unnecessary |
||
} | ||
|
||
CudaRTCFunction::~CudaRTCFunction() { | ||
if (!cleared_) { | ||
|
@@ -43,6 +46,7 @@ void CudaRTCFunction::clear() { | |
WithCudaDevice(kvp.first); | ||
TC_CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(kvp.second)); | ||
} | ||
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaFree((void*)startTimeDev)); | ||
cleared_ = true; | ||
} | ||
} | ||
|
@@ -136,7 +140,9 @@ Duration CudaRTCFunction::Launch( | |
std::vector<long> params, | ||
std::vector<void*> outputs, | ||
std::vector<const void*> inputs, | ||
uint32_t timeout, | ||
bool profile) const { | ||
uint64_t timeoutInNs = timeout * 1000 * 1000; | ||
int dev; | ||
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaGetDevice(&dev)); | ||
if (perGpuModule_.count(dev) == 0) { | ||
|
@@ -152,11 +158,19 @@ Duration CudaRTCFunction::Launch( | |
|
||
constexpr size_t kNumMaxParameters = 100; | ||
std::array<void*, kNumMaxParameters> args_voidp{0}; | ||
CHECK_GE(kNumMaxParameters, params.size() + outputs.size() + inputs.size()); | ||
CHECK_GE( | ||
kNumMaxParameters, | ||
params.size() + outputs.size() + inputs.size() + (timeout != 0)); | ||
int ind = 0; | ||
for (auto& p : params) { | ||
args_voidp[ind++] = &p; | ||
} | ||
if (timeout != 0) { | ||
args_voidp[ind++] = | ||
const_cast<void*>(static_cast<const void*>(&startTimeDev)); | ||
args_voidp[ind++] = | ||
const_cast<void*>(static_cast<const void*>(&timeoutInNs)); | ||
} | ||
for (auto& o : outputs) { | ||
args_voidp[ind++] = &o; | ||
} | ||
|
@@ -171,6 +185,15 @@ Duration CudaRTCFunction::Launch( | |
unsigned int bx = block[0]; | ||
unsigned int by = block[1]; | ||
unsigned int bz = block[2]; | ||
if (timeout != 0) { | ||
unsigned long long startTime = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. significant, unnecessary complexity once you use |
||
TC_CUDA_RUNTIMEAPI_ENFORCE(cudaMemcpy( | ||
(void*)startTimeDev, | ||
(void*)&startTime, | ||
sizeof(unsigned long long), | ||
cudaMemcpyHostToDevice)); | ||
} | ||
|
||
auto launch = [&]() { | ||
TC_CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel( | ||
perGpuKernel_.at(dev), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,13 +57,15 @@ class CudaRTCFunction { | |
std::vector<long> params, | ||
std::vector<void*> outputs, | ||
std::vector<const void*> inputs, | ||
uint32_t timeout, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. definitely not, we JIT compile kernels, let's hardcode the timeouts instead of increasing the complexity of the flow |
||
bool profile = false) const; | ||
|
||
void clear(); | ||
|
||
private: | ||
mutable std::unordered_map<size_t, CUmodule> perGpuModule_; | ||
mutable std::unordered_map<size_t, CUfunction> perGpuKernel_; | ||
unsigned long long* startTimeDev; | ||
std::string specializedName; | ||
std::vector<char> nvrtc_ptx; | ||
bool cleared_; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,13 @@ constexpr auto defines = R"C( | |
#define inf __longlong_as_double(0x7ff0000000000000LL) | ||
)C"; | ||
|
||
constexpr auto timestampFunction = R"C( | ||
__device__ unsigned long long __timestamp() { | ||
unsigned long long startns; | ||
asm volatile("mov.u64 %0,%%globaltimer;" : "=l"(startns)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since you mention NVIDIA explicitly states we should not use globaltime, why not use something based on |
||
return startns; | ||
})C"; | ||
|
||
constexpr auto warpSyncFunctions = R"C( | ||
// Before CUDA 9, syncwarp is a noop since warps are always synchronized. | ||
#if __CUDACC_VER_MAJOR__ < 9 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see a good justification to adding debug type information to the protos, I would kill this with fire.