From 5a0d78f36a3a941fcbb92a195774b61a574677ed Mon Sep 17 00:00:00 2001 From: Pratham Sood <62192153+prathams417@users.noreply.github.com> Date: Wed, 24 Jan 2024 19:53:06 -0500 Subject: [PATCH] Update tutorial tests 05 and 11 to include fixme comments for tolerance (#352) Also added test 11 to CI --- .github/workflows/build_and_test.yml | 1 + python/tutorials/05-layer-norm.py | 1 + python/tutorials/11-grouped-gemm.py | 1 + scripts/test-triton.sh | 1 + 4 files changed, 4 insertions(+) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index e308e56d3f..6b2c561767 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -195,6 +195,7 @@ jobs: python3 04-low-memory-dropout.py python3 05-layer-norm.py python3 07-math-functions.py + python3 11-grouped-gemm.py - name: Run CXX unittests run: | diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 03cf87d682..43e2152c3f 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -318,6 +318,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='xpu'): assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) assert torch.allclose(dx_tri, dx_ref, atol=1e-2, rtol=0) assert torch.allclose(db_tri, db_ref, atol=1e-2, rtol=0) + # FIXME: tolerance was increased from 1e-2 to 2e-2 to allow test to run assert torch.allclose(dw_tri, dw_ref, atol=2e-2, rtol=0) diff --git a/python/tutorials/11-grouped-gemm.py b/python/tutorials/11-grouped-gemm.py index 54abefe6a6..60a9b55e2d 100644 --- a/python/tutorials/11-grouped-gemm.py +++ b/python/tutorials/11-grouped-gemm.py @@ -208,6 +208,7 @@ def group_gemm_fn(group_A, group_B): tri_out = group_gemm_fn(group_A, group_B) ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] for i in range(group_size): + # FIXME: tolerance was increased from 1e-2 to 3e-2 to allow test to run assert torch.allclose(ref_out[i], tri_out[i], atol=3e-1, rtol=0) diff --git a/scripts/test-triton.sh b/scripts/test-triton.sh index cb8cb4c72d..c383380333 100755 --- a/scripts/test-triton.sh +++ b/scripts/test-triton.sh @@ -165,6 +165,7 @@ function run_tutorial_tests { run_tutorial_test "04-low-memory-dropout" 04-low-memory-dropout.py run_tutorial_test "05-layer-norm" 05-layer-norm.py run_tutorial_test "07-math-functions" 07-math-functions.py + run_tutorial_test "11-grouped-gemm" 11-grouped-gemm.py } function test_triton {