Skip to content

Commit

Permalink
Update tutorial tests 05 and 11 to include fixme comments for toleran…
Browse files Browse the repository at this point in the history
…ce (#352)

Also added test 11 to CI
  • Loading branch information
prathams417 authored Jan 25, 2024
1 parent 062ccc2 commit 5a0d78f
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ jobs:
python3 04-low-memory-dropout.py
python3 05-layer-norm.py
python3 07-math-functions.py
python3 11-grouped-gemm.py
- name: Run CXX unittests
run: |
Expand Down
1 change: 1 addition & 0 deletions python/tutorials/05-layer-norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='xpu'):
assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
assert torch.allclose(dx_tri, dx_ref, atol=1e-2, rtol=0)
assert torch.allclose(db_tri, db_ref, atol=1e-2, rtol=0)
# FIXME: tolerance was increased from 1e-2 to 2e-2 to allow test to run
assert torch.allclose(dw_tri, dw_ref, atol=2e-2, rtol=0)


Expand Down
1 change: 1 addition & 0 deletions python/tutorials/11-grouped-gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def group_gemm_fn(group_A, group_B):
tri_out = group_gemm_fn(group_A, group_B)
ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)]
for i in range(group_size):
# FIXME: tolerance was increased from 1e-2 to 3e-2 to allow test to run
assert torch.allclose(ref_out[i], tri_out[i], atol=3e-1, rtol=0)


Expand Down
1 change: 1 addition & 0 deletions scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ function run_tutorial_tests {
run_tutorial_test "04-low-memory-dropout" 04-low-memory-dropout.py
run_tutorial_test "05-layer-norm" 05-layer-norm.py
run_tutorial_test "07-math-functions" 07-math-functions.py
run_tutorial_test "11-grouped-gemm" 11-grouped-gemm.py
}

function test_triton {
Expand Down

0 comments on commit 5a0d78f

Please sign in to comment.