diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh
index 9a11ccc008..4e52153db9 100644
--- a/qa/L1_pytorch_distributed_unittest/test.sh
+++ b/qa/L1_pytorch_distributed_unittest/test.sh
@@ -11,4 +11,5 @@ pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py
 pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
 pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py
 pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py
+pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py
 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py
new file mode 100644
index 0000000000..0f00a6717b
--- /dev/null
+++ b/tests/pytorch/distributed/run_fsdp2_model.py
@@ -0,0 +1,181 @@
+#!/usr/bin/python3
+
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import os
+import sys
+import argparse
+
+import transformer_engine.pytorch as te
+from transformer_engine.common.recipe import Format, DelayedScaling
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch import nn, optim
+from torch.distributed import DeviceMesh
+from torch.distributed._composable.fsdp import fully_shard
+from torch.distributed.device_mesh import init_device_mesh
+from contextlib import nullcontext
+
+
+class SimpleNet(nn.Module):
+    def __init__(self, input_size, hidden_size, output_size):
+        super(SimpleNet, self).__init__()
+        self.fc1 = te.Linear(input_size, hidden_size)
+        self.fc2 = te.Linear(hidden_size, output_size)
+
+    def forward(self, x):
+        x = F.relu(self.fc1(x))
+        x = self.fc2(x)
+        return x
+
+
+def save_custom_attrs(module):
+    custom_attrs = {}
+    for name, param in module.named_parameters():
+        attrs = vars(param)
+        custom_attrs[name] = {k: v for k, v in attrs.items()}
+    return custom_attrs
+
+
+def restore_custom_attrs(module, custom_attrs):
+    for name, param in module.named_parameters():
+        if name in custom_attrs:
+            for attr_name, attr_value in custom_attrs[name].items():
+                setattr(param, attr_name, attr_value)
+
+
+def _parse_args(argv=None, namespace=None):
+    parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
+    parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
+    parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
+    parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
+    parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
+    parser.add_argument(
+        "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
+    )
+    parser.add_argument(
+        "--iter", type=int, default=10, help="Number of iterations for forward pass"
+    )
+    parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
+    # Adding hsdp_dim as a list argument, comma-separated
+    parser.add_argument(
+        "--sharding-dims",
+        type=int,
+        nargs="+",
+        help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
+    )
+    args = parser.parse_args(argv, namespace)
+    if args.sharding_dims:
+        assert len(args.sharding_dims) <= 2
+    return args
+
+
+sub_modules_to_wrap = [te.Linear]
+
+
+def _train(args):
+    assert "TORCHELASTIC_RUN_ID" in os.environ
+    WORLD_RANK = int(os.getenv("RANK", "0"))
+    WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
+    LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
+    LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
+    assert LOCAL_SIZE == WORLD_SIZE
+
+    # Set device and initialize RNG states
+    torch.cuda.set_device(WORLD_RANK)
+    torch.manual_seed(args.seed)
+    torch.cuda.manual_seed(args.seed)
+
+    # Initialize torch.distributed global process group and get DP/TP groups
+    dist_init_kwargs = {
+        "backend": "nccl",
+        "rank": WORLD_RANK,
+        "world_size": WORLD_SIZE,
+    }
+    assert dist.is_nccl_available()
+    dist.init_process_group(**dist_init_kwargs)
+    nccl_world = dist.new_group(backend="nccl")
+    device = torch.device(f"cuda:{LOCAL_RANK}")
+
+    # FP8 Configuration
+    fp8_format = Format.HYBRID
+    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
+
+    if not args.fp8_init:
+        # Build model context (FP8 init)
+        build_model_context = nullcontext
+        build_model_context_args = {}
+
+        from transformer_engine.pytorch import fp8_model_init
+
+        build_model_context = fp8_model_init
+        build_model_context_args["enabled"] = True
+
+        # Build the model with the specified context
+        with build_model_context(**build_model_context_args):
+            model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
+    else:
+        model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
+    # Move the model to the correct device
+
+    model.to(device)
+
+    if LOCAL_RANK == 0:
+        print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
+    # Creating a DeviceMesh for fully_shard
+    world_size = int(WORLD_SIZE)
+    device_ids = list(range(world_size))
+    if LOCAL_RANK == 0:
+        print(f"sharding-dims:{args.sharding_dims}")
+    # Setup the sharding mesh for FSDP/HSDP
+    if args.sharding_dims == None:  # FSDP
+        mesh = DeviceMesh("cuda", device_ids)
+    elif len(args.sharding_dims) == 1:
+        assert args.sharding_dims[0] == device_ids[-1] + 1
+        mesh = DeviceMesh("cuda", device_ids)
+    elif len(args.sharding_dims) == 2:  # HSDP
+        assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
+        mesh = init_device_mesh(
+            "cuda",
+            (args.sharding_dims[0], args.sharding_dims[1]),
+            mesh_dim_names=("replicate", "shard"),
+        )
+    else:
+        assert False
+
+    # Apply FSDP/HSDP
+    custom_attrs = save_custom_attrs(model)
+    for sub_module in model.modules():
+        if any(
+            isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
+        ):
+            fully_shard(sub_module, mesh=mesh)
+    fully_shard(model, mesh=mesh)
+    restore_custom_attrs(model, custom_attrs)
+
+    optimizer = optim.Adam(model.parameters(), lr=1e-3)
+
+    for iteration in range(args.iter):
+        # Zero the parameter gradients
+        optimizer.zero_grad()
+        input_data = torch.randn(args.batch_size, args.input_size).to(device)
+        output = model(input_data)
+        target = torch.randn(args.batch_size, args.output_size).to(device)
+        loss = F.mse_loss(output, target)
+        loss.backward()
+        optimizer.step()
+        if LOCAL_RANK == 0:
+            print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
+
+    dist.destroy_process_group()
+    if LOCAL_RANK == 0:
+        print(f"Rank {LOCAL_RANK}: Done...")
+    return 0
+
+
+if __name__ == "__main__":
+    sys.exit(_train(_parse_args()))
diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py
new file mode 100644
index 0000000000..3c9197c322
--- /dev/null
+++ b/tests/pytorch/distributed/test_torch_fsdp2.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import os
+import pytest
+import subprocess
+from pathlib import Path
+from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+import torch
+from packaging.version import Version as PkgVersion
+
+
+def get_torch_version():
+    """Get pytorch version from __version__"""
+
+    def get_torch_version_str():
+        import torch
+
+        return str(torch.__version__)
+
+    return PkgVersion(get_torch_version_str())
+
+
+if torch.cuda.device_count() < 4:
+    pytest.skip("FSDP2 test requires at least 4 GPUs.")
+
+if torch.cuda.device_count() % 2 != 0:
+    pytest.skip("Number of device should be divided by 2.")
+
+if not get_torch_version() >= PkgVersion("2.4"):
+    pytest.skip("FSDP2 requires PyTorch >= 2.4.0 with FSDP 2 support.")
+
+fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
+
+TEST_ROOT = Path(__file__).parent.resolve()
+NUM_PROCS: int = torch.cuda.device_count()
+LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
+
+
+def _run_test(fp_init, sharding_dims):
+    test_path = TEST_ROOT / "run_fsdp2_model.py"
+    test_cmd = LAUNCH_CMD + [str(test_path)]
+
+    if fp_init:
+        test_cmd += ["--fp8-init"]
+    if len(sharding_dims) == 1:
+        test_cmd += ["--sharding-dims", str(sharding_dims[0])]
+    elif len(sharding_dims) == 2:
+        test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
+    else:
+        assert False
+    result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
+    if result.returncode != 0:
+        raise AssertionError(result.stderr.decode())
+
+
+all_boolean = [True, False]
+sharding_dims = [[NUM_PROCS], [2, NUM_PROCS // 2]]
+
+
+@pytest.mark.parametrize("sharding_dims", sharding_dims)
+@pytest.mark.parametrize("fp8_init", all_boolean)
+def test_distributed(fp8_init, sharding_dims):
+    if fp8_init and not fp8_available:
+        pytest.skip(reason_for_no_fp8)
+    _run_test(fp8_init, sharding_dims)
diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py
index 7ace68a222..414e819f53 100644
--- a/transformer_engine/pytorch/tensor/float8_tensor.py
+++ b/transformer_engine/pytorch/tensor/float8_tensor.py
@@ -24,6 +24,19 @@
 aten = torch.ops.aten
 updated_fp8_params = {}
 
+_ops_to_preserve_subclass_in_fsdp2 = {
+    torch.ops.aten.empty_like.default,
+    torch.ops.aten.new_zeros.default,
+    torch.ops.aten.slice.Tensor,
+    torch.ops.aten.copy_.default,
+    torch.ops.aten.view.default,
+    torch.ops.aten.as_strided.default,
+    torch.ops.aten._to_copy.default,
+    torch.ops.aten._pin_memory.default,
+    torch.ops.aten.split.Tensor,
+    torch.ops.aten.clone.default,
+}
+
 
 def _make_fp8_attr_property_funcs(name: str) -> Any:
     """Make accessors for an FP8 attribute
@@ -430,6 +443,37 @@ def __new__(
 
         return self
 
+    def fsdp_pre_all_gather(self, mesh):  # pylint: disable=unused-argument
+        """
+        A hook function used in torch fsdp2, called before all-gather
+        return (all-gather input), (metadata)
+        Ref: https://github.com/pytorch/pytorch/pull/122908
+
+        """
+
+        return (self._data,), (self,)
+
+    def fsdp_post_all_gather(
+        self,
+        all_gather_outputs: Tuple[torch.Tensor, ...],
+        metadata: Any,
+        param_dtype: torch.dtype,  # pylint: disable=unused-argument
+        *,
+        out: Optional[torch.Tensor] = None,
+    ):
+        """
+        A hook function used in torch fsdp2, called after all-gather
+        return (Float8Tensor class instance of all-gathered input), (Things to free after forward)
+        Ref: https://github.com/pytorch/pytorch/pull/122908
+
+        """
+        (data,) = all_gather_outputs
+        (sample,) = metadata
+        if out is not None:
+            assert isinstance(out, Float8Tensor), f"{type(out)}"
+            return None
+        return Float8Tensor.make_like(sample, data=data), all_gather_outputs
+
     @classmethod
     def make_like(
         cls,
@@ -902,7 +946,53 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
             )
             return Float8Tensor.make_like(tensor, data=data_view)
 
-        # Default case
+        # Related to FSDP2
+        if func == aten.split.Tensor:
+            tensor = args[0]
+            data = tensor._data
+            func_out = data.__torch_dispatch__(
+                func,
+                types,
+                [data] + list(args[1:]),
+                kwargs,
+            )
+            return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out]
+        if func == aten.new_zeros.default:
+            tensor = args[0]
+            data = tensor._data
+            func_out = data.__torch_dispatch__(
+                func,
+                types,
+                [data] + list(args[1:]),
+                kwargs,
+            )
+            return Float8Tensor.make_like(tensor, data=func_out)
+        if func == torch.ops.aten.as_strided.default:
+            tensor = args[0]
+            data = tensor._data
+            func_out = data.__torch_dispatch__(
+                func,
+                types,
+                [data] + list(args[1:]),
+                kwargs,
+            )
+            return Float8Tensor.make_like(tensor, data=func_out)
+        if func == torch.ops.aten.detach.default:
+            return cls.detach(args[0])
+        if func == torch.ops.aten.clone.default:
+            return cls.clone(args[0])
+        if func == torch.ops.aten.copy_.default:
+            # Implementation in the superclass (QuantizedTensor) returns a proper output
+            pass
+        elif func in _ops_to_preserve_subclass_in_fsdp2:
+            # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2
+            warnings.warn(
+                f"A function call({func}) in {cls} may not return {cls} tensor as an output. It"
+                " might cause an error in torch FSDP2!"
+            )
+        else:
+            pass
+
         return super().__torch_dispatch__(func, types, args, kwargs)
 
     @classmethod