diff --git a/tests/data/testdb_210818/generate-results.py b/tests/data/testdb_210818/generate-results.py index 920df44..e9c8633 100755 --- a/tests/data/testdb_210818/generate-results.py +++ b/tests/data/testdb_210818/generate-results.py @@ -9,47 +9,39 @@ import sys from pathlib import Path -from csv import DictReader from gambit.seq import SequenceFile -from gambit.db import ReferenceDatabase, reportable_taxon -from gambit.query import QueryParams, query_parse +from gambit.db import reportable_taxon +from gambit.query import QueryParams, QueryResults, query_parse from gambit.results import ResultsArchiveWriter from gambit.util.misc import zip_strict +THISDIR = Path(__file__).parent +ROOTDIR = THISDIR.parent.parent.parent + +sys.path.insert(0, str(ROOTDIR)) +from tests.testdb import TestDB, TestQueryGenome + + PARAMS = { 'non_strict': QueryParams(classify_strict=False, report_closest=10), 'strict': QueryParams(classify_strict=True, report_closest=10), } -def load_query_data(): - with open('queries/queries.csv', newline='') as f: - rows = list(DictReader(f)) - - genomes_dir = Path('queries/genomes') +def check_results(queries: list[TestQueryGenome], query_files: list[SequenceFile], results: QueryResults): + """Check query results object against queries.csv table before exporting.""" - for row in rows: - row['warnings'] = row['warnings'].lower() == 'true' - row['file'] = SequenceFile( - path=genomes_dir / (row['name'] + '.fasta'), - format='fasta', - ) - - return rows - - -def check_results(queries, results): strict = results.params.classify_strict - for query, item in zip_strict(queries, results.items): + for query, query_file, item in zip_strict(queries, query_files, results.items): warnings = [] clsresult = item.classifier_result predicted = clsresult.predicted_taxon - assert item.input.file == query['file'] + assert item.input.file == query_file # No errors assert clsresult.success @@ -84,7 +76,6 @@ def check_results(queries, results): # Closest matches assert len(item.closest_genomes) == results.params.report_closest assert item.closest_genomes[0] == clsresult.closest_match - assert item.closest_genomes[0].genome.description == query['closest'] for i in range(1, results.params.report_closest): assert item.closest_genomes[i].distance >= item.closest_genomes[i-1].distance @@ -111,19 +102,22 @@ def check_results(queries, results): def main(): - queries = load_query_data() - query_files = [query['file'] for query in queries] - db = ReferenceDatabase.load_from_dir('.') + testdb = TestDB(THISDIR) + db = testdb.refdb + query_files = testdb.get_query_files(relative=True) writer = ResultsArchiveWriter(pretty=True) for label, params in PARAMS.items(): + print('Running query:', label) results = query_parse(db, query_files, params) - check_results(queries, results) + check_results(testdb.query_genomes, query_files, results) with open(f'results/{label}.json', 'wt') as f: writer.export(f, results) + print('done!\n\n') + if __name__ == '__main__': main()