diff --git a/alipcs_py/alipcs/api.py b/alipcs_py/alipcs/api.py index d3f067e..47038e7 100644 --- a/alipcs_py/alipcs/api.py +++ b/alipcs_py/alipcs/api.py @@ -633,6 +633,21 @@ def download_link(self, file_id: str) -> Optional[PcsDownloadUrl]: info = self._alipcs.download_link(file_id) return PcsDownloadUrl.from_(info) + def update_download_url(self, pcs_file: PcsFile) -> PcsFile: + """Update the download url of the `pcs_file` if it is expired + + Return a new `PcsFile` with the updated download url. + """ + + assert pcs_file.is_file, f"{pcs_file} is not a file" + + pcs_file = deepcopy(pcs_file) + if pcs_file.download_url_expires(): + pcs_url = self.download_link(pcs_file.file_id) + if pcs_url: + pcs_file.download_url = pcs_url.url + return pcs_file + def file_stream( self, file_id: str, @@ -982,6 +997,21 @@ def download_link(self, file_id: str) -> Optional[PcsDownloadUrl]: info = self._aliopenpcs.download_link(file_id) return PcsDownloadUrl.from_(info) + def update_download_url(self, pcs_file: PcsFile) -> PcsFile: + """Update the download url of the `pcs_file` if it is expired + + Return a new `PcsFile` with the updated download url. + """ + + assert pcs_file.is_file, f"{pcs_file} is not a file" + + pcs_file = deepcopy(pcs_file) + if pcs_file.download_url_expires(): + pcs_url = self.download_link(pcs_file.file_id) + if pcs_url: + pcs_file.download_url = pcs_url.url + return pcs_file + def file_stream( self, file_id: str, @@ -1053,11 +1083,24 @@ def __init__( def download_link(self, file_id: str) -> Optional[PcsDownloadUrl]: """Get the download link of the `file_id`""" - if self._aliopenpcsapi: + if self._aliopenpcsapi is not None: return self._aliopenpcsapi.download_link(file_id) else: return super().download_link(file_id) + def update_download_url(self, pcs_file: PcsFile) -> PcsFile: + """Update the download url of the `pcs_file` if it is expired + + Return a new `PcsFile` with the updated download url. + """ + + assert pcs_file.is_file, f"{pcs_file} is not a file" + + if self._aliopenpcsapi is not None: + return self._aliopenpcsapi.update_download_url(pcs_file) + else: + return super().update_download_url(pcs_file) + def file_stream( self, file_id: str, @@ -1067,7 +1110,7 @@ def file_stream( ) -> Optional[RangeRequestIO]: """File stream as a normal io""" - if self._aliopenpcsapi: + if self._aliopenpcsapi is not None: return self._aliopenpcsapi.file_stream( file_id, max_chunk_size=max_chunk_size, callback=callback, encrypt_password=encrypt_password ) diff --git a/alipcs_py/alipcs/errors.py b/alipcs_py/alipcs/errors.py index 69a23b6..8213e46 100644 --- a/alipcs_py/alipcs/errors.py +++ b/alipcs_py/alipcs/errors.py @@ -58,7 +58,7 @@ def refresh(*args, **kwargs): share_auth = self.__class__.SHARE_AUTHS.get(share_id) if share_auth: - share_auth.expire_time = 0.0 + share_auth.expire_time = 0 continue elif code == "ParamFlowException": diff --git a/alipcs_py/alipcs/inner.py b/alipcs_py/alipcs/inner.py index 41ff541..371e1fc 100644 --- a/alipcs_py/alipcs/inner.py +++ b/alipcs_py/alipcs/inner.py @@ -4,6 +4,7 @@ import time import re import urllib.parse +import warnings from alipcs_py.common.date import iso_8601_to_timestamp, now_timestamp @@ -178,6 +179,11 @@ def download_url_expires(self) -> bool: def update_download_url(self, api: "AliPCSApi"): """Update the download url if it expires""" + warnings.warn( + "This method is deprecated and will be removed in a future version, use `update_download_url` in `AliPCSApi` instead", + DeprecationWarning, + ) + if self.is_file: if self.download_url_expires(): pcs_url = api.download_link(self.file_id) @@ -394,7 +400,7 @@ class SharedAuth: share_id: str share_password: str share_token: str - expire_time: float + expire_time: int expires_in: int info: Any @@ -578,6 +584,7 @@ def from_(info) -> "PcsRateLimit": @dataclass class PcsDownloadUrl: url: Optional[str] = None + download_url: Optional[str] = None # url and download_url seem the same, the download rate is same internal_url: Optional[str] = None cdn_url: Optional[str] = None size: Optional[int] = None @@ -594,6 +601,7 @@ def from_(info) -> "PcsDownloadUrl": return PcsDownloadUrl( url=info.get("url"), + download_url=info.get("download_url"), internal_url=info.get("internal_url"), cdn_url=info.get("cdn_url"), size=info.get("size"), diff --git a/alipcs_py/alipcs/pcs.py b/alipcs_py/alipcs/pcs.py index 67968f3..6cc5ae5 100644 --- a/alipcs_py/alipcs/pcs.py +++ b/alipcs_py/alipcs/pcs.py @@ -25,8 +25,9 @@ ALIYUNDRIVE_COM_API = "https://api.aliyundrive.com" ALIYUNDRIVE_OPENAPI_DOMAIN = "https://openapi.aliyundrive.com" +# TODO: Update UA PCS_UA = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36" -PCS_HEADERS = {"Origin": ALIYUNDRIVE_COM[:-1], "Referer": ALIYUNDRIVE_COM + "/", "User-Agent": PCS_UA} +PCS_HEADERS = {"Origin": ALIYUNDRIVE_COM, "Referer": ALIYUNDRIVE_COM + "/", "User-Agent": PCS_UA} CheckNameMode = Literal[ "overwrite", # 直接覆盖,以后多版本有用 @@ -769,14 +770,15 @@ def get_share_token(self, share_id: str, share_password: str = ""): """Get share token""" shared_auth = self.__class__.SHARE_AUTHS.get(share_id) - if shared_auth and not shared_auth.is_expired(): + if shared_auth is not None: share_password = share_password or shared_auth.share_password - return shared_auth.info + + if not shared_auth.is_expired(): + return shared_auth.info url = PcsNode.ShareToken.url() data = dict(share_id=share_id, share_pwd=share_password) resp = self._request(Method.Post, url, json=data) - info = resp.json() if info.get("share_token"): # Store share password for refreshing share token @@ -883,6 +885,10 @@ def user_info(self): @assert_ok @handle_error def download_link(self, file_id: str): + info = self.meta(file_id)["responses"][0]["body"] + if info.get("url") or info.get("download_url"): + return info + url = PcsNode.DownloadUrl.url() data = dict(drive_id=self.default_drive_id, file_id=file_id) headers = dict(PCS_HEADERS) @@ -898,7 +904,7 @@ def file_stream( encrypt_password: bytes = b"", ) -> Optional[RangeRequestIO]: info = self.download_link(file_id) - url = info["url"] + url = info.get("url") or info.get("download_url") headers = { "User-Agent": PCS_UA, @@ -1207,8 +1213,8 @@ def meta(self, *file_ids: str, share_id: str = None): responses = [] for file_id in file_ids: - data = dict(file_id=file_id, drive_id=self.default_drive_id) - url = PcsNode.Meta.url() + data = dict(file_id=file_id, fields="*", drive_id=self.default_drive_id) + url = OpenPcsNode.Meta.url() resp = self._request(Method.Post, url, json=data) info = resp.json() responses.append(dict(body=info)) @@ -1276,7 +1282,7 @@ def list( assert limit <= 200, "`limit` should be less than 200" - url = PcsNode.FileList.url() + url = OpenPcsNode.FileList.url() orderby = "name" if name: orderby = "name" diff --git a/alipcs_py/commands/download.py b/alipcs_py/commands/download.py index 04dd1d4..d7a5071 100644 --- a/alipcs_py/commands/download.py +++ b/alipcs_py/commands/download.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Any, Callable +from typing import Iterable, Optional, List, Tuple from types import SimpleNamespace from enum import Enum from pathlib import Path @@ -6,13 +6,15 @@ import time import shutil import subprocess -from concurrent.futures import Future +import random from alipcs_py.alipcs import AliPCSApi, PcsFile +from alipcs_py.alipcs.errors import AliPCSError from alipcs_py.alipcs.pcs import PCS_UA +from alipcs_py.common.concurrent import Executor from alipcs_py.utils import human_size_to_int from alipcs_py.common import constant -from alipcs_py.common.io import to_decryptio, DecryptIO, READ_SIZE +from alipcs_py.common.io import RangeRequestIO, to_decryptio, DecryptIO, READ_SIZE from alipcs_py.common.downloader import MeDownloader from alipcs_py.common.progress_bar import ( _progress, @@ -32,7 +34,7 @@ USER_AGENT = PCS_UA -DEFAULT_CONCURRENCY = 5 +DEFAULT_CONCURRENCY = os.cpu_count() or 1 DEFAULT_CHUNK_SIZE = str(100 * constant.OneM) # This is the threshold of range request setted by Ali server @@ -43,6 +45,7 @@ class DownloadParams(SimpleNamespace): concurrency: int = DEFAULT_CONCURRENCY chunk_size: str = DEFAULT_CHUNK_SIZE quiet: bool = False + retries: int = 2 DEFAULT_DOWNLOADPARAMS = DownloadParams() @@ -76,21 +79,14 @@ def download( localpath_tmp = localpath + ".tmp" - def done_callback(fut: Future): - err = fut.exception() - if not err: - shutil.move(localpath_tmp, localpath) - else: - logger.info("`download`: MeDownloader fails: error: %s", err) - if self == Downloader.me: self._me_download( url, localpath_tmp, downloadparams=downloadparams, - done_callback=done_callback, encrypt_password=encrypt_password, ) + shutil.move(localpath_tmp, localpath) return elif self == Downloader.aget_py: cmd = self._aget_py_cmd(url, localpath_tmp, downloadparams) @@ -134,7 +130,6 @@ def _me_download( url: str, localpath: str, downloadparams: DownloadParams = DEFAULT_DOWNLOADPARAMS, - done_callback: Optional[Callable[[Future], Any]] = None, encrypt_password: bytes = b"", ): headers = { @@ -148,41 +143,40 @@ def _me_download( init_progress_bar() task_id = _progress.add_task("MeDownloader", start=False, title=localpath) - def _wrap_done_callback(fut: Future): + def done_callback(): remove_progress_task(task_id) - if done_callback: - done_callback(fut) - def monit_callback(task_id: Optional[TaskID], offset: int): + def monitor_callback(offset: int): if task_id is not None: _progress.update(task_id, completed=offset + 1) - def except_callback(task_id: Optional[TaskID]): + def except_callback(err): reset_progress_task(task_id) chunk_size_int = human_size_to_int(downloadparams.chunk_size) - meDownloader = MeDownloader( + io = RangeRequestIO( "GET", url, headers=headers, - max_workers=downloadparams.concurrency, max_chunk_size=chunk_size_int, - callback=monit_callback, + callback=monitor_callback, encrypt_password=encrypt_password, ) if task_id is not None: - length = len(meDownloader) + length = len(io) _progress.update(task_id, total=length) _progress.start_task(task_id) - meDownloader.download( - Path(localpath), - task_id=task_id, + meDownloader = MeDownloader( + io, + localpath=Path(localpath), continue_=True, - done_callback=_wrap_done_callback, + retries=downloadparams.retries, + done_callback=done_callback, except_callback=except_callback, ) + meDownloader.download() def _aget_py_cmd( self, @@ -274,6 +268,7 @@ def download_file( out_cmd: bool = False, encrypt_password: bytes = b"", ): + quiet = downloadparams.quiet localpath = Path(localdir) / pcs_file.name # Make sure parent directory existed @@ -281,13 +276,11 @@ def download_file( localpath.parent.mkdir(parents=True) if not out_cmd and localpath.exists(): - print(f"[yellow]{localpath}[/yellow] is ready existed.") + if not quiet: + print(f"[yellow]{localpath}[/yellow] is ready existed.") return - if not pcs_file: - return - - if downloader != Downloader.me: + if not quiet and downloader != Downloader.me: print(f"[italic blue]Download[/italic blue]: {pcs_file.path or pcs_file.name} to {localpath}") download_url: Optional[str] @@ -305,66 +298,70 @@ def download_file( if not pcs_file or pcs_file.is_dir: return - pcs_file.update_download_url(api) + while True: + try: + pcs_file = api.update_download_url(pcs_file) + break + except AliPCSError as err: + if err.error_code == "TooManyRequests": + time.sleep(random.randint(1, 2)) + continue + raise err + download_url = pcs_file.download_url assert download_url - downloader.download( - download_url, - str(localpath), - downloadparams=downloadparams, - out_cmd=out_cmd, - encrypt_password=encrypt_password, - ) + try: + downloader.download( + download_url, + str(localpath), + downloadparams=downloadparams, + out_cmd=out_cmd, + encrypt_password=encrypt_password, + ) + except Exception as err: + logger.error("`download_file` fails: error: %s", err) + if not quiet: + print(f"[red]ERROR[/red]: `{pcs_file.path or pcs_file.name}` download fails.") if share_id: api.remove(pcs_file.file_id) -def download_dir( +def walk_remote_paths( api: AliPCSApi, - pcs_file: PcsFile, + pcs_files: List[PcsFile], localdir: str, share_id: str = None, sifters: List[Sifter] = [], recursive: bool = False, from_index: int = 0, - downloader: Downloader = DEFAULT_DOWNLOADER, - downloadparams=DEFAULT_DOWNLOADPARAMS, - out_cmd: bool = False, - encrypt_password: bytes = b"", -): - remotefiles = list(api.list_iter(pcs_file.file_id, share_id=share_id)) - remotefiles = sift(remotefiles, sifters, recursive=recursive) - for rp in remotefiles[from_index:]: - if rp.is_file: - download_file( - api, - rp, - localdir, - share_id=share_id, - downloader=downloader, - downloadparams=downloadparams, - out_cmd=out_cmd, - encrypt_password=encrypt_password, - ) - else: # is_dir - if recursive: - _localdir = Path(localdir) / os.path.basename(rp.path) - download_dir( - api, - rp, - str(_localdir), - share_id=share_id, - sifters=sifters, - recursive=recursive, - from_index=from_index, - downloader=downloader, - downloadparams=downloadparams, - out_cmd=out_cmd, - encrypt_password=encrypt_password, - ) + deep: int = 0, +) -> Iterable[Tuple[PcsFile, str]]: + pcs_files = [pf for pf in sift(pcs_files, sifters, recursive=recursive)] + for pf in pcs_files: + if pf.is_file: + yield pf, localdir + else: + if deep > 0 and not recursive: + continue + + _localdir = Path(localdir) / pf.name + for pcs_file in api.list_iter(pf.file_id, share_id=share_id): + if pcs_file.is_file: + yield pcs_file, str(_localdir) + else: + yield from walk_remote_paths( + api, + [pcs_file], + str(_localdir), + share_id=share_id, + sifters=sifters, + recursive=recursive, + from_index=from_index, + deep=deep + 1, + ) def download( @@ -399,73 +396,47 @@ def download( bool(encrypt_password), ) + quiet = downloadparams.quiet + + pcs_files = [] for rp in remotepaths: - rpf = api.path(rp, share_id=share_id) - if not rpf: - print(f"[yellow]WARNING[/yellow]: `{rp}` does not exist.") + pf = api.path(rp, share_id=share_id) + if pf is None: + if not quiet: + print(f"[yellow]WARNING[/yellow]: `{rp}` does not exist.") continue - - if rpf.is_file: - download_file( - api, - rpf, - localdir, - share_id=share_id, - downloader=downloader, - downloadparams=downloadparams, - out_cmd=out_cmd, - encrypt_password=encrypt_password, - ) - else: - _localdir = str(Path(localdir) / rpf.name) - download_dir( - api, - rpf, - _localdir, - share_id=share_id, - sifters=sifters, - recursive=recursive, - from_index=from_index, - downloader=downloader, - downloadparams=downloadparams, - out_cmd=out_cmd, - encrypt_password=encrypt_password, - ) + pcs_files.append(pf) for file_id in file_ids: - rpf = api.meta(file_id, share_id=share_id)[0] - if not rpf: - print(f"[yellow]WARNING[/yellow]: file_id `{file_id}` does not exist.") + info = api.meta(file_id, share_id=share_id) + if len(info) == 0: + if not quiet: + print(f"[yellow]WARNING[/yellow]: `{file_id}` does not exist.") continue - - if rpf.is_file: - download_file( - api, - rpf, - localdir, - share_id=share_id, - downloader=downloader, - downloadparams=downloadparams, - out_cmd=out_cmd, - encrypt_password=encrypt_password, - ) - else: - _localdir = str(Path(localdir) / rpf.name) - download_dir( + pcs_files.append(info[0]) + + using_me_downloader = downloader == Downloader.me + with Executor(downloadparams.concurrency if using_me_downloader else 1) as executor: + for pf, _localdir in walk_remote_paths( + api, + pcs_files, + localdir, + share_id=share_id, + sifters=sifters, + recursive=recursive, + from_index=from_index, + ): + executor.submit( + download_file, api, - rpf, + pf, _localdir, share_id=share_id, - sifters=sifters, - recursive=recursive, - from_index=from_index, downloader=downloader, downloadparams=downloadparams, out_cmd=out_cmd, encrypt_password=encrypt_password, ) - if downloader == Downloader.me: - MeDownloader._exit_executor() - - _progress.stop() + if not quiet: + _progress.stop() diff --git a/alipcs_py/commands/play.py b/alipcs_py/commands/play.py index 39e5ef4..5ab5734 100644 --- a/alipcs_py/commands/play.py +++ b/alipcs_py/commands/play.py @@ -9,6 +9,7 @@ from urllib.parse import quote from alipcs_py.alipcs import AliPCSApi, PcsFile +from alipcs_py.alipcs.errors import AliPCSError from alipcs_py.commands.sifter import Sifter, sift from alipcs_py.commands.download import USER_AGENT from alipcs_py.commands.errors import CommandError @@ -118,31 +119,46 @@ def play_file( print(f"[italic blue]Play[/italic blue]: {pcs_file.path or pcs_file.name}") - # For typing - download_url: Optional[str] = None - use_local_server = bool(local_server) if share_id: + shared_pcs_file_id = pcs_file.file_id + shared_pcs_filename = pcs_file.name use_local_server = False remote_temp_dir = "/__alipcs_py_temp__" pcs_temp_dir = api.path(remote_temp_dir) or api.makedir_path(remote_temp_dir) - pcs_file = api.transfer_shared_files([pcs_file.file_id], pcs_temp_dir.file_id, share_id)[0] - # download_url = api.shared_file_download_url(pcs_file.file_id, share_id) - + pf = api.transfer_shared_files([shared_pcs_file_id], pcs_temp_dir.file_id, share_id)[0] + target_file_id = pf.file_id while True: - pcs_file = api.meta(pcs_file.file_id)[0] - if pcs_file.download_url: - break - time.sleep(2) + pfs = api.search_all(shared_pcs_filename) + for pf_ in pfs: + if pf_.file_id == target_file_id: + pcs_file = pf_ + break + else: + time.sleep(2) + continue + + break + download_url: Optional[str] = None if use_local_server: download_url = f"{local_server}/__fileid__/?file_id={pcs_file.file_id}" print("url:", download_url) else: if not pcs_file or pcs_file.is_dir: return - pcs_file.update_download_url(api) + + while True: + try: + pcs_file = api.update_download_url(pcs_file) + break + except AliPCSError as err: + if err.error_code == "TooManyRequests": + time.sleep(random.randint(1, 2)) + continue + raise err + download_url = pcs_file.download_url if download_url: diff --git a/alipcs_py/commands/server.py b/alipcs_py/commands/server.py index d541830..2db5768 100644 --- a/alipcs_py/commands/server.py +++ b/alipcs_py/commands/server.py @@ -237,9 +237,9 @@ def start_server( make_http_server(path) log_config = copy.deepcopy(uvicorn.config.LOGGING_CONFIG) - log_config["formatters"]["access"][ - "fmt" - ] = '%(asctime)s - %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s - %(msecs)d ms' + log_config["formatters"]["access"]["fmt"] = ( + '%(asctime)s - %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s - %(msecs)d ms' + ) uvicorn.run( "alipcs_py.commands.server:app", host=host, diff --git a/alipcs_py/commands/upload.py b/alipcs_py/commands/upload.py index a33b44e..2e109aa 100644 --- a/alipcs_py/commands/upload.py +++ b/alipcs_py/commands/upload.py @@ -1,5 +1,5 @@ from hashlib import sha1 -from typing import Optional, List, Tuple, IO +from typing import Callable, Optional, List, Tuple, IO import os import time @@ -509,10 +509,20 @@ def upload_many( logger.debug("`upload_many`: Upload: index: %s, task_id: %s", idx, task_id) + retry_upload_file = retry( + -1, + except_callback=lambda err, fail_count: logger.warning( + "`upload_file`: fails: error: %s, fail_count: %s", + err, + fail_count, + exc_info=err, + ), + )(upload_file) + fut = executor.submit( sure_release, semaphore, - upload_file, + retry_upload_file, api, from_to, check_name_mode, @@ -546,15 +556,6 @@ def upload_many( _progress.console.print(table) -@retry( - -1, - except_callback=lambda err, fail_count: logger.warning( - "`upload_file`: fails: error: %s, fail_count: %s", - err, - fail_count, - exc_info=err, - ), -) def upload_file( api: AliPCSApi, from_to: FromTo, @@ -565,6 +566,7 @@ def upload_file( task_id: Optional[TaskID] = None, user_id: Optional[str] = None, user_name: Optional[str] = None, + callback_for_monitor: Optional[Callable[[MultipartEncoderMonitor], None]] = None, ): """Upload one file with one connection""" @@ -687,7 +689,11 @@ def callback_for_slice(monitor: MultipartEncoderMonitor): ) upload_url = upload_urls[slice_idx] - api.upload_slice(io, upload_url, callback=callback_for_slice) + api.upload_slice( + io, + upload_url, + callback=callback_for_slice if callback_for_monitor is None else callback_for_monitor, + ) slice_idx += 1 break except Exception as err: diff --git a/alipcs_py/common/concurrent.py b/alipcs_py/common/concurrent.py index d2428b9..31649f6 100644 --- a/alipcs_py/common/concurrent.py +++ b/alipcs_py/common/concurrent.py @@ -1,16 +1,19 @@ +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Optional, Callable, Any from functools import wraps from threading import Semaphore def sure_release(semaphore: Semaphore, func, *args, **kwargs): + """Release semaphore after func is done.""" + try: return func(*args, **kwargs) finally: semaphore.release() -def retry(times: int, except_callback: Optional[Callable[..., Any]] = None): +def retry(times: int, except_callback: Optional[Callable[[Exception, int], Any]] = None): """Retry times when func fails""" def wrap(func): @@ -34,3 +37,33 @@ def retry_it(*args, **kwargs): return retry_it return wrap + + +class Executor: + """ + Executor is a ThreadPoolExecutor when max_workers > 1, else a single thread executor. + """ + + def __init__(self, max_workers: int = 1): + self._max_workers = max_workers + self._pool = ThreadPoolExecutor(max_workers=max_workers) if max_workers > 1 else None + self._semaphore = Semaphore(max_workers) + self._futures = [] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._pool is not None: + as_completed(self._futures) + self._pool.shutdown() + self._futures.clear() + + def submit(self, func, *args, **kwargs): + if self._pool is not None: + self._semaphore.acquire() + fut = self._pool.submit(sure_release, self._semaphore, func, *args, **kwargs) + self._futures.append(fut) + return fut + else: + return func(*args, **kwargs) diff --git a/alipcs_py/common/date.py b/alipcs_py/common/date.py index ec3b885..d671bf5 100644 --- a/alipcs_py/common/date.py +++ b/alipcs_py/common/date.py @@ -1,5 +1,6 @@ import time from datetime import datetime, timezone +from dateutil import parser def now_timestamp() -> int: @@ -9,17 +10,15 @@ def now_timestamp() -> int: def iso_8601_to_timestamp(date_string: str) -> int: - """Convert ISO 8601 datetime string to timestamp + """Convert ISO 8601 datetime string to the timestamp (integer) Args: date_string (str): ISO 8601 format. e.g. "2021-06-22T07:16:03Z" or "2021-06-22T07:16:03.032Z" """ - if len(date_string) == 20: - return int(datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S%z").timestamp()) - else: - return int(datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%f%z").timestamp()) + date_obj = parser.parse(date_string) + return int(date_obj.timestamp()) def timestamp_to_iso_8601(timestamp: int) -> str: diff --git a/alipcs_py/common/downloader.py b/alipcs_py/common/downloader.py index 7d620a8..3c0ae1d 100644 --- a/alipcs_py/common/downloader.py +++ b/alipcs_py/common/downloader.py @@ -1,113 +1,76 @@ -from typing import Optional, List, Any, Callable +from typing import Optional, Any, Callable from os import PathLike from pathlib import Path -from threading import Semaphore -from concurrent.futures import ThreadPoolExecutor, as_completed, Future -from alipcs_py.common.constant import CPU_NUM from alipcs_py.common.io import RangeRequestIO -from alipcs_py.common.concurrent import sure_release, retry +from alipcs_py.common.concurrent import retry -from rich.progress import TaskID DEFAULT_MAX_WORKERS = 5 -class MeDownloader(RangeRequestIO): - _executor: ThreadPoolExecutor - _semaphore: Semaphore - _futures: List[Future] - - @classmethod - def _set_executor( - cls, - max_workers: int = CPU_NUM, - ): - cls._executor = ThreadPoolExecutor(max_workers=max_workers) - cls._semaphore = Semaphore(max_workers) - cls._futures = [] - - @classmethod - def _exit_executor(cls): - if getattr(cls, "_executor", None): - as_completed(cls._futures) - cls._futures = [] - cls._executor.__exit__(None, None, None) - - def __init__(self, *args, max_workers: int = CPU_NUM, **kwargs): - super().__init__(*args, **kwargs) - if not getattr(self.__class__, "_executor", None): - self.__class__._set_executor(max_workers) - - def download( +class MeDownloader: + def __init__( self, + range_request_io: RangeRequestIO, localpath: PathLike, - task_id: Optional[TaskID], continue_: bool = False, - done_callback: Optional[Callable[[Future], Any]] = None, - except_callback: Optional[Callable[..., Any]] = None, - ): + retries: int = 2, + done_callback: Optional[Callable[..., Any]] = None, + except_callback: Optional[Callable[[Exception], Any]] = None, + ) -> None: + self.range_request_io = range_request_io + self.localpath = localpath + self.continue_ = continue_ + self.retries = retries + self.done_callback = done_callback + self.except_callback = except_callback + + def _init_fd(self): + if self.continue_: + path = Path(self.localpath) + if self.range_request_io.seekable(): + offset = path.stat().st_size if path.exists() else 0 + fd = path.open("ab") + fd.seek(offset, 0) + else: + offset = 0 + fd = path.open("wb") + else: + offset = 0 + fd = open(self.localpath, "wb") + + self.offset = offset + self.fd = fd + + def download(self): """ Download the url content to `localpath` - The downloading work executing in the class ThreadPoolExecutor - Args: continue_ (bool): If set to True, only downloading the remain content depended on the size of `localpath` """ - self._localpath = localpath - self._task_id = task_id - self._except_callback = except_callback - self.continue_ = continue_ - - cls = self.__class__ - cls._semaphore.acquire() - - fut = cls._executor.submit( - sure_release, - cls._semaphore, - self.work, + @retry( + self.retries, + except_callback=lambda err, fails: ( + self.range_request_io.reset(), + self.except_callback(err) if self.except_callback else None, + ), ) - if done_callback: - fut.add_done_callback(done_callback) - cls._futures.append(fut) + def _download(): + self._init_fd() - def _init_fd(self): - if self.continue_: - _path = Path(self._localpath) - if self.seekable(): - _offset = _path.stat().st_size if _path.exists() else 0 - _fd = _path.open("ab") - _fd.seek(_offset, 0) - else: - _offset = 0 - _fd = _path.open("wb") - else: - _offset = 0 - _fd = open(self._localpath, "wb") + self.range_request_io.seek(self.offset) - self._offset = _offset - self._fd = _fd + for buf in self.range_request_io.read_iter(): + self.fd.write(buf) + self.offset += len(buf) - @retry(30) - def work(self): - self._init_fd() + if self.done_callback: + self.done_callback() - try: - start, end = self._offset, len(self) + self.fd.close() - for b in self._auto_decrypt_request.read((start, end)): - self._fd.write(b) - self._offset += len(b) - # Call callback - if self._callback: - self._callback(self._task_id, self._offset) - except Exception as err: - if self._except_callback: - self._except_callback(self._task_id) - self.reset() - raise err - finally: - self._fd.close() + _download() diff --git a/alipcs_py/common/io.py b/alipcs_py/common/io.py index 44bb49e..12dcdd4 100644 --- a/alipcs_py/common/io.py +++ b/alipcs_py/common/io.py @@ -1,4 +1,5 @@ from typing import ( + Iterable, Optional, List, Tuple, @@ -6,7 +7,6 @@ Union, Any, Callable, - Generator, IO, cast, ) @@ -965,7 +965,7 @@ def __len__(self) -> int: assert self._content_length return self._content_length - self._total_head_len - def read(self, _range: Tuple[int, int]) -> Generator[bytes, None, None]: + def read(self, _range: Tuple[int, int]) -> Iterable[bytes]: self._init() start, end = _range @@ -1065,6 +1065,22 @@ def read(self, size: int = -1) -> Optional[bytes]: self._callback(self._offset) return buf + def read_iter(self, size: int = -1) -> Iterable[bytes]: + if size == 0: + return b"" + + if size == -1: + size = len(self) - self._offset + + start, end = self._offset, self._offset + size + + for buf in self._auto_decrypt_request.read((start, end)): + self._offset += len(buf) + # Call callback + if self._callback: + self._callback(self._offset) + yield buf + def seek(self, offset: int, whence: int = 0) -> int: if whence == 0: self._offset = offset diff --git a/pyproject.toml b/pyproject.toml index 8e5d867..ba77d25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ line-length = 119 [tool.ruff] -ignore = ["E501", "E402", "F401", "F403", "F841"] +lint.ignore = ["E501", "E402", "F401", "F403", "F841"] line-length = 119 [tool.poetry.dependencies] @@ -33,6 +33,7 @@ requests = ">=2.31" requests-toolbelt = ">=1.0" peewee = ">=3.17" toml = ">=0.10" +python-dateutil = ">=2.8" qrcode = ">=7.4" rich = ">=13.7" pillow = ">=10.1" @@ -50,7 +51,7 @@ passlib = ">=1.7" [tool.poetry.group.dev.dependencies] pytest = ">=7.4" -ruff = ">=0.2" +ruff = ">=0.3" setuptools = ">=69.0" cython = ">=3.0" diff --git a/tests/test_common.py b/tests/test_common.py index c0f9a4a..f5c1b22 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -6,6 +6,7 @@ import requests +from alipcs_py.common.concurrent import Executor from alipcs_py.common import constant from alipcs_py.common.number import u64_to_u8x8, u8x8_to_u64 from alipcs_py.common.path import join_path @@ -430,3 +431,21 @@ def test_human_size(): s_int = human_size_to_int(s_str) assert s == s_int + + +def test_executor(): + def f(n): + return n + + with Executor(max_workers=1) as executor: + r = executor.submit(f, 1) + assert r == 1 + + futs = [] + with Executor(max_workers=2) as executor: + fut = executor.submit(f, 1) + futs.append(fut) + + for fut in futs: + r = fut.result() + assert r == 1