Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build: remove pytorch lightning dependency #436

Merged
merged 2 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion imaginairy/api/generate_compvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING, Any

from imaginairy.api.generate import IMAGINAIRY_SAFETY_MODE
from imaginairy.utils import seed_everything
from imaginairy.utils.img_utils import calc_scale_to_fit_within, combine_image
from imaginairy.utils.named_resolutions import normalize_image_size

Expand All @@ -25,7 +26,6 @@ def _generate_single_image_compvis(
):
import torch.nn
from PIL import Image, ImageOps
from pytorch_lightning import seed_everything

from imaginairy.enhancers.clip_masking import get_img_mask
from imaginairy.enhancers.describe_image_blip import generate_caption
Expand Down
3 changes: 1 addition & 2 deletions imaginairy/api/generate_refiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from imaginairy.config import CONTROL_CONFIG_SHORTCUTS
from imaginairy.schema import ControlInput, ImaginePrompt, MaskMode
from imaginairy.utils import clear_gpu_cache
from imaginairy.utils import clear_gpu_cache, seed_everything
from imaginairy.utils.log_utils import ImageLoggingContext

logger = logging.getLogger(__name__)
Expand All @@ -27,7 +27,6 @@ def generate_single_image(
):
import torch.nn
from PIL import Image, ImageOps
from pytorch_lightning import seed_everything
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DPMSolver
from tqdm import tqdm

Expand Down
2 changes: 2 additions & 0 deletions imaginairy/cli/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def imaginairy_click_context(log_level="INFO"):
yield
except errors_to_catch as e:
logger.error(e)
# import traceback
# traceback.print_exc()


def _imagine_cmd(
Expand Down
6 changes: 3 additions & 3 deletions imaginairy/img_processors/control_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def create_canny_edges(img: "Tensor") -> "Tensor":
img = torch.clamp((img + 1.0) / 2.0, min=0.0, max=1.0)
img = einops.rearrange(img[0], "c h w -> h w c")
img = (255.0 * img).cpu().numpy().astype(np.uint8).squeeze()
blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8)
blurred = cv2.GaussianBlur(img, (5, 5), 0).astype(np.uint8) # type: ignore

if len(blurred.shape) > 2:
blurred = cv2.cvtColor(blurred, cv2.COLOR_BGR2GRAY)
Expand Down Expand Up @@ -143,7 +143,7 @@ def make_noise_disk(H: int, W: int, C: int, F: int) -> "np.ndarray":
import numpy as np

noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) # type: ignore
noise = noise[F : F + H, F : F + W]
noise -= np.min(noise)
noise /= np.max(noise)
Expand All @@ -165,7 +165,7 @@ def shuffle_map_np(img: "np.ndarray", h=None, w=None, f=256) -> "np.ndarray":
x = make_noise_disk(h, w, 1, f) * float(W - 1)
y = make_noise_disk(h, w, 1, f) * float(H - 1)
flow = np.concatenate([x, y], axis=2).astype(np.float32)
return cv2.remap(img, flow, None, cv2.INTER_LINEAR)
return cv2.remap(img, flow, None, cv2.INTER_LINEAR) # type: ignore


def shuffle_map_torch(tensor: "Tensor", h=None, w=None, f=256) -> "Tensor":
Expand Down
4 changes: 2 additions & 2 deletions imaginairy/modules/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import math
from contextlib import contextmanager

import pytorch_lightning as pl
import torch
from torch import nn
from torch.cuda import OutOfMemoryError

from imaginairy.modules.diffusion.model import Decoder, Encoder
Expand All @@ -18,7 +18,7 @@
logger = logging.getLogger(__name__)


class AutoencoderKL(pl.LightningModule):
class AutoencoderKL(nn.Module):
def __init__(
self,
ddconfig,
Expand Down
5 changes: 2 additions & 3 deletions imaginairy/modules/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Optional

import numpy as np
import pytorch_lightning as pl
import torch
from einops import rearrange, repeat
from omegaconf import ListConfig
Expand Down Expand Up @@ -93,7 +92,7 @@ def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2


class DDPM(pl.LightningModule):
class DDPM(nn.Module):
# classic DDPM with Gaussian diffusion, in image space
def __init__(
self,
Expand Down Expand Up @@ -1711,7 +1710,7 @@ def to_rgb(self, x):
return x


class DiffusionWrapper(pl.LightningModule):
class DiffusionWrapper(nn.Module):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
Expand Down
3 changes: 1 addition & 2 deletions imaginairy/modules/sgm/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import pytorch_lightning as pl
import torch
import torch.nn as nn
from einops import rearrange
Expand All @@ -32,7 +31,7 @@
logpy = logging.getLogger(__name__)


class AbstractAutoencoder(pl.LightningModule):
class AbstractAutoencoder(nn.Module):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
unCLIP models, etc. Hence, it is fairly general, and specific features
Expand Down
4 changes: 2 additions & 2 deletions imaginairy/modules/sgm/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union

import pytorch_lightning as pl
import torch
from omegaconf import ListConfig, OmegaConf
from safetensors.torch import load_file as load_safetensors
from torch import nn
from torch.optim.lr_scheduler import LambdaLR

from imaginairy.modules.ema import LitEma
Expand All @@ -30,7 +30,7 @@
OPENAIUNETWRAPPER = "imaginairy.modules.sgm.diffusionmodules.wrappers.OpenAIWrapper"


class DiffusionEngine(pl.LightningModule):
class DiffusionEngine(nn.Module):
def __init__(
self,
network_config,
Expand Down
12 changes: 12 additions & 0 deletions imaginairy/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import importlib
import logging
import platform
import random
import re
import time
from contextlib import contextmanager, nullcontext
from functools import lru_cache
from typing import Any, List, Optional

import numpy as np
import torch
from torch import Tensor, autocast
from torch.nn import functional
Expand Down Expand Up @@ -334,3 +336,13 @@ def clear_gpu_cache():
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()


def seed_everything(seed: int | None = None) -> None:
if seed is None:
seed = random.randint(0, 2**32 - 1)
logger.info(f"Using random seed: {seed}")
random.seed(a=seed)
np.random.seed(seed=seed)
torch.manual_seed(seed=seed)
torch.cuda.manual_seed_all(seed=seed)
7 changes: 5 additions & 2 deletions imaginairy/utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ def disable_transformers_custom_logging():


def disable_pytorch_lighting_custom_logging():
from pytorch_lightning import _logger as pytorch_logger
try:
from pytorch_lightning import _logger as pytorch_logger
except ImportError:
return

try:
from pytorch_lightning.utilities.seed import log
Expand Down Expand Up @@ -419,7 +422,7 @@ def disable_common_warnings():

def suppress_annoying_logs_and_warnings():
disable_transformers_custom_logging()
disable_pytorch_lighting_custom_logging()
# disable_pytorch_lighting_custom_logging()
disable_common_warnings()


Expand Down
5 changes: 4 additions & 1 deletion imaginairy/weight_management/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def could_convert(self, source_weights):
def cast_weights(self, source_weights) -> dict[str, "Tensor"]:
converted_state_dict: dict[str, "Tensor"] = {}
for source_key in source_weights:
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
try:
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
except ValueError:
continue
# handle aliases
source_prefix = self.source_aliases.get(source_prefix, source_prefix)
try:
Expand Down
Loading