Skip to content

Commit

Permalink
Update evaluation logging test (#18896)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Nov 18, 2023
1 parent b8a96fe commit de7faf9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 31 deletions.
52 changes: 21 additions & 31 deletions tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from contextlib import redirect_stdout
from io import StringIO
from unittest import mock
from unittest.mock import call
from unittest.mock import ANY, call

import numpy as np
import pytest
Expand Down Expand Up @@ -527,57 +527,47 @@ def test_step(self, batch, batch_idx):
trainer = Trainer(
default_root_dir=tmpdir,
logger=TensorBoardLogger(tmpdir),
limit_train_batches=2,
limit_train_batches=1,
limit_val_batches=2,
limit_test_batches=2,
log_every_n_steps=1,
max_epochs=2,
)

# Train the model ⚡
trainer.fit(model)

# hp_metric + 2 steps + epoch + 2 steps + epoch
expected_num_calls = 1 + 2 + 1 + 2 + 1

assert set(trainer.callback_metrics) == {
"train_loss",
"valid_loss_0_epoch",
"valid_loss_0",
"valid_loss_1",
}
assert len(mock_log_metrics.mock_calls) == expected_num_calls
assert mock_log_metrics.mock_calls[0] == call({"hp_metric": -1}, 0)
assert mock_log_metrics.mock_calls == [
call({"hp_metric": -1}, 0),
call(metrics={"train_loss": ANY, "epoch": 0}, step=0),
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=0),
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=1),
call(metrics={"valid_loss_0_epoch": ANY, "valid_loss_1": ANY, "epoch": 0}, step=0),
call(metrics={"train_loss": ANY, "epoch": 1}, step=1),
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=2),
call(metrics={"valid_loss_0_step": ANY, "valid_loss_2": ANY}, step=3),
call(metrics={"valid_loss_0_epoch": ANY, "valid_loss_1": ANY, "epoch": 1}, step=1),
]

def get_metrics_at_idx(idx):
mock_call = mock_log_metrics.mock_calls[idx]
return mock_call.kwargs["metrics"] if _PYTHON_GREATER_EQUAL_3_8_0 else mock_call[2]["metrics"]

expected = {"valid_loss_0_step", "valid_loss_2"}
assert set(get_metrics_at_idx(1)) == expected
assert set(get_metrics_at_idx(2)) == expected

assert get_metrics_at_idx(1)["valid_loss_0_step"] == model.val_losses[2]
assert get_metrics_at_idx(2)["valid_loss_0_step"] == model.val_losses[3]

assert set(get_metrics_at_idx(3)) == {"valid_loss_0_epoch", "valid_loss_1", "epoch"}

assert get_metrics_at_idx(3)["valid_loss_1"] == torch.stack(model.val_losses[2:4]).mean()

expected = {"valid_loss_0_step", "valid_loss_2"}
assert set(get_metrics_at_idx(4)) == expected
assert set(get_metrics_at_idx(5)) == expected

assert get_metrics_at_idx(4)["valid_loss_0_step"] == model.val_losses[4]
assert get_metrics_at_idx(5)["valid_loss_0_step"] == model.val_losses[5]

assert set(get_metrics_at_idx(6)) == {"valid_loss_0_epoch", "valid_loss_1", "epoch"}

assert get_metrics_at_idx(6)["valid_loss_1"] == torch.stack(model.val_losses[4:]).mean()
assert get_metrics_at_idx(2)["valid_loss_0_step"] == model.val_losses[2]
assert get_metrics_at_idx(3)["valid_loss_0_step"] == model.val_losses[3]
assert get_metrics_at_idx(4)["valid_loss_1"] == torch.stack(model.val_losses[2:4]).mean()
assert get_metrics_at_idx(6)["valid_loss_0_step"] == model.val_losses[4]
assert get_metrics_at_idx(7)["valid_loss_0_step"] == model.val_losses[5]
assert get_metrics_at_idx(8)["valid_loss_1"] == torch.stack(model.val_losses[4:]).mean()

results = trainer.test(model)
assert set(trainer.callback_metrics) == {
"test_loss",
}
assert set(trainer.callback_metrics) == {"test_loss"}
assert set(results[0]) == {"test_loss"}


Expand Down
3 changes: 3 additions & 0 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,6 +1925,9 @@ def training_step(self, batch, batch_idx):
def test_trainer_config_strategy(monkeypatch, trainer_kwargs, strategy_cls, accelerator_cls, devices):
if trainer_kwargs.get("accelerator") == "cuda":
mock_cuda_count(monkeypatch, trainer_kwargs["devices"])
if trainer_kwargs.get("accelerator") == "auto":
# current parametrizations assume non-CUDA env
mock_cuda_count(monkeypatch, 0)

trainer = Trainer(**trainer_kwargs)

Expand Down

0 comments on commit de7faf9

Please sign in to comment.