Skip to content

Commit

Permalink
Merge pull request #19 from LlmKira/dev
Browse files Browse the repository at this point in the history
/ai/upscale
  • Loading branch information
sudoskys authored Feb 13, 2024
2 parents b2f3b2b + 571ec3c commit 7095acd
Show file tree
Hide file tree
Showing 18 changed files with 419 additions and 19 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# novelai-python

[![PyPI version](https://badge.fury.io/py/novelai_python.svg)](https://badge.fury.io/py/novelai_python)
[![PyPI version](https://badge.fury.io/py/novelai-python.svg)](https://badge.fury.io/py/novelai-python)
[![Downloads](https://pepy.tech/badge/novelai_python)](https://pepy.tech/project/novelai_python)

The goal of this repository is to use Pydantic to build legitimate requests to access the Novelai API service.
Expand All @@ -11,16 +11,20 @@ The goal of this repository is to use Pydantic to build legitimate requests to a
- [x] /user/subscription
- [x] /user/login
- [x] /user/information
- [x] /ai/upscale
- [ ] /ai/generate-image/suggest-tags
- [ ] /ai/annotate-image
- [ ] /ai/classify
- [ ] /ai/upscale
- [ ] /ai/generate-prompt
- [ ] /ai/generate
- [ ] /ai/generate-voice

### Usage 🖥️

```shell
pip install novelai-python
```

More examples can be found in the [playground](https://github.com/LlmKira/novelai-python/tree/main/playground) directory.

```python
Expand Down
4 changes: 3 additions & 1 deletion playground/generate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from novelai_python import APIError, Login
from novelai_python import GenerateImageInfer, ImageGenerateResp, JwtCredential
from novelai_python.sdk.ai.generate_image import Action
from novelai_python.utils.useful import enum_to_list

load_dotenv()

Expand All @@ -25,13 +26,14 @@ async def main():
_res = await Login.build(user_name=os.getenv("NOVELAI_USER"), password=os.getenv("NOVELAI_PASS")
).request()
try:
print(f"Action List:{enum_to_list(Action)}")
gen = GenerateImageInfer.build(
prompt=f"1girl, winter, jacket, sfw, angel, flower,{enhance}",
action=Action.GENERATE,
)
cost = gen.calculate_cost(is_opus=True)
print(f"charge: {cost} if you are vip3")
print(f"charge: {gen.calculate_cost(is_opus=True)}")
print(f"charge: {gen.calculate_cost(is_opus=False)} if you are not vip3")
_res = await gen.request(
session=globe_s, remove_sign=True
)
Expand Down
55 changes: 55 additions & 0 deletions playground/upscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# @Time : 2024/2/13 上午11:58
# @Author : sudoskys
# @File : upscale.py
# @Software: PyCharm
# -*- coding: utf-8 -*-
# @Time : 2024/2/14 下午5:20
# @Author : sudoskys
# @File : upscale_demo.py
# @Software: PyCharm
# To run the demo, you need an event loop, for instance by using asyncio
import asyncio
import os

from dotenv import load_dotenv
from pydantic import SecretStr

from novelai_python import APIError, Upscale
from novelai_python import UpscaleResp, JwtCredential
from novelai_python.sdk.ai.generate_image import Action
from novelai_python.utils.useful import enum_to_list

load_dotenv()

token = None
jwt = os.getenv("NOVELAI_JWT") or token


async def main():
globe_s = JwtCredential(jwt_token=SecretStr(jwt))
if not os.path.exists("generate_image.png"):
raise FileNotFoundError("generate_image.png not found")
with open("generate_image.png", "rb") as f:
data = f.read()
try:
print(f"Action List:{enum_to_list(Action)}")
upscale = Upscale(image=data) # Auto detect image size | base64

_res = await upscale.request(
session=globe_s, remove_sign=True
)
except APIError as e:
print(e.response)
return

# Meta
_res: UpscaleResp
print(_res.meta.endpoint)
file = _res.files
with open("upscale.py.png", "wb") as f:
f.write(file[1])


loop = asyncio.get_event_loop()
loop.run_until_complete(main())
4 changes: 4 additions & 0 deletions src/novelai_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
from .sdk import Information, InformationResp
from .sdk import Login, LoginResp
from .sdk import Subscription, SubscriptionResp
from .sdk import Upscale, UpscaleResp

__all__ = [
"GenerateImageInfer",
"ImageGenerateResp",

"Upscale",
"UpscaleResp",

"Subscription",
"SubscriptionResp",

Expand Down
4 changes: 3 additions & 1 deletion src/novelai_python/_response/ai/generate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from pydantic import BaseModel

from ..schema import RespBase

class ImageGenerateResp(BaseModel):

class ImageGenerateResp(RespBase):
class RequestParams(BaseModel):
endpoint: str
raw_request: dict = None
Expand Down
19 changes: 19 additions & 0 deletions src/novelai_python/_response/ai/upscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-
# @Time : 2024/2/13 上午11:29
# @Author : sudoskys
# @File : upscale.py
# @Software: PyCharm
from typing import Tuple

from pydantic import BaseModel

from ..schema import RespBase


class UpscaleResp(RespBase):
class RequestParams(BaseModel):
endpoint: str
raw_request: dict = None

meta: RequestParams
files: Tuple[str, bytes] = None
11 changes: 11 additions & 0 deletions src/novelai_python/_response/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-
# @Time : 2024/2/13 上午11:30
# @Author : sudoskys
# @File : schema.py
# @Software: PyCharm

from pydantic import BaseModel


class RespBase(BaseModel):
pass
6 changes: 4 additions & 2 deletions src/novelai_python/_response/user/information.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# @File : information.py
# @Software: PyCharm

from pydantic import BaseModel, Field
from pydantic import Field

from ..schema import RespBase

class InformationResp(BaseModel):

class InformationResp(RespBase):
emailVerified: bool = Field(..., description="Email verification status")
emailVerificationLetterSent: bool = Field(..., description="Email verification letter sent status")
trialActivated: bool = Field(..., description="Trial activation status")
Expand Down
5 changes: 3 additions & 2 deletions src/novelai_python/_response/user/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
# @Author : sudoskys
# @File : login.py
# @Software: PyCharm
from pydantic import BaseModel

from ..schema import RespBase

class LoginResp(BaseModel):

class LoginResp(RespBase):
accessToken: str
4 changes: 3 additions & 1 deletion src/novelai_python/_response/user/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from pydantic import BaseModel, Field

from ..schema import RespBase


class TrainingSteps(BaseModel):
fixedTrainingStepsLeft: int
Expand All @@ -30,7 +32,7 @@ class Perks(BaseModel):
contextTokens: int


class SubscriptionResp(BaseModel):
class SubscriptionResp(RespBase):
tier: int = Field(..., description="Subscription tier")
active: bool = Field(..., description="Subscription status")
expiresAt: int = Field(..., description="Subscription expiration time")
Expand Down
1 change: 1 addition & 0 deletions src/novelai_python/credential/JwtToken.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ async def get_session(self, timeout: int = 180, update_headers: dict = None):
self._session = AsyncSession(timeout=timeout, headers={
"Accept": "*/*",
"Accept-Language": "zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2",
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:123.0) Gecko/20100101 Firefox/123.0",
"Accept-Encoding": "gzip, deflate, br",
"Authorization": f"Bearer {self.jwt_token.get_secret_value()}",
"Content-Type": "application/json",
Expand Down
1 change: 1 addition & 0 deletions src/novelai_python/credential/UserAuth.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ async def get_session(self, timeout: int = 180, update_headers: dict = None):
resp = await Login.build(user_name=self.username, password=self.password.get_secret_value()).request()
self._session = AsyncSession(timeout=timeout, headers={
"Accept": "*/*",
"User-Agent": "Mozilla/5.0 (X11; Linux x86_64; rv:123.0) Gecko/20100101 Firefox/123.0",
"Accept-Language": "zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2",
"Accept-Encoding": "gzip, deflate, br",
"Authorization": f"Bearer {resp.accessToken}",
Expand Down
1 change: 1 addition & 0 deletions src/novelai_python/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .user.information import Information, InformationResp # noqa 401
from .user.login import Login, LoginResp # noqa 401
from .user.subscription import Subscription, SubscriptionResp # noqa 401
from .ai.upscale import Upscale, UpscaleResp # noqa 401
23 changes: 19 additions & 4 deletions src/novelai_python/sdk/ai/generate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..schema import ApiBaseModel
from ..._exceptions import APIError, AuthError, ConcurrentGenerationError, SessionHttpError
from ..._response import ImageGenerateResp
from ..._response.ai.generate_image import ImageGenerateResp
from ...credential import CredentialBase
from ...utils import try_jsonfy, NovelAiMetadata

Expand Down Expand Up @@ -102,7 +102,6 @@ class Resolution(Enum):

class GenerateImageInfer(ApiBaseModel):
_endpoint: Optional[str] = PrivateAttr("https://api.novelai.net")
_charge: bool = PrivateAttr(False)

class Params(BaseModel):
# Inpaint
Expand Down Expand Up @@ -412,14 +411,30 @@ async def request(self,
request_data = self.model_dump(mode="json", exclude_none=True)
# Header
if isinstance(session, AsyncSession):
session.headers.update(self.necessary_headers(request_data))
session.headers.update(await self.necessary_headers(request_data))
elif isinstance(session, CredentialBase):
update_header = await self.necessary_headers(request_data)
session = await session.get_session(update_headers=update_header)
if override_headers:
session.headers.clear()
session.headers.update(override_headers)
logger.debug(f"Request Data: {request_data}")
try:
_log_data = request_data.copy()
if self.action == Action.GENERATE:
logger.debug(f"Request Data: {_log_data}")
else:
_log_data.get("parameters", {}).update({
"image": "base64 data" if self.parameters.image else "None",
}
)
_log_data.get("parameters", {}).update(
{
"mask": "base64 data" if self.parameters.mask else "None",
}
)
logger.debug(f"Request Data: {request_data}")
except Exception as e:
logger.warning(f"Error when print log data: {e}")
try:
assert hasattr(session, "post"), "session must have post method."
response = await session.post(
Expand Down
Loading

0 comments on commit 7095acd

Please sign in to comment.