diff --git a/requirements/common.txt b/requirements/common.txt index d17ab22..82884e0 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -1,2 +1,3 @@ httpx>=0.26.0 jsonschema>=3.2.0 +urllib3>=2.2.1 diff --git a/requirements/tests.txt b/requirements/tests.txt index 45b26c5..e12bdbb 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -2,5 +2,6 @@ -r common.txt pytest pytest-cov +pytest-asyncio tox selenium==3.141.0 diff --git a/src/pyDataverse/api.py b/src/pyDataverse/api.py index 14c4826..1b79d70 100644 --- a/src/pyDataverse/api.py +++ b/src/pyDataverse/api.py @@ -1,6 +1,7 @@ """Dataverse API wrapper for all it's API's.""" import json +from typing import Any, Dict, Optional import httpx import subprocess as sp from urllib.parse import urljoin @@ -37,7 +38,10 @@ class Api: """ def __init__( - self, base_url: str, api_token: str = None, api_version: str = "latest" + self, + base_url: str, + api_token: Optional[str] = None, + api_version: str = "latest", ): """Init an Api() class. @@ -64,6 +68,7 @@ def __init__( raise ApiUrlError("base_url {0} is not a string.".format(base_url)) self.base_url = base_url + self.client = None if not isinstance(api_version, ("".__class__, "".__class__)): raise ApiUrlError("api_version {0} is not a string.".format(api_version)) @@ -120,28 +125,17 @@ def get_request(self, url, params=None, auth=False): if self.api_token: params["key"] = str(self.api_token) - try: - url = urljoin(self.base_url_api, url) - resp = httpx.get(url, params=params) - if resp.status_code == 401: - error_msg = resp.json()["message"] - raise ApiAuthorizationError( - "ERROR: GET - Authorization invalid {0}. MSG: {1}.".format( - url, error_msg - ) - ) - elif resp.status_code >= 300: - if resp.text: - error_msg = resp.text - raise OperationFailedError( - "ERROR: GET HTTP {0} - {1}. MSG: {2}".format( - resp.status_code, url, error_msg - ) - ) - return resp - except ConnectError: - raise ConnectError( - "ERROR: GET - Could not establish connection to api {0}.".format(url) + if self.client is None: + return self._sync_request( + method=httpx.get, + url=url, + params=params, + ) + else: + return self._async_request( + method=self.client.get, + url=url, + params=params, ) def post_request(self, url, data=None, auth=False, params=None, files=None): @@ -174,19 +168,21 @@ def post_request(self, url, data=None, auth=False, params=None, files=None): if self.api_token: params["key"] = self.api_token - try: - resp = httpx.post(url, data=data, params=params, files=files) - if resp.status_code == 401: - error_msg = resp.json()["message"] - raise ApiAuthorizationError( - "ERROR: POST HTTP 401 - Authorization error {0}. MSG: {1}".format( - url, error_msg - ) - ) - return resp - except ConnectError: - raise ConnectError( - "ERROR: POST - Could not establish connection to API: {0}".format(url) + if self.client is None: + return self._sync_request( + method=httpx.post, + url=url, + data=data, + params=params, + files=files, + ) + else: + return self._async_request( + method=self.client.post, + url=url, + data=data, + params=params, + files=files, ) def put_request(self, url, data=None, auth=False, params=None): @@ -215,19 +211,19 @@ def put_request(self, url, data=None, auth=False, params=None): if self.api_token: params["key"] = self.api_token - try: - resp = httpx.put(url, data=data, params=params) - if resp.status_code == 401: - error_msg = resp.json()["message"] - raise ApiAuthorizationError( - "ERROR: PUT HTTP 401 - Authorization error {0}. MSG: {1}".format( - url, error_msg - ) - ) - return resp - except ConnectError: - raise ConnectError( - "ERROR: PUT - Could not establish connection to api '{0}'.".format(url) + if self.client is None: + return self._sync_request( + method=httpx.put, + url=url, + data=data, + params=params, + ) + else: + return self._async_request( + method=self.client.put, + url=url, + data=data, + params=params, ) def delete_request(self, url, auth=False, params=None): @@ -254,13 +250,141 @@ def delete_request(self, url, auth=False, params=None): if self.api_token: params["key"] = self.api_token + if self.client is None: + return self._sync_request( + method=httpx.delete, + url=url, + params=params, + ) + else: + return self._async_request( + method=self.client.delete, + url=url, + params=params, + ) + + def _sync_request( + self, + method, + **kwargs, + ): + """ + Sends a synchronous request to the specified URL using the specified HTTP method. + + Args: + method (function): The HTTP method to use for the request. + **kwargs: Additional keyword arguments to be passed to the method. + + Returns: + requests.Response: The response object returned by the request. + + Raises: + ApiAuthorizationError: If the response status code is 401 (Authorization error). + ConnectError: If a connection to the API cannot be established. + """ + assert "url" in kwargs, "URL is required for a request." + + kwargs = self._filter_kwargs(kwargs) + try: - return httpx.delete(url, params=params) + resp = method(**kwargs) + + if resp.status_code == 401: + error_msg = resp.json()["message"] + raise ApiAuthorizationError( + "ERROR: HTTP 401 - Authorization error {0}. MSG: {1}".format( + kwargs["url"], error_msg + ) + ) + + return resp + except ConnectError: raise ConnectError( - "ERROR: DELETE could not establish connection to api {}.".format(url) + "ERROR - Could not establish connection to api '{0}'.".format( + kwargs["url"] + ) ) + async def _async_request( + self, + method, + **kwargs, + ): + """ + Sends an asynchronous request to the specified URL using the specified HTTP method. + + Args: + method (callable): The HTTP method to use for the request. + **kwargs: Additional keyword arguments to be passed to the method. + + Raises: + ApiAuthorizationError: If the response status code is 401 (Authorization error). + ConnectError: If a connection to the API cannot be established. + + Returns: + The response object. + + """ + assert "url" in kwargs, "URL is required for a request." + + kwargs = self._filter_kwargs(kwargs) + + try: + resp = await method(**kwargs) + + if resp.status_code == 401: + error_msg = resp.json()["message"] + raise ApiAuthorizationError( + "ERROR: HTTP 401 - Authorization error {0}. MSG: {1}".format( + kwargs["url"], error_msg + ) + ) + + return resp + + except ConnectError: + raise ConnectError( + "ERROR - Could not establish connection to api '{0}'.".format( + kwargs["url"] + ) + ) + + @staticmethod + def _filter_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: + """ + Filters out any keyword arguments that are `None` from the specified dictionary. + + Args: + kwargs (Dict[str, Any]): The dictionary to filter. + + Returns: + Dict[str, Any]: The filtered dictionary. + """ + return {k: v for k, v in kwargs.items() if v is not None} + + async def __aenter__(self): + """ + Context manager method that initializes an instance of httpx.AsyncClient. + + Returns: + httpx.AsyncClient: An instance of httpx.AsyncClient. + """ + self.client = httpx.AsyncClient() + + async def __aexit__(self, exc_type, exc_value, traceback): + """ + Closes the client connection when exiting a context manager. + + Args: + exc_type (type): The type of the exception raised, if any. + exc_value (Exception): The exception raised, if any. + traceback (traceback): The traceback object associated with the exception, if any. + """ + + await self.client.aclose() + self.client = None + class DataAccessApi(Api): """Class to access Dataverse's Data Access API. diff --git a/tests/api/test_async_api.py b/tests/api/test_async_api.py new file mode 100644 index 0000000..61b1910 --- /dev/null +++ b/tests/api/test_async_api.py @@ -0,0 +1,16 @@ +import asyncio +import pytest + + +class TestAsyncAPI: + + @pytest.mark.asyncio + async def test_async_api(self, native_api): + + async with native_api: + tasks = [native_api.get_info_version() for _ in range(10)] + responses = await asyncio.gather(*tasks) + + assert len(responses) == 10 + for response in responses: + assert response.status_code == 200, "Request failed."