Skip to content

Commit

Permalink
Fix ufmt
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 18, 2024
1 parent 9655a3a commit 6e62004
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 6 additions & 3 deletions test/test_gpu/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

print(f"Testing operators: {TEST_OPERATORS}")


def check_ci_output(op):
from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS

Expand All @@ -54,9 +55,7 @@ def check_ci_output(op):
), f"output impls: {output_impls} != ci_enabled impls: {ci_enabled_impls}"


def _run_one_operator(
args: List[str]
):
def _run_one_operator(args: List[str]):
parser = get_parser(args)
tb_args, extra_args = parser.parse_known_args(args)
if tb_args.op in skip_tests:
Expand All @@ -81,8 +80,10 @@ def _run_one_operator(
f"Operator {op.name} does not support backward, skipping backward test."
)


def _run_operator_in_task(op: str, args: List[str]):
from tritonbench.operators.op_task import OpTask

if op in skip_tests:
# If the op itself is in the skip list, skip all tests
if not skip_tests[op]:
Expand All @@ -100,6 +101,7 @@ def _run_operator_in_task(op: str, args: List[str]):
task.run()
task.check_output()


def make_test(operator):
def test_case(self):
# Add `--test-only` to disable Triton autotune in tests
Expand All @@ -116,6 +118,7 @@ def test_case(self):
_run_one_operator(args)
else:
_run_operator_in_task(op=operator, args=args)

return test_case


Expand Down
2 changes: 2 additions & 0 deletions tritonbench/operators/op_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def make_operator_instance(
args: List[str],
) -> None:
from tritonbench.utils.parser import get_parser

parser = get_parser()
tb_args, extra_args = parser.parse_known_args(args)
Operator = globals()["Operator"]
Expand Down Expand Up @@ -187,6 +188,7 @@ def get_attribute(
def check_output() -> None:
op = globals()["op"]
from tritonbench.utils.triton_op import REGISTERED_BENCHMARKS

output = op.output
output_impls = output.result[0][1].keys()
ci_enabled_impls = [
Expand Down

0 comments on commit 6e62004

Please sign in to comment.