Skip to content

Commit

Permalink
Move torchmetrics to device when using FSDP (#18954)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Nov 8, 2023
1 parent 07461a1 commit 964364b
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)

import torch
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from torch.nn import Module, Parameter
from torch.optim import Optimizer
Expand Down Expand Up @@ -292,6 +293,8 @@ def setup_module(self, module: Module) -> Module:
**self._fsdp_kwargs,
)

_move_torchmetrics_to_device(module, self.root_device)

# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13:
_setup_activation_checkpointing(module, self._activation_checkpointing_kwargs)
Expand Down Expand Up @@ -886,3 +889,15 @@ def _has_meta_device_parameters(obj: Union[Module, Optimizer]) -> bool:
if isinstance(obj, Module):
return any(t.is_meta for t in obj.parameters())
raise TypeError(f"Expected `torch.nn.Module` or `torch.optim.Optimizer`, got: {type(obj).__name__}")


def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) -> None:
# FSDP doesn't move modules without parameters (e.g. Metrics) to the device
# https://github.com/pytorch/pytorch/issues/113113
if not RequirementCache("torchmetrics"):
return

from torchmetrics import Metric

for metric in (m for m in module.modules() if isinstance(m, Metric)):
metric.to(device) # `.to()` is in-place
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue causing permission errors on Windows when attempting to create a symlink for the "last" checkpoint ([#18942](https://github.com/Lightning-AI/lightning/issues/18942))


- Fixed an issue where Metric instances from `torchmetrics` wouldn't get moved to the device when using FSDP ([#18954](https://github.com/Lightning-AI/lightning/issues/18954))


## [2.1.0] - 2023-10-11

### Added
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
_is_full_checkpoint,
_is_sharded_checkpoint,
_load_raw_module_state,
_move_torchmetrics_to_device,
_optimizer_has_flat_params,
_setup_activation_checkpointing,
)
Expand Down Expand Up @@ -292,6 +293,8 @@ def _setup_model(self, model: Module) -> Module:
**self.kwargs,
)

_move_torchmetrics_to_device(model, self.root_device)

# activation checkpointing needs to be set up after wrapping the model
if _TORCH_GREATER_EQUAL_1_13:
_setup_activation_checkpointing(model, self._activation_checkpointing_kwargs)
Expand Down
31 changes: 31 additions & 0 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, size_based_auto_wrap_policy, wrap
from torchmetrics import Accuracy

from tests_pytorch.helpers.runif import RunIf

Expand Down Expand Up @@ -239,6 +240,36 @@ def test_fsdp_strategy_sync_batchnorm(tmpdir):
_run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt"))


@RunIf(min_cuda_gpus=1, skip_windows=True)
def test_fsdp_modules_without_parameters(tmp_path):
"""Test that TorchMetrics get moved to the device despite not having any parameters."""

class MetricsModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = Accuracy("multiclass", num_classes=10)
assert self.metric.device == self.metric.tp.device == torch.device("cpu")

def setup(self, stage) -> None:
assert self.metric.device == self.metric.tp.device == torch.device("cpu")

def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)
assert self.metric.device == self.metric.tp.device == torch.device("cuda", 0)
self.metric(torch.rand(2, 10, device=self.device), torch.randint(0, 10, size=(2,), device=self.device))
return loss

model = MetricsModel()
trainer = Trainer(
default_root_dir=tmp_path,
accelerator="cuda",
devices=1,
strategy="fsdp",
max_steps=1,
)
trainer.fit(model)


@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
@pytest.mark.parametrize("precision", ["16-mixed", pytest.param("bf16-mixed", marks=RunIf(bf16_cuda=True))])
def test_fsdp_strategy_checkpoint(tmpdir, precision):
Expand Down

0 comments on commit 964364b

Please sign in to comment.