Skip to content

Commit

Permalink
LogScaler class and some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rwedge committed Jan 16, 2025
1 parent 0f5e910 commit 53d70fe
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 2 deletions.
2 changes: 2 additions & 0 deletions rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ClusterBasedNormalizer,
FloatFormatter,
GaussianNormalizer,
LogScaler,
)
from rdt.transformers.pii.anonymizer import (
AnonymizedFaker,
Expand All @@ -46,6 +47,7 @@
'FrequencyEncoder',
'GaussianNormalizer',
'LabelEncoder',
'LogScaler',
'NullTransformer',
'OneHotEncoder',
'OptimizedTimestampEncoder',
Expand Down
113 changes: 112 additions & 1 deletion rdt/transformers/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pandas as pd
import scipy

from rdt.errors import TransformerInputError
from rdt.errors import InvalidDataError, TransformerInputError
from rdt.transformers.base import BaseTransformer
from rdt.transformers.null import NullTransformer
from rdt.transformers.utils import learn_rounding_digits
Expand Down Expand Up @@ -626,3 +626,114 @@ def _reverse_transform(self, data):
recovered_data = np.stack([recovered_data, data[:, -1]], axis=1) # noqa: PD013

return super()._reverse_transform(recovered_data)


class LogScaler(FloatFormatter):
"""Transformer for numerical data using log.
This transformer replaces integer values with their float equivalent.
Non null float values are not modified.
Null values are replaced using a ``NullTransformer``.
Args:
missing_value_replacement (object):
Indicate what to replace the null values with. If an integer or float is given,
replace them with the given value. If the strings ``'mean'`` or ``'mode'``
are given, replace them with the corresponding aggregation and if ``'random'``
replace each null value with a random value in the data range. Defaults to ``mean``.
missing_value_generation (str or None):
The way missing values are being handled. There are three strategies:
* ``random``: Randomly generates missing values based on the percentage of
missing values.
* ``from_column``: Creates a binary column that describes whether the original
value was missing. Then use it to recreate missing values.
* ``None``: Do nothing with the missing values on the reverse transform. Simply
pass whatever data we get through.
constant (float):
The constant to set as the 0-value for the log-based transform. Default to 0
(do not modify the 0-value of the data).
invert (bool):
Whether to invert the data with respect to the constant value. If False, do not
invert the data (all values will be greater than the constant value). If True,
invert the data (all the values will be less than the constant value).
Defaults to False.
learn_rounding_scheme (bool):
Whether or not to learn what place to round to based on the data seen during ``fit``.
If ``True``, the data returned by ``reverse_transform`` will be rounded to that place.
Defaults to ``False``.
"""

def __init__(
self,
missing_value_replacement='mean',
missing_value_generation='random',
constant: float = 0,
invert: bool = False,
learn_rounding_scheme: bool = False,
):
self.constant = constant
self.invert = invert
super().__init__(
missing_value_replacement=missing_value_replacement,
missing_value_generation=missing_value_generation,
learn_rounding_scheme=learn_rounding_scheme,
)

def _validate_data(self, data: pd.Series):
column_name = self.get_input_column()
if self.invert:
if not all(data < self.constant):
raise InvalidDataError(

Check warning on line 688 in rdt/transformers/numerical.py

View check run for this annotation

Codecov / codecov/patch

rdt/transformers/numerical.py#L687-L688

Added lines #L687 - L688 were not covered by tests
f"Unable to apply a log transform to column '{column_name}' due to constant"
' being too small.'
)
else:
if not all(data > self.constant):
raise InvalidDataError(

Check warning on line 694 in rdt/transformers/numerical.py

View check run for this annotation

Codecov / codecov/patch

rdt/transformers/numerical.py#L694

Added line #L694 was not covered by tests
f"Unable to apply a log transform to column '{column_name}' due to constant"
' being too large.'
)

def _fit(self, data):
super()._fit(data)
data = super()._transform(data)
if data.ndim > 1:
self._validate_data(data[:, 0])
else:
self._validate_data(data)

def _transform(self, data):
data = super()._transform(data)

if data.ndim > 1:
self._validate_data(data[:, 0])
if self.invert:
data[:, 0] = np.log(self.constant - data[:, 0])

Check warning on line 713 in rdt/transformers/numerical.py

View check run for this annotation

Codecov / codecov/patch

rdt/transformers/numerical.py#L713

Added line #L713 was not covered by tests
else:
data[:, 0] = np.log(data[:, 0] - self.constant)
else:
self._validate_data(data)
if self.invert:
data = np.log(self.constant - data)

Check warning on line 719 in rdt/transformers/numerical.py

View check run for this annotation

Codecov / codecov/patch

rdt/transformers/numerical.py#L719

Added line #L719 was not covered by tests
else:
data = np.log(data - self.constant)
return data

def _reverse_transform(self, data):
if not isinstance(data, np.ndarray):
data = data.to_numpy()

if data.ndim > 1:
if self.invert:
data[:, 0] = self.constant - np.exp(data[:, 0])

Check warning on line 730 in rdt/transformers/numerical.py

View check run for this annotation

Codecov / codecov/patch

rdt/transformers/numerical.py#L730

Added line #L730 was not covered by tests
else:
data[:, 0] = np.exp(data[:, 0]) + self.constant
else:
if self.invert:
data = self.constant - np.exp(data)

Check warning on line 735 in rdt/transformers/numerical.py

View check run for this annotation

Codecov / codecov/patch

rdt/transformers/numerical.py#L735

Added line #L735 was not covered by tests
else:
data = np.exp(data) + self.constant

return super()._reverse_transform(data)
1 change: 1 addition & 0 deletions tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
'FloatFormatter': {'missing_value_generation': 'from_column'},
'GaussianNormalizer': {'missing_value_generation': 'from_column'},
'ClusterBasedNormalizer': {'missing_value_generation': 'from_column'},
'LogScaler': {'constant': -40000000000, 'missing_value_generation': 'from_column'},
}

# Mapping of rdt sdtype to dtype
Expand Down
201 changes: 200 additions & 1 deletion tests/unit/transformers/test_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from copulas import univariate
from pandas.api.types import is_float_dtype

from rdt.errors import TransformerInputError
from rdt.errors import InvalidDataError, TransformerInputError
from rdt.transformers.null import NullTransformer
from rdt.transformers.numerical import (
ClusterBasedNormalizer,
FloatFormatter,
GaussianNormalizer,
LogScaler,
)


Expand Down Expand Up @@ -1863,3 +1864,201 @@ def test__reverse_transform_missing_value_replacement_missing_value_replacement_
call_data,
rtol=1e-1,
)


class TestLogScaler:
def test___init__super_attrs(self):
"""super() arguments are properly passed and set as attributes."""
ls = LogScaler(
missing_value_generation='random',
learn_rounding_scheme=False,
)

assert ls.missing_value_replacement == 'mean'
assert ls.missing_value_generation == 'random'
assert ls.learn_rounding_scheme is False

def test___init__constant(self):
"""Test constant parameter is set as an attribute."""
# Setup
ls_set = LogScaler(constant=2.5)
ls_default = LogScaler()

# Test
assert ls_set.constant == 2.5
assert ls_default.constant == 0.0

def test___init__invert(self):
"""Test invert parameter is set as an attribute."""
# Setup
ls_set = LogScaler(invert=True)
ls_default = LogScaler()

# Test
assert ls_set.invert
assert not ls_default.invert

def test__validate_data(self):
"""Test the ``_validate_data`` method"""
# Setup
ls = LogScaler()
ls.columns = ['test_col']
valid_data = pd.Series([1, 2, 3])
invalid_data = pd.Series([-1, 2, 4])
message = (
"Unable to apply a log transform to column 'test_col' due to constant being too large."
)
# Run and Assert
ls._validate_data(valid_data)

with pytest.raises(InvalidDataError, match=message):
ls._validate_data(invalid_data)

def test__validate_data_invert(self):
"""Test the ``_validate_data`` method"""
# Setup
ls = LogScaler(invert=True)
ls.columns = ['test']
valid_data = pd.Series([-1, -2, -3])
invalid_data = pd.Series([-1, 2, 4])
message = (
"Unable to apply a log transform to column 'test' due to constant being too small."
)

# Run and Assert
ls._validate_data(valid_data)

with pytest.raises(InvalidDataError, match=message):
ls._validate_data(invalid_data)

@patch('rdt.transformers.LogScaler._validate_data')
def test__fit(self, mock_validate):
"""Test the ``_fit`` method."""
# Setup
data = pd.Series([0.5, np.nan, 1.0])
ls = LogScaler()

# Run
ls._fit(data)

# Assert
mock_validate.assert_called_once()
call_value = mock_validate.call_args_list[0]
np.testing.assert_array_equal(call_value[0][0], np.array([0.5, 0.75, 1.0]))

def test__transform(self):
"""Test the ``_transform`` method."""
# Setup
ls = LogScaler()
ls.fit(pd.DataFrame({'test': [0.25, 0.5, 0.75]}), 'test')
data = pd.DataFrame({'test': [0.1, 1.0, 2.0]})

# Run
transformed_data = ls.transform(data)

# Assert
expected = np.array([-2.30259, 0, 0.69314])
np.testing.assert_allclose(transformed_data['test'], expected, rtol=1e-3)

def test__transform_invert(self):
"""Test the ``_transform`` method with ``invert=True``"""
# Setup
ls = LogScaler(constant=3, invert=True)
ls.fit(pd.DataFrame({'test': [0.25, 0.5, 0.75]}), 'test')
data = pd.DataFrame({'test': [0.1, 1.0, 2.0]})

# Run
transformed_data = ls.transform(data)

# Assert
expected = np.array([1.06471, 0.69315, 0])
np.testing.assert_allclose(transformed_data['test'], expected, rtol=1e-3)

def test__transform_invalid_data(self):
# Setup
ls = LogScaler()
ls.fit(pd.DataFrame({'test': [0.25, 0.5, 0.75]}), 'test')
data = pd.DataFrame({'test': [-0.1, 1.0, 2.0]})
message = (
"Unable to apply a log transform to column 'test' due to constant being too large."
)

# Run and Assert
with pytest.raises(InvalidDataError, match=message):
ls.transform(data)

def test__transform_missing_value_generation_is_random(self):
"""Test the ``_transform`` method.
Validate that ``_transform`` produces the correct values when ``missing_value_generation``
is ``random``.
"""
# Setup
data = pd.Series([1.0, 2.0, 1.0])
ls = LogScaler()
ls.columns = ['test']
ls.null_transformer = NullTransformer('mean', missing_value_generation='random')

# Run
ls.null_transformer.fit(data)
transformed_data = ls._transform(data)

# Assert
expected = np.array([0, 0.69315, 0])
np.testing.assert_allclose(transformed_data, expected, rtol=1e-3)

def test__reverse_transform(self):
"""Test the ``_reverse_transform`` method.
Validate that ``_reverse_transform`` produces the correct values when
``missing_value_generation`` is 'from_column'.
"""
# Setup
data = np.array([
[0, 0.6931471805599453, 0],
[0, 0, 1.0],
]).T
expected = pd.Series([1.0, 2.0, np.nan])
ls = LogScaler()
ls.null_transformer = NullTransformer(
missing_value_replacement='mean',
missing_value_generation='from_column',
)

# Run
ls.null_transformer.fit(expected)
transformed_data = ls._reverse_transform(data)

# Assert
np.testing.assert_allclose(transformed_data, expected, rtol=1e-3)

def test__reverse_transform_missing_value_generation(self):
"""Test the ``_reverse_transform`` method.
Validate that ``_reverse_transform`` produces the correct values when
``missing_value_generation`` is 'random'.
"""
# Setup
data = np.array([0, 0.6931471805599453, 0])
expected = pd.Series([1.0, 2.0, 1.0])
ls = LogScaler()
ls.null_transformer = NullTransformer(None, missing_value_generation='random')

# Run
ls.null_transformer.fit(expected)
transformed_data = ls._reverse_transform(data)

# Assert
np.testing.assert_allclose(transformed_data, expected, rtol=1e-3)

def test_print(self, capsys):
"""Test the class can be printed. GH#883"""
# Setup
transformer = LogScaler()

# Run
print(transformer) # noqa: T201 `print` found

# Assert
captured = capsys.readouterr()
assert captured.out == 'LogScaler()\n'

0 comments on commit 53d70fe

Please sign in to comment.