Skip to content

Commit

Permalink
Added magnification levels for scaling and extended Multiplexed datas…
Browse files Browse the repository at this point in the history
…et for GT
  • Loading branch information
JLrumberger committed Nov 14, 2024
1 parent d5e12b9 commit 17d89f6
Show file tree
Hide file tree
Showing 6 changed files with 419 additions and 192 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"natsort",
"ipython",
"zarr",
"lmdb",
]

[[project.source]]
Expand Down
44 changes: 8 additions & 36 deletions src/nimbus_inference/nimbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,25 @@ class Nimbus(nn.Module):
dataset (MultiplexDataset): Path to directory containing fovs.
output_dir (str): Path to directory to save output.
save_predictions (bool): Whether to save predictions.
half_resolution (bool): Whether to run model on half resolution images.
model_magnification (int): Expected magnification of images.
batch_size (int): Batch size for model inference.
test_time_aug (bool): Whether to use test time augmentation.
input_shape (list): Shape of input images.
suffix (str): Suffix of images to load.
device (str): Device to run model on, either "auto" (either "mps" or "cuda"
, with "cpu" as a fallback), "cpu", "cuda", or "mps". Defaults to "auto".
checkpoint: which checkpoint to use for the model, either "latest" or one of the local
checkpoints.
"""
def __init__(
self, dataset: MultiplexDataset, output_dir: str, save_predictions: bool=True,
half_resolution: bool=True, batch_size: int=4, test_time_aug: bool=True,
input_shape: list=[1024, 1024], device: str="auto",
batch_size: int=4, test_time_aug: bool=True, model_magnification: int=10,
input_shape: list=[1024, 1024], device: str="auto",
):
super(Nimbus, self).__init__()
self.dataset = dataset
self.output_dir = output_dir
self.half_resolution = half_resolution
self.model_magnification = model_magnification
self.save_predictions = save_predictions
self.batch_size = batch_size
self.checked_inputs = False
Expand Down Expand Up @@ -201,32 +203,6 @@ def initialize_model(self, padding="reflect"):
print(f"Loaded weights from {self.checkpoint_path}")
self.model = model.to(self.device).eval()

def prepare_normalization_dict(
self, quantile=0.999, clip_values=(0, 2), n_subset=10, multiprocessing=False,
overwrite=False,
):
"""Load or prepare and save normalization dictionary for Nimbus model.
Args:
quantile (float): Quantile to use for normalization.
clip_values (list): Values to clip images to after normalization.
n_subset (int): Number of fovs to use for normalization.
multiprocessing (bool): Whether to use multiprocessing.
overwrite (bool): Whether to overwrite existing normalization dict.
Returns:
dict: Dictionary of normalization factors.
"""
self.clip_values = tuple(clip_values)
self.normalization_dict_path = os.path.join(self.output_dir, "normalization_dict.json")
if os.path.exists(self.normalization_dict_path) and not overwrite:
self.normalization_dict = json.load(open(self.normalization_dict_path))
else:
n_jobs = os.cpu_count() if multiprocessing else 1
self.normalization_dict = prepare_normalization_dict(
self.dataset, self.output_dir, quantile, n_subset,
n_jobs
)

def predict_fovs(self):
"""Predicts cell classification for input data.
Expand All @@ -235,18 +211,15 @@ def predict_fovs(self):
"""
if self.checked_inputs == False:
self.check_inputs()
if not hasattr(self, "normalization_dict"):
self.prepare_normalization_dict()
# check if GPU is available
gpus = torch.cuda.device_count()
print("Available GPUs: ", gpus)
print("Predictions will be saved in {}".format(self.output_dir))
print("Iterating through fovs will take a while...")
self.cell_table = predict_fovs(
nimbus=self, dataset=self.dataset, output_dir=self.output_dir,
normalization_dict=self.normalization_dict, save_predictions=self.save_predictions,
half_resolution=self.half_resolution, batch_size=self.batch_size,
test_time_augmentation=self.test_time_aug, suffix=self.dataset.suffix,
save_predictions=self.save_predictions, batch_size=self.batch_size,
test_time_augmentation=self.test_time_aug,
)
self.cell_table.to_csv(os.path.join(self.output_dir, "nimbus_cell_table.csv"), index=False)
return self.cell_table
Expand All @@ -261,7 +234,6 @@ def predict_segmentation(self, input_data, preprocess_kwargs):
Returns:
np.array: Predicted segmentation.
"""
preprocess_kwargs["clip_values"] = self.clip_values
input_data = nimbus_preprocess(input_data, **preprocess_kwargs)
if np.all(np.greater_equal(self.input_shape, input_data.shape[-2:])):
if not hasattr(self, "model") or self.model.padding != "reflect":
Expand Down
176 changes: 159 additions & 17 deletions src/nimbus_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import sys, os
import logging
import os, sys
import lmdb


class HidePrints:
Expand Down Expand Up @@ -149,10 +150,17 @@ class MultiplexDataset():
path
suffix (str): suffix of channel images
silent (bool): whether to print messages
groundtruth_df (pd.DataFrame): groundtruth dataframe with columns fov, cell_id, channel and
activity (0: negative, 1: positive, 2: ambiguous)
magnification (int): magnification factor of the images (default: 20)
validation_fovs (list): list of fovs to use for validation
output_dir (str): path to output directory
"""
def __init__(
self, fov_paths: list, segmentation_naming_convention: Callable = None,
include_channels: list = [], suffix: str = ".tiff", silent=False,
include_channels: list = [], suffix: str = ".tiff", silent: bool = False,
groundtruth_df: pd.DataFrame = None, magnification: int = 20,
validation_fovs: list = [], output_dir: str = ""
):
self.fov_paths = fov_paths
self.segmentation_naming_convention = segmentation_naming_convention
Expand All @@ -166,6 +174,17 @@ def __init__(
self.check_inputs()
self.fovs = self.get_fovs()
self.channels = self.filter_channels(self.channels)
self.groundtruth_df = groundtruth_df
self.magnification = magnification
self.output_dir = output_dir

if validation_fovs and groundtruth_df is not None:
self.validation_fovs = validation_fovs
self.training_fovs = [fov for fov in self.fovs if fov not in self.validation_fovs]
elif groundtruth_df is not None:
num_validation_fovs = len(self.fovs)//10 if len(self.fovs) > 10 else 1
self.validation_fovs = self.fovs[-num_validation_fovs:]
self.training_fovs = self.fovs[:-num_validation_fovs]

def filter_channels(self, channels):
"""Filter channels based on include_channels
Expand All @@ -178,6 +197,31 @@ def filter_channels(self, channels):
if self.include_channels:
return [channel for channel in channels if channel in self.include_channels]
return channels

def get_groundtruth(self, fov: str, channel: str):
"""Get the groundtruth for a fov / channel combination
Args:
fov (str): name of a fov
channel (str): channel name
Returns:
np.array: groundtruth activity mask (0: negative, 1: positive, 2: ambiguous)
"""
if self.groundtruth_df is None:
raise ValueError("No groundtruth dataframe provided.")
subset_df = self.groundtruth_df[
(self.groundtruth_df["fov"] == fov) & (self.groundtruth_df["channel"] == channel)
]
positive_cells = subset_df[subset_df["activity"] == 1].cell_id.values
ambiguous_cells = subset_df[subset_df["activity"] == 2].cell_id.values
instance_mask = self.get_segmentation(fov)
groundtruth = np.zeros_like(instance_mask)
# get all positions of positive cells in instance mask without a for loop
positive_positions = np.where(np.isin(instance_mask, positive_cells))
ambiguous_positions = np.where(np.isin(instance_mask, ambiguous_cells))
groundtruth[positive_positions] = 1
groundtruth[ambiguous_positions] = 2
return groundtruth[np.newaxis,...] # 1, h, w

def check_inputs(self):
"""Check inputs for Nimbus model"""
Expand Down Expand Up @@ -292,6 +336,32 @@ def get_segmentation(self, fov: str):
instance_mask = instance_mask.astype(np.uint32)
return instance_mask

def prepare_normalization_dict(
self, quantile=0.999, clip_values=(0, 2), n_subset=10, multiprocessing=False,
overwrite=False,
):
"""Load or prepare and save normalization dictionary for Nimbus model.
Args:
quantile (float): Quantile to use for normalization.
clip_values (list): Values to clip images to after normalization.
n_subset (int): Number of fovs to use for normalization.
multiprocessing (bool): Whether to use multiprocessing.
overwrite (bool): Whether to overwrite existing normalization dict.
Returns:
dict: Dictionary of normalization factors.
"""
self.clip_values = tuple(clip_values)
self.normalization_dict_path = os.path.join(self.output_dir, "normalization_dict.json")
if os.path.exists(self.normalization_dict_path) and not overwrite:
self.normalization_dict = json.load(open(self.normalization_dict_path))
self.normalization_dict = {k: float(v) for k, v in self.normalization_dict.items()}
else:
n_jobs = os.cpu_count() if multiprocessing else 1
self.normalization_dict = prepare_normalization_dict(
self, self.output_dir, quantile, n_subset,
n_jobs
)

def prepare_input_data(mplex_img, instance_mask):
"""Prepares the input data for the segmentation model
Expand Down Expand Up @@ -327,7 +397,8 @@ def segment_mean(instance_mask, prediction):


def test_time_aug(
input_data, channel, app, normalization_dict, rotate=True, flip=True, batch_size=4
input_data, channel, app, normalization_dict, rotate=True, flip=True, batch_size=4,
clip_values=(0, 2)
):
"""Performs test time augmentation
Expand Down Expand Up @@ -367,7 +438,8 @@ def test_time_aug(
preprocess_kwargs={
"normalize": True,
"marker": channel,
"normalization_dict": normalization_dict},
"normalization_dict": normalization_dict,
"clip_values": clip_values},
)
if not isinstance(seg_map, torch.Tensor):
seg_map = torch.from_numpy(seg_map)
Expand All @@ -380,20 +452,17 @@ def test_time_aug(


def predict_fovs(
nimbus, dataset: MultiplexDataset, normalization_dict: dict,
output_dir: str, suffix: str="tiff", save_predictions: bool=True,
half_resolution: bool=False, batch_size: int=4, test_time_augmentation: bool=True
nimbus, dataset: MultiplexDataset, output_dir: str, suffix: str=".tiff",
save_predictions: bool=True, batch_size: int=4, test_time_augmentation: bool=True
):
"""Predicts the segmentation map for each mplex image in each fov
Args:
nimbus (Nimbus): nimbus object
dataset (MultiplexDataset): dataset object
normalization_dict (dict): dict with channel names as keys and norm factors as values
output_dir (str): path to output dir
suffix (str): suffix of mplex images
save_predictions (bool): whether to save predictions
half_resolution (bool): whether to use half resolution
batch_size (int): batch size
test_time_augmentation (bool): whether to use test time augmentation
Returns:
Expand All @@ -410,8 +479,8 @@ def predict_fovs(
for channel_name in tqdm(dataset.channels):
mplex_img = dataset.get_channel(fov, channel_name)
input_data = prepare_input_data(mplex_img, instance_mask)
if half_resolution:
scale = 0.5
if dataset.magnification != nimbus.model_magnification:
scale = nimbus.model_magnification / dataset.magnification
input_data = np.squeeze(input_data)
_, h,w = input_data.shape
img = cv2.resize(input_data[0], [int(w*scale), int(h*scale)])
Expand All @@ -421,20 +490,22 @@ def predict_fovs(
input_data = np.stack([img, binary_mask], axis=0)[np.newaxis,...]
if test_time_augmentation:
prediction = test_time_aug(
input_data, channel_name, nimbus, normalization_dict, batch_size=batch_size
input_data, channel_name, nimbus, dataset.normalization_dict,
batch_size=batch_size, clip_values=dataset.clip_values
)
else:
prediction = nimbus.predict_segmentation(
input_data,
preprocess_kwargs={
"normalize": True, "marker": channel_name,
"normalization_dict": normalization_dict
"normalization_dict": dataset.normalization_dict,
"clip_values": dataset.clip_values
},
)
if not isinstance(prediction, np.ndarray):
prediction = prediction.cpu().numpy()
prediction = np.squeeze(prediction)
if half_resolution:
if dataset.magnification != nimbus.model_magnification:
prediction = cv2.resize(prediction, (w, h), interpolation=cv2.INTER_NEAREST)
df = pd.DataFrame(segment_mean(instance_mask, prediction))
if df_fov.empty:
Expand All @@ -445,7 +516,8 @@ def predict_fovs(
os.makedirs(out_fov_path, exist_ok=True)
pred_int = (prediction*255.0).astype(np.uint8)
io.imwrite(
os.path.join(out_fov_path, channel_name + suffix), pred_int, photometric="minisblack",
os.path.join(out_fov_path, channel_name + suffix), pred_int,
photometric="minisblack",
# compress=0,
)
fov_dict_list.append(df_fov)
Expand Down Expand Up @@ -557,8 +629,78 @@ def prepare_normalization_dict(
norm_values = np.mean(norm_values)
if np.isnan(norm_values):
norm_values = 1e-8
normalization_dict[channel] = norm_values
# save normalization dict
normalization_dict[channel] = float(norm_values)
# turn numbers to strings and save normalization dict
normalization_dict_str = {k: str(v) for k, v in normalization_dict.items()}
with open(os.path.join(output_dir, output_name), 'w') as f:
json.dump(normalization_dict, f)
json.dump(normalization_dict_str, f)
return normalization_dict


def prepare_training_data(
nimbus, dataset: MultiplexDataset, output_dir: str, tile_size: int=512,
map_size: int=5
):
"""Prepares the training data and stores it into lmdb files for fine-tuning Nimbus
Args:
nimbus (Nimbus): nimbus object
dataset (MultiplexDataset): dataset object
output_dir (str): path to output directory
tile_size (int): size of the training tiles
map_size (int): size of the lmdb database in gigabytes
"""
# create lmdb env
for split, fovs in ((
"training", dataset.training_fovs),
("validation", dataset.validation_fovs)):
print(f"Preparing {split} data. storing data in {os.path.join(output_dir, split)}")
env = lmdb.open(
os.path.join(output_dir, split),
map_size=map_size*1024**3,
map_async=True,
max_dbs=0,
create=True
)
with env.begin(write=True) as txn:
for fov in tqdm(fovs):
instance_mask = dataset.get_segmentation(fov)
for channel_name in dataset.channels:
# load data
mplex_img = dataset.get_channel(fov, channel_name)
input_data = prepare_input_data(mplex_img, instance_mask)
input_data = nimbus_preprocess(
input_data, normalize=True, marker=channel_name,
normalization_dict=dataset.normalization_dict
)
groundtruth = dataset.get_groundtruth(fov, channel_name)
# resize data if necessary
if dataset.magnification != nimbus.model_magnification:
scale = nimbus.model_magnification / dataset.magnification
input_data = np.squeeze(input_data)
groundtruth = np.squeeze(groundtruth)
_, h,w = input_data.shape
img = cv2.resize(input_data[0], [int(w*scale), int(h*scale)])
binary_mask = cv2.resize(
input_data[1], [int(w*scale), int(h*scale)], interpolation=0
)
input_data = np.stack([img, binary_mask], axis=0) # 2, h, w
groundtruth = cv2.resize(
groundtruth.astype(np.uint8), [int(w*scale), int(h*scale)],
interpolation=0
)[np.newaxis, ...] # 1, h, w
# mirror pad and tile data
h, w = input_data.shape[-2:]
h_pad = h % tile_size
w_pad = w % tile_size
input_data = np.pad(input_data, ((0, 0), (0,h_pad), (0,w_pad)), mode="reflect")
groundtruth = np.pad(groundtruth, ((0, 0), (0,h_pad), (0,w_pad)), mode="reflect")
h, w = input_data.shape[-2:]
for i in range(0, h, tile_size):
for j in range(0, w, tile_size):
input_tile = input_data[..., i:i+tile_size, j:j+tile_size] # 2, h, w
gt_tile = groundtruth[..., i:i+tile_size, j:j+tile_size] # 1, h, w
sample_tile = np.concatenate([input_tile, gt_tile], axis=0) # 3, h, w
tile_key = f"{fov}_{channel_name}_{i}_{j}"
txn.put(tile_key.encode(), sample_tile.tobytes())

Loading

0 comments on commit 17d89f6

Please sign in to comment.