Skip to content

Commit

Permalink
Update testdb results generation script to use TestDB class
Browse files Browse the repository at this point in the history
  • Loading branch information
jlumpe committed Aug 14, 2024
1 parent b097e8f commit 90fe8db
Showing 1 changed file with 20 additions and 26 deletions.
46 changes: 20 additions & 26 deletions tests/data/testdb_210818/generate-results.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,47 +9,39 @@

import sys
from pathlib import Path
from csv import DictReader

from gambit.seq import SequenceFile
from gambit.db import ReferenceDatabase, reportable_taxon
from gambit.query import QueryParams, query_parse
from gambit.db import reportable_taxon
from gambit.query import QueryParams, QueryResults, query_parse
from gambit.results import ResultsArchiveWriter
from gambit.util.misc import zip_strict


THISDIR = Path(__file__).parent
ROOTDIR = THISDIR.parent.parent.parent

sys.path.insert(0, str(ROOTDIR))
from tests.testdb import TestDB, TestQueryGenome


PARAMS = {
'non_strict': QueryParams(classify_strict=False, report_closest=10),
'strict': QueryParams(classify_strict=True, report_closest=10),
}


def load_query_data():
with open('queries/queries.csv', newline='') as f:
rows = list(DictReader(f))

genomes_dir = Path('queries/genomes')
def check_results(queries: list[TestQueryGenome], query_files: list[SequenceFile], results: QueryResults):
"""Check query results object against queries.csv table before exporting."""

for row in rows:
row['warnings'] = row['warnings'].lower() == 'true'
row['file'] = SequenceFile(
path=genomes_dir / (row['name'] + '.fasta'),
format='fasta',
)

return rows


def check_results(queries, results):
strict = results.params.classify_strict

for query, item in zip_strict(queries, results.items):
for query, query_file, item in zip_strict(queries, query_files, results.items):
warnings = []

clsresult = item.classifier_result
predicted = clsresult.predicted_taxon

assert item.input.file == query['file']
assert item.input.file == query_file

# No errors
assert clsresult.success
Expand Down Expand Up @@ -84,7 +76,6 @@ def check_results(queries, results):
# Closest matches
assert len(item.closest_genomes) == results.params.report_closest
assert item.closest_genomes[0] == clsresult.closest_match
assert item.closest_genomes[0].genome.description == query['closest']

for i in range(1, results.params.report_closest):
assert item.closest_genomes[i].distance >= item.closest_genomes[i-1].distance
Expand All @@ -111,19 +102,22 @@ def check_results(queries, results):


def main():
queries = load_query_data()
query_files = [query['file'] for query in queries]
db = ReferenceDatabase.load_from_dir('.')
testdb = TestDB(THISDIR)
db = testdb.refdb
query_files = testdb.get_query_files(relative=True)

writer = ResultsArchiveWriter(pretty=True)

for label, params in PARAMS.items():
print('Running query:', label)
results = query_parse(db, query_files, params)
check_results(queries, results)
check_results(testdb.query_genomes, query_files, results)

with open(f'results/{label}.json', 'wt') as f:
writer.export(f, results)

print('done!\n\n')


if __name__ == '__main__':
main()

0 comments on commit 90fe8db

Please sign in to comment.