Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni committed Jul 8, 2024
1 parent fd50434 commit 4c11961
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(

self._sequence_index = self.metadata.sequence_index
self.context_columns = context_columns or []
self._validate_sequence_key_and_context_columns(self._sequence_key, self.context_columns)
self._validate_sequence_key_and_context_columns()
self._extra_context_columns = {}
self.extended_columns = {}
self.segment_size = segment_size
Expand Down Expand Up @@ -195,7 +195,7 @@ def add_custom_constraint_class(self, class_object, class_name):
"""Error that tells the user custom constraints can't be used in the ``PARSynthesizer``."""
raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.')

def _validate_sequence_key_and_context_columns(self, sequence_key, context_columns):
def _validate_sequence_key_and_context_columns(self):
"""Check that the sequence key is not present in the context colums.
Args:
Expand All @@ -204,9 +204,9 @@ def _validate_sequence_key_and_context_columns(self, sequence_key, context_colum
context_columns (list[str]):
A list of strings, representing the columns that do not vary in a sequence.
"""
if set(sequence_key).intersection(set(context_columns)):
if set(self._sequence_key).intersection(set(self.context_columns)):
raise SynthesizerInputError(
f'The sequence key {sequence_key} cannot be a context column. '
f'The sequence key {self._sequence_key} cannot be a context column. '
'To proceed, please remove the sequence key from the context_columns parameter.'
)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def test_par_sequence_index_is_numerical():
assert sample.columns.to_list() == data.columns.to_list()


def test_par_error_on_context_columns():
def test_init_error_sequence_key_in_context():
# Setup
metadata_dict = {
'columns': {
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sequential/test_par.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,7 @@ def test_load(self, mock_file, cloudpickle_mock):
cloudpickle_mock.load.assert_called_once_with(mock_file.return_value)
assert loaded_instance == synthesizer_mock

def test__par_error_on_context_columns(self):
def test___init___error_sequence_key_in_context(self):
"""Test that the sequence_key is not a context column"""
# Setup
metadata = self.get_metadata(add_sequence_key=True)
Expand Down

0 comments on commit 4c11961

Please sign in to comment.