From ceda820948aa47f75a95236704e5acf75bb8809e Mon Sep 17 00:00:00 2001 From: Bea Date: Wed, 7 Jul 2021 15:52:38 -0400 Subject: [PATCH 1/6] add formatter argument --- fire/core.py | 11 ++++++++--- fire/core_test.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/fire/core.py b/fire/core.py index 8ca142c7..c1af54d9 100644 --- a/fire/core.py +++ b/fire/core.py @@ -78,7 +78,7 @@ def main(argv): import asyncio # pylint: disable=import-error,g-import-not-at-top # pytype: disable=import-error -def Fire(component=None, command=None, name=None): +def Fire(component=None, command=None, name=None, formatter=None): """This function, Fire, is the main entrypoint for Python Fire. Executes a command either from the `command` argument or from sys.argv by @@ -164,7 +164,7 @@ def Fire(component=None, command=None, name=None): raise FireExit(0, component_trace) # The command succeeded normally; print the result. - _PrintResult(component_trace, verbose=component_trace.verbose) + _PrintResult(component_trace, verbose=component_trace.verbose, formatter=formatter) result = component_trace.GetResult() return result @@ -241,12 +241,17 @@ def _IsHelpShortcut(component_trace, remaining_args): return show_help -def _PrintResult(component_trace, verbose=False): +def _PrintResult(component_trace, verbose=False, formatter=None): """Prints the result of the Fire call to stdout in a human readable way.""" # TODO(dbieber): Design human readable deserializable serialization method # and move serialization to its own module. result = component_trace.GetResult() + # Allow users to modify the return value of the component and provide + # custom formatting. + if callable(formatter): + result = formatter(result) + if value_types.HasCustomStr(result): # If the object has a custom __str__ method, rather than one inherited from # object, then we use that to serialize the object. diff --git a/fire/core_test.py b/fire/core_test.py index 27c9f418..e9d12e96 100644 --- a/fire/core_test.py +++ b/fire/core_test.py @@ -194,6 +194,22 @@ def testClassMethod(self): 7, ) + def testCustomFormatter(self): + def formatter(x): + if isinstance(x, list): + return ', '.join(str(xi) for xi in x) + if isinstance(x, dict): + return ', '.join('{}={!r}'.format(k, v) for k, v in x.items()) + return x + + with self.assertOutputMatches(stdout='a, b', stderr=None): + result = core.Fire(lambda x: list(x), command=['[a,b]'], formatter=formatter) + with self.assertOutputMatches(stdout='a=5, b=6', stderr=None): + result = core.Fire(lambda x: dict(x), command=['{a:5,b:6}'], formatter=formatter) + with self.assertOutputMatches(stdout='asdf', stderr=None): + result = core.Fire(lambda x: str(x), command=['asdf'], formatter=formatter) + + @testutils.skipIf(six.PY2, 'lru_cache is Python 3 only.') def testLruCacheDecoratorBoundArg(self): self.assertEqual( From ffc8077b2a6783a348409bb99e8ed14ec985e133 Mon Sep 17 00:00:00 2001 From: Bea Date: Wed, 7 Jul 2021 15:59:45 -0400 Subject: [PATCH 2/6] added another test case --- fire/core_test.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/fire/core_test.py b/fire/core_test.py index e9d12e96..d79fc523 100644 --- a/fire/core_test.py +++ b/fire/core_test.py @@ -200,14 +200,20 @@ def formatter(x): return ', '.join(str(xi) for xi in x) if isinstance(x, dict): return ', '.join('{}={!r}'.format(k, v) for k, v in x.items()) + if x == 'special': + return 'SURPRISE!!' return x + + ident = lambda x: x with self.assertOutputMatches(stdout='a, b', stderr=None): - result = core.Fire(lambda x: list(x), command=['[a,b]'], formatter=formatter) + result = core.Fire(ident, command=['[a,b]'], formatter=formatter) with self.assertOutputMatches(stdout='a=5, b=6', stderr=None): - result = core.Fire(lambda x: dict(x), command=['{a:5,b:6}'], formatter=formatter) + result = core.Fire(ident, command=['{a:5,b:6}'], formatter=formatter) with self.assertOutputMatches(stdout='asdf', stderr=None): - result = core.Fire(lambda x: str(x), command=['asdf'], formatter=formatter) + result = core.Fire(ident, command=['asdf'], formatter=formatter) + with self.assertOutputMatches(stdout='SURPRISE!!', stderr=None): + result = core.Fire(ident, command=['special'], formatter=formatter) @testutils.skipIf(six.PY2, 'lru_cache is Python 3 only.') From 8a4963d8fc9617759419de7c191a05940462d277 Mon Sep 17 00:00:00 2001 From: Bea Date: Wed, 7 Jul 2021 16:01:35 -0400 Subject: [PATCH 3/6] improved test case --- fire/core_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fire/core_test.py b/fire/core_test.py index d79fc523..6bdc611a 100644 --- a/fire/core_test.py +++ b/fire/core_test.py @@ -201,7 +201,7 @@ def formatter(x): if isinstance(x, dict): return ', '.join('{}={!r}'.format(k, v) for k, v in x.items()) if x == 'special': - return 'SURPRISE!!' + return ['SURPRISE!!', "I'm a list!"] return x ident = lambda x: x @@ -212,7 +212,7 @@ def formatter(x): result = core.Fire(ident, command=['{a:5,b:6}'], formatter=formatter) with self.assertOutputMatches(stdout='asdf', stderr=None): result = core.Fire(ident, command=['asdf'], formatter=formatter) - with self.assertOutputMatches(stdout='SURPRISE!!', stderr=None): + with self.assertOutputMatches(stdout="SURPRISE!!\nI'm a list!\n", stderr=None): result = core.Fire(ident, command=['special'], formatter=formatter) From e837db20d1c0032f34eedc975a6e9eefb96471b1 Mon Sep 17 00:00:00 2001 From: Bea Date: Sun, 5 Dec 2021 17:10:27 -0500 Subject: [PATCH 4/6] rename formatter= to serialize= --- fire/core.py | 10 +++++----- fire/core_test.py | 12 ++++++------ 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/fire/core.py b/fire/core.py index c1af54d9..71c913a7 100644 --- a/fire/core.py +++ b/fire/core.py @@ -78,7 +78,7 @@ def main(argv): import asyncio # pylint: disable=import-error,g-import-not-at-top # pytype: disable=import-error -def Fire(component=None, command=None, name=None, formatter=None): +def Fire(component=None, command=None, name=None, serialize=None): """This function, Fire, is the main entrypoint for Python Fire. Executes a command either from the `command` argument or from sys.argv by @@ -164,7 +164,7 @@ def Fire(component=None, command=None, name=None, formatter=None): raise FireExit(0, component_trace) # The command succeeded normally; print the result. - _PrintResult(component_trace, verbose=component_trace.verbose, formatter=formatter) + _PrintResult(component_trace, verbose=component_trace.verbose, serialize=serialize) result = component_trace.GetResult() return result @@ -241,7 +241,7 @@ def _IsHelpShortcut(component_trace, remaining_args): return show_help -def _PrintResult(component_trace, verbose=False, formatter=None): +def _PrintResult(component_trace, verbose=False, serialize=None): """Prints the result of the Fire call to stdout in a human readable way.""" # TODO(dbieber): Design human readable deserializable serialization method # and move serialization to its own module. @@ -249,8 +249,8 @@ def _PrintResult(component_trace, verbose=False, formatter=None): # Allow users to modify the return value of the component and provide # custom formatting. - if callable(formatter): - result = formatter(result) + if callable(serialize): + result = serialize(result) if value_types.HasCustomStr(result): # If the object has a custom __str__ method, rather than one inherited from diff --git a/fire/core_test.py b/fire/core_test.py index 6bdc611a..87479f5d 100644 --- a/fire/core_test.py +++ b/fire/core_test.py @@ -194,8 +194,8 @@ def testClassMethod(self): 7, ) - def testCustomFormatter(self): - def formatter(x): + def testCustomSerialize(self): + def serialize(x): if isinstance(x, list): return ', '.join(str(xi) for xi in x) if isinstance(x, dict): @@ -207,13 +207,13 @@ def formatter(x): ident = lambda x: x with self.assertOutputMatches(stdout='a, b', stderr=None): - result = core.Fire(ident, command=['[a,b]'], formatter=formatter) + result = core.Fire(ident, command=['[a,b]'], serialize=serialize) with self.assertOutputMatches(stdout='a=5, b=6', stderr=None): - result = core.Fire(ident, command=['{a:5,b:6}'], formatter=formatter) + result = core.Fire(ident, command=['{a:5,b:6}'], serialize=serialize) with self.assertOutputMatches(stdout='asdf', stderr=None): - result = core.Fire(ident, command=['asdf'], formatter=formatter) + result = core.Fire(ident, command=['asdf'], serialize=serialize) with self.assertOutputMatches(stdout="SURPRISE!!\nI'm a list!\n", stderr=None): - result = core.Fire(ident, command=['special'], formatter=formatter) + result = core.Fire(ident, command=['special'], serialize=serialize) @testutils.skipIf(six.PY2, 'lru_cache is Python 3 only.') From 278334297e1ebbfcf9fcec5d894762132e65e03c Mon Sep 17 00:00:00 2001 From: Bea Date: Mon, 6 Dec 2021 12:59:50 -0500 Subject: [PATCH 5/6] raise error if serialize is not callable --- fire/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fire/core.py b/fire/core.py index 71c913a7..456034c3 100644 --- a/fire/core.py +++ b/fire/core.py @@ -249,7 +249,9 @@ def _PrintResult(component_trace, verbose=False, serialize=None): # Allow users to modify the return value of the component and provide # custom formatting. - if callable(serialize): + if serialize: + if not callable(serialize): + raise FireError("serialize must be callable.") result = serialize(result) if value_types.HasCustomStr(result): From 1b2976685cf8cff8bb9f9bcece4502557cf74611 Mon Sep 17 00:00:00 2001 From: Bea Date: Mon, 6 Dec 2021 13:05:51 -0500 Subject: [PATCH 6/6] improve serialize non-callable error message and add test case --- fire/core.py | 2 +- fire/core_test.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/fire/core.py b/fire/core.py index 456034c3..6fd1bf7a 100644 --- a/fire/core.py +++ b/fire/core.py @@ -251,7 +251,7 @@ def _PrintResult(component_trace, verbose=False, serialize=None): # custom formatting. if serialize: if not callable(serialize): - raise FireError("serialize must be callable.") + raise FireError("serialize argument {} must be empty or callable.".format(serialize)) result = serialize(result) if value_types.HasCustomStr(result): diff --git a/fire/core_test.py b/fire/core_test.py index 87479f5d..a0576ee9 100644 --- a/fire/core_test.py +++ b/fire/core_test.py @@ -214,6 +214,8 @@ def serialize(x): result = core.Fire(ident, command=['asdf'], serialize=serialize) with self.assertOutputMatches(stdout="SURPRISE!!\nI'm a list!\n", stderr=None): result = core.Fire(ident, command=['special'], serialize=serialize) + with self.assertRaises(core.FireError): + core.Fire(ident, command=['asdf'], serialize=55) @testutils.skipIf(six.PY2, 'lru_cache is Python 3 only.')