diff --git a/dataclass_rest/rest.py b/dataclass_rest/rest.py index 9394811..21e03b0 100644 --- a/dataclass_rest/rest.py +++ b/dataclass_rest/rest.py @@ -1,19 +1,20 @@ -from functools import partial -from typing import Any, Dict, Optional, Callable +from typing import Any, Callable, Dict, Optional, TypeVar, cast from .boundmethod import BoundMethod from .method import Method -from .parse_func import parse_func, DEFAULT_BODY_PARAM +from .parse_func import DEFAULT_BODY_PARAM, parse_func + +_Func = TypeVar("_Func", bound=Callable[..., Any]) def rest( - url_template: str, - *, - method: str, - body_name: str = DEFAULT_BODY_PARAM, - additional_params: Optional[Dict[str, Any]] = None, - method_class: Optional[Callable[..., BoundMethod]] = None, - send_json: bool = True, + url_template: str, + *, + method: str, + body_name: str = DEFAULT_BODY_PARAM, + additional_params: Optional[Dict[str, Any]] = None, + method_class: Optional[Callable[..., BoundMethod]] = None, + send_json: bool = True, ) -> Callable[[Callable], Method]: if additional_params is None: additional_params = {} @@ -32,8 +33,15 @@ def dec(func: Callable) -> Method: return dec -get = partial(rest, method="GET") -post = partial(rest, method="POST") -put = partial(rest, method="PUT") -patch = partial(rest, method="PATCH") -delete = partial(rest, method="DELETE") +def _rest_method(func: _Func, method: str) -> _Func: + def wrapper(*args, **kwargs): + return func(*args, **kwargs, method=method) + + return cast(_Func, wrapper) + + +get = _rest_method(rest, method="GET") +post = _rest_method(rest, method="POST") +put = _rest_method(rest, method="PUT") +patch = _rest_method(rest, method="PATCH") +delete = _rest_method(rest, method="DELETE")