From 0978de8104046250593806c927b416e4ffbbe293 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Wed, 7 Aug 2024 08:33:22 -0700 Subject: [PATCH] Passes `gr.Request` if type hint is `Request | None` (#9011) * changes * fix for python 3.8, 3.9 * fix * add changeset --------- Co-authored-by: gradio-pr-bot --- .changeset/salty-terms-grow.md | 5 +++++ gradio/helpers.py | 2 +- gradio/utils.py | 14 +++++++++++--- test/test_utils.py | 8 +++++++- 4 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 .changeset/salty-terms-grow.md diff --git a/.changeset/salty-terms-grow.md b/.changeset/salty-terms-grow.md new file mode 100644 index 0000000000000..48c21e35e129a --- /dev/null +++ b/.changeset/salty-terms-grow.md @@ -0,0 +1,5 @@ +--- +"gradio": patch +--- + +fix:Passes `gr.Request` if type hint is `Request | None` diff --git a/gradio/helpers.py b/gradio/helpers.py index 6f5072647d1d2..22149e763e280 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -902,7 +902,7 @@ def special_args( progress_index = i if inputs is not None: inputs.insert(i, param.default) - elif type_hint == routes.Request: + elif type_hint in (routes.Request, Optional[routes.Request]): if inputs is not None: inputs.insert(i, request) elif type_hint in ( diff --git a/gradio/utils.py b/gradio/utils.py index 573a7a1836f34..84d4ca8f8a98c 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -923,12 +923,20 @@ def get_type_hints(fn): for name, param in sig.parameters.items(): if param.annotation is inspect.Parameter.empty: continue - if param.annotation == "gr.OAuthProfile | None": + if param.annotation in ["gr.OAuthProfile | None", "None | gr.OAuthProfile"]: # Special case: we want to inject the OAuthProfile value even on Python 3.9 type_hints[name] = Optional[OAuthProfile] - if param.annotation == "gr.OAuthToken | None": + if param.annotation == ["gr.OAuthToken | None", "None | gr.OAuthToken"]: # Special case: we want to inject the OAuthToken value even on Python 3.9 type_hints[name] = Optional[OAuthToken] + if param.annotation in [ + "gr.Request | None", + "Request | None", + "None | gr.Request", + "None | Request", + ]: + # Special case: we want to inject the Request value even on Python 3.9 + type_hints[name] = Optional[Request] if "|" in str(param.annotation): continue # To convert the string annotation to a class, we use the @@ -953,7 +961,7 @@ def is_special_typed_parameter(name, parameter_types): hint = parameter_types.get(name) if not hint: return False - is_request = hint == Request + is_request = hint in (Request, Optional[Request]) is_oauth_arg = hint in ( OAuthProfile, Optional[OAuthProfile], diff --git a/test/test_utils.py b/test/test_utils.py index d3ad92818dae7..ed09b314c6204 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -306,7 +306,7 @@ class GenericObject: assert len(get_type_hints(GenericObject())) == 0 def test_is_special_typed_parameter(self): - def func(a: list[str], b: Literal["a", "b"], c, d: Request): + def func(a: list[str], b: Literal["a", "b"], c, d: Request, e: Request | None): pass hints = get_type_hints(func) @@ -314,6 +314,7 @@ def func(a: list[str], b: Literal["a", "b"], c, d: Request): assert not is_special_typed_parameter("b", hints) assert not is_special_typed_parameter("c", hints) assert is_special_typed_parameter("d", hints) + assert is_special_typed_parameter("e", hints) def test_is_special_typed_parameter_with_pipe(self): def func(a: Request, b: str | int, c: list[str]): @@ -502,6 +503,11 @@ def func(a, r: Request, b=10): assert get_function_params(func) == [("a", False, None), ("b", True, 10)] + def func2(a, r: Request | None = None, b="abc"): + pass + + assert get_function_params(func2) == [("a", False, None), ("b", True, "abc")] + def test_class_method_skip_first_param(self): class MyClass: def method(self, arg1, arg2=42):