From 597ca5bebc1f2e858ef6410a98aac715ef4c0378 Mon Sep 17 00:00:00 2001 From: Kevin James Date: Fri, 23 Dec 2022 13:53:14 -0600 Subject: [PATCH] refactor(lint): enable add-trailing-comma flake8-commas is deprecated --- .pre-commit-config.yaml | 43 +- auth/gcloud/aio/auth/iam.py | 96 +-- auth/gcloud/aio/auth/session.py | 263 +++++---- auth/gcloud/aio/auth/token.py | 75 ++- auth/tests/integration/smoke_test.py | 34 +- auth/tests/unit/token_test.py | 8 +- auth/tests/unit/utils_test.py | 10 +- bigquery/gcloud/aio/bigquery/bigquery.py | 29 +- bigquery/gcloud/aio/bigquery/dataset.py | 52 +- bigquery/gcloud/aio/bigquery/job.py | 75 ++- bigquery/gcloud/aio/bigquery/table.py | 168 ++++-- bigquery/gcloud/aio/bigquery/utils.py | 43 +- bigquery/tests/integration/smoke_test.py | 62 +- bigquery/tests/unit/bigquery_test.py | 9 +- bigquery/tests/unit/utils_test.py | 556 ++++++++++++------ datastore/gcloud/aio/datastore/datastore.py | 240 +++++--- .../aio/datastore/datastore_operation.py | 16 +- datastore/gcloud/aio/datastore/entity.py | 43 +- datastore/gcloud/aio/datastore/filter.py | 18 +- datastore/gcloud/aio/datastore/key.py | 38 +- datastore/gcloud/aio/datastore/lat_lng.py | 3 +- datastore/gcloud/aio/datastore/mutation.py | 8 +- .../gcloud/aio/datastore/property_order.py | 9 +- datastore/gcloud/aio/datastore/query.py | 124 ++-- datastore/gcloud/aio/datastore/value.py | 21 +- datastore/tests/integration/smoke_test.py | 154 +++-- datastore/tests/unit/filter_test.py | 39 +- datastore/tests/unit/gql_query_test.py | 32 +- datastore/tests/unit/property_order_test.py | 4 +- datastore/tests/unit/query_test.py | 23 +- datastore/tests/unit/value_test.py | 124 ++-- kms/gcloud/aio/kms/kms.py | 30 +- pubsub/gcloud/aio/pubsub/metrics.py | 22 +- pubsub/gcloud/aio/pubsub/metrics_agent.py | 16 +- pubsub/gcloud/aio/pubsub/publisher_client.py | 59 +- pubsub/gcloud/aio/pubsub/subscriber.py | 379 +++++++----- pubsub/gcloud/aio/pubsub/subscriber_client.py | 103 ++-- .../gcloud/aio/pubsub/subscriber_message.py | 42 +- pubsub/gcloud/aio/pubsub/utils.py | 6 +- pubsub/tests/unit/subscriber_test.py | 321 ++++++---- pubsub/tests/unit/subscription_test.py | 17 +- storage/gcloud/aio/storage/blob.py | 113 ++-- storage/gcloud/aio/storage/bucket.py | 45 +- storage/gcloud/aio/storage/storage.py | 427 ++++++++------ .../tests/integration/download_range_test.py | 37 +- .../tests/integration/download_stream_test.py | 19 +- storage/tests/integration/metadata_test.py | 150 +++-- storage/tests/integration/signed_url_test.py | 6 +- storage/tests/integration/smoke_test.py | 28 +- .../integration/upload_multipart_test.py | 41 +- .../integration/upload_resumable_test.py | 41 +- storage/tests/unit/upload_retry_test.py | 3 +- taskqueue/gcloud/aio/taskqueue/queue.py | 56 +- taskqueue/tests/integration/conftest.py | 14 +- taskqueue/tests/integration/pushqueue_test.py | 2 +- 55 files changed, 2799 insertions(+), 1597 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8331d2571..7f86217ba 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,6 +35,10 @@ repos: - id: double-quote-string-fixer - id: name-tests-test - id: requirements-txt-fixer +- repo: https://github.com/asottile/yesqa + rev: v1.4.0 + hooks: + - id: yesqa - repo: https://github.com/PyCQA/pylint rev: v2.15.9 hooks: @@ -73,6 +77,25 @@ repos: hooks: - id: reorder-python-imports args: [--py26-plus] +- repo: https://github.com/asottile/add-trailing-comma + rev: v2.4.0 + hooks: + - id: add-trailing-comma +- repo: https://github.com/PyCQA/flake8 + rev: 5.0.4 + hooks: + - id: flake8 + additional_dependencies: + - flake8-2020==1.7.0 + - flake8-broken-line==0.6.0 + - flake8-comprehensions==3.10.1 + - importlib-metadata<5 # TODO: nuke once we upgrade past py3.7 + args: + - --ignore=A003,E501,W503,F401,F811 +- repo: https://github.com/pre-commit/mirrors-autopep8 + rev: v2.0.1 + hooks: + - id: autopep8 - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.991 hooks: @@ -151,26 +174,6 @@ repos: - types-pkg_resources - types-requests files: taskqueue/ -- repo: https://github.com/asottile/yesqa - rev: v1.4.0 - hooks: - - id: yesqa -- repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v2.0.1 - hooks: - - id: autopep8 -- repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 - hooks: - - id: flake8 - additional_dependencies: - - flake8-2020==1.7.0 - - flake8-broken-line==0.6.0 - # - flake8-commas==2.1.0 # TODO: enable me - - flake8-comprehensions==3.10.1 - - importlib-metadata<5 # TODO: nuke once we upgrade past py3.7 - args: - - --ignore=A003,E501,W503,F401,F811 - repo: local hooks: - &poetry-check diff --git a/auth/gcloud/aio/auth/iam.py b/auth/gcloud/aio/auth/iam.py index 4fdf00dfe..20d73a262 100644 --- a/auth/gcloud/aio/auth/iam.py +++ b/auth/gcloud/aio/auth/iam.py @@ -25,18 +25,25 @@ class IamClient: - def __init__(self, service_file: Optional[Union[str, IO[AnyStr]]] = None, - session: Optional[Session] = None, - token: Optional[Token] = None) -> None: + def __init__( + self, service_file: Optional[Union[str, IO[AnyStr]]] = None, + session: Optional[Session] = None, + token: Optional[Token] = None, + ) -> None: self.session = AioSession(session) self.token = token or Token( service_file=service_file, scopes=SCOPES, - session=self.session.session) # type: ignore[arg-type] - - if self.token.token_type not in {Type.GCE_METADATA, - Type.SERVICE_ACCOUNT}: - raise TypeError('IAM Credentials Client is only valid for use ' - 'with Service Accounts or GCE Metadata') + session=self.session.session, # type: ignore[arg-type] + ) + + if self.token.token_type not in { + Type.GCE_METADATA, + Type.SERVICE_ACCOUNT, + }: + raise TypeError( + 'IAM Credentials Client is only valid for use ' + 'with Service Accounts or GCE Metadata', + ) async def headers(self) -> Dict[str, str]: token = await self.token.get() @@ -49,22 +56,28 @@ def service_account_email(self) -> Optional[str]: return self.token.service_data.get('client_email') # https://cloud.google.com/iam/reference/rest/v1/projects.serviceAccounts.keys/get - async def get_public_key(self, key_id: Optional[str] = None, - key: Optional[str] = None, - service_account_email: Optional[str] = None, - project: Optional[str] = None, - session: Optional[Session] = None, - timeout: int = 10) -> Dict[str, str]: - service_account_email = (service_account_email - or self.service_account_email) + async def get_public_key( + self, key_id: Optional[str] = None, + key: Optional[str] = None, + service_account_email: Optional[str] = None, + project: Optional[str] = None, + session: Optional[Session] = None, + timeout: int = 10, + ) -> Dict[str, str]: + service_account_email = ( + service_account_email + or self.service_account_email + ) project = project or await self.token.get_project() if not key_id and not key: raise ValueError('get_public_key must have either key_id or key') if not key: - key = (f'projects/{project}/serviceAccounts/' - f'{service_account_email}/keys/{key_id}') + key = ( + f'projects/{project}/serviceAccounts/' + f'{service_account_email}/keys/{key_id}' + ) url = f'{API_ROOT_IAM}/{key}?publicKeyType=TYPE_X509_PEM_FILE' headers = await self.headers() @@ -81,13 +94,18 @@ async def list_public_keys( self, service_account_email: Optional[str] = None, project: Optional[str] = None, session: Optional[Session] = None, - timeout: int = 10) -> List[Dict[str, str]]: - service_account_email = (service_account_email - or self.service_account_email) + timeout: int = 10, + ) -> List[Dict[str, str]]: + service_account_email = ( + service_account_email + or self.service_account_email + ) project = project or await self.token.get_project() - url = (f'{API_ROOT_IAM}/projects/{project}/' - f'serviceAccounts/{service_account_email}/keys') + url = ( + f'{API_ROOT_IAM}/projects/{project}/' + f'serviceAccounts/{service_account_email}/keys' + ) headers = await self.headers() @@ -99,16 +117,22 @@ async def list_public_keys( return data # https://cloud.google.com/iam/credentials/reference/rest/v1/projects.serviceAccounts/signBlob - async def sign_blob(self, payload: Optional[Union[str, bytes]], - service_account_email: Optional[str] = None, - delegates: Optional[List[str]] = None, - session: Optional[Session] = None, - timeout: int = 10) -> Dict[str, str]: - service_account_email = (service_account_email - or self.service_account_email) + async def sign_blob( + self, payload: Optional[Union[str, bytes]], + service_account_email: Optional[str] = None, + delegates: Optional[List[str]] = None, + session: Optional[Session] = None, + timeout: int = 10, + ) -> Dict[str, str]: + service_account_email = ( + service_account_email + or self.service_account_email + ) if not service_account_email: - raise TypeError('sign_blob must have a valid ' - 'service_account_email') + raise TypeError( + 'sign_blob must have a valid ' + 'service_account_email', + ) resource_name = f'projects/-/serviceAccounts/{service_account_email}' url = f'{API_ROOT_IAM_CREDENTIALS}/{resource_name}:signBlob' @@ -126,8 +150,10 @@ async def sign_blob(self, payload: Optional[Union[str, bytes]], s = AioSession(session) if session else self.session - resp = await s.post(url=url, data=json_str, headers=headers, - timeout=timeout) + resp = await s.post( + url=url, data=json_str, headers=headers, + timeout=timeout, + ) data: Dict[str, Any] = await resp.json() return data diff --git a/auth/gcloud/aio/auth/session.py b/auth/gcloud/aio/auth/session.py index df10e68f2..c51f1b215 100644 --- a/auth/gcloud/aio/auth/session.py +++ b/auth/gcloud/aio/auth/session.py @@ -26,8 +26,10 @@ class BaseSession: __metaclass__ = ABCMeta - def __init__(self, session: Optional[Session] = None, timeout: float = 10, - verify_ssl: bool = True) -> None: + def __init__( + self, session: Optional[Session] = None, timeout: float = 10, + verify_ssl: bool = True, + ) -> None: self._shared_session = bool(session) self._session = session self._ssl = verify_ssl @@ -38,38 +40,50 @@ def session(self) -> Optional[Session]: return self._session @abstractmethod - async def post(self, url: str, headers: Dict[str, str], - data: Optional[Union[bytes, str]], timeout: float, - params: Optional[Dict[str, Union[int, str]]]) -> Response: + async def post( + self, url: str, headers: Dict[str, str], + data: Optional[Union[bytes, str]], timeout: float, + params: Optional[Dict[str, Union[int, str]]], + ) -> Response: pass @abstractmethod - async def get(self, url: str, headers: Optional[Dict[str, str]], - timeout: float, params: Optional[Dict[str, Union[int, str]]], - stream: bool) -> Response: + async def get( + self, url: str, headers: Optional[Dict[str, str]], + timeout: float, params: Optional[Dict[str, Union[int, str]]], + stream: bool, + ) -> Response: pass @abstractmethod - async def patch(self, url: str, headers: Dict[str, str], - data: Optional[str], timeout: float, - params: Optional[Dict[str, Union[int, str]]]) -> Response: + async def patch( + self, url: str, headers: Dict[str, str], + data: Optional[str], timeout: float, + params: Optional[Dict[str, Union[int, str]]], + ) -> Response: pass @abstractmethod - async def put(self, url: str, headers: Dict[str, str], data: IO[Any], - timeout: float) -> Response: + async def put( + self, url: str, headers: Dict[str, str], data: IO[Any], + timeout: float, + ) -> Response: pass @abstractmethod - async def delete(self, url: str, headers: Dict[str, str], - params: Optional[Dict[str, Union[int, str]]], - timeout: float) -> Response: + async def delete( + self, url: str, headers: Dict[str, str], + params: Optional[Dict[str, Union[int, str]]], + timeout: float, + ) -> Response: pass @abstractmethod - async def request(self, method: str, url: str, headers: Dict[str, str], - auto_raise_for_status: bool = True, - **kwargs: Any) -> Response: + async def request( + self, method: str, url: str, headers: Dict[str, str], + auto_raise_for_status: bool = True, + **kwargs: Any + ) -> Response: pass @abstractmethod @@ -106,10 +120,12 @@ async def _raise_for_status(resp: aiohttp.ClientResponse) -> None: # Google's error messages are useful, pass 'em through body = await resp.text(errors='replace') resp.release() - raise aiohttp.ClientResponseError(resp.request_info, resp.history, - status=resp.status, - message=f'{resp.reason}: {body}', - headers=resp.headers) + raise aiohttp.ClientResponseError( + resp.request_info, resp.history, + status=resp.status, + message=f'{resp.reason}: {body}', + headers=resp.headers, + ) class AioSession(BaseSession): _session: aiohttp.ClientSession # type: ignore[assignment] @@ -125,68 +141,94 @@ def session(self) -> aiohttp.ClientSession: # type: ignore[override] else: timeout = aiohttp.ClientTimeout(total=self._timeout) - self._session = aiohttp.ClientSession(connector=connector, - timeout=timeout) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=timeout, + ) return self._session - async def post(self, url: str, # type: ignore[override] - headers: Dict[str, str], - data: Optional[Union[bytes, str]] = None, - timeout: Timeout = 10, - params: Optional[Dict[str, Union[int, str]]] = None - ) -> aiohttp.ClientResponse: - resp = await self.session.post(url, data=data, headers=headers, - timeout=timeout, params=params) + async def post( # type: ignore[override] + self, url: str, + headers: Dict[str, str], + data: Optional[Union[bytes, str]] = None, + timeout: Timeout = 10, + params: Optional[Dict[str, Union[int, str]]] = None, + ) -> aiohttp.ClientResponse: + resp = await self.session.post( + url, data=data, headers=headers, + timeout=timeout, params=params, + ) await _raise_for_status(resp) return resp - async def get(self, url: str, # type: ignore[override] - headers: Optional[Dict[str, str]] = None, - timeout: Timeout = 10, - params: Optional[Dict[str, Union[int, str]]] = None, - stream: Optional[bool] = None) -> aiohttp.ClientResponse: + async def get( # type: ignore[override] + self, url: str, + headers: Optional[Dict[str, str]] = None, + timeout: Timeout = 10, + params: Optional[Dict[str, Union[int, str]]] = None, + stream: Optional[bool] = None, + ) -> aiohttp.ClientResponse: if stream is not None: - log.warning('passed unused argument stream=%s to AioSession: ' - 'this argument is only used by SyncSession', - stream) - resp = await self.session.get(url, headers=headers, - timeout=timeout, params=params) + log.warning( + 'passed unused argument stream=%s to AioSession: ' + 'this argument is only used by SyncSession', + stream, + ) + resp = await self.session.get( + url, headers=headers, + timeout=timeout, params=params, + ) await _raise_for_status(resp) return resp - async def patch(self, url: str, # type: ignore[override] - headers: Dict[str, str], data: Optional[str] = None, - timeout: Timeout = 10, - params: Optional[Dict[str, Union[int, str]]] = None - ) -> aiohttp.ClientResponse: - resp = await self.session.patch(url, data=data, headers=headers, - timeout=timeout, params=params) + async def patch( # type: ignore[override] + self, url: str, + headers: Dict[str, str], data: Optional[str] = None, + timeout: Timeout = 10, + params: Optional[Dict[str, Union[int, str]]] = None, + ) -> aiohttp.ClientResponse: + resp = await self.session.patch( + url, data=data, headers=headers, + timeout=timeout, params=params, + ) await _raise_for_status(resp) return resp - async def put(self, url: str, # type: ignore[override] - headers: Dict[str, str], data: IO[Any], - timeout: Timeout = 10) -> aiohttp.ClientResponse: - resp = await self.session.put(url, data=data, headers=headers, - timeout=timeout) + async def put( # type: ignore[override] + self, url: str, + headers: Dict[str, str], data: IO[Any], + timeout: Timeout = 10, + ) -> aiohttp.ClientResponse: + resp = await self.session.put( + url, data=data, headers=headers, + timeout=timeout, + ) await _raise_for_status(resp) return resp - async def delete(self, url: str, # type: ignore[override] - headers: Dict[str, str], - params: Optional[Dict[str, Union[int, str]]] = None, - timeout: Timeout = 10) -> aiohttp.ClientResponse: - resp = await self.session.delete(url, headers=headers, - params=params, timeout=timeout) + async def delete( # type: ignore[override] + self, url: str, + headers: Dict[str, str], + params: Optional[Dict[str, Union[int, str]]] = None, + timeout: Timeout = 10, + ) -> aiohttp.ClientResponse: + resp = await self.session.delete( + url, headers=headers, + params=params, timeout=timeout, + ) await _raise_for_status(resp) return resp - async def request(self, method: str, # type: ignore[override] - url: str, headers: Dict[str, str], - auto_raise_for_status: bool = True, - **kwargs: Any) -> aiohttp.ClientResponse: - resp = await self.session.request(method, url, headers=headers, - **kwargs) + async def request( # type: ignore[override] + self, method: str, + url: str, headers: Dict[str, str], + auto_raise_for_status: bool = True, + **kwargs: Any + ) -> aiohttp.ClientResponse: + resp = await self.session.request( + method, url, headers=headers, + **kwargs + ) if auto_raise_for_status: await _raise_for_status(resp) return resp @@ -214,60 +256,81 @@ def session(self) -> Session: # N.B.: none of these will be `async` in compiled form, but adding the # symbol ensures we match the base class's definition for static # analysis. - async def post(self, url: str, headers: Dict[str, str], - data: Optional[Union[bytes, str]] = None, - timeout: float = 10, - params: Optional[Dict[str, Union[int, str]]] = None - ) -> Response: + async def post( + self, url: str, headers: Dict[str, str], + data: Optional[Union[bytes, str]] = None, + timeout: float = 10, + params: Optional[Dict[str, Union[int, str]]] = None, + ) -> Response: with self.google_api_lock: - resp = self.session.post(url, data=data, headers=headers, - timeout=timeout, params=params) + resp = self.session.post( + url, data=data, headers=headers, + timeout=timeout, params=params, + ) resp.raise_for_status() return resp - async def get(self, url: str, headers: Optional[Dict[str, str]] = None, - timeout: float = 10, - params: Optional[Dict[str, Union[int, str]]] = None, - stream: bool = False) -> Response: + async def get( + self, url: str, headers: Optional[Dict[str, str]] = None, + timeout: float = 10, + params: Optional[Dict[str, Union[int, str]]] = None, + stream: bool = False, + ) -> Response: with self.google_api_lock: - resp = self.session.get(url, headers=headers, timeout=timeout, - params=params, stream=stream) + resp = self.session.get( + url, headers=headers, timeout=timeout, + params=params, stream=stream, + ) resp.raise_for_status() return resp - async def patch(self, url: str, headers: Dict[str, str], - data: Optional[str] = None, timeout: float = 10, - params: Optional[Dict[str, Union[int, str]]] = None - ) -> Response: + async def patch( + self, url: str, headers: Dict[str, str], + data: Optional[str] = None, timeout: float = 10, + params: Optional[Dict[str, Union[int, str]]] = None, + ) -> Response: with self.google_api_lock: - resp = self.session.patch(url, data=data, headers=headers, - timeout=timeout, params=params) + resp = self.session.patch( + url, data=data, headers=headers, + timeout=timeout, params=params, + ) resp.raise_for_status() return resp - async def put(self, url: str, headers: Dict[str, str], data: IO[Any], - timeout: float = 10) -> Response: + async def put( + self, url: str, headers: Dict[str, str], data: IO[Any], + timeout: float = 10, + ) -> Response: with self.google_api_lock: - resp = self.session.put(url, data=data, headers=headers, - timeout=timeout) + resp = self.session.put( + url, data=data, headers=headers, + timeout=timeout, + ) resp.raise_for_status() return resp - async def delete(self, url: str, headers: Dict[str, str], - params: Optional[Dict[str, Union[int, str]]] = None, - timeout: float = 10) -> Response: + async def delete( + self, url: str, headers: Dict[str, str], + params: Optional[Dict[str, Union[int, str]]] = None, + timeout: float = 10, + ) -> Response: with self.google_api_lock: - resp = self.session.delete(url, params=params, headers=headers, - timeout=timeout) + resp = self.session.delete( + url, params=params, headers=headers, + timeout=timeout, + ) resp.raise_for_status() return resp - async def request(self, method: str, url: str, headers: Dict[str, str], - auto_raise_for_status: bool = True, **kwargs: Any - ) -> Response: + async def request( + self, method: str, url: str, headers: Dict[str, str], + auto_raise_for_status: bool = True, **kwargs: Any + ) -> Response: with self.google_api_lock: - resp = self.session.request(method, url, headers=headers, - **kwargs) + resp = self.session.request( + method, url, headers=headers, + **kwargs + ) if auto_raise_for_status: resp.raise_for_status() return resp diff --git a/auth/gcloud/aio/auth/token.py b/auth/gcloud/aio/auth/token.py index e495f8d10..540cf39f7 100644 --- a/auth/gcloud/aio/auth/token.py +++ b/auth/gcloud/aio/auth/token.py @@ -48,8 +48,10 @@ GCE_METADATA_BASE = 'http://metadata.google.internal/computeMetadata/v1' GCE_METADATA_HEADERS = {'metadata-flavor': 'Google'} GCE_ENDPOINT_PROJECT = (f'{GCE_METADATA_BASE}/project/project-id') -GCE_ENDPOINT_TOKEN = (f'{GCE_METADATA_BASE}/instance/service-accounts' - '/default/token?recursive=true') +GCE_ENDPOINT_TOKEN = ( + f'{GCE_METADATA_BASE}/instance/service-accounts' + '/default/token?recursive=true' +) GCLOUD_TOKEN_DURATION = 3600 REFRESH_HEADERS = {'Content-Type': 'application/x-www-form-urlencoded'} @@ -61,7 +63,8 @@ class Type(enum.Enum): def get_service_data( - service: Optional[Union[str, IO[AnyStr]]]) -> Dict[str, Any]: + service: Optional[Union[str, IO[AnyStr]]], +) -> Dict[str, Any]: """ Get the service data dictionary for the current auth method. @@ -82,14 +85,18 @@ def get_service_data( if cloudsdk_config is not None: sdkpath = cloudsdk_config elif os.name != 'nt': - sdkpath = os.path.join(os.path.expanduser('~'), '.config', - 'gcloud') + sdkpath = os.path.join( + os.path.expanduser('~'), '.config', + 'gcloud', + ) else: try: sdkpath = os.path.join(os.environ['APPDATA'], 'gcloud') except KeyError: - sdkpath = os.path.join(os.environ.get('SystemDrive', 'C:'), - '\\', 'gcloud') + sdkpath = os.path.join( + os.environ.get('SystemDrive', 'C:'), + '\\', 'gcloud', + ) service = os.path.join(sdkpath, 'application_default_credentials.json') set_explicitly = bool(cloudsdk_config) @@ -103,8 +110,10 @@ def get_service_data( # also support passing IO objects directly rather than strictly paths # on disk try: - with open(service, # type: ignore[arg-type] - encoding='utf-8') as f: + with open( + service, # type: ignore[arg-type] + encoding='utf-8', + ) as f: data: Dict[str, Any] = json.loads(f.read()) return data except TypeError: @@ -125,14 +134,17 @@ def get_service_data( class Token: # pylint: disable=too-many-instance-attributes - def __init__(self, service_file: Optional[Union[str, IO[AnyStr]]] = None, - session: Optional[Session] = None, - scopes: Optional[List[str]] = None) -> None: + def __init__( + self, service_file: Optional[Union[str, IO[AnyStr]]] = None, + session: Optional[Session] = None, + scopes: Optional[List[str]] = None, + ) -> None: self.service_data = get_service_data(service_file) if self.service_data: self.token_type = Type(self.service_data['type']) self.token_uri = self.service_data.get( - 'token_uri', 'https://oauth2.googleapis.com/token') + 'token_uri', 'https://oauth2.googleapis.com/token', + ) else: # At this point, all we can do is assume we're running somewhere # with default credentials, eg. GCE. @@ -142,8 +154,10 @@ def __init__(self, service_file: Optional[Union[str, IO[AnyStr]]] = None, self.session = AioSession(session) self.scopes = ' '.join(scopes or []) if self.token_type == Type.SERVICE_ACCOUNT and not self.scopes: - raise Exception('scopes must be provided when token type is ' - 'service account') + raise Exception( + 'scopes must be provided when token type is ' + 'service account', + ) self.access_token: Optional[str] = None self.access_token_duration = 0 @@ -152,14 +166,18 @@ def __init__(self, service_file: Optional[Union[str, IO[AnyStr]]] = None, self.acquiring: Optional['asyncio.Future[Any]'] = None async def get_project(self) -> Optional[str]: - project = (os.environ.get('GOOGLE_CLOUD_PROJECT') - or os.environ.get('GCLOUD_PROJECT') - or os.environ.get('APPLICATION_ID')) + project = ( + os.environ.get('GOOGLE_CLOUD_PROJECT') + or os.environ.get('GCLOUD_PROJECT') + or os.environ.get('APPLICATION_ID') + ) if self.token_type == Type.GCE_METADATA: await self.ensure_token() - resp = await self.session.get(GCE_ENDPOINT_PROJECT, timeout=10, - headers=GCE_METADATA_HEADERS) + resp = await self.session.get( + GCE_ENDPOINT_PROJECT, timeout=10, + headers=GCE_METADATA_HEADERS, + ) if not project: try: @@ -200,12 +218,14 @@ async def _refresh_authorized_user(self, timeout: int) -> Response: resp: Response = await self.session.post( # type: ignore[assignment] url=self.token_uri, data=payload, headers=REFRESH_HEADERS, - timeout=timeout) + timeout=timeout, + ) return resp async def _refresh_gce_metadata(self, timeout: int) -> Response: resp: Response = await self.session.get( # type: ignore[assignment] - url=self.token_uri, headers=GCE_METADATA_HEADERS, timeout=timeout) + url=self.token_uri, headers=GCE_METADATA_HEADERS, timeout=timeout, + ) return resp async def _refresh_service_account(self, timeout: int) -> Response: @@ -219,9 +239,11 @@ async def _refresh_service_account(self, timeout: int) -> Response: } # N.B. algorithm='RS256' requires an extra 240MB in dependencies... - assertion = jwt.encode(assertion_payload, - self.service_data['private_key'], - algorithm='RS256') + assertion = jwt.encode( + assertion_payload, + self.service_data['private_key'], + algorithm='RS256', + ) payload = urlencode({ 'assertion': assertion, 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer', @@ -229,7 +251,8 @@ async def _refresh_service_account(self, timeout: int) -> Response: resp: Response = await self.session.post( # type: ignore[assignment] self.token_uri, data=payload, headers=REFRESH_HEADERS, - timeout=timeout) + timeout=timeout, + ) return resp @backoff.on_exception(backoff.expo, Exception, max_tries=5) diff --git a/auth/tests/integration/smoke_test.py b/auth/tests/integration/smoke_test.py index 34c5e2f71..8acba1ecb 100644 --- a/auth/tests/integration/smoke_test.py +++ b/auth/tests/integration/smoke_test.py @@ -61,13 +61,17 @@ async def test_token_does_not_require_creds() -> None: # https://cloud.google.com/appengine/docs/standard/python/appidentity/#asserting_identity_to_third-party_services async def verify_signature(data, signature, key_name, iam_client): key_data = await iam_client.get_public_key(key_name) - cert = x509.load_pem_x509_certificate(decode(key_data['publicKeyData']), - backend=default_backend()) + cert = x509.load_pem_x509_certificate( + decode(key_data['publicKeyData']), + backend=default_backend(), + ) pubkey = cert.public_key() # raises on failure - pubkey.verify(decode(signature), data.encode(), padding.PKCS1v15(), - hashes.SHA256()) + pubkey.verify( + decode(signature), data.encode(), padding.PKCS1v15(), + hashes.SHA256(), + ) @pytest.mark.asyncio @@ -94,21 +98,29 @@ async def test_get_service_account_public_key(creds: str) -> None: async with Session(timeout=10) as s: iam_client = IamClient(service_file=creds, session=s) resp = await iam_client.list_public_keys(session=s) - pub_key_data = await iam_client.get_public_key(key=resp[0]['name'], - session=s) + pub_key_data = await iam_client.get_public_key( + key=resp[0]['name'], + session=s, + ) assert pub_key_data['name'] == resp[0]['name'] assert 'publicKeyData' in pub_key_data key_id = resp[0]['name'].split('/')[-1] - pub_key_by_key_id_data = await iam_client.get_public_key(key_id=key_id, - session=s) + pub_key_by_key_id_data = await iam_client.get_public_key( + key_id=key_id, + session=s, + ) # Sometimes, one or both keys will be created with "no" expiry. pub_key_time = pub_key_data.pop('validBeforeTime') pub_key_by_key_id_time = pub_key_by_key_id_data.pop('validBeforeTime') - assert (pub_key_time == pub_key_by_key_id_time - or '9999-12-31T23:59:59Z' in {pub_key_time, - pub_key_by_key_id_time}) + assert ( + pub_key_time == pub_key_by_key_id_time + or '9999-12-31T23:59:59Z' in { + pub_key_time, + pub_key_by_key_id_time, + } + ) assert pub_key_data == pub_key_by_key_id_data diff --git a/auth/tests/unit/token_test.py b/auth/tests/unit/token_test.py index 2844c66fb..64ae03d13 100644 --- a/auth/tests/unit/token_test.py +++ b/auth/tests/unit/token_test.py @@ -18,7 +18,7 @@ async def test_service_as_io(): 'auth_uri': 'https://accounts.google.com/o/oauth2/auth', 'token_uri': 'https://oauth2.googleapis.com/token', 'auth_provider_x509_cert_url': 'https://www.googleapis.com/oauth2/v1/certs', - 'client_x509_cert_url': 'https://www.googleapis.com/robot/v1/metadata/x509/gcloud-aio%40random-project-123.iam.gserviceaccount.com' + 'client_x509_cert_url': 'https://www.googleapis.com/robot/v1/metadata/x509/gcloud-aio%40random-project-123.iam.gserviceaccount.com', } # io.StringIO does not like str inputs in python2. So in `py3to2` step in @@ -26,8 +26,10 @@ async def test_service_as_io(): # turns this seemingly noop operation to allow the literal string to get # converted to unicode. service_file = io.StringIO(f'{json.dumps(service_data)}') - t = token.Token(service_file=service_file, - scopes=['https://google.com/random-scope']) + t = token.Token( + service_file=service_file, + scopes=['https://google.com/random-scope'], + ) assert t.token_type == token.Type.SERVICE_ACCOUNT assert t.token_uri == 'https://oauth2.googleapis.com/token' diff --git a/auth/tests/unit/utils_test.py b/auth/tests/unit/utils_test.py index 244dda4a1..7a7577f37 100644 --- a/auth/tests/unit/utils_test.py +++ b/auth/tests/unit/utils_test.py @@ -4,9 +4,13 @@ from gcloud.aio.auth import utils -@pytest.mark.parametrize('str_or_bytes', ['Hello Test', - 'UTF-8 Bytes'.encode('utf-8'), - pickle.dumps([])]) +@pytest.mark.parametrize( + 'str_or_bytes', [ + 'Hello Test', + 'UTF-8 Bytes'.encode('utf-8'), + pickle.dumps([]), + ], +) def test_encode_decode(str_or_bytes): encoded = utils.encode(str_or_bytes) expected = str_or_bytes diff --git a/bigquery/gcloud/aio/bigquery/bigquery.py b/bigquery/gcloud/aio/bigquery/bigquery.py index 41fc56c27..d62b4d75b 100644 --- a/bigquery/gcloud/aio/bigquery/bigquery.py +++ b/bigquery/gcloud/aio/bigquery/bigquery.py @@ -75,13 +75,16 @@ def __init__( self.session = AioSession(session) self.token = token or Token( service_file=service_file, scopes=SCOPES, - session=self.session.session) # type: ignore[arg-type] + session=self.session.session, # type: ignore[arg-type] + ) self._project = project if self._api_is_dev and not project: - self._project = (os.environ.get('BIGQUERY_PROJECT_ID') - or os.environ.get('GOOGLE_CLOUD_PROJECT') - or 'dev') + self._project = ( + os.environ.get('BIGQUERY_PROJECT_ID') + or os.environ.get('GOOGLE_CLOUD_PROJECT') + or 'dev' + ) async def project(self) -> str: if self._project: @@ -104,7 +107,8 @@ async def headers(self) -> Dict[str, str]: async def _post_json( self, url: str, body: Dict[str, Any], session: Optional[Session], - timeout: int) -> Dict[str, Any]: + timeout: int, + ) -> Dict[str, Any]: payload = json.dumps(body).encode('utf-8') headers = await self.headers() @@ -115,20 +119,25 @@ async def _post_json( s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, data=payload, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.post( + url, data=payload, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) data: Dict[str, Any] = await resp.json() return data async def _get_url( self, url: str, session: Optional[Session], timeout: int, - params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: headers = await self.headers() s = AioSession(session) if session else self.session - resp = await s.get(url, headers=headers, timeout=timeout, - params=params or {}) + resp = await s.get( + url, headers=headers, timeout=timeout, + params=params or {}, + ) data: Dict[str, Any] = await resp.json() return data diff --git a/bigquery/gcloud/aio/bigquery/dataset.py b/bigquery/gcloud/aio/bigquery/dataset.py index 4f222bd1f..48b2b2b65 100644 --- a/bigquery/gcloud/aio/bigquery/dataset.py +++ b/bigquery/gcloud/aio/bigquery/dataset.py @@ -26,29 +26,37 @@ def __init__( api_root: Optional[str] = None, ) -> None: self.dataset_name = dataset_name - super().__init__(project=project, service_file=service_file, - session=session, token=token, api_root=api_root) + super().__init__( + project=project, service_file=service_file, + session=session, token=token, api_root=api_root, + ) # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables/list async def list_tables( self, session: Optional[Session] = None, timeout: int = 60, - params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: """List tables in a dataset.""" project = await self.project() if not self.dataset_name: - raise ValueError('could not determine dataset,' - ' please set it manually') + raise ValueError( + 'could not determine dataset,' + ' please set it manually', + ) - url = (f'{self._api_root}/projects/{project}/datasets/' - f'{self.dataset_name}/tables') + url = ( + f'{self._api_root}/projects/{project}/datasets/' + f'{self.dataset_name}/tables' + ) return await self._get_url(url, session, timeout, params=params) # https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/list async def list_datasets( self, session: Optional[Session] = None, timeout: int = 60, - params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: """List datasets in current project.""" project = await self.project() @@ -56,23 +64,31 @@ async def list_datasets( return await self._get_url(url, session, timeout, params=params) # https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/get - async def get(self, session: Optional[Session] = None, - timeout: int = 60, - params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + async def get( + self, session: Optional[Session] = None, + timeout: int = 60, + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: """Get a specific dataset in current project.""" project = await self.project() if not self.dataset_name: - raise ValueError('could not determine dataset,' - ' please set it manually') + raise ValueError( + 'could not determine dataset,' + ' please set it manually', + ) - url = (f'{self._api_root}/projects/{project}/datasets/' - f'{self.dataset_name}') + url = ( + f'{self._api_root}/projects/{project}/datasets/' + f'{self.dataset_name}' + ) return await self._get_url(url, session, timeout, params=params) # https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/insert - async def insert(self, dataset: Dict[str, Any], - session: Optional[Session] = None, - timeout: int = 60) -> Dict[str, Any]: + async def insert( + self, dataset: Dict[str, Any], + session: Optional[Session] = None, + timeout: int = 60, + ) -> Dict[str, Any]: """Create datasets in current project.""" project = await self.project() diff --git a/bigquery/gcloud/aio/bigquery/job.py b/bigquery/gcloud/aio/bigquery/job.py index a41434401..969f9e968 100644 --- a/bigquery/gcloud/aio/bigquery/job.py +++ b/bigquery/gcloud/aio/bigquery/job.py @@ -26,8 +26,10 @@ def __init__( api_root: Optional[str] = None, ) -> None: self.job_id = job_id - super().__init__(project=project, service_file=service_file, - session=session, token=token, api_root=api_root) + super().__init__( + project=project, service_file=service_file, + session=session, token=token, api_root=api_root, + ) @staticmethod def _make_query_body( @@ -35,7 +37,8 @@ def _make_query_body( write_disposition: Disposition, use_query_cache: bool, dry_run: bool, use_legacy_sql: bool, - destination_table: Optional[Any]) -> Dict[str, Any]: + destination_table: Optional[Any], + ) -> Dict[str, Any]: return { 'configuration': { 'query': { @@ -54,8 +57,10 @@ def _make_query_body( } # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/get - async def get_job(self, session: Optional[Session] = None, - timeout: int = 60) -> Dict[str, Any]: + async def get_job( + self, session: Optional[Session] = None, + timeout: int = 60, + ) -> Dict[str, Any]: """Get the specified job resource by job ID.""" project = await self.project() @@ -64,10 +69,11 @@ async def get_job(self, session: Optional[Session] = None, return await self._get_url(url, session, timeout) # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/getQueryResults - async def get_query_results(self, session: Optional[Session] = None, - timeout: int = 60, - params: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + async def get_query_results( + self, session: Optional[Session] = None, + timeout: int = 60, + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: """Get the specified jobQueryResults by job ID.""" project = await self.project() @@ -76,20 +82,26 @@ async def get_query_results(self, session: Optional[Session] = None, return await self._get_url(url, session, timeout, params=params) # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/cancel - async def cancel(self, session: Optional[Session] = None, - timeout: int = 60) -> Dict[str, Any]: + async def cancel( + self, session: Optional[Session] = None, + timeout: int = 60, + ) -> Dict[str, Any]: """Cancel the specified job by job ID.""" project = await self.project() - url = (f'{self._api_root}/projects/{project}/queries/{self.job_id}' - '/cancel') + url = ( + f'{self._api_root}/projects/{project}/queries/{self.job_id}' + '/cancel' + ) return await self._post_json(url, {}, session, timeout) # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/query - async def query(self, query_request: Dict[str, Any], - session: Optional[Session] = None, - timeout: int = 60) -> Dict[str, Any]: + async def query( + self, query_request: Dict[str, Any], + session: Optional[Session] = None, + timeout: int = 60, + ) -> Dict[str, Any]: """Runs a query synchronously and returns query results if completes within a specified timeout.""" project = await self.project() @@ -98,9 +110,11 @@ async def query(self, query_request: Dict[str, Any], return await self._post_json(url, query_request, session, timeout) # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert - async def insert(self, job: Dict[str, Any], - session: Optional[Session] = None, - timeout: int = 60) -> Dict[str, Any]: + async def insert( + self, job: Dict[str, Any], + session: Optional[Session] = None, + timeout: int = 60, + ) -> Dict[str, Any]: """Insert a new asynchronous job.""" project = await self.project() url = f'{self._api_root}/projects/{project}/jobs' @@ -117,24 +131,29 @@ async def insert_via_query( write_disposition: Disposition = Disposition.WRITE_EMPTY, timeout: int = 60, use_query_cache: bool = True, dry_run: bool = False, use_legacy_sql: bool = True, - destination_table: Optional[Any] = None) -> Dict[str, Any]: + destination_table: Optional[Any] = None, + ) -> Dict[str, Any]: """Create table as a result of the query""" project = await self.project() url = f'{self._api_root}/projects/{project}/jobs' - body = self._make_query_body(query=query, - write_disposition=write_disposition, - use_query_cache=use_query_cache, - dry_run=dry_run, - use_legacy_sql=use_legacy_sql, - destination_table=destination_table) + body = self._make_query_body( + query=query, + write_disposition=write_disposition, + use_query_cache=use_query_cache, + dry_run=dry_run, + use_legacy_sql=use_legacy_sql, + destination_table=destination_table, + ) response = await self._post_json(url, body, session, timeout) if not dry_run: self.job_id = response['jobReference']['jobId'] return response - async def result(self, - session: Optional[Session] = None) -> Dict[str, Any]: + async def result( + self, + session: Optional[Session] = None, + ) -> Dict[str, Any]: data = await self.get_job(session) status = data.get('status', {}) if status.get('state') == 'DONE': diff --git a/bigquery/gcloud/aio/bigquery/table.py b/bigquery/gcloud/aio/bigquery/table.py index b93f39564..54173e5e4 100644 --- a/bigquery/gcloud/aio/bigquery/table.py +++ b/bigquery/gcloud/aio/bigquery/table.py @@ -37,8 +37,10 @@ def __init__( ) -> None: self.dataset_name = dataset_name self.table_name = table_name - super().__init__(project=project, service_file=service_file, - session=session, token=token, api_root=api_root) + super().__init__( + project=project, service_file=service_file, + session=session, token=token, api_root=api_root, + ) @staticmethod def _mk_unique_insert_id(row: Dict[str, Any]) -> str: @@ -48,7 +50,8 @@ def _mk_unique_insert_id(row: Dict[str, Any]) -> str: def _make_copy_body( self, source_project: str, destination_project: str, destination_dataset: str, - destination_table: str) -> Dict[str, Any]: + destination_table: str, + ) -> Dict[str, Any]: return { 'configuration': { 'copy': { @@ -62,24 +65,27 @@ def _make_copy_body( 'projectId': source_project, 'datasetId': self.dataset_name, 'tableId': self.table_name, - } - } - } + }, + }, + }, } @staticmethod def _make_insert_body( rows: List[Dict[str, Any]], *, skip_invalid: bool, ignore_unknown: bool, template_suffix: Optional[str], - insert_id_fn: Callable[[Dict[str, Any]], str]) -> Dict[str, Any]: + insert_id_fn: Callable[[Dict[str, Any]], str] + ) -> Dict[str, Any]: body = { 'kind': 'bigquery#tableDataInsertAllRequest', 'skipInvalidRows': skip_invalid, 'ignoreUnknownValues': ignore_unknown, - 'rows': [{ - 'insertId': insert_id_fn(row), - 'json': row, - } for row in rows], + 'rows': [ + { + 'insertId': insert_id_fn(row), + 'json': row, + } for row in rows + ], } if template_suffix is not None: @@ -92,7 +98,7 @@ def _make_load_body( source_format: SourceFormat, write_disposition: Disposition, ignore_unknown_values: bool, - schema_update_options: List[SchemaUpdateOption] + schema_update_options: List[SchemaUpdateOption], ) -> Dict[str, Any]: return { 'configuration': { @@ -103,7 +109,8 @@ def _make_load_body( 'sourceFormat': source_format.value, 'writeDisposition': write_disposition.value, 'schemaUpdateOptions': [ - e.value for e in schema_update_options], + e.value for e in schema_update_options + ], 'destinationTable': { 'projectId': project, 'datasetId': self.dataset_name, @@ -117,7 +124,8 @@ def _make_query_body( self, query: str, project: str, write_disposition: Disposition, use_query_cache: bool, - dry_run: bool) -> Dict[str, Any]: + dry_run: bool, + ) -> Dict[str, Any]: return { 'configuration': { 'query': { @@ -135,35 +143,43 @@ def _make_query_body( } # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables/insert - async def create(self, table: Dict[str, Any], - session: Optional[Session] = None, - timeout: int = 60) -> Dict[str, Any]: + async def create( + self, table: Dict[str, Any], + session: Optional[Session] = None, + timeout: int = 60, + ) -> Dict[str, Any]: """Create the table specified by tableId from the dataset.""" project = await self.project() - url = (f'{self._api_root}/projects/{project}/datasets/' - f'{self.dataset_name}/tables') + url = ( + f'{self._api_root}/projects/{project}/datasets/' + f'{self.dataset_name}/tables' + ) table['tableReference'] = { 'projectId': project, 'datasetId': self.dataset_name, - 'tableId': self.table_name + 'tableId': self.table_name, } return await self._post_json(url, table, session, timeout) # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables/patch - async def patch(self, table: Dict[str, Any], - session: Optional[Session] = None, - timeout: int = 60) -> Dict[str, Any]: + async def patch( + self, table: Dict[str, Any], + session: Optional[Session] = None, + timeout: int = 60, + ) -> Dict[str, Any]: """Patch an existing table specified by tableId from the dataset.""" project = await self.project() - url = (f'{self._api_root}/projects/{project}/datasets/' - f'{self.dataset_name}/tables/{self.table_name}') + url = ( + f'{self._api_root}/projects/{project}/datasets/' + f'{self.dataset_name}/tables/{self.table_name}' + ) table['tableReference'] = { 'projectId': project, 'datasetId': self.dataset_name, - 'tableId': self.table_name + 'tableId': self.table_name, } table_data = json.dumps(table).encode('utf-8') @@ -171,25 +187,33 @@ async def patch(self, table: Dict[str, Any], s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.patch(url, data=table_data, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.patch( + url, data=table_data, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) data: Dict[str, Any] = await resp.json() return data # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables/delete - async def delete(self, - session: Optional[Session] = None, - timeout: int = 60) -> Dict[str, Any]: + async def delete( + self, + session: Optional[Session] = None, + timeout: int = 60, + ) -> Dict[str, Any]: """Deletes the table specified by tableId from the dataset.""" project = await self.project() - url = (f'{self._api_root}/projects/{project}/datasets/' - f'{self.dataset_name}/tables/{self.table_name}') + url = ( + f'{self._api_root}/projects/{project}/datasets/' + f'{self.dataset_name}/tables/{self.table_name}' + ) headers = await self.headers() s = AioSession(session) if session else self.session - resp = await s.session.delete(url, headers=headers, params=None, - timeout=timeout) + resp = await s.session.delete( + url, headers=headers, params=None, + timeout=timeout, + ) try: data: Dict[str, Any] = await resp.json() except Exception: # pylint: disable=broad-except @@ -206,11 +230,14 @@ async def delete(self, # https://cloud.google.com/bigquery/docs/reference/rest/v2/tables/get async def get( self, session: Optional[Session] = None, - timeout: int = 60) -> Dict[str, Any]: + timeout: int = 60, + ) -> Dict[str, Any]: """Gets the specified table resource by table ID.""" project = await self.project() - url = (f'{self._api_root}/projects/{project}/datasets/' - f'{self.dataset_name}/tables/{self.table_name}') + url = ( + f'{self._api_root}/projects/{project}/datasets/' + f'{self.dataset_name}/tables/{self.table_name}' + ) return await self._get_url(url, session, timeout) @@ -238,13 +265,16 @@ async def insert( return {} project = await self.project() - url = (f'{self._api_root}/projects/{project}/datasets/' - f'{self.dataset_name}/tables/{self.table_name}/insertAll') + url = ( + f'{self._api_root}/projects/{project}/datasets/' + f'{self.dataset_name}/tables/{self.table_name}/insertAll' + ) body = self._make_insert_body( rows, skip_invalid=skip_invalid, ignore_unknown=ignore_unknown, template_suffix=template_suffix, - insert_id_fn=insert_id_fn or self._mk_unique_insert_id) + insert_id_fn=insert_id_fn or self._mk_unique_insert_id, + ) return await self._post_json(url, body, session, timeout) # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert @@ -252,18 +282,22 @@ async def insert( async def insert_via_copy( self, destination_project: str, destination_dataset: str, destination_table: str, session: Optional[Session] = None, - timeout: int = 60) -> Job: + timeout: int = 60, + ) -> Job: """Copy BQ table to another table in BQ""" project = await self.project() url = f'{self._api_root}/projects/{project}/jobs' body = self._make_copy_body( project, destination_project, - destination_dataset, destination_table) + destination_dataset, destination_table, + ) response = await self._post_json(url, body, session, timeout) - return Job(response['jobReference']['jobId'], self._project, - session=self.session.session, # type: ignore[arg-type] - token=self.token) + return Job( + response['jobReference']['jobId'], self._project, + session=self.session.session, # type: ignore[arg-type] + token=self.token, + ) # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert # https://cloud.google.com/bigquery/docs/reference/rest/v2/Job#JobConfigurationLoad @@ -274,7 +308,7 @@ async def insert_via_load( write_disposition: Disposition = Disposition.WRITE_TRUNCATE, timeout: int = 60, ignore_unknown_values: bool = False, - schema_update_options: Optional[List[SchemaUpdateOption]] = None + schema_update_options: Optional[List[SchemaUpdateOption]] = None, ) -> Job: """Loads entities from storage to BigQuery.""" project = await self.project() @@ -282,12 +316,14 @@ async def insert_via_load( body = self._make_load_body( source_uris, project, autodetect, source_format, write_disposition, - ignore_unknown_values, schema_update_options or [] + ignore_unknown_values, schema_update_options or [], ) response = await self._post_json(url, body, session, timeout) - return Job(response['jobReference']['jobId'], self._project, - session=self.session.session, # type: ignore[arg-type] - token=self.token) + return Job( + response['jobReference']['jobId'], self._project, + session=self.session.session, # type: ignore[arg-type] + token=self.token, + ) # https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs/insert # https://cloud.google.com/bigquery/docs/reference/rest/v2/Job#JobConfigurationQuery @@ -295,27 +331,37 @@ async def insert_via_query( self, query: str, session: Optional[Session] = None, write_disposition: Disposition = Disposition.WRITE_EMPTY, timeout: int = 60, use_query_cache: bool = True, - dry_run: bool = False) -> Job: + dry_run: bool = False, + ) -> Job: """Create table as a result of the query""" - warnings.warn('using Table#insert_via_query is deprecated.' - 'use Job#insert_via_query instead', DeprecationWarning) + warnings.warn( + 'using Table#insert_via_query is deprecated.' + 'use Job#insert_via_query instead', DeprecationWarning, + ) project = await self.project() url = f'{self._api_root}/projects/{project}/jobs' - body = self._make_query_body(query, project, write_disposition, - use_query_cache, dry_run) + body = self._make_query_body( + query, project, write_disposition, + use_query_cache, dry_run, + ) response = await self._post_json(url, body, session, timeout) job_id = response['jobReference']['jobId'] if not dry_run else None - return Job(job_id, self._project, token=self.token, - session=self.session.session) # type: ignore[arg-type] + return Job( + job_id, self._project, token=self.token, + session=self.session.session, # type: ignore[arg-type] + ) # https://cloud.google.com/bigquery/docs/reference/rest/v2/tabledata/list async def list_tabledata( self, session: Optional[Session] = None, timeout: int = 60, - params: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + params: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: """List the content of a table in rows.""" project = await self.project() - url = (f'{self._api_root}/projects/{project}/datasets/' - f'{self.dataset_name}/tables/{self.table_name}/data') + url = ( + f'{self._api_root}/projects/{project}/datasets/' + f'{self.dataset_name}/tables/{self.table_name}/data' + ) return await self._get_url(url, session, timeout, params) diff --git a/bigquery/gcloud/aio/bigquery/utils.py b/bigquery/gcloud/aio/bigquery/utils.py index 66ad704b6..04a33c91a 100644 --- a/bigquery/gcloud/aio/bigquery/utils.py +++ b/bigquery/gcloud/aio/bigquery/utils.py @@ -16,8 +16,10 @@ except AttributeError: # build our own UTC for Python 2 class UTC(datetime.tzinfo): - def utcoffset(self, - _dt: Optional[datetime.datetime]) -> datetime.timedelta: + def utcoffset( + self, + _dt: Optional[datetime.datetime], + ) -> datetime.timedelta: return datetime.timedelta(0) def tzname(self, _dt: Optional[datetime.datetime]) -> str: @@ -74,17 +76,20 @@ def parse(field: Dict[str, Any], value: Any) -> Any: try: convert: Callable[[Any], Any] = { # type: ignore[assignment] 'BIGNUMERIC': lambda x: decimal.Decimal( - x, decimal.Context(prec=77)), + x, decimal.Context(prec=77), + ), 'BOOLEAN': lambda x: x == 'true', 'BYTES': bytes, 'FLOAT': float, 'INTEGER': int, 'NUMERIC': lambda x: decimal.Decimal( - x, decimal.Context(prec=38)), + x, decimal.Context(prec=38), + ), 'RECORD': dict, 'STRING': str, 'TIMESTAMP': lambda x: datetime.datetime.fromtimestamp( - float(x), tz=utc), + float(x), tz=utc, + ), }[field['type']] except KeyError: # TODO: determine the proper methods for converting the following: @@ -92,9 +97,11 @@ def parse(field: Dict[str, Any], value: Any) -> Any: # DATETIME -> datetime? # GEOGRAPHY -> ?? # TIME -> datetime? - log.error('Unsupported field type %s. Please open a bug report with ' - 'the following data: %s, %s', field['type'], field['mode'], - flatten(value)) + log.error( + 'Unsupported field type %s. Please open a bug report with ' + 'the following data: %s, %s', field['type'], field['mode'], + flatten(value), + ) raise if field['mode'] == 'NULLABLE' and value is None: @@ -102,15 +109,19 @@ def parse(field: Dict[str, Any], value: Any) -> Any: if field['mode'] == 'REPEATED': if field['type'] == 'RECORD': - return [{f['name']: parse(f, x) - for f, x in zip(field['fields'], xs)} - for xs in flatten(value)] + return [{ + f['name']: parse(f, x) + for f, x in zip(field['fields'], xs) + } + for xs in flatten(value)] return [convert(x) for x in flatten(value)] if field['type'] == 'RECORD': - return {f['name']: parse(f, x) - for f, x in zip(field['fields'], flatten(value))} + return { + f['name']: parse(f, x) + for f, x in zip(field['fields'], flatten(value)) + } return convert(flatten(value)) @@ -128,5 +139,7 @@ def query_response_to_dict(response: Dict[str, Any]) -> List[Dict[str, Any]]: """ fields = response['schema'].get('fields', []) rows = [x['f'] for x in response.get('rows', [])] - return [{k['name']: parse(k, v) for k, v in zip(fields, row)} - for row in rows] + return [ + {k['name']: parse(k, v) for k, v in zip(fields, row)} + for row in rows + ] diff --git a/bigquery/tests/integration/smoke_test.py b/bigquery/tests/integration/smoke_test.py index d3caa28df..610cdcdca 100644 --- a/bigquery/tests/integration/smoke_test.py +++ b/bigquery/tests/integration/smoke_test.py @@ -19,21 +19,27 @@ @pytest.mark.asyncio -async def test_data_is_inserted(creds: str, dataset: str, project: str, - table: str) -> None: +async def test_data_is_inserted( + creds: str, dataset: str, project: str, + table: str, +) -> None: rows = [{'key': uuid.uuid4().hex, 'value': uuid.uuid4().hex} for _ in range(3)] async with Session() as s: # TODO: create this table (with a random name) - t = Table(dataset, table, project=project, service_file=creds, - session=s) + t = Table( + dataset, table, project=project, service_file=creds, + session=s, + ) await t.insert(rows) @pytest.mark.asyncio -async def test_table_load_copy(creds: str, dataset: str, project: str, - export_bucket_name: str) -> None: +async def test_table_load_copy( + creds: str, dataset: str, project: str, + export_bucket_name: str, +) -> None: # pylint: disable=too-many-locals # N.B. this test relies on Datastore.export -- see `test_datastore_export` # in the `gcloud-aio-datastore` smoke tests. @@ -44,15 +50,19 @@ async def test_table_load_copy(creds: str, dataset: str, project: str, async with Session() as s: ds = Datastore(project=project, service_file=creds, session=s) - await ds.insert(Key(project, [PathElement(kind)]), - properties={'rand_str': rand_uuid}) + await ds.insert( + Key(project, [PathElement(kind)]), + properties={'rand_str': rand_uuid}, + ) operation = await ds.export(export_bucket_name, kinds=[kind]) count = 0 - while (count < 10 - and operation - and operation.metadata['common']['state'] == 'PROCESSING'): + while ( + count < 10 + and operation + and operation.metadata['common']['state'] == 'PROCESSING' + ): await sleep(10) operation = await ds.get_datastore_operation(operation.name) count += 1 @@ -64,13 +74,19 @@ async def test_table_load_copy(creds: str, dataset: str, project: str, backup_entity_table = f'public_test_backup_entity_{uuid_}' copy_entity_table = f'{backup_entity_table}_copy' - t = Table(dataset, backup_entity_table, project=project, - service_file=creds, session=s) + t = Table( + dataset, backup_entity_table, project=project, + service_file=creds, session=s, + ) gs_prefix = operation.metadata['outputUrlPrefix'] - gs_file = (f'{gs_prefix}/all_namespaces/kind_{kind}/' - f'all_namespaces_kind_{kind}.export_metadata') - await t.insert_via_load([gs_file], - source_format=SourceFormat.DATASTORE_BACKUP) + gs_file = ( + f'{gs_prefix}/all_namespaces/kind_{kind}/' + f'all_namespaces_kind_{kind}.export_metadata' + ) + await t.insert_via_load( + [gs_file], + source_format=SourceFormat.DATASTORE_BACKUP, + ) await sleep(20) @@ -79,8 +95,10 @@ async def test_table_load_copy(creds: str, dataset: str, project: str, await t.insert_via_copy(project, dataset, copy_entity_table) await sleep(20) - t1 = Table(dataset, copy_entity_table, project=project, - service_file=creds, session=s) + t1 = Table( + dataset, copy_entity_table, project=project, + service_file=creds, session=s, + ) copy_table = await t1.get() assert copy_table['numRows'] == source_table['numRows'] @@ -94,7 +112,9 @@ async def test_table_load_copy(creds: str, dataset: str, project: str, export_path = operation.metadata['outputUrlPrefix'][prefix_len:] storage = Storage(service_file=creds, session=s) - files = await storage.list_objects(export_bucket_name, - params={'prefix': export_path}) + files = await storage.list_objects( + export_bucket_name, + params={'prefix': export_path}, + ) for file in files['items']: await storage.delete(export_bucket_name, file['name']) diff --git a/bigquery/tests/unit/bigquery_test.py b/bigquery/tests/unit/bigquery_test.py index 4a0221fb0..06723732e 100644 --- a/bigquery/tests/unit/bigquery_test.py +++ b/bigquery/tests/unit/bigquery_test.py @@ -5,7 +5,8 @@ def test_make_insert_body(): body = Table._make_insert_body( # pylint: disable=protected-access [{'foo': 'herp', 'bar': 42}, {'foo': 'derp', 'bar': 13}], skip_invalid=False, ignore_unknown=False, template_suffix=None, - insert_id_fn=lambda b: b['bar']) + insert_id_fn=lambda b: b['bar'], + ) expected = { 'kind': 'bigquery#tableDataInsertAllRequest', @@ -24,7 +25,8 @@ def test_make_insert_body_template_suffix(): body = Table._make_insert_body( # pylint: disable=protected-access [{'foo': 'herp', 'bar': 42}, {'foo': 'derp', 'bar': 13}], skip_invalid=False, ignore_unknown=False, template_suffix='suffix', - insert_id_fn=lambda b: b['bar']) + insert_id_fn=lambda b: b['bar'], + ) expected = { 'kind': 'bigquery#tableDataInsertAllRequest', @@ -45,7 +47,8 @@ def test_make_insert_body_defult_id_fn(): body = Table._make_insert_body( # pylint: disable=protected-access [{'foo': 'herp', 'bar': 42}, {'foo': 'derp', 'bar': 13}], skip_invalid=False, ignore_unknown=False, template_suffix=None, - insert_id_fn=insert_id) + insert_id_fn=insert_id, + ) assert len(body['rows']) == 2 assert all(r['insertId'] for r in body['rows']) diff --git a/bigquery/tests/unit/utils_test.py b/bigquery/tests/unit/utils_test.py index 657860ecb..fdd158009 100644 --- a/bigquery/tests/unit/utils_test.py +++ b/bigquery/tests/unit/utils_test.py @@ -1,3 +1,4 @@ +# pylint: disable=line-too-long import datetime import pytest @@ -7,70 +8,98 @@ from gcloud.aio.bigquery.utils import utc -@pytest.mark.parametrize('data,expected', [ - ({'v': None}, None), - ({'v': 'foo'}, 'foo'), - ({'v': [{'v': 0}, {'v': 1}]}, [0, 1]), - ({'v': {'f': [{'v': 'foo'}]}}, ['foo']), - ({'v': {'f': [{'v': 'foo'}, {'v': 'bar'}]}}, ['foo', 'bar']), - ({'v': {'f': [{'v': {'f': [{'v': 0}, {'v': 1}]}}, - {'v': {'f': [{'v': 2}, {'v': 3}]}}]}}, [[0, 1], [2, 3]]), -]) +@pytest.mark.parametrize( + 'data,expected', [ + ({'v': None}, None), + ({'v': 'foo'}, 'foo'), + ({'v': [{'v': 0}, {'v': 1}]}, [0, 1]), + ({'v': {'f': [{'v': 'foo'}]}}, ['foo']), + ({'v': {'f': [{'v': 'foo'}, {'v': 'bar'}]}}, ['foo', 'bar']), + ( + { + 'v': { + 'f': [ + {'v': {'f': [{'v': 0}, {'v': 1}]}}, + {'v': {'f': [{'v': 2}, {'v': 3}]}}, + ], + }, + }, [[0, 1], [2, 3]], + ), + ], +) def test_flatten(data, expected): assert flatten({'f': [data]}) == [expected] -@pytest.mark.parametrize('field,value,expected', [ - ({'type': 'BIGNUMERIC', 'mode': 'NULLABLE'}, '0.0', 0.0), - ({'type': 'BIGNUMERIC', 'mode': 'NULLABLE'}, '1.25', 1.25), +@pytest.mark.parametrize( + 'field,value,expected', [ + ({'type': 'BIGNUMERIC', 'mode': 'NULLABLE'}, '0.0', 0.0), + ({'type': 'BIGNUMERIC', 'mode': 'NULLABLE'}, '1.25', 1.25), - ({'type': 'BOOLEAN', 'mode': 'NULLABLE'}, 'false', False), - ({'type': 'BOOLEAN', 'mode': 'NULLABLE'}, 'true', True), + ({'type': 'BOOLEAN', 'mode': 'NULLABLE'}, 'false', False), + ({'type': 'BOOLEAN', 'mode': 'NULLABLE'}, 'true', True), - ({'type': 'FLOAT', 'mode': 'NULLABLE'}, '0.0', 0.0), - ({'type': 'FLOAT', 'mode': 'NULLABLE'}, '1.25', 1.25), + ({'type': 'FLOAT', 'mode': 'NULLABLE'}, '0.0', 0.0), + ({'type': 'FLOAT', 'mode': 'NULLABLE'}, '1.25', 1.25), - ({'type': 'INTEGER', 'mode': 'NULLABLE'}, '0', 0), - ({'type': 'INTEGER', 'mode': 'NULLABLE'}, '1', 1), + ({'type': 'INTEGER', 'mode': 'NULLABLE'}, '0', 0), + ({'type': 'INTEGER', 'mode': 'NULLABLE'}, '1', 1), - ({'type': 'NUMERIC', 'mode': 'NULLABLE'}, '0.0', 0.0), - ({'type': 'NUMERIC', 'mode': 'NULLABLE'}, '1.25', 1.25), + ({'type': 'NUMERIC', 'mode': 'NULLABLE'}, '0.0', 0.0), + ({'type': 'NUMERIC', 'mode': 'NULLABLE'}, '1.25', 1.25), - ({'type': 'RECORD', 'mode': 'NULLABLE', 'fields': [ - {'type': 'INTEGER', 'mode': 'REQUIRED'}, - ]}, [], {}), - ({'type': 'RECORD', 'mode': 'NULLABLE', 'fields': [ - {'name': 'x', 'type': 'INTEGER', 'mode': 'REQUIRED'}, - {'name': 'y', 'type': 'INTEGER', 'mode': 'REQUIRED'}, - ]}, [1, 2], {'x': 1, 'y': 2}), + ( + { + 'type': 'RECORD', 'mode': 'NULLABLE', 'fields': [ + {'type': 'INTEGER', 'mode': 'REQUIRED'}, + ], + }, [], {}, + ), + ( + { + 'type': 'RECORD', 'mode': 'NULLABLE', 'fields': [ + {'name': 'x', 'type': 'INTEGER', 'mode': 'REQUIRED'}, + {'name': 'y', 'type': 'INTEGER', 'mode': 'REQUIRED'}, + ], + }, [1, 2], {'x': 1, 'y': 2}, + ), - ({'type': 'STRING', 'mode': 'NULLABLE'}, '', ''), - ({'type': 'STRING', 'mode': 'NULLABLE'}, 'foo', 'foo'), + ({'type': 'STRING', 'mode': 'NULLABLE'}, '', ''), + ({'type': 'STRING', 'mode': 'NULLABLE'}, 'foo', 'foo'), - ({'type': 'TIMESTAMP', 'mode': 'NULLABLE'}, '0.0', - datetime.datetime(1970, 1, 1, 0, tzinfo=utc)), - ({'type': 'TIMESTAMP', 'mode': 'NULLABLE'}, '1656511192.51', - datetime.datetime(2022, 6, 29, 13, 59, 52, 510000, tzinfo=utc)), + ( + {'type': 'TIMESTAMP', 'mode': 'NULLABLE'}, '0.0', + datetime.datetime(1970, 1, 1, 0, tzinfo=utc), + ), + ( + {'type': 'TIMESTAMP', 'mode': 'NULLABLE'}, '1656511192.51', + datetime.datetime(2022, 6, 29, 13, 59, 52, 510000, tzinfo=utc), + ), - ({'type': 'STRING', 'mode': 'REQUIRED'}, '', ''), - ({'type': 'STRING', 'mode': 'REQUIRED'}, 'foo', 'foo'), + ({'type': 'STRING', 'mode': 'REQUIRED'}, '', ''), + ({'type': 'STRING', 'mode': 'REQUIRED'}, 'foo', 'foo'), - ({'type': 'STRING', 'mode': 'REPEATED'}, - {'v': [{'v': 'foo'}, {'v': 'bar'}]}, ['foo', 'bar']), + ( + {'type': 'STRING', 'mode': 'REPEATED'}, + {'v': [{'v': 'foo'}, {'v': 'bar'}]}, ['foo', 'bar'], + ), -]) + ], +) def test_parse(field, value, expected): assert parse(field, value) == expected -@pytest.mark.parametrize('kind', [ - 'BOOLEAN', - 'FLOAT', - 'INTEGER', - 'RECORD', - 'STRING', - 'TIMESTAMP', -]) +@pytest.mark.parametrize( + 'kind', [ + 'BOOLEAN', + 'FLOAT', + 'INTEGER', + 'RECORD', + 'STRING', + 'TIMESTAMP', + ], +) def test_parse_nullable(kind): field = {'type': kind, 'mode': 'NULLABLE'} # make sure we never convert to a falsey typed equivalent @@ -78,145 +107,322 @@ def test_parse_nullable(kind): assert parse(field, None) is None -@pytest.mark.parametrize('fields,rows,expected', [ - # collection of misc data - ([ - {'name': 'id', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'unixtime', 'type': 'INTEGER', 'mode': 'NULLABLE'}, - {'name': 'isfakedata', 'type': 'BOOLEAN', 'mode': 'NULLABLE'}, - {'name': 'nested', 'type': 'RECORD', 'mode': 'REPEATED', 'fields': [ - {'name': 'nestedagain', 'type': 'RECORD', 'mode': 'REPEATED', - 'fields': [ - {'name': 'item', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'value', 'type': 'FLOAT', 'mode': 'NULLABLE'}]}]}, - {'name': 'repeated', 'type': 'STRING', 'mode': 'REPEATED'}, - {'name': 'record', 'type': 'RECORD', 'mode': 'REQUIRED', 'fields': [ - {'name': 'item', 'type': 'STRING', 'mode': 'NULLABLE'}, - {'name': 'value', 'type': 'INTEGER', 'mode': 'NULLABLE'}]}, - {'name': 'PARTITIONTIME', 'type': 'TIMESTAMP', 'mode': 'NULLABLE'}], - [ - {'f': [ - {'v': 'ident1'}, - {'v': '1654122422181'}, - {'v': 'true'}, - {'v': [ - {'v': {'f': [{'v': [ - {'v': {'f': [{'v': 'apples'}, {'v': '1.23'}]}}, - {'v': {'f': [{'v': 'oranges'}, {'v': '2.34'}]}}]}]}}, - {'v': {'f': [{'v': [ - {'v': {'f': [{'v': 'aardvarks'}, {'v': '9000.1'}]}}]}]}}]}, - {'v': [{'v': 'foo'}, {'v': 'bar'}]}, - {'v': {'f': [{'v': 'slothtoes'}, {'v': 3}]}}, - {'v': '1.6540416E9'}]}, - {'f': [ - {'v': 'ident2'}, - {'v': '1654122422181'}, - {'v': 'false'}, - {'v': []}, - {'v': [{'v': 'foo'}, {'v': 'bar'}]}, - {'v': {'f': [{'v': 'slothtoes'}, {'v': 3}]}}, - {'v': '1.6540416E9'}]}], - [{ - 'PARTITIONTIME': datetime.datetime(2022, 6, 1, 0, 0, tzinfo=utc), - 'id': 'ident1', - 'isfakedata': True, - 'nested': [ - { - 'nestedagain': [ - {'item': 'apples', 'value': 1.23}, - {'item': 'oranges', 'value': 2.34}, - ], - }, - { - 'nestedagain': [ - {'item': 'aardvarks', 'value': 9000.1}, - ], - } - ], - 'record': {'item': 'slothtoes', 'value': 3}, - 'repeated': ['foo', 'bar'], - 'unixtime': 1654122422181, - }, { - 'PARTITIONTIME': datetime.datetime(2022, 6, 1, 0, 0, tzinfo=utc), - 'id': 'ident2', - 'isfakedata': False, - 'nested': [], - 'record': {'item': 'slothtoes', 'value': 3}, - 'repeated': ['foo', 'bar'], - 'unixtime': 1654122422181}], - ), - - # double-nested RECORDs - ([{ - 'name': 'paragraph', - 'type': 'RECORD', - 'mode': 'REPEATED', - 'fields': [ - { - 'name': 'sentence', +@pytest.mark.parametrize( + 'fields,rows,expected', [ + # collection of misc data + ( + [ + {'name': 'id', 'type': 'STRING', 'mode': 'NULLABLE'}, + {'name': 'unixtime', 'type': 'INTEGER', 'mode': 'NULLABLE'}, + {'name': 'isfakedata', 'type': 'BOOLEAN', 'mode': 'NULLABLE'}, + { + 'name': 'nested', 'type': 'RECORD', 'mode': 'REPEATED', 'fields': [ + { + 'name': 'nestedagain', 'type': 'RECORD', 'mode': 'REPEATED', + 'fields': [ + { + 'name': 'item', 'type': 'STRING', + 'mode': 'NULLABLE', + }, + { + 'name': 'value', 'type': 'FLOAT', + 'mode': 'NULLABLE', + }, + ], + }, + ], + }, + {'name': 'repeated', 'type': 'STRING', 'mode': 'REPEATED'}, + { + 'name': 'record', 'type': 'RECORD', 'mode': 'REQUIRED', 'fields': [ + {'name': 'item', 'type': 'STRING', 'mode': 'NULLABLE'}, + {'name': 'value', 'type': 'INTEGER', 'mode': 'NULLABLE'}, + ], + }, + {'name': 'PARTITIONTIME', 'type': 'TIMESTAMP', 'mode': 'NULLABLE'}, + ], + [ + { + 'f': [ + {'v': 'ident1'}, + {'v': '1654122422181'}, + {'v': 'true'}, + { + 'v': [ + { + 'v': { + 'f': [{ + 'v': [ + { + 'v': { + 'f': [{'v': 'apples'}, {'v': '1.23'}], + }, + }, + { + 'v': { + 'f': [{'v': 'oranges'}, {'v': '2.34'}], + }, + }, + ], + }], + }, + }, + { + 'v': { + 'f': [{ + 'v': [ + { + 'v': { + 'f': [{'v': 'aardvarks'}, {'v': '9000.1'}], + }, + }, + ], + }], + }, + }, + ], + }, + {'v': [{'v': 'foo'}, {'v': 'bar'}]}, + {'v': {'f': [{'v': 'slothtoes'}, {'v': 3}]}}, + {'v': '1.6540416E9'}, + ], + }, + { + 'f': [ + {'v': 'ident2'}, + {'v': '1654122422181'}, + {'v': 'false'}, + {'v': []}, + {'v': [{'v': 'foo'}, {'v': 'bar'}]}, + {'v': {'f': [{'v': 'slothtoes'}, {'v': 3}]}}, + {'v': '1.6540416E9'}, + ], + }, + ], + [ + { + 'PARTITIONTIME': datetime.datetime(2022, 6, 1, 0, 0, tzinfo=utc), + 'id': 'ident1', + 'isfakedata': True, + 'nested': [ + { + 'nestedagain': [ + {'item': 'apples', 'value': 1.23}, + {'item': 'oranges', 'value': 2.34}, + ], + }, + { + 'nestedagain': [ + {'item': 'aardvarks', 'value': 9000.1}, + ], + }, + ], + 'record': {'item': 'slothtoes', 'value': 3}, + 'repeated': ['foo', 'bar'], + 'unixtime': 1654122422181, + }, { + 'PARTITIONTIME': datetime.datetime(2022, 6, 1, 0, 0, tzinfo=utc), + 'id': 'ident2', + 'isfakedata': False, + 'nested': [], + 'record': {'item': 'slothtoes', 'value': 3}, + 'repeated': ['foo', 'bar'], + 'unixtime': 1654122422181, + }, + ], + ), + + # double-nested RECORDs + ( + [{ + 'name': 'paragraph', 'type': 'RECORD', 'mode': 'REPEATED', 'fields': [ { - 'name': 'word', - 'type': 'STRING', - 'mode': 'NULLABLE' + 'name': 'sentence', + 'type': 'RECORD', + 'mode': 'REPEATED', + 'fields': [ + { + 'name': 'word', + 'type': 'STRING', + 'mode': 'NULLABLE', + }, + { + 'name': 'timestamp', + 'type': 'FLOAT', + 'mode': 'NULLABLE', + }, + ], }, + ], + }], + [{ + 'f': [{ + 'v': [ + { + 'v': { + 'f': [{ + 'v': [{ + 'v': { + 'f': [ + {'v': 'hello'}, + {'v': '2.34'}, + ], + }, + }], + }], + }, + }, + { + 'v': { + 'f': [{ + 'v': [{ + 'v': { + 'f': [ + {'v': 'hey'}, + {'v': '5.22'}, + ], + }, + }], + }], + }, + }, + { + 'v': { + 'f': [{ + 'v': [ + { + 'v': { + 'f': [ + {'v': "I'm"}, + {'v': '7.86'}, + ], + }, + }, + { + 'v': { + 'f': [ + {'v': 'good'}, + {'v': '8.31'}, + ], + }, + }, + { + 'v': { + 'f': [ + {'v': "I'm"}, + {'v': '8.46'}, + ], + }, + }, + { + 'v': { + 'f': [ + {'v': 'very'}, + {'v': '8.76'}, + ], + }, + }, + { + 'v': { + 'f': [ + {'v': 'caffeinated'}, + {'v': '9.45'}, + ], + }, + }, + { + 'v': { + 'f': [ + {'v': 'this'}, + {'v': '9.66'}, + ], + }, + }, + { + 'v': { + 'f': [ + {'v': 'morning'}, + {'v': '10.05'}, + ], + }, + }, + { + 'v': { + 'f': [ + {'v': 'how'}, + {'v': '10.92'}, + ], + }, + }, + { + 'v': { + 'f': [ + {'v': 'are'}, + {'v': '11.04'}, + ], + }, + }, + { + 'v': { + 'f': [ + {'v': 'you'}, + {'v': '11.13'}, + ], + }, + }, + { + 'v': { + 'f': [ + {'v': 'doing'}, + {'v': '11.4'}, + ], + }, + }, + ], + }], + }, + }, + ], + }], + }], + [{ + 'paragraph': [ { - 'name': 'timestamp', - 'type': 'FLOAT', - 'mode': 'NULLABLE' - }]}]}], - [{'f': [{'v': [ - {'v': {'f': [{'v': [{'v': {'f': [{'v': 'hello'}, - {'v': '2.34'}]}}]}]}}, - {'v': {'f': [{'v': [{'v': {'f': [{'v': 'hey'}, - {'v': '5.22'}]}}]}]}}, - {'v': {'f': [{'v': [{'v': {'f': [{'v': "I'm"}, - {'v': '7.86'}]}}, - {'v': {'f': [{'v': 'good'}, - {'v': '8.31'}]}}, - {'v': {'f': [{'v': "I'm"}, - {'v': '8.46'}]}}, - {'v': {'f': [{'v': 'very'}, - {'v': '8.76'}]}}, - {'v': {'f': [{'v': 'caffeinated'}, - {'v': '9.45'}]}}, - {'v': {'f': [{'v': 'this'}, - {'v': '9.66'}]}}, - {'v': {'f': [{'v': 'morning'}, - {'v': '10.05'}]}}, - {'v': {'f': [{'v': 'how'}, - {'v': '10.92'}]}}, - {'v': {'f': [{'v': 'are'}, - {'v': '11.04'}]}}, - {'v': {'f': [{'v': 'you'}, - {'v': '11.13'}]}}, - {'v': {'f': [{'v': 'doing'}, - {'v': '11.4'}]}}]}]}}]}]}], - [{'paragraph': [{ - 'sentence': [{'word': 'hello', 'timestamp': 2.34}]}, { - 'sentence': [{'word': 'hey', 'timestamp': 5.22}]}, { - 'sentence': [{'word': "I'm", 'timestamp': 7.86}, - {'word': 'good', 'timestamp': 8.31}, - {'word': "I'm", 'timestamp': 8.46}, - {'word': 'very', 'timestamp': 8.76}, - {'word': 'caffeinated', 'timestamp': 9.45}, - {'word': 'this', 'timestamp': 9.66}, - {'word': 'morning', 'timestamp': 10.05}, - {'word': 'how', 'timestamp': 10.92}, - {'word': 'are', 'timestamp': 11.04}, - {'word': 'you', 'timestamp': 11.13}, - {'word': 'doing', 'timestamp': 11.4}]}]}], - ), -]) + 'sentence': [{'word': 'hello', 'timestamp': 2.34}], + }, { + 'sentence': [{'word': 'hey', 'timestamp': 5.22}], + }, { + 'sentence': [ + {'word': "I'm", 'timestamp': 7.86}, + {'word': 'good', 'timestamp': 8.31}, + {'word': "I'm", 'timestamp': 8.46}, + {'word': 'very', 'timestamp': 8.76}, + {'word': 'caffeinated', 'timestamp': 9.45}, + {'word': 'this', 'timestamp': 9.66}, + {'word': 'morning', 'timestamp': 10.05}, + {'word': 'how', 'timestamp': 10.92}, + {'word': 'are', 'timestamp': 11.04}, + {'word': 'you', 'timestamp': 11.13}, + {'word': 'doing', 'timestamp': 11.4}, + ], + }, + ], + }], + ), + ], +) def test_query_response_to_dict(fields, rows, expected): resp = { 'kind': 'bigquery#queryResponse', 'schema': {'fields': fields}, - 'jobReference': {'projectId': 'sample-project', - 'jobId': 'job_Tlpl-66ca7a8e365a28084c39ffc52d402671', - 'location': 'US'}, + 'jobReference': { + 'projectId': 'sample-project', + 'jobId': 'job_Tlpl-66ca7a8e365a28084c39ffc52d402671', + 'location': 'US', + }, 'rows': rows, 'totalRows': str(len(rows)), 'totalBytesProcessed': '0', diff --git a/datastore/gcloud/aio/datastore/datastore.py b/datastore/gcloud/aio/datastore/datastore.py index b10979ebe..a6cde5d7f 100644 --- a/datastore/gcloud/aio/datastore/datastore.py +++ b/datastore/gcloud/aio/datastore/datastore.py @@ -74,13 +74,16 @@ def __init__( self.session = AioSession(session) self.token = token or Token( service_file=service_file, scopes=SCOPES, - session=self.session.session) # type: ignore[arg-type] + session=self.session.session, # type: ignore[arg-type] + ) self._project = project if self._api_is_dev and not project: - self._project = (os.environ.get('DATASTORE_PROJECT_ID') - or os.environ.get('GOOGLE_CLOUD_PROJECT') - or 'dev') + self._project = ( + os.environ.get('DATASTORE_PROJECT_ID') + or os.environ.get('GOOGLE_CLOUD_PROJECT') + or 'dev' + ) async def project(self) -> str: if self._project: @@ -93,15 +96,19 @@ async def project(self) -> str: raise Exception('could not determine project, please set it manually') @staticmethod - def _make_commit_body(mutations: List[Dict[str, Any]], - transaction: Optional[str] = None, - mode: Mode = Mode.TRANSACTIONAL) -> Dict[str, Any]: + def _make_commit_body( + mutations: List[Dict[str, Any]], + transaction: Optional[str] = None, + mode: Mode = Mode.TRANSACTIONAL, + ) -> Dict[str, Any]: if not mutations: raise Exception('at least one mutation record is required') if transaction is None and mode != Mode.NON_TRANSACTIONAL: - raise Exception('a transaction ID must be provided when mode is ' - 'transactional') + raise Exception( + 'a transaction ID must be provided when mode is ' + 'transactional', + ) data = { 'mode': mode.value, @@ -124,7 +131,8 @@ async def headers(self) -> Dict[str, str]: @classmethod def make_mutation( cls, operation: Operation, key: Key, - properties: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + properties: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: if operation == Operation.DELETE: return {operation.value: key.to_repr()} @@ -137,13 +145,15 @@ def make_mutation( operation.value: { 'key': key.to_repr(), 'properties': mutation_properties, - } + }, } # https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects/allocateIds - async def allocateIds(self, keys: List[Key], - session: Optional[Session] = None, - timeout: int = 10) -> List[Key]: + async def allocateIds( + self, keys: List[Key], + session: Optional[Session] = None, + timeout: int = 10, + ) -> List[Key]: project = await self.project() url = f'{self._api_root}/projects/{project}:allocateIds' @@ -159,16 +169,20 @@ async def allocateIds(self, keys: List[Key], s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, data=payload, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.post( + url, data=payload, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) data = await resp.json() return [self.key_kind.from_repr(k) for k in data['keys']] # https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects/beginTransaction # TODO: support readwrite vs readonly transaction types - async def beginTransaction(self, session: Optional[Session] = None, - timeout: int = 10) -> str: + async def beginTransaction( + self, session: Optional[Session] = None, + timeout: int = 10, + ) -> str: project = await self.project() url = f'{self._api_root}/projects/{project}:beginTransaction' headers = await self.headers() @@ -185,16 +199,20 @@ async def beginTransaction(self, session: Optional[Session] = None, return transaction # https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects/commit - async def commit(self, mutations: List[Dict[str, Any]], - transaction: Optional[str] = None, - mode: Mode = Mode.TRANSACTIONAL, - session: Optional[Session] = None, - timeout: int = 10) -> Dict[str, Any]: + async def commit( + self, mutations: List[Dict[str, Any]], + transaction: Optional[str] = None, + mode: Mode = Mode.TRANSACTIONAL, + session: Optional[Session] = None, + timeout: int = 10, + ) -> Dict[str, Any]: project = await self.project() url = f'{self._api_root}/projects/{project}:commit' - body = self._make_commit_body(mutations, transaction=transaction, - mode=mode) + body = self._make_commit_body( + mutations, transaction=transaction, + mode=mode, + ) payload = json.dumps(body).encode('utf-8') headers = await self.headers() @@ -205,23 +223,29 @@ async def commit(self, mutations: List[Dict[str, Any]], s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, data=payload, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.post( + url, data=payload, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) data: Dict[str, Any] = await resp.json() return { - 'mutationResults': [self.mutation_result_kind.from_repr(r) - for r in data.get('mutationResults', [])], + 'mutationResults': [ + self.mutation_result_kind.from_repr(r) + for r in data.get('mutationResults', []) + ], 'indexUpdates': data.get('indexUpdates', 0), } # https://cloud.google.com/datastore/docs/reference/admin/rest/v1/projects/export - async def export(self, output_bucket_prefix: str, - kinds: Optional[List[str]] = None, - namespaces: Optional[List[str]] = None, - labels: Optional[Dict[str, str]] = None, - session: Optional[Session] = None, - timeout: int = 10) -> DatastoreOperation: + async def export( + self, output_bucket_prefix: str, + kinds: Optional[List[str]] = None, + namespaces: Optional[List[str]] = None, + labels: Optional[Dict[str, str]] = None, + session: Optional[Session] = None, + timeout: int = 10, + ) -> DatastoreOperation: project = await self.project() url = f'{self._api_root}/projects/{project}:export' @@ -242,16 +266,20 @@ async def export(self, output_bucket_prefix: str, s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, data=payload, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.post( + url, data=payload, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) data: Dict[str, Any] = await resp.json() return self.datastore_operation_kind.from_repr(data) # https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects.operations/get - async def get_datastore_operation(self, name: str, - session: Optional[Session] = None, - timeout: int = 10) -> DatastoreOperation: + async def get_datastore_operation( + self, name: str, + session: Optional[Session] = None, + timeout: int = 10, + ) -> DatastoreOperation: url = f'{self._api_root}/{name}' headers = await self.headers() @@ -269,7 +297,7 @@ async def get_datastore_operation(self, name: str, async def lookup( self, keys: List[Key], transaction: Optional[str] = None, consistency: Consistency = Consistency.STRONG, - session: Optional[Session] = None, timeout: int = 10 + session: Optional[Session] = None, timeout: int = 10, ) -> Dict[str, List[Union[EntityResult, Key]]]: project = await self.project() url = f'{self._api_root}/projects/{project}:lookup' @@ -291,24 +319,34 @@ async def lookup( s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, data=payload, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.post( + url, data=payload, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) data: Dict[str, List[Any]] = await resp.json() return { - 'found': [self.entity_result_kind.from_repr(e) - for e in data.get('found', [])], - 'missing': [self.entity_result_kind.from_repr(e) - for e in data.get('missing', [])], - 'deferred': [self.key_kind.from_repr(k) - for k in data.get('deferred', [])], + 'found': [ + self.entity_result_kind.from_repr(e) + for e in data.get('found', []) + ], + 'missing': [ + self.entity_result_kind.from_repr(e) + for e in data.get('missing', []) + ], + 'deferred': [ + self.key_kind.from_repr(k) + for k in data.get('deferred', []) + ], } # https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects/reserveIds - async def reserveIds(self, keys: List[Key], database_id: str = '', - session: Optional[Session] = None, - timeout: int = 10) -> None: + async def reserveIds( + self, keys: List[Key], database_id: str = '', + session: Optional[Session] = None, + timeout: int = 10, + ) -> None: project = await self.project() url = f'{self._api_root}/projects/{project}:reserveIds' @@ -325,13 +363,17 @@ async def reserveIds(self, keys: List[Key], database_id: str = '', s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - await s.post(url, data=payload, # type: ignore[arg-type] - headers=headers, timeout=timeout) + await s.post( + url, data=payload, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) # https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects/rollback - async def rollback(self, transaction: str, - session: Optional[Session] = None, - timeout: int = 10) -> None: + async def rollback( + self, transaction: str, + session: Optional[Session] = None, + timeout: int = 10, + ) -> None: project = await self.project() url = f'{self._api_root}/projects/{project}:rollback' @@ -347,15 +389,19 @@ async def rollback(self, transaction: str, s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - await s.post(url, data=payload, # type: ignore[arg-type] - headers=headers, timeout=timeout) + await s.post( + url, data=payload, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) # https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects/runQuery - async def runQuery(self, query: BaseQuery, - transaction: Optional[str] = None, - consistency: Consistency = Consistency.EVENTUAL, - session: Optional[Session] = None, - timeout: int = 10) -> QueryResultBatch: + async def runQuery( + self, query: BaseQuery, + transaction: Optional[str] = None, + consistency: Consistency = Consistency.EVENTUAL, + session: Optional[Session] = None, + timeout: int = 10, + ) -> QueryResultBatch: project = await self.project() url = f'{self._api_root}/projects/{project}:runQuery' @@ -380,39 +426,59 @@ async def runQuery(self, query: BaseQuery, s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, data=payload, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.post( + url, data=payload, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) data: Dict[str, Any] = await resp.json() return self.query_result_batch_kind.from_repr(data['batch']) - async def delete(self, key: Key, - session: Optional[Session] = None) -> Dict[str, Any]: + async def delete( + self, key: Key, + session: Optional[Session] = None, + ) -> Dict[str, Any]: return await self.operate(Operation.DELETE, key, session=session) - async def insert(self, key: Key, properties: Dict[str, Any], - session: Optional[Session] = None) -> Dict[str, Any]: - return await self.operate(Operation.INSERT, key, properties, - session=session) - - async def update(self, key: Key, properties: Dict[str, Any], - session: Optional[Session] = None) -> Dict[str, Any]: - return await self.operate(Operation.UPDATE, key, properties, - session=session) - - async def upsert(self, key: Key, properties: Dict[str, Any], - session: Optional[Session] = None) -> Dict[str, Any]: - return await self.operate(Operation.UPSERT, key, properties, - session=session) + async def insert( + self, key: Key, properties: Dict[str, Any], + session: Optional[Session] = None, + ) -> Dict[str, Any]: + return await self.operate( + Operation.INSERT, key, properties, + session=session, + ) + + async def update( + self, key: Key, properties: Dict[str, Any], + session: Optional[Session] = None, + ) -> Dict[str, Any]: + return await self.operate( + Operation.UPDATE, key, properties, + session=session, + ) + + async def upsert( + self, key: Key, properties: Dict[str, Any], + session: Optional[Session] = None, + ) -> Dict[str, Any]: + return await self.operate( + Operation.UPSERT, key, properties, + session=session, + ) # TODO: accept Entity rather than key/properties? - async def operate(self, operation: Operation, key: Key, - properties: Optional[Dict[str, Any]] = None, - session: Optional[Session] = None) -> Dict[str, Any]: + async def operate( + self, operation: Operation, key: Key, + properties: Optional[Dict[str, Any]] = None, + session: Optional[Session] = None, + ) -> Dict[str, Any]: transaction = await self.beginTransaction(session=session) mutation = self.make_mutation(operation, key, properties=properties) - return await self.commit([mutation], transaction=transaction, - session=session) + return await self.commit( + [mutation], transaction=transaction, + session=session, + ) async def close(self) -> None: await self.session.close() diff --git a/datastore/gcloud/aio/datastore/datastore_operation.py b/datastore/gcloud/aio/datastore/datastore_operation.py index 0167f4c0f..b45f7df9e 100644 --- a/datastore/gcloud/aio/datastore/datastore_operation.py +++ b/datastore/gcloud/aio/datastore/datastore_operation.py @@ -4,10 +4,12 @@ class DatastoreOperation: - def __init__(self, name: str, done: bool, - metadata: Optional[Dict[str, Any]] = None, - error: Optional[Dict[str, str]] = None, - response: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, name: str, done: bool, + metadata: Optional[Dict[str, Any]] = None, + error: Optional[Dict[str, str]] = None, + response: Optional[Dict[str, Any]] = None, + ) -> None: self.name = name self.done = done @@ -20,8 +22,10 @@ def __repr__(self) -> str: @classmethod def from_repr(cls, data: Dict[str, Any]) -> 'DatastoreOperation': - return cls(data['name'], data.get('done', False), data.get('metadata'), - data.get('error'), data.get('response')) + return cls( + data['name'], data.get('done', False), data.get('metadata'), + data.get('error'), data.get('response'), + ) def to_repr(self) -> Dict[str, Any]: return { diff --git a/datastore/gcloud/aio/datastore/entity.py b/datastore/gcloud/aio/datastore/entity.py index bf4202bf9..81933fe3d 100644 --- a/datastore/gcloud/aio/datastore/entity.py +++ b/datastore/gcloud/aio/datastore/entity.py @@ -12,17 +12,22 @@ class Entity: def __init__( self, key: Optional[Key], - properties: Optional[Dict[str, Dict[str, Any]]] = None) -> None: + properties: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> None: self.key = key - self.properties = {k: self.value_kind.from_repr(v).value - for k, v in (properties or {}).items()} + self.properties = { + k: self.value_kind.from_repr(v).value + for k, v in (properties or {}).items() + } def __eq__(self, other: Any) -> bool: if not isinstance(other, Entity): return False - return bool(self.key == other.key - and self.properties == other.properties) + return bool( + self.key == other.key + and self.properties == other.properties, + ) def __repr__(self) -> str: return str(self.to_repr()) @@ -40,16 +45,20 @@ def from_repr(cls, data: Dict[str, Any]) -> 'Entity': def to_repr(self) -> Dict[str, Any]: return { 'key': self.key.to_repr() if self.key else None, - 'properties': {k: self.value_kind(v).to_repr() - for k, v in self.properties.items()}, + 'properties': { + k: self.value_kind(v).to_repr() + for k, v in self.properties.items() + }, } class EntityResult: entity_kind = Entity - def __init__(self, entity: Entity, version: str = '', - cursor: str = '') -> None: + def __init__( + self, entity: Entity, version: str = '', + cursor: str = '', + ) -> None: self.entity = entity self.version = version self.cursor = cursor @@ -58,18 +67,22 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, EntityResult): return False - return bool(self.entity == other.entity - and self.version == other.version - and self.cursor == self.cursor) + return bool( + self.entity == other.entity + and self.version == other.version + and self.cursor == self.cursor, + ) def __repr__(self) -> str: return str(self.to_repr()) @classmethod def from_repr(cls, data: Dict[str, Any]) -> 'EntityResult': - return cls(cls.entity_kind.from_repr(data['entity']), - data.get('version', ''), - data.get('cursor', '')) + return cls( + cls.entity_kind.from_repr(data['entity']), + data.get('version', ''), + data.get('cursor', ''), + ) def to_repr(self) -> Dict[str, Any]: data: Dict[str, Any] = { diff --git a/datastore/gcloud/aio/datastore/filter.py b/datastore/gcloud/aio/datastore/filter.py index 21b04b53f..0d289ee85 100644 --- a/datastore/gcloud/aio/datastore/filter.py +++ b/datastore/gcloud/aio/datastore/filter.py @@ -54,8 +54,10 @@ def to_repr(self) -> Dict[str, Any]: class CompositeFilter(BaseFilter): json_key = 'compositeFilter' - def __init__(self, operator: CompositeFilterOperator, - filters: List[Filter]) -> None: + def __init__( + self, operator: CompositeFilterOperator, + filters: List[Filter], + ) -> None: self.operator = operator self.filters = filters @@ -65,7 +67,8 @@ def __eq__(self, other: Any) -> bool: return bool( self.operator == other.operator - and self.filters == other.filters) + and self.filters == other.filters, + ) @classmethod def from_repr(cls, data: Dict[str, Any]) -> 'CompositeFilter': @@ -84,8 +87,10 @@ def to_repr(self) -> Dict[str, Any]: class PropertyFilter(BaseFilter): json_key = 'propertyFilter' - def __init__(self, prop: str, operator: PropertyFilterOperator, - value: Value) -> None: + def __init__( + self, prop: str, operator: PropertyFilterOperator, + value: Value, + ) -> None: self.prop = prop self.operator = operator self.value = value @@ -97,7 +102,8 @@ def __eq__(self, other: Any) -> bool: return bool( self.prop == other.prop and self.operator == other.operator - and self.value == other.value) + and self.value == other.value, + ) @classmethod def from_repr(cls, data: Dict[str, Any]) -> 'PropertyFilter': diff --git a/datastore/gcloud/aio/datastore/key.py b/datastore/gcloud/aio/datastore/key.py index 768d54758..d1a10c825 100644 --- a/datastore/gcloud/aio/datastore/key.py +++ b/datastore/gcloud/aio/datastore/key.py @@ -5,8 +5,10 @@ class PathElement: - def __init__(self, kind: str, *, id_: Optional[int] = None, - name: Optional[str] = None) -> None: + def __init__( + self, kind: str, *, id_: Optional[int] = None, + name: Optional[str] = None + ) -> None: self.kind = kind self.id = id_ @@ -18,8 +20,10 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, PathElement): return False - return bool(self.kind == other.kind and self.id == other.id - and self.name == other.name) + return bool( + self.kind == other.kind and self.id == other.id + and self.name == other.name, + ) def __repr__(self) -> str: return str(self.to_repr()) @@ -44,8 +48,10 @@ def to_repr(self) -> Dict[str, Any]: class Key: path_element_kind = PathElement - def __init__(self, project: str, path: List[PathElement], - namespace: str = '') -> None: + def __init__( + self, project: str, path: List[PathElement], + namespace: str = '', + ) -> None: self.project = project self.namespace = namespace self.path = path @@ -54,19 +60,25 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, Key): return False - return bool(self.project == other.project - and self.namespace == other.namespace - and self.path == other.path) + return bool( + self.project == other.project + and self.namespace == other.namespace + and self.path == other.path, + ) def __repr__(self) -> str: return str(self.to_repr()) @classmethod def from_repr(cls, data: Dict[str, Any]) -> 'Key': - return cls(data['partitionId']['projectId'], - path=[cls.path_element_kind.from_repr(p) - for p in data['path']], - namespace=data['partitionId'].get('namespaceId', '')) + return cls( + data['partitionId']['projectId'], + path=[ + cls.path_element_kind.from_repr(p) + for p in data['path'] + ], + namespace=data['partitionId'].get('namespaceId', ''), + ) def to_repr(self) -> Dict[str, Any]: return { diff --git a/datastore/gcloud/aio/datastore/lat_lng.py b/datastore/gcloud/aio/datastore/lat_lng.py index 6076c803e..3b38d8d0b 100644 --- a/datastore/gcloud/aio/datastore/lat_lng.py +++ b/datastore/gcloud/aio/datastore/lat_lng.py @@ -14,7 +14,8 @@ def __eq__(self, other: Any) -> bool: return bool( self.lat == other.lat - and self.lon == other.lon) + and self.lon == other.lon, + ) def __repr__(self) -> str: return str(self.to_repr()) diff --git a/datastore/gcloud/aio/datastore/mutation.py b/datastore/gcloud/aio/datastore/mutation.py index 3deaa4016..d0fb0b238 100644 --- a/datastore/gcloud/aio/datastore/mutation.py +++ b/datastore/gcloud/aio/datastore/mutation.py @@ -15,8 +15,10 @@ class Mutation: class MutationResult: key_kind = Key - def __init__(self, key: Optional[Key], version: str, - conflict_detected: bool) -> None: + def __init__( + self, key: Optional[Key], version: str, + conflict_detected: bool, + ) -> None: self.key = key self.version = version self.conflict_detected = conflict_detected @@ -43,7 +45,7 @@ def from_repr(cls, data: Dict[str, Any]) -> 'MutationResult': def to_repr(self) -> Dict[str, Any]: data = { 'version': self.version, - 'conflictDetected': self.conflict_detected + 'conflictDetected': self.conflict_detected, } if self.key: data['key'] = self.key.to_repr() diff --git a/datastore/gcloud/aio/datastore/property_order.py b/datastore/gcloud/aio/datastore/property_order.py index 416de71db..b1adf531c 100644 --- a/datastore/gcloud/aio/datastore/property_order.py +++ b/datastore/gcloud/aio/datastore/property_order.py @@ -6,8 +6,10 @@ # https://cloud.google.com/datastore/docs/reference/data/rest/v1/projects/runQuery#PropertyOrder class PropertyOrder: - def __init__(self, prop: str, - direction: Direction = Direction.ASCENDING) -> None: + def __init__( + self, prop: str, + direction: Direction = Direction.ASCENDING, + ) -> None: self.prop = prop self.direction = direction @@ -17,7 +19,8 @@ def __eq__(self, other: Any) -> bool: return bool( self.prop == other.prop - and self.direction == other.direction) + and self.direction == other.direction, + ) def __repr__(self) -> str: return str(self.to_repr()) diff --git a/datastore/gcloud/aio/datastore/query.py b/datastore/gcloud/aio/datastore/query.py index ae0ba74be..46e497177 100644 --- a/datastore/gcloud/aio/datastore/query.py +++ b/datastore/gcloud/aio/datastore/query.py @@ -32,12 +32,14 @@ class Query(BaseQuery): # pylint: disable=too-many-instance-attributes json_key = 'query' - def __init__(self, kind: str = '', query_filter: Optional[Filter] = None, - order: Optional[List[PropertyOrder]] = None, - start_cursor: str = '', end_cursor: str = '', - offset: Optional[int] = None, limit: Optional[int] = None, - projection: Optional[List[Projection]] = None, - distinct_on: Optional[List[str]] = None) -> None: + def __init__( + self, kind: str = '', query_filter: Optional[Filter] = None, + order: Optional[List[PropertyOrder]] = None, + start_cursor: str = '', end_cursor: str = '', + offset: Optional[int] = None, limit: Optional[int] = None, + projection: Optional[List[Projection]] = None, + distinct_on: Optional[List[str]] = None, + ) -> None: self.kind = kind self.query_filter = query_filter self.orders = order or [] @@ -54,7 +56,8 @@ def __eq__(self, other: Any) -> bool: return bool( self.kind == other.kind - and self.query_filter == other.query_filter) + and self.query_filter == other.query_filter, + ) @classmethod def from_repr(cls, data: Dict[str, Any]) -> 'Query': @@ -68,17 +71,21 @@ def from_repr(cls, data: Dict[str, Any]) -> 'Query': end_cursor = data.get('endCursor') or '' offset = int(data['offset']) if 'offset' in data else None limit = int(data['limit']) if 'limit' in data else None - projection = [Projection.from_repr(p) - for p in data.get('projection', [])] + projection = [ + Projection.from_repr(p) + for p in data.get('projection', []) + ] distinct_on = [d['name'] for d in data.get('distinct_on', [])] filter_ = data.get('filter') query_filter = Filter.from_repr(filter_) if filter_ else None - return cls(kind=kind, query_filter=query_filter, order=orders, - start_cursor=start_cursor, end_cursor=end_cursor, - offset=offset, limit=limit, - projection=projection, distinct_on=distinct_on) + return cls( + kind=kind, query_filter=query_filter, order=orders, + start_cursor=start_cursor, end_cursor=end_cursor, + offset=offset, limit=limit, + projection=projection, distinct_on=distinct_on, + ) def to_repr(self) -> Dict[str, Any]: data: Dict[str, Any] = { @@ -107,9 +114,11 @@ def to_repr(self) -> Dict[str, Any]: class GQLQuery(BaseQuery): json_key = 'gqlQuery' - def __init__(self, query_string: str, allow_literals: bool = True, - named_bindings: Optional[Dict[str, Any]] = None, - positional_bindings: Optional[List[Any]] = None) -> None: + def __init__( + self, query_string: str, allow_literals: bool = True, + named_bindings: Optional[Dict[str, Any]] = None, + positional_bindings: Optional[List[Any]] = None, + ) -> None: self.query_string = query_string self.allow_literals = allow_literals self.named_bindings = named_bindings or {} @@ -123,19 +132,26 @@ def __eq__(self, other: Any) -> bool: self.query_string == other.query_string and self.allow_literals == other.allow_literals and self.named_bindings == other.named_bindings - and self.positional_bindings == other.positional_bindings) + and self.positional_bindings == other.positional_bindings, + ) @classmethod def from_repr(cls, data: Dict[str, Any]) -> 'GQLQuery': allow_literals = data['allowLiterals'] query_string = data['queryString'] - named_bindings = {k: cls._param_from_repr(v) - for k, v in data.get('namedBindings', {}).items()} - positional_bindings = [cls._param_from_repr(v) - for v in data.get('positionalBindings', [])] - return cls(query_string, allow_literals=allow_literals, - named_bindings=named_bindings, - positional_bindings=positional_bindings) + named_bindings = { + k: cls._param_from_repr(v) + for k, v in data.get('namedBindings', {}).items() + } + positional_bindings = [ + cls._param_from_repr(v) + for v in data.get('positionalBindings', []) + ] + return cls( + query_string, allow_literals=allow_literals, + named_bindings=named_bindings, + positional_bindings=positional_bindings, + ) @classmethod def _param_from_repr(cls, param_repr: Dict[str, Any]) -> Any: @@ -148,10 +164,14 @@ def to_repr(self) -> Dict[str, Any]: return { 'allowLiterals': self.allow_literals, 'queryString': self.query_string, - 'namedBindings': {k: self._param_to_repr(v) - for k, v in self.named_bindings.items()}, - 'positionalBindings': [self._param_to_repr(v) - for v in self.positional_bindings], + 'namedBindings': { + k: self._param_to_repr(v) + for k, v in self.named_bindings.items() + }, + 'positionalBindings': [ + self._param_to_repr(v) + for v in self.positional_bindings + ], } def _param_to_repr(self, param: Any) -> Dict[str, Any]: @@ -173,12 +193,14 @@ def __eq__(self, other: Any) -> bool: class QueryResultBatch: entity_result_kind = EntityResult - def __init__(self, end_cursor: str, - entity_result_type: ResultType = ResultType.UNSPECIFIED, - entity_results: Optional[List[EntityResult]] = None, - more_results: MoreResultsType = MoreResultsType.UNSPECIFIED, - skipped_cursor: str = '', skipped_results: int = 0, - snapshot_version: str = '') -> None: + def __init__( + self, end_cursor: str, + entity_result_type: ResultType = ResultType.UNSPECIFIED, + entity_results: Optional[List[EntityResult]] = None, + more_results: MoreResultsType = MoreResultsType.UNSPECIFIED, + skipped_cursor: str = '', skipped_results: int = 0, + snapshot_version: str = '', + ) -> None: self.end_cursor = end_cursor self.entity_result_type = entity_result_type @@ -192,13 +214,15 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, QueryResultBatch): return False - return bool(self.end_cursor == other.end_cursor - and self.entity_result_type == other.entity_result_type - and self.entity_results == other.entity_results - and self.more_results == other.more_results - and self.skipped_cursor == other.skipped_cursor - and self.skipped_results == other.skipped_results - and self.snapshot_version == other.snapshot_version) + return bool( + self.end_cursor == other.end_cursor + and self.entity_result_type == other.entity_result_type + and self.entity_results == other.entity_results + and self.more_results == other.more_results + and self.skipped_cursor == other.skipped_cursor + and self.skipped_results == other.skipped_results + and self.snapshot_version == other.snapshot_version, + ) def __repr__(self) -> str: return str(self.to_repr()) @@ -207,17 +231,21 @@ def __repr__(self) -> str: def from_repr(cls, data: Dict[str, Any]) -> 'QueryResultBatch': end_cursor = data['endCursor'] entity_result_type = ResultType(data['entityResultType']) - entity_results = [cls.entity_result_kind.from_repr(er) - for er in data.get('entityResults', [])] + entity_results = [ + cls.entity_result_kind.from_repr(er) + for er in data.get('entityResults', []) + ] more_results = MoreResultsType(data['moreResults']) skipped_cursor = data.get('skippedCursor', '') skipped_results = data.get('skippedResults', 0) snapshot_version = data.get('snapshotVersion', '') - return cls(end_cursor, entity_result_type=entity_result_type, - entity_results=entity_results, more_results=more_results, - skipped_cursor=skipped_cursor, - skipped_results=skipped_results, - snapshot_version=snapshot_version) + return cls( + end_cursor, entity_result_type=entity_result_type, + entity_results=entity_results, more_results=more_results, + skipped_cursor=skipped_cursor, + skipped_results=skipped_results, + snapshot_version=snapshot_version, + ) def to_repr(self) -> Dict[str, Any]: data = { diff --git a/datastore/gcloud/aio/datastore/value.py b/datastore/gcloud/aio/datastore/value.py index 08741e2fe..c32c02583 100644 --- a/datastore/gcloud/aio/datastore/value.py +++ b/datastore/gcloud/aio/datastore/value.py @@ -22,7 +22,8 @@ def __eq__(self, other: Any) -> bool: return bool( self.excludeFromIndexes == other.excludeFromIndexes - and self.value == other.value) + and self.value == other.value, + ) def __repr__(self) -> str: return str(self.to_repr()) @@ -39,8 +40,10 @@ def from_repr(cls, data: Dict[str, Any]) -> 'Value': value = base64.b64decode(data[json_key]) elif value_type == datetime: date_string = data[json_key].rstrip('Z')[:26] - date_fmt = ('%Y-%m-%dT%H:%M:%S.%f' - if '.' in date_string else '%Y-%m-%dT%H:%M:%S') + date_fmt = ( + '%Y-%m-%dT%H:%M:%S.%f' + if '.' in date_string else '%Y-%m-%dT%H:%M:%S' + ) value = datetime.strptime(date_string, date_fmt) elif hasattr(value_type, 'from_repr'): value = value_type.from_repr(data[json_key]) @@ -51,7 +54,8 @@ def from_repr(cls, data: Dict[str, Any]) -> 'Value': supported = [name.value for name in supported_types.values()] raise NotImplementedError( f'{data.keys()} does not contain a supported value type ' - f'(any of: {supported})') + f'(any of: {supported})', + ) # Google may not populate that field. This can happen with both # indexed and non-indexed fields. @@ -61,8 +65,10 @@ def from_repr(cls, data: Dict[str, Any]) -> 'Value': def to_repr(self) -> Dict[str, Any]: value_type = self._infer_type(self.value) - if value_type in {TypeName.ARRAY, TypeName.ENTITY, TypeName.GEOPOINT, - TypeName.KEY}: + if value_type in { + TypeName.ARRAY, TypeName.ENTITY, TypeName.GEOPOINT, + TypeName.KEY, + }: value = self.value.to_repr() elif value_type == TypeName.TIMESTAMP: value = self.value.strftime('%Y-%m-%dT%H:%M:%S.%f000Z') @@ -85,7 +91,8 @@ def _infer_type(self, value: Any) -> TypeName: except KeyError: raise NotImplementedError( # pylint: disable=raise-missing-from f'{kind} is not a supported value type (any of: ' - f'{supported_types})') + f'{supported_types})', + ) @classmethod def _get_supported_types(cls) -> Dict[Any, TypeName]: diff --git a/datastore/tests/integration/smoke_test.py b/datastore/tests/integration/smoke_test.py index 3fd597151..5cf7e4305 100644 --- a/datastore/tests/integration/smoke_test.py +++ b/datastore/tests/integration/smoke_test.py @@ -81,8 +81,9 @@ async def test_mutation_result(creds: str, kind: str, project: str) -> None: @pytest.mark.asyncio -async def test_insert_value_object(creds: str, kind: str, project: str - ) -> None: +async def test_insert_value_object( + creds: str, kind: str, project: str, +) -> None: key = Key(project, [PathElement(kind)]) async with Session() as s: @@ -104,10 +105,14 @@ async def test_transaction(creds: str, kind: str, project: str) -> None: assert len(actual['missing']) == 1 mutations = [ - ds.make_mutation(Operation.INSERT, key, - properties={'animal': 'three-toed sloth'}), - ds.make_mutation(Operation.UPDATE, key, - properties={'animal': 'aardvark'}), + ds.make_mutation( + Operation.INSERT, key, + properties={'animal': 'three-toed sloth'}, + ), + ds.make_mutation( + Operation.UPDATE, key, + properties={'animal': 'aardvark'}, + ), ] await ds.commit(mutations, transaction=transaction, session=s) @@ -125,19 +130,24 @@ async def test_rollback(creds: str, project: str) -> None: @pytest.mark.asyncio -async def test_query_with_key_projection(creds: str, kind: str, - project: str) -> None: +async def test_query_with_key_projection( + creds: str, kind: str, + project: str, +) -> None: async with Session() as s: ds = Datastore(project=project, service_file=creds, session=s) # setup test data await ds.insert(Key(project, [PathElement(kind)]), {'value': 30}, s) property_filter = PropertyFilter( prop='value', operator=PropertyFilterOperator.EQUAL, - value=Value(30)) + value=Value(30), + ) projection = [Projection.from_repr({'property': {'name': '__key__'}})] - query = Query(kind=kind, query_filter=Filter(property_filter), limit=1, - projection=projection) + query = Query( + kind=kind, query_filter=Filter(property_filter), limit=1, + projection=projection, + ) result = await ds.runQuery(query, session=s) assert result.entity_results[0].entity.properties == {} assert result.entity_result_type.value == 'KEY_ONLY' @@ -146,16 +156,20 @@ async def test_query_with_key_projection(creds: str, kind: str, @pytest.mark.asyncio -async def test_query_with_value_projection(creds: str, kind: str, - project: str) -> None: +async def test_query_with_value_projection( + creds: str, kind: str, + project: str, +) -> None: async with Session() as s: ds = Datastore(project=project, service_file=creds, session=s) # setup test data await ds.insert(Key(project, [PathElement(kind)]), {'value': 30}, s) projection = [Projection.from_repr({'property': {'name': 'value'}})] - query = Query(kind=kind, limit=1, - projection=projection) + query = Query( + kind=kind, limit=1, + projection=projection, + ) result = await ds.runQuery(query, session=s) assert result.entity_result_type.value == 'PROJECTION' # clean up test data @@ -163,8 +177,10 @@ async def test_query_with_value_projection(creds: str, kind: str, @pytest.mark.asyncio -async def test_query_with_distinct_on(creds: str, kind: str, - project: str) -> None: +async def test_query_with_distinct_on( + creds: str, kind: str, + project: str, +) -> None: keys1 = [Key(project, [PathElement(kind)]) for i in range(3)] keys2 = [Key(project, [PathElement(kind)]) for i in range(3)] async with Session() as s: @@ -195,7 +211,8 @@ async def test_query(creds: str, kind: str, project: str) -> None: property_filter = PropertyFilter( prop='value', operator=PropertyFilterOperator.EQUAL, - value=Value(42)) + value=Value(42), + ) query = Query(kind=kind, query_filter=Filter(property_filter)) before = await ds.runQuery(query, session=s) @@ -203,12 +220,16 @@ async def test_query(creds: str, kind: str, project: str) -> None: transaction = await ds.beginTransaction(session=s) mutations = [ - ds.make_mutation(Operation.INSERT, - Key(project, [PathElement(kind)]), - properties={'value': 42}), - ds.make_mutation(Operation.INSERT, - Key(project, [PathElement(kind)]), - properties={'value': 42}), + ds.make_mutation( + Operation.INSERT, + Key(project, [PathElement(kind)]), + properties={'value': 42}, + ), + ds.make_mutation( + Operation.INSERT, + Key(project, [PathElement(kind)]), + properties={'value': 42}, + ), ] await ds.commit(mutations, transaction=transaction, session=s) @@ -222,23 +243,31 @@ async def test_gql_query(creds: str, kind: str, project: str) -> None: async with Session() as s: ds = Datastore(project=project, service_file=creds, session=s) - query = GQLQuery(f'SELECT * FROM {kind} WHERE value = @value', - named_bindings={'value': 42}) + query = GQLQuery( + f'SELECT * FROM {kind} WHERE value = @value', + named_bindings={'value': 42}, + ) before = await ds.runQuery(query, session=s) num_results = len(before.entity_results) transaction = await ds.beginTransaction(session=s) mutations = [ - ds.make_mutation(Operation.INSERT, - Key(project, [PathElement(kind)]), - properties={'value': 42}), - ds.make_mutation(Operation.INSERT, - Key(project, [PathElement(kind)]), - properties={'value': 42}), - ds.make_mutation(Operation.INSERT, - Key(project, [PathElement(kind)]), - properties={'value': 42}), + ds.make_mutation( + Operation.INSERT, + Key(project, [PathElement(kind)]), + properties={'value': 42}, + ), + ds.make_mutation( + Operation.INSERT, + Key(project, [PathElement(kind)]), + properties={'value': 42}, + ), + ds.make_mutation( + Operation.INSERT, + Key(project, [PathElement(kind)]), + properties={'value': 42}, + ), ] await ds.commit(mutations, transaction=transaction, session=s) @@ -249,41 +278,52 @@ async def test_gql_query(creds: str, kind: str, project: str) -> None: @pytest.mark.asyncio @pytest.mark.xfail(strict=False) async def test_gql_query_pagination( - creds: str, kind: str, project: str) -> None: + creds: str, kind: str, project: str, +) -> None: async with Session() as s: - query_string = (f'SELECT __key__ FROM {kind}' - 'WHERE value = @value LIMIT @limit OFFSET @offset') + query_string = ( + f'SELECT __key__ FROM {kind}' + 'WHERE value = @value LIMIT @limit OFFSET @offset' + ) named_bindings = {'value': 42, 'limit': 2 ** 31 - 1, 'offset': 0} ds = Datastore(project=project, service_file=creds, session=s) before = await ds.runQuery( - GQLQuery(query_string, named_bindings=named_bindings), session=s) + GQLQuery(query_string, named_bindings=named_bindings), session=s, + ) insertion_count = 8 transaction = await ds.beginTransaction(session=s) - mutations = [ds.make_mutation(Operation.INSERT, - Key(project, [PathElement(kind)]), - properties=named_bindings) - ] * insertion_count + mutations = [ + ds.make_mutation( + Operation.INSERT, + Key(project, [PathElement(kind)]), + properties=named_bindings, + ), + ] * insertion_count await ds.commit(mutations, transaction=transaction, session=s) page_size = 5 named_bindings['limit'] = page_size named_bindings['offset'] = GQLCursor(before.end_cursor) first_page = await ds.runQuery( - GQLQuery(query_string, named_bindings=named_bindings), session=s) + GQLQuery(query_string, named_bindings=named_bindings), session=s, + ) assert (len(first_page.entity_results)) == page_size named_bindings['offset'] = GQLCursor(first_page.end_cursor) second_page = await ds.runQuery( - GQLQuery(query_string, named_bindings=named_bindings), session=s) + GQLQuery(query_string, named_bindings=named_bindings), session=s, + ) assert len(second_page.entity_results) == insertion_count - page_size @pytest.mark.asyncio -async def test_datastore_export(creds: str, project: str, - export_bucket_name: str): +async def test_datastore_export( + creds: str, project: str, + export_bucket_name: str, +): # N.B. when modifying this test, please also see `test_table_load_copy` in # `gcloud-aio-bigquery`. kind = 'PublicTestDatastoreExportModel' @@ -293,15 +333,19 @@ async def test_datastore_export(creds: str, project: str, async with Session() as s: ds = Datastore(project=project, service_file=creds, session=s) - await ds.insert(Key(project, [PathElement(kind)]), - properties={'rand_str': rand_uuid}) + await ds.insert( + Key(project, [PathElement(kind)]), + properties={'rand_str': rand_uuid}, + ) operation = await ds.export(export_bucket_name, kinds=[kind]) count = 0 - while (count < 10 - and operation - and operation.metadata['common']['state'] == 'PROCESSING'): + while ( + count < 10 + and operation + and operation.metadata['common']['state'] == 'PROCESSING' + ): await sleep(10) operation = await ds.get_datastore_operation(operation.name) count += 1 @@ -312,7 +356,9 @@ async def test_datastore_export(creds: str, project: str, export_path = operation.metadata['outputUrlPrefix'][prefix_len:] storage = Storage(service_file=creds, session=s) - files = await storage.list_objects(export_bucket_name, - params={'prefix': export_path}) + files = await storage.list_objects( + export_bucket_name, + params={'prefix': export_path}, + ) for file in files['items']: await storage.delete(export_bucket_name, file['name']) diff --git a/datastore/tests/unit/filter_test.py b/datastore/tests/unit/filter_test.py index 99a4adba7..fc1a33fff 100644 --- a/datastore/tests/unit/filter_test.py +++ b/datastore/tests/unit/filter_test.py @@ -17,10 +17,10 @@ def test_property_filter_from_repr(property_filters): original_filter = property_filters[0] data = { 'property': { - 'name': original_filter.prop + 'name': original_filter.prop, }, 'op': original_filter.operator, - 'value': original_filter.value.to_repr() + 'value': original_filter.value.to_repr(), } output_filter = PropertyFilter.from_repr(data) @@ -34,7 +34,8 @@ def test_property_filter_to_repr(self, property_filters): r = query_filter.to_repr() self._assert_is_correct_prop_dict_for_property_filter( - r['propertyFilter'], property_filter) + r['propertyFilter'], property_filter, + ) @staticmethod def test_composite_filter_from_repr(property_filters): @@ -44,7 +45,8 @@ def test_composite_filter_from_repr(property_filters): filters=[ Filter(property_filters[0]), Filter(property_filters[1]), - ]) + ], + ) data = { 'op': original_filter.operator, 'filters': [ @@ -78,8 +80,9 @@ def test_composite_filter_to_repr(self, property_filters): operator=CompositeFilterOperator.AND, filters=[ Filter(property_filters[0]), - Filter(property_filters[1]) - ]) + Filter(property_filters[1]), + ], + ) query_filter = Filter(composite_filter) r = query_filter.to_repr() @@ -88,17 +91,19 @@ def test_composite_filter_to_repr(self, property_filters): assert composite_filter_dict['op'] == 'AND' self._assert_is_correct_prop_dict_for_property_filter( composite_filter_dict['filters'][0]['propertyFilter'], - property_filters[0]) + property_filters[0], + ) self._assert_is_correct_prop_dict_for_property_filter( composite_filter_dict['filters'][1]['propertyFilter'], - property_filters[1]) + property_filters[1], + ) @staticmethod def test_filter_from_repr(composite_filter): original_filter = Filter(inner_filter=composite_filter) data = { - 'compositeFilter': original_filter.inner_filter.to_repr() + 'compositeFilter': original_filter.inner_filter.to_repr(), } output_filter = Filter.from_repr(data) @@ -109,7 +114,7 @@ def test_filter_from_repr(composite_filter): def test_filter_from_repr_unexpected_filter_name(): unexpected_filter_name = 'unexpectedFilterName' data = { - unexpected_filter_name: 'DoesNotMatter' + unexpected_filter_name: 'DoesNotMatter', } with pytest.raises(ValueError) as ex_info: @@ -136,13 +141,13 @@ def property_filters() -> List[PropertyFilter]: PropertyFilter( prop='prop1', operator=PropertyFilterOperator.LESS_THAN, - value=Value('value1') + value=Value('value1'), ), PropertyFilter( prop='prop2', operator=PropertyFilterOperator.GREATER_THAN, - value=Value(1234) - ) + value=Value(1234), + ), ] @staticmethod @@ -152,8 +157,9 @@ def composite_filter(property_filters) -> CompositeFilter: operator=CompositeFilterOperator.AND, filters=[ Filter(property_filters[0]), - Filter(property_filters[1]) - ]) + Filter(property_filters[1]), + ], + ) @staticmethod @pytest.fixture(scope='session') @@ -167,7 +173,8 @@ def value() -> Value: @staticmethod def _assert_is_correct_prop_dict_for_property_filter( - prop_dict: Dict[str, Any], property_filter: PropertyFilter): + prop_dict: Dict[str, Any], property_filter: PropertyFilter, + ): assert prop_dict['property']['name'] == property_filter.prop assert prop_dict['op'] == property_filter.operator.value assert prop_dict['value'] == property_filter.value.to_repr() diff --git a/datastore/tests/unit/gql_query_test.py b/datastore/tests/unit/gql_query_test.py index a1599cfa3..b1d2b6e04 100644 --- a/datastore/tests/unit/gql_query_test.py +++ b/datastore/tests/unit/gql_query_test.py @@ -16,16 +16,17 @@ def test_from_repr(query): 'namedBindings': { 'string_param': { 'value': { - 'stringValue': 'foo' - } + 'stringValue': 'foo', + }, }, 'cursor_param': { - 'cursor': 'startCursor' - } + 'cursor': 'startCursor', + }, }, 'positionalBindings': [ {'value': {'integerValue': '123'}}, - {'cursor': 'endCursor'}], + {'cursor': 'endCursor'}, + ], } output_query = GQLQuery.from_repr(data) @@ -40,16 +41,17 @@ def test_to_repr(query): 'string_param': { 'value': { 'excludeFromIndexes': False, - 'stringValue': 'foo' - } + 'stringValue': 'foo', + }, }, 'cursor_param': { - 'cursor': 'startCursor' - } + 'cursor': 'startCursor', + }, }, 'positionalBindings': [ {'value': {'excludeFromIndexes': False, 'integerValue': 123}}, - {'cursor': 'endCursor'}], + {'cursor': 'endCursor'}, + ], } output_data = query.to_repr() @@ -69,16 +71,18 @@ def test_repr_returns_to_repr_as_string(query): @staticmethod @pytest.fixture(scope='session') def query(named_bindings, positional_bindings) -> GQLQuery: - return GQLQuery('query_string', - named_bindings=named_bindings, - positional_bindings=positional_bindings) + return GQLQuery( + 'query_string', + named_bindings=named_bindings, + positional_bindings=positional_bindings, + ) @staticmethod @pytest.fixture(scope='session') def named_bindings() -> Dict[str, Any]: return { 'string_param': 'foo', - 'cursor_param': GQLCursor('startCursor') + 'cursor_param': GQLCursor('startCursor'), } @staticmethod diff --git a/datastore/tests/unit/property_order_test.py b/datastore/tests/unit/property_order_test.py index 3e39def3c..15b4c18a2 100644 --- a/datastore/tests/unit/property_order_test.py +++ b/datastore/tests/unit/property_order_test.py @@ -13,9 +13,9 @@ def test_order_from_repr(property_order): original_order = property_order data = { 'property': { - 'name': original_order.prop + 'name': original_order.prop, }, - 'direction': original_order.direction + 'direction': original_order.direction, } output_order = PropertyOrder.from_repr(data) diff --git a/datastore/tests/unit/query_test.py b/datastore/tests/unit/query_test.py index 26286c292..b189dc1d1 100644 --- a/datastore/tests/unit/query_test.py +++ b/datastore/tests/unit/query_test.py @@ -14,7 +14,7 @@ def test_from_repr(query): original_query = query data = { 'kind': original_query.kind, - 'filter': original_query.query_filter.to_repr() + 'filter': original_query.query_filter.to_repr(), } output_query = Query.from_repr(data) @@ -26,7 +26,7 @@ def test_from_repr_query_without_kind(query_filter): original_query = Query(kind='', query_filter=query_filter) data = { 'kind': [], - 'filter': original_query.query_filter.to_repr() + 'filter': original_query.query_filter.to_repr(), } output_query = Query.from_repr(data) @@ -37,7 +37,7 @@ def test_from_repr_query_without_kind(query_filter): def test_from_repr_query_with_several_orders(): orders = [ PropertyOrder('property1', direction=Direction.ASCENDING), - PropertyOrder('property2', direction=Direction.DESCENDING) + PropertyOrder('property2', direction=Direction.DESCENDING), ] original_query = Query(order=orders) @@ -46,17 +46,17 @@ def test_from_repr_query_with_several_orders(): 'order': [ { 'property': { - 'name': orders[0].prop + 'name': orders[0].prop, }, - 'direction': orders[0].direction + 'direction': orders[0].direction, }, { 'property': { - 'name': orders[1].prop + 'name': orders[1].prop, }, - 'direction': orders[1].direction - } - ] + 'direction': orders[1].direction, + }, + ], } output_query = Query.from_repr(data) @@ -94,7 +94,7 @@ def test_to_repr_query_with_filter(query_filter): def test_to_repr_query_with_several_orders(): orders = [ PropertyOrder('property1', direction=Direction.ASCENDING), - PropertyOrder('property2', direction=Direction.DESCENDING) + PropertyOrder('property2', direction=Direction.DESCENDING), ] query = Query(order=orders) @@ -124,5 +124,6 @@ def query_filter() -> Filter: inner_filter = PropertyFilter( prop='property_name', operator=PropertyFilterOperator.EQUAL, - value=Value(123)) + value=Value(123), + ) return Filter(inner_filter) diff --git a/datastore/tests/unit/value_test.py b/datastore/tests/unit/value_test.py index 815461cac..faa96df1c 100644 --- a/datastore/tests/unit/value_test.py +++ b/datastore/tests/unit/value_test.py @@ -10,20 +10,22 @@ class TestValue: @staticmethod - @pytest.mark.parametrize('json_key,json_value', [ - ('booleanValue', True), - ('doubleValue', 34.48), - ('integerValue', 8483), - ('stringValue', 'foobar'), - ('booleanValue', False), - ('doubleValue', 0.0), - ('integerValue', 0), - ('stringValue', ''), - ]) + @pytest.mark.parametrize( + 'json_key,json_value', [ + ('booleanValue', True), + ('doubleValue', 34.48), + ('integerValue', 8483), + ('stringValue', 'foobar'), + ('booleanValue', False), + ('doubleValue', 0.0), + ('integerValue', 0), + ('stringValue', ''), + ], + ) def test_from_repr(json_key, json_value): data = { 'excludeFromIndexes': False, - json_key: json_value + json_key: json_value, } value = Value.from_repr(data) @@ -35,7 +37,7 @@ def test_from_repr(json_key, json_value): def test_from_repr_with_null_value(): data = { 'excludeFromIndexes': False, - 'nullValue': 'NULL_VALUE' + 'nullValue': 'NULL_VALUE', } value = Value.from_repr(data) @@ -44,37 +46,57 @@ def test_from_repr_with_null_value(): assert value.value is None @staticmethod - @pytest.mark.parametrize('v,expected', [ - ('1998-07-12T11:22:33.456789000Z', - datetime(year=1998, month=7, day=12, hour=11, - minute=22, second=33, microsecond=456789)), - ('1998-07-12T11:22:33.456789Z', - datetime(year=1998, month=7, day=12, hour=11, - minute=22, second=33, microsecond=456789)), - ('1998-07-12T11:22:33.456Z', - datetime(year=1998, month=7, day=12, hour=11, - minute=22, second=33, microsecond=456000)), - ('1998-07-12T11:22:33', - datetime(year=1998, month=7, day=12, hour=11, - minute=22, second=33, microsecond=0)), - ]) + @pytest.mark.parametrize( + 'v,expected', [ + ( + '1998-07-12T11:22:33.456789000Z', + datetime( + year=1998, month=7, day=12, hour=11, + minute=22, second=33, microsecond=456789, + ), + ), + ( + '1998-07-12T11:22:33.456789Z', + datetime( + year=1998, month=7, day=12, hour=11, + minute=22, second=33, microsecond=456789, + ), + ), + ( + '1998-07-12T11:22:33.456Z', + datetime( + year=1998, month=7, day=12, hour=11, + minute=22, second=33, microsecond=456000, + ), + ), + ( + '1998-07-12T11:22:33', + datetime( + year=1998, month=7, day=12, hour=11, + minute=22, second=33, microsecond=0, + ), + ), + ], + ) def test_from_repr_with_datetime_value(v, expected): data = { 'excludeFromIndexes': False, - 'timestampValue': v + 'timestampValue': v, } value = Value.from_repr(data) assert value.value == expected @staticmethod - @pytest.mark.skipif(sys.version_info[0] < 3, - reason='skipping because python2 has same ' - 'type for str and bytes') + @pytest.mark.skipif( + sys.version_info[0] < 3, + reason='skipping because python2 has same ' + 'type for str and bytes', + ) def test_from_repr_with_blob_value(): data = { 'excludedFromIndexed': False, - 'blobValue': 'Zm9vYmFy' + 'blobValue': 'Zm9vYmFy', } value = Value.from_repr(data) @@ -84,7 +106,7 @@ def test_from_repr_with_blob_value(): def test_from_repr_with_key_value(key): data = { 'excludeFromIndexes': False, - 'keyValue': key.to_repr() + 'keyValue': key.to_repr(), } value = Value.from_repr(data) @@ -95,7 +117,7 @@ def test_from_repr_with_key_value(key): def test_from_repr_with_geo_point_value(lat_lng): data = { 'excludeFromIndexes': False, - 'geoPointValue': lat_lng.to_repr() + 'geoPointValue': lat_lng.to_repr(), } value = Value.from_repr(data) @@ -114,16 +136,18 @@ def test_from_repr_could_not_find_supported_value_key(): assert 'excludeFromIndexes' in ex_info.value.args[0] @staticmethod - @pytest.mark.parametrize('v,expected_json_key', [ - (True, 'booleanValue'), - (34.48, 'doubleValue'), - (8483, 'integerValue'), - ('foobar', 'stringValue'), - (False, 'booleanValue'), - (0.0, 'doubleValue'), - (0, 'integerValue'), - ('', 'stringValue'), - ]) + @pytest.mark.parametrize( + 'v,expected_json_key', [ + (True, 'booleanValue'), + (34.48, 'doubleValue'), + (8483, 'integerValue'), + ('foobar', 'stringValue'), + (False, 'booleanValue'), + (0.0, 'doubleValue'), + (0, 'integerValue'), + ('', 'stringValue'), + ], + ) def test_to_repr(v, expected_json_key): value = Value(v) @@ -143,8 +167,10 @@ def test_to_repr_with_null_value(): @staticmethod def test_to_repr_with_datetime_value(): - dt = datetime(year=2018, month=7, day=15, hour=11, minute=22, - second=33, microsecond=456789) + dt = datetime( + year=2018, month=7, day=15, hour=11, minute=22, + second=33, microsecond=456789, + ) value = Value(dt) r = value.to_repr() @@ -152,9 +178,11 @@ def test_to_repr_with_datetime_value(): assert r['timestampValue'] == '2018-07-15T11:22:33.456789000Z' @staticmethod - @pytest.mark.skipif(sys.version_info[0] < 3, - reason='skipping because python2 has same ' - 'type for str and bytes') + @pytest.mark.skipif( + sys.version_info[0] < 3, + reason='skipping because python2 has same ' + 'type for str and bytes', + ) def test_to_repr_with_blob_value(): value = Value(b'foobar') diff --git a/kms/gcloud/aio/kms/kms.py b/kms/gcloud/aio/kms/kms.py index e5ef59d33..1d4880032 100644 --- a/kms/gcloud/aio/kms/kms.py +++ b/kms/gcloud/aio/kms/kms.py @@ -51,13 +51,15 @@ def __init__( self._api_is_dev, self._api_root = init_api_root(api_root) self._api_root = ( f'{self._api_root}/projects/{keyproject}/locations/{location}/' - f'keyRings/{keyring}/cryptoKeys/{keyname}') + f'keyRings/{keyring}/cryptoKeys/{keyname}' + ) self.session = AioSession(session) self.token = token or Token( service_file=service_file, session=self.session.session, # type: ignore[arg-type] - scopes=SCOPES) + scopes=SCOPES, + ) async def headers(self) -> Dict[str, str]: if self._api_is_dev: @@ -70,8 +72,10 @@ async def headers(self) -> Dict[str, str]: } # https://cloud.google.com/kms/docs/reference/rest/v1/projects.locations.keyRings.cryptoKeys/decrypt - async def decrypt(self, ciphertext: str, - session: Optional[Session] = None) -> str: + async def decrypt( + self, ciphertext: str, + session: Optional[Session] = None, + ) -> str: url = f'{self._api_root}:decrypt' body = json.dumps({ 'ciphertext': ciphertext, @@ -79,15 +83,19 @@ async def decrypt(self, ciphertext: str, s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, headers=await self.headers(), - data=body) # type: ignore[arg-type] + resp = await s.post( + url, headers=await self.headers(), + data=body, # type: ignore[arg-type] + ) plaintext: str = (await resp.json())['plaintext'] return plaintext # https://cloud.google.com/kms/docs/reference/rest/v1/projects.locations.keyRings.cryptoKeys/encrypt - async def encrypt(self, plaintext: str, - session: Optional[Session] = None) -> str: + async def encrypt( + self, plaintext: str, + session: Optional[Session] = None, + ) -> str: url = f'{self._api_root}:encrypt' body = json.dumps({ 'plaintext': plaintext, @@ -95,8 +103,10 @@ async def encrypt(self, plaintext: str, s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, headers=await self.headers(), - data=body) # type: ignore[arg-type] + resp = await s.post( + url, headers=await self.headers(), + data=body, # type: ignore[arg-type] + ) ciphertext: str = (await resp.json())['ciphertext'] return ciphertext diff --git a/pubsub/gcloud/aio/pubsub/metrics.py b/pubsub/gcloud/aio/pubsub/metrics.py index db1ccbe31..64e18829f 100644 --- a/pubsub/gcloud/aio/pubsub/metrics.py +++ b/pubsub/gcloud/aio/pubsub/metrics.py @@ -14,15 +14,19 @@ namespace=_NAMESPACE, subsystem=_SUBSYSTEM, unit='size', - buckets=(0, 1, 5, 10, 25, 50, 100, 150, 250, 500, 1000, 1500, 2000, - 5000, float('inf'))) + buckets=( + 0, 1, 5, 10, 25, 50, 100, 150, 250, 500, 1000, 1500, 2000, + 5000, float('inf'), + ), + ) CONSUME = prometheus_client.Counter( 'subscriber_consume', 'Counter of the outcomes of PubSub message consume attempts', ['outcome'], namespace=_NAMESPACE, - subsystem=_SUBSYSTEM) + subsystem=_SUBSYSTEM, + ) CONSUME_LATENCY = prometheus_client.Histogram( 'subscriber_consume_latency', @@ -30,24 +34,28 @@ ['phase'], namespace=_NAMESPACE, subsystem=_SUBSYSTEM, - unit='seconds') + unit='seconds', + ) BATCH_STATUS = prometheus_client.Counter( 'subscriber_batch_status', 'Counter for success/failure to process PubSub message batches', ['component', 'outcome'], namespace=_NAMESPACE, - subsystem=_SUBSYSTEM) + subsystem=_SUBSYSTEM, + ) MESSAGES_PROCESSED = prometheus_client.Counter( 'subscriber_messages_processed', 'Counter of successfully acked/nacked messages', ['component'], namespace=_NAMESPACE, - subsystem=_SUBSYSTEM) + subsystem=_SUBSYSTEM, + ) MESSAGES_RECEIVED = prometheus_client.Counter( 'subscriber_messages_received', 'Counter of messages pulled from subscription', namespace=_NAMESPACE, - subsystem=_SUBSYSTEM) + subsystem=_SUBSYSTEM, + ) diff --git a/pubsub/gcloud/aio/pubsub/metrics_agent.py b/pubsub/gcloud/aio/pubsub/metrics_agent.py index 524f39e6a..447c3b181 100644 --- a/pubsub/gcloud/aio/pubsub/metrics_agent.py +++ b/pubsub/gcloud/aio/pubsub/metrics_agent.py @@ -4,12 +4,16 @@ class MetricsAgent: to be compatible with subscriber.subscribe """ - def histogram(self, - metric: str, - value: float) -> None: + def histogram( + self, + metric: str, + value: float, + ) -> None: pass - def increment(self, - metric: str, - value: float = 1) -> None: + def increment( + self, + metric: str, + value: float = 1, + ) -> None: pass diff --git a/pubsub/gcloud/aio/pubsub/publisher_client.py b/pubsub/gcloud/aio/pubsub/publisher_client.py index d9368aa97..587eca184 100644 --- a/pubsub/gcloud/aio/pubsub/publisher_client.py +++ b/pubsub/gcloud/aio/pubsub/publisher_client.py @@ -55,7 +55,8 @@ def __init__( self.session = AioSession(session, verify_ssl=not self._api_is_dev) self.token = token or Token( service_file=service_file, scopes=SCOPES, - session=self.session.session) # type: ignore[arg-type] + session=self.session.session, # type: ignore[arg-type] + ) @staticmethod def project_path(project: str) -> str: @@ -71,7 +72,7 @@ def topic_path(cls, project: str, topic: str) -> str: async def _headers(self) -> Dict[str, str]: headers = { - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', } if self._api_is_dev: return headers @@ -84,26 +85,32 @@ async def _headers(self) -> Dict[str, str]: # https://github.com/googleapis/python-pubsub/blob/master/google/cloud/pubsub_v1/gapic/publisher_client.py # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.topics/list - async def list_topics(self, project: str, - query_params: Optional[Dict[str, str]] = None, *, - session: Optional[Session] = None, - timeout: int = 10) -> Dict[str, Any]: + async def list_topics( + self, project: str, + query_params: Optional[Dict[str, str]] = None, *, + session: Optional[Session] = None, + timeout: int = 10 + ) -> Dict[str, Any]: """ List topics """ url = f'{self._api_root}/{project}/topics' headers = await self._headers() s = AioSession(session) if session else self.session - resp = await s.get(url, headers=headers, params=query_params, - timeout=timeout) + resp = await s.get( + url, headers=headers, params=query_params, + timeout=timeout, + ) result: Dict[str, Any] = await resp.json() return result # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.topics/create - async def create_topic(self, topic: str, - body: Optional[Dict[str, Any]] = None, *, - session: Optional[Session] = None, - timeout: int = 10) -> Dict[str, Any]: + async def create_topic( + self, topic: str, + body: Optional[Dict[str, Any]] = None, *, + session: Optional[Session] = None, + timeout: int = 10 + ) -> Dict[str, Any]: """ Create topic. """ @@ -112,15 +119,19 @@ async def create_topic(self, topic: str, encoded = json.dumps(body or {}).encode() s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.put(url, data=encoded, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.put( + url, data=encoded, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) result: Dict[str, Any] = await resp.json() return result # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.topics/delete - async def delete_topic(self, topic: str, *, - session: Optional[Session] = None, - timeout: int = 10) -> None: + async def delete_topic( + self, topic: str, *, + session: Optional[Session] = None, + timeout: int = 10 + ) -> None: """ Delete topic. """ @@ -130,9 +141,11 @@ async def delete_topic(self, topic: str, *, await s.delete(url, headers=headers, timeout=timeout) # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.topics/publish - async def publish(self, topic: str, messages: List[PubsubMessage], - session: Optional[Session] = None, - timeout: int = 10) -> Dict[str, Any]: + async def publish( + self, topic: str, messages: List[PubsubMessage], + session: Optional[Session] = None, + timeout: int = 10, + ) -> Dict[str, Any]: if not messages: return {} @@ -146,8 +159,10 @@ async def publish(self, topic: str, messages: List[PubsubMessage], s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, data=payload, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.post( + url, data=payload, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) data: Dict[str, Any] = await resp.json() return data diff --git a/pubsub/gcloud/aio/pubsub/subscriber.py b/pubsub/gcloud/aio/pubsub/subscriber.py index b535d375f..248be8924 100644 --- a/pubsub/gcloud/aio/pubsub/subscriber.py +++ b/pubsub/gcloud/aio/pubsub/subscriber.py @@ -25,8 +25,12 @@ log = logging.getLogger(__name__) if TYPE_CHECKING: - MessageQueue = asyncio.Queue[Tuple[SubscriberMessage, # pylint: disable=unsubscriptable-object - float]] + MessageQueue = asyncio.Queue[ + Tuple[ + SubscriberMessage, # pylint: disable=unsubscriptable-object + float, + ] + ] else: MessageQueue = asyncio.Queue @@ -34,8 +38,10 @@ T = TypeVar('T') class AckDeadlineCache: - def __init__(self, subscriber_client: SubscriberClient, - subscription: str, cache_timeout: float): + def __init__( + self, subscriber_client: SubscriberClient, + subscription: str, cache_timeout: float, + ): self.subscriber_client = subscriber_client self.subscription = subscription self.cache_timeout = cache_timeout @@ -50,38 +56,48 @@ async def get(self) -> float: async def refresh(self) -> None: try: sub = await self.subscriber_client.get_subscription( - self.subscription) + self.subscription, + ) self.ack_deadline = float(sub['ackDeadlineSeconds']) except Exception as e: - log.warning('failed to refresh ackDeadlineSeconds value', - exc_info=e) + log.warning( + 'failed to refresh ackDeadlineSeconds value', + exc_info=e, + ) self.last_refresh = time.perf_counter() def cache_outdated(self) -> bool: - if (time.perf_counter() - self.last_refresh > self.cache_timeout - or self.ack_deadline == float('inf')): + if ( + time.perf_counter() - self.last_refresh > self.cache_timeout + or self.ack_deadline == float('inf') + ): return True return False - async def _budgeted_queue_get(queue: 'asyncio.Queue[T]', - time_budget: float) -> List[T]: + async def _budgeted_queue_get( + queue: 'asyncio.Queue[T]', + time_budget: float, + ) -> List[T]: result = [] while time_budget > 0: start = time.perf_counter() try: message = await asyncio.wait_for( - queue.get(), timeout=time_budget) + queue.get(), timeout=time_budget, + ) result.append(message) except asyncio.TimeoutError: break time_budget -= (time.perf_counter() - start) return result - async def acker(subscription: str, - ack_queue: 'asyncio.Queue[str]', - subscriber_client: 'SubscriberClient', - ack_window: float, - metrics_client: MetricsAgent) -> None: + async def acker( + subscription: str, + ack_queue: 'asyncio.Queue[str]', + subscriber_client: 'SubscriberClient', + ack_window: float, + metrics_client: MetricsAgent, + ) -> None: ack_ids: List[str] = [] while True: if not ack_ids: @@ -92,30 +108,39 @@ async def acker(subscription: str, # acknowledge endpoint limit is 524288 bytes # which is ~2744 ack_ids if len(ack_ids) > 2500: - log.error('acker is falling behind, dropping unacked messages', - extra={'count': len(ack_ids) - 2500}) + log.error( + 'acker is falling behind, dropping unacked messages', + extra={'count': len(ack_ids) - 2500}, + ) ack_ids = ack_ids[-2500:] for _ in range(len(ack_ids) - 2500): ack_queue.task_done() try: - await subscriber_client.acknowledge(subscription, - ack_ids=ack_ids) + await subscriber_client.acknowledge( + subscription, + ack_ids=ack_ids, + ) for _ in ack_ids: ack_queue.task_done() except aiohttp.client_exceptions.ClientResponseError as e: if e.status == 400: - log.exception('unrecoverable ack error, one or more ' - 'messages may be dropped: %s', e) + log.exception( + 'unrecoverable ack error, one or more ' + 'messages may be dropped: %s', e, + ) async def maybe_ack(ack_id: str) -> None: try: await subscriber_client.acknowledge( subscription, - ack_ids=[ack_id]) + ack_ids=[ack_id], + ) except Exception as ex: - log.warning('ack failed', extra={'ack_id': ack_id}, - exc_info=ex) + log.warning( + 'ack failed', extra={'ack_id': ack_id}, + exc_info=ex, + ) finally: ack_queue.task_done() @@ -123,37 +148,50 @@ async def maybe_ack(ack_id: str) -> None: asyncio.ensure_future(maybe_ack(ack_id)) ack_ids = [] - log.warning('ack request failed, better luck next batch', - exc_info=e) + log.warning( + 'ack request failed, better luck next batch', + exc_info=e, + ) metrics_client.increment('pubsub.acker.batch.failed') - metrics.BATCH_STATUS.labels(component='acker', - outcome='failed').inc() + metrics.BATCH_STATUS.labels( + component='acker', + outcome='failed', + ).inc() continue except asyncio.CancelledError: raise except Exception as e: - log.warning('ack request failed, better luck next batch', - exc_info=e) + log.warning( + 'ack request failed, better luck next batch', + exc_info=e, + ) metrics_client.increment('pubsub.acker.batch.failed') - metrics.BATCH_STATUS.labels(component='acker', - outcome='failed').inc() + metrics.BATCH_STATUS.labels( + component='acker', + outcome='failed', + ).inc() continue metrics_client.histogram('pubsub.acker.batch', len(ack_ids)) - metrics.BATCH_STATUS.labels(component='acker', - outcome='succeeded').inc() + metrics.BATCH_STATUS.labels( + component='acker', + outcome='succeeded', + ).inc() metrics.MESSAGES_PROCESSED.labels(component='acker').inc( - len(ack_ids)) + len(ack_ids), + ) ack_ids = [] - async def nacker(subscription: str, - nack_queue: 'asyncio.Queue[str]', - subscriber_client: 'SubscriberClient', - nack_window: float, - metrics_client: MetricsAgent) -> None: + async def nacker( + subscription: str, + nack_queue: 'asyncio.Queue[str]', + subscriber_client: 'SubscriberClient', + nack_window: float, + metrics_client: MetricsAgent, + ) -> None: ack_ids: List[str] = [] while True: if not ack_ids: @@ -164,8 +202,10 @@ async def nacker(subscription: str, # modifyAckDeadline endpoint limit is 524288 bytes # which is ~2744 ack_ids if len(ack_ids) > 2500: - log.error('nacker is falling behind, dropping unacked ' - 'messages', extra={'count': len(ack_ids) - 2500}) + log.error( + 'nacker is falling behind, dropping unacked ' + 'messages', extra={'count': len(ack_ids) - 2500}, + ) ack_ids = ack_ids[-2500:] for _ in range(len(ack_ids) - 2500): nack_queue.task_done() @@ -173,23 +213,29 @@ async def nacker(subscription: str, await subscriber_client.modify_ack_deadline( subscription, ack_ids=ack_ids, - ack_deadline_seconds=0) + ack_deadline_seconds=0, + ) for _ in ack_ids: nack_queue.task_done() except aiohttp.client_exceptions.ClientResponseError as e: if e.status == 400: - log.exception('unrecoverable nack error, one or more ' - 'messages may be dropped: %s', e) + log.exception( + 'unrecoverable nack error, one or more ' + 'messages may be dropped: %s', e, + ) async def maybe_nack(ack_id: str) -> None: try: await subscriber_client.modify_ack_deadline( subscription, ack_ids=[ack_id], - ack_deadline_seconds=0) + ack_deadline_seconds=0, + ) except Exception as ex: - log.warning('nack failed', - extra={'ack_id': ack_id}, exc_info=ex) + log.warning( + 'nack failed', + extra={'ack_id': ack_id}, exc_info=ex, + ) finally: nack_queue.task_done() @@ -197,48 +243,61 @@ async def maybe_nack(ack_id: str) -> None: asyncio.ensure_future(maybe_nack(ack_id)) ack_ids = [] - log.warning('nack request failed, better luck next batch', - exc_info=e) + log.warning( + 'nack request failed, better luck next batch', + exc_info=e, + ) metrics_client.increment('pubsub.nacker.batch.failed') metrics.BATCH_STATUS.labels( - component='nacker', outcome='failed').inc() + component='nacker', outcome='failed', + ).inc() continue except asyncio.CancelledError: raise except Exception as e: - log.warning('nack request failed, better luck next batch', - exc_info=e) + log.warning( + 'nack request failed, better luck next batch', + exc_info=e, + ) metrics_client.increment('pubsub.nacker.batch.failed') metrics.BATCH_STATUS.labels( - component='nacker', outcome='failed').inc() + component='nacker', outcome='failed', + ).inc() continue metrics_client.histogram('pubsub.nacker.batch', len(ack_ids)) - metrics.BATCH_STATUS.labels(component='nacker', - outcome='succeeded').inc() + metrics.BATCH_STATUS.labels( + component='nacker', + outcome='succeeded', + ).inc() metrics.MESSAGES_PROCESSED.labels(component='nacker').inc( - len(ack_ids)) + len(ack_ids), + ) ack_ids = [] - async def _execute_callback(message: SubscriberMessage, - callback: ApplicationHandler, - ack_queue: 'asyncio.Queue[str]', - nack_queue: 'Optional[asyncio.Queue[str]]', - insertion_time: float, - metrics_client: MetricsAgent - ) -> None: + async def _execute_callback( + message: SubscriberMessage, + callback: ApplicationHandler, + ack_queue: 'asyncio.Queue[str]', + nack_queue: 'Optional[asyncio.Queue[str]]', + insertion_time: float, + metrics_client: MetricsAgent, + ) -> None: try: start = time.perf_counter() metrics.CONSUME_LATENCY.labels(phase='queueing').observe( - start - insertion_time) + start - insertion_time, + ) with metrics.CONSUME_LATENCY.labels(phase='runtime').time(): await callback(message) await ack_queue.put(message.ack_id) - metrics_client.histogram('pubsub.consumer.latency.runtime', - time.perf_counter() - start) + metrics_client.histogram( + 'pubsub.consumer.latency.runtime', + time.perf_counter() - start, + ) metrics_client.increment('pubsub.consumer.succeeded') metrics.CONSUME.labels(outcome='succeeded').inc() @@ -262,12 +321,15 @@ async def consumer( # pylint: disable=too-many-locals ack_deadline_cache: AckDeadlineCache, max_tasks: int, nack_queue: 'Optional[asyncio.Queue[str]]', - metrics_client: MetricsAgent) -> None: + metrics_client: MetricsAgent, + ) -> None: try: semaphore = asyncio.Semaphore(max_tasks) - async def _consume_one(message: SubscriberMessage, - pulled_at: float) -> None: + async def _consume_one( + message: SubscriberMessage, + pulled_at: float, + ) -> None: await semaphore.acquire() ack_deadline = await ack_deadline_cache.get() @@ -282,18 +344,22 @@ async def _consume_one(message: SubscriberMessage, # https://cloud.google.com/pubsub/docs/reference/rest/v1/PubsubMessage recv_latency = time.time() - message.publish_time.timestamp() metrics_client.histogram( - 'pubsub.consumer.latency.receive', recv_latency) + 'pubsub.consumer.latency.receive', recv_latency, + ) metrics.CONSUME_LATENCY.labels(phase='receive').observe( - recv_latency) - - task = asyncio.ensure_future(_execute_callback( - message, - callback, - ack_queue, - nack_queue, - time.perf_counter(), - metrics_client, - )) + recv_latency, + ) + + task = asyncio.ensure_future( + _execute_callback( + message, + callback, + ack_queue, + nack_queue, + time.perf_counter(), + metrics_client, + ), + ) task.add_done_callback(lambda _f: semaphore.release()) message_queue.task_done() @@ -316,7 +382,8 @@ async def producer( message_queue: MessageQueue, subscriber_client: 'SubscriberClient', max_messages: int, - metrics_client: MetricsAgent) -> None: + metrics_client: MetricsAgent, + ) -> None: try: while True: new_messages = [] @@ -330,13 +397,16 @@ async def producer( # hanging on a server which will cause delay in # message delivery or even false deadlettering if # it is enabled - timeout=30)) + timeout=30, + ), + ) new_messages = await asyncio.shield(pull_task) except (asyncio.TimeoutError, KeyError): continue metrics_client.histogram( - 'pubsub.producer.batch', len(new_messages)) + 'pubsub.producer.batch', len(new_messages), + ) metrics.MESSAGES_RECEIVED.inc(len(new_messages)) metrics.BATCH_SIZE.observe(len(new_messages)) @@ -365,66 +435,91 @@ async def producer( log.info('producer terminated gracefully') raise - async def subscribe(subscription: str, # pylint: disable=too-many-locals - handler: ApplicationHandler, - subscriber_client: SubscriberClient, - *, - num_producers: int = 1, - max_messages_per_producer: int = 100, - ack_window: float = 0.3, - ack_deadline_cache_timeout: float = float('inf'), - num_tasks_per_consumer: int = 1, - enable_nack: bool = True, - nack_window: float = 0.3, - metrics_client: Optional[MetricsAgent] = None - ) -> None: + async def subscribe( + subscription: str, + handler: ApplicationHandler, + subscriber_client: SubscriberClient, + *, + num_producers: int = 1, + max_messages_per_producer: int = 100, + ack_window: float = 0.3, + ack_deadline_cache_timeout: float = float('inf'), + num_tasks_per_consumer: int = 1, + enable_nack: bool = True, + nack_window: float = 0.3, + metrics_client: Optional[MetricsAgent] = None + ) -> None: + # pylint: disable=too-many-locals ack_queue: 'asyncio.Queue[str]' = asyncio.Queue( - maxsize=(max_messages_per_producer * num_producers)) + maxsize=(max_messages_per_producer * num_producers), + ) nack_queue: 'Optional[asyncio.Queue[str]]' = None - ack_deadline_cache = AckDeadlineCache(subscriber_client, - subscription, - ack_deadline_cache_timeout) + ack_deadline_cache = AckDeadlineCache( + subscriber_client, + subscription, + ack_deadline_cache_timeout, + ) if metrics_client is not None: - warnings.warn('Using MetricsAgent in subscribe() is deprecated. ' - 'Refer to Prometheus metrics instead.', - DeprecationWarning) + warnings.warn( + 'Using MetricsAgent in subscribe() is deprecated. ' + 'Refer to Prometheus metrics instead.', + DeprecationWarning, + ) metrics_client = metrics_client or MetricsAgent() acker_tasks = [] consumer_tasks = [] producer_tasks = [] try: - acker_tasks.append(asyncio.ensure_future( - acker(subscription, ack_queue, subscriber_client, - ack_window=ack_window, metrics_client=metrics_client) - )) + acker_tasks.append( + asyncio.ensure_future( + acker( + subscription, ack_queue, subscriber_client, + ack_window=ack_window, metrics_client=metrics_client, + ), + ), + ) if enable_nack: nack_queue = asyncio.Queue( - maxsize=(max_messages_per_producer * num_producers)) - acker_tasks.append(asyncio.ensure_future( - nacker(subscription, nack_queue, subscriber_client, - nack_window=nack_window, - metrics_client=metrics_client) - )) + maxsize=(max_messages_per_producer * num_producers), + ) + acker_tasks.append( + asyncio.ensure_future( + nacker( + subscription, nack_queue, subscriber_client, + nack_window=nack_window, + metrics_client=metrics_client, + ), + ), + ) for _ in range(num_producers): q: MessageQueue = asyncio.Queue( - maxsize=max_messages_per_producer) - consumer_tasks.append(asyncio.ensure_future( - consumer(q, - handler, - ack_queue, - ack_deadline_cache, - num_tasks_per_consumer, - nack_queue, - metrics_client=metrics_client) - )) - producer_tasks.append(asyncio.ensure_future( - producer(subscription, - q, - subscriber_client, - max_messages=max_messages_per_producer, - metrics_client=metrics_client) - )) + maxsize=max_messages_per_producer, + ) + consumer_tasks.append( + asyncio.ensure_future( + consumer( + q, + handler, + ack_queue, + ack_deadline_cache, + num_tasks_per_consumer, + nack_queue, + metrics_client=metrics_client, + ), + ), + ) + producer_tasks.append( + asyncio.ensure_future( + producer( + subscription, + q, + subscriber_client, + max_messages=max_messages_per_producer, + metrics_client=metrics_client, + ), + ), + ) # TODO: since this is in a `not BUILD_GCLOUD_REST` section, we # shouldn't have to care about py2 support. Using splat syntax @@ -432,8 +527,10 @@ async def subscribe(subscription: str, # pylint: disable=too-many-locals # though it would never be loaded at runtime in py2. # all_tasks = [*producer_tasks, *consumer_tasks, *acker_tasks] all_tasks = producer_tasks + consumer_tasks + acker_tasks - done, _ = await asyncio.wait(all_tasks, - return_when=asyncio.FIRST_COMPLETED) + done, _ = await asyncio.wait( + all_tasks, + return_when=asyncio.FIRST_COMPLETED, + ) for task in done: task.result() raise Exception('a subscriber worker shut down unexpectedly') @@ -441,13 +538,17 @@ async def subscribe(subscription: str, # pylint: disable=too-many-locals log.warning('subscriber exited', exc_info=e) for task in producer_tasks: task.cancel() - await asyncio.wait(producer_tasks, - return_when=asyncio.ALL_COMPLETED) + await asyncio.wait( + producer_tasks, + return_when=asyncio.ALL_COMPLETED, + ) for task in consumer_tasks: task.cancel() - await asyncio.wait(consumer_tasks, - return_when=asyncio.ALL_COMPLETED) + await asyncio.wait( + consumer_tasks, + return_when=asyncio.ALL_COMPLETED, + ) for task in acker_tasks: task.cancel() diff --git a/pubsub/gcloud/aio/pubsub/subscriber_client.py b/pubsub/gcloud/aio/pubsub/subscriber_client.py index 51a1a0e3c..e29129af2 100644 --- a/pubsub/gcloud/aio/pubsub/subscriber_client.py +++ b/pubsub/gcloud/aio/pubsub/subscriber_client.py @@ -22,7 +22,7 @@ from aiohttp import ClientSession as Session # type: ignore[assignment] SCOPES = [ - 'https://www.googleapis.com/auth/pubsub' + 'https://www.googleapis.com/auth/pubsub', ] @@ -51,11 +51,12 @@ def __init__( self.session = AioSession(session, verify_ssl=not self._api_is_dev) self.token = token or Token( service_file=service_file, scopes=SCOPES, - session=self.session.session) # type: ignore[arg-type] + session=self.session.session, # type: ignore[arg-type] + ) async def _headers(self) -> Dict[str, str]: headers = { - 'Content-Type': 'application/json' + 'Content-Type': 'application/json', } if self._api_is_dev: return headers @@ -65,13 +66,15 @@ async def _headers(self) -> Dict[str, str]: return headers # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.subscriptions/create - async def create_subscription(self, - subscription: str, - topic: str, - body: Optional[Dict[str, Any]] = None, - *, - session: Optional[Session] = None, - timeout: int = 10) -> Dict[str, Any]: + async def create_subscription( + self, + subscription: str, + topic: str, + body: Optional[Dict[str, Any]] = None, + *, + session: Optional[Session] = None, + timeout: int = 10 + ) -> Dict[str, Any]: """ Create subscription. """ @@ -83,15 +86,19 @@ async def create_subscription(self, encoded = json.dumps(payload).encode() s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.put(url, data=encoded, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.put( + url, data=encoded, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) result: Dict[str, Any] = await resp.json() return result # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.subscriptions/delete - async def delete_subscription(self, subscription: str, *, - session: Optional[Session] = None, - timeout: int = 10) -> None: + async def delete_subscription( + self, subscription: str, *, + session: Optional[Session] = None, + timeout: int = 10 + ) -> None: """ Delete subscription. """ @@ -101,9 +108,11 @@ async def delete_subscription(self, subscription: str, *, await s.delete(url, headers=headers, timeout=timeout) # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.subscriptions/pull - async def pull(self, subscription: str, max_messages: int, - *, session: Optional[Session] = None, - timeout: int = 30) -> List[SubscriberMessage]: + async def pull( + self, subscription: str, max_messages: int, + *, session: Optional[Session] = None, + timeout: int = 30 + ) -> List[SubscriberMessage]: """ Pull messages from subscription """ @@ -115,8 +124,10 @@ async def pull(self, subscription: str, max_messages: int, encoded = json.dumps(payload).encode() s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, data=encoded, # type: ignore[arg-type] - headers=headers, timeout=timeout) + resp = await s.post( + url, data=encoded, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) data = await resp.json() return [ SubscriberMessage.from_repr(m) @@ -124,9 +135,11 @@ async def pull(self, subscription: str, max_messages: int, ] # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.subscriptions/acknowledge - async def acknowledge(self, subscription: str, ack_ids: List[str], - *, session: Optional[Session] = None, - timeout: int = 10) -> None: + async def acknowledge( + self, subscription: str, ack_ids: List[str], + *, session: Optional[Session] = None, + timeout: int = 10 + ) -> None: """ Acknowledge messages by ackIds """ @@ -138,15 +151,19 @@ async def acknowledge(self, subscription: str, ack_ids: List[str], encoded = json.dumps(payload).encode() s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - await s.post(url, data=encoded, # type: ignore[arg-type] - headers=headers, timeout=timeout) + await s.post( + url, data=encoded, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.subscriptions/modifyAckDeadline - async def modify_ack_deadline(self, subscription: str, - ack_ids: List[str], - ack_deadline_seconds: int, - *, session: Optional[Session] = None, - timeout: int = 10) -> None: + async def modify_ack_deadline( + self, subscription: str, + ack_ids: List[str], + ack_deadline_seconds: int, + *, session: Optional[Session] = None, + timeout: int = 10 + ) -> None: """ Modify messages' ack deadline. Set ack deadline to 0 to nack messages. @@ -159,13 +176,17 @@ async def modify_ack_deadline(self, subscription: str, }).encode('utf-8') s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - await s.post(url, data=data, # type: ignore[arg-type] - headers=headers, timeout=timeout) + await s.post( + url, data=data, # type: ignore[arg-type] + headers=headers, timeout=timeout, + ) # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.subscriptions/get - async def get_subscription(self, subscription: str, - *, session: Optional[Session] = None, - timeout: int = 10) -> Dict[str, Any]: + async def get_subscription( + self, subscription: str, + *, session: Optional[Session] = None, + timeout: int = 10 + ) -> Dict[str, Any]: """ Get Subscription """ @@ -177,10 +198,12 @@ async def get_subscription(self, subscription: str, return result # https://cloud.google.com/pubsub/docs/reference/rest/v1/projects.subscriptions/list - async def list_subscriptions(self, project: str, - query_params: Optional[Dict[str, str]] = None, - *, session: Optional[Session] = None, - timeout: int = 10) -> Dict[str, Any]: + async def list_subscriptions( + self, project: str, + query_params: Optional[Dict[str, str]] = None, + *, session: Optional[Session] = None, + timeout: int = 10 + ) -> Dict[str, Any]: """ List subscriptions """ @@ -196,7 +219,7 @@ async def list_subscriptions(self, project: str, url, headers=headers, params=next_query_params, - timeout=timeout + timeout=timeout, ) page: Dict[str, Any] = await resp.json() all_results['subscriptions'] += page['subscriptions'] diff --git a/pubsub/gcloud/aio/pubsub/subscriber_message.py b/pubsub/gcloud/aio/pubsub/subscriber_message.py index 805e3a469..908965e78 100644 --- a/pubsub/gcloud/aio/pubsub/subscriber_message.py +++ b/pubsub/gcloud/aio/pubsub/subscriber_message.py @@ -8,18 +8,22 @@ def parse_publish_time(publish_time: str) -> datetime.datetime: try: return datetime.datetime.strptime( - publish_time, '%Y-%m-%dT%H:%M:%S.%fZ') + publish_time, '%Y-%m-%dT%H:%M:%S.%fZ', + ) except ValueError: return datetime.datetime.strptime( - publish_time, '%Y-%m-%dT%H:%M:%SZ') + publish_time, '%Y-%m-%dT%H:%M:%SZ', + ) class SubscriberMessage: - def __init__(self, ack_id: str, message_id: str, - publish_time: 'datetime.datetime', - data: Optional[bytes], - attributes: Optional[Dict[str, Any]], - delivery_attempt: Optional[int] = None): + def __init__( + self, ack_id: str, message_id: str, + publish_time: 'datetime.datetime', + data: Optional[bytes], + attributes: Optional[Dict[str, Any]], + delivery_attempt: Optional[int] = None, + ): self.ack_id = ack_id self.message_id = message_id self.publish_time = publish_time @@ -28,28 +32,34 @@ def __init__(self, ack_id: str, message_id: str, self.delivery_attempt = delivery_attempt @staticmethod - def from_repr(received_message: Dict[str, Any] - ) -> 'SubscriberMessage': + def from_repr( + received_message: Dict[str, Any], + ) -> 'SubscriberMessage': ack_id = received_message['ackId'] message_id = received_message['message']['messageId'] raw_data = received_message['message'].get('data') data = base64.b64decode(raw_data) if raw_data is not None else None attributes = received_message['message'].get('attributes') publish_time: datetime.datetime = parse_publish_time( - received_message['message']['publishTime']) + received_message['message']['publishTime'], + ) delivery_attempt = received_message.get('deliveryAttempt') - return SubscriberMessage(ack_id=ack_id, message_id=message_id, - publish_time=publish_time, data=data, - attributes=attributes, - delivery_attempt=delivery_attempt) + return SubscriberMessage( + ack_id=ack_id, message_id=message_id, + publish_time=publish_time, data=data, + attributes=attributes, + delivery_attempt=delivery_attempt, + ) def to_repr(self) -> Dict[str, Any]: r: Dict[str, Any] = { 'ackId': self.ack_id, 'message': { 'messageId': self.message_id, - 'publishTime': self.publish_time.strftime('%Y-%m-%dT%H:%M:%SZ') - } + 'publishTime': ( + self.publish_time.strftime('%Y-%m-%dT%H:%M:%SZ') + ), + }, } if self.attributes is not None: r['message']['attributes'] = self.attributes diff --git a/pubsub/gcloud/aio/pubsub/utils.py b/pubsub/gcloud/aio/pubsub/utils.py index 53372a322..94cfb59dc 100644 --- a/pubsub/gcloud/aio/pubsub/utils.py +++ b/pubsub/gcloud/aio/pubsub/utils.py @@ -7,8 +7,10 @@ # https://cloud.google.com/pubsub/docs/reference/rest/v1/PubsubMessage class PubsubMessage: - def __init__(self, data: Union[bytes, str], ordering_key: str = '', - **kwargs: Any) -> None: + def __init__( + self, data: Union[bytes, str], ordering_key: str = '', + **kwargs: Any + ) -> None: self.data = data self.attributes = kwargs self.ordering_key = ordering_key diff --git a/pubsub/tests/unit/subscriber_test.py b/pubsub/tests/unit/subscriber_test.py index 292fe988e..707cde48d 100644 --- a/pubsub/tests/unit/subscriber_test.py +++ b/pubsub/tests/unit/subscriber_test.py @@ -77,7 +77,8 @@ def application_callback(): @pytest.mark.asyncio async def test_ack_deadline_cache_defaults(subscriber_client): cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 1) + subscriber_client, 'fake_subscription', 1, + ) assert cache.cache_timeout == 1 assert cache.ack_deadline == float('inf') assert cache.last_refresh == float('-inf') @@ -85,7 +86,8 @@ async def test_ack_deadline_cache_defaults(subscriber_client): @pytest.mark.asyncio async def test_ack_deadline_cache_cache_outdated_false(subscriber_client): cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 1000) + subscriber_client, 'fake_subscription', 1000, + ) cache.ack_deadline = 10 cache.last_refresh = time.perf_counter() assert not cache.cache_outdated() @@ -93,38 +95,44 @@ async def test_ack_deadline_cache_cache_outdated_false(subscriber_client): @pytest.mark.asyncio async def test_ack_deadline_cache_cache_outdated_true(subscriber_client): cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 0) + subscriber_client, 'fake_subscription', 0, + ) cache.last_refresh = time.perf_counter() assert cache.cache_outdated() cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 1000) + subscriber_client, 'fake_subscription', 1000, + ) cache.last_refresh = time.perf_counter() cache.ack_deadline = float('inf') assert cache.cache_outdated() @pytest.mark.asyncio async def test_ack_deadline_cache_refresh_updates_value_and_last_refresh( - subscriber_client + subscriber_client, ): cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 1) + subscriber_client, 'fake_subscription', 1, + ) await cache.refresh() assert cache.ack_deadline == 42 assert cache.last_refresh subscriber_client.get_subscription.assert_called_once_with( - 'fake_subscription') + 'fake_subscription', + ) @pytest.mark.asyncio async def test_ack_deadline_cache_refresh_is_cool_about_failures( - subscriber_client + subscriber_client, ): f = asyncio.Future() f.set_exception(RuntimeError) subscriber_client.get_subscription = MagicMock( - return_value=f) + return_value=f, + ) cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 1) + subscriber_client, 'fake_subscription', 1, + ) cache.ack_deadline = 55.0 await cache.refresh() assert cache.ack_deadline == 55.0 @@ -132,19 +140,21 @@ async def test_ack_deadline_cache_refresh_is_cool_about_failures( @pytest.mark.asyncio async def test_ack_deadline_cache_get_calls_refresh_first_time( - subscriber_client + subscriber_client, ): cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 1) + subscriber_client, 'fake_subscription', 1, + ) assert await cache.get() == 42 assert cache.last_refresh @pytest.mark.asyncio async def test_ack_deadline_cache_get_no_call_if_not_outdated( - subscriber_client + subscriber_client, ): cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 1000) + subscriber_client, 'fake_subscription', 1000, + ) cache.ack_deadline = 33 cache.last_refresh = time.perf_counter() assert await cache.get() == 33 @@ -152,19 +162,22 @@ async def test_ack_deadline_cache_get_no_call_if_not_outdated( @pytest.mark.asyncio async def test_ack_deadline_cache_get_call_first_time( - subscriber_client + subscriber_client, ): cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 1000) + subscriber_client, 'fake_subscription', 1000, + ) cache.last_refresh = time.perf_counter() assert await cache.get() == 42 subscriber_client.get_subscription.assert_called() @pytest.mark.asyncio async def test_ack_deadline_cache_get_refreshes_if_outdated( - subscriber_client): + subscriber_client, + ): cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 0) + subscriber_client, 'fake_subscription', 0, + ) cache.ack_deadline = 33 assert await cache.get() == 42 assert cache.last_refresh @@ -175,9 +188,11 @@ async def test_ack_deadline_cache_first_get_failed(subscriber_client): f = asyncio.Future() f.set_exception(RuntimeError) subscriber_client.get_subscription = MagicMock( - return_value=f) + return_value=f, + ) cache = AckDeadlineCache( - subscriber_client, 'fake_subscription', 10) + subscriber_client, 'fake_subscription', 10, + ) assert await cache.get() == float('inf') assert cache.last_refresh subscriber_client.get_subscription.assert_called_once() @@ -195,8 +210,8 @@ async def test_producer_fetches_messages(subscriber_client): queue, subscriber_client, max_messages=1, - metrics_client=MagicMock() - ) + metrics_client=MagicMock(), + ), ) message, pulled_at = await asyncio.wait_for(queue.get(), 0.1) producer_task.cancel() @@ -220,8 +235,8 @@ async def f(*args, **kwargs): queue, subscriber_client, max_messages=1, - metrics_client=MagicMock() - ) + metrics_client=MagicMock(), + ), ) await asyncio.sleep(0) await asyncio.sleep(0) @@ -248,8 +263,8 @@ async def f(*args, **kwargs): queue, subscriber_client, max_messages=1, - metrics_client=MagicMock() - ) + metrics_client=MagicMock(), + ), ) await asyncio.sleep(0) await asyncio.sleep(0) @@ -276,8 +291,8 @@ async def f(*args, **kwargs): queue, subscriber_client, max_messages=1, - metrics_client=MagicMock() - ) + metrics_client=MagicMock(), + ), ) await asyncio.sleep(0) await asyncio.sleep(0) @@ -291,8 +306,10 @@ async def f(*args, **kwargs): @pytest.mark.asyncio async def test_producer_gracefully_shutsdown(subscriber_client): - with patch('time.perf_counter', - side_effect=(asyncio.CancelledError, 1)): + with patch( + 'time.perf_counter', + side_effect=(asyncio.CancelledError, 1), + ): queue = asyncio.Queue() producer_task = asyncio.ensure_future( producer( @@ -300,8 +317,8 @@ async def test_producer_gracefully_shutsdown(subscriber_client): queue, subscriber_client, max_messages=1, - metrics_client=MagicMock() - ) + metrics_client=MagicMock(), + ), ) await asyncio.sleep(0) await asyncio.sleep(0) @@ -316,7 +333,8 @@ async def test_producer_gracefully_shutsdown(subscriber_client): @pytest.mark.asyncio async def test_producer_fetches_once_then_waits_for_consumer( - subscriber_client): + subscriber_client, + ): queue = asyncio.Queue() producer_task = asyncio.ensure_future( producer( @@ -324,8 +342,8 @@ async def test_producer_fetches_once_then_waits_for_consumer( queue, subscriber_client, max_messages=1, - metrics_client=MagicMock() - ) + metrics_client=MagicMock(), + ), ) await asyncio.sleep(0) await asyncio.wait_for(queue.get(), 1.0) @@ -339,16 +357,21 @@ async def test_producer_fetches_once_then_waits_for_consumer( # ======== @pytest.mark.asyncio - async def test_consumer_calls_none_means_ack(ack_deadline_cache, - message, - application_callback): + async def test_consumer_calls_none_means_ack( + ack_deadline_cache, + message, + application_callback, + ): queue = asyncio.Queue() ack_queue = asyncio.Queue() nack_queue = asyncio.Queue() consumer_task = asyncio.ensure_future( - consumer(queue, application_callback, ack_queue, - ack_deadline_cache, 1, nack_queue, MagicMock())) + consumer( + queue, application_callback, ack_queue, + ack_deadline_cache, 1, nack_queue, MagicMock(), + ), + ) await queue.put((message, 0.0)) await asyncio.sleep(0) @@ -382,8 +405,11 @@ async def callback(mock): mock4 = make_message_mock() consumer_task = asyncio.ensure_future( - consumer(queue, callback, ack_queue, ack_deadline_cache, 2, None, - MagicMock())) + consumer( + queue, callback, ack_queue, ack_deadline_cache, 2, None, + MagicMock(), + ), + ) for m in [mock1, mock2, mock3, mock4]: await queue.put((m, 0.0)) @@ -409,9 +435,11 @@ async def callback(mock): await asyncio.wait_for(consumer_task, 1) @pytest.mark.asyncio - async def test_consumer_drops_expired_messages(ack_deadline_cache, - message, - application_callback): + async def test_consumer_drops_expired_messages( + ack_deadline_cache, + message, + application_callback, + ): f = asyncio.Future() f.set_result(0.0) ack_deadline_cache.get = MagicMock(return_value=f) @@ -420,8 +448,11 @@ async def test_consumer_drops_expired_messages(ack_deadline_cache, ack_queue = asyncio.Queue() nack_queue = asyncio.Queue() consumer_task = asyncio.ensure_future( - consumer(queue, application_callback, ack_queue, - ack_deadline_cache, 1, nack_queue, MagicMock())) + consumer( + queue, application_callback, ack_queue, + ack_deadline_cache, 1, nack_queue, MagicMock(), + ), + ) await queue.put((message, 0.0)) await asyncio.sleep(0) @@ -434,7 +465,7 @@ async def test_consumer_drops_expired_messages(ack_deadline_cache, @pytest.mark.asyncio async def test_consumer_handles_callback_exception_no_nack( - ack_deadline_cache, message + ack_deadline_cache, message, ): queue = asyncio.Queue() ack_queue = asyncio.Queue() @@ -445,8 +476,11 @@ async def f(*args): raise RuntimeError consumer_task = asyncio.ensure_future( - consumer(queue, f, ack_queue, ack_deadline_cache, 1, None, - MagicMock())) + consumer( + queue, f, ack_queue, ack_deadline_cache, 1, None, + MagicMock(), + ), + ) await queue.put((message, 0.0)) await asyncio.sleep(0.1) consumer_task.cancel() @@ -457,7 +491,7 @@ async def f(*args): @pytest.mark.asyncio async def test_consumer_handles_callback_exception_nack( - ack_deadline_cache, message + ack_deadline_cache, message, ): queue = asyncio.Queue() ack_queue = asyncio.Queue() @@ -469,8 +503,11 @@ async def f(*args): raise RuntimeError consumer_task = asyncio.ensure_future( - consumer(queue, f, ack_queue, ack_deadline_cache, 1, nack_queue, - MagicMock())) + consumer( + queue, f, ack_queue, ack_deadline_cache, 1, nack_queue, + MagicMock(), + ), + ) await queue.put((message, 0.0)) await asyncio.sleep(0.1) @@ -490,7 +527,7 @@ async def f(*args): @pytest.mark.asyncio async def test_consumer_gracefull_shutdown( - ack_deadline_cache, message + ack_deadline_cache, message, ): queue = asyncio.Queue() ack_queue = asyncio.Queue() @@ -510,8 +547,8 @@ async def f(*args): ack_deadline_cache, 1, nack_queue, - MagicMock() - ) + MagicMock(), + ), ) await queue.put((message, 0.0)) await asyncio.sleep(0.1) @@ -529,7 +566,7 @@ async def f(*args): @pytest.mark.asyncio async def test_consumer_gracefull_shutdown_without_pending_tasks( - ack_deadline_cache + ack_deadline_cache, ): queue = asyncio.Queue() ack_queue = asyncio.Queue() @@ -543,8 +580,8 @@ async def test_consumer_gracefull_shutdown_without_pending_tasks( ack_deadline_cache, 1, nack_queue, - MagicMock() - ) + MagicMock(), + ), ) await asyncio.sleep(0.1) consumer_task.cancel() @@ -564,13 +601,14 @@ async def test_acker_does_ack(subscriber_client): queue, subscriber_client, 0.0, - MagicMock() - ) + MagicMock(), + ), ) await queue.put('ack_id') await queue.join() subscriber_client.acknowledge.assert_called_once_with( - 'fake_subscription', ack_ids=['ack_id']) + 'fake_subscription', ack_ids=['ack_id'], + ) assert queue.qsize() == 0 acker_task.cancel() @@ -591,8 +629,8 @@ async def f(*args, **kwargs): queue, subscriber_client, 0.0, - MagicMock() - ) + MagicMock(), + ), ) await queue.put('ack_id') await asyncio.sleep(0) @@ -611,15 +649,16 @@ async def test_acker_does_batching(subscriber_client): queue, subscriber_client, 0.1, - MagicMock() - ) + MagicMock(), + ), ) await queue.put('ack_id_1') await queue.put('ack_id_2') await asyncio.sleep(0.2) acker_task.cancel() subscriber_client.acknowledge.assert_called_once_with( - 'fake_subscription', ack_ids=['ack_id_1', 'ack_id_2']) + 'fake_subscription', ack_ids=['ack_id_1', 'ack_id_2'], + ) assert queue.qsize() == 0 @pytest.mark.asyncio @@ -639,8 +678,8 @@ async def f(*args, **kwargs): queue, subscriber_client, 0.1, - MagicMock() - ) + MagicMock(), + ), ) await queue.put('ack_id_1') await queue.put('ack_id_2') @@ -651,21 +690,26 @@ async def f(*args, **kwargs): [ call('fake_subscription', ack_ids=['ack_id_1', 'ack_id_2']), call('fake_subscription', ack_ids=['ack_id_1', 'ack_id_2']), - ] + ], ) @pytest.mark.asyncio - async def test_acker_batches_not_retried_on_400(caplog, - subscriber_client): - caplog.set_level(logging.WARNING, - logger='gcloud.aio.pubsub.subscriber') + async def test_acker_batches_not_retried_on_400( + caplog, + subscriber_client, + ): + caplog.set_level( + logging.WARNING, + logger='gcloud.aio.pubsub.subscriber', + ) mock = MagicMock() async def f(*args, **kwargs): await asyncio.sleep(0) mock(*args, **kwargs) raise aiohttp.client_exceptions.ClientResponseError( - MagicMock(), None, status=400) + MagicMock(), None, status=400, + ) subscriber_client.acknowledge = f queue = asyncio.Queue() @@ -675,8 +719,8 @@ async def f(*args, **kwargs): queue, subscriber_client, 0.1, - MagicMock() - ) + MagicMock(), + ), ) await queue.put('ack_id_1') await queue.put('ack_id_2') @@ -688,12 +732,14 @@ async def f(*args, **kwargs): call('fake_subscription', ack_ids=['ack_id_1', 'ack_id_2']), call('fake_subscription', ack_ids=['ack_id_1']), call('fake_subscription', ack_ids=['ack_id_2']), - ] + ], + ) + ack_fails = sum( + 1 for (logger, level, message) in caplog.record_tuples + if logger == 'gcloud.aio.pubsub.subscriber' + and level == logging.WARNING + and message == 'ack failed' ) - ack_fails = sum(1 for (logger, level, message) in caplog.record_tuples - if logger == 'gcloud.aio.pubsub.subscriber' - and level == logging.WARNING - and message == 'ack failed') assert ack_fails == 2 # ======== @@ -709,13 +755,14 @@ async def test_nacker_does_modify_ack_deadline(subscriber_client): queue, subscriber_client, 0.0, - MagicMock() - ) + MagicMock(), + ), ) await queue.put('ack_id') await queue.join() subscriber_client.modify_ack_deadline.assert_called_once_with( - 'fake_subscription', ack_ids=['ack_id'], ack_deadline_seconds=0) + 'fake_subscription', ack_ids=['ack_id'], ack_deadline_seconds=0, + ) assert queue.qsize() == 0 nacker_task.cancel() @@ -736,8 +783,8 @@ async def f(*args, **kwargs): queue, subscriber_client, 0.0, - MagicMock() - ) + MagicMock(), + ), ) await queue.put('ack_id') await asyncio.sleep(0) @@ -756,8 +803,8 @@ async def test_nacker_does_batching(subscriber_client): queue, subscriber_client, 0.1, - MagicMock() - ) + MagicMock(), + ), ) await queue.put('ack_id_1') await queue.put('ack_id_2') @@ -766,7 +813,8 @@ async def test_nacker_does_batching(subscriber_client): subscriber_client.modify_ack_deadline.assert_called_once_with( 'fake_subscription', ack_ids=['ack_id_1', 'ack_id_2'], - ack_deadline_seconds=0) + ack_deadline_seconds=0, + ) assert queue.qsize() == 0 @pytest.mark.asyncio @@ -786,8 +834,8 @@ async def f(*args, **kwargs): queue, subscriber_client, 0.1, - MagicMock() - ) + MagicMock(), + ), ) await queue.put('ack_id_1') await queue.put('ack_id_2') @@ -796,18 +844,26 @@ async def f(*args, **kwargs): assert queue.qsize() == 0 mock.assert_has_calls( [ - call('fake_subscription', - ack_ids=['ack_id_1', 'ack_id_2'], ack_deadline_seconds=0), - call('fake_subscription', - ack_ids=['ack_id_1', 'ack_id_2'], ack_deadline_seconds=0), - ] + call( + 'fake_subscription', + ack_ids=['ack_id_1', 'ack_id_2'], ack_deadline_seconds=0, + ), + call( + 'fake_subscription', + ack_ids=['ack_id_1', 'ack_id_2'], ack_deadline_seconds=0, + ), + ], ) @pytest.mark.asyncio - async def test_nacker_batches_not_retried_on_400(caplog, - subscriber_client): - caplog.set_level(logging.WARNING, - logger='gcloud.aio.pubsub.subscriber') + async def test_nacker_batches_not_retried_on_400( + caplog, + subscriber_client, + ): + caplog.set_level( + logging.WARNING, + logger='gcloud.aio.pubsub.subscriber', + ) mock = MagicMock() @@ -815,7 +871,8 @@ async def f(*args, **kwargs): await asyncio.sleep(0) mock(*args, **kwargs) raise aiohttp.client_exceptions.ClientResponseError( - MagicMock(), None, status=400) + MagicMock(), None, status=400, + ) subscriber_client.modify_ack_deadline = f queue = asyncio.Queue() @@ -825,8 +882,8 @@ async def f(*args, **kwargs): queue, subscriber_client, 0.1, - MagicMock() - ) + MagicMock(), + ), ) await queue.put('ack_id_1') await queue.put('ack_id_2') @@ -835,18 +892,26 @@ async def f(*args, **kwargs): assert queue.qsize() == 0 mock.assert_has_calls( [ - call('fake_subscription', - ack_ids=['ack_id_1', 'ack_id_2'], ack_deadline_seconds=0), - call('fake_subscription', - ack_ids=['ack_id_1'], ack_deadline_seconds=0), - call('fake_subscription', - ack_ids=['ack_id_2'], ack_deadline_seconds=0), - ] - ) - nack_fails = sum(1 for (logger, level, message) in caplog.record_tuples - if logger == 'gcloud.aio.pubsub.subscriber' - and level == logging.WARNING - and message == 'nack failed') + call( + 'fake_subscription', + ack_ids=['ack_id_1', 'ack_id_2'], ack_deadline_seconds=0, + ), + call( + 'fake_subscription', + ack_ids=['ack_id_1'], ack_deadline_seconds=0, + ), + call( + 'fake_subscription', + ack_ids=['ack_id_2'], ack_deadline_seconds=0, + ), + ], + ) + nack_fails = sum( + 1 for (logger, level, message) in caplog.record_tuples + if logger == 'gcloud.aio.pubsub.subscriber' + and level == logging.WARNING + and message == 'nack failed' + ) assert nack_fails == 2 # ========= @@ -854,21 +919,27 @@ async def f(*args, **kwargs): # ========= @pytest.mark.asyncio - async def test_subscribe_integrates_whole_chain(subscriber_client, - application_callback): + async def test_subscribe_integrates_whole_chain( + subscriber_client, + application_callback, + ): subscribe_task = asyncio.ensure_future( - subscribe('fake_subscription', application_callback, - subscriber_client, num_producers=1, - max_messages_per_producer=100, ack_window=0.0, - ack_deadline_cache_timeout=1000, - num_tasks_per_consumer=1, enable_nack=True, - nack_window=0.0)) + subscribe( + 'fake_subscription', application_callback, + subscriber_client, num_producers=1, + max_messages_per_producer=100, ack_window=0.0, + ack_deadline_cache_timeout=1000, + num_tasks_per_consumer=1, enable_nack=True, + nack_window=0.0, + ), + ) await asyncio.sleep(0.1) subscribe_task.cancel() application_callback.assert_called() subscriber_client.acknowledge.assert_called_with( - 'fake_subscription', ack_ids=['ack_id']) + 'fake_subscription', ack_ids=['ack_id'], + ) # verify that the subscriber shuts down gracefully with pytest.raises(asyncio.CancelledError): diff --git a/pubsub/tests/unit/subscription_test.py b/pubsub/tests/unit/subscription_test.py index f4e160457..5171e9a76 100644 --- a/pubsub/tests/unit/subscription_test.py +++ b/pubsub/tests/unit/subscription_test.py @@ -17,12 +17,13 @@ def test_construct_subscriber_message_from_message(): 'ackId': 'some_ack_id', 'message': { 'data': base64.b64encode( - json.dumps({'foo': 'bar'}).encode('utf-8')), + json.dumps({'foo': 'bar'}).encode('utf-8'), + ), 'attributes': {'attr_key': 'attr_value'}, 'messageId': '123', - 'publishTime': '2020-01-01T00:00:01.000Z' + 'publishTime': '2020-01-01T00:00:01.000Z', }, - 'deliveryAttempt': 1 + 'deliveryAttempt': 1, } message = SubscriberMessage.from_repr(message_dict) assert message.ack_id == 'some_ack_id' @@ -30,7 +31,8 @@ def test_construct_subscriber_message_from_message(): assert message.message_id == '123' assert message.data == b'{"foo": "bar"}' assert message.publish_time == datetime.datetime( - 2020, 1, 1, 0, 0, 1) + 2020, 1, 1, 0, 0, 1, + ) assert message.delivery_attempt == 1 @@ -39,8 +41,8 @@ def test_construct_subscriber_message_no_metadata(): 'ackId': 'some_ack_id', 'message': { 'messageId': '123', - 'publishTime': '2020-01-01T00:00:01.000Z' - } + 'publishTime': '2020-01-01T00:00:01.000Z', + }, } message = SubscriberMessage.from_repr(message_dict) assert message.ack_id == 'some_ack_id' @@ -48,5 +50,6 @@ def test_construct_subscriber_message_no_metadata(): assert message.message_id == '123' assert message.data is None assert message.publish_time == datetime.datetime( - 2020, 1, 1, 0, 0, 1) + 2020, 1, 1, 0, 0, 1, + ) assert message.delivery_attempt is None diff --git a/storage/gcloud/aio/storage/blob.py b/storage/gcloud/aio/storage/blob.py index 16e6fb270..f8846c392 100644 --- a/storage/gcloud/aio/storage/blob.py +++ b/storage/gcloud/aio/storage/blob.py @@ -30,10 +30,14 @@ HOST = 'storage.googleapis.com' -PKCS1_MARKER = ('-----BEGIN RSA PRIVATE KEY-----', - '-----END RSA PRIVATE KEY-----') -PKCS8_MARKER = ('-----BEGIN PRIVATE KEY-----', - '-----END PRIVATE KEY-----') +PKCS1_MARKER = ( + '-----BEGIN RSA PRIVATE KEY-----', + '-----END RSA PRIVATE KEY-----', +) +PKCS8_MARKER = ( + '-----BEGIN PRIVATE KEY-----', + '-----END PRIVATE KEY-----', +) PKCS8_SPEC = PrivateKeyInfo() @@ -62,8 +66,10 @@ class PemKind(enum.Enum): class Blob: - def __init__(self, bucket: 'Bucket', name: str, - metadata: Dict[str, Any]) -> None: + def __init__( + self, bucket: 'Bucket', name: str, + metadata: Dict[str, Any], + ) -> None: self.__dict__.update(**metadata) self.bucket = bucket @@ -74,17 +80,24 @@ def __init__(self, bucket: 'Bucket', name: str, def chunk_size(self) -> int: return self.size + (262144 - (self.size % 262144)) - async def download(self, timeout: int = DEFAULT_TIMEOUT, - session: Optional[Session] = None) -> Any: - return await self.bucket.storage.download(self.bucket.name, - self.name, - timeout=timeout, - session=session) - - async def upload(self, data: Any, - session: Optional[Session] = None) -> Dict[str, Any]: + async def download( + self, timeout: int = DEFAULT_TIMEOUT, + session: Optional[Session] = None, + ) -> Any: + return await self.bucket.storage.download( + self.bucket.name, + self.name, + timeout=timeout, + session=session, + ) + + async def upload( + self, data: Any, + session: Optional[Session] = None, + ) -> Dict[str, Any]: metadata = await self.bucket.storage.upload( - self.bucket.name, self.name, data, session=session) + self.bucket.name, self.name, data, session=session, + ) self.__dict__.update(metadata) return metadata @@ -92,7 +105,8 @@ async def upload(self, data: Any, async def get_signed_url( # pylint: disable=too-many-locals self, expiration: int, headers: Optional[Dict[str, str]] = None, query_params: Optional[Dict[str, Any]] = None, - http_method: str = 'GET', token: Optional[Token] = None) -> str: + http_method: str = 'GET', token: Optional[Token] = None, + ) -> str: """ Create a temporary access URL for Storage Blob accessible by anyone with the link. @@ -101,8 +115,10 @@ async def get_signed_url( # pylint: disable=too-many-locals https://cloud.google.com/storage/docs/access-control/signing-urls-manually#python-sample """ if expiration > 604800: - raise ValueError("expiration time can't be longer than 604800 " - 'seconds (7 days)') + raise ValueError( + "expiration time can't be longer than 604800 " + 'seconds (7 days)', + ) quoted_name = quote(self.name, safe=b'/~') canonical_uri = f'/{quoted_name}' @@ -121,10 +137,12 @@ async def get_signed_url( # pylint: disable=too-many-locals client_email = token.service_data.get('client_email') private_key = token.service_data.get('private_key') if not client_email or not private_key: - raise KeyError('Blob signing is only suported for tokens with ' - 'explicit client_email and private_key data; ' - 'please check your token points to a JSON service ' - 'account file') + raise KeyError( + 'Blob signing is only suported for tokens with ' + 'explicit client_email and private_key data; ' + 'please check your token points to a JSON service ' + 'account file', + ) credential_scope = f'{datestamp}/auto/storage/goog4_request' credential = f'{client_email}/{credential_scope}' @@ -135,10 +153,12 @@ async def get_signed_url( # pylint: disable=too-many-locals ordered_headers = collections.OrderedDict(sorted(headers.items())) canonical_headers = ''.join( f'{str(k).lower()}:{str(v).lower()}\n' - for k, v in ordered_headers.items()) + for k, v in ordered_headers.items() + ) signed_headers = ';'.join( - f'{str(k).lower()}' for k in ordered_headers.keys()) + f'{str(k).lower()}' for k in ordered_headers.keys() + ) query_params = query_params or {} query_params['X-Goog-Algorithm'] = 'GOOG4-RSA-SHA256' @@ -148,33 +168,44 @@ async def get_signed_url( # pylint: disable=too-many-locals query_params['X-Goog-SignedHeaders'] = signed_headers ordered_query_params = collections.OrderedDict( - sorted(query_params.items())) + sorted(query_params.items()), + ) canonical_query_str = '&'.join( f'{quote(str(k), safe="")}={quote(str(v), safe="")}' - for k, v in ordered_query_params.items()) - - canonical_req = '\n'.join([http_method, canonical_uri, - canonical_query_str, canonical_headers, - signed_headers, 'UNSIGNED-PAYLOAD']) + for k, v in ordered_query_params.items() + ) + + canonical_req = '\n'.join([ + http_method, canonical_uri, + canonical_query_str, canonical_headers, + signed_headers, 'UNSIGNED-PAYLOAD', + ]) canonical_req_hash = hashlib.sha256(canonical_req.encode()).hexdigest() - str_to_sign = '\n'.join(['GOOG4-RSA-SHA256', request_timestamp, - credential_scope, canonical_req_hash]) + str_to_sign = '\n'.join([ + 'GOOG4-RSA-SHA256', request_timestamp, + credential_scope, canonical_req_hash, + ]) # N.B. see the ``PemKind`` enum marker_id, key_bytes = pem.readPemBlocksFromFile( - io.StringIO(private_key), PKCS1_MARKER, PKCS8_MARKER) + io.StringIO(private_key), PKCS1_MARKER, PKCS8_MARKER, + ) if marker_id == PemKind.INVALID.value: raise ValueError('private key is invalid or unsupported') if marker_id == PemKind.PKCS8.value: # convert from pkcs8 to pkcs1 - key_info, remaining = decoder.decode(key_bytes, - asn1Spec=PKCS8_SPEC) + key_info, remaining = decoder.decode( + key_bytes, + asn1Spec=PKCS8_SPEC, + ) if remaining != b'': - raise ValueError('could not read PKCS8 key: found extra bytes', - remaining) + raise ValueError( + 'could not read PKCS8 key: found extra bytes', + remaining, + ) private_key_info = key_info.getComponentByName('privateKey') key_bytes = private_key_info.asOctets() @@ -184,5 +215,7 @@ async def get_signed_url( # pylint: disable=too-many-locals signature = binascii.hexlify(signed_blob).decode() - return (f'https://{self.bucket.name}.{HOST}{canonical_uri}?' - f'{canonical_query_str}&X-Goog-Signature={signature}') + return ( + f'https://{self.bucket.name}.{HOST}{canonical_uri}?' + f'{canonical_query_str}&X-Goog-Signature={signature}' + ) diff --git a/storage/gcloud/aio/storage/bucket.py b/storage/gcloud/aio/storage/bucket.py index 1bd956ac7..f8e23a9ca 100644 --- a/storage/gcloud/aio/storage/bucket.py +++ b/storage/gcloud/aio/storage/bucket.py @@ -16,7 +16,8 @@ from requests import Session else: from aiohttp import ( # type: ignore[assignment] - ClientResponseError as ResponseError) + ClientResponseError as ResponseError, + ) from aiohttp import ClientSession as Session # type: ignore[assignment] if TYPE_CHECKING: @@ -31,16 +32,22 @@ def __init__(self, storage: 'Storage', name: str) -> None: self.storage = storage self.name = name - async def get_blob(self, blob_name: str, timeout: int = DEFAULT_TIMEOUT, - session: Optional[Session] = None) -> Blob: - metadata = await self.storage.download_metadata(self.name, blob_name, - timeout=timeout, - session=session) + async def get_blob( + self, blob_name: str, timeout: int = DEFAULT_TIMEOUT, + session: Optional[Session] = None, + ) -> Blob: + metadata = await self.storage.download_metadata( + self.name, blob_name, + timeout=timeout, + session=session, + ) return Blob(self, blob_name, metadata) - async def blob_exists(self, blob_name: str, - session: Optional[Session] = None) -> bool: + async def blob_exists( + self, blob_name: str, + session: Optional[Session] = None, + ) -> bool: try: await self.get_blob(blob_name, session=session) return True @@ -54,14 +61,18 @@ async def blob_exists(self, blob_name: str, raise e - async def list_blobs(self, prefix: str = '', - session: Optional[Session] = None) -> List[str]: + async def list_blobs( + self, prefix: str = '', + session: Optional[Session] = None, + ) -> List[str]: params = {'prefix': prefix, 'pageToken': ''} items = [] while True: - content = await self.storage.list_objects(self.name, - params=params, - session=session) + content = await self.storage.list_objects( + self.name, + params=params, + session=session, + ) items.extend([x['name'] for x in content.get('items', [])]) params['pageToken'] = content.get('nextPageToken', '') @@ -75,7 +86,9 @@ def new_blob(self, blob_name: str) -> Blob: async def get_metadata( self, params: Optional[Dict[str, Any]] = None, - session: Optional[Session] = None + session: Optional[Session] = None, ) -> Dict[str, Any]: - return await self.storage.get_bucket_metadata(self.name, params=params, - session=session) + return await self.storage.get_bucket_metadata( + self.name, params=params, + session=session, + ) diff --git a/storage/gcloud/aio/storage/storage.py b/storage/gcloud/aio/storage/storage.py index 09774697b..34e1c1962 100644 --- a/storage/gcloud/aio/storage/storage.py +++ b/storage/gcloud/aio/storage/storage.py @@ -33,7 +33,8 @@ from aiofiles import open as file_open # type: ignore[no-redef] from asyncio import sleep # type: ignore[assignment] from aiohttp import ( # type: ignore[assignment] - ClientResponseError as ResponseError) + ClientResponseError as ResponseError, + ) from aiohttp import ClientSession as Session # type: ignore[assignment] MAX_CONTENT_LENGTH_SIMPLE_UPLOAD = 5 * 1024 * 1024 # 5 MB @@ -63,8 +64,10 @@ def choose_boundary() -> str: return boundary.decode('ascii') -def encode_multipart_formdata(fields: List[Tuple[Dict[str, str], bytes]], - boundary: str) -> Tuple[bytes, str]: +def encode_multipart_formdata( + fields: List[Tuple[Dict[str, str], bytes]], + boundary: str, +) -> Tuple[bytes, str]: """ Stolen from urllib3.filepost.encode_multipart_formdata() as of v1.26.2. @@ -79,8 +82,10 @@ def encode_multipart_formdata(fields: List[Tuple[Dict[str, str], bytes]], # The below is from RequestFields.render_headers() # Since we only use Content-Type, we could simplify the below to a # single line... but probably best to be safe for future modifications. - for field in ['Content-Disposition', 'Content-Type', - 'Content-Location']: + for field in [ + 'Content-Disposition', 'Content-Type', + 'Content-Location', + ]: value = headers.pop(field, None) if value: body.append(f'{field}: {value}\r\n'.encode('utf-8')) @@ -159,7 +164,8 @@ def __init__( self.session = AioSession(session, verify_ssl=not self._api_is_dev) self.token = token or Token( service_file=service_file, scopes=SCOPES, - session=self.session.session) # type: ignore[arg-type] + session=self.session.session, # type: ignore[arg-type] + ) async def _headers(self) -> Dict[str, str]: if self._api_is_dev: @@ -174,13 +180,15 @@ def get_bucket(self, bucket_name: str) -> Bucket: return Bucket(self, bucket_name) # pylint: disable=too-many-locals - async def copy(self, bucket: str, object_name: str, - destination_bucket: str, *, new_name: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - params: Optional[Dict[str, str]] = None, - headers: Optional[Dict[str, str]] = None, - timeout: int = DEFAULT_TIMEOUT, - session: Optional[Session] = None) -> Dict[str, Any]: + async def copy( + self, bucket: str, object_name: str, + destination_bucket: str, *, new_name: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = DEFAULT_TIMEOUT, + session: Optional[Session] = None + ) -> Dict[str, Any]: """ When files are too large, multiple calls to `rewriteTo` are made. We @@ -201,21 +209,26 @@ async def copy(self, bucket: str, object_name: str, if not new_name: new_name = object_name - url = (f'{self._api_root_read}/{bucket}/o/' - f'{quote(object_name, safe="")}/rewriteTo/b/' - f'{destination_bucket}/o/{quote(new_name, safe="")}') + url = ( + f'{self._api_root_read}/{bucket}/o/' + f'{quote(object_name, safe="")}/rewriteTo/b/' + f'{destination_bucket}/o/{quote(new_name, safe="")}' + ) # We may optionally supply metadata* to apply to the rewritten # object, which explains why `rewriteTo` is a POST endpoint; when no # metadata is given, we have to send an empty body. # * https://cloud.google.com/storage/docs/json_api/v1/objects#resource metadict = (metadata or {}).copy() - metadict = {self._format_metadata_key(k): v - for k, v in metadict.items()} + metadict = { + self._format_metadata_key(k): v + for k, v in metadict.items() + } if 'metadata' in metadict: metadict['metadata'] = { str(k): str(v) if v is not None else None - for k, v in metadict['metadata'].items()} + for k, v in metadict['metadata'].items() + } metadata_ = json.dumps(metadict) @@ -229,24 +242,30 @@ async def copy(self, bucket: str, object_name: str, params = params or {} s = AioSession(session) if session else self.session - resp = await s.post(url, headers=headers, params=params, - timeout=timeout, data=metadata_) + resp = await s.post( + url, headers=headers, params=params, + timeout=timeout, data=metadata_, + ) data: Dict[str, Any] = await resp.json(content_type=None) while not data.get('done') and data.get('rewriteToken'): params['rewriteToken'] = data['rewriteToken'] - resp = await s.post(url, headers=headers, params=params, - timeout=timeout) + resp = await s.post( + url, headers=headers, params=params, + timeout=timeout, + ) data = await resp.json(content_type=None) return data - async def delete(self, bucket: str, object_name: str, *, - timeout: int = DEFAULT_TIMEOUT, - params: Optional[Dict[str, str]] = None, - headers: Optional[Dict[str, str]] = None, - session: Optional[Session] = None) -> str: + async def delete( + self, bucket: str, object_name: str, *, + timeout: int = DEFAULT_TIMEOUT, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + session: Optional[Session] = None + ) -> str: # https://cloud.google.com/storage/docs/request-endpoints#encoding encoded_object_name = quote(object_name, safe='') url = f'{self._api_root_read}/{bucket}/o/{encoded_object_name}' @@ -254,8 +273,10 @@ async def delete(self, bucket: str, object_name: str, *, headers.update(await self._headers()) s = AioSession(session) if session else self.session - resp = await s.delete(url, headers=headers, params=params or {}, - timeout=timeout) + resp = await s.delete( + url, headers=headers, params=params or {}, + timeout=timeout, + ) try: data: str = await resp.text() @@ -264,39 +285,49 @@ async def delete(self, bucket: str, object_name: str, *, return data - async def download(self, bucket: str, object_name: str, *, - headers: Optional[Dict[str, Any]] = None, - timeout: int = DEFAULT_TIMEOUT, - session: Optional[Session] = None) -> bytes: - return await self._download(bucket, object_name, headers=headers, - timeout=timeout, params={'alt': 'media'}, - session=session) - - async def download_to_filename(self, bucket: str, object_name: str, - filename: str, **kwargs: Any) -> None: + async def download( + self, bucket: str, object_name: str, *, + headers: Optional[Dict[str, Any]] = None, + timeout: int = DEFAULT_TIMEOUT, + session: Optional[Session] = None + ) -> bytes: + return await self._download( + bucket, object_name, headers=headers, + timeout=timeout, params={'alt': 'media'}, + session=session, + ) + + async def download_to_filename( + self, bucket: str, object_name: str, + filename: str, **kwargs: Any + ) -> None: async with file_open( # type: ignore[attr-defined] filename, mode='wb+', ) as file_object: await file_object.write( - await self.download(bucket, object_name, **kwargs) + await self.download(bucket, object_name, **kwargs), ) - async def download_metadata(self, bucket: str, object_name: str, *, - headers: Optional[Dict[str, Any]] = None, - session: Optional[Session] = None, - timeout: int = DEFAULT_TIMEOUT - ) -> Dict[str, Any]: - data = await self._download(bucket, object_name, headers=headers, - timeout=timeout, session=session) + async def download_metadata( + self, bucket: str, object_name: str, *, + headers: Optional[Dict[str, Any]] = None, + session: Optional[Session] = None, + timeout: int = DEFAULT_TIMEOUT + ) -> Dict[str, Any]: + data = await self._download( + bucket, object_name, headers=headers, + timeout=timeout, session=session, + ) metadata: Dict[str, Any] = json.loads(data.decode()) return metadata - async def download_stream(self, bucket: str, object_name: str, *, - headers: Optional[Dict[str, Any]] = None, - timeout: int = DEFAULT_TIMEOUT, - session: Optional[Session] = None - ) -> StreamResponse: + async def download_stream( + self, bucket: str, object_name: str, *, + headers: Optional[Dict[str, Any]] = None, + timeout: int = DEFAULT_TIMEOUT, + session: Optional[Session] = None + ) -> StreamResponse: """Download a GCS object in a buffered stream. Args: @@ -314,36 +345,44 @@ async def download_stream(self, bucket: str, object_name: str, *, StreamResponse: A object encapsulating the stream, similar to io.BufferedIOBase, but it only supports the read() function. """ - return await self._download_stream(bucket, object_name, - headers=headers, timeout=timeout, - params={'alt': 'media'}, - session=session) - - async def list_objects(self, bucket: str, *, - params: Optional[Dict[str, str]] = None, - headers: Optional[Dict[str, Any]] = None, - session: Optional[Session] = None, - timeout: int = DEFAULT_TIMEOUT) -> Dict[str, Any]: + return await self._download_stream( + bucket, object_name, + headers=headers, timeout=timeout, + params={'alt': 'media'}, + session=session, + ) + + async def list_objects( + self, bucket: str, *, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, Any]] = None, + session: Optional[Session] = None, + timeout: int = DEFAULT_TIMEOUT + ) -> Dict[str, Any]: url = f'{self._api_root_read}/{bucket}/o' headers = headers or {} headers.update(await self._headers()) s = AioSession(session) if session else self.session - resp = await s.get(url, headers=headers, params=params or {}, - timeout=timeout) + resp = await s.get( + url, headers=headers, params=params or {}, + timeout=timeout, + ) data: Dict[str, Any] = await resp.json(content_type=None) return data # https://cloud.google.com/storage/docs/json_api/v1/how-tos/upload # pylint: disable=too-many-locals - async def upload(self, bucket: str, object_name: str, file_data: Any, - *, content_type: Optional[str] = None, - parameters: Optional[Dict[str, str]] = None, - headers: Optional[Dict[str, str]] = None, - metadata: Optional[Dict[str, Any]] = None, - session: Optional[Session] = None, - force_resumable_upload: Optional[bool] = None, - timeout: int = 30) -> Dict[str, Any]: + async def upload( + self, bucket: str, object_name: str, file_data: Any, + *, content_type: Optional[str] = None, + parameters: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + metadata: Optional[Dict[str, Any]] = None, + session: Optional[Session] = None, + force_resumable_upload: Optional[bool] = None, + timeout: int = 30 + ) -> Dict[str, Any]: url = f'{self._api_root_write}/{bucket}/o' stream = self._preprocess_data(file_data) @@ -367,35 +406,44 @@ async def upload(self, bucket: str, object_name: str, file_data: Any, 'Content-Type': content_type or '', }) - upload_type = self._decide_upload_type(force_resumable_upload, - content_length) + upload_type = self._decide_upload_type( + force_resumable_upload, + content_length, + ) log.debug('using %r gcloud storage upload method', upload_type) if upload_type == UploadType.RESUMABLE: return await self._upload_resumable( url, object_name, stream, parameters, headers, - metadata=metadata, session=session, timeout=timeout) + metadata=metadata, session=session, timeout=timeout, + ) if upload_type == UploadType.SIMPLE: if metadata: return await self._upload_multipart( url, object_name, stream, parameters, headers, metadata, - session=session, timeout=timeout) + session=session, timeout=timeout, + ) return await self._upload_simple( url, object_name, stream, parameters, headers, session=session, - timeout=timeout) + timeout=timeout, + ) raise TypeError(f'upload type {upload_type} not supported') - async def upload_from_filename(self, bucket: str, object_name: str, - filename: str, - **kwargs: Any) -> Dict[str, Any]: + async def upload_from_filename( + self, bucket: str, object_name: str, + filename: str, + **kwargs: Any + ) -> Dict[str, Any]: async with file_open( # type: ignore[attr-defined] filename, mode='rb', ) as file_object: contents = await file_object.read() - return await self.upload(bucket, object_name, contents, - **kwargs) + return await self.upload( + bucket, object_name, contents, + **kwargs + ) @staticmethod def _get_stream_len(stream: IO[AnyStr]) -> int: @@ -420,8 +468,10 @@ def _preprocess_data(data: Any) -> IO[Any]: raise TypeError(f'unsupported upload type: "{type(data)}"') @staticmethod - def _decide_upload_type(force_resumable_upload: Optional[bool], - content_length: int) -> UploadType: + def _decide_upload_type( + force_resumable_upload: Optional[bool], + content_length: int, + ) -> UploadType: # force resumable if force_resumable_upload is True: return UploadType.RESUMABLE @@ -459,11 +509,13 @@ def _format_metadata_key(key: str) -> str: parts = [parts[0].lower()] + [p.capitalize() for p in parts[1:]] return ''.join(parts) - async def _download(self, bucket: str, object_name: str, *, - params: Optional[Dict[str, str]] = None, - headers: Optional[Dict[str, str]] = None, - timeout: int = DEFAULT_TIMEOUT, - session: Optional[Session] = None) -> bytes: + async def _download( + self, bucket: str, object_name: str, *, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = DEFAULT_TIMEOUT, + session: Optional[Session] = None + ) -> bytes: # https://cloud.google.com/storage/docs/request-endpoints#encoding encoded_object_name = quote(object_name, safe='') url = f'{self._api_root_read}/{bucket}/o/{encoded_object_name}' @@ -471,8 +523,10 @@ async def _download(self, bucket: str, object_name: str, *, headers.update(await self._headers()) s = AioSession(session) if session else self.session - response = await s.get(url, headers=headers, params=params or {}, - timeout=timeout) + response = await s.get( + url, headers=headers, params=params or {}, + timeout=timeout, + ) # N.B. the GCS API sometimes returns 'application/octet-stream' when a # string was uploaded. To avoid potential weirdness, always return a @@ -484,12 +538,13 @@ async def _download(self, bucket: str, object_name: str, *, return data - async def _download_stream(self, bucket: str, object_name: str, *, - params: Optional[Dict[str, str]] = None, - headers: Optional[Dict[str, str]] = None, - timeout: int = DEFAULT_TIMEOUT, - session: Optional[Session] = None - ) -> StreamResponse: + async def _download_stream( + self, bucket: str, object_name: str, *, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = DEFAULT_TIMEOUT, + session: Optional[Session] = None + ) -> StreamResponse: # https://cloud.google.com/storage/docs/request-endpoints#encoding encoded_object_name = quote(object_name, safe='') url = f'{self._api_root_read}/{bucket}/o/{encoded_object_name}' @@ -501,45 +556,62 @@ async def _download_stream(self, bucket: str, object_name: str, *, if BUILD_GCLOUD_REST: # stream argument is only expected by requests.Session. # pylint: disable=unexpected-keyword-arg - return StreamResponse(s.get(url, headers=headers, - params=params or {}, - timeout=timeout, stream=True)) - return StreamResponse(await s.get(url, headers=headers, - params=params or {}, - timeout=timeout)) - - async def _upload_simple(self, url: str, object_name: str, - stream: IO[AnyStr], params: Dict[str, str], - headers: Dict[str, str], *, - session: Optional[Session] = None, - timeout: int = 30) -> Dict[str, Any]: + return StreamResponse( + s.get( + url, headers=headers, + params=params or {}, + timeout=timeout, stream=True, + ), + ) + return StreamResponse( + await s.get( + url, headers=headers, + params=params or {}, + timeout=timeout, + ), + ) + + async def _upload_simple( + self, url: str, object_name: str, + stream: IO[AnyStr], params: Dict[str, str], + headers: Dict[str, str], *, + session: Optional[Session] = None, + timeout: int = 30 + ) -> Dict[str, Any]: # https://cloud.google.com/storage/docs/json_api/v1/how-tos/simple-upload params['name'] = object_name params['uploadType'] = 'media' s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, data=stream, # type: ignore[arg-type] - headers=headers, params=params, timeout=timeout) + resp = await s.post( + url, data=stream, # type: ignore[arg-type] + headers=headers, params=params, timeout=timeout, + ) data: Dict[str, Any] = await resp.json(content_type=None) return data - async def _upload_multipart(self, url: str, object_name: str, - stream: IO[AnyStr], params: Dict[str, str], - headers: Dict[str, str], - metadata: Dict[str, Any], *, - session: Optional[Session] = None, - timeout: int = 30) -> Dict[str, Any]: + async def _upload_multipart( + self, url: str, object_name: str, + stream: IO[AnyStr], params: Dict[str, str], + headers: Dict[str, str], + metadata: Dict[str, Any], *, + session: Optional[Session] = None, + timeout: int = 30 + ) -> Dict[str, Any]: # https://cloud.google.com/storage/docs/json_api/v1/how-tos/multipart-upload params['uploadType'] = 'multipart' metadata_headers = {'Content-Type': 'application/json; charset=UTF-8'} - metadata = {self._format_metadata_key(k): v - for k, v in metadata.items()} + metadata = { + self._format_metadata_key(k): v + for k, v in metadata.items() + } if 'metadata' in metadata: metadata['metadata'] = { str(k): str(v) if v is not None else None - for k, v in metadata['metadata'].items()} + for k, v in metadata['metadata'].items() + } metadata['name'] = object_name @@ -558,7 +630,7 @@ async def _upload_multipart(self, url: str, object_name: str, headers.update({ 'Content-Type': content_type, 'Content-Length': str(len(body)), - 'Accept': 'application/json' + 'Accept': 'application/json', }) s = AioSession(session) if session else self.session @@ -568,38 +640,51 @@ async def _upload_multipart(self, url: str, object_name: str, body = io.BytesIO(body) # type: ignore[assignment] # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, data=body, # type: ignore[arg-type] - headers=headers, params=params, timeout=timeout) + resp = await s.post( + url, data=body, # type: ignore[arg-type] + headers=headers, params=params, timeout=timeout, + ) data: Dict[str, Any] = await resp.json(content_type=None) return data - async def _upload_resumable(self, url: str, object_name: str, - stream: IO[AnyStr], params: Dict[str, str], - headers: Dict[str, str], *, - metadata: Optional[Dict[str, Any]] = None, - session: Optional[Session] = None, - timeout: int = 30) -> Dict[str, Any]: + async def _upload_resumable( + self, url: str, object_name: str, + stream: IO[AnyStr], params: Dict[str, str], + headers: Dict[str, str], *, + metadata: Optional[Dict[str, Any]] = None, + session: Optional[Session] = None, + timeout: int = 30 + ) -> Dict[str, Any]: # https://cloud.google.com/storage/docs/json_api/v1/how-tos/resumable-upload - session_uri = await self._initiate_upload(url, object_name, params, - headers, metadata=metadata, - session=session) - return await self._do_upload(session_uri, stream, headers=headers, - session=session, timeout=timeout) - - async def _initiate_upload(self, url: str, object_name: str, - params: Dict[str, str], headers: Dict[str, str], - *, metadata: Optional[Dict[str, Any]] = None, - timeout: int = DEFAULT_TIMEOUT, - session: Optional[Session] = None) -> str: + session_uri = await self._initiate_upload( + url, object_name, params, + headers, metadata=metadata, + session=session, + ) + return await self._do_upload( + session_uri, stream, headers=headers, + session=session, timeout=timeout, + ) + + async def _initiate_upload( + self, url: str, object_name: str, + params: Dict[str, str], headers: Dict[str, str], + *, metadata: Optional[Dict[str, Any]] = None, + timeout: int = DEFAULT_TIMEOUT, + session: Optional[Session] = None + ) -> str: params['uploadType'] = 'resumable' metadict = (metadata or {}).copy() - metadict = {self._format_metadata_key(k): v - for k, v in metadict.items()} + metadict = { + self._format_metadata_key(k): v + for k, v in metadict.items() + } if 'metadata' in metadict: metadict['metadata'] = { str(k): str(v) if v is not None else None - for k, v in metadict['metadata'].items()} + for k, v in metadict['metadata'].items() + } metadict.update({'name': object_name}) metadata_ = json.dumps(metadict) @@ -609,19 +694,23 @@ async def _initiate_upload(self, url: str, object_name: str, 'Content-Length': str(len(metadata_)), 'Content-Type': 'application/json; charset=UTF-8', 'X-Upload-Content-Type': headers['Content-Type'], - 'X-Upload-Content-Length': headers['Content-Length'] + 'X-Upload-Content-Length': headers['Content-Length'], }) s = AioSession(session) if session else self.session - resp = await s.post(url, headers=post_headers, params=params, - data=metadata_, timeout=timeout) + resp = await s.post( + url, headers=post_headers, params=params, + data=metadata_, timeout=timeout, + ) session_uri: str = resp.headers['Location'] return session_uri - async def _do_upload(self, session_uri: str, stream: IO[AnyStr], - headers: Dict[str, str], *, retries: int = 5, - session: Optional[Session] = None, - timeout: int = 30) -> Dict[str, Any]: + async def _do_upload( + self, session_uri: str, stream: IO[AnyStr], + headers: Dict[str, str], *, retries: int = 5, + session: Optional[Session] = None, + timeout: int = 30 + ) -> Dict[str, Any]: s = AioSession(session) if session else self.session original_close = stream.close @@ -631,14 +720,16 @@ async def _do_upload(self, session_uri: str, stream: IO[AnyStr], try: for tries in range(retries): try: - resp = await s.put(session_uri, headers=headers, - data=stream, timeout=timeout) + resp = await s.put( + session_uri, headers=headers, + data=stream, timeout=timeout, + ) except ResponseError: headers.update({'Content-Range': '*/*'}) stream.seek(original_position) await sleep( # type: ignore[func-returns-value] - 2. ** tries + 2. ** tries, ) else: break @@ -653,7 +744,8 @@ async def patch_metadata( *, params: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None, session: Optional[Session] = None, - timeout: int = DEFAULT_TIMEOUT) -> Dict[str, Any]: + timeout: int = DEFAULT_TIMEOUT + ) -> Dict[str, Any]: # https://cloud.google.com/storage/docs/json_api/v1/objects/patch encoded_object_name = quote(object_name, safe='') url = f'{self._api_root_read}/{bucket}/o/{encoded_object_name}' @@ -665,24 +757,29 @@ async def patch_metadata( s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.patch(url, data=body, # type: ignore[arg-type] - headers=headers, params=params, timeout=timeout) + resp = await s.patch( + url, data=body, # type: ignore[arg-type] + headers=headers, params=params, timeout=timeout, + ) data: Dict[str, Any] = await resp.json(content_type=None) return data - async def get_bucket_metadata(self, bucket: str, *, - params: Optional[Dict[str, str]] = None, - headers: Optional[Dict[str, str]] = None, - session: Optional[Session] = None, - timeout: int = DEFAULT_TIMEOUT - ) -> Dict[str, Any]: + async def get_bucket_metadata( + self, bucket: str, *, + params: Optional[Dict[str, str]] = None, + headers: Optional[Dict[str, str]] = None, + session: Optional[Session] = None, + timeout: int = DEFAULT_TIMEOUT + ) -> Dict[str, Any]: url = f'{self._api_root_read}/{bucket}' headers = headers or {} headers.update(await self._headers()) s = AioSession(session) if session else self.session - resp = await s.get(url, headers=headers, params=params or {}, - timeout=timeout) + resp = await s.get( + url, headers=headers, params=params or {}, + timeout=timeout, + ) data: Dict[str, Any] = await resp.json(content_type=None) return data diff --git a/storage/tests/integration/download_range_test.py b/storage/tests/integration/download_range_test.py index de429a056..9c3898a54 100644 --- a/storage/tests/integration/download_range_test.py +++ b/storage/tests/integration/download_range_test.py @@ -26,17 +26,29 @@ @pytest.mark.asyncio @pytest.mark.parametrize( 'uploaded_data,range_header,expected_data,file_extension', [ - (json.dumps([1, 2, 3]), 'bytes=0-1', json.dumps( - [1, 2, 3]).encode('utf-8')[0:2], 'json'), - ('test'.encode('utf-8'), 'bytes=2-3', 'test'.encode('utf-8')[2:4], - 'bin'), - (io.BytesIO(RANDOM_BINARY), 'bytes=1-1000', RANDOM_BINARY[1:1001], - 'bin'), - (io.StringIO(RANDOM_STRING), 'bytes=10-100', - RANDOM_STRING.encode('utf-8')[10:101], 'txt'), - ]) -async def test_download_range(bucket_name, creds, uploaded_data, range_header, - expected_data, file_extension): + ( + json.dumps([1, 2, 3]), 'bytes=0-1', json.dumps( + [1, 2, 3], + ).encode('utf-8')[0:2], 'json', + ), + ( + 'test'.encode('utf-8'), 'bytes=2-3', 'test'.encode('utf-8')[2:4], + 'bin', + ), + ( + io.BytesIO(RANDOM_BINARY), 'bytes=1-1000', RANDOM_BINARY[1:1001], + 'bin', + ), + ( + io.StringIO(RANDOM_STRING), 'bytes=10-100', + RANDOM_STRING.encode('utf-8')[10:101], 'txt', + ), + ], +) +async def test_download_range( + bucket_name, creds, uploaded_data, range_header, + expected_data, file_extension, +): object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}.{file_extension}' async with Session() as session: @@ -44,7 +56,8 @@ async def test_download_range(bucket_name, creds, uploaded_data, range_header, res = await storage.upload(bucket_name, object_name, uploaded_data) downloaded_data = await storage.download( - bucket_name, res['name'], headers={'Range': range_header}) + bucket_name, res['name'], headers={'Range': range_header}, + ) assert expected_data == downloaded_data await storage.delete(bucket_name, res['name']) diff --git a/storage/tests/integration/download_stream_test.py b/storage/tests/integration/download_stream_test.py index ca2c8d743..b831a52c1 100644 --- a/storage/tests/integration/download_stream_test.py +++ b/storage/tests/integration/download_stream_test.py @@ -23,12 +23,16 @@ @pytest.mark.asyncio -@pytest.mark.parametrize('uploaded_data,expected_data', [ - (io.BytesIO(RANDOM_BINARY), RANDOM_BINARY), - (io.StringIO(RANDOM_STRING), RANDOM_STRING.encode('utf-8')), -]) -async def test_download_stream(bucket_name, creds, uploaded_data, - expected_data): +@pytest.mark.parametrize( + 'uploaded_data,expected_data', [ + (io.BytesIO(RANDOM_BINARY), RANDOM_BINARY), + (io.StringIO(RANDOM_STRING), RANDOM_STRING.encode('utf-8')), + ], +) +async def test_download_stream( + bucket_name, creds, uploaded_data, + expected_data, +): object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}' async with Session() as session: @@ -37,7 +41,8 @@ async def test_download_stream(bucket_name, creds, uploaded_data, with io.BytesIO(b'') as downloaded_data: download_stream = await storage.download_stream( - bucket_name, res['name']) + bucket_name, res['name'], + ) while True: chunk = await download_stream.read(4096) if not chunk: diff --git a/storage/tests/integration/metadata_test.py b/storage/tests/integration/metadata_test.py index 724b75e72..1a3c95b3c 100644 --- a/storage/tests/integration/metadata_test.py +++ b/storage/tests/integration/metadata_test.py @@ -15,35 +15,49 @@ async def test_metadata_multipart(bucket_name, creds): object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}.txt' original_data = f'{uuid.uuid4().hex}' - original_metadata = {'Content-Disposition': 'inline', - 'metadata': - {'a': 1, - 'b': 2, - 'c': [1, 2, 3], - 'd': {'a': 4, 'b': 5}}} + original_metadata = { + 'Content-Disposition': 'inline', + 'metadata': + { + 'a': 1, + 'b': 2, + 'c': [1, 2, 3], + 'd': {'a': 4, 'b': 5}, + }, + } # Google casts all metadata elements as string. - google_metadata = {'Content-Disposition': 'inline', - 'metadata': - {'a': str(1), - 'b': str(2), - 'c': str([1, 2, 3]), - 'd': str({'a': 4, 'b': 5})}} + google_metadata = { + 'Content-Disposition': 'inline', + 'metadata': + { + 'a': str(1), + 'b': str(2), + 'c': str([1, 2, 3]), + 'd': str({'a': 4, 'b': 5}), + }, + } async with Session() as session: storage = Storage(service_file=creds, session=session) # Without metadata - res0 = await storage.upload(bucket_name, object_name, original_data, - force_resumable_upload=False) + res0 = await storage.upload( + bucket_name, object_name, original_data, + force_resumable_upload=False, + ) data0 = await storage.download(bucket_name, res0['name']) await storage.download_metadata(bucket_name, res0['name']) # With metadata - res = await storage.upload(bucket_name, object_name, original_data, - metadata=original_metadata) + res = await storage.upload( + bucket_name, object_name, original_data, + metadata=original_metadata, + ) data = await storage.download(bucket_name, res['name']) - data_metadata = await storage.download_metadata(bucket_name, - res['name']) + data_metadata = await storage.download_metadata( + bucket_name, + res['name'], + ) assert res['name'] == object_name assert str(data, 'utf-8') == original_data @@ -57,36 +71,50 @@ async def test_metadata_multipart(bucket_name, creds): async def test_metadata_resumable(bucket_name, creds): object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}.txt' original_data = f'{uuid.uuid4().hex}' - original_metadata = {'Content-Disposition': 'inline', - 'metadata': - {'a': 1, - 'b': 2, - 'c': [1, 2, 3], - 'd': {'a': 4, 'b': 5}}} + original_metadata = { + 'Content-Disposition': 'inline', + 'metadata': + { + 'a': 1, + 'b': 2, + 'c': [1, 2, 3], + 'd': {'a': 4, 'b': 5}, + }, + } # Google casts all metadata elements as string. - google_metadata = {'Content-Disposition': 'inline', - 'metadata': - {'a': str(1), - 'b': str(2), - 'c': str([1, 2, 3]), - 'd': str({'a': 4, 'b': 5})}} + google_metadata = { + 'Content-Disposition': 'inline', + 'metadata': + { + 'a': str(1), + 'b': str(2), + 'c': str([1, 2, 3]), + 'd': str({'a': 4, 'b': 5}), + }, + } async with Session() as session: storage = Storage(service_file=creds, session=session) # Without metadata - res0 = await storage.upload(bucket_name, object_name, original_data, - force_resumable_upload=True) + res0 = await storage.upload( + bucket_name, object_name, original_data, + force_resumable_upload=True, + ) data0 = await storage.download(bucket_name, res0['name']) await storage.download_metadata(bucket_name, res0['name']) # With metadata - res = await storage.upload(bucket_name, object_name, original_data, - metadata=original_metadata, - force_resumable_upload=True) + res = await storage.upload( + bucket_name, object_name, original_data, + metadata=original_metadata, + force_resumable_upload=True, + ) data = await storage.download(bucket_name, res['name']) - data_metadata = await storage.download_metadata(bucket_name, - res['name']) + data_metadata = await storage.download_metadata( + bucket_name, + res['name'], + ) assert res['name'] == object_name assert str(data, 'utf-8') == original_data @@ -101,35 +129,47 @@ async def test_metadata_copy(bucket_name, creds): object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}.txt' copied_object_name = f'{object_name}.copy' original_data = f'{uuid.uuid4().hex}' - original_metadata = {'Content-Disposition': 'inline', - 'metadata': - {'a': 1, - 'b': 2, - 'c': [1, 2, 3], - 'd': {'a': 4, 'b': 5}}} + original_metadata = { + 'Content-Disposition': 'inline', + 'metadata': + { + 'a': 1, + 'b': 2, + 'c': [1, 2, 3], + 'd': {'a': 4, 'b': 5}, + }, + } # Google casts all metadata elements as string. - google_metadata = {'Content-Disposition': 'inline', - 'metadata': - {'a': str(1), - 'b': str(2), - 'c': str([1, 2, 3]), - 'd': str({'a': 4, 'b': 5})}} + google_metadata = { + 'Content-Disposition': 'inline', + 'metadata': + { + 'a': str(1), + 'b': str(2), + 'c': str([1, 2, 3]), + 'd': str({'a': 4, 'b': 5}), + }, + } async with Session() as session: storage = Storage(service_file=creds, session=session) # Without metadata - res0 = await storage.upload(bucket_name, object_name, original_data, - force_resumable_upload=True) + res0 = await storage.upload( + bucket_name, object_name, original_data, + force_resumable_upload=True, + ) data0 = await storage.download(bucket_name, res0['name']) - await storage.copy(bucket_name, object_name, bucket_name, - new_name=copied_object_name, - metadata=original_metadata) + await storage.copy( + bucket_name, object_name, bucket_name, + new_name=copied_object_name, + metadata=original_metadata, + ) data = await storage.download(bucket_name, copied_object_name) data_metadata = await storage.download_metadata( - bucket_name, copied_object_name + bucket_name, copied_object_name, ) assert data == data0 diff --git a/storage/tests/integration/signed_url_test.py b/storage/tests/integration/signed_url_test.py index 26fea1c6e..780fb18e8 100644 --- a/storage/tests/integration/signed_url_test.py +++ b/storage/tests/integration/signed_url_test.py @@ -19,8 +19,10 @@ async def test_gcs_signed_url(bucket_name, creds, data): async with Session() as session: storage = Storage(service_file=creds, session=session) - await storage.upload(bucket_name, object_name, data, - force_resumable_upload=True) + await storage.upload( + bucket_name, object_name, data, + force_resumable_upload=True, + ) bucket = Bucket(storage, bucket_name) blob = await bucket.get_blob(object_name, session=session) diff --git a/storage/tests/integration/smoke_test.py b/storage/tests/integration/smoke_test.py index 3fa1939f2..c02565196 100644 --- a/storage/tests/integration/smoke_test.py +++ b/storage/tests/integration/smoke_test.py @@ -15,13 +15,21 @@ @pytest.mark.asyncio -@pytest.mark.parametrize('uploaded_data,expected_data,file_extension', [ - ('test', b'test', 'txt'), - (b'test', b'test', 'bin'), - (json.dumps({'data': 1}), json.dumps({'data': 1}).encode('utf-8'), 'json'), -]) -async def test_object_life_cycle(bucket_name, creds, uploaded_data, - expected_data, file_extension): +@pytest.mark.parametrize( + 'uploaded_data,expected_data,file_extension', [ + ('test', b'test', 'txt'), + (b'test', b'test', 'bin'), + ( + json.dumps({'data': 1}), json.dumps( + {'data': 1}, + ).encode('utf-8'), 'json', + ), + ], +) +async def test_object_life_cycle( + bucket_name, creds, uploaded_data, + expected_data, file_extension, +): object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}.{file_extension}' copied_object_name = f'copyof_{object_name}' @@ -37,8 +45,10 @@ async def test_object_life_cycle(bucket_name, creds, uploaded_data, direct_result = await storage.download(bucket_name, object_name) assert direct_result == expected_data - await storage.copy(bucket_name, object_name, bucket_name, - new_name=copied_object_name) + await storage.copy( + bucket_name, object_name, bucket_name, + new_name=copied_object_name, + ) direct_result = await storage.download(bucket_name, copied_object_name) assert direct_result == expected_data diff --git a/storage/tests/integration/upload_multipart_test.py b/storage/tests/integration/upload_multipart_test.py index 96e610c50..1a804b513 100644 --- a/storage/tests/integration/upload_multipart_test.py +++ b/storage/tests/integration/upload_multipart_test.py @@ -24,16 +24,24 @@ @pytest.mark.asyncio -@pytest.mark.parametrize('uploaded_data,expected_data,file_extension', [ - ('test', b'test', 'txt'), - (json.dumps({'data': 1}), json.dumps({'data': 1}).encode('utf-8'), 'json'), - (json.dumps([1, 2, 3]), json.dumps([1, 2, 3]).encode('utf-8'), 'json'), - ('test'.encode('utf-8'), 'test'.encode('utf-8'), 'bin'), - (io.BytesIO(RANDOM_BINARY), RANDOM_BINARY, 'bin'), - (io.StringIO(RANDOM_STRING), RANDOM_STRING.encode('utf-8'), 'txt'), -]) -async def test_upload_multipart(bucket_name, creds, uploaded_data, - expected_data, file_extension): +@pytest.mark.parametrize( + 'uploaded_data,expected_data,file_extension', [ + ('test', b'test', 'txt'), + ( + json.dumps({'data': 1}), json.dumps( + {'data': 1}, + ).encode('utf-8'), 'json', + ), + (json.dumps([1, 2, 3]), json.dumps([1, 2, 3]).encode('utf-8'), 'json'), + ('test'.encode('utf-8'), 'test'.encode('utf-8'), 'bin'), + (io.BytesIO(RANDOM_BINARY), RANDOM_BINARY, 'bin'), + (io.StringIO(RANDOM_STRING), RANDOM_STRING.encode('utf-8'), 'txt'), + ], +) +async def test_upload_multipart( + bucket_name, creds, uploaded_data, + expected_data, file_extension, +): object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}.{file_extension}' async with Session() as session: @@ -42,8 +50,11 @@ async def test_upload_multipart(bucket_name, creds, uploaded_data, bucket_name, object_name, uploaded_data, - metadata={'Content-Disposition': 'inline', - 'metadata': {'a': 1, 'b': 2}}) + metadata={ + 'Content-Disposition': 'inline', + 'metadata': {'a': 1, 'b': 2}, + }, + ) try: assert res['name'] == object_name @@ -51,8 +62,10 @@ async def test_upload_multipart(bucket_name, creds, uploaded_data, downloaded_data = await storage.download(bucket_name, res['name']) assert downloaded_data == expected_data - downloaded_metadata = await storage.download_metadata(bucket_name, - res['name']) + downloaded_metadata = await storage.download_metadata( + bucket_name, + res['name'], + ) assert downloaded_metadata.pop('contentDisposition') == 'inline' assert downloaded_metadata['metadata']['a'] == '1' assert downloaded_metadata['metadata']['b'] == '2' diff --git a/storage/tests/integration/upload_resumable_test.py b/storage/tests/integration/upload_resumable_test.py index 80ecb8d13..097ea4ec4 100644 --- a/storage/tests/integration/upload_resumable_test.py +++ b/storage/tests/integration/upload_resumable_test.py @@ -24,16 +24,24 @@ @pytest.mark.asyncio -@pytest.mark.parametrize('uploaded_data,expected_data,file_extension', [ - ('test', b'test', 'txt'), - (json.dumps({'data': 1}), json.dumps({'data': 1}).encode('utf-8'), 'json'), - (json.dumps([1, 2, 3]), json.dumps([1, 2, 3]).encode('utf-8'), 'json'), - ('test'.encode('utf-8'), 'test'.encode('utf-8'), 'bin'), - (io.BytesIO(RANDOM_BINARY), RANDOM_BINARY, 'bin'), - (io.StringIO(RANDOM_STRING), RANDOM_STRING.encode('utf-8'), 'txt'), -]) -async def test_upload_resumable(bucket_name, creds, uploaded_data, - expected_data, file_extension): +@pytest.mark.parametrize( + 'uploaded_data,expected_data,file_extension', [ + ('test', b'test', 'txt'), + ( + json.dumps({'data': 1}), json.dumps( + {'data': 1}, + ).encode('utf-8'), 'json', + ), + (json.dumps([1, 2, 3]), json.dumps([1, 2, 3]).encode('utf-8'), 'json'), + ('test'.encode('utf-8'), 'test'.encode('utf-8'), 'bin'), + (io.BytesIO(RANDOM_BINARY), RANDOM_BINARY, 'bin'), + (io.StringIO(RANDOM_STRING), RANDOM_STRING.encode('utf-8'), 'txt'), + ], +) +async def test_upload_resumable( + bucket_name, creds, uploaded_data, + expected_data, file_extension, +): object_name = f'{uuid.uuid4().hex}/{uuid.uuid4().hex}.{file_extension}' async with Session() as session: @@ -43,14 +51,19 @@ async def test_upload_resumable(bucket_name, creds, uploaded_data, object_name, uploaded_data, force_resumable_upload=True, - metadata={'Content-Disposition': 'inline', - 'metadata': {'a': 1, 'b': 2}}) + metadata={ + 'Content-Disposition': 'inline', + 'metadata': {'a': 1, 'b': 2}, + }, + ) downloaded_data = await storage.download(bucket_name, res['name']) assert expected_data == downloaded_data - downloaded_metadata = await storage.download_metadata(bucket_name, - res['name']) + downloaded_metadata = await storage.download_metadata( + bucket_name, + res['name'], + ) assert downloaded_metadata.pop('contentDisposition') == 'inline' assert downloaded_metadata['metadata']['a'] == '1' assert downloaded_metadata['metadata']['b'] == '2' diff --git a/storage/tests/unit/upload_retry_test.py b/storage/tests/unit/upload_retry_test.py index be771a8df..6332954df 100644 --- a/storage/tests/unit/upload_retry_test.py +++ b/storage/tests/unit/upload_retry_test.py @@ -79,6 +79,7 @@ async def test_upload_retry(fake_server): # pylint: disable=redefined-outer-nam response = await storage.upload( bucket_name, object_name, content_type='text/plain', - file_data=data_stream, force_resumable_upload=True) + file_data=data_stream, force_resumable_upload=True, + ) assert response.get('data') == 'test data' diff --git a/taskqueue/gcloud/aio/taskqueue/queue.py b/taskqueue/gcloud/aio/taskqueue/queue.py index 70884d81c..5b35cefe9 100644 --- a/taskqueue/gcloud/aio/taskqueue/queue.py +++ b/taskqueue/gcloud/aio/taskqueue/queue.py @@ -55,12 +55,14 @@ def __init__( self._api_is_dev, self._api_root = init_api_root(api_root) self._api_root_queue = ( f'{self._api_root}/projects/{project}/locations/{location}/' - f'queues/{taskqueue}') + f'queues/{taskqueue}' + ) self.session = AioSession(session) self.token = token or Token( service_file=service_file, scopes=SCOPES, - session=self.session.session) # type: ignore[arg-type] + session=self.session.session, # type: ignore[arg-type] + ) async def headers(self) -> Dict[str, str]: if self._api_is_dev: @@ -74,8 +76,10 @@ async def headers(self) -> Dict[str, str]: # https://cloud.google.com/tasks/docs/reference/rest/v2beta3/projects.locations.queues.tasks/create @backoff.on_exception(backoff.expo, Exception, max_tries=3) - async def create(self, task: Dict[str, Any], - session: Optional[Session] = None) -> Any: + async def create( + self, task: Dict[str, Any], + session: Optional[Session] = None, + ) -> Any: url = f'{self._api_root_queue}/tasks' payload = json.dumps({ 'task': task, @@ -86,14 +90,18 @@ async def create(self, task: Dict[str, Any], s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, headers=headers, - data=payload) # type: ignore[arg-type] + resp = await s.post( + url, headers=headers, + data=payload, # type: ignore[arg-type] + ) return await resp.json() # https://cloud.google.com/tasks/docs/reference/rest/v2beta3/projects.locations.queues.tasks/delete @backoff.on_exception(backoff.expo, Exception, max_tries=3) - async def delete(self, tname: str, - session: Optional[Session] = None) -> Any: + async def delete( + self, tname: str, + session: Optional[Session] = None, + ) -> Any: url = f'{self._api_root}/{tname}' headers = await self.headers() @@ -104,8 +112,10 @@ async def delete(self, tname: str, # https://cloud.google.com/tasks/docs/reference/rest/v2beta3/projects.locations.queues.tasks/get @backoff.on_exception(backoff.expo, Exception, max_tries=3) - async def get(self, tname: str, full: bool = False, - session: Optional[Session] = None) -> Any: + async def get( + self, tname: str, full: bool = False, + session: Optional[Session] = None, + ) -> Any: url = f'{self._api_root}/{tname}' params = { 'responseView': 'FULL' if full else 'BASIC', @@ -119,9 +129,11 @@ async def get(self, tname: str, full: bool = False, # https://cloud.google.com/tasks/docs/reference/rest/v2beta3/projects.locations.queues.tasks/list @backoff.on_exception(backoff.expo, Exception, max_tries=3) - async def list(self, full: bool = False, page_size: int = 1000, - page_token: str = '', - session: Optional[Session] = None) -> Any: + async def list( + self, full: bool = False, page_size: int = 1000, + page_token: str = '', + session: Optional[Session] = None, + ) -> Any: url = f'{self._api_root_queue}/tasks' params: Dict[str, Union[int, str]] = { 'responseView': 'FULL' if full else 'BASIC', @@ -133,14 +145,18 @@ async def list(self, full: bool = False, page_size: int = 1000, s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.get(url, headers=headers, - params=params) # type: ignore[arg-type] + resp = await s.get( + url, headers=headers, + params=params, # type: ignore[arg-type] + ) return await resp.json() # https://cloud.google.com/tasks/docs/reference/rest/v2beta3/projects.locations.queues.tasks/run @backoff.on_exception(backoff.expo, Exception, max_tries=3) - async def run(self, tname: str, full: bool = False, - session: Optional[Session] = None) -> Any: + async def run( + self, tname: str, full: bool = False, + session: Optional[Session] = None, + ) -> Any: url = f'{self._api_root}/{tname}:run' payload = json.dumps({ 'responseView': 'FULL' if full else 'BASIC', @@ -150,8 +166,10 @@ async def run(self, tname: str, full: bool = False, s = AioSession(session) if session else self.session # TODO: the type issue will be fixed in auth-4.0.2 - resp = await s.post(url, headers=headers, - data=payload) # type: ignore[arg-type] + resp = await s.post( + url, headers=headers, + data=payload, # type: ignore[arg-type] + ) return await resp.json() async def close(self) -> None: diff --git a/taskqueue/tests/integration/conftest.py b/taskqueue/tests/integration/conftest.py index c165447fa..c6b3bed81 100644 --- a/taskqueue/tests/integration/conftest.py +++ b/taskqueue/tests/integration/conftest.py @@ -14,7 +14,7 @@ def pytest_configure(config): config.addinivalue_line( - 'markers', 'slow: marks tests as slow (deselect with `-m "not slow"`)' + 'markers', 'slow: marks tests as slow (deselect with `-m "not slow"`)', ) @@ -46,7 +46,11 @@ async def session() -> str: @pytest.fixture(scope='function') -async def push_queue(project, creds, push_queue_name, push_queue_location, - session): - return PushQueue(project, push_queue_name, service_file=creds, - location=push_queue_location, session=session) +async def push_queue( + project, creds, push_queue_name, push_queue_location, + session, +): + return PushQueue( + project, push_queue_name, service_file=creds, + location=push_queue_location, session=session, + ) diff --git a/taskqueue/tests/integration/pushqueue_test.py b/taskqueue/tests/integration/pushqueue_test.py index 6b83f54ab..555cd18ba 100644 --- a/taskqueue/tests/integration/pushqueue_test.py +++ b/taskqueue/tests/integration/pushqueue_test.py @@ -18,7 +18,7 @@ async def test_task_lifecycle_in_push_queue(push_queue): # something that we know won't work, # so that 'run' task operation doesn't end up deleting the task. 'relativeUri': '/some/test/uri', - } + }, } # CREATE