Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wandb #80

Merged
merged 7 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/params.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
record_path = "C:/Users/lorenz/Desktop/angelo_lab/MIBI_test/TNBC_CD45.tfrecord"
path = "C:/Users/lorenz/OneDrive/Desktop/angelo_lab/"
experiment = "test"
project = "Nimbus"
logging_mode = "offline"
model = "ModelBuilder"
num_steps = 20
lr = 1e-3
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ dependencies = [
"seaborn>=0.12",
"alpineer>=0.1.5",
"natsort>=7.1",
"tensorflow==2.8",
"protobuf<=3.20",
"tensorflow~=2.8.0",
"tensorflow_addons~=0.16.1",
"pydot>=1.4.2,<2",
"protobuf",
"wandb"
]
name = "cell_classification"
authors = [{ name = "Angelo Lab", email = "[email protected]" }]
Expand Down
2 changes: 1 addition & 1 deletion src/cell_classification/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
# exclude segmentation channel from analysis
seg_name = os.path.basename(self.segmentation_naming_convention(self.fov_paths[0]))
self.exclude_channels.append(seg_name.split(".")[0])
if self.output_dir is not '':
if self.output_dir != '':
os.makedirs(self.output_dir, exist_ok=True)

# initialize model and parent class
Expand Down
24 changes: 0 additions & 24 deletions src/cell_classification/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,6 @@
import pandas as pd
from sklearn.metrics import auc, confusion_matrix, roc_curve

from cell_classification.model_builder import ModelBuilder
from cell_classification.promix_naive import PromixNaive


def load_model(params):
"""Load model and validation data from params dict
Args:
params (dict):
dictionary containing model and validation data paths
Returns:
model (ModelBuilder):
trained model
val_data (tf.data.Dataset):
validation dataset
"""
params["eval"] = True
if params["model"] == "ModelBuilder":
model = ModelBuilder(params)
elif params["model"] == "PromixNaive":
model = PromixNaive(params)
model.prep_data()
model.load_model(params["model_path"])
return model


def calc_roc(pred_list, gt_key="marker_activity_mask", pred_key="prediction", cell_level=False):
"""Calculate ROC curve
Expand Down
135 changes: 85 additions & 50 deletions src/cell_classification/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
from tensorflow.keras.optimizers.schedules import CosineDecay
from deepcell.semantic_head import create_semantic_head
from tqdm import tqdm

from cell_classification.metrics import calc_scores
from cell_classification.augmentation_pipeline import (
get_augmentation_pipeline, prepare_tf_aug, py_aug)
from cell_classification.loss import Loss
from cell_classification.post_processing import (merge_activity_df,
process_to_cells)
from cell_classification.segmentation_data_prep import (feature_description,
parse_dict)
import wandb


class ModelBuilder:
Expand All @@ -42,6 +43,14 @@ def __init__(self, params):
self.prep_batches = self.gen_prep_batches_fn()
# make prep_batches a callable static method
self.prep_batches = staticmethod(self.prep_batches).__func__
# prepare folders
self.params["model_dir"] = os.path.join(
os.path.normpath(self.params["path"]), self.params["experiment"]
)
self.params["log_dir"] = os.path.join(self.params["model_dir"], "logs", str(int(time())))
os.makedirs(self.params["model_dir"], exist_ok=True)
os.makedirs(self.params["log_dir"], exist_ok=True)


def prep_data(self):
"""Prepares training and validation data"""
Expand Down Expand Up @@ -155,13 +164,6 @@ def prep_data(self):

def prep_model(self):
"""Prepares the model for training"""
# prepare folders
self.params["model_dir"] = os.path.join(
os.path.normpath(self.params["path"]), self.params["experiment"]
)
self.params["log_dir"] = os.path.join(self.params["model_dir"], "logs", str(int(time())))
os.makedirs(self.params["model_dir"], exist_ok=True)
os.makedirs(self.params["log_dir"], exist_ok=True)
if "model_path" not in self.params.keys() or self.params["model_path"] is None:
self.params["model_path"] = os.path.join(
self.params["model_dir"], "{}.h5".format(self.params["experiment"])
Expand Down Expand Up @@ -232,6 +234,15 @@ def train(self):
# initialize data and model
self.prep_data()

wandb.init(
name=self.params["experiment"],
project=self.params["project"],
entity="kainmueller-lab",
config=self.params,
dir=self.params["log_dir"],
mode=self.params["logging_mode"]
)

# make transformations on the training dataset
augmentation_pipeline = get_augmentation_pipeline(self.params)
tf_aug = prepare_tf_aug(augmentation_pipeline)
Expand Down Expand Up @@ -259,10 +270,8 @@ def train(self):
with open(os.path.join(self.params["model_dir"], "params.toml"), "w") as f:
toml.dump(self.params, f)

self.summary_writer = tf.summary.create_file_writer(self.params["log_dir"])
self.step = 0
self.global_val_loss = []
self.val_loss_history = {}
self.val_f1_history = {}
self.train_loss_tmp = []
while self.step < self.params["num_steps"]:
for x, y in tqdm(self.train_dataset):
Expand All @@ -272,6 +281,7 @@ def train(self):
self.tensorboard_callbacks(x, y)
if self.step > self.params["num_steps"]:
break
wandb.finish()

def tensorboard_callbacks(self, x, y):
"""Logs training metrics to Tensorboard
Expand All @@ -280,13 +290,11 @@ def tensorboard_callbacks(self, x, y):
y (tf.Tensor): ground truth labels
"""
if self.step % 10 == 0:
with self.summary_writer.as_default():
tf.summary.scalar(
"train_loss", tf.reduce_mean(self.train_loss_tmp), step=self.step
)
tf.summary.scalar(
"lr", self.model.optimizer._decayed_lr(tf.float32), step=self.step
)
wandb.log({
"train_loss": tf.reduce_mean(self.train_loss_tmp),
"lr": self.model.optimizer._decayed_lr(tf.float32),
"step": self.step
})
print(
"Step: {step}, loss {loss}".format(
step=self.step, loss=tf.reduce_mean(self.train_loss_tmp))
Expand All @@ -301,37 +309,64 @@ def tensorboard_callbacks(self, x, y):
y = self.strategy.experimental_local_results(y)[0]
else:
y_pred = self.model(x, training=False)
with self.summary_writer.as_default():
tf.summary.image(
"x_0 | y | y_pred",
tf.concat([
x[:1, ..., :1],
x[:1, ..., 1:2] * 0.25 + tf.cast(y[:1, ..., :1], tf.float32),
y_pred[:1, ..., :1]], axis=0,
),
step=self.step,
)
wandb.log({
"x_0": wandb.Image(x[:1, ..., :1]),
"y": wandb.Image(x[:1, ..., 1:2] * 0.25 + tf.cast(y[:1, ..., :1], tf.float32)),
"y_pred": wandb.Image(y_pred[:1, ..., :1]),
"step": self.step
})
# run validation and write to tensorboard
if self.step % self.params["val_steps"] == 0:
print("Running validation...")
metric_dict = {}
for validation_dataset, dataset_name in zip(
self.validation_datasets, self.dataset_names
):
validation_dataset = validation_dataset.map(
self.prep_batches, num_parallel_calls=tf.data.AUTOTUNE
)
val_loss = self.model.evaluate(validation_dataset, verbose=1)
print("Validation loss:", val_loss)
if dataset_name not in self.val_loss_history.keys():
self.val_loss_history[dataset_name] = []
self.val_loss_history[dataset_name].append(val_loss)
with self.summary_writer.as_default():
tf.summary.scalar(dataset_name + "_val", val_loss, step=self.step)
val_loss = np.mean([val_loss[-1] for val_loss in self.val_loss_history.values()])
self.global_val_loss.append(val_loss)
with self.summary_writer.as_default():
tf.summary.scalar("global_val", val_loss, step=self.step)
if val_loss <= tf.reduce_min(self.global_val_loss):
activity_df = self.predict_dataset_list(validation_dataset, save_predictions=False)
for marker in activity_df.marker.unique():
tmp_df = activity_df[activity_df.marker == marker]
metrics = calc_scores(
gt=tmp_df["activity"].values, pred=tmp_df["prediction"].values, threshold=0.5
)
metric_dict[dataset_name + "/" + marker] = {
"precision": metrics["precision"],
"recall": metrics["recall"],
"f1_score": metrics["f1_score"],
"specificity": metrics["specificity"],
}
# average over all markers
metric_dict[dataset_name + "/avg"] = {
"precision": np.mean(
[v["precision"] for k, v in metric_dict.items() if dataset_name in k]
),
"recall": np.mean(
[v["recall"] for k, v in metric_dict.items() if dataset_name in k]
),
"f1_score": np.mean(
[v["f1_score"] for k, v in metric_dict.items() if dataset_name in k]
),
"specificity": np.mean(
[v["specificity"] for k, v in metric_dict.items() if dataset_name in k]
),
}
# average over all datasets
metric_dict["avg"] = {
"precision": np.mean(
[v["precision"] for k, v in metric_dict.items() if "avg" in k]
),
"recall": np.mean(
[v["recall"] for k, v in metric_dict.items() if "avg" in k]
),
"f1_score": np.mean(
[v["f1_score"] for k, v in metric_dict.items() if "avg" in k]
),
"specificity": np.mean(
[v["specificity"] for k, v in metric_dict.items() if "avg" in k]
),
}
self.val_f1_history = metric_dict["avg"]["f1_score"]
wandb.log(metric_dict)
if metric_dict["avg"]["f1_score"] >= tf.reduce_max(self.val_f1_history):
print("Saving model to", self.params["model_path"])
self.model.save_weights(self.params["model_path"])
# run external validation
Expand All @@ -347,8 +382,7 @@ def tensorboard_callbacks(self, x, y):
if dataset_name not in self.val_loss_history.keys():
self.val_loss_history[dataset_name] = []
self.val_loss_history[dataset_name].append(val_loss)
with self.summary_writer.as_default():
tf.summary.scalar(dataset_name + "_val", val_loss, step=self.step)
wandb.log({dataset_name + "_val": val_loss})
if "save_model_on_dataset_name" in self.params.keys():
current = self.val_loss_history[self.params["save_model_on_dataset_name"]][-1]
if current <= self.best_val_loss[self.params["save_model_on_dataset_name"]]:
Expand Down Expand Up @@ -636,7 +670,6 @@ def predicate(dataset, marker):
)
return dataset


def predict_dataset_list(
self, datasets, save_predictions=True, fname="predictions", ckpt_path=None
):
Expand All @@ -652,17 +685,18 @@ def predict_dataset_list(
if ckpt_path:
self.load_model(ckpt_path)
print("Loaded model from", ckpt_path)
else:
elif not hasattr(self, "model") or self.model is None:
self.load_model(self.params["model_path"])
print("Loaded model from", self.params["model_path"])
if isinstance(datasets, tf.data.Dataset):
datasets = [datasets]

df_list = []
for dataset in datasets:
dataset = dataset.prefetch(tf.data.AUTOTUNE)
for example in dataset:
for j, example in enumerate(dataset):
x_batch, _ = self.prep_batches(example)
prediction = self.predict(x_batch)

for i, df in enumerate(example["activity_df"].numpy()):
df = pd.read_json(df.decode())
cell_ids, mean_per_cell = segment_mean(
Expand All @@ -675,6 +709,7 @@ def predict_dataset_list(
df["marker"] = [example["marker"].numpy()[i].decode()] * len(df)
df_list.append(df)
df = pd.concat(df_list)
df = df[df["labels"] != 0] # remove background
if save_predictions:
df.to_csv(os.path.join(self.params["model_dir"], fname+".csv"))
return df
Expand All @@ -691,4 +726,4 @@ def predict_dataset_list(
args = parser.parse_args()
params = toml.load(args.params)
trainer = ModelBuilder(params)
trainer.train()
trainer.train()
Loading
Loading