Skip to content

Commit

Permalink
feat: trim graph to fit image bbox (#81)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Mar 18, 2024
1 parent 1043e24 commit ebceb0d
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 7 deletions.
9 changes: 5 additions & 4 deletions src/deep_neurographs/densegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
"""

import os
import networkx as nx
import numpy as np
from random import sample
from time import time

import networkx as nx

from deep_neurographs import swc_utils, utils


Expand Down Expand Up @@ -75,7 +75,7 @@ def init_graph(self, paths):
graph, _ = swc_utils.to_graph(swc_dict, set_attrs=True)
graph = add_swc_id(graph, swc_id)
self.graph = nx.disjoint_union(self.graph, graph)

# Report progress
if i > cnt * chunk_size:
cnt, t1 = report_progress(
Expand Down Expand Up @@ -111,12 +111,13 @@ def component_to_swc(self, path, component):

swc_utils.write(path, entry_list)


def add_swc_id(graph, swc_id):
for i in graph.nodes:
graph.nodes[i]["swc_id"] = swc_id
return graph


def report_progress(current, total, chunk_size, cnt, t0, t1):
eta = get_eta(current, total, chunk_size, t1)
runtime = get_runtime(current, total, chunk_size, t0, t1)
Expand Down
13 changes: 13 additions & 0 deletions src/deep_neurographs/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

def get_irreducibles(
swc_dict,
bbox=None,
prune_connectors=False,
prune_spurious=True,
connector_length=8,
Expand Down Expand Up @@ -71,6 +72,7 @@ def get_irreducibles(
# Build dense graph
swc_dict["idx"] = dict(zip(swc_dict["id"], range(len(swc_dict["id"]))))
graph, _ = swc_utils.to_graph(swc_dict, set_attrs=True)
graph = trim_branches(graph, bbox)
graph, connector_centroids = prune_branches(
graph,
prune_connectors=prune_connectors,
Expand All @@ -91,6 +93,17 @@ def get_irreducibles(
return irreducibles, connector_centroids


def trim_branches(graph, bbox):
if bbox:
delete_nodes = set()
for i in graph.nodes:
xyz = utils.to_img(graph.nodes[i]["xyz"])
if not utils.is_contained(bbox, xyz):
delete_nodes.add(i)
graph.remove_nodes_from(delete_nodes)
return graph


def prune_branches(
graph,
prune_connectors=False,
Expand Down
3 changes: 3 additions & 0 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def build_neurograph(
print("# connected components:", utils.reformat_number(n_components))
irreducibles, n_nodes, n_edges = get_irreducibles(
swc_dicts,
bbox=img_bbox,
progress_bar=progress_bar,
prune_connectors=prune_connectors,
prune_spurious=prune_spurious,
Expand Down Expand Up @@ -315,6 +316,7 @@ def build_neurograph(

def get_irreducibles(
swc_dicts,
bbox=None,
progress_bar=True,
prune_connectors=PRUNE_CONNECTORS,
prune_spurious=PRUNE_SPURIOUS,
Expand All @@ -333,6 +335,7 @@ def get_irreducibles(
processes[i] = executor.submit(
gutils.get_irreducibles,
swc_dict,
bbox,
prune_connectors,
prune_spurious,
connector_length,
Expand Down
2 changes: 0 additions & 2 deletions src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ def predict(dataset, model, model_type):
for batch in dataloader:
with torch.no_grad():
x_i = batch["inputs"]
y_i = batch["targets"]
y_pred_i = sigmoid(model(x_i))
y_pred.extend(np.array(y_pred_i).tolist())
#print((np.sum((np.array(y_pred_i) > 0.5) == (np.array(y_i) > 0))) / len(y_i))
else:
y_pred = model.predict_proba(dataset["inputs"])[:, 1]
return np.array(y_pred)
1 change: 1 addition & 0 deletions src/deep_neurographs/machine_learning/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from deep_neurographs import feature_extraction as extracter
from deep_neurographs.machine_learning.datasets import (
ImgProposalDataset,
MultiModalDataset,
ProposalDataset,
)
from deep_neurographs.machine_learning.models import (
Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def read_tensorstore(arr, xyz, shape, from_center=True):
def get_chunk(arr, xyz, shape, from_center=True):
start, end = get_start_end(xyz, shape, from_center=from_center)
return deepcopy(
arr[start[0] : end[0], start[1] : end[1], start[2] : end[2]]
arr[start[0]: end[0], start[1]: end[1], start[2]: end[2]]
)


Expand Down

0 comments on commit ebceb0d

Please sign in to comment.