From 6e62004edac3697c02d1b902db527fa5ffe162ec Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Mon, 18 Nov 2024 16:44:26 -0500 Subject: [PATCH] Fix ufmt --- test/test_gpu/main.py | 9 ++++++--- tritonbench/operators/op_task.py | 2 ++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index ad515a3d..517a9ef9 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -39,6 +39,7 @@ print(f"Testing operators: {TEST_OPERATORS}") + def check_ci_output(op): from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS @@ -54,9 +55,7 @@ def check_ci_output(op): ), f"output impls: {output_impls} != ci_enabled impls: {ci_enabled_impls}" -def _run_one_operator( - args: List[str] -): +def _run_one_operator(args: List[str]): parser = get_parser(args) tb_args, extra_args = parser.parse_known_args(args) if tb_args.op in skip_tests: @@ -81,8 +80,10 @@ def _run_one_operator( f"Operator {op.name} does not support backward, skipping backward test." ) + def _run_operator_in_task(op: str, args: List[str]): from tritonbench.operators.op_task import OpTask + if op in skip_tests: # If the op itself is in the skip list, skip all tests if not skip_tests[op]: @@ -100,6 +101,7 @@ def _run_operator_in_task(op: str, args: List[str]): task.run() task.check_output() + def make_test(operator): def test_case(self): # Add `--test-only` to disable Triton autotune in tests @@ -116,6 +118,7 @@ def test_case(self): _run_one_operator(args) else: _run_operator_in_task(op=operator, args=args) + return test_case diff --git a/tritonbench/operators/op_task.py b/tritonbench/operators/op_task.py index 7693c5c9..adc4c1de 100644 --- a/tritonbench/operators/op_task.py +++ b/tritonbench/operators/op_task.py @@ -122,6 +122,7 @@ def make_operator_instance( args: List[str], ) -> None: from tritonbench.utils.parser import get_parser + parser = get_parser() tb_args, extra_args = parser.parse_known_args(args) Operator = globals()["Operator"] @@ -187,6 +188,7 @@ def get_attribute( def check_output() -> None: op = globals()["op"] from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS + output = op.output output_impls = output.result[0][1].keys() ci_enabled_impls = [