diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index eb6848cc..208fab0f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -37,7 +37,7 @@ jobs: strategy: fail-fast: true matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/VERSION b/VERSION index 9c6d6293..fdd3be6d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.6.1 +1.6.2 diff --git a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py index 434c949c..1d8c76aa 100644 --- a/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py +++ b/pytorch_widedeep/self_supervised_training/_base_contrastive_denoising_trainer.py @@ -1,7 +1,9 @@ import os import sys +import json import warnings from abc import ABC, abstractmethod +from pathlib import Path import numpy as np import torch @@ -31,6 +33,11 @@ from pytorch_widedeep.preprocessing.tab_preprocessor import TabPreprocessor +# There is quite a lot of code repetition between the +# BaseContrastiveDenoisingTrainer and the BaseEncoderDecoderTrainer. Given +# how differently they are instantiated I am happy to tolerate this +# repetition. However, if the code base grows, it might be worth refactoring +# this code class BaseContrastiveDenoisingTrainer(ABC): def __init__( self, @@ -96,45 +103,82 @@ def pretrain( ): raise NotImplementedError("Trainer.pretrain method not implemented") - @abstractmethod def save( self, path: str, save_state_dict: bool, + save_optimizer: bool, model_filename: str, ): - raise NotImplementedError("Trainer.save method not implemented") + r"""Saves the model, training and evaluation history (if any) to disk - def _set_loss_fn(self, **kwargs): - if self.loss_type in ["contrastive", "both"]: - temperature = kwargs.get("temperature", 0.1) - reduction = kwargs.get("reduction", "mean") - self.contrastive_loss = InfoNCELoss(temperature, reduction) + Parameters + ---------- + path: str + path to the directory where the model and the feature importance + attribute will be saved. + save_state_dict: bool, default = False + Boolean indicating whether to save directly the model or the + model's state dictionary + save_optimizer: bool, default = False + Boolean indicating whether to save the optimizer or not + model_filename: str, Optional, default = "ed_model.pt" + filename where the model weights will be store + """ - if self.loss_type in ["denoising", "both"]: - lambda_cat = kwargs.get("lambda_cat", 1.0) - lambda_cont = kwargs.get("lambda_cont", 1.0) - reduction = kwargs.get("reduction", "mean") - self.denoising_loss = DenoisingLoss(lambda_cat, lambda_cont, reduction) + self._save_history(path) - def _compute_loss( - self, - g_projs: Optional[Tuple[Tensor, Tensor]], - x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]], - x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]], - ) -> Tensor: - contrastive_loss = ( - self.contrastive_loss(g_projs) - if self.loss_type in ["contrastive", "both"] - else torch.tensor(0.0) + self._save_model_and_optimizer( + path, save_state_dict, save_optimizer, model_filename ) - denoising_loss = ( - self.denoising_loss(x_cat_and_cat_, x_cont_and_cont_) - if self.loss_type in ["denoising", "both"] - else torch.tensor(0.0) + + def _save_history(self, path: str): + # 'history' here refers to both, the training/evaluation history and + # the lr history + save_dir = Path(path) + history_dir = save_dir / "history" + history_dir.mkdir(exist_ok=True, parents=True) + + # the trainer is run with the History Callback by default + with open(history_dir / "train_eval_history.json", "w") as teh: + json.dump(self.history, teh) # type: ignore[attr-defined] + + has_lr_history = any( + [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] ) + if self.lr_scheduler is not None and has_lr_history: + with open(history_dir / "lr_history.json", "w") as lrh: + json.dump(self.lr_history, lrh) # type: ignore[attr-defined] - return contrastive_loss + denoising_loss + def _save_model_and_optimizer( + self, + path: str, + save_state_dict: bool, + save_optimizer: bool, + model_filename: str, + ): + + model_path = Path(path) / model_filename + if save_state_dict and save_optimizer: + torch.save( + { + "model_state_dict": self.cd_model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + }, + model_path, + ) + elif save_state_dict and not save_optimizer: + torch.save(self.cd_model.state_dict(), model_path) + elif not save_state_dict and save_optimizer: + torch.save( + { + "model": self.cd_model, + "optimizer": self.optimizer, # this can be a MultipleOptimizer + }, + model_path, + ) + else: + torch.save(self.cd_model, model_path) def _set_reduce_on_plateau_criterion( self, lr_scheduler, reducelronplateau_criterion @@ -233,6 +277,37 @@ def _set_device_and_num_workers(**kwargs): num_workers = kwargs.get("num_workers", default_num_workers) return device, num_workers + def _set_loss_fn(self, **kwargs): + if self.loss_type in ["contrastive", "both"]: + temperature = kwargs.get("temperature", 0.1) + reduction = kwargs.get("reduction", "mean") + self.contrastive_loss = InfoNCELoss(temperature, reduction) + + if self.loss_type in ["denoising", "both"]: + lambda_cat = kwargs.get("lambda_cat", 1.0) + lambda_cont = kwargs.get("lambda_cont", 1.0) + reduction = kwargs.get("reduction", "mean") + self.denoising_loss = DenoisingLoss(lambda_cat, lambda_cont, reduction) + + def _compute_loss( + self, + g_projs: Optional[Tuple[Tensor, Tensor]], + x_cat_and_cat_: Optional[Tuple[Tensor, Tensor]], + x_cont_and_cont_: Optional[Tuple[Tensor, Tensor]], + ) -> Tensor: + contrastive_loss = ( + self.contrastive_loss(g_projs) + if self.loss_type in ["contrastive", "both"] + else torch.tensor(0.0) + ) + denoising_loss = ( + self.denoising_loss(x_cat_and_cat_, x_cont_and_cont_) + if self.loss_type in ["denoising", "both"] + else torch.tensor(0.0) + ) + + return contrastive_loss + denoising_loss + @staticmethod def _check_model_is_supported(model: ModelWithAttention): if model.__class__.__name__ == "TabPerceiver": diff --git a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py index a2b644e5..fe0aa4a8 100644 --- a/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py +++ b/pytorch_widedeep/self_supervised_training/_base_encoder_decoder_trainer.py @@ -1,7 +1,9 @@ import os import sys +import json import warnings from abc import ABC, abstractmethod +from pathlib import Path import numpy as np import torch @@ -66,7 +68,7 @@ def __init__( def pretrain( self, X_tab: np.ndarray, - X_val: Optional[np.ndarray], + X_tab_val: Optional[np.ndarray], val_split: Optional[float], validation_freq: int, n_epochs: int, @@ -74,14 +76,82 @@ def pretrain( ): raise NotImplementedError("Trainer.pretrain method not implemented") - @abstractmethod def save( self, path: str, save_state_dict: bool, + save_optimizer: bool, model_filename: str, ): - raise NotImplementedError("Trainer.save method not implemented") + r"""Saves the model, training and evaluation history (if any) to disk + + Parameters + ---------- + path: str + path to the directory where the model and the feature importance + attribute will be saved. + save_state_dict: bool, default = False + Boolean indicating whether to save directly the model or the + model's state dictionary + save_optimizer: bool, default = False + Boolean indicating whether to save the optimizer or not + model_filename: str, Optional, default = "ed_model.pt" + filename where the model weights will be store + """ + + self._save_history(path) + + self._save_model_and_optimizer( + path, save_state_dict, save_optimizer, model_filename + ) + + def _save_history(self, path: str): + # 'history' here refers to both, the training/evaluation history and + # the lr history + save_dir = Path(path) + history_dir = save_dir / "history" + history_dir.mkdir(exist_ok=True, parents=True) + + # the trainer is run with the History Callback by default + with open(history_dir / "train_eval_history.json", "w") as teh: + json.dump(self.history, teh) # type: ignore[attr-defined] + + has_lr_history = any( + [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] + ) + if self.lr_scheduler is not None and has_lr_history: + with open(history_dir / "lr_history.json", "w") as lrh: + json.dump(self.lr_history, lrh) # type: ignore[attr-defined] + + def _save_model_and_optimizer( + self, + path: str, + save_state_dict: bool, + save_optimizer: bool, + model_filename: str, + ): + + model_path = Path(path) / model_filename + if save_state_dict and save_optimizer: + torch.save( + { + "model_state_dict": self.ed_model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + }, + model_path, + ) + elif save_state_dict and not save_optimizer: + torch.save(self.ed_model.state_dict(), model_path) + elif not save_state_dict and save_optimizer: + torch.save( + { + "model": self.ed_model, + "optimizer": self.optimizer, # this can be a MultipleOptimizer + }, + model_path, + ) + else: + torch.save(self.ed_model, model_path) def _set_reduce_on_plateau_criterion( self, lr_scheduler, reducelronplateau_criterion diff --git a/pytorch_widedeep/self_supervised_training/contrastive_denoising_trainer.py b/pytorch_widedeep/self_supervised_training/contrastive_denoising_trainer.py index 99475edd..15c4d6ad 100644 --- a/pytorch_widedeep/self_supervised_training/contrastive_denoising_trainer.py +++ b/pytorch_widedeep/self_supervised_training/contrastive_denoising_trainer.py @@ -1,6 +1,3 @@ -import json -from pathlib import Path - import numpy as np import torch from tqdm import trange @@ -259,46 +256,6 @@ def fit( X_tab, X_tab_val, val_split, validation_freq, n_epochs, batch_size ) - def save( - self, - path: str, - save_state_dict: bool = False, - model_filename: str = "cd_model.pt", - ): - r"""Saves the model, training and evaluation history (if any) to disk - - Parameters - ---------- - path: str - path to the directory where the model and the feature importance - attribute will be saved. - save_state_dict: bool, default = False - Boolean indicating whether to save directly the model or the - model's state dictionary - model_filename: str, Optional, default = "cd_model.pt" - filename where the model weights will be store - """ - save_dir = Path(path) - history_dir = save_dir / "history" - history_dir.mkdir(exist_ok=True, parents=True) - - # the trainer is run with the History Callback by default - with open(history_dir / "train_eval_history.json", "w") as teh: - json.dump(self.history, teh) # type: ignore[attr-defined] - - has_lr_history = any( - [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] - ) - if self.lr_scheduler is not None and has_lr_history: - with open(history_dir / "lr_history.json", "w") as lrh: - json.dump(self.lr_history, lrh) # type: ignore[attr-defined] - - model_path = save_dir / model_filename - if save_state_dict: - torch.save(self.cd_model.state_dict(), model_path) - else: - torch.save(self.cd_model, model_path) - def _train_step(self, X_tab: Tensor, batch_idx: int) -> float: X = X_tab.to(self.device) @@ -337,7 +294,7 @@ def _train_eval_split( train_set = TensorDataset(torch.from_numpy(X)) eval_set = TensorDataset(torch.from_numpy(X_tab_val)) elif val_split is not None: - X_tr, X_tab_val = train_test_split( + X_tr, X_tab_val = train_test_split( # type: ignore X, test_size=val_split, random_state=self.seed ) train_set = TensorDataset(torch.from_numpy(X_tr)) diff --git a/pytorch_widedeep/self_supervised_training/encoder_decoder_trainer.py b/pytorch_widedeep/self_supervised_training/encoder_decoder_trainer.py index dcce9158..6b2ce440 100644 --- a/pytorch_widedeep/self_supervised_training/encoder_decoder_trainer.py +++ b/pytorch_widedeep/self_supervised_training/encoder_decoder_trainer.py @@ -1,6 +1,3 @@ -import json -from pathlib import Path - import numpy as np import torch from tqdm import trange @@ -211,46 +208,6 @@ def fit( X_tab, X_tab_val, val_split, validation_freq, n_epochs, batch_size ) - def save( - self, - path: str, - save_state_dict: bool = False, - model_filename: str = "ed_model.pt", - ): - r"""Saves the model, training and evaluation history (if any) to disk - - Parameters - ---------- - path: str - path to the directory where the model and the feature importance - attribute will be saved. - save_state_dict: bool, default = False - Boolean indicating whether to save directly the model or the - model's state dictionary - model_filename: str, Optional, default = "ed_model.pt" - filename where the model weights will be store - """ - save_dir = Path(path) - history_dir = save_dir / "history" - history_dir.mkdir(exist_ok=True, parents=True) - - # the trainer is run with the History Callback by default - with open(history_dir / "train_eval_history.json", "w") as teh: - json.dump(self.history, teh) # type: ignore[attr-defined] - - has_lr_history = any( - [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] - ) - if self.lr_scheduler is not None and has_lr_history: - with open(history_dir / "lr_history.json", "w") as lrh: - json.dump(self.lr_history, lrh) # type: ignore[attr-defined] - - model_path = save_dir / model_filename - if save_state_dict: - torch.save(self.ed_model.state_dict(), model_path) - else: - torch.save(self.ed_model, model_path) - def explain(self, X_tab: np.ndarray, save_step_masks: bool = False): raise NotImplementedError( "The 'explain' is currently not implemented for Self Supervised Pretraining" @@ -294,7 +251,7 @@ def _train_eval_split( train_set = TensorDataset(torch.from_numpy(X)) eval_set = TensorDataset(torch.from_numpy(X_tab_val)) elif val_split is not None: - X_tr, X_tab_val = train_test_split( + X_tr, X_tab_val = train_test_split( # type: ignore X, test_size=val_split, random_state=self.seed ) train_set = TensorDataset(torch.from_numpy(X_tr)) diff --git a/pytorch_widedeep/training/_base_trainer.py b/pytorch_widedeep/training/_base_trainer.py index 36cc17f3..971031e4 100644 --- a/pytorch_widedeep/training/_base_trainer.py +++ b/pytorch_widedeep/training/_base_trainer.py @@ -1,7 +1,9 @@ import os import sys +import json import warnings from abc import ABC, abstractmethod +from pathlib import Path import numpy as np import torch @@ -108,6 +110,7 @@ def save( self, path: str, save_state_dict: bool, + save_optimizer: bool, model_filename: str, ): raise NotImplementedError("Trainer.save method not implemented") @@ -318,6 +321,61 @@ def _set_callbacks_and_metrics( self.callback_container.set_model(self.model) self.callback_container.set_trainer(self) + def _save_history(self, path: str): + # 'history' here refers to both, the training/evaluation history and + # the lr history + save_dir = Path(path) + history_dir = save_dir / "history" + history_dir.mkdir(exist_ok=True, parents=True) + + # the trainer is run with the History Callback by default + with open(history_dir / "train_eval_history.json", "w") as teh: + json.dump(self.history, teh) # type: ignore[attr-defined] + + has_lr_history = any( + [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] + ) + if self.lr_scheduler is not None and has_lr_history: + with open(history_dir / "lr_history.json", "w") as lrh: + json.dump(self.lr_history, lrh) # type: ignore[attr-defined] + + def _save_model_and_optimizer( + self, + path: str, + save_state_dict: bool, + save_optimizer: bool, + model_filename: str, + ): + + model_path = Path(path) / model_filename + if save_state_dict and save_optimizer: + torch.save( + { + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": ( + self.optimizer.state_dict() + if not isinstance(self.optimizer, MultipleOptimizer) + else { + k: v.state_dict() # type: ignore[union-attr] + for k, v in self.optimizer._optimizers.items() + } + ), + }, + model_path, + ) + elif save_state_dict and not save_optimizer: + torch.save(self.model.state_dict(), model_path) + elif not save_state_dict and save_optimizer: + torch.save( + { + "model": self.model, + "optimizer": self.optimizer, # this can be a MultipleOptimizer + }, + model_path, + ) + else: + torch.save(self.model, model_path) + @staticmethod def _check_inputs( model, diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 85b5d591..13617bc4 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -467,21 +467,15 @@ def fit( # noqa: C901 self.transforms, **lds_args, ) - if isinstance(custom_dataloader, type): - if issubclass(custom_dataloader, DataLoader): - train_loader = custom_dataloader( # type: ignore[misc] - dataset=train_set, - batch_size=batch_size, - num_workers=self.num_workers, - **dataloader_args, - ) - else: - NotImplementedError( - "Custom DataLoader must be a subclass of " - "torch.utils.data.DataLoader, please see the " - "pytorch documentation or examples in " - "pytorch_widedeep.dataloaders" - ) + if custom_dataloader is not None: + # make sure is callable (and HAS to be an subclass of DataLoader) + assert isinstance(custom_dataloader, type) + train_loader = custom_dataloader( # type: ignore[misc] + dataset=train_set, + batch_size=batch_size, + num_workers=self.num_workers, + **dataloader_args, + ) else: train_loader = DataLoaderDefault( dataset=train_set, @@ -794,6 +788,7 @@ def save( self, path: str, save_state_dict: bool = False, + save_optimizer: bool = False, model_filename: str = "wd_model.pt", ): r"""Saves the model, training and evaluation history, and the @@ -822,35 +817,23 @@ def save( path to the directory where the model and the feature importance attribute will be saved. save_state_dict: bool, default = False - Boolean indicating whether to save directly the model or the - model's state dictionary + Boolean indicating whether to save directly the model + (and optimizer) or the model's (and optimizer's) state + dictionary + save_optimizer: bool, default = False + Boolean indicating whether to save the optimizer model_filename: str, Optional, default = "wd_model.pt" filename where the model weights will be store """ - save_dir = Path(path) - history_dir = save_dir / "history" - history_dir.mkdir(exist_ok=True, parents=True) - - # the trainer is run with the History Callback by default - with open(history_dir / "train_eval_history.json", "w") as teh: - json.dump(self.history, teh) # type: ignore[attr-defined] + self._save_history(path) - has_lr_history = any( - [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] + self._save_model_and_optimizer( + path, save_state_dict, save_optimizer, model_filename ) - if self.lr_scheduler is not None and has_lr_history: - with open(history_dir / "lr_history.json", "w") as lrh: - json.dump(self.lr_history, lrh) # type: ignore[attr-defined] - - model_path = save_dir / model_filename - if save_state_dict: - torch.save(self.model.state_dict(), model_path) - else: - torch.save(self.model, model_path) if self.model.is_tabnet: - with open(save_dir / "feature_importance.json", "w") as fi: + with open(Path(path) / "feature_importance.json", "w") as fi: json.dump(self.feature_importance, fi) @alias("n_epochs", ["finetune_epochs", "warmup_epochs"]) diff --git a/pytorch_widedeep/training/trainer_from_folder.py b/pytorch_widedeep/training/trainer_from_folder.py index fc58e4c7..e2159b71 100644 --- a/pytorch_widedeep/training/trainer_from_folder.py +++ b/pytorch_widedeep/training/trainer_from_folder.py @@ -1,6 +1,3 @@ -import json -from pathlib import Path - import numpy as np import torch import torch.nn.functional as F @@ -408,28 +405,29 @@ def save( self, path: str, save_state_dict: bool = False, + save_optimizer: bool = False, model_filename: str = "wd_model.pt", ): # pragma: no cover - save_dir = Path(path) - history_dir = save_dir / "history" - history_dir.mkdir(exist_ok=True, parents=True) - - # the trainer is run with the History Callback by default - with open(history_dir / "train_eval_history.json", "w") as teh: - json.dump(self.history, teh) # type: ignore[attr-defined] + """ + Parameters + ---------- + path: str + path to the directory where the model and the feature importance + attribute will be saved. + save_state_dict: bool, default = False + Boolean indicating whether to save directly the model + (and optimizer) or the model's (and optimizer's) state + dictionary + save_optimizer: bool, default = False + Boolean indicating whether to save the optimizer + model_filename: str, Optional, default = "wd_model.pt" + filename where the model weights will be store + """ + self._save_history(path) - has_lr_history = any( - [clbk.__class__.__name__ == "LRHistory" for clbk in self.callbacks] + self._save_model_and_optimizer( + path, save_state_dict, save_optimizer, model_filename ) - if self.lr_scheduler is not None and has_lr_history: - with open(history_dir / "lr_history.json", "w") as lrh: - json.dump(self.lr_history, lrh) # type: ignore[attr-defined] - - model_path = save_dir / model_filename - if save_state_dict: - torch.save(self.model.state_dict(), model_path) - else: - torch.save(self.model, model_path) @alias("n_epochs", ["finetune_epochs", "warmup_epochs"]) @alias("max_lr", ["finetune_max_lr", "warmup_max_lr"]) diff --git a/pytorch_widedeep/version.py b/pytorch_widedeep/version.py index f49459c7..51bbb3f2 100644 --- a/pytorch_widedeep/version.py +++ b/pytorch_widedeep/version.py @@ -1 +1 @@ -__version__ = "1.6.1" +__version__ = "1.6.2" diff --git a/tests/test_model_functioning/test_save_optimizer.py b/tests/test_model_functioning/test_save_optimizer.py new file mode 100644 index 00000000..6376639f --- /dev/null +++ b/tests/test_model_functioning/test_save_optimizer.py @@ -0,0 +1,165 @@ +import os +import shutil + +import numpy as np +import torch +import pandas as pd +import pytest + +from pytorch_widedeep import Trainer +from pytorch_widedeep.models import Wide, TabMlp, WideDeep +from pytorch_widedeep.metrics import Accuracy +from pytorch_widedeep.preprocessing import TabPreprocessor, WidePreprocessor + +full_path = os.path.realpath(__file__) +path = os.path.split(full_path)[0] +save_path = os.path.join(path, "test_save_optimizer_dir") + + +data = { + "categorical_1": ["a", "b", "c", "d"] * 16, + "categorical_2": ["e", "f", "g", "h"] * 16, + "continuous_1": [1, 2, 3, 4] * 16, + "continuous_2": [5, 6, 7, 8] * 16, + "target": [0, 1] * 32, +} + +df = pd.DataFrame(data) + + +cat_cols = ["categorical_1", "categorical_2"] +wide_preprocessor = WidePreprocessor(wide_cols=cat_cols) +X_wide = wide_preprocessor.fit_transform(df) + +tab_preprocessor = TabPreprocessor( + cat_embed_cols=cat_cols, + continuous_cols=["continuous_1", "continuous_2"], + scale=True, +) +X_tab = tab_preprocessor.fit_transform(df) + +wide = Wide(input_dim=np.unique(X_wide).shape[0], pred_dim=1) + +tab_mlp = TabMlp( + column_idx=tab_preprocessor.column_idx, + cat_embed_input=tab_preprocessor.cat_embed_input, + continuous_cols=["continuous_1", "continuous_2"], + mlp_hidden_dims=[16, 8], +) + + +@pytest.mark.parametrize("save_state_dict", [True, False]) +def test_save_one_optimizer(save_state_dict): + + model = WideDeep(wide=wide, deeptabular=tab_mlp) + + trainer = Trainer( + model, + objective="binary", + optimizer=torch.optim.AdamW(model.parameters(), lr=0.001), + metrics=[Accuracy()], + ) + + trainer.fit(X_wide=X_wide, X_tab=X_tab, target=df["target"].values, n_epochs=1) + + trainer.save( + path=save_path, + save_state_dict=save_state_dict, + save_optimizer=True, + model_filename="model_and_optimizer.pt", + ) + + checkpoint = torch.load(os.path.join(save_path, "model_and_optimizer.pt")) + + if save_state_dict: + new_model = WideDeep(wide=wide, deeptabular=tab_mlp) + # just to change the initial weights + new_model.wide.wide_linear.weight.data = torch.nn.init.xavier_normal_( + new_model.wide.wide_linear.weight + ) + new_optimizer = torch.optim.AdamW(new_model.parameters(), lr=0.001) + new_model.load_state_dict(checkpoint["model_state_dict"]) + new_optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + else: + # This else statement is mostly testing that it runs, as it does not + # involved loading a state_dict + saved_objects = torch.load(os.path.join(save_path, "model_and_optimizer.pt")) + new_model = saved_objects["model"] + new_optimizer = saved_objects["optimizer"] + + shutil.rmtree(save_path) + + assert torch.all( + new_model.wide.wide_linear.weight.data == model.wide.wide_linear.weight.data + ) and torch.all( + new_optimizer.state_dict()["state"][1]["exp_avg"] + == trainer.optimizer.state_dict()["state"][1]["exp_avg"] + ) + + +@pytest.mark.parametrize("save_state_dict", [True, False]) +def test_save_multiple_optimizers(save_state_dict): + + model = WideDeep(wide=wide, deeptabular=tab_mlp) + + wide_opt = torch.optim.AdamW(model.wide.parameters(), lr=0.001) + deep_opt = torch.optim.AdamW(model.deeptabular.parameters(), lr=0.001) + + optimizers = {"wide": wide_opt, "deeptabular": deep_opt} + + trainer = Trainer( + model, + objective="binary", + optimizers=optimizers, + metrics=[Accuracy()], + ) + + trainer.fit(X_wide=X_wide, X_tab=X_tab, target=df["target"].values, n_epochs=1) + + trainer.save( + path=save_path, + save_state_dict=save_state_dict, + save_optimizer=True, + model_filename="model_and_optimizer.pt", + ) + + checkpoint = torch.load(os.path.join(save_path, "model_and_optimizer.pt")) + + if save_state_dict: + new_model = WideDeep(wide=wide, deeptabular=tab_mlp) + # just to change the initial weights + new_model.wide.wide_linear.weight.data = torch.nn.init.xavier_normal_( + new_model.wide.wide_linear.weight + ) + + new_wide_opt = torch.optim.AdamW(model.wide.parameters(), lr=0.001) + new_deep_opt = torch.optim.AdamW(model.deeptabular.parameters(), lr=0.001) + new_model.load_state_dict(checkpoint["model_state_dict"]) + new_wide_opt.load_state_dict(checkpoint["optimizer_state_dict"]["wide"]) + new_deep_opt.load_state_dict(checkpoint["optimizer_state_dict"]["deeptabular"]) + else: + # This else statement is mostly testing that it runs, as it does not + # involved loading a state_dict + saved_objects = torch.load(os.path.join(save_path, "model_and_optimizer.pt")) + new_model = saved_objects["model"] + new_optimizers = saved_objects["optimizer"] + new_wide_opt = new_optimizers._optimizers["wide"] + new_deep_opt = new_optimizers._optimizers["deeptabular"] + + shutil.rmtree(save_path) + + assert ( + torch.all( + new_model.wide.wide_linear.weight.data == model.wide.wide_linear.weight.data + ) + and torch.all( + new_wide_opt.state_dict()["state"][1]["exp_avg"] + == trainer.optimizer._optimizers["wide"].state_dict()["state"][1]["exp_avg"] + ) + and torch.all( + new_deep_opt.state_dict()["state"][1]["exp_avg"] + == trainer.optimizer._optimizers["deeptabular"].state_dict()["state"][1][ + "exp_avg" + ] + ) + ) diff --git a/tests/test_self_supervised/test_ss_miscellaneous.py b/tests/test_self_supervised/test_ss_miscellaneous.py index f6f2d54b..244f14ac 100644 --- a/tests/test_self_supervised/test_ss_miscellaneous.py +++ b/tests/test_self_supervised/test_ss_miscellaneous.py @@ -1,6 +1,7 @@ import os import shutil import string +from copy import deepcopy import numpy as np import torch @@ -103,7 +104,12 @@ def test_save_and_load(model_type): embed_module = model.cat_embed.embed embeddings = embed_module.weight.data - trainer.save("tests/test_self_supervised/model_dir/", model_filename="ss_model.pt") + trainer.save( + path="tests/test_self_supervised/model_dir/", + save_optimizer=False, + save_state_dict=False, + model_filename="ss_model.pt", + ) new_model = torch.load("tests/test_self_supervised/model_dir/ss_model.pt") if model_type == "mlp": @@ -117,6 +123,98 @@ def test_save_and_load(model_type): assert torch.allclose(embeddings, new_embeddings) +@pytest.mark.parametrize( + "model_type", + ["encoder_decoder", "contrastive_denoising"], +) +@pytest.mark.parametrize( + "save_state_dict", + [True, False], +) +def test_save_model_and_optimizer(model_type, save_state_dict): + if model_type == "encoder_decoder": + model = TabMlp( + column_idx=non_transf_preprocessor.column_idx, + cat_embed_input=non_transf_preprocessor.cat_embed_input, + continuous_cols=non_transf_preprocessor.continuous_cols, + mlp_hidden_dims=[16, 8], + ) + X = X_tab + elif model_type == "contrastive_denoising": + model = TabTransformer( + column_idx=transf_preprocessor.column_idx, + cat_embed_input=transf_preprocessor.cat_embed_input, + continuous_cols=transf_preprocessor.continuous_cols, + embed_continuous=True, + embed_continuous_method="standard", + n_heads=2, + n_blocks=2, + ) + X = X_tab_transf + + if model_type == "encoder_decoder": + trainer = EncoderDecoderTrainer( + encoder=model, + callbacks=[LRHistory(n_epochs=5)], + masked_prob=0.2, + verbose=0, + ) + elif model_type == "contrastive_denoising": + trainer = ContrastiveDenoisingTrainer( + model=model, + preprocessor=transf_preprocessor, + callbacks=[LRHistory(n_epochs=5)], + verbose=0, + ) + + trainer.pretrain(X, n_epochs=2, batch_size=16) + + trainer.save( + path="tests/test_self_supervised/model_dir/", + save_optimizer=True, + save_state_dict=save_state_dict, + model_filename="model_and_optimizer.pt", + ) + + checkpoint = torch.load( + os.path.join("tests/test_self_supervised/model_dir/", "model_and_optimizer.pt") + ) + + if save_state_dict: + if model_type == "encoder_decoder": + new_model = deepcopy(trainer.ed_model) + # just to change some weights + new_model.encoder.cat_embed.embed_layers.emb_layer_col1.weight.data = ( + torch.nn.init.xavier_normal_( + new_model.encoder.cat_embed.embed_layers.emb_layer_col1.weight + ) + ) + new_optimizer = torch.optim.AdamW(new_model.parameters()) + + new_model.load_state_dict(checkpoint["model_state_dict"]) + new_optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + else: + # Best unit test ever! but this is to avoid the "Only Tensors + # created explicitly by the user (graph leaves) support the + # deepcopy protocol at the moment" error + return True + else: + # This else statement is mostly testing that it runs, as it does not + # involved loading a state_dict + saved_objects = torch.load( + os.path.join( + "tests/test_self_supervised/model_dir/", "model_and_optimizer.pt" + ) + ) + new_optimizer = saved_objects["optimizer"] + + shutil.rmtree("tests/test_self_supervised/model_dir/") + assert torch.all( + new_optimizer.state_dict()["state"][1]["exp_avg"] + == trainer.optimizer.state_dict()["state"][1]["exp_avg"] + ) + + def _build_model_and_trainer(model_type): if model_type == "mlp": model = TabMlp( @@ -170,6 +268,7 @@ def test_save_and_load_dict(model_type): # noqa: C901 "tests/test_self_supervised/model_dir/", model_filename="ss_model.pt", save_state_dict=True, + save_optimizer=False, ) model2, trainer2 = _build_model_and_trainer(model_type)