diff --git a/test/test_gpu/main.py b/test/test_gpu/main.py index 2738ba77..5e900d76 100644 --- a/test/test_gpu/main.py +++ b/test/test_gpu/main.py @@ -66,17 +66,13 @@ def _run_one_operator(args: List[str]): op = Operator(tb_args=tb_args, extra_args=extra_args) op.run() check_ci_output(op) - del op # Test backward (if applicable) - try: + if op.has_bwd(): + del op tb_args.mode = "bwd" op = Operator(tb_args=tb_args, extra_args=extra_args) op.run() check_ci_output(op) - except NotImplementedError: - logger.info( - f"Operator {op.name} does not support backward, skipping backward test." - ) def _run_operator_in_task(op: str, args: List[str]): @@ -92,16 +88,13 @@ def _run_operator_in_task(op: str, args: List[str]): task.make_operator_instance(args=args) task.run() task.check_output() - task.del_op_instance() # Test backward (if applicable) - try: + if task.get_attribute("has_bwd", method=True): + task.del_op_instance() args.extend(["--bwd"]) task.make_operator_instance(args=args) task.run() task.check_output() - except NotImplementedError: - # Operator does not support backward, skip the test - pass def make_test(operator): diff --git a/tritonbench/operators/op_task.py b/tritonbench/operators/op_task.py index adc4c1de..ff365a49 100644 --- a/tritonbench/operators/op_task.py +++ b/tritonbench/operators/op_task.py @@ -165,7 +165,10 @@ def run(self) -> None: @base_task.run_in_worker(scoped=True) @staticmethod def get_attribute( - attr: str, field: Optional[str] = None, classattr: bool = False + attr: str, + field: Optional[str] = None, + classattr: bool = False, + method: bool = False, ) -> Any: if classattr: op = globals()["Operator"] @@ -173,10 +176,10 @@ def get_attribute( op = globals()["op"] if hasattr(op, attr): if field: - op_attr = getattr(op, attr) - return getattr(op_attr, field) + op_attr = getattr(getattr(op, attr), field) else: - return getattr(op, attr) + op_attr = getattr(op, attr) + return op_attr() if method else op_attr else: return None diff --git a/tritonbench/utils/triton_op.py b/tritonbench/utils/triton_op.py index 785583c7..ccaf88cb 100644 --- a/tritonbench/utils/triton_op.py +++ b/tritonbench/utils/triton_op.py @@ -1517,3 +1517,7 @@ def run_and_capture(self, *args, **kwargs): ir_dir / f"{fn._name}_{kernel.name}_{input_id}.sass", "w" ) as f: f.write(sass) + + @classmethod + def has_bwd(cls) -> bool: + return cls.get_bwd_fn is not BenchmarkOperator.get_bwd_fn