diff --git a/README.md b/README.md index 2e47f38..1381492 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ datadreamer --save_dir --class_names --prompts_number --class_names --prompts_number diff --git a/datadreamer/image_generation/__init__.py b/datadreamer/image_generation/__init__.py index 6462c2d..9c81fec 100644 --- a/datadreamer/image_generation/__init__.py +++ b/datadreamer/image_generation/__init__.py @@ -1,4 +1,9 @@ from .sdxl_image_generator import StableDiffusionImageGenerator +from .sdxl_lightning_image_generator import StableDiffusionLightningImageGenerator from .sdxl_turbo_image_generator import StableDiffusionTurboImageGenerator -__all__ = ["StableDiffusionImageGenerator", "StableDiffusionTurboImageGenerator"] +__all__ = [ + "StableDiffusionImageGenerator", + "StableDiffusionTurboImageGenerator", + "StableDiffusionLightningImageGenerator", +] diff --git a/datadreamer/image_generation/sdxl_image_generator.py b/datadreamer/image_generation/sdxl_image_generator.py index e97e8d4..6ded0c2 100644 --- a/datadreamer/image_generation/sdxl_image_generator.py +++ b/datadreamer/image_generation/sdxl_image_generator.py @@ -173,7 +173,7 @@ def release(self, empty_cuda_cache=False) -> None: "A photo of a bicycle pedaling alongside an aeroplane taking off, showcasing the harmony between human-powered and mechanical transportation.", "A photo of bicycles along a scenic mountain path, where the riders seem to have taken a moment to appreciate the stunning views.", ] - prompt_objects = [["aeroplane", "boat", "bicycle"], ["bicycle"]] + prompt_objects = [["aeroplane", "bicycle"], ["bicycle"]] image_paths = [] counter = 0 diff --git a/datadreamer/image_generation/sdxl_lightning_image_generator.py b/datadreamer/image_generation/sdxl_lightning_image_generator.py new file mode 100644 index 0000000..c590cc1 --- /dev/null +++ b/datadreamer/image_generation/sdxl_lightning_image_generator.py @@ -0,0 +1,173 @@ +from typing import List, Optional + +import torch +from compel import Compel, ReturnedEmbeddingsType +from diffusers import ( + EulerDiscreteScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from huggingface_hub import hf_hub_download +from PIL import Image +from safetensors.torch import load_file + +from datadreamer.image_generation.image_generator import ImageGenerator + + +class StableDiffusionLightningImageGenerator(ImageGenerator): + """A subclass of ImageGenerator specifically designed to use the Stable Diffusion + Lightning model for faster image generation. + + Attributes: + pipe (StableDiffusionXLPipeline): The Stable Diffusion Lightning model for image generation. + + Methods: + _init_gen_model(): Initializes the Stable Diffusion Lightning model. + _init_compel(): Initializes the Compel model for text prompt weighting. + generate_images_batch(prompts, negative_prompt, prompt_objects): Generates a batch of images based on the provided prompts. + release(empty_cuda_cache): Releases resources and optionally empties the CUDA cache. + """ + + def __init__(self, *args, **kwargs): + """Initializes the StableDiffusionLightningImageGenerator with the given + arguments.""" + super().__init__(*args, **kwargs) + self.pipe = self._init_gen_model() + self.compel = self._init_compel() + + def _init_gen_model(self): + """Initializes the Stable Diffusion Lightning model for image generation. + + Returns: + StableDiffusionXLPipeline: The initialized Stable Diffusion Lightning model. + """ + base = "stabilityai/stable-diffusion-xl-base-1.0" + repo = "ByteDance/SDXL-Lightning" + ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting! + + # Load model. + if self.device == "cpu": + print("Loading SDXL Lightning on CPU...") + unet = UNet2DConditionModel.from_config(base, subfolder="unet") + unet.load_state_dict(load_file(hf_hub_download(repo, ckpt))) + pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet) + else: + print("Loading SDXL Lightning on GPU...") + unet = UNet2DConditionModel.from_config(base, subfolder="unet").to( + self.device, torch.float16 + ) + unet.load_state_dict( + load_file(hf_hub_download(repo, ckpt), device=self.device) + ) + pipe = StableDiffusionXLPipeline.from_pretrained( + base, unet=unet, torch_dtype=torch.float16, variant="fp16" + ).to(self.device) + pipe.enable_model_cpu_offload() + + # Ensure sampler uses "trailing" timesteps. + pipe.scheduler = EulerDiscreteScheduler.from_config( + pipe.scheduler.config, timestep_spacing="trailing" + ) + + return pipe + + def _init_compel(self): + """Initializes the Compel model for text prompt weighting. + + Returns: + Compel: The initialized Compel model. + """ + compel = Compel( + tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2], + text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], + returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, + requires_pooled=[False, True], + ) + return compel + + def generate_images_batch( + self, + prompts: List[str], + negative_prompt: str, + prompt_objects: Optional[List[List[str]]] = None, + batch_size: int = 1, + ) -> List[Image.Image]: + """Generates a batch of images using the Stable Diffusion Lightning model based + on the provided prompts. + + Args: + prompts (List[str]): A list of positive prompts to guide image generation. + negative_prompt (str): The negative prompt to avoid certain features in the image. + prompt_objects (Optional[List[List[str]]]): Optional list of objects for each prompt for CLIP model testing. + batch_size (int): The number of images to generate in each batch. + + Returns: + List[Image.Image]: A list of generated images. + """ + + if prompt_objects is not None: + for i in range(len(prompt_objects)): + for obj in prompt_objects[i]: + prompts[i] = prompts[i].replace(obj, f"({obj})1.5", 1) + + conditioning, pooled = self.compel(prompts) + conditioning_neg, pooled_neg = self.compel([negative_prompt] * len(prompts)) + images = self.pipe( + prompt_embeds=conditioning, + pooled_prompt_embeds=pooled, + negative_prompt_embeds=conditioning_neg, + negative_pooled_prompt_embeds=pooled_neg, + guidance_scale=0.0, + num_inference_steps=4, + ).images + + return images + + def release(self, empty_cuda_cache=False) -> None: + """Releases the model and optionally empties the CUDA cache.""" + self.pipe = self.pipe.to("cpu") + if self.use_clip_image_tester: + self.clip_image_tester.release() + if empty_cuda_cache: + with torch.no_grad(): + torch.cuda.empty_cache() + + +if __name__ == "__main__": + import os + + # Create the generator + image_generator = StableDiffusionLightningImageGenerator( + seed=42, + use_clip_image_tester=False, + image_tester_patience=1, + batch_size=4, + device="cpu", + ) + prompts = [ + "A photo of a bicycle pedaling alongside an aeroplane.", + "A photo of a dragonfly flying in the sky.", + "A photo of a dog walking in the park.", + "A photo of an alien exploring the galaxy.", + "A photo of a robot working on a computer.", + ] + prompt_objects = [ + ["aeroplane", "bicycle"], + ["dragonfly"], + ["dog"], + ["alien"], + ["robot", "computer"], + ] + + image_paths = [] + counter = 0 + for generated_images_batch in image_generator.generate_images( + prompts, prompt_objects + ): + for generated_image in generated_images_batch: + image_path = os.path.join("./", f"image_lightning_{counter}.jpg") + generated_image.save(image_path) + image_paths.append(image_path) + counter += 1 + + image_generator.release(empty_cuda_cache=True) diff --git a/datadreamer/image_generation/sdxl_turbo_image_generator.py b/datadreamer/image_generation/sdxl_turbo_image_generator.py index 5198a85..0abc29e 100644 --- a/datadreamer/image_generation/sdxl_turbo_image_generator.py +++ b/datadreamer/image_generation/sdxl_turbo_image_generator.py @@ -105,7 +105,7 @@ def release(self, empty_cuda_cache=False) -> None: prompts = [ "A photo of a bicycle pedaling alongside an aeroplane taking off, showcasing the harmony between human-powered and mechanical transportation.", ] * 16 - prompt_objects = [["aeroplane", "boat", "bicycle"]] * 16 + prompt_objects = [["aeroplane", "bicycle"]] * 16 image_paths = [] counter = 0 diff --git a/datadreamer/pipelines/generate_dataset_from_scratch.py b/datadreamer/pipelines/generate_dataset_from_scratch.py index db52cfb..843fdc0 100644 --- a/datadreamer/pipelines/generate_dataset_from_scratch.py +++ b/datadreamer/pipelines/generate_dataset_from_scratch.py @@ -12,6 +12,7 @@ from datadreamer.dataset_annotation import OWLv2Annotator from datadreamer.image_generation import ( StableDiffusionImageGenerator, + StableDiffusionLightningImageGenerator, StableDiffusionTurboImageGenerator, ) from datadreamer.prompt_generation import ( @@ -30,6 +31,7 @@ image_generators = { "sdxl": StableDiffusionImageGenerator, "sdxl-turbo": StableDiffusionTurboImageGenerator, + "sdxl-lightning": StableDiffusionLightningImageGenerator, } annotators = {"owlv2": OWLv2Annotator} @@ -84,7 +86,7 @@ def parse_args(): "--image_generator", type=str, default="sdxl-turbo", - choices=["sdxl", "sdxl-turbo"], + choices=["sdxl", "sdxl-turbo", "sdxl-lightning"], help="Image generator to use", ) parser.add_argument( diff --git a/media/coverage_badge.svg b/media/coverage_badge.svg index 1581a9a..a5804ea 100644 --- a/media/coverage_badge.svg +++ b/media/coverage_badge.svg @@ -15,7 +15,7 @@ coverage coverage - 48% - 48% + 46% + 46% diff --git a/tests/integration/test_pipeline.py b/tests/integration/test_pipeline.py index 2a5cdc8..5d25b95 100644 --- a/tests/integration/test_pipeline.py +++ b/tests/integration/test_pipeline.py @@ -318,6 +318,50 @@ def test_cuda_simple_sdxl_detection_pipeline(): _check_detection_pipeline(cmd, target_folder) +@pytest.mark.skipif( + total_memory < 16 or total_disk_space < 35, + reason="Test requires at least 16GB of RAM and 35GB of HDD", +) +def test_cpu_simple_sdxl_lightning_detection_pipeline(): + # Define target folder + target_folder = "data/data-det-cpu-simple-sdxl-lightning/" + # Define the command to run the datadreamer + cmd = ( + f"datadreamer --save_dir {target_folder} " + f"--class_names alien mars cat " + f"--prompts_number 1 " + f"--prompt_generator simple " + f"--num_objects_range 1 2 " + f"--image_generator sdxl-lightning " + f"--use_image_tester " + f"--device cpu" + ) + # Check the run of the pipeline + _check_detection_pipeline(cmd, target_folder) + + +@pytest.mark.skipif( + not torch.cuda.is_available() or total_memory < 16 or total_disk_space < 35, + reason="Test requires GPU, at least 16GB of RAM and 35GB of HDD", +) +def test_cuda_simple_sdxl_lightning_detection_pipeline(): + # Define target folder + target_folder = "data/data-det-cuda-simple-sdxl-lightning/" + # Define the command to run the datadreamer + cmd = ( + f"datadreamer --save_dir {target_folder} " + f"--class_names alien mars cat " + f"--prompts_number 1 " + f"--prompt_generator simple " + f"--num_objects_range 1 2 " + f"--image_generator sdxl-lightning " + f"--use_image_tester " + f"--device cuda" + ) + # Check the run of the pipeline + _check_detection_pipeline(cmd, target_folder) + + # ========================================================= # DETECTION - LLM # ========================================================= diff --git a/tests/unittests/test_image_generation.py b/tests/unittests/test_image_generation.py index 1ccd5db..6cce53c 100644 --- a/tests/unittests/test_image_generation.py +++ b/tests/unittests/test_image_generation.py @@ -6,13 +6,12 @@ import torch from PIL import Image -from datadreamer.image_generation.clip_image_tester import ClipImageTester -from datadreamer.image_generation.sdxl_image_generator import ( +from datadreamer.image_generation import ( StableDiffusionImageGenerator, -) -from datadreamer.image_generation.sdxl_turbo_image_generator import ( + StableDiffusionLightningImageGenerator, StableDiffusionTurboImageGenerator, ) +from datadreamer.image_generation.clip_image_tester import ClipImageTester # Get the total memory in GB total_memory = psutil.virtual_memory().total / (1024**3) @@ -55,7 +54,11 @@ def test_cpu_clip_image_tester(): def _check_image_generator( image_generator_class: Type[ - Union[StableDiffusionImageGenerator, StableDiffusionTurboImageGenerator] + Union[ + StableDiffusionImageGenerator, + StableDiffusionTurboImageGenerator, + StableDiffusionLightningImageGenerator, + ] ], device: str, ): @@ -101,3 +104,19 @@ def test_cuda_sdxl_turbo_image_generator(): ) def test_cpu_sdxl_turbo_image_generator(): _check_image_generator(StableDiffusionTurboImageGenerator, "cpu") + + +@pytest.mark.skipif( + not torch.cuda.is_available() or total_memory < 16 or total_disk_space < 25, + reason="Test requires GPU, at least 16GB of RAM and 25GB of HDD", +) +def test_cuda_sdxl_lightning_image_generator(): + _check_image_generator(StableDiffusionLightningImageGenerator, "cuda") + + +@pytest.mark.skipif( + total_memory < 16 or total_disk_space < 25, + reason="Test requires at least 16GB of RAM and 25GB of HDD", +) +def test_cpu_sdxl_lightning_image_generator(): + _check_image_generator(StableDiffusionLightningImageGenerator, "cpu")