From 0b5319670047451526cd1c143c6fcd1332251763 Mon Sep 17 00:00:00 2001 From: sudoskys Date: Wed, 7 Feb 2024 10:29:04 +0800 Subject: [PATCH 1/2] :sparkles: feat: Add subscription endpoint Added a new endpoint `/user/subscription` to retrieve user subscription information. This endpoint makes use of the `Subscription` class from the `sdk.user` module. --- playground/subscription.py | 41 ++++++++++ pyproject.toml | 2 +- src/novelai_python/__init__.py | 9 ++- src/novelai_python/_response/__init__.py | 5 +- src/novelai_python/_response/user/__init__.py | 5 ++ .../_response/user/subscription.py | 52 ++++++++++++ src/novelai_python/sdk/__init__.py | 2 +- src/novelai_python/sdk/user/__init__.py | 7 ++ src/novelai_python/sdk/user/subscription.py | 81 +++++++++++++++++++ src/novelai_python/server.py | 21 +++++ tests/test_server_run.py | 38 ++++++--- 11 files changed, 247 insertions(+), 16 deletions(-) create mode 100644 playground/subscription.py create mode 100644 src/novelai_python/_response/user/__init__.py create mode 100644 src/novelai_python/_response/user/subscription.py create mode 100644 src/novelai_python/sdk/user/__init__.py create mode 100644 src/novelai_python/sdk/user/subscription.py diff --git a/playground/subscription.py b/playground/subscription.py new file mode 100644 index 0000000..6ed7c74 --- /dev/null +++ b/playground/subscription.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/7 上午10:16 +# @Author : sudoskys +# @File : subscription.py +# @Software: PyCharm +import asyncio +import os + +from dotenv import load_dotenv +from loguru import logger +from pydantic import SecretStr + +from novelai_python import APIError, SubscriptionResp +from novelai_python import Subscription, JwtCredential + +load_dotenv() + +enhance = "year 2023,dynamic angle, best quality, amazing quality, very aesthetic, absurdres" +token = None +jwt = os.getenv("NOVELAI_JWT") or token + + +async def main(): + globe_s = JwtCredential(jwt_token=SecretStr(jwt)) + try: + _res = await Subscription().request( + session=globe_s + ) + except APIError as e: + logger.exception(e) + print(e.response) + return + + _res: SubscriptionResp + print(_res) + print(_res.is_active) + print(_res.anlas_left) + + +loop = asyncio.get_event_loop() +loop.run_until_complete(main()) diff --git a/pyproject.toml b/pyproject.toml index 6fd2709..ef9be76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "novelai-python" -version = "0.1.7" +version = "0.1.8" description = "Novelai Python Binding With Pydantic" authors = [ { name = "sudoskys", email = "coldlando@hotmail.com" }, diff --git a/src/novelai_python/__init__.py b/src/novelai_python/__init__.py index 11bd4dc..8d2521d 100644 --- a/src/novelai_python/__init__.py +++ b/src/novelai_python/__init__.py @@ -9,14 +9,19 @@ AuthError, NovelAiError, ) # noqa: F401, F403 -from ._response import ImageGenerateResp # noqa: F401, F403 from .credential import JwtCredential # noqa: F401, F403 -from .sdk.ai import GenerateImageInfer # noqa: F401, F403 +from .sdk.ai import GenerateImageInfer, ImageGenerateResp # noqa: F401, F403 +from .sdk.user import Subscription, SubscriptionResp # noqa: F401, F403 __all__ = [ "GenerateImageInfer", "ImageGenerateResp", + + "Subscription", + "SubscriptionResp", + "JwtCredential", + "APIError", "AuthError", "NovelAiError", diff --git a/src/novelai_python/_response/__init__.py b/src/novelai_python/_response/__init__.py index 4c30a65..7f694ba 100644 --- a/src/novelai_python/_response/__init__.py +++ b/src/novelai_python/_response/__init__.py @@ -4,7 +4,8 @@ # @File : __init__.py.py # @Software: PyCharm from .ai.generate_image import ImageGenerateResp - +from .user.subscription import SubscriptionResp __all__ = [ - "ImageGenerateResp" + "ImageGenerateResp", + "SubscriptionResp", ] diff --git a/src/novelai_python/_response/user/__init__.py b/src/novelai_python/_response/user/__init__.py new file mode 100644 index 0000000..3cbe0b7 --- /dev/null +++ b/src/novelai_python/_response/user/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/7 上午9:57 +# @Author : sudoskys +# @File : __init__.py.py +# @Software: PyCharm diff --git a/src/novelai_python/_response/user/subscription.py b/src/novelai_python/_response/user/subscription.py new file mode 100644 index 0000000..35e6455 --- /dev/null +++ b/src/novelai_python/_response/user/subscription.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/7 上午9:57 +# @Author : sudoskys +# @File : subscription.py +# @Software: PyCharm +from typing import Optional, Dict, Any, List + +from pydantic import BaseModel, Field + + +class TrainingSteps(BaseModel): + fixedTrainingStepsLeft: int + purchasedTrainingSteps: int + + +class ImageGenerationLimit(BaseModel): + resolution: int + maxPrompts: int + + +class Perks(BaseModel): + maxPriorityActions: int + startPriority: int + moduleTrainingSteps: int + unlimitedMaxPriority: bool + voiceGeneration: bool + imageGeneration: bool + unlimitedImageGeneration: bool + unlimitedImageGenerationLimits: List[ImageGenerationLimit] + contextTokens: int + + +class SubscriptionResp(BaseModel): + tier: int = Field(..., description="Subscription tier") + active: bool = Field(..., description="Subscription status") + expiresAt: int = Field(..., description="Subscription expiration time") + perks: Perks = Field(..., description="Subscription perks") + paymentProcessorData: Optional[Dict[Any, Any]] + trainingStepsLeft: TrainingSteps = Field(..., description="Training steps left") + accountType: int = Field(..., description="Account type") + + @property + def is_active(self): + return self.active + + @property + def anlas_left(self): + return self.trainingStepsLeft.fixedTrainingStepsLeft + + @property + def is_unlimited_image_generation(self): + return self.perks.unlimitedImageGeneration diff --git a/src/novelai_python/sdk/__init__.py b/src/novelai_python/sdk/__init__.py index 3349aff..bdff43d 100644 --- a/src/novelai_python/sdk/__init__.py +++ b/src/novelai_python/sdk/__init__.py @@ -5,4 +5,4 @@ # @Software: PyCharm from .ai.generate_image import GenerateImageInfer, ImageGenerateResp # noqa 401 - +from .user.subscription import Subscription, SubscriptionResp # noqa 401 diff --git a/src/novelai_python/sdk/user/__init__.py b/src/novelai_python/sdk/user/__init__.py new file mode 100644 index 0000000..d049979 --- /dev/null +++ b/src/novelai_python/sdk/user/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/7 上午10:03 +# @Author : sudoskys +# @File : __init__.py.py +# @Software: PyCharm + +from .subscription import * # noqa: F403 diff --git a/src/novelai_python/sdk/user/subscription.py b/src/novelai_python/sdk/user/subscription.py new file mode 100644 index 0000000..4dc311c --- /dev/null +++ b/src/novelai_python/sdk/user/subscription.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/2/7 上午10:04 +# @Author : sudoskys +# @File : subscription.py.py +# @Software: PyCharm +from typing import Optional, Union + +import httpx +from curl_cffi.requests import AsyncSession +from loguru import logger +from pydantic import BaseModel, PrivateAttr + +from ... import APIError, AuthError +from ..._response.user.subscription import SubscriptionResp +from ...credential import JwtCredential +from ...utils import try_jsonfy + + +class Subscription(BaseModel): + _endpoint: Optional[str] = PrivateAttr("https://api.novelai.net") + + @property + def base_url(self): + return f"{self.endpoint.strip('/')}/user/subscription" + + @property + def endpoint(self): + return self._endpoint + + @endpoint.setter + def endpoint(self, value): + self._endpoint = value + + async def request(self, + session: Union[AsyncSession, JwtCredential], + ) -> SubscriptionResp: + """ + Request to get user subscription information + :param session: + :return: + """ + if isinstance(session, JwtCredential): + session = session.session + request_data = {} + logger.debug("Subscription") + try: + assert hasattr(session, "get"), "session must have get method." + response = await session.get( + self.base_url, + ) + if "application/json" not in response.headers.get('Content-Type'): + logger.error(f"Unexpected content type: {response.headers.get('Content-Type')}") + try: + _msg = response.json() + except Exception: + raise APIError( + message=f"Unexpected content type: {response.headers.get('Content-Type')}", + request=request_data, + status_code=response.status_code, + response=try_jsonfy(response.content) + ) + status_code = _msg.get("statusCode", response.status_code) + message = _msg.get("message", "Unknown error") + if status_code in [400, 401, 402]: + # 400 : validation error + # 401 : unauthorized + # 402 : payment required + # 409 : conflict + raise AuthError(message, request=request_data, status_code=status_code, response=_msg) + if status_code in [500]: + # An unknown error occured. + raise APIError(message, request=request_data, status_code=status_code, response=_msg) + raise APIError(message, request=request_data, status_code=status_code, response=_msg) + return SubscriptionResp.model_validate(response.json()) + except httpx.HTTPError as exc: + raise RuntimeError(f"An HTTP error occurred: {exc}") + except APIError as e: + raise e + except Exception as e: + logger.opt(exception=e).exception("An Unexpected error occurred") + raise e diff --git a/src/novelai_python/server.py b/src/novelai_python/server.py index 7f62cb5..3566191 100644 --- a/src/novelai_python/server.py +++ b/src/novelai_python/server.py @@ -10,10 +10,12 @@ import uvicorn from fastapi import FastAPI, Depends, Security from fastapi.security import APIKeyHeader +from loguru import logger from starlette.responses import JSONResponse, StreamingResponse from .credential import JwtCredential, SecretStr from .sdk.ai.generate_image import GenerateImageInfer +from .sdk.user.subscription import Subscription app = FastAPI() token_key = APIKeyHeader(name="Authorization") @@ -35,6 +37,24 @@ async def health(): return {"status": "ok"} +@app.get("/user/subscription") +async def subscription( + current_token: str = Depends(get_current_token) +): + """ + 订阅 + :param current_token: Authorization + :param req: Subscription + :return: + """ + try: + _result = await Subscription().request(session=get_session(current_token)) + return _result.model_dump() + except Exception as e: + logger.exception(e) + return JSONResponse(status_code=500, content=e.__dict__) + + @app.post("/ai/generate_image") async def generate_image( req: GenerateImageInfer, @@ -58,6 +78,7 @@ async def generate_image( 'Content-Disposition': 'attachment;filename=multiple_files.zip' }) except Exception as e: + logger.exception(e) return JSONResponse(status_code=500, content=e.__dict__) diff --git a/tests/test_server_run.py b/tests/test_server_run.py index dc6d885..ef7e1b7 100644 --- a/tests/test_server_run.py +++ b/tests/test_server_run.py @@ -3,9 +3,11 @@ # @Author : sudoskys # @File : test_server_run.py # @Software: PyCharm + +from unittest.mock import patch + from fastapi.testclient import TestClient -from src.novelai_python.sdk.ai.generate_image import GenerateImageInfer from src.novelai_python.server import app, get_session client = TestClient(app) @@ -17,12 +19,28 @@ def test_health_check(): assert response.json() == {"status": "ok"} -def test_generate_image_with_valid_token(): - valid_token = "valid_token" - get_session(valid_token) # to simulate a valid session - response = client.post( - "/ai/generate_image", - headers={"Authorization": valid_token}, - json=GenerateImageInfer(input="1girl").model_dump() - ) - assert response.status_code == 500 +@patch('src.novelai_python.server.Subscription') +def test_subscription_without_api_key(mock_subscription): + mock_subscription.return_value.request.return_value.model_dump.return_value = {"status": "subscribed"} + response = client.get("/user/subscription") + assert response.status_code == 403 + + +@patch('src.novelai_python.server.GenerateImageInfer') +def test_generate_image_without_api_key(mock_generate_image): + mock_generate_image.return_value.generate.return_value = {"status": "image generated"} + response = client.post("/ai/generate_image") + assert response.status_code == 403 + + +def test_get_session_new_token(): + token = "new_token" + session = get_session(token) + assert session.jwt_token.get_secret_value() == token + + +def test_get_session_existing_token(): + token = "existing_token" + get_session(token) + session = get_session(token) + assert session.jwt_token.get_secret_value() == token From 739e4c7f9bea1b216ff7a491a5e306200e0ea7b8 Mon Sep 17 00:00:00 2001 From: sudoskys Date: Wed, 7 Feb 2024 10:30:01 +0800 Subject: [PATCH 2/2] :sparkles: feat(novelai-python): Add user subscription endpoint --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b0565e3..dea5219 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ The goal of this repository is to use Pydantic to build legitimate requests to a ### Roadmap 🚧 - [x] /ai/generate-image +- [x] /user/subscription - [ ] /ai/generate-image/suggest-tags - [ ] /ai/annotate-image - [ ] /ai/classify