Skip to content

Commit

Permalink
Unit-test bypass args (bigscience-workshop#533)
Browse files Browse the repository at this point in the history
* make a pip installable package biomed_datasets

* rm src directory idea

* drop n2c2_2014, was there for PR review

* remove requirements.txt

* added basic setup.py and deploy requirements in setup.cfg

* change package name to bigbio

* Revert "Merge branch 'minimal_make_package' of github.com:bigscience-workshop/biomedical"

This reverts commit 3b5119a, reversing
changes made to 8409ab7.

* doc: typo fix

* adds bypass split_keys args

* typo fix

* fix key name

* adds splits; WIP keys + splitkeys

* implement keys + splitkeys

* fix: makes helper function to streamline key check

* fix: split name pass in schema

* fix: incorrect reference in logger fixed for schema

* minor clarification on text

Co-authored-by: Gabriel Altay <[email protected]>
  • Loading branch information
hakunanatasha and galtay authored May 5, 2022
1 parent 4c7066e commit f113917
Showing 1 changed file with 182 additions and 23 deletions.
205 changes: 182 additions & 23 deletions tests/test_bigbio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
Unit-tests to ensure tasks adhere to big-bio schema.
NOTE: If bypass keys/splits present, statistics are STILL printed.
"""
import argparse
from collections import defaultdict
Expand Down Expand Up @@ -89,13 +91,18 @@ class TestDataLoader(unittest.TestCase):
PATH: str
NAME: str
DATA_DIR: Optional[str]
BYPASS_SPLITS: List[str]
BYPASS_KEYS: List[str]
BYPASS_SPLIT_KEY_PAIRS: List[str]


def runTest(self):

logger.info(f"self.PATH: {self.PATH}")
logger.info(f"self.NAME: {self.NAME}")
logger.info(f"self.DATA_DIR: {self.DATA_DIR}")

self._warn_bypass()

logger.info("importing module .... ")
module_name = self.PATH
Expand Down Expand Up @@ -224,17 +231,25 @@ def _assert_ids_globally_unique(
for elem in collection:
self._assert_ids_globally_unique(elem, ids_seen)


def test_are_ids_globally_unique(self, dataset_bigbio: DatasetDict):
"""
Tests each example in a split has a unique ID.
"""
logger.info("Checking global ID uniqueness")
for split in dataset_bigbio.values():
for split_name, split in dataset_bigbio.items():

# Skip entire data split
if split_name in self.BYPASS_SPLITS:
logger.info(f"\tSkipping unique ID check on {split_name}")
continue

ids_seen = set()
for example in split:
self._assert_ids_globally_unique(example, ids_seen=ids_seen)
logger.info("Found {} unique IDs".format(len(ids_seen)))


def _get_referenced_ids(self, example):
referenced_ids = []

Expand All @@ -255,6 +270,7 @@ def _get_referenced_ids(self, example):

return referenced_ids


def _get_existing_referable_ids(self, example):
existing_ids = []

Expand All @@ -272,7 +288,13 @@ def test_do_all_referenced_ids_exist(self, dataset_bigbio: DatasetDict):
Checks if referenced IDs are correctly labeled.
"""
logger.info("Checking if referenced IDs are properly mapped")
for split in dataset_bigbio.values():
for split_name, split in dataset_bigbio.items():

# skip entire split
if split_name in self.BYPASS_SPLITS:
logger.info(f"\tSkipping referenced ids on {split_name}")
continue

for example in split:
referenced_ids = set()
existing_ids = set()
Expand All @@ -281,6 +303,12 @@ def test_do_all_referenced_ids_exist(self, dataset_bigbio: DatasetDict):
existing_ids.update(self._get_existing_referable_ids(example))

for ref_id, ref_type in referenced_ids:

if self._skipkey_or_keysplit(ref_type, split_name):
split_keys = (split_name, ref_type)
logger.warning(f"\tSkipping referenced ids on {split_keys}")
continue

if ref_type == "event":
if not (
(ref_id, "entity") in existing_ids
Expand All @@ -307,6 +335,15 @@ def test_passages_offsets(self, dataset_bigbio: DatasetDict):
logger.info("KB ONLY: Checking passage offsets")
for split in dataset_bigbio:

# skip entire split
if split in self.BYPASS_SPLITS:
logger.info(f"\tSkipping passage offsets on {split}")
continue

if self._skipkey_or_keysplit("passages", split):
logger.warning(f"Skipping passages offsets for split='{split}'")
continue

if "passages" in dataset_bigbio[split].features:

for example in dataset_bigbio[split]:
Expand Down Expand Up @@ -409,6 +446,15 @@ def test_entities_offsets(self, dataset_bigbio: DatasetDict):

for split in dataset_bigbio:

# skip entire split
if split in self.BYPASS_SPLITS:
logger.info(f"\tSkipping entities offsets on {split}")
continue

if self._skipkey_or_keysplit("entities", split):
logger.warning(f"Skipping entities offsets for split='{split}'")
continue

if "entities" in dataset_bigbio[split].features:

for example in dataset_bigbio[split]:
Expand Down Expand Up @@ -445,6 +491,15 @@ def test_events_offsets(self, dataset_bigbio: DatasetDict):

for split in dataset_bigbio:

# skip entire split
if split in self.BYPASS_SPLITS:
logger.info(f"\tSkipping events offsets on {split}")
continue

if self._skipkey_or_keysplit("events", split):
logger.warning(f"Skipping events offsets for split='{split}'")
continue

if "events" in dataset_bigbio[split].features:

for example in dataset_bigbio[split]:
Expand Down Expand Up @@ -479,6 +534,15 @@ def test_coref_ids(self, dataset_bigbio: DatasetDict):
logger.info("KB ONLY: Checking coref offsets")
for split in dataset_bigbio:

# skip entire split
if split in self.BYPASS_SPLITS:
logger.info(f"\tSkipping coref ids on {split}")
continue

if self._skipkey_or_keysplit("coreferences", split):
logger.warning(f"Skipping coreferences ids for split='{split}'")
continue

if "coreferences" in dataset_bigbio[split].features:

for example in dataset_bigbio[split]:
Expand All @@ -499,23 +563,40 @@ def test_multiple_choice(self, dataset_bigbio: DatasetDict):
logger.info("QA ONLY: Checking multiple choice")
for split in dataset_bigbio:

# skip entire split
if split in self.BYPASS_SPLITS:
logger.info(f"\tSkipping multiple-choice on {split}")
continue

for example in dataset_bigbio[split]:

if len(example["choices"]) > 0:
# can change "==" to "in" if we include ranking later
assert (
example["type"] == "multiple_choice"
), f"`choices` is populated, but type is not 'multiple_choice' {example}"
if self._skipkey_or_keysplit("choices", split):
logger.warning("Skipping multiple choice for key=choices, split='{split}'")
continue

if example["type"] == "multiple_choice":
assert (
len(example["choices"]) > 0
), f"type is 'multiple_choice' but no values in 'choices' {example}"
else:

if len(example["choices"]) > 0:
# can change "==" to "in" if we include ranking later
assert (
example["type"] == "multiple_choice"
), f"`choices` is populated, but type is not 'multiple_choice' {example}"

for answer in example["answer"]:
if example["type"] == "multiple_choice":
assert (
answer in example["choices"]
), f"answer is not present in 'choices' {example}"
len(example["choices"]) > 0
), f"type is 'multiple_choice' but no values in 'choices' {example}"


if self._skipkey_or_keysplit("answer", split):
logger.warning("Skipping multiple choice for key=answer, split='{split}'")
continue

else:
for answer in example["answer"]:
assert (
answer in example["choices"]
), f"answer is not present in 'choices' {example}"


def test_entities_multilabel_db_id(self, dataset_bigbio: DatasetDict):
Expand All @@ -530,13 +611,21 @@ def test_entities_multilabel_db_id(self, dataset_bigbio: DatasetDict):
# one warning is enough to prompt a cleaning pass
for split in dataset_bigbio:

# skip entire split
if split in self.BYPASS_SPLITS:
logger.info(f"\tSkipping entities multilabel db on {split}")
continue

if warning_raised:
break

if "entities" not in dataset_bigbio[split].features:

continue

if self._skipkey_or_keysplit("entities", split):
logger.warning(f"Skipping multilabel entities for split='{split}'")
continue

for example in dataset_bigbio[split]:

if warning_raised:
Expand All @@ -547,17 +636,14 @@ def test_entities_multilabel_db_id(self, dataset_bigbio: DatasetDict):
for entity in example["entities"]:

if warning_raised:

break

normalized = entity.get("normalized", [])

entity_id = entity["id"]

for norm in normalized:

db_id = norm["db_id"]

match = re.search(_CONNECTORS, db_id)

if match is not None:
Expand Down Expand Up @@ -586,8 +672,17 @@ def test_multilabel_type(self, dataset_bigbio: DatasetDict):

for split in dataset_bigbio:

# skip entire split
if split in self.BYPASS_SPLITS:
logger.info(f"\tSkipping multilabel type on {split}")
continue

for feature_name in features_with_type:

if self._skipkey_or_keysplit(feature_name, split):
logger.warning(f"Skipping multilabel type for splitkey = '{(split, feature_name)}'")
continue

if (
feature_name not in dataset_bigbio[split].features
or warning_raised[feature_name]
Expand All @@ -597,17 +692,14 @@ def test_multilabel_type(self, dataset_bigbio: DatasetDict):
for example in dataset_bigbio[split]:

if warning_raised[feature_name]:

break

example_id = example["id"]

features = example[feature_name]

for feature in features:

feature_type = feature["type"]

match = re.search(_CONNECTORS, feature_type)

if match is not None:
Expand Down Expand Up @@ -648,9 +740,22 @@ def test_schema(self, schema: str):
print()

for split_name, split in self.dataset.items():

# Skip entire data split
if split_name in self.BYPASS_SPLITS:
logger.info(f"Skipping schema on {split_name}")
continue

logger.info("Testing schema for: " + str(split_name))
self.assertEqual(split.info.features, features)

for non_empty_feature in non_empty_features:
if split_to_feature_counts[split_name][non_empty_feature] == 0:

if self._skipkey_or_keysplit(non_empty_feature, split_name):
logger.warning(f"Skipping schema for split, key = '{(split_name, non_empty_feature)}'")
continue

if (split_to_feature_counts[split_name][non_empty_feature] == 0):
raise AssertionError(
f"Required key '{non_empty_feature}' does not have any instances"
)
Expand Down Expand Up @@ -680,11 +785,37 @@ def _test_has_only_one_item(self, msg: str, field: list):
):
self.assertEqual(len(field), 1)

def _warn_bypass(self):
""" Warn if keys, data splits, or schemas are skipped """

if len(self.BYPASS_SPLITS) > 0:
logger.warning(f"Splits ignored = '{self.BYPASS_SPLITS}'")

if len(self.BYPASS_KEYS) > 0:
logger.warning(f"Keys ignored = '{self.BYPASS_KEYS}'")

if len(self.BYPASS_SPLIT_KEY_PAIRS) > 0:
logger.warning(
f"Split and key pairs ignored ='{self.BYPASS_SPLIT_KEY_PAIRS}'"
)
self.BYPASS_SPLIT_KEY_PAIRS = [i.split(",") for i in self.BYPASS_SPLIT_KEY_PAIRS]

def _skipkey_or_keysplit(self, key: str, split: str):
"""Check if key or (split, key) pair should be omitted"""
flag = False
if key in self.BYPASS_KEYS:
flag = True

if [split, key] in self.BYPASS_SPLIT_KEY_PAIRS:
flag = True

return flag

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser(description="Unit tests for BigBio dataloaders.")

parser.add_argument(
"dataloader_path",
type=str,
Expand All @@ -698,9 +829,34 @@ def _test_has_only_one_item(self, msg: str, field: list):
parser.add_argument(
"--config_name",
type=str,
help="use to run on a single config name (defualt is to run on all config names)",
help="use to run on a single config name (default is to run on all config names)",
)

parser.add_argument(
"--bypass_splits",
default=[],
required=False,
nargs="*",
help="Skip a data split (e.g. 'train', 'dev') from testing. List all splits as space separated (ex: --bypass_splits train dev)",
)

parser.add_argument(
"--bypass_keys",
default=[],
required=False,
nargs="*",
help="Skip a required key (e.g. 'entities' for NER) from testing. List all keys as space separated (ex: --bypass_keys entities events)",
)

parser.add_argument(
"--bypass_split_key_pairs",
default=[],
required=False,
nargs="*",
help="Skip a key in a data split (e.g. skip 'entities' in 'test'). List all key-pairs comma separated. (ex: --bypass_split_key_pairs test,entities train, events)",
)


args = parser.parse_args()
logger.info(f"args: {args}")

Expand All @@ -718,4 +874,7 @@ def _test_has_only_one_item(self, msg: str, field: list):
TestDataLoader.PATH = args.dataloader_path
TestDataLoader.NAME = config_name
TestDataLoader.DATA_DIR = args.data_dir
TestDataLoader.BYPASS_SPLITS = args.bypass_splits
TestDataLoader.BYPASS_KEYS = args.bypass_keys
TestDataLoader.BYPASS_SPLIT_KEY_PAIRS = args.bypass_split_key_pairs
unittest.TextTestRunner().run(TestDataLoader())

0 comments on commit f113917

Please sign in to comment.