Skip to content

Commit

Permalink
feature: support loading sdxl compvis weights (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
brycedrennan authored Jan 13, 2024
1 parent 907e80d commit 700cb45
Show file tree
Hide file tree
Showing 13 changed files with 3,053 additions and 134 deletions.
22 changes: 21 additions & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@

## ChangeLog

**14.0.0**
**14.1.0**
- 🎉 feature: make video generation smooth by adding frame interpolation
- feature: SDXL weights in the compvis format can now be used
- feature: allow video generation at any size specified by user
- feature: video generations output in "bounce" format
- feature: choose video output format: mp4, webp, or gif
- feature: fix random seed handling in video generation
- docs: auto-publish docs on push to master
- build: remove imageio dependency
- build: vendorize facexlib so we don't install its unneeded dependencies


**14.0.4**
- docs: add a documentation website at https://brycedrennan.github.io/imaginAIry/
- build: remove fairscale dependency
- fix: video generation was broken

**14.0.3**
- fix: several critical bugs with package
- tests: add a wheel smoketest to detect these issues in the future

**14.0.0**
- 🎉 video generation using [Stable Video Diffusion](https://github.com/Stability-AI/generative-models)
- add `--videogen` to any image generation to create a short video from the generated image
Expand Down
2 changes: 1 addition & 1 deletion imaginairy/enhancers/video_interpolation/rife/RIFE_HDv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class Model:
def __init__(self):
self.flownet = IFNet()
self.version: float
self.version = None

def eval(self):
self.flownet.eval()
Expand Down
8 changes: 7 additions & 1 deletion imaginairy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,12 @@ def width(self) -> int:
def height(self) -> int:
return self.size[1]

@property
def aspect_ratio(self) -> str:
from imaginairy.utils.img_utils import aspect_ratio

return aspect_ratio(width=self.width, height=self.height)

@property
def should_use_inpainting(self) -> bool:
return bool(self.outpaint or self.mask_image or self.mask_prompt)
Expand Down Expand Up @@ -787,7 +793,7 @@ def prompt_description(self):
" "
f"negative-prompt:{neg_prompt}\n"
" "
f"size:{self.width}x{self.height}px "
f"size:{self.width}x{self.height}px-({self.aspect_ratio}) "
f"seed:{self.seed} "
f"prompt-strength:{self.prompt_strength} "
f"steps:{self.steps} solver-type:{self.solver_type} "
Expand Down
23 changes: 23 additions & 0 deletions imaginairy/utils/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,26 @@ def calc_scale_to_fit_within(height: int, width: int, max_size) -> float:
height_ratio = max_height / height

return min(width_ratio, height_ratio)


def aspect_ratio(width, height):
"""
Calculate the aspect ratio of a given width and height.
Args:
width (int): The width dimension.
height (int): The height dimension.
Returns:
str: The aspect ratio in the format 'X:Y'.
"""
from math import gcd

# Calculate the greatest common divisor
divisor = gcd(width, height)

# Calculate the aspect ratio
x = width // divisor
y = height // divisor

return f"{x}:{y}"
70 changes: 70 additions & 0 deletions imaginairy/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import re
import time
import warnings
from contextlib import contextmanager
from functools import lru_cache
from logging import Logger
from typing import Callable

import torch.cuda
Expand Down Expand Up @@ -57,6 +60,73 @@ def increment_step():
_CURRENT_LOGGING_CONTEXT.step_count += 1


@contextmanager
def timed_log_method(logger, level, msg, *args, hide_below_ms=0, **kwargs):
start_time = time.perf_counter()
try:
yield
finally:
end_time = time.perf_counter()
elapsed_ms = (end_time - start_time) * 1000
if elapsed_ms < hide_below_ms:
return
full_msg = f"{msg} (in {elapsed_ms:.1f}ms)"
logger.log(level, full_msg, *args, **kwargs, stacklevel=3)


@lru_cache
def add_timed_methods_to_logger():
"""Monkey patches the default python logger to have timed logs"""

def create_timed_method(level):
def timed_method(self, msg, *args, hide_below_ms=0, **kwargs):
return timed_log_method(
self, level, msg, *args, hide_below_ms=hide_below_ms, **kwargs
)

return timed_method

logging.Logger.timed_debug = create_timed_method(logging.DEBUG)
logging.Logger.timed_info = create_timed_method(logging.INFO)
logging.Logger.timed_warning = create_timed_method(logging.WARNING)
logging.Logger.timed_error = create_timed_method(logging.ERROR)
logging.Logger.timed_critical = create_timed_method(logging.CRITICAL)


add_timed_methods_to_logger()


class TimedLogger(Logger):
def timed_debug(self, msg, *args, hide_below_ms=0, **kwargs):
return timed_log_method(
self, logging.DEBUG, msg, *args, hide_below_ms=hide_below_ms, **kwargs
)

def timed_info(self, msg, *args, hide_below_ms=0, **kwargs):
return timed_log_method(
self, logging.INFO, msg, *args, hide_below_ms=hide_below_ms, **kwargs
)

def timed_warning(self, msg, *args, hide_below_ms=0, **kwargs):
return timed_log_method(
self, logging.WARNING, msg, *args, hide_below_ms=hide_below_ms, **kwargs
)

def timed_error(self, msg, *args, hide_below_ms=0, **kwargs):
return timed_log_method(
self, logging.ERROR, msg, *args, hide_below_ms=hide_below_ms, **kwargs
)

def timed_critical(self, msg, *args, hide_below_ms=0, **kwargs):
return timed_log_method(
self, logging.CRITICAL, msg, *args, hide_below_ms=hide_below_ms, **kwargs
)


def getLogger(name) -> TimedLogger:
return logging.getLogger(name) # type: ignore


class TimingContext:
"""Tracks time and memory usage of a block of code"""

Expand Down
166 changes: 126 additions & 40 deletions imaginairy/utils/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
LatentDiffusionModel,
)
from imaginairy.weight_management import translators
from imaginairy.weight_management.translators import (
DoubleTextEncoderTranslator,
diffusers_autoencoder_kl_to_refiners_translator,
diffusers_unet_sdxl_to_refiners_translator,
load_weight_map,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,36 +114,6 @@ def load_model_from_config(config, weights_location, half_mode=False):
return model


def load_model_from_config_old(
config, weights_location, control_weights_locations=None, half_mode=False
):
model = instantiate_from_config(config.model)
base_model_dict = load_state_dict(weights_location, half_mode=half_mode)
model.init_from_state_dict(base_model_dict)

control_weights_locations = control_weights_locations or []
controlnets = []
for control_weights_location in control_weights_locations:
controlnet_state_dict = load_state_dict(
control_weights_location, half_mode=half_mode
)
controlnet_state_dict = {
k.replace("control_model.", ""): v for k, v in controlnet_state_dict.items()
}
controlnet = instantiate_from_config(model.control_stage_config)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.to(get_device())
controlnets.append(controlnet)
model.set_control_models(controlnets)

if half_mode:
model = model.half()

model.to(get_device())
model.eval()
return model


def add_controlnet(base_state_dict, controlnet_state_dict):
"""Merges a base sd15 model with a controlnet model."""
for key in controlnet_state_dict:
Expand Down Expand Up @@ -286,7 +262,7 @@ def _get_diffusion_model_refiners(

architecture = iconfig.MODEL_ARCHITECTURE_LOOKUP[architecture_alias]
if architecture.primary_alias in ("sd15", "sd15inpaint"):
sd = _get_sd15_diffusion_model_refiners(
sd = load_sd15_pipeline(
weights_location=weights_location,
for_inpainting=for_inpainting,
device=device,
Expand All @@ -301,18 +277,19 @@ def _get_diffusion_model_refiners(
MOST_RECENTLY_LOADED_MODEL = sd

msg = (
f"sd dtype:{sd.dtype} device:{sd.device}\n"
f"sd.unet dtype:{sd.unet.dtype} device:{sd.unet.device}\n"
f"sd.lda dtype:{sd.lda.dtype} device:{sd.lda.device}\n"
f"sd.clip_text_encoder dtype:{sd.clip_text_encoder.dtype} device:{sd.clip_text_encoder.device}\n"
"Pipeline loaded "
f"sd[dtype:{sd.dtype} device:{sd.device}] "
f"sd.unet[dtype:{sd.unet.dtype} device:{sd.unet.device}] "
f"sd.lda[dtype:{sd.lda.dtype} device:{sd.lda.device}]"
f"sd.clip_text_encoder[dtype:{sd.clip_text_encoder.dtype} device:{sd.clip_text_encoder.device}]"
)
logger.debug(msg)

return sd


# new
def _get_sd15_diffusion_model_refiners(
def load_sd15_pipeline(
weights_location: str,
for_inpainting: bool = False,
device=None,
Expand Down Expand Up @@ -756,7 +733,9 @@ def load_sd15_diffusers_weights(base_url: str, device=None):
return vae_weights, unet_weights, text_encoder_weights


def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16):
def load_sdxl_pipeline_from_diffusers_weights(
base_url: str, device=None, dtype=torch.float16
):
from imaginairy.utils import get_device

device = device or get_device()
Expand Down Expand Up @@ -817,13 +796,47 @@ def load_sdxl_diffusers_weights(base_url: str, device=None, dtype=torch.float16)
return sd


def load_sdxl_pipeline(base_url, device=None):
logger.info(f"Loading SDXL weights from {base_url}")
def load_sdxl_pipeline_from_compvis_weights(
base_url: str, device=None, dtype=torch.float16
):
from imaginairy.utils import get_device

device = device or get_device()
sd = load_sdxl_diffusers_weights(base_url, device=device)
unet_weights, vae_weights, text_encoder_weights = load_sdxl_compvis_weights(
base_url
)
lda = SDXLAutoencoderSliced(device="cpu", dtype=dtype)
lda.load_state_dict(vae_weights, assign=True)
del vae_weights

unet = SDXLUNet(device="cpu", dtype=dtype, in_channels=4)
unet.load_state_dict(unet_weights, assign=True)
del unet_weights

text_encoder = DoubleTextEncoder(device="cpu", dtype=torch.float32)
text_encoder.load_state_dict(text_encoder_weights, assign=True)
del text_encoder_weights
lda = lda.to(device=device, dtype=torch.float32)
unet = unet.to(device=device)
text_encoder = text_encoder.to(device=device)
sd = StableDiffusion_XL(
device=device, dtype=None, lda=lda, unet=unet, clip_text_encoder=text_encoder
)

return sd


def load_sdxl_pipeline(base_url, device=None):
device = device or get_device()

with logger.timed_info(f"Loaded SDXL pipeline from {base_url}"):
if is_diffusers_repo_url(base_url):
sd = load_sdxl_pipeline_from_diffusers_weights(base_url, device=device)
else:
sd = load_sdxl_pipeline_from_compvis_weights(base_url, device=device)
return sd


def open_weights(filepath, device=None):
from imaginairy.utils import get_device

Expand Down Expand Up @@ -940,3 +953,76 @@ def load_stable_diffusion_compvis_weights(weights_url):
)

return vae_state_dict, unet_state_dict, text_encoder_state_dict


def load_sdxl_compvis_weights(url):
from safetensors import safe_open

weights_path = get_cached_url_path(url)
state_dict = {}
unet_state_dict = {}
vae_state_dict = {}
text_encoder_1_state_dict = {}
text_encoder_2_state_dict = {}
with safe_open(weights_path, framework="pt") as f:
for key in f.keys(): # noqa
if key.startswith("model.diffusion_model."):
unet_state_dict[key] = f.get_tensor(key)
elif key.startswith("first_stage_model"):
vae_state_dict[key] = f.get_tensor(key)
elif key.startswith("conditioner.embedders.0."):
text_encoder_1_state_dict[key] = f.get_tensor(key)
elif key.startswith("conditioner.embedders.1."):
text_encoder_2_state_dict[key] = f.get_tensor(key)
else:
state_dict[key] = f.get_tensor(key)
logger.warning(f"Unused key {key}")

unet_weightmap = load_weight_map("Compvis-UNet-SDXL-to-Diffusers")
vae_weightmap = load_weight_map("Compvis-Autoencoder-SDXL-to-Diffusers")
text_encoder_1_weightmap = load_weight_map("Compvis-TextEncoder-SDXL-to-Diffusers")
text_encoder_2_weightmap = load_weight_map(
"Compvis-OpenClipTextEncoder-SDXL-to-Diffusers"
)

diffusers_unet_state_dict = unet_weightmap.translate_weights(unet_state_dict)

refiners_unet_state_dict = (
diffusers_unet_sdxl_to_refiners_translator().translate_weights(
diffusers_unet_state_dict
)
)

diffusers_vae_state_dict = vae_weightmap.translate_weights(vae_state_dict)

refiners_vae_state_dict = (
diffusers_autoencoder_kl_to_refiners_translator().translate_weights(
diffusers_vae_state_dict
)
)

diffusers_text_encoder_1_state_dict = text_encoder_1_weightmap.translate_weights(
text_encoder_1_state_dict
)

for key in list(text_encoder_2_state_dict.keys()):
if key.endswith((".in_proj_bias", ".in_proj_weight")):
value = text_encoder_2_state_dict[key]
q, k, v = value.chunk(3, dim=0)
text_encoder_2_state_dict[f"{key}.0"] = q
text_encoder_2_state_dict[f"{key}.1"] = k
text_encoder_2_state_dict[f"{key}.2"] = v
del text_encoder_2_state_dict[key]

diffusers_text_encoder_2_state_dict = text_encoder_2_weightmap.translate_weights(
text_encoder_2_state_dict
)

refiners_text_encoder_weights = DoubleTextEncoderTranslator().translate_weights(
diffusers_text_encoder_1_state_dict, diffusers_text_encoder_2_state_dict
)
return (
refiners_unet_state_dict,
refiners_vae_state_dict,
refiners_text_encoder_weights,
)
Loading

0 comments on commit 700cb45

Please sign in to comment.