Skip to content

Commit

Permalink
Fix RelationshipValidity issue
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Dec 6, 2024
1 parent 36e4f90 commit 9ee8ce5
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 7 deletions.
6 changes: 4 additions & 2 deletions sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):
self._enforce_table_size(child_name, table_name, scale, sampled_data)

if child_name not in sampled_data: # Sample based on only 1 parent
for _, row in sampled_data[table_name].iterrows():
for _, row in sampled_data[table_name].astype(object).iterrows():
self._add_child_rows(
child_name=child_name,
parent_name=table_name,
Expand All @@ -219,7 +219,9 @@ def _sample_children(self, table_name, sampled_data, scale=1.0):

if child_name not in sampled_data: # No child rows sampled, force row creation
num_rows_key = f'__{child_name}__{foreign_key}__num_rows'
max_num_child_index = sampled_data[table_name][num_rows_key].idxmax()
max_num_child_index = pd.to_numeric(
sampled_data[table_name][num_rows_key], errors='coerce'
).idxmax()
parent_row = sampled_data[table_name].iloc[max_num_child_index]

self._add_child_rows(
Expand Down
73 changes: 73 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,79 @@ def test_sampling_with_unknown_sdtype_numerical_column(self):
assert all(pd.api.types.is_numeric_dtype(dtype) for dtype in numeric_data)
assert all(dtype == 'object' for dtype in object_data)

def test_large_integer_ids(self):
"""Test that HMASynthesizer can handle large integer IDs correctly GH#919."""
# Setup
table_1 = pd.DataFrame({
'col_1': [1, 2, 3],
'col_3': [7, 8, 9],
'col_2': [4, 5, 6],
})
table_2 = pd.DataFrame({
'col_A': [1, 1, 2],
'col_B': ['d', 'e', 'f'],
'col_C': ['g', 'h', 'i'],
})
metadata = Metadata.load_from_dict({
'tables': {
'table_1': {
'columns': {
'col_1': {'sdtype': 'id', 'regex_format': '[1-9]{17}'},
'col_2': {'sdtype': 'numerical'},
'col_3': {'sdtype': 'numerical'},
},
'primary_key': 'col_1',
},
'table_2': {
'columns': {
'col_A': {'sdtype': 'id', 'regex_format': '[1-9]{17}'},
'col_B': {'sdtype': 'categorical'},
'col_C': {'sdtype': 'categorical'},
},
},
},
'relationships': [
{
'parent_table_name': 'table_1',
'child_table_name': 'table_2',
'parent_primary_key': 'col_1',
'child_foreign_key': 'col_A',
}
],
})
data = {
'table_1': table_1,
'table_2': table_2,
}

# Run
synthesizer = HMASynthesizer(metadata, verbose=False)
synthesizer.fit(data)
synthetic_data = synthesizer.sample()

# Assert
# Check that IDs match the regex pattern
for table_name, table in synthetic_data.items():
for col in table.columns:
if metadata.tables[table_name].columns[col].get('sdtype') == 'id':
values = table[col].astype(str)
assert all(len(str(v)) == 17 for v in values), (
f'ID length mismatch in {table_name}.{col}'
)
assert all(v.isdigit() for v in values), (
f'Non-digit characters in {table_name}.{col}'
)

# Check relationships are preserved
child_fks = set(synthetic_data['table_2']['col_A'])
parent_pks = set(synthetic_data['table_1']['col_1'])
assert child_fks.issubset(parent_pks), 'Foreign key constraint violated'

# Check that the diagnostic report is 1.0
report = DiagnosticReport()
report.generate(data, synthetic_data, metadata.to_dict(), verbose=False)
assert report.get_score() == 1.0


@pytest.mark.parametrize('num_rows', [(10), (1000)])
def test_hma_0_1_child(num_rows):
Expand Down
20 changes: 15 additions & 5 deletions tests/unit/sampling/test_hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows=
call(
child_name='sessions',
parent_name='users',
parent_row=SeriesMatcher(pd.Series({'user_id': 1}, name=0)),
parent_row=SeriesMatcher(pd.Series({'user_id': 1}, name=0, dtype=object)),
sampled_data=result,
),
call(
child_name='sessions',
parent_name='users',
parent_row=SeriesMatcher(pd.Series({'user_id': 3}, name=1)),
parent_row=SeriesMatcher(pd.Series({'user_id': 3}, name=1, dtype=object)),
sampled_data=result,
),
]
Expand Down Expand Up @@ -277,13 +277,20 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows=
instance._null_foreign_key_percentages = {'__sessions__user_id': 0}

# Run
result = {'users': pd.DataFrame({'user_id': [1], '__sessions__user_id__num_rows': [1]})}
result = {
'users': pd.DataFrame({
'user_id': [1],
'__sessions__user_id__num_rows': pd.Series([1], dtype=object),
})
}
BaseHierarchicalSampler._sample_children(
self=instance, table_name='users', sampled_data=result
)

# Assert
expected_parent_row = pd.Series({'user_id': 1, '__sessions__user_id__num_rows': 1}, name=0)
expected_parent_row = pd.Series(
{'user_id': 1, '__sessions__user_id__num_rows': 1}, name=0, dtype=object
)
expected_calls = [
call(
child_name='sessions',
Expand All @@ -300,7 +307,10 @@ def _add_child_rows(child_name, parent_name, parent_row, sampled_data, num_rows=
),
]
expected_result = {
'users': pd.DataFrame({'user_id': [1], '__sessions__user_id__num_rows': [1]}),
'users': pd.DataFrame({
'user_id': [1],
'__sessions__user_id__num_rows': pd.Series([1], dtype=object),
}),
'sessions': pd.DataFrame({
'user_id': [1],
'session_id': ['a'],
Expand Down

0 comments on commit 9ee8ce5

Please sign in to comment.