From d313a487ef21fdeae5c0264672e57ffe845085ce Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Mon, 10 May 2021 12:29:49 +0900 Subject: [PATCH] [refactor] Update the split functions to be able to call function directly --- autoPyTorch/datasets/resampling_strategy.py | 330 ++++++++++---------- 1 file changed, 159 insertions(+), 171 deletions(-) diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index a1e599dd6..f6e6ae570 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -1,5 +1,6 @@ -from enum import IntEnum -from typing import Any, Dict, List, Optional, Tuple, Union +from enum import Enum +from functools import partial +from typing import List, NamedTuple, Optional, Tuple, Union import numpy as np @@ -12,187 +13,69 @@ train_test_split ) -from typing_extensions import Protocol +from torch.utils.data import Dataset -# Use callback protocol as workaround, since callable with function fields count 'self' as argument -class CrossValFunc(Protocol): - def __call__(self, - random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]: - ... +class _ResamplingStrategyArgs(NamedTuple): + val_share: float = 0.33 + num_splits: int = 5 + shuffle: bool = False + stratify: bool = False -class HoldOutFunc(Protocol): - def __call__(self, random_state: np.random.RandomState, val_share: float, - indices: np.ndarray, stratify: Optional[Any] - ) -> Tuple[np.ndarray, np.ndarray]: - ... - - -class CrossValTypes(IntEnum): - """The type of cross validation - - This class is used to specify the cross validation function - and is not supposed to be instantiated. - - Examples: This class is supposed to be used as follows - >>> cv_type = CrossValTypes.k_fold_cross_validation - >>> print(cv_type.name) - - k_fold_cross_validation - - >>> for cross_val_type in CrossValTypes: - print(cross_val_type.name, cross_val_type.value) - - stratified_k_fold_cross_validation 1 - k_fold_cross_validation 2 - stratified_shuffle_split_cross_validation 3 - shuffle_split_cross_validation 4 - time_series_cross_validation 5 - """ - stratified_k_fold_cross_validation = 1 - k_fold_cross_validation = 2 - stratified_shuffle_split_cross_validation = 3 - shuffle_split_cross_validation = 4 - time_series_cross_validation = 5 - - def is_stratified(self) -> bool: - stratified = [self.stratified_k_fold_cross_validation, - self.stratified_shuffle_split_cross_validation] - return getattr(self, self.name) in stratified - - -class HoldoutValTypes(IntEnum): - """TODO: change to enum using functools.partial""" - """The type of hold out validation (refer to CrossValTypes' doc-string)""" - holdout_validation = 6 - stratified_holdout_validation = 7 - - def is_stratified(self) -> bool: - stratified = [self.stratified_holdout_validation] - return getattr(self, self.name) in stratified - - -# TODO: replace it with another way -RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes] - -DEFAULT_RESAMPLING_PARAMETERS = { - HoldoutValTypes.holdout_validation: { - 'val_share': 0.33, - }, - HoldoutValTypes.stratified_holdout_validation: { - 'val_share': 0.33, - }, - CrossValTypes.k_fold_cross_validation: { - 'num_splits': 5, - }, - CrossValTypes.stratified_k_fold_cross_validation: { - 'num_splits': 5, - }, - CrossValTypes.shuffle_split_cross_validation: { - 'num_splits': 5, - }, - CrossValTypes.time_series_cross_validation: { - 'num_splits': 5, - }, -} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] - - -class HoldOutFuncs(): +class HoldoutFuncs(): @staticmethod - def holdout_validation(random_state: np.random.RandomState, - val_share: float, - indices: np.ndarray, - **kwargs: Any - ) -> Tuple[np.ndarray, np.ndarray]: - shuffle = kwargs.get('shuffle', True) - train, val = train_test_split(indices, test_size=val_share, - shuffle=shuffle, - random_state=random_state if shuffle else None, - ) + def holdout_validation( + random_state: np.random.RandomState, + val_share: float, + indices: np.ndarray, + shuffle: bool = False, + labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None + ): + + train, val = train_test_split( + indices, test_size=val_share, shuffle=shuffle, + random_state=random_state if shuffle else None, + stratify=labels_to_stratify + ) return train, val - @staticmethod - def stratified_holdout_validation(random_state: np.random.RandomState, - val_share: float, - indices: np.ndarray, - **kwargs: Any - ) -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"], - random_state=random_state) - return train, val - - @classmethod - def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str, HoldOutFunc]: - - holdout_validators = { - holdout_val_type.name: getattr(cls, holdout_val_type.name) - for holdout_val_type in holdout_val_types - } - return holdout_validators - class CrossValFuncs(): - @staticmethod - def shuffle_split_cross_validation(random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - **kwargs: Any - ) -> List[Tuple[np.ndarray, np.ndarray]]: - cv = ShuffleSplit(n_splits=num_splits, random_state=random_state) - splits = list(cv.split(indices)) - return splits - - @staticmethod - def stratified_shuffle_split_cross_validation(random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - **kwargs: Any - ) -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedShuffleSplit(n_splits=num_splits, random_state=random_state) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits - - @staticmethod - def stratified_k_fold_cross_validation(random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - **kwargs: Any - ) -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedKFold(n_splits=num_splits, random_state=random_state) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits + # (shuffle, is_stratify) -> split_fn + _args2split_fn = { + (True, True): StratifiedShuffleSplit, + (True, False): ShuffleSplit, + (False, True): StratifiedKFold, + (False, False): KFold, + } @staticmethod - def k_fold_cross_validation(random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - **kwargs: Any - ) -> List[Tuple[np.ndarray, np.ndarray]]: + def k_fold_cross_validation( + random_state: np.random.RandomState, + num_splits: int, + indices: np.ndarray, + shuffle: bool = False, + labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None + ) -> List[Tuple[np.ndarray, np.ndarray]]: """ - Standard k fold cross validation. - - Args: - indices (np.ndarray): array of indices to be split - num_splits (int): number of cross validation splits - Returns: splits (List[Tuple[List, List]]): list of tuples of training and validation indices """ - shuffle = kwargs.get('shuffle', True) - cv = KFold(n_splits=num_splits, random_state=random_state if shuffle else None, shuffle=shuffle) + + split_fn = CrossValFuncs._args2split_fn[(shuffle, labels_to_stratify is not None)] + cv = split_fn(n_splits=num_splits, random_state=random_state) splits = list(cv.split(indices)) return splits @staticmethod - def time_series_cross_validation(random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - **kwargs: Any - ) -> List[Tuple[np.ndarray, np.ndarray]]: + def time_series( + random_state: np.random.RandomState, + num_splits: int, + indices: np.ndarray, + shuffle: bool = False, + labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None + ) -> List[Tuple[np.ndarray, np.ndarray]]: """ Returns train and validation indices respecting the temporal ordering of the data. @@ -215,10 +98,115 @@ def time_series_cross_validation(random_state: np.random.RandomState, splits = list(cv.split(indices)) return splits - @classmethod - def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]: - cross_validators = { - cross_val_type.name: getattr(cls, cross_val_type.name) - for cross_val_type in cross_val_types - } - return cross_validators + +class CrossValTypes(Enum): + """The type of cross validation + + This class is used to specify the cross validation function + and is not supposed to be instantiated. + + Examples: This class is supposed to be used as follows + >>> cv_type = CrossValTypes.k_fold_cross_validation + >>> print(cv_type.name) + + k_fold_cross_validation + + >>> for cross_val_type in CrossValTypes: + print(cross_val_type.name, cross_val_type.value) + + k_fold_cross_validation functools.partial() + time_series + """ + k_fold_cross_validation = partial(CrossValFuncs.k_fold_cross_validation) + time_series = partial(CrossValFuncs.time_series) + + def __call__( + self, + random_state: np.random.RandomState, + indices: np.ndarray, + num_splits: int = 5, + shuffle: bool = False, + labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None + ) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + This function allows to call and type-check the specified function. + + Args: + random_state (np.random.RandomState): random number genetor for the reproducibility + num_splits (int): The number of splits in cross validation + indices (np.ndarray): The indices of data points in a dataset + shuffle (bool): If shuffle the indices or not + labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]): + The labels of the corresponding data points. It is used for the stratification. + + Returns: + splits (List[Tuple[np.ndarray, np.ndarray]]): + splits[a split identifier][0: train, 1: val][a data point identifier] + + """ + return self.value( + random_state=random_state, + num_splits=num_splits, + indices=indices, + shuffle=shuffle, + labels_to_stratify=labels_to_stratify + ) + + +class HoldoutValTypes(Enum): + """The type of holdout validation + + This class is used to specify the holdout validation function + and is not supposed to be instantiated. + + Examples: This class is supposed to be used as follows + >>> holdout_type = HoldoutValTypes.holdout_validation + >>> print(holdout_type.name) + + holdout_validation + + >>> print(holdout_type.value) + + functools.partial() + + >>> for holdout_type in HoldoutValTypes: + print(holdout_type.name) + + holdout_validation + + Additionally, HoldoutValTypes. can be called directly. + """ + + holdout = partial(HoldoutFuncs.holdout_validation) + + def __call__( + self, + random_state: np.random.RandomState, + indices: np.ndarray, + val_share: float = 0.33, + shuffle: bool = False, + labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None + ) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + This function allows to call and type-check the specified function. + + Args: + random_state (np.random.RandomState): random number genetor for the reproducibility + val_share (float): The ratio of validation dataset vs the given dataset + indices (np.ndarray): The indices of data points in a dataset + shuffle (bool): If shuffle the indices or not + labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]): + The labels of the corresponding data points. It is used for the stratification. + + Returns: + splits (List[Tuple[np.ndarray, np.ndarray]]): + splits[a split identifier][0: train, 1: val][a data point identifier] + + """ + return self.value( + random_state=random_state, + val_share=val_share, + indices=indices, + shuffle=shuffle, + labels_to_stratify=labels_to_stratify + )