Skip to content

Commit

Permalink
param ordering with defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
ksaur committed Apr 5, 2023
1 parent 0711d50 commit 866422a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class OneHotEncoderString(PhysicalOperator, torch.nn.Module):
Because we are dealing with tensors, strings require additional length information for processing.
"""

def __init__(self, logical_operator, categories, handle_unknown, device, infrequent=None, extra_config={}):
def __init__(self, logical_operator, categories, device, extra_config={}, handle_unknown='error', infrequent=None):
super(OneHotEncoderString, self).__init__(logical_operator, transformer=True)

self.num_columns = len(categories)
Expand Down Expand Up @@ -113,7 +113,7 @@ class OneHotEncoder(PhysicalOperator, torch.nn.Module):
Class implementing OneHotEncoder operators for ints in PyTorch.
"""

def __init__(self, logical_operator, categories, handle_unknown, device, infrequent=None):
def __init__(self, logical_operator, categories, device, handle_unknown='error', infrequent=None):
super(OneHotEncoder, self).__init__(logical_operator, transformer=True)

self.num_columns = len(categories)
Expand Down
9 changes: 5 additions & 4 deletions hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config):
]
):
categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_]
return OneHotEncoderString(operator, categories, operator.raw_operator.handle_unknown,
device, infrequent, extra_config)
return OneHotEncoderString(operator, categories, device, extra_config=extra_config,
handle_unknown=operator.raw_operator.handle_unknown,
infrequent=infrequent)
else:
return OneHotEncoder(operator, operator.raw_operator.categories_, operator.raw_operator.handle_unknown,
device, infrequent)
return OneHotEncoder(operator, operator.raw_operator.categories_, device,
handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent)


register_converter("SklearnOneHotEncoder", convert_sklearn_one_hot_encoder)

0 comments on commit 866422a

Please sign in to comment.