Skip to content

Commit

Permalink
Merge pull request #4 from LlmKira/dev
Browse files Browse the repository at this point in the history
feat(novelai-python): Add user subscription endpoint
  • Loading branch information
sudoskys authored Feb 7, 2024
2 parents ae32ac3 + 739e4c7 commit bea4338
Show file tree
Hide file tree
Showing 12 changed files with 248 additions and 16 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The goal of this repository is to use Pydantic to build legitimate requests to a
### Roadmap 🚧

- [x] /ai/generate-image
- [x] /user/subscription
- [ ] /ai/generate-image/suggest-tags
- [ ] /ai/annotate-image
- [ ] /ai/classify
Expand Down
41 changes: 41 additions & 0 deletions playground/subscription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
# @Time : 2024/2/7 上午10:16
# @Author : sudoskys
# @File : subscription.py
# @Software: PyCharm
import asyncio
import os

from dotenv import load_dotenv
from loguru import logger
from pydantic import SecretStr

from novelai_python import APIError, SubscriptionResp
from novelai_python import Subscription, JwtCredential

load_dotenv()

enhance = "year 2023,dynamic angle, best quality, amazing quality, very aesthetic, absurdres"
token = None
jwt = os.getenv("NOVELAI_JWT") or token


async def main():
globe_s = JwtCredential(jwt_token=SecretStr(jwt))
try:
_res = await Subscription().request(
session=globe_s
)
except APIError as e:
logger.exception(e)
print(e.response)
return

_res: SubscriptionResp
print(_res)
print(_res.is_active)
print(_res.anlas_left)


loop = asyncio.get_event_loop()
loop.run_until_complete(main())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "novelai-python"
version = "0.1.7"
version = "0.1.8"
description = "Novelai Python Binding With Pydantic"
authors = [
{ name = "sudoskys", email = "[email protected]" },
Expand Down
9 changes: 7 additions & 2 deletions src/novelai_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@
AuthError,
NovelAiError,
) # noqa: F401, F403
from ._response import ImageGenerateResp # noqa: F401, F403
from .credential import JwtCredential # noqa: F401, F403
from .sdk.ai import GenerateImageInfer # noqa: F401, F403
from .sdk.ai import GenerateImageInfer, ImageGenerateResp # noqa: F401, F403
from .sdk.user import Subscription, SubscriptionResp # noqa: F401, F403

__all__ = [
"GenerateImageInfer",
"ImageGenerateResp",

"Subscription",
"SubscriptionResp",

"JwtCredential",

"APIError",
"AuthError",
"NovelAiError",
Expand Down
5 changes: 3 additions & 2 deletions src/novelai_python/_response/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# @File : __init__.py.py
# @Software: PyCharm
from .ai.generate_image import ImageGenerateResp

from .user.subscription import SubscriptionResp
__all__ = [
"ImageGenerateResp"
"ImageGenerateResp",
"SubscriptionResp",
]
5 changes: 5 additions & 0 deletions src/novelai_python/_response/user/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-
# @Time : 2024/2/7 上午9:57
# @Author : sudoskys
# @File : __init__.py.py
# @Software: PyCharm
52 changes: 52 additions & 0 deletions src/novelai_python/_response/user/subscription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# @Time : 2024/2/7 上午9:57
# @Author : sudoskys
# @File : subscription.py
# @Software: PyCharm
from typing import Optional, Dict, Any, List

from pydantic import BaseModel, Field


class TrainingSteps(BaseModel):
fixedTrainingStepsLeft: int
purchasedTrainingSteps: int


class ImageGenerationLimit(BaseModel):
resolution: int
maxPrompts: int


class Perks(BaseModel):
maxPriorityActions: int
startPriority: int
moduleTrainingSteps: int
unlimitedMaxPriority: bool
voiceGeneration: bool
imageGeneration: bool
unlimitedImageGeneration: bool
unlimitedImageGenerationLimits: List[ImageGenerationLimit]
contextTokens: int


class SubscriptionResp(BaseModel):
tier: int = Field(..., description="Subscription tier")
active: bool = Field(..., description="Subscription status")
expiresAt: int = Field(..., description="Subscription expiration time")
perks: Perks = Field(..., description="Subscription perks")
paymentProcessorData: Optional[Dict[Any, Any]]
trainingStepsLeft: TrainingSteps = Field(..., description="Training steps left")
accountType: int = Field(..., description="Account type")

@property
def is_active(self):
return self.active

@property
def anlas_left(self):
return self.trainingStepsLeft.fixedTrainingStepsLeft

@property
def is_unlimited_image_generation(self):
return self.perks.unlimitedImageGeneration
2 changes: 1 addition & 1 deletion src/novelai_python/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# @Software: PyCharm

from .ai.generate_image import GenerateImageInfer, ImageGenerateResp # noqa 401

from .user.subscription import Subscription, SubscriptionResp # noqa 401
7 changes: 7 additions & 0 deletions src/novelai_python/sdk/user/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-
# @Time : 2024/2/7 上午10:03
# @Author : sudoskys
# @File : __init__.py.py
# @Software: PyCharm

from .subscription import * # noqa: F403
81 changes: 81 additions & 0 deletions src/novelai_python/sdk/user/subscription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# @Time : 2024/2/7 上午10:04
# @Author : sudoskys
# @File : subscription.py.py
# @Software: PyCharm
from typing import Optional, Union

import httpx
from curl_cffi.requests import AsyncSession
from loguru import logger
from pydantic import BaseModel, PrivateAttr

from ... import APIError, AuthError
from ..._response.user.subscription import SubscriptionResp
from ...credential import JwtCredential
from ...utils import try_jsonfy


class Subscription(BaseModel):
_endpoint: Optional[str] = PrivateAttr("https://api.novelai.net")

@property
def base_url(self):
return f"{self.endpoint.strip('/')}/user/subscription"

@property
def endpoint(self):
return self._endpoint

@endpoint.setter
def endpoint(self, value):
self._endpoint = value

async def request(self,
session: Union[AsyncSession, JwtCredential],
) -> SubscriptionResp:
"""
Request to get user subscription information
:param session:
:return:
"""
if isinstance(session, JwtCredential):
session = session.session
request_data = {}
logger.debug("Subscription")
try:
assert hasattr(session, "get"), "session must have get method."
response = await session.get(
self.base_url,
)
if "application/json" not in response.headers.get('Content-Type'):
logger.error(f"Unexpected content type: {response.headers.get('Content-Type')}")
try:
_msg = response.json()
except Exception:
raise APIError(
message=f"Unexpected content type: {response.headers.get('Content-Type')}",
request=request_data,
status_code=response.status_code,
response=try_jsonfy(response.content)
)
status_code = _msg.get("statusCode", response.status_code)
message = _msg.get("message", "Unknown error")
if status_code in [400, 401, 402]:
# 400 : validation error
# 401 : unauthorized
# 402 : payment required
# 409 : conflict
raise AuthError(message, request=request_data, status_code=status_code, response=_msg)
if status_code in [500]:
# An unknown error occured.
raise APIError(message, request=request_data, status_code=status_code, response=_msg)
raise APIError(message, request=request_data, status_code=status_code, response=_msg)
return SubscriptionResp.model_validate(response.json())
except httpx.HTTPError as exc:
raise RuntimeError(f"An HTTP error occurred: {exc}")
except APIError as e:
raise e
except Exception as e:
logger.opt(exception=e).exception("An Unexpected error occurred")
raise e
21 changes: 21 additions & 0 deletions src/novelai_python/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import uvicorn
from fastapi import FastAPI, Depends, Security
from fastapi.security import APIKeyHeader
from loguru import logger
from starlette.responses import JSONResponse, StreamingResponse

from .credential import JwtCredential, SecretStr
from .sdk.ai.generate_image import GenerateImageInfer
from .sdk.user.subscription import Subscription

app = FastAPI()
token_key = APIKeyHeader(name="Authorization")
Expand All @@ -35,6 +37,24 @@ async def health():
return {"status": "ok"}


@app.get("/user/subscription")
async def subscription(
current_token: str = Depends(get_current_token)
):
"""
订阅
:param current_token: Authorization
:param req: Subscription
:return:
"""
try:
_result = await Subscription().request(session=get_session(current_token))
return _result.model_dump()
except Exception as e:
logger.exception(e)
return JSONResponse(status_code=500, content=e.__dict__)


@app.post("/ai/generate_image")
async def generate_image(
req: GenerateImageInfer,
Expand All @@ -58,6 +78,7 @@ async def generate_image(
'Content-Disposition': 'attachment;filename=multiple_files.zip'
})
except Exception as e:
logger.exception(e)
return JSONResponse(status_code=500, content=e.__dict__)


Expand Down
38 changes: 28 additions & 10 deletions tests/test_server_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
# @Author : sudoskys
# @File : test_server_run.py
# @Software: PyCharm

from unittest.mock import patch

from fastapi.testclient import TestClient

from src.novelai_python.sdk.ai.generate_image import GenerateImageInfer
from src.novelai_python.server import app, get_session

client = TestClient(app)
Expand All @@ -17,12 +19,28 @@ def test_health_check():
assert response.json() == {"status": "ok"}


def test_generate_image_with_valid_token():
valid_token = "valid_token"
get_session(valid_token) # to simulate a valid session
response = client.post(
"/ai/generate_image",
headers={"Authorization": valid_token},
json=GenerateImageInfer(input="1girl").model_dump()
)
assert response.status_code == 500
@patch('src.novelai_python.server.Subscription')
def test_subscription_without_api_key(mock_subscription):
mock_subscription.return_value.request.return_value.model_dump.return_value = {"status": "subscribed"}
response = client.get("/user/subscription")
assert response.status_code == 403


@patch('src.novelai_python.server.GenerateImageInfer')
def test_generate_image_without_api_key(mock_generate_image):
mock_generate_image.return_value.generate.return_value = {"status": "image generated"}
response = client.post("/ai/generate_image")
assert response.status_code == 403


def test_get_session_new_token():
token = "new_token"
session = get_session(token)
assert session.jwt_token.get_secret_value() == token


def test_get_session_existing_token():
token = "existing_token"
get_session(token)
session = get_session(token)
assert session.jwt_token.get_secret_value() == token

0 comments on commit bea4338

Please sign in to comment.