-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling FP8 all-gather for TE Float8Tensor when using Torch FSDP2 #1358
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
47ad24a
draft implementation of fsdp2 fp8 all gather
youngeunkwon0405 2f4c102
fix the convergence issue
youngeunkwon0405 03a98c0
Merge branch 'main' into fsdp2
youngeunkwon0405 76ff010
Add warning
youngeunkwon0405 aed545b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8d81f56
Merge branch 'main' into fsdp2
youngeunkwon0405 38e060d
disable lint error
youngeunkwon0405 f01245e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4e7694d
fix the lint error
youngeunkwon0405 ff6d1d6
fix lint error
youngeunkwon0405 16eb7b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 74f4d17
fix lint error
youngeunkwon0405 fb7690d
Merge branch 'fsdp2' of github.com:youngeunkwon0405/TransformerEngine…
youngeunkwon0405 aeb851f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 941dbcb
fix lint error
youngeunkwon0405 daba5a6
add comments
youngeunkwon0405 689e30a
add ref
youngeunkwon0405 7ecfe04
add related tests
youngeunkwon0405 e4cf960
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5c1f189
Merge branch 'main' into fsdp2
youngeunkwon0405 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
#!/usr/bin/python3 | ||
|
||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
import os | ||
import sys | ||
import argparse | ||
|
||
import transformer_engine.pytorch as te | ||
from transformer_engine.common.recipe import Format, DelayedScaling | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn.functional as F | ||
from torch import nn, optim | ||
from torch.distributed import DeviceMesh | ||
from torch.distributed._composable.fsdp import fully_shard | ||
from torch.distributed.device_mesh import init_device_mesh | ||
from contextlib import nullcontext | ||
|
||
|
||
class SimpleNet(nn.Module): | ||
def __init__(self, input_size, hidden_size, output_size): | ||
super(SimpleNet, self).__init__() | ||
self.fc1 = te.Linear(input_size, hidden_size) | ||
self.fc2 = te.Linear(hidden_size, output_size) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.fc1(x)) | ||
x = self.fc2(x) | ||
return x | ||
|
||
|
||
def save_custom_attrs(module): | ||
custom_attrs = {} | ||
for name, param in module.named_parameters(): | ||
attrs = vars(param) | ||
custom_attrs[name] = {k: v for k, v in attrs.items()} | ||
return custom_attrs | ||
|
||
|
||
def restore_custom_attrs(module, custom_attrs): | ||
for name, param in module.named_parameters(): | ||
if name in custom_attrs: | ||
for attr_name, attr_value in custom_attrs[name].items(): | ||
setattr(param, attr_name, attr_value) | ||
|
||
|
||
def _parse_args(argv=None, namespace=None): | ||
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") | ||
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") | ||
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") | ||
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") | ||
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") | ||
parser.add_argument( | ||
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." | ||
) | ||
parser.add_argument( | ||
"--iter", type=int, default=10, help="Number of iterations for forward pass" | ||
) | ||
parser.add_argument("--seed", type=int, default=42, help="RNG seed.") | ||
# Adding hsdp_dim as a list argument, comma-separated | ||
parser.add_argument( | ||
"--sharding-dims", | ||
type=int, | ||
nargs="+", | ||
help='FSDP/HSDP sharding dimensions ("replicate", "shard")', | ||
) | ||
args = parser.parse_args(argv, namespace) | ||
if args.sharding_dims: | ||
assert len(args.sharding_dims) <= 2 | ||
return args | ||
|
||
|
||
sub_modules_to_wrap = [te.Linear] | ||
|
||
|
||
def _train(args): | ||
assert "TORCHELASTIC_RUN_ID" in os.environ | ||
WORLD_RANK = int(os.getenv("RANK", "0")) | ||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) | ||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) | ||
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) | ||
assert LOCAL_SIZE == WORLD_SIZE | ||
|
||
# Set device and initialize RNG states | ||
torch.cuda.set_device(WORLD_RANK) | ||
torch.manual_seed(args.seed) | ||
torch.cuda.manual_seed(args.seed) | ||
|
||
# Initialize torch.distributed global process group and get DP/TP groups | ||
dist_init_kwargs = { | ||
"backend": "nccl", | ||
"rank": WORLD_RANK, | ||
"world_size": WORLD_SIZE, | ||
} | ||
assert dist.is_nccl_available() | ||
dist.init_process_group(**dist_init_kwargs) | ||
nccl_world = dist.new_group(backend="nccl") | ||
device = torch.device(f"cuda:{LOCAL_RANK}") | ||
|
||
# FP8 Configuration | ||
fp8_format = Format.HYBRID | ||
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") | ||
|
||
if not args.fp8_init: | ||
# Build model context (FP8 init) | ||
build_model_context = nullcontext | ||
build_model_context_args = {} | ||
|
||
from transformer_engine.pytorch import fp8_model_init | ||
|
||
build_model_context = fp8_model_init | ||
build_model_context_args["enabled"] = True | ||
|
||
# Build the model with the specified context | ||
with build_model_context(**build_model_context_args): | ||
model = SimpleNet(args.input_size, args.hidden_size, args.output_size) | ||
else: | ||
model = SimpleNet(args.input_size, args.hidden_size, args.output_size) | ||
# Move the model to the correct device | ||
|
||
model.to(device) | ||
|
||
if LOCAL_RANK == 0: | ||
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") | ||
# Creating a DeviceMesh for fully_shard | ||
world_size = int(WORLD_SIZE) | ||
device_ids = list(range(world_size)) | ||
if LOCAL_RANK == 0: | ||
print(f"sharding-dims:{args.sharding_dims}") | ||
# Setup the sharding mesh for FSDP/HSDP | ||
if args.sharding_dims == None: # FSDP | ||
mesh = DeviceMesh("cuda", device_ids) | ||
elif len(args.sharding_dims) == 1: | ||
assert args.sharding_dims[0] == device_ids[-1] + 1 | ||
mesh = DeviceMesh("cuda", device_ids) | ||
elif len(args.sharding_dims) == 2: # HSDP | ||
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 | ||
mesh = init_device_mesh( | ||
"cuda", | ||
(args.sharding_dims[0], args.sharding_dims[1]), | ||
mesh_dim_names=("replicate", "shard"), | ||
) | ||
else: | ||
assert False | ||
|
||
# Apply FSDP/HSDP | ||
custom_attrs = save_custom_attrs(model) | ||
for sub_module in model.modules(): | ||
if any( | ||
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap | ||
): | ||
fully_shard(sub_module, mesh=mesh) | ||
fully_shard(model, mesh=mesh) | ||
restore_custom_attrs(model, custom_attrs) | ||
|
||
optimizer = optim.Adam(model.parameters(), lr=1e-3) | ||
|
||
for iteration in range(args.iter): | ||
# Zero the parameter gradients | ||
optimizer.zero_grad() | ||
input_data = torch.randn(args.batch_size, args.input_size).to(device) | ||
output = model(input_data) | ||
target = torch.randn(args.batch_size, args.output_size).to(device) | ||
loss = F.mse_loss(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
if LOCAL_RANK == 0: | ||
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") | ||
|
||
dist.destroy_process_group() | ||
if LOCAL_RANK == 0: | ||
print(f"Rank {LOCAL_RANK}: Done...") | ||
return 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(_train(_parse_args())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
|
||
import os | ||
import pytest | ||
import subprocess | ||
from pathlib import Path | ||
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager | ||
import torch | ||
from packaging.version import Version as PkgVersion | ||
|
||
|
||
def get_torch_version(): | ||
"""Get pytorch version from __version__""" | ||
|
||
def get_torch_version_str(): | ||
import torch | ||
|
||
return str(torch.__version__) | ||
|
||
return PkgVersion(get_torch_version_str()) | ||
|
||
|
||
if torch.cuda.device_count() < 4: | ||
pytest.skip("FSDP2 test requires at least 4 GPUs.") | ||
|
||
if torch.cuda.device_count() % 2 != 0: | ||
pytest.skip("Number of device should be divided by 2.") | ||
|
||
if not get_torch_version() >= PkgVersion("2.4"): | ||
pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.") | ||
|
||
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() | ||
|
||
TEST_ROOT = Path(__file__).parent.resolve() | ||
NUM_PROCS: int = torch.cuda.device_count() | ||
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] | ||
|
||
|
||
def _run_test(fp_init, sharding_dims): | ||
test_path = TEST_ROOT / "run_fsdp2_model.py" | ||
test_cmd = LAUNCH_CMD + [str(test_path)] | ||
|
||
if fp_init: | ||
test_cmd += ["--fp8-init"] | ||
if len(sharding_dims) == 1: | ||
test_cmd += ["--sharding-dims", str(sharding_dims[0])] | ||
elif len(sharding_dims) == 2: | ||
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] | ||
else: | ||
assert False | ||
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) | ||
if result.returncode != 0: | ||
raise AssertionError(result.stderr.decode()) | ||
|
||
|
||
all_boolean = [True, False] | ||
sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]] | ||
|
||
|
||
@pytest.mark.parametrize("sharding_dims", sharding_dims) | ||
@pytest.mark.parametrize("fp8_init", all_boolean) | ||
def test_distributed(fp8_init, sharding_dims): | ||
if fp8_init and not fp8_available: | ||
pytest.skip(reason_for_no_fp8) | ||
_run_test(fp8_init, sharding_dims) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having
_to_copy
in this list puts us in a weird position. @youngeunkwon0405 Where did you get this list of ops and could we figure out a way to remove_to_copy
?The trouble is that we are implicitly using
torch.Tensor.to
as a dequantize function, so always expect the_to_copy
op to output a plain PyTorch tensor. The reason for this design was to work with Mcore's logic for maintaining FP32 master weights (see logic for DDP and distopt). With this PR, we now see many spurious errors whenever we dequantize an FP8 tensor withto
/float
/half
/etc.If the current impl of
_to_copy
leads to insurmountable problems with FSDP2, we'll probably need to remove the implicit dequantization and change Mcore so that it explicitly callsFloat8Tensor.dequantize
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was following Torch AO's implementation.
https://github.com/pytorch/ao/blob/main/torchao/float8/fsdp_utils.py#L86-L99
I checked that it was okay for
_to_copy
to not preserve the tensor class currently. But leaved the warning for the future reference.This PR did not change the functional behavior of the Float8Tensor
_to_copy
it only adds a warning here.