-
Notifications
You must be signed in to change notification settings - Fork 292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Made CrossValTypes, HoldoutValTypes to have split functions directly #108
base: master
Are you sure you want to change the base?
[Refactor] Made CrossValTypes, HoldoutValTypes to have split functions directly #108
Conversation
class HoldoutValTypes(Enum): | ||
"""The type of hold out validation (refer to CrossValTypes' doc-string)""" | ||
holdout_validation = partial(HoldoutValFuncs.holdout_validation) | ||
stratified_holdout_validation = partial(HoldoutValFuncs.stratified_holdout_validation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Major change: IntEnum -> Enum and holding functions directly
|
||
def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
self.value(val_share=val_share, indices=indices, stratify=stratify) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we can call the function directly in a way that HoldoutValTypes.holdout_validation()
.
def __call__(self, | ||
num_splits: int, | ||
indices: np.ndarray, | ||
stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]: | ||
... | ||
|
||
|
||
class HoldOutFunc(Protocol): | ||
class HoldoutValFunc(Protocol): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we often use holdout_validator
, I unified the name.
62b326e
to
c7fd2d5
Compare
c7fd2d5
to
a7e8a7f
Compare
1e82b21
to
d313a48
Compare
Since the previous codes had the default shuffle = True and the indices shuffle before splitting, the test cases for CV and Holdout did not match. More specifically, when I bring back the followings, I could reproduce the original outputs: 1. Bring back _get_indices in BaseDataset 2. Make the default value of self.shuffle in BaseDataset True 3. Input shuffle = True in KFold instead of using ShuffleSplit These reproduce the original outputs. Note that KFold(shuffle=True) and ShuffleSplit() are not identical and even when we input the same random_state, the results do not reproduce.
af8059b
to
6ef981d
Compare
indices: np.ndarray, | ||
**kwargs: Any | ||
) -> List[Tuple[np.ndarray, np.ndarray]]: | ||
Additionally, HoldoutValTypes.<function> can be called directly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add an example to use it directly?
|
||
|
||
class CrossValFuncs(): | ||
# (shuffle, is_stratify) -> split_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also have documentation similar to HoldoutFuncs here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, thanks a lot for this PR. I have left a few suggestions. Also, could you state the reason for making this PR. What issues were there in the previous implementation? How does this PR solve them?
While maintaining the changes as small as possible, I made the changes.