Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #302

Merged
merged 5 commits into from
Jun 18, 2024
Merged

Dev #302

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 59 additions & 32 deletions goatools/godag/go_tasks.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""item-DAG tasks."""

__copyright__ = "Copyright (C) 2010-present, DV Klopfenstein, H Tang, All rights reserved."
__copyright__ = (
"Copyright (C) 2010-present, DV Klopfenstein, H Tang, All rights reserved."
)
__author__ = "DV Klopfenstein"

from goatools.godag.consts import RELATIONSHIP_SET
from ..godag.consts import RELATIONSHIP_SET


# ------------------------------------------------------------------------------------
def get_go2parents(go2obj, relationships):
"""Get set of parents GO IDs, including parents through user-specfied relationships"""
if go2obj and not hasattr(next(iter(go2obj.values())), 'relationship') or not relationships:
if (
go2obj
and not hasattr(next(iter(go2obj.values())), "relationship")
or not relationships
):
return get_go2parents_isa(go2obj)
go2parents = {}
for goid_main, goterm in go2obj.items():
Expand All @@ -21,10 +26,14 @@ def get_go2parents(go2obj, relationships):
go2parents[goid_main] = parents_goids
return go2parents

# ------------------------------------------------------------------------------------

def get_go2children(go2obj, relationships):
"""Get set of children GO IDs, including children through user-specfied relationships"""
if go2obj and not hasattr(next(iter(go2obj.values())), 'relationship') or not relationships:
if (
go2obj
and not hasattr(next(iter(go2obj.values())), "relationship")
or not relationships
):
return get_go2children_isa(go2obj)
go2children = {}
for goid_main, goterm in go2obj.items():
Expand All @@ -36,7 +45,7 @@ def get_go2children(go2obj, relationships):
go2children[goid_main] = children_goids
return go2children

# ------------------------------------------------------------------------------------

def get_go2parents_isa(go2obj):
"""Get set of immediate parents GO IDs"""
go2parents = {}
Expand All @@ -46,7 +55,7 @@ def get_go2parents_isa(go2obj):
go2parents[goid_main] = parents_goids
return go2parents

# ------------------------------------------------------------------------------------

def get_go2children_isa(go2obj):
"""Get set of immediate children GO IDs"""
go2children = {}
Expand All @@ -56,84 +65,96 @@ def get_go2children_isa(go2obj):
go2children[goid_main] = children_goids
return go2children

# ------------------------------------------------------------------------------------

def get_go2ancestors(terms, relationships, prt=None):
"""Get GO-to- ancestors (all parents)"""
if not relationships:
if prt is not None:
prt.write('up: is_a\n')
prt.write("up: is_a\n")
return get_id2parents(terms)
if relationships == RELATIONSHIP_SET or relationships is True:
if prt is not None:
prt.write('up: is_a and {Rs}\n'.format(
Rs=' '.join(sorted(RELATIONSHIP_SET))))
prt.write(
"up: is_a and {Rs}\n".format(Rs=" ".join(sorted(RELATIONSHIP_SET)))
)
return get_id2upper(terms)
if prt is not None:
prt.write('up: is_a and {Rs}\n'.format(
Rs=' '.join(sorted(relationships))))
prt.write("up: is_a and {Rs}\n".format(Rs=" ".join(sorted(relationships))))
return get_id2upperselect(terms, relationships)


def get_go2descendants(terms, relationships, prt=None):
"""Get GO-to- descendants"""
if not relationships:
if prt is not None:
prt.write('down: is_a\n')
prt.write("down: is_a\n")
return get_id2children(terms)
if relationships == RELATIONSHIP_SET or relationships is True:
if prt is not None:
prt.write('down: is_a and {Rs}\n'.format(
Rs=' '.join(sorted(RELATIONSHIP_SET))))
prt.write(
"down: is_a and {Rs}\n".format(Rs=" ".join(sorted(RELATIONSHIP_SET)))
)
return get_id2lower(terms)
if prt is not None:
prt.write('down: is_a and {Rs}\n'.format(
Rs=' '.join(sorted(relationships))))
prt.write("down: is_a and {Rs}\n".format(Rs=" ".join(sorted(relationships))))
return get_id2lowerselect(terms, relationships)

# ------------------------------------------------------------------------------------

def get_go2depth(goobjs, relationships):
"""Get depth of each object"""
if not relationships:
return {o.item_id:o.depth for o in goobjs}
return {o.item_id: o.depth for o in goobjs}
from goatools.godag.reldepth import get_go2reldepth

return get_go2reldepth(goobjs, relationships)

# ------------------------------------------------------------------------------------

def get_id2parents(objs):
"""Get all parent IDs up the hierarchy"""
id2parents = {}
for obj in objs:
_get_id2parents(id2parents, obj.item_id, obj)
return {e:es for e, es in id2parents.items() if es}
return {e: es for e, es in id2parents.items() if es}


def get_id2children(objs):
"""Get all child IDs down the hierarchy"""
id2children = {}
for obj in objs:
_get_id2children(id2children, obj.item_id, obj)
return {e:es for e, es in id2children.items() if es}
return {e: es for e, es in id2children.items() if es}


def get_id2upper(objs):
"""Get all ancestor IDs, including all parents and IDs up all relationships"""
id2upper = {}
for obj in objs:
_get_id2upper(id2upper, obj.item_id, obj)
return {e:es for e, es in id2upper.items() if es}
return {e: es for e, es in id2upper.items() if es}


def get_id2lower(objs):
"""Get all descendant IDs, including all children and IDs down all relationships"""
id2lower = {}
cache = set()
for obj in objs:
_get_id2lower(id2lower, obj.item_id, obj)
return {e:es for e, es in id2lower.items() if es}
item_id = obj.item_id
if item_id in cache:
continue
_get_id2lower(id2lower, obj.item_id, obj, cache)
return {e: es for e, es in id2lower.items() if es}


def get_id2upperselect(objs, relationship_set):
"""Get all ancestor IDs, including all parents and IDs up selected relationships"""
return IdToUpperSelect(objs, relationship_set).id2upperselect


def get_id2lowerselect(objs, relationship_set):
"""Get all descendant IDs, including all children and IDs down selected relationships"""
return IdToLowerSelect(objs, relationship_set).id2lowerselect


def get_relationship_targets(item_ids, relationships, id2rec):
"""Get item ID set of item IDs in a relationship target set"""
# Requirements to use this function:
Expand All @@ -148,7 +169,7 @@ def get_relationship_targets(item_ids, relationships, id2rec):
reltgt_objs_all.update(reltgt_objs_cur)
return reltgt_objs_all

# ------------------------------------------------------------------------------------

# pylint: disable=too-few-public-methods
class IdToUpperSelect:
"""Get all ancestor IDs, including all parents and IDs up selected relationships"""
Expand Down Expand Up @@ -178,6 +199,7 @@ def _get_id2upperselect(self, item_id, item_obj):
id2upperselect[item_id] = parent_ids
return parent_ids


class IdToLowerSelect:
"""Get all descendant IDs, including all children and IDs down selected relationships"""

Expand Down Expand Up @@ -206,7 +228,6 @@ def _get_id2lowerselect(self, item_id, item_obj):
id2lowerselect[item_id] = child_ids
return child_ids

# ------------------------------------------------------------------------------------

def _get_id2parents(id2parents, item_id, item_obj):
"""Add the parent item IDs for one item object and their parents."""
Expand All @@ -220,6 +241,7 @@ def _get_id2parents(id2parents, item_id, item_obj):
id2parents[item_id] = parent_ids
return parent_ids


def _get_id2children(id2children, item_id, item_obj):
"""Add the child item IDs for one item object and their children."""
if item_id in id2children:
Expand All @@ -232,6 +254,7 @@ def _get_id2children(id2children, item_id, item_obj):
id2children[item_id] = child_ids
return child_ids


def _get_id2upper(id2upper, item_id, item_obj):
"""Add the parent item IDs for one item object and their upper."""
if item_id in id2upper:
Expand All @@ -244,19 +267,23 @@ def _get_id2upper(id2upper, item_id, item_obj):
id2upper[item_id] = upper_ids
return upper_ids

def _get_id2lower(id2lower, item_id, item_obj):

def _get_id2lower(id2lower, item_id, item_obj, cache: set):
"""Add the lower item IDs for one item object and the objects below them."""
if item_id in id2lower:
return id2lower[item_id]
lower_ids = set()
cache.add(item_id)
for lower_obj in item_obj.get_goterms_lower():
lower_id = lower_obj.item_id
lower_ids.add(lower_id)
lower_ids |= _get_id2lower(id2lower, lower_id, lower_obj)
if lower_id in cache:
continue
lower_ids |= _get_id2lower(id2lower, lower_id, lower_obj, cache)
id2lower[item_id] = lower_ids
return lower_ids

# ------------------------------------------------------------------------------------

class CurNHigher:
"""Fill id2obj with item IDs in relationships."""

Expand Down
36 changes: 27 additions & 9 deletions goatools/nt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import datetime
import collections as cx


def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""):
"""Return a new dict of namedtuples by combining "dicts" of namedtuples or objects."""
assert len(ids) == len(set(ids)), "NOT ALL IDs ARE UNIQUE: {IDs}".format(IDs=ids)
assert len(flds) == len(set(flds)), "DUPLICATE FIELDS: {IDs}".format(
IDs=cx.Counter(flds).most_common())
IDs=cx.Counter(flds).most_common()
)
usr_id_nt = []
# 1. Instantiate namedtuple object
ntobj = cx.namedtuple("Nt", " ".join(flds))
Expand All @@ -23,6 +25,7 @@ def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""):
usr_id_nt.append((item_id, ntobj._make(vals)))
return cx.OrderedDict(usr_id_nt)


def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""):
"""Return a new list of namedtuples by combining "dicts" of namedtuples or objects."""
combined_nt_list = []
Expand All @@ -36,48 +39,61 @@ def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""):
combined_nt_list.append(ntobj._make(vals))
return combined_nt_list


def combine_nt_lists(lists, flds, dflt_null=""):
"""Return a new list of namedtuples by zipping "lists" of namedtuples or objects."""
combined_nt_list = []
# Check that all lists are the same length
lens = [len(lst) for lst in lists]
assert len(set(lens)) == 1, \
"LIST LENGTHS MUST BE EQUAL: {Ls}".format(Ls=" ".join(str(l) for l in lens))
assert len(set(lens)) == 1, "LIST LENGTHS MUST BE EQUAL: {Ls}".format(
Ls=" ".join(str(l) for l in lens)
)
# 1. Instantiate namedtuple object
ntobj = cx.namedtuple("Nt", " ".join(flds))
# 2. Loop through zipped list
for lst0_lstn in zip(*lists):
# 2a. Combine various namedtuples into a single namedtuple
combined_nt_list.append(ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null)))
combined_nt_list.append(
ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null))
)
return combined_nt_list


def wr_py_nts(fout_py, nts, docstring=None, varname="nts"):
"""Save namedtuples into a Python module."""
if nts:
with open(fout_py, 'w') as prt:
with open(fout_py, "w") as prt:
prt.write('"""{DOCSTRING}"""\n\n'.format(DOCSTRING=docstring))
prt.write("# Created: {DATE}\n".format(DATE=str(datetime.date.today())))
prt_nts(prt, nts, varname)
sys.stdout.write(" {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py))
sys.stdout.write(
" {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py)
)

def prt_nts(prt, nts, varname, spc=' '):

def prt_nts(prt, nts, varname, spc=" "):
"""Print namedtuples into a Python module."""
first_nt = nts[0]
nt_name = type(first_nt).__name__
prt.write("import collections as cx\n\n")
prt.write("import numpy as np\n\n")
prt.write("NT_FIELDS = [\n")
for fld in first_nt._fields:
prt.write('{SPC}"{F}",\n'.format(SPC=spc, F=fld))
prt.write("]\n\n")
prt.write('{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format(
NtName=nt_name))
prt.write(
'{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format(
NtName=nt_name
)
)
prt.write("# {N:,} items\n".format(N=len(nts)))
prt.write("# pylint: disable=line-too-long\n")
prt.write("{VARNAME} = [\n".format(VARNAME=varname))
for ntup in nts:
prt.write("{SPC}{NT},\n".format(SPC=spc, NT=ntup))
prt.write("]\n")


def get_unique_fields(fld_lists):
"""Get unique namedtuple fields, despite potential duplicates in lists of fields."""
flds = []
Expand All @@ -93,6 +109,7 @@ def get_unique_fields(fld_lists):
assert len(flds) == len(fld_set)
return flds


# -- Internal methods ----------------------------------------------------------------
def _combine_nt_vals(lst0_lstn, flds, dflt_null):
"""Given a list of lists of nts, return a single namedtuple."""
Expand All @@ -110,4 +127,5 @@ def _combine_nt_vals(lst0_lstn, flds, dflt_null):
vals.append(dflt_null)
return vals


# Copyright (C) 2016-2018, DV Klopfenstein, H Tang. All rights reserved.
3 changes: 3 additions & 0 deletions tests/test_dcnt_r01.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import sys
import timeit
import numpy as np
import pytest

from numpy.random import shuffle
from scipy import stats

Expand All @@ -14,6 +16,7 @@
from goatools.obo_parser import GODag


@pytest.mark.skip(reason="Latest obo (`releases/2024-06-10`) is not DAG")
def test_go_pools():
"""Print a comparison of GO terms from different species in two different comparisons."""
objr = _Run()
Expand Down
Loading