diff --git a/test/nn/test_qlinear.py b/test/nn/test_qlinear.py index 08320f3a..2351d057 100644 --- a/test/nn/test_qlinear.py +++ b/test/nn/test_qlinear.py @@ -119,7 +119,11 @@ def test_quantize_linear_float32_activations_float8( def test_quantize_linear_float16_weight_only(batch_size, tokens, embeddings, use_bias, weights, device): if device.type == "mps" and weights == qfloat8: pytest.skip("Float 8 are not supported on MPS device") - _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, None, torch.float16, device) + atol = None + if device.type == "cuda" and weights == qfloat8 and embeddings % 64 == 0: + # FIXME: accuracy is slightly worse using MARLIN FP8 kernels + atol = 1e-2 + _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, None, torch.float16, device, atol) @pytest.mark.parametrize("batch_size", [1, 10])