Skip to content

Commit

Permalink
[refactor] Update the split functions to be able to call function dir…
Browse files Browse the repository at this point in the history
…ectly
  • Loading branch information
nabenabe0928 committed May 10, 2021
1 parent 3191642 commit d313a48
Showing 1 changed file with 159 additions and 171 deletions.
330 changes: 159 additions & 171 deletions autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple, Union
from enum import Enum
from functools import partial
from typing import List, NamedTuple, Optional, Tuple, Union

import numpy as np

Expand All @@ -12,187 +13,69 @@
train_test_split
)

from typing_extensions import Protocol
from torch.utils.data import Dataset


# Use callback protocol as workaround, since callable with function fields count 'self' as argument
class CrossValFunc(Protocol):
def __call__(self,
random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]:
...
class _ResamplingStrategyArgs(NamedTuple):
val_share: float = 0.33
num_splits: int = 5
shuffle: bool = False
stratify: bool = False


class HoldOutFunc(Protocol):
def __call__(self, random_state: np.random.RandomState, val_share: float,
indices: np.ndarray, stratify: Optional[Any]
) -> Tuple[np.ndarray, np.ndarray]:
...


class CrossValTypes(IntEnum):
"""The type of cross validation
This class is used to specify the cross validation function
and is not supposed to be instantiated.
Examples: This class is supposed to be used as follows
>>> cv_type = CrossValTypes.k_fold_cross_validation
>>> print(cv_type.name)
k_fold_cross_validation
>>> for cross_val_type in CrossValTypes:
print(cross_val_type.name, cross_val_type.value)
stratified_k_fold_cross_validation 1
k_fold_cross_validation 2
stratified_shuffle_split_cross_validation 3
shuffle_split_cross_validation 4
time_series_cross_validation 5
"""
stratified_k_fold_cross_validation = 1
k_fold_cross_validation = 2
stratified_shuffle_split_cross_validation = 3
shuffle_split_cross_validation = 4
time_series_cross_validation = 5

def is_stratified(self) -> bool:
stratified = [self.stratified_k_fold_cross_validation,
self.stratified_shuffle_split_cross_validation]
return getattr(self, self.name) in stratified


class HoldoutValTypes(IntEnum):
"""TODO: change to enum using functools.partial"""
"""The type of hold out validation (refer to CrossValTypes' doc-string)"""
holdout_validation = 6
stratified_holdout_validation = 7

def is_stratified(self) -> bool:
stratified = [self.stratified_holdout_validation]
return getattr(self, self.name) in stratified


# TODO: replace it with another way
RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes]

DEFAULT_RESAMPLING_PARAMETERS = {
HoldoutValTypes.holdout_validation: {
'val_share': 0.33,
},
HoldoutValTypes.stratified_holdout_validation: {
'val_share': 0.33,
},
CrossValTypes.k_fold_cross_validation: {
'num_splits': 5,
},
CrossValTypes.stratified_k_fold_cross_validation: {
'num_splits': 5,
},
CrossValTypes.shuffle_split_cross_validation: {
'num_splits': 5,
},
CrossValTypes.time_series_cross_validation: {
'num_splits': 5,
},
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]


class HoldOutFuncs():
class HoldoutFuncs():
@staticmethod
def holdout_validation(random_state: np.random.RandomState,
val_share: float,
indices: np.ndarray,
**kwargs: Any
) -> Tuple[np.ndarray, np.ndarray]:
shuffle = kwargs.get('shuffle', True)
train, val = train_test_split(indices, test_size=val_share,
shuffle=shuffle,
random_state=random_state if shuffle else None,
)
def holdout_validation(
random_state: np.random.RandomState,
val_share: float,
indices: np.ndarray,
shuffle: bool = False,
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
):

train, val = train_test_split(
indices, test_size=val_share, shuffle=shuffle,
random_state=random_state if shuffle else None,
stratify=labels_to_stratify
)
return train, val

@staticmethod
def stratified_holdout_validation(random_state: np.random.RandomState,
val_share: float,
indices: np.ndarray,
**kwargs: Any
) -> Tuple[np.ndarray, np.ndarray]:
train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"],
random_state=random_state)
return train, val

@classmethod
def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str, HoldOutFunc]:

holdout_validators = {
holdout_val_type.name: getattr(cls, holdout_val_type.name)
for holdout_val_type in holdout_val_types
}
return holdout_validators


class CrossValFuncs():
@staticmethod
def shuffle_split_cross_validation(random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
cv = ShuffleSplit(n_splits=num_splits, random_state=random_state)
splits = list(cv.split(indices))
return splits

@staticmethod
def stratified_shuffle_split_cross_validation(random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
cv = StratifiedShuffleSplit(n_splits=num_splits, random_state=random_state)
splits = list(cv.split(indices, kwargs["stratify"]))
return splits

@staticmethod
def stratified_k_fold_cross_validation(random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
cv = StratifiedKFold(n_splits=num_splits, random_state=random_state)
splits = list(cv.split(indices, kwargs["stratify"]))
return splits
# (shuffle, is_stratify) -> split_fn
_args2split_fn = {
(True, True): StratifiedShuffleSplit,
(True, False): ShuffleSplit,
(False, True): StratifiedKFold,
(False, False): KFold,
}

@staticmethod
def k_fold_cross_validation(random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
def k_fold_cross_validation(
random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
shuffle: bool = False,
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""
Standard k fold cross validation.
Args:
indices (np.ndarray): array of indices to be split
num_splits (int): number of cross validation splits
Returns:
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
"""
shuffle = kwargs.get('shuffle', True)
cv = KFold(n_splits=num_splits, random_state=random_state if shuffle else None, shuffle=shuffle)

split_fn = CrossValFuncs._args2split_fn[(shuffle, labels_to_stratify is not None)]
cv = split_fn(n_splits=num_splits, random_state=random_state)
splits = list(cv.split(indices))
return splits

@staticmethod
def time_series_cross_validation(random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
def time_series(
random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
shuffle: bool = False,
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""
Returns train and validation indices respecting the temporal ordering of the data.
Expand All @@ -215,10 +98,115 @@ def time_series_cross_validation(random_state: np.random.RandomState,
splits = list(cv.split(indices))
return splits

@classmethod
def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]:
cross_validators = {
cross_val_type.name: getattr(cls, cross_val_type.name)
for cross_val_type in cross_val_types
}
return cross_validators

class CrossValTypes(Enum):
"""The type of cross validation
This class is used to specify the cross validation function
and is not supposed to be instantiated.
Examples: This class is supposed to be used as follows
>>> cv_type = CrossValTypes.k_fold_cross_validation
>>> print(cv_type.name)
k_fold_cross_validation
>>> for cross_val_type in CrossValTypes:
print(cross_val_type.name, cross_val_type.value)
k_fold_cross_validation functools.partial(<function CrossValFuncs.k_fold_cross_validation at ...>)
time_series <function CrossValFuncs.time_series>
"""
k_fold_cross_validation = partial(CrossValFuncs.k_fold_cross_validation)
time_series = partial(CrossValFuncs.time_series)

def __call__(
self,
random_state: np.random.RandomState,
indices: np.ndarray,
num_splits: int = 5,
shuffle: bool = False,
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""
This function allows to call and type-check the specified function.
Args:
random_state (np.random.RandomState): random number genetor for the reproducibility
num_splits (int): The number of splits in cross validation
indices (np.ndarray): The indices of data points in a dataset
shuffle (bool): If shuffle the indices or not
labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]):
The labels of the corresponding data points. It is used for the stratification.
Returns:
splits (List[Tuple[np.ndarray, np.ndarray]]):
splits[a split identifier][0: train, 1: val][a data point identifier]
"""
return self.value(
random_state=random_state,
num_splits=num_splits,
indices=indices,
shuffle=shuffle,
labels_to_stratify=labels_to_stratify
)


class HoldoutValTypes(Enum):
"""The type of holdout validation
This class is used to specify the holdout validation function
and is not supposed to be instantiated.
Examples: This class is supposed to be used as follows
>>> holdout_type = HoldoutValTypes.holdout_validation
>>> print(holdout_type.name)
holdout_validation
>>> print(holdout_type.value)
functools.partial(<function HoldoutValTypes.holdout_validation at ...>)
>>> for holdout_type in HoldoutValTypes:
print(holdout_type.name)
holdout_validation
Additionally, HoldoutValTypes.<function> can be called directly.
"""

holdout = partial(HoldoutFuncs.holdout_validation)

def __call__(
self,
random_state: np.random.RandomState,
indices: np.ndarray,
val_share: float = 0.33,
shuffle: bool = False,
labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None
) -> List[Tuple[np.ndarray, np.ndarray]]:
"""
This function allows to call and type-check the specified function.
Args:
random_state (np.random.RandomState): random number genetor for the reproducibility
val_share (float): The ratio of validation dataset vs the given dataset
indices (np.ndarray): The indices of data points in a dataset
shuffle (bool): If shuffle the indices or not
labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]):
The labels of the corresponding data points. It is used for the stratification.
Returns:
splits (List[Tuple[np.ndarray, np.ndarray]]):
splits[a split identifier][0: train, 1: val][a data point identifier]
"""
return self.value(
random_state=random_state,
val_share=val_share,
indices=indices,
shuffle=shuffle,
labels_to_stratify=labels_to_stratify
)

0 comments on commit d313a48

Please sign in to comment.