diff --git a/playground/generate_image_img2img.py b/playground/generate_image_img2img.py index fe58766..97cc120 100755 --- a/playground/generate_image_img2img.py +++ b/playground/generate_image_img2img.py @@ -7,11 +7,12 @@ import base64 import os import pathlib + from dotenv import load_dotenv from loguru import logger from pydantic import SecretStr -from novelai_python import APIError, Login, LoginCredential +from novelai_python import APIError, LoginCredential from novelai_python import GenerateImageInfer, ImageGenerateResp, JwtCredential from novelai_python.sdk.ai.generate_image import Action, Sampler from novelai_python.utils.useful import enum_to_list @@ -53,8 +54,9 @@ async def generate( print(f"charge: {agent.calculate_cost(is_opus=True)} if you are vip3") print(f"charge: {agent.calculate_cost(is_opus=False)} if you are not vip3") result = await agent.request( - session=credential + session=_login_credential ) + logger.info("Using login credential") except APIError as e: print(f"Error: {e.message}") return None @@ -68,5 +70,5 @@ async def generate( load_dotenv() -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() loop.run_until_complete(generate(image_path="static_refer.png")) diff --git a/playground/vibe_img2img.py b/playground/vibe_img2img.py index f06e111..ab13145 100644 --- a/playground/vibe_img2img.py +++ b/playground/vibe_img2img.py @@ -55,7 +55,7 @@ async def generate( print(f"charge: {agent.calculate_cost(is_opus=True)} if you are vip3") print(f"charge: {agent.calculate_cost(is_opus=False)} if you are not vip3") result = await agent.request( - session=credential + session=_login_credential ) except APIError as e: print(f"Error: {e.message}") @@ -69,5 +69,5 @@ async def generate( load_dotenv() -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() loop.run_until_complete(generate()) diff --git a/pyproject.toml b/pyproject.toml index 2db000b..67e9c59 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "novelai-python" -version = "0.4.16" +version = "0.4.17" description = "NovelAI Python Binding With Pydantic" authors = [ { name = "sudoskys", email = "coldlando@hotmail.com" }, diff --git a/src/novelai_python/sdk/ai/_cost.py b/src/novelai_python/sdk/ai/_cost.py index a0a8bd5..353a278 100644 --- a/src/novelai_python/sdk/ai/_cost.py +++ b/src/novelai_python/sdk/ai/_cost.py @@ -1,11 +1,11 @@ import math import random -from typing import List, Optional +from typing import List, Optional, Union from pydantic import BaseModel from novelai_python.sdk.ai._const import map, initialN, initial_n, step, newN -from novelai_python.sdk.ai._enum import Sampler, Model, ModelGroups, get_model_group +from novelai_python.sdk.ai._enum import Sampler, Model, ModelGroups, get_model_group, ModelTypeAlias class Args(BaseModel): @@ -43,9 +43,10 @@ def calculate( is_sm_dynamic: bool, is_account_active: bool, sampler: Optional[Sampler], - model: Optional[Model], + model: ModelTypeAlias, tool: str = None, - is_tool_active: bool = False) -> int: + is_tool_active: bool = False + ) -> int: return CostCalculator.calculate_cost( Args( height=height, @@ -59,7 +60,7 @@ def calculate( sampler=sampler, tool=tool ), - model_group=get_model_group(model.value) if model else None, + model_group=get_model_group(model) if model else None, is_account_active=is_account_active, account_tier=account_tier, is_tool_active=is_tool_active diff --git a/src/novelai_python/sdk/ai/_enum.py b/src/novelai_python/sdk/ai/_enum.py index 85f7518..dc1d184 100644 --- a/src/novelai_python/sdk/ai/_enum.py +++ b/src/novelai_python/sdk/ai/_enum.py @@ -4,7 +4,7 @@ # @File : _enum.py # @Software: PyCharm from enum import Enum, IntEnum -from typing import List +from typing import List, Optional, Union from pydantic.dataclasses import dataclass @@ -36,56 +36,6 @@ class NoiseSchedule(Enum): POLYEXPONENTIAL = "polyexponential" -def get_supported_noise_schedule(sample_type: Sampler) -> List[NoiseSchedule]: - """ - Get supported noise schedule for a given sample type - :param sample_type: Sampler - :return: List[NoiseSchedule] - """ - if sample_type in [ - Sampler.K_EULER_ANCESTRAL, - Sampler.K_DPMPP_2S_ANCESTRAL, - Sampler.K_DPMPP_2M, - Sampler.K_DPMPP_2M_SDE, - Sampler.K_DPMPP_SDE, - Sampler.K_EULER - ]: - return [ - NoiseSchedule.NATIVE, - NoiseSchedule.KARRAS, - NoiseSchedule.EXPONENTIAL, - NoiseSchedule.POLYEXPONENTIAL - ] - elif sample_type in [Sampler.K_DPM_2]: - return [ - NoiseSchedule.EXPONENTIAL, - NoiseSchedule.POLYEXPONENTIAL - ] - else: - return [] - - -def get_default_noise_schedule(sample_type: Sampler) -> NoiseSchedule: - """ - Get default noise schedule for a given sample type - :param sample_type: Sampler - :return: NoiseSchedule - """ - if sample_type in [ - Sampler.K_EULER_ANCESTRAL, - Sampler.K_DPMPP_2S_ANCESTRAL, - Sampler.K_DPMPP_2M, - Sampler.K_DPMPP_2M_SDE, - Sampler.K_DPMPP_SDE, - Sampler.K_EULER - ]: - return NoiseSchedule.KARRAS - elif sample_type in [Sampler.K_DPM_2]: - return NoiseSchedule.EXPONENTIAL - else: - return NoiseSchedule.NATIVE - - class UCPreset(IntEnum): TYPE0 = 0 TYPE1 = 1 @@ -171,30 +121,6 @@ class ModelGroups(Enum): STABLE_DIFFUSION_XL_FURRY = "stable_diffusion_xl_furry" -def get_model_group(model: str) -> ModelGroups: - if isinstance(model, Enum): - model = model.value - mapping = { - "stable-diffusion": ModelGroups.STABLE_DIFFUSION, - "nai-diffusion": ModelGroups.STABLE_DIFFUSION, - "safe-diffusion": ModelGroups.STABLE_DIFFUSION, - "waifu-diffusion": ModelGroups.STABLE_DIFFUSION, - "nai-diffusion-furry": ModelGroups.STABLE_DIFFUSION, - "curated-diffusion-test": ModelGroups.STABLE_DIFFUSION, - "nai-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION, - "safe-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION, - "furry-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION, - "nai-diffusion-2": ModelGroups.STABLE_DIFFUSION_GROUP_2, - "nai-diffusion-xl": ModelGroups.STABLE_DIFFUSION_XL, - "nai-diffusion-3": ModelGroups.STABLE_DIFFUSION_XL, - "nai-diffusion-3-inpainting": ModelGroups.STABLE_DIFFUSION_XL, - "custom": ModelGroups.STABLE_DIFFUSION_XL, - "nai-diffusion-furry-3": ModelGroups.STABLE_DIFFUSION_XL_FURRY, - "nai-diffusion-furry-3-inpainting": ModelGroups.STABLE_DIFFUSION_XL_FURRY, - } - return mapping.get(model, ModelGroups.STABLE_DIFFUSION) - - INPAINTING_MODEL_LIST = [ Model.NAI_DIFFUSION_3_INPAINTING, Model.NAI_DIFFUSION_INPAINTING, @@ -224,8 +150,88 @@ def get_model_group(model: str) -> ModelGroups: "Stable Diffusion XL 9CC2F394": Model.NAI_DIFFUSION_FURRY_3, } +ModelTypeAlias = Optional[Union[Model, str]] +ImageBytesTypeAlias = Optional[Union[str, bytes]] +UCPresetTypeAlias = Optional[Union[UCPreset, int]] + + +def get_supported_noise_schedule(sample_type: Sampler) -> List[NoiseSchedule]: + """ + Get supported noise schedule for a given sample type + :param sample_type: Sampler + :return: List[NoiseSchedule] + """ + if sample_type in [ + Sampler.K_EULER_ANCESTRAL, + Sampler.K_DPMPP_2S_ANCESTRAL, + Sampler.K_DPMPP_2M, + Sampler.K_DPMPP_2M_SDE, + Sampler.K_DPMPP_SDE, + Sampler.K_EULER + ]: + return [ + NoiseSchedule.NATIVE, + NoiseSchedule.KARRAS, + NoiseSchedule.EXPONENTIAL, + NoiseSchedule.POLYEXPONENTIAL + ] + elif sample_type in [Sampler.K_DPM_2]: + return [ + NoiseSchedule.EXPONENTIAL, + NoiseSchedule.POLYEXPONENTIAL + ] + else: + return [] + + +def get_default_noise_schedule(sample_type: Sampler) -> NoiseSchedule: + """ + Get default noise schedule for a given sample type + :param sample_type: Sampler + :return: NoiseSchedule + """ + if sample_type in [ + Sampler.K_EULER_ANCESTRAL, + Sampler.K_DPMPP_2S_ANCESTRAL, + Sampler.K_DPMPP_2M, + Sampler.K_DPMPP_2M_SDE, + Sampler.K_DPMPP_SDE, + Sampler.K_EULER + ]: + return NoiseSchedule.KARRAS + elif sample_type in [Sampler.K_DPM_2]: + return NoiseSchedule.EXPONENTIAL + else: + return NoiseSchedule.NATIVE + + +def get_model_group(model: ModelTypeAlias) -> ModelGroups: + if isinstance(model, Enum): + model = model.value + else: + model = str(model) + mapping = { + "stable-diffusion": ModelGroups.STABLE_DIFFUSION, + "nai-diffusion": ModelGroups.STABLE_DIFFUSION, + "safe-diffusion": ModelGroups.STABLE_DIFFUSION, + "waifu-diffusion": ModelGroups.STABLE_DIFFUSION, + "nai-diffusion-furry": ModelGroups.STABLE_DIFFUSION, + "curated-diffusion-test": ModelGroups.STABLE_DIFFUSION, + "nai-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION, + "safe-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION, + "furry-diffusion-inpainting": ModelGroups.STABLE_DIFFUSION, + "nai-diffusion-2": ModelGroups.STABLE_DIFFUSION_GROUP_2, + "nai-diffusion-xl": ModelGroups.STABLE_DIFFUSION_XL, + "nai-diffusion-3": ModelGroups.STABLE_DIFFUSION_XL, + "nai-diffusion-3-inpainting": ModelGroups.STABLE_DIFFUSION_XL, + "custom": ModelGroups.STABLE_DIFFUSION_XL, + "nai-diffusion-furry-3": ModelGroups.STABLE_DIFFUSION_XL_FURRY, + "nai-diffusion-furry-3-inpainting": ModelGroups.STABLE_DIFFUSION_XL_FURRY, + } + return mapping.get(model, ModelGroups.STABLE_DIFFUSION) + -def get_default_uc_preset(model: str, uc_preset: int) -> str: +def get_default_uc_preset(model: ModelTypeAlias, uc_preset: int) -> str: if isinstance(model, Enum): model = model.value if isinstance(uc_preset, Enum): diff --git a/src/novelai_python/sdk/ai/generate_image/__init__.py b/src/novelai_python/sdk/ai/generate_image/__init__.py index e2cbae1..4e998fb 100755 --- a/src/novelai_python/sdk/ai/generate_image/__init__.py +++ b/src/novelai_python/sdk/ai/generate_image/__init__.py @@ -25,7 +25,8 @@ from novelai_python.sdk.ai._cost import CostCalculator from novelai_python.sdk.ai._enum import Model, Sampler, NoiseSchedule, ControlNetModel, Action, UCPreset, \ - INPAINTING_MODEL_LIST, get_default_noise_schedule, get_supported_noise_schedule, get_default_uc_preset + INPAINTING_MODEL_LIST, get_default_noise_schedule, get_supported_noise_schedule, get_default_uc_preset, \ + ModelTypeAlias, ImageBytesTypeAlias, UCPresetTypeAlias from ...schema import ApiBaseModel from ...._exceptions import APIError, AuthError, ConcurrentGenerationError, SessionHttpError from ...._response.ai.generate_image import ImageGenerateResp, RequestParams @@ -46,7 +47,7 @@ class Params(BaseModel): 叠加原始图像 防止现有图像发生更改,但可能会沿蒙版边缘引入接缝。 """ - mask: Optional[Union[str, bytes]] = None + mask: ImageBytesTypeAlias = None """Mask for Inpainting""" cfg_rescale: Optional[float] = Field(0, ge=0, le=1, multiple_of=0.02) """Prompt Guidance Rescale""" @@ -56,7 +57,7 @@ class Params(BaseModel): """Decrisp:Reduce artifacts caused by high prompt guidance values""" height: Optional[int] = Field(1216, ge=64, le=49152) """Height For Generate Image""" - image: Optional[Union[str, bytes]] = None + image: ImageBytesTypeAlias = None """Image for img2img""" strength: Optional[float] = Field(default=0.5, ge=0.01, le=0.99, multiple_of=0.01) """Strength for img2img""" @@ -138,7 +139,7 @@ def reference_strength_multiple_validator(cls, v): # TODO: find out the usage steps: Optional[int] = Field(23, ge=1, le=50) """Steps""" - ucPreset: Optional[Union[UCPreset, int]] = Field(None, ge=0) + ucPreset: UCPresetTypeAlias = Field(None, ge=0) """The Negative Prompt Preset, Bigger or equal to 0""" uncond_scale: Optional[float] = Field(1.0, ge=0, le=1.5, multiple_of=0.05) """Undesired Content Strength""" @@ -324,7 +325,7 @@ def endpoint(self, value): action: Union[str, Action] = Field(Action.GENERATE, description="Mode for img generate") input: str = "1girl, best quality, amazing quality, very aesthetic, absurdres" - model: Optional[Union[Model, str]] = "nai-diffusion-3" + model: ModelTypeAlias = "nai-diffusion-3" parameters: Params = Params() model_config = ConfigDict(extra="ignore") @@ -420,7 +421,7 @@ def build(cls, model: Union[Model, str] = "nai-diffusion-3", action: Union[Action, str] = 'generate', negative_prompt: str = "", - ucPreset: Optional[Union[UCPreset, int]] = UCPreset.TYPE0, + ucPreset: UCPresetTypeAlias = UCPreset.TYPE0, steps: int = 28, seed: int = None, scale: float = 5.0, diff --git a/src/novelai_python/sdk/user/login.py b/src/novelai_python/sdk/user/login.py index e10f127..f88f5bf 100755 --- a/src/novelai_python/sdk/user/login.py +++ b/src/novelai_python/sdk/user/login.py @@ -74,7 +74,7 @@ async def request(self, if override_headers: sess.headers.clear() sess.headers.update(override_headers) - logger.debug("Login") + logger.debug("Fetching login-credential") try: assert hasattr(sess, "post"), "session must have get method." response = await sess.post(