Skip to content

Commit

Permalink
Merge pull request #74 from LlmKira/dev
Browse files Browse the repository at this point in the history
(fix): If the model is a string, a cost calculation error will be raised
  • Loading branch information
sudoskys authored Sep 23, 2024
2 parents 5d0135d + 8c3bedf commit a04ca0f
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 94 deletions.
8 changes: 5 additions & 3 deletions playground/generate_image_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import base64
import os
import pathlib

from dotenv import load_dotenv
from loguru import logger
from pydantic import SecretStr

from novelai_python import APIError, Login, LoginCredential
from novelai_python import APIError, LoginCredential
from novelai_python import GenerateImageInfer, ImageGenerateResp, JwtCredential
from novelai_python.sdk.ai.generate_image import Action, Sampler
from novelai_python.utils.useful import enum_to_list
Expand Down Expand Up @@ -53,8 +54,9 @@ async def generate(
print(f"charge: {agent.calculate_cost(is_opus=True)} if you are vip3")
print(f"charge: {agent.calculate_cost(is_opus=False)} if you are not vip3")
result = await agent.request(
session=credential
session=_login_credential
)
logger.info("Using login credential")
except APIError as e:
print(f"Error: {e.message}")
return None
Expand All @@ -68,5 +70,5 @@ async def generate(


load_dotenv()
loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
loop.run_until_complete(generate(image_path="static_refer.png"))
4 changes: 2 additions & 2 deletions playground/vibe_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def generate(
print(f"charge: {agent.calculate_cost(is_opus=True)} if you are vip3")
print(f"charge: {agent.calculate_cost(is_opus=False)} if you are not vip3")
result = await agent.request(
session=credential
session=_login_credential
)
except APIError as e:
print(f"Error: {e.message}")
Expand All @@ -69,5 +69,5 @@ async def generate(


load_dotenv()
loop = asyncio.get_event_loop()
loop = asyncio.new_event_loop()
loop.run_until_complete(generate())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "novelai-python"
version = "0.4.16"
version = "0.4.17"
description = "NovelAI Python Binding With Pydantic"
authors = [
{ name = "sudoskys", email = "[email protected]" },
Expand Down
11 changes: 6 additions & 5 deletions src/novelai_python/sdk/ai/_cost.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math
import random
from typing import List, Optional
from typing import List, Optional, Union

from pydantic import BaseModel

from novelai_python.sdk.ai._const import map, initialN, initial_n, step, newN
from novelai_python.sdk.ai._enum import Sampler, Model, ModelGroups, get_model_group
from novelai_python.sdk.ai._enum import Sampler, Model, ModelGroups, get_model_group, ModelTypeAlias


class Args(BaseModel):
Expand Down Expand Up @@ -43,9 +43,10 @@ def calculate(
is_sm_dynamic: bool,
is_account_active: bool,
sampler: Optional[Sampler],
model: Optional[Model],
model: ModelTypeAlias,
tool: str = None,
is_tool_active: bool = False) -> int:
is_tool_active: bool = False
) -> int:
return CostCalculator.calculate_cost(
Args(
height=height,
Expand All @@ -59,7 +60,7 @@ def calculate(
sampler=sampler,
tool=tool
),
model_group=get_model_group(model.value) if model else None,
model_group=get_model_group(model) if model else None,
is_account_active=is_account_active,
account_tier=account_tier,
is_tool_active=is_tool_active
Expand Down
158 changes: 82 additions & 76 deletions src/novelai_python/sdk/ai/_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @File : _enum.py
# @Software: PyCharm
from enum import Enum, IntEnum
from typing import List
from typing import List, Optional, Union

from pydantic.dataclasses import dataclass

Expand Down Expand Up @@ -36,56 +36,6 @@ class NoiseSchedule(Enum):
POLYEXPONENTIAL = "polyexponential"


def get_supported_noise_schedule(sample_type: Sampler) -> List[NoiseSchedule]:
"""
Get supported noise schedule for a given sample type
:param sample_type: Sampler
:return: List[NoiseSchedule]
"""
if sample_type in [
Sampler.K_EULER_ANCESTRAL,
Sampler.K_DPMPP_2S_ANCESTRAL,
Sampler.K_DPMPP_2M,
Sampler.K_DPMPP_2M_SDE,
Sampler.K_DPMPP_SDE,
Sampler.K_EULER
]:
return [
NoiseSchedule.NATIVE,
NoiseSchedule.KARRAS,
NoiseSchedule.EXPONENTIAL,
NoiseSchedule.POLYEXPONENTIAL
]
elif sample_type in [Sampler.K_DPM_2]:
return [
NoiseSchedule.EXPONENTIAL,
NoiseSchedule.POLYEXPONENTIAL
]
else:
return []


def get_default_noise_schedule(sample_type: Sampler) -> NoiseSchedule:
"""
Get default noise schedule for a given sample type
:param sample_type: Sampler
:return: NoiseSchedule
"""
if sample_type in [
Sampler.K_EULER_ANCESTRAL,
Sampler.K_DPMPP_2S_ANCESTRAL,
Sampler.K_DPMPP_2M,
Sampler.K_DPMPP_2M_SDE,
Sampler.K_DPMPP_SDE,
Sampler.K_EULER
]:
return NoiseSchedule.KARRAS
elif sample_type in [Sampler.K_DPM_2]:
return NoiseSchedule.EXPONENTIAL
else:
return NoiseSchedule.NATIVE


class UCPreset(IntEnum):
TYPE0 = 0
TYPE1 = 1
Expand Down Expand Up @@ -171,30 +121,6 @@ class ModelGroups(Enum):
STABLE_DIFFUSION_XL_FURRY = "stable_diffusion_xl_furry"


def get_model_group(model: str) -> ModelGroups:
if isinstance(model, Enum):
model = model.value
mapping = {
"stable-diffusion": ModelGroups.STABLE_DIFFUSION,
"nai-diffusion": ModelGroups.STABLE_DIFFUSION,
"safe-diffusion": ModelGroups.STABLE_DIFFUSION,
"waifu-diffusion": ModelGroups.STABLE_DIFFUSION,
"nai-diffusion-furry": ModelGroups.STABLE_DIFFUSION,
"curated-diffusion-test": ModelGroups.STABLE_DIFFUSION,
"nai-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION,
"safe-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION,
"furry-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION,
"nai-diffusion-2": ModelGroups.STABLE_DIFFUSION_GROUP_2,
"nai-diffusion-xl": ModelGroups.STABLE_DIFFUSION_XL,
"nai-diffusion-3": ModelGroups.STABLE_DIFFUSION_XL,
"nai-diffusion-3-inpainting": ModelGroups.STABLE_DIFFUSION_XL,
"custom": ModelGroups.STABLE_DIFFUSION_XL,
"nai-diffusion-furry-3": ModelGroups.STABLE_DIFFUSION_XL_FURRY,
"nai-diffusion-furry-3-inpainting": ModelGroups.STABLE_DIFFUSION_XL_FURRY,
}
return mapping.get(model, ModelGroups.STABLE_DIFFUSION)


INPAINTING_MODEL_LIST = [
Model.NAI_DIFFUSION_3_INPAINTING,
Model.NAI_DIFFUSION_INPAINTING,
Expand Down Expand Up @@ -224,8 +150,88 @@ def get_model_group(model: str) -> ModelGroups:
"Stable Diffusion XL 9CC2F394": Model.NAI_DIFFUSION_FURRY_3,
}

ModelTypeAlias = Optional[Union[Model, str]]
ImageBytesTypeAlias = Optional[Union[str, bytes]]
UCPresetTypeAlias = Optional[Union[UCPreset, int]]


def get_supported_noise_schedule(sample_type: Sampler) -> List[NoiseSchedule]:
"""
Get supported noise schedule for a given sample type
:param sample_type: Sampler
:return: List[NoiseSchedule]
"""
if sample_type in [
Sampler.K_EULER_ANCESTRAL,
Sampler.K_DPMPP_2S_ANCESTRAL,
Sampler.K_DPMPP_2M,
Sampler.K_DPMPP_2M_SDE,
Sampler.K_DPMPP_SDE,
Sampler.K_EULER
]:
return [
NoiseSchedule.NATIVE,
NoiseSchedule.KARRAS,
NoiseSchedule.EXPONENTIAL,
NoiseSchedule.POLYEXPONENTIAL
]
elif sample_type in [Sampler.K_DPM_2]:
return [
NoiseSchedule.EXPONENTIAL,
NoiseSchedule.POLYEXPONENTIAL
]
else:
return []


def get_default_noise_schedule(sample_type: Sampler) -> NoiseSchedule:
"""
Get default noise schedule for a given sample type
:param sample_type: Sampler
:return: NoiseSchedule
"""
if sample_type in [
Sampler.K_EULER_ANCESTRAL,
Sampler.K_DPMPP_2S_ANCESTRAL,
Sampler.K_DPMPP_2M,
Sampler.K_DPMPP_2M_SDE,
Sampler.K_DPMPP_SDE,
Sampler.K_EULER
]:
return NoiseSchedule.KARRAS
elif sample_type in [Sampler.K_DPM_2]:
return NoiseSchedule.EXPONENTIAL
else:
return NoiseSchedule.NATIVE


def get_model_group(model: ModelTypeAlias) -> ModelGroups:
if isinstance(model, Enum):
model = model.value
else:
model = str(model)
mapping = {
"stable-diffusion": ModelGroups.STABLE_DIFFUSION,
"nai-diffusion": ModelGroups.STABLE_DIFFUSION,
"safe-diffusion": ModelGroups.STABLE_DIFFUSION,
"waifu-diffusion": ModelGroups.STABLE_DIFFUSION,
"nai-diffusion-furry": ModelGroups.STABLE_DIFFUSION,
"curated-diffusion-test": ModelGroups.STABLE_DIFFUSION,
"nai-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION,
"safe-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION,
"furry-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION,
"nai-diffusion-2": ModelGroups.STABLE_DIFFUSION_GROUP_2,
"nai-diffusion-xl": ModelGroups.STABLE_DIFFUSION_XL,
"nai-diffusion-3": ModelGroups.STABLE_DIFFUSION_XL,
"nai-diffusion-3-inpainting": ModelGroups.STABLE_DIFFUSION_XL,
"custom": ModelGroups.STABLE_DIFFUSION_XL,
"nai-diffusion-furry-3": ModelGroups.STABLE_DIFFUSION_XL_FURRY,
"nai-diffusion-furry-3-inpainting": ModelGroups.STABLE_DIFFUSION_XL_FURRY,
}
return mapping.get(model, ModelGroups.STABLE_DIFFUSION)


def get_default_uc_preset(model: str, uc_preset: int) -> str:
def get_default_uc_preset(model: ModelTypeAlias, uc_preset: int) -> str:
if isinstance(model, Enum):
model = model.value
if isinstance(uc_preset, Enum):
Expand Down
13 changes: 7 additions & 6 deletions src/novelai_python/sdk/ai/generate_image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

from novelai_python.sdk.ai._cost import CostCalculator
from novelai_python.sdk.ai._enum import Model, Sampler, NoiseSchedule, ControlNetModel, Action, UCPreset, \
INPAINTING_MODEL_LIST, get_default_noise_schedule, get_supported_noise_schedule, get_default_uc_preset
INPAINTING_MODEL_LIST, get_default_noise_schedule, get_supported_noise_schedule, get_default_uc_preset, \
ModelTypeAlias, ImageBytesTypeAlias, UCPresetTypeAlias
from ...schema import ApiBaseModel
from ...._exceptions import APIError, AuthError, ConcurrentGenerationError, SessionHttpError
from ...._response.ai.generate_image import ImageGenerateResp, RequestParams
Expand All @@ -46,7 +47,7 @@ class Params(BaseModel):
叠加原始图像
防止现有图像发生更改,但可能会沿蒙版边缘引入接缝。
"""
mask: Optional[Union[str, bytes]] = None
mask: ImageBytesTypeAlias = None
"""Mask for Inpainting"""
cfg_rescale: Optional[float] = Field(0, ge=0, le=1, multiple_of=0.02)
"""Prompt Guidance Rescale"""
Expand All @@ -56,7 +57,7 @@ class Params(BaseModel):
"""Decrisp:Reduce artifacts caused by high prompt guidance values"""
height: Optional[int] = Field(1216, ge=64, le=49152)
"""Height For Generate Image"""
image: Optional[Union[str, bytes]] = None
image: ImageBytesTypeAlias = None
"""Image for img2img"""
strength: Optional[float] = Field(default=0.5, ge=0.01, le=0.99, multiple_of=0.01)
"""Strength for img2img"""
Expand Down Expand Up @@ -138,7 +139,7 @@ def reference_strength_multiple_validator(cls, v):
# TODO: find out the usage
steps: Optional[int] = Field(23, ge=1, le=50)
"""Steps"""
ucPreset: Optional[Union[UCPreset, int]] = Field(None, ge=0)
ucPreset: UCPresetTypeAlias = Field(None, ge=0)
"""The Negative Prompt Preset, Bigger or equal to 0"""
uncond_scale: Optional[float] = Field(1.0, ge=0, le=1.5, multiple_of=0.05)
"""Undesired Content Strength"""
Expand Down Expand Up @@ -324,7 +325,7 @@ def endpoint(self, value):

action: Union[str, Action] = Field(Action.GENERATE, description="Mode for img generate")
input: str = "1girl, best quality, amazing quality, very aesthetic, absurdres"
model: Optional[Union[Model, str]] = "nai-diffusion-3"
model: ModelTypeAlias = "nai-diffusion-3"
parameters: Params = Params()
model_config = ConfigDict(extra="ignore")

Expand Down Expand Up @@ -420,7 +421,7 @@ def build(cls,
model: Union[Model, str] = "nai-diffusion-3",
action: Union[Action, str] = 'generate',
negative_prompt: str = "",
ucPreset: Optional[Union[UCPreset, int]] = UCPreset.TYPE0,
ucPreset: UCPresetTypeAlias = UCPreset.TYPE0,
steps: int = 28,
seed: int = None,
scale: float = 5.0,
Expand Down
2 changes: 1 addition & 1 deletion src/novelai_python/sdk/user/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def request(self,
if override_headers:
sess.headers.clear()
sess.headers.update(override_headers)
logger.debug("Login")
logger.debug("Fetching login-credential")
try:
assert hasattr(sess, "post"), "session must have get method."
response = await sess.post(
Expand Down

0 comments on commit a04ca0f

Please sign in to comment.