-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add int8 to gemm w/ addmatrix #3040
Conversation
I think we treat them as 2 separate benchmarks in terms of reporting (So we have 2 lines of Then we either run benchmark script twice with different dypes to generate 2 separate report files (I'd prefer this), or modify our report script to add some filtering capability. About onednn and int8 support. Do we want to measure onednn? If not, and we run it only for validation, maybe we could run it in other precision, like fp32 or bf16 just for validation of the output. |
My only issue with this PR right now is that all charts and GeoMeans for addmatrix benchmark will be discontinued, due to new parameters. Hence, my suggestion to introduce separate benchmark for int8 |
164fd00
to
3cc0b6e
Compare
I removed the onednn related bits because for onednn we only measure kernel time, and the add step is not fused into the main gemm kernel so the comparison would not be appropriate. I modified the int8 code to only run as a separate benchmark step controllable by environment variable, so the existing bfloat16 time should not be affected. The default run mode is bfloat16, with an optional int8 mode or all dtypes mode for local runs. |
benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py
Outdated
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py
Outdated
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py
Outdated
Show resolved
Hide resolved
benchmarks/triton_kernels_benchmark/gemm_postop_addmatrix_benchmark.py
Outdated
Show resolved
Hide resolved
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') | ||
if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048], | ||
[1, 512, 8192, 32768], [4, 32768, 4096, 128]]: | ||
# torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have a env var for all benchmarks to control if we verify the result?
Don't think we should skip checking correctness for some shapes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The correctness checks are very slow for int8 because they have to be done on CPU. I don't think it is worth doubling (at least) the benchmarks runtime to do correctness checks for all shapes. I tried to pick a small sampling of different shapes (batched and not batched) so we had some int8 correctness checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I agree is a good compromise due to torch int8 matmul not supported on GPU, but my concern is it is not obvious to users that some int8 shapes are not tested for correctness. I don't have a better suggestion anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we wanted to do a CI run I would add the env variable, but if a user is running the benchmarks they can just comment out the if
(and unindent the line after it).
1146674
to
530d13c
Compare
add int8 to addmatrix benchmark 2/? add int8 to addmatrix benchmark 3/3
530d13c
to
72dfa2a
Compare
Update the gemm addmatrix benchmark to support int8 inputs as well as bfloat16.
The int8 benchmark is pretty slow - not because Triton performance is bad (it is at least on par with bfloat16) but because PyTorch does not support int8 matmul on GPU, so we have to do the matmul on the GPU. This makes the benchmark something like 20x slower. To fix that, I changed the PyTorch accuracy check to only run for a few shapes instead of all the shapes - I tried to pick shapes that I thought were representative of different cases but am open to suggestions. Now the benchmark runs in reasonable time.
A few open items need to be addressed:
cc #3014