Skip to content

Commit

Permalink
refactor: improvements to intake (#77)
Browse files Browse the repository at this point in the history
Co-authored-by: anna-grim <[email protected]>
  • Loading branch information
anna-grim and anna-grim authored Mar 16, 2024
1 parent 6445d4b commit 2848d67
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 166 deletions.
Empty file.
9 changes: 6 additions & 3 deletions src/deep_neurographs/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def build_neurograph_from_local(
prune_depth=prune_depth,
smooth=smooth,
)

# Delete nodes outside bbox
if img_bbox:
neurograph.delete_isolated()

return neurograph


Expand Down Expand Up @@ -185,11 +190,9 @@ def download_gcs_zips(bucket_name, cloud_path, min_size, anisotropy):
"""
# Initializations
storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
bucket = storage.Client().bucket(bucket_name)
zip_paths = utils.list_gcs_filenames(bucket, cloud_path, ".zip")
chunk_size = int(len(zip_paths) * 0.02)
print(f"# zip files: {len(zip_paths)}")

# Parse
cnt = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, inputs, labels, transform=True):
"""
self.inputs = inputs.astype(np.float32)
self.labels = reformat(labels)
self.transform = Augmentator() if transform else None
self.transform = AugmentImages() if transform else None

def __len__(self):
"""
Expand Down Expand Up @@ -186,7 +186,7 @@ def __init__(self, inputs, labels, transform=True):
self.img_inputs = inputs["imgs"].astype(np.float32)
self.feature_inputs = inputs["features"].astype(np.float32)
self.labels = reformat(labels)
self.transform = Augmentator() if transform else None
self.transform = AugmentImages() if transform else None

def __len__(self):
"""
Expand Down Expand Up @@ -227,16 +227,16 @@ def __getitem__(self, idx):
return {"inputs": inputs, "labels": self.labels[idx]}


# Miscellaneous
class Augmentator:
# Augmentation
class AugmentImages:
"""
Applies augmentation to an image chunk.
"""

def __init__(self):
"""
Constructs an Augmentator object.
Constructs an AugmentImages object.
Parameters
----------
Expand Down Expand Up @@ -276,6 +276,7 @@ def run(self, arr):
return self.transform(arr)


# -- utils --
def reformat(arr):
"""
Reformats a label vector for training by adding a dimension and casting it
Expand Down
File renamed without changes.
115 changes: 115 additions & 0 deletions src/deep_neurographs/machine_learning/ml_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Created on Sat November 04 15:30:00 2023
@author: Anna Grim
@email: [email protected]
Helper routines for training machine learning models.
"""

import numpy as np
from random import sample
from deep_neurographs.machine_learning.models import ConvNet, FeedForwardNet, MultiModalNet
from deep_neurographs import feature_extraction as extracter

SUPPORTED_MODELS = [
"AdaBoost",
"RandomForest",
"FeedForwardNet",
"ConvNet",
"MultiModalNet",
]


def get_kfolds(filenames, k):
"""
Partitions "filenames" into k-folds to perform cross validation.
Parameters
----------
filenames : list[str]
List of filenames of samples for training.
k : int
Number of folds to be used in k-fold cross validation.
Returns
-------
folds : list[list[str]]
Partition of "filesnames" into k-folds.
"""
folds = []
samples = set(filenames)
n_samples = int(np.floor(len(filenames) / k))
assert n_samples > 0, "Sample size is too small for {}-folds".format(k)
for i in range(k):
samples_i = sample(samples, n_samples)
samples = samples.difference(samples_i)
folds.append(samples_i)
if n_samples > len(samples):
break
return folds


def get_model_type(model):
# Set model_type
assert model in SUPPORTED_MODELS, "Model not supported!"
if type(model) == FeedForwardNet:
return "FeedForwardNet"
elif type(model) == ConvNet:
return "ConvNet"
elif type(model) == MultiModalNet:
return "MultiModalNet"
else:
print("Input model instead of model_type")


def init_model(model_type):
assert model_type in SUPPORTED_MODELS, "Model not supported!"
if model_type == "AdaBoost":
return AdaBoostClassifier()
elif model_type == "RandomForest":
return RandomForestClassifier()
elif model_type == "FeedForwardNet":
n_features = extracter.count_features(model_type)
return FeedForwardNet(n_features)
elif model_type == "ConvNet":
return ConvNet()
elif model_type == "MultiModalNet":
n_features = extracter.count_features(model_type)
return MultiModalNet(n_features)


def init_dataloader(model_type, augmentation=False):
"""
Gets classification model to be fit.
Parameters
----------
model_type : str
Indication of type of model. Options are "AdaBoost",
"RandomForest", "FeedForwardNet", "ConvNet", and
"MultiModalNet".
data : dict, optional
Training data used to fit model. This dictionary must contain the keys
"inputs" and "labels" which correspond to the feature matrix and
target labels to be learned. The default is None.
Returns
-------
...
"""
if model_type == "FeedForwardNet":
dataset = ds.ProposalDataset(data["inputs"], data["labels"], transform=augmentation)
elif model_type == "ConvNet":
dataset = ds.ImgProposalDataset(
data["inputs"], data["labels"], transform=True
)
elif model_type == "MultiModalNet":
models.init_weights(net)
dataset = ds.MultiModalDataset(
data["inputs"], data["labels"], transform=True
)
return net, dataset
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
)

from deep_neurographs import feature_extraction as extracter
from deep_neurographs.deep_learning import datasets as ds
from deep_neurographs.deep_learning import loss, models
#from deep_neurographs.deep_learning.datasets import ConvNet, FeedForwardNet, MultiModalNet
from deep_neurographs.machine_learning import datasets as ds
from deep_neurographs.machine_learning import loss, models, ml_utils
#from deep_neurographs.deep_learning.models import ConvNet, FeedForwardNet, MultiModalNet

logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

Expand All @@ -44,55 +44,19 @@
]


# -- Cross Validation --
def get_kfolds(filenames, k):
"""
Partitions "filenames" into k-folds to perform cross validation.
Parameters
----------
filenames : list[str]
List of filenames of samples for training.
k : int
Number of folds to be used in k-fold cross validation.
Returns
-------
folds : list[list[str]]
Partition of "filesnames" into k-folds.
"""
folds = []
samples = set(filenames)
n_samples = int(np.floor(len(filenames) / k))
assert n_samples > 0, "Sample size is too small for {}-folds".format(k)
for i in range(k):
samples_i = sample(samples, n_samples)
samples = samples.difference(samples_i)
folds.append(samples_i)
if n_samples > len(samples):
break
return folds


# -- Training --
def run_on_blocks(neurographs, features, dataset, model, block_ids=None):
# Set model_type
if type(model) == FeedForwardNet:
model_type = "FeedForwardNet"
elif type(model) == ConvNet:
model_type = "ConvNet"
elif type(model) == MultiModalNet:
model_type = "MultiModalNet"
else:
print("Input model instead of model_type")
def run(neurographs, features, model, block_ids=None):
i


def run_on_blocks(neurographs, features, model, block_ids):
# Initialize data
model_type = ml_utils.get_model_type(model)
X_train, y_train, _, _ = extracter.get_feature_matrix(
neurographs,
features,
model_type,
block_ids=train_blocks,
block_ids=block_ids,
)


Expand Down
58 changes: 35 additions & 23 deletions src/deep_neurographs/neurograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,13 @@ def copy_graph(self, add_attrs=False):
graph = nx.Graph()
graph.add_nodes_from(self.nodes(data=add_attrs))
if add_attrs:
for edge in self.get_edges_temp():
for edge in self.edges:
i, j = tuple(edge)
graph.add_edge(i, j, **self.get_edge_data(i, j))
else:
graph.add_edges_from(self.get_edges_temp())
graph.add_edges_from(self.edges)
return graph

def get_edges_temp(self):
edges = []
for edge in self.edges:
edge = frozenset(edge)
if edge not in self.proposals:
edges.append(edge)
return edges

# --- Add nodes or edges ---
def add_swc_id(self, swc_id):
self.swc_ids.add(swc_id)
Expand All @@ -107,17 +99,18 @@ def add_component(self, irreducibles):
# Add edges
for edge, values in irreducibles["edges"].items():
i, j = edge
self.add_edge(
node_ids[i],
node_ids[j],
radius=values["radius"],
xyz=values["xyz"],
swc_id=swc_id,
)
edge = (node_ids[i], node_ids[j])
for xyz in values["xyz"][::2]:
self.xyz_to_edge[tuple(xyz)] = edge
self.xyz_to_edge[tuple(values["xyz"][-1])] = edge
if self.branch_contained(values["xyz"]):
self.add_edge(
node_ids[i],
node_ids[j],
radius=values["radius"],
xyz=values["xyz"],
swc_id=swc_id,
)
edge = (node_ids[i], node_ids[j])
for xyz in values["xyz"][::2]:
self.xyz_to_edge[tuple(xyz)] = edge
self.xyz_to_edge[tuple(values["xyz"][-1])] = edge

def __add_nodes(self, nodes, key, node_ids, cur_id, swc_id):
for i in nodes[key].keys():
Expand All @@ -135,7 +128,7 @@ def __add_nodes(self, nodes, key, node_ids, cur_id, swc_id):
self.junctions.add(cur_id)
cur_id += 1
return node_ids, cur_id

# --- Proposal and Ground Truth Generation ---
def generate_proposals(
self,
Expand Down Expand Up @@ -493,6 +486,12 @@ def is_contained(self, node_or_xyz, buffer=0):
else:
return True

def branch_contained(self, xyz_list):
if self.bbox:
return all([self.is_contained(xyz, buffer=-32) for xyz in xyz_list])
else:
return True

def to_img(self, node_or_xyz, shift=False):
shift = self.origin if shift else np.zeros((3))
if type(node_or_xyz) == int:
Expand Down Expand Up @@ -630,6 +629,17 @@ def filter_nodes(self):
nbs = list(self.neighbors(i))
self.absorb_node(i, nbs[0], nbs[1])

def delete_isolated(self):
delete_nodes = set()
for i in self.nodes:
if self.degree[i] == 0:
delete_nodes.add(i)
if i in self.leafs:
self.leafs.remove(i)
elif i in self.junctions:
self.junctions.remove(i)
self.remove_nodes_from(delete_nodes)

def absorb_node(self, i, nb_1, nb_2):
# Get attributes
xyz = self.get_branches(i, key="xyz")
Expand Down Expand Up @@ -663,7 +673,9 @@ def merge_proposal(self, edge):

def to_swc(self, path):
for i, component in enumerate(nx.connected_components(self)):
component_path = os.path.join(path, f"neuron-{i}.swc")
node = sample(component, 1)[0]
swc_id = self.nodes[node]["swc_id"]
component_path = os.path.join(path, f"{swc_id}.swc")
self.component_to_swc(component_path, component)

def component_to_swc(self, path, component):
Expand Down
4 changes: 2 additions & 2 deletions src/deep_neurographs/seedbased_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def build_from_soma(
pass


def get_swc_ids(path, xyz, chunk_shape):
def get_swc_ids(path, xyz, chunk_shape, from_center=True):
img = utils.open_tensorstore(path, "neuroglancer_precomputed")
img = utils.read_tensorstore(img, xyz, chunk_shape)
img = utils.read_tensorstore(img, xyz, chunk_shape, from_center=from_center)
return set(fastremap.unique(img).astype(int))


Expand Down
Loading

0 comments on commit 2848d67

Please sign in to comment.