Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
clean up and make TMA, scheduling autotunable (#54)
Summary: Variants: - triton_tutorial_flash_v2_opt no TMA, with computation pipelining - triton_tutorial_flash_v2_tma TMA, with computation pipelining - triton_tutorial_flash_v2_tma_ws: TMA, with computation pipelining and Warp Spec - triton_tutorial_flash_v2_ws: no TMA, with computation pipelining and Warp Spec Pull Request resolved: #54 Test Plan: ``` CUDA_VISIBLE_DEVICES=5 TORCH_CUDA_ARCH_LIST=9.0a python run.py --op flash_attention --only triton_tutorial_flash_v2_opt,triton_tutorial_flash_v2_tma,triton_tutorial_flash_v2 --num-inputs 1 --seq-len 13 --metrics tflops --batch 8 --n-heads 16 --d-head 128 CUDA_VISIBLE_DEVICES=5 TORCH_CUDA_ARCH_LIST=9.0a python run.py --op flash_attention --only triton_tutorial_flash_v2_opt,triton_tutorial_flash_v2_tma,triton_tutorial_flash_v2 --num-inputs 1 --seq-len 13 --metrics accuracy --batch 8 --n-heads 16 --d-head 128 --baseline triton_tutorial_flash_v2 On compiler supporting WarpSpec: CUDA_VISIBLE_DEVICES=5 TORCH_CUDA_ARCH_LIST=9.0a python run.py --op flash_attention --only triton_tutorial_flash_v2_ws,triton_tutorial_flash_v2_tma_ws,triton_tutorial_flash_v2 --num-inputs 1 --seq-len 13 --metrics tflops --batch 8 --n-heads 16 --d-head 128 CUDA_VISIBLE_DEVICES=5 TORCH_CUDA_ARCH_LIST=9.0a python run.py --op flash_attention --only triton_tutorial_flash_v2_ws,triton_tutorial_flash_v2_tma_ws,triton_tutorial_flash_v2 --num-inputs 1 --seq-len 13 --metrics accuracy --batch 8 --n-heads 16 --d-head 128 --baseline triton_tutorial_flash_v2 ``` Reviewed By: htyu Differential Revision: D66109428 Pulled By: manman-ren fbshipit-source-id: 52d89e555ae717f2258dddfc17b4011414ef0e83
- Loading branch information