Skip to content

Commit

Permalink
Fix bwd code
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 18, 2024
1 parent 8a85210 commit 90bbb73
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,15 +556,15 @@ def __init__(
self.tb_args.cudagraph if self.tb_args.cudagraph else self.use_cuda_graphs
)
# we accept both "fwd" and "eval"
if self.tb_args.mode == "fwd":
if self.tb_args.mode == "fwd" or self.tb_args.fwd:
self.mode = Mode.FWD
elif self.tb_args.mode == "fwd_bwd":
elif self.tb_args.mode == "fwd_bwd" or self.tb_args.fwd_bwd:
self.mode = Mode.FWD_BWD
elif self.tb_args.mode == "fwd_no_grad":
elif self.tb_args.mode == "fwd_no_grad" or self.tb_args.fwd_no_grad:
self.mode = Mode.FWD_NO_GRAD
else:
assert (
self.tb_args.mode == "bwd"
self.tb_args.mode == "bwd" or self.tb_args.bwd
), f"We only accept 3 test modes: fwd(eval), fwd_bwd(train), or bwd."
self.mode = Mode.BWD
self.device = tb_args.device
Expand Down

0 comments on commit 90bbb73

Please sign in to comment.