Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introducing components manager #10572

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

introducing components manager #10572

wants to merge 1 commit into from

Conversation

yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Jan 14, 2025

This PR introduces "components manager" (a.k.a diffusers built-in "model management system")

It was mainly made for modular diffusers, but IMO it is also useful for regular pipelines. I think I'll be able to demonstrate its behavior with the regular pipeline use case. So making a separate PR here for easy review and I can start to iterate on this feature before the modular diffusers PR is ready for review

Motivations/objective :

  1. The original diffusers design couples model components with pipelines. However, different pipelines often share the same components, and it is the best practice to reuse these components as much as possible. This PR introduces a dedicated class to manage model components independent of pipelines, so we take a step further in prompting resuing model components and maximize memory efficiency
  2. With this PR, we also provide a way to apply global model offloading strategy across different pipeline/pipeline blocks. This is particularly useful when the user works with multiple workflows at same time. I will walk you though an example below. our goal here is 1: provide such strategy(s) that's simple and work reasonable well out-of-box 2, more importantly, allow users to apply custom offloading strategy through a flexible API

How does it work?

we add a accelerate hook to all models where we:

  1. specify the execution device for that model
  2. give it accesses to other models with a handle to offload other models if needed

before the forward pass of each model, if the model is not already on the execution device, the hook moves the models there along with all its inputs , and will run a custom function to decide if it needs to offload other models (for example, very commonly UI can decide whether to offload based on the available memory and model size, and they may have a set of custom logic that they use to decide which model to offload)

so it is a little bit similar to our sequential cpu model offload, but it is a lot more flexible.

an example

in this example, we'll work with 3 flux related workflows at same time: flux text2img, flux canny control, and flux depth control. we demo a default strategy we made that's basically:

  1. only offload model(s) when there is not enough memory: it basically just checks the available memory on the device and compares with the size of the model that it plans to move to the device. There is also an argument memory_reserve_margin that you can use to adjust how aggressive it is to offload, e.g. if your model size is 20G and you think the actual memory used would be around 25G, I will set memory_reserve_margin=5G, so if you're getting an OOM and the offloading strategy applied wasn't aggressive enough, i.e. it still has unused model left on device, you can reduce this number, otherwise, increase.
  2. when offloading, it tries to offload as little as possible, i.e. it will find the model(s) with the smallest total size that meets the memory requirement.

this strategy is really just an example to show users how they can set their own strategies, can totally use a different strategy so feel free to help brainstorm

I made a colab notebook here too, https://colab.research.google.com/drive/1EVVS8ai4qIW5Ca2CcSGz5VsdvD_N4N8_?usp=sharing

first let's set up, and define inputs etc

import torch
from diffusers.pipelines.components_manager import ComponentsManager
from diffusers import FluxTransformer2DModel, FluxPipeline, FluxControlPipeline
from diffusers.utils import load_image
import gc

from image_gen_aux import DepthPreprocessor
from controlnet_aux import CannyDetector

import logging
logging.getLogger().setLevel(logging.INFO)
logging.getLogger("diffusers").setLevel(logging.INFO)

def reset_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def clear_memory():
    torch.cuda.empty_cache()

dtype = torch.bfloat16
device = "cuda:0"


# create inputs
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."

control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")

canny_processor = CannyDetector()
depth_processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")    

control_image_canny = canny_processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
control_image_depth = depth_processor(control_image)[0].convert("RGB")

now let's create the components manager and add all the models we need to use: that include CLIP/T5 text encoders, 3 flux transformers and vae

# set up components manager
repo = "black-forest-labs/FLUX.1-dev"
canny_repo = "black-forest-labs/FLUX.1-Canny-dev"
depth_repo = "black-forest-labs/FLUX.1-Depth-dev"

components = ComponentsManager()
components.add_from_pretrained(repo, torch_dtype=dtype)

canny = FluxTransformer2DModel.from_pretrained(canny_repo, subfolder="transformer",torch_dtype=dtype)
depth = FluxTransformer2DModel.from_pretrained(depth_repo, subfolder="transformer",torch_dtype=dtype)
components.add("canny", canny)
components.add("depth", depth)

you can print out the components to get an overview of all the model components you have, their device/dtype info; this is what I have before we run any pipelines. You can see that all the model adds up to be something around 76G and in the colab notebook instance I have around 40G memory

components: Components:
===============================================================================================
Models:
-----------------------------------------------------------------------------------------------
Model ID        | Class                           | Device     | Dtype           | Size (GB) 
-----------------------------------------------------------------------------------------------
vae             | AutoencoderKL                   | cpu        | torch.bfloat16  | 0.16
text_encoder    | CLIPTextModel                   | cpu        | torch.bfloat16  | 0.23
text_encoder_2  | T5EncoderModel                  | cpu        | torch.bfloat16  | 8.87
transformer     | FluxTransformer2DModel          | cpu        | torch.bfloat16  | 22.17
canny           | FluxTransformer2DModel          | cpu        | torch.bfloat16  | 22.17
depth           | FluxTransformer2DModel          | cpu        | torch.bfloat16  | 22.17
-----------------------------------------------------------------------------------------------

Other Components:
-----------------------------------------------------------------------------------------------
Component ID    | Class                          
-----------------------------------------------------------------------------------------------
tokenizer       | CLIPTokenizer                  
tokenizer_2     | T5TokenizerFast                
scheduler       | FlowMatchEulerDiscreteScheduler
-----------------------------------------------------------------------------------------------

now let's apply the custom offloading strategy on components manager

components.enable_auto_cpu_offload(device)

run the first workflow

# use case 1: regular text2img
pipe = FluxPipeline.from_pretrained(repo, **components.get(["transformer","text_encoder","text_encoder_2","vae"]), torch_dtype=dtype)

reset_memory()
print(f"components: {components}")
print(f"memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")

generator = torch.Generator(device="cpu").manual_seed(42)
image = pipe(
    prompt=prompt,
    num_inference_steps=28,
    guidance_scale=3.5,
    generator=generator,
).images[0]
image.save("yiyi_test_3_output_text2img.png")

print(f" after text2img:")
print(components)
print(f"memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")

this is what I got after the first workflow, you can see that all the models that were needed for text2image are moved to device and kept on device, the max memory was 36.17 (less than the available) so we were able to run this workflow without offloading any model

after text2img:
Components:
===============================================================================================
Models:
-----------------------------------------------------------------------------------------------
Model ID        | Class                           | Device     | Dtype           | Size (GB) 
-----------------------------------------------------------------------------------------------
vae             | AutoencoderKL                   | cuda:0     | torch.bfloat16  | 0.16
text_encoder    | CLIPTextModel                   | cuda:0     | torch.bfloat16  | 0.23
text_encoder_2  | T5EncoderModel                  | cuda:0     | torch.bfloat16  | 8.87
transformer     | FluxTransformer2DModel          | cuda:0     | torch.bfloat16  | 22.17
canny           | FluxTransformer2DModel          | cpu        | torch.bfloat16  | 22.17
depth           | FluxTransformer2DModel          | cpu        | torch.bfloat16  | 22.17
-----------------------------------------------------------------------------------------------

Other Components:
-----------------------------------------------------------------------------------------------
Component ID    | Class                          
-----------------------------------------------------------------------------------------------
tokenizer       | CLIPTokenizer                  
tokenizer_2     | T5TokenizerFast                
scheduler       | FlowMatchEulerDiscreteScheduler
-----------------------------------------------------------------------------------------------

memory: 36.17 GB

now let's clear the memory cache and run the second workflow

clear_memory()

# test2: controlnet
pipe_canny = FluxControlPipeline.from_pipe(
    pipe, transformer=components.get("canny"), torch_dtype=dtype
)


image = pipe_canny(
    prompt=prompt,
    control_image=control_image_canny,
    height=1024,
    width=1024,
    num_inference_steps=50,
    guidance_scale=30.0,
).images[0]
image.save("yiyi_test_3_output_canny.png")

print(f" after canny:")
print(components)
print(f"memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")

this is what I got, the base transformer are now moved to cpu, and canny was moved to device, max memory stays at 36.28

search for models to offload in order to free up 19.68 GB memory
 time taken to apply offload strategy for canny: 0.02 seconds
moving canny to cuda:0, offloading transformer to cpu
 after canny:
Components:
===============================================================================================
Models:
-----------------------------------------------------------------------------------------------
Model ID        | Class                           | Device     | Dtype           | Size (GB) 
-----------------------------------------------------------------------------------------------
vae             | AutoencoderKL                   | cuda:0     | torch.bfloat16  | 0.16
text_encoder    | CLIPTextModel                   | cuda:0     | torch.bfloat16  | 0.23
text_encoder_2  | T5EncoderModel                  | cuda:0     | torch.bfloat16  | 8.87
transformer     | FluxTransformer2DModel          | cpu        | torch.bfloat16  | 22.17
canny           | FluxTransformer2DModel          | cuda:0     | torch.bfloat16  | 22.17
depth           | FluxTransformer2DModel          | cpu        | torch.bfloat16  | 22.17
-----------------------------------------------------------------------------------------------

Other Components:
-----------------------------------------------------------------------------------------------
Component ID    | Class                          
-----------------------------------------------------------------------------------------------
tokenizer       | CLIPTokenizer                  
tokenizer_2     | T5TokenizerFast                
scheduler       | FlowMatchEulerDiscreteScheduler
-----------------------------------------------------------------------------------------------

memory: 36.28 GB

now run the last workflow

pipe_depth = FluxControlPipeline.from_pipe(
    pipe, transformer=components.get("depth"), torch_dtype=dtype
)


image = pipe_depth(
    prompt=prompt,
    control_image=control_image_depth,
    height=1024,
    width=1024,
    num_inference_steps=30,
    guidance_scale=10.0,
    generator=torch.Generator().manual_seed(42),
).images[0]
image.save("yiyi_test_3_output_depth.png")

print(f" after depth:")
print(components)
print(f"memory: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")
after depth:
Components:
===============================================================================================
Models:
-----------------------------------------------------------------------------------------------
Model ID        | Class                           | Device     | Dtype           | Size (GB) 
-----------------------------------------------------------------------------------------------
vae             | AutoencoderKL                   | cuda:0     | torch.bfloat16  | 0.16
text_encoder    | CLIPTextModel                   | cuda:0     | torch.bfloat16  | 0.23
text_encoder_2  | T5EncoderModel                  | cuda:0     | torch.bfloat16  | 8.87
transformer     | FluxTransformer2DModel          | cpu        | torch.bfloat16  | 22.17
canny           | FluxTransformer2DModel          | cpu        | torch.bfloat16  | 22.17
depth           | FluxTransformer2DModel          | cuda:0     | torch.bfloat16  | 22.17
-----------------------------------------------------------------------------------------------

Other Components:
-----------------------------------------------------------------------------------------------
Component ID    | Class                          
-----------------------------------------------------------------------------------------------
tokenizer       | CLIPTokenizer                  
tokenizer_2     | T5TokenizerFast                
scheduler       | FlowMatchEulerDiscreteScheduler
-----------------------------------------------------------------------------------------------

memory: 36.28 GB

@yiyixuxu yiyixuxu requested review from SunMarc, sayakpaul and DN6 January 14, 2025 08:37
@yiyixuxu
Copy link
Collaborator Author

cc @a-r-r-o-w here too since you're working on some offloading strategy that's targeted on UI use case (e.g. #10503) , we should make it work with this API

@yiyixuxu
Copy link
Collaborator Author

cc @vladmandic too, I think it might be useful for SD Next, if it's the case, let us know if you have any feedbacks!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vladmandic
Copy link
Contributor

cc @vladmandic too, I think it might be useful for SD Next, if it's the case, let us know if you have any feedbacks!

we've implemented our custom offloading (on top of accelerate hooks) a bit back and it works as goal based - move model components based on their size until goal is reach (goal being configurable min/max vram usage thresholds).

i'm not sure how this compares since there are no notes on:

  • what if some model components are quantized?
  • how can some components be offloaded to cpu again after they've been moved to gpu? e.g. text-encoder needs to move to gpu for encode-prompt but then it should move back to cpu for the main to free up as much as possible. and at the end, offload unet/transformer to allow vae to have as much room as possible. offloading once is not useful and offloading sequentially is too expensive.
  • applying loras especially with above two scenarios: a) partially quantized models and b) loras that contain weights for multiple components.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean 💚

Apart from the in-line comments I have some notes below:

will run a custom function to decide if it needs to offload other models (for example, very commonly UI can decide whether to offload based on the available memory and model size, and they may have a set of custom logic that they use to decide which model to offload)

Should this follow a common template so that users can supply their own implementations?

when offloading, it tries to offload as little as possible, i.e. it will find the model(s) with the smallest total size that meets the memory requirement.

Let's say I have exhausted the available GPU memory. What happens in that case? Do I offload some parts of the model to CPU and possibly to disk i.e., the ones that didn't fit the GPU?

components.enable_auto_cpu_offload(device)

It's not a blocker but we could think of allowing the users to pass multiple devices here as well. But perhaps this is best done with a separate offload class. Or should it rather be handled with device_map (pipeline-level) completely? Okay with me if this feels like we're digressing or table the discussion for later.

you can print out the components to get an overview of all the model components you have, their device/dtype info; this is what I have before we run any pipelines.

This is SO GOOD 🔥

pipe = FluxPipeline.from_pretrained(repo, **components.get(["transformer","text_encoder","text_encoder_2","vae"]), torch_dtype=dtype)

So, in this case, users don't have to do any kind of device placement on the pipe as the components already have been mapped when we called components.enable_auto_cpu_offload(device). Yeah?

Definitely not a blocker but **components.get(["transformer","text_encoder","text_encoder_2","vae"]), assumes that the user exactly knows the name of the components. Would it be possible to automatically infer these names to make it a tad bit easier? Or is it far-fetched?

this is what I got, the base transformer are now moved to cpu, and canny was moved to device, max memory stays at 36.28

This is nice but also assumes that we have enough CPU memory, which starts to add up when there are multiple models that are to be kept in CPU. Should we expose an argument that lets users remove these components completely?

LMK if anything's unclear. Excited to see this getting shipped soon.

Comment on lines +21 to +24
from ..utils import (
is_accelerate_available,
logging,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit.

Suggested change
from ..utils import (
is_accelerate_available,
logging,
)
from ..utils import is_accelerate_available, logging

Comment on lines +38 to +54
def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to
benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch
discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2

Arguments:
return_buffers (`bool`, *optional*, defaults to `True`):
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are
tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm
layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
"""
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
if return_buffers:
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem + mem_bufs
return mem
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay to have this independent method.

Perhaps from the model-level implementation of get_memory_footprint() we can call this method and that allows us to reuse. Can be revisited later.

Comment on lines +79 to +88
def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
self.offload_strategy = offload_strategy

def add_other_hook(self, hook: "UserCustomOffloadHook"):
"""
Add a hook to the list of hooks to consider for offloading.
"""
if self.other_hooks is None:
self.other_hooks = []
self.other_hooks.append(hook)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to have utilities like:

  • delete_hooks()
  • list_current_hooks()

import time

# YiYi Notes: only logging time for now to monitor the overhead of offloading strategy (remove later)
start_time = time.perf_counter()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinions but I think we could keep these with logger.debug(), it's very useful, IMO.

def __init__(
self,
execution_device: Optional[Union[str, int, torch.device]] = None,
other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this just be called hooks? Or we're calling it other_hooks to avoid any potential overlaps with naming?

Edit: Think other_hooks is better.


if hooks_to_offload:
clear_device_cache()
module.to(self.execution_device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please help me understand why we need to do this additional device placement provided the call to self.offload_strategy() above (when it's not None)? Do we have to guard the placement with if self.offload_strategy is None?

Comment on lines +126 to +130
class UserCustomOffloadHook:
"""
A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
the hook or remove it entirely.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users need to override this class in case they want to customize one?


current_module_size = get_memory_footprint(model)

mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be guarded if CUDA is available. So, perhaps we can have a dispatching system to obtain the memory on the device based on the device being used.

Regardless of the dispatching system, this call needs to be guarded I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants