Skip to content

Commit

Permalink
Updates to TestDB class
Browse files Browse the repository at this point in the history
  • Loading branch information
jlumpe committed Aug 14, 2024
1 parent abec7ba commit b097e8f
Showing 1 changed file with 42 additions and 44 deletions.
86 changes: 42 additions & 44 deletions tests/testdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,12 @@ class TestQueryGenome(TypedDict):
primary: str
closest: str
warnings: bool
file: SequenceFile
file_gz: SequenceFile


class TestRefGenome(TypedDict):
name: str
key: str
taxon: str
file: SequenceFile
file_gz: SequenceFile


class TestDB:
Expand Down Expand Up @@ -162,18 +158,6 @@ def refdb(self) -> ReferenceDatabase:
gset = only_genomeset(session)
return ReferenceDatabase(gset, self.ref_signatures)

@classmethod
def _add_file_cols(cls, genomes_dir, row):
row['file'] = SequenceFile(
path=genomes_dir / (row['name'] + '.fasta'),
format='fasta',
)
row['file_gz'] = SequenceFile(
path=genomes_dir / (row['name'] + '.fasta.gz'),
format='fasta',
compression='gzip',
)

@lazy
def query_genomes(self) -> list[TestQueryGenome]:
"""Query genomes and their expected results."""
Expand All @@ -184,7 +168,6 @@ def query_genomes(self) -> list[TestQueryGenome]:
for row in rows:
# Convert "warnings" column to bool
row['warnings'] = row['warnings'].lower() == 'true'
self._add_file_cols(self.paths.query_genomes_dir, row)

return rows # type: ignore

Expand All @@ -195,42 +178,57 @@ def ref_genomes(self) -> list[TestRefGenome]:
with open(self.paths.refs_table, newline='') as f:
rows = list(DictReader(f))

for row in rows:
self._add_file_cols(self.paths.ref_genomes_dir, row)

return rows # type: ignore

@classmethod
def _ensure_gz(cls, items):
"""Ensure gzipped versions of the query/ref files are available.
def _ensure_gz(cls, file: Path, file_gz: Path):
"""Ensure gzipped version of the query/ref file is available.
These aren't added to version control, so they are created the first time they are needed.
"""
for item in items:
dst = item['file_gz'].path
if dst.is_file():
continue
if file_gz.is_file():
return

with open(item['file'].path) as f:
content = f.read()
with open(file) as f:
content = f.read()

with gzip.open(dst, 'wt') as f:
f.write(content)
with gzip.open(file_gz, 'wt') as f:
f.write(content)

@classmethod
def _get_genome_files(cls, items, gzipped):
if gzipped:
col = 'file_gz'
cls._ensure_gz(items)
else:
col = 'file'
return [q[col] for q in items]

def get_query_files(self, gzipped: bool=False) -> list[SequenceFile]:
return self._get_genome_files(self.query_genomes, gzipped)

def get_ref_files(self, gzipped: bool=False) -> list[SequenceFile]:
return self._get_genome_files(self.ref_genomes, gzipped)
def _get_genome_files(self, base: Path, names: list[str], gzipped: bool, relative: bool) -> list[SequenceFile]:
base2 = base.relative_to(self.paths.root) if relative else base

files = []

for name in names:
fname = name + '.fasta'

if gzipped:
fname_gz = fname + '.gz'
self._ensure_gz(base / fname, base / fname_gz)
path = base2 / fname_gz
else:
path = base2 / fname

files.append(SequenceFile(path, 'fasta', 'gzip' if gzipped else None))

return files

def get_query_files(self, gzipped: bool = False, relative: bool = False) -> list[SequenceFile]:
return self._get_genome_files(
self.paths.query_genomes_dir,
[genome['name'] for genome in self.query_genomes],
gzipped=gzipped,
relative=relative,
)

def get_ref_files(self, gzipped: bool = False, relative: bool = False) -> list[SequenceFile]:
return self._get_genome_files(
self.paths.ref_genomes_dir,
[genome['name'] for genome in self.ref_genomes],
gzipped=gzipped,
relative=relative,
)

def get_query_results(self, strict: bool, session=None) -> QueryResults:
"""Pre-calculated query results."""
Expand Down

0 comments on commit b097e8f

Please sign in to comment.