Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
R-Palazzo committed Feb 5, 2025
1 parent 1b64e12 commit 791c7d5
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 16 deletions.
6 changes: 3 additions & 3 deletions sdmetrics/single_table/data_augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _transform_preprocess(self, tables):
Args:
tables (dict[str, pandas.DataFrame]):
The tables to transform.
Dict containing `real_training_data`, `synthetic_data` and `real_validation_data`.
"""
tables_result = {}
for table_name, table in tables.items():
Expand All @@ -82,8 +82,8 @@ def _get_best_threshold(self, train_data, train_target):
"""Find the best threshold for the classifier model."""
target_probabilities = self._classifier.predict_proba(train_data)[:, 1]
precision, recall, thresholds = precision_recall_curve(train_target, target_probabilities)
# To assess the preicision efficacy, we have to fix the recall and reciprocally
metric = precision if self.metric_name == 'recall' else recall
metric_map = {'precision': precision, 'recall': recall}
metric = metric_map[self._metric_to_fix]
best_threshold = 0.0
valid_idx = np.where(metric >= self.fixed_value)[0]
if valid_idx.size:
Expand Down
46 changes: 34 additions & 12 deletions sdmetrics/single_table/data_augmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,64 @@
import pandas as pd


def _validate_parameters(
real_training_data,
synthetic_data,
real_validation_data,
metadata,
prediction_column_name,
classifier,
fixed_recall_value,
):
"""Validate the parameters of the Data Augmentation metrics."""
def _validate_tables(real_training_data, synthetic_data, real_validation_data):
"""Validate the tables of the Data Augmentation metrics."""
tables = [real_training_data, synthetic_data, real_validation_data]
if any(not isinstance(table, pd.DataFrame) for table in tables):
raise ValueError(
'`real_training_data`, `synthetic_data` and `real_validation_data` must be '
'pandas DataFrames.'
)


def _validate_metadata(metadata):
"""Validate the metadata of the Data Augmentation metrics."""
if not isinstance(metadata, dict):
raise TypeError(
f"Expected a dictionary but received a '{type(metadata).__name__}' instead."
" For SDV metadata objects, please use the 'to_dict' function to convert it"
' to a dictionary.'
)


def _validate_prediction_column_name(prediction_column_name):
"""Validate the prediction column name of the Data Augmentation metrics."""
if not isinstance(prediction_column_name, str):
raise TypeError('`prediction_column_name` must be a string.')


def _validate_classifier(classifier):
"""Validate the classifier of the Data Augmentation metrics."""
if classifier is not None and not isinstance(classifier, str):
raise TypeError('`classifier` must be a string or None.')

if classifier != 'XGBoost':
raise ValueError('Currently only `XGBoost` is supported as classifier.')


def _validate_fixed_recall_value(fixed_recall_value):
"""Validate the fixed recall value of the Data Augmentation metrics."""
if not isinstance(fixed_recall_value, (int, float)) or not (0 < fixed_recall_value < 1):
raise TypeError('`fixed_recall_value` must be a float in the range (0, 1).')


def _validate_parameters(
real_training_data,
synthetic_data,
real_validation_data,
metadata,
prediction_column_name,
classifier,
fixed_recall_value,
):
"""Validate the parameters of the Data Augmentation metrics."""
_validate_tables(real_training_data, synthetic_data, real_validation_data)
_validate_metadata(metadata)
_validate_prediction_column_name(prediction_column_name)
_validate_classifier(classifier)
_validate_fixed_recall_value(fixed_recall_value)


def _validate_data_and_metadata(
real_training_data,
synthetic_data,
Expand Down Expand Up @@ -89,10 +110,11 @@ def _validate_data_and_metadata(
synthetic_labels = set(synthetic_data[prediction_column_name].unique())
real_labels = set(real_training_data[prediction_column_name].unique())
if not synthetic_labels.issubset(real_labels):
to_print = "', '".join(sorted(synthetic_labels - real_labels))
raise ValueError(
f'The ``{prediction_column_name}`` column must have the same values in the real '
'and synthetic data. The synthetic data has the following unseen values: '
f'{sorted(synthetic_labels - real_labels)}'
'and synthetic data. The following values are present in the synthetic data and'
f" not the real data: '{to_print}'"
)


Expand Down
4 changes: 3 additions & 1 deletion tests/unit/single_table/data_augmentation/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import pytest
from sklearn.metrics import precision_score, recall_score
from xgboost import XGBClassifier

from sdmetrics.single_table.data_augmentation.base import BaseDataAugmentationMetric

Expand Down Expand Up @@ -103,7 +104,7 @@ def test__fit(self, real_training_data, metadata):
assert metric.fixed_value == fixed_recall_value
assert metric._metric_method == recall_score
assert metric._classifier_name == classifier
# assert metric._classifier == 'XGBClassifier()'
assert isinstance(metric._classifier, XGBClassifier)

@patch('sdmetrics.single_table.data_augmentation.base.precision_recall_curve')
def test__get_best_threshold(self, mock_precision_recall_curve, real_training_data):
Expand All @@ -120,6 +121,7 @@ def test__get_best_threshold(self, mock_precision_recall_curve, real_training_da
np.array([0.02, 0.15, 0.25, 0.35, 0.42, 0.51, 0.63, 0.77, 0.82, 0.93, 0.97]),
]
metric.metric_name = 'recall'
metric._metric_to_fix = 'precision'
metric.fixed_value = 0.69
train_data = real_training_data[['numerical']]
train_target = real_training_data['target']
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/single_table/data_augmentation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def test__validate_data_and_metadata():
'the column `target` for the real validation data. The `precision`and `recall`'
' are undefined for this case.'
)
expected_error_synthetic_wrong_label = re.escape(
'The ``target`` column must have the same values in the real and synthetic data. '
'The following values are present in the synthetic data and not the real'
" data: 'wrong_1', 'wrong_2'"
)

# Run and Assert
_validate_data_and_metadata(**inputs)
Expand Down Expand Up @@ -146,6 +151,11 @@ def test__validate_data_and_metadata():
with pytest.raises(ValueError, match=expected_error_missing_minority):
_validate_data_and_metadata(**missing_minority_class_label_validation)

wrong_synthetic_label = deepcopy(inputs)
wrong_synthetic_label['synthetic_data'] = pd.DataFrame({'target': [0, 1, 'wrong_1', 'wrong_2']})
with pytest.raises(ValueError, match=expected_error_synthetic_wrong_label):
_validate_data_and_metadata(**wrong_synthetic_label)


@patch('sdmetrics.single_table.data_augmentation.utils._validate_parameters')
@patch('sdmetrics.single_table.data_augmentation.utils._validate_data_and_metadata')
Expand Down

0 comments on commit 791c7d5

Please sign in to comment.