Skip to content

Commit

Permalink
feat: nearly finished merge deletion (#85)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Mar 27, 2024
1 parent aae3041 commit 06ad479
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 138 deletions.
134 changes: 134 additions & 0 deletions src/deep_neurographs/delete_merges_gt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Created on Sat March 26 17:30:00 2024
@author: Anna Grim
@email: [email protected]
Deletes merges from predicted swc files in the case when there are ground
truth swc files.
"""

import networkx as nx
import numpy as np
from deep_neurographs.densegraph import DenseGraph
from deep_neurographs import geometry, utils

CLOSE_THRESHOLD = 3.5
DELETION_RADIUS = 5
MERGE_DIST_THRESHOLD = 30
MIN_INTERSECTION = 10


def delete_merges(
target_swc_paths,
pred_swc_paths,
output_dir,
deletion_radius=DELETION_RADIUS,
):
"""
Deletes merges from predicted swc files in the case when there are ground
truth swc files.
Parameters
----------
target_swc_paths : list[str]
List of paths to ground truth swc files.
pred_swc_paths : list[str]
List of paths to predicted swc files.
output_dir : str
Directory that updated graphs will be written to.
deletion_radius : int, optional
Each node within "deletion_radius" is deleted. The default is the
global variable "DELETION_RADIUS".
Returns
-------
None
"""
target_densegraph = DenseGraph(target_swc_paths)
pred_densegraph = DenseGraph(pred_swc_paths)
for swc_id in pred_densegraph.graphs.keys():
# Detection
pred_graph = pred_densegraph.graphs[swc_id]
merged_nodes = detect_merge(target_densegraph, pred_graph)

# Deletion
if len(merged_nodes.keys()) > 0:
visited = set()
delete_nodes = set()
for key_1 in merged_nodes.keys():
for key_2 in merged_nodes.keys():
pair = frozenset((key_1, key_2))
if key_1 != key_2 and pair not in visited:
sites, d = locate_site(
pred_graph, merged_nodes[key_1], merged_nodes[key_2]
)
if d < MERGE_DIST_THRESHOLD:
print(sites, d)
# delete just like a connector

pred_densegraph.graphs[swc_id] = pred_graph

# Save
pred_densegraph.save(output_dir)


def detect_merge(target_densegraph, pred_graph):
"""
Determines whether the "pred_graph" contains a merge mistake. This routine
projects each node in "pred_graph" onto "target_neurograph", then computes
the projection distance. ...
Parameters
----------
target_densegraph : DenseGraph
Graph built from ground truth swc files.
pred_graph : networkx.Graph
Graph build from a predicted swc file.
Returns
-------
set
Set of nodes that are part of a merge mistake.
"""
# Compute projections
hits = dict()
for i in pred_graph.nodes:
xyz = tuple(pred_graph.nodes[i]["xyz"])
hat_xyz = target_densegraph.get_projection(xyz)
hat_swc_id = target_densegraph.xyz_to_swc[hat_xyz]
if geometry.dist(hat_xyz, xyz) < CLOSE_THRESHOLD:
hits = utils.append_dict_value(hits, hat_swc_id, i)

# Remove spurious intersections
keys = [key for key in hits.keys() if len(hits[key]) < MIN_INTERSECTION]
return utils.remove_items(hits, keys)


def locate_site(graph, merged_1, merged_2):
min_dist = np.inf
node_pair = [None, None]
for i in merged_1:
for j in merged_2:
xyz_i = graph.nodes[i]["xyz"]
xyz_j = graph.nodes[j]["xyz"]
if geometry.dist(xyz_i, xyz_j) < min_dist:
min_dist = geometry.dist(xyz_i, xyz_j)
node_pair = [i, j]
return node_pair, min_dist


def delete_merge(graph, root, radius):
delete_nodes = get_nearby_nodes(graph, root, radius)
graph.remove_nodes_from(delete_nodes)
return graph


def get_nearby_nodes(graph, root, radius):
nearby_nodes = set()
for _, j in nx.dfs_edges(graph, source=root, depth_limit=radius):
nearby_nodes.add(j)
return nearby_nodes
137 changes: 61 additions & 76 deletions src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@
@author: Anna Grim
@email: [email protected]
Class of graphs built from swc files. Each swc file is stored as a distinct
graph and each node in this graph.
Class of graphs built from swc files where each entry in the swc file
corresponds to a node in the graph.
"""

import os
from random import sample
from scipy.spatial import KDTree
from time import time

import networkx as nx

from deep_neurographs import swc_utils, utils

DELETION_RADIUS = 10

class DenseGraph:
"""
Expand All @@ -25,7 +27,7 @@ class DenseGraph:
"""

def __init__(self, swc_paths, image_patch_origin, image_patch_shape):
def __init__(self, swc_paths):
"""
Constructs a DenseGraph object from a directory of swc files.
Expand All @@ -40,101 +42,84 @@ def __init__(self, swc_paths, image_patch_origin, image_patch_shape):
None
"""
self.init_graph(swc_paths)
self.origin = image_patch_origin
self.shape = image_patch_shape
self.init_graphs(swc_paths)
self.init_kdtree()

def init_graph(self, paths):
def init_graphs(self, paths):
"""
Initializes graphs by reading swc files in "swc_paths". Graphs are
Initializes graphs by reading swc files in "paths". Graphs are
stored in a hash table where the entries are filename-graph pairs.
Parameters
----------
swc_paths : list[str]
List of paths to swc files which are used to construct a hash
table in which the entries are filename-graph pairs.
paths : list[str]
List of paths to swc files that are used to construct a dictionary
in which the items are filename-graph pairs.
Returns
-------
None
"""
# Initializations
print("Building graph...")
self.graph = nx.Graph()
self.graphs = dict()
self.xyz_to_swc = dict()
swc_dicts, _ = swc_utils.process_local_paths(paths)

# Run
cnt = 1
t0, t1 = utils.init_timers()
chunk_size = max(int(len(swc_dicts) * 0.02), 1)
for i, swc_dict in enumerate(swc_dicts):
# Construct Graph
swc_id = swc_dict["swc_id"]
graph, _ = swc_utils.to_graph(swc_dict, set_attrs=True)
graph = add_swc_id(graph, swc_id)
self.graph = nx.disjoint_union(self.graph, graph)

# Report progress
if i > cnt * chunk_size:
cnt, t1 = report_progress(
i, len(swc_dicts), chunk_size, cnt, t0, t1
)

def trim(self):
pass

def save(self, path):
for i, component in enumerate(nx.connected_components(self.graph)):
node = sample(component, 1)[0]
swc_id = self.graph.nodes[node]["swc_id"]
component_path = os.path.join(path, f"{swc_id}.swc")
self.component_to_swc(component_path, component)

def component_to_swc(self, path, component):
self.store_xyz(graph, swc_id)
self.graphs[swc_id] = graph

def store_xyz(self, graph, swc_id):
for i in graph.nodes:
self.xyz_to_swc[tuple(graph.nodes[i]["xyz"])] = swc_id

def init_kdtree(self):
"""
Builds a KD-Tree from the xyz coordinates from every node stored in
self.graphs.
Parameters
----------
None
Returns
-------
None
"""
self.kdtree = KDTree(list(self.xyz_to_swc.keys()))

def get_projection(self, xyz):
_, idx = self.kdtree.query(xyz, k=1)
return tuple(self.kdtree.data[idx])

def save(self, output_dir):
for swc_id, graph in self.graphs.items():
cnt = 0
for component in nx.connected_components(graph):
entry_list = self.make_entries(graph, component)
path = os.path.join(output_dir, f"{swc_id}.swc")
while os.path.exists(path):
path = os.path.join(output_dir, f"{swc_id}.{cnt}.swc")
cnt += 1
swc_utils.write(path, entry_list)

def make_entries(self, graph, component):
node_to_idx = dict()
entry_list = []
for i, j in nx.dfs_edges(self.graph.subgraph(component)):
for i, j in nx.dfs_edges(graph.subgraph(component)):
# Initialize
if len(entry_list) == 0:
x, y, z = tuple(self.graph.nodes[i]["xyz"])
r = self.graph.nodes[i]["radius"]
entry_list.append([1, 2, x, y, z, r, -1])
node_to_idx[i] = 1
x, y, z = tuple(graph.nodes[i]["xyz"])
r = graph.nodes[i]["radius"]
entry_list.append([1, 2, x, y, z, r, -1])

# Create entry
node_to_idx[j] = len(entry_list) + 1
x, y, z = tuple(self.graph.nodes[j]["xyz"])
r = self.graph.nodes[j]["radius"]
x, y, z = tuple(graph.nodes[j]["xyz"])
r = graph.nodes[j]["radius"]
entry_list.append([node_to_idx[j], 2, x, y, z, r, node_to_idx[i]])

swc_utils.write(path, entry_list)


def add_swc_id(graph, swc_id):
for i in graph.nodes:
graph.nodes[i]["swc_id"] = swc_id
return graph


def report_progress(current, total, chunk_size, cnt, t0, t1):
eta = get_eta(current, total, chunk_size, t1)
runtime = get_runtime(current, total, chunk_size, t0, t1)
utils.progress_bar(current, total, eta=eta, runtime=runtime)
return cnt + 1, time()


def get_eta(current, total, chunk_size, t0, return_str=True):
chunk_runtime = time() - t0
remaining = total - current
eta = remaining * (chunk_runtime / chunk_size)
t, unit = utils.time_writer(eta)
return f"{round(t, 4)} {unit}" if return_str else eta


def get_runtime(current, total, chunk_size, t0, t1):
eta = get_eta(current, total, chunk_size, t1, return_str=False)
total_runtime = time() - t0 + eta
t, unit = utils.time_writer(total_runtime)
return f"{round(t, 4)} {unit}"
return entry_list

23 changes: 8 additions & 15 deletions src/deep_neurographs/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from deep_neurographs import utils
from deep_neurographs.geometry import dist as get_dist

CLOSE_THRESHOLD = 3.5
MIN_INTERSECTION = 10


def init_targets(target_neurograph, pred_neurograph):
# Initializations
Expand Down Expand Up @@ -104,13 +107,15 @@ def is_component_aligned(target_neurograph, pred_neurograph, component):
# Check whether there's a merge
hits = []
for key in dists.keys():
if len(dists[key]) > 8 and np.mean(dists[key]) < 10:
if len(dists[key]) > 10 and np.mean(dists[key]) < CLOSE_THRESHOLD:
hits.append(key)
if len(hits) > 1:
print(pred_neurograph.edges[edge]["swc_id"])
print("pred_swc_id:", pred_neurograph.edges[edge]["swc_id"])
print("target_swc_id:", list(dists.keys()))
print("")

# Deterine whether aligned
hat_swc_id = find_best(dists)
hat_swc_id = utils.find_best(dists)
dists = np.array(dists[hat_swc_id])
aligned_score = np.mean(dists[dists < np.percentile(dists, 90)])
if aligned_score < 4 and hat_swc_id:
Expand Down Expand Up @@ -240,18 +245,6 @@ def upd_dict_cnts(my_dict, key):
return my_dict


def find_best(my_dict):
best_key = None
best_vote_cnt = 0
for key in my_dict.keys():
val_type = type(my_dict[key])
vote_cnt = my_dict[key] if val_type == int else len(my_dict[key])
if vote_cnt > best_vote_cnt:
best_key = key
best_vote_cnt = vote_cnt
return best_key


def orient_branch(branch_i, branch_j):
"""
Flips branches so that "all(branch_i[0] == branch_j[0])" is True.
Expand Down
Loading

0 comments on commit 06ad479

Please sign in to comment.