From 4c0f84d620ae7f7515d4a874bcd8f48dc3c6afb0 Mon Sep 17 00:00:00 2001 From: Boris Filippov Date: Thu, 19 Aug 2021 19:42:37 +0400 Subject: [PATCH] Add decorator to parse function type hints --- docs/guide.md | 50 +++++++++++++++++++++++++++++++++++++++++ fire/decorators.py | 40 +++++++++++++++++++++++++++++++++ fire/decorators_test.py | 36 +++++++++++++++++++++++++++++ 3 files changed, 126 insertions(+) diff --git a/docs/guide.md b/docs/guide.md index d5da3212..4d99cbc6 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -692,6 +692,56 @@ flag (as in `--obj=True`), or by making sure there's another flag after any boolean flag argument. +#### Type hints + +Fire can be configured to use type hints information by decorating functions with `UseTypeHints()` decorator. +Only `int`, `float` and `str` type hints are respected by default, everything else is ignored (parsed as usual). +Quite common usecase is to instruct fire not to convert strings to integer/floats by supplying `str` +type annotation. + +See minimal example below: + +```python +import fire + +from fire.decorators import UseTypeHints + + +@UseTypeHints() # () are mandatory here +def main(a: str, b: float): + print(type(a), type(b)) + + +if __name__ == "__main__": + fire.Fire(main) +``` + +When invoked with `python command.py 1 2` this code will print `str float`. + +You can set custom parsers for type hints via decorator argument, following example shows how to parse string to `pathlib.Path` object: + +```python +import fire + +from pathlib import Path +from fire.decorators import UseTypeHints + + +@UseTypeHints({Path: Path}) +def main(a: Path, b: str): + print(a) + + +if __name__ == "__main__": + fire.Fire(main) +``` + +This code will convert argument `a` to `pathlib.Path`. + +To override default behavior for `int`, `str`, and `float` type hints you need to add them into dictionary supplied to +`UseTypeHints` decorator. + + ### Using Fire Flags Fire CLIs all come with a number of flags. These flags should be separated from diff --git a/fire/decorators.py b/fire/decorators.py index 9e56d6df..60bbe93e 100644 --- a/fire/decorators.py +++ b/fire/decorators.py @@ -29,6 +29,46 @@ ACCEPTS_POSITIONAL_ARGS = 'ACCEPTS_POSITIONAL_ARGS' +def UseTypeHints(type_hints_mapping=None): + """Instruct fire to use type hints information when parsing args for this + function. + + Args: + type_hints_mapping: mapping of type hints into parsing functions, by + default floats, ints and strings are treated, and all other type + hints are ignored (parsed as usual) + Returns: + The decorated function, which now has metadata telling Fire how to perform + according to type hints. + + Examples: + @UseTypeHints() + def main(a, b:int, c:float=2.0) + assert isinstance(b, int) + assert isinstance(c, float) + + @UseTypeHints({list: lambda s: s.split(";")}) + def main(a, c: list): + assert isinstance(c, list) + """ + mapping = {float: float, int: int, str: str} + if type_hints_mapping is not None: + mapping.update(type_hints_mapping) + type_hints_mapping = mapping + + def _Decorator(fn): + signature = inspect.signature(fn) + named = {} + for name, param in signature.parameters.items(): + has_type_hint = param.annotation is not param.empty + if has_type_hint and param.annotation in type_hints_mapping: + named[name] = type_hints_mapping[param.annotation] + decorator = SetParseFns(**named) + decorated_func = decorator(fn) + return decorated_func + return _Decorator + + def SetParseFn(fn, *arguments): """Sets the fn for Fire to use to parse args when calling the decorated fn. diff --git a/fire/decorators_test.py b/fire/decorators_test.py index cc7d6203..7216e7a0 100644 --- a/fire/decorators_test.py +++ b/fire/decorators_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys +import unittest from fire import core from fire import decorators @@ -90,6 +92,22 @@ def example7(self, arg1, arg2=None, *varargs, **kwargs): # pylint: disable=keyw return arg1, arg2, varargs, kwargs +if sys.version_info >= (3, 5): + from pathlib import Path + + + class WithTypeHints(object): + + @decorators.UseTypeHints() + def example8(self, a: int, b: str, c, d : float = None): + return a, b, c, d + + @decorators.UseTypeHints({list: lambda arg: list(map(int, arg.split(";"))), + Path: Path}) + def example9(self, a: Path, b, c: list, d : list = None): + return a, b, c, d + + class FireDecoratorsTest(testutils.BaseTestCase): def testSetParseFnsNamedArgs(self): @@ -169,6 +187,24 @@ def testSetParseFn(self): command=['example7', '1', '--arg2=2', '3', '4', '--kwarg=5']), ('1', '2', ('3', '4'), {'kwarg': '5'})) + @unittest.skipIf(sys.version_info < (3, 5), + 'Type hints were introduced in python 3.5') + def testDefaultTypeHints(self): + self.assertEqual( + core.Fire(WithTypeHints, + command=['example8', '1', '2', '3', '--d=4']), + (1, '2', 3, 4) + ) + + @unittest.skipIf(sys.version_info < (3, 5), + 'Type hints were introduced in python 3.5') + def testCustomTypeHints(self): + self.assertEqual( + core.Fire(WithTypeHints, + command=['example9', '1', '2', '3', '--d=4;5;6']), + (Path('1'), 2, [3], [4, 5, 6]) + ) + if __name__ == '__main__': testutils.main()