Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/best model checkpoint fix #35885

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
30 changes: 27 additions & 3 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import importlib
import inspect
import logging
import math
import multiprocessing
import os
import re
Expand All @@ -38,14 +39,15 @@
from functools import wraps
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
from typing import Callable, Dict, Generator, Iterable, Iterator, List, Optional, Union
from unittest import mock
from unittest.mock import patch

import huggingface_hub.utils
import urllib3
from huggingface_hub import delete_repo

from transformers import Trainer
from transformers import logging as transformers_logging

from .integrations import (
Expand Down Expand Up @@ -1405,6 +1407,28 @@ def get_tests_dir(append_path=None):
return tests_dir


def get_steps_per_epoch(trainer: Trainer) -> int:
train_dataloader = trainer.get_train_dataloader()
batches_per_epoch = len(train_dataloader)
steps_per_epoch = math.ceil(batches_per_epoch / trainer.args.gradient_accumulation_steps)

return steps_per_epoch


def evaluate_side_effect_factory(
side_effect_values: List[Dict[str, float]],
) -> Generator[Dict[str, float], None, None]:
"""
Function that returns side effects for the _evaluate method.
Used when we're unsure of exactly how many times _evaluate will be called.
"""
for side_effect_value in side_effect_values:
yield side_effect_value

while True:
yield side_effect_values[-1]


#
# Helper functions for dealing with testing text outputs
# The original code came from:
Expand Down Expand Up @@ -2170,7 +2194,7 @@ def pytest_terminal_summary_main(tr, id):
f.write("slowest durations\n")
for i, rep in enumerate(dlist):
if rep.duration < durations_min:
f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
f.write(f"{len(dlist) - i} durations < {durations_min} secs were omitted")
break
f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")

Expand Down Expand Up @@ -2555,7 +2579,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
process.join(timeout=timeout)

if results["error"] is not None:
test_case.fail(f'{results["error"]}')
test_case.fail(f"{results['error']}")


def run_test_using_subprocess(func):
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3195,12 +3195,10 @@ def _determine_best_metric(self, metrics, trial):
self.state.best_metric = float("-inf") if self.args.greater_is_better else float("inf")

if operator(metric_value, self.state.best_metric):
run_dir = self._get_output_dir(trial=trial)
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
output_dir = os.path.join(run_dir, checkpoint_folder)

self.state.best_metric = metric_value
self.state.best_model_checkpoint = output_dir

if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH]:
self.state.best_global_step = self.state.global_step

is_new_best_metric = True

Expand All @@ -3221,6 +3219,13 @@ def _save_checkpoint(self, model, trial):
output_dir = os.path.join(run_dir, checkpoint_folder)
self.save_model(output_dir, _internal_call=True)

if self.args.save_strategy in [SaveStrategy.STEPS, SaveStrategy.EPOCH] and self.state.best_global_step:
best_checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.best_global_step}"
best_checkpoint_dir = os.path.join(run_dir, best_checkpoint_folder)

if os.path.exists(best_checkpoint_dir):
self.state.best_model_checkpoint = best_checkpoint_dir

if not self.args.save_only_model:
# Save optimizer and scheduler
self._save_optimizer_and_scheduler(output_dir)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class TrainerState:
The list of logs done since the beginning of training.
best_metric (`float`, *optional*):
When tracking the best model, the value of the best metric encountered so far.
best_global_step (`int`, *optional*):
When tracking the best model, the step at which the best metric was encountered.
Used for setting `best_model_checkpoint`.
best_model_checkpoint (`str`, *optional*):
When tracking the best model, the value of the name of the checkpoint for the best model encountered so
far.
Expand Down Expand Up @@ -102,6 +105,7 @@ class TrainerState:
total_flos: float = 0
log_history: List[Dict[str, float]] = None
best_metric: Optional[float] = None
best_global_step: Optional[int] = None
best_model_checkpoint: Optional[str] = None
is_local_process_zero: bool = True
is_world_process_zero: bool = True
Expand Down
189 changes: 188 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@
TemporaryHubRepo,
TestCasePlus,
backend_device_count,
evaluate_side_effect_factory,
execute_subprocess_async,
get_gpu_count,
get_steps_per_epoch,
get_tests_dir,
is_staging_test,
require_accelerate,
Expand Down Expand Up @@ -656,7 +658,7 @@ def convert_to_sharded_checkpoint(self, folder, save_safe=True, load_safe=True):
keys = list(state_dict.keys())

shard_files = [
shard_name.replace(f".{extension}", f"-{idx+1:05d}-of-{len(keys):05d}.{extension}")
shard_name.replace(f".{extension}", f"-{idx + 1:05d}-of-{len(keys):05d}.{extension}")
for idx in range(len(keys))
]
index = {"metadata": {}, "weight_map": {key: shard_files[i] for i, key in enumerate(keys)}}
Expand Down Expand Up @@ -4423,6 +4425,191 @@ def test_metric_for_best_model_behavior(self):
)
self.assertTrue(trainer.args.metric_for_best_model == "loss")

def test_best_model_checkpoint_behavior(self):
# Case 1. Never evaluated, save_total_limit > 1 and save_steps == 1.
# Both best_metric and best_model_checkpoint should be None.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="steps",
save_strategy="steps",
save_steps=1,
metric_for_best_model="accuracy",
greater_is_better=True,
)
trainer.train()

assert trainer.state.best_metric is None
assert trainer.state.best_model_checkpoint is None
assert len(os.listdir(tmpdir)) == trainer.state.global_step

# Case 2. Never evaluated and save_total_limit == 1.
# Both best_metric and best_model_checkpoint should be None.
# Only the last checkpoint should remain.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="steps",
save_strategy="steps",
save_steps=1,
metric_for_best_model="accuracy",
greater_is_better=True,
save_total_limit=1,
)
trainer.train()

num_steps = trainer.state.global_step

assert trainer.state.best_metric is None
assert trainer.state.best_model_checkpoint is None
assert len(os.listdir(tmpdir)) == 1

ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{num_steps}")
assert os.path.isdir(ckpt)
assert os.listdir(tmpdir)[0] == f"{PREFIX_CHECKPOINT_DIR}-{num_steps}"

# Case 3. eval_strategy == save_strategy.
# best_model_checkpoint should be at epoch 1.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="epoch",
save_strategy="epoch",
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
greater_is_better=True,
load_best_model_at_end=False,
)

with patch.object(
trainer,
"_evaluate",
side_effect=evaluate_side_effect_factory(
[
{"eval_accuracy": 0.59},
{"eval_accuracy": 0.57},
{"eval_accuracy": 0.55},
]
),
):
trainer.train()

steps_per_epoch = get_steps_per_epoch(trainer)

assert trainer.state.best_metric == 0.59
assert trainer.state.best_global_step == steps_per_epoch

best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
assert trainer.state.best_model_checkpoint == best_ckpt

assert len(os.listdir(tmpdir)) == trainer.state.num_train_epochs

# Case 4. eval_strategy != save_strategy.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="epoch",
save_strategy="steps",
save_steps=1,
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
greater_is_better=True,
load_best_model_at_end=False,
)

with patch.object(
trainer,
"_evaluate",
side_effect=evaluate_side_effect_factory(
[
{"eval_accuracy": 0.59},
{"eval_accuracy": 0.57},
{"eval_accuracy": 0.55},
]
),
):
trainer.train()

steps_per_epoch = get_steps_per_epoch(trainer)

assert trainer.state.best_metric == 0.59
assert trainer.state.best_global_step == steps_per_epoch

best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
assert trainer.state.best_model_checkpoint == best_ckpt

assert len(os.listdir(tmpdir)) == trainer.state.global_step

# Case 5. Multiple checkpoints, save_total_limit == 1.
# Best metric is found at step 1 and that checkpoint should be saved.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="steps",
eval_steps=1,
save_strategy="steps",
save_steps=1,
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
greater_is_better=True,
save_total_limit=1,
)

with patch.object(
trainer,
"_evaluate",
side_effect=evaluate_side_effect_factory(
[
{"eval_accuracy": 0.90},
{"eval_accuracy": 0.80},
{"eval_accuracy": 0.70},
]
),
):
trainer.train()

assert trainer.state.best_metric == 0.90
assert trainer.state.best_global_step == 1

best_ckpt = os.path.join(tmpdir, f"{PREFIX_CHECKPOINT_DIR}-{trainer.state.best_global_step}")
assert trainer.state.best_model_checkpoint == best_ckpt

assert len(os.listdir(tmpdir)) == 1

# Case 6. Saving happens more often and eval/save mismatch.
# `best_model_checkpoint` should be None due to a step mismatch.
with tempfile.TemporaryDirectory() as tmpdir:
trainer = get_regression_trainer(
output_dir=tmpdir,
eval_strategy="steps",
eval_steps=3,
save_strategy="steps",
save_steps=2,
metric_for_best_model="accuracy",
compute_metrics=AlmostAccuracy(),
greater_is_better=True,
)

with patch.object(
trainer,
"_evaluate",
side_effect=evaluate_side_effect_factory(
[
{"eval_accuracy": 0.90},
{"eval_accuracy": 0.80},
{"eval_accuracy": 0.70},
]
),
):
trainer.train()

assert trainer.state.best_metric == 0.90
assert trainer.state.best_global_step == 3

assert trainer.state.best_model_checkpoint is None

assert len(os.listdir(tmpdir)) == trainer.state.global_step // 2


@require_torch
@is_staging_test
Expand Down