Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNCF] NNCF common accuracy aware training code pass mypy checks #2521

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
files = nncf/common/sparsity
files = nncf/common/accuracy_aware_training
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't replace the path, you should extend it.

follow_imports = silent
strict = True

Expand Down
94 changes: 51 additions & 43 deletions nncf/common/accuracy_aware_training/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pathlib
from abc import ABC
from abc import abstractmethod
from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

from nncf.api.compression import CompressionAlgorithmController
from nncf.api.compression import CompressionStage
Expand Down Expand Up @@ -127,8 +127,8 @@ def initialize_training_loop_fns(
validate_fn: Callable[[TModel, Optional[int]], float],
configure_optimizers_fn: Callable[[], Tuple[OptimizerType, LRSchedulerType]],
dump_checkpoint_fn: Callable[[TModel, CompressionAlgorithmController, "TrainingRunner", str], None],
**kwargs,
):
**kwargs: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type hint is incorrect, review your general Python knowledge as to what type the **kwargs parameter is constrained to.

) -> None:
"""
Register the user-supplied functions to be used to control the training process.

Expand All @@ -146,7 +146,7 @@ def initialize_logging(
self,
log_dir: Optional[Union[str, pathlib.Path]] = None,
tensorboard_writer: Optional[TensorboardWriterType] = None,
):
) -> None:
"""
Initialize logging related variables

Expand All @@ -164,7 +164,7 @@ def load_best_checkpoint(self, model: TModel) -> float:
"""

@abstractmethod
def is_model_fully_compressed(self, compression_controller) -> bool:
def is_model_fully_compressed(self, compression_controller: CompressionAlgorithmController) -> bool:
"""
Check if model is fully compressed

Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
self.maximal_absolute_accuracy_drop = accuracy_aware_training_params.get(
"maximal_absolute_accuracy_degradation"
)
self.maximal_total_epochs = accuracy_aware_training_params.get("maximal_total_epochs", AA_MAXIMAL_TOTAL_EPOCHS)
self.maximal_total_epochs: int = accuracy_aware_training_params.get("maximal_total_epochs", AA_MAXIMAL_TOTAL_EPOCHS)

self.verbose = verbose
self.dump_checkpoints = dump_checkpoints
Expand All @@ -213,8 +213,8 @@ def __init__(
self.current_val_metric_value = 0
self.current_loss = 0

self._compressed_training_history = []
self._best_checkpoint = None
self._compressed_training_histor: List[Tuple[float, float]] = []
self._best_checkpoint: Optional[Tuple[str, float]] = None

self._train_epoch_fn = None
self._validate_fn = None
Expand All @@ -224,11 +224,15 @@ def __init__(
self._early_stopping_fn = None
self._update_learning_rate_fn = None

self._log_dir = None
self._checkpoint_save_dir = None
self._tensorboard_writer = None
self._log_dir: Optional[Union[str, pathlib.Path]] = None
self._checkpoint_save_dir: Optional[Union[str, pathlib.Path]] = None
self._tensorboard_writer: Optional[TensorboardWriterType] = None

def train_epoch(self, model, compression_controller):
def train_epoch(
self,
model: TModel,
compression_controller: CompressionAlgorithmController,
) -> None:
compression_controller.scheduler.epoch_step()
# assuming that epoch number is only used for logging in train_fn:
self.current_loss = self._train_epoch_fn(
Expand All @@ -241,7 +245,7 @@ def train_epoch(self, model, compression_controller):
self.training_epoch_count += 1
self.cumulative_epoch_count += 1

def dump_statistics(self, model, compression_controller):
def dump_statistics(self, model: TModel, compression_controller: CompressionAlgorithmController) -> None:
statistics = compression_controller.statistics()

if self.verbose:
Expand All @@ -259,15 +263,19 @@ def dump_statistics(self, model, compression_controller):

self.dump_checkpoint(model, compression_controller)

def calculate_minimal_tolerable_accuracy(self, uncompressed_model_accuracy: float):
def calculate_minimal_tolerable_accuracy(self, uncompressed_model_accuracy: float) -> None:
if self.maximal_absolute_accuracy_drop is not None:
self.minimal_tolerable_accuracy = uncompressed_model_accuracy - self.maximal_absolute_accuracy_drop
else:
self.minimal_tolerable_accuracy = uncompressed_model_accuracy * (
1 - 0.01 * self.maximal_relative_accuracy_drop
)

def dump_checkpoint(self, model, compression_controller):
def dump_checkpoint(
self,
model: TModel,
compression_controller: CompressionAlgorithmController
) -> None:
is_best_checkpoint = (
self.best_val_metric_value == self.current_val_metric_value
and self.is_model_fully_compressed(compression_controller)
Expand All @@ -285,19 +293,19 @@ def dump_checkpoint(self, model, compression_controller):
if is_best_checkpoint:
self._save_best_checkpoint(model, compression_controller)

def configure_optimizers(self):
def configure_optimizers(self) -> None:
self.optimizer, self.lr_scheduler = self._configure_optimizers_fn()

def initialize_training_loop_fns(
self,
train_epoch_fn,
validate_fn,
configure_optimizers_fn,
dump_checkpoint_fn,
load_checkpoint_fn=None,
early_stopping_fn=None,
update_learning_rate_fn=None,
):
train_epoch_fn: Callable[[TModel, CompressionAlgorithmController], None],
validate_fn: Callable[[TModel, Optional[int]], float],
configure_optimizers_fn: Callable[[], Tuple[OptimizerType, LRSchedulerType]],
dump_checkpoint_fn: Callable[[TModel, CompressionAlgorithmController, "TrainingRunner", str], None],
load_checkpoint_fn: Callable[[TModel, str], None] = None,
early_stopping_fn: Callable[[float], bool] = None,
update_learning_rate_fn: Callable[[LRSchedulerType, float, float, float], None] = None,
) -> None:
self._train_epoch_fn = train_epoch_fn
self._validate_fn = validate_fn
self._configure_optimizers_fn = configure_optimizers_fn
Expand All @@ -306,34 +314,34 @@ def initialize_training_loop_fns(
self._early_stopping_fn = early_stopping_fn
self._update_learning_rate_fn = update_learning_rate_fn

def initialize_logging(self, log_dir=None, tensorboard_writer=None):
self._log_dir = log_dir if log_dir is not None else osp.join(os.getcwd(), "runs")
def initialize_logging(self, log_dir: Optional[Union[str, pathlib.Path]] = None, tensorboard_writer: Optional[TensorboardWriterType] = None) -> None:
self._log_dir = str(log_dir) if log_dir is not None else osp.join(os.getcwd(), "runs")
self._log_dir = configure_accuracy_aware_paths(self._log_dir)
self._checkpoint_save_dir = self._log_dir
self._tensorboard_writer = tensorboard_writer

def stop_training(self, compression_controller):
def stop_training(self, compression_controller: CompressionAlgorithmController) -> bool:
if self.is_model_fully_compressed(compression_controller) and self._early_stopping_fn is not None:
return self._early_stopping_fn(self.current_val_metric_value)
return False

def _save_best_checkpoint(self, model, compression_controller):
def _save_best_checkpoint(self, model: TModel, compression_controller: CompressionAlgorithmController) -> None:
best_path = self._make_checkpoint_path(is_best=True)
self._best_checkpoint = (best_path, compression_controller.compression_rate)
self._save_checkpoint(model, compression_controller, best_path)
nncf_logger.info(f"Saved the best model to {best_path}")

def load_best_checkpoint(self, model):
def load_best_checkpoint(self, model: TModel) -> float:
resuming_checkpoint_path, compression_rate = self._best_checkpoint
nncf_logger.info(f"Loading the best checkpoint found during training: {resuming_checkpoint_path}")
self._load_checkpoint(model, resuming_checkpoint_path)
return compression_rate

def is_model_fully_compressed(self, compression_controller) -> bool:
def is_model_fully_compressed(self, compression_controller: CompressionAlgorithmController) -> bool:
return compression_controller.compression_stage() == CompressionStage.FULLY_COMPRESSED

@abstractmethod
def add_tensorboard_scalar(self, key, data, step):
def add_tensorboard_scalar(self, key: str, data: float, step: int) -> None:
"""
Add a scalar to tensorboard

Expand All @@ -343,7 +351,7 @@ def add_tensorboard_scalar(self, key, data, step):
"""

@abstractmethod
def add_tensorboard_image(self, key, data, step):
def add_tensorboard_image(self, key: str, data: PIL.Image.Image, step: int) -> None:
"""
Add an image to tensorboard

Expand Down Expand Up @@ -375,7 +383,7 @@ def _load_checkpoint(self, model: TModel, checkpoint_path: str) -> None:
"""

@abstractmethod
def _make_checkpoint_path(self, is_best, compression_rate=None):
def _make_checkpoint_path(self, is_best: bool, compression_rate: float = None) -> str:
"""
Make a path to save the checkpoint there

Expand Down Expand Up @@ -423,15 +431,15 @@ def __init__(
self.maximal_compression_rate = maximal_compression_rate

self._best_checkpoints = {}
self._compression_rate_target = None
self.adaptive_controller = None
self.was_compression_increased_on_prev_step = None
self._compression_rate_target: Optional[float] = None
self.adaptive_controller: Optional[CompressionAlgorithmController] = None
self.was_compression_increased_on_prev_step: Optional[bool] = None

def dump_statistics(self, model, compression_controller):
def dump_statistics(self, model: TModel, compression_controller: CompressionAlgorithmController) -> None:
self.update_training_history(self.compression_rate_target, self.current_val_metric_value)
super().dump_statistics(model, compression_controller)

def _save_best_checkpoint(self, model, compression_controller):
def _save_best_checkpoint(self, model: TModel, compression_controller: CompressionAlgorithmController) -> None:
best_path = self._make_checkpoint_path(is_best=True, compression_rate=self.compression_rate_target)

accuracy_budget = self.best_val_metric_value - self.minimal_tolerable_accuracy
Expand All @@ -445,7 +453,7 @@ def _save_best_checkpoint(self, model, compression_controller):
self._save_checkpoint(model, compression_controller, best_path)
nncf_logger.info(f"Saved the best model to {best_path}")

def load_best_checkpoint(self, model):
def load_best_checkpoint(self, model: TModel) -> float:
# load checkpoint with the highest compression rate and positive acc budget
possible_checkpoint_rates = self.get_compression_rates_with_positive_acc_budget()
if len(possible_checkpoint_rates) == 0:
Expand Down Expand Up @@ -473,16 +481,16 @@ def load_best_checkpoint(self, model):
return best_checkpoint_compression_rate

@property
def compression_rate_target(self):
def compression_rate_target(self) -> float:
if self._compression_rate_target is None:
return self.adaptive_controller.compression_rate
return self._compression_rate_target

@compression_rate_target.setter
def compression_rate_target(self, value):
def compression_rate_target(self, value: float) -> None:
self._compression_rate_target = value

def update_training_history(self, compression_rate, metric_value):
def update_training_history(self, compression_rate: float, metric_value: float) -> None:
accuracy_budget = metric_value - self.minimal_tolerable_accuracy
self._compressed_training_history.append((compression_rate, accuracy_budget))

Expand All @@ -500,7 +508,7 @@ def update_training_history(self, compression_rate, metric_value):
plt.close(fig)

@property
def compressed_training_history(self):
def compressed_training_history(self) -> Dict[float, float]:
return dict(self._compressed_training_history)

def get_compression_rates_with_positive_acc_budget(self) -> List[float]:
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/accuracy_aware_training/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TrainingLoopStatistics(Statistics):
relative_accuracy_degradation: float
accuracy_budget: float

def to_str(self):
def to_str(self) -> str:
stats_str = (
f"Uncompressed model accuracy: {self.uncompressed_accuracy:.4f}\n"
f"Compressed model accuracy: {self.compressed_accuracy:.4f}\n"
Expand Down
Loading
Loading