Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for OneHotEncoder in newer SKL versions #696

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ class OneHotEncoderString(PhysicalOperator, torch.nn.Module):
Because we are dealing with tensors, strings require additional length information for processing.
"""

def __init__(self, logical_operator, categories, device, extra_config={}):
def __init__(self, logical_operator, categories, device, extra_config={}, handle_unknown='error', infrequent=None):
super(OneHotEncoderString, self).__init__(logical_operator, transformer=True)

self.num_columns = len(categories)
self.max_word_length = max([max([len(c) for c in cat]) for cat in categories])
self.handle_unknown = handle_unknown
self.mask = None

# Strings are casted to int32, therefore we need to properly size the tensor to me dividable by 4.
while self.max_word_length % 4 != 0:
Expand Down Expand Up @@ -55,17 +57,55 @@ def __init__(self, logical_operator, categories, device, extra_config={}):
self.condition_tensors = torch.nn.Parameter(torch.IntTensor(condition_tensors), requires_grad=False)
self.categories_idx = categories_idx

if infrequent is not None:
infrequent_tensors = []
categories_idx = [0]
for arr in infrequent:
cats = (
np.array(arr, dtype="|S" + str(self.max_word_length)) # Encode objects into 4 byte strings.
.view("int32")
.reshape(-1, self.max_word_length // 4)
.tolist()
)
# We merge all categories for all columns into a single tensor
infrequent_tensors.extend(cats)
# Since all categories are merged together, we need to track of indexes to retrieve them at inference time.
categories_idx.append(categories_idx[-1] + len(cats))
self.infrequent_tensors = torch.nn.Parameter(torch.IntTensor(infrequent_tensors), requires_grad=False)

# We need to create a mask to filter out infrequent categories.
mask = []
for i in range(len(self.condition_tensors)):
if self.condition_tensors[i] not in self.infrequent_tensors:
mask.append(self.condition_tensors[i])
self.mask = torch.nn.Parameter(torch.tensor([mask]).T, requires_grad=False)
else:
self.infrequent_tensors = None

def forward(self, x):
encoded_tensors = []

# TODO: implement 'error' case separately
if self.handle_unknown == "ignore" or self.handle_unknown == "error":
compare_tensors = self.condition_tensors
elif self.handle_unknown == "infrequent_if_exist":
compare_tensors = self.mask if self.mask is not None else self.condition_tensors
else:
raise RuntimeError("Unsupported handle_unknown setting: {0}".format(self.handle_unknown))

for i in range(self.num_columns):
# First we fetch the condition for the particular column.
conditions = self.condition_tensors[self.categories_idx[i] : self.categories_idx[i + 1], :].view(
conditions = compare_tensors[self.categories_idx[i] : self.categories_idx[i + 1], :].view(
1, -1, self.max_word_length // 4
)
# Differently than the numeric case where eq is enough, here we need to aggregate per object (dim = 2)
# because objects can span multiple integers. We use product here since all ints must match to get encoding of 1.
encoded_tensors.append(torch.prod(torch.eq(x[:, i : i + 1, :], conditions), dim=2))

# if self.infrequent_tensors is not None, then append another tensor that is the "not" of the sum of the encoded tensors.
if self.infrequent_tensors is not None:
encoded_tensors.append(torch.logical_not(torch.sum(torch.stack(encoded_tensors), dim=0)))

return torch.cat(encoded_tensors, dim=1).float()


Expand All @@ -74,19 +114,45 @@ class OneHotEncoder(PhysicalOperator, torch.nn.Module):
Class implementing OneHotEncoder operators for ints in PyTorch.
"""

def __init__(self, logical_operator, categories, device):
def __init__(self, logical_operator, categories, device, handle_unknown='error', infrequent=None):
super(OneHotEncoder, self).__init__(logical_operator, transformer=True)

self.num_columns = len(categories)
self.handle_unknown = handle_unknown
self.mask = None

condition_tensors = []
for arr in categories:
condition_tensors.append(torch.nn.Parameter(torch.LongTensor(arr).detach().clone(), requires_grad=False))
self.condition_tensors = torch.nn.ParameterList(condition_tensors)

if infrequent is not None:
infrequent_tensors = []
for arr in infrequent:
infrequent_tensors.append(torch.nn.Parameter(torch.LongTensor(arr).detach().clone(), requires_grad=False))
self.infrequent_tensors = torch.nn.ParameterList(infrequent_tensors)

# Filter out infrequent categories by creating a mask
self.mask = []
for i in range(len(self.condition_tensors)):
row_mask = []
for j in range(len(self.infrequent_tensors[0])):
if self.condition_tensors[i][j] not in self.infrequent_tensors[i]:
row_mask.append(self.condition_tensors[i][j])
self.mask.append(torch.nn.Parameter(torch.tensor(row_mask), requires_grad=False))
else:
self.infrequent_tensors = None

def forward(self, *x):
encoded_tensors = []

if self.handle_unknown == "ignore" or self.handle_unknown == "error": # TODO: error
compare_tensors = self.condition_tensors
elif self.handle_unknown == "infrequent_if_exist":
compare_tensors = self.mask if self.mask is not None else self.condition_tensors
else:
raise RuntimeError("Unsupported handle_unknown setting: {0}".format(self.handle_unknown))

if len(x) > 1:
assert len(x) == self.num_columns

Expand All @@ -95,14 +161,20 @@ def forward(self, *x):
if input.dtype != torch.int64:
input = input.long()

encoded_tensors.append(torch.eq(input, self.condition_tensors[i]))
encoded_tensors.append(torch.eq(input, compare_tensors[i]))
else:
# This is already a tensor.
x = x[0]
if x.dtype != torch.int64:
x = x.long()

for i in range(self.num_columns):
encoded_tensors.append(torch.eq(x[:, i : i + 1], self.condition_tensors[i]))
curr_column = torch.eq(x[:, i : i + 1], compare_tensors[i])
encoded_tensors.append(curr_column)

# If self.infrequent_tensors is not None, then append another tensor that is
# the logical "not" of the sum of the encoded tensors of the *current* iteration only
if self.infrequent_tensors is not None:
encoded_tensors.append(torch.logical_not(torch.sum(torch.stack([curr_column]), dim=0)))

return torch.cat(encoded_tensors, dim=1).float()
19 changes: 17 additions & 2 deletions hummingbird/ml/operator_converters/sklearn/one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,31 @@ def convert_sklearn_one_hot_encoder(operator, device, extra_config):
"""
assert operator is not None, "Cannot convert None operator"

# scikit-learn >= 1.1 with handle_unknown = 'frequent_if_exist'
if hasattr(operator.raw_operator, "infrequent_categories_"):
infrequent = operator.raw_operator.infrequent_categories_
else:
infrequent = None

# TODO: What to do about min_frequency and max_categories?
# If I understand correctly, they are only used prior to "fit", and we won't need them for inference.
# Both min_frequency and max_categories trigger the creation of the "infrequent" categories, but then
# are not used again. So, we can ignore them for HB....i think?
# see https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/preprocessing/_encoders.py#L178
# and the comment on line 503 same file.

if all(
[
np.array(c).dtype == object or np.array(c).dtype.kind in constants.SUPPORTED_STRING_TYPES
for c in operator.raw_operator.categories_
]
):
categories = [[str(x) for x in c.tolist()] for c in operator.raw_operator.categories_]
return OneHotEncoderString(operator, categories, device, extra_config)
return OneHotEncoderString(operator, categories, device, extra_config=extra_config,
handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent)
else:
return OneHotEncoder(operator, operator.raw_operator.categories_, device)
return OneHotEncoder(operator, operator.raw_operator.categories_, device,
handle_unknown=operator.raw_operator.handle_unknown, infrequent=infrequent)


register_converter("SklearnOneHotEncoder", convert_sklearn_one_hot_encoder)
97 changes: 96 additions & 1 deletion tests/test_sklearn_one_hot_encoder_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import unittest

import numpy as np
import torch
import sklearn
from sklearn.preprocessing import OneHotEncoder
import hummingbird.ml

from packaging.version import Version, parse


class TestSklearnOneHotEncoderConverter(unittest.TestCase):
def test_model_one_hot_encoder_int(self):
Expand Down Expand Up @@ -91,6 +93,99 @@ def test_model_one_hot_encoder_ts_string_not_mod4_len(self):

np.testing.assert_allclose(model.transform(data).todense(), pytorch_model.transform(data), rtol=1e-06, atol=1e-06)

@unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.")
def test_infrequent_if_exists_str(self):

# This test is a copy of the test in sklearn.
# https://github.com/scikit-learn/scikit-learn/blob/
# ecb9a70e82d4ee352e2958c555536a395b53d2bd/sklearn/preprocessing/tests/test_encoders.py#L868

X_train = np.array([["a"] * 5 + ["b"] * 2000 + ["c"] * 10 + ["d"] * 3]).T
model = OneHotEncoder(
categories=[["a", "b", "c", "d"]],
handle_unknown="infrequent_if_exist",
sparse_output=False,
min_frequency=15,

).fit(X_train)
np.testing.assert_array_equal(model.infrequent_categories_, [["a", "c", "d"]])

pytorch_model = hummingbird.ml.convert(model, "torch", device="cpu")
self.assertIsNotNone(pytorch_model)

X_test = [["b"], ["a"], ["c"], ["d"], ["e"]]
expected = np.array([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1]])
orig = model.transform(X_test)
np.testing.assert_allclose(expected, orig)

hb = pytorch_model.transform(X_test)

np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06)

@unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.")
def test_infrequent_if_exists_int(self):

X_train = np.array([[1] * 5 + [2] * 2000 + [3] * 10 + [4] * 3]).T
model = OneHotEncoder(
categories=[[1, 2, 3, 4]],
handle_unknown="infrequent_if_exist",
sparse_output=False,
min_frequency=15,
).fit(X_train)
np.testing.assert_array_equal(model.infrequent_categories_, [[1, 3, 4]])

pytorch_model = hummingbird.ml.convert(model, "torch", device="cpu")
self.assertIsNotNone(pytorch_model)

X_test = [[2], [1], [3], [4], [5]]
expected = np.array([[1, 0], [0, 1], [0, 1], [0, 1], [0, 1]])
orig = model.transform(X_test)
np.testing.assert_allclose(expected, orig)

hb = pytorch_model.transform(X_test)

np.testing.assert_allclose(orig, hb, rtol=1e-06, atol=1e-06)
np.testing.assert_allclose(orig.shape, hb.shape, rtol=1e-06, atol=1e-06)

@unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.")
def test_2d_infrequent(self):

X_train = np.array([[10.0, 1.0]] * 3 + [[14.0, 2.0]] * 2)
ohe = OneHotEncoder(sparse_output=False, handle_unknown="infrequent_if_exist", min_frequency=0.49).fit(X_train)

hb = hummingbird.ml.convert(ohe, "pytorch", device="cpu")

# Quick check on a dataset where all values have been seen during training
np.testing.assert_allclose(ohe.transform(X_train), hb.transform(X_train), rtol=1e-06, atol=1e-06)

# Now check data not seen during training
X_test = np.array([[10.0, 1.0]] * 3 + [[14.0, 3.0]] * 2)
np.testing.assert_allclose(ohe.transform(X_test), hb.transform(X_test), rtol=1e-06, atol=1e-06)

@unittest.skipIf(parse(sklearn.__version__) < Version("1.1"), "Skipping test because sklearn version is too old.")
def test_user_provided_example(self):
pass

# from sklearn.impute import SimpleImputer
# from sklearn.pipeline import Pipeline

# X_train = np.array([[22.0, 1.0, 0.0, 1251.0, 123.0, 124.0, 123.0, 0, 0, 0, 0, 0, 0, 0, 0] * 10
# + [10.0, 1.0, 0.0, 1251.0, 123.0, 124.0, 134.0, 0, 0, 0, 0, 0, 0, 0, 0] * 2
# + [14.0, 1.0, 0.0, 1251.0, 123.0, 124.0, 134.0, 0, 0, 0, 0, 0, 0, 0, 0] * 3
# + [12.0, 2.0, 0.0, 1251.0, 123.0, 124.0, 134.0, 0, 0, 0, 0, 0, 0, 0, 0] * 1])
# pipe = Pipeline(
# [
# ("imputer", SimpleImputer(strategy="most_frequent")),
# ("encoder", OneHotEncoder(sparse_output=False, handle_unknown="infrequent_if_exist", min_frequency=9)),
# ],
# verbose=True,
# ).fit(X_train)

# hb = hummingbird.ml.convert(pipe, "pytorch", device="cpu")

# np.testing.assert_allclose(pipe.transform(X_train), hb.transform(X_train), rtol=1e-06, atol=1e-06)


if __name__ == "__main__":
unittest.main()