Skip to content

Commit

Permalink
WIP remove QueryInput class
Browse files Browse the repository at this point in the history
  • Loading branch information
jlumpe committed Dec 1, 2024
1 parent ecb9751 commit ca4ead2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 77 deletions.
95 changes: 39 additions & 56 deletions src/gambit/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@

from warnings import warn
from datetime import datetime
from typing import Sequence, Optional, Union, Any
from typing import Sequence, Optional, Any
from pathlib import Path

from attr import attrs, attrib
from attr.converters import optional as optional_converter
import numpy as np

from gambit import __version__ as GAMBIT_VERSION
from gambit.classify import classify, ClassifierResult, GenomeMatch
from gambit.db import ReferenceDatabase, Taxon, ReferenceGenomeSet, reportable_taxon
from gambit.seq import SequenceFile
from gambit.sigs import KmerSignature, SignaturesMeta
from gambit.sigs import KmerSignature, SignaturesMeta, AnnotatedSignatures
from gambit.metric import jaccarddist_matrix
from gambit.util.misc import zip_strict
from gambit.util.progress import progress_config, iter_progress


Expand All @@ -36,54 +37,27 @@ class QueryParams:
report_closest: int = attrib(default=10)


@attrs()
class QueryInput:
"""Information on a query genome.
Attributes
----------
label
Some unique label for the input, probably the file name.
file
Source file (optional).
"""
label: str = attrib()
file: Optional[SequenceFile] = attrib(default=None, repr=False)

@classmethod
def convert(cls, x: Union['QueryInput', SequenceFile, str]) -> 'QueryInput':
"""Convenience function to convert flexible argument types into QueryInput.
Accepts single string label, ``SequenceFile`` (uses file path for label), or existing
``QueryInput`` instance (returned unchanged).
"""
if isinstance(x, QueryInput):
return x
if isinstance(x, str):
return QueryInput(x)
if isinstance(x, SequenceFile):
return QueryInput(str(x.path), x)
raise TypeError(f'Cannot convert {type(x)} instance to QueryInput')


@attrs()
class QueryResultItem:
"""Result for a single query sequence.
Attributes
----------
input
Information on input genome.
label
Unique label describing query.
classifier_result
Result of running classifier.
file
Path to file containing query genome (optional).
report_taxon
Final taxonomy prediction to be reported to the user.
closest_genomes
List of closest reference genomes to query. Length determined by
:attr:`.QueryParams.report_closest`.
"""
input: QueryInput = attrib()
label: str = attrib()
classifier_result: ClassifierResult = attrib()
file: Optional[Path] = attrib(default=None, converter=optional_converter(Path))
report_taxon: Optional[Taxon] = attrib(default=None)
closest_genomes: list[GenomeMatch] = attrib(factory=list)

Expand Down Expand Up @@ -122,7 +96,7 @@ def query(db: ReferenceDatabase,
queries: Sequence[KmerSignature],
params: Optional[QueryParams] = None,
*,
inputs: Optional[Sequence[Union[QueryInput, SequenceFile, str]]] = None,
labels: Optional[Sequence[str]],
progress = None,
**kw,
) -> QueryResults:
Expand All @@ -137,10 +111,10 @@ def query(db: ReferenceDatabase,
params
``QueryParams`` instance defining parameter values. If None take values from additional
keyword arguments or use defaults.
inputs
Description for each input, converted to :class:`.QueryInput` in results
object. Only used for reporting, does not any other aspect of results. Items can be
``QueryInput``, ``SequenceFile`` or ``str``.
labels
Optional list of string labels for each query. Only used for reporting (sets ``label``
attribute of :class:`QueryResultItem` in results object), does not any other aspect of
results.
progress
Report progress for distance matrix calculation and classification. See
:func:`gambit.util.progress.get_progress` for description of allowed values.
Expand All @@ -152,18 +126,22 @@ def query(db: ReferenceDatabase,
elif kw:
warn('Additional keyword arguments ignored if "params" argument is not None.')

queries = list(queries)
pconf = progress_config(progress)

if len(queries) == 0:
raise ValueError('Must supply at least one query.')

if inputs is not None:
inputs = list(map(QueryInput.convert, inputs))
if len(inputs) != len(queries):
raise ValueError('Number of inputs does not match number of queries.')
# Labels
if labels is not None:
if len(labels) != len(queries):
raise ValueError('Number of labels does not match number of queries.')

elif isinstance(queries, AnnotatedSignatures):
# Get default labels from queries of AnnotatedSignatures object
labels = list(map(str, queries.ids))

else:
inputs = [QueryInput(str(i + 1)) for i in range(len(queries))]
labels = [str(i + 1) for i in range(len(queries))]

# Calculate distances
# (This will only be about 200kB per row/query [50k float32's] so having the whole thing in
Expand All @@ -177,8 +155,11 @@ def query(db: ReferenceDatabase,
)

# Classify inputs and create result items
with iter_progress(inputs, pconf, desc='Classifying') as inputs_iter:
items = [get_result_item(db, params, dmat[i, :], input) for i, input in enumerate(inputs_iter)]
with iter_progress(labels, pconf, desc='Classifying') as labels_iter:
items = [
get_result_item(db, params, dmat[i, :], label)
for i, label in enumerate(labels_iter)
]

return QueryResults(
items=items,
Expand All @@ -188,22 +169,22 @@ def query(db: ReferenceDatabase,
)


def get_result_item(db:ReferenceDatabase, params: QueryParams, dists: np.ndarray, input: QueryInput) -> QueryResultItem:
def get_result_item(db: ReferenceDatabase, params: QueryParams, dists: np.ndarray, label: str) -> QueryResultItem:
"""Perform classification and create result item object for single query input.
Parameters
----------
db
params
dists
Distances from query to reference genomes.
input
1D array of distances from query to all reference genomes.
label
"""
clsresult = classify(db.genomes, dists, strict=params.classify_strict)
closest = [GenomeMatch(db.genomes[i], dists[i]) for i in np.argsort(dists)[:params.report_closest]]

return QueryResultItem(
input=input,
label=label,
classifier_result=clsresult,
report_taxon=reportable_taxon(clsresult.predicted_taxon),
closest_genomes=closest,
Expand Down Expand Up @@ -244,10 +225,12 @@ def query_parse(db: ReferenceDatabase,
parse_kw.setdefault('progress', pconf.update(desc='Parsing input'))

if file_labels is None:
inputs = files
labels = [str(file.path) for file in files]
else:
inputs = [QueryInput(label, file) for label, file in zip_strict(file_labels, files)]
labels = file_labels
if len(labels) != len(files):
raise ValueError('Number of labels does not match number of files')

query_sigs = calc_file_signatures(db.signatures.kmerspec, files, **parse_kw)

return query(db, query_sigs, params, inputs=inputs, progress=pconf, **kw)
return query(db, query_sigs, params, labels=labels, progress=pconf, **kw)
26 changes: 5 additions & 21 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,14 @@

import pytest

from gambit.query import QueryInput, QueryResults, query, query_parse
from gambit.seq import SequenceFile
from gambit.query import QueryResults, query, query_parse
from gambit.util.misc import zip_strict
from gambit import __version__ as GAMBIT_VERSION

from .testdb import TestDB
from .results import compare_result_items, check_results


class TestQueryInput:
"""Test QueryInput class."""

def test_convert(self):
file = SequenceFile('path/to/file.fa', 'fasta')
qi = QueryInput('foo', file)

assert QueryInput.convert(qi) is qi
assert QueryInput.convert('foo') == QueryInput('foo', None)
assert QueryInput.convert(file) == QueryInput(str(file.path), file)

with pytest.raises(TypeError):
QueryInput.convert(3.4)


@pytest.mark.parametrize('strict', [False, True])
class TestQuery:
"""Run a full query using the Python API."""
Expand Down Expand Up @@ -55,8 +39,8 @@ def test_query(self, testdb: TestDB, strict: bool):
self.check_results(results, ref_results)

for sigid, item in zip_strict(query_sigs.ids, results.items):
assert item.input.file is None
# assert item.input.label == sigid
assert item.file is None
assert item.label == sigid

def test_query_parse(self, testdb: TestDB, strict: bool):
"""Test the query_parse() function."""
Expand All @@ -69,5 +53,5 @@ def test_query_parse(self, testdb: TestDB, strict: bool):
self.check_results(results, ref_results)

for file, item in zip_strict(query_files, results.items):
assert item.input.file == file
assert item.input.label == str(file.path)
assert item.file == file.path
assert item.label == str(file.path)

0 comments on commit ca4ead2

Please sign in to comment.