diff --git a/src/setfit/training_args.py b/src/setfit/training_args.py index 6121794e..f33e0986 100644 --- a/src/setfit/training_args.py +++ b/src/setfit/training_args.py @@ -249,8 +249,6 @@ def __post_init__(self) -> None: if self.report_to in (None, "all", ["all"]): self.report_to = get_available_reporting_integrations() - elif self.report_to in ("none", ["none"]): - self.report_to = [] elif not isinstance(self.report_to, list): self.report_to = [self.report_to] diff --git a/tests/test_training_args.py b/tests/test_training_args.py index ecce4f42..5e035cd7 100644 --- a/tests/test_training_args.py +++ b/tests/test_training_args.py @@ -64,9 +64,9 @@ def test_learning_rates(self): def test_report_to(self): args = TrainingArguments(report_to="none") - self.assertEqual(args.report_to, []) + self.assertEqual(args.report_to, ["none"]) args = TrainingArguments(report_to=["none"]) - self.assertEqual(args.report_to, []) + self.assertEqual(args.report_to, ["none"]) args = TrainingArguments(report_to="hello") self.assertEqual(args.report_to, ["hello"])