-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LogScaler transformer #932
Closed
Closed
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
f1dbfa5
LogScaler class and some tests
rwedge 1f7541e
update LogScaler docstring
rwedge d1c054c
test _transform
rwedge 3f1bf3d
add more tests
rwedge 1c5289d
lint
rwedge 57e2a94
set new constant value for LogScaler test
rwedge c7414ec
use int64 minimum instead
rwedge d5bc2ed
add more integration tests
rwedge 8c41306
fix typo
rwedge d63495b
make helper functions for transform and reverse transform
rwedge 3c3b211
update test docstrings
rwedge 2e33bc0
validate constant and invert params
rwedge e13e166
accept int values for constant
rwedge a1b7753
move validation into _log_transform
rwedge 9ab83b6
rename reversed to reverse_values
rwedge File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
import pandas as pd | ||
import scipy | ||
|
||
from rdt.errors import TransformerInputError | ||
from rdt.errors import InvalidDataError, TransformerInputError | ||
from rdt.transformers.base import BaseTransformer | ||
from rdt.transformers.null import NullTransformer | ||
from rdt.transformers.utils import learn_rounding_digits | ||
|
@@ -626,3 +626,122 @@ def _reverse_transform(self, data): | |
recovered_data = np.stack([recovered_data, data[:, -1]], axis=1) # noqa: PD013 | ||
|
||
return super()._reverse_transform(recovered_data) | ||
|
||
|
||
class LogScaler(FloatFormatter): | ||
"""Transformer for numerical data using log. | ||
|
||
This transformer scales numerical values using log and an optional constant. | ||
|
||
Null values are replaced using a ``NullTransformer``. | ||
|
||
Args: | ||
missing_value_replacement (object): | ||
Indicate what to replace the null values with. If an integer or float is given, | ||
replace them with the given value. If the strings ``'mean'`` or ``'mode'`` | ||
are given, replace them with the corresponding aggregation and if ``'random'`` | ||
replace each null value with a random value in the data range. Defaults to ``mean``. | ||
missing_value_generation (str or None): | ||
The way missing values are being handled. There are three strategies: | ||
|
||
* ``random``: Randomly generates missing values based on the percentage of | ||
missing values. | ||
* ``from_column``: Creates a binary column that describes whether the original | ||
value was missing. Then use it to recreate missing values. | ||
* ``None``: Do nothing with the missing values on the reverse transform. Simply | ||
pass whatever data we get through. | ||
constant (float): | ||
The constant to set as the 0-value for the log-based transform. Defaults to 0 | ||
(do not modify the 0-value of the data). | ||
invert (bool): | ||
Whether to invert the data with respect to the constant value. If False, do not | ||
invert the data (all values will be greater than the constant value). If True, | ||
invert the data (all the values will be less than the constant value). | ||
Defaults to False. | ||
learn_rounding_scheme (bool): | ||
Whether or not to learn what place to round to based on the data seen during ``fit``. | ||
If ``True``, the data returned by ``reverse_transform`` will be rounded to that place. | ||
Defaults to ``False``. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
missing_value_replacement='mean', | ||
missing_value_generation='random', | ||
constant: float = 0.0, | ||
invert: bool = False, | ||
learn_rounding_scheme: bool = False, | ||
): | ||
if isinstance(constant, (int, float)): | ||
self.constant = constant | ||
else: | ||
raise ValueError('The constant parameter must be a float or int.') | ||
if isinstance(invert, bool): | ||
self.invert = invert | ||
else: | ||
raise ValueError('The invert parameter must be a bool.') | ||
|
||
super().__init__( | ||
missing_value_replacement=missing_value_replacement, | ||
missing_value_generation=missing_value_generation, | ||
learn_rounding_scheme=learn_rounding_scheme, | ||
) | ||
|
||
def _validate_data(self, data: pd.Series): | ||
column_name = self.get_input_column() | ||
if self.invert: | ||
if not all(data < self.constant): | ||
raise InvalidDataError( | ||
f"Unable to apply a log transform to column '{column_name}' due to constant" | ||
' being too small.' | ||
) | ||
else: | ||
if not all(data > self.constant): | ||
raise InvalidDataError( | ||
f"Unable to apply a log transform to column '{column_name}' due to constant" | ||
' being too large.' | ||
) | ||
|
||
def _fit(self, data): | ||
super()._fit(data) | ||
data = super()._transform(data) | ||
|
||
if data.ndim > 1: | ||
self._validate_data(data[:, 0]) | ||
else: | ||
self._validate_data(data) | ||
|
||
def _log_transform(self, data): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can move self._validate_data here as well, no? |
||
self._validate_data(data) | ||
|
||
if self.invert: | ||
return np.log(self.constant - data) | ||
else: | ||
return np.log(data - self.constant) | ||
|
||
def _transform(self, data): | ||
data = super()._transform(data) | ||
|
||
if data.ndim > 1: | ||
data[:, 0] = self._log_transform(data[:, 0]) | ||
else: | ||
data = self._log_transform(data) | ||
|
||
return data | ||
|
||
def _reverse_log(self, data): | ||
if self.invert: | ||
return self.constant - np.exp(data) | ||
else: | ||
return np.exp(data) + self.constant | ||
|
||
def _reverse_transform(self, data): | ||
if not isinstance(data, np.ndarray): | ||
data = data.to_numpy() | ||
|
||
if data.ndim > 1: | ||
data[:, 0] = self._reverse_log(data[:, 0]) | ||
else: | ||
data = self._reverse_log(data) | ||
|
||
return super()._reverse_transform(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
ClusterBasedNormalizer, | ||
FloatFormatter, | ||
GaussianNormalizer, | ||
LogScaler, | ||
) | ||
|
||
|
||
|
@@ -560,3 +561,61 @@ def test_out_of_bounds_reverse_transform(self): | |
|
||
# Assert | ||
assert isinstance(reverse, pd.DataFrame) | ||
|
||
|
||
class TestLogScaler: | ||
def test_learn_rounding(self): | ||
"""Test that transformer learns rounding scheme from data.""" | ||
# Setup | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add short docstrings to these tests. |
||
data = pd.DataFrame({'test': [1.0, np.nan, 1.5]}) | ||
transformer = LogScaler( | ||
missing_value_generation=None, | ||
missing_value_replacement='mean', | ||
learn_rounding_scheme=True, | ||
) | ||
expected = pd.DataFrame({'test': [1.0, 1.2, 1.5]}) | ||
|
||
# Run | ||
transformer.fit(data, 'test') | ||
transformed = transformer.transform(data) | ||
reversed_values = transformer.reverse_transform(transformed) | ||
|
||
# Assert | ||
np.testing.assert_array_equal(reversed_values, expected) | ||
|
||
def test_missing_value_generation_from_column(self): | ||
"""Test from_column missing value generation with nans present.""" | ||
# Setup | ||
data = pd.DataFrame({'test': [1.0, np.nan, 1.5]}) | ||
transformer = LogScaler( | ||
missing_value_generation='from_column', | ||
missing_value_replacement='mean', | ||
) | ||
|
||
# Run | ||
transformer.fit(data, 'test') | ||
transformed = transformer.transform(data) | ||
reversed_values = transformer.reverse_transform(transformed) | ||
|
||
# Assert | ||
np.testing.assert_array_equal(reversed_values, data) | ||
|
||
def test_missing_value_generation_random(self): | ||
"""Test random missing_value_generation with nans present.""" | ||
# Setup | ||
data = pd.DataFrame({'test': [1.0, np.nan, 1.5, 1.5]}) | ||
transformer = LogScaler( | ||
missing_value_generation='random', | ||
missing_value_replacement='mode', | ||
invert=True, | ||
constant=3.0, | ||
) | ||
expected = pd.DataFrame({'test': [np.nan, 1.5, 1.5, 1.5]}) | ||
|
||
# Run | ||
transformer.fit(data, 'test') | ||
transformed = transformer.transform(data) | ||
reversed_values = transformer.reverse_transform(transformed) | ||
|
||
# Assert | ||
np.testing.assert_array_equal(reversed_values, expected) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Default to 0" -> "Defaults to 0".
Also, either add the `` quotation marks around the 0, False, True values here, or remove them from the other values in the docstring, so it's conistent.