-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: nearly finished merge deletion (#85)
Co-authored-by: anna-grim <[email protected]>
- Loading branch information
Showing
5 changed files
with
279 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
""" | ||
Created on Sat March 26 17:30:00 2024 | ||
@author: Anna Grim | ||
@email: [email protected] | ||
Deletes merges from predicted swc files in the case when there are ground | ||
truth swc files. | ||
""" | ||
|
||
import networkx as nx | ||
import numpy as np | ||
from deep_neurographs.densegraph import DenseGraph | ||
from deep_neurographs import geometry, utils | ||
|
||
CLOSE_THRESHOLD = 3.5 | ||
DELETION_RADIUS = 5 | ||
MERGE_DIST_THRESHOLD = 30 | ||
MIN_INTERSECTION = 10 | ||
|
||
|
||
def delete_merges( | ||
target_swc_paths, | ||
pred_swc_paths, | ||
output_dir, | ||
deletion_radius=DELETION_RADIUS, | ||
): | ||
""" | ||
Deletes merges from predicted swc files in the case when there are ground | ||
truth swc files. | ||
Parameters | ||
---------- | ||
target_swc_paths : list[str] | ||
List of paths to ground truth swc files. | ||
pred_swc_paths : list[str] | ||
List of paths to predicted swc files. | ||
output_dir : str | ||
Directory that updated graphs will be written to. | ||
deletion_radius : int, optional | ||
Each node within "deletion_radius" is deleted. The default is the | ||
global variable "DELETION_RADIUS". | ||
Returns | ||
------- | ||
None | ||
""" | ||
target_densegraph = DenseGraph(target_swc_paths) | ||
pred_densegraph = DenseGraph(pred_swc_paths) | ||
for swc_id in pred_densegraph.graphs.keys(): | ||
# Detection | ||
pred_graph = pred_densegraph.graphs[swc_id] | ||
merged_nodes = detect_merge(target_densegraph, pred_graph) | ||
|
||
# Deletion | ||
if len(merged_nodes.keys()) > 0: | ||
visited = set() | ||
delete_nodes = set() | ||
for key_1 in merged_nodes.keys(): | ||
for key_2 in merged_nodes.keys(): | ||
pair = frozenset((key_1, key_2)) | ||
if key_1 != key_2 and pair not in visited: | ||
sites, d = locate_site( | ||
pred_graph, merged_nodes[key_1], merged_nodes[key_2] | ||
) | ||
if d < MERGE_DIST_THRESHOLD: | ||
print(sites, d) | ||
# delete just like a connector | ||
|
||
pred_densegraph.graphs[swc_id] = pred_graph | ||
|
||
# Save | ||
pred_densegraph.save(output_dir) | ||
|
||
|
||
def detect_merge(target_densegraph, pred_graph): | ||
""" | ||
Determines whether the "pred_graph" contains a merge mistake. This routine | ||
projects each node in "pred_graph" onto "target_neurograph", then computes | ||
the projection distance. ... | ||
Parameters | ||
---------- | ||
target_densegraph : DenseGraph | ||
Graph built from ground truth swc files. | ||
pred_graph : networkx.Graph | ||
Graph build from a predicted swc file. | ||
Returns | ||
------- | ||
set | ||
Set of nodes that are part of a merge mistake. | ||
""" | ||
# Compute projections | ||
hits = dict() | ||
for i in pred_graph.nodes: | ||
xyz = tuple(pred_graph.nodes[i]["xyz"]) | ||
hat_xyz = target_densegraph.get_projection(xyz) | ||
hat_swc_id = target_densegraph.xyz_to_swc[hat_xyz] | ||
if geometry.dist(hat_xyz, xyz) < CLOSE_THRESHOLD: | ||
hits = utils.append_dict_value(hits, hat_swc_id, i) | ||
|
||
# Remove spurious intersections | ||
keys = [key for key in hits.keys() if len(hits[key]) < MIN_INTERSECTION] | ||
return utils.remove_items(hits, keys) | ||
|
||
|
||
def locate_site(graph, merged_1, merged_2): | ||
min_dist = np.inf | ||
node_pair = [None, None] | ||
for i in merged_1: | ||
for j in merged_2: | ||
xyz_i = graph.nodes[i]["xyz"] | ||
xyz_j = graph.nodes[j]["xyz"] | ||
if geometry.dist(xyz_i, xyz_j) < min_dist: | ||
min_dist = geometry.dist(xyz_i, xyz_j) | ||
node_pair = [i, j] | ||
return node_pair, min_dist | ||
|
||
|
||
def delete_merge(graph, root, radius): | ||
delete_nodes = get_nearby_nodes(graph, root, radius) | ||
graph.remove_nodes_from(delete_nodes) | ||
return graph | ||
|
||
|
||
def get_nearby_nodes(graph, root, radius): | ||
nearby_nodes = set() | ||
for _, j in nx.dfs_edges(graph, source=root, depth_limit=radius): | ||
nearby_nodes.add(j) | ||
return nearby_nodes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,19 +4,21 @@ | |
@author: Anna Grim | ||
@email: [email protected] | ||
Class of graphs built from swc files. Each swc file is stored as a distinct | ||
graph and each node in this graph. | ||
Class of graphs built from swc files where each entry in the swc file | ||
corresponds to a node in the graph. | ||
""" | ||
|
||
import os | ||
from random import sample | ||
from scipy.spatial import KDTree | ||
from time import time | ||
|
||
import networkx as nx | ||
|
||
from deep_neurographs import swc_utils, utils | ||
|
||
DELETION_RADIUS = 10 | ||
|
||
class DenseGraph: | ||
""" | ||
|
@@ -25,7 +27,7 @@ class DenseGraph: | |
""" | ||
|
||
def __init__(self, swc_paths, image_patch_origin, image_patch_shape): | ||
def __init__(self, swc_paths): | ||
""" | ||
Constructs a DenseGraph object from a directory of swc files. | ||
|
@@ -40,101 +42,84 @@ def __init__(self, swc_paths, image_patch_origin, image_patch_shape): | |
None | ||
""" | ||
self.init_graph(swc_paths) | ||
self.origin = image_patch_origin | ||
self.shape = image_patch_shape | ||
self.init_graphs(swc_paths) | ||
self.init_kdtree() | ||
|
||
def init_graph(self, paths): | ||
def init_graphs(self, paths): | ||
""" | ||
Initializes graphs by reading swc files in "swc_paths". Graphs are | ||
Initializes graphs by reading swc files in "paths". Graphs are | ||
stored in a hash table where the entries are filename-graph pairs. | ||
Parameters | ||
---------- | ||
swc_paths : list[str] | ||
List of paths to swc files which are used to construct a hash | ||
table in which the entries are filename-graph pairs. | ||
paths : list[str] | ||
List of paths to swc files that are used to construct a dictionary | ||
in which the items are filename-graph pairs. | ||
Returns | ||
------- | ||
None | ||
""" | ||
# Initializations | ||
print("Building graph...") | ||
self.graph = nx.Graph() | ||
self.graphs = dict() | ||
self.xyz_to_swc = dict() | ||
swc_dicts, _ = swc_utils.process_local_paths(paths) | ||
|
||
# Run | ||
cnt = 1 | ||
t0, t1 = utils.init_timers() | ||
chunk_size = max(int(len(swc_dicts) * 0.02), 1) | ||
for i, swc_dict in enumerate(swc_dicts): | ||
# Construct Graph | ||
swc_id = swc_dict["swc_id"] | ||
graph, _ = swc_utils.to_graph(swc_dict, set_attrs=True) | ||
graph = add_swc_id(graph, swc_id) | ||
self.graph = nx.disjoint_union(self.graph, graph) | ||
|
||
# Report progress | ||
if i > cnt * chunk_size: | ||
cnt, t1 = report_progress( | ||
i, len(swc_dicts), chunk_size, cnt, t0, t1 | ||
) | ||
|
||
def trim(self): | ||
pass | ||
|
||
def save(self, path): | ||
for i, component in enumerate(nx.connected_components(self.graph)): | ||
node = sample(component, 1)[0] | ||
swc_id = self.graph.nodes[node]["swc_id"] | ||
component_path = os.path.join(path, f"{swc_id}.swc") | ||
self.component_to_swc(component_path, component) | ||
|
||
def component_to_swc(self, path, component): | ||
self.store_xyz(graph, swc_id) | ||
self.graphs[swc_id] = graph | ||
|
||
def store_xyz(self, graph, swc_id): | ||
for i in graph.nodes: | ||
self.xyz_to_swc[tuple(graph.nodes[i]["xyz"])] = swc_id | ||
|
||
def init_kdtree(self): | ||
""" | ||
Builds a KD-Tree from the xyz coordinates from every node stored in | ||
self.graphs. | ||
Parameters | ||
---------- | ||
None | ||
Returns | ||
------- | ||
None | ||
""" | ||
self.kdtree = KDTree(list(self.xyz_to_swc.keys())) | ||
|
||
def get_projection(self, xyz): | ||
_, idx = self.kdtree.query(xyz, k=1) | ||
return tuple(self.kdtree.data[idx]) | ||
|
||
def save(self, output_dir): | ||
for swc_id, graph in self.graphs.items(): | ||
cnt = 0 | ||
for component in nx.connected_components(graph): | ||
entry_list = self.make_entries(graph, component) | ||
path = os.path.join(output_dir, f"{swc_id}.swc") | ||
while os.path.exists(path): | ||
path = os.path.join(output_dir, f"{swc_id}.{cnt}.swc") | ||
cnt += 1 | ||
swc_utils.write(path, entry_list) | ||
|
||
def make_entries(self, graph, component): | ||
node_to_idx = dict() | ||
entry_list = [] | ||
for i, j in nx.dfs_edges(self.graph.subgraph(component)): | ||
for i, j in nx.dfs_edges(graph.subgraph(component)): | ||
# Initialize | ||
if len(entry_list) == 0: | ||
x, y, z = tuple(self.graph.nodes[i]["xyz"]) | ||
r = self.graph.nodes[i]["radius"] | ||
entry_list.append([1, 2, x, y, z, r, -1]) | ||
node_to_idx[i] = 1 | ||
x, y, z = tuple(graph.nodes[i]["xyz"]) | ||
r = graph.nodes[i]["radius"] | ||
entry_list.append([1, 2, x, y, z, r, -1]) | ||
|
||
# Create entry | ||
node_to_idx[j] = len(entry_list) + 1 | ||
x, y, z = tuple(self.graph.nodes[j]["xyz"]) | ||
r = self.graph.nodes[j]["radius"] | ||
x, y, z = tuple(graph.nodes[j]["xyz"]) | ||
r = graph.nodes[j]["radius"] | ||
entry_list.append([node_to_idx[j], 2, x, y, z, r, node_to_idx[i]]) | ||
|
||
swc_utils.write(path, entry_list) | ||
|
||
|
||
def add_swc_id(graph, swc_id): | ||
for i in graph.nodes: | ||
graph.nodes[i]["swc_id"] = swc_id | ||
return graph | ||
|
||
|
||
def report_progress(current, total, chunk_size, cnt, t0, t1): | ||
eta = get_eta(current, total, chunk_size, t1) | ||
runtime = get_runtime(current, total, chunk_size, t0, t1) | ||
utils.progress_bar(current, total, eta=eta, runtime=runtime) | ||
return cnt + 1, time() | ||
|
||
|
||
def get_eta(current, total, chunk_size, t0, return_str=True): | ||
chunk_runtime = time() - t0 | ||
remaining = total - current | ||
eta = remaining * (chunk_runtime / chunk_size) | ||
t, unit = utils.time_writer(eta) | ||
return f"{round(t, 4)} {unit}" if return_str else eta | ||
|
||
|
||
def get_runtime(current, total, chunk_size, t0, t1): | ||
eta = get_eta(current, total, chunk_size, t1, return_str=False) | ||
total_runtime = time() - t0 + eta | ||
t, unit = utils.time_writer(total_runtime) | ||
return f"{round(t, 4)} {unit}" | ||
return entry_list | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.