Skip to content

Commit

Permalink
utils logic
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Jan 17, 2024
1 parent b4000b5 commit d7b5c4b
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 62 deletions.
33 changes: 16 additions & 17 deletions rdt/transformers/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from rdt.errors import TransformerInputError
from rdt.transformers.base import BaseTransformer
from rdt.transformers.utils import fill_nan_with_none
from rdt.transformers.utils import check_nan_in_transform, fill_nan_with_none

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, order_by=None):
)

self.order_by = order_by
self.is_integer = None
self._is_integer = False

def _order_categories(self, unique_data):
nans = pd.isna(unique_data)
Expand Down Expand Up @@ -123,7 +123,7 @@ def _fit(self, data):
Data to fit the transformer to.
"""
self.dtype = data.dtypes
self.is_integer = pd.api.types.is_integer_dtype(self.dtype)
self._is_integer = pd.api.types.is_integer_dtype(self.dtype)
data = fill_nan_with_none(data)
labels = pd.unique(data)
labels = self._order_categories(labels)
Expand Down Expand Up @@ -179,11 +179,11 @@ def _reverse_transform(self, data):
Returns:
pandas.Series
"""
convert_to_float = check_nan_in_transform(data, self._is_integer)
data = data.clip(0, 1)
bins = [0]
labels = []
nan_name = 'NaN'
convert_to_float = False
while nan_name in self.intervals.keys():
nan_name += '_'

Expand All @@ -194,19 +194,6 @@ def _reverse_transform(self, data):
else:
labels.append(key)

if pd.isna(data).any():
message = (
'There are null values in the transformed data. The reversed '
'transformed data will contain null values'
)
if self.is_integer:
message += " of type 'float'."
convert_to_float = True
else:
message += '.'

warnings.warn(message)

result = pd.cut(data, bins=bins, labels=labels, include_lowest=True)
result = result.replace(nan_name, np.nan)

Expand Down Expand Up @@ -271,6 +258,7 @@ def _fit(self, data):
Data to fit the transformer to.
"""
self.dtype = data.dtypes
self._is_integer = pd.api.types.is_integer_dtype(self.dtype)
data = fill_nan_with_none(data)
self._check_unknown_categories(data)

Expand Down Expand Up @@ -354,6 +342,7 @@ def __init__(self, add_noise=False):
)
super().__init__()
self.add_noise = add_noise
self._is_integer = False

@staticmethod
def _get_intervals(data):
Expand Down Expand Up @@ -423,6 +412,7 @@ def _fit(self, data):
Data to fit the transformer to.
"""
self.dtype = data.dtype
self._is_integer = pd.api.types.is_integer_dtype(self.dtype)
self.intervals, self.means, self.starts = self._get_intervals(data)

@staticmethod
Expand Down Expand Up @@ -537,6 +527,7 @@ def _reverse_transform(self, data):
Returns:
pandas.Series
"""
check_nan_in_transform(data, self._is_integer)
data = data.clip(0, 1)
num_rows = len(data)
num_categories = len(self.means)
Expand Down Expand Up @@ -732,6 +723,7 @@ def __init__(self, add_noise=False, order_by=None):
)

self.order_by = order_by
self._is_integer = False

def _order_categories(self, unique_data):
if self.order_by == 'alphabetical':
Expand Down Expand Up @@ -766,6 +758,7 @@ def _fit(self, data):
Data to fit the transformer to.
"""
self.dtype = data.dtype
self._is_integer = pd.api.types.is_integer_dtype(self.dtype)
unique_data = pd.unique(data.fillna(np.nan))
unique_data = self._order_categories(unique_data)
self.values_to_categories = dict(enumerate(unique_data))
Expand Down Expand Up @@ -822,12 +815,16 @@ def _reverse_transform(self, data):
Returns:
pandas.Series
"""
convert_to_float = check_nan_in_transform(data, self._is_integer)
if self.add_noise:
data = np.floor(data)

data = data.clip(min(self.values_to_categories), max(self.values_to_categories))
data = data.round().map(self.values_to_categories)

if convert_to_float:
return data.astype(float)

return data.astype(self.dtype)


Expand Down Expand Up @@ -886,6 +883,8 @@ def _fit(self, data):
data (pandas.Series):
Data to fit the transformer to.
"""
self.dtype = data.dtype
self._is_integer = pd.api.types.is_integer_dtype(self.dtype)
data = data.fillna(np.nan)
missing = list(data[~data.isin(self.order)].unique())
if len(missing) > 0:
Expand Down
32 changes: 32 additions & 0 deletions rdt/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import re
import string
import warnings

import numpy as np
import pandas as pd

import sre_parse # isort:skip

Expand Down Expand Up @@ -184,3 +186,33 @@ def flatten_column_list(column_list):
flattened.append(column)

return flattened


def check_nan_in_transform(data, is_integer=False):
"""Check if there are null values in the transformed data.
Args:
data (pd.Series or numpy.ndarray):
Data that has been transformed.
is_integer (bool):
Indicates if the initial data was integer.
Returns:
bool:
Indicates if the transformed data has to be converted to float.
"""
convert_to_float = False
if pd.isna(data).any():
message = (
'There are null values in the transformed data. The reversed '
'transformed data will contain null values'
)
if is_integer:
message += " of type 'float'."
convert_to_float = True
else:
message += '.'

warnings.warn(message)

return convert_to_float
100 changes: 56 additions & 44 deletions tests/unit/transformers/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,12 @@ def test__transform_user_warning(self):
assert transformed.iloc[4] >= 0
assert transformed.iloc[4] < 1

def test__reverse_transform(self):
@patch('rdt.transformers.categorical.check_nan_in_transform')
def test__reverse_transform(self, mock_check_nan):
"""Test the ``_reverse_transform``."""
# Setup
data = pd.Series([1, 2, 3, 2, 2, 1, 3, 3, 2])
mock_check_nan.return_value = False
transformer = UniformEncoder()
transformer.dtype = np.int64
transformer.frequencies = {
Expand All @@ -295,11 +297,14 @@ def test__reverse_transform(self):

# Asserts
pd.testing.assert_series_equal(output, data)
mock_check_nan.assert_called_once()

def test__reverse_transform_nans(self):
@patch('rdt.transformers.categorical.check_nan_in_transform')
def test__reverse_transform_nans(self, mock_check_nan):
"""Test ``_reverse_transform`` for data with NaNs."""
# Setup
data = pd.Series(['a', 'b', 'NaN', np.nan, 'NaN', 'b', 'b', 'a', 'b', np.nan])
mock_check_nan.return_value = False
transformer = UniformEncoder()
transformer.dtype = object
transformer.frequencies = {
Expand All @@ -323,51 +328,24 @@ def test__reverse_transform_nans(self):
# Asserts
pd.testing.assert_series_equal(output, data)

def test__reverse_transform_nan_in_transform(self):
"""Test ``_reverse_transform`` when there is NaNs in the transform data.
def test__reverse_transform_integer_and_nans(self):
"""Test the ``reverse_transform`` method with integers and nans.
In this situation, a warning should be raised. If the data was integer, it should be
converted to float.
Test that the method correctly reverse transforms the data
when the initial data is integers and the transformed data has nans.
"""
# Setup
object_categories = ['a', 'b', 'c']
integer_categories = [1, 2, 3]
frequencies = [0.2, 0.3, 0.5]
intervals = [[0, 0.2], [0.2, 0.5], [0.5, 1]]

transformer_object = UniformEncoder()
transformer_object.dtype = object

transformer_integer = UniformEncoder()
transformer_integer.dtype = np.int64
transformer_integer.is_integer = True

transformer_object.frequencies = dict(zip(object_categories, frequencies))
transformer_object.intervals = dict(zip(object_categories, intervals))
transformer_integer.frequencies = dict(zip(integer_categories, frequencies))
transformer_integer.intervals = dict(zip(integer_categories, intervals))

transformed = pd.Series([0.1026, 0.1651, np.nan, 0.3116, 0.6546, 0.8541, 0.7041])
transformer = UniformEncoder()
transformer._is_integer = True
transformer.frequencies = {11: 0.2, 12: 0.3, 13: 0.5}
transformer.intervals = {11: [0, 0.2], 12: [0.2, 0.5], 13: [0.5, 1]}
data = pd.Series([0.1, 0.25, np.nan, 0.65])

# Run
expected_message = (
'There are null values in the transformed data. The reversed '
'transformed data will contain null values'
)
expected_message_object = expected_message + '.'
expected_message_integer = expected_message + " of type 'float'."
with pytest.warns(UserWarning, match=expected_message_object):
output_object = transformer_object._reverse_transform(transformed)

with pytest.warns(UserWarning, match=expected_message_integer):
output_integer = transformer_integer._reverse_transform(transformed)

# Asserts
expected_output_object = pd.Series(['a', 'a', np.nan, 'b', 'c', 'c', 'c'])
expected_output_integer = pd.Series([1, 1, np.nan, 2, 3, 3, 3])
out = transformer._reverse_transform(data)

pd.testing.assert_series_equal(output_object, expected_output_object)
pd.testing.assert_series_equal(output_integer, expected_output_integer)
# Assert
pd.testing.assert_series_equal(out, pd.Series([11, 12, np.nan, 13]))


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -764,7 +742,8 @@ def test__get_value_add_noise_true(self, norm_mock):
# Asserts
assert result == 0.2745

def test__reverse_transform_series(self):
@patch('rdt.transformers.categorical.check_nan_in_transform')
def test__reverse_transform_series(self, mock_check_nan):
"""Test reverse_transform a pandas Series"""
# Setup
data = pd.Series(['foo', 'bar', 'bar', 'foo', 'foo', 'tar'])
Expand All @@ -776,6 +755,10 @@ def test__reverse_transform_series(self):
result = transformer._reverse_transform(rt_data)

# Asserts
mock_input_data = mock_check_nan.call_args.args[0]
mock_input_boolean = mock_check_nan.call_args.args[1]
pd.testing.assert_series_equal(mock_input_data, rt_data)
assert mock_input_boolean is False
expected_intervals = {
'foo': (
0,
Expand Down Expand Up @@ -1169,7 +1152,8 @@ def test__reverse_transform_by_row_called(self):
np.testing.assert_array_equal(reverse_arg, data.clip(0, 1))
assert reverse == categorical_transformer_mock._reverse_transform_by_row.return_value

def test__reverse_transform_by_row(self):
@patch('rdt.transformers.categorical.check_nan_in_transform')
def test__reverse_transform_by_row(self, mock_check_nan):
"""Test the _reverse_transform_by_row method with numerical data.
Expect that the transformed data is correctly reverse transformed.
Expand Down Expand Up @@ -1202,6 +1186,10 @@ def test__reverse_transform_by_row(self):
reverse = transformer._reverse_transform(transformed)

# Assert
mock_input_data = mock_check_nan.call_args.args[0]
mock_input_boolean = mock_check_nan.call_args.args[1]
pd.testing.assert_series_equal(mock_input_data, transformed)
assert mock_input_boolean is False
pd.testing.assert_series_equal(data, reverse)


Expand Down Expand Up @@ -2214,7 +2202,8 @@ def test__reverse_transform_clips_values(self):
# Assert
pd.testing.assert_series_equal(out, pd.Series(['a', 'b', 'c']))

def test__reverse_transform_add_noise(self):
@patch('rdt.transformers.categorical.check_nan_in_transform')
def test__reverse_transform_add_noise(self, mock_check_nan):
"""Test the ``_reverse_transform`` method with ``add_noise``.
Test that the method correctly reverse transforms the data
Expand All @@ -2229,12 +2218,35 @@ def test__reverse_transform_add_noise(self):
transformer = LabelEncoder(add_noise=True)
transformer.values_to_categories = {0: 'a', 1: 'b', 2: 'c'}
data = pd.Series([0.5, 1.0, 10.9])
mock_check_nan.return_value = False

# Run
out = transformer._reverse_transform(data)

# Assert
pd.testing.assert_series_equal(out, pd.Series(['a', 'b', 'c']))
mock_input_data = mock_check_nan.call_args.args[0]
mock_input_boolean = mock_check_nan.call_args.args[1]
pd.testing.assert_series_equal(mock_input_data, data)
assert mock_input_boolean is False

def test__reverse_transform_integer_and_nans(self):
"""Test the ``reverse_transform`` method with integers and nans.
Test that the method correctly reverse transforms the data
when the initial data is integers and the transformed data has nans.
"""
# Setup
transformer = LabelEncoder()
transformer._is_integer = True
transformer.values_to_categories = {0: 11, 1: 12, 2: 13}
data = pd.Series([0, 1, np.nan])

# Run
out = transformer._reverse_transform(data)

# Assert
pd.testing.assert_series_equal(out, pd.Series([11, 12, np.nan]))


class TestOrderedLabelEncoder:
Expand Down
Loading

0 comments on commit d7b5c4b

Please sign in to comment.