Skip to content

Commit

Permalink
Remove dead code branches from test_reduce1d (#3265)
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Jan 25, 2025
1 parent 8c8a722 commit ec278de
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,18 +2285,16 @@ def kernel(X, Z, BLOCK: tl.constexpr):
'min': np.min,
'max-with-indices': np.max,
'min-with-indices': np.min,
'argmin-tie-break-fast': np.argmin,
'argmin-tie-break-left': np.argmin,
'argmax-tie-break-fast': np.argmax,
'argmax-tie-break-left': np.argmax,
}[op]
if 'tie-break-left' in op:
x[3:10] = x[numpy_op(x)]
x_tri = to_triton(x, device=device)
# numpy result
z_dtype_str = 'int32' if op in ('argmin', 'argmax') else dtype_str
z_dtype_str = 'int32' if 'tie-break-left' in op else dtype_str
z_tri_dtype_str = z_dtype_str
if op not in ['argmin', 'argmax'] and dtype_str == 'bfloat16':
if 'tie-break-left' not in op and dtype_str == 'bfloat16':
z_dtype_str = 'float32'
z_ref = numpy_op(x).astype(getattr(np, z_dtype_str))
# trunc mantissa for a fair comparison of accuracy
Expand All @@ -2316,7 +2314,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
if op == 'sum':
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
else:
if op in ('argmin', 'argmax'):
if 'tie-break-left' in op:
# argmin and argmax can have multiple valid indices.
# so instead we compare the values pointed by indices
np.testing.assert_equal(x[z_ref], x[z_tri])
Expand Down

0 comments on commit ec278de

Please sign in to comment.