Skip to content

Commit

Permalink
Release
Browse files Browse the repository at this point in the history
  • Loading branch information
cerlymarco committed Aug 7, 2021
1 parent cbaea24 commit f0046b5
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 9 deletions.
24 changes: 19 additions & 5 deletions kerashypetune/kerashypetune.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import random
import inspect
import numpy as np
from copy import deepcopy

from .utils import (ParameterSampler, _check_param, _check_data,
_clear_callbacks, _create_fold)
_clear_callbacks, _create_fold, _is_multioutput)


class _KerasSearch:
Expand Down Expand Up @@ -497,12 +498,25 @@ def _search(self,
"Expected cv as cross-validation object with split method to "
"generate indices to split data into training and test set "
"(like from sklearn.model_selection).")
else:
split_args = inspect.signature(self.cv.split).parameters
split_args = {k: v.default for k, v in split_args.items()}
split_need_y = split_args['y'] is not None

_x = _check_data(x)

if y is not None:
_y = _check_data(y, is_target=True)
if _is_multioutput(y) and split_need_y:
raise ValueError(
"{} not supports multioutput.".format(self.cv))
else:
_y = None

_check_data(x)
if y is not None: _check_data(y)
if sample_weight is not None: _check_data(sample_weight)
if sample_weight is not None:
_ = _check_data(sample_weight)

for fold, (train_id, val_id) in enumerate(self.cv.split(x, y, groups)):
for fold, (train_id, val_id) in enumerate(self.cv.split(_x, _y, groups)):

if self.tuner_verbose > 0:
print("\n{}\n{} Fold {} {}\n{}".format(
Expand Down
39 changes: 36 additions & 3 deletions kerashypetune/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ def _check_param(values):

if isinstance(values, (list, tuple, np.ndarray)):
return list(set(values))

elif hasattr(values, 'rvs'):
return values

else:
return [values]

Expand Down Expand Up @@ -44,7 +46,7 @@ def _create_fold(X, ids):
Returns
-------
array/list or arrays/dict of arrays containing fold data.
Data fold.
"""

if isinstance(X, list):
Expand All @@ -57,31 +59,62 @@ def _create_fold(X, ids):
return X[ids]


def _check_data(X):
def _check_data(X, is_target=False):
"""Data controls for cross validation."""

if isinstance(X, list):
data_len = []
for x in X:
if not isinstance(x, np.ndarray):
raise ValueError(
"Received data in list format. Take care to cast each "
"value of the list to numpy array.")
data_len.append(len(x))

if len(set(data_len)) > 1:
raise ValueError("Data must have the same cardinality. "
"Got {}".format(data_len))

elif isinstance(X, dict):
data_len = []
for x in X.values():
if not isinstance(x, np.ndarray):
raise ValueError(
"Received data in dict format. Take care to cast each "
"value of the dict to numpy array.")
data_len.append(len(x))

if len(set(data_len)) > 1:
raise ValueError("Data must have the same cardinality. "
"Got {}".format(data_len))

elif isinstance(X, np.ndarray):
pass
x = X
data_len = [len(x)]

else:
raise ValueError(
"Data format not appropriate for Keras CV search. "
"Supported types are list, dict or numpy array.")

if not is_target:
x = np.zeros(data_len[0])

return x


def _is_multioutput(y):
"""Check if multioutput task."""

if isinstance(y, list):
return len(y) > 1

elif isinstance(y, dict):
return len(y) > 1

else:
return False


class ParameterSampler(object):
# modified from scikit-learn ParameterSampler
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

HERE = pathlib.Path(__file__).parent

VERSION = '0.1.2'
VERSION = '0.1.3'
PACKAGE_NAME = 'keras-hypetune'
AUTHOR = 'Marco Cerliani'
AUTHOR_EMAIL = '[email protected]'
Expand Down

0 comments on commit f0046b5

Please sign in to comment.