diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ea131dc4..53bbd729 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -15,7 +15,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4.5.0 with: - python-version: 3.9 + python-version: "3.10" - name: Cache dependencies uses: actions/cache@v3.2.4 id: cache @@ -35,7 +35,7 @@ jobs: - uses: actions/checkout@v3 - uses: actions/setup-python@v4.5.0 with: - python-version: 3.9 + python-version: "3.10" - name: Cache dependencies uses: actions/cache@v3.2.4 id: cache @@ -73,3 +73,18 @@ jobs: CUDA_LAUNCH_BLOCKING: 1 run: | pytest --durations=10 -v + type-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4.5.0 + with: + python-version: "3.10" + cache: pip + cache-dependency-path: requirements-dev.txt + - name: Install dependencies + run: | + python -m pip install -r requirements-dev.txt . --upgrade + - name: Run mypy + run: | + make type-check diff --git a/Makefile b/Makefile index 2d53582c..84589542 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,9 @@ lint: ## Run the code linter. @ruff check --config tests/ruff.toml . @echo -e "No linting errors - well done! ✨ 🍰 ✨" +type-check: ## Run the type checker. + @mypy --config-file tox.ini . + deploy: ## Deploy the package to pypi.org pip install twine wheel -git tag $$(python setup.py -V) diff --git a/README.md b/README.md index 71b901c8..166b56c1 100644 --- a/README.md +++ b/README.md @@ -787,7 +787,7 @@ Use with `--model SD-2.1` or `--model SD-2.0-v` **6.1.0** - feature: use different default steps and image sizes depending on sampler and model selected - fix: #110 use proper version in image metadata -- refactor: samplers all have their own class that inherits from ImageSampler +- refactor: solvers all have their own class that inherits from ImageSolver - feature: 🎉🎉🎉 Stable Diffusion 2.0 - `--model SD-2.0` to use (it makes worse images than 1.5 though...) - Tested on macOS and Linux diff --git a/docs/examples/generate_doc_examples.py b/docs/examples/generate_doc_examples.py index 23f92bc7..f0e05fdd 100644 --- a/docs/examples/generate_doc_examples.py +++ b/docs/examples/generate_doc_examples.py @@ -1,4 +1,5 @@ -from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files +from imaginairy.api import imagine_image_files +from imaginairy.schema import ImaginePrompt, LazyLoadingImage def main(): diff --git a/docs/examples/immortal_pearl_earring.py b/docs/examples/immortal_pearl_earring.py index aaaa5d91..ebd678ab 100644 --- a/docs/examples/immortal_pearl_earring.py +++ b/docs/examples/immortal_pearl_earring.py @@ -4,8 +4,9 @@ from PIL import ImageDraw, ImageFont from tqdm import tqdm -from imaginairy import ImaginePrompt, LazyLoadingImage, WeightedPrompt, imagine +from imaginairy.api import imagine from imaginairy.log_utils import configure_logging +from imaginairy.schema import ImaginePrompt, LazyLoadingImage, WeightedPrompt def generate_image_morph_video(): diff --git a/docs/todo.md b/docs/todo.md index e95bd8f5..0dadf76c 100644 --- a/docs/todo.md +++ b/docs/todo.md @@ -3,19 +3,25 @@ ### v14 todo - configurable composition cutoff - - rename model parameter weights - - rename model_config parameter to architecture and make it case insensitive - - add --size parameter that accepts strings (e.g. 256x256, 4k, uhd, 8k, etc) - - detect if cuda torch missing and give better error message - - add method to install correct torch version + - ✅ rename model parameter weights + - ✅ rename model_config parameter to architecture and make it case insensitive + - ✅ add --size parameter that accepts strings (e.g. 256x256, 4k, uhd, 8k, etc) + - ✅ detect if cuda torch missing and give better error message + - ✅ add method to install correct torch version + - ✅ make cli run faster again + - ✅ add tests for cli commands + - add type checker + - only output the main image unless some flag is set - allow selection of output video format - chain multiple operations together imggen => videogen + - https://github.com/pallets/click/tree/main/examples/imagepipe + + - add interface for loading diffusers weights + - https://huggingface.co/playgroundai/playground-v2-1024px-aesthetic - make sure terminal output on windows doesn't suck - - add karras schedule to refiners - add method to show cache size - add method to clear model cache - add method to clear cached items not recently used (does diffusers have one?) - ### Old Todo - Inference Performance Optimizations diff --git a/imaginairy/__init__.py b/imaginairy/__init__.py index 5203ae88..aede5c9a 100644 --- a/imaginairy/__init__.py +++ b/imaginairy/__init__.py @@ -4,22 +4,3 @@ os.putenv("PYTORCH_ENABLE_MPS_FALLBACK", "1") # use more memory than we should os.putenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO", "0.0") - -import sys # noqa - -from .api import imagine, imagine_image_files # noqa -from .schema import ( # noqa - ImaginePrompt, - ImagineResult, - LazyLoadingImage, - WeightedPrompt, -) - -# if python version is 3.11 or higher, throw an exception -if sys.version_info >= (3, 11): - msg = ( - "Imaginairy is not compatible with Python 3.11 or higher. Please use Python 3.8 - 3.10.\n" - "This is due to torch 1.13 not supporting Python 3.11 and this library not having yet switched " - "to torch 2.0" - ) - raise RuntimeError(msg) diff --git a/imaginairy/api.py b/imaginairy/api.py index b7f2be6a..d98f4343 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -1,35 +1,39 @@ import logging import os import re +from typing import TYPE_CHECKING, Any, Callable -from imaginairy.schema import ControlNetInput, SafetyMode +from imaginairy.utils.named_resolutions import normalize_image_size + +if TYPE_CHECKING: + from imaginairy.schema import ImaginePrompt, LazyLoadingImage logger = logging.getLogger(__name__) # leave undocumented. I'd ask that no one publicize this flag. Just want a # slight barrier to entry. Please don't use this is any way that's gonna cause # the media or politicians to freak out about AI... -IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", SafetyMode.STRICT) +IMAGINAIRY_SAFETY_MODE = os.getenv("IMAGINAIRY_SAFETY_MODE", "strict") if IMAGINAIRY_SAFETY_MODE in {"disabled", "classify"}: - IMAGINAIRY_SAFETY_MODE = SafetyMode.RELAXED + IMAGINAIRY_SAFETY_MODE = "relaxed" elif IMAGINAIRY_SAFETY_MODE == "filter": - IMAGINAIRY_SAFETY_MODE = SafetyMode.STRICT + IMAGINAIRY_SAFETY_MODE = "strict" # we put this in the global scope so it can be used in the interactive shell _most_recent_result = None def imagine_image_files( - prompts, - outdir, - precision="autocast", - record_step_images=False, - output_file_extension="jpg", - print_caption=False, - make_gif=False, - make_compare_gif=False, - return_filename_type="generated", - videogen=False, + prompts: "list[ImaginePrompt] | ImaginePrompt", + outdir: str, + precision: str = "autocast", + record_step_images: bool = False, + output_file_extension: str = "jpg", + print_caption: bool = False, + make_gif: bool = False, + make_compare_gif: bool = False, + return_filename_type: str = "generated", + videogen: bool = False, ): from PIL import ImageDraw @@ -46,6 +50,9 @@ def imagine_image_files( if output_file_extension not in {"jpg", "png"}: raise ValueError("Must output a png or jpg") + if not isinstance(prompts, list): + prompts = [prompts] + def _record_step(img, description, image_count, step_count, prompt): steps_path = os.path.join(outdir, "steps", f"{base_count:08}_S{prompt.seed}") os.makedirs(steps_path, exist_ok=True) @@ -74,7 +81,7 @@ def _record_step(img, description, image_count, step_count, prompt): if prompt.init_image: img_str = f"_img2img-{prompt.init_image_strength}" basefilename = ( - f"{base_count:06}_{prompt.seed}_{prompt.sampler_type.replace('_', '')}{prompt.steps}_" + f"{base_count:06}_{prompt.seed}_{prompt.solver_type.replace('_', '')}{prompt.steps}_" f"PS{prompt.prompt_strength}{img_str}_{prompt_normalized(prompt.prompt_text)}" ) @@ -139,15 +146,15 @@ def _record_step(img, description, image_count, step_count, prompt): def imagine( - prompts, - precision="autocast", - debug_img_callback=None, - progress_img_callback=None, - progress_img_interval_steps=3, + prompts: "list[ImaginePrompt] | str | ImaginePrompt", + precision: str = "autocast", + debug_img_callback: Callable | None = None, + progress_img_callback: Callable | None = None, + progress_img_interval_steps: int = 3, progress_img_interval_min_s=0.1, half_mode=None, - add_caption=False, - unsafe_retry_count=1, + add_caption: bool = False, + unsafe_retry_count: int = 1, ): import torch.nn @@ -209,7 +216,7 @@ def imagine( def _generate_single_image_compvis( - prompt, + prompt: "ImaginePrompt", debug_img_callback=None, progress_img_callback=None, progress_img_interval_steps=3, @@ -248,9 +255,9 @@ def _generate_single_image_compvis( from imaginairy.modules.midas.api import torch_image_to_depth_map from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint from imaginairy.safety import create_safety_score - from imaginairy.samplers import SAMPLER_LOOKUP + from imaginairy.samplers import SOLVER_LOOKUP from imaginairy.samplers.editing import CFGEditingDenoiser - from imaginairy.schema import ImaginePrompt, ImagineResult + from imaginairy.schema import ControlInput, ImagineResult, MaskMode from imaginairy.utils import get_device, randn_seeded latent_channels = 4 @@ -274,7 +281,7 @@ def _generate_single_image_compvis( if control_inputs: control_modes = [c.mode for c in prompt.control_inputs] if inpaint_method == "auto": - if prompt.model in {"SD-1.5", "SD-2.0"}: + if prompt.model_weights in {"SD-1.5", "SD-2.0"}: inpaint_method = "finetune" else: inpaint_method = "controlnet" @@ -282,8 +289,8 @@ def _generate_single_image_compvis( if for_inpainting and inpaint_method == "controlnet": control_modes.append("inpaint") model = get_diffusion_model( - weights_location=prompt.model, - config_path=prompt.model_config_path, + weights_location=prompt.model_weights, + config_path=prompt.model_architecture, control_weights_locations=control_modes, half_mode=half_mode, for_inpainting=for_inpainting and inpaint_method == "finetune", @@ -326,22 +333,26 @@ def latent_logger(latents): prompt.height // downsampling_factor, prompt.width // downsampling_factor, ] - SamplerCls = SAMPLER_LOOKUP[prompt.sampler_type.lower()] - sampler = SamplerCls(model) - mask_latent = mask_image = mask_image_orig = mask_grayscale = None - t_enc = init_latent = control_image = None + SolverCls = SOLVER_LOOKUP[prompt.solver_type.lower()] + solver = SolverCls(model) + mask_image: Image.Image | LazyLoadingImage | None = None + mask_latent = mask_image_orig = mask_grayscale = None + init_latent: torch.Tensor | None = None + t_enc = None starting_image = None denoiser_cls = None c_cat = [] c_cat_neutral = None - result_images = {} + result_images: dict[str, torch.Tensor | Image.Image | None] = {} + assert prompt.seed is not None seed_everything(prompt.seed) noise = randn_seeded(seed=prompt.seed, size=shape).to(get_device()) control_strengths = [] if prompt.init_image: starting_image = prompt.init_image + assert prompt.init_image_strength is not None generation_strength = 1 - prompt.init_image_strength if model.cond_stage_key == "edit" or generation_strength >= 1: @@ -360,18 +371,18 @@ def latent_logger(latents): starting_image, mask_image = prepare_image_for_outpaint( starting_image, mask_image, **outpaint_kwargs ) - + assert starting_image is not None init_image = pillow_fit_image_within( starting_image, max_height=prompt.height, max_width=prompt.width, ) - init_image_t = pillow_img_to_torch_image(init_image) - init_image_t = init_image_t.to(get_device()) + init_image_t = pillow_img_to_torch_image(init_image).to(get_device()) init_latent = model.get_first_stage_encoding( model.encode_first_stage(init_image_t) ) - shape = init_latent.shape + assert init_latent is not None + shape = list(init_latent.shape) log_latent(init_latent, "init_latent") @@ -385,7 +396,7 @@ def latent_logger(latents): log_img(mask_image, "init mask") - if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE: + if prompt.mask_mode == MaskMode.REPLACE: mask_image = ImageOps.invert(mask_image) mask_image_orig = mask_image @@ -396,11 +407,11 @@ def latent_logger(latents): if inpaint_method == "controlnet": result_images["control-inpaint"] = mask_image control_inputs.append( - ControlNetInput(mode="inpaint", image=mask_image) + ControlInput(mode="inpaint", image=mask_image) ) - + assert prompt.seed is not None seed_everything(prompt.seed) - noise = randn_seeded(seed=prompt.seed, size=init_latent.shape).to( + noise = randn_seeded(seed=prompt.seed, size=list(init_latent.shape)).to( get_device() ) # noise = noise[:, :, : init_latent.shape[2], : init_latent.shape[3]] @@ -444,8 +455,13 @@ def latent_logger(latents): control_image = control_input.image_raw elif control_input.image is not None: control_image = control_input.image + else: + raise RuntimeError("Control image must be provided") + assert control_image is not None control_image = control_image.convert("RGB") log_img(control_image, "control_image_input") + assert control_image is not None + control_image_input = pillow_fit_image_within( control_image, max_height=prompt.height, @@ -457,11 +473,11 @@ def latent_logger(latents): if control_input.image_raw is None: control_prep_function = CONTROL_MODES[control_input.mode] if control_input.mode == "inpaint": - control_image_t = control_prep_function( + control_image_t = control_prep_function( # type: ignore control_image_input_t, init_image_t ) else: - control_image_t = control_prep_function(control_image_input_t) + control_image_t = control_prep_function(control_image_input_t) # type: ignore else: control_image_t = (control_image_input_t + 1) / 2 @@ -492,6 +508,8 @@ def latent_logger(latents): elif hasattr(model, "masked_image_key"): # inpainting model + assert mask_image_orig is not None + assert mask_latent is not None mask_t = pillow_img_to_torch_image(ImageOps.invert(mask_image_orig)).to( get_device() ) @@ -512,6 +530,7 @@ def latent_logger(latents): elif model.cond_stage_key == "edit": # pix2pix model c_cat = [model.encode_first_stage(init_image_t)] + assert init_latent is not None c_cat_neutral = [torch.zeros_like(init_latent)] denoiser_cls = CFGEditingDenoiser if c_cat: @@ -520,37 +539,46 @@ def latent_logger(latents): if c_cat_neutral is None: c_cat_neutral = c_cat - positive_conditioning = { + positive_conditioning_d: dict[str, Any] = { "c_concat": c_cat, "c_crossattn": [positive_conditioning], } - neutral_conditioning = { + neutral_conditioning_d: dict[str, Any] = { "c_concat": c_cat_neutral, "c_crossattn": [neutral_conditioning], } + del neutral_conditioning + del positive_conditioning if control_strengths and is_controlnet_model: - positive_conditioning["control_strengths"] = torch.Tensor(control_strengths) - neutral_conditioning["control_strengths"] = torch.Tensor(control_strengths) + positive_conditioning_d["control_strengths"] = torch.Tensor( + control_strengths + ) + neutral_conditioning_d["control_strengths"] = torch.Tensor( + control_strengths + ) if ( prompt.allow_compose_phase and not is_controlnet_model and model.cond_stage_key != "edit" ): + default_size = get_model_default_image_size( + prompt.model_weights.architecture + ) if prompt.init_image: comp_image = _generate_composition_image( prompt=prompt, target_height=init_image.height, target_width=init_image.width, - cutoff=get_model_default_image_size(prompt.model), + cutoff=default_size, ) else: comp_image = _generate_composition_image( prompt=prompt, target_height=prompt.height, target_width=prompt.width, - cutoff=get_model_default_image_size(prompt.model), + cutoff=default_size, ) if comp_image is not None: result_images["composition"] = comp_image @@ -563,10 +591,10 @@ def latent_logger(latents): model.encode_first_stage(comp_image_t) ) with lc.timing("sampling"): - samples = sampler.sample( + samples = solver.sample( num_steps=prompt.steps, - positive_conditioning=positive_conditioning, - neutral_conditioning=neutral_conditioning, + positive_conditioning=positive_conditioning_d, + neutral_conditioning=neutral_conditioning_d, guidance_scale=prompt.prompt_strength, t_start=t_enc, mask=mask_latent, @@ -632,15 +660,16 @@ def latent_logger(latents): caption_text = prompt.caption_text.format(prompt=prompt.prompt_text) add_caption_to_image(gen_img, caption_text) + result_images["upscaled"] = upscaled_img + result_images["modified_original"] = rebuilt_orig_img + result_images["mask_binary"] = mask_image_orig + result_images["mask_grayscale"] = mask_grayscale + result = ImagineResult( img=gen_img, prompt=prompt, - upscaled_img=upscaled_img, is_nsfw=safety_score.is_nsfw, safety_score=safety_score, - modified_original=rebuilt_orig_img, - mask_binary=mask_image_orig, - mask_grayscale=mask_grayscale, result_images=result_images, timings=lc.get_timings(), progress_latents=progress_latents.copy(), @@ -660,18 +689,15 @@ def _prompts_to_embeddings(prompts, model): return conditioning -def calc_scale_to_fit_within( - height, - width, - max_size, -): - if max(height, width) < max_size: +def calc_scale_to_fit_within(height: int, width: int, max_size) -> float: + max_width, max_height = normalize_image_size(max_size) + if width <= max_width and height <= max_height: return 1 - if width > height: - return max_size / width + width_ratio = max_width / width + height_ratio = max_height / height - return max_size / height + return min(width_ratio, height_ratio) def _scale_latent( @@ -690,14 +716,19 @@ def _scale_latent( def _generate_composition_image( - prompt, target_height, target_width, cutoff=512, dtype=None + prompt, + target_height, + target_width, + cutoff: tuple[int, int] = (512, 512), + dtype=None, ): from PIL import Image from imaginairy.api_refiners import _generate_single_image from imaginairy.utils import default, get_default_dtype - if prompt.width <= cutoff and prompt.height <= cutoff: + cutoff = normalize_image_size(cutoff) + if prompt.width <= cutoff[0] and prompt.height <= cutoff[1]: return None, None dtype = default(dtype, get_default_dtype) @@ -711,12 +742,15 @@ def _generate_composition_image( composition_prompt = prompt.full_copy( deep=True, update={ - "width": int(prompt.width * shrink_scale), - "height": int(prompt.height * shrink_scale), + "size": ( + int(prompt.width * shrink_scale), + int(prompt.height * shrink_scale), + ), "steps": None, "upscale": False, "fix_faces": False, "mask_modify_original": False, + "allow_compose_phase": False, }, ) diff --git a/imaginairy/api_refiners.py b/imaginairy/api_refiners.py index 5731a236..4599611d 100644 --- a/imaginairy/api_refiners.py +++ b/imaginairy/api_refiners.py @@ -1,30 +1,25 @@ import logging from typing import List, Optional -from imaginairy import WeightedPrompt -from imaginairy.config import CONTROLNET_CONFIG_SHORTCUTS -from imaginairy.model_manager import load_controlnet_adapter +from imaginairy.config import CONTROL_CONFIG_SHORTCUTS +from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode, WeightedPrompt logger = logging.getLogger(__name__) def _generate_single_image( - prompt, + prompt: ImaginePrompt, debug_img_callback=None, progress_img_callback=None, progress_img_interval_steps=3, progress_img_interval_min_s=0.1, add_caption=False, - # controlnet, finetune, naive, auto - inpaint_method="finetune", return_latent=False, dtype=None, half_mode=None, ): - import gc - import torch.nn - from PIL import ImageOps + from PIL import Image, ImageOps from pytorch_lightning import seed_everything from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver from tqdm import tqdm @@ -55,32 +50,22 @@ def _generate_single_image( ) from imaginairy.outpaint import outpaint_arg_str_parse, prepare_image_for_outpaint from imaginairy.safety import create_safety_score - from imaginairy.samplers import SamplerName - from imaginairy.schema import ImaginePrompt, ImagineResult + from imaginairy.samplers import SolverName + from imaginairy.schema import ImagineResult from imaginairy.utils import get_device, randn_seeded if dtype is None: dtype = torch.float16 if half_mode else torch.float32 get_device() - gc.collect() - torch.cuda.empty_cache() + clear_gpu_cache() prompt = prompt.make_concrete_copy() - control_modes = [] - control_inputs = prompt.control_inputs or [] - control_inputs = control_inputs.copy() - for_inpainting = bool(prompt.mask_image or prompt.mask_prompt or prompt.outpaint) - - if control_inputs: - control_modes = [c.mode for c in prompt.control_inputs] - sd = get_diffusion_model_refiners( - weights_location=prompt.model, - config_path=prompt.model_config_path, - control_weights_locations=tuple(control_modes), + weights_config=prompt.model_weights, + for_inpainting=prompt.should_use_inpainting + and prompt.inpaint_method == "finetune", dtype=dtype, - for_inpainting=for_inpainting and inpaint_method == "finetune", ) seed_everything(prompt.seed) @@ -90,7 +75,6 @@ def _generate_single_image( mask_image = None mask_image_orig = None - prompt = prompt.make_concrete_copy() def latent_logger(latents): progress_latents.append(latents) @@ -116,8 +100,8 @@ def latent_logger(latents): ) clip_text_embedding = clip_text_embedding.to(device=sd.device, dtype=sd.dtype) - result_images = {} - progress_latents = [] + result_images: dict[str, torch.Tensor | None | Image.Image] = {} + progress_latents: list[torch.Tensor] = [] first_step = 0 mask_grayscale = None @@ -130,9 +114,18 @@ def latent_logger(latents): init_latent = None noise_step = None + + control_modes = [] + control_inputs = prompt.control_inputs or [] + control_inputs = control_inputs.copy() + + if control_inputs: + control_modes = [c.mode for c in prompt.control_inputs] + if prompt.init_image: starting_image = prompt.init_image - first_step = int((prompt.steps) * prompt.init_image_strength) + assert prompt.init_image_strength is not None + first_step = int(prompt.steps * prompt.init_image_strength) # noise_step = int((prompt.steps - 1) * prompt.init_image_strength) if prompt.mask_prompt: @@ -157,7 +150,7 @@ def latent_logger(latents): init_image_t = init_image_t.to(device=sd.device, dtype=sd.dtype) init_latent = sd.lda.encode(init_image_t) - shape = init_latent.shape + shape = list(init_latent.shape) log_latent(init_latent, "init_latent") @@ -171,7 +164,7 @@ def latent_logger(latents): log_img(mask_image, "init mask") - if prompt.mask_mode == ImaginePrompt.MaskMode.REPLACE: + if prompt.mask_mode == MaskMode.REPLACE: mask_image = ImageOps.invert(mask_image) mask_image_orig = mask_image @@ -179,13 +172,14 @@ def latent_logger(latents): pillow_mask_to_latent_mask( mask_image, downsampling_factor=downsampling_factor ).to(get_device()) - # if inpaint_method == "controlnet": - # result_images["control-inpaint"] = mask_image - # control_inputs.append( - # ControlNetInput(mode="inpaint", image=mask_image) - # ) + if prompt.inpaint_method == "controlnet": + result_images["control-inpaint"] = mask_image + control_inputs.append( + ControlInput(mode="inpaint", image=mask_image) + ) seed_everything(prompt.seed) + assert prompt.seed is not None noise = randn_seeded(seed=prompt.seed, size=shape).to( get_device(), dtype=sd.dtype @@ -194,7 +188,6 @@ def latent_logger(latents): controlnets = [] if control_modes: - control_strengths = [] from imaginairy.img_processors.control_modes import CONTROL_MODES for control_input in control_inputs: @@ -218,11 +211,11 @@ def latent_logger(latents): if control_input.image_raw is None: control_prep_function = CONTROL_MODES[control_input.mode] if control_input.mode == "inpaint": - control_image_t = control_prep_function( + control_image_t = control_prep_function( # type: ignore control_image_input_t, init_image_t ) else: - control_image_t = control_prep_function(control_image_input_t) + control_image_t = control_prep_function(control_image_input_t) # type: ignore else: control_image_t = (control_image_input_t + 1) / 2 @@ -231,10 +224,10 @@ def latent_logger(latents): log_img(control_image_disp, "control_image") if len(control_image_t.shape) == 3: - raise RuntimeError("Control image must be 4D") + raise ValueError("Control image must be 4D") if control_image_t.shape[1] != 3: - raise RuntimeError("Control image must have 3 channels") + raise ValueError("Control image must have 3 channels") if ( control_input.mode != "inpaint" @@ -242,43 +235,47 @@ def latent_logger(latents): or control_image_t.max() > 1 ): msg = f"Control image must be in [0, 1] but we received {control_image_t.min()} and {control_image_t.max()}" - raise RuntimeError(msg) + raise ValueError(msg) if control_image_t.max() == control_image_t.min(): msg = f"No control signal found in control image {control_input.mode}." - raise RuntimeError(msg) - - control_strengths.append(control_input.strength) + raise ValueError(msg) - control_weights_path = CONTROLNET_CONFIG_SHORTCUTS.get( - control_input.mode, None - ).weights_url + control_config = CONTROL_CONFIG_SHORTCUTS.get(control_input.mode, None) + if not control_config: + msg = f"Unknown control mode: {control_input.mode}" + raise ValueError(msg) + from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter - controlnet = load_controlnet_adapter( + controlnet = SD1ControlnetAdapter( # type: ignore name=control_input.mode, - control_weights_location=control_weights_path, - target_unet=sd.unet, - scale=control_input.strength, + target=sd.unet, # type: ignore + weights_location=control_config.weights_location, ) + controlnet.set_scale(control_input.strength) + controlnets.append((controlnet, control_image_t)) if prompt.allow_compose_phase: + cutoff_size = get_model_default_image_size(prompt.model_architecture) + cutoff_size = (int(cutoff_size[0] * 1.30), int(cutoff_size[1] * 1.30)) + compose_kwargs = { + "prompt": prompt, + "target_height": prompt.height, + "target_width": prompt.width, + "cutoff": cutoff_size, + "dtype": dtype, + } + if prompt.init_image: - comp_image, comp_img_orig = _generate_composition_image( - prompt=prompt, - target_height=init_image.height, - target_width=init_image.width, - cutoff=get_model_default_image_size(prompt.model), - dtype=dtype, - ) - else: - comp_image, comp_img_orig = _generate_composition_image( - prompt=prompt, - target_height=prompt.height, - target_width=prompt.width, - cutoff=get_model_default_image_size(prompt.model), - dtype=dtype, + compose_kwargs.update( + { + "target_height": init_image.height, + "target_width": init_image.width, + } ) + comp_image, comp_img_orig = _generate_composition_image(**compose_kwargs) + if comp_image is not None: result_images["composition"] = comp_img_orig result_images["composition-upscaled"] = comp_image @@ -296,17 +293,17 @@ def latent_logger(latents): control_image_t.to(device=sd.device, dtype=sd.dtype) ) controlnet.inject() - if prompt.sampler_type.lower() == SamplerName.K_DPMPP_2M: + if prompt.solver_type.lower() == SolverName.DPMPP: sd.scheduler = DPMSolver(num_inference_steps=prompt.steps) - elif prompt.sampler_type.lower() == SamplerName.DDIM: + elif prompt.solver_type.lower() == SolverName.DDIM: sd.scheduler = DDIM(num_inference_steps=prompt.steps) else: - msg = f"Unknown sampler type: {prompt.sampler_type}" + msg = f"Unknown solver type: {prompt.solver_type}" raise ValueError(msg) sd.scheduler.to(device=sd.device, dtype=sd.dtype) sd.set_num_inference_steps(prompt.steps) - if hasattr(sd, "mask_latents"): + if hasattr(sd, "mask_latents") and mask_image is not None: sd.set_inpainting_conditions( target_image=init_image, mask=ImageOps.invert(mask_image), @@ -327,8 +324,7 @@ def latent_logger(latents): # if "cuda" in str(sd.lda.device): # sd.lda.to("cpu") - gc.collect() - torch.cuda.empty_cache() + clear_gpu_cache() # print(f"moving unet to {sd.device}") # sd.unet.to(device=sd.device, dtype=sd.dtype) for step in tqdm(sd.steps[first_step:], bar_format=" {l_bar}{bar}{r_bar}"): @@ -340,26 +336,12 @@ def latent_logger(latents): condition_scale=prompt.prompt_strength, ) - # z = sd( - # randn_seeded(seed=prompt.seed, size=[1, 4, 8, 8]).to( - # device=sd.device, dtype=sd.dtype - # ), - # step=step, - # clip_text_embedding=clip_text_embedding, - # condition_scale=prompt.prompt_strength, - # ) - - if "cuda" in str(sd.unet.device): - # print("moving unet to cpu") - # sd.unet.to("cpu") - gc.collect() - torch.cuda.empty_cache() + clear_gpu_cache() logger.debug("Decoding image") if x.device != sd.lda.device: sd.lda.to(x.device) - gc.collect() - torch.cuda.empty_cache() + clear_gpu_cache() gen_img = sd.lda.decode_latents(x) if mask_image_orig and init_image: @@ -414,18 +396,24 @@ def latent_logger(latents): caption_text = prompt.caption_text.format(prompt=prompt.prompt_text) add_caption_to_image(gen_img, caption_text) + # todo: do something smarter + result_images.update( + { + "upscaled": upscaled_img, + "modified_original": rebuilt_orig_img, + "mask_binary": mask_image_orig, + "mask_grayscale": mask_grayscale, + } + ) + result = ImagineResult( img=gen_img, prompt=prompt, - upscaled_img=upscaled_img, is_nsfw=safety_score.is_nsfw, safety_score=safety_score, - modified_original=rebuilt_orig_img, - mask_binary=mask_image_orig, - mask_grayscale=mask_grayscale, result_images=result_images, - timings={}, - progress_latents=[], + timings=lc.get_timings(), + progress_latents=[], # todo ) _most_recent_result = result @@ -433,14 +421,16 @@ def latent_logger(latents): logger.info(f"Image Generated. Timings: {result.timings_str()}") for controlnet, _ in controlnets: controlnet.eject() - gc.collect() - torch.cuda.empty_cache() + clear_gpu_cache() return result def _prompts_to_embeddings(prompts, text_encoder): import torch + if not prompts: + prompts = [WeightedPrompt(text="")] + total_weight = sum(wp.weight for wp in prompts) if str(text_encoder.device) == "cpu": text_encoder = text_encoder.to(dtype=torch.float32) @@ -473,3 +463,13 @@ def _calc_conditioning( tensors=(neutral_conditioning, positive_conditioning), dim=0 ) return clip_text_embedding + + +def clear_gpu_cache(): + import gc + + import torch + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/imaginairy/cli/arg_schedule.py b/imaginairy/cli/arg_schedule.py new file mode 100644 index 00000000..1569462f --- /dev/null +++ b/imaginairy/cli/arg_schedule.py @@ -0,0 +1,72 @@ +from typing import Iterable + +from imaginairy.utils import frange + + +def with_arg_schedule(f): + """Decorator to add arg-schedule functionality to a click command.""" + + def new_func(*args, **kwargs): + arg_schedules = kwargs.pop("arg_schedules", None) + + if arg_schedules: + schedules = parse_schedule_strs(arg_schedules) + schedule_length = len(next(iter(schedules.values()))) + for i in range(schedule_length): + for attr_name, schedule in schedules.items(): + kwargs[attr_name] = schedule[i] + f(*args, **kwargs) + else: + f(*args, **kwargs) + + return new_func + + +def parse_schedule_strs(schedule_strs: Iterable[str]) -> dict: + """Parse and validate input prompt schedules.""" + schedules = {} + for schedule_str in schedule_strs: + arg_name, arg_values = parse_schedule_str(schedule_str) + schedules[arg_name] = arg_values + + # Validate that all schedules have the same length + schedule_lengths = [len(v) for v in schedules.values()] + if len(set(schedule_lengths)) > 1: + raise ValueError("All schedules must have the same length") + + return schedules + + +def parse_schedule_str(schedule_str): + """Parse a schedule string into a list of values.""" + import re + + pattern = re.compile(r"([a-zA-Z0-9_-]+)\[([a-zA-Z0-9_:,. -]+)\]") + match = pattern.match(schedule_str) + if not match: + msg = f"Invalid kwarg schedule: {schedule_str}" + raise ValueError(msg) + + arg_name = match.group(1).replace("-", "_") + + arg_values = match.group(2) + if ":" in arg_values: + start, end, step = arg_values.split(":") + arg_values = list(frange(float(start), float(end), float(step))) + else: + arg_values = parse_csv_line(arg_values) + return arg_name, arg_values + + +def parse_csv_line(line): + import csv + + reader = csv.reader([line]) + for row in reader: + parsed_row = [] + for value in row: + try: + parsed_row.append(float(value)) + except ValueError: + parsed_row.append(value) + return parsed_row diff --git a/imaginairy/cli/clickshell_mod.py b/imaginairy/cli/clickshell_mod.py index 4ef3c0e0..aa30cf9f 100644 --- a/imaginairy/cli/clickshell_mod.py +++ b/imaginairy/cli/clickshell_mod.py @@ -88,6 +88,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.help_headers_color = "yellow" self.help_options_color = "green" + from imaginairy.cli.unslow_the_cli import unslowify_scripts_safe + + unslowify_scripts_safe() def parse_args(self, ctx, args): # run the parser for ourselves to preserve the passed order diff --git a/imaginairy/cli/colorize.py b/imaginairy/cli/colorize.py index fd4870f7..156b0cfc 100644 --- a/imaginairy/cli/colorize.py +++ b/imaginairy/cli/colorize.py @@ -36,9 +36,9 @@ def colorize_cmd(image_filepaths, outdir, repeats, caption): from tqdm import tqdm - from imaginairy import LazyLoadingImage from imaginairy.colorize import colorize_img from imaginairy.log_utils import configure_logging + from imaginairy.schema import LazyLoadingImage configure_logging() diff --git a/imaginairy/cli/describe.py b/imaginairy/cli/describe.py index c265df21..78ab9d13 100644 --- a/imaginairy/cli/describe.py +++ b/imaginairy/cli/describe.py @@ -7,8 +7,8 @@ def describe_cmd(image_filepaths): """Generate text descriptions of images.""" import os - from imaginairy import LazyLoadingImage from imaginairy.enhancers.describe_image_blip import generate_caption + from imaginairy.schema import LazyLoadingImage imgs = [] for p in image_filepaths: diff --git a/imaginairy/cli/edit.py b/imaginairy/cli/edit.py index 62e7b9eb..c25120b5 100644 --- a/imaginairy/cli/edit.py +++ b/imaginairy/cli/edit.py @@ -31,9 +31,9 @@ @click.option( "--model-weights-path", "--model", - help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.", + help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.", show_default=True, - default="SD-1.5", + default="sd15", ) @click.option( "--negative-prompt", @@ -53,15 +53,13 @@ def edit_cmd( outdir, output_file_extension, repeats, - height, - width, size, steps, seed, upscale, fix_faces, fix_faces_fidelity, - sampler_type, + solver, log_level, quiet, show_work, @@ -76,7 +74,7 @@ def edit_cmd( caption, precision, model_weights_path, - model_config_path, + model_architecture, prompt_library_path, version, make_gif, @@ -95,16 +93,10 @@ def edit_cmd( Same as calling `aimg imagine --model edit --init-image my-dog.jpg --init-image-strength 1` except this command can batch edit images. """ - from imaginairy.schema import ControlNetInput + from imaginairy.schema import ControlInput allow_compose_phase = False - control_inputs = [ - ControlNetInput( - image=None, - image_raw=None, - mode="edit", - ) - ] + control_inputs = [ControlInput(image=None, image_raw=None, mode="edit", strength=1)] return _imagine_cmd( ctx, @@ -116,15 +108,13 @@ def edit_cmd( outdir, output_file_extension, repeats, - height, - width, size, steps, seed, upscale, fix_faces, fix_faces_fidelity, - sampler_type, + solver, log_level, quiet, show_work, @@ -140,7 +130,7 @@ def edit_cmd( caption, precision, model_weights_path, - model_config_path, + model_architecture, prompt_library_path, version, make_gif, diff --git a/imaginairy/cli/imagine.py b/imaginairy/cli/imagine.py index ada0d570..9c234c62 100644 --- a/imaginairy/cli/imagine.py +++ b/imaginairy/cli/imagine.py @@ -1,7 +1,12 @@ import click from imaginairy.cli.clickshell_mod import ImagineColorsCommand -from imaginairy.cli.shared import _imagine_cmd, add_options, common_options +from imaginairy.cli.shared import ( + _imagine_cmd, + add_options, + common_options, + imaginairy_click_context, +) @click.command( @@ -67,6 +72,7 @@ help="Turns the generated photo into video", ) @click.pass_context +@imaginairy_click_context() def imagine_cmd( ctx, prompt_texts, @@ -77,15 +83,13 @@ def imagine_cmd( outdir, output_file_extension, repeats, - height, - width, size, steps, seed, upscale, fix_faces, fix_faces_fidelity, - sampler_type, + solver, log_level, quiet, show_work, @@ -101,7 +105,7 @@ def imagine_cmd( caption, precision, model_weights_path, - model_config_path, + model_architecture, prompt_library_path, version, make_gif, @@ -120,7 +124,7 @@ def imagine_cmd( Can be invoked via either `aimg imagine` or just `imagine`. """ - from imaginairy.schema import ControlNetInput, LazyLoadingImage + from imaginairy.schema import ControlInput, LazyLoadingImage # hacky method of getting order of control images (mixing raw and normal images) control_images = [ @@ -128,13 +132,18 @@ def imagine_cmd( for o, path in ImagineColorsCommand._option_order if o.name in ("control_image", "control_image_raw") ] + control_strengths = [ + strength + for o, strength in ImagineColorsCommand._option_order + if o.name == "control_strength" + ] + control_inputs = [] if control_mode: for i, cm in enumerate(control_mode): - try: - option = control_images[i] - except IndexError: - option = None + option = index_default(control_images, i, None) + control_strength = index_default(control_strengths, i, 1.0) + if option is None: control_image = None control_image_raw = None @@ -149,10 +158,10 @@ def imagine_cmd( if control_image_raw and control_image_raw.startswith("http"): control_image_raw = LazyLoadingImage(url=control_image_raw) control_inputs.append( - ControlNetInput( + ControlInput( image=control_image, image_raw=control_image_raw, - strength=float(control_strength[i]), + strength=float(control_strength), mode=cm, ) ) @@ -167,15 +176,13 @@ def imagine_cmd( outdir, output_file_extension, repeats, - height, - width, size, steps, seed, upscale, fix_faces, fix_faces_fidelity, - sampler_type, + solver, log_level, quiet, show_work, @@ -191,7 +198,7 @@ def imagine_cmd( caption, precision, model_weights_path, - model_config_path, + model_architecture, prompt_library_path, version, make_gif, @@ -204,5 +211,12 @@ def imagine_cmd( ) +def index_default(items, index, default): + try: + return items[index] + except IndexError: + return default + + if __name__ == "__main__": imagine_cmd() diff --git a/imaginairy/cli/main.py b/imaginairy/cli/main.py index 3722f5ed..aab73584 100644 --- a/imaginairy/cli/main.py +++ b/imaginairy/cli/main.py @@ -45,8 +45,6 @@ def aimg(ctx): aimg.add_command(edit_cmd, name="edit") aimg.add_command(edit_demo_cmd, name="edit-demo") aimg.add_command(imagine_cmd, name="imagine") -# aimg.add_command(prep_images_cmd, name="prep-images") -# aimg.add_command(prune_ckpt_cmd, name="prune-ckpt") aimg.add_command(upscale_cmd, name="upscale") aimg.add_command(run_server_cmd, name="server") aimg.add_command(videogen_cmd, name="videogen") @@ -84,17 +82,16 @@ def model_list_cmd(): """Print list of available models.""" from imaginairy import config - print(f"{'ALIAS': <10} {'NAME': <18} {'DESCRIPTION'}") - for model_config in config.MODEL_CONFIGS: - print( - f"{model_config.alias: <10} {model_config.short_name: <18} {model_config.description}" - ) + print("\nWEIGHT NAMES") + print(f"{'ALIAS': <25} {'NAME': <25} ") + for model_config in config.MODEL_WEIGHT_CONFIGS: + print(f"{model_config.aliases[0]: <25} {model_config.name: <25}") - print("\nCONTROL MODES:") - print(f"{'ALIAS': <10} {'NAME': <18} {'CONTROL TYPE'}") - for control_mode in config.CONTROLNET_CONFIGS: + print("\nCONTROL MODES") + print(f"{'ALIAS': <14} {'NAME': <35} {'CONTROL TYPE'}") + for control_mode in config.CONTROL_CONFIGS: print( - f"{control_mode.alias: <10} {control_mode.short_name: <18} {control_mode.control_type}" + f"{control_mode.aliases[0]: <14} {control_mode.name: <35} {control_mode.control_type}" ) diff --git a/imaginairy/cli/shared.py b/imaginairy/cli/shared.py index 62dae7ee..a7f02832 100644 --- a/imaginairy/cli/shared.py +++ b/imaginairy/cli/shared.py @@ -15,13 +15,12 @@ def imaginairy_click_context(log_level="INFO"): from imaginairy.log_utils import configure_logging - errors_to_catch = (FileNotFoundError, ValidationError) + errors_to_catch = (FileNotFoundError, ValidationError, ValueError) configure_logging(level=log_level) try: yield except errors_to_catch as e: logger.error(e) - exit(1) def _imagine_cmd( @@ -34,15 +33,13 @@ def _imagine_cmd( outdir, output_file_extension, repeats, - height, - width, size, steps, seed, upscale, fix_faces, fix_faces_fidelity, - sampler_type, + solver, log_level, quiet, show_work, @@ -58,7 +55,7 @@ def _imagine_cmd( caption, precision, model_weights_path, - model_config_path, + model_architecture, prompt_library_path, version=False, make_gif=False, @@ -96,15 +93,6 @@ def _imagine_cmd( configure_logging(log_level) - if (height is not None or width is not None) and size is not None: - msg = "You cannot specify both --size and --height/--width. Please choose one." - raise ValueError(msg) - - if size is not None: - from imaginairy.utils.named_resolutions import get_named_resolution - - width, height = get_named_resolution(size) - init_images = [init_image] if isinstance(init_image, str) else init_image from imaginairy.utils import glob_expand_paths @@ -121,7 +109,8 @@ def _imagine_cmd( f"Received {len(prompt_texts)} prompt(s) and {len(init_images)} input image(s). Will repeat the generations {repeats} times to create {total_image_count} images." ) - from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files + from imaginairy.api import imagine_image_files + from imaginairy.schema import ImaginePrompt, LazyLoadingImage new_init_images = [] for _init_image in init_images: @@ -144,6 +133,15 @@ def _imagine_cmd( prompt_expanding_iterators = {} from imaginairy.enhancers.prompt_expansion import expand_prompts + if model_weights_path.lower() not in config.MODEL_WEIGHT_CONFIG_LOOKUP: + model_weights_path = config.ModelWeightsConfig( + name="custom weights", + aliases=["custom"], + weights_location=model_weights_path, + architecture=model_architecture, + defaults={"negative_prompt": config.DEFAULT_NEGATIVE_PROMPT}, + ) + for _ in range(repeats): for prompt_text in prompt_texts: if prompt_text not in prompt_expanding_iterators: @@ -171,10 +169,9 @@ def _imagine_cmd( init_image_strength=init_image_strength, control_inputs=control_inputs, seed=seed, - sampler_type=sampler_type, + solver_type=solver, steps=steps, - height=height, - width=width, + size=size, mask_image=mask_image, mask_prompt=mask_prompt, mask_mode=mask_mode, @@ -185,8 +182,7 @@ def _imagine_cmd( fix_faces_fidelity=fix_faces_fidelity, tile_mode=_tile_mode, allow_compose_phase=allow_compose_phase, - model=model_weights_path, - model_config_path=model_config_path, + model_weights=model_weights_path, caption_text=caption_text, ) from imaginairy.prompt_schedules import ( @@ -318,28 +314,12 @@ def temp_f(): type=int, help="How many times to repeat the renders. If you provide two prompts and --repeat=3 then six images will be generated.", ), - click.option( - "-h", - "--height", - default=None, - show_default=True, - type=int, - help="Image height. Should be multiple of 8.", - ), - click.option( - "-w", - "--width", - default=None, - show_default=True, - type=int, - help="Image width. Should be multiple of 8.", - ), click.option( "--size", default=None, show_default=True, type=str, - help="Image size as a string. Can be a named size or WIDTHxHEIGHT format. Should be multiple of 8. Examples: 512x512, 4k, UHD, 8k, ", + help="Image size as a string. Can be a named size, WIDTHxHEIGHT, or single integer. Should be multiple of 8. Examples: 512x512, 4k, UHD, 8k, 512, 1080p", ), click.option( "--steps", @@ -363,18 +343,18 @@ def temp_f(): help="How faithful to the original should face enhancement be. 1 = best fidelity, 0 = best looking face.", ), click.option( - "--sampler-type", + "--solver", "--sampler", - default=config.DEFAULT_SAMPLER, + default=config.DEFAULT_SOLVER, show_default=True, - type=click.Choice(config.SAMPLER_TYPE_OPTIONS), - help="What sampling strategy to use.", + type=click.Choice(config.SOLVER_TYPE_NAMES, case_sensitive=False), + help="Solver algorithm to generate the image with. (AKA 'Sampler' or 'Scheduler' in other libraries.", ), click.option( "--log-level", default="INFO", show_default=True, - type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"]), + type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False), help="What level of logs to show.", ), click.option( @@ -429,7 +409,7 @@ def temp_f(): "--mask-mode", default="replace", show_default=True, - type=click.Choice(["keep", "replace"]), + type=click.Choice(["keep", "replace"], case_sensitive=False), help="Should we replace the masked area or keep it?", ), click.option( @@ -458,20 +438,20 @@ def temp_f(): click.option( "--precision", help="Evaluate at this precision.", - type=click.Choice(["full", "autocast"]), + type=click.Choice(["full", "autocast"], case_sensitive=False), default="autocast", show_default=True, ), click.option( "--model-weights-path", "--model", - help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.", + help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.", show_default=True, - default=config.DEFAULT_MODEL, + default=config.DEFAULT_MODEL_WEIGHTS, ), click.option( - "--model-config-path", - help="Model config file to use. If a model name is specified, the appropriate config will be used.", + "--model-architecture", + help="Model architecture. When specifying custom weights the model architecture must be specified. (sd15, sdxl, etc).", show_default=True, default=None, ), diff --git a/imaginairy/cli/train.py b/imaginairy/cli/train.py index ec86d17e..b952f14b 100644 --- a/imaginairy/cli/train.py +++ b/imaginairy/cli/train.py @@ -44,9 +44,9 @@ "--model-weights-path", "--model", "model", - help=f"Model to use. Should be one of {', '.join(config.MODEL_SHORT_NAMES)}, or a path to custom weights.", + help=f"Model to use. Should be one of {', '.join(config.IMAGE_WEIGHTS_SHORT_NAMES)}, or a path to custom weights.", show_default=True, - default=config.DEFAULT_MODEL, + default=config.DEFAULT_MODEL_WEIGHTS, ) @click.option( "--person", diff --git a/imaginairy/cli/unslow_the_cli.py b/imaginairy/cli/unslow_the_cli.py new file mode 100644 index 00000000..83892d79 --- /dev/null +++ b/imaginairy/cli/unslow_the_cli.py @@ -0,0 +1,79 @@ +""" +horrible hack to overcome horrible design choices by easy_install/setuptools + +If we don't do this then the scripts will be slow to start up because of +pkg_resources.require() which is called by setuptools to ensure the +"correct" version of the package is installed. +""" +import os + + +def log(text): + # for debugging + pass + # print(text) + + +def find_script_path(script_name): + for path in os.environ["PATH"].split(os.pathsep): + script_path = os.path.join(path, script_name) + if os.path.isfile(script_path): + return script_path + return None + + +def is_already_modified(): + return bool(os.environ.get("IMAGINAIRY_SCRIPT_MODIFIED")) + + +def remove_pkg_resources_requirement(script_path): + import shutil + import tempfile + + with open(script_path) as file: + lines = file.readlines() + + with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: + for line in lines: + if "__import__('pkg_resources').require" not in line: + temp_file.write(line) + else: + temp_file.write( + '\nimport os\nos.environ["IMAGINAIRY_SCRIPT_MODIFIED"] = "1"\n' + ) + log(f"Writing to {temp_file.name}") + + # Preserve the original file permissions + original_permissions = os.stat(script_path).st_mode + os.chmod(temp_file.name, original_permissions) + + # Replace the original file with the modified one + shutil.move(temp_file.name, script_path) + log(f"Replaced {script_path}") + + +has_run = False + + +def unslowify_scripts(): + global has_run + + if has_run or is_already_modified(): + return + + has_run = True + script_names = ["aimg", "imagine"] + + for script_name in script_names: + script_path = find_script_path(script_name) + log(f"Found script {script_name} at {script_path}") + + if script_path: + remove_pkg_resources_requirement(script_path) + + +def unslowify_scripts_safe(): + try: # noqa + unslowify_scripts() + except Exception: # noqa + pass diff --git a/imaginairy/cli/upscale.py b/imaginairy/cli/upscale.py index 9684239e..64945131 100644 --- a/imaginairy/cli/upscale.py +++ b/imaginairy/cli/upscale.py @@ -29,9 +29,9 @@ def upscale_cmd(image_filepaths, outdir, fix_faces, fix_faces_fidelity): from tqdm import tqdm - from imaginairy import LazyLoadingImage from imaginairy.enhancers.face_restoration_codeformer import enhance_faces from imaginairy.enhancers.upscale_realesrgan import upscale_image + from imaginairy.schema import LazyLoadingImage from imaginairy.utils import glob_expand_paths os.makedirs(outdir, exist_ok=True) diff --git a/imaginairy/colorize.py b/imaginairy/colorize.py index c3ec5361..f50f8052 100644 --- a/imaginairy/colorize.py +++ b/imaginairy/colorize.py @@ -2,9 +2,9 @@ from PIL import Image, ImageEnhance, ImageStat -from imaginairy import ImaginePrompt, imagine +from imaginairy.api import imagine from imaginairy.enhancers.describe_image_blip import generate_caption -from imaginairy.schema import ControlNetInput +from imaginairy.schema import ControlInput, ImaginePrompt logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ def colorize_img(img, max_width=1024, max_height=1024, caption=None): caption = caption.replace(" old ", " ") logger.info(caption) control_inputs = [ - ControlNetInput(mode="colorize", image=img, strength=2), + ControlInput(mode="colorize", image=img, strength=2), ] prompt_add = ". color photo, sharp-focus, highly detailed, intricate, Canon 5D" prompt = ImaginePrompt( @@ -31,8 +31,7 @@ def colorize_img(img, max_width=1024, max_height=1024, caption=None): init_image=img, init_image_strength=0.0, control_inputs=control_inputs, - width=min(img.width, max_width), - height=min(img.height, max_height), + size=(min(img.width, max_width), min(img.height, max_height)), steps=30, prompt_strength=12, ) diff --git a/imaginairy/config.py b/imaginairy/config.py index 3e61fc5f..cc43ab47 100644 --- a/imaginairy/config.py +++ b/imaginairy/config.py @@ -1,7 +1,9 @@ from dataclasses import dataclass +from typing import Any, List -DEFAULT_MODEL = "SD-1.5" -DEFAULT_SAMPLER = "ddim" +DEFAULT_MODEL_WEIGHTS = "sd15" +DEFAULT_MODEL_ARCHITECTURE = "sd15" +DEFAULT_SOLVER = "ddim" DEFAULT_NEGATIVE_PROMPT = ( "Ugly, duplication, duplicates, mutilation, deformed, mutilated, mutation, twisted body, disfigured, bad anatomy, " @@ -12,229 +14,313 @@ "grainy, blurred, blurry, writing, calligraphy, signature, text, watermark, bad art," ) -SPLITMEM_ENABLED = False +midas_url = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt" @dataclass -class ModelConfig: - description: str - short_name: str - config_path: str - weights_url: str - default_image_size: int - forced_attn_precision: str = "default" - default_negative_prompt: str = DEFAULT_NEGATIVE_PROMPT - alias: str = None +class ModelArchitecture: + name: str + aliases: List[str] + output_modality: str + defaults: dict[str, Any] + config_path: str | None = None -midas_url = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt" - -MODEL_CONFIGS = [ - ModelConfig( - description="Stable Diffusion 1.5", - short_name="SD-1.5", +MODEL_ARCHITECTURES = [ + ModelArchitecture( + name="Stable Diffusion 1.5", + aliases=["sd15", "sd-15", "sd1.5", "sd-1.5"], + output_modality="image", + defaults={"size": "512"}, config_path="configs/stable-diffusion-v1.yaml", - weights_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt", - default_image_size=512, - alias="sd15", ), - ModelConfig( - description="Stable Diffusion 1.5 - Inpainting", - short_name="SD-1.5-inpaint", + ModelArchitecture( + name="Stable Diffusion 1.5 - Inpainting", + aliases=[ + "sd15inpaint", + "sd15-inpaint", + "sd-15-inpaint", + "sd1.5inpaint", + "sd1.5-inpaint", + "sd-1.5-inpaint", + ], + output_modality="image", + defaults={"size": "512"}, config_path="configs/stable-diffusion-v1-inpaint.yaml", - weights_url="https://huggingface.co/julienacquaviva/inpainting/resolve/2155ff7fe38b55f4c0d99c2f1ab9b561f8311ca7/sd-v1-5-inpainting.ckpt", - default_image_size=512, - alias="sd15in", - ), - # ModelConfig( - # description="Instruct Pix2Pix - Photo Editing", - # short_name="instruct-pix2pix", - # config_path="configs/instruct-pix2pix.yaml", - # weights_url="https://huggingface.co/imaginairy/instruct-pix2pix/resolve/ea0009b3d0d4888f410a40bd06d69516d0b5a577/instruct-pix2pix-00-22000-pruned.ckpt", - # default_image_size=512, - # default_negative_prompt="", - # alias="edit", - # ), - ModelConfig( - description="OpenJourney V1", - short_name="openjourney-v1", - config_path="configs/stable-diffusion-v1.yaml", - weights_url="https://huggingface.co/prompthero/openjourney/resolve/7428477dad893424c92f6ea1cc29d45f6d1448c1/mdjrny-v4.safetensors", - default_image_size=512, - default_negative_prompt="", - alias="oj1", - ), - ModelConfig( - description="OpenJourney V2", - short_name="openjourney-v2", - config_path="configs/stable-diffusion-v1.yaml", - weights_url="https://huggingface.co/prompthero/openjourney-v2/resolve/47257274a40e93dab7fbc0cd2cfd5f5704cfeb60/openjourney-v2.ckpt", - default_image_size=512, - default_negative_prompt="", - alias="oj2", - ), - ModelConfig( - description="OpenJourney V4", - short_name="openjourney-v4", - config_path="configs/stable-diffusion-v1.yaml", - weights_url="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors", - default_image_size=512, - default_negative_prompt="", - alias="oj4", + ), + ModelArchitecture( + name="Stable Diffusion XL", + aliases=["sdxl", "sd-xl"], + output_modality="image", + defaults={"size": "512"}, + ), + ModelArchitecture( + name="Stable Video Diffusion", + aliases=["svd", "stablevideo"], + output_modality="video", + defaults={"size": "1024x576"}, + config_path="configs/svd.yaml", + ), + ModelArchitecture( + name="Stable Video Diffusion - Image Decoder", + aliases=["svd-image-decoder", "svd-imdec"], + output_modality="video", + defaults={"size": "1024x576"}, + config_path="configs/svd_image_decoder.yaml", + ), + ModelArchitecture( + name="Stable Video Diffusion - XT", + aliases=["svd-xt", "svd25f", "svd-25f", "stablevideoxt", "svdxt"], + output_modality="video", + defaults={"size": "1024x576"}, + config_path="configs/svd_xt.yaml", + ), + ModelArchitecture( + name="Stable Video Diffusion - XT - Image Decoder", + aliases=[ + "svd-xt-image-decoder", + "svd-xt-imdec", + "svd-25f-imdec", + "svdxt-imdec", + "svdxtimdec", + "svd25fimdec", + "svdxtimdec", + ], + output_modality="video", + defaults={"size": "1024x576"}, + config_path="configs/svd_xt_image_decoder.yaml", ), ] +MODEL_ARCHITECTURE_LOOKUP = {} +for m in MODEL_ARCHITECTURES: + for a in m.aliases: + MODEL_ARCHITECTURE_LOOKUP[a] = m + + +@dataclass +class ModelWeightsConfig: + name: str + aliases: List[str] + architecture: ModelArchitecture + defaults: dict[str, Any] + weights_location: str + + def __post_init__(self): + if isinstance(self.architecture, str): + self.architecture = MODEL_ARCHITECTURE_LOOKUP[self.architecture] + if not isinstance(self.architecture, ModelArchitecture): + msg = f"You must specify an architecture {self.architecture}" + raise ValueError(msg) # noqa + -video_models = [ - { - "short_name": "svd", - "description": "Stable Video Diffusion", - "default_frames": 14, - "default_steps": 25, - "config_path": "configs/svd.yaml", - "weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd.fp16.safetensors", - }, - { - "short_name": "svd_image_decoder", - "description": "Stable Video Diffusion - Image Decoder", - "default_frames": 14, - "default_steps": 25, - "config_path": "configs/svd_image_decoder.yaml", - "weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_image_decoder.fp16.safetensors", - }, - { - "short_name": "svd_xt", - "description": "Stable Video Diffusion - XT", - "default_frames": 25, - "default_steps": 30, - "config_path": "configs/svd_xt.yaml", - "weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt.fp16.safetensors", - }, - { - "short_name": "svd_xt_image_decoder", - "description": "Stable Video Diffusion - XT - Image Decoder", - "default_frames": 25, - "default_steps": 30, - "config_path": "configs/svd_xt_image_decoder.yaml", - "weights_url": "https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt_image_decoder.fp16.safetensors", - }, +MODEL_WEIGHT_CONFIGS = [ + ModelWeightsConfig( + name="Stable Diffusion 1.5", + aliases=MODEL_ARCHITECTURE_LOOKUP["sd15"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], + defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT}, + weights_location="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/889b629140e71758e1e0006e355c331a5744b4bf/v1-5-pruned-emaonly.ckpt", + ), + ModelWeightsConfig( + name="Stable Diffusion 1.5 - Inpainting", + aliases=MODEL_ARCHITECTURE_LOOKUP["sd15inpaint"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["sd15inpaint"], + defaults={"negative_prompt": DEFAULT_NEGATIVE_PROMPT}, + weights_location="https://huggingface.co/julienacquaviva/inpainting/resolve/2155ff7fe38b55f4c0d99c2f1ab9b561f8311ca7/sd-v1-5-inpainting.ckpt", + ), + ModelWeightsConfig( + name="OpenJourney V1", + aliases=["openjourney-v1", "oj1", "ojv1", "openjourney1"], + architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], + defaults={"negative_prompt": "poor quality"}, + weights_location="https://huggingface.co/prompthero/openjourney/resolve/7428477dad893424c92f6ea1cc29d45f6d1448c1/mdjrny-v4.safetensors", + ), + ModelWeightsConfig( + name="OpenJourney V2", + aliases=["openjourney-v2", "oj2", "ojv2", "openjourney2"], + architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], + weights_location="https://huggingface.co/prompthero/openjourney-v2/resolve/47257274a40e93dab7fbc0cd2cfd5f5704cfeb60/openjourney-v2.ckpt", + defaults={"negative_prompt": "poor quality"}, + ), + ModelWeightsConfig( + name="OpenJourney V4", + aliases=["openjourney-v4", "oj4", "ojv4", "openjourney4", "openjourney", "oj"], + architecture=MODEL_ARCHITECTURE_LOOKUP["sd15"], + weights_location="https://huggingface.co/prompthero/openjourney/resolve/e291118e93d5423dc88ac1ed93c02362b17d698f/mdjrny-v4.safetensors", + defaults={"negative_prompt": "poor quality"}, + ), + # Video Weights + ModelWeightsConfig( + name="Stable Video Diffusion", + aliases=MODEL_ARCHITECTURE_LOOKUP["svd"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["svd"], + weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd.fp16.safetensors", + defaults={"frames": 14, "steps": 25}, + ), + ModelWeightsConfig( + name="Stable Video Diffusion - Image Decoder", + aliases=MODEL_ARCHITECTURE_LOOKUP["svd-image-decoder"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["svd-image-decoder"], + weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_image_decoder.fp16.safetensors", + defaults={"frames": 14, "steps": 25}, + ), + ModelWeightsConfig( + name="Stable Video Diffusion - XT", + aliases=MODEL_ARCHITECTURE_LOOKUP["svdxt"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["svdxt"], + weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt.fp16.safetensors", + defaults={"frames": 25, "steps": 30}, + ), + ModelWeightsConfig( + name="Stable Video Diffusion - XT - Image Decoder", + aliases=MODEL_ARCHITECTURE_LOOKUP["svd-xt-image-decoder"].aliases, + architecture=MODEL_ARCHITECTURE_LOOKUP["svd-xt-image-decoder"], + weights_location="https://huggingface.co/imaginairy/stable-video-diffusion/resolve/f9dce2757a0713da6262f35438050357c2be7ee6/svd_xt_image_decoder.fp16.safetensors", + defaults={"frames": 25, "steps": 30}, + ), ] -video_models = {m["short_name"]: m for m in video_models} -MODEL_CONFIG_SHORTCUTS = {m.short_name: m for m in MODEL_CONFIGS} -for m in MODEL_CONFIGS: - if m.alias: - MODEL_CONFIG_SHORTCUTS[m.alias] = m +MODEL_WEIGHT_CONFIG_LOOKUP = {} +for mw in MODEL_WEIGHT_CONFIGS: + for a in mw.aliases: + MODEL_WEIGHT_CONFIG_LOOKUP[a] = mw -MODEL_CONFIG_SHORTCUTS["openjourney"] = MODEL_CONFIG_SHORTCUTS["openjourney-v2"] -MODEL_CONFIG_SHORTCUTS["oj"] = MODEL_CONFIG_SHORTCUTS["openjourney-v2"] -MODEL_SHORT_NAMES = sorted(MODEL_CONFIG_SHORTCUTS.keys()) +IMAGE_WEIGHTS_SHORT_NAMES = [ + k + for k, mw in MODEL_WEIGHT_CONFIG_LOOKUP.items() + if mw.architecture.output_modality == "image" +] +IMAGE_WEIGHTS_SHORT_NAMES.sort() @dataclass -class ControlNetConfig: - short_name: str +class ControlConfig: + name: str + aliases: List[str] control_type: str config_path: str - weights_url: str - alias: str = None + weights_location: str -CONTROLNET_CONFIGS = [ - ControlNetConfig( - short_name="canny15", +CONTROL_CONFIGS = [ + ControlConfig( + name="Canny Edge Control", + aliases=["canny", "canny15"], control_type="canny", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/115a470d547982438f70198e353a921996e2e819/diffusion_pytorch_model.fp16.safetensors", - alias="canny", + weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/115a470d547982438f70198e353a921996e2e819/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="depth15", + ControlConfig( + name="Depth Control", + aliases=["depth", "depth15"], control_type="depth", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/539f99181d33db39cf1af2e517cd8056785f0a87/diffusion_pytorch_model.fp16.safetensors", - alias="depth", + weights_location="https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/539f99181d33db39cf1af2e517cd8056785f0a87/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="normal15", + ControlConfig( + name="Normal Map Control", + aliases=["normal", "normal15"], control_type="normal", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/cb7296e6587a219068e9d65864e38729cd862aa8/diffusion_pytorch_model.fp16.safetensors", - alias="normal", + weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/cb7296e6587a219068e9d65864e38729cd862aa8/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="hed15", + ControlConfig( + name="Soft Edge Control (HED)", + aliases=["hed", "hed15"], control_type="hed", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/b5bcad0c48e9b12f091968cf5eadbb89402d6bc9/diffusion_pytorch_model.fp16.safetensors", - alias="hed", + weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_softedge/resolve/b5bcad0c48e9b12f091968cf5eadbb89402d6bc9/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="openpose15", + ControlConfig( + name="Pose Control", control_type="openpose", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/9ae9f970358db89e211b87c915f9535c6686d5ba/diffusion_pytorch_model.fp16.safetensors", - alias="openpose", + weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/9ae9f970358db89e211b87c915f9535c6686d5ba/diffusion_pytorch_model.fp16.safetensors", + aliases=["openpose", "pose", "pose15", "openpose15"], ), - ControlNetConfig( - short_name="shuffle15", + ControlConfig( + name="Shuffle Control", control_type="shuffle", config_path="configs/control-net-v15-pool.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/8cf275970f984acf5cc0fdfa537db8be098936a3/diffusion_pytorch_model.fp16.safetensors", - alias="shuffle", + weights_location="https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/8cf275970f984acf5cc0fdfa537db8be098936a3/diffusion_pytorch_model.fp16.safetensors", + aliases=["shuffle", "shuffle15"], ), # "instruct pix2pix" - ControlNetConfig( - short_name="edit15", + ControlConfig( + name="Edit Prompt Control", + aliases=["edit", "edit15"], control_type="edit", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/1fed6ebb905c61929a60514830eb05b039969d6d/diffusion_pytorch_model.fp16.safetensors", - alias="edit", + weights_location="https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p/resolve/1fed6ebb905c61929a60514830eb05b039969d6d/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="inpaint15", + ControlConfig( + name="Inpaint Control", + aliases=["inpaint", "inpaint15"], control_type="inpaint", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/c96e03a807e64135568ba8aecb66b3a306ec73bd/diffusion_pytorch_model.fp16.safetensors", - alias="inpaint", + weights_location="https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint/resolve/c96e03a807e64135568ba8aecb66b3a306ec73bd/diffusion_pytorch_model.fp16.safetensors", ), - ControlNetConfig( - short_name="details15", + ControlConfig( + name="Details Control (Upscale Tile)", + aliases=["details", "details15"], control_type="details", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/3f877705c37010b7221c3d10743307d6b5b6efac/diffusion_pytorch_model.bin", - alias="details", + weights_location="https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/resolve/3f877705c37010b7221c3d10743307d6b5b6efac/diffusion_pytorch_model.bin", ), - ControlNetConfig( - short_name="colorize15", + ControlConfig( + name="Brightness Control (Colorize)", + aliases=["colorize", "colorize15"], control_type="colorize", config_path="configs/control-net-v15.yaml", - weights_url="https://huggingface.co/ioclab/control_v1p_sd15_brightness/resolve/8509361eb1ba89c03839040ed8c75e5f11bbd9c5/diffusion_pytorch_model.safetensors", - alias="colorize", + weights_location="https://huggingface.co/ioclab/control_v1p_sd15_brightness/resolve/8509361eb1ba89c03839040ed8c75e5f11bbd9c5/diffusion_pytorch_model.safetensors", ), ] -CONTROLNET_CONFIG_SHORTCUTS = {} -for m in CONTROLNET_CONFIGS: - if m.alias: - CONTROLNET_CONFIG_SHORTCUTS[m.alias] = m - -for m in CONTROLNET_CONFIGS: - CONTROLNET_CONFIG_SHORTCUTS[m.short_name] = m - -SAMPLER_TYPE_OPTIONS = [ - # "plms", - "ddim", - "k_dpmpp_2m" - # "k_dpm_fast", - # "k_dpm_adaptive", - # "k_lms", - # "k_dpm_2", - # "k_dpm_2_a", - # "k_dpmpp_2m", - # "k_dpmpp_2s_a", - # "k_euler", - # "k_euler_a", - # "k_heun", +CONTROL_CONFIG_SHORTCUTS: dict[str, ControlConfig] = {} +for cc in CONTROL_CONFIGS: + for ca in cc.aliases: + CONTROL_CONFIG_SHORTCUTS[ca] = cc + + +@dataclass +class SolverConfig: + name: str + short_name: str + aliases: List[str] + papers: List[str] + implementations: List[str] + + +SOLVER_CONFIGS = [ + SolverConfig( + name="DDIM", + short_name="DDIM", + aliases=["ddim"], + papers=["https://arxiv.org/abs/2010.02502"], + implementations=[ + "https://github.com/ermongroup/ddim", + "https://github.com/Stability-AI/stablediffusion/blob/cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf/ldm/models/diffusion/ddim.py#L10", + "https://github.com/huggingface/diffusers/blob/76c645d3a641c879384afcb43496f0b7db8cc5cb/src/diffusers/schedulers/scheduling_ddim.py#L131", + ], + ), + SolverConfig( + name="DPM-Solver++", + short_name="DPMPP", + aliases=["dpmpp", "dpm++", "dpmsolver"], + papers=["https://arxiv.org/abs/2211.01095"], + implementations=[ + "https://github.com/LuChengTHU/dpm-solver/blob/52bc3fbcd5de56d60917b826b15d2b69460fc2fa/dpm_solver_pytorch.py#L337", + "https://github.com/apple/ml-stable-diffusion/blob/7449ce46a4b23c94413b714704202e4ea4c55080/swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift#L27", + "https://github.com/crowsonkb/k-diffusion/blob/045515774882014cc14c1ba2668ab5bad9cbf7c0/k_diffusion/sampling.py#L509", + ], + ), ] + +SOLVER_TYPE_NAMES = [s.aliases[0] for s in SOLVER_CONFIGS] + +SOLVER_LOOKUP = {} +for s in SOLVER_CONFIGS: + for a in s.aliases: + SOLVER_LOOKUP[a.lower()] = s diff --git a/imaginairy/enhancers/clip_masking.py b/imaginairy/enhancers/clip_masking.py index 74c78a61..9eb31bd5 100644 --- a/imaginairy/enhancers/clip_masking.py +++ b/imaginairy/enhancers/clip_masking.py @@ -9,6 +9,7 @@ from imaginairy.img_utils import pillow_fit_image_within from imaginairy.log_utils import log_img +from imaginairy.schema import LazyLoadingImage from imaginairy.vendored.clipseg import CLIPDensePredT weights_url = "https://github.com/timojl/clipseg/raw/master/weights/rd64-uni.pth" @@ -32,7 +33,7 @@ def clip_mask_model(): def get_img_mask( - img: PIL.Image.Image, + img: PIL.Image.Image | LazyLoadingImage, mask_description_statement: str, threshold: Optional[float] = None, ): diff --git a/imaginairy/http_app/stablestudio/dist/assets/index-80230e31.js b/imaginairy/http_app/stablestudio/dist/assets/index-80230e31.js index e1432808..89089453 100644 --- a/imaginairy/http_app/stablestudio/dist/assets/index-80230e31.js +++ b/imaginairy/http_app/stablestudio/dist/assets/index-80230e31.js @@ -223,7 +223,7 @@ For more info see: https://github.com/konvajs/react-konva/issues/194 right: 0; bottom: 0; } - `,onMouseMove:o=>{if(o.cancelBubble=!0,(o.evt.buttons===4||o.evt.button===0&&a==="select")&&i.current){const l=o.target.getStage(),u=l.position();l.position({x:u.x+o.evt.movementX,y:u.y+o.evt.movementY}),a!=="brush"&&(document.body.style.cursor="grabbing")}},onWheel:o=>{Ah.changeZoomRelative(o.target.getStage(),o.evt.deltaY)},onMouseDown:o=>{(o.evt.buttons===4||o.evt.button===0&&a==="select")&&(i.current=!0),a==="select"&&s()},onMouseUp:()=>{i.current=!1,a!=="brush"&&(document.body.style.cursor="default")},onMouseLeave:()=>{i.current=!1},children:[C(am,{children:C(QK,{})}),e,C(am,{children:C(Mh,{})})]})})}(e=>{e.use=()=>Zl.useCanvas(),e.useGetContainer=()=>{const f=(0,e.use)();return useCallback(()=>f.current?.container().parentElement,[f])},e.useResize=()=>{const f=(0,e.use)(),p=(0,e.useGetContainer)();return useCallback(()=>{if(!f.current)return;const m=p();if(!m)return;const g=m.offsetWidth,v=m.offsetHeight;!g||!v||f.current.setSize({width:g,height:v})},[f,p])};function t(f,p,m=[]){const g=(0,e.use)();return useEffect(()=>{if(g&&g.current)return g.current.on(f,p),()=>{g.current?.off(f,p)}},[f,p,g,...m]),g?.current||null}e.useStageEvent=t;function n(f,p=[]){return t("mousemove",f,p)}e.useMouseMove=n;function r(f,p=[]){return t("mouseup",f,p)}e.useMouseUp=r;function i(f,p=[]){return t("mousedown",f,p)}e.useMouseDown=i;function a(f){const p=[],m=[];return f.forEach(g=>{const v=g.position(),x=g.width(),w=g.height();p.push(v.x,v.x+x),m.push(v.y,v.y+w)}),{vertical:p,horizontal:m}}e.collectSnapLines=a;function s(f,p,m,g,v,x,w,b){return f<=v&&p<=x&&f+m>=v+w&&p+g>=x+b}e.rectContainsRect=s;function o(f,p){const m=p.scaleX(),g={x:p.width()/2,y:p.height()/2},v={x:(g.x-p.x())/m,y:(g.y-p.y())/m},x=f?m*f:1;p.scale({x,y:x});const w={x:g.x-v.x*x,y:g.y-v.y*x};p.position(w),p.batchDraw()}e.changeZoom=o;function l(f,p,m){if(p===0)return;const g=f.scaleX(),v=f.getPointerPosition(),x=m||{x:(v.x-f.x())/g,y:(v.y-f.y())/g},b=(p>0?-1:1)>0?g*(1+.02):g*(1-.02);f.scale({x:b,y:b});const S={x:v.x-x.x*b,y:v.y-x.y*b};f.position(S),f.batchDraw();const P=f.scaleX();Zl.setZoom(P)}e.changeZoomRelative=l;function u(){const f=(0,e.use)();return useCallback(p=>{p.deltaY!==0&&f&&f.current&&(f.current.setPointersPositions(p),l(f.current,p.deltaY))},[f])}e.useTriggerWheelEvent=u;function c(){const f=(0,e.use)();return useCallback(p=>{f&&f.current&&f.current.fire("mousemove",{evt:p},!0)},[f])}e.useTriggerMouseMoveEvent=c,e.Render=Ix,e.ExportBox=Mh})(Ah||(Ah={}));var Tx;(e=>{e.useCopy=()=>useCallback(async t=>{const n=await ne.Image.blob(t);if(!n)return;const r=new ClipboardItem({[n.type]:n});await navigator.clipboard.write([r])},[])})(Tx||(Tx={}));function XK({className:e,children:t,y:n}){return C(mr.div,{className:classes("m-3 my-4 flex justify-end gap-2",e),initial:{y:n},transition:{duration:.15},children:t})}function kp({button:e,className:t,name:n,...r}){const i={...r,outline:!1,className:classes(t,"pointer-events-auto bg-transparent",!r.noBg&&!r.disabled&&"hover:bg-brand-400 dark:hover:bg-brand-400")},a=e?e(i):C(G.Button,{...i});return C(Qe,{children:n?C(G.Tooltip,{content:n,placement:"top",className:t,children:a}):a})}function e0({image:e}){const t=ne.Image.Download.use(e),n=ne.Image.Session.useSetInitialImage(),r=ne.Image.Session.useCreateVariations(e),i=J.Import.use(e),{input:a}=ne.Image.Input.use(e.inputID),s=useMemo(()=>!!a?.init&&"src"in a.init&&a.init?.src===e.src,[a,e.src]),o=useMemo(()=>a?.id&&!ne.Image.Input.isUpscaling(a)&&C(kp,{button:f=>C(ne.Image.Variations.Create.Button,{...f,id:a.id,icon:G.Icon.Variation,onIdleClick:r,noBrand:!0}),transparent:!0,className:"mr-auto -ml-1"}),[a,r]),l=useMemo(()=>C(kp,{icon:G.Icon.Edit,name:"Edit image",onClick:i,color:"zinc",transparent:!0}),[i]),u=useMemo(()=>C(kp,{icon:G.Icon.Image,name:s?"Initial image":"Set as initial image",disabled:s,className:classes(s&&"opacity-50"),onClick:()=>n(e),color:"zinc",transparent:!0}),[s,n,e]),c=G.useIsMobileDevice();return useMemo(()=>te(Qe,{children:[te("div",{className:"pointer-events-none absolute flex h-full w-full flex-col justify-between opacity-0 duration-150 group-hover:opacity-75",children:[C("div",{className:"h-[6rem] bg-gradient-to-b from-black to-transparent"}),C("div",{className:"h-[6rem] bg-gradient-to-b from-transparent to-black sm:hidden"})]}),te("div",{className:"pointer-events-none absolute flex h-full w-full flex-col justify-between opacity-0 duration-150 group-hover:opacity-100",children:[te(ne.Image.Controls.Buttons,{y:-6,children:[o,!c&&te(Qe,{children:[l,u,C(ne.Image.Controls.Button,{name:"Download image",icon:G.Icon.Download,onClick:()=>t(),transparent:!0}),C(ne.Images.Delete.Button,{images:[e.id]})]})]}),c&&te(ne.Image.Controls.Buttons,{className:"justify-start",y:6,children:[u,C(ne.Image.Controls.Button,{name:"Download image",icon:G.Icon.Download,onClick:()=>t(),transparent:!0}),C(ne.Images.Delete.Button,{images:[e.id]})]})]})]}),[o,c,l,u,t,e.id])}e0.Button=kp;e0.Buttons=XK;var Jl;(e=>{function t(){const[n,r]=e.use();return C(G.Slider,{title:"Image count",min:1,max:10,value:n,onChange:r})}e.Slider=t})(Jl||(Jl={}));(e=>{e.preset=()=>1,e.get=()=>Jy.getState().count,e.set=t=>Jy.getState().setCount(t),e.use=()=>Jy(({count:t,setCount:n})=>[t,n],qe.shallow)})(Jl||(Jl={}));const Jy=qe.create(e=>({count:Jl.preset(),setCount:t=>e({count:t})}));var Mx={exports:{}};(function(e,t){Object.defineProperty(t,"__esModule",{value:!0});function n(r,i,a){a===void 0&&(a=!1),a&&(i=i/r,r=1);var s=[],o=0,l=0,u,c=function(){var f=o+i,p=Date.now();if(pi&&(o=v,l=0),l++{if(a=tQ(a),a in f4)return;f4[a]=!0;const s=a.endsWith(".css"),o=s?'[rel="stylesheet"]':"";if(!!r)for(let c=i.length-1;c>=0;c--){const f=i[c];if(f.href===a&&(!s||f.rel==="stylesheet"))return}else if(document.querySelector(`link[href="${a}"]${o}`))return;const u=document.createElement("link");if(u.rel=s?"stylesheet":eQ,s||(u.as="script",u.crossOrigin=""),u.href=a,document.head.appendChild(u),s)return new Promise((c,f)=>{u.addEventListener("load",c),u.addEventListener("error",()=>f(new Error(`Unable to preload CSS for ${a}`)))})})).then(()=>t())};var p4=e=>{let t,n=new Set,r=(s,o)=>{let l=typeof s=="function"?s(t):s;if(!Object.is(l,t)){let u=t;t=o??typeof l!="object"?l:Object.assign({},t,l),n.forEach(c=>c(t,u))}},i=()=>t,a={setState:r,getState:i,subscribe:s=>(n.add(s),()=>n.delete(s)),destroy:()=>{({VITE_GIT_HASH:"7b568ee739a614875ad45d0c7a84de5be7c1c2a2",BASE_URL:"/",MODE:"production",DEV:!1,PROD:!0,SSR:!1}&&"production")!=="production"&&console.warn("[DEPRECATED] The `destroy` method will be unsupported in a future version. Instead use unsubscribe function returned by subscribe. Everything will be garbage-collected if store is garbage-collected."),n.clear()}};return t=e(r,i,a),a},rQ=e=>e?p4(e):p4,iQ=e=>t=>rQ((n,r)=>e({set:n,get:r,context:t})),Zf=`${window.location.origin}/api/stablestudio`,aQ=iQ(({set:e,get:t})=>({imagesGeneratedSoFar:0,manifest:{name:"imaginAIry Local Diffusion Plugin",author:"Bryce Drennan",link:"https://github.com/brycedrennan/imaginAIry",icon:`${window.location.origin}/DummyImage.png`,version:"0.0.1",license:"MIT",description:"Generate images using imaginAIry."},createStableDiffusionImages:async n=>{console.log(n),e(({imagesGeneratedSoFar:s})=>({imagesGeneratedSoFar:s+4}));let r=t().settings.apiUrl.value??Zf,i=await(await fetch(r+"/generate",{method:"POST",headers:{"Content-Type":"application/json"},body:await oQ(n)})).json();console.log(i);let a=i.images.map(s=>{let o=sQ(s.blob,"image/jpeg");return{id:s.id,createdAt:new Date(s.createdAt),blob:o}});return{id:`${Math.random()*1e7}`,images:a}},getStableDiffusionModels:async()=>{let n=t().settings.apiUrl.value??Zf;return await(await fetch(n+"/models")).json()},getStableDiffusionSamplers:async()=>{let n=t().settings.apiUrl.value??Zf;return await(await fetch(n+"/samplers")).json()},getStableDiffusionDefaultCount:()=>1,getStableDiffusionDefaultInput:()=>(console.log("getStableDiffusionDefaultInput"),{steps:16,sampler:{id:"k_dpmpp_2m",name:"k_dpmpp_2m"},model:"SD-1.5"}),getStatus:()=>{let{imagesGeneratedSoFar:n}=t();return{indicator:"success",text:n>0?`${n} images generated`:"Ready"}},settings:{apiUrl:{type:"string",default:"",placeholder:"URL to imaginAIry API",value:localStorage.getItem("imaginairy-apiUrl")??Zf}},setSetting:(n,r)=>{e(({settings:a})=>({settings:{[n]:{...a[n],value:r}}}));let i="imaginairy-"+n;console.log(i+" : "+r),localStorage.setItem(i,r)}}));function sQ(e,t=""){let n=atob(e),r=new Array(n.length);for(let a=0;a{let r=new FileReader;r.onloadend=()=>{let i=r.result;t(i.split(",")[1])},r.onerror=n,r.readAsDataURL(e)})}async function oQ(e){let t=JSON.parse(JSON.stringify(e));if(e?.input?.initialImage?.blob){let n=await m4(e.input.initialImage.blob);t.input.initialImage.blob=n}if(e?.input?.maskImage?.blob){let n=await m4(e.input.maskImage.blob);t.input.maskImage.blob=n}return JSON.stringify(t)}const lQ=Object.freeze(Object.defineProperty({__proto__:null,createPlugin:aQ},Symbol.toStringTag,{value:"Module"}));function uQ(e,t){var n=Object.setPrototypeOf;n?n(e,t):e.__proto__=t}function cQ(e,t){t===void 0&&(t=e.constructor);var n=Error.captureStackTrace;n&&n(e,t)}var hQ=function(){var e=function(n,r){return e=Object.setPrototypeOf||{__proto__:[]}instanceof Array&&function(i,a){i.__proto__=a}||function(i,a){for(var s in a)Object.prototype.hasOwnProperty.call(a,s)&&(i[s]=a[s])},e(n,r)};return function(t,n){if(typeof n!="function"&&n!==null)throw new TypeError("Class extends value "+String(n)+" is not a constructor or null");e(t,n);function r(){this.constructor=t}t.prototype=n===null?Object.create(n):(r.prototype=n.prototype,new r)}}(),dQ=function(e){hQ(t,e);function t(n,r){var i=this.constructor,a=e.call(this,n,r)||this;return Object.defineProperty(a,"name",{value:i.name,enumerable:!1,configurable:!0}),uQ(a,i.prototype),cQ(a),a}return t}(Error),Rh;(e=>{const t={VITE_GIT_HASH:"7b568ee739a614875ad45d0c7a84de5be7c1c2a2",VITE_USE_EXAMPLE_PLUGIN:{}.VITE_USE_EXAMPLE_PLUGIN??"false",VITE_USE_IMAGINAIRY_PLUGIN:{}.VITE_USE_IMAGINAIRY_PLUGIN??"false"};function n(a){return t[`VITE_${a}`]}e.get=n;class r extends dQ{constructor(s){s.map(o=>`\`${o}\` is \`undefined\`!`),super(s.join(` + `,onMouseMove:o=>{if(o.cancelBubble=!0,(o.evt.buttons===4||o.evt.button===0&&a==="select")&&i.current){const l=o.target.getStage(),u=l.position();l.position({x:u.x+o.evt.movementX,y:u.y+o.evt.movementY}),a!=="brush"&&(document.body.style.cursor="grabbing")}},onWheel:o=>{Ah.changeZoomRelative(o.target.getStage(),o.evt.deltaY)},onMouseDown:o=>{(o.evt.buttons===4||o.evt.button===0&&a==="select")&&(i.current=!0),a==="select"&&s()},onMouseUp:()=>{i.current=!1,a!=="brush"&&(document.body.style.cursor="default")},onMouseLeave:()=>{i.current=!1},children:[C(am,{children:C(QK,{})}),e,C(am,{children:C(Mh,{})})]})})}(e=>{e.use=()=>Zl.useCanvas(),e.useGetContainer=()=>{const f=(0,e.use)();return useCallback(()=>f.current?.container().parentElement,[f])},e.useResize=()=>{const f=(0,e.use)(),p=(0,e.useGetContainer)();return useCallback(()=>{if(!f.current)return;const m=p();if(!m)return;const g=m.offsetWidth,v=m.offsetHeight;!g||!v||f.current.setSize({width:g,height:v})},[f,p])};function t(f,p,m=[]){const g=(0,e.use)();return useEffect(()=>{if(g&&g.current)return g.current.on(f,p),()=>{g.current?.off(f,p)}},[f,p,g,...m]),g?.current||null}e.useStageEvent=t;function n(f,p=[]){return t("mousemove",f,p)}e.useMouseMove=n;function r(f,p=[]){return t("mouseup",f,p)}e.useMouseUp=r;function i(f,p=[]){return t("mousedown",f,p)}e.useMouseDown=i;function a(f){const p=[],m=[];return f.forEach(g=>{const v=g.position(),x=g.width(),w=g.height();p.push(v.x,v.x+x),m.push(v.y,v.y+w)}),{vertical:p,horizontal:m}}e.collectSnapLines=a;function s(f,p,m,g,v,x,w,b){return f<=v&&p<=x&&f+m>=v+w&&p+g>=x+b}e.rectContainsRect=s;function o(f,p){const m=p.scaleX(),g={x:p.width()/2,y:p.height()/2},v={x:(g.x-p.x())/m,y:(g.y-p.y())/m},x=f?m*f:1;p.scale({x,y:x});const w={x:g.x-v.x*x,y:g.y-v.y*x};p.position(w),p.batchDraw()}e.changeZoom=o;function l(f,p,m){if(p===0)return;const g=f.scaleX(),v=f.getPointerPosition(),x=m||{x:(v.x-f.x())/g,y:(v.y-f.y())/g},b=(p>0?-1:1)>0?g*(1+.02):g*(1-.02);f.scale({x:b,y:b});const S={x:v.x-x.x*b,y:v.y-x.y*b};f.position(S),f.batchDraw();const P=f.scaleX();Zl.setZoom(P)}e.changeZoomRelative=l;function u(){const f=(0,e.use)();return useCallback(p=>{p.deltaY!==0&&f&&f.current&&(f.current.setPointersPositions(p),l(f.current,p.deltaY))},[f])}e.useTriggerWheelEvent=u;function c(){const f=(0,e.use)();return useCallback(p=>{f&&f.current&&f.current.fire("mousemove",{evt:p},!0)},[f])}e.useTriggerMouseMoveEvent=c,e.Render=Ix,e.ExportBox=Mh})(Ah||(Ah={}));var Tx;(e=>{e.useCopy=()=>useCallback(async t=>{const n=await ne.Image.blob(t);if(!n)return;const r=new ClipboardItem({[n.type]:n});await navigator.clipboard.write([r])},[])})(Tx||(Tx={}));function XK({className:e,children:t,y:n}){return C(mr.div,{className:classes("m-3 my-4 flex justify-end gap-2",e),initial:{y:n},transition:{duration:.15},children:t})}function kp({button:e,className:t,name:n,...r}){const i={...r,outline:!1,className:classes(t,"pointer-events-auto bg-transparent",!r.noBg&&!r.disabled&&"hover:bg-brand-400 dark:hover:bg-brand-400")},a=e?e(i):C(G.Button,{...i});return C(Qe,{children:n?C(G.Tooltip,{content:n,placement:"top",className:t,children:a}):a})}function e0({image:e}){const t=ne.Image.Download.use(e),n=ne.Image.Session.useSetInitialImage(),r=ne.Image.Session.useCreateVariations(e),i=J.Import.use(e),{input:a}=ne.Image.Input.use(e.inputID),s=useMemo(()=>!!a?.init&&"src"in a.init&&a.init?.src===e.src,[a,e.src]),o=useMemo(()=>a?.id&&!ne.Image.Input.isUpscaling(a)&&C(kp,{button:f=>C(ne.Image.Variations.Create.Button,{...f,id:a.id,icon:G.Icon.Variation,onIdleClick:r,noBrand:!0}),transparent:!0,className:"mr-auto -ml-1"}),[a,r]),l=useMemo(()=>C(kp,{icon:G.Icon.Edit,name:"Edit image",onClick:i,color:"zinc",transparent:!0}),[i]),u=useMemo(()=>C(kp,{icon:G.Icon.Image,name:s?"Initial image":"Set as initial image",disabled:s,className:classes(s&&"opacity-50"),onClick:()=>n(e),color:"zinc",transparent:!0}),[s,n,e]),c=G.useIsMobileDevice();return useMemo(()=>te(Qe,{children:[te("div",{className:"pointer-events-none absolute flex h-full w-full flex-col justify-between opacity-0 duration-150 group-hover:opacity-75",children:[C("div",{className:"h-[6rem] bg-gradient-to-b from-black to-transparent"}),C("div",{className:"h-[6rem] bg-gradient-to-b from-transparent to-black sm:hidden"})]}),te("div",{className:"pointer-events-none absolute flex h-full w-full flex-col justify-between opacity-0 duration-150 group-hover:opacity-100",children:[te(ne.Image.Controls.Buttons,{y:-6,children:[o,!c&&te(Qe,{children:[l,u,C(ne.Image.Controls.Button,{name:"Download image",icon:G.Icon.Download,onClick:()=>t(),transparent:!0}),C(ne.Images.Delete.Button,{images:[e.id]})]})]}),c&&te(ne.Image.Controls.Buttons,{className:"justify-start",y:6,children:[u,C(ne.Image.Controls.Button,{name:"Download image",icon:G.Icon.Download,onClick:()=>t(),transparent:!0}),C(ne.Images.Delete.Button,{images:[e.id]})]})]})]}),[o,c,l,u,t,e.id])}e0.Button=kp;e0.Buttons=XK;var Jl;(e=>{function t(){const[n,r]=e.use();return C(G.Slider,{title:"Image count",min:1,max:10,value:n,onChange:r})}e.Slider=t})(Jl||(Jl={}));(e=>{e.preset=()=>1,e.get=()=>Jy.getState().count,e.set=t=>Jy.getState().setCount(t),e.use=()=>Jy(({count:t,setCount:n})=>[t,n],qe.shallow)})(Jl||(Jl={}));const Jy=qe.create(e=>({count:Jl.preset(),setCount:t=>e({count:t})}));var Mx={exports:{}};(function(e,t){Object.defineProperty(t,"__esModule",{value:!0});function n(r,i,a){a===void 0&&(a=!1),a&&(i=i/r,r=1);var s=[],o=0,l=0,u,c=function(){var f=o+i,p=Date.now();if(pi&&(o=v,l=0),l++{if(a=tQ(a),a in f4)return;f4[a]=!0;const s=a.endsWith(".css"),o=s?'[rel="stylesheet"]':"";if(!!r)for(let c=i.length-1;c>=0;c--){const f=i[c];if(f.href===a&&(!s||f.rel==="stylesheet"))return}else if(document.querySelector(`link[href="${a}"]${o}`))return;const u=document.createElement("link");if(u.rel=s?"stylesheet":eQ,s||(u.as="script",u.crossOrigin=""),u.href=a,document.head.appendChild(u),s)return new Promise((c,f)=>{u.addEventListener("load",c),u.addEventListener("error",()=>f(new Error(`Unable to preload CSS for ${a}`)))})})).then(()=>t())};var p4=e=>{let t,n=new Set,r=(s,o)=>{let l=typeof s=="function"?s(t):s;if(!Object.is(l,t)){let u=t;t=o??typeof l!="object"?l:Object.assign({},t,l),n.forEach(c=>c(t,u))}},i=()=>t,a={setState:r,getState:i,subscribe:s=>(n.add(s),()=>n.delete(s)),destroy:()=>{({VITE_GIT_HASH:"7b568ee739a614875ad45d0c7a84de5be7c1c2a2",BASE_URL:"/",MODE:"production",DEV:!1,PROD:!0,SSR:!1}&&"production")!=="production"&&console.warn("[DEPRECATED] The `destroy` method will be unsupported in a future version. Instead use unsubscribe function returned by subscribe. Everything will be garbage-collected if store is garbage-collected."),n.clear()}};return t=e(r,i,a),a},rQ=e=>e?p4(e):p4,iQ=e=>t=>rQ((n,r)=>e({set:n,get:r,context:t})),Zf=`${window.location.origin}/api/stablestudio`,aQ=iQ(({set:e,get:t})=>({imagesGeneratedSoFar:0,manifest:{name:"imaginAIry Local Diffusion Plugin",author:"Bryce Drennan",link:"https://github.com/brycedrennan/imaginAIry",icon:`${window.location.origin}/DummyImage.png`,version:"0.0.1",license:"MIT",description:"Generate images using imaginAIry."},createStableDiffusionImages:async n=>{console.log(n),e(({imagesGeneratedSoFar:s})=>({imagesGeneratedSoFar:s+4}));let r=t().settings.apiUrl.value??Zf,i=await(await fetch(r+"/generate",{method:"POST",headers:{"Content-Type":"application/json"},body:await oQ(n)})).json();console.log(i);let a=i.images.map(s=>{let o=sQ(s.blob,"image/jpeg");return{id:s.id,createdAt:new Date(s.createdAt),blob:o}});return{id:`${Math.random()*1e7}`,images:a}},getStableDiffusionModels:async()=>{let n=t().settings.apiUrl.value??Zf;return await(await fetch(n+"/models")).json()},getStableDiffusionSamplers:async()=>{let n=t().settings.apiUrl.value??Zf;return await(await fetch(n+"/samplers")).json()},getStableDiffusionDefaultCount:()=>1,getStableDiffusionDefaultInput:()=>(console.log("getStableDiffusionDefaultInput"),{steps:50,sampler:{id:"ddim",name:"ddim"},model:"SD-1.5"}),getStatus:()=>{let{imagesGeneratedSoFar:n}=t();return{indicator:"success",text:n>0?`${n} images generated`:"Ready"}},settings:{apiUrl:{type:"string",default:"",placeholder:"URL to imaginAIry API",value:localStorage.getItem("imaginairy-apiUrl")??Zf}},setSetting:(n,r)=>{e(({settings:a})=>({settings:{[n]:{...a[n],value:r}}}));let i="imaginairy-"+n;console.log(i+" : "+r),localStorage.setItem(i,r)}}));function sQ(e,t=""){let n=atob(e),r=new Array(n.length);for(let a=0;a{let r=new FileReader;r.onloadend=()=>{let i=r.result;t(i.split(",")[1])},r.onerror=n,r.readAsDataURL(e)})}async function oQ(e){let t=JSON.parse(JSON.stringify(e));if(e?.input?.initialImage?.blob){let n=await m4(e.input.initialImage.blob);t.input.initialImage.blob=n}if(e?.input?.maskImage?.blob){let n=await m4(e.input.maskImage.blob);t.input.maskImage.blob=n}return JSON.stringify(t)}const lQ=Object.freeze(Object.defineProperty({__proto__:null,createPlugin:aQ},Symbol.toStringTag,{value:"Module"}));function uQ(e,t){var n=Object.setPrototypeOf;n?n(e,t):e.__proto__=t}function cQ(e,t){t===void 0&&(t=e.constructor);var n=Error.captureStackTrace;n&&n(e,t)}var hQ=function(){var e=function(n,r){return e=Object.setPrototypeOf||{__proto__:[]}instanceof Array&&function(i,a){i.__proto__=a}||function(i,a){for(var s in a)Object.prototype.hasOwnProperty.call(a,s)&&(i[s]=a[s])},e(n,r)};return function(t,n){if(typeof n!="function"&&n!==null)throw new TypeError("Class extends value "+String(n)+" is not a constructor or null");e(t,n);function r(){this.constructor=t}t.prototype=n===null?Object.create(n):(r.prototype=n.prototype,new r)}}(),dQ=function(e){hQ(t,e);function t(n,r){var i=this.constructor,a=e.call(this,n,r)||this;return Object.defineProperty(a,"name",{value:i.name,enumerable:!1,configurable:!0}),uQ(a,i.prototype),cQ(a),a}return t}(Error),Rh;(e=>{const t={VITE_GIT_HASH:"7b568ee739a614875ad45d0c7a84de5be7c1c2a2",VITE_USE_EXAMPLE_PLUGIN:{}.VITE_USE_EXAMPLE_PLUGIN??"false",VITE_USE_IMAGINAIRY_PLUGIN:{}.VITE_USE_IMAGINAIRY_PLUGIN??"false"};function n(a){return t[`VITE_${a}`]}e.get=n;class r extends dQ{constructor(s){s.map(o=>`\`${o}\` is \`undefined\`!`),super(s.join(` `))}}e.MissingVariablesError=r;function i({children:a}){return useEffect(()=>{const s=Object.entries(t).filter(([,o])=>`${o??""}`=="").map(([o])=>o);if(s[0])throw new e.MissingVariablesError(s)},[]),C(Qe,{children:a})}e.Provider=i})(Rh||(Rh={}));var bn;(e=>{e.get=()=>{const{rootPlugin:n,activePluginID:r,plugins:i}=Ya.use.getState();return(r?i[r]?.plugin:n).getState()};function t(n=r=>r){const r=Ya.use(({rootPlugin:i,activePluginID:a,plugins:s})=>a&&s[a]?.plugin?s[a].plugin:i);return qe.useStore(r,n)}e.use=t,e.useUnload=()=>useCallback(n=>{Ya.use.getState().setPlugins(({[n]:r,...i})=>i),Ya.use.getState().setActivePluginID()},[]),e.useSetup=()=>{const[n,r]=useState(void 0),[i,a]=useState(!1),s=useCallback(async o=>{try{a(!0);const u=(await nQ(()=>import(o),[])).createPlugin({getGitHash:()=>Rh.get("GIT_HASH"),getStableDiffusionRandomPrompt:()=>ne.Image.Prompt.Random.get()}),c=ID.create();return Ya.use.getState().setPlugins(f=>({...f,[c]:{enabled:!0,index:Object.keys(f).length,plugin:u}})),Ya.use.getState().setActivePluginID(c),a(!1),u}catch(l){a(!1),r(l instanceof Error?l:new Error("Failed to load plugin"));return}},[]);return{error:n,setError:r,isLoading:i,loadFromURL:s,getFromURL:doNothing}}})(bn||(bn={}));var Ya;(e=>{e.use=qe.create(t=>{const{createPlugin:n}=lQ;return{rootPlugin:n({getGitHash:()=>Rh.get("GIT_HASH"),getStableDiffusionRandomPrompt:()=>ne.Image.Prompt.Random.get()}),setActivePluginID:r=>t({activePluginID:r}),plugins:{},setPlugins:r=>t(typeof r=="function"?({plugins:i})=>({plugins:r(i)}):{plugins:r})}})})(Ya||(Ya={}));function fQ({id:e,noTitle:t,noBrand:n,disabled:r,onIdleClick:i,onClick:a,children:s,...o}){const l=ne.Image.Create.useIsEnabled(),{input:u}=ne.Image.Input.use(e),c=useCallback(p=>{a?.(p),i?.(p)},[i,a]),f=useMemo(()=>u&&ne.Image.Model.StableDiffusionV1.validate(u),[u]);return u?C(G.Button,{size:"lg",color:n?"zinc":"brand",icon:G.Icon.Dream,disabled:r||!l||!f,onClick:c,...o,children:s??(!t&&C(Qe,{children:"Dream"}))}):null}var Ax;(e=>{e.Button=fQ;let t;(r=>{const o=JK(1,500,!0);r.wait=()=>o(()=>Promise.resolve())})(t||(t={})),e.execute=async({count:r=ne.Image.Count.preset(),input:i,onStarted:a=doNothing,onException:s=doNothing,onSuccess:o=doNothing,onFinished:l=doNothing})=>{const{createStableDiffusionImages:u}=bn.get();try{if(!u)throw new Error("Plugin not found");n.set(new Date),a(),await t.wait();const c=await ne.Image.Input.resizeInit(i),f=await ne.Image.Input.toInput(c?{...i,init:{base64:c,weight:i.init?.weight??1,mask:i.init?.mask??!1}}:i);ne.Image.Input.isUpscaling(i)||(f.height=Math.ceil((f.height??512)/64)*64,f.width=Math.ceil((f.width??512)/64)*64);const p=[],m=await u({input:f,count:r});if(m instanceof Error)throw m;if(!m||!m?.images||m?.images?.length<=0)throw new Error;const g={};for(const v of m.images){const x=ID.create(),w={...ne.Image.Input.initial(x),...i,seed:v.input?.seed??i.seed,id:x},b=await pQ(v,w);b&&(p.push(b),g[x]=w)}return ne.Image.Inputs.set({...ne.Image.Inputs.get(),...g}),o(p),l(p),p}catch(c){const f=ne.Image.Exception.create(c);return s(f),l(f),f}},e.use=()=>{const r=ne.Image.Exception.Snackbar.use();return useCallback(async({inputID:i,onStarted:a=doNothing,onException:s=doNothing,onSuccess:o=doNothing,onFinished:l=doNothing,modifiers:u={}})=>{let c=ne.Image.Input.get(i);if(!c)return;c={...c,...u};const f=ne.Image.Output.requested(i,u);return(0,e.execute)({count:u.count??ne.Image.Count.get(),input:c,onStarted:()=>{ne.Image.Output.set(f),a(f)},onException:p=>{r(p),s(p),ne.Image.Output.clear(f.id)},onSuccess:p=>{p.forEach(ne.Image.add),o(p)},onFinished:p=>{ne.Image.Output.received(f.id,p),l(p)}})},[r])},e.useIsEnabled=()=>bn.use(({createStableDiffusionImages:r})=>!!r);let n;(r=>{r.get=()=>i.get().latest,r.set=a=>i.get().setLatest(a),r.use=()=>i.use(({latest:a})=>a,qe.shallow);let i;(a=>{const s=qe.create(o=>({setLatest:l=>o({latest:l})}));a.get=()=>s.getState(),a.use=s})(i||(i={}))})(n=e.Latest||(e.Latest={}))})(Ax||(Ax={}));function pQ(e,t){return new Promise(n=>{const r=e.id,i=e.blob;if(!i||!r)return n();const a=document.createElement("canvas");a.width=t.width,a.height=t.height;const s=a.getContext("2d"),o=new window.Image;o.src=URL.createObjectURL(i),o.onload=()=>{s.drawImage(o,0,0,t.width,t.height,0,0,t.width,t.height),a.toBlob(l=>{if(l){const u=URL.createObjectURL(l);n({id:r,inputID:t.id,created:new Date,src:u,finishReason:0})}})}})}function Dh(){const{image:e,setImage:t,fileName:n,setFileName:r,upscale:i,setUpscale:a}=Dh.State.use(),s=ne.Image.Session.useUpscale(e),[o,l]=useState(!1),u=ne.Image.Download.use(e),c=e&&ne.Image.Input.get(e.inputID),f=ne.Image.Input.isUpscaling(c||{});useEffect(()=>{e&&(c&&r(ne.Image.Download.fileName(c)),a(f))},[f,e,c,r,a]);const p=useCallback(async()=>{if(!e)return;l(!0);const g=await s();if(l(!1),Array.isArray(g))return g[0]},[s,e]),m=i&&!f;return C(G.Modal,{modalName:"Download",open:!!e,onClose:()=>t(),children:e&&te(G.Modal.Panel,{className:"flex w-[25rem] grow",children:[C(G.Modal.TopBar,{onClose:()=>t(),children:C(G.Modal.Title,{className:"text-lg",children:"Download"})}),te("div",{className:"flex flex-col gap-3 p-2",children:[te("div",{className:"flex items-center justify-between",children:[te("div",{className:"flex flex-col gap-1",children:[C(G.Label,{className:"mb-0 ml-0",children:"File name"}),C(G.Input,{fullWidth:!0,onFocus:g=>{const v=g.target.value,x=v.slice(v.lastIndexOf("."));g.target.setSelectionRange(0,v.length-x.length)},placeholder:"File name",value:n,onChange:r})]}),te("div",{className:"flex flex-col gap-1",children:[C(G.Label,{className:"mb-0 ml-0",children:"Upscale"}),C(G.Dropdown,{fullWidth:!0,className:"mx-0",options:[{label:"1x",value:"1x"},{label:"2x",value:"2x"}],value:i?"2x":"1x",onChange:g=>a(g.value==="2x"),disabled:f})]})]}),C(G.Button,{fullWidth:!0,size:"lg",color:"brand",icon:G.Icon.Download,loading:o,badgeRight:m&&C(G.Badge,{color:"brand",children:"0.2"}),onClick:()=>{!e||!c||(m?p().then(g=>{g&&u(g),t()}):(u(),t()))},children:o?"Upscaling...":m?"Upscale and download":"Download"})]})]})})}(e=>{(t=>{t.use=qe.create(n=>({setImage:r=>n(i=>({...i,image:r})),setFileName:r=>n(i=>({...i,fileName:r})),setUpscale:r=>n(i=>({...i,upscale:r})),upscale:!1}))})(e.State||(e.State={}))})(Dh||(Dh={}));var Rx;(e=>{e.Modal=Dh,e.fileName=t=>{const r=(t.prompts.map(({text:s})=>s).join(" ")??"").slice(0,50),i=t.model.replace("stable-diffusion-","").replace("stable-","");return`${t.seed}_${r}_${i}.png`},e.execute=async(t,n,r)=>{const i=await ne.Image.blobURL(t);if(!i)return;const a=document.createElement("a");a.href=i,a.download=r??(0,e.fileName)(n),a.click()},e.use=t=>{const{image:n,fileName:r,setImage:i}=Dh.State.use();return useCallback(async a=>{if(a){const o=ne.Image.Input.get(a.inputID);if(o)return(0,e.execute)(a,o)}if(!t)return;const s=ne.Image.Input.get(t.inputID);s&&(t===n?await(0,e.execute)(t,s,r):i(t))},[r,t,n,i])}})(Rx||(Rx={}));class mQ extends Error{constructor(t,n="UNKNOWN",r){super(t),this.name="RpcError",Object.setPrototypeOf(this,new.target.prototype),this.code=n,this.meta=r??{}}toString(){const t=[this.name+": "+this.message];this.code&&(t.push(""),t.push("Code: "+this.code)),this.serviceName&&this.methodName&&t.push("Method: "+this.serviceName+"/"+this.methodName);let n=Object.entries(this.meta);if(n.length){t.push(""),t.push("Meta:");for(let[r,i]of n)t.push(` ${r}: ${i}`)}return t.join(` `)}}var Dx;(e=>{e.use=()=>{const{enqueueSnackbar:t}=G.Snackbar.use();return useCallback(n=>t(n.description,{variant:"error"}),[t])}})(Dx||(Dx={}));var Ox;(e=>{e.Snackbar=Dx;const t="Something went wrong on our end, please try again later";function n(o){if(!o&&typeof o!="object")return!1;const l=o;return l.cause instanceof Error&&(l.status===void 0||typeof l.status=="string")&&typeof l.description=="string"}e.is=n;function r(o={}){if(console.error(o),o instanceof mQ||o instanceof Error&&o.name==="RpcError"){const u=o;return{cause:u,description:i(u),status:a(u)?"BANNED_TERM":s(u)?"OUT_OF_CREDITS":u.code}}const l=o instanceof Error?o:new Error(toJSON(o));return{cause:l,description:l.message}}e.create=r;function i(o){switch(!0){case o.message.includes("Completion canceled by user."):return"Canceled";case o.message.includes("Unable to remove the only API key remaining."):return"Sorry, you can't delete your only API key";case o.message.includes("Unable to get organization for request."):return"We couldn't find the organization you're asking for";case s(o):return"Not enough credits";case a(o):return"Something isn't quite right with your prompts";case o.message.includes("Unable to get user for request."):return"We couldn't find the user you're asking for";case o.message.includes("Unable to locate request organization."):return"We couldn't find the organization you're asking for ";case o.message.includes("Error filtering prompts."):return"Something isn't quite right with your prompts";case o.message.includes("Invalid prompts detected"):return"Something isn't quite right with your prompts";case o.message.includes("no prompts provided"):return"You have to provide at least one prompt";case o.message.includes("Unable to create a new billing ticket for this request."):return"We couldn't create a charge for some reason";case o.message.includes("cannot use empty Id for project file"):return"You have to specify a project";case o.message.includes("Incorrect time range provided."):return"Something was wrong with the times you provided";case o.message.includes("must provide image parameters"):return"Something isn't right with the image you provided";case o.message.includes("cannot create project with deleted status"):return"That project was already deleted";case o.message.includes("Incorrect amount value."):return"You can't charge less than $0.00";case o.message.includes('API key with ID "'):return"We couldn't find the API key you're asking for";case o.message.includes('Unable to find organization with ID "'):return"We couldn't find the organization you're asking for";case o.message.includes("No auto-charge intent."):return"You aren't set up for auto-charging";case o.message.includes("You do not have permission to access this resource."):return"You don't the right permissions";case o.message.includes("EmailNotVerifiedMessage"):return"You still need to verify your email address";case o.message.includes("user is not a member of the requested organization"):return"You aren't a member of that organization";case o.message.includes("You have insufficient privileges to access this resource."):return"You don't have the right permissions";case o.message.includes("You have too many API keys."):return"There are too many API keys";case o.message.includes("Unable to set default organization."):return"We couldn't set your default organization";case o.message.includes("Unable to update client settings."):return"We couldn't update your settings";case o.message.includes("Unable to connect to the prompt filter requested"):return"We couldn't properly filter your prompt";case o.message.includes("Unable to generate cost env for this request."):return"We couldn't generate a cost estimate for your request";case o.message.includes("Unable to connect to the engine requested"):return"Something's wrong with the model your're trying to use";case o.message.includes("Unable to connect to the classifier requested"):return"We couldn't use the classifier you're asking for";case o.message.includes("An unexpected server error occurred."):return t;case o.message.includes("Unable to create a new API key."):return"We couldn't create an API key for some reason";case o.message.includes("Unable to delete the API key."):return"We couldn't delete the API key for some reason";case o.message.includes("Unable to create an auto-charge intent:"):return"We couldn't set up auto-charge for some reason";case o.message.includes("Unable to get auto-charge intent."):return"We couldn't get your auto-charge settings";case o.message.includes("Unable to create a checkout session for this charge."):return"We couldn't create a charge for some reason";case o.message.includes("Unable to get charges."):return"We couldn't get your charges";case o.message.includes("Unable to delete account, contact support"):return"We couldn't delete your account, please contact support";case o.message.includes("image dimensions must be multiples of 64"):return"We somehow sent an image with the wrong dimensions";default:return t}}function a(o){return o.code==="BANNED_TERM"}e.isBannedTermError=a;function s(o){return o.message.includes("does not have enough balance")}e.isOutOfCreditsError=s})(Ox||(Ox={}));var Oh;(e=>{e.useImage=()=>qc.use(({image:t})=>t),e.useStart=()=>{const t=qc.use(({setImage:n})=>n);return useCallback(n=>t(n),[t])},e.useStop=()=>{const t=qc.use(({setImage:n})=>n);return useCallback(n=>t(void 0),[t])}})(Oh||(Oh={}));var qc;(e=>{e.use=qe.create(t=>({image:void 0,setImage:n=>t({image:n})}))})(qc||(qc={}));function rk({image:e,src:t,onLoadingChange:n,onClick:r,className:i}){const a=useRef(null),s=Oh.useStart(),o=Oh.useStop(),[l,u]=useState(!0),c=t??e?.src;return useEffect(()=>{!a.current?.complete&&u(!0)},[c]),useEffect(()=>{n?.(l)},[l,n]),c?C("img",{src:c,onDragStart:()=>s(e),onDragEnd:()=>o(e),onClick:r,ref:f=>{a.current=f,f&&(f.onload=()=>u(!1))},className:classes("h-full w-full object-cover opacity-0 duration-500",!l&&"opacity-100",i)}):null}rk.Dragging=Oh;/** * react-virtual diff --git a/imaginairy/http_app/stablestudio/models.py b/imaginairy/http_app/stablestudio/models.py index 77cb39ef..a0fb2978 100644 --- a/imaginairy/http_app/stablestudio/models.py +++ b/imaginairy/http_app/stablestudio/models.py @@ -26,7 +26,7 @@ class StableStudioStyle(BaseModel): image: Optional[HttpUrl] = None -class StableStudioSampler(BaseModel): +class StableStudioSolver(BaseModel): id: str name: Optional[str] = None @@ -55,7 +55,7 @@ class StableStudioInput(BaseModel, extra=Extra.forbid): style: Optional[str] = None width: Optional[int] = None height: Optional[int] = None - sampler: Optional[StableStudioSampler] = None + solver: Optional[StableStudioSolver] = Field(None, alias="sampler") cfg_scale: Optional[float] = Field(None, alias="cfgScale") steps: Optional[int] = None seed: Optional[int] = None @@ -88,18 +88,17 @@ def to_imagine_prompt(self): mask_image = self.mask_image.blob if self.mask_image else None - sampler_type = self.sampler.id if self.sampler else None + solver_type = self.solver.id if self.solver else None return ImaginePrompt( prompt=positive_prompt, prompt_strength=self.cfg_scale, negative_prompt=negative_prompt, - model=self.model, - sampler_type=sampler_type, + model_weights=self.model, + solver_type=solver_type, seed=self.seed, steps=self.steps, - height=self.height, - width=self.width, + size=(self.width, self.height), init_image=Image.open(BytesIO(init_image)) if init_image else None, init_image_strength=init_image_strength, mask_image=Image.open(BytesIO(mask_image)) if mask_image else None, diff --git a/imaginairy/http_app/stablestudio/routes.py b/imaginairy/http_app/stablestudio/routes.py index 8aa3e5c8..bebe5e22 100644 --- a/imaginairy/http_app/stablestudio/routes.py +++ b/imaginairy/http_app/stablestudio/routes.py @@ -8,7 +8,7 @@ StableStudioBatchResponse, StableStudioImage, StableStudioModel, - StableStudioSampler, + StableStudioSolver, ) from imaginairy.http_app.utils import generate_image_b64 @@ -37,11 +37,14 @@ async def generate(studio_request: StableStudioBatchRequest): @router.get("/samplers") async def list_samplers(): - from imaginairy.config import SAMPLER_TYPE_OPTIONS + from imaginairy.config import SOLVER_CONFIGS sampler_objs = [] - for sampler_type in SAMPLER_TYPE_OPTIONS: - sampler_obj = StableStudioSampler(id=sampler_type, name=sampler_type) + + for solver_config in SOLVER_CONFIGS: + sampler_obj = StableStudioSolver( + id=solver_config.aliases[0], name=solver_config.aliases[0] + ) sampler_objs.append(sampler_obj) return sampler_objs @@ -49,16 +52,18 @@ async def list_samplers(): @router.get("/models") async def list_models(): - from imaginairy.config import MODEL_CONFIGS + from imaginairy.config import MODEL_WEIGHT_CONFIGS model_objs = [] - for model_config in MODEL_CONFIGS: - if "inpaint" in model_config.description.lower(): + for model_config in MODEL_WEIGHT_CONFIGS: + if "inpaint" in model_config.name.lower(): + continue + if model_config.architecture.output_modality != "image": continue model_obj = StableStudioModel( - id=model_config.short_name, - name=model_config.description, - description=model_config.description, + id=model_config.aliases[0], + name=model_config.name, + description=model_config.name, ) model_objs.append(model_obj) diff --git a/imaginairy/http_app/utils.py b/imaginairy/http_app/utils.py index 68551961..2eb76189 100644 --- a/imaginairy/http_app/utils.py +++ b/imaginairy/http_app/utils.py @@ -1,7 +1,7 @@ import base64 from io import BytesIO -from imaginairy import imagine +from imaginairy.api import imagine def generate_image(prompt): diff --git a/imaginairy/img_processors/control_modes.py b/imaginairy/img_processors/control_modes.py index 2679d1ef..74b16f8d 100644 --- a/imaginairy/img_processors/control_modes.py +++ b/imaginairy/img_processors/control_modes.py @@ -1,7 +1,12 @@ """Functions to create hint images for controlnet.""" +from typing import TYPE_CHECKING, Callable, Dict, Union + +if TYPE_CHECKING: + import numpy as np + from torch import Tensor # noqa -def create_canny_edges(img): +def create_canny_edges(img: "Tensor") -> "Tensor": import cv2 import numpy as np import torch @@ -33,7 +38,7 @@ def create_canny_edges(img): return canny_image -def create_depth_map(img): +def create_depth_map(img: "Tensor") -> "Tensor": import torch orig_size = img.shape[2:] @@ -56,7 +61,7 @@ def create_depth_map(img): return depth_pt -def _create_depth_map_raw(img): +def _create_depth_map_raw(img: "Tensor") -> "Tensor": import torch from imaginairy.modules.midas.api import MiDaSInference, midas_device @@ -83,7 +88,7 @@ def _create_depth_map_raw(img): return depth_pt -def create_normal_map(img): +def create_normal_map(img: "Tensor") -> "Tensor": import torch from imaginairy.vendored.imaginairy_normal_map.model import ( @@ -97,7 +102,7 @@ def create_normal_map(img): return normal_img_t -def create_hed_edges(img_t): +def create_hed_edges(img_t: "Tensor") -> "Tensor": import torch from imaginairy.img_processors.hed_boundary import create_hed_map @@ -120,7 +125,7 @@ def create_hed_edges(img_t): return hint_t -def create_pose_map(img_t): +def create_pose_map(img_t: "Tensor"): from imaginairy.img_processors.openpose import create_body_pose_img from imaginairy.utils import get_device @@ -130,7 +135,7 @@ def create_pose_map(img_t): return pose_t -def make_noise_disk(H, W, C, F): +def make_noise_disk(H: int, W: int, C: int, F: int) -> "np.ndarray": import cv2 import numpy as np @@ -144,7 +149,7 @@ def make_noise_disk(H, W, C, F): return noise -def shuffle_map_np(img, h=None, w=None, f=256): +def shuffle_map_np(img: "np.ndarray", h=None, w=None, f=256) -> "np.ndarray": import cv2 import numpy as np @@ -160,7 +165,7 @@ def shuffle_map_np(img, h=None, w=None, f=256): return cv2.remap(img, flow, None, cv2.INTER_LINEAR) -def shuffle_map_torch(tensor, h=None, w=None, f=256): +def shuffle_map_torch(tensor: "Tensor", h=None, w=None, f=256) -> "Tensor": import torch # Assuming the input tensor is in shape (B, C, H, W) @@ -187,7 +192,7 @@ def shuffle_map_torch(tensor, h=None, w=None, f=256): return shuffled_tensor.to(device) -def inpaint_prep(mask_image_t, target_image_t): +def inpaint_prep(mask_image_t: "Tensor", target_image_t: "Tensor") -> "Tensor": """ Combines the masked image and target image into a single tensor. @@ -207,7 +212,7 @@ def inpaint_prep(mask_image_t, target_image_t): return output_image_t -def to_grayscale(img): +def to_grayscale(img: "Tensor") -> "Tensor": # The dimensions of input should be (batch_size, channels, height, width) if img.dim() != 4: raise ValueError("Input should be a 4d tensor") @@ -228,11 +233,13 @@ def to_grayscale(img): return (gray_3_channels + 1.0) / 2.0 -def noop(img): +def noop(img: "Tensor") -> "Tensor": return (img + 1.0) / 2.0 -CONTROL_MODES = { +FunctionType = Union["Callable[[Tensor, Tensor], Tensor]", "Callable[[Tensor], Tensor]"] + +CONTROL_MODES: Dict[str, FunctionType] = { "canny": create_canny_edges, "depth": create_depth_map, "normal": create_normal_map, diff --git a/imaginairy/img_utils.py b/imaginairy/img_utils.py index b4bc4664..961c0ca3 100644 --- a/imaginairy/img_utils.py +++ b/imaginairy/img_utils.py @@ -23,8 +23,12 @@ def pillow_fit_image_within( - image: PIL.Image.Image, max_height=512, max_width=512, convert="RGB", snap_size=8 -): + image: PIL.Image.Image | LazyLoadingImage, + max_height=512, + max_width=512, + convert="RGB", + snap_size=8, +) -> PIL.Image.Image: image = image.convert(convert) w, h = image.size resize_ratio = 1 @@ -45,17 +49,21 @@ def pillow_fit_image_within( return image -def pillow_img_to_torch_image(img: PIL.Image.Image, convert="RGB"): +def pillow_img_to_torch_image( + img: PIL.Image.Image | LazyLoadingImage, convert="RGB" +) -> torch.Tensor: if convert: img = img.convert(convert) - img = np.array(img).astype(np.float32) / 255.0 + img_np = np.array(img).astype(np.float32) / 255.0 # b, h, w, c => b, c, h, w - img = img[None].transpose(0, 3, 1, 2) - img = torch.from_numpy(img) - return 2.0 * img - 1.0 + img_np = img_np[None].transpose(0, 3, 1, 2) + img_t = torch.from_numpy(img_np) + return 2.0 * img_t - 1.0 -def pillow_mask_to_latent_mask(mask_img: PIL.Image.Image, downsampling_factor): +def pillow_mask_to_latent_mask( + mask_img: PIL.Image.Image | LazyLoadingImage, downsampling_factor +) -> torch.Tensor: mask_img = mask_img.resize( ( mask_img.width // downsampling_factor, @@ -66,31 +74,31 @@ def pillow_mask_to_latent_mask(mask_img: PIL.Image.Image, downsampling_factor): mask = np.array(mask_img).astype(np.float32) / 255.0 mask = mask[None, None] - mask = torch.from_numpy(mask) - return mask + mask_t = torch.from_numpy(mask) + return mask_t -def pillow_img_to_opencv_img(img: PIL.Image.Image): +def pillow_img_to_opencv_img(img: PIL.Image.Image | LazyLoadingImage): open_cv_image = np.array(img) # Convert RGB to BGR open_cv_image = open_cv_image[:, :, ::-1].copy() return open_cv_image -def torch_image_to_openvcv_img(img: torch.Tensor): +def torch_image_to_openvcv_img(img: torch.Tensor) -> np.ndarray: img = (img + 1) / 2 - img = img.detach().cpu().numpy() + img_np = img.detach().cpu().numpy() # assert there is only one image - assert img.shape[0] == 1 - img = img[0] - img = img.transpose(1, 2, 0) - img = (img * 255).astype(np.uint8) + assert img_np.shape[0] == 1 + img_np = img_np[0] + img_np = img_np.transpose(1, 2, 0) + img_np = (img_np * 255).astype(np.uint8) # RGB to BGR - img = img[:, :, ::-1] - return img + img_np = img_np[:, :, ::-1] + return img_np -def torch_img_to_pillow_img(img_t: torch.Tensor): +def torch_img_to_pillow_img(img_t: torch.Tensor) -> PIL.Image.Image: img_t = img_t.to(torch.float32).detach().cpu() if len(img_t.shape) == 3: img_t = img_t.unsqueeze(0) @@ -129,7 +137,9 @@ def model_latents_to_pillow_imgs(latents: torch.Tensor) -> Sequence[PIL.Image.Im return [model_latent_to_pillow_img(latent) for latent in latents] -def pillow_img_to_model_latent(model, img, batch_size=1, half=True): +def pillow_img_to_model_latent( + model, img: PIL.Image.Image | LazyLoadingImage, batch_size=1, half=True +): init_image = pillow_img_to_torch_image(img).to(get_device()) init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) if half: @@ -152,14 +162,18 @@ def imgpaths_to_imgs(imgpaths): def add_caption_to_image( - img, caption, font_size=16, font_path=f"{PKG_ROOT}/data/DejaVuSans.ttf" + img: PIL.Image.Image | LazyLoadingImage, + caption, + font_size=16, + font_path=f"{PKG_ROOT}/data/DejaVuSans.ttf", ): - draw = ImageDraw.Draw(img) + img_pil = img.as_pillow() if isinstance(img, LazyLoadingImage) else img + draw = ImageDraw.Draw(img_pil) font = ImageFont.truetype(font_path, font_size) x = 15 - y = img.height - 15 - font_size + y = img_pil.height - 15 - font_size draw.text( (x, y), diff --git a/imaginairy/model_manager.py b/imaginairy/model_manager.py index a2a18c0d..80dcee62 100644 --- a/imaginairy/model_manager.py +++ b/imaginairy/model_manager.py @@ -13,16 +13,17 @@ try_to_load_from_cache, ) from omegaconf import OmegaConf -from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet +from refiners.foundationals.latent_diffusion import SD1UNet +from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel from safetensors.torch import load_file from imaginairy import config as iconfig -from imaginairy.config import MODEL_SHORT_NAMES +from imaginairy.config import IMAGE_WEIGHTS_SHORT_NAMES, ModelArchitecture from imaginairy.modules import attention from imaginairy.paths import PKG_ROOT from imaginairy.utils import get_device, instantiate_from_config from imaginairy.utils.model_cache import memory_managed_model -from imaginairy.weight_management.conversion import cast_weights +from imaginairy.utils.named_resolutions import normalize_image_size logger = logging.getLogger(__name__) @@ -66,7 +67,7 @@ def load_state_dict(weights_location, half_mode=False, device=None): except FileNotFoundError as e: if e.errno == 2: logger.error( - f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {MODEL_SHORT_NAMES}.' + f'Error: "{ckpt_path}" not a valid path to model weights.\nPreconfigured models you can use: {IMAGE_WEIGHTS_SHORT_NAMES}.' ) sys.exit(1) raise @@ -149,7 +150,7 @@ def add_controlnet(base_state_dict, controlnet_state_dict): def get_diffusion_model( - weights_location=iconfig.DEFAULT_MODEL, + weights_location=iconfig.DEFAULT_MODEL_WEIGHTS, config_path="configs/stable-diffusion-v1.yaml", control_weights_locations=None, half_mode=None, @@ -174,7 +175,7 @@ def get_diffusion_model( f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}" ) return _get_diffusion_model( - iconfig.DEFAULT_MODEL, + iconfig.DEFAULT_MODEL_WEIGHTS, config_path, half_mode, for_inpainting=False, @@ -184,8 +185,8 @@ def get_diffusion_model( def _get_diffusion_model( - weights_location=iconfig.DEFAULT_MODEL, - config_path="configs/stable-diffusion-v1.yaml", + weights_location=iconfig.DEFAULT_MODEL_WEIGHTS, + model_architecture="configs/stable-diffusion-v1.yaml", half_mode=None, for_inpainting=False, control_weights_locations=None, @@ -197,24 +198,20 @@ def _get_diffusion_model( """ global MOST_RECENTLY_LOADED_MODEL - ( - model_config, - weights_location, - config_path, - control_weights_locations, - ) = resolve_model_paths( - weights_path=weights_location, - config_path=config_path, - control_weights_paths=control_weights_locations, + model_weights_config = resolve_model_weights_config( + model_weights=weights_location, + default_model_architecture=model_architecture, for_inpainting=for_inpainting, ) # some models need the attention calculated in float32 - if model_config is not None: - attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision + if model_weights_config is not None: + attention.ATTENTION_PRECISION_OVERRIDE = ( + model_weights_config.forced_attn_precision + ) else: attention.ATTENTION_PRECISION_OVERRIDE = "default" diffusion_model = _load_diffusion_model( - config_path=config_path, + config_path=model_weights_config.architecture.config_path, weights_location=weights_location, half_mode=half_mode, ) @@ -229,74 +226,25 @@ def _get_diffusion_model( def get_diffusion_model_refiners( - weights_location=iconfig.DEFAULT_MODEL, - config_path="configs/stable-diffusion-v1.yaml", - control_weights_locations=None, - dtype=None, - for_inpainting=False, -): - """ - Load a diffusion model. - - Weights location may also be shortcut name, e.g. "SD-1.5" - """ - try: - return _get_diffusion_model_refiners( - weights_location, - config_path, - for_inpainting, - dtype=dtype, - control_weights_locations=control_weights_locations, - ) - except HuggingFaceAuthorizationError as e: - if for_inpainting: - logger.warning( - f"Failed to load inpainting model. Attempting to fall-back to standard model. {e!s}" - ) - return _get_diffusion_model_refiners( - iconfig.DEFAULT_MODEL, - config_path, - dtype=dtype, - for_inpainting=False, - control_weights_locations=control_weights_locations, - ) - raise - - -def _get_diffusion_model_refiners( - weights_location=iconfig.DEFAULT_MODEL, - config_path="configs/stable-diffusion-v1.yaml", + weights_config: iconfig.ModelWeightsConfig, for_inpainting=False, - control_weights_locations=None, - device=None, - dtype=torch.float16, -): - """ - Load a diffusion model. - - Weights location may also be shortcut name, e.g. "SD-1.5" - """ - - sd = _get_diffusion_model_refiners_only( - weights_location=weights_location, - config_path=config_path, + dtype=None, +) -> LatentDiffusionModel: + """Load a diffusion model.""" + return _get_diffusion_model_refiners( + weights_location=weights_config.weights_location, for_inpainting=for_inpainting, - device=device, dtype=dtype, ) - return sd - @lru_cache(maxsize=1) -def _get_diffusion_model_refiners_only( - weights_location=iconfig.DEFAULT_MODEL, - config_path="configs/stable-diffusion-v1.yaml", - for_inpainting=False, - control_weights_locations=None, +def _get_diffusion_model_refiners( + weights_location: str, + for_inpainting: bool = False, device=None, dtype=torch.float16, -): +) -> LatentDiffusionModel: """ Load a diffusion model. @@ -312,29 +260,13 @@ def _get_diffusion_model_refiners_only( device = device or get_device() - ( - model_config, - weights_location, - config_path, - control_weights_locations, - ) = resolve_model_paths( - weights_path=weights_location, - config_path=config_path, - control_weights_paths=control_weights_locations, - for_inpainting=for_inpainting, - ) - # some models need the attention calculated in float32 - if model_config is not None: - attention.ATTENTION_PRECISION_OVERRIDE = model_config.forced_attn_precision - else: - attention.ATTENTION_PRECISION_OVERRIDE = "default" - ( vae_weights, unet_weights, text_encoder_weights, ) = load_stable_diffusion_compvis_weights(weights_location) + StableDiffusionCls: type[LatentDiffusionModel] if for_inpainting: unet = SD1UNet(in_channels=9) StableDiffusionCls = StableDiffusion_1_Inpainting @@ -379,32 +311,6 @@ def _load_diffusion_model(config_path, weights_location, half_mode): return model -def load_controlnet_adapter( - name, - control_weights_location, - target_unet, - scale=1.0, -): - controlnet_state_dict = load_state_dict(control_weights_location, half_mode=False) - controlnet_state_dict = cast_weights( - source_weights=controlnet_state_dict, - source_model_name="controlnet-1-1", - source_component_name="all", - source_format="diffusers", - dest_format="refiners", - ) - - for key in controlnet_state_dict: - controlnet_state_dict[key] = controlnet_state_dict[key].to( - device=target_unet.device, dtype=target_unet.dtype - ) - adapter = SD1ControlnetAdapter( - target=target_unet, name=name, scale=scale, weights=controlnet_state_dict - ) - - return adapter - - @memory_managed_model("controlnet") def load_controlnet(control_weights_location, half_mode): controlnet_state_dict = load_state_dict( @@ -422,58 +328,82 @@ def load_controlnet(control_weights_location, half_mode): return controlnet -def resolve_model_paths( - weights_path=iconfig.DEFAULT_MODEL, - config_path=None, - control_weights_paths=None, - for_inpainting=False, -): +def resolve_model_weights_config( + model_weights: str | iconfig.ModelWeightsConfig, + default_model_architecture: str | None = None, + for_inpainting: bool = False, +) -> iconfig.ModelWeightsConfig: """Resolve weight and config path if they happen to be shortcuts.""" - model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_path, None) - model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get(config_path, None) - - control_weights_paths = control_weights_paths or [] - control_net_metadatas = [ - iconfig.CONTROLNET_CONFIG_SHORTCUTS.get(control_weights_path, None) - for control_weights_path in control_weights_paths - ] - - if not control_net_metadatas and for_inpainting: - model_metadata_w = iconfig.MODEL_CONFIG_SHORTCUTS.get( - f"{weights_path}-inpaint", model_metadata_w + if isinstance(model_weights, iconfig.ModelWeightsConfig): + return model_weights + + if not isinstance(model_weights, str): + msg = f"Invalid model weights: {model_weights}" + raise ValueError(msg) # noqa + + if default_model_architecture is not None and not isinstance( + default_model_architecture, str + ): + msg = f"Invalid model architecture: {default_model_architecture}" + raise ValueError(msg) + + if for_inpainting: + model_weights_config = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get( + f"{model_weights.lower()}-inpaint", None ) - model_metadata_c = iconfig.MODEL_CONFIG_SHORTCUTS.get( - f"{config_path}-inpaint", model_metadata_c + if model_weights_config: + return model_weights_config + + model_weights_config = iconfig.MODEL_WEIGHT_CONFIG_LOOKUP.get( + model_weights.lower(), None + ) + if model_weights_config: + return model_weights_config + + if not default_model_architecture: + msg = "You must specify the model architecture when loading custom weights." + raise ValueError(msg) + + default_model_architecture = default_model_architecture.lower() + model_architecture_config = None + if for_inpainting: + model_architecture_config = iconfig.MODEL_ARCHITECTURE_LOOKUP.get( + f"{default_model_architecture}-inpaint", None ) - if model_metadata_w: - if config_path is None: - config_path = model_metadata_w.config_path + if not model_architecture_config: + model_architecture_config = iconfig.MODEL_ARCHITECTURE_LOOKUP.get( + default_model_architecture, None + ) - weights_path = model_metadata_w.weights_url + if model_architecture_config is None: + msg = f"Invalid model architecture: {default_model_architecture}" + raise ValueError(msg) - if model_metadata_c: - config_path = model_metadata_c.config_path + model_weights_config = iconfig.ModelWeightsConfig( + name="Custom Loaded", + aliases=[], + architecture=model_architecture_config, + weights_location=model_weights, + defaults={}, + ) - if config_path is None: - config_path = iconfig.MODEL_CONFIG_SHORTCUTS[iconfig.DEFAULT_MODEL].config_path - if control_net_metadatas: - if "stable-diffusion-v1" not in config_path: - msg = "Control net is only supported for stable diffusion v1. Please use a different model." - raise ValueError(msg) - control_weights_paths = [cnm.weights_url for cnm in control_net_metadatas] - config_path = control_net_metadatas[0].config_path - model_metadata = model_metadata_w or model_metadata_c - logger.debug(f"Loading model weights from: {weights_path}") - logger.debug(f"Loading model config from: {config_path}") - return model_metadata, weights_path, config_path, control_weights_paths + return model_weights_config -def get_model_default_image_size(weights_location): - model_config = iconfig.MODEL_CONFIG_SHORTCUTS.get(weights_location, None) - if model_config: - return model_config.default_image_size - return 512 +def get_model_default_image_size(model_architecture: str | ModelArchitecture | None): + if isinstance(model_architecture, str): + model_architecture = iconfig.MODEL_ARCHITECTURE_LOOKUP.get( + model_architecture, None + ) + default_size = None + if model_architecture: + default_size = model_architecture.defaults.get("size") + + if default_size is None: + default_size = 512 + default_size = normalize_image_size(default_size) + return default_size def get_current_diffusion_model(): @@ -680,7 +610,6 @@ def open_weights(filepath, device=None): return state_dict -@lru_cache def load_stable_diffusion_compvis_weights(weights_url): from imaginairy.model_manager import get_cached_url_path from imaginairy.utils import get_device diff --git a/imaginairy/modules/cldm.py b/imaginairy/modules/cldm.py index 8a470818..fc77df37 100644 --- a/imaginairy/modules/cldm.py +++ b/imaginairy/modules/cldm.py @@ -2,7 +2,7 @@ from torch import nn from imaginairy.modules.attention import SpatialTransformer -from imaginairy.modules.diffusion.ddpm import LatentDiffusion +from imaginairy.modules.diffusion.ddpm import LatentDiffusion # type: ignore from imaginairy.modules.diffusion.openaimodel import ( AttentionBlock, Downsample, diff --git a/imaginairy/modules/diffusion/ddpm.py b/imaginairy/modules/diffusion/ddpm.py index 8ed146fc..61e5de78 100644 --- a/imaginairy/modules/diffusion/ddpm.py +++ b/imaginairy/modules/diffusion/ddpm.py @@ -1,3 +1,5 @@ +# type: ignore + """ wild mixture of https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -888,7 +890,7 @@ def _get_denoise_row_from_list(self, samples, desc=""): denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) return denoise_grid - def get_first_stage_encoding(self, encoder_posterior): + def get_first_stage_encoding(self, encoder_posterior) -> torch.Tensor: if isinstance(encoder_posterior, DiagonalGaussianDistribution): z = encoder_posterior.mode() elif isinstance(encoder_posterior, torch.Tensor): diff --git a/imaginairy/modules/diffusion/model.py b/imaginairy/modules/diffusion/model.py index 043e22ed..89053e62 100644 --- a/imaginairy/modules/diffusion/model.py +++ b/imaginairy/modules/diffusion/model.py @@ -1,7 +1,6 @@ # pytorch_diffusion + derived encoder decoder import gc import math -from typing import Any, Optional import numpy as np import torch @@ -300,7 +299,7 @@ def __init__(self, in_channels): self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) - self.attention_op: Optional[Any] = None + self.attention_op = None def forward(self, x): h_ = x diff --git a/imaginairy/modules/refiners_sd.py b/imaginairy/modules/refiners_sd.py index 893774c7..86a37c95 100644 --- a/imaginairy/modules/refiners_sd.py +++ b/imaginairy/modules/refiners_sd.py @@ -1,13 +1,20 @@ +import logging import math +from functools import lru_cache from typing import Literal import torch from refiners.fluxion.layers.attentions import ScaledDotProductAttention from refiners.fluxion.layers.chain import ChainError from refiners.foundationals.latent_diffusion import ( + SD1ControlnetAdapter, + SD1UNet, StableDiffusion_1 as RefinerStableDiffusion_1, StableDiffusion_1_Inpainting as RefinerStableDiffusion_1_Inpainting, ) +from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import ( + Controlnet, +) from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( SD1Autoencoder, ) @@ -16,7 +23,9 @@ from torch.nn.modules.utils import _pair from imaginairy.feather_tile import rebuild_image, tile_image -from imaginairy.modules.autoencoder import logger +from imaginairy.weight_management.conversion import cast_weights + +logger = logging.getLogger(__name__) TileModeType = Literal["", "x", "y", "xy"] @@ -53,13 +62,13 @@ def set_tile_mode(self, tile_mode: TileModeType = ""): if isinstance(m, nn.Conv2d): if not hasattr(m, "_orig_conv_forward"): # patch with a function that can handle tiling in a single direction - m._initial_padding_mode = m.padding_mode - m._orig_conv_forward = m._conv_forward - m._conv_forward = _tile_mode_conv2d_conv_forward.__get__( + m._initial_padding_mode = m.padding_mode # type: ignore + m._orig_conv_forward = m._conv_forward # type: ignore + m._conv_forward = _tile_mode_conv2d_conv_forward.__get__( # type: ignore m, nn.Conv2d ) - m.padding_modeX = "circular" if tile_x else "constant" - m.padding_modeY = "circular" if tile_y else "constant" + m.padding_modeX = "circular" if tile_x else "constant" # type: ignore + m.padding_modeY = "circular" if tile_y else "constant" # type: ignore if m.padding_modeY == m.padding_modeX: m.padding_mode = m.padding_modeX m.paddingX = ( @@ -67,13 +76,13 @@ def set_tile_mode(self, tile_mode: TileModeType = ""): m._reversed_padding_repeated_twice[1], 0, 0, - ) + ) # type: ignore m.paddingY = ( 0, 0, m._reversed_padding_repeated_twice[2], m._reversed_padding_repeated_twice[3], - ) + ) # type: ignore class StableDiffusion_1(TileModeMixin, RefinerStableDiffusion_1): @@ -258,3 +267,56 @@ def _process_attention(self, query, key, value, is_causal=None): add_sliced_attention_to_scaled_dot_product_attention(ScaledDotProductAttention) + + +@lru_cache +def monkeypatch_sd1controlnetadapter(): + """ + Another horrible thing. + + I needed to be able to cache the controlnet objects so I wouldn't be making new ones on every image generation. + """ + + def __init__( + self, + target: SD1UNet, + name: str, + weights_location: str, + ) -> None: + self.name = name + controlnet = get_controlnet( + name=name, + weights_location=weights_location, + device=target.device, + dtype=target.dtype, + ) + + self._controlnet: list[Controlnet] = [ # type: ignore + controlnet + ] # not registered by PyTorch + + with self.setup_adapter(target): + super(SD1ControlnetAdapter, self).__init__(target) + + SD1ControlnetAdapter.__init__ = __init__ + + +monkeypatch_sd1controlnetadapter() + + +@lru_cache(maxsize=4) +def get_controlnet(name, weights_location, device, dtype): + from imaginairy.model_manager import load_state_dict + + controlnet_state_dict = load_state_dict(weights_location, half_mode=False) + controlnet_state_dict = cast_weights( + source_weights=controlnet_state_dict, + source_model_name="controlnet-1-1", + source_component_name="all", + source_format="diffusers", + dest_format="refiners", + ) + + controlnet = Controlnet(name=name, scale=1, device=device, dtype=dtype) + controlnet.load_state_dict(controlnet_state_dict) + return controlnet diff --git a/imaginairy/prompt_schedules.py b/imaginairy/prompt_schedules.py index 3072180a..31ba094e 100644 --- a/imaginairy/prompt_schedules.py +++ b/imaginairy/prompt_schedules.py @@ -2,7 +2,7 @@ import re from copy import copy -from imaginairy import ImaginePrompt +from imaginairy.schema import ImaginePrompt from imaginairy.utils import frange diff --git a/imaginairy/safety.py b/imaginairy/safety.py index 4db7e32e..5e1d1d1f 100644 --- a/imaginairy/safety.py +++ b/imaginairy/safety.py @@ -21,7 +21,7 @@ def __init__(self): self.special_care_scores = {} self.is_filtered = False - def add_special_care_score(self, concept_idx, abs_score, threshold): + def add_special_care_score(self, concept_idx: int, abs_score, threshold): adjustment = self._default_adjustment adjusted_score = round(abs_score - threshold + adjustment, 3) try: @@ -138,8 +138,8 @@ def cosine_distance_float32(image_embeds, text_embeds): safety_checker_mod.cosine_distance = cosine_distance_float32 -_CONCEPT_DESCRIPTIONS = [] -_SPECIAL_CARE_DESCRIPTIONS = [] +_CONCEPT_DESCRIPTIONS: list[str] = [] +_SPECIAL_CARE_DESCRIPTIONS: list[str] = [] def create_safety_score(img, safety_mode=SafetyMode.STRICT): diff --git a/imaginairy/samplers/__init__.py b/imaginairy/samplers/__init__.py index 229a5b3b..825b6cd9 100644 --- a/imaginairy/samplers/__init__.py +++ b/imaginairy/samplers/__init__.py @@ -1,21 +1,20 @@ -from imaginairy.samplers import kdiff -from imaginairy.samplers.base import SamplerName # noqa -from imaginairy.samplers.ddim import DDIMSampler +from imaginairy.samplers.base import SolverName # noqa +from imaginairy.samplers.ddim import DDIMSolver -SAMPLERS = [ - # PLMSSampler, - DDIMSampler, +SOLVERS = [ + # PLMSSolver, + DDIMSolver, # kdiff.DPMFastSampler, # kdiff.DPMAdaptiveSampler, # kdiff.LMSSampler, # kdiff.DPM2Sampler, # kdiff.DPM2AncestralSampler, - kdiff.DPMPP2MSampler, + # kdiff.DPMPP2MSampler, # kdiff.DPMPP2SAncestralSampler, # kdiff.EulerSampler, # kdiff.EulerAncestralSampler, # kdiff.HeunSampler, ] -SAMPLER_LOOKUP = {sampler.short_name: sampler for sampler in SAMPLERS} -SAMPLER_TYPE_OPTIONS = [sampler.short_name for sampler in SAMPLERS] +SOLVER_LOOKUP = {s.short_name: s for s in SOLVERS} +SOLVER_TYPE_OPTIONS = [s.short_name for s in SOLVERS] diff --git a/imaginairy/samplers/base.py b/imaginairy/samplers/base.py index 5f411df8..e0203dc3 100644 --- a/imaginairy/samplers/base.py +++ b/imaginairy/samplers/base.py @@ -16,9 +16,10 @@ logger = logging.getLogger(__name__) -class SamplerName: +class SolverName: PLMS = "plms" DDIM = "ddim" + DPMPP = "dpmpp" K_DPM_FAST = "k_dpm_fast" K_DPM_ADAPTIVE = "k_dpm_adaptive" K_LMS = "k_lms" @@ -31,7 +32,7 @@ class SamplerName: K_HEUN = "k_heun" -class ImageSampler(ABC): +class ImageSolver(ABC): short_name: str name: str default_steps: int diff --git a/imaginairy/samplers/ddim.py b/imaginairy/samplers/ddim.py index 4405688a..f3319e9f 100644 --- a/imaginairy/samplers/ddim.py +++ b/imaginairy/samplers/ddim.py @@ -8,9 +8,9 @@ from imaginairy.log_utils import increment_step, log_latent from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like from imaginairy.samplers.base import ( - ImageSampler, + ImageSolver, NoiseSchedule, - SamplerName, + SolverName, get_noise_prediction, mask_blend, ) @@ -19,14 +19,14 @@ logger = logging.getLogger(__name__) -class DDIMSampler(ImageSampler): +class DDIMSolver(ImageSolver): """ Denoising Diffusion Implicit Models. https://arxiv.org/abs/2010.02502 """ - short_name = SamplerName.DDIM + short_name = SolverName.DDIM name = "Denoising Diffusion Implicit Models" default_steps = 50 diff --git a/imaginairy/samplers/kdiff.py b/imaginairy/samplers/kdiff.py index bbe8e7c8..474ba872 100644 --- a/imaginairy/samplers/kdiff.py +++ b/imaginairy/samplers/kdiff.py @@ -1,13 +1,14 @@ # pylama:ignore=W0613 from abc import ABC +from typing import Callable import torch from torch import nn from imaginairy.log_utils import increment_step, log_latent from imaginairy.samplers.base import ( - ImageSampler, - SamplerName, + ImageSolver, + SolverName, get_noise_prediction, mask_blend, ) @@ -57,8 +58,8 @@ def sample_dpm_fast(model, x, sigmas, extra_args=None, disable=False, callback=N ) -class KDiffusionSampler(ImageSampler, ABC): - sampler_func: callable +class KDiffusionSolver(ImageSolver, ABC): + sampler_func: Callable def __init__(self, model): super().__init__(model) @@ -98,9 +99,9 @@ def sample( # see https://github.com/crowsonkb/k-diffusion/issues/43#issuecomment-1305195666 if self.short_name in ( - SamplerName.K_DPM_2, - SamplerName.K_DPMPP_2M, - SamplerName.K_DPM_2_ANCESTRAL, + SolverName.K_DPM_2, + SolverName.K_DPMPP_2M, + SolverName.K_DPM_2_ANCESTRAL, ): sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) @@ -152,73 +153,73 @@ def callback(data): # -# class DPMFastSampler(KDiffusionSampler): -# short_name = SamplerName.K_DPM_FAST +# class DPMFastSampler(KDiffusionSolver): +# short_name = SolverName.K_DPM_FAST # name = "Diffusion probabilistic models - fast" # default_steps = 15 # sampler_func = staticmethod(sample_dpm_fast) # # -# class DPMAdaptiveSampler(KDiffusionSampler): -# short_name = SamplerName.K_DPM_ADAPTIVE +# class DPMAdaptiveSampler(KDiffusionSolver): +# short_name = SolverName.K_DPM_ADAPTIVE # name = "Diffusion probabilistic models - adaptive" # default_steps = 40 # sampler_func = staticmethod(sample_dpm_adaptive) # # -# class DPM2Sampler(KDiffusionSampler): -# short_name = SamplerName.K_DPM_2 +# class DPM2Sampler(KDiffusionSolver): +# short_name = SolverName.K_DPM_2 # name = "Diffusion probabilistic models - 2" # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_dpm_2) # # -# class DPM2AncestralSampler(KDiffusionSampler): -# short_name = SamplerName.K_DPM_2_ANCESTRAL +# class DPM2AncestralSampler(KDiffusionSolver): +# short_name = SolverName.K_DPM_2_ANCESTRAL # name = "Diffusion probabilistic models - 2 ancestral" # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_dpm_2_ancestral) # -class DPMPP2MSampler(KDiffusionSampler): - short_name = SamplerName.K_DPMPP_2M +class DPMPP2MSampler(KDiffusionSolver): + short_name = SolverName.K_DPMPP_2M name = "Diffusion probabilistic models - 2m" default_steps = 15 sampler_func = staticmethod(k_sampling.sample_dpmpp_2m) # -# class DPMPP2SAncestralSampler(KDiffusionSampler): -# short_name = SamplerName.K_DPMPP_2S_ANCESTRAL +# class DPMPP2SAncestralSampler(KDiffusionSolver): +# short_name = SolverName.K_DPMPP_2S_ANCESTRAL # name = "Ancestral sampling with DPM-Solver++(2S) second-order steps." # default_steps = 15 # sampler_func = staticmethod(k_sampling.sample_dpmpp_2s_ancestral) # # -# class EulerSampler(KDiffusionSampler): -# short_name = SamplerName.K_EULER +# class EulerSampler(KDiffusionSolver): +# short_name = SolverName.K_EULER # name = "Algorithm 2 (Euler steps) from Karras et al. (2022)" # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_euler) # # -# class EulerAncestralSampler(KDiffusionSampler): -# short_name = SamplerName.K_EULER_ANCESTRAL +# class EulerAncestralSampler(KDiffusionSolver): +# short_name = SolverName.K_EULER_ANCESTRAL # name = "Euler ancestral" # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_euler_ancestral) # # -# class HeunSampler(KDiffusionSampler): -# short_name = SamplerName.K_HEUN +# class HeunSampler(KDiffusionSolver): +# short_name = SolverName.K_HEUN # name = "Algorithm 2 (Heun steps) from Karras et al. (2022)." # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_heun) # # -# class LMSSampler(KDiffusionSampler): -# short_name = SamplerName.K_LMS +# class LMSSampler(KDiffusionSolver): +# short_name = SolverName.K_LMS # name = "LMS" # default_steps = 40 # sampler_func = staticmethod(k_sampling.sample_lms) diff --git a/imaginairy/samplers/plms.py b/imaginairy/samplers/plms.py index acc0ad3b..6406e6ef 100644 --- a/imaginairy/samplers/plms.py +++ b/imaginairy/samplers/plms.py @@ -8,9 +8,9 @@ from imaginairy.log_utils import increment_step, log_latent from imaginairy.modules.diffusion.util import extract_into_tensor, noise_like from imaginairy.samplers.base import ( - ImageSampler, + ImageSolver, NoiseSchedule, - SamplerName, + SolverName, get_noise_prediction, mask_blend, ) @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -class PLMSSampler(ImageSampler): +class PLMSSolver(ImageSolver): """ probabilistic least-mean-squares. @@ -29,7 +29,7 @@ class PLMSSampler(ImageSampler): https://github.com/luping-liu/PNDM """ - short_name = SamplerName.PLMS + short_name = SolverName.PLMS name = "probabilistic least-mean-squares sampler" default_steps = 40 diff --git a/imaginairy/schema.py b/imaginairy/schema.py index c871a698..cecb70a1 100644 --- a/imaginairy/schema.py +++ b/imaginairy/schema.py @@ -7,43 +7,32 @@ import os.path import random from datetime import datetime, timezone +from enum import Enum from io import BytesIO -from typing import TYPE_CHECKING, Any, List, Literal, Optional +from typing import TYPE_CHECKING, Any, List, Literal, cast from pydantic import ( BaseModel, + ConfigDict, Field, GetCoreSchemaHandler, field_validator, model_validator, ) from pydantic_core import core_schema +from typing_extensions import Self from imaginairy import config if TYPE_CHECKING: + from pathlib import Path + from PIL import Image -else: - Image = Any logger = logging.getLogger(__name__) -def save_image_as_base64(image: "Image.Image") -> str: - buffered = io.BytesIO() - image.save(buffered, format="PNG") - img_bytes = buffered.getvalue() - return base64.b64encode(img_bytes).decode() - - -def load_image_from_base64(image_str: str) -> "Image.Image": - from PIL import Image - - img_bytes = base64.b64decode(image_str) - return Image.open(io.BytesIO(img_bytes)) - - class InvalidUrlError(ValueError): pass @@ -52,7 +41,12 @@ class LazyLoadingImage: """Image file encoded as base64 string.""" def __init__( - self, *, filepath=None, url=None, img: Image = None, b64: Optional[str] = None + self, + *, + filepath: str | None = None, + url: str | None = None, + img: "Image.Image | None" = None, + b64: str | None = None, ): if not filepath and not url and not img and not b64: msg = "You must specify a url or filepath or img or base64 string" @@ -188,7 +182,7 @@ def load_image_from_base64(image_str: str) -> "Image.Image": def as_base64(self): self._load_img() - return self.save_image_as_base64(self._img) # type: ignore + return self.save_image_as_base64(self._img) def as_pillow(self): self._load_img() @@ -208,10 +202,10 @@ def __repr__(self): return f"" -class ControlNetInput(BaseModel): +class ControlInput(BaseModel): mode: str - image: Optional[LazyLoadingImage] = None - image_raw: Optional[LazyLoadingImage] = None + image: LazyLoadingImage | None = None + image_raw: LazyLoadingImage | None = None strength: float = Field(1, ge=0, le=1000) # @field_validator("image", "image_raw", mode="before") @@ -233,8 +227,8 @@ def image_raw_validate(cls, v, info: core_schema.FieldValidationInfo): @field_validator("mode") def mode_validate(cls, v): - if v not in config.CONTROLNET_CONFIG_SHORTCUTS: - valid_modes = list(config.CONTROLNET_CONFIG_SHORTCUTS.keys()) + if v not in config.CONTROL_CONFIG_SHORTCUTS: + valid_modes = list(config.CONTROL_CONFIG_SHORTCUTS.keys()) valid_modes = ", ".join(valid_modes) msg = f"Invalid controlnet mode: '{v}'. Valid modes are: {valid_modes}" raise ValueError(msg) @@ -249,43 +243,53 @@ def __repr__(self): return f"{self.weight}*({self.text})" +class MaskMode(str, Enum): + REPLACE = "replace" + KEEP = "keep" + + +MaskInput = MaskMode | str +PromptInput = str | WeightedPrompt | list[WeightedPrompt] | list[str] | None +InpaintMethod = Literal["finetune", "control"] + + class ImaginePrompt(BaseModel, protected_namespaces=()): - prompt: Optional[List[WeightedPrompt]] = Field(default=None, validate_default=True) - negative_prompt: Optional[List[WeightedPrompt]] = Field( - default=None, validate_default=True - ) - prompt_strength: Optional[float] = Field( - default=7.5, le=10_000, ge=-10_000, validate_default=True + model_config = ConfigDict(extra="forbid", validate_assignment=True) + + prompt: List[WeightedPrompt] = Field(default=None, validate_default=True) # type: ignore + negative_prompt: List[WeightedPrompt] = Field( + default_factory=list, validate_default=True ) - init_image: Optional[LazyLoadingImage] = Field( + prompt_strength: float = Field(default=7.5, le=50, ge=-50, validate_default=True) + init_image: LazyLoadingImage | None = Field( None, description="base64 encoded image", validate_default=True ) - init_image_strength: Optional[float] = Field( + init_image_strength: float | None = Field( ge=0, le=1, default=None, validate_default=True ) - control_inputs: List[ControlNetInput] = Field( + control_inputs: List[ControlInput] = Field( default_factory=list, validate_default=True ) - mask_prompt: Optional[str] = Field( + mask_prompt: str | None = Field( default=None, description="text description of the things to be masked", validate_default=True, ) - mask_image: Optional[LazyLoadingImage] = Field(default=None, validate_default=True) - mask_mode: Optional[Literal["keep", "replace"]] = "replace" + mask_image: LazyLoadingImage | None = Field(default=None, validate_default=True) + mask_mode: MaskMode = MaskMode.REPLACE mask_modify_original: bool = True - outpaint: Optional[str] = "" - model: str = Field(default=config.DEFAULT_MODEL, validate_default=True) - model_config_path: Optional[str] = None - sampler_type: str = Field(default=config.DEFAULT_SAMPLER, validate_default=True) - seed: Optional[int] = Field(default=None, validate_default=True) - steps: Optional[int] = Field(default=None, validate_default=True) - height: Optional[int] = Field(None, ge=1, le=100_000, validate_default=True) - width: Optional[int] = Field(None, ge=1, le=100_000, validate_default=True) + outpaint: str | None = "" + model_weights: config.ModelWeightsConfig = Field( # type: ignore + default=config.DEFAULT_MODEL_WEIGHTS, validate_default=True + ) + solver_type: str = Field(default=config.DEFAULT_SOLVER, validate_default=True) + seed: int | None = Field(default=None, validate_default=True) + steps: int = Field(validate_default=True) + size: tuple[int, int] = Field(validate_default=True) upscale: bool = False fix_faces: bool = False - fix_faces_fidelity: Optional[float] = Field(0.2, ge=0, le=1, validate_default=True) - conditioning: Optional[str] = None + fix_faces_fidelity: float | None = Field(0.2, ge=0, le=1, validate_default=True) + conditioning: str | None = None tile_mode: str = "" allow_compose_phase: bool = True is_intermediate: bool = False @@ -293,23 +297,89 @@ class ImaginePrompt(BaseModel, protected_namespaces=()): caption_text: str = Field( "", description="text to be overlaid on the image", validate_default=True ) + inpaint_method: InpaintMethod = "finetune" - class MaskMode: - REPLACE = "replace" - KEEP = "keep" - - def __init__(self, prompt=None, **kwargs): - # allows `prompt` to be positional - super().__init__(prompt=prompt, **kwargs) + def __init__( + self, + prompt: PromptInput = "", + *, + negative_prompt: PromptInput = None, + prompt_strength: float | None = 7.5, + init_image: LazyLoadingImage | None = None, + init_image_strength: float | None = None, + control_inputs: List[ControlInput] | None = None, + mask_prompt: str | None = None, + mask_image: LazyLoadingImage | None = None, + mask_mode: MaskInput = MaskMode.REPLACE, + mask_modify_original: bool = True, + outpaint: str | None = "", + model_weights: str | config.ModelWeightsConfig = config.DEFAULT_MODEL_WEIGHTS, + solver_type: str = config.DEFAULT_SOLVER, + seed: int | None = None, + steps: int | None = None, + size: int | str | tuple[int, int] | None = None, + upscale: bool = False, + fix_faces: bool = False, + fix_faces_fidelity: float | None = 0.2, + conditioning: str | None = None, + tile_mode: str = "", + allow_compose_phase: bool = True, + is_intermediate: bool = False, + collect_progress_latents: bool = False, + caption_text: str = "", + inpaint_method: InpaintMethod = "finetune", + ): + super().__init__( + prompt=prompt, + negative_prompt=negative_prompt, + prompt_strength=prompt_strength, + init_image=init_image, + init_image_strength=init_image_strength, + control_inputs=control_inputs, + mask_prompt=mask_prompt, + mask_image=mask_image, + mask_mode=mask_mode, + mask_modify_original=mask_modify_original, + outpaint=outpaint, + model_weights=model_weights, + solver_type=solver_type, + seed=seed, + steps=steps, + size=size, + upscale=upscale, + fix_faces=fix_faces, + fix_faces_fidelity=fix_faces_fidelity, + conditioning=conditioning, + tile_mode=tile_mode, + allow_compose_phase=allow_compose_phase, + is_intermediate=is_intermediate, + collect_progress_latents=collect_progress_latents, + caption_text=caption_text, + inpaint_method=inpaint_method, + ) @field_validator("prompt", "negative_prompt", mode="before") - @classmethod - def make_into_weighted_prompts(cls, v): - if isinstance(v, str): - v = [WeightedPrompt(text=v)] - elif isinstance(v, WeightedPrompt): - v = [v] - return v + def make_into_weighted_prompts( + cls, + value: PromptInput, + ) -> list[WeightedPrompt]: + match value: + case None: + return [] + + case str(): + if value is not None: + return [WeightedPrompt(text=value)] + else: + return [] + case WeightedPrompt(): + return [value] + case list(): + if all(isinstance(item, str) for item in value): + return [WeightedPrompt(text=str(p)) for p in value] + elif all(isinstance(item, WeightedPrompt) for item in value): + return cast(List[WeightedPrompt], value) + raise ValueError("Invalid prompt input") @field_validator("prompt", "negative_prompt", mode="after") @classmethod @@ -328,19 +398,17 @@ def sort_prompts(cls, v): @model_validator(mode="after") def validate_negative_prompt(self): - if self.negative_prompt is None: - model_config = config.MODEL_CONFIG_SHORTCUTS.get(self.model, None) - if model_config: - self.negative_prompt = [ - WeightedPrompt(text=model_config.default_negative_prompt) - ] - else: - self.negative_prompt = [ - WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT) - ] + if self.negative_prompt == []: + default_negative_prompt = config.DEFAULT_NEGATIVE_PROMPT + if self.model_weights: + default_negative_prompt = self.model_weights.defaults.get( + "negative_prompt", default_negative_prompt + ) + + self.negative_prompt = [WeightedPrompt(text=default_negative_prompt)] return self - @field_validator("prompt_strength") + @field_validator("prompt_strength", mode="before") def validate_prompt_strength(cls, v): return 7.5 if v is None else v @@ -426,12 +494,30 @@ def validate_mask_prompt(cls, v, info: core_schema.FieldValidationInfo): raise ValueError(msg) return v - @field_validator("model", mode="before") - def set_default_diffusion_model(cls, v): - if v is None: - return config.DEFAULT_MODEL + @model_validator(mode="before") + def resolve_model_weights(cls, data: Any): + if not isinstance(data, dict): + return data - return v + model_weights = data.get("model_weights") + if model_weights is None: + model_weights = config.DEFAULT_MODEL_WEIGHTS + from imaginairy.model_manager import resolve_model_weights_config + + should_use_inpainting = bool( + data.get("mask_image") or data.get("mask_prompt") or data.get("outpaint") + ) + should_use_inpainting_weights = ( + should_use_inpainting and data.get("inpaint_method") == "finetune" + ) + model_weights_config = resolve_model_weights_config( + model_weights=model_weights, + default_model_architecture=None, + for_inpainting=should_use_inpainting_weights, + ) + data["model_weights"] = model_weights_config + + return data @field_validator("seed") def validate_seed(cls, v): @@ -444,35 +530,37 @@ def validate_fix_faces_fidelity(cls, v): return v - @field_validator("sampler_type", mode="after") - def validate_sampler_type(cls, v, info: core_schema.FieldValidationInfo): - from imaginairy.samplers import SamplerName + @field_validator("solver_type", mode="after") + def validate_solver_type(cls, v, info: core_schema.FieldValidationInfo): + from imaginairy.samplers import SolverName if v is None: - v = config.DEFAULT_SAMPLER + v = config.DEFAULT_SOLVER v = v.lower() - if info.data.get("model") == "SD-2.0-v" and v == SamplerName.PLMS: - raise ValueError("PLMS sampler is not supported for SD-2.0-v model.") + if info.data.get("model") == "SD-2.0-v" and v == SolverName.PLMS: + raise ValueError("PLMS solvers is not supported for SD-2.0-v model.") if info.data.get("model") == "edit" and v in ( - SamplerName.PLMS, - SamplerName.DDIM, + SolverName.PLMS, + SolverName.DDIM, ): - msg = "PLMS and DDIM samplers are not supported for pix2pix edit model." + msg = "PLMS and DDIM solvers are not supported for pix2pix edit model." raise ValueError(msg) return v - @field_validator("steps") + @field_validator("steps", mode="before") def validate_steps(cls, v, info: core_schema.FieldValidationInfo): - from imaginairy.samplers import SAMPLER_LOOKUP + steps_lookup = {"ddim": 50, "dpmpp": 20} if v is None: - SamplerCls = SAMPLER_LOOKUP[info.data["sampler_type"]] - v = SamplerCls.default_steps + v = steps_lookup[info.data["solver_type"]] - return int(v) + try: + return int(v) + except (OverflowError, TypeError) as e: + raise ValueError("Steps must be an integer") from e @model_validator(mode="after") def validate_init_image_strength(self): @@ -486,13 +574,30 @@ def validate_init_image_strength(self): return self - @field_validator("height", "width") + @field_validator("size", mode="before") def validate_image_size(cls, v, info: core_schema.FieldValidationInfo): from imaginairy.model_manager import get_model_default_image_size + from imaginairy.utils.named_resolutions import normalize_image_size if v is None: - v = get_model_default_image_size(info.data["model"]) + v = get_model_default_image_size(info.data["model_weights"].architecture) + + width, height = normalize_image_size(v) + + return width, height + @field_validator("size", mode="after") + def validate_image_size_after(cls, v, info: core_schema.FieldValidationInfo): + width, height = v + min_size = 8 + max_size = 100_000 + if not min_size <= width <= max_size: + msg = f"Width must be between {min_size} and {max_size}. Got: {width}" + raise ValueError(msg) + + if not min_size <= height <= max_size: + msg = f"Height must be between {min_size} and {max_size}. Got: {height}" + raise ValueError(msg) return v @field_validator("caption_text", mode="before") @@ -507,7 +612,7 @@ def prompts(self): return self.prompt @property - def prompt_text(self): + def prompt_text(self) -> str: if not self.prompt: return "" if len(self.prompt) == 1: @@ -515,18 +620,43 @@ def prompt_text(self): return "|".join(str(p) for p in self.prompt) @property - def negative_prompt_text(self): + def negative_prompt_text(self) -> str: if not self.negative_prompt: return "" if len(self.negative_prompt) == 1: return self.negative_prompt[0].text return "|".join(str(p) for p in self.negative_prompt) + @property + def width(self) -> int: + return self.size[0] + + @property + def height(self) -> int: + return self.size[1] + + @property + def should_use_inpainting(self) -> bool: + return bool(self.outpaint or self.mask_image or self.mask_prompt) + + @property + def should_use_inpainting_weights(self) -> bool: + return self.should_use_inpainting and self.inpaint_method == "finetune" + + @property + def model_architecture(self) -> config.ModelArchitecture: + return self.model_weights.architecture + def prompt_description(self): return ( f'"{self.prompt_text}" {self.width}x{self.height}px ' f'negative-prompt:"{self.negative_prompt_text}" ' - f"seed:{self.seed} prompt-strength:{self.prompt_strength} steps:{self.steps} sampler-type:{self.sampler_type} init-image-strength:{self.init_image_strength} model:{self.model}" + f"seed:{self.seed} " + f"prompt-strength:{self.prompt_strength} " + f"steps:{self.steps} solver-type:{self.solver_type} " + f"init-image-strength:{self.init_image_strength} " + f"arch:{self.model_architecture.aliases[0]} " + f"weights: {self.model_weights.aliases[0]}" ) def logging_dict(self): @@ -547,7 +677,7 @@ def full_copy(self, deep=True, update=None): new_prompt = new_prompt.model_validate(dict(new_prompt)) return new_prompt - def make_concrete_copy(self): + def make_concrete_copy(self) -> Self: seed = self.seed if self.seed is not None else random.randint(1, 1_000_000_000) return self.full_copy( deep=False, @@ -574,10 +704,6 @@ def __init__( prompt: ImaginePrompt, is_nsfw, safety_score, - upscaled_img=None, - modified_original=None, - mask_binary=None, - mask_grayscale=None, result_images=None, timings=None, progress_latents=None, @@ -594,20 +720,10 @@ def __init__( self.images = {"generated": img} - if upscaled_img: - self.images["upscaled"] = upscaled_img - - if modified_original: - self.images["modified_original"] = modified_original - - if mask_binary: - self.images["mask_binary"] = mask_binary - - if mask_grayscale: - self.images["mask_grayscale"] = mask_grayscale - if result_images: for img_type, r_img in result_images.items(): + if r_img is None: + continue if isinstance(r_img, torch.Tensor): if r_img.shape[1] == 4: r_img = model_latent_to_pillow_img(r_img) @@ -620,7 +736,6 @@ def __init__( # for backward compat self.img = img - self.upscaled_img = upscaled_img self.is_nsfw = is_nsfw self.safety_score = safety_score @@ -628,7 +743,7 @@ def __init__( self.torch_backend = get_device() self.hardware_name = get_hardware_description(get_device()) - def md5(self): + def md5(self) -> str: return hashlib.md5(self.img.tobytes()).hexdigest() def metadata_dict(self): @@ -636,27 +751,27 @@ def metadata_dict(self): "prompt": self.prompt.logging_dict(), } - def timings_str(self): + def timings_str(self) -> str: if not self.timings: return "" return " ".join(f"{k}:{v:.2f}s" for k, v in self.timings.items()) - def _exif(self): + def _exif(self) -> "Image.Exif": from PIL import Image exif = Image.Exif() exif[ExifCodes.ImageDescription] = self.prompt.prompt_description() exif[ExifCodes.UserComment] = json.dumps(self.metadata_dict()) # help future web scrapes not ingest AI generated art - sd_version = self.prompt.model - if len(sd_version) > 20: + sd_version = self.prompt.model_weights.name + if len(sd_version) > 40: sd_version = "custom weights" exif[ExifCodes.Software] = f"Imaginairy / Stable Diffusion {sd_version}" exif[ExifCodes.DateTime] = self.created_at.isoformat(sep=" ")[:19] exif[ExifCodes.HostComputer] = f"{self.torch_backend}:{self.hardware_name}" return exif - def save(self, save_path, image_type="generated"): + def save(self, save_path: "Path | str", image_type: str = "generated") -> None: img = self.images.get(image_type, None) if img is None: msg = f"Image of type {image_type} not stored. Options are: {self.images.keys()}" @@ -665,6 +780,6 @@ def save(self, save_path, image_type="generated"): img.convert("RGB").save(save_path, exif=self._exif()) -class SafetyMode: +class SafetyMode(str, Enum): STRICT = "strict" RELAXED = "relaxed" diff --git a/imaginairy/surprise_me.py b/imaginairy/surprise_me.py index b82fee79..71b4a893 100644 --- a/imaginairy/surprise_me.py +++ b/imaginairy/surprise_me.py @@ -1,63 +1,71 @@ -""" - - -aimg. -""" - +import logging import os.path -from imaginairy import ImaginePrompt, LazyLoadingImage, imagine_image_files from imaginairy.animations import make_gif_animation +from imaginairy.api import imagine_image_files from imaginairy.enhancers.facecrop import detect_faces from imaginairy.img_utils import add_caption_to_image, pillow_fit_image_within -from imaginairy.schema import ControlNetInput +from imaginairy.schema import ControlInput, ImaginePrompt, LazyLoadingImage + +logger = logging.getLogger(__name__) preserve_head_kwargs = { "mask_prompt": "head|face", "mask_mode": "keep", } +preserve_face_kwargs = { + "mask_prompt": "face", + "mask_mode": "keep", +} + generic_prompts = [ - ("add confetti", 6, {}), + ("add confetti", 15, {}), # ("add sparkles", 14, {}), - ("make it christmas", 15, preserve_head_kwargs), - ("make it halloween", 15, {}), - ("give it a dark omninous vibe", 15, {}), - ("give it a bright cheery vibe", 15, {}), + ("make it christmas", 15, preserve_face_kwargs), + ("make it halloween", 15, preserve_face_kwargs), + ("give it a depressing vibe", 10, {}), + ("give it a bright cheery vibe", 10, {}), # weather - ("make it look like a snowstorm", 20, {}), - ("make it midnight", 15, {}), - ("make it a sunny day", 15, {}), - ("add misty fog", 15, {}), + ("make it look like a snowstorm", 15, preserve_face_kwargs), + ("make it sunset", 15, preserve_face_kwargs), + ("make it a sunny day", 15, preserve_face_kwargs), + ("add misty fog", 10, {}), ("make it flooded", 10, {}), # setting - ("make it underwater", 15, {}), + ("make it underwater", 10, {}), ("add fireworks to the sky", 15, {}), # ("make it in a forest", 10, {}), # ("make it grassy", 11, {}), ("make it on mars", 14, {}), # style ("add glitter", 10, {}), - ("turn it into a still from a western", 15, {}), + ("turn it into a still from a western", 10, {}), ("old 1900s photo", 11.5, {}), - ("Daguerreotype", 12, {}), - ("make it anime style", 18, {}), + ("Daguerreotype", 14, {}), + ("make it anime style", 15, {}), + ("watercolor painting", 10, {}), + ("crayon drawing", 10, {}), # ("make it pen and ink style", 20, {}), - ("graphite pencil", 15, {}), - # ("make it a thomas kinkade painting", 20, {}), - ("make it pixar style", 20, {}), + ("graphite pencil", 10, {"negative_prompt": "low quality"}), + ("make it a thomas kinkade painting", 10, {}), + ("make it pixar style", 18, {}), ("low-poly", 20, {}), - ("make it stained glass", 10, {}), - ("make it pop art", 12, {}), - # ("make it street graffiti", 15, {}), + ("make it stained glass", 15, {}), + ("make it pop art", 15, {}), + ("oil painting", 11, {}), + ("street graffiti", 10, {}), + ("photorealistic", 8, {}), ("vector art", 8, {}), ("comic book style. happy", 9, {}), ("starry night painting", 15, {}), ("make it minecraft", 12, {}), # materials ("make it look like a marble statue", 15, {}), + ("marble statue", 15, {}), ("make it look like a golden statue", 15, {}), - # ("make it claymation", 8, {}), + ("golden statue", 15, {}), + # ("make it claymation", 15, {}), ("play-doh", 15, {}), ("voxel", 15, {}), # ("lego", 15, {}), @@ -80,22 +88,26 @@ person_prompt_configs = [ # face - ("make the person close their eyes", 10, only_face_kwargs), - # ( - # "make the person wear intricate highly detailed facepaint. ornate, artistic", - # 9, - # only_face_kwargs, - # ), - # ("make the person wear makeup. professional photoshoot", 8, only_face_kwargs), + ("make the person close their eyes", 7, only_face_kwargs), + ( + "make the person wear intricate highly detailed facepaint. ornate, artistic", + 6, + only_face_kwargs, + ), + # ("make the person wear makeup. professional photoshoot", 15, only_face_kwargs), # ("make the person wear mime makeup. intricate, artistic", 7, only_face_kwargs), - # ("make the person wear clown makeup. intricate, artistic", 6, only_face_kwargs), + ("make the person wear clown makeup. intricate, artistic", 7, only_face_kwargs), ("make the person a cyborg", 14, {}), # clothes - ("make the person wear shiny metal clothes", 14, preserve_head_kwargs), - ("make the person wear a tie-dye shirt", 7.5, preserve_head_kwargs), - ("make the person wear a suit", 7.5, preserve_head_kwargs), - # ("make the person bald", 7.5, {}), - ("change the hair to pink", 7.5, {"mask_mode": "replace", "mask_prompt": "hair"}), + ("make the person wear shiny metal clothes", 14, preserve_face_kwargs), + ("make the person wear a tie-dye shirt", 14, preserve_face_kwargs), + ("make the person wear a suit", 14, preserve_face_kwargs), + ("make the person bald", 15, {}), + ( + "change the hair to pink", + 7.5, + {"mask_mode": "keep", "mask_prompt": "face"}, + ), # ("change the hair to black", 7.5, {"mask_mode": "replace", "mask_prompt": "hair"}), # ("change the hair to blonde", 7.5, {"mask_mode": "replace", "mask_prompt": "hair"}), # ("change the hair to red", 7.5, {"mask_mode": "replace", "mask_prompt": "hair"}), @@ -107,22 +119,22 @@ # ("change the hair to silver", 7.5, {"mask_mode": "replace", "mask_prompt": "hair"}), ( "professional corporate photo headshot. Canon EOS, sharp focus, high resolution", - 10, - {"negative_prompt": "old, ugly"}, + 6, + {"negative_prompt": "low quality"}, ), - # ("make the person stoic. pensive", 10, only_face_kwargs), - # ("make the person sad", 20, {}), - # ("make the person angry", 20, {}), - # ("make the person look like a celebrity", 10, {}), - ("make the person younger", 11, {}), - ("make the person 70 years old", 9, {}), - ("make the person a disney cartoon character", 7.5, {}), + ("make the person stoic. pensive", 7, only_face_kwargs), + ("make the person sad", 7, only_face_kwargs), + ("make the person angry", 7, only_face_kwargs), + ("make the person look like a celebrity", 10, {}), + ("make the person younger", 7, {}), + ("make the person 70 years old", 10, {}), + ("make the person a disney cartoon character", 9, {}), ("turn the humans into robots", 13, {}), - ("make the person darth vader", 15, {}), - ("make the person a starfleet officer", 15, preserve_head_kwargs), - ("make the person a superhero", 15, {}), - ("make the person a tiger", 15, only_face_kwargs), - # ("lego minifig", 15, {}), + ("make the person a jedi knight. star wars", 15, preserve_head_kwargs), + ("make the person a starfleet officer. star trek", 15, preserve_head_kwargs), + ("make the person a superhero", 15, preserve_head_kwargs), + # ("a tiger", 15, only_face_kwargs), + ("lego minifig", 15, {}), ] @@ -138,22 +150,21 @@ def surprise_me_prompts( if person is None: person = bool(detect_faces(img)) prompts = [] - + logger.info("Person detected in photo. Adjusting edits accordingly.") + init_image_strength = 0.3 for prompt_text, strength, kwargs in generic_prompts: + kwargs.setdefault("negative_prompt", None) + kwargs.setdefault("init_image_strength", init_image_strength) if use_controlnet: - strength = 5 - control_input = ControlNetInput(mode="edit", strength=2) + control_input = ControlInput(mode="edit") prompts.append( ImaginePrompt( prompt_text, - negative_prompt="", init_image=img, - init_image_strength=0.3, prompt_strength=strength, control_inputs=[control_input], steps=steps, - width=width, - height=height, + size=(width, height), **kwargs, ) ) @@ -163,10 +174,9 @@ def surprise_me_prompts( prompt_text, init_image=img, prompt_strength=strength, - model="edit", + model_weights="edit", steps=steps, - width=width, - height=height, + size=(width, height), **kwargs, ) ) @@ -178,19 +188,19 @@ def surprise_me_prompts( for prompt_subconfig in prompt_subconfigs: prompt_text, strength, kwargs = prompt_subconfig if use_controlnet: - control_input = ControlNetInput( + control_input = ControlInput( mode="edit", ) + kwargs.setdefault("negative_prompt", None) + kwargs.setdefault("init_image_strength", init_image_strength) prompts.append( ImaginePrompt( prompt_text, init_image=img, - init_image_strength=0.05, prompt_strength=strength, control_inputs=[control_input], steps=steps, - width=width, - height=height, + size=(width, height), seed=seed, **kwargs, ) @@ -201,10 +211,9 @@ def surprise_me_prompts( prompt_text, init_image=img, prompt_strength=strength, - model="edit", + model_weights="edit", steps=steps, - width=width, - height=height, + size=(width, height), seed=seed, **kwargs, ) @@ -247,7 +256,7 @@ def create_surprise_me_images( gif_imgs.append(gen_img) - make_gif_animation(outpath=new_filename, imgs=gif_imgs) + make_gif_animation(outpath=new_filename, imgs=gif_imgs, frame_duration_ms=1000) if __name__ == "__main__": diff --git a/imaginairy/train.py b/imaginairy/train.py new file mode 100644 index 00000000..41f42120 --- /dev/null +++ b/imaginairy/train.py @@ -0,0 +1,533 @@ +import datetime +import logging +import os +import signal +import time +from functools import partial + +import numpy as np +import pytorch_lightning as pl +import torch +import torchvision +from omegaconf import OmegaConf +from PIL import Image +from pytorch_lightning import seed_everything +from pytorch_lightning.callbacks import Callback, LearningRateMonitor + +try: + from pytorch_lightning.strategies import DDPStrategy +except ImportError: + # let's not break all of imaginairy just because a training import doesn't exist in an older version of PL + # Use >= 1.6.0 to make this work + DDPStrategy = None # type: ignore +import contextlib + +from pytorch_lightning.trainer import Trainer +from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities.distributed import rank_zero_only +from torch.utils.data import DataLoader, Dataset + +from imaginairy import config +from imaginairy.model_manager import get_diffusion_model +from imaginairy.training_tools.single_concept import SingleConceptDataset +from imaginairy.utils import get_device, instantiate_from_config + +mod_logger = logging.getLogger(__name__) + +referenced_by_string = [LearningRateMonitor] + + +class WrappedDataset(Dataset): + """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset.""" + + def __init__(self, dataset): + self.data = dataset + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + + dataset = worker_info.dataset + worker_id = worker_info.id + + if isinstance(dataset, SingleConceptDataset): + # split_size = dataset.num_records // worker_info.num_workers + # reset num_records to the true number to retain reliable length information + # dataset.sample_ids = dataset.valid_ids[ + # worker_id * split_size : (worker_id + 1) * split_size + # ] + current_id = np.random.choice(len(np.random.get_state()[1]), 1) + return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +class DataModuleFromConfig(pl.LightningDataModule): + def __init__( + self, + batch_size, + train=None, + validation=None, + test=None, + predict=None, + wrap=False, + num_workers=None, + shuffle_test_loader=False, + use_worker_init_fn=False, + shuffle_val_dataloader=False, + num_val_workers=0, + ): + super().__init__() + self.batch_size = batch_size + self.dataset_configs = {} + self.num_workers = num_workers if num_workers is not None else batch_size * 2 + if num_val_workers is None: + self.num_val_workers = self.num_workers + else: + self.num_val_workers = num_val_workers + self.use_worker_init_fn = use_worker_init_fn + if train is not None: + self.dataset_configs["train"] = train + self.train_dataloader = self._train_dataloader + if validation is not None: + self.dataset_configs["validation"] = validation + self.val_dataloader = partial( + self._val_dataloader, shuffle=shuffle_val_dataloader + ) + if test is not None: + self.dataset_configs["test"] = test + self.test_dataloader = partial( + self._test_dataloader, shuffle=shuffle_test_loader + ) + if predict is not None: + self.dataset_configs["predict"] = predict + self.predict_dataloader = self._predict_dataloader + self.wrap = wrap + self.datasets = None + + def prepare_data(self): + for data_cfg in self.dataset_configs.values(): + instantiate_from_config(data_cfg) + + def setup(self, stage=None): + self.datasets = { + k: instantiate_from_config(c) for k, c in self.dataset_configs.items() + } + if self.wrap: + self.datasets = {k: WrappedDataset(v) for k, v in self.datasets.items()} + + def _train_dataloader(self): + is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset) + if is_iterable_dataset or self.use_worker_init_fn: + pass + else: + pass + return DataLoader( + self.datasets["train"], + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + worker_init_fn=worker_init_fn, + ) + + def _val_dataloader(self, shuffle=False): + if ( + isinstance(self.datasets["validation"], SingleConceptDataset) + or self.use_worker_init_fn + ): + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader( + self.datasets["validation"], + batch_size=self.batch_size, + num_workers=self.num_val_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) + + def _test_dataloader(self, shuffle=False): + is_iterable_dataset = isinstance(self.datasets["train"], SingleConceptDataset) + if is_iterable_dataset or self.use_worker_init_fn: + init_fn = worker_init_fn + else: + init_fn = None + is_iterable_dataset = False + + # do not shuffle dataloader for iterable dataset + shuffle = shuffle and (not is_iterable_dataset) + + return DataLoader( + self.datasets["test"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + shuffle=shuffle, + ) + + def _predict_dataloader(self, shuffle=False): + if ( + isinstance(self.datasets["predict"], SingleConceptDataset) + or self.use_worker_init_fn + ): + init_fn = worker_init_fn + else: + init_fn = None + return DataLoader( + self.datasets["predict"], + batch_size=self.batch_size, + num_workers=self.num_workers, + worker_init_fn=init_fn, + ) + + +class SetupCallback(Callback): + def __init__( + self, + resume, + now, + logdir, + ckptdir, + cfgdir, + ): + super().__init__() + self.resume = resume + self.now = now + self.logdir = logdir + self.ckptdir = ckptdir + self.cfgdir = cfgdir + + def on_keyboard_interrupt(self, trainer, pl_module): + if trainer.global_rank == 0: + mod_logger.info("Stopping execution and saving final checkpoint.") + ckpt_path = os.path.join(self.ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + def on_fit_start(self, trainer, pl_module): + if trainer.global_rank == 0: + # Create logdirs and save configs + os.makedirs(self.logdir, exist_ok=True) + os.makedirs(self.ckptdir, exist_ok=True) + os.makedirs(self.cfgdir, exist_ok=True) + + else: + # ModelCheckpoint callback created log directory --- remove it + if not self.resume and os.path.exists(self.logdir): + dst, name = os.path.split(self.logdir) + dst = os.path.join(dst, "child_runs", name) + os.makedirs(os.path.split(dst)[0], exist_ok=True) + with contextlib.suppress(FileNotFoundError): + os.rename(self.logdir, dst) + + +class ImageLogger(Callback): + def __init__( + self, + batch_frequency, + max_images, + clamp=True, + increase_log_steps=True, + rescale=True, + disabled=False, + log_on_batch_idx=False, + log_first_step=False, + log_images_kwargs=None, + log_all_val=False, + concept_label=None, + ): + super().__init__() + self.rescale = rescale + self.batch_freq = batch_frequency + self.max_images = max_images + self.logger_log_images = {} + self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)] + if not increase_log_steps: + self.log_steps = [self.batch_freq] + self.clamp = clamp + self.disabled = disabled + self.log_on_batch_idx = log_on_batch_idx + self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} + self.log_first_step = log_first_step + self.log_all_val = log_all_val + self.concept_label = concept_label + + @rank_zero_only + def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx): + root = os.path.join(save_dir, "logs", "images", split) + for k in images: + grid = torchvision.utils.make_grid(images[k], nrow=4) + if self.rescale: + grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) + grid = grid.numpy() + grid = (grid * 255).astype(np.uint8) + filename = ( + f"{k}_gs-{global_step:06}_e-{current_epoch:06}_b-{batch_idx:06}.png" + ) + path = os.path.join(root, filename) + os.makedirs(os.path.split(path)[0], exist_ok=True) + Image.fromarray(grid).save(path) + + def log_img(self, pl_module, batch, batch_idx, split="train"): + # always generate the concept label + batch["txt"][0] = self.concept_label + + check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step + if self.log_all_val and split == "val": + should_log = True + else: + should_log = self.check_frequency(check_idx) + if ( + should_log + and (batch_idx % self.batch_freq == 0) + and hasattr(pl_module, "log_images") + and callable(pl_module.log_images) + and self.max_images > 0 + ): + logger = type(pl_module.logger) + + is_train = pl_module.training + if is_train: + pl_module.eval() + + with torch.no_grad(): + images = pl_module.log_images( + batch, split=split, **self.log_images_kwargs + ) + + for k in images: + N = min(images[k].shape[0], self.max_images) + images[k] = images[k][:N] + if isinstance(images[k], torch.Tensor): + images[k] = images[k].detach().cpu() + if self.clamp: + images[k] = torch.clamp(images[k], -1.0, 1.0) + + self.log_local( + pl_module.logger.save_dir, + split, + images, + pl_module.global_step, + pl_module.current_epoch, + batch_idx, + ) + + logger_log_images = self.logger_log_images.get( + logger, lambda *args, **kwargs: None + ) + logger_log_images(pl_module, images, pl_module.global_step, split) + + if is_train: + pl_module.train() + + def check_frequency(self, check_idx): + if (check_idx % self.batch_freq) == 0 and ( + check_idx > 0 or self.log_first_step + ): + return True + return False + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): + self.log_img(pl_module, batch, batch_idx, split="train") + + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ): + if not self.disabled and pl_module.global_step > 0: + self.log_img(pl_module, batch, batch_idx, split="val") + if ( + hasattr(pl_module, "calibrate_grad_norm") + and (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) + and batch_idx > 0 + ): + self.log_gradients(trainer, pl_module, batch_idx=batch_idx) + + +class CUDACallback(Callback): + # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py + def on_train_epoch_start(self, trainer, pl_module): + # Reset the memory use counter + if "cuda" in get_device(): + torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) + torch.cuda.synchronize(trainer.strategy.root_device.index) + self.start_time = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + if "cuda" in get_device(): + torch.cuda.synchronize(trainer.strategy.root_device.index) + max_memory = ( + torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) + / 2**20 + ) + epoch_time = time.time() - self.start_time + + try: + max_memory = trainer.training_type_plugin.reduce(max_memory) + epoch_time = trainer.training_type_plugin.reduce(epoch_time) + + rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") + rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") + except AttributeError: + pass + + +def train_diffusion_model( + concept_label, + concept_images_dir, + class_label, + class_images_dir, + weights_location=config.DEFAULT_MODEL_WEIGHTS, + logdir="logs", + learning_rate=1e-6, + accumulate_grad_batches=32, + resume=None, +): + """ + Train a diffusion model on a single concept. + + accumulate_grad_batches used to simulate a bigger batch size - https://arxiv.org/pdf/1711.00489.pdf + """ + if DDPStrategy is None: + msg = "Please install pytorch-lightning>=1.6.0 to train a model" + raise ImportError(msg) + + batch_size = 1 + seed = 23 + num_workers = 1 + num_val_workers = 0 + now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # noqa: DTZ005 + logdir = os.path.join(logdir, now) + + ckpt_output_dir = os.path.join(logdir, "checkpoints") + cfg_output_dir = os.path.join(logdir, "configs") + seed_everything(seed) + model = get_diffusion_model( + weights_location=weights_location, half_mode=False, for_training=True + )._model + model.learning_rate = learning_rate * accumulate_grad_batches * batch_size + + # add callback which sets up log directory + default_callbacks_cfg = { + "setup_callback": { + "target": "imaginairy.train.SetupCallback", + "params": { + "resume": False, + "now": now, + "logdir": logdir, + "ckptdir": ckpt_output_dir, + "cfgdir": cfg_output_dir, + }, + }, + "image_logger": { + "target": "imaginairy.train.ImageLogger", + "params": { + "batch_frequency": 10, + "max_images": 1, + "clamp": True, + "increase_log_steps": False, + "log_first_step": True, + "log_all_val": True, + "concept_label": concept_label, + "log_images_kwargs": { + "use_ema_scope": True, + "inpaint": False, + "plot_progressive_rows": False, + "plot_diffusion_rows": False, + "N": 1, + "unconditional_guidance_scale:": 7.5, + "unconditional_guidance_label": [""], + "ddim_steps": 20, + }, + }, + }, + "learning_rate_logger": { + "target": "imaginairy.train.LearningRateMonitor", + "params": { + "logging_interval": "step", + # "log_momentum": True + }, + }, + "cuda_callback": {"target": "imaginairy.train.CUDACallback"}, + } + + default_modelckpt_cfg = { + "target": "pytorch_lightning.callbacks.ModelCheckpoint", + "params": { + "dirpath": ckpt_output_dir, + "filename": "{epoch:06}", + "verbose": True, + "save_last": True, + "every_n_train_steps": 50, + "save_top_k": -1, + "monitor": None, + }, + } + + modelckpt_cfg = OmegaConf.create(default_modelckpt_cfg) + default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg}) + + callbacks_cfg = OmegaConf.create(default_callbacks_cfg) + + dataset_config = { + "concept_label": concept_label, + "concept_images_dir": concept_images_dir, + "class_label": class_label, + "class_images_dir": class_images_dir, + "image_transforms": [ + { + "target": "torchvision.transforms.Resize", + "params": {"size": 512, "interpolation": 3}, + }, + {"target": "torchvision.transforms.RandomCrop", "params": {"size": 512}}, + ], + } + + data_module_config = { + "batch_size": batch_size, + "num_workers": num_workers, + "num_val_workers": num_val_workers, + "train": { + "target": "imaginairy.training_tools.single_concept.SingleConceptDataset", + "params": dataset_config, + }, + } + trainer = Trainer( + benchmark=True, + num_sanity_val_steps=0, + accumulate_grad_batches=accumulate_grad_batches, + strategy=DDPStrategy(), + callbacks=[instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg], + gpus=1, + default_root_dir=".", + ) + trainer.logdir = logdir + + data = DataModuleFromConfig(**data_module_config) + # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html + # calling these ourselves should not be necessary but it is. + # lightning still takes care of proper multiprocessing though + data.prepare_data() + data.setup() + + def melk(*args, **kwargs): + if trainer.global_rank == 0: + mod_logger.info("Summoning checkpoint.") + ckpt_path = os.path.join(ckpt_output_dir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + signal.signal(signal.SIGUSR1, melk) + try: + try: + trainer.fit(model, data) + except Exception: + melk() + raise + finally: + mod_logger.info(trainer.profiler.summary()) diff --git a/imaginairy/training_tools/image_prep.py b/imaginairy/training_tools/image_prep.py index 88d07f19..5a1a331b 100644 --- a/imaginairy/training_tools/image_prep.py +++ b/imaginairy/training_tools/image_prep.py @@ -6,10 +6,11 @@ from PIL import Image from tqdm import tqdm -from imaginairy import ImaginePrompt, LazyLoadingImage, imagine +from imaginairy.api import imagine from imaginairy.enhancers.face_restoration_codeformer import enhance_faces from imaginairy.enhancers.facecrop import detect_faces, generate_face_crops from imaginairy.enhancers.upscale_realesrgan import upscale_image +from imaginairy.schema import ImaginePrompt, LazyLoadingImage from imaginairy.vendored.smart_crop import SmartCrop logger = logging.getLogger(__name__) diff --git a/imaginairy/utils/__init__.py b/imaginairy/utils/__init__.py index c8cbebe4..feb125cc 100644 --- a/imaginairy/utils/__init__.py +++ b/imaginairy/utils/__init__.py @@ -4,7 +4,7 @@ import time from contextlib import contextmanager, nullcontext from functools import lru_cache -from typing import Any, List, Optional, Union +from typing import Any, List, Optional import torch from torch import Tensor, autocast @@ -62,7 +62,7 @@ def get_obj_from_str(import_path: str, reload=False) -> Any: return getattr(module, obj_name) -def instantiate_from_config(config: Union[dict, str]) -> Any: +def instantiate_from_config(config: dict) -> Any: """Instantiate an object from a config dict.""" if "target" not in config: if config == "__is_first_stage__": diff --git a/imaginairy/utils/data_distorter.py b/imaginairy/utils/data_distorter.py index 3d133a30..a55f6ca9 100644 --- a/imaginairy/utils/data_distorter.py +++ b/imaginairy/utils/data_distorter.py @@ -170,7 +170,21 @@ def replace_value_at_path(data, path, new_value): parent = get_path(data, path[:-1]) last_key = path[-1] if new_value == NODE_DELETE: - del parent[last_key] + if isinstance(parent, tuple): + grandparent = get_path(data, path[:-2]) + grandparent_key = path[-2] + new_parent = list(parent) + del new_parent[last_key] + grandparent[grandparent_key] = tuple(new_parent) + else: + del parent[last_key] else: - parent[last_key] = new_value + if isinstance(parent, tuple): + grandparent = get_path(data, path[:-2]) + grandparent_key = path[-2] + new_parent = list(parent) + new_parent[last_key] = new_value + grandparent[grandparent_key] = tuple(new_parent) + else: + parent[last_key] = new_value return data diff --git a/imaginairy/utils/named_resolutions.py b/imaginairy/utils/named_resolutions.py index 8e77cda8..e10447b5 100644 --- a/imaginairy/utils/named_resolutions.py +++ b/imaginairy/utils/named_resolutions.py @@ -43,27 +43,42 @@ "SVD": (1024, 576), # stable video diffusion } +_NAMED_RESOLUTIONS = {k.upper(): v for k, v in _NAMED_RESOLUTIONS.items()} -def get_named_resolution(resolution: str): - resolution = resolution.upper() - size = _NAMED_RESOLUTIONS.get(resolution) +def normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]: + size = _normalize_image_size(resolution) + if any(s <= 0 for s in size): + msg = f"Invalid resolution: {resolution!r}" + raise ValueError(msg) + return size - if size is None: - # is it WIDTHxHEIGHT format? - try: - width, height = resolution.split("X") - size = (int(width), int(height)) - except ValueError: - pass - if size is None: - # is it just a single number? - with contextlib.suppress(ValueError): - size = (int(resolution), int(resolution)) +def _normalize_image_size(resolution: str | int | tuple[int, int]) -> tuple[int, int]: + match resolution: + case (int(), int()): + return resolution # type: ignore + case int(): + return resolution, resolution + case str(): + resolution = resolution.strip().upper() + resolution = resolution.replace(" ", "").replace("X", ",").replace("*", ",") + if resolution.upper() in _NAMED_RESOLUTIONS: + return _NAMED_RESOLUTIONS[resolution.upper()] - if size is None: - msg = f"Unknown resolution: {resolution}" - raise ValueError(msg) + # is it WIDTH,HEIGHT format? + try: + width, height = resolution.split(",") + return int(width), int(height) + except ValueError: + pass - return size + # is it just a single number? + with contextlib.suppress(ValueError): + return int(resolution), int(resolution) + + msg = f"Invalid resolution: '{resolution}'" + raise ValueError(msg) + case _: + msg = f"Invalid resolution: {resolution!r}" + raise ValueError(msg) diff --git a/imaginairy/video_sample.py b/imaginairy/video_sample.py index dfe5a272..32c92cb3 100644 --- a/imaginairy/video_sample.py +++ b/imaginairy/video_sample.py @@ -6,7 +6,7 @@ import time from glob import glob from pathlib import Path -from typing import Optional +from typing import Any, Optional import cv2 import numpy as np @@ -16,9 +16,10 @@ from PIL import Image from torchvision.transforms import ToTensor -from imaginairy import LazyLoadingImage, config +from imaginairy import config from imaginairy.model_manager import get_cached_url_path from imaginairy.paths import PKG_ROOT +from imaginairy.schema import LazyLoadingImage from imaginairy.utils import ( default, get_device, @@ -30,9 +31,10 @@ def generate_video( - input_path: str = "other/images/sound-music.jpg", # Can either be image file or folder with image files - num_frames: Optional[int] = None, - num_steps: Optional[int] = None, + input_path: str, # Can either be image file or folder with image files + output_folder: str | None = None, + num_frames: int = 6, + num_steps: int = 30, model_name: str = "svd_xt", fps_id: int = 6, output_fps: int = 6, @@ -41,7 +43,6 @@ def generate_video( seed: Optional[int] = None, decoding_t: int = 1, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: Optional[str] = None, - output_folder: Optional[str] = None, repetitions=1, ): """ @@ -69,15 +70,16 @@ def generate_video( seed = default(seed, random.randint(0, 1000000)) output_fps = default(output_fps, fps_id) - video_model_config = config.video_models.get(model_name, None) + video_model_config = config.MODEL_WEIGHT_CONFIG_LOOKUP.get(model_name, None) if video_model_config is None: msg = f"Version {model_name} does not exist." raise ValueError(msg) - num_frames = default(num_frames, video_model_config["default_frames"]) - num_steps = default(num_steps, video_model_config["default_steps"]) - output_folder = default(output_folder, "outputs/video/") - video_config_path = f"{PKG_ROOT}/{video_model_config['config_path']}" + num_frames = default(num_frames, video_model_config.defaults.get("frames", 12)) + num_steps = default(num_steps, video_model_config.defaults.get("steps", 30)) + output_folder_str = default(output_folder, "outputs/video/") + del output_folder + video_config_path = f"{PKG_ROOT}/{video_model_config.architecture.config_path}" logger.info( f"Generating a {num_frames} frame video from {input_path}. Device:{device} seed:{seed}" @@ -87,7 +89,7 @@ def generate_video( device="cpu", num_frames=num_frames, num_steps=num_steps, - weights_url=video_model_config["weights_url"], + weights_url=video_model_config.weights_location, ) torch.manual_seed(seed) @@ -118,11 +120,10 @@ def generate_video( for _ in range(repetitions): for input_path in all_img_paths: if input_path.startswith("http"): - image = LazyLoadingImage(url=input_path) + image = LazyLoadingImage(url=input_path).as_pillow() else: - image = LazyLoadingImage(filepath=input_path) + image = LazyLoadingImage(filepath=input_path).as_pillow() crop_coords = None - image = image.as_pillow() if image.mode == "RGBA": image = image.convert("RGB") if image.size != expected_size: @@ -179,7 +180,7 @@ def generate_video( "Large fps value! This may lead to suboptimal performance." ) - value_dict = {} + value_dict: dict[str, Any] = {} value_dict["motion_bucket_id"] = motion_bucket_id value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug @@ -249,14 +250,14 @@ def denoiser(_input, sigma, c): left, upper, right, lower = crop_coords samples = samples[:, :, upper:lower, left:right] - os.makedirs(output_folder, exist_ok=True) - base_count = len(glob(os.path.join(output_folder, "*.mp4"))) + 1 + os.makedirs(output_folder_str, exist_ok=True) + base_count = len(glob(os.path.join(output_folder_str, "*.mp4"))) + 1 source_slug = make_safe_filename(input_path) video_filename = f"{base_count:06d}_{model_name}_{seed}_{fps_id}fps_{source_slug}.mp4" - video_path = os.path.join(output_folder, video_filename) + video_path = os.path.join(output_folder_str, video_filename) writer = cv2.VideoWriter( video_path, - cv2.VideoWriter_fourcc(*"MP4V"), + cv2.VideoWriter_fourcc(*"MP4V"), # type: ignore output_fps, (samples.shape[-1], samples.shape[-2]), ) @@ -329,20 +330,20 @@ def get_batch(keys, value_dict, N, T, device): def load_model( config: str, device: str, num_frames: int, num_steps: int, weights_url: str ): - config = OmegaConf.load(config) + oconfig = OmegaConf.load(config) ckpt_path = get_cached_url_path(weights_url) - config["model"]["params"]["ckpt_path"] = ckpt_path + oconfig["model"]["params"]["ckpt_path"] = ckpt_path # type: ignore if device == "cuda": - config.model.params.conditioner_config.params.emb_models[ + oconfig.model.params.conditioner_config.params.emb_models[ 0 ].params.open_clip_embedding_config.params.init_device = device - config.model.params.sampler_config.params.num_steps = num_steps - config.model.params.sampler_config.params.guider_config.params.num_frames = ( + oconfig.model.params.sampler_config.params.num_steps = num_steps + oconfig.model.params.sampler_config.params.guider_config.params.num_frames = ( num_frames ) - model = instantiate_from_config(config.model).to(device).half().eval() + model = instantiate_from_config(oconfig.model).to(device).half().eval() # safety_filter = DeepFloydDataFiltering(verbose=False, device=device) def safety_filter(x): @@ -406,13 +407,3 @@ def make_safe_filename(input_string): safe_name = re.sub(r"[^a-zA-Z0-9\-]", "", name_without_extension) return safe_name - - -if __name__ == "__main__": - # configure logging - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(levelname)s %(name)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - generate_video() diff --git a/imaginairy/weight_management/conversion.py b/imaginairy/weight_management/conversion.py index 9c2087f4..6f43c643 100644 --- a/imaginairy/weight_management/conversion.py +++ b/imaginairy/weight_management/conversion.py @@ -6,7 +6,7 @@ from imaginairy.weight_management import utils if TYPE_CHECKING: - from torch import Tensor + from torch import Tensor # noqa @dataclass @@ -69,8 +69,8 @@ def could_convert(self, source_weights): source_keys = set(source_weights.keys()) return source_keys.issubset(self.all_valid_prefixes) - def cast_weights(self, source_weights): - converted_state_dict: dict[str, Tensor] = {} + def cast_weights(self, source_weights) -> dict[str, "Tensor"]: + converted_state_dict: dict[str, "Tensor"] = {} for source_key in source_weights: source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1) # handle aliases @@ -89,20 +89,24 @@ def cast_weights(self, source_weights): @lru_cache(maxsize=None) -def load_state_dict_conversion_maps(): +def load_state_dict_conversion_maps() -> dict[str, dict]: import json conversion_maps = {} from importlib.resources import files for file in files("imaginairy").joinpath("weight_conversion/maps").iterdir(): - if file.is_file() and file.suffix == ".json": + if file.is_file() and file.suffix == ".json": # type: ignore conversion_maps[file.name] = json.loads(file.read_text()) return conversion_maps def cast_weights( - source_weights, source_model_name, source_component_name, source_format, dest_format + source_weights, + source_model_name: str, + source_component_name: str, + source_format: str, + dest_format: str, ): weight_map = WeightMap( model_name=source_model_name, diff --git a/imaginairy/weight_management/generate_weight_info.py b/imaginairy/weight_management/generate_weight_info.py index 3cccb86d..723d0b17 100644 --- a/imaginairy/weight_management/generate_weight_info.py +++ b/imaginairy/weight_management/generate_weight_info.py @@ -3,7 +3,7 @@ from imaginairy.model_manager import ( get_cached_url_path, open_weights, - resolve_model_paths, + resolve_model_weights_config, ) from imaginairy.weight_management import utils from imaginairy.weight_management.pattern_collapse import find_state_dict_key_patterns @@ -11,15 +11,12 @@ def save_compvis_patterns(): - ( - model_metadata, - weights_url, - config_path, - control_weights_paths, - ) = resolve_model_paths( - weights_path="openjourney-v1", + model_weights_config = resolve_model_weights_config( + model_weights="openjourney-v1", + ) + weights_path = get_cached_url_path( + model_weights_config.weights_location, category="weights" ) - weights_path = get_cached_url_path(weights_url, category="weights") with safetensors.safe_open(weights_path, "pytorch") as f: weights_keys = f.keys() @@ -98,7 +95,7 @@ def save_weight_info( model_name, component_name, format_name, weights_url=None, weights_keys=None ): if weights_keys is None and weights_url is None: - msg = "Either weights_keys or weights_url must be provided" + msg = "Either weights_keys or weights_location must be provided" raise ValueError(msg) if weights_keys is None: diff --git a/requirements-dev.in b/requirements-dev.in index 6234432f..4955aab2 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -1,10 +1,15 @@ black coverage +mypy ruff pytest pytest-randomly pytest-sugar responses +types-pillow +types-psutil +types-requests +types-tqdm wheel -c tests/constraints.txt \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 3ecc2c95..afc472ae 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -71,7 +71,7 @@ frozenlist==1.4.0 # via # aiohttp # aiosignal -fsspec[http]==2023.12.0 +fsspec[http]==2023.12.1 # via # huggingface-hub # pytorch-lightning @@ -126,8 +126,12 @@ multidict==6.0.4 # via # aiohttp # yarl +mypy==1.7.1 + # via -r requirements-dev.in mypy-extensions==1.0.0 - # via black + # via + # black + # mypy networkx==3.2.1 # via torch numba==0.58.1 @@ -172,7 +176,7 @@ packaging==23.2 # pytorch-lightning # torchmetrics # transformers -pathspec==0.11.2 +pathspec==0.12.1 # via black pillow==10.1.0 # via @@ -276,6 +280,7 @@ tokenizers==0.15.0 tomli==2.0.1 # via # black + # mypy # pytest torch==2.1.1 # via @@ -314,13 +319,22 @@ transformers==4.35.2 # via imaginAIry (setup.py) typeguard==2.13.3 # via jaxtyping -typing-extensions==4.8.0 +types-pillow==10.1.0.2 + # via -r requirements-dev.in +types-psutil==5.9.5.17 + # via -r requirements-dev.in +types-requests==2.31.0.10 + # via -r requirements-dev.in +types-tqdm==4.66.0.5 + # via -r requirements-dev.in +typing-extensions==4.9.0 # via # black # fastapi # huggingface-hub # jaxtyping # lightning-utilities + # mypy # pydantic # pydantic-core # pytorch-lightning @@ -330,13 +344,14 @@ urllib3==2.1.0 # via # requests # responses + # types-requests uvicorn==0.24.0.post1 # via imaginAIry (setup.py) wcwidth==0.2.12 # via ftfy wheel==0.42.0 # via -r requirements-dev.in -yarl==1.9.3 +yarl==1.9.4 # via aiohttp zipp==3.17.0 # via importlib-metadata diff --git a/scripts/asses_memory_usage.py b/scripts/asses_memory_usage.py index afe68e22..9359a09f 100644 --- a/scripts/asses_memory_usage.py +++ b/scripts/asses_memory_usage.py @@ -1,22 +1,22 @@ import torch from torch.cuda import OutOfMemoryError -from imaginairy import ImaginePrompt, imagine_image_files +from imaginairy.api import imagine_image_files +from imaginairy.schema import ImaginePrompt from imaginairy.utils import get_device def assess_memory_usage(): assert get_device() == "cuda" img_size = 3048 - prompt = ImaginePrompt("strawberries", width=64, height=64, seed=1) + prompt = ImaginePrompt("strawberries", size=64, seed=1) imagine_image_files([prompt], outdir="outputs") datalog = [] while True: torch.cuda.reset_peak_memory_stats() prompt = ImaginePrompt( "beautiful landscape, Unreal Engine 5, RTX, AAA Game, Detailed 3D Render, Cinema4D", - width=img_size, - height=img_size, + size=img_size, seed=1, steps=2, ) diff --git a/setup.py b/setup.py index baaab491..9825d751 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ if is_for_windows: scripts = None - entry_points = { + entry_points: dict | None = { "console_scripts": [ "imagine=imaginairy.cli.main:imagine_cmd", "aimg=imaginairy.cli.main:aimg", diff --git a/tests/__init__.py b/tests/__init__.py index ce4152c8..635455c1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,4 @@ import os.path TESTS_FOLDER = os.path.abspath(os.path.dirname(__file__)) +PROJECT_FOLDER = os.path.abspath(os.path.join(TESTS_FOLDER, "..")) diff --git a/tests/conftest.py b/tests/conftest.py index f7fc6a6e..56193e29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,6 @@ import contextlib +import csv +import gc import logging import os import sys @@ -7,12 +9,14 @@ import pytest import responses +import torch.cuda from tqdm import tqdm from urllib3 import HTTPConnectionPool -from imaginairy import ImaginePrompt, api, imagine +from imaginairy import api +from imaginairy.api import imagine from imaginairy.log_utils import configure_logging, suppress_annoying_logs_and_warnings -from imaginairy.samplers import SAMPLER_TYPE_OPTIONS +from imaginairy.schema import ImaginePrompt from imaginairy.utils import ( fix_torch_group_norm, fix_torch_nn_layer_norm, @@ -26,13 +30,13 @@ logger = logging.getLogger(__name__) -SAMPLERS_FOR_TESTING = SAMPLER_TYPE_OPTIONS -if get_device() == "mps:0": - SAMPLERS_FOR_TESTING = ["plms", "k_euler_a"] -elif get_device() == "cpu": - SAMPLERS_FOR_TESTING = [] +# SOLVERS_FOR_TESTING = SOLVER_TYPE_OPTIONS +# if get_device() == "mps:0": +# SOLVERS_FOR_TESTING = ["plms", "k_euler_a"] +# elif get_device() == "cpu": +# SOLVERS_FOR_TESTING = [] -SAMPLERS_FOR_TESTING = ["ddim", "k_dpmpp_2m"] +SOLVERS_FOR_TESTING = ["ddim", "dpmpp"] @pytest.fixture(scope="session", autouse=True) @@ -90,8 +94,8 @@ def filename_base_for_orig_outputs(request): return filename_base -@pytest.fixture(params=SAMPLERS_FOR_TESTING) -def sampler_type(request): +@pytest.fixture(params=SOLVERS_FOR_TESTING) +def solver_type(request): return request.param @@ -118,23 +122,51 @@ def default_model_loaded(): """ prompt = ImaginePrompt( "dogs lying on a hot pink couch", - width=64, - height=64, + size=64, steps=2, seed=1, - sampler_type="ddim", + solver_type="ddim", ) next(imagine(prompt)) +cuda_tests_node_ids = [] +cuda_test_tracker_filepath = f"{TESTS_FOLDER}/data/cuda-tests.csv" + + +@pytest.fixture(autouse=True) +def detect_cuda_tests(request): + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + start_memory = torch.cuda.max_memory_allocated() + yield + if torch.cuda.is_available(): + end_memory = torch.cuda.max_memory_allocated() + memory_diff = end_memory - start_memory + if memory_diff > 0: + test_name = request.node.name + print(f"Test {test_name} used {memory_diff} bytes of GPU memory") + cuda_tests_node_ids.append(test_name) + + torch.cuda.empty_cache() + gc.collect() + + @pytest.hookimpl() def pytest_collection_modifyitems(config, items): """Only select a subset of tests to run, based on the --subset option.""" + + node_ids_to_mark = read_stored_cuda_test_nodes() + for item in items: + if item.nodeid in node_ids_to_mark: + item.add_marker(pytest.mark.gputest) + filtered_node_ids = set() node_ids = [f.nodeid for f in items] node_ids.sort() subset = config.getoption("--subset") + if subset: partition_no, total_partitions = subset.split("/") partition_no, total_partitions = int(partition_no), int(total_partitions) @@ -168,3 +200,26 @@ def pytest_sessionstart(session): if "nvidia_smi" in debug_info: print(debug_info["nvidia_smi"]) + + +def pytest_sessionfinish(session, exitstatus): + existing_node_ids = read_stored_cuda_test_nodes() + updated_node_ids = existing_node_ids.union(set(cuda_tests_node_ids)) + + # Write updated, sorted list of node IDs to file + with open(cuda_test_tracker_filepath, "w", newline="") as file: + writer = csv.writer(file) + for node_id in sorted(updated_node_ids): + writer.writerow([node_id]) + + +def read_stored_cuda_test_nodes(): + node_ids = set() + try: + with open(cuda_test_tracker_filepath, newline="") as file: + reader = csv.reader(file) + for row in reader: + node_ids.add(row[0]) + except FileNotFoundError: + pass # File does not exist yet + return node_ids diff --git a/tests/data/cuda-tests.csv b/tests/data/cuda-tests.csv new file mode 100644 index 00000000..71e42fe9 --- /dev/null +++ b/tests/data/cuda-tests.csv @@ -0,0 +1,51 @@ +test_cache_ordering +test_clip_masking +test_clip_text_comparison +test_cliptext_inpainting_pearl_doctor +test_colorize_cmd +test_control_images[depth-create_depth_map] +test_control_images[hed-create_hed_edges] +test_control_images[normal-create_normal_map] +test_control_images[openpose-create_pose_map] +test_controlnet[canny] +test_controlnet[colorize] +test_controlnet[depth] +test_controlnet[edit] +test_controlnet[hed] +test_controlnet[inpaint] +test_controlnet[normal] +test_controlnet[openpose] +test_controlnet[shuffle] +test_describe_cmd +test_describe_picture +test_edit_cmd +test_edit_demo +test_fix_faces +test_get_existing_move_to_gpu +test_imagine[ddim] +test_imagine[dpmpp] +test_imagine_cmd +test_img2img_beach_to_sunset[ddim] +test_img2img_beach_to_sunset[dpmpp] +test_img2img_low_noise[ddim] +test_img2img_low_noise[dpmpp] +test_img_to_file +test_img_to_img_from_url_cats[ddim] +test_img_to_img_from_url_cats[dpmpp] +test_img_to_img_fruit_2_gold[ddim-0.05] +test_img_to_img_fruit_2_gold[ddim-0.2] +test_img_to_img_fruit_2_gold[ddim-0] +test_img_to_img_fruit_2_gold[ddim-1] +test_img_to_img_fruit_2_gold[dpmpp-0.05] +test_img_to_img_fruit_2_gold[dpmpp-0.2] +test_img_to_img_fruit_2_gold[dpmpp-0] +test_img_to_img_fruit_2_gold[dpmpp-1] +test_img_to_img_fruit_2_gold_repeat +test_inpainting_bench +test_large_image +test_model_versions[SD-1.5] +test_nonlinearity +test_outpainting_outpaint +test_set_cpu_full +test_text_conditioning +test_tile_mode diff --git a/tests/enhancers/test_facecrop.py b/tests/enhancers/test_facecrop.py index 63f73cee..b27b1cbc 100644 --- a/tests/enhancers/test_facecrop.py +++ b/tests/enhancers/test_facecrop.py @@ -1,7 +1,7 @@ import logging -from imaginairy import LazyLoadingImage from imaginairy.enhancers.facecrop import generate_face_crops +from imaginairy.schema import LazyLoadingImage from tests import TESTS_FOLDER logger = logging.getLogger(__name__) diff --git a/tests/expected_output/test_imagine[dpmpp]_.png b/tests/expected_output/test_imagine[dpmpp]_.png new file mode 100644 index 00000000..ea9ebfea Binary files /dev/null and b/tests/expected_output/test_imagine[dpmpp]_.png differ diff --git a/tests/expected_output/test_img2img_beach_to_sunset[dpmpp]_.png b/tests/expected_output/test_img2img_beach_to_sunset[dpmpp]_.png new file mode 100644 index 00000000..24d4fa79 Binary files /dev/null and b/tests/expected_output/test_img2img_beach_to_sunset[dpmpp]_.png differ diff --git a/tests/expected_output/test_img2img_low_noise[dpmpp]_.png b/tests/expected_output/test_img2img_low_noise[dpmpp]_.png new file mode 100644 index 00000000..a396df6b Binary files /dev/null and b/tests/expected_output/test_img2img_low_noise[dpmpp]_.png differ diff --git a/tests/expected_output/test_img_to_img_from_url_cats[dpmpp]_.png b/tests/expected_output/test_img_to_img_from_url_cats[dpmpp]_.png new file mode 100644 index 00000000..ed8aa77f Binary files /dev/null and b/tests/expected_output/test_img_to_img_from_url_cats[dpmpp]_.png differ diff --git a/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.05]_.png b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.05]_.png new file mode 100644 index 00000000..d4ce6288 Binary files /dev/null and b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.05]_.png differ diff --git a/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.2]_.png b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.2]_.png new file mode 100644 index 00000000..26137634 Binary files /dev/null and b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0.2]_.png differ diff --git a/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0]_.png b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0]_.png new file mode 100644 index 00000000..264f6a82 Binary files /dev/null and b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-0]_.png differ diff --git a/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-1]_.png b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-1]_.png new file mode 100644 index 00000000..e32a8429 Binary files /dev/null and b/tests/expected_output/test_img_to_img_fruit_2_gold[dpmpp-1]_.png differ diff --git a/tests/expected_output/test_large_image_.png b/tests/expected_output/test_large_image_.png index 23f33c21..bcdcf72e 100644 Binary files a/tests/expected_output/test_large_image_.png and b/tests/expected_output/test_large_image_.png differ diff --git a/tests/expected_output/test_model_versions__a headshot photo of a happy couple smiling at the camera_SD-1.5.png b/tests/expected_output/test_model_versions__a headshot photo of a happy couple smiling at the camera_sd15.png similarity index 100% rename from tests/expected_output/test_model_versions__a headshot photo of a happy couple smiling at the camera_SD-1.5.png rename to tests/expected_output/test_model_versions__a headshot photo of a happy couple smiling at the camera_sd15.png diff --git a/tests/expected_output/test_model_versions__a painting of a beautiful cloudy sunset at the beach_SD-1.5.png b/tests/expected_output/test_model_versions__a painting of a beautiful cloudy sunset at the beach_sd15.png similarity index 100% rename from tests/expected_output/test_model_versions__a painting of a beautiful cloudy sunset at the beach_SD-1.5.png rename to tests/expected_output/test_model_versions__a painting of a beautiful cloudy sunset at the beach_sd15.png diff --git a/tests/expected_output/test_model_versions__a photo of a bowl of fruit_SD-1.5.png b/tests/expected_output/test_model_versions__a photo of a bowl of fruit_sd15.png similarity index 100% rename from tests/expected_output/test_model_versions__a photo of a bowl of fruit_SD-1.5.png rename to tests/expected_output/test_model_versions__a photo of a bowl of fruit_sd15.png diff --git a/tests/expected_output/test_model_versions__a photo of a dog_SD-1.5.png b/tests/expected_output/test_model_versions__a photo of a dog_sd15.png similarity index 100% rename from tests/expected_output/test_model_versions__a photo of a dog_SD-1.5.png rename to tests/expected_output/test_model_versions__a photo of a dog_sd15.png diff --git a/tests/expected_output/test_model_versions__a photo of a handshake_SD-1.5.png b/tests/expected_output/test_model_versions__a photo of a handshake_sd15.png similarity index 100% rename from tests/expected_output/test_model_versions__a photo of a handshake_SD-1.5.png rename to tests/expected_output/test_model_versions__a photo of a handshake_sd15.png diff --git a/tests/expected_output/test_model_versions__a photo of an astronaut riding a horse on the moon. the earth visible in the background_SD-1.5.png b/tests/expected_output/test_model_versions__a photo of an astronaut riding a horse on the moon. the earth visible in the background_sd15.png similarity index 100% rename from tests/expected_output/test_model_versions__a photo of an astronaut riding a horse on the moon. the earth visible in the background_SD-1.5.png rename to tests/expected_output/test_model_versions__a photo of an astronaut riding a horse on the moon. the earth visible in the background_sd15.png diff --git a/tests/img_processors/test_control_modes.py b/tests/img_processors/test_control_modes.py index 90f31e80..1d02d59a 100644 --- a/tests/img_processors/test_control_modes.py +++ b/tests/img_processors/test_control_modes.py @@ -1,9 +1,9 @@ import pytest from lightning_fabric import seed_everything -from imaginairy import LazyLoadingImage from imaginairy.img_processors.control_modes import CONTROL_MODES from imaginairy.img_utils import pillow_img_to_torch_image, torch_img_to_pillow_img +from imaginairy.schema import LazyLoadingImage from tests import TESTS_FOLDER from tests.utils import assert_image_similar_to_expectation diff --git a/tests/modules/diffusion/test_model.py b/tests/modules/diffusion/test_model.py index 0303a481..33e86a68 100644 --- a/tests/modules/diffusion/test_model.py +++ b/tests/modules/diffusion/test_model.py @@ -1,23 +1,7 @@ -import time - import torch from imaginairy.utils import get_device - - -class Timer: - def __init__(self, name): - self.name = name - self.start = None - - def __enter__(self): - self.start = time.perf_counter() - return self - - def __exit__(self, *args): - elapsed = time.perf_counter() - self.start - - print(f"{self.name} took {elapsed*1000:.2f} ms") +from tests.utils import Timer def test_nonlinearity(): diff --git a/tests/modules/test_autoencoders.py b/tests/modules/test_autoencoders.py index f68fe472..7523d10b 100644 --- a/tests/modules/test_autoencoders.py +++ b/tests/modules/test_autoencoders.py @@ -3,7 +3,6 @@ from PIL import Image from torch.nn.functional import interpolate -from imaginairy import LazyLoadingImage from imaginairy.enhancers.upscale_riverwing import upscale_latent from imaginairy.img_utils import ( pillow_fit_image_within, @@ -11,6 +10,7 @@ torch_img_to_pillow_img, ) from imaginairy.model_manager import get_diffusion_model +from imaginairy.schema import LazyLoadingImage from imaginairy.utils import get_device from tests import TESTS_FOLDER diff --git a/tests/test_api.py b/tests/test_api.py index cf8eb586..084df2e7 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -2,26 +2,25 @@ import pytest -from imaginairy import LazyLoadingImage from imaginairy.api import imagine, imagine_image_files from imaginairy.img_processors.control_modes import CONTROL_MODES from imaginairy.img_utils import pillow_fit_image_within -from imaginairy.schema import ControlNetInput, ImaginePrompt +from imaginairy.schema import ControlInput, ImaginePrompt, LazyLoadingImage, MaskMode from imaginairy.utils import get_device from . import TESTS_FOLDER from .utils import assert_image_similar_to_expectation -def test_imagine(sampler_type, filename_base_for_outputs): +def test_imagine(solver_type, filename_base_for_outputs): prompt_text = "a scenic old-growth forest with diffuse light poking through the canopy. high resolution nature photography" prompt = ImaginePrompt( - prompt_text, width=512, height=512, steps=20, seed=1, sampler_type=sampler_type + prompt_text, size=512, steps=20, seed=1, solver_type=solver_type ) result = next(imagine(prompt)) threshold_lookup = {"k_dpm_2_a": 26000} - threshold = threshold_lookup.get(sampler_type, 10000) + threshold = threshold_lookup.get(solver_type, 10000) img_path = f"{filename_base_for_outputs}.png" assert_image_similar_to_expectation( @@ -49,25 +48,25 @@ def test_model_versions(filename_base_for_orig_outputs, model_version): ImaginePrompt( prompt_text, seed=1, - model=model_version, + model_weights=model_version, ) ) threshold = 35000 - - for i, result in enumerate(imagine(prompts)): - img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png" + results = list(imagine(prompts)) + for i, result in enumerate(results): + img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights.aliases[0]}.png" result.img.save(img_path) - for i, result in enumerate(imagine(prompts)): - img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png" + for i, result in enumerate(results): + img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model_weights.aliases[0]}.png" assert_image_similar_to_expectation( result.img, img_path=img_path, threshold=threshold ) def test_img2img_beach_to_sunset( - sampler_type, filename_base_for_outputs, filename_base_for_orig_outputs + solver_type, filename_base_for_outputs, filename_base_for_orig_outputs ): img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/beach_at_sainte_adresse.jpg") prompt = ImaginePrompt( @@ -77,11 +76,10 @@ def test_img2img_beach_to_sunset( prompt_strength=15, mask_prompt="(sky|clouds) AND !(buildings|trees)", mask_mode="replace", - width=512, - height=512, + size=512, steps=40 * 2, seed=1, - sampler_type=sampler_type, + solver_type=solver_type, ) result = next(imagine(prompt)) @@ -91,7 +89,7 @@ def test_img2img_beach_to_sunset( def test_img_to_img_from_url_cats( - sampler_type, + solver_type, filename_base_for_outputs, mocked_responses, filename_base_for_orig_outputs, @@ -113,11 +111,10 @@ def test_img_to_img_from_url_cats( "dogs lying on a hot pink couch", init_image=img, init_image_strength=0.5, - width=512, - height=512, + size=512, steps=50, seed=1, - sampler_type=sampler_type, + solver_type=solver_type, ) result = next(imagine(prompt)) @@ -130,7 +127,7 @@ def test_img_to_img_from_url_cats( def test_img2img_low_noise( filename_base_for_outputs, - sampler_type, + solver_type, ): fruit_path = os.path.join(TESTS_FOLDER, "data", "bowl_of_fruit.jpg") img = LazyLoadingImage(filepath=fruit_path) @@ -144,17 +141,18 @@ def test_img2img_low_noise( mask_mode="replace", # steps=40, seed=1, - sampler_type=sampler_type, + solver_type=solver_type, ) result = next(imagine(prompt)) threshold_lookup = { + "dpmpp": 26000, "k_dpm_2_a": 26000, "k_euler_a": 18000, "k_dpm_adaptive": 13000, } - threshold = threshold_lookup.get(sampler_type, 14000) + threshold = threshold_lookup.get(solver_type, 14000) img_path = f"{filename_base_for_outputs}.png" assert_image_similar_to_expectation( @@ -165,7 +163,7 @@ def test_img2img_low_noise( @pytest.mark.parametrize("init_strength", [0, 0.05, 0.2, 1]) def test_img_to_img_fruit_2_gold( filename_base_for_outputs, - sampler_type, + solver_type, init_strength, filename_base_for_orig_outputs, ): @@ -183,7 +181,7 @@ def test_img_to_img_fruit_2_gold( mask_mode="replace", steps=needed_steps, seed=1, - sampler_type=sampler_type, + solver_type=solver_type, ) result = next(imagine(prompt)) @@ -194,7 +192,7 @@ def test_img_to_img_fruit_2_gold( "k_dpm_adaptive": 13000, "k_dpmpp_2s": 16000, } - threshold = threshold_lookup.get(sampler_type, 16000) + threshold = threshold_lookup.get(solver_type, 16000) pillow_fit_image_within(img).save(f"{filename_base_for_orig_outputs}__orig.jpg") img_path = f"{filename_base_for_outputs}.png" @@ -227,7 +225,7 @@ def test_img_to_img_fruit_2_gold_repeat(): ] for result in imagine(prompts, debug_img_callback=None): result.img.save( - f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_{result.prompt.sampler_type}_{get_device()}_run-{run_count:02}.jpg" + f"{TESTS_FOLDER}/test_output/img2img_fruit_2_gold_{result.prompt.solver_type}_{get_device()}_run-{run_count:02}.jpg" ) run_count += 1 @@ -236,9 +234,8 @@ def test_img_to_img_fruit_2_gold_repeat(): def test_img_to_file(): prompt = ImaginePrompt( "an old growth forest, diffuse light poking through the canopy. high-resolution, nature photography, nat geo photo", - width=512 + 64, - height=512 - 64, - steps=20, + size=(512 + 64, 512 - 64), + steps=2, seed=2, upscale=True, ) @@ -254,8 +251,7 @@ def test_inpainting_bench(filename_base_for_outputs, filename_base_for_orig_outp init_image=img, init_image_strength=0.4, mask_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png"), - width=512, - height=512, + size=512, steps=40, seed=1, ) @@ -279,9 +275,8 @@ def test_cliptext_inpainting_pearl_doctor( init_image=img, init_image_strength=0.2, mask_prompt="face AND NOT (bandana OR hair OR blue fabric){*5}", - mask_mode=ImaginePrompt.MaskMode.KEEP, - width=512, - height=512, + mask_mode=MaskMode.KEEP, + size=512, steps=40, seed=181509347, ) @@ -297,8 +292,7 @@ def test_tile_mode(filename_base_for_outputs): prompt_text = "gold coins" prompt = ImaginePrompt( prompt_text, - width=400, - height=400, + size=400, steps=15, seed=1, tile_mode="xy", @@ -317,7 +311,7 @@ def test_tile_mode(filename_base_for_outputs): def test_controlnet(filename_base_for_outputs, control_mode): prompt_text = "a photo of a woman sitting on a bench" img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2.png") - control_input = ControlNetInput( + control_input = ControlInput( mode=control_mode, image=img, ) @@ -327,30 +321,27 @@ def test_controlnet(filename_base_for_outputs, control_mode): prompt_text = "a wise old man" seed = 1 mask_image = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/bench2_mask.png") - control_input = ControlNetInput( + control_input = ControlInput( mode=control_mode, image=mask_image, ) prompt = ImaginePrompt( prompt_text, - width=512, - height=512, + size=512, steps=45, seed=seed, init_image=img, init_image_strength=0, control_inputs=[control_input], fix_faces=True, - sampler="ddim", + solver_type="ddim", ) prompt.steps = 1 - prompt.width = 256 - prompt.height = 256 + prompt.size = 256 result = next(imagine(prompt)) prompt.steps = 15 - prompt.width = 512 - prompt.height = 512 + prompt.size = 512 result = next(imagine(prompt)) img_path = f"{filename_base_for_outputs}.png" @@ -365,8 +356,7 @@ def test_large_image(filename_base_for_outputs): prompt_text = "a stormy ocean. oil painting" prompt = ImaginePrompt( prompt_text, - width=1920, - height=1080, + size="1080p", steps=30, seed=0, ) diff --git a/tests/test_cli/__init__.py b/tests/test_cli/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_cmds.py b/tests/test_cli/test_cmds.py similarity index 66% rename from tests/test_cmds.py rename to tests/test_cli/test_cmds.py index 5da66508..a46e890c 100644 --- a/tests/test_cmds.py +++ b/tests/test_cli/test_cmds.py @@ -1,15 +1,69 @@ +import subprocess from unittest import mock +import pytest from click.testing import CliRunner -from imaginairy import ImaginePrompt, LazyLoadingImage, surprise_me +from imaginairy import surprise_me from imaginairy.cli.edit import edit_cmd from imaginairy.cli.edit_demo import edit_demo_cmd from imaginairy.cli.imagine import imagine_cmd from imaginairy.cli.main import aimg from imaginairy.cli.upscale import upscale_cmd +from imaginairy.schema import ImaginePrompt, LazyLoadingImage from imaginairy.utils.model_cache import GPUModelCache -from tests import TESTS_FOLDER +from tests import PROJECT_FOLDER, TESTS_FOLDER +from tests.utils import Timer + + +@pytest.mark.parametrize("subcommand_name", aimg.commands.keys()) +def test_cmd_help_time(subcommand_name): + cmd_parts = [ + "aimg", + subcommand_name, + "--help", + ] + with Timer(f"{subcommand_name} --help") as t: + result = subprocess.run( + cmd_parts, check=False, capture_output=True, cwd=PROJECT_FOLDER + ) + assert result.returncode == 0, result.stderr + assert t.elapsed < 1.0, f"{t.elapsed} > 1.0" + + +def test_model_info_cmd(): + runner = CliRunner() + result = runner.invoke( + aimg, + [ + "model-list", + ], + ) + assert result.exit_code == 0, result.stdout + + +def test_describe_cmd(): + runner = CliRunner() + result = runner.invoke( + aimg, + [ + "describe", + f"{TESTS_FOLDER}/data/dog.jpg", + ], + ) + assert result.exit_code == 0, result.stdout + + +def test_colorize_cmd(): + runner = CliRunner() + result = runner.invoke( + aimg, + [ + "colorize", + f"{TESTS_FOLDER}/data/dog.jpg", + ], + ) + assert result.exit_code == 0, result.stdout def test_imagine_cmd(monkeypatch): @@ -25,8 +79,6 @@ def test_imagine_cmd(monkeypatch): f"{TESTS_FOLDER}/test_output", "--seed", "703425280", - # "--model", - # "empty", "--outdir", f"{TESTS_FOLDER}/test_output", ], @@ -72,8 +124,7 @@ def mock_surprise_me_prompts(*args, **kwargs): ImaginePrompt( "", steps=1, - width=256, - height=256, + size=256, # model="empty", ) ] @@ -89,7 +140,7 @@ def mock_surprise_me_prompts(*args, **kwargs): f"{TESTS_FOLDER}/test_output", ], ) - assert result.exit_code == 0 + assert result.exit_code == 0, result.stdout def test_upscale(monkeypatch): diff --git a/tests/test_config.py b/tests/test_config.py index 5bc784c7..34a7ac80 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,6 @@ -from imaginairy import config -from imaginairy.samplers import SAMPLER_TYPE_OPTIONS - - -def test_sampler_options(): - assert set(config.SAMPLER_TYPE_OPTIONS) == set(SAMPLER_TYPE_OPTIONS) +# from imaginairy import config +# from imaginairy.samplers import SOLVER_TYPE_OPTIONS +# +# +# def test_sampler_options(): +# assert set(config.SOLVER_TYPE_NAMES) == set(SOLVER_TYPE_OPTIONS) diff --git a/tests/test_enhancers.py b/tests/test_enhancers.py index 087b10c8..26426618 100644 --- a/tests/test_enhancers.py +++ b/tests/test_enhancers.py @@ -2,12 +2,13 @@ from PIL import Image from pytorch_lightning import seed_everything -from imaginairy import ImaginePrompt, imagine +from imaginairy.api import imagine from imaginairy.enhancers.bool_masker import MASK_PROMPT from imaginairy.enhancers.clip_masking import get_img_mask from imaginairy.enhancers.describe_image_blip import generate_caption from imaginairy.enhancers.describe_image_clip import find_img_text_similarity from imaginairy.enhancers.face_restoration_codeformer import enhance_faces +from imaginairy.schema import ImaginePrompt from imaginairy.utils import get_device from tests import TESTS_FOLDER from tests.utils import assert_image_similar_to_expectation @@ -58,7 +59,7 @@ def test_clip_masking(filename_base_for_outputs): upscale=False, fix_faces=True, seed=42, - # sampler_type="plms", + # solver_type="plms", ) result = next(imagine(prompt)) diff --git a/tests/test_feather_tile.py b/tests/test_feather_tile.py index 748f52ea..0b2d8044 100644 --- a/tests/test_feather_tile.py +++ b/tests/test_feather_tile.py @@ -2,9 +2,9 @@ import pytest -from imaginairy import LazyLoadingImage from imaginairy.feather_tile import rebuild_image, tile_image, tile_setup from imaginairy.img_utils import pillow_img_to_torch_image, torch_img_to_pillow_img +from imaginairy.schema import LazyLoadingImage from tests import TESTS_FOLDER img_ratios = [0.2, 0.242, 0.3, 0.33333333, 0.5, 0.75, 1, 4 / 3.0, 16 / 9.0, 2, 21 / 9.0] diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index ac1be827..16bcb0d5 100644 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -1,24 +1,25 @@ from imaginairy import config -from imaginairy.model_manager import resolve_model_paths +from imaginairy.model_manager import resolve_model_weights_config def test_resolved_paths(): """Test that the resolved model path is correct.""" - ( - model_metadata, - weights_path, - config_path, - control_weights_path, - ) = resolve_model_paths() - assert model_metadata.short_name == config.DEFAULT_MODEL - assert model_metadata.config_path == config_path - default_config_path = config_path + model_weights_config = resolve_model_weights_config(config.DEFAULT_MODEL_WEIGHTS) + assert config.DEFAULT_MODEL_WEIGHTS.lower() in model_weights_config.aliases + assert ( + config.DEFAULT_MODEL_ARCHITECTURE in model_weights_config.architecture.aliases + ) - ( - model_metadata, - weights_path, - config_path, - control_weights_path, - ) = resolve_model_paths(weights_path="foo.ckpt") - assert weights_path == "foo.ckpt" - assert config_path == default_config_path + model_weights_config = resolve_model_weights_config( + model_weights="foo.ckpt", + default_model_architecture="sd15", + ) + print(model_weights_config) + assert model_weights_config.aliases == [] + assert "sd15" in model_weights_config.architecture.aliases + + model_weights_config = resolve_model_weights_config( + model_weights="foo.ckpt", default_model_architecture="sd15", for_inpainting=True + ) + assert model_weights_config.aliases == [] + assert "sd15-inpaint" in model_weights_config.architecture.aliases diff --git a/tests/test_outpaint.py b/tests/test_outpaint.py index 5027fb74..4edb2b11 100644 --- a/tests/test_outpaint.py +++ b/tests/test_outpaint.py @@ -1,7 +1,8 @@ import pytest -from imaginairy import ImaginePrompt, LazyLoadingImage, imagine +from imaginairy.api import imagine from imaginairy.outpaint import outpaint_arg_str_parse +from imaginairy.schema import ImaginePrompt, LazyLoadingImage from imaginairy.utils import get_device from tests import TESTS_FOLDER from tests.utils import assert_image_similar_to_expectation diff --git a/tests/test_schema/test_controlnetinput.py b/tests/test_schema/test_controlnetinput.py index 90d9a091..3452bb1e 100644 --- a/tests/test_schema/test_controlnetinput.py +++ b/tests/test_schema/test_controlnetinput.py @@ -1,8 +1,7 @@ import pytest from pydantic import ValidationError -from imaginairy import LazyLoadingImage -from imaginairy.schema import ControlNetInput +from imaginairy.schema import ControlInput, LazyLoadingImage from tests import TESTS_FOLDER @@ -12,29 +11,29 @@ def _lazy_img(): def test_controlnetinput_basic(lazy_img): - ControlNetInput(mode="canny", image=lazy_img) - ControlNetInput(mode="canny", image_raw=lazy_img) + ControlInput(mode="canny", image=lazy_img) + ControlInput(mode="canny", image_raw=lazy_img) def test_controlnetinput_invalid_mode(lazy_img): with pytest.raises(ValueError, match=r".*Invalid controlnet mode.*"): - ControlNetInput(mode="pizza", image=lazy_img) + ControlInput(mode="pizza", image=lazy_img) def test_controlnetinput_both_images(lazy_img): with pytest.raises(ValueError, match=r".*cannot specify both.*"): - ControlNetInput(mode="canny", image=lazy_img, image_raw=lazy_img) + ControlInput(mode="canny", image=lazy_img, image_raw=lazy_img) def test_controlnetinput_filepath_input(lazy_img): """Test that we accept filepaths here.""" - c = ControlNetInput(mode="canny", image=f"{TESTS_FOLDER}/data/red.png") + c = ControlInput(mode="canny", image=f"{TESTS_FOLDER}/data/red.png") c.image.convert("RGB") - c = ControlNetInput(mode="canny", image_raw=f"{TESTS_FOLDER}/data/red.png") + c = ControlInput(mode="canny", image_raw=f"{TESTS_FOLDER}/data/red.png") c.image_raw.convert("RGB") def test_controlnetinput_big(lazy_img): - ControlNetInput(mode="canny", strength=2) + ControlInput(mode="canny", strength=2) with pytest.raises(ValidationError, match=r".*float_type.*"): - ControlNetInput(mode="canny", strength=2**2048) + ControlInput(mode="canny", strength=2**2048) diff --git a/tests/test_schema/test_imagineprompt.py b/tests/test_schema/test_imagineprompt.py index 1245e082..5cca2c9b 100644 --- a/tests/test_schema/test_imagineprompt.py +++ b/tests/test_schema/test_imagineprompt.py @@ -1,14 +1,41 @@ import pytest from pydantic import ValidationError -from imaginairy import LazyLoadingImage, config -from imaginairy.schema import ControlNetInput, ImaginePrompt, WeightedPrompt +from imaginairy import config +from imaginairy.schema import ( + ControlInput, + ImaginePrompt, + LazyLoadingImage, + WeightedPrompt, +) from imaginairy.utils.data_distorter import DataDistorter from tests import TESTS_FOLDER +def test_imagine_prompt_default(): + prompt = ImaginePrompt() + assert prompt.prompt == [WeightedPrompt(text="")] + assert prompt.negative_prompt == [ + WeightedPrompt(text=config.DEFAULT_NEGATIVE_PROMPT) + ] + + prompt = ImaginePrompt(negative_prompt="") + assert prompt.negative_prompt == [WeightedPrompt(text="")] + + assert prompt.width == 512 + + def test_imagine_prompt_has_default_negative(): - prompt = ImaginePrompt("fruit salad", model="foobar") + prompt = ImaginePrompt( + "fruit salad", + model_weights=config.ModelWeightsConfig( + name="foobar", + aliases=["foobar"], + weights_location="foobar", + architecture="sd15", + defaults={}, + ), + ) assert isinstance(prompt.prompt[0], WeightedPrompt) assert isinstance(prompt.negative_prompt[0], WeightedPrompt) @@ -21,10 +48,10 @@ def test_imagine_prompt_custom_negative_prompt(): def test_imagine_prompt_model_specific_negative_prompt(): - prompt = ImaginePrompt("fruit salad", model="openjourney-v1") + prompt = ImaginePrompt("fruit salad", model_weights="openjourney-v1") assert isinstance(prompt.prompt[0], WeightedPrompt) assert isinstance(prompt.negative_prompt[0], WeightedPrompt) - assert prompt.negative_prompt[0].text == "" + assert prompt.negative_prompt[0].text == "poor quality" def test_imagine_prompt_weighted_prompts(): @@ -84,7 +111,7 @@ def test_imagine_prompt_control_inputs(): prompt = ImaginePrompt( "fruit", control_inputs=[ - ControlNetInput(mode="depth", image=img), + ControlInput(mode="depth", image=img), ], ) prompt.control_inputs[0].image.convert("RGB") @@ -98,7 +125,7 @@ def test_imagine_prompt_control_inputs(): "fruit", init_image=img, control_inputs=[ - ControlNetInput(mode="depth"), + ControlInput(mode="depth"), ], ) assert prompt.control_inputs[0].image is not None @@ -107,7 +134,7 @@ def test_imagine_prompt_control_inputs(): prompt = ImaginePrompt( "fruit", control_inputs=[ - ControlNetInput(mode="depth"), + ControlInput(mode="depth"), ], ) assert prompt.control_inputs[0].image is None @@ -136,8 +163,8 @@ def test_imagine_prompt_mask_params(): def test_imagine_prompt_default_model(): - prompt = ImaginePrompt("fruit", model=None) - assert prompt.model == config.DEFAULT_MODEL + prompt = ImaginePrompt("fruit", model_weights=None) + assert config.DEFAULT_MODEL_WEIGHTS in prompt.model_weights.aliases def test_imagine_prompt_default_negative(): @@ -152,7 +179,7 @@ def test_imagine_prompt_fix_faces_fidelity(): def test_imagine_prompt_init_strength_zero(): lazy_img = LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png") prompt = ImaginePrompt( - "fruit", control_inputs=[ControlNetInput(mode="depth", image=lazy_img)] + "fruit", control_inputs=[ControlInput(mode="depth", image=lazy_img)] ) assert prompt.init_image_strength == 0.0 @@ -171,12 +198,12 @@ def test_distorted_prompts(): init_image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"), init_image_strength=0.5, control_inputs=[ - ControlNetInput( + ControlInput( mode="details", image=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"), strength=2, ), - ControlNetInput( + ControlInput( mode="depth", image_raw=LazyLoadingImage(filepath=f"{TESTS_FOLDER}/data/red.png"), strength=3, @@ -187,13 +214,11 @@ def test_distorted_prompts(): mask_mode="replace", mask_modify_original=False, outpaint="all5,up0,down20", - model=config.DEFAULT_MODEL, - model_config_path=None, - sampler_type=config.DEFAULT_SAMPLER, + model_weights=config.DEFAULT_MODEL_WEIGHTS, + solver_type=config.DEFAULT_SOLVER, seed=42, steps=10, - height=256, - width=256, + size=256, upscale=True, fix_faces=True, fix_faces_fidelity=0.7, diff --git a/tests/test_schema/test_lazy_load_image.py b/tests/test_schema/test_lazy_load_image.py index 57bf9d67..b200401e 100644 --- a/tests/test_schema/test_lazy_load_image.py +++ b/tests/test_schema/test_lazy_load_image.py @@ -5,8 +5,7 @@ from PIL import Image from pydantic import BaseModel -from imaginairy import LazyLoadingImage -from imaginairy.schema import InvalidUrlError +from imaginairy.schema import InvalidUrlError, LazyLoadingImage from tests import TESTS_FOLDER diff --git a/tests/test_utils/test_model_cache.py b/tests/test_utils/test_model_cache.py index d75edd61..21a06213 100644 --- a/tests/test_utils/test_model_cache.py +++ b/tests/test_utils/test_model_cache.py @@ -1,7 +1,8 @@ import pytest from torch import nn -from imaginairy import ImaginePrompt, imagine +from imaginairy.api import imagine +from imaginairy.schema import ImaginePrompt from imaginairy.utils import get_device from imaginairy.utils.model_cache import GPUModelCache @@ -40,7 +41,9 @@ def create_model_of_n_bytes(n): def test_memory_usage(filename_base_for_orig_outputs, model_version): """Test that we can switch between model versions.""" prompt_text = "valley, fairytale treehouse village covered, , matte painting, highly detailed, dynamic lighting, cinematic, realism, realistic, photo real, sunset, detailed, high contrast, denoised, centered, michael whelan" - prompts = [ImaginePrompt(prompt_text, model=model_version, seed=1, steps=30)] + prompts = [ + ImaginePrompt(prompt_text, model_weights=model_version, seed=1, steps=30) + ] for i, result in enumerate(imagine(prompts)): img_path = f"{filename_base_for_orig_outputs}_{result.prompt.prompt_text}_{result.prompt.model}.png" diff --git a/tests/test_utils/test_named_resolutions.py b/tests/test_utils/test_named_resolutions.py index d5162c59..07d6e808 100644 --- a/tests/test_utils/test_named_resolutions.py +++ b/tests/test_utils/test_named_resolutions.py @@ -1,6 +1,6 @@ import pytest -from imaginairy.utils.named_resolutions import get_named_resolution +from imaginairy.utils.named_resolutions import normalize_image_size valid_cases = [ ("HD", (1280, 720)), @@ -12,11 +12,25 @@ ("1920x1080", (1920, 1080)), ("1280x720", (1280, 720)), ("1024x768", (1024, 768)), + ("1024,768", (1024, 768)), + ("1024*768", (1024, 768)), + ("1024, 768", (1024, 768)), ("800", (800, 800)), ("1024", (1024, 1024)), + ("1080p", (1920, 1080)), + ("1080P", (1920, 1080)), + (512, (512, 512)), + ((512, 512), (512, 512)), + ("1x1", (1, 1)), ] invalid_cases = [ + None, + 3.14, + (3.14, 3.14), + "", + " ", "abc", + "-512", "1920xABC", "1920x1080x1234", "x1920", @@ -30,10 +44,10 @@ @pytest.mark.parametrize(("named_resolution", "expected"), valid_cases) def test_named_resolutions(named_resolution, expected): - assert get_named_resolution(named_resolution) == expected + assert normalize_image_size(named_resolution) == expected @pytest.mark.parametrize("named_resolution", invalid_cases) def test_invalid_inputs(named_resolution): - with pytest.raises(ValueError, match="Unknown resolution"): - get_named_resolution(named_resolution) + with pytest.raises(ValueError, match="Invalid resolution"): + normalize_image_size(named_resolution) diff --git a/tests/utils.py b/tests/utils.py index 770e2712..17658665 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,3 +1,5 @@ +import time + import numpy as np from PIL import Image @@ -23,3 +25,21 @@ def calc_norm_sum_sq_diff(img, img2): ) norm_sum_sq_diff = sum_sq_diff / np.sqrt(sum_sq_diff) return norm_sum_sq_diff + + +class Timer: + def __init__(self, name): + self.name = name + self.start = None + self.elapsed = None + self.end = None + + def __enter__(self): + self.start = time.perf_counter() + return self + + def __exit__(self, *args): + self.end = time.perf_counter() + self.elapsed = self.end - self.start + + print(f"{self.name} took {self.elapsed*1000:.2f} ms") diff --git a/tox.ini b/tox.ini index f5079623..90befd69 100644 --- a/tox.ini +++ b/tox.ini @@ -4,3 +4,16 @@ norecursedirs = build dist downloads other prolly_delete imaginairy/vendored filterwarnings = ignore::DeprecationWarning ignore::UserWarning +markers = + gputest: uses the gpu + +[mypy] +plugins = pydantic.mypy +exclude = ^(\./|)(downloads|dist|build|other|testing_support|imaginairy/vendored|imaginairy/modules/sgm) +ignore_missing_imports = True +warn_unused_configs = True +warn_unused_ignores = True + +[mypy-imaginairy.vendored.*] +follow_imports = skip +ignore_errors = True \ No newline at end of file