diff --git a/tests/test_bigbio.py b/tests/test_bigbio.py index 6a3fb2d3..5e9b1ba7 100644 --- a/tests/test_bigbio.py +++ b/tests/test_bigbio.py @@ -1,5 +1,7 @@ """ Unit-tests to ensure tasks adhere to big-bio schema. + +NOTE: If bypass keys/splits present, statistics are STILL printed. """ import argparse from collections import defaultdict @@ -89,6 +91,10 @@ class TestDataLoader(unittest.TestCase): PATH: str NAME: str DATA_DIR: Optional[str] + BYPASS_SPLITS: List[str] + BYPASS_KEYS: List[str] + BYPASS_SPLIT_KEY_PAIRS: List[str] + def runTest(self): @@ -96,6 +102,7 @@ def runTest(self): logger.info(f"self.NAME: {self.NAME}") logger.info(f"self.DATA_DIR: {self.DATA_DIR}") + self._warn_bypass() logger.info("importing module .... ") module_name = self.PATH @@ -224,17 +231,25 @@ def _assert_ids_globally_unique( for elem in collection: self._assert_ids_globally_unique(elem, ids_seen) + def test_are_ids_globally_unique(self, dataset_bigbio: DatasetDict): """ Tests each example in a split has a unique ID. """ logger.info("Checking global ID uniqueness") - for split in dataset_bigbio.values(): + for split_name, split in dataset_bigbio.items(): + + # Skip entire data split + if split_name in self.BYPASS_SPLITS: + logger.info(f"\tSkipping unique ID check on {split_name}") + continue + ids_seen = set() for example in split: self._assert_ids_globally_unique(example, ids_seen=ids_seen) logger.info("Found {} unique IDs".format(len(ids_seen))) + def _get_referenced_ids(self, example): referenced_ids = [] @@ -255,6 +270,7 @@ def _get_referenced_ids(self, example): return referenced_ids + def _get_existing_referable_ids(self, example): existing_ids = [] @@ -272,7 +288,13 @@ def test_do_all_referenced_ids_exist(self, dataset_bigbio: DatasetDict): Checks if referenced IDs are correctly labeled. """ logger.info("Checking if referenced IDs are properly mapped") - for split in dataset_bigbio.values(): + for split_name, split in dataset_bigbio.items(): + + # skip entire split + if split_name in self.BYPASS_SPLITS: + logger.info(f"\tSkipping referenced ids on {split_name}") + continue + for example in split: referenced_ids = set() existing_ids = set() @@ -281,6 +303,12 @@ def test_do_all_referenced_ids_exist(self, dataset_bigbio: DatasetDict): existing_ids.update(self._get_existing_referable_ids(example)) for ref_id, ref_type in referenced_ids: + + if self._skipkey_or_keysplit(ref_type, split_name): + split_keys = (split_name, ref_type) + logger.warning(f"\tSkipping referenced ids on {split_keys}") + continue + if ref_type == "event": if not ( (ref_id, "entity") in existing_ids @@ -307,6 +335,15 @@ def test_passages_offsets(self, dataset_bigbio: DatasetDict): logger.info("KB ONLY: Checking passage offsets") for split in dataset_bigbio: + # skip entire split + if split in self.BYPASS_SPLITS: + logger.info(f"\tSkipping passage offsets on {split}") + continue + + if self._skipkey_or_keysplit("passages", split): + logger.warning(f"Skipping passages offsets for split='{split}'") + continue + if "passages" in dataset_bigbio[split].features: for example in dataset_bigbio[split]: @@ -409,6 +446,15 @@ def test_entities_offsets(self, dataset_bigbio: DatasetDict): for split in dataset_bigbio: + # skip entire split + if split in self.BYPASS_SPLITS: + logger.info(f"\tSkipping entities offsets on {split}") + continue + + if self._skipkey_or_keysplit("entities", split): + logger.warning(f"Skipping entities offsets for split='{split}'") + continue + if "entities" in dataset_bigbio[split].features: for example in dataset_bigbio[split]: @@ -445,6 +491,15 @@ def test_events_offsets(self, dataset_bigbio: DatasetDict): for split in dataset_bigbio: + # skip entire split + if split in self.BYPASS_SPLITS: + logger.info(f"\tSkipping events offsets on {split}") + continue + + if self._skipkey_or_keysplit("events", split): + logger.warning(f"Skipping events offsets for split='{split}'") + continue + if "events" in dataset_bigbio[split].features: for example in dataset_bigbio[split]: @@ -479,6 +534,15 @@ def test_coref_ids(self, dataset_bigbio: DatasetDict): logger.info("KB ONLY: Checking coref offsets") for split in dataset_bigbio: + # skip entire split + if split in self.BYPASS_SPLITS: + logger.info(f"\tSkipping coref ids on {split}") + continue + + if self._skipkey_or_keysplit("coreferences", split): + logger.warning(f"Skipping coreferences ids for split='{split}'") + continue + if "coreferences" in dataset_bigbio[split].features: for example in dataset_bigbio[split]: @@ -499,23 +563,40 @@ def test_multiple_choice(self, dataset_bigbio: DatasetDict): logger.info("QA ONLY: Checking multiple choice") for split in dataset_bigbio: + # skip entire split + if split in self.BYPASS_SPLITS: + logger.info(f"\tSkipping multiple-choice on {split}") + continue + for example in dataset_bigbio[split]: - if len(example["choices"]) > 0: - # can change "==" to "in" if we include ranking later - assert ( - example["type"] == "multiple_choice" - ), f"`choices` is populated, but type is not 'multiple_choice' {example}" + if self._skipkey_or_keysplit("choices", split): + logger.warning("Skipping multiple choice for key=choices, split='{split}'") + continue - if example["type"] == "multiple_choice": - assert ( - len(example["choices"]) > 0 - ), f"type is 'multiple_choice' but no values in 'choices' {example}" + else: + + if len(example["choices"]) > 0: + # can change "==" to "in" if we include ranking later + assert ( + example["type"] == "multiple_choice" + ), f"`choices` is populated, but type is not 'multiple_choice' {example}" - for answer in example["answer"]: + if example["type"] == "multiple_choice": assert ( - answer in example["choices"] - ), f"answer is not present in 'choices' {example}" + len(example["choices"]) > 0 + ), f"type is 'multiple_choice' but no values in 'choices' {example}" + + + if self._skipkey_or_keysplit("answer", split): + logger.warning("Skipping multiple choice for key=answer, split='{split}'") + continue + + else: + for answer in example["answer"]: + assert ( + answer in example["choices"] + ), f"answer is not present in 'choices' {example}" def test_entities_multilabel_db_id(self, dataset_bigbio: DatasetDict): @@ -530,13 +611,21 @@ def test_entities_multilabel_db_id(self, dataset_bigbio: DatasetDict): # one warning is enough to prompt a cleaning pass for split in dataset_bigbio: + # skip entire split + if split in self.BYPASS_SPLITS: + logger.info(f"\tSkipping entities multilabel db on {split}") + continue + if warning_raised: break if "entities" not in dataset_bigbio[split].features: - continue + if self._skipkey_or_keysplit("entities", split): + logger.warning(f"Skipping multilabel entities for split='{split}'") + continue + for example in dataset_bigbio[split]: if warning_raised: @@ -547,17 +636,14 @@ def test_entities_multilabel_db_id(self, dataset_bigbio: DatasetDict): for entity in example["entities"]: if warning_raised: - break normalized = entity.get("normalized", []) - entity_id = entity["id"] for norm in normalized: db_id = norm["db_id"] - match = re.search(_CONNECTORS, db_id) if match is not None: @@ -586,8 +672,17 @@ def test_multilabel_type(self, dataset_bigbio: DatasetDict): for split in dataset_bigbio: + # skip entire split + if split in self.BYPASS_SPLITS: + logger.info(f"\tSkipping multilabel type on {split}") + continue + for feature_name in features_with_type: + if self._skipkey_or_keysplit(feature_name, split): + logger.warning(f"Skipping multilabel type for splitkey = '{(split, feature_name)}'") + continue + if ( feature_name not in dataset_bigbio[split].features or warning_raised[feature_name] @@ -597,17 +692,14 @@ def test_multilabel_type(self, dataset_bigbio: DatasetDict): for example in dataset_bigbio[split]: if warning_raised[feature_name]: - break example_id = example["id"] - features = example[feature_name] for feature in features: feature_type = feature["type"] - match = re.search(_CONNECTORS, feature_type) if match is not None: @@ -648,9 +740,22 @@ def test_schema(self, schema: str): print() for split_name, split in self.dataset.items(): + + # Skip entire data split + if split_name in self.BYPASS_SPLITS: + logger.info(f"Skipping schema on {split_name}") + continue + + logger.info("Testing schema for: " + str(split_name)) self.assertEqual(split.info.features, features) + for non_empty_feature in non_empty_features: - if split_to_feature_counts[split_name][non_empty_feature] == 0: + + if self._skipkey_or_keysplit(non_empty_feature, split_name): + logger.warning(f"Skipping schema for split, key = '{(split_name, non_empty_feature)}'") + continue + + if (split_to_feature_counts[split_name][non_empty_feature] == 0): raise AssertionError( f"Required key '{non_empty_feature}' does not have any instances" ) @@ -680,11 +785,37 @@ def _test_has_only_one_item(self, msg: str, field: list): ): self.assertEqual(len(field), 1) + def _warn_bypass(self): + """ Warn if keys, data splits, or schemas are skipped """ + + if len(self.BYPASS_SPLITS) > 0: + logger.warning(f"Splits ignored = '{self.BYPASS_SPLITS}'") + + if len(self.BYPASS_KEYS) > 0: + logger.warning(f"Keys ignored = '{self.BYPASS_KEYS}'") + + if len(self.BYPASS_SPLIT_KEY_PAIRS) > 0: + logger.warning( + f"Split and key pairs ignored ='{self.BYPASS_SPLIT_KEY_PAIRS}'" + ) + self.BYPASS_SPLIT_KEY_PAIRS = [i.split(",") for i in self.BYPASS_SPLIT_KEY_PAIRS] + + def _skipkey_or_keysplit(self, key: str, split: str): + """Check if key or (split, key) pair should be omitted""" + flag = False + if key in self.BYPASS_KEYS: + flag = True + + if [split, key] in self.BYPASS_SPLIT_KEY_PAIRS: + flag = True + + return flag if __name__ == "__main__": logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser(description="Unit tests for BigBio dataloaders.") + parser.add_argument( "dataloader_path", type=str, @@ -698,9 +829,34 @@ def _test_has_only_one_item(self, msg: str, field: list): parser.add_argument( "--config_name", type=str, - help="use to run on a single config name (defualt is to run on all config names)", + help="use to run on a single config name (default is to run on all config names)", + ) + + parser.add_argument( + "--bypass_splits", + default=[], + required=False, + nargs="*", + help="Skip a data split (e.g. 'train', 'dev') from testing. List all splits as space separated (ex: --bypass_splits train dev)", + ) + + parser.add_argument( + "--bypass_keys", + default=[], + required=False, + nargs="*", + help="Skip a required key (e.g. 'entities' for NER) from testing. List all keys as space separated (ex: --bypass_keys entities events)", ) + parser.add_argument( + "--bypass_split_key_pairs", + default=[], + required=False, + nargs="*", + help="Skip a key in a data split (e.g. skip 'entities' in 'test'). List all key-pairs comma separated. (ex: --bypass_split_key_pairs test,entities train, events)", + ) + + args = parser.parse_args() logger.info(f"args: {args}") @@ -718,4 +874,7 @@ def _test_has_only_one_item(self, msg: str, field: list): TestDataLoader.PATH = args.dataloader_path TestDataLoader.NAME = config_name TestDataLoader.DATA_DIR = args.data_dir + TestDataLoader.BYPASS_SPLITS = args.bypass_splits + TestDataLoader.BYPASS_KEYS = args.bypass_keys + TestDataLoader.BYPASS_SPLIT_KEY_PAIRS = args.bypass_split_key_pairs unittest.TextTestRunner().run(TestDataLoader())