Skip to content

Commit

Permalink
Optimize learning (#84)
Browse files Browse the repository at this point in the history
* refactor: gt generation, black

* refactor: ground truth generation

---------

Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Mar 26, 2024
1 parent a21c1ae commit aae3041
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 21 deletions.
27 changes: 21 additions & 6 deletions src/deep_neurographs/groundtruth_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,23 @@ def is_component_aligned(target_neurograph, pred_neurograph, component):
hat_swc_id = target_neurograph.xyz_to_swc(hat_xyz)
d = get_dist(hat_xyz, xyz)
dists = utils.append_dict_value(dists, hat_swc_id, d)

# Check whether there's a merge
hits = []
for key in dists.keys():
if len(dists[key]) > 8 and np.mean(dists[key]) < 10:
hits.append(key)
if len(hits) > 1:
print(pred_neurograph.edges[edge]["swc_id"])

# Deterine whether aligned
hat_swc_id = find_best(dists)
dists = np.array(dists[hat_swc_id])
aligned_score = np.mean(dists[dists < np.percentile(dists, 95)])
return True if aligned_score < 3 else False, hat_swc_id
aligned_score = np.mean(dists[dists < np.percentile(dists, 90)])
if aligned_score < 4 and hat_swc_id:
return True, hat_swc_id
else:
return False, None


def is_valid(target_neurograph, pred_neurograph, target_id, edge):
Expand Down Expand Up @@ -136,10 +147,14 @@ def is_valid(target_neurograph, pred_neurograph, target_id, edge):
hat_edge_i = proj_branch(target_neurograph, pred_neurograph, target_id, i)
hat_edge_j = proj_branch(target_neurograph, pred_neurograph, target_id, j)

# Check if edges either identical or adjacent
if hat_edge_i == hat_edge_j:
# Check if edges are identical or None
if not hat_edge_i or not hat_edge_j:
return False
elif hat_edge_i == hat_edge_j:
return True
elif is_adjacent(target_neurograph, hat_edge_i, hat_edge_j):

# Check if edges are adjacent
if is_adjacent(target_neurograph, hat_edge_i, hat_edge_j):
hat_branch_i = target_neurograph.edges[hat_edge_i]["xyz"]
hat_branch_j = target_neurograph.edges[hat_edge_j]["xyz"]
xyz_i = pred_neurograph.nodes[i]["xyz"]
Expand Down Expand Up @@ -171,7 +186,7 @@ def proj_branch(target_neurograph, pred_neurograph, target_id, i):
if d < min_dist:
min_dist = d
best_edge = edge
else:
elif len(hits.keys()) == 1:
best_edge = list(hits.keys())[0]
return best_edge

Expand Down
2 changes: 1 addition & 1 deletion src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def build_neurograph_from_gcs_zips(
Parameters
----------
bucket_name : str
Name of GCS bucket where zips are stored.
Name of GCS bucket where zips of swc files are stored.
cloud_path : str
Path within GCS bucket to directory containing zips.
anisotropy : list[float], optional
Expand Down
9 changes: 5 additions & 4 deletions src/deep_neurographs/machine_learning/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ProposalDataset(Dataset):
"""

def __init__(self, inputs, targets, search_radius=10, transform=False):
def __init__(self, inputs, targets, search_radius=10, transform=False, lengths=[]):
"""
Constructs ProposalDataset object.
Expand All @@ -41,7 +41,7 @@ def __init__(self, inputs, targets, search_radius=10, transform=False):
"""
self.inputs = inputs.astype(np.float32)
self.targets = reformat(targets)
self.search_radius = search_radius
self.lengths = lengths
self.transform = transform

def __len__(self):
Expand Down Expand Up @@ -77,8 +77,9 @@ def __getitem__(self, idx):
"""
inputs_i = self.inputs[idx]
if self.transform:
if inputs_i[0] > self.search_radius or np.random.random() > 0.75:
inputs_i[0] = abs(np.random.normal(0, 5, 1))
if np.random.random() > 0.6:
p = 100 * np.random.random()
inputs_i[0] = np.percentile(self.lengths, p)
return {"inputs": inputs_i, "targets": self.targets[idx]}


Expand Down
File renamed without changes.
File renamed without changes.
22 changes: 18 additions & 4 deletions src/deep_neurographs/machine_learning/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
from sklearn.ensemble import AdaBoostClassifier, RandomForestClassifier

from deep_neurographs import feature_extraction as extracter
from deep_neurographs.machine_learning import feature_extraction as extracter
from deep_neurographs.machine_learning.datasets import (
ImgProposalDataset,
MultiModalDataset,
Expand Down Expand Up @@ -80,7 +80,7 @@ def init_model(model_type):
return MultiModalNet(n_features)


def get_dataset(inputs, targets, model_type, transform):
def get_dataset(inputs, targets, model_type, transform, lengths):
"""
Gets classification model to be fit.
Expand All @@ -102,7 +102,7 @@ def get_dataset(inputs, targets, model_type, transform):
"""
if model_type == "FeedForwardNet":
return ProposalDataset(inputs, targets, transform=transform)
return ProposalDataset(inputs, targets, transform=transform, lengths=lengths)
elif model_type == "ConvNet":
return ImgProposalDataset(inputs, targets, transform=transform)
elif model_type == "MultiModalNet":
Expand All @@ -114,12 +114,26 @@ def get_dataset(inputs, targets, model_type, transform):
def init_dataset(
neurographs, features, model_type, block_ids, transform=False
):
# Extract features
inputs, targets, block_to_idx, idx_to_edge = extracter.get_feature_matrix(
neurographs, features, model_type, block_ids=block_ids
)
lens = []
if transform:
for block_id in block_ids:
lens.extend(get_lengths(neurographs[block_id]))

dataset = {
"dataset": get_dataset(inputs, targets, model_type, transform),
"dataset": get_dataset(inputs, targets, model_type, transform, lens),
"block_to_idxs": block_to_idx,
"idx_to_edge": idx_to_edge,
}
return dataset


def get_lengths(neurograph):
lengths = []
for edge in neurograph.proposals.keys():
lengths.append(neurograph.proposal_length(edge))
return lengths

2 changes: 1 addition & 1 deletion src/deep_neurographs/machine_learning/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _init_fc_layer(self, D_in, D_out):
"""
fc_layer = nn.Sequential(
nn.Linear(D_in, D_out), nn.LeakyReLU(), nn.Dropout(p=0.3)
nn.Linear(D_in, D_out), nn.LeakyReLU(), nn.Dropout(p=0.25)
)
return fc_layer

Expand Down
File renamed without changes.
19 changes: 15 additions & 4 deletions src/deep_neurographs/machine_learning/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def fit_model(model, dataset):


def fit_deep_model(
model, dataset, batch_size=BATCH_SIZE, logger=False, lr=1e-3, max_epochs=50
model,
dataset,
batch_size=BATCH_SIZE,
criterion=None,
logger=False,
lr=1e-3,
max_epochs=1000,
):
"""
Fits a neural network to a dataset.
Expand Down Expand Up @@ -76,7 +82,7 @@ def fit_deep_model(
valid_loader = DataLoader(valid_set, batch_size=batch_size)

# Configure trainer
lit_model = LitModel(model=model, lr=lr)
lit_model = LitModel(criterion=criterion, model=model, lr=lr)
ckpt_callback = ModelCheckpoint(save_top_k=1, monitor="val_f1", mode="max")

# Fit model
Expand Down Expand Up @@ -107,11 +113,16 @@ def random_split(train_set, train_ratio=0.8):

# -- Lightning Module --
class LitModel(pl.LightningModule):
def __init__(self, model=None, lr=1e-3):
def __init__(self, criterion=None, model=None, lr=1e-3):
super().__init__()
self.criterion = nn.BCEWithLogitsLoss()
self.model = model
self.lr = lr
if criterion:
self.criterion = criterion
else:
pos_weight = torch.tensor([0.75], device=0)
self.criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)


def forward(self, batch):
x = self.get_example(batch, "inputs")
Expand Down
1 change: 1 addition & 0 deletions src/deep_neurographs/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,6 @@ def get_structure_aware_prediction(
# Add remaining viable edges
for edge in remaining_proposals:
if not gutils.creates_cycle(pred_graph, tuple(edge)):
pred_graph.add_edges_from([edge])
positive_predictions.append(edge)
return positive_predictions
2 changes: 1 addition & 1 deletion src/deep_neurographs/swc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def parse(contents, anisotropy=[1.0, 1.0, 1.0]):
for i, line in enumerate(contents):
parts = line.split()
swc_dict["id"][i] = parts[0]
swc_dict["radius"][i] = parts[-2]
swc_dict["radius"][i] = float(parts[-2])
swc_dict["pid"][i] = parts[-1]
swc_dict["xyz"][i] = read_xyz(
parts[2:5], anisotropy=anisotropy, offset=offset
Expand Down

0 comments on commit aae3041

Please sign in to comment.