diff --git a/README.md b/README.md index c8495f2..bfccd8d 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # novelai-python -[![PyPI version](https://badge.fury.io/py/novelai_python.svg)](https://badge.fury.io/py/novelai_python) +[![PyPI version](https://badge.fury.io/py/novelai-python.svg)](https://badge.fury.io/py/novelai-python) [![Downloads](https://pepy.tech/badge/novelai_python)](https://pepy.tech/project/novelai_python) The goal of this repository is to use Pydantic to build legitimate requests to access the Novelai API service. @@ -11,16 +11,20 @@ The goal of this repository is to use Pydantic to build legitimate requests to a - [x] /user/subscription - [x] /user/login - [x] /user/information +- [x] /ai/upscale - [ ] /ai/generate-image/suggest-tags - [ ] /ai/annotate-image - [ ] /ai/classify -- [ ] /ai/upscale - [ ] /ai/generate-prompt - [ ] /ai/generate - [ ] /ai/generate-voice ### Usage 🖥️ +```shell +pip install novelai-python +``` + More examples can be found in the [playground](https://github.com/LlmKira/novelai-python/tree/main/playground) directory. ```python diff --git a/playground/generate_image.py b/playground/generate_image.py index a18b3e2..12ea1e6 100644 --- a/playground/generate_image.py +++ b/playground/generate_image.py @@ -12,6 +12,7 @@ from novelai_python import APIError, Login from novelai_python import GenerateImageInfer, ImageGenerateResp, JwtCredential from novelai_python.sdk.ai.generate_image import Action +from novelai_python.utils.useful import enum_to_list load_dotenv() @@ -25,13 +26,14 @@ async def main(): _res = await Login.build(user_name=os.getenv("NOVELAI_USER"), password=os.getenv("NOVELAI_PASS") ).request() try: + print(f"Action List:{enum_to_list(Action)}") gen = GenerateImageInfer.build( prompt=f"1girl, winter, jacket, sfw, angel, flower,{enhance}", action=Action.GENERATE, ) cost = gen.calculate_cost(is_opus=True) print(f"charge: {cost} if you are vip3") - print(f"charge: {gen.calculate_cost(is_opus=True)}") + print(f"charge: {gen.calculate_cost(is_opus=False)} if you are not vip3") _res = await gen.request( session=globe_s, remove_sign=True ) diff --git a/playground/upscale.py b/playground/upscale.py new file mode 100644 index 0000000..5dd94a3 --- /dev/null +++ b/playground/upscale.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/13 上午11:58 +# @Author : sudoskys +# @File : upscale.py +# @Software: PyCharm +# -*- coding: utf-8 -*- +# @Time : 2024/2/14 下午5:20 +# @Author : sudoskys +# @File : upscale_demo.py +# @Software: PyCharm +# To run the demo, you need an event loop, for instance by using asyncio +import asyncio +import os + +from dotenv import load_dotenv +from pydantic import SecretStr + +from novelai_python import APIError, Upscale +from novelai_python import UpscaleResp, JwtCredential +from novelai_python.sdk.ai.generate_image import Action +from novelai_python.utils.useful import enum_to_list + +load_dotenv() + +token = None +jwt = os.getenv("NOVELAI_JWT") or token + + +async def main(): + globe_s = JwtCredential(jwt_token=SecretStr(jwt)) + if not os.path.exists("generate_image.png"): + raise FileNotFoundError("generate_image.png not found") + with open("generate_image.png", "rb") as f: + data = f.read() + try: + print(f"Action List:{enum_to_list(Action)}") + upscale = Upscale(image=data) # Auto detect image size | base64 + + _res = await upscale.request( + session=globe_s, remove_sign=True + ) + except APIError as e: + print(e.response) + return + + # Meta + _res: UpscaleResp + print(_res.meta.endpoint) + file = _res.files + with open("upscale.py.png", "wb") as f: + f.write(file[1]) + + +loop = asyncio.get_event_loop() +loop.run_until_complete(main()) diff --git a/src/novelai_python/__init__.py b/src/novelai_python/__init__.py index f229c0a..fbeaefe 100644 --- a/src/novelai_python/__init__.py +++ b/src/novelai_python/__init__.py @@ -16,11 +16,15 @@ from .sdk import Information, InformationResp from .sdk import Login, LoginResp from .sdk import Subscription, SubscriptionResp +from .sdk import Upscale, UpscaleResp __all__ = [ "GenerateImageInfer", "ImageGenerateResp", + "Upscale", + "UpscaleResp", + "Subscription", "SubscriptionResp", diff --git a/src/novelai_python/_response/ai/generate_image.py b/src/novelai_python/_response/ai/generate_image.py index 5786f93..673802e 100644 --- a/src/novelai_python/_response/ai/generate_image.py +++ b/src/novelai_python/_response/ai/generate_image.py @@ -7,8 +7,10 @@ from pydantic import BaseModel +from ..schema import RespBase -class ImageGenerateResp(BaseModel): + +class ImageGenerateResp(RespBase): class RequestParams(BaseModel): endpoint: str raw_request: dict = None diff --git a/src/novelai_python/_response/ai/upscale.py b/src/novelai_python/_response/ai/upscale.py new file mode 100644 index 0000000..fa56421 --- /dev/null +++ b/src/novelai_python/_response/ai/upscale.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/13 上午11:29 +# @Author : sudoskys +# @File : upscale.py +# @Software: PyCharm +from typing import Tuple + +from pydantic import BaseModel + +from ..schema import RespBase + + +class UpscaleResp(RespBase): + class RequestParams(BaseModel): + endpoint: str + raw_request: dict = None + + meta: RequestParams + files: Tuple[str, bytes] = None diff --git a/src/novelai_python/_response/schema.py b/src/novelai_python/_response/schema.py new file mode 100644 index 0000000..d978b90 --- /dev/null +++ b/src/novelai_python/_response/schema.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/13 上午11:30 +# @Author : sudoskys +# @File : schema.py +# @Software: PyCharm + +from pydantic import BaseModel + + +class RespBase(BaseModel): + pass diff --git a/src/novelai_python/_response/user/information.py b/src/novelai_python/_response/user/information.py index 8ff6920..939a633 100644 --- a/src/novelai_python/_response/user/information.py +++ b/src/novelai_python/_response/user/information.py @@ -4,10 +4,12 @@ # @File : information.py # @Software: PyCharm -from pydantic import BaseModel, Field +from pydantic import Field +from ..schema import RespBase -class InformationResp(BaseModel): + +class InformationResp(RespBase): emailVerified: bool = Field(..., description="Email verification status") emailVerificationLetterSent: bool = Field(..., description="Email verification letter sent status") trialActivated: bool = Field(..., description="Trial activation status") diff --git a/src/novelai_python/_response/user/login.py b/src/novelai_python/_response/user/login.py index 7aa2622..745fa5a 100644 --- a/src/novelai_python/_response/user/login.py +++ b/src/novelai_python/_response/user/login.py @@ -3,8 +3,9 @@ # @Author : sudoskys # @File : login.py # @Software: PyCharm -from pydantic import BaseModel +from ..schema import RespBase -class LoginResp(BaseModel): + +class LoginResp(RespBase): accessToken: str diff --git a/src/novelai_python/_response/user/subscription.py b/src/novelai_python/_response/user/subscription.py index 535abf9..4ab9d29 100644 --- a/src/novelai_python/_response/user/subscription.py +++ b/src/novelai_python/_response/user/subscription.py @@ -7,6 +7,8 @@ from pydantic import BaseModel, Field +from ..schema import RespBase + class TrainingSteps(BaseModel): fixedTrainingStepsLeft: int @@ -30,7 +32,7 @@ class Perks(BaseModel): contextTokens: int -class SubscriptionResp(BaseModel): +class SubscriptionResp(RespBase): tier: int = Field(..., description="Subscription tier") active: bool = Field(..., description="Subscription status") expiresAt: int = Field(..., description="Subscription expiration time") diff --git a/src/novelai_python/credential/JwtToken.py b/src/novelai_python/credential/JwtToken.py index 699698b..7550ac9 100644 --- a/src/novelai_python/credential/JwtToken.py +++ b/src/novelai_python/credential/JwtToken.py @@ -24,6 +24,7 @@ async def get_session(self, timeout: int = 180, update_headers: dict = None): self._session = AsyncSession(timeout=timeout, headers={ "Accept": "*/*", "Accept-Language": "zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2", + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:123.0) Gecko/20100101 Firefox/123.0", "Accept-Encoding": "gzip, deflate, br", "Authorization": f"Bearer {self.jwt_token.get_secret_value()}", "Content-Type": "application/json", diff --git a/src/novelai_python/credential/UserAuth.py b/src/novelai_python/credential/UserAuth.py index 1c0bf92..89ca8a9 100644 --- a/src/novelai_python/credential/UserAuth.py +++ b/src/novelai_python/credential/UserAuth.py @@ -30,6 +30,7 @@ async def get_session(self, timeout: int = 180, update_headers: dict = None): resp = await Login.build(user_name=self.username, password=self.password.get_secret_value()).request() self._session = AsyncSession(timeout=timeout, headers={ "Accept": "*/*", + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:123.0) Gecko/20100101 Firefox/123.0", "Accept-Language": "zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2", "Accept-Encoding": "gzip, deflate, br", "Authorization": f"Bearer {resp.accessToken}", diff --git a/src/novelai_python/sdk/__init__.py b/src/novelai_python/sdk/__init__.py index 2f079aa..dcefc18 100644 --- a/src/novelai_python/sdk/__init__.py +++ b/src/novelai_python/sdk/__init__.py @@ -8,3 +8,4 @@ from .user.information import Information, InformationResp # noqa 401 from .user.login import Login, LoginResp # noqa 401 from .user.subscription import Subscription, SubscriptionResp # noqa 401 +from .ai.upscale import Upscale, UpscaleResp # noqa 401 \ No newline at end of file diff --git a/src/novelai_python/sdk/ai/generate_image.py b/src/novelai_python/sdk/ai/generate_image.py index 0527136..231910d 100644 --- a/src/novelai_python/sdk/ai/generate_image.py +++ b/src/novelai_python/sdk/ai/generate_image.py @@ -22,7 +22,7 @@ from ..schema import ApiBaseModel from ..._exceptions import APIError, AuthError, ConcurrentGenerationError, SessionHttpError -from ..._response import ImageGenerateResp +from ..._response.ai.generate_image import ImageGenerateResp from ...credential import CredentialBase from ...utils import try_jsonfy, NovelAiMetadata @@ -102,7 +102,6 @@ class Resolution(Enum): class GenerateImageInfer(ApiBaseModel): _endpoint: Optional[str] = PrivateAttr("https://api.novelai.net") - _charge: bool = PrivateAttr(False) class Params(BaseModel): # Inpaint @@ -412,14 +411,30 @@ async def request(self, request_data = self.model_dump(mode="json", exclude_none=True) # Header if isinstance(session, AsyncSession): - session.headers.update(self.necessary_headers(request_data)) + session.headers.update(await self.necessary_headers(request_data)) elif isinstance(session, CredentialBase): update_header = await self.necessary_headers(request_data) session = await session.get_session(update_headers=update_header) if override_headers: session.headers.clear() session.headers.update(override_headers) - logger.debug(f"Request Data: {request_data}") + try: + _log_data = request_data.copy() + if self.action == Action.GENERATE: + logger.debug(f"Request Data: {_log_data}") + else: + _log_data.get("parameters", {}).update({ + "image": "base64 data" if self.parameters.image else "None", + } + ) + _log_data.get("parameters", {}).update( + { + "mask": "base64 data" if self.parameters.mask else "None", + } + ) + logger.debug(f"Request Data: {request_data}") + except Exception as e: + logger.warning(f"Error when print log data: {e}") try: assert hasattr(session, "post"), "session must have post method." response = await session.post( diff --git a/src/novelai_python/sdk/ai/upscale.py b/src/novelai_python/sdk/ai/upscale.py new file mode 100644 index 0000000..cc4de17 --- /dev/null +++ b/src/novelai_python/sdk/ai/upscale.py @@ -0,0 +1,204 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/13 上午10:48 +# @Author : sudoskys +# @File : upscale.py +# @Software: PyCharm +import base64 +import json +from io import BytesIO +from typing import Optional, Union +from urllib.parse import urlparse +from zipfile import ZipFile + +import curl_cffi +import httpx +from curl_cffi.requests import AsyncSession +from loguru import logger +from pydantic import ConfigDict, PrivateAttr, model_validator + +from ..schema import ApiBaseModel +from ..._exceptions import APIError, AuthError, SessionHttpError +from ..._response.ai.upscale import UpscaleResp +from ...credential import CredentialBase +from ...utils import try_jsonfy, NovelAiMetadata + + +class Upscale(ApiBaseModel): + _endpoint: Optional[str] = PrivateAttr("https://api.novelai.net") + image: Union[str, bytes] # base64 + width: Optional[int] = None + height: Optional[int] = None + scale: float = 4 + model_config = ConfigDict(extra="ignore") + + @model_validator(mode="after") + def validate_model(self): + if isinstance(self.image, str) and self.image.startswith("data:image/"): + raise ValueError("Invalid image format, must be base64 encoded.") + if isinstance(self.image, bytes): + self.image = base64.b64encode(self.image).decode("utf-8") + # Auto detect image size + try: + from PIL import Image + with Image.open(BytesIO(base64.b64decode(self.image))) as img: + width, height = img.size + except Exception as e: + logger.warning(f"Error when validate image size: {e}") + if self.width is None or self.height is None: + raise ValueError("Invalid image size and cant auto detect, must be set width and height.") + else: + if self.width is None: + self.width = width + if self.height is None: + self.height = height + return self + + @property + def base_url(self): + return f"{self.endpoint.strip('/')}/ai/upscale" + + @property + def endpoint(self): + return self._endpoint + + @endpoint.setter + def endpoint(self, value): + self._endpoint = value + + async def necessary_headers(self, request_data) -> dict: + """ + :param request_data: + :return: + """ + return { + "Host": urlparse(self.endpoint).netloc, + "Accept": "*/*", + "User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:123.0) Gecko/20100101 Firefox/123.0", + "Accept-Language": "zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2", + "Accept-Encoding": "gzip, deflate, br", + "Referer": "https://novelai.net/", + "Content-Type": "application/json", + "Origin": "https://novelai.net", + "Content-Length": str(len(json.dumps(request_data).encode("utf-8"))), + "Connection": "keep-alive", + "Sec-Fetch-Dest": "empty", + "Sec-Fetch-Mode": "cors", + "Sec-Fetch-Site": "same-site", + "Pragma": "no-cache", + "Cache-Control": "no-cache", + } + + async def request(self, + session: Union[AsyncSession, "CredentialBase"], + *, + override_headers: Optional[dict] = None, + remove_sign: bool = False + ) -> UpscaleResp: + """ + 生成图片 + :param override_headers: + :param session: session + :param remove_sign: 移除追踪信息 + :return: + """ + # Data Build + request_data = self.model_dump(mode="json", exclude_none=True) + # Header + if isinstance(session, AsyncSession): + session.headers.update(await self.necessary_headers(request_data)) + elif isinstance(session, CredentialBase): + update_header = await self.necessary_headers(request_data) + session = await session.get_session(update_headers=update_header) + if override_headers: + session.headers.clear() + session.headers.update(override_headers) + try: + _log_data = request_data.copy() + _log_data.update({"image": "base64 data"}) if isinstance(_log_data.get("image"), str) else None + logger.info(f"Upscale request data: {_log_data}") + except Exception as e: + logger.warning(f"Error when print log data: {e}") + try: + assert hasattr(session, "post"), "session must have post method." + response = await session.post( + self.base_url, + data=json.dumps(request_data).encode("utf-8") + ) + if response.headers.get('Content-Type') not in ['binary/octet-stream', 'application/x-zip-compressed']: + logger.error( + f"Error with content type: {response.headers.get('Content-Type')} and code: {response.status_code}" + ) + try: + _msg = response.json() + except Exception as e: + logger.warning(e) + if not isinstance(response.content, str) and len(response.content) < 50: + raise APIError( + message=f"Unexpected content type: {response.headers.get('Content-Type')}", + request=request_data, + code=response.status_code, + response=try_jsonfy(response.content) + ) + else: + _msg = {"statusCode": response.status_code, "message": response.content} + status_code = _msg.get("statusCode", response.status_code) + message = _msg.get("message", "Unknown error") + if status_code in [400, 401, 402]: + # 400 : validation error + # 401 : unauthorized + # 402 : payment required + # 409 : conflict + raise AuthError(message, request=request_data, code=status_code, response=_msg) + if status_code in [409]: + # conflict error + raise APIError(message, request=request_data, code=status_code, response=_msg) + """ + if status_code in [429]: + # concurrent error + raise ConcurrentGenerationError( + message=message, + request=request_data, + code=status_code, + response=_msg + ) + """ + raise APIError(message, request=request_data, code=status_code, response=_msg) + zip_file = ZipFile(BytesIO(response.content)) + unzip_content = [] + with zip_file as zf: + file_list = zf.namelist() + if not file_list: + raise APIError( + message="No file in zip", + request=request_data, + code=response.status_code, + response=try_jsonfy(response.content) + ) + for filename in file_list: + data = zip_file.read(filename) + if remove_sign: + try: + data = NovelAiMetadata.rehash(BytesIO(data), remove_stealth=True) + if not isinstance(data, bytes): + data = data.getvalue() + except Exception as e: + logger.exception(f"SdkWarn:Remove sign error: {e}") + unzip_content.append((filename, data)) + return UpscaleResp( + meta=UpscaleResp.RequestParams( + endpoint=self.base_url, + raw_request=request_data, + ), + files=unzip_content[0] + ) + except curl_cffi.requests.errors.RequestsError as exc: + logger.exception(exc) + raise SessionHttpError("An AsyncSession RequestsError occurred, maybe SSL error. Try again later!") + except httpx.HTTPError as exc: + logger.exception(exc) + raise SessionHttpError("An HTTPError occurred, maybe SSL error. Try again later!") + except APIError as e: + raise e + except Exception as e: + logger.opt(exception=e).exception("An Unexpected error occurred") + raise e diff --git a/src/novelai_python/sdk/user/login.py b/src/novelai_python/sdk/user/login.py index 7e73d3e..a623ae1 100644 --- a/src/novelai_python/sdk/user/login.py +++ b/src/novelai_python/sdk/user/login.py @@ -77,7 +77,7 @@ async def request(self, # Data Build request_data = self.model_dump(mode="json", exclude_none=True) if isinstance(session, AsyncSession): - session.headers.update(self.necessary_headers(request_data)) + session.headers.update(await self.necessary_headers(request_data)) elif isinstance(session, CredentialBase): session = await session.get_session(update_headers=await self.necessary_headers(request_data)) # Header diff --git a/src/novelai_python/utils/useful.py b/src/novelai_python/utils/useful.py index 37c43dd..5a4a7e2 100644 --- a/src/novelai_python/utils/useful.py +++ b/src/novelai_python/utils/useful.py @@ -5,7 +5,7 @@ # @Software: PyCharm import collections import random -from typing import List +from typing import List, Union def enum_to_list(enum_): @@ -22,7 +22,8 @@ def __init__(self, data: List[str]): self.used = collections.deque() self.users = {} - def get(self, user_id: int) -> str: + def get(self, user_id: Union[int, str]) -> str: + user_id = str(user_id) if user_id not in self.users: self.users[user_id] = {'data': self.data.copy(), 'used': collections.deque()} @@ -41,6 +42,3 @@ def get(self, user_id: int) -> str: user_used.append(selected) return selected - - - diff --git a/tests/test_upscale.py b/tests/test_upscale.py new file mode 100644 index 0000000..ac13740 --- /dev/null +++ b/tests/test_upscale.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/13 下午12:03 +# @Author : sudoskys +# @File : test_upscale.py +# @Software: PyCharm +# -*- coding: utf-8 -*- +# @Time : 2024/2/14 下午4:20 +# @Author : sudoskys +# @File : test_upscale.py +# @Software: PyCharm +from unittest import mock + +import pytest +from curl_cffi.requests import AsyncSession + +from novelai_python import APIError, Upscale, AuthError + + +@pytest.mark.asyncio +async def test_validation_error_during_upscale(): + validation_error_response = mock.Mock() + validation_error_response.headers = {"Content-Type": "application/json"} + validation_error_response.status_code = 400 + validation_error_response.json.return_value = { + "statusCode": 400, + "message": "A validation error occurred." + } + session = mock.MagicMock(spec=AsyncSession) + session.post = mock.AsyncMock(return_value=validation_error_response) + session.headers = {} + + upscale = Upscale(image="base64_encoded_image", height=100, width=100) + with pytest.raises(AuthError) as e: + await upscale.request(session=session) + assert e.type is AuthError + expected_message = 'A validation error occurred.' + assert expected_message == str(e.value) + + +@pytest.mark.asyncio +async def test_unauthorized_error_during_upscale(): + unauthorized_error_response = mock.Mock() + unauthorized_error_response.headers = {"Content-Type": "application/json"} + unauthorized_error_response.status_code = 401 + unauthorized_error_response.json.return_value = { + "statusCode": 401, + "message": "Unauthorized." + } + session = mock.MagicMock(spec=AsyncSession) + session.post = mock.AsyncMock(return_value=unauthorized_error_response) + session.headers = {} + + upscale = Upscale(image="base64_encoded_image", height=100, width=100) + with pytest.raises(APIError) as e: + await upscale.request(session=session) + assert e.type is AuthError + expected_message = 'Unauthorized.' + assert expected_message == str(e.value) + + +@pytest.mark.asyncio +async def test_unknown_error_during_upscale(): + unknown_error_response = mock.Mock() + unknown_error_response.headers = {"Content-Type": "application/json"} + unknown_error_response.status_code = 500 + unknown_error_response.json.return_value = { + "statusCode": 500, + "message": "Unknown error occurred." + } + session = mock.MagicMock(spec=AsyncSession) + session.post = mock.AsyncMock(return_value=unknown_error_response) + session.headers = {} + + upscale = Upscale(image="base64_encoded_image", height=100, width=100) + with pytest.raises(APIError) as e: + await upscale.request(session=session) + expected_message = 'Unknown error occurred.' + assert expected_message == str(e.value)