diff --git a/pdm.lock b/pdm.lock index d83da56..33edffc 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "testing"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:ff215917aa0a4f634788c7f83cab0540c5d5446b0b6d551af9ab635b03e14469" +content_hash = "sha256:90a2e0d19ce5af6c64699269d1f82ca785c5ef752121c570efdee808a295bb3d" [[package]] name = "annotated-types" @@ -1056,6 +1056,17 @@ files = [ {file = "starlette-0.35.1.tar.gz", hash = "sha256:3e2639dac3520e4f58734ed22553f950d3f3cb1001cd2eaac4d57e8cdc5f66bc"}, ] +[[package]] +name = "tenacity" +version = "8.2.3" +requires_python = ">=3.7" +summary = "Retry code until it succeeds" +groups = ["default"] +files = [ + {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"}, + {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"}, +] + [[package]] name = "tomli" version = "2.0.1" diff --git a/playground/generate_image.py b/playground/generate_image.py index 12ea1e6..c16b703 100644 --- a/playground/generate_image.py +++ b/playground/generate_image.py @@ -38,6 +38,7 @@ async def main(): session=globe_s, remove_sign=True ) except APIError as e: + print(str(e)) print(e.response) return diff --git a/pyproject.toml b/pyproject.toml index 497dd68..96e5d75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "argon2-cffi>=23.1.0", "opencv-python>=4.8.1.78", "fake-useragent>=1.4.0", + "tenacity>=8.2.3", ] requires-python = ">=3.8" readme = "README.md" diff --git a/src/novelai_python/sdk/ai/generate_image/__init__.py b/src/novelai_python/sdk/ai/generate_image/__init__.py index e9057af..cde9550 100644 --- a/src/novelai_python/sdk/ai/generate_image/__init__.py +++ b/src/novelai_python/sdk/ai/generate_image/__init__.py @@ -18,6 +18,7 @@ from curl_cffi.requests import AsyncSession from loguru import logger from pydantic import BaseModel, ConfigDict, PrivateAttr, field_validator, model_validator, Field +from tenacity import retry, stop_after_attempt, wait_random, retry_if_exception from typing_extensions import override from ._enum import Model, Sampler, NoiseSchedule, ControlNetModel, Action, UCPreset @@ -327,6 +328,12 @@ async def necessary_headers(self, request_data) -> dict: "Cache-Control": "no-cache", } + @retry( + wait=wait_random(min=1, max=3), + stop=stop_after_attempt(3), + retry=retry_if_exception(lambda e: hasattr(e, "code") and str(e.code) == "500"), + reraise=True + ) async def request(self, session: Union[AsyncSession, "CredentialBase"], *, diff --git a/src/novelai_python/sdk/ai/upscale.py b/src/novelai_python/sdk/ai/upscale.py index bdecb53..f21fb65 100644 --- a/src/novelai_python/sdk/ai/upscale.py +++ b/src/novelai_python/sdk/ai/upscale.py @@ -15,6 +15,7 @@ from curl_cffi.requests import AsyncSession from loguru import logger from pydantic import ConfigDict, PrivateAttr, model_validator +from tenacity import wait_random, retry, stop_after_attempt, retry_if_exception from ..schema import ApiBaseModel from ..._exceptions import APIError, AuthError, SessionHttpError @@ -87,6 +88,12 @@ async def necessary_headers(self, request_data) -> dict: "Cache-Control": "no-cache", } + @retry( + wait=wait_random(min=1, max=3), + stop=stop_after_attempt(3), + retry=retry_if_exception(lambda e: hasattr(e, "code") and str(e.code) == "500"), + reraise=True + ) async def request(self, session: Union[AsyncSession, "CredentialBase"], *,