From f330c6f20f585876385d249df3df2a41bc218c72 Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Fri, 5 Apr 2024 15:26:04 +0100 Subject: [PATCH] Move out sdtype validations from multi-column transformers (#779) --- rdt/transformers/_validators.py | 171 +++++++++++ tests/unit/transformers/test__validators.py | 309 ++++++++++++++++++++ 2 files changed, 480 insertions(+) create mode 100644 rdt/transformers/_validators.py create mode 100644 tests/unit/transformers/test__validators.py diff --git a/rdt/transformers/_validators.py b/rdt/transformers/_validators.py new file mode 100644 index 00000000..5c5f7fd8 --- /dev/null +++ b/rdt/transformers/_validators.py @@ -0,0 +1,171 @@ +"""Validations for multi-column transformers.""" +import importlib + +from rdt.errors import TransformerInputError + + +class BaseValidator: + """Base validation class. + + The validation classes ensure that the input data is compatible with the transformers + and that they can be imported. + """ + + SUPPORTED_SDTYPES = [] + VALIDATION_TYPE = None + + @classmethod + def _validate_supported_sdtypes(cls, columns_to_sdtypes): + message = '' + for column, sdtype in columns_to_sdtypes.items(): + if sdtype not in cls.SUPPORTED_SDTYPES: + message += f"Column '{column}' has an unsupported sdtype '{sdtype}'.\n" + + if message: + message += ( + f'Please provide a column that is compatible with {cls.VALIDATION_TYPE} data.' + ) + raise TransformerInputError(message) + + @classmethod + def validate_sdtypes(cls, columns_to_sdtypes): + """Validate the columns to sdtypes mapping. + + This method aims to call all other sdtype validation method in the class. + + Args: + columns_to_sdtypes (dict): + Mapping of column names to sdtypes. + """ + raise NotImplementedError + + @classmethod + def validate_imports(cls): + """Check that the transformers can be imported.""" + raise NotImplementedError + + @classmethod + def validate(cls, columns_to_sdtypes): + """Validate the input data. + + Args: + columns_to_sdtypes (dict): + Mapping of column names to sdtypes. + """ + cls.validate_sdtypes(columns_to_sdtypes) + cls.validate_imports() + + +class AddressValidator(BaseValidator): + """Validation class for Address data.""" + + SUPPORTED_SDTYPES = [ + 'country_code', 'administrative_unit', 'city', 'postcode', + 'street_address', 'secondary_address', 'state', 'state_abbr' + ] + VALIDATION_TYPE = 'Address' + + @classmethod + def _validate_number_columns(cls, columns_to_sdtypes): + if len(columns_to_sdtypes) > 7: + raise TransformerInputError( + f'{cls.VALIDATION_TYPE} transformers takes up to 7 columns to transform. ' + 'Please provide address data with valid fields.' + ) + + @staticmethod + def _validate_uniqueness_sdtype(columns_to_sdtypes): + sdtypes_to_columns = {} + for column, sdtype in columns_to_sdtypes.items(): + if sdtype not in sdtypes_to_columns: + sdtypes_to_columns[sdtype] = [] + + sdtypes_to_columns[sdtype].append(column) + + duplicate_fields = { + value: keys for value, keys in sdtypes_to_columns.items() if len(keys) > 1 + } + + if duplicate_fields: + message = '' + for sdtype, columns in duplicate_fields.items(): + to_print = "', '".join(columns) + message += f"Columns '{to_print}' have the same sdtype '{sdtype}'.\n" + + message += 'Your address data cannot have duplicate fields.' + raise TransformerInputError(message) + + @classmethod + def _validate_administrative_unit(cls, columns_to_sdtypes): + num_column_administrative_unit = sum( + 1 for itm in columns_to_sdtypes.values() if itm in ['administrative_unit', 'state'] + ) + if num_column_administrative_unit > 1: + raise TransformerInputError( + f"The {cls.__name__} can have up to 1 column with sdtype 'state'" + f" or 'administrative_unit'. Please provide address data with valid fields." + ) + + @classmethod + def validate_sdtypes(cls, columns_to_sdtypes): + """Validate the columns to sdtypes mapping.""" + cls._validate_supported_sdtypes(columns_to_sdtypes) + cls._validate_number_columns(columns_to_sdtypes) + cls._validate_uniqueness_sdtype(columns_to_sdtypes) + cls._validate_administrative_unit(columns_to_sdtypes) + + @classmethod + def validate_imports(cls): + """Check that the address transformers can be imported.""" + error_message = ( + 'You must have SDV Enterprise with the address add-on to use the address features.' + ) + + try: + address_module = importlib.import_module('rdt.transformers.address') + except ModuleNotFoundError: + raise ImportError(error_message) from None + + required_classes = ['RandomLocationGenerator', 'RegionalAnonymizer'] + for class_name in required_classes: + if not hasattr(address_module, class_name): + raise ImportError(error_message) + + +class GPSValidator(BaseValidator): + """Validation class for GPS data.""" + + SUPPORTED_SDTYPES = ['latitude', 'longitude'] + VALIDATION_TYPE = 'GPS' + + @staticmethod + def _validate_uniqueness_sdtype(columns_to_sdtypes): + sdtypes_to_columns = {sdtype: column for column, sdtype in columns_to_sdtypes.items()} + if len(sdtypes_to_columns) != 2: + raise TransformerInputError( + 'The GPS columns must have one latitude and on longitude columns sdtypes. ' + 'Please provide GPS data with valid fields.' + ) + + @classmethod + def validate_sdtypes(cls, columns_to_sdtypes): + """Validate the columns to sdtypes mapping.""" + cls._validate_supported_sdtypes(columns_to_sdtypes) + cls._validate_uniqueness_sdtype(columns_to_sdtypes) + + @classmethod + def validate_imports(cls): + """Check that the GPS transformers can be imported.""" + error_message = ( + 'You must have SDV Enterprise with the gps add-on to use the GPS features.' + ) + + try: + gps_module = importlib.import_module('rdt.transformers.gps') + except ModuleNotFoundError: + raise ImportError(error_message) from None + + required_classes = ['RandomLocationGenerator', 'MetroAreaAnonymizer', 'GPSNoiser'] + for class_name in required_classes: + if not hasattr(gps_module, class_name): + raise ImportError(error_message) diff --git a/tests/unit/transformers/test__validators.py b/tests/unit/transformers/test__validators.py new file mode 100644 index 00000000..dc1c2347 --- /dev/null +++ b/tests/unit/transformers/test__validators.py @@ -0,0 +1,309 @@ +import re +from unittest.mock import Mock, patch + +import pytest + +from rdt.errors import TransformerInputError +from rdt.transformers._validators import AddressValidator, BaseValidator, GPSValidator + + +class TestBaseValidator: + + @patch('rdt.transformers._validators.BaseValidator.SUPPORTED_SDTYPES', ['numerical']) + @patch('rdt.transformers._validators.BaseValidator.VALIDATION_TYPE', 'Base') + def test_validate_supported_sdtypes(self): + """Test ``_validate_supported_sdtypes`` method.""" + # Setup + columns_to_sdtypes_valid = { + 'col1': 'numerical', + 'col2': 'numerical', + } + columns_to_sdtypes_invalid = { + 'col1': 'numerical', + 'col2': 'categorical', + 'col3': 'categorical', + } + + expected_message = re.escape( + "Column 'col2' has an unsupported sdtype 'categorical'.\n" + "Column 'col3' has an unsupported sdtype 'categorical'.\n" + 'Please provide a column that is compatible with Base data.' + ) + + # Run and Assert + BaseValidator._validate_supported_sdtypes(columns_to_sdtypes_valid) + with pytest.raises(TransformerInputError, match=expected_message): + BaseValidator._validate_supported_sdtypes(columns_to_sdtypes_invalid) + + def test_validate_sdtypes(self): + """Test ``validate_sdtypes`` method.""" + # Setup + columns_to_sdtypes = { + 'col1': 'numerical', + 'col2': 'categorical', + } + + # Run and Assert + with pytest.raises(NotImplementedError): + BaseValidator.validate_sdtypes(columns_to_sdtypes) + + def test_validate_imports(self): + """Test ``validate_imports`` method.""" + # Run and Assert + with pytest.raises(NotImplementedError): + BaseValidator.validate_imports() + + @patch('rdt.transformers._validators.BaseValidator.validate_sdtypes') + @patch('rdt.transformers._validators.BaseValidator.validate_imports') + def test_validate(self, mock_validate_imports, mock_validate_sdtypes): + """Test ``validate`` method.""" + # Setup + columns_to_sdtypes = { + 'col1': 'numerical', + 'col2': 'categorical', + } + + # Run + BaseValidator.validate(columns_to_sdtypes) + + # Assert + mock_validate_sdtypes.assert_called_once_with(columns_to_sdtypes) + mock_validate_imports.assert_called_once() + + +class TestAddressValidator: + def test__validate_number_columns(self): + """Test ``_validate_number_columns`` method.""" + # Setup + columns_to_sdtypes_valid = { + 'col_1': 'country_code', + 'col_2': 'administrative_unit', + } + column_to_sdtypes_invalid = { + 'col_1': 'country_code', + 'col_2': 'administrative_unit', + 'col_3': 'city', + 'col_4': 'postcode', + 'col_5': 'street_address', + 'col_6': 'secondary_address', + 'col_7': 'country_code', + 'col_8': 'administrative_unit' + } + + # Run and Assert + AddressValidator._validate_number_columns(columns_to_sdtypes_valid) + + expected_message = ( + 'Address transformers takes up to 7 columns to transform. Please provide address' + ' data with valid fields.' + ) + with pytest.raises(TransformerInputError, match=re.escape(expected_message)): + AddressValidator._validate_number_columns(column_to_sdtypes_invalid) + + def test__validate_uniqueness_sdtype(self): + """Test ``_validate_uniqueness_sdtype`` method.""" + # Setup + columns_to_sdtypes_valid = { + 'col_1': 'country_code', + 'col_2': 'administrative_unit', + } + columns_to_sdtypes_invalid = { + 'col_1': 'country_code', + 'col_2': 'country_code', + 'col_3': 'city', + 'col_4': 'city' + } + + # Run and Assert + AddressValidator._validate_uniqueness_sdtype(columns_to_sdtypes_valid) + + expected_message = re.escape( + "Columns 'col_1', 'col_2' have the same sdtype 'country_code'.\n" + "Columns 'col_3', 'col_4' have the same sdtype 'city'.\n" + 'Your address data cannot have duplicate fields.' + ) + with pytest.raises(TransformerInputError, match=expected_message): + AddressValidator._validate_uniqueness_sdtype(columns_to_sdtypes_invalid) + + def test__validate_supported_sdtype(self): + """Test ``_validate_supported_sdtype`` method.""" + # Setup + columns_to_sdtypes_valid = { + 'col_1': 'country_code', + 'col_2': 'administrative_unit', + } + columns_to_sdtypes_invalid = { + 'col_1': 'country_code', + 'col_2': 'numerical', + 'col_3': 'categorical', + } + + # Run and Assert + AddressValidator._validate_supported_sdtypes(columns_to_sdtypes_valid) + + expected_message = re.escape( + "Column 'col_2' has an unsupported sdtype 'numerical'.\n" + "Column 'col_3' has an unsupported sdtype 'categorical'.\n" + 'Please provide a column that is compatible with Address data.' + ) + with pytest.raises(TransformerInputError, match=expected_message): + AddressValidator._validate_supported_sdtypes(columns_to_sdtypes_invalid) + + def test__validate_administrative_unit(self): + """Test ``_validate_administrative_unit`` method.""" + # Setup + columns_to_sdtypes_valid = { + 'col_1': 'country_code', + 'col_2': 'administrative_unit', + } + columns_to_sdtypes_invalid = { + 'col_1': 'administrative_unit', + 'col_2': 'state' + } + + # Run and Assert + AddressValidator._validate_administrative_unit(columns_to_sdtypes_valid) + + expected_message = ( + "The AddressValidator can have up to 1 column with sdtype 'state'" + " or 'administrative_unit'. Please provide address data with valid fields." + ) + with pytest.raises(TransformerInputError, match=re.escape(expected_message)): + AddressValidator._validate_administrative_unit(columns_to_sdtypes_invalid) + + def test__validate_sdtypes(self): + """Test ``validate_sdtypes`` method.""" + # Setup + columns_to_sdtypes = { + 'country': 'country_code', + 'region': 'administrative_unit', + } + AddressValidator._validate_number_columns = Mock() + AddressValidator._validate_uniqueness_sdtype = Mock() + AddressValidator._validate_supported_sdtypes = Mock() + AddressValidator._validate_administrative_unit = Mock() + + # Run + AddressValidator.validate_sdtypes(columns_to_sdtypes) + + # Assert + AddressValidator._validate_number_columns.assert_called_once_with(columns_to_sdtypes) + AddressValidator._validate_uniqueness_sdtype.assert_called_once_with(columns_to_sdtypes) + AddressValidator._validate_supported_sdtypes.assert_called_once_with(columns_to_sdtypes) + AddressValidator._validate_administrative_unit.assert_called_once_with( + columns_to_sdtypes + ) + + def test__validate_imports_without_address_module(self): + """Test ``validate_imports`` when address module doesn't exist.""" + # Run and Assert + expected_message = ( + 'You must have SDV Enterprise with the address add-on to use the address features' + ) + with pytest.raises(ImportError, match=expected_message): + AddressValidator.validate_imports() + + @patch('rdt.transformers._validators.importlib.import_module') + def test__validate_imports_without_premium_features(self, mock_import_module): + """Test ``validate_imports`` when the user doesn't have the transformers.""" + # Setup + mock_address = Mock() + del mock_address.RandomLocationGenerator + del mock_address.RegionalAnonymizer + mock_import_module.return_value = mock_address + + # Run and Assert + expected_message = ( + 'You must have SDV Enterprise with the address add-on to use the address features' + ) + with pytest.raises(ImportError, match=expected_message): + AddressValidator.validate_imports() + + +class TestGPSValidator: + def test__validate_uniqueness_sdtype(self): + """Test ``_validate_uniqueness_sdtype`` method.""" + # Setup + columns_to_sdtypes_valid = { + 'col_1': 'latitude', + 'col_2': 'longitude', + } + columns_to_sdtypes_invalid = { + 'col_1': 'latitude', + 'col_2': 'latitude', + } + + # Run and Assert + GPSValidator._validate_uniqueness_sdtype(columns_to_sdtypes_valid) + + expected_message = re.escape( + 'The GPS columns must have one latitude and on longitude columns sdtypes. ' + 'Please provide GPS data with valid fields.' + ) + with pytest.raises(TransformerInputError, match=expected_message): + GPSValidator._validate_uniqueness_sdtype(columns_to_sdtypes_invalid) + + def test__validate_supported_sdtype(self): + """Test ``_validate_supported_sdtype`` method.""" + # Setup + columns_to_sdtypes_valid = { + 'col_1': 'latitude', + 'col_2': 'longitude', + } + columns_to_sdtypes_invalid = { + 'col_1': 'latitude', + 'col_2': 'postal_code', + } + + # Run and Assert + GPSValidator._validate_supported_sdtypes(columns_to_sdtypes_valid) + + expected_message = re.escape( + "Column 'col_2' has an unsupported sdtype 'postal_code'.\n" + 'Please provide a column that is compatible with GPS data.' + ) + with pytest.raises(TransformerInputError, match=expected_message): + GPSValidator._validate_supported_sdtypes(columns_to_sdtypes_invalid) + + def test__validate_sdtypes(self): + """Test ``validate_sdtypes`` method.""" + # Setup + columns_to_sdtypes = { + 'latitude_column': 'latitude', + 'longitude_column': 'longitude', + } + GPSValidator._validate_uniqueness_sdtype = Mock() + GPSValidator._validate_supported_sdtypes = Mock() + + # Run + GPSValidator.validate_sdtypes(columns_to_sdtypes) + + # Assert + GPSValidator._validate_uniqueness_sdtype.assert_called_once_with(columns_to_sdtypes) + GPSValidator._validate_supported_sdtypes.assert_called_once_with(columns_to_sdtypes) + + def test_validate_import_gps_transformers_without_gps_module(self): + """Test ``validate_imports`` when gps module doesn't exist.""" + # Run and Assert + expected_message = ( + 'You must have SDV Enterprise with the gps add-on to use the GPS features' + ) + with pytest.raises(ImportError, match=expected_message): + GPSValidator.validate_imports() + + @patch('rdt.transformers._validators.importlib.import_module') + def test_validate_import_gps_transformers_without_premium_features(self, mock_import_module): + """Test ``validate_imports`` when the user doesn't have the transformers.""" + # Setup + mock_gps = Mock() + del mock_gps.RandomLocationGenerator + del mock_gps.MetroAreaAnonymizer + del mock_gps.GPSNoiser + mock_import_module.return_value = mock_gps + + # Run and Assert + expected_message = ( + 'You must have SDV Enterprise with the gps add-on to use the GPS features' + ) + with pytest.raises(ImportError, match=expected_message): + GPSValidator.validate_imports()