Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor-parallel like FeedForward to lower memory requirements #10623

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

a-r-r-o-w
Copy link
Member

In tensor parallelism, we split the internal layers of modules either column-wise or row-wise across multiple GPUs, perform individual forward passes for each split, and then all_reduce with a sum to gather outputs. We can do the same thing sequentially on a single GPU. Doing so:

  • does not reduce memory usage from model weights ❌
  • does reduce memory usage from intermediate activations ✅ (goal of this PR)

Typically, there is a 4x-8x expansion in the intermediate hidden dimension of the FFs. This results in a 4x-8x larger intermediate tensor being created compared to input tensor. Given that we hav large models like HunyuanVideo now, the sequence length can be very large (> 2**14 for decent frames/resolution), so FFs end up allocating much additional memory -- this is much more worse for training without gradient checkpointing (or partial gradient checkpointing #10611), I think, because all intermediate tensors need to be stored. We can, however, get rid of allocating a large intermediate tensor.

There are some gotchas however. Applying this directly on an arbitrary model will most likely not show any memory savings. In order for this to have any effect, we first need to optimize the memory usage of a model to the point where FeedForward hidden dim expansion actually starts to affect the peak memory required. There are many ways reach the tipping point where FFs end up causing the peaks (sometimes multiple of below mentioned techniques need to be combined to reach that point):

  • Quantization or FP8 Layerwise-upcasting
  • Memory-optimized attention
  • Chunked feed-forward (we will be deprecating the way we do it currently soon)
  • Split-Inference (reference is FreeNoise here)
  • CPU/Sequential offloading of internal/leaf modules - Group offloading (see Module Group Offloading #10503)
  • ... other things I'm forgetting

With the latest group offloading support upcoming, we know that VRAM usage can be reduced significantly without any penalty to throughput, given adequate CPU RAM. This reduces the memory peaks from model weights. The main cause of the spiky points on memory trace is now due either:

  • Default SDPA backend - can be somewhat flattened with flash attention backend or using flash-attn library directly with custom attention processor
  • FeedForward - this was also a huge problem when we added FreeNoise, but split-inference helped improve things back then

We can reduce these memory peaks by making use of ideas from tensor and sequence-parellism and applying them sequentially for single-GPU case. This PR implements the tensor-parallel equivalent of FFs. Will do a follow-up for sequence-parallism based optimization in the near future after some more important things are taken care of.

Onto the benchmark numbers!

Benchmark
import gc

import torch

from diffusers.models.attention import FeedForward
from diffusers.models.memory_utils import apply_memory_optimized_feedforward
from diffusers.utils.logging import set_verbosity_debug

# set_verbosity_debug()


class DummyModel(torch.nn.Module):
    def __init__(self, dim: int, dim_out: int, mult: int, activation_fn: str, num_layers: int = 5):
        super().__init__()
        
        self.blocks = torch.nn.ModuleList([
            FeedForward(dim=dim, dim_out=dim_out, mult=mult, activation_fn=activation_fn)
            for _ in range(num_layers)
        ])
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output


def reset_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()


ACTIVATION_FUNCTIONS = ["gelu", "gelu-approximate", "geglu", "geglu-approximate", "swiglu", "linear-silu"]

batch_size = 2
sequence_length = 8192
dim = 4096
mult = 4
num_layers = 5

@torch.no_grad()
def main(act_fn, device="cuda"):
    model = DummyModel(dim=dim, dim_out=dim, mult=mult, activation_fn=act_fn, num_layers=num_layers)
    model.to(device)

    input = torch.randn(batch_size, sequence_length, dim, device=device)
    
    # Benchmark normal FF
    model(input)  # warmup
    reset_memory()
    elapsed_time, output1 = benchmark_fn(model, input)
    max_memory_allocated = torch.cuda.max_memory_allocated() / 1024 ** 3
    print(f"Normal FF: {elapsed_time=:.2f}s, {max_memory_allocated=:.2f}GB")

    # Benchmark memory optimized FF
    model(input)  # warmup
    reset_memory()
    model = apply_memory_optimized_feedforward(model)
    elapsed_time, output2 = benchmark_fn(model, input)
    max_memory_allocated = torch.cuda.max_memory_allocated() / 1024 ** 3
    print(f"Memory optimized FF: {elapsed_time=:.2f}s, {max_memory_allocated=:.2f}GB")

    diff = output1 - output2
    absmax = diff.abs().max()
    absmean = diff.abs().mean()
    print(f"Output comparison: {absmax=:.5f}, {absmean=:.5f}")


if __name__ == "__main__":
    for act_fn in ACTIVATION_FUNCTIONS:
        # for device in ["cuda", "cpu"]:
        for device in ["cuda"]:
            print(f"Activation function: {act_fn}, device={device}")
            main(act_fn, device=device)
            print()

Benchmarks results:

Activation function: gelu, device=cuda
Normal FF: elapsed_time=1.29s, max_memory_allocated=5.01GB
Memory optimized FF: elapsed_time=1.40s, max_memory_allocated=4.26GB
Output comparison: absmax=0.00000, absmean=0.00000

Activation function: gelu-approximate, device=cuda
Normal FF: elapsed_time=1.29s, max_memory_allocated=5.01GB
Memory optimized FF: elapsed_time=1.28s, max_memory_allocated=4.26GB
Output comparison: absmax=0.00000, absmean=0.00000

Activation function: geglu, device=cuda
Normal FF: elapsed_time=1.92s, max_memory_allocated=8.26GB
Memory optimized FF: elapsed_time=1.92s, max_memory_allocated=6.01GB
Output comparison: absmax=0.00015, absmean=0.00003

Activation function: geglu-approximate, device=cuda
Normal FF: elapsed_time=1.30s, max_memory_allocated=6.01GB
Memory optimized FF: elapsed_time=1.70s, max_memory_allocated=4.51GB
Output comparison: absmax=0.00000, absmean=0.00000

Activation function: swiglu, device=cuda
Normal FF: elapsed_time=1.92s, max_memory_allocated=8.26GB
Memory optimized FF: elapsed_time=1.92s, max_memory_allocated=6.01GB
Output comparison: absmax=0.00014, absmean=0.00003

Activation function: linear-silu, device=cuda
Normal FF: elapsed_time=1.29s, max_memory_allocated=5.01GB
Memory optimized FF: elapsed_time=1.29s, max_memory_allocated=4.26GB
Output comparison: absmax=0.00000, absmean=0.00000

To make it easier to parse, table:

Activation Function Type Elapsed Time Max Memory Allocated
gelu Normal FF 1.29s 5.01GB
Memory Optimized FF 1.40s 4.26GB
gelu-approximate Normal FF 1.29s 5.01GB
Memory Optimized FF 1.28s 4.26GB
geglu Normal FF 1.92s 8.26GB
Memory Optimized FF 1.92s 6.01GB
geglu-approximate Normal FF 1.30s 6.01GB
Memory Optimized FF 1.70s 4.51GB
swiglu Normal FF 1.92s 8.26GB
Memory Optimized FF 1.92s 6.01GB
linear-silu Normal FF 1.29s 5.01GB
Memory Optimized FF 1.29s 4.26GB
minimal reproducer with memory trace
import gc
import torch
from diffusers.models.attention import FeedForward


class MemoryOptimizedFeedForward(torch.nn.Module):
    def __init__(self, dim, dim_out, mult, activation_fn, num_splits=4):
        super().__init__()
        self.dim = dim
        self.dim_out = dim_out
        self.mult = mult
        self.activation_fn = activation_fn
        self.num_splits = num_splits
        self.inner_dim = dim * mult

        self.proj = torch.nn.ModuleList([
            torch.nn.Linear(dim, self.inner_dim // num_splits) for _ in range(num_splits)
        ])
        self.act_fn = torch.nn.GELU()
        self.out = torch.nn.ModuleList([
            torch.nn.Linear(self.inner_dim // num_splits, dim_out, bias=False) for _ in range(num_splits)
        ])
        self.bias = torch.nn.Parameter(torch.zeros(dim_out))
    
    def forward(self, x):
        outputs = x.new_zeros(x.shape)
        for i in range(self.num_splits):
            out = self.proj[i](x)
            out = self.act_fn(out)
            out = self.out[i](out)
            outputs += out
        outputs += self.bias
        return outputs


def reset_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()


def run_forward(model, *args):
    output = model(*args)
    return output

torch.cuda.memory._record_memory_history()

dim = 4096
mult = 4
ff_normal = FeedForward(dim=dim, mult=mult, dim_out=dim, activation_fn="gelu")
ff_split = MemoryOptimizedFeedForward(dim=dim, dim_out=dim, mult=mult, activation_fn="gelu", num_splits=mult)

for i, layer in enumerate(ff_split.proj):
    layer.weight.data.copy_(ff_normal.net[0].proj.weight.data.chunk(ff_split.num_splits, dim=0)[i])
    layer.bias.data.copy_(ff_normal.net[0].proj.bias.data.chunk(ff_split.num_splits, dim=0)[i])

for i, layer in enumerate(ff_split.out):
    layer.weight.data.copy_(ff_normal.net[2].weight.data.chunk(ff_split.num_splits, dim=1)[i])

ff_split.bias.data.copy_(ff_normal.net[2].bias.data)

device = "cuda"
ff_normal.to(device)
ff_split.to(device)

batch_size = 2
sequence_length = 8192
input = torch.randn((batch_size, sequence_length, dim), device=device)

with torch.no_grad():
    reset_memory()
    output1 = run_forward(ff_normal, input)
    print(f"Max memory reserved: {torch.cuda.max_memory_reserved() / 1024 ** 3:.3f} GB")
    
    reset_memory()
    output2 = run_forward(ff_split, input)
    print(f"Max memory reserved: {torch.cuda.max_memory_reserved() / 1024 ** 3:.3f} GB")

    diff = output1 - output2
    print(diff.abs().max(), diff.abs().mean())

torch.cuda.memory._dump_snapshot("ff.pickle")

Results:

Max memory reserved: 3.271 GB
Max memory reserved: 2.271 GB
tensor(6.7353e-06, device='cuda:0') tensor(3.3257e-07, device='cuda:0')

@DN6 @yiyixuxu Would like a first-pass review before making further changes to gather feedback on what should be changed. Can add docs and think about how to expose a single API for applying memory optimizations so that we don't confuse users

cc @bghira (As an admirer and power-user of SimpleTuner, I think some of the latest optimizations will help benefit training as well [group offloading with cuda stream prefetching + this]. Would love to see how low we can go for training the biggest available models with negligible impact to speed)

@a-r-r-o-w a-r-r-o-w requested review from DN6 and yiyixuxu January 21, 2025 22:57
# Apply feedforward pass sequentially since this is intended for memory optimization on a single GPU
for proj_in, proj_out in zip(self.proj_in, self.proj_out):
out = proj_in(hidden_states)
out = self.dropout(out)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dropout is probably incorrect here. As we split the embed dimension, applying dropout on each split will cause num_split times more features to be dropped. I think dividing the original dropout rate by num_split should have equivalent effect as normal feedforward 🤔

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants