diff --git a/storage/gcloud/aio/storage/storage.py b/storage/gcloud/aio/storage/storage.py index 827182f33..265afc27a 100644 --- a/storage/gcloud/aio/storage/storage.py +++ b/storage/gcloud/aio/storage/storage.py @@ -4,8 +4,10 @@ import logging import mimetypes import os +import sys from typing import Any from typing import Dict +from typing import Generator from typing import Optional from typing import Tuple from typing import Union @@ -16,6 +18,12 @@ from gcloud.aio.auth import Token # pylint: disable=no-name-in-module from gcloud.aio.storage.bucket import Bucket +if sys.version_info >= (3, 6): + from typing import AsyncGenerator # pylint: disable=ungrouped-imports + GENERATORS = (AsyncGenerator, Generator) +else: + GENERATORS = (Generator,) + # Selectively load libraries based on the package if BUILD_GCLOUD_REST: from time import sleep @@ -26,6 +34,7 @@ from aiohttp import ClientResponseError as ResponseError from aiohttp import ClientSession as Session +StreamTypes = GENERATORS + (io.IOBase,) API_ROOT = 'https://www.googleapis.com/storage/v1/b' API_ROOT_UPLOAD = 'https://www.googleapis.com/upload/storage/v1/b' @@ -185,8 +194,9 @@ async def list_objects(self, bucket: str, *, params: dict = None, # TODO: if `metadata` is set, use multipart upload: # https://cloud.google.com/storage/docs/json_api/v1/how-tos/upload # pylint: disable=too-many-locals - async def upload(self, bucket: str, object_name: str, file_data: Any, - *, content_type: str = None, parameters: dict = None, + async def upload(self, bucket: str, object_name: str, + file_data: Union[StreamTypes], *, + content_type: str = None, parameters: dict = None, headers: dict = None, metadata: dict = None, session: Optional[Session] = None, timeout: int = 30, force_resumable_upload: bool = None) -> dict: @@ -208,10 +218,9 @@ async def upload(self, bucket: str, object_name: str, file_data: Any, headers = headers or {} headers.update(await self._headers()) - headers.update({ - 'Content-Length': str(content_length), - 'Content-Type': content_type or '', - }) + headers['Content-Type'] = content_type or '' + if content_length > 0: + headers['Content-Length'] = str(content_length) upload_type = self._decide_upload_type(force_resumable_upload, content_length) @@ -239,7 +248,10 @@ async def upload_from_filename(self, bucket: str, object_name: str, **kwargs) @staticmethod - def _get_stream_len(stream: io.IOBase) -> int: + def _get_stream_len(stream: Union[StreamTypes]) -> int: + if isinstance(stream, GENERATORS): + # generator length is not known, return 0 + return 0 current = stream.tell() try: return stream.seek(0, os.SEEK_END) @@ -255,7 +267,7 @@ def _preprocess_data(data: Any) -> io.IOBase: return io.BytesIO(data) if isinstance(data, str): return io.StringIO(data) - if isinstance(data, io.IOBase): + if isinstance(data, StreamTypes): return data raise TypeError(f'unsupported upload type: "{type(data)}"') @@ -272,7 +284,8 @@ def _decide_upload_type(force_resumable_upload: Optional[bool], return UploadType.SIMPLE # decide based on Content-Length - if content_length > MAX_CONTENT_LENGTH_SIMPLE_UPLOAD: + if (content_length == 0 or + content_length > MAX_CONTENT_LENGTH_SIMPLE_UPLOAD): return UploadType.RESUMABLE return UploadType.SIMPLE @@ -327,7 +340,7 @@ async def _upload_simple(self, url: str, object_name: str, return data async def _upload_resumable(self, url: str, object_name: str, - stream: io.IOBase, params: dict, + stream: Union[StreamTypes], params: dict, headers: dict, *, metadata: dict = None, session: Optional[Session] = None, timeout: int = 30) -> dict: @@ -354,16 +367,16 @@ async def _initiate_upload(self, url: str, object_name: str, params: dict, 'Content-Length': str(len(metadata)), 'Content-Type': 'application/json; charset=UTF-8', 'X-Upload-Content-Type': headers['Content-Type'], - 'X-Upload-Content-Length': headers['Content-Length'] }) - + if 'Content-Length' in headers: + post_headers['X-Upload-Content-Length'] = headers['Content-Length'] s = AioSession(session) if session else self.session resp = await s.post(url, headers=post_headers, params=params, data=metadata, timeout=10) session_uri: str = resp.headers['Location'] return session_uri - async def _do_upload(self, session_uri: str, stream: io.IOBase, + async def _do_upload(self, session_uri: str, stream: Union[StreamTypes], headers: dict, *, retries: int = 5, session: Optional[Session] = None, timeout: int = 30) -> dict: