diff --git a/src/gambit/query.py b/src/gambit/query.py index 9634a5a..05c7b7b 100644 --- a/src/gambit/query.py +++ b/src/gambit/query.py @@ -2,18 +2,19 @@ from warnings import warn from datetime import datetime -from typing import Sequence, Optional, Union, Any +from typing import Sequence, Optional, Any +from pathlib import Path from attr import attrs, attrib +from attr.converters import optional as optional_converter import numpy as np from gambit import __version__ as GAMBIT_VERSION from gambit.classify import classify, ClassifierResult, GenomeMatch from gambit.db import ReferenceDatabase, Taxon, ReferenceGenomeSet, reportable_taxon from gambit.seq import SequenceFile -from gambit.sigs import KmerSignature, SignaturesMeta +from gambit.sigs import KmerSignature, SignaturesMeta, AnnotatedSignatures from gambit.metric import jaccarddist_matrix -from gambit.util.misc import zip_strict from gambit.util.progress import progress_config, iter_progress @@ -36,54 +37,27 @@ class QueryParams: report_closest: int = attrib(default=10) -@attrs() -class QueryInput: - """Information on a query genome. - - Attributes - ---------- - label - Some unique label for the input, probably the file name. - file - Source file (optional). - """ - label: str = attrib() - file: Optional[SequenceFile] = attrib(default=None, repr=False) - - @classmethod - def convert(cls, x: Union['QueryInput', SequenceFile, str]) -> 'QueryInput': - """Convenience function to convert flexible argument types into QueryInput. - - Accepts single string label, ``SequenceFile`` (uses file path for label), or existing - ``QueryInput`` instance (returned unchanged). - """ - if isinstance(x, QueryInput): - return x - if isinstance(x, str): - return QueryInput(x) - if isinstance(x, SequenceFile): - return QueryInput(str(x.path), x) - raise TypeError(f'Cannot convert {type(x)} instance to QueryInput') - - @attrs() class QueryResultItem: """Result for a single query sequence. Attributes ---------- - input - Information on input genome. + label + Unique label describing query. classifier_result Result of running classifier. + file + Path to file containing query genome (optional). report_taxon Final taxonomy prediction to be reported to the user. closest_genomes List of closest reference genomes to query. Length determined by :attr:`.QueryParams.report_closest`. """ - input: QueryInput = attrib() + label: str = attrib() classifier_result: ClassifierResult = attrib() + file: Optional[Path] = attrib(default=None, converter=optional_converter(Path)) report_taxon: Optional[Taxon] = attrib(default=None) closest_genomes: list[GenomeMatch] = attrib(factory=list) @@ -122,7 +96,7 @@ def query(db: ReferenceDatabase, queries: Sequence[KmerSignature], params: Optional[QueryParams] = None, *, - inputs: Optional[Sequence[Union[QueryInput, SequenceFile, str]]] = None, + labels: Optional[Sequence[str]], progress = None, **kw, ) -> QueryResults: @@ -137,10 +111,10 @@ def query(db: ReferenceDatabase, params ``QueryParams`` instance defining parameter values. If None take values from additional keyword arguments or use defaults. - inputs - Description for each input, converted to :class:`.QueryInput` in results - object. Only used for reporting, does not any other aspect of results. Items can be - ``QueryInput``, ``SequenceFile`` or ``str``. + labels + Optional list of string labels for each query. Only used for reporting (sets ``label`` + attribute of :class:`QueryResultItem` in results object), does not any other aspect of + results. progress Report progress for distance matrix calculation and classification. See :func:`gambit.util.progress.get_progress` for description of allowed values. @@ -152,18 +126,22 @@ def query(db: ReferenceDatabase, elif kw: warn('Additional keyword arguments ignored if "params" argument is not None.') - queries = list(queries) pconf = progress_config(progress) if len(queries) == 0: raise ValueError('Must supply at least one query.') - if inputs is not None: - inputs = list(map(QueryInput.convert, inputs)) - if len(inputs) != len(queries): - raise ValueError('Number of inputs does not match number of queries.') + # Labels + if labels is not None: + if len(labels) != len(queries): + raise ValueError('Number of labels does not match number of queries.') + + elif isinstance(queries, AnnotatedSignatures): + # Get default labels from queries of AnnotatedSignatures object + labels = list(map(str, queries.ids)) + else: - inputs = [QueryInput(str(i + 1)) for i in range(len(queries))] + labels = [str(i + 1) for i in range(len(queries))] # Calculate distances # (This will only be about 200kB per row/query [50k float32's] so having the whole thing in @@ -177,8 +155,11 @@ def query(db: ReferenceDatabase, ) # Classify inputs and create result items - with iter_progress(inputs, pconf, desc='Classifying') as inputs_iter: - items = [get_result_item(db, params, dmat[i, :], input) for i, input in enumerate(inputs_iter)] + with iter_progress(labels, pconf, desc='Classifying') as labels_iter: + items = [ + get_result_item(db, params, dmat[i, :], label) + for i, label in enumerate(labels_iter) + ] return QueryResults( items=items, @@ -188,7 +169,7 @@ def query(db: ReferenceDatabase, ) -def get_result_item(db:ReferenceDatabase, params: QueryParams, dists: np.ndarray, input: QueryInput) -> QueryResultItem: +def get_result_item(db: ReferenceDatabase, params: QueryParams, dists: np.ndarray, label: str) -> QueryResultItem: """Perform classification and create result item object for single query input. Parameters @@ -196,14 +177,14 @@ def get_result_item(db:ReferenceDatabase, params: QueryParams, dists: np.ndarray db params dists - Distances from query to reference genomes. - input + 1D array of distances from query to all reference genomes. + label """ clsresult = classify(db.genomes, dists, strict=params.classify_strict) closest = [GenomeMatch(db.genomes[i], dists[i]) for i in np.argsort(dists)[:params.report_closest]] return QueryResultItem( - input=input, + label=label, classifier_result=clsresult, report_taxon=reportable_taxon(clsresult.predicted_taxon), closest_genomes=closest, @@ -244,10 +225,12 @@ def query_parse(db: ReferenceDatabase, parse_kw.setdefault('progress', pconf.update(desc='Parsing input')) if file_labels is None: - inputs = files + labels = [str(file.path) for file in files] else: - inputs = [QueryInput(label, file) for label, file in zip_strict(file_labels, files)] + labels = file_labels + if len(labels) != len(files): + raise ValueError('Number of labels does not match number of files') query_sigs = calc_file_signatures(db.signatures.kmerspec, files, **parse_kw) - return query(db, query_sigs, params, inputs=inputs, progress=pconf, **kw) + return query(db, query_sigs, params, labels=labels, progress=pconf, **kw) diff --git a/tests/test_query.py b/tests/test_query.py index c2baea2..4de1477 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -2,8 +2,7 @@ import pytest -from gambit.query import QueryInput, QueryResults, query, query_parse -from gambit.seq import SequenceFile +from gambit.query import QueryResults, query, query_parse from gambit.util.misc import zip_strict from gambit import __version__ as GAMBIT_VERSION @@ -11,21 +10,6 @@ from .results import compare_result_items, check_results -class TestQueryInput: - """Test QueryInput class.""" - - def test_convert(self): - file = SequenceFile('path/to/file.fa', 'fasta') - qi = QueryInput('foo', file) - - assert QueryInput.convert(qi) is qi - assert QueryInput.convert('foo') == QueryInput('foo', None) - assert QueryInput.convert(file) == QueryInput(str(file.path), file) - - with pytest.raises(TypeError): - QueryInput.convert(3.4) - - @pytest.mark.parametrize('strict', [False, True]) class TestQuery: """Run a full query using the Python API.""" @@ -55,8 +39,8 @@ def test_query(self, testdb: TestDB, strict: bool): self.check_results(results, ref_results) for sigid, item in zip_strict(query_sigs.ids, results.items): - assert item.input.file is None - # assert item.input.label == sigid + assert item.file is None + assert item.label == sigid def test_query_parse(self, testdb: TestDB, strict: bool): """Test the query_parse() function.""" @@ -69,5 +53,5 @@ def test_query_parse(self, testdb: TestDB, strict: bool): self.check_results(results, ref_results) for file, item in zip_strict(query_files, results.items): - assert item.input.file == file - assert item.input.label == str(file.path) + assert item.file == file.path + assert item.label == str(file.path)