Skip to content

Commit

Permalink
Type aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
jlumpe committed Aug 4, 2024
1 parent 2d5675f commit f1f718a
Show file tree
Hide file tree
Showing 15 changed files with 62 additions and 42 deletions.
5 changes: 5 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@
autodoc_member_order = 'groupwise'
autodoc_typehints = 'description'

autodoc_type_aliases = {
'FilePath': 'FilePath',
'DNASeq': 'DNASeq',
}

intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
'numpy': ('https://numpy.org/doc/stable/', None),
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ install_requires =
click>=7.0
h5py~=3.0
scipy~=1.7
typing-extensions>=4.0

tests_require =
pytest
Expand Down
2 changes: 1 addition & 1 deletion src/gambit/cli/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def strip_seq_file_ext(filename: str) -> str:
return filename


def get_file_id(path: FilePath, strip_dir: bool = True, strip_ext: bool = True) -> str:
def get_file_id(path: 'FilePath', strip_dir: bool = True, strip_ext: bool = True) -> str:
"""Get sequence file ID derived from file path.
Parameters
Expand Down
4 changes: 2 additions & 2 deletions src/gambit/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def check_clade(clade):
assert root_i == nleaves * 2 - 2


def dump_dmat_csv(file: Union[FilePath, TextIO],
def dump_dmat_csv(file: Union['FilePath', TextIO],
dmat: np.ndarray,
row_ids: Sequence,
col_ids: Sequence,
Expand All @@ -136,7 +136,7 @@ def dump_dmat_csv(file: Union[FilePath, TextIO],
writer.writerow([str(row_id), *values_str])


def load_dmat_csv(file: Union[FilePath, TextIO]) -> tuple[np.ndarray, list[str], list[str]]:
def load_dmat_csv(file: Union['FilePath', TextIO]) -> tuple[np.ndarray, list[str], list[str]]:
"""Load distance matrix from CSV file.
Returns
Expand Down
8 changes: 4 additions & 4 deletions src/gambit/db/refdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, msg, directory=None, genomes_file=None, signatures_file=None)
self.signatures_file = signatures_file


def load_genomeset(db_file: FilePath) -> tuple[Session, ReferenceGenomeSet]:
def load_genomeset(db_file: 'FilePath') -> tuple[Session, ReferenceGenomeSet]:
"""Get the only :class:`gambit.db.models.ReferenceGenomeSet` from a genomes database file."""
session = file_sessionmaker(db_file)()
gset = only_genomeset(session)
Expand Down Expand Up @@ -211,7 +211,7 @@ def __init__(self, genomeset: ReferenceGenomeSet, signatures: ReferenceSignature
raise ValueError(f'{missing} of {n} genomes not matched to signature IDs. Is the id_attr attribute of the signatures metadata correct?')

@classmethod
def locate_files(cls, path: FilePath) -> tuple[Path, Path]:
def locate_files(cls, path: 'FilePath') -> tuple[Path, Path]:
"""Locate an SQLite genome database file and HDF5 signatures file in a directory.
Files are located by extension, ``.gdb`` or ``.db`` for SQLite file and ``.gs`` or ``.h5``
Expand Down Expand Up @@ -258,14 +258,14 @@ def check_single_match(matches, desc: str):
return genomes_file, signatures_file

@classmethod
def load(cls, genomes_file: FilePath, signatures_file: FilePath) -> 'ReferenceDatabase':
def load(cls, genomes_file: 'FilePath', signatures_file: 'FilePath') -> 'ReferenceDatabase':
"""Load complete database given paths to SQLite genomes database file and HDF5 signatures file."""
session, gset = load_genomeset(genomes_file)
sigs = load_signatures(signatures_file)
return cls(gset, sigs)

@classmethod
def load_from_dir(cls, path: FilePath) -> 'ReferenceDatabase':
def load_from_dir(cls, path: 'FilePath') -> 'ReferenceDatabase':
"""
Load complete database given directory containing SQLite genomes database file and HDF5
signatures file.
Expand Down
2 changes: 1 addition & 1 deletion src/gambit/db/sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def process_result_value(self, value, dialect):
return None if value is None else gjson.loads(value)


def file_sessionmaker(path: FilePath, readonly: bool = True, cls: type = None, **kw) -> sessionmaker:
def file_sessionmaker(path: 'FilePath', readonly: bool = True, cls: type = None, **kw) -> sessionmaker:
"""Get an SQLAlchemy ``sessionmaker`` for an sqlite database file.
Parameters
Expand Down
10 changes: 5 additions & 5 deletions src/gambit/kmers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def index_dtype(k: int) -> Optional[np.dtype]:
return None


def kmer_to_index(kmer: DNASeq) -> int:
def kmer_to_index(kmer: 'DNASeq') -> int:
"""Convert a k-mer to its integer index.
Raises
Expand All @@ -41,7 +41,7 @@ def kmer_to_index(kmer: DNASeq) -> int:
return ckmers.kmer_to_index(seq_to_bytes(kmer))


def kmer_to_index_rc(kmer: DNASeq) -> int:
def kmer_to_index_rc(kmer: 'DNASeq') -> int:
"""Get the integer index of a k-mer's reverse complement.
Raises
Expand Down Expand Up @@ -84,7 +84,7 @@ class KmerSpec(Jsonable):
nkmers: int = attrib(eq=False)
index_dtype: np.dtype = attrib(eq=False)

def __init__(self, k: int, prefix: DNASeq):
def __init__(self, k: int, prefix: 'DNASeq'):
"""
Parameters
----------
Expand Down Expand Up @@ -143,7 +143,7 @@ class KmerMatch:
If the match is on the reverse strand.
"""
kmerspec: KmerSpec = attrib()
seq: DNASeq = attrib()
seq: 'DNASeq' = attrib()
pos: int = attrib()
reverse: bool = attrib()

Expand Down Expand Up @@ -178,7 +178,7 @@ def kmer_index(self) -> int:
return kmer_to_index_rc(kmer) if self.reverse else kmer_to_index(kmer)


def find_kmers(kmerspec: KmerSpec, seq: DNASeq) -> Iterator[KmerMatch]:
def find_kmers(kmerspec: KmerSpec, seq: 'DNASeq') -> Iterator[KmerMatch]:
"""Locate k-mers with the given prefix in a DNA sequence.
Searches sequence both backwards and forwards (reverse complement). The sequence may contain
Expand Down
8 changes: 4 additions & 4 deletions src/gambit/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class AbstractResultsExporter(ABC):
"""

@abstractmethod
def export(self, file_or_path: Union[FilePath, IO], results: QueryResults):
def export(self, file_or_path: Union['FilePath', IO], results: QueryResults):
"""Write query results to file.
Parameters
Expand Down Expand Up @@ -55,7 +55,7 @@ def to_json(self, obj):
"""Convert object to JSON-compatible format (need not work recursively)."""
return gjson.to_json(obj)

def export(self, file_or_path: Union[FilePath, TextIO], results: QueryResults):
def export(self, file_or_path: Union['FilePath', TextIO], results: QueryResults):
opts = dict(indent=4, sort_keys=True) if self.pretty else dict()
with maybe_open(file_or_path, 'w') as f:
json.dump(results, f, default=self.to_json, **opts)
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_row(self, item: QueryResultItem) -> list:
"""Get row values for single result item."""
return [getattr_nested(item, attrs, pass_none=True) for _, attrs in self.COLUMNS]

def export(self, file_or_path: Union[FilePath, TextIO], results: QueryResults):
def export(self, file_or_path: Union['FilePath', TextIO], results: QueryResults):
with maybe_open(file_or_path, 'w') as f:
writer = csv.writer(f, **self.format_opts)

Expand Down Expand Up @@ -229,7 +229,7 @@ def _init_converter(self):
self._converter.register_structure_hook(AnnotatedGenome, self._structure_genome)
self._converter.register_structure_hook(Taxon, self._structure_taxon)

def read(self, file_or_path: Union[FilePath, IO]) -> QueryResults:
def read(self, file_or_path: Union['FilePath', IO]) -> QueryResults:
"""Read query results from JSON file.
Parameters
Expand Down
22 changes: 14 additions & 8 deletions src/gambit/seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@
Note that all code in this package operates on DNA sequences as sequences of
bytes containing ascii-encoded nucleotide codes.
.. data:: NUCLEOTIDES
``bytes`` corresponding to the four DNA nucleotides. Ascii-encoded upper
case letters ``ACGT``. Note that the order, while arbitrary, is important
in this variable as it defines how unique indices are assigned to k-mer
sequences.
.. class:: DNASeq
Type alias for DNA sequence types accepted for k-mer search / signature calculation
(``str``, ``bytes``, ``bytearray``, or :class:`Bio.Seq.Seq`).
"""

from pathlib import Path
from typing import Union, Optional, IO, Iterable
from os import PathLike

from Bio import SeqIO
from Bio.Seq import Seq
from attr import attrs, attrib
from typing_extensions import TypeAlias

from gambit._cython.kmers import revcomp
from gambit.util.io import FilePath
Expand All @@ -29,14 +37,12 @@

SEQ_TYPES = (str, bytes, bytearray, Seq)

#: Union of DNA sequence types accepted for k-mer search / signature calculation.
DNASeq = Union[SEQ_TYPES]

#: Sequence types accepted directly by native (Cython) code.
DNASeqBytes = Union[bytes, bytearray]
DNASeq: TypeAlias = Union[SEQ_TYPES]
# Type alias for sequence types accepted directly by native (Cython) code.
DNASeqBytes: TypeAlias = Union[bytes, bytearray]


def seq_to_bytes(seq: DNASeq) -> DNASeqBytes:
def seq_to_bytes(seq: 'DNASeq') -> 'DNASeqBytes':
"""Convert generic DNA sequence to byte string representation.
This is for passing sequence data to Cython functions.
Expand All @@ -52,7 +58,7 @@ def seq_to_bytes(seq: DNASeq) -> DNASeqBytes:
raise TypeError(f'Expected sequence type, got {type(seq)}')


def validate_dna_seq_bytes(seq : bytes):
def validate_dna_seq_bytes(seq: DNASeqBytes):
"""Check that a sequence contains only valid nucleotide codes (upper case).
Parameters
Expand Down Expand Up @@ -171,7 +177,7 @@ def absolute(self) -> 'SequenceFile':

@classmethod
def from_paths(cls,
paths: Iterable[FilePath],
paths: Iterable['FilePath'],
format: str,
compression: Optional[str] = None,
) -> list['SequenceFile']:
Expand Down
6 changes: 3 additions & 3 deletions src/gambit/sigs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class SignaturesFileError(Exception):
filename: str
format: str

def __init__(self, message: str, filename: Optional[FilePath], format: Optional[str]):
def __init__(self, message: str, filename: Optional['FilePath'], format: Optional[str]):
self.message = message
self.filename = str(filename)
self.format = format
Expand All @@ -418,7 +418,7 @@ def __str__(self):
return self.message


def load_signatures(path: FilePath, **kw) -> AbstractSignatureArray:
def load_signatures(path: 'FilePath', **kw) -> AbstractSignatureArray:
"""Load signatures from file.
Currently the only format used to store signatures is the one in :mod:`gambit.sigs.hdf5`, but
Expand All @@ -435,7 +435,7 @@ def load_signatures(path: FilePath, **kw) -> AbstractSignatureArray:
return load_signatures_hdf5(path, **kw)


def dump_signatures(path: FilePath,
def dump_signatures(path: 'FilePath',
signatures: AbstractSignatureArray,
format: str = 'hdf5',
**kw,
Expand Down
4 changes: 2 additions & 2 deletions src/gambit/sigs/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def default_accumulator(k: int) -> KmerAccumulator:
return SetAccumulator(k) if k > 11 else ArrayAccumulator(k)


def accumulate_kmers(accumulator: KmerAccumulator, kmerspec: KmerSpec, seq: DNASeq):
def accumulate_kmers(accumulator: KmerAccumulator, kmerspec: KmerSpec, seq: 'DNASeq'):
"""Find k-mer matches in sequence and add their indices to an accumulator."""
for match in find_kmers(kmerspec, seq):
try:
Expand All @@ -138,7 +138,7 @@ def accumulate_kmers(accumulator: KmerAccumulator, kmerspec: KmerSpec, seq: DNAS


def calc_signature(kmerspec: KmerSpec,
seqs: Union[DNASeq, Iterable[DNASeq]],
seqs: Union['DNASeq', Iterable['DNASeq']],
*,
accumulator: Optional[KmerAccumulator] = None,
) -> KmerSignature:
Expand Down
4 changes: 2 additions & 2 deletions src/gambit/sigs/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def create(cls,
return cls(group)


def load_signatures_hdf5(path: FilePath, **kw) -> HDF5Signatures:
def load_signatures_hdf5(path: 'FilePath', **kw) -> HDF5Signatures:
"""Open HDF5 signature file.
Parameters
Expand Down Expand Up @@ -254,7 +254,7 @@ def load_signatures_hdf5(path: FilePath, **kw) -> HDF5Signatures:
raise


def dump_signatures_hdf5(path: FilePath, signatures: AbstractSignatureArray, **kw):
def dump_signatures_hdf5(path: 'FilePath', signatures: AbstractSignatureArray, **kw):
"""Write k-mer signatures and associated metadata to an HDF5 file.
Parameters
Expand Down
22 changes: 15 additions & 7 deletions src/gambit/util/io.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
"""Utility code for reading/writing data files."""
"""Utility code for reading/writing data files.
.. class:: FilePath
Alias for types which can represent a file system path (``str`` or :class:`os.PathLike`).
"""

import os
from io import TextIOWrapper
from typing import Union, Optional, IO, TextIO, BinaryIO, ContextManager, Iterable, TypeVar
from contextlib import nullcontext

#: Alias for types which can represent a file system path
FilePath = Union[str, os.PathLike]
from typing_extensions import TypeAlias


FilePath: TypeAlias = Union[str, os.PathLike]

T = TypeVar('T')

Expand Down Expand Up @@ -69,7 +77,7 @@ def guess_compression(fobj: BinaryIO) -> Optional[str]:


def open_compressed(compression: Optional[str],
path: FilePath,
path: 'FilePath',
mode: str = 'rt',
**kwargs,
) -> IO:
Expand Down Expand Up @@ -172,7 +180,7 @@ def __exit__(self, *args):
self.close()


def maybe_open(file_or_path: Union[FilePath, IO], mode: str = 'r', **open_kw) -> ContextManager[IO]:
def maybe_open(file_or_path: Union['FilePath', IO], mode: str = 'r', **open_kw) -> ContextManager[IO]:
"""Open a file given a file path as an argument, but pass existing file objects though.
Intended to be used by API functions that take either type as an argument. If a file path is
Expand Down Expand Up @@ -208,7 +216,7 @@ def maybe_open(file_or_path: Union[FilePath, IO], mode: str = 'r', **open_kw) ->
return open(path, mode, **open_kw)


def read_lines(file_or_path: Union[FilePath, TextIO], strip: bool=True, skip_empty: bool=False) -> Iterable[str]:
def read_lines(file_or_path: Union['FilePath', TextIO], strip: bool=True, skip_empty: bool=False) -> Iterable[str]:
"""Iterate over lines in text file.
Parameters
Expand All @@ -232,7 +240,7 @@ def read_lines(file_or_path: Union[FilePath, TextIO], strip: bool=True, skip_emp
yield line


def write_lines(lines: Iterable, file_or_path: Union[FilePath, TextIO]):
def write_lines(lines: Iterable, file_or_path: Union['FilePath', TextIO]):
"""Write strings to text file, one per line.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion tests/cli/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_strip_seq_file_ext():
class TestGetSequenceFiles:
"""Test the get_sequence_files() function."""

def check_ids(self, ids: Iterable[str], paths: Iterable[FilePath], strip_dir: bool, strip_ext: bool):
def check_ids(self, ids: Iterable[str], paths: Iterable['FilePath'], strip_dir: bool, strip_ext: bool):
for id_, path in zip_strict(ids, paths):
if strip_dir:
expected = Path(path).name
Expand Down
4 changes: 2 additions & 2 deletions tests/cli/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

def make_args(testdb: TestDB, *,
positional_files: Optional[Iterable[SequenceFile]] = None,
list_file: Optional[FilePath] = None,
list_file: Optional['FilePath'] = None,
sig_file: bool = False,
output: Optional[FilePath] = None,
output: Optional['FilePath'] = None,
outfmt: Optional[str] = None,
strict: bool=False,
) -> list[str]:
Expand Down

0 comments on commit f1f718a

Please sign in to comment.