Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
lubaskinc0de committed Jul 29, 2024
1 parent faffe2a commit d8ca305
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 106 deletions.
6 changes: 5 additions & 1 deletion src/dataclass_rest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
__all__ = [
"File",
"rest",
"get", "put", "post", "patch", "delete",
"get",
"put",
"post",
"patch",
"delete",
]

from .http_request import File
Expand Down
47 changes: 24 additions & 23 deletions src/dataclass_rest/boundmethod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy

from abc import ABC, abstractmethod
from inspect import getcallargs
Expand All @@ -15,11 +14,11 @@

class BoundMethod(ClientMethodProtocol, ABC):
def __init__(
self,
name: str,
method_spec: MethodSpec,
client: ClientProtocol,
on_error: Optional[Callable[[Any], Any]],
self,
name: str,
method_spec: MethodSpec,
client: ClientProtocol,
on_error: Optional[Callable[[Any], Any]],
):
self.name = name
self.method_spec = method_spec
Expand All @@ -28,29 +27,31 @@ def __init__(

def _apply_args(self, *args, **kwargs) -> Dict:
return getcallargs(
self.method_spec.func, self.client, *args, **kwargs,
self.method_spec.func,
self.client,
*args,
**kwargs,
)

def _get_url(self, args) -> str:
args = copy.copy(args)

if not self.method_spec.url_template_func_pop_args:
return self.method_spec.url_template_func(**args)

for arg in self.method_spec.url_template_func_pop_args:
args.pop(arg)

return self.method_spec.url_template_func(**args)
args = {
arg: value
for arg, value in args.items()
if arg in self.method_spec.url_params
}
return self.method_spec.url_template(**args)

def _get_body(self, args) -> Any:
python_body = args.get(self.method_spec.body_param_name)
return self.client.request_body_factory.dump(
python_body, self.method_spec.body_type,
python_body,
self.method_spec.body_type,
)

def _get_query_params(self, args) -> Any:
return self.client.request_args_factory.dump(
args, self.method_spec.query_params_type,
args,
self.method_spec.query_params_type,
)

def _get_files(self, args) -> Dict[str, File]:
Expand All @@ -61,11 +62,11 @@ def _get_files(self, args) -> Dict[str, File]:
}

def _create_request(
self,
url: str,
query_params: Any,
files: Dict[str, File],
data: Any,
self,
url: str,
query_params: Any,
files: Dict[str, File],
data: Any,
) -> HttpRequest:
return HttpRequest(
method=self.method_spec.http_method,
Expand Down
7 changes: 5 additions & 2 deletions src/dataclass_rest/client_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def load(self, data: Any, class_: Type[TypeT]) -> TypeT:
raise NotImplementedError

def dump(
self, data: TypeT, class_: Optional[Type[TypeT]] = None,
self,
data: TypeT,
class_: Optional[Type[TypeT]] = None,
) -> Any:
raise NotImplementedError

Expand All @@ -37,6 +39,7 @@ class ClientProtocol(Protocol):
method_class: Optional[Callable]

def do_request(
self, request: HttpRequest,
self,
request: HttpRequest,
) -> Any:
raise NotImplementedError
9 changes: 5 additions & 4 deletions src/dataclass_rest/http/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ class AiohttpClient(BaseClient):
method_class = AiohttpMethod

def __init__(
self,
base_url: str,
session: Optional[ClientSession] = None,
self,
base_url: str,
session: Optional[ClientSession] = None,
):
super().__init__()
self.session = session or ClientSession()
Expand All @@ -68,7 +68,8 @@ async def do_request(self, request: HttpRequest) -> Any:
for name, file in request.files.items():
data.add_field(
name,
filename=file.filename, content_type=file.content_type,
filename=file.filename,
content_type=file.content_type,
value=file.contents,
)
try:
Expand Down
7 changes: 3 additions & 4 deletions src/dataclass_rest/http/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class RequestsMethod(SyncMethod):

def _on_error_default(self, response: Response) -> Any:
if 400 <= response.status_code < 500:
raise ClientError(response.status_code)
Expand All @@ -39,9 +38,9 @@ class RequestsClient(BaseClient):
method_class = RequestsMethod

def __init__(
self,
base_url: str,
session: Optional[Session] = None,
self,
base_url: str,
session: Optional[Session] = None,
):
super().__init__()
self.session = session or Session()
Expand Down
10 changes: 6 additions & 4 deletions src/dataclass_rest/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

class Method:
def __init__(
self,
method_spec: MethodSpec,
method_class: Optional[Callable[..., BoundMethod]] = None,
self,
method_spec: MethodSpec,
method_class: Optional[Callable[..., BoundMethod]] = None,
):
self.name = method_spec.func.__name__
self.method_spec = method_spec
Expand All @@ -29,7 +29,9 @@ def __set_name__(self, owner, name):
)

def __get__(
self, instance: Optional[ClientProtocol], objtype=None,
self,
instance: Optional[ClientProtocol],
objtype=None,
) -> BoundMethod:
return self.method_class(
name=self.name,
Expand Down
30 changes: 14 additions & 16 deletions src/dataclass_rest/methodspec.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
from typing import Any, Dict, Type, Callable, List, Optional
from typing import Any, Callable, Dict, List, Type


class MethodSpec:
def __init__(
self,
func: Callable,
url_template: Optional[str],
url_template_func: Optional[Callable[..., str]],
url_template_func_pop_args: Optional[List[str]],
http_method: str,
response_type: Type,
body_param_name: str,
body_type: Type,
is_json_request: bool,
query_params_type: Type,
file_param_names: List[str],
additional_params: Dict[str, Any],
self,
func: Callable,
url_template: Callable[..., str],
url_params: List[str],
http_method: str,
response_type: Type,
body_param_name: str,
body_type: Type,
is_json_request: bool, # noqa: FBT001
query_params_type: Type,
file_param_names: List[str],
additional_params: Dict[str, Any],
):
self.func = func
self.url_template = url_template
self.url_template_func = url_template_func
self.url_template_func_pop_args = url_template_func_pop_args
self.url_params = url_params
self.http_method = http_method
self.response_type = response_type
self.body_param_name = body_param_name
Expand Down
63 changes: 33 additions & 30 deletions src/dataclass_rest/parse_func.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
import string
from inspect import getfullargspec, FullArgSpec, isclass
from typing import Callable, List, Sequence, Any, Type, TypedDict, Dict, Union
from inspect import FullArgSpec, getfullargspec, isclass
from typing import (
Any,
Callable,
Dict,
List,
Sequence,
Type,
TypeAlias,
TypedDict,
Union,
)

from .http_request import File
from .methodspec import MethodSpec

DEFAULT_BODY_PARAM = "body"
UrlTemplate: TypeAlias = Union[str, Callable[..., str]]


def get_url_params_from_string(url_template: str) -> List[str]:
parsed_format = string.Formatter().parse(url_template)
return [x[1] for x in parsed_format]
return [x[1] for x in parsed_format if x[1]]


def create_query_params_type(
spec: FullArgSpec,
func: Callable,
skipped: Sequence[str],
spec: FullArgSpec,
func: Callable,
skipped: Sequence[str],
) -> Type:
fields = {}
self_processed = False
Expand All @@ -31,14 +42,14 @@ def create_query_params_type(


def create_body_type(
spec: FullArgSpec,
body_param_name: str,
spec: FullArgSpec,
body_param_name: str,
) -> Type:
return spec.annotations.get(body_param_name, Any)


def create_response_type(
spec: FullArgSpec,
spec: FullArgSpec,
) -> Type:
return spec.annotations.get("return", Any)

Expand All @@ -52,31 +63,24 @@ def get_file_params(spec):


def parse_func(
func: Callable,
method: str,
url_template: Union[str, Callable[..., str]],
additional_params: Dict[str, Any],
is_json_request: bool,
body_param_name: str,
func: Callable,
method: str,
url_template: UrlTemplate,
additional_params: Dict[str, Any],
is_json_request: bool, # noqa: FBT001
body_param_name: str,
) -> MethodSpec:
spec = getfullargspec(func)
file_params = get_file_params(spec)

is_string_url_template = isinstance(url_template, str)
url_template_func = url_template.format if is_string_url_template else url_template

url_template_func_pop_args = None
url_template_callable = (
url_template.format if is_string_url_template else url_template
)

if not is_string_url_template:
url_template_func_arg_spec = getfullargspec(url_template_func)
url_template_func_args = url_template_func_arg_spec.args

url_template_func_args_set = set(url_template_func_args)
diff_kwargs = set(spec.kwonlyargs).difference(url_template_func_args_set)
diff_args = set(spec.args).difference(url_template_func_args_set)

url_template_func_pop_args = diff_args.union(diff_kwargs)
url_params = url_template_func_args
url_template_func_arg_spec = getfullargspec(url_template_callable)
url_params = url_template_func_arg_spec.args
else:
url_params = get_url_params_from_string(url_template)

Expand All @@ -85,9 +89,8 @@ def parse_func(
return MethodSpec(
func=func,
http_method=method,
url_template=url_template if is_string_url_template else None,
url_template_func=url_template_func,
url_template_func_pop_args=url_template_func_pop_args,
url_template=url_template_callable,
url_params=url_params,
query_params_type=create_query_params_type(spec, func, skipped_params),
body_type=create_body_type(spec, body_param_name),
response_type=create_response_type(spec),
Expand Down
4 changes: 2 additions & 2 deletions src/dataclass_rest/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from .boundmethod import BoundMethod
from .method import Method
from .parse_func import DEFAULT_BODY_PARAM, parse_func
from .parse_func import DEFAULT_BODY_PARAM, UrlTemplate, parse_func

_Func = TypeVar("_Func", bound=Callable[..., Any])


def rest(
url_template: str,
url_template: UrlTemplate,
*,
method: str,
body_name: str = DEFAULT_BODY_PARAM,
Expand Down
3 changes: 2 additions & 1 deletion tests/requests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def session():
@pytest.fixture
def mocker(session):
with requests_mock.Mocker(
session=session, case_sensitive=True,
session=session,
case_sensitive=True,
) as session_mock:
yield session_mock
24 changes: 15 additions & 9 deletions tests/requests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,25 @@ class ResponseBody:
def test_body(session, mocker):
class Api(RequestsClient):
def _init_request_body_factory(self) -> Retort:
return Retort(recipe=[
name_mapping(name_style=NameStyle.CAMEL),
])
return Retort(
recipe=[
name_mapping(name_style=NameStyle.CAMEL),
],
)

def _init_request_args_factory(self) -> Retort:
return Retort(recipe=[
name_mapping(name_style=NameStyle.UPPER_DOT),
])
return Retort(
recipe=[
name_mapping(name_style=NameStyle.UPPER_DOT),
],
)

def _init_response_body_factory(self) -> Retort:
return Retort(recipe=[
name_mapping(name_style=NameStyle.LOWER_KEBAB),
])
return Retort(
recipe=[
name_mapping(name_style=NameStyle.LOWER_KEBAB),
],
)

@patch("/post/")
def post_x(self, long_param: str, body: RequestBody) -> ResponseBody:
Expand Down
Loading

0 comments on commit d8ca305

Please sign in to comment.