Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling FP8 all-gather for TE Float8Tensor when using Torch FSDP2 #1358

Merged
merged 20 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
181 changes: 181 additions & 0 deletions tests/pytorch/distributed/run_fsdp2_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
#!/usr/bin/python3

# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import sys
import argparse

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn, optim
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
from contextlib import nullcontext


class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
self.fc1 = te.Linear(input_size, hidden_size)
self.fc2 = te.Linear(hidden_size, output_size)

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x


def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs


def restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)


def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
)
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated
parser.add_argument(
"--sharding-dims",
type=int,
nargs="+",
help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
)
args = parser.parse_args(argv, namespace)
if args.sharding_dims:
assert len(args.sharding_dims) <= 2
return args


sub_modules_to_wrap = [te.Linear]


def _train(args):
assert "TORCHELASTIC_RUN_ID" in os.environ
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert LOCAL_SIZE == WORLD_SIZE

# Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

# Initialize torch.distributed global process group and get DP/TP groups
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
device = torch.device(f"cuda:{LOCAL_RANK}")

# FP8 Configuration
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
build_model_context_args = {}

from transformer_engine.pytorch import fp8_model_init

build_model_context = fp8_model_init
build_model_context_args["enabled"] = True

# Build the model with the specified context
with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
else:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
# Move the model to the correct device

model.to(device)

if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
# Creating a DeviceMesh for fully_shard
world_size = int(WORLD_SIZE)
device_ids = list(range(world_size))
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] + 1
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
mesh_dim_names=("replicate", "shard"),
)
else:
assert False

# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model)
for sub_module in model.modules():
if any(
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
restore_custom_attrs(model, custom_attrs)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

for iteration in range(args.iter):
# Zero the parameter gradients
optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device)
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")

dist.destroy_process_group()
if LOCAL_RANK == 0:
print(f"Rank {LOCAL_RANK}: Done...")
return 0


if __name__ == "__main__":
sys.exit(_train(_parse_args()))
67 changes: 67 additions & 0 deletions tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
import torch
from packaging.version import Version as PkgVersion


def get_torch_version():
"""Get pytorch version from __version__"""

def get_torch_version_str():
import torch

return str(torch.__version__)

return PkgVersion(get_torch_version_str())


if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs.")

if torch.cuda.device_count() % 2 != 0:
pytest.skip("Number of device should be divided by 2.")

if not get_torch_version() >= PkgVersion("2.4"):
pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.")

fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = torch.cuda.device_count()
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]


def _run_test(fp_init, sharding_dims):
test_path = TEST_ROOT / "run_fsdp2_model.py"
test_cmd = LAUNCH_CMD + [str(test_path)]

if fp_init:
test_cmd += ["--fp8-init"]
if len(sharding_dims) == 1:
test_cmd += ["--sharding-dims", str(sharding_dims[0])]
elif len(sharding_dims) == 2:
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
else:
assert False
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
if result.returncode != 0:
raise AssertionError(result.stderr.decode())


all_boolean = [True, False]
sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]]


@pytest.mark.parametrize("sharding_dims", sharding_dims)
@pytest.mark.parametrize("fp8_init", all_boolean)
def test_distributed(fp8_init, sharding_dims):
if fp8_init and not fp8_available:
pytest.skip(reason_for_no_fp8)
_run_test(fp8_init, sharding_dims)
92 changes: 91 additions & 1 deletion transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@
aten = torch.ops.aten
updated_fp8_params = {}

_ops_to_preserve_subclass_in_fsdp2 = {
torch.ops.aten.empty_like.default,
torch.ops.aten.new_zeros.default,
torch.ops.aten.slice.Tensor,
torch.ops.aten.copy_.default,
torch.ops.aten.view.default,
torch.ops.aten.as_strided.default,
torch.ops.aten._to_copy.default,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having _to_copy in this list puts us in a weird position. @youngeunkwon0405 Where did you get this list of ops and could we figure out a way to remove _to_copy?

The trouble is that we are implicitly using torch.Tensor.to as a dequantize function, so always expect the _to_copy op to output a plain PyTorch tensor. The reason for this design was to work with Mcore's logic for maintaining FP32 master weights (see logic for DDP and distopt). With this PR, we now see many spurious errors whenever we dequantize an FP8 tensor with to/float/half/etc.

If the current impl of _to_copy leads to insurmountable problems with FSDP2, we'll probably need to remove the implicit dequantization and change Mcore so that it explicitly calls Float8Tensor.dequantize.

Copy link
Collaborator Author

@youngeunkwon0405 youngeunkwon0405 Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following Torch AO's implementation.
https://github.com/pytorch/ao/blob/main/torchao/float8/fsdp_utils.py#L86-L99

I checked that it was okay for _to_copy to not preserve the tensor class currently. But leaved the warning for the future reference.
This PR did not change the functional behavior of the Float8Tensor _to_copy it only adds a warning here.

torch.ops.aten._pin_memory.default,
torch.ops.aten.split.Tensor,
torch.ops.aten.clone.default,
}


def _make_fp8_attr_property_funcs(name: str) -> Any:
"""Make accessors for an FP8 attribute
Expand Down Expand Up @@ -430,6 +443,37 @@ def __new__(

return self

def fsdp_pre_all_gather(self, mesh): # pylint: disable=unused-argument
"""
A hook function used in torch fsdp2, called before all-gather
return (all-gather input), (metadata)
Ref: https://github.com/pytorch/pytorch/pull/122908

"""

return (self._data,), (self,)

def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype, # pylint: disable=unused-argument
*,
out: Optional[torch.Tensor] = None,
):
"""
A hook function used in torch fsdp2, called after all-gather
return (Float8Tensor class instance of all-gathered input), (Things to free after forward)
Ref: https://github.com/pytorch/pytorch/pull/122908

"""
(data,) = all_gather_outputs
(sample,) = metadata
if out is not None:
assert isinstance(out, Float8Tensor), f"{type(out)}"
return None
return Float8Tensor.make_like(sample, data=data), all_gather_outputs

@classmethod
def make_like(
cls,
Expand Down Expand Up @@ -902,7 +946,53 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
)
return Float8Tensor.make_like(tensor, data=data_view)

# Default case
# Related to FSDP2
if func == aten.split.Tensor:
tensor = args[0]
data = tensor._data
func_out = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out]
if func == aten.new_zeros.default:
tensor = args[0]
data = tensor._data
func_out = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=func_out)
if func == torch.ops.aten.as_strided.default:
tensor = args[0]
data = tensor._data
func_out = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=func_out)
if func == torch.ops.aten.detach.default:
return cls.detach(args[0])
if func == torch.ops.aten.clone.default:
return cls.clone(args[0])
if func == torch.ops.aten.copy_.default:
# Implementation in the superclass (QuantizedTensor) returns a proper output
pass
elif func in _ops_to_preserve_subclass_in_fsdp2:
# Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2
warnings.warn(
f"A function call({func}) in {cls} may not return {cls} tensor as an output. It"
" might cause an error in torch FSDP2!"
)
else:
pass

return super().__torch_dispatch__(func, types, args, kwargs)

@classmethod
Expand Down
Loading