Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: iteratively prune branches #272

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@email: [email protected]

Module that removes doubled fragments and trims branches that pass by each
other from a NeuroGraph.
other from a FragmentsGraph.

"""
from collections import defaultdict
Expand All @@ -21,57 +21,70 @@
QUERY_DIST = 15


# --- Curvy Removal ---
def remove_curvy(graph, max_length, ratio=0.5):
deleted_ids = set()
components = [c for c in connected_components(graph) if len(c) == 2]
for nodes in tqdm(components, desc="Curvy Filter:"):
if len(nodes) == 2:
i, j = tuple(nodes)
length = graph.edges[i, j]["length"]
endpoint_dist = graph.dist(i, j)
if endpoint_dist / length < ratio and length < max_length:
deleted_ids.add(graph.edges[i, j]["swc_id"])
delete_fragment(graph, i, j)
return len(deleted_ids)


# --- Doubles Removal ---
def remove_doubles(neurograph, max_size, node_spacing, output_dir=None):
def remove_doubles(graph, max_length, node_spacing, output_dir=None):
"""
Removes connected components from "neurgraph" that are likely to be a
double.

Parameters
----------
neurograph : NeuroGraph
graph : FragmentsGraph
Graph to be searched for doubles.
max_size : int
max_length : int
Maximum size of connected components to be searched.
node_spacing : int
Expected distance in microns between nodes in "neurograph".
Expected distance in microns between nodes in "graph".
output_dir : str or None, optional
Directory that doubles will be written to. The default is None.

Returns
-------
NeuroGraph
graph
Graph with doubles removed.

"""
# Initializations
components = [c for c in connected_components(neurograph) if len(c) == 2]
deleted = set()
kdtree = neurograph.get_kdtree()
components = [c for c in connected_components(graph) if len(c) == 2]
deleted_ids = set()
kdtree = graph.get_kdtree()
if output_dir:
util.mkdir(output_dir, delete=True)

# Main
desc = "Doubles Detection"
desc = "Doubles Filtering"
for idx in tqdm(np.argsort([len(c) for c in components]), desc=desc):
i, j = tuple(components[idx])
swc_id = neurograph.nodes[i]["swc_id"]
if swc_id not in deleted:
if len(neurograph.edges[i, j]["xyz"]) * node_spacing < max_size:
swc_id = graph.nodes[i]["swc_id"]
if swc_id not in deleted_ids:
if graph.edges[i, j]["length"] < max_length:
# Check doubles criteria
n_points = len(neurograph.edges[i, j]["xyz"])
hits = compute_projections(neurograph, kdtree, (i, j))
n_points = len(graph.edges[i, j]["xyz"])
hits = compute_projections(graph, kdtree, (i, j))
if check_doubles_criteria(hits, n_points):
if output_dir:
neurograph.to_swc(
output_dir, components[idx], color=COLOR
)
neurograph = delete(neurograph, components[idx], swc_id)
deleted.add(swc_id)
return len(deleted)
graph.to_swc(output_dir, [i, j], color=COLOR)
delete_fragment(graph, i, j)
deleted_ids.add(swc_id)
return len(deleted_ids)


def compute_projections(neurograph, kdtree, edge):
def compute_projections(graph, kdtree, edge):
"""
Given a fragment defined by "edge", this routine iterates of every xyz in
the fragment and projects it onto the closest fragment. For each detected
Expand All @@ -80,11 +93,11 @@ def compute_projections(neurograph, kdtree, edge):

Parameters
----------
neurograph : NeuroGraph
graph : graph
Graph that contains "edge".
kdtree : KDTree
KD-Tree that contains all xyz coordinates of every fragment in
"neurograph".
"graph".
edge : tuple
Pair of leaf nodes that define a fragment.

Expand All @@ -96,13 +109,13 @@ def compute_projections(neurograph, kdtree, edge):

"""
hits = defaultdict(list)
query_id = neurograph.edges[edge]["swc_id"]
for i, xyz in enumerate(neurograph.edges[edge]["xyz"]):
query_id = graph.edges[edge]["swc_id"]
for i, xyz in enumerate(graph.edges[edge]["xyz"]):
# Compute projections
best_id = None
best_dist = np.inf
for hit_xyz in geometry.query_ball(kdtree, xyz, QUERY_DIST):
hit_id = neurograph.xyz_to_swc(hit_xyz)
hit_id = graph.xyz_to_swc(hit_xyz)
if hit_id is not None and hit_id != query_id:
if geometry.dist(hit_xyz, xyz) < best_dist:
best_dist = geometry.dist(hit_xyz, xyz)
Expand Down Expand Up @@ -144,56 +157,54 @@ def check_doubles_criteria(hits, n_points):
return False


def delete(neurograph, nodes, swc_id):
def delete_fragment(graph, i, j):
"""
Deletes "nodes" from "neurograph".
Deletes nodes "i" and "j" from "graph", where these nodes form a connected
component.

Parameters
----------
neurograph : NeuroGraph
Graph that contains "nodes".
nodes : list[int]
Nodes to be removed.
swc_id : str
swc id corresponding to nodes which comprise a connected component in
"neurograph".
graph : FragmentsGraph
Graph that contains nodes to be deleted.
i : int
Node to be removed.
j : int
Node to be removed.

Returns
-------
NeuroGraph
graph
Graph with nodes removed.

"""
i, j = tuple(nodes)
neurograph = remove_xyz_entries(neurograph, i, j)
neurograph.remove_nodes_from([i, j])
neurograph.swc_ids.remove(swc_id)
return neurograph
graph = remove_xyz_entries(graph, i, j)
graph.swc_ids.remove(graph.nodes[i]["swc_id"])
graph.remove_nodes_from([i, j])


def remove_xyz_entries(neurograph, i, j):
def remove_xyz_entries(graph, i, j):
"""
Removes dictionary entries from "neurograph.xyz_to_edge" corresponding to
Removes dictionary entries from "graph.xyz_to_edge" corresponding to
the edge {i, j}.

Parameters
----------
neurograph : NeuroGraph
graph : graph
Graph to be updated.
i : int
Node in "neurograph".
Node in "graph".
j : int
Node in "neurograph".
Node in "graph".

Returns
-------
NeuroGraph
graph
Updated graph.

"""
for xyz in neurograph.edges[i, j]["xyz"]:
del neurograph.xyz_to_edge[tuple(xyz)]
return neurograph
for xyz in graph.edges[i, j]["xyz"]:
del graph.xyz_to_edge[tuple(xyz)]
return graph


def upd_hits(hits, key, value):
Expand Down
Loading
Loading