Skip to content

Commit

Permalink
Feature/sdxl lightning (#36)
Browse files Browse the repository at this point in the history
* feature: add SDXL-Lightning image generator

* fix: minor example fix

* feature: add prompt weighting to sdxl-lightning

* docs: update README

* test: add sdxl-lightning tests

* [Automated] Updated coverage badge

---------

Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
sokovninn and actions-user authored Feb 27, 2024
1 parent e8a9e11 commit cc26649
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 12 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ datadreamer --save_dir <directory> --class_names <objects> --prompts_number <num
- `--task`: Choose between `detection` and `classification`. Default is `detection`.
- `--num_objects_range`: Range of objects in a prompt. Default is 1 to 3.
- `--prompt_generator`: Choose between `simple`, `lm` (language model) and `tiny` (tiny LM). Default is `simple`.
- `--image_generator`: Choose image generator, e.g., `sdxl` or `sdxl-turbo`. Default is `sdxl-turbo`.
- `--image_generator`: Choose image generator, e.g., `sdxl`, `sdxl-turbo` or `sdxl-lightning`. Default is `sdxl-turbo`.
- `--image_annotator`: Specify the image annotator, like `owlv2`. Default is `owlv2`.
- `--conf_threshold`: Confidence threshold for object detection. Default is 0.15.
- `--use_tta`: Toggle test time augmentation for object detection. Default is True.
Expand All @@ -132,6 +132,7 @@ datadreamer --save_dir <directory> --class_names <objects> --prompts_number <num
| | Simple random generator | Joins randomly chosen object names |
| Image Generation | [SDXL-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | Slow and accurate (1024x1024 images) |
| | [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) | Fast and less accurate (512x512 images) |
| | [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning) | Fast and accurate (1024x1024 images) |
| Image Annotation | [OWLv2](https://huggingface.co/google/owlv2-base-patch16-ensemble) | Open-Vocabulary object detector |

<a name="example"></a>
Expand Down
7 changes: 6 additions & 1 deletion datadreamer/image_generation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .sdxl_image_generator import StableDiffusionImageGenerator
from .sdxl_lightning_image_generator import StableDiffusionLightningImageGenerator
from .sdxl_turbo_image_generator import StableDiffusionTurboImageGenerator

__all__ = ["StableDiffusionImageGenerator", "StableDiffusionTurboImageGenerator"]
__all__ = [
"StableDiffusionImageGenerator",
"StableDiffusionTurboImageGenerator",
"StableDiffusionLightningImageGenerator",
]
2 changes: 1 addition & 1 deletion datadreamer/image_generation/sdxl_image_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def release(self, empty_cuda_cache=False) -> None:
"A photo of a bicycle pedaling alongside an aeroplane taking off, showcasing the harmony between human-powered and mechanical transportation.",
"A photo of bicycles along a scenic mountain path, where the riders seem to have taken a moment to appreciate the stunning views.",
]
prompt_objects = [["aeroplane", "boat", "bicycle"], ["bicycle"]]
prompt_objects = [["aeroplane", "bicycle"], ["bicycle"]]

image_paths = []
counter = 0
Expand Down
173 changes: 173 additions & 0 deletions datadreamer/image_generation/sdxl_lightning_image_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from typing import List, Optional

import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import (
EulerDiscreteScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from huggingface_hub import hf_hub_download
from PIL import Image
from safetensors.torch import load_file

from datadreamer.image_generation.image_generator import ImageGenerator


class StableDiffusionLightningImageGenerator(ImageGenerator):
"""A subclass of ImageGenerator specifically designed to use the Stable Diffusion
Lightning model for faster image generation.
Attributes:
pipe (StableDiffusionXLPipeline): The Stable Diffusion Lightning model for image generation.
Methods:
_init_gen_model(): Initializes the Stable Diffusion Lightning model.
_init_compel(): Initializes the Compel model for text prompt weighting.
generate_images_batch(prompts, negative_prompt, prompt_objects): Generates a batch of images based on the provided prompts.
release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache.
"""

def __init__(self, *args, **kwargs):
"""Initializes the StableDiffusionLightningImageGenerator with the given
arguments."""
super().__init__(*args, **kwargs)
self.pipe = self._init_gen_model()
self.compel = self._init_compel()

def _init_gen_model(self):
"""Initializes the Stable Diffusion Lightning model for image generation.
Returns:
StableDiffusionXLPipeline: The initialized Stable Diffusion Lightning model.
"""
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!

# Load model.
if self.device == "cpu":
print("Loading SDXL Lightning on CPU...")
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt)))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet)
else:
print("Loading SDXL Lightning on GPU...")
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(
self.device, torch.float16
)
unet.load_state_dict(
load_file(hf_hub_download(repo, ckpt), device=self.device)
)
pipe = StableDiffusionXLPipeline.from_pretrained(
base, unet=unet, torch_dtype=torch.float16, variant="fp16"
).to(self.device)
pipe.enable_model_cpu_offload()

# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
)

return pipe

def _init_compel(self):
"""Initializes the Compel model for text prompt weighting.
Returns:
Compel: The initialized Compel model.
"""
compel = Compel(
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
)
return compel

def generate_images_batch(
self,
prompts: List[str],
negative_prompt: str,
prompt_objects: Optional[List[List[str]]] = None,
batch_size: int = 1,
) -> List[Image.Image]:
"""Generates a batch of images using the Stable Diffusion Lightning model based
on the provided prompts.
Args:
prompts (List[str]): A list of positive prompts to guide image generation.
negative_prompt (str): The negative prompt to avoid certain features in the image.
prompt_objects (Optional[List[List[str]]]): Optional list of objects for each prompt for CLIP model testing.
batch_size (int): The number of images to generate in each batch.
Returns:
List[Image.Image]: A list of generated images.
"""

if prompt_objects is not None:
for i in range(len(prompt_objects)):
for obj in prompt_objects[i]:
prompts[i] = prompts[i].replace(obj, f"({obj})1.5", 1)

conditioning, pooled = self.compel(prompts)
conditioning_neg, pooled_neg = self.compel([negative_prompt] * len(prompts))
images = self.pipe(
prompt_embeds=conditioning,
pooled_prompt_embeds=pooled,
negative_prompt_embeds=conditioning_neg,
negative_pooled_prompt_embeds=pooled_neg,
guidance_scale=0.0,
num_inference_steps=4,
).images

return images

def release(self, empty_cuda_cache=False) -> None:
"""Releases the model and optionally empties the CUDA cache."""
self.pipe = self.pipe.to("cpu")
if self.use_clip_image_tester:
self.clip_image_tester.release()
if empty_cuda_cache:
with torch.no_grad():
torch.cuda.empty_cache()


if __name__ == "__main__":
import os

# Create the generator
image_generator = StableDiffusionLightningImageGenerator(
seed=42,
use_clip_image_tester=False,
image_tester_patience=1,
batch_size=4,
device="cpu",
)
prompts = [
"A photo of a bicycle pedaling alongside an aeroplane.",
"A photo of a dragonfly flying in the sky.",
"A photo of a dog walking in the park.",
"A photo of an alien exploring the galaxy.",
"A photo of a robot working on a computer.",
]
prompt_objects = [
["aeroplane", "bicycle"],
["dragonfly"],
["dog"],
["alien"],
["robot", "computer"],
]

image_paths = []
counter = 0
for generated_images_batch in image_generator.generate_images(
prompts, prompt_objects
):
for generated_image in generated_images_batch:
image_path = os.path.join("./", f"image_lightning_{counter}.jpg")
generated_image.save(image_path)
image_paths.append(image_path)
counter += 1

image_generator.release(empty_cuda_cache=True)
2 changes: 1 addition & 1 deletion datadreamer/image_generation/sdxl_turbo_image_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def release(self, empty_cuda_cache=False) -> None:
prompts = [
"A photo of a bicycle pedaling alongside an aeroplane taking off, showcasing the harmony between human-powered and mechanical transportation.",
] * 16
prompt_objects = [["aeroplane", "boat", "bicycle"]] * 16
prompt_objects = [["aeroplane", "bicycle"]] * 16

image_paths = []
counter = 0
Expand Down
4 changes: 3 additions & 1 deletion datadreamer/pipelines/generate_dataset_from_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datadreamer.dataset_annotation import OWLv2Annotator
from datadreamer.image_generation import (
StableDiffusionImageGenerator,
StableDiffusionLightningImageGenerator,
StableDiffusionTurboImageGenerator,
)
from datadreamer.prompt_generation import (
Expand All @@ -30,6 +31,7 @@
image_generators = {
"sdxl": StableDiffusionImageGenerator,
"sdxl-turbo": StableDiffusionTurboImageGenerator,
"sdxl-lightning": StableDiffusionLightningImageGenerator,
}

annotators = {"owlv2": OWLv2Annotator}
Expand Down Expand Up @@ -84,7 +86,7 @@ def parse_args():
"--image_generator",
type=str,
default="sdxl-turbo",
choices=["sdxl", "sdxl-turbo"],
choices=["sdxl", "sdxl-turbo", "sdxl-lightning"],
help="Image generator to use",
)
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions media/coverage_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
44 changes: 44 additions & 0 deletions tests/integration/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,50 @@ def test_cuda_simple_sdxl_detection_pipeline():
_check_detection_pipeline(cmd, target_folder)


@pytest.mark.skipif(
total_memory < 16 or total_disk_space < 35,
reason="Test requires at least 16GB of RAM and 35GB of HDD",
)
def test_cpu_simple_sdxl_lightning_detection_pipeline():
# Define target folder
target_folder = "data/data-det-cpu-simple-sdxl-lightning/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator simple "
f"--num_objects_range 1 2 "
f"--image_generator sdxl-lightning "
f"--use_image_tester "
f"--device cpu"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


@pytest.mark.skipif(
not torch.cuda.is_available() or total_memory < 16 or total_disk_space < 35,
reason="Test requires GPU, at least 16GB of RAM and 35GB of HDD",
)
def test_cuda_simple_sdxl_lightning_detection_pipeline():
# Define target folder
target_folder = "data/data-det-cuda-simple-sdxl-lightning/"
# Define the command to run the datadreamer
cmd = (
f"datadreamer --save_dir {target_folder} "
f"--class_names alien mars cat "
f"--prompts_number 1 "
f"--prompt_generator simple "
f"--num_objects_range 1 2 "
f"--image_generator sdxl-lightning "
f"--use_image_tester "
f"--device cuda"
)
# Check the run of the pipeline
_check_detection_pipeline(cmd, target_folder)


# =========================================================
# DETECTION - LLM
# =========================================================
Expand Down
29 changes: 24 additions & 5 deletions tests/unittests/test_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import torch
from PIL import Image

from datadreamer.image_generation.clip_image_tester import ClipImageTester
from datadreamer.image_generation.sdxl_image_generator import (
from datadreamer.image_generation import (
StableDiffusionImageGenerator,
)
from datadreamer.image_generation.sdxl_turbo_image_generator import (
StableDiffusionLightningImageGenerator,
StableDiffusionTurboImageGenerator,
)
from datadreamer.image_generation.clip_image_tester import ClipImageTester

# Get the total memory in GB
total_memory = psutil.virtual_memory().total / (1024**3)
Expand Down Expand Up @@ -55,7 +54,11 @@ def test_cpu_clip_image_tester():

def _check_image_generator(
image_generator_class: Type[
Union[StableDiffusionImageGenerator, StableDiffusionTurboImageGenerator]
Union[
StableDiffusionImageGenerator,
StableDiffusionTurboImageGenerator,
StableDiffusionLightningImageGenerator,
]
],
device: str,
):
Expand Down Expand Up @@ -101,3 +104,19 @@ def test_cuda_sdxl_turbo_image_generator():
)
def test_cpu_sdxl_turbo_image_generator():
_check_image_generator(StableDiffusionTurboImageGenerator, "cpu")


@pytest.mark.skipif(
not torch.cuda.is_available() or total_memory < 16 or total_disk_space < 25,
reason="Test requires GPU, at least 16GB of RAM and 25GB of HDD",
)
def test_cuda_sdxl_lightning_image_generator():
_check_image_generator(StableDiffusionLightningImageGenerator, "cuda")


@pytest.mark.skipif(
total_memory < 16 or total_disk_space < 25,
reason="Test requires at least 16GB of RAM and 25GB of HDD",
)
def test_cpu_sdxl_lightning_image_generator():
_check_image_generator(StableDiffusionLightningImageGenerator, "cpu")

0 comments on commit cc26649

Please sign in to comment.