Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module Group Offloading #10503

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open

Module Group Offloading #10503

wants to merge 31 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Jan 9, 2025

  • enable_model_cpu_offload onloads the entire transformer model at once. The minimal memory requirements for this is, therefore, determined by the size of the transformer. For large models, it is sometimes impossible to even load the memory on GPU
  • enable_sequential_cpu_offload has very minimal memory requirements, but is too slow because of lots of synchronous device transfers. We can speed this up with async cuda streams to "hide" the HtoD and DtoH transfer latency by overlapping with computation. The implementation with cuda sterams would be required to come from accelerate since we rely on it for memory management in this case.
  • UIs usually rely on such a model management system. They require finegrained control on what layers to offload/onload to device, and the requirement may change dynamically during runtime based on chosen user settings. This PR enables that in limited capacity at the moment.
model_id offloading_type use_stream num_blocks time model_memory inference_memory
cogvideox-1.0 none False 242.485 19.697 24.396
cogvideox-1.0 model False 328.429 0.059 15.314
cogvideox-1.0* block_level False 8 289.809 9.35 16.023
cogvideox-1.0* block_level True 8 247.966 9.332 17.996
cogvideox-1.0* block_level False 1 282.885 9.332 14.457
cogvideox-1.0* block_level True 1 244.795 9.35 14.549
cogvideox-1.0 leaf_level True 248.396 0.479 5.344
hunyuan_video none False 69.436 24.436 26.451
hunyuan_video model False 141.01 0.041 25.949
hunyuan_video block_level False 8 139.316 0.525 8.273
hunyuan_video block_level True 8 83.975 0.701 13.508
hunyuan_video block_level False 1 137.182 0.525 3.844
hunyuan_video block_level True 1 77.245 0.721 4.666
hunyuan_video leaf_level True 80.375 0.533 2.908
ltx_video none False 36.702 4.85 6.475
ltx_video model False 51.525 0.025 5.678
ltx_video block_level False 8 56.477 0.848 3.564
ltx_video block_level True 8 41.053 1.121 5.025
ltx_video block_level False 1 55.057 0.848 2.658
ltx_video block_level True 1 38.443 1.121 3.057
ltx_video leaf_level True 38.456 1.121 2.826
flux none False 16.811 31.467 32.088
flux model False 178.134 0.041 22.84
flux* block_level False 8 116.088 9.307 14.891
flux* block_level True 8 50.228 9.309 20.168
flux* block_level False 1 119.193 9.307 10.461
flux* block_level True 1 50.097 9.309 11.307
flux leaf_level True 52.519 0.227 1.09

*The benchmarks were run with a mistake in the offloading code. This caused text encoder to be on the GPU instead of being offloaded, making the comparison unfair to those runs marked without a *

Benchmark
import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import (
    AllegroPipeline,
    CogVideoXPipeline,
    FluxPipeline,
    HunyuanVideoPipeline,
    LattePipeline,
    LTXPipeline,
    MochiPipeline,
)
from diffusers.hooks import apply_group_offloading
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/home/aryan/work/diffusers")
branch = repo.active_branch


def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output


def prepare_allegro(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "rhymes-ai/Allegro"
    cache_dir = None

    pipe = AllegroPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats.",
        "height": 720,
        "width": 1280,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_cogvideox_1_0(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "THUDM/CogVideoX-5b"
    cache_dir = None

    pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt=(
            "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
            "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
            "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
            "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
            "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
            "atmosphere of this unique musical performance."
        ),
        device="cuda",
        dtype=dtype,
    )

    pipe.text_encoder.to("cpu")
    del pipe.text_encoder

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "height": 480,
        "width": 720,
        "num_frames": 49,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_flux(dtype: torch.dtype, compile: bool = False, **kwargs) -> None:
    model_id = "black-forest-labs/Flux.1-Dev"
    cache_dir = "/raid/.cache/huggingface"

    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    prompt_embeds, pooled_prompt_embeds, _ = pipe.encode_prompt(
        prompt="A cat holding a sign that says hello world", prompt_2=None, device="cuda"
    )

    pipe.text_encoder.to("cpu")
    pipe.text_encoder_2.to("cpu")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
        "height": 768,
        "width": 768,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_hunyuan_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "hunyuanvideo-community/HunyuanVideo"
    cache_dir = None

    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
        model_id, subfolder="transformer", torch_dtype=torch.bfloat16
    )
    pipe = HunyuanVideoPipeline.from_pretrained(
        model_id, transformer=transformer, torch_dtype=torch.float16, cache_dir=cache_dir
    )
    pipe.to("cuda")

    prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
        prompt="A cat wearing sunglasses and working as a lifeguard at pool.", device="cuda", dtype=torch.float16
    )
    pipe.text_encoder.to("cpu")
    pipe.text_encoder_2.to("cpu")
    del pipe.text_encoder, pipe.text_encoder_2

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
        "prompt_attention_mask": prompt_attention_mask,
        "height": 320,
        "width": 512,
        "num_frames": 61,
        "num_inference_steps": 30,
    }

    return pipe, generation_kwargs


def prepare_latte(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "maxin-cn/Latte-1"
    cache_dir = None

    pipe = LattePipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt="A cat wearing sunglasses and working as a lifeguard at pool.",
        do_classifier_free_guidance=True,
        num_videos_per_prompt=1,
        device="cuda",
    )
    pipe.text_encoder.to("cpu")
    del pipe.text_encoder

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "height": 512,
        "width": 512,
        "video_length": 16,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_ltx_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "a-r-r-o-w/LTX-Video-diffusers"
    cache_dir = None

    pipe = LTXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    (
        prompt_embeds,
        prompt_attention_mask,
        negative_prompt_embeds,
        negative_prompt_attention_mask,
    ) = pipe.encode_prompt(
        prompt="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage",
        negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
        do_classifier_free_guidance=True,
        num_videos_per_prompt=1,
        device="cuda",
    )
    pipe.text_encoder.to("cpu")
    del pipe.text_encoder

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)
    
    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "prompt_attention_mask": prompt_attention_mask,
        "negative_prompt_embeds": negative_prompt_embeds,
        "negative_prompt_attention_mask": negative_prompt_attention_mask,
        "width": 768,
        "height": 512,
        "num_frames": 161,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_mochi(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "genmo/mochi-1-preview"
    cache_dir = None

    pipe = MochiPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
        "height": 480,
        "width": 848,
        "num_frames": 85,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def decode_allegro(pipe: AllegroPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_cogvideox_1_0(pipe: CogVideoXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


def decode_hunyuan_video(pipe: HunyuanVideoPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_latte(pipe: LattePipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents, video_length=kwargs["video_length"])
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_ltx_video(pipe: LTXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latent_num_frames = (kwargs["num_frames"] - 1) // pipe.vae_temporal_compression_ratio + 1
    latent_height = kwargs["height"] // pipe.vae_spatial_compression_ratio
    latent_width = kwargs["width"] // pipe.vae_spatial_compression_ratio

    latents = pipe._unpack_latents(
        latents,
        latent_num_frames,
        latent_height,
        latent_width,
        pipe.transformer_spatial_patch_size,
        pipe.transformer_temporal_patch_size,
    )
    latents = pipe._denormalize_latents(
        latents, pipe.vae.latents_mean, pipe.vae.latents_std, pipe.vae.config.scaling_factor
    )
    latents = latents.to(pipe.vae.dtype)

    timestep = None
    video = pipe.vae.decode(latents, timestep, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=24)
    return filename


def decode_mochi(pipe: MochiPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents_std = torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def reset_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()


MODEL_MAPPING = {
    "allegro": {
        "prepare": prepare_allegro,
        "decode": decode_allegro,
    },
    "cogvideox-1.0": {
        "prepare": prepare_cogvideox_1_0,
        "decode": decode_cogvideox_1_0,
    },
    "flux": {
        "prepare": prepare_flux,
        "decode": decode_flux,
    },
    "hunyuan_video": {
        "prepare": prepare_hunyuan_video,
        "decode": decode_hunyuan_video,
    },
    "latte": {
        "prepare": prepare_latte,
        "decode": decode_latte,
    },
    "ltx_video": {
        "prepare": prepare_ltx_video,
        "decode": decode_ltx_video,
    },
    "mochi": {
        "prepare": prepare_mochi,
        "decode": decode_mochi,
    },
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator().manual_seed(181201)
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id: str, output_dir: str, dtype: str, offloading_type: str, num_blocks_per_group: int, use_stream: bool, compile: bool):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]
    reset_memory()

    try:
        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=compute_dtype)

        # 2. Apply group offloading
        if offloading_type == "model":
            pipe.enable_model_cpu_offload()
        elif offloading_type == "sequential":
            pipe.enable_sequential_cpu_offload()
        elif offloading_type in ["block_level", "leaf_level"]:
            apply_group_offloading(
                pipe.transformer,
                offload_type=offloading_type,
                num_blocks_per_group=num_blocks_per_group,
                offload_device=torch.device("cpu"),
                onload_device=torch.device("cuda"),
                # force_offload=True for a more fair comparison against model offloading
                # If we set to True -> lower memory
                # If we set to False -> lower time required
                force_offload=True,
                non_blocking=True,
                use_stream=use_stream,
            )

        reset_memory()
        model_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        if compile:
            pipe.transformer = torch.compile(
                pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
            )

        # 3. Warmup
        num_warmups = 1
        original_num_inference_steps = generation_kwargs["num_inference_steps"]
        generation_kwargs["num_inference_steps"] = 2
        for _ in range(num_warmups):
            run_inference(pipe, generation_kwargs)
        generation_kwargs["num_inference_steps"] = original_num_inference_steps

        # 4. Benchmark
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---dtype-{dtype}---offloading_type-{offloading_type}---num_blocks_per_group-{num_blocks_per_group}---use_stream-{use_stream}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            num_frames=generation_kwargs.get("num_frames", None),
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "offloading_type": offloading_type,
            "use_stream": use_stream,
            "num_blocks": num_blocks_per_group,
            "time": time,
            "model_memory": model_max_memory_reserved,
            "inference_memory": inference_max_memory_reserved,
            "compile": compile,
            "compute_dtype": dtype,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "offloading_type": offloading_type,
            "use_stream": use_stream,
            "num_blocks": num_blocks_per_group,
            "time": None,
            "model_memory": None,
            "inference_memory": None,
            "compile": compile,
            "compute_dtype": dtype,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux", "cogvideox-1.0", "latte", "allegro", "hunyuan_video", "mochi", "ltx_video"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--output_dir", required=True, type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument("--offloading_type", type=str, default="none", choices=["none", "model", "block_level", "leaf_level"], help="Type of offloading to use.")
    parser.add_argument("--num_blocks_per_group", type=int, default=None, help="Number of layers per group for group offloading.")
    parser.add_argument("--use_stream", action="store_true", default=False, help="Whether to use CUDA streams for offloading.")
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether to torch.compile the denoiser.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(
        args.model_id,
        args.output_dir,
        args.dtype,
        args.offloading_type,
        args.num_blocks_per_group,
        args.use_stream,
        args.compile,
    )

Some goals of this PR:

  • Opt-in choice to completely eliminate/hide any device transfer latency where possible by overlapping computation/transfer. This usually comes at a slightly higher memory requirement than choosing to not hide latency
  • Fully compatible with torch.compile There are a few recompiles triggered. Not really sure how to get away with it

In a way, these changes can enable both enable_model_cpu_offload and enable_sequential_cpu_offload because of the way offload_group_patterns can be leveraged, but that is not the goal.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w a-r-r-o-w added the roadmap Add to current release roadmap label Jan 9, 2025
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 9, 2025

I think this fits well in the offloading I'm working on in modular diffusers

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Jan 9, 2025

Maybe we should consolidate a bit - I will separate the offloading part into its own PR

from .hooks import HookRegistry, ModelHook


_COMMON_STACK_IDENTIFIERS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to have this as an attribute within each model.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually will remove this completely. This should be applicable on any model containing ModuleList or Sequential because we know for sure, atleast in Diffusers, that the call order of these layers are sequential and not in some weird access pattern.

So, will make the check to just look for the above two classes with isinstance

buffer.data = buffer.data.to(onload_device)


def _apply_group_offloading_group_patterns(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these can be consolidated into a single function and use the offload_group_pattern. If we add something like a _group_offload_modules to the Model class, we can just extend it with the offload_group_patterns argument here.

return module


class HookRegistry:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good 👍🏽

@a-r-r-o-w
Copy link
Member Author

Some more numbers after latest changes:

| model_id   | offloading_type | num_blocks | non_blocking | time   | model_memory | inference_memory | cuda_stream |
|------------|-----------------|------------|--------------|--------|--------------|------------------|-------------|
| ltx_video  | none            |            | False        | 36.852 | 4.85         | 6.475            | False       |
| ltx_video  | group           | 8          | True         | 53.5   | 1.205        | 3.531            | False       |
| ltx_video  | group           | 8          | True         | 36.715 | 0.848        | 4.787            | True        |
| ltx_video  | group           | 1          | True         | 52.181 | 1.205        | 2.811            | False       |
| ltx_video  | group           | 1          | True         | 36.611 | 0.848        | 2.818            | True        |

Continuing from our internal thread, we have positive signal that sequential CPU offloading can be done without any hit to time required for inference when using cuda streams for transfer.

@a-r-r-o-w a-r-r-o-w requested a review from DN6 January 26, 2025 22:50
@a-r-r-o-w
Copy link
Member Author

After some discussing and re-iterating, the two main offloading strategies in this PR are now:

  • Taking a module with a stack of ModuleList/Sequential layers and being able to offload/onload "groups" of them when required
  • Taking a module and being able to offload/onload the leaf-level modules when required (similar to our current enable_sequential_cpu_offload)

The latter has minimal memory requirements, but can be very slow. If layer prefetching is utilized, there is some time overhead but not much if there is sufficient computation to overlap with (video models are great use case, or image models with bigger batch size).

The former is more beneficial for non-CUDA devices since it allows offloading at the inner module levels. This helps reduce the contribution to memory usage by requiring entire model on GPU (normal CPU offloading aka enable_model_cpu_offload). Initially, I wanted to get to a stage where any group of layers could be grouped together, but there's no real benefit to it, so a simple implementation is now used (only module lists and sequentials).

From this, it might be apparent that the main target users are people with CUDA devices supporting streams - low memory usage without much overhead to generation time.

Also, if you have tested the PR before, you might find the latest version slightly faster and using a few hundred megabytes lesser :)

Some code examples:

Offloading LTX Transformer and Text encoder
import torch
from diffusers import LTXPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug

set_verbosity_debug()

pipe = LTXPipeline.from_pretrained("a-r-r-o-w/LTX-Video-diffusers", torch_dtype=torch.bfloat16)
pipe.vae.to("cuda")
pipe.vae.enable_tiling()

apply_group_offloading(
    pipe.text_encoder,
    offload_type="leaf_level",
    offload_device=torch.device("cpu"),
    onload_device=torch.device("cuda"),
    use_stream=True,
)
apply_group_offloading(
    pipe.transformer,
    offload_type="leaf_level",
    offload_device=torch.device("cpu"),
    onload_device=torch.device("cuda"),
    use_stream=True,
)

prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    width=704,
    height=480,
    num_frames=161,
    num_inference_steps=50,
).frames[0]
print(f"Max memory reserved: {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")
export_to_video(video, "output.mp4", fps=24)
HunyuanVideo LoRA inference
import argparse
import os
from pathlib import Path

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_debug
from diffusers.hooks.group_offloading import apply_group_offloading

set_verbosity_debug()


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--lora_path", type=str, default="none")
    parser.add_argument("--id_token", type=str, default="")
    parser.add_argument("--prompts_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--lora_strength", type=float, default=1.0)
    parser.add_argument("--height", type=int, default=320)
    parser.add_argument("--width", type=int, default=512)
    parser.add_argument("--num_frames", type=int, default=61)
    parser.add_argument("--num_inference_steps", type=int, default=30)
    return parser.parse_args()


def string_to_filename(x):
    return x.replace(" ", "_").replace(",", "").replace(".", "").replace(":", "").replace(";", "").replace("!", "").replace("?", "").replace("'", "").replace('"', "")


args = get_args()
output_dir = Path(args.output_dir)

output_dir.mkdir(parents=True, exist_ok=True)

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
if args.lora_path != "none":
    pipe.load_lora_weights(args.lora_path, adapter_name="hunyuan-lora")
    pipe.set_adapters("hunyuan-lora", args.lora_strength)
pipe.vae.enable_tiling()

pipe.text_encoder.to("cuda")
pipe.text_encoder_2.to("cuda")
pipe.vae.to("cuda")
apply_group_offloading(
    pipe.transformer,
    offload_type="leaf_level",
    offload_device=torch.device("cpu"),
    onload_device=torch.device("cuda"),
    force_offload=True,
    non_blocking=True,
    use_stream=True,
)

with open(args.prompts_path) as file:
    prompts = [line.strip() for line in file if len(line.strip()) > 0]

for prompt in prompts:
    if args.id_token:
        prompt = f"{args.id_token} {prompt}"
    print(prompt)
    output = pipe(
        prompt=prompt,
        height=args.height,
        width=args.width,
        num_frames=args.num_frames,
        num_inference_steps=args.num_inference_steps,
        generator=torch.Generator().manual_seed(42),
    ).frames[0]
    filename = string_to_filename(prompt)[:25]
    filename = f"{filename}---lora_strength-{args.lora_strength}---height-{args.height}---width-{args.width}---num_frames-{args.num_frames}---num_inference_steps-{args.num_inference_steps}"
    filepath = output_dir / f"{filename}.mp4"
    export_to_video(output, filepath.as_posix(), fps=15)

@DN6 @yiyixuxu Could you give this a review?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

4 participants