diff --git a/axlearn/common/compiler_options.py b/axlearn/common/compiler_options.py index c1c056cd..bff095c0 100644 --- a/axlearn/common/compiler_options.py +++ b/axlearn/common/compiler_options.py @@ -144,7 +144,7 @@ def xla_flags_from_options(xla_options: dict[str, Union[str, bool, int]]) -> str flags = [] for k, v in xla_options.items(): if isinstance(v, bool): - v = "1" if v else "0" + v = "true" if v else "false" flags.append(f"--{k}={v}") return " ".join(flags) diff --git a/axlearn/common/compiler_options_test.py b/axlearn/common/compiler_options_test.py index ff146c2e..ed26f5da 100644 --- a/axlearn/common/compiler_options_test.py +++ b/axlearn/common/compiler_options_test.py @@ -24,10 +24,10 @@ def f(x: Tensor) -> Tensor: ) self.assertEqual(f_compiled(5), 15) - def atest_xla_flags_from_options(self): + def test_xla_flags_from_options(self): options = dict(a="true", b="false", c=True, d=False, long_option_name=True) result = compiler_options.xla_flags_from_options(options) - self.assertEqual(result, "--a=true --b=false --c=1 --d=0 --long_option_name=1") + self.assertEqual(result, "--a=true --b=false --c=true --d=false --long_option_name=true") def test_xsc_compiler_options(self): options = compiler_options.infer_xsc_compiler_options(