-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Flash Attention backward to benchmarks/triton_kernels_benchmark
#3108
base: main
Are you sure you want to change the base?
Conversation
run: | | ||
cd benchmarks/triton_kernels_benchmark | ||
FA_KERNEL_MODE="bwd" \ | ||
BENCHMARKING_METHOD="ELAPSED_TIME" python flash_attention_fwd_benchmark.py --reports $REPORTS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default UPSTREAM_PYTORCH_PROFILER
returns zero value for all providers in bwd
mode. Specify ELAPSED_TIME
method (06 tutorial used)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why? FYI @anmyachev
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the launcher hasn't rebuilt with injected PyTorch? @ESI-SYD could you clean ~/.triton/cache
and restart benchmarks with UPSTREAM_PYTORCH_PROFILER
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaning ~/.triton/cache
doesn't help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaning
~/.triton/cache
doesn't help.
for both Triton and XeTLA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaning
~/.triton/cache
doesn't help.for both Triton and XeTLA?
Only for Triton.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, for triton fa bwd, __profile_kernel_of_func
does not include kernel execution time, Xetla works.
Maybe revert this commit c83c0ed (kernel_name deprecated) helps.
We only keep the method of elapsed_time
method in the end?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the information, I think we can leave elapsed_time
for now, and in the meantime I'll look into why UPSTREAM_PYTORCH_PROFILER
mode doesn't work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I figured out the reason. forward
and backward
run in different threads, so the kernels from backward
are not included in __profile_kernel_of_func
cpu_children
. For example:
(Pdb) functions[0]
<FunctionEvent id=3 name=__profile_kernel_of_func device_type=DeviceType.CPU node_id=-1 cpu_time=88.778ms start_us=12366.193 end_us=101144.535 cpu_children=[] xpu_time=0.000us name=__profile_kernel_of_func thread=1 input_shapes=[] cpu_memory_usage=0 xpu_memory_usage=0 is_async=False is_remote=False seq_nr=-1 is_legacy=False>
(Pdb) functions[1]
<FunctionEvent id=524 name=__profile_kernel_of_func2 device_type=DeviceType.CPU node_id=-1 cpu_time=497.294us start_us=13000.303 end_us=13497.597 cpu_children=[525, 526] xpu_time=87.021ms name=__profile_kernel_of_func2 thread=2 input_shapes=[] cpu_memory_usage=0 xpu_memory_usage=0 is_async=False is_remote=False seq_nr=-1 is_legacy=False>
In order to take into account kernels from another thread (from backward
function) I added the following:
from torch.profiler import record_function
with record_function("__profile_kernel_of_func2"):
_attn_bwd_preprocess[pre_grid](
o, do, #
delta, #
BATCH, N_HEAD, N_CTX, #
BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM #
)
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
M, delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
N_HEAD, N_CTX, #
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=ctx.HEAD_DIM, #
num_warps=NUM_WARPS, #
num_stages=NUM_STAGES #
)
This problem can be solved by adding additional record_function
calls, but maybe just not using inheritance from torch.autograd.Function
?
run: | | ||
cd benchmarks/triton_kernels_benchmark | ||
FA_KERNEL_MODE="bwd" \ | ||
BENCHMARKING_METHOD="ELAPSED_TIME" python flash_attention_fwd_benchmark.py --reports $REPORTS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why? FYI @anmyachev
benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchmark.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Whitney Tsang <[email protected]>
Co-authored-by: Whitney Tsang <[email protected]>
This reverts commit e0f25b4.
No description provided.