Skip to content

Commit

Permalink
test(qlinear): increase tolerance when using Marlin FP8
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Aug 28, 2024
1 parent 7830ca7 commit 455a7c7
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion test/nn/test_qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@ def test_quantize_linear_float32_activations_float8(
def test_quantize_linear_float16_weight_only(batch_size, tokens, embeddings, use_bias, weights, device):
if device.type == "mps" and weights == qfloat8:
pytest.skip("Float 8 are not supported on MPS device")
_test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, None, torch.float16, device)
atol = None
if device.type == "cuda" and weights == qfloat8 and embeddings % 64 == 0:
# FIXME: accuracy is slightly worse using MARLIN FP8 kernels
atol = 1e-2
_test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, None, torch.float16, device, atol)


@pytest.mark.parametrize("batch_size", [1, 10])
Expand Down

0 comments on commit 455a7c7

Please sign in to comment.