Tensor-parallel like FeedForward to lower memory requirements #10623
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
In tensor parallelism, we split the internal layers of modules either column-wise or row-wise across multiple GPUs, perform individual forward passes for each split, and then all_reduce with a sum to gather outputs. We can do the same thing sequentially on a single GPU. Doing so:
Typically, there is a 4x-8x expansion in the intermediate hidden dimension of the FFs. This results in a 4x-8x larger intermediate tensor being created compared to input tensor. Given that we hav large models like HunyuanVideo now, the sequence length can be very large (> 2**14 for decent frames/resolution), so FFs end up allocating much additional memory -- this is much more worse for training without gradient checkpointing (or partial gradient checkpointing #10611), I think, because all intermediate tensors need to be stored. We can, however, get rid of allocating a large intermediate tensor.
There are some gotchas however. Applying this directly on an arbitrary model will most likely not show any memory savings. In order for this to have any effect, we first need to optimize the memory usage of a model to the point where
FeedForward
hidden dim expansion actually starts to affect the peak memory required. There are many ways reach the tipping point where FFs end up causing the peaks (sometimes multiple of below mentioned techniques need to be combined to reach that point):With the latest group offloading support upcoming, we know that VRAM usage can be reduced significantly without any penalty to throughput, given adequate CPU RAM. This reduces the memory peaks from model weights. The main cause of the spiky points on memory trace is now due either:
We can reduce these memory peaks by making use of ideas from tensor and sequence-parellism and applying them sequentially for single-GPU case. This PR implements the tensor-parallel equivalent of FFs. Will do a follow-up for sequence-parallism based optimization in the near future after some more important things are taken care of.
Onto the benchmark numbers!
Benchmark
Benchmarks results:
To make it easier to parse, table:
minimal reproducer with memory trace
Results:
@DN6 @yiyixuxu Would like a first-pass review before making further changes to gather feedback on what should be changed. Can add docs and think about how to expose a single API for applying memory optimizations so that we don't confuse users
cc @bghira (As an admirer and power-user of SimpleTuner, I think some of the latest optimizations will help benefit training as well [group offloading with cuda stream prefetching + this]. Would love to see how low we can go for training the biggest available models with negligible impact to speed)