Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention backward to benchmarks/triton_kernels_benchmark #3108

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

ESI-SYD
Copy link
Contributor

@ESI-SYD ESI-SYD commented Jan 7, 2025

No description provided.

@ESI-SYD ESI-SYD marked this pull request as draft January 7, 2025 09:11
run: |
cd benchmarks/triton_kernels_benchmark
FA_KERNEL_MODE="bwd" \
BENCHMARKING_METHOD="ELAPSED_TIME" python flash_attention_fwd_benchmark.py --reports $REPORTS
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default UPSTREAM_PYTORCH_PROFILER returns zero value for all providers in bwd mode. Specify ELAPSED_TIME method (06 tutorial used)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why? FYI @anmyachev

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the launcher hasn't rebuilt with injected PyTorch? @ESI-SYD could you clean ~/.triton/cache and restart benchmarks with UPSTREAM_PYTORCH_PROFILER?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaning ~/.triton/cache doesn't help.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaning ~/.triton/cache doesn't help.

for both Triton and XeTLA?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaning ~/.triton/cache doesn't help.

for both Triton and XeTLA?

Only for Triton.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for triton fa bwd, __profile_kernel_of_func does not include kernel execution time, Xetla works.
image

Maybe revert this commit c83c0ed (kernel_name deprecated) helps.
We only keep the method of elapsed_time method in the end?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the information, I think we can leave elapsed_time for now, and in the meantime I'll look into why UPSTREAM_PYTORCH_PROFILER mode doesn't work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I figured out the reason. forward and backward run in different threads, so the kernels from backward are not included in __profile_kernel_of_func cpu_children. For example:

(Pdb) functions[0]
<FunctionEvent id=3 name=__profile_kernel_of_func device_type=DeviceType.CPU node_id=-1 cpu_time=88.778ms start_us=12366.193 end_us=101144.535 cpu_children=[] xpu_time=0.000us name=__profile_kernel_of_func thread=1 input_shapes=[] cpu_memory_usage=0 xpu_memory_usage=0 is_async=False is_remote=False seq_nr=-1 is_legacy=False>
(Pdb) functions[1]
<FunctionEvent id=524 name=__profile_kernel_of_func2 device_type=DeviceType.CPU node_id=-1 cpu_time=497.294us start_us=13000.303 end_us=13497.597 cpu_children=[525, 526] xpu_time=87.021ms name=__profile_kernel_of_func2 thread=2 input_shapes=[] cpu_memory_usage=0 xpu_memory_usage=0 is_async=False is_remote=False seq_nr=-1 is_legacy=False>

In order to take into account kernels from another thread (from backward function) I added the following:

        from torch.profiler import record_function
        with record_function("__profile_kernel_of_func2"):
            _attn_bwd_preprocess[pre_grid](
                o, do,  #
                delta,  #
                BATCH, N_HEAD, N_CTX,  #
                BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM  #
            )
            grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
            _attn_bwd[grid](
                q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,  #
                M, delta,  #
                q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
                N_HEAD, N_CTX,  #
                BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1,  #
                BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2,  #
                BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,  #
                HEAD_DIM=ctx.HEAD_DIM,  #
                num_warps=NUM_WARPS,  #
                num_stages=NUM_STAGES  #
            )

This problem can be solved by adding additional record_function calls, but maybe just not using inheritance from torch.autograd.Function?

@ESI-SYD ESI-SYD marked this pull request as ready for review January 8, 2025 02:53
@ESI-SYD ESI-SYD linked an issue Jan 8, 2025 that may be closed by this pull request
scripts/test-triton.sh Outdated Show resolved Hide resolved
run: |
cd benchmarks/triton_kernels_benchmark
FA_KERNEL_MODE="bwd" \
BENCHMARKING_METHOD="ELAPSED_TIME" python flash_attention_fwd_benchmark.py --reports $REPORTS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why? FYI @anmyachev

.github/workflows/triton-benchmarks.yml Outdated Show resolved Hide resolved
scripts/test-triton.sh Outdated Show resolved Hide resolved
scripts/test-triton.sh Outdated Show resolved Hide resolved
.github/workflows/triton-benchmarks.yml Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add Flash Attention backward to benchmarks/triton_kernels_benchmark
4 participants