diff --git a/synapseclient/core/download/download_async.py b/synapseclient/core/download/download_async.py index 1c9c65c47..b98550c5d 100644 --- a/synapseclient/core/download/download_async.py +++ b/synapseclient/core/download/download_async.py @@ -25,7 +25,6 @@ DEFAULT_MAX_BACK_OFF_ASYNC, RETRYABLE_CONNECTION_ERRORS, RETRYABLE_CONNECTION_EXCEPTIONS, - with_retry_async, with_retry_time_based, ) from synapseclient.core.transfer_bar import get_or_create_download_progress_bar @@ -272,15 +271,10 @@ async def download_file(self) -> None: """ url_provider = PresignedUrlProvider(self._syn, request=self._download_request) - file_size = await with_retry_async( - function=lambda: _get_file_size_wrapper( - syn=self._syn, - url_provider=url_provider, - debug=self._download_request.debug, - ), - verbose=self._download_request.debug, - retry_status_codes=[403, 429, 500, 502, 503, 504], - log_for_retry=True, + file_size = await _get_file_size_wrapper( + syn=self._syn, + url_provider=url_provider, + debug=self._download_request.debug, ) self._progress_bar = get_or_create_download_progress_bar( file_size=file_size, diff --git a/synapseclient/core/retry.py b/synapseclient/core/retry.py index a70fff07f..5851df5bc 100644 --- a/synapseclient/core/retry.py +++ b/synapseclient/core/retry.py @@ -206,140 +206,6 @@ def foo(a, b, c): return [a, b, c] return response -async def with_retry_async( - function: Coroutine[Any, Any, Any], - verbose=False, - retry_status_codes=[429, 500, 502, 503, 504], - expected_status_codes=[], - retry_errors=[], - retry_exceptions=[], - retries=DEFAULT_RETRIES, - wait=DEFAULT_WAIT, - back_off=DEFAULT_BACK_OFF, - max_wait=DEFAULT_MAX_WAIT, - log_for_retry=False, -): - """ - Retries the given function under certain conditions. - - Arguments: - function: A function with no arguments. If arguments are needed, use a lambda (see example). - retry_status_codes: What status codes to retry upon in the case of a SynapseHTTPError. - expected_status_codes: If specified responses with any other status codes result in a retry. - retry_errors: What reasons to retry upon, if function().response.json()['reason'] exists. - retry_exceptions: What types of exceptions, specified as strings or Exception classes, to retry upon. - retries: How many times to retry maximum. - wait: How many seconds to wait between retries. - back_off: Exponential constant to increase wait for between progressive failures. - max_wait: back_off between requests will not exceed this value - log_for_retry: Determine if a message indiciating a retry will occur is logged. - - Returns: - function() - - Example: Using with_retry - Using ``with_retry`` to consolidate inputs into a list. - - from synapseclient.core.retry import with_retry - - def foo(a, b, c): return [a, b, c] - result = await with_retry_async(lambda: foo("1", "2", "3"), **STANDARD_RETRY_PARAMS) - """ - - if verbose: - logger = logging.getLogger(DEBUG_LOGGER_NAME) - else: - logger = logging.getLogger(DEFAULT_LOGGER_NAME) - - # Retry until we succeed or run out of tries - total_wait = 0 - while True: - # Start with a clean slate - exc = None - exc_info = None - retry = False - response = None - - # Try making the call - try: - response = await function() - except Exception as ex: - exc = ex - exc_info = sys.exc_info() - logger.debug(DEBUG_EXCEPTION, function, exc_info=True) - if hasattr(ex, "response"): - response = ex.response - - # Check if we got a retry-able error - if response is not None and hasattr(response, "status_code"): - if ( - expected_status_codes - and response.status_code not in expected_status_codes - ) or (retry_status_codes and response.status_code in retry_status_codes): - response_message = _get_message(response) - retry = True - logger.debug("retrying on status code: %s" % str(response.status_code)) - logger.debug(str(response_message)) - if (response.status_code == 429) and (wait > 10): - logger.warning("%s...\n" % response_message) - logger.warning("Retrying in %i seconds" % wait) - - elif response.status_code not in range(200, 299): - # For all other non 200 messages look for retryable errors in the body or reason field - response_message = _get_message(response) - if any( - [msg.lower() in response_message.lower() for msg in retry_errors] - ): - retry = True - logger.debug("retrying %s" % response_message) - # special case for message throttling - elif ( - "Please slow down. You may send a maximum of 10 message" - in response - ): - retry = True - wait = 16 - logger.debug("retrying " + response_message) - - # Check if we got a retry-able exception - if exc is not None: - if ( - exc.__class__.__name__ in retry_exceptions - or exc.__class__ in retry_exceptions - or any( - [msg.lower() in str(exc_info[1]).lower() for msg in retry_errors] - ) - ): - retry = True - logger.debug("retrying exception: " + str(exc)) - - # Wait then retry - retries -= 1 - if retries >= 0 and retry: - if log_for_retry: - logger.info(f"Retrying action in {wait} seconds") - - randomized_wait = wait * random.uniform(0.5, 1.5) - logger.debug( - "total wait time {total_wait:5.0f} seconds\n " - "... Retrying in {wait:5.1f} seconds...".format( - total_wait=total_wait, wait=randomized_wait - ) - ) - total_wait += randomized_wait - doze(randomized_wait) - wait = min(max_wait, wait * back_off) - continue - - # Out of retries, re-raise the exception or return the response - if exc_info is not None and exc_info[0] is not None: - logger.debug( - "retries have run out. re-raising the exception", exc_info=True - ) - raise exc - return response - - def calculate_exponential_backoff( retries: int, base_wait: float, diff --git a/synapseclient/models/mixins/storable_container.py b/synapseclient/models/mixins/storable_container.py index 49b1f48f4..002086aeb 100644 --- a/synapseclient/models/mixins/storable_container.py +++ b/synapseclient/models/mixins/storable_container.py @@ -2,7 +2,7 @@ import asyncio import os -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Union from typing_extensions import Self @@ -55,6 +55,42 @@ class StorableContainer(StorableContainerSynchronousProtocol): async def get_async(self, *, synapse_client: Optional[Synapse] = None) -> None: """Used to satisfy the usage in this mixin from the parent class.""" + async def _worker( + self, + queue: asyncio.Queue, + failure_strategy: FailureStrategy, + synapse_client: Synapse, + ) -> NoReturn: + """ + Coroutine that will process the queue of work items. This will process the + work items until the queue is empty. This will be used to download files in + parallel. + + Arguments: + queue: The queue of work items to process. + failure_strategy: Determines how to handle failures when retrieving items + out of the queue and an exception occurs. + synapse_client: The Synapse client to use to download the files. + """ + while True: + # Get a "work item" out of the queue. + work_item = await queue.get() + + try: + result = await work_item + except asyncio.CancelledError as ex: + raise ex + except Exception as ex: + result = ex + + self._resolve_sync_from_synapse_result( + result=result, + failure_strategy=failure_strategy, + synapse_client=synapse_client, + ) + + queue.task_done() + @otel_trace_method( method_to_trace_name=lambda self, **kwargs: f"{self.__class__.__name__}_sync_from_synapse: {self.id}" ) @@ -68,6 +104,7 @@ async def sync_from_synapse_async( include_activity: bool = True, follow_link: bool = False, link_hops: int = 1, + queue: asyncio.Queue = None, *, synapse_client: Optional[Synapse] = None, ) -> Self: @@ -239,6 +276,7 @@ async def sync_from_synapse_async( include_activity=include_activity, follow_link=follow_link, link_hops=link_hops, + queue=queue, synapse_client=syn, ) @@ -252,6 +290,7 @@ async def _sync_from_synapse_async( include_activity: bool = True, follow_link: bool = False, link_hops: int = 1, + queue: asyncio.Queue = None, *, synapse_client: Optional[Synapse] = None, ) -> Self: @@ -277,6 +316,21 @@ async def _sync_from_synapse_async( ), ) + create_workers = not queue + + queue = queue or asyncio.Queue() + worker_tasks = [] + if create_workers: + for _ in range(max(syn.max_threads * 2, 1)): + task = asyncio.create_task( + self._worker( + queue=queue, + failure_strategy=failure_strategy, + synapse_client=syn, + ) + ) + worker_tasks.append(task) + pending_tasks = [] self.folders = [] self.files = [] @@ -294,6 +348,7 @@ async def _sync_from_synapse_async( include_activity=include_activity, follow_link=follow_link, link_hops=link_hops, + queue=queue, ) ) @@ -305,6 +360,14 @@ async def _sync_from_synapse_async( synapse_client=syn, ) + if create_workers: + try: + # Wait until the queue is fully processed. + await queue.join() + finally: + for task in worker_tasks: + task.cancel() + return self def flatten_file_list(self) -> List["File"]: @@ -419,6 +482,7 @@ def _retrieve_children( async def _wrap_recursive_get_children( self, folder: "Folder", + queue: asyncio.Queue, recursive: bool = False, path: Optional[str] = None, download_file: bool = False, @@ -451,11 +515,13 @@ async def _wrap_recursive_get_children( follow_link=follow_link, link_hops=link_hops, synapse_client=synapse_client, + queue=queue, ) def _create_task_for_child( self, child, + queue: asyncio.Queue, recursive: bool = False, path: Optional[str] = None, download_file: bool = False, @@ -525,6 +591,7 @@ def _create_task_for_child( follow_link=follow_link, link_hops=link_hops, synapse_client=synapse_client, + queue=queue, ) ) ) @@ -546,17 +613,14 @@ def _create_task_for_child( if if_collision: file.if_collision = if_collision - pending_tasks.append( - asyncio.create_task( - wrap_coroutine( - file.get_async( - include_activity=include_activity, - synapse_client=synapse_client, - ) + queue.put_nowait( + wrap_coroutine( + file.get_async( + include_activity=include_activity, + synapse_client=synapse_client, ) ) ) - elif link_hops > 0 and synapse_id and child_type == LINK_ENTITY: pending_tasks.append( asyncio.create_task( @@ -572,6 +636,7 @@ def _create_task_for_child( include_activity=include_activity, follow_link=follow_link, link_hops=link_hops - 1, + queue=queue, ) ) ) @@ -582,6 +647,7 @@ def _create_task_for_child( async def _follow_link( self, child, + queue: asyncio.Queue, recursive: bool = False, path: Optional[str] = None, download_file: bool = False, @@ -634,6 +700,7 @@ async def _follow_link( include_activity=include_activity, follow_link=follow_link, link_hops=link_hops, + queue=queue, synapse_client=synapse_client, ) for task in asyncio.as_completed(pending_tasks):