Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
rzlim08 committed Apr 25, 2024
1 parent ac8293c commit fc068dc
Show file tree
Hide file tree
Showing 6 changed files with 404 additions and 193 deletions.
136 changes: 90 additions & 46 deletions lib/dagless/dagless/m8/call_hits.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,45 @@
from ..utils.parsing import BlastnOutput6Reader, BlastnOutput6Writer, HitSummaryWriter

log = logging.getLogger(__name__)
log.log_context = lambda x, y: print(x, y) # monkey patch log function
log.write = lambda x: print(x) # monkey patch log function
log.log_context = lambda x, y: print(x, y) # monkey patch log function
log.write = lambda x: print(x) # monkey patch log function

# mock lineage


class lineage:
NULL_SPECIES_ID = "-100"
NULL_GENUS_ID = "-200"
NULL_FAMILY_ID = "-300"
NULL_LINEAGE = (NULL_SPECIES_ID, NULL_GENUS_ID, NULL_FAMILY_ID)


################
### CP PASTE ###
################
def build_should_keep_filter(
deuterostome_path,
taxon_whitelist_path,
taxon_blacklist_path
deuterostome_path, taxon_whitelist_path, taxon_blacklist_path
):

# See also HOMO_SAPIENS_TAX_IDS in idseq-web
taxids_to_remove = set(['9605', '9606'])
taxids_to_remove = set(["9605", "9606"])

if taxon_blacklist_path:
with log.log_context("generate_taxon_count_json_from_m8", {"substep": "read_blacklist_into_set"}):
with log.log_context(
"generate_taxon_count_json_from_m8", {"substep": "read_blacklist_into_set"}
):
taxids_to_remove.update(read_file_into_set(taxon_blacklist_path))

if deuterostome_path:
with log.log_context("generate_taxon_count_json_from_m8", {"substep": "read_file_into_set"}):
with log.log_context(
"generate_taxon_count_json_from_m8", {"substep": "read_file_into_set"}
):
taxids_to_remove.update(read_file_into_set(deuterostome_path))

if taxon_whitelist_path:
with log.log_context("generate_taxon_count_json_from_m8", {"substep": "read_whitelist_into_set"}):
with log.log_context(
"generate_taxon_count_json_from_m8", {"substep": "read_whitelist_into_set"}
):
taxids_to_keep = read_file_into_set(taxon_whitelist_path)

def is_blacklisted(hits: Iterable[str]):
Expand All @@ -66,13 +72,22 @@ def should_keep(hits: Iterable[str]):
return should_keep


def _call_hits_m8_work(input_blastn_6_path, lineage_map, accession2taxid_dict,
output_blastn_6_path, output_summary, min_alignment_length,
deuterostome_path, taxon_whitelist_path, taxon_blacklist_path):
def _call_hits_m8_work(
input_blastn_6_path,
lineage_map,
accession2taxid_dict,
output_blastn_6_path,
output_summary,
min_alignment_length,
deuterostome_path,
taxon_whitelist_path,
taxon_blacklist_path,
):
lineage_cache = {}

should_keep = build_should_keep_filter(
deuterostome_path, taxon_whitelist_path, taxon_blacklist_path)
deuterostome_path, taxon_whitelist_path, taxon_blacklist_path
)

# Helper functions
def get_lineage(accession_id):
Expand All @@ -82,8 +97,7 @@ def get_lineage(accession_id):
"""
if accession_id in lineage_cache:
return lineage_cache[accession_id]
accession_taxid = accession2taxid_dict.get(
accession_id.split(".")[0], "NA")
accession_taxid = accession2taxid_dict.get(accession_id.split(".")[0], "NA")
result = lineage_map.get(accession_taxid, lineage.NULL_LINEAGE)
lineage_cache[accession_id] = result
return result
Expand All @@ -100,8 +114,7 @@ def accumulate(hits, accession_id):
# provided by other accessions. This occurs a lot and
# handling it in this way seems to work well.
continue
accession_list = hits[level].get(
taxid_at_level, []) + [accession_id]
accession_list = hits[level].get(taxid_at_level, []) + [accession_id]
hits[level][taxid_at_level] = accession_list

def most_frequent_accession(accession_list):
Expand All @@ -117,7 +130,7 @@ def most_frequent_accession(accession_list):
randgen = random.Random(x=4) # chosen by fair dice role, guaranteed to be random

def call_hit_level_v2(hits):
''' Always call hit at the species level with the taxid with most matches '''
""" Always call hit at the species level with the taxid with most matches """
species_level_hits = hits[0]
max_match = 0
taxid_candidates = []
Expand All @@ -132,8 +145,7 @@ def call_hit_level_v2(hits):
selected_taxid = taxid_candidates[0]
if len(taxid_candidates) > 1:
selected_taxid = randgen.sample(taxid_candidates, 1)[0]
accession_id = most_frequent_accession(
species_level_hits[selected_taxid])
accession_id = most_frequent_accession(species_level_hits[selected_taxid])
return 1, selected_taxid, accession_id
return -1, "-1", None

Expand All @@ -143,16 +155,26 @@ def call_hit_level_v2(hits):
LOG_INCREMENT = 50000
log.write(f"Starting to summarize hits from {input_blastn_6_path}.")
with open(input_blastn_6_path) as input_blastn_6_f:
for row in BlastnOutput6Reader(input_blastn_6_f, filter_invalid=True, min_alignment_length=min_alignment_length):
read_id, accession_id, bitscore = row["qseqid"], row["sseqid"], row["bitscore"]
for row in BlastnOutput6Reader(
input_blastn_6_f,
filter_invalid=True,
min_alignment_length=min_alignment_length,
):
read_id, accession_id, bitscore = (
row["qseqid"],
row["sseqid"],
row["bitscore"],
)
# The Expect value (E) is a parameter that describes the number of
# hits one can 'expect' to see by chance when searching a database of
# a particular size. It decreases exponentially as the Score (S) of
# the match increases. Essentially, the E value describes the random
# background noise. https://blast.ncbi.nlm.nih.gov/Blast.cgi?CMD=Web
# &PAGE_TYPE=BlastDocs&DOC_TYPE=FAQ
# We have since moved to using the bitscore rather than the e-value
my_best_bitscore, hits, _ = summary.get(read_id, (float("-inf"), [{}, {}, {}], None))
my_best_bitscore, hits, _ = summary.get(
read_id, (float("-inf"), [{}, {}, {}], None)
)
if my_best_bitscore < bitscore:
# If we find a new better bitscore we want to start accumulation over
hits = [{}, {}, {}]
Expand All @@ -164,14 +186,18 @@ def call_hit_level_v2(hits):
summary[read_id] = my_best_bitscore, hits, call_hit_level_v2(hits)
count += 1
if count % LOG_INCREMENT == 0:
log.write(f"Summarized hits for {count} read ids from {input_blastn_6_path}, and counting.")
log.write(
f"Summarized hits for {count} read ids from {input_blastn_6_path}, and counting."
)

log.write(f"Summarized hits for all {count} read ids from {input_blastn_6_path}.")

# Generate output files. outf is the main output_m8 file and outf_sum is
# the summary level info.
emitted = set()
with open(output_blastn_6_path, "w") as blastn_6_out_f, open(output_summary, "w") as hit_summary_out_f, open(input_blastn_6_path) as input_blastn_6_f:
with open(output_blastn_6_path, "w") as blastn_6_out_f, open(
output_summary, "w"
) as hit_summary_out_f, open(input_blastn_6_path) as input_blastn_6_f:
blastn_6_writer = BlastnOutput6Writer(blastn_6_out_f)
hit_summary_writer = HitSummaryWriter(hit_summary_out_f)
# Iterator over the lines of the m8 file. Emit the hit with the
Expand All @@ -188,15 +214,26 @@ def call_hit_level_v2(hits):
# TODO: Consider all hits within a fixed margin of the best e-value.
# This change may need to be accompanied by a change to
# GSNAP/RAPSearch2 parameters.
for row in BlastnOutput6Reader(input_blastn_6_f, filter_invalid=True, min_alignment_length=min_alignment_length):
read_id, accession_id, bitscore = row["qseqid"], row["sseqid"], row["bitscore"]
for row in BlastnOutput6Reader(
input_blastn_6_f,
filter_invalid=True,
min_alignment_length=min_alignment_length,
):
read_id, accession_id, bitscore = (
row["qseqid"],
row["sseqid"],
row["bitscore"],
)
if read_id in emitted:
continue

# Read the fields from the summary level info
best_bitscore, _, (hit_level, taxid,
best_accession_id) = summary[read_id]
if best_bitscore == bitscore and best_accession_id in (None, accession_id) and should_keep([taxid]):
best_bitscore, _, (hit_level, taxid, best_accession_id) = summary[read_id]
if (
best_bitscore == bitscore
and best_accession_id in (None, accession_id)
and should_keep([taxid])
):
# Read out the hit with the best value that provides the
# most specific taxonomy information.
emitted.add(read_id)
Expand All @@ -206,17 +243,21 @@ def call_hit_level_v2(hits):
family_taxid = -1
if best_accession_id != None:
(species_taxid, genus_taxid, family_taxid) = get_lineage(
best_accession_id)

hit_summary_writer.writerow({
"read_id": read_id,
"level": hit_level,
"taxid": taxid,
"accession_id": best_accession_id,
"species_taxid": species_taxid,
"genus_taxid": genus_taxid,
"family_taxid": family_taxid,
})
best_accession_id
)

hit_summary_writer.writerow(
{
"read_id": read_id,
"level": hit_level,
"taxid": taxid,
"accession_id": best_accession_id,
"species_taxid": species_taxid,
"genus_taxid": genus_taxid,
"family_taxid": family_taxid,
}
)


##############
### END CP ###
Expand All @@ -231,7 +272,7 @@ def call_hits_m8(
deuterostome_path,
taxon_whitelist_path,
taxon_blacklist_path,
output_prefix
output_prefix,
):
with open_file_db_by_extension(
lineage_map_path, "lll"
Expand All @@ -258,13 +299,16 @@ def main():
description="Given an m8 file, call hits for each query"
)
parser.add_argument("-f", "--file", help="Path to the input m8 file", required=True)
parser.add_argument("-l", "--lineage", help="Path to the lineage file", required=True)
parser.add_argument(
"-a", "--accession2taxid", help="Path to the accession2taxid file", required=True
"-l", "--lineage", help="Path to the lineage file", required=True
)
parser.add_argument(
"-o", "--output", help="Output prefix", required=True
"-a",
"--accession2taxid",
help="Path to the accession2taxid file",
required=True,
)
parser.add_argument("-o", "--output", help="Output prefix", required=True)

parser.add_argument(
"--min-alignment-length",
Expand Down Expand Up @@ -292,7 +336,7 @@ def main():
args.deuterostome_path,
args.taxon_whitelist_path,
args.taxon_blacklist_path,
args.output
args.output,
)


Expand Down
Loading

0 comments on commit fc068dc

Please sign in to comment.