Skip to content
This repository has been archived by the owner on Apr 28, 2023. It is now read-only.

Tuner timeout #486

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open

Tuner timeout #486

wants to merge 2 commits into from

Conversation

math-fehr
Copy link
Contributor

@math-fehr math-fehr commented Jun 7, 2018

Implements timeout for cuda backend using mapping option.
Also adds a flag to change the default mapping option of the timeout.

The aim of this PR is to allow the use of timeouts in the autotuner, where sometimes kernel of 1s can appear where 5ms can be achieved. As for now, the timeout flag can be used to set a timeout for all
produced kernels in the autotuner.

Closes #394
Tag #381

@ftynse ftynse requested a review from nicolasvasilache June 7, 2018 09:26
Also, added timeout flag to set default timeout to options. However, the flag
does not work when the option is initialized before gflags. The flags set
the timeout in ms.

__timestamp() function is used to get the current timestamp(). This function
should not be used according to nvidia, so it could have a different behavior
in some devices.

To have a timeout in CUDA, the blocks should first now the timestamp of the
kernel launch. To do that, the firsts instruction of every block is to retrieve
the timestamp stored in the global memory. If the value is 0 (it is at the start
of the kernel), the block set the value to the current timestamp. All of that
is done atomically. It might happend that the timestamp stored is not the
lowest timestamp that blocks have computed, but it is close.
After that, timeout checks are inserted in the kernel code, which checks if the
kernel has ran more than n ns. The checks are inserted after some for loops
and some sequences, and the checks are inserted such that between two checks,
there is at least timeout_check_frequency iterations of for loops, to ensure
that checks do not influence much the results. timeout_check_frequency can be
modified by a flag.
constexpr auto timestampFunction = R"C(
__device__ unsigned long long __timestamp() {
unsigned long long startns;
asm volatile("mov.u64 %0,%%globaltimer;" : "=l"(startns));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you mention NVIDIA explicitly states we should not use globaltime, why not use something based on clock64, see an example here. We can get the frequency from the device property and turn that into time. This should also be significantly lower overhead than fectching data from global memory.

@@ -318,7 +323,8 @@ CudaMappingOptions CudaMappingOptions::makeUnmappedMappingOptions() {
.useSharedMemory(false)
.usePrivateMemory(false)
.unrollCopyShared(false)
.useReadOnlyCache(false);
.useReadOnlyCache(false)
.timeout(FLAGS_timeout);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a good justification to adding debug type information to the protos, I would kill this with fire.

CudaRTCFunction::CudaRTCFunction() {}
CudaRTCFunction::CudaRTCFunction() {
TC_CUDA_RUNTIMEAPI_ENFORCE(
cudaMalloc((void**)&startTimeDev, sizeof(unsigned long long)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using the clock instructions should make this unnecessary

@@ -171,6 +185,15 @@ Duration CudaRTCFunction::Launch(
unsigned int bx = block[0];
unsigned int by = block[1];
unsigned int bz = block[2];
if (timeout != 0) {
unsigned long long startTime = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

significant, unnecessary complexity once you use clock

@@ -57,13 +57,15 @@ class CudaRTCFunction {
std::vector<long> params,
std::vector<void*> outputs,
std::vector<const void*> inputs,
uint32_t timeout,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely not, we JIT compile kernels, let's hardcode the timeouts instead of increasing the complexity of the flow

@nicolasvasilache
Copy link
Contributor

The general idea is important and if it already works it is great. However it comes with too much technical debt to my taste. Let's try to:

  1. use the clock instruction instead of inserting new variables in memory and using a feature that is explicitly discouraged by NVIDIA
  2. not change the proto and the executors API for this, we JIT compile code let's hardcode the value directly in the string; there is no value in passing this as an extra parameter in RTC and significantly complexify the compilation flow.
  3. I would just insert the test above thread mapping nodes and call it a day + have a sanity check that a block should not execute more than a certain number of iterations (we can then prune early)

I would much rather a global variable (as much as I hate them) than increase the complexity of the APIs. But maybe with the use of the clock instruction and the heuristics above, we don't even need to turn it off during tuning.

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address comments, thanks!

@math-fehr
Copy link
Contributor Author

  1. I tried with clock function, but it has a main issue: it is multiprocessor dependent, which means that you cannot 'synchronize' the blocks to know when the kernel started. clock allows to have a timeout in a block, while we may want to have a timeout for the kernel. The main difference here is that we can have an order of magnitude of difference. If the firsts blocks timeout, the next blocks will likely timeout, but if they aren't schedule yet on the device, they simply don't know if they should timeout earlier or not. On the other hand, the timestamp function allows a synchronization so each block know when the kernel started at the start of the block.
    But yes, one of the big problem with the timestamp is that we don't currently know if it will react the same way on all GPUs.

  2. While I agree that this code adds a level of complexity that is not wanted, I think that hardcoding the value in the kernel might pose other problems. The idea is that it would be nice to use the timeout in the tuner, and change the timeout over time in a generation to reduce the tuning time for that generation (which can be really long for the first generation). But, to do that, it would be nice to be able to change the value when executing the code, instead of when compiling the code. I implemented it in a way that it would be easy to modify the code to change the timeout when launching the kernel.

  3. Yes, I will insert the test above the thread mapping nodes instead of after.
    I think it is really difficult to know when there is "to much" iterations in a kernel. For instance, in the most extreme cases, there can be one or two orders of magnitude between using uncoalesced global memory and using registers, for the same number of iterations. More than that, the kernel execution time depends a lot on how many blocks are effectively scheduled at the same time on the device, or if the code is well pipelined. To know when to prune early, we would need to have a sort of performance model, which would require too much work I believe.

@nicolasvasilache
Copy link
Contributor

I tried with clock function, but it has a main issue: it is multiprocessor dependent

Should be pretty easy to transform your kernel timeout into a per block timeout: we always know statically the number of SMs and the number of blocks. So something like: ceil((timeout * #sms) / #blocks) should do it. Ok it won't be a perfect approx but I wouldn't hesitate for 1 second killing off the complexity. And for the level of granularity we want to catch with this timeout this seems way sufficient to me

The idea is that it would be nice to use the timeout in the tuner, and change the timeout over time in a generation to reduce the tuning time for that generation

We compile a new version for each set of options at each generation, so I don't see the issue, just impl a dynamic timeout in the tuner and hardcode it. I see that one could object that it would prevent memoizing the generated cuda / ptx during tuning but that is premature optimization IMO. The value of keeping a simple and clean abstraction is orders of magnitude more important.

I think it is really difficult to know when there is "to much" iterations in a kernel

Agreed, OTOH I would argue that we don't need something precise, just something that works "well enough". The whole point here is to skip catastrophically bad cases that are so bad that even our pruning with the chainsaw didn't catch (or at least on the first iteration).

Additionally, @ftynse it seems some of the issues this PR wants to catch is on the first iteration where there is no "best kernel" to compare to and we just end up executing really bad ones multiple times. How about putting in the timeout value in the pruning function? This would guarantee we would only execute the bad function once which already catches 90%+ of the useful cases. I still think the value of having a way to interrupt kernels is important but simplicity first without any hesitation.

have a sort of performance model, which would require too much work I believe

we're working on learning some :) def don't want to engineer those.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants