diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 6fe4d0965..8401f982c 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -110,7 +110,7 @@ def __init__( self._sequence_index = self.metadata.sequence_index self.context_columns = context_columns or [] - self._validate_sequence_key_and_context_columns(self._sequence_key, self.context_columns) + self._validate_sequence_key_and_context_columns() self._extra_context_columns = {} self.extended_columns = {} self.segment_size = segment_size @@ -195,7 +195,7 @@ def add_custom_constraint_class(self, class_object, class_name): """Error that tells the user custom constraints can't be used in the ``PARSynthesizer``.""" raise SynthesizerInputError('The PARSynthesizer cannot accommodate custom constraints.') - def _validate_sequence_key_and_context_columns(self, sequence_key, context_columns): + def _validate_sequence_key_and_context_columns(self): """Check that the sequence key is not present in the context colums. Args: @@ -204,9 +204,9 @@ def _validate_sequence_key_and_context_columns(self, sequence_key, context_colum context_columns (list[str]): A list of strings, representing the columns that do not vary in a sequence. """ - if set(sequence_key).intersection(set(context_columns)): + if set(self._sequence_key).intersection(set(self.context_columns)): raise SynthesizerInputError( - f'The sequence key {sequence_key} cannot be a context column. ' + f'The sequence key {self._sequence_key} cannot be a context column. ' 'To proceed, please remove the sequence key from the context_columns parameter.' ) diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 5d78f905f..724b91b52 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -387,7 +387,7 @@ def test_par_sequence_index_is_numerical(): assert sample.columns.to_list() == data.columns.to_list() -def test_par_error_on_context_columns(): +def test_init_error_sequence_key_in_context(): # Setup metadata_dict = { 'columns': { diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 4187acd88..3ee048f29 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -932,7 +932,7 @@ def test_load(self, mock_file, cloudpickle_mock): cloudpickle_mock.load.assert_called_once_with(mock_file.return_value) assert loaded_instance == synthesizer_mock - def test__par_error_on_context_columns(self): + def test___init___error_sequence_key_in_context(self): """Test that the sequence_key is not a context column""" # Setup metadata = self.get_metadata(add_sequence_key=True)