Skip to content

Commit

Permalink
Refined FSDP saving logic and error messaging when path exists (#18884)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Oct 30, 2023
1 parent 2526c90 commit e66be67
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 15 deletions.
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed false-positive warnings about method calls on the Fabric-wrapped module ([#18819](https://github.com/Lightning-AI/lightning/pull/18819))


- Refined the FSDP saving logic and error messaging when path exists ([#18884](https://github.com/Lightning-AI/lightning/pull/18884))


## [2.1.0] - 2023-10-11

### Added
Expand Down
12 changes: 9 additions & 3 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from contextlib import ExitStack
from datetime import timedelta
from functools import partial
Expand Down Expand Up @@ -432,8 +432,8 @@ def save_checkpoint(

# broadcast the path from rank 0 to ensure all the states are saved in a common path
path = Path(self.broadcast(path))
if path.is_dir() and os.listdir(path):
raise FileExistsError(f"The checkpoint directory already exists and is not empty: {path}")
if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path):
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")

from torch.distributed.checkpoint import FileSystemWriter, save_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand All @@ -454,7 +454,10 @@ def save_checkpoint(
module = modules[0]

if self._state_dict_type == "sharded":
if path.is_file():
path.unlink()
path.mkdir(parents=True, exist_ok=True)

state_dict_ctx = _get_sharded_state_dict_context(module)

# replace the modules and optimizer objects in the state with their local state dict
Expand Down Expand Up @@ -483,6 +486,9 @@ def save_checkpoint(
torch.save(metadata, path / _METADATA_FILENAME)

elif self._state_dict_type == "full":
if _is_sharded_checkpoint(path):
shutil.rmtree(path)

state_dict_ctx = _get_full_state_dict_context(module, world_size=self.world_size)
full_state: Dict[str, Any] = {}
with state_dict_ctx:
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue saving the `last.ckpt` file when using `ModelCheckpoint` on a remote filesystem and no logger is used ([#18867](https://github.com/Lightning-AI/lightning/issues/18867))


- Refined the FSDP saving logic and error messaging when path exists ([#18884](https://github.com/Lightning-AI/lightning/pull/18884))


## [2.1.0] - 2023-10-11

### Added
Expand Down
10 changes: 7 additions & 3 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -522,12 +522,14 @@ def save_checkpoint(
)

path = Path(self.broadcast(filepath))
if path.is_dir() and os.listdir(path):
raise FileExistsError(f"The checkpoint directory already exists and is not empty: {path}")
if path.is_dir() and self._state_dict_type == "full" and not _is_sharded_checkpoint(path):
raise IsADirectoryError(f"The checkpoint path exists and is a directory: {path}")

if self._state_dict_type == "sharded":
from torch.distributed.checkpoint import FileSystemWriter, save_state_dict

if path.is_file():
path.unlink()
path.mkdir(parents=True, exist_ok=True)

converted_state = {"model": checkpoint.pop("state_dict")}
Expand All @@ -542,6 +544,8 @@ def save_checkpoint(
if self.global_rank == 0:
torch.save(checkpoint, path / _METADATA_FILENAME)
elif self._state_dict_type == "full":
if _is_sharded_checkpoint(path):
shutil.rmtree(path)
return super().save_checkpoint(checkpoint=checkpoint, filepath=path)
else:
raise ValueError(f"Unknown state_dict_type: {self._state_dict_type}")
Expand Down
55 changes: 51 additions & 4 deletions tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_FSDPBackwardSyncControl,
_get_full_state_dict_context,
_has_meta_device_parameters,
_is_sharded_checkpoint,
)
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision
Expand Down Expand Up @@ -275,15 +276,61 @@ def test_fsdp_save_checkpoint_storage_options(tmp_path):


@RunIf(min_torch="2.0.0")
@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock())
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
def test_fsdp_save_checkpoint_folder_exists(tmp_path):
path = tmp_path / "exists"
@mock.patch("lightning.fabric.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock())
@mock.patch("lightning.fabric.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock())
@mock.patch("lightning.fabric.strategies.fsdp.torch.save", return_value=Mock())
@mock.patch("lightning.fabric.strategies.fsdp.shutil", return_value=MagicMock())
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):
strategy = FSDPStrategy(state_dict_type="full")

# state_dict_type='full', path exists, path is not a sharded checkpoint: error
path = tmp_path / "not-empty"
path.mkdir()
(path / "file").touch()
strategy = FSDPStrategy()
with pytest.raises(FileExistsError, match="exists and is not empty"):
assert not _is_sharded_checkpoint(path)
with pytest.raises(IsADirectoryError, match="exists and is a directory"):
strategy.save_checkpoint(path=path, state=Mock())

# state_dict_type='full', path exists, path is a sharded checkpoint: no error (overwrite)
path = tmp_path / "sharded-checkpoint"
path.mkdir()
(path / "meta.pt").touch()
assert _is_sharded_checkpoint(path)
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint(path=path, state={"model": model})
shutil_mock.rmtree.assert_called_once_with(path)

# state_dict_type='full', path exists, path is a file: no error (overwrite)
path = tmp_path / "file.pt"
path.touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
torch_save_mock.reset_mock()
strategy.save_checkpoint(path=path, state={"model": model})
torch_save_mock.assert_called_once()

strategy = FSDPStrategy(state_dict_type="sharded")

# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
path = tmp_path / "not-empty-2"
path.mkdir()
(path / "file").touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint(path=path, state={"model": model})
assert (path / "file").exists()

# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
path = tmp_path / "file-2.pt"
path.touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint(path=path, state={"model": model})
assert path.is_dir()


@RunIf(min_torch="2.0.0")
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
Expand Down
58 changes: 53 additions & 5 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
import torch.nn as nn
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_0,
_TORCH_GREATER_EQUAL_2_1,
Expand Down Expand Up @@ -760,14 +761,61 @@ def test_save_checkpoint_storage_options(tmp_path):
strategy.save_checkpoint(filepath=tmp_path, checkpoint=Mock(), storage_options=Mock())


@RunIf(min_torch="2.0.0")
@mock.patch("torch.distributed.checkpoint.save_state_dict", return_value=MagicMock())
@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
def test_save_checkpoint_folder_exists(tmp_path):
path = tmp_path / "exists"
@mock.patch("lightning.pytorch.strategies.fsdp._get_full_state_dict_context", return_value=MagicMock())
@mock.patch("lightning.pytorch.strategies.fsdp._get_sharded_state_dict_context", return_value=MagicMock())
@mock.patch("lightning.fabric.plugins.io.torch_io._atomic_save", return_value=Mock())
@mock.patch("lightning.pytorch.strategies.fsdp.shutil", return_value=MagicMock())
def test_fsdp_save_checkpoint_path_exists(shutil_mock, torch_save_mock, __, ___, ____, tmp_path):
strategy = FSDPStrategy(state_dict_type="full")

# state_dict_type='full', path exists, path is not a sharded checkpoint: error
path = tmp_path / "not-empty"
path.mkdir()
(path / "file").touch()
strategy = FSDPStrategy()
with pytest.raises(FileExistsError, match="exists and is not empty"):
strategy.save_checkpoint(filepath=tmp_path, checkpoint=Mock())
assert not _is_sharded_checkpoint(path)
with pytest.raises(IsADirectoryError, match="exists and is a directory"):
strategy.save_checkpoint(Mock(), filepath=path)

# state_dict_type='full', path exists, path is a sharded checkpoint: no error (overwrite)
path = tmp_path / "sharded-checkpoint"
path.mkdir()
(path / "meta.pt").touch()
assert _is_sharded_checkpoint(path)
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint(Mock(), filepath=path)
shutil_mock.rmtree.assert_called_once_with(path)

# state_dict_type='full', path exists, path is a file: no error (overwrite)
path = tmp_path / "file.pt"
path.touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
torch_save_mock.reset_mock()
strategy.save_checkpoint(Mock(), filepath=path)
torch_save_mock.assert_called_once()

strategy = FSDPStrategy(state_dict_type="sharded")

# state_dict_type='sharded', path exists, path is a folder: no error (overwrite)
path = tmp_path / "not-empty-2"
path.mkdir()
(path / "file").touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
assert (path / "file").exists()

# state_dict_type='sharded', path exists, path is a file: no error (overwrite)
path = tmp_path / "file-2.pt"
path.touch()
model = Mock(spec=FullyShardedDataParallel)
model.modules.return_value = [model]
strategy.save_checkpoint({"state_dict": {}, "optimizer_states": {"": {}}}, filepath=path)
assert path.is_dir()


@mock.patch("lightning.pytorch.strategies.fsdp.FSDPStrategy.broadcast", lambda _, x: x)
Expand Down

0 comments on commit e66be67

Please sign in to comment.