Skip to content

Commit

Permalink
bug: target edge generation (#82)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Mar 20, 2024
1 parent ebceb0d commit 631adad
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 51 deletions.
4 changes: 4 additions & 0 deletions src/deep_neurographs/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np

METRICS_LIST = [
"accuracy_dif",
"accuracy",
"precision",
"recall",
Expand Down Expand Up @@ -146,10 +147,13 @@ def get_stats(neurograph, proposals, pred_edges):
"METRICS_LIST".
"""
n_pos = len([e for e in proposals if e in neurograph.target_edges])
a_baseline = n_pos / (len(proposals) if len(proposals) > 0 else 1)
tp, fp, a, p, r, f1 = get_accuracy(neurograph, proposals, pred_edges)
stats = {
"# splits fixed": tp,
"# merges created": fp,
"accuracy_dif": a - a_baseline,
"accuracy": a,
"precision": p,
"recall": r,
Expand Down
13 changes: 6 additions & 7 deletions src/deep_neurographs/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def __multiblock_feature_matrix(neurographs, features, blocks, model_type):
neurographs[block_id], features[block_id], shift=idx_shift
)
else:
X_i, y_i, idx_to_edge_i = get_feature_vectors(
X_i, y_i, idxs_i, idx_to_edge_i = get_feature_vectors(
neurographs[block_id], features[block_id], shift=idx_shift
)

Expand All @@ -354,18 +354,15 @@ def __multiblock_feature_matrix(neurographs, features, blocks, model_type):
X = deepcopy(X_i)
y = deepcopy(y_i)
if model_type == "MultiModalNet":
print("if")
x = deepcopy(x_i)
else:
X = np.concatenate((X, X_i), axis=0)
y = np.concatenate((y, y_i), axis=0)
if model_type == "MultiModalNet":
print("else")
x = np.concatenate((x, x_i), axis=0)

# Update dicts
idxs = set(np.arange(idx_shift, idx_shift + len(idx_to_edge_i)))
block_to_idxs[block_id] = idxs
block_to_idxs[block_id] = idxs_i
idx_to_edge.update(idx_to_edge_i)

if model_type == "MultiModalNet":
Expand All @@ -391,13 +388,15 @@ def get_feature_vectors(neurograph, features, shift=0):

# Build
idx_to_edge = dict()
idxs = set()
X = np.zeros((neurograph.n_proposals(), len(features[key])))
y = np.zeros((neurograph.n_proposals()))
for i, edge in enumerate(features.keys()):
idx_to_edge[i + shift] = edge
X[i, :] = features[edge]
y[i] = 1 if edge in neurograph.target_edges else 0
return X, y, idx_to_edge
idxs.add(i + shift)
idx_to_edge[i + shift] = edge
return X, y, idxs, idx_to_edge


def get_multimodal_features(neurograph, features, shift=0):
Expand Down
12 changes: 7 additions & 5 deletions src/deep_neurographs/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import networkx as nx
import numpy as np

from deep_neurographs import graph_utils as gutils
from deep_neurographs import graph_utils as gutils, geometry
from deep_neurographs.geometry import dist as get_dist

site = np.array([22086.158, 10681.918, 9549.215])

def init_targets(target_neurograph, pred_neurograph):
# Initializations
Expand Down Expand Up @@ -58,6 +59,7 @@ def get_valid_proposals(target_neurograph, pred_neurograph):
# Check whether aligned to same/adjacent target edges
branches_i = pred_neurograph.get_branches(i)
branches_j = pred_neurograph.get_branches(j)

if is_mutually_aligned(target_neurograph, branches_i, branches_j):
valid_proposals.append(edge)

Expand Down Expand Up @@ -94,7 +96,7 @@ def is_component_aligned(target_neurograph, pred_neurograph, component):
dists.append(get_dist(hat_xyz, xyz))
dists = np.array(dists)
aligned_score = np.mean(dists[dists < np.percentile(dists, 90)])
return True if aligned_score < 5 else False
return True if aligned_score < 6 else False


def is_mutually_aligned(target_neurograph, branches_i, branches_j):
Expand All @@ -104,8 +106,8 @@ def is_mutually_aligned(target_neurograph, branches_i, branches_j):

# Check if edges either identical or adjacent
identical = hat_edge_i == hat_edge_j
adjacent = is_adjacent(target_neurograph, hat_edge_i, hat_edge_j)
if identical or adjacent:
adjacent = is_adjacent(target_neurograph, hat_edge_i, hat_edge_j)
if identical:
return True
else:
return False
Expand Down Expand Up @@ -137,7 +139,7 @@ def is_adjacent(neurograph, edge_i, edge_j):
"""
for i in edge_i:
for j in edge_j:
if neurograph.is_nb(i, j):
if i == j:
return True
return False

Expand Down
27 changes: 19 additions & 8 deletions src/deep_neurographs/machine_learning/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,27 @@


def predict(dataset, model, model_type):
dataset = dataset["dataset"]
accuracy = []
accuracy_baseline = []
data = dataset["dataset"]
if "Net" in model_type:
model.eval()
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)
y_pred = []
for batch in dataloader:
hat_y = []
for batch in DataLoader(data, batch_size=32, shuffle=False):
# Run model
with torch.no_grad():
x_i = batch["inputs"]
y_pred_i = sigmoid(model(x_i))
y_pred.extend(np.array(y_pred_i).tolist())
hat_y_i = sigmoid(model(x_i))

# Postprocess
hat_y_i = np.array(hat_y_i)
y_i = np.array(batch["targets"])
hat_y.extend(hat_y_i.tolist())
accuracy_baseline.extend((y_i > 0).tolist())
accuracy.extend(((hat_y_i > 0.5) == (y_i > 0)).tolist())
accuracy = np.mean(accuracy)
accuracy_baseline = np.sum(accuracy_baseline) / len(accuracy_baseline)
print("Accuracy +/-:", accuracy - accuracy_baseline)
else:
y_pred = model.predict_proba(dataset["inputs"])[:, 1]
return np.array(y_pred)
hat_y = model.predict_proba(data["inputs"])[:, 1]
return np.array(hat_y)
5 changes: 1 addition & 4 deletions src/deep_neurographs/machine_learning/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def __init__(self, num_features):
"""
nn.Module.__init__(self)
self.fc1 = self._init_fc_layer(num_features, 2 * num_features)
self.fc2 = self._init_fc_layer(2 * num_features, num_features)
self.fc3 = self._init_fc_layer(num_features, num_features // 2)
self.fc2 = self._init_fc_layer(2 * num_features, num_features // 2)
self.output = nn.Linear(num_features // 2, 1)

def _init_fc_layer(self, D_in, D_out):
Expand Down Expand Up @@ -77,8 +76,6 @@ def forward(self, x):
"""
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.output(x)
return x

Expand Down
2 changes: 2 additions & 0 deletions src/deep_neurographs/machine_learning/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def fit_deep_model(
pylightning_trainer.fit(lit_model, train_loader, valid_loader)

# Return best model
print(ckpt_callback.best_model_path)
ckpt = torch.load(ckpt_callback.best_model_path)
lit_model.model.load_state_dict(ckpt["state_dict"])
return lit_model.model
Expand Down Expand Up @@ -134,6 +135,7 @@ def validation_step(self, batch, batch_idx):
X = self.get_example(batch, "inputs")
y = self.get_example(batch, "targets")
y_hat = self.model(X)
self.log("val_loss", self.criterion(y_hat, y))
self.compute_stats(y_hat, y, prefix="val_")

def compute_stats(self, y_hat, y, prefix=""):
Expand Down
46 changes: 19 additions & 27 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,36 +182,28 @@ def generate_proposals(
)
if self.degree[node] >= 3:
continue
if swc_id in existing_connections.keys() and restrict:
if leaf_swc_id in existing_connections[swc_id].keys():
edge = existing_connections[swc_id][leaf_swc_id]
len1 = self.node_xyz_dist(leaf, xyz)
len2 = self.proposal_length(edge)
if len1 < len2:
node1, node2 = tuple(edge)
self.nodes[node1]["proposals"].remove(node2)
self.nodes[node2]["proposals"].remove(node1)
del self.proposals[edge]
else:
continue

# Add edge
if self.degree[node] < 3:

pair_id = frozenset((swc_id, leaf_swc_id))
if pair_id in existing_connections.keys() and restrict:
edge = existing_connections[pair_id]
len1 = self.node_xyz_dist(leaf, xyz)
len2 = self.proposal_length(edge)
if len1 < len2:
node1, node2 = tuple(edge)
self.nodes[node1]["proposals"].discard(node2)
self.nodes[node2]["proposals"].discard(node1)
del self.proposals[edge]
del existing_connections[pair_id]
else:
continue

# Add proposal
if self.degree[node] < 2:
edge = frozenset({leaf, node})
self.proposals[edge] = {"xyz": np.array([xyz_leaf, xyz])}
self.nodes[node]["proposals"].add(leaf)
self.nodes[leaf]["proposals"].add(node)

# Update existing connections
if leaf_swc_id in existing_connections.keys():
existing_connections[leaf_swc_id][swc_id] = edge
else:
existing_connections[leaf_swc_id] = {swc_id: edge}

if swc_id in existing_connections.keys():
existing_connections[swc_id][leaf_swc_id] = edge
else:
existing_connections[swc_id] = {leaf_swc_id: edge}
self.proposals[edge] = {"xyz": np.array([xyz_leaf, xyz])}
existing_connections[pair_id] = edge

# print("# doubles:", len(doubles))

Expand Down

0 comments on commit 631adad

Please sign in to comment.