From c7793a8524e7893ca321672a8a87101ba2a8977d Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 13 Feb 2024 15:48:52 +0100 Subject: [PATCH 01/31] Added torch model --- ...ec_mskc_mskp_2_channel_halfres_512_bs32.pt | 3 + src/nimbus_inference/unet.py | 267 ++++++++++++++++++ 2 files changed, 270 insertions(+) create mode 100644 src/nimbus_inference/assets/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt create mode 100644 src/nimbus_inference/unet.py diff --git a/src/nimbus_inference/assets/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt b/src/nimbus_inference/assets/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt new file mode 100644 index 0000000..6d98b5e --- /dev/null +++ b/src/nimbus_inference/assets/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9f607c19ecf3b1ff998b52b75a8290efefb315ba2808e7d314e6444c8fda885 +size 142545376 diff --git a/src/nimbus_inference/unet.py b/src/nimbus_inference/unet.py new file mode 100644 index 0000000..3442b53 --- /dev/null +++ b/src/nimbus_inference/unet.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def _get_filter_count(layer_idx, filters_root): + return 2 ** layer_idx * filters_root + + +class Pad2D(nn.Module): + def __init__(self, padding=(1, 1), mode="constant"): + """ Padding for 2D input (e.g. images). + Args: + padding: tuple of 2 ints, how many zeros to add at the beginning and at the end of + the 2 padding dimensions (rows and cols) + mode: "constant", "reflect", or "replicate" + """ + super(Pad2D, self).__init__() + if mode not in ["constant", "reflect", "replicate", "valid"]: + raise ValueError("Padding mode must be 'valid', 'constant', 'reflect', or 'replicate'") + self.padding = (padding[1], padding[1], padding[0], padding[0]) # PyTorch expects padding in reverse order + self.mode = mode + + def forward(self, x): + if self.mode =="valid": + return x + else: + return nn.functional.pad(x, self.padding, mode=self.mode) + + +def maybe_crop(x, target_shape, data_format): + """Center crops x to target_shape if necessary. + Args: + x: input tensor + target_shape: shape of a reference tensor in BHWC or BCHW format + data_format: data format, either "channels_last" or "channels_first" + Returns: + cropped tensor + """ + if data_format == "channels_last": + target_shape = target_shape[1:3] + x_shape = x.size()[1:3] + h_diff = x_shape[0] - target_shape[0] + w_diff = x_shape[1] - target_shape[1] + h_crop_start = h_diff // 2 + w_crop_start = w_diff // 2 + h_crop_end = h_diff - h_crop_start + w_crop_end = w_diff - w_crop_start + x = x[:, h_crop_start:-h_crop_end, w_crop_start:-w_crop_end, :] + elif data_format == "channels_first": + target_shape = target_shape[2:4] + x_shape = x.size()[2:4] + h_diff = x_shape[0] - target_shape[0] + w_diff = x_shape[1] - target_shape[1] + h_crop_start = h_diff // 2 + w_crop_start = w_diff // 2 + h_crop_end = h_diff - h_crop_start + w_crop_end = w_diff - w_crop_start + x = x[:, :, h_crop_start:-h_crop_end, w_crop_start:-w_crop_end] + return x + + +class ConvBlock(nn.Module): + """Convolutional block consisting of two convolutional layers with same number of filters + and a batch normalization layer in between. + """ + def __init__( + self, layer_idx, filters_root, kernel_size, padding, activation, up=False, **kwargs, + ): + """Initialize ConvBlock. + Args: + layer_idx: index of the layer, used to compute the number of filters + filters_root: number of filters in the first convolutional layer + kernel_size: size of convolutional kernels + padding: padding, either "VALID", "CONSTANT", "REFLECT", or "SYMMETRIC" + activation: activation to be used + data_format: data format, either "channels_last" or "channels_first" + """ + super(ConvBlock, self).__init__(**kwargs) + self.layer_idx=layer_idx + self.filters_root=filters_root + self.kernel_size=kernel_size + self.padding=padding + if layer_idx == 0: + # get in_channels from kwargs + in_ch = kwargs.get('in_channels', 2) + self.activation=getattr(nn, activation)() + filters = _get_filter_count(layer_idx, filters_root) + self.padding_layer = Pad2D(padding=(1, 1), mode=padding) + + if up: + self.conv2d_0 = nn.Conv2d( + in_channels=_get_filter_count(layer_idx+1, filters_root), + out_channels=filters, kernel_size=(1, 1), stride=1, padding='valid' + ) + else: + self.conv2d_0 = nn.Conv2d( + in_channels=_get_filter_count(layer_idx-1, filters_root) if layer_idx != 0 else in_ch, + out_channels=filters, kernel_size=(1, 1), stride=1, padding='valid' + ) + self.conv2d_1 = nn.Conv2d( + in_channels=filters, out_channels=filters, + kernel_size=(3, 3), stride=1, padding='valid' + ) + self.bn_1 = nn.BatchNorm2d(filters) + self.conv2d_2 = nn.Conv2d( + in_channels=filters, out_channels=filters,kernel_size=(3, 3), stride=1, padding='valid' + ) + self.bn_2 = nn.BatchNorm2d(filters) + + + + def forward(self, x): + skip = self.conv2d_0(x) + x = self.padding_layer(skip) + x = self.conv2d_1(x) + x = self.bn_1(x) + x = self.activation(x) + x = self.padding_layer(x) + x = self.conv2d_2(x) + x = self.bn_2(x) + x = self.activation(x) + if self.padding == "valid": + skip = maybe_crop(skip, x.size(), self.data_format) + x = x + skip + return x + + +class UpconvBlock(nn.Module): + """Upconvolutional block consisting of an upsampling layer and a convolutional layer. + """ + def __init__( + self, layer_idx, filters_root, kernel_size, pool_size, padding, activation, **kwargs + ): + """UpconvBlock initializer. + Args: + layer_idx: index of the layer, used to compute the number of filters + filters_root: number of filters in the first convolutional layer + kernel_size: size of convolutional kernels + pool_size: size of the pooling layer + padding: padding, either "VALID", "CONSTANT", "REFLECT", or "SYMMETRIC" + activation: activation to be used + data_format: data format, either "channels_last" or "channels_first" + """ + super(UpconvBlock, self).__init__(**kwargs) + self.layer_idx=layer_idx + self.filters_root=filters_root + self.kernel_size=kernel_size + self.pool_size=pool_size + self.padding=padding + self.activation=activation + + filters = _get_filter_count(layer_idx, filters_root) + self.padding_layer = Pad2D(padding=(1, 1), mode=padding) + self.upconv = nn.ConvTranspose2d( + in_channels=filters, out_channels=filters // 2, + kernel_size=pool_size, stride=pool_size, padding=0 + ) + self.activation_1 = getattr(nn, activation)() + + def forward(self, x): + # x = self.padding_layer(x) + x = self.upconv(x) + x = self.activation_1(x) + return x + + +class CropConcatBlock(nn.Module): + """CropConcatBlock that crops spatial dimensions and concatenates filter maps. + """ + def __init__(self, **kwargs): + """CropConcatBlock initializer. + Args: + data_format: data format, either "channels_last" or "channels_first" + """ + super(CropConcatBlock, self).__init__(**kwargs) + + def forward(self, x, down_layer): + """Apply CropConcatBlock to inputs. + Args: + x: input tensor + down_layer: tensor from the contracting path + Returns: + output tensor + """ + x1_shape = down_layer.shape + x2_shape = x.shape + height_diff = abs(x1_shape[2] - x2_shape[2]) // 2 + width_diff = abs(x1_shape[3] - x2_shape[3]) // 2 + down_layer_cropped = down_layer[:,:, + height_diff: (x2_shape[2] + height_diff), + width_diff: (x2_shape[3] + width_diff)] + x = torch.cat([down_layer_cropped, x], dim=1) + return x + + +class UNet(nn.Module): + def __init__(self, nx: int = 512, + ny: int = 512, + channels: int = 1, + num_classes: int = 2, + layer_depth: int = 5, + filters_root: int = 64, + data_format = "channels_first", + kernel_size: int = 3, + pool_size: int = 2, + padding: str = "reflect", + activation: str = 'ReLU'): + super(UNet, self).__init__() + + self.layer_depth = layer_depth + self.contracting_layers = nn.ModuleList() + self.expanding_layers = nn.ModuleList() + + for layer_idx in range(layer_depth-1): + conv_block = ConvBlock( + layer_idx = layer_idx, filters_root=filters_root, kernel_size=kernel_size, + padding=padding, activation=activation) + self.contracting_layers.append(conv_block) + self.contracting_layers.append( + nn.MaxPool2d(kernel_size=(pool_size, pool_size)) + ) + self.bottle_neck = ConvBlock( + layer_idx=layer_idx+1, filters_root=filters_root, kernel_size=kernel_size, + padding=padding, activation=activation + ) + for layer_idx in range(layer_depth-2, -1, -1): + upconv_block = UpconvBlock( + layer_idx=layer_idx+1, filters_root=filters_root, kernel_size=kernel_size, + padding=padding, activation=activation, pool_size=(pool_size, pool_size) + ) + crop_concat_block = CropConcatBlock() + conv_block = ConvBlock( + layer_idx=layer_idx, filters_root=filters_root, kernel_size=kernel_size, + padding=padding, activation=activation, up=True + ) + self.expanding_layers += [upconv_block, crop_concat_block, conv_block] + + self.final_conv = nn.Conv2d(filters_root, num_classes, kernel_size=(1, 1)) + if data_format == "channels_last": + self.to(memory_format=torch.channels_last) + + def forward(self, x): + contracting_outputs = [] + for layer in self.contracting_layers: + x = layer(x) + if isinstance(layer, ConvBlock): + contracting_outputs.append(x) + x = self.bottle_neck(x) + + for layer in self.expanding_layers: + if isinstance(layer, CropConcatBlock): + x = layer(x, contracting_outputs.pop(-1)) + else: + x = layer(x) + x = self.final_conv(x) + return torch.sigmoid(x) + + +if __name__ == "__main__": + model = UNet(num_classes=1) + x = torch.rand(1,2,512,512) + out = model(x) + from torchsummary import summary + model = model.cuda() + summary(model, (2, 512, 512)) + print(out.shape) \ No newline at end of file From 376af208d6146e24099e03b26bcb03f0b9f8bbaf Mon Sep 17 00:00:00 2001 From: Lenz Date: Fri, 16 Feb 2024 15:36:08 +0100 Subject: [PATCH 02/31] Added tile and stitch inference on nimbus object --- src/nimbus_inference/nimbus.py | 332 +++++++++++++++++++++++++++++++++ src/nimbus_inference/unet.py | 17 +- src/nimbus_inference/utils.py | 272 +++++++++++++++++++++++++++ tests/test_nimbus.py | 88 +++++++++ 4 files changed, 701 insertions(+), 8 deletions(-) create mode 100644 src/nimbus_inference/nimbus.py create mode 100644 src/nimbus_inference/utils.py create mode 100644 tests/test_nimbus.py diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py new file mode 100644 index 0000000..247a9fd --- /dev/null +++ b/src/nimbus_inference/nimbus.py @@ -0,0 +1,332 @@ +# nimbus model class +# loads the image +# does inference on images, decides if tile & stitch or whole image inference +# preprocesses the predictions +# saves the output to the output folder + +from alpineer import io_utils +from skimage.util.shape import view_as_windows +import nimbus_inference +from nimbus_inference.utils import prepare_normalization_dict, predict_fovs, predict_ome_fovs, nimbus_preprocess +from nimbus_inference.unet import UNet +from tqdm.autonotebook import tqdm +from pathlib import Path +from torch import nn +from glob import glob +import numpy as np +import torch +import json +import os + + +def nimbus_preprocess(image, **kwargs): + """Preprocess input data for Nimbus model. + Args: + image: array to be processed + Returns: + np.array: processed image array + """ + output = np.copy(image) + if len(image.shape) != 4: + raise ValueError("Image data must be 4D, got image of shape {}".format(image.shape)) + + normalize = kwargs.get('normalize', True) + if normalize: + marker = kwargs.get('marker', None) + normalization_dict = kwargs.get('normalization_dict', {}) + if marker in normalization_dict.keys(): + norm_factor = normalization_dict[marker] + else: + print("Norm_factor not found for marker {}, calculating directly from the image. \ + ".format(marker)) + norm_factor = np.quantile(output[..., 0], 0.999) + # normalize only marker channel in chan 0 not binary mask in chan 1 + output[..., 0] /= norm_factor + output = output.clip(0, 1) + return output + + +def prep_naming_convention(deepcell_output_dir): + """Prepares the naming convention for the segmentation data + Args: + deepcell_output_dir (str): path to directory where segmentation data is saved + Returns: + segmentation_naming_convention (function): function that returns the path to the + segmentation data for a given fov + """ + def segmentation_naming_convention(fov_path): + """Prepares the path to the segmentation data for a given fov + Args: + fov_path (str): path to fov + Returns: + seg_path (str): paths to segmentation fovs + """ + fov_name = os.path.basename(fov_path) + return os.path.join( + deepcell_output_dir, fov_name + "_whole_cell.tiff" + ) + return segmentation_naming_convention + + +class Nimbus(nn.Module): + """Nimbus application class for predicting marker activity for cells in multiplexed images. + """ + def __init__( + self, fov_paths, segmentation_naming_convention, output_dir, + save_predictions=True, include_channels=[], half_resolution=True, + batch_size=4, test_time_aug=True, input_shape=[1024,1024], suffix=".tiff", + ): + """Initializes a Nimbus Application. + Args: + fov_paths (list): List of paths to fovs to be analyzed. + exclude_channels (list): List of channels to exclude from analysis. + segmentation_naming_convention (function): Function that returns the path to the + segmentation mask for a given fov path. + output_dir (str): Path to directory to save output. + save_predictions (bool): Whether to save predictions. + half_resolution (bool): Whether to run model on half resolution images. + batch_size (int): Batch size for model inference. + test_time_aug (bool): Whether to use test time augmentation. + input_shape (list): Shape of input images. + """ + super(Nimbus, self).__init__() + self.fov_paths = fov_paths + self.include_channels = include_channels + self.segmentation_naming_convention = segmentation_naming_convention + self.output_dir = output_dir + self.half_resolution = half_resolution + self.save_predictions = save_predictions + self.batch_size = batch_size + self.checked_inputs = False + self.test_time_aug = test_time_aug + self.input_shape = input_shape + self.suffix = suffix + if self.output_dir != '': + os.makedirs(self.output_dir, exist_ok=True) + + def check_inputs(self): + """ check inputs for Nimbus model + """ + # check if all paths in fov_paths exists + io_utils.validate_paths(self.fov_paths) + + # check if segmentation_naming_convention returns valid paths + path_to_segmentation = self.segmentation_naming_convention(self.fov_paths[0]) + if not os.path.exists(path_to_segmentation): + raise FileNotFoundError("Function segmentation_naming_convention does not return valid\ + path. Segmentation path {} does not exist."\ + .format(path_to_segmentation)) + # check if output_dir exists + io_utils.validate_paths([self.output_dir]) + + if isinstance(self.include_channels, str): + self.include_channels = [self.include_channels] + self.checked_inputs = True + print("All inputs are valid.") + + def initialize_model(self, padding="reflect"): + """Initializes the model and load weights. + """ + model = UNet(num_classes=1, padding=padding) + # make sure path can be resolved on any OS and when importing from anywhere + self.checkpoint_path = os.path.normpath( + "src/nimbus_inference/assets/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt" + ) + if not os.path.exists(self.checkpoint_path): + path = os.path.abspath(nimbus_inference.__file__) + path = Path(path).resolve() + self.checkpoint_path = os.path.join( + *path.parts[:-3], 'assets', + 'resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt' + ) + if not os.path.exists(self.checkpoint_path): + self.checkpoint_path = os.path.abspath(*glob( + '**/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt' + ) + ) + + if not os.path.exists(self.checkpoint_path): + self.checkpoint_path = os.path.join( + os.getcwd(), 'assets', 'resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt' + ) + + if os.path.exists(self.checkpoint_path): + model.load_state_dict(torch.load(self.checkpoint_path)) + print("Loaded weights from {}".format(self.checkpoint_path)) + else: + raise FileNotFoundError("Could not find Nimbus weights at {ckpt_path}. \ + Current path is {current_path} and directory contains {dir_c},\ + path to cell_clasification i{p}".format( + ckpt_path=self.checkpoint_path, + current_path=os.getcwd(), + dir_c=os.listdir(os.getcwd()), + p=os.path.abspath(nimbus_inference.__file__) + ) + ) + self.model = model + + def prepare_normalization_dict( + self, quantile=0.999, n_subset=10, multiprocessing=False, overwrite=False, + ): + """Load or prepare and save normalization dictionary for Nimbus model. + Args: + quantile (float): Quantile to use for normalization. + n_subset (int): Number of fovs to use for normalization. + multiprocessing (bool): Whether to use multiprocessing. + overwrite (bool): Whether to overwrite existing normalization dict. + Returns: + dict: Dictionary of normalization factors. + """ + self.normalization_dict_path = os.path.join(self.output_dir, "normalization_dict.json") + if os.path.exists(self.normalization_dict_path) and not overwrite: + self.normalization_dict = json.load(open(self.normalization_dict_path)) + else: + + n_jobs = os.cpu_count() if multiprocessing else 1 + self.normalization_dict = prepare_normalization_dict( + self.fov_paths, self.output_dir, quantile, self.include_channels, n_subset, n_jobs + ) + if self.include_channels == []: + self.include_channels = list(self.normalization_dict.keys()) + + def predict_fovs(self): + """Predicts cell classification for input data. + Returns: + np.array: Predicted cell classification. + """ + if self.checked_inputs == False: + self.check_inputs() + if not hasattr(self, "normalization_dict"): + self.prepare_normalization_dict() + # check if GPU is available + print("Available GPUs: ", torch.cuda.device_count()) + print("Predictions will be saved in {}".format(self.output_dir)) + print("Iterating through fovs will take a while...") + if self.suffix.lower() in [".ome.tif", ".ome.tiff"]: + self.cell_table = predict_ome_fovs( + self.fov_paths, self.output_dir, self, self.normalization_dict, + self.segmentation_naming_convention, self.include_channels, self.save_predictions, + self.half_resolution, batch_size=self.batch_size, + test_time_augmentation=self.test_time_aug, + ) + elif self.suffix.lower() in [".tiff",".tif", ".jpg", ".jpeg", ".png"]: + self.cell_table = predict_fovs( + self.fov_paths, self.output_dir, self, self.normalization_dict, + self.segmentation_naming_convention, self.include_channels, self.save_predictions, + self.half_resolution, batch_size=self.batch_size, + test_time_augmentation=self.test_time_aug, + ) + self.cell_table.to_csv( + os.path.join(self.output_dir,"nimbus_cell_table.csv"), index=False + ) + return self.cell_table + + def predict_segmentation(self, input_data, preprocess_kwargs, batch_size): + """Predicts segmentation for input data. + Args: + input_data (np.array): Input data to predict segmentation for. + preprocess_kwargs (dict): Keyword arguments for preprocessing. + batch_size (int): Batch size for prediction. + Returns: + np.array: Predicted segmentation. + """ + input_data = nimbus_preprocess(input_data, **preprocess_kwargs) + if np.all(np.greater_equal(self.input_shape, input_data.shape[-2:])): + if not hasattr(self, "model") or self.model.padding != "reflect": + self.initialize_model(padding="reflect") + with torch.no_grad(): + prediction = self.model.predict(input_data) + else: + if not hasattr(self, "model") or self.model.padding != "valid": + self.initialize_model(padding="valid") + prediction = self._tile_and_stitch(input_data, batch_size) + return prediction + + def _tile_and_stitch(self, input_data): + """Predicts segmentation for input data using tile and stitch method. + Args: + input_data (np.array): Input data to predict segmentation for. + batch_size (int): Batch size for prediction. + Returns: + np.array: Predicted segmentation. + """ + with torch.no_grad(): + output_shape = self.model(torch.rand(1, 2, *self.input_shape)).shape[-2:] + input_shape = input_data.shape + # f^dl crop to have perfect shift equivariance inference + self.crop_by = np.array(output_shape) % 2**5 + output_shape = output_shape - self.crop_by + tiled_input, padding = self._tile_input( + input_data, tile_size=self.input_shape, output_shape=output_shape + ) + shape_diff = (self.input_shape - output_shape) // 2 + padding = [ + padding[0] - shape_diff[0], padding[1] - shape_diff[0], + padding[2] - shape_diff[1], padding[3] - shape_diff[1] + ] + h_t, h_w = tiled_input.shape[:2] + tiled_input = tiled_input.reshape(h_t*h_w, input_shape[1], *self.input_shape) # h_t,w_t,c,h,w -> h_t*w_t,c,h,w + # predict tiles + prediction = [] + for i in tqdm(range(0, len(tiled_input), self.batch_size)): + batch = torch.from_numpy(tiled_input[i:i + self.batch_size]).float() + if torch.cuda.is_available(): + batch = batch.cuda() + with torch.no_grad(): + pred = self.model(batch).cpu().numpy() + # crop pred + if self.crop_by.any(): + pred = pred[ + ..., self.crop_by[0]//2:-self.crop_by[0]//2, self.crop_by[1]//2:-self.crop_by[1]//2 + ] + prediction += [pred] + prediction = np.concatenate(prediction) # h_t*w_t,c,h,w + prediction = prediction.reshape(h_t, h_w, *prediction.shape[1:]) # h_t,w_t,c,h,w + # stitch tiles + prediction = self._stitch_tiles(prediction, padding) + return prediction + + def _tile_input(self, image, tile_size, output_shape, pad_mode="reflect"): + """Tiles input image for model inference. + Args: + image (np.array): Image to tile b,c,h,w. + tile_size (list): Size of input tiles. + output_shape (list): Shape of model output. + pad_mode (str): Padding mode for tiling. + Returns: + list: List of tiled images. + """ + overlap_px = (np.array(tile_size) - np.array(output_shape)) // 2 + # pad image to be divisible by tile size + pad_h0, pad_w0 = np.array(tile_size) - (np.array(image.shape[-2:]) % np.array(output_shape)) + pad_h1, pad_w1 = pad_h0 // 2, pad_h0 - pad_h0 // 2 + pad_h0, pad_w0 = pad_h0 - pad_h1, pad_w0 - pad_w1 + image = np.pad(image, ((0, 0), (0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1)), mode=pad_mode) + b,c = image.shape[:2] + # tile image with np.lib.stride_tricks.as_strided + view = np.squeeze(view_as_windows( + image, [b,c]+list(tile_size), step=[b,c]+list(output_shape) + ) + ) # h_t,w_t,c,h,w + padding = [pad_h0, pad_h1, pad_w0, pad_w1] + return view, padding + + def _stitch_tiles(self, tiles, padding): + """Stitches tiles to reconstruct full image. + Args: + tiles (np.array): Tiled predictions n_tiles x c x h x w. + input_shape (list): Shape of input image. + Returns: + np.array: Reconstructed image. + """ + # stitch tiles + h_t,w_t,c,h,w = tiles.shape + stitched = np.zeros((c, h_t*h, w_t*w)) + for i in range(h_t): + for j in range(w_t): + stitched[:, i*h:(i+1)*h, j*w:(j+1)*w] = tiles[i, j] + # remove padding + stitched = stitched[ + :, padding[0]:-padding[1], padding[2]:-padding[3] + ] + return stitched diff --git a/src/nimbus_inference/unet.py b/src/nimbus_inference/unet.py index 3442b53..7682cb3 100644 --- a/src/nimbus_inference/unet.py +++ b/src/nimbus_inference/unet.py @@ -28,7 +28,7 @@ def forward(self, x): return nn.functional.pad(x, self.padding, mode=self.mode) -def maybe_crop(x, target_shape, data_format): +def maybe_crop(x, target_shape, data_format="channels_first"): """Center crops x to target_shape if necessary. Args: x: input tensor @@ -86,7 +86,7 @@ def __init__( in_ch = kwargs.get('in_channels', 2) self.activation=getattr(nn, activation)() filters = _get_filter_count(layer_idx, filters_root) - self.padding_layer = Pad2D(padding=(1, 1), mode=padding) + self.padding_layer = Pad2D(padding=(1, 1), mode=self.padding) if up: self.conv2d_0 = nn.Conv2d( @@ -121,7 +121,7 @@ def forward(self, x): x = self.bn_2(x) x = self.activation(x) if self.padding == "valid": - skip = maybe_crop(skip, x.size(), self.data_format) + skip = maybe_crop(skip, x.size()) x = x + skip return x @@ -151,7 +151,7 @@ def __init__( self.activation=activation filters = _get_filter_count(layer_idx, filters_root) - self.padding_layer = Pad2D(padding=(1, 1), mode=padding) + self.padding_layer = Pad2D(padding=(1, 1), mode=self.padding) self.upconv = nn.ConvTranspose2d( in_channels=filters, out_channels=filters // 2, kernel_size=pool_size, stride=pool_size, padding=0 @@ -211,28 +211,29 @@ def __init__(self, nx: int = 512, self.layer_depth = layer_depth self.contracting_layers = nn.ModuleList() self.expanding_layers = nn.ModuleList() + self.padding = padding for layer_idx in range(layer_depth-1): conv_block = ConvBlock( layer_idx = layer_idx, filters_root=filters_root, kernel_size=kernel_size, - padding=padding, activation=activation) + padding=self.padding, activation=activation) self.contracting_layers.append(conv_block) self.contracting_layers.append( nn.MaxPool2d(kernel_size=(pool_size, pool_size)) ) self.bottle_neck = ConvBlock( layer_idx=layer_idx+1, filters_root=filters_root, kernel_size=kernel_size, - padding=padding, activation=activation + padding=self.padding, activation=activation ) for layer_idx in range(layer_depth-2, -1, -1): upconv_block = UpconvBlock( layer_idx=layer_idx+1, filters_root=filters_root, kernel_size=kernel_size, - padding=padding, activation=activation, pool_size=(pool_size, pool_size) + padding=self.padding, activation=activation, pool_size=(pool_size, pool_size) ) crop_concat_block = CropConcatBlock() conv_block = ConvBlock( layer_idx=layer_idx, filters_root=filters_root, kernel_size=kernel_size, - padding=padding, activation=activation, up=True + padding=self.padding, activation=activation, up=True ) self.expanding_layers += [upconv_block, crop_concat_block, conv_block] diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py new file mode 100644 index 0000000..6b3829a --- /dev/null +++ b/src/nimbus_inference/utils.py @@ -0,0 +1,272 @@ +import os +import cv2 +import json +import torch +import random +import numpy as np +import pandas as pd +from skimage import io +from tqdm.autonotebook import tqdm +from joblib import Parallel, delayed +from skimage.segmentation import find_boundaries + + +def calculate_normalization(channel_path, quantile): + """Calculates the normalization value for a given channel + Args: + channel_path (str): path to channel + quantile (float): quantile to use for normalization + Returns: + normalization_value (float): normalization value + """ + mplex_img = io.imread(channel_path) + normalization_value = np.quantile(mplex_img, quantile) + chan = os.path.basename(channel_path).split(".")[0] + return chan, normalization_value + + +def prepare_normalization_dict( + fov_paths, output_dir, quantile=0.999, include_channels=[], n_subset=10, n_jobs=1, + output_name="normalization_dict.json" + ): + """Prepares the normalization dict for a list of fovs + Args: + fov_paths (list): list of paths to fovs + output_dir (str): path to output directory + quantile (float): quantile to use for normalization + exclude_channels (list): list of channels to exclude + n_subset (int): number of fovs to use for normalization + n_jobs (int): number of jobs to use for joblib multiprocessing + output_name (str): name of output file + Returns: + normalization_dict (dict): dict with channel names as keys and norm factors as values + """ + normalization_dict = {} + if n_subset is not None: + random.shuffle(fov_paths) + fov_paths = fov_paths[:n_subset] + print("Iterate over fovs...") + for fov_path in tqdm(fov_paths): + channels = os.listdir(fov_path) + if include_channels: + channels = [ + channel for channel in channels if channel.split(".")[0] in include_channels + ] + channel_paths = [os.path.join(fov_path, channel) for channel in channels] + if n_jobs > 1: + normalization_values = Parallel(n_jobs=n_jobs)( + delayed(calculate_normalization)(channel_path, quantile) + for channel_path in channel_paths + ) + else: + normalization_values = [ + calculate_normalization(channel_path, quantile) + for channel_path in channel_paths + ] + for channel, normalization_value in normalization_values: + if channel not in normalization_dict: + normalization_dict[channel] = [] + normalization_dict[channel].append(normalization_value) + for channel in normalization_dict.keys(): + normalization_dict[channel] = np.mean(normalization_dict[channel]) + # save normalization dict + with open(os.path.join(output_dir, output_name), 'w') as f: + json.dump(normalization_dict, f) + return normalization_dict + + +def prepare_input_data(mplex_img, instance_mask): + """Prepares the input data for the segmentation model + Args: + mplex_img (np.array): multiplex image + instance_mask (np.array): instance mask + Returns: + input_data (np.array): input data for segmentation model + """ + edge = find_boundaries(instance_mask, mode="inner").astype(np.uint8) + binary_mask = np.logical_and(edge == 0, instance_mask > 0).astype(np.float32) + input_data = np.stack([mplex_img, binary_mask], axis=-1)[np.newaxis,...] # bhwc + return input_data + + +def segment_mean(instance_mask, prediction): + """Calculates the mean prediction per instance + Args: + instance_mask (np.array): instance mask + prediction (np.array): prediction + Returns: + uniques (np.array): unique instance ids + mean_per_cell (np.array): mean prediction per instance + """ + instance_mask_flat = tf.cast(tf.reshape(instance_mask, -1), tf.int32) # (h*w) + pred_flat = tf.cast(tf.reshape(prediction, -1), tf.float32) + sort_order = tf.argsort(instance_mask_flat) + instance_mask_flat = tf.gather(instance_mask_flat, sort_order) + uniques, _ = tf.unique(instance_mask_flat) + pred_flat = tf.gather(pred_flat, sort_order) + mean_per_cell = tf.math.segment_mean(pred_flat, instance_mask_flat) + mean_per_cell = tf.gather(mean_per_cell, uniques) + return [uniques.numpy()[1:], mean_per_cell.numpy()[1:]] # discard background + + +def test_time_aug( + input_data, channel, app, normalization_dict, rotate=True, flip=True, batch_size=4 + ): + """Performs test time augmentation + Args: + input_data (np.array): input data for segmentation model, mplex_img and binary mask + channel (str): channel name + app (tf.keras.Model): segmentation model + normalization_dict (dict): dict with channel names as keys and norm factors as values + rotate (bool): whether to rotate + flip (bool): whether to flip + batch_size (int): batch size + Returns: + seg_map (np.array): predicted segmentation map + """ + forward_augmentations = [] + backward_augmentations = [] + if rotate: + for k in [0,1,2,3]: + forward_augmentations.append(lambda x: torch.rot90(x, k=k)) + backward_augmentations.append(lambda x: torch.rot90(x, k=-k)) + if flip: + forward_augmentations += [ + lambda x: torch.flip(x, [2]), + lambda x: torch.flip(x, [3]) + ] + backward_augmentations += [ + lambda x: torch.flip(x, [2]), + lambda x: torch.flip(x, [3]) + ] + input_batch = [] + for forw_aug in forward_augmentations: + input_data_tmp = forw_aug(input_data).numpy() # bhwc + input_batch.append(np.concatenate(input_data_tmp)) + input_batch = np.stack(input_batch, 0) + seg_map = app._predict_segmentation( + input_batch, + batch_size=batch_size, + preprocess_kwargs={ + "normalize": True, + "marker": channel, + "normalization_dict": normalization_dict}, + ) + tmp = [] + for backw_aug, seg_map_tmp in zip(backward_augmentations, seg_map): + seg_map_tmp = backw_aug(seg_map_tmp[np.newaxis,...]) + seg_map_tmp = np.squeeze(seg_map_tmp) + tmp.append(seg_map_tmp) + seg_map = np.stack(tmp, -1) + seg_map = np.mean(seg_map, axis = -1, keepdims = True) + return seg_map + + +def predict_fovs( + nimbus, fov_paths, normalization_dict, segmentation_naming_convention, output_dir, + suffix, include_channels=[], save_predictions=True, half_resolution=False, batch_size=4, + test_time_augmentation=True + ): + """Predicts the segmentation map for each mplex image in each fov + Args: + nimbus (Nimbus): nimbus object + fov_paths (list): list of fov paths + normalization_dict (dict): dict with channel names as keys and norm factors as values + segmentation_naming_convention (function): function to get instance mask path from fov path + output_dir (str): path to output dir + suffix (str): suffix of mplex images + include_channels (list): list of channels to include + save_predictions (bool): whether to save predictions + half_resolution (bool): whether to use half resolution + batch_size (int): batch size + test_time_augmentation (bool): whether to use test time augmentation + Returns: + cell_table (pd.DataFrame): cell table with predicted confidence scores per fov and cell + """ + fov_dict_list = [] + for fov_path in tqdm(fov_paths): + out_fov_path = os.path.join( + os.path.normpath(output_dir), os.path.basename(fov_path) + ) + fov_dict = {} + for channel in os.listdir(fov_path): + channel_path = os.path.join(fov_path, channel) + channel_ = channel.split(".")[0] + if not channel.endswith(suffix) or channel not in include_channels: + continue + mplex_img = np.squeeze(io.imread(channel_path)) + instance_path = segmentation_naming_convention(fov_path) + instance_mask = np.squeeze(io.imread(instance_path)) + input_data = prepare_input_data(mplex_img, instance_mask) + if half_resolution: + scale = 0.5 + input_data = np.squeeze(input_data) + h,w,_ = input_data.shape + img = cv2.resize(input_data[...,0], [int(h*scale), int(w*scale)]) + binary_mask = cv2.resize( + input_data[...,1], [int(h*scale), int(w*scale)], interpolation=0 + ) + input_data = np.stack([img, binary_mask], axis=-1)[np.newaxis,...] + if test_time_augmentation: + prediction = test_time_aug( + input_data, channel, nimbus, normalization_dict, batch_size=batch_size + ) + else: + prediction = nimbus._predict_segmentation( + input_data, + preprocess_kwargs={ + "normalize": True, "marker": channel, + "normalization_dict": normalization_dict + }, + batch_size=batch_size + ) + prediction = np.squeeze(prediction) + if half_resolution: + prediction = cv2.resize(prediction, (h, w)) + instance_mask = np.expand_dims(instance_mask, axis=-1) + labels, mean_per_cell = segment_mean(instance_mask, prediction) + if "label" not in fov_dict.keys(): + fov_dict["fov"] = [os.path.basename(fov_path)]*len(labels) + fov_dict["label"] = labels + fov_dict[channel+"_pred"] = mean_per_cell + if save_predictions: + os.makedirs(out_fov_path, exist_ok=True) + pred_int = tf.cast(prediction*255.0, tf.uint8).numpy() + io.imsave( + os.path.join(out_fov_path, channel+".tiff"), pred_int, + photometric="minisblack", compression="zlib" + ) + fov_dict_list.append(pd.DataFrame(fov_dict)) + cell_table = pd.concat(fov_dict_list, ignore_index=True) + return cell_table + + +def nimbus_preprocess(image, **kwargs): + """Preprocess input data for Nimbus model. + Args: + image: array to be processed + Returns: + np.array: processed image array + """ + output = np.copy(image) + if len(image.shape) != 4: + raise ValueError("Image data must be 4D, got image of shape {}".format(image.shape)) + + normalize = kwargs.get('normalize', True) + if normalize: + marker = kwargs.get('marker', None) + normalization_dict = kwargs.get('normalization_dict', {}) + if marker in normalization_dict.keys(): + norm_factor = normalization_dict[marker] + else: + print("Norm_factor not found for marker {}, calculating directly from the image. \ + ".format(marker)) + norm_factor = np.quantile(output[..., 0], 0.999) + # normalize only marker channel in chan 0 not binary mask in chan 1 + output[..., 0] /= norm_factor + output = output.clip(0, 1) + return output + + +def predict_ome_fovs(): + pass \ No newline at end of file diff --git a/tests/test_nimbus.py b/tests/test_nimbus.py new file mode 100644 index 0000000..2daa799 --- /dev/null +++ b/tests/test_nimbus.py @@ -0,0 +1,88 @@ +from test_dataset import prepare_ome_tif_data, prepare_tif_data +import tempfile +from nimbus_inference.nimbus import Nimbus, prep_naming_convention +from nimbus_inference.unet import UNet +from skimage.data import astronaut +from skimage.transform import rescale +import numpy as np +import torch +import os + + +def test_check_inputs(): + with tempfile.TemporaryDirectory() as temp_dir: + num_samples = 5 + selected_markers = ["CD45", "CD3", "CD8", "ChyTr"] + fov_paths = prepare_tif_data(num_samples, temp_dir, selected_markers) + naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) + nimbus = Nimbus( + fov_paths=[temp_dir], segmentation_naming_convention=naming_convention, output_dir=temp_dir + ) + nimbus.check_inputs() + # check if the model is initialized + assert isinstance(nimbus.model, UNet) + + +def test_prepare_normalization_dict(): + # test if normalization dict gets prepared and saved, in-depth tests are in inference_test.py + with tempfile.TemporaryDirectory() as temp_dir: + + num_samples = 5 + selected_markers = ["CD45", "CD3", "CD8", "ChyTr"] + fov_paths = prepare_tif_data(num_samples, temp_dir, selected_markers) + naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) + nimbus = Nimbus( + fov_paths, naming_convention, temp_dir, + include_channels=["CD45", "CD3", "CD8"] + ) + # test if normalization dict gets prepared and saved + nimbus.prepare_normalization_dict(overwrite=True) + assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) + assert "ChyTr" not in nimbus.normalization_dict.keys() + + # test if normalization dict gets loaded + nimbus_2 = Nimbus( + fov_paths, naming_convention, temp_dir, include_channels=["CD45", "CD3", "CD8"] + ) + nimbus_2.prepare_normalization_dict() + assert nimbus_2.normalization_dict == nimbus.normalization_dict + + +def test_tile_input(): + image = torch.rand([1,2,768,768]) + tile_size = (512, 512) + output_shape = (320,320) + nimbus = Nimbus(fov_paths=[""], segmentation_naming_convention="", output_dir="") + nimbus.model = lambda x: x[..., 96:-96, 96:-96] + tiled_input, padding = nimbus._tile_input(image, tile_size, output_shape) + assert tiled_input.shape == (3,3,2,512,512) + assert padding == (192, 192, 192, 192) + + +def test_tile_and_stitch(): + # tests _tile_and_stitch which chains _tile_input, model.forward and _stitch_tiles + image = rescale(astronaut(), 1.5, channel_axis=-1) + image = np.moveaxis(image, -1, 0)[np.newaxis, ...] + nimbus = Nimbus( + fov_paths=[""], segmentation_naming_convention="", output_dir="", + input_shape=[512,512], batch_size=4 + ) + # check if tile and stitch works for mock model unequal input and output shape + # mock model only center crops the input, so that the stitched output is equal to the input + for s in [41, 89, 96]: + nimbus.model = lambda x: x[..., s:-s, s:-s] + out = nimbus._tile_and_stitch(image) + assert np.all( + np.isclose(np.transpose(image[0], (1,2,0)), np.transpose(out, (1,2,0)), rtol=1e-4) + ) + # check if tile and stitch works with the real model + nimbus.initialize_model(padding="valid") + image = np.random.rand(1, 2, 768, 768) + prediction = nimbus._tile_and_stitch(image) + assert prediction.shape == (1, 768, 768) + assert prediction.max() <= 1 + assert prediction.min() >= 0 + + +def test_predict_segmentation(): + pass \ No newline at end of file From 7058559e3a69b5c4664bd8e7be978607bac49107 Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 19 Feb 2024 18:16:18 +0100 Subject: [PATCH 03/31] Added tests and functionality for inference --- .github/workflows/test.yaml | 2 +- src/nimbus_inference/nimbus.py | 167 +++++++++++++---------- src/nimbus_inference/utils.py | 62 ++++----- tests/__init__.py | 0 tests/test_basic.py | 12 -- tests/test_nimbus.py | 30 +++-- tests/test_utils.py | 239 +++++++++++++++++++++++++++++++++ 7 files changed, 386 insertions(+), 126 deletions(-) create mode 100644 tests/__init__.py delete mode 100644 tests/test_basic.py create mode 100644 tests/test_utils.py diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7aff0c4..59fdc4d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -35,7 +35,7 @@ jobs: - name: Run Tests run: | - pytest + python -m pytest - name: Archive Coverage uses: actions/upload-artifact@v4 diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 247a9fd..203945a 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -1,13 +1,18 @@ # nimbus model class # loads the image # does inference on images, decides if tile & stitch or whole image inference -# preprocesses the predictions +# preprocesses the predictions # saves the output to the output folder from alpineer import io_utils from skimage.util.shape import view_as_windows import nimbus_inference -from nimbus_inference.utils import prepare_normalization_dict, predict_fovs, predict_ome_fovs, nimbus_preprocess +from nimbus_inference.utils import ( + prepare_normalization_dict, + predict_fovs, + predict_ome_fovs, + nimbus_preprocess, +) from nimbus_inference.unet import UNet from tqdm.autonotebook import tqdm from pathlib import Path @@ -30,15 +35,19 @@ def nimbus_preprocess(image, **kwargs): if len(image.shape) != 4: raise ValueError("Image data must be 4D, got image of shape {}".format(image.shape)) - normalize = kwargs.get('normalize', True) + normalize = kwargs.get("normalize", True) if normalize: - marker = kwargs.get('marker', None) - normalization_dict = kwargs.get('normalization_dict', {}) + marker = kwargs.get("marker", None) + normalization_dict = kwargs.get("normalization_dict", {}) if marker in normalization_dict.keys(): norm_factor = normalization_dict[marker] else: - print("Norm_factor not found for marker {}, calculating directly from the image. \ - ".format(marker)) + print( + "Norm_factor not found for marker {}, calculating directly from the image. \ + ".format( + marker + ) + ) norm_factor = np.quantile(output[..., 0], 0.999) # normalize only marker channel in chan 0 not binary mask in chan 1 output[..., 0] /= norm_factor @@ -54,6 +63,7 @@ def prep_naming_convention(deepcell_output_dir): segmentation_naming_convention (function): function that returns the path to the segmentation data for a given fov """ + def segmentation_naming_convention(fov_path): """Prepares the path to the segmentation data for a given fov Args: @@ -62,32 +72,32 @@ def segmentation_naming_convention(fov_path): seg_path (str): paths to segmentation fovs """ fov_name = os.path.basename(fov_path) - return os.path.join( - deepcell_output_dir, fov_name + "_whole_cell.tiff" - ) + return os.path.join(deepcell_output_dir, fov_name + "_whole_cell.tiff") + return segmentation_naming_convention class Nimbus(nn.Module): - """Nimbus application class for predicting marker activity for cells in multiplexed images. - """ + """Nimbus application class for predicting marker activity for cells in multiplexed images.""" + def __init__( - self, fov_paths, segmentation_naming_convention, output_dir, - save_predictions=True, include_channels=[], half_resolution=True, - batch_size=4, test_time_aug=True, input_shape=[1024,1024], suffix=".tiff", - ): + self, fov_paths, segmentation_naming_convention, output_dir, save_predictions=True, + include_channels=[], half_resolution=True, batch_size=4, test_time_aug=True, + input_shape=[1024, 1024], suffix=".tiff", + ): """Initializes a Nimbus Application. Args: fov_paths (list): List of paths to fovs to be analyzed. - exclude_channels (list): List of channels to exclude from analysis. segmentation_naming_convention (function): Function that returns the path to the segmentation mask for a given fov path. output_dir (str): Path to directory to save output. save_predictions (bool): Whether to save predictions. + include_channels (list): List of channels to include in analysis. half_resolution (bool): Whether to run model on half resolution images. batch_size (int): Batch size for model inference. test_time_aug (bool): Whether to use test time augmentation. input_shape (list): Shape of input images. + suffix (str): Suffix of images to load. """ super(Nimbus, self).__init__() self.fov_paths = fov_paths @@ -101,21 +111,25 @@ def __init__( self.test_time_aug = test_time_aug self.input_shape = input_shape self.suffix = suffix - if self.output_dir != '': + if self.output_dir != "": os.makedirs(self.output_dir, exist_ok=True) - + def check_inputs(self): - """ check inputs for Nimbus model - """ + """check inputs for Nimbus model""" # check if all paths in fov_paths exists + if not isinstance(self.fov_paths, (list, tuple)): + self.fov_paths = [self.fov_paths] io_utils.validate_paths(self.fov_paths) # check if segmentation_naming_convention returns valid paths path_to_segmentation = self.segmentation_naming_convention(self.fov_paths[0]) if not os.path.exists(path_to_segmentation): - raise FileNotFoundError("Function segmentation_naming_convention does not return valid\ - path. Segmentation path {} does not exist."\ - .format(path_to_segmentation)) + raise FileNotFoundError( + "Function segmentation_naming_convention does not return valid\ + path. Segmentation path {} does not exist.".format( + path_to_segmentation + ) + ) # check if output_dir exists io_utils.validate_paths([self.output_dir]) @@ -126,6 +140,8 @@ def check_inputs(self): def initialize_model(self, padding="reflect"): """Initializes the model and load weights. + Args: + padding (str): Padding mode for model, either "reflect" or "valid". """ model = UNet(num_classes=1, padding=padding) # make sure path can be resolved on any OS and when importing from anywhere @@ -136,38 +152,43 @@ def initialize_model(self, padding="reflect"): path = os.path.abspath(nimbus_inference.__file__) path = Path(path).resolve() self.checkpoint_path = os.path.join( - *path.parts[:-3], 'assets', - 'resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt' + *path.parts[:-3], + "assets", + "resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt" ) if not os.path.exists(self.checkpoint_path): - self.checkpoint_path = os.path.abspath(*glob( - '**/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt' + self.checkpoint_path = os.path.abspath( + *glob( + "**/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt" ) ) if not os.path.exists(self.checkpoint_path): self.checkpoint_path = os.path.join( - os.getcwd(), 'assets', 'resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt' + os.getcwd(), + "assets", + "resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt", ) if os.path.exists(self.checkpoint_path): model.load_state_dict(torch.load(self.checkpoint_path)) print("Loaded weights from {}".format(self.checkpoint_path)) else: - raise FileNotFoundError("Could not find Nimbus weights at {ckpt_path}. \ + raise FileNotFoundError( + "Could not find Nimbus weights at {ckpt_path}. \ Current path is {current_path} and directory contains {dir_c},\ path to cell_clasification i{p}".format( - ckpt_path=self.checkpoint_path, - current_path=os.getcwd(), - dir_c=os.listdir(os.getcwd()), - p=os.path.abspath(nimbus_inference.__file__) - ) + ckpt_path=self.checkpoint_path, + current_path=os.getcwd(), + dir_c=os.listdir(os.getcwd()), + p=os.path.abspath(nimbus_inference.__file__), + ) ) self.model = model def prepare_normalization_dict( - self, quantile=0.999, n_subset=10, multiprocessing=False, overwrite=False, - ): + self, quantile=0.999, n_subset=10, multiprocessing=False, overwrite=False, + ): """Load or prepare and save normalization dictionary for Nimbus model. Args: quantile (float): Quantile to use for normalization. @@ -205,23 +226,21 @@ def predict_fovs(self): if self.suffix.lower() in [".ome.tif", ".ome.tiff"]: self.cell_table = predict_ome_fovs( self.fov_paths, self.output_dir, self, self.normalization_dict, - self.segmentation_naming_convention, self.include_channels, self.save_predictions, - self.half_resolution, batch_size=self.batch_size, + self.segmentation_naming_convention, self.include_channels, + self.save_predictions, self.half_resolution, batch_size=self.batch_size, test_time_augmentation=self.test_time_aug, ) - elif self.suffix.lower() in [".tiff",".tif", ".jpg", ".jpeg", ".png"]: + elif self.suffix.lower() in [".tiff", ".tif", ".jpg", ".jpeg", ".png"]: self.cell_table = predict_fovs( self.fov_paths, self.output_dir, self, self.normalization_dict, - self.segmentation_naming_convention, self.include_channels, self.save_predictions, - self.half_resolution, batch_size=self.batch_size, + self.segmentation_naming_convention, self.include_channels, + self.save_predictions, self.half_resolution, batch_size=self.batch_size, test_time_augmentation=self.test_time_aug, ) - self.cell_table.to_csv( - os.path.join(self.output_dir,"nimbus_cell_table.csv"), index=False - ) + self.cell_table.to_csv(os.path.join(self.output_dir, "nimbus_cell_table.csv"), index=False) return self.cell_table - - def predict_segmentation(self, input_data, preprocess_kwargs, batch_size): + + def predict_segmentation(self, input_data, preprocess_kwargs): """Predicts segmentation for input data. Args: input_data (np.array): Input data to predict segmentation for. @@ -235,13 +254,16 @@ def predict_segmentation(self, input_data, preprocess_kwargs, batch_size): if not hasattr(self, "model") or self.model.padding != "reflect": self.initialize_model(padding="reflect") with torch.no_grad(): - prediction = self.model.predict(input_data) + if not isinstance(input_data, torch.Tensor): + input_data = torch.tensor(input_data).float() + prediction = self.model(input_data) + prediction = prediction.cpu().squeeze(0).numpy() else: if not hasattr(self, "model") or self.model.padding != "valid": self.initialize_model(padding="valid") - prediction = self._tile_and_stitch(input_data, batch_size) + prediction = self._tile_and_stitch(input_data, self.batch_size) return prediction - + def _tile_and_stitch(self, input_data): """Predicts segmentation for input data using tile and stitch method. Args: @@ -254,7 +276,7 @@ def _tile_and_stitch(self, input_data): output_shape = self.model(torch.rand(1, 2, *self.input_shape)).shape[-2:] input_shape = input_data.shape # f^dl crop to have perfect shift equivariance inference - self.crop_by = np.array(output_shape) % 2**5 + self.crop_by = np.array(output_shape) % 2 ** 5 output_shape = output_shape - self.crop_by tiled_input, padding = self._tile_input( input_data, tile_size=self.input_shape, output_shape=output_shape @@ -262,14 +284,16 @@ def _tile_and_stitch(self, input_data): shape_diff = (self.input_shape - output_shape) // 2 padding = [ padding[0] - shape_diff[0], padding[1] - shape_diff[0], - padding[2] - shape_diff[1], padding[3] - shape_diff[1] + padding[2] - shape_diff[1], padding[3] - shape_diff[1], ] h_t, h_w = tiled_input.shape[:2] - tiled_input = tiled_input.reshape(h_t*h_w, input_shape[1], *self.input_shape) # h_t,w_t,c,h,w -> h_t*w_t,c,h,w + tiled_input = tiled_input.reshape( + h_t * h_w, input_shape[1], *self.input_shape + ) # h_t,w_t,c,h,w -> h_t*w_t,c,h,w # predict tiles prediction = [] for i in tqdm(range(0, len(tiled_input), self.batch_size)): - batch = torch.from_numpy(tiled_input[i:i + self.batch_size]).float() + batch = torch.from_numpy(tiled_input[i : i + self.batch_size]).float() if torch.cuda.is_available(): batch = batch.cuda() with torch.no_grad(): @@ -277,15 +301,17 @@ def _tile_and_stitch(self, input_data): # crop pred if self.crop_by.any(): pred = pred[ - ..., self.crop_by[0]//2:-self.crop_by[0]//2, self.crop_by[1]//2:-self.crop_by[1]//2 + ..., + self.crop_by[0] // 2 : -self.crop_by[0] // 2, + self.crop_by[1] // 2 : -self.crop_by[1] // 2, ] prediction += [pred] - prediction = np.concatenate(prediction) # h_t*w_t,c,h,w - prediction = prediction.reshape(h_t, h_w, *prediction.shape[1:]) # h_t,w_t,c,h,w + prediction = np.concatenate(prediction) # h_t*w_t,c,h,w + prediction = prediction.reshape(h_t, h_w, *prediction.shape[1:]) # h_t,w_t,c,h,w # stitch tiles prediction = self._stitch_tiles(prediction, padding) return prediction - + def _tile_input(self, image, tile_size, output_shape, pad_mode="reflect"): """Tiles input image for model inference. Args: @@ -298,16 +324,17 @@ def _tile_input(self, image, tile_size, output_shape, pad_mode="reflect"): """ overlap_px = (np.array(tile_size) - np.array(output_shape)) // 2 # pad image to be divisible by tile size - pad_h0, pad_w0 = np.array(tile_size) - (np.array(image.shape[-2:]) % np.array(output_shape)) + pad_h0, pad_w0 = np.array(tile_size) - ( + np.array(image.shape[-2:]) % np.array(output_shape) + ) pad_h1, pad_w1 = pad_h0 // 2, pad_h0 - pad_h0 // 2 pad_h0, pad_w0 = pad_h0 - pad_h1, pad_w0 - pad_w1 image = np.pad(image, ((0, 0), (0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1)), mode=pad_mode) - b,c = image.shape[:2] - # tile image with np.lib.stride_tricks.as_strided - view = np.squeeze(view_as_windows( - image, [b,c]+list(tile_size), step=[b,c]+list(output_shape) - ) - ) # h_t,w_t,c,h,w + b, c = image.shape[:2] + # tile image + view = np.squeeze( + view_as_windows(image, [b, c] + list(tile_size), step=[b, c] + list(output_shape)) + ) # h_t,w_t,c,h,w padding = [pad_h0, pad_h1, pad_w0, pad_w1] return view, padding @@ -320,13 +347,11 @@ def _stitch_tiles(self, tiles, padding): np.array: Reconstructed image. """ # stitch tiles - h_t,w_t,c,h,w = tiles.shape - stitched = np.zeros((c, h_t*h, w_t*w)) + h_t, w_t, c, h, w = tiles.shape + stitched = np.zeros((c, h_t * h, w_t * w)) for i in range(h_t): for j in range(w_t): - stitched[:, i*h:(i+1)*h, j*w:(j+1)*w] = tiles[i, j] + stitched[:, i * h : (i + 1) * h, j * w : (j + 1) * w] = tiles[i, j] # remove padding - stitched = stitched[ - :, padding[0]:-padding[1], padding[2]:-padding[3] - ] + stitched = stitched[:, padding[0] : -padding[1], padding[2] : -padding[3]] return stitched diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py index 6b3829a..a609e2c 100644 --- a/src/nimbus_inference/utils.py +++ b/src/nimbus_inference/utils.py @@ -9,6 +9,7 @@ from tqdm.autonotebook import tqdm from joblib import Parallel, delayed from skimage.segmentation import find_boundaries +from skimage.measure import regionprops_table def calculate_normalization(channel_path, quantile): @@ -85,7 +86,7 @@ def prepare_input_data(mplex_img, instance_mask): """ edge = find_boundaries(instance_mask, mode="inner").astype(np.uint8) binary_mask = np.logical_and(edge == 0, instance_mask > 0).astype(np.float32) - input_data = np.stack([mplex_img, binary_mask], axis=-1)[np.newaxis,...] # bhwc + input_data = np.stack([mplex_img, binary_mask], axis=0)[np.newaxis,...] # bhwc return input_data @@ -98,15 +99,11 @@ def segment_mean(instance_mask, prediction): uniques (np.array): unique instance ids mean_per_cell (np.array): mean prediction per instance """ - instance_mask_flat = tf.cast(tf.reshape(instance_mask, -1), tf.int32) # (h*w) - pred_flat = tf.cast(tf.reshape(prediction, -1), tf.float32) - sort_order = tf.argsort(instance_mask_flat) - instance_mask_flat = tf.gather(instance_mask_flat, sort_order) - uniques, _ = tf.unique(instance_mask_flat) - pred_flat = tf.gather(pred_flat, sort_order) - mean_per_cell = tf.math.segment_mean(pred_flat, instance_mask_flat) - mean_per_cell = tf.gather(mean_per_cell, uniques) - return [uniques.numpy()[1:], mean_per_cell.numpy()[1:]] # discard background + props_df = regionprops_table( + label_image=instance_mask, intensity_image=prediction, + properties=['label' ,'intensity_mean'] + ) + return props_df def test_time_aug( @@ -126,10 +123,12 @@ def test_time_aug( """ forward_augmentations = [] backward_augmentations = [] + if not isinstance(input_data, torch.Tensor): + input_data = torch.tensor(input_data) if rotate: for k in [0,1,2,3]: - forward_augmentations.append(lambda x: torch.rot90(x, k=k)) - backward_augmentations.append(lambda x: torch.rot90(x, k=-k)) + forward_augmentations.append(lambda x: torch.rot90(x, k=k, dims=[2,3])) + backward_augmentations.append(lambda x: torch.rot90(x, k=-k, dims=[2,3])) if flip: forward_augmentations += [ lambda x: torch.flip(x, [2]), @@ -144,21 +143,21 @@ def test_time_aug( input_data_tmp = forw_aug(input_data).numpy() # bhwc input_batch.append(np.concatenate(input_data_tmp)) input_batch = np.stack(input_batch, 0) - seg_map = app._predict_segmentation( + seg_map = app.predict_segmentation( input_batch, - batch_size=batch_size, preprocess_kwargs={ "normalize": True, "marker": channel, "normalization_dict": normalization_dict}, ) + seg_map = torch.from_numpy(seg_map) tmp = [] for backw_aug, seg_map_tmp in zip(backward_augmentations, seg_map): seg_map_tmp = backw_aug(seg_map_tmp[np.newaxis,...]) seg_map_tmp = np.squeeze(seg_map_tmp) tmp.append(seg_map_tmp) - seg_map = np.stack(tmp, -1) - seg_map = np.mean(seg_map, axis = -1, keepdims = True) + seg_map = np.stack(tmp, 0) + seg_map = np.mean(seg_map, axis = 0) return seg_map @@ -188,11 +187,13 @@ def predict_fovs( out_fov_path = os.path.join( os.path.normpath(output_dir), os.path.basename(fov_path) ) - fov_dict = {} + df_fov = pd.DataFrame() for channel in os.listdir(fov_path): channel_path = os.path.join(fov_path, channel) channel_ = channel.split(".")[0] - if not channel.endswith(suffix) or channel not in include_channels: + if not channel.endswith(suffix) or ( + include_channels != [] and channel_ not in include_channels + ): continue mplex_img = np.squeeze(io.imread(channel_path)) instance_path = segmentation_naming_convention(fov_path) @@ -201,12 +202,12 @@ def predict_fovs( if half_resolution: scale = 0.5 input_data = np.squeeze(input_data) - h,w,_ = input_data.shape - img = cv2.resize(input_data[...,0], [int(h*scale), int(w*scale)]) + _, h,w = input_data.shape + img = cv2.resize(input_data[0], [int(h*scale), int(w*scale)]) binary_mask = cv2.resize( - input_data[...,1], [int(h*scale), int(w*scale)], interpolation=0 + input_data[1], [int(h*scale), int(w*scale)], interpolation=0 ) - input_data = np.stack([img, binary_mask], axis=-1)[np.newaxis,...] + input_data = np.stack([img, binary_mask], axis=0)[np.newaxis,...] if test_time_augmentation: prediction = test_time_aug( input_data, channel, nimbus, normalization_dict, batch_size=batch_size @@ -223,20 +224,19 @@ def predict_fovs( prediction = np.squeeze(prediction) if half_resolution: prediction = cv2.resize(prediction, (h, w)) - instance_mask = np.expand_dims(instance_mask, axis=-1) - labels, mean_per_cell = segment_mean(instance_mask, prediction) - if "label" not in fov_dict.keys(): - fov_dict["fov"] = [os.path.basename(fov_path)]*len(labels) - fov_dict["label"] = labels - fov_dict[channel+"_pred"] = mean_per_cell + df = pd.DataFrame(segment_mean(instance_mask, prediction)) + if df_fov.empty: + df_fov["label"] = df["label"] + df_fov["fov"] = os.path.basename(fov_path) + df_fov[channel.split(".")[0]] = df["intensity_mean"] if save_predictions: os.makedirs(out_fov_path, exist_ok=True) - pred_int = tf.cast(prediction*255.0, tf.uint8).numpy() + pred_int = (prediction*255.0).astype(np.uint8) io.imsave( - os.path.join(out_fov_path, channel+".tiff"), pred_int, + os.path.join(out_fov_path, channel), pred_int, photometric="minisblack", compression="zlib" ) - fov_dict_list.append(pd.DataFrame(fov_dict)) + fov_dict_list.append(df_fov) cell_table = pd.concat(fov_dict_list, ignore_index=True) return cell_table diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_basic.py b/tests/test_basic.py deleted file mode 100644 index 757e4da..0000000 --- a/tests/test_basic.py +++ /dev/null @@ -1,12 +0,0 @@ -import pytest - -import nimbus_inference - - -def test_package_has_version(): - assert nimbus_inference.__version__ is not None - - -@pytest.mark.skip(reason="This decorator should be removed when test passes.") -def test_example(): - assert 1 == 0 # This test is designed to fail. diff --git a/tests/test_nimbus.py b/tests/test_nimbus.py index 2daa799..55dc465 100644 --- a/tests/test_nimbus.py +++ b/tests/test_nimbus.py @@ -1,4 +1,4 @@ -from test_dataset import prepare_ome_tif_data, prepare_tif_data +from tests.test_utils import prepare_ome_tif_data, prepare_tif_data import tempfile from nimbus_inference.nimbus import Nimbus, prep_naming_convention from nimbus_inference.unet import UNet @@ -13,14 +13,26 @@ def test_check_inputs(): with tempfile.TemporaryDirectory() as temp_dir: num_samples = 5 selected_markers = ["CD45", "CD3", "CD8", "ChyTr"] - fov_paths = prepare_tif_data(num_samples, temp_dir, selected_markers) + fov_paths, _ = prepare_tif_data(num_samples, temp_dir, selected_markers) naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) nimbus = Nimbus( - fov_paths=[temp_dir], segmentation_naming_convention=naming_convention, output_dir=temp_dir + fov_paths=fov_paths, segmentation_naming_convention=naming_convention, + output_dir=temp_dir ) nimbus.check_inputs() - # check if the model is initialized - assert isinstance(nimbus.model, UNet) + + +def test_initialize_model(): + nimbus = Nimbus( + fov_paths=[""], segmentation_naming_convention="", output_dir="", + input_shape=[512,512], batch_size=4 + ) + nimbus.initialize_model(padding="valid") + assert isinstance(nimbus.model, UNet) + assert nimbus.model.padding == "valid" + nimbus.initialize_model(padding="reflect") + assert isinstance(nimbus.model, UNet) + assert nimbus.model.padding == "reflect" def test_prepare_normalization_dict(): @@ -29,7 +41,7 @@ def test_prepare_normalization_dict(): num_samples = 5 selected_markers = ["CD45", "CD3", "CD8", "ChyTr"] - fov_paths = prepare_tif_data(num_samples, temp_dir, selected_markers) + fov_paths,_ = prepare_tif_data(num_samples, temp_dir, selected_markers) naming_convention = prep_naming_convention(os.path.join(temp_dir, "deepcell_output")) nimbus = Nimbus( fov_paths, naming_convention, temp_dir, @@ -56,7 +68,7 @@ def test_tile_input(): nimbus.model = lambda x: x[..., 96:-96, 96:-96] tiled_input, padding = nimbus._tile_input(image, tile_size, output_shape) assert tiled_input.shape == (3,3,2,512,512) - assert padding == (192, 192, 192, 192) + assert padding == [192, 192, 192, 192] def test_tile_and_stitch(): @@ -82,7 +94,3 @@ def test_tile_and_stitch(): assert prediction.shape == (1, 768, 768) assert prediction.max() <= 1 assert prediction.min() >= 0 - - -def test_predict_segmentation(): - pass \ No newline at end of file diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..976aef2 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,239 @@ +from nimbus_inference.utils import prepare_normalization_dict, calculate_normalization +from nimbus_inference.utils import predict_fovs, predict_ome_fovs, prepare_input_data +from nimbus_inference.utils import test_time_aug as tt_aug +from nimbus_inference.nimbus import Nimbus +from skimage import io +from pyometiff import OMETIFFWriter +import numpy as np +import tempfile +import torch +import json +import os + + +class MockModel(torch.nn.Module): + def __init__(self, padding): + super(MockModel, self).__init__() + self.padding = padding + self.fn = torch.nn.Identity() + + def forward(self, x): + return self.fn(x) + + +def prepare_tif_data(num_samples, temp_dir, selected_markers, random=False, std=1): + np.random.seed(42) + fov_paths = [] + inst_paths = [] + deepcell_dir = os.path.join(temp_dir, "deepcell_output") + os.makedirs(deepcell_dir, exist_ok=True) + if isinstance(std, (int, float)) or len(std) != len(selected_markers): + std = [std] * len(selected_markers) + for i in range(num_samples): + folder = os.path.join(temp_dir, "fov_" + str(i)) + os.makedirs(folder, exist_ok=True) + for marker, scale in zip(selected_markers, std): + if random: + img = np.random.rand(256, 256) * scale + else: + img = np.ones([256, 256]) + io.imsave( + os.path.join(folder, marker + ".tiff"), + img, + ) + inst_path = os.path.join(deepcell_dir, f"fov_{i}_whole_cell.tiff") + io.imsave( + inst_path, np.array( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] + ).repeat(64, axis=1).repeat(64, axis=0) + ) + if folder not in fov_paths: + fov_paths.append(folder) + inst_paths.append(inst_path) + return fov_paths, inst_paths + + +def prepare_ome_tif_data(num_samples, temp_dir, selected_markers, random=False, std=1): + np.random.seed(42) + metadata_dict = { + "PhysicalSizeX" : "0.88", + "PhysicalSizeXUnit" : "µm", + "PhysicalSizeY" : "0.88", + "PhysicalSizeYUnit" : "µm", + "PhysicalSizeZ" : "3.3", + "PhysicalSizeZUnit" : "µm", + } + + for i in range(num_samples): + metadata_dict["Channels"] = {} + channels = [] + for marker in zip(selected_markers): + if random: + img = np.random.rand(256, 256) * std + else: + img = np.ones([256, 256]) + channels.append(img) + metadata_dict["Channels"][marker] = { + "Name" : marker, + "SamplesPerPixel": 1, + } + channel_data = np.stack(channels, axis=0) + sample_name = os.path.join(temp_dir, f"fov_{i}.ome.tiff") + dimension_order = "CYX" + writer = OMETIFFWriter( + fpath=sample_name, + dimension_order=dimension_order, + array=channel_data, + metadata=metadata_dict, + explicit_tiffdata=False) + writer.write() + return None + + +def test_calculate_normalization(): + with tempfile.TemporaryDirectory() as temp_dir: + fov_paths, _ = prepare_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4"], random=True, std=[0.5] + ) + channel = "CD4" + channel_path = os.path.join(fov_paths[0], channel + ".tiff") + channel_out, norm_val = calculate_normalization(channel_path, 0.999) + # test if we get the correct channel and normalization value + assert channel_out == channel + assert np.isclose(norm_val, 0.5, 0.01) + + +def test_prepare_normalization_dict(): + with tempfile.TemporaryDirectory() as temp_dir: + scales = [0.5, 1.0, 1.5, 2.0, 5.0] + channels = ["CD4", "CD11c", "CD14", "CD56", "CD57"] + fov_paths, _ = prepare_tif_data( + num_samples=5, temp_dir=temp_dir, selected_markers=channels, random=True, std=scales + ) + normalization_dict = prepare_normalization_dict( + fov_paths, temp_dir, quantile=0.999, n_subset=10, n_jobs=1, + output_name="normalization_dict.json" + ) + # test if normalization dict got saved + assert os.path.exists(os.path.join(temp_dir, "normalization_dict.json")) + assert normalization_dict == json.load( + open(os.path.join(temp_dir, "normalization_dict.json")) + ) + # test if normalization dict is correct + for channel, scale in zip(channels, scales): + assert np.isclose(normalization_dict[channel], scale, 0.01) + + # test if multiprocessing yields approximately the same results + normalization_dict_mp = prepare_normalization_dict( + fov_paths, temp_dir, quantile=0.999, n_subset=10, n_jobs=2, + output_name="normalization_dict.json" + ) + for key in normalization_dict.keys(): + assert np.isclose(normalization_dict[key], normalization_dict_mp[key], 1e-6) + + +def test_prepare_input_data(): + with tempfile.TemporaryDirectory() as temp_dir: + scales = [0.5] + channels = ["CD4"] + fov_paths, inst_paths = prepare_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=channels, random=True, std=scales + ) + mplex_img = io.imread(os.path.join(fov_paths[0], "CD4.tiff")) + instance_mask = io.imread(inst_paths[0]) + input_data = prepare_input_data(mplex_img, instance_mask) + # check shape + assert input_data.shape == (1, 2, 256, 256) + # check if instance mask got binarized and eroded + assert np.alltrue(np.unique(input_data[:,1]) == np.array([0, 1])) + assert np.sum(input_data[:,1]) < np.sum(instance_mask) + # check if mplex image is the same as before + assert np.alltrue(input_data[0, 0] == mplex_img) + + +def test_tt_aug(): + with tempfile.TemporaryDirectory() as temp_dir: + def segmentation_naming_convention(fov_path): + temp_dir_, fov_ = os.path.split(fov_path) + return os.path.join(temp_dir_, "deepcell_output", fov_ + "_whole_cell.tiff") + channel = "CD4" + fov_paths, inst_paths = prepare_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=[channel] + ) + output_dir = os.path.join(temp_dir, "nimbus_output") + nimbus = Nimbus( + fov_paths, segmentation_naming_convention, output_dir, + ) + nimbus.prepare_normalization_dict() + mplex_img = io.imread(os.path.join(fov_paths[0], channel+".tiff")) + instance_mask = io.imread(inst_paths[0]) + input_data = prepare_input_data(mplex_img, instance_mask) + nimbus.model = MockModel(padding="reflect") + pred_map = tt_aug( + input_data, channel, nimbus, nimbus.normalization_dict, rotate=True, flip=True, + batch_size=32 + ) + # check if we get the correct shape + assert pred_map.shape == (2, 256, 256) + + pred_map_2 = tt_aug( + input_data, channel, nimbus, nimbus.normalization_dict, rotate=False, flip=True, + batch_size=32 + ) + pred_map_3 = tt_aug( + input_data, channel, nimbus, nimbus.normalization_dict, rotate=True, flip=False, + batch_size=32 + ) + pred_map_no_tt_aug = nimbus.predict_segmentation( + input_data, + preprocess_kwargs={ + "normalize": True, + "marker": channel, + "normalization_dict": nimbus.normalization_dict}, + ) + # check if we get roughly the same results for non augmented and augmented predictions + assert np.allclose(pred_map, pred_map_no_tt_aug, atol=0.05) + assert np.allclose(pred_map_2, pred_map_no_tt_aug, atol=0.05) + assert np.allclose(pred_map_3, pred_map_no_tt_aug, atol=0.05) + + +def test_predict_fovs(): + with tempfile.TemporaryDirectory() as temp_dir: + def segmentation_naming_convention(fov_path): + temp_dir_, fov_ = os.path.split(fov_path) + return os.path.join(temp_dir_, "deepcell_output", fov_ + "_whole_cell.tiff") + + fov_paths, _ = prepare_tif_data( + num_samples=1, temp_dir=temp_dir, selected_markers=["CD4", "CD56"] + ) + output_dir = os.path.join(temp_dir, "nimbus_output") + nimbus = Nimbus( + fov_paths, segmentation_naming_convention, output_dir, + ) + output_dir = os.path.join(temp_dir, "nimbus_output") + nimbus.prepare_normalization_dict() + cell_table = predict_fovs( + nimbus=nimbus, fov_paths=fov_paths, output_dir=output_dir, + normalization_dict=nimbus.normalization_dict, + segmentation_naming_convention=segmentation_naming_convention, suffix=".tiff", + save_predictions=False, half_resolution=True, + ) + # check if we get the correct number of cells + assert len(cell_table) == 15 + # check if we get the correct columns (fov, label, CD4, CD56) + assert np.alltrue( + set(cell_table.columns) == set(["fov", "label", "CD4", "CD56"]) + ) + # check if predictions don't get written to output_dir + assert not os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) + assert not os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) + # + # run again with save_predictions=True and check if predictions get written to output_dir + cell_table = predict_fovs( + nimbus=nimbus, fov_paths=fov_paths, output_dir=output_dir, + normalization_dict=nimbus.normalization_dict, + segmentation_naming_convention=segmentation_naming_convention, suffix=".tiff", + save_predictions=True, half_resolution=True, + ) + assert os.path.exists(os.path.join(output_dir, "fov_0", "CD4.tiff")) + assert os.path.exists(os.path.join(output_dir, "fov_0", "CD56.tiff")) From 573366623795c55b2ce52c71a2fe44c553ab6a02 Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 19 Feb 2024 18:23:33 +0100 Subject: [PATCH 04/31] Updated dependencies --- pyproject.toml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8bfa1aa..0bb09fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,15 @@ urls.Documentation = "https://Nimbus-Inference.readthedocs.io/" urls.Source = "https://github.com/angelolab/Nimbus-Inference" urls.Home-page = "https://github.com/angelolab/Nimbus-Inference" dependencies = [ - "anndata", + "torch", + "torchvision", + "alpineer", + "scikit-image", + "tqdm", + "opencv-python", + "numpy", + "pandas", + "joblib", # for debug logging (referenced from the issue template) "session-info", ] From ff74b43bea7090bcac285fc9a63f098181e49d50 Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 19 Feb 2024 18:25:45 +0100 Subject: [PATCH 05/31] Fixed torch and torchvision version dependency --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0bb09fb..e10e341 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,8 @@ urls.Documentation = "https://Nimbus-Inference.readthedocs.io/" urls.Source = "https://github.com/angelolab/Nimbus-Inference" urls.Home-page = "https://github.com/angelolab/Nimbus-Inference" dependencies = [ - "torch", - "torchvision", + "torch==2.2.0", + "torchvision=0.17.0", "alpineer", "scikit-image", "tqdm", From 8172e4e4d870cbcfc934f3871dc80d1f3781b782 Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 19 Feb 2024 18:27:39 +0100 Subject: [PATCH 06/31] Fixed typo --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e10e341..796efd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ urls.Source = "https://github.com/angelolab/Nimbus-Inference" urls.Home-page = "https://github.com/angelolab/Nimbus-Inference" dependencies = [ "torch==2.2.0", - "torchvision=0.17.0", + "torchvision==0.17.0", "alpineer", "scikit-image", "tqdm", From b14d607416d6a3eddab8cfe78c920482d97c42bb Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 19 Feb 2024 18:31:25 +0100 Subject: [PATCH 07/31] Deleted python files from repo init --- src/nimbus_inference/pl/__init__.py | 1 - src/nimbus_inference/pl/basic.py | 63 ----------------------------- src/nimbus_inference/pp/__init__.py | 1 - src/nimbus_inference/pp/basic.py | 17 -------- src/nimbus_inference/tl/__init__.py | 1 - src/nimbus_inference/tl/basic.py | 17 -------- 6 files changed, 100 deletions(-) delete mode 100644 src/nimbus_inference/pl/__init__.py delete mode 100644 src/nimbus_inference/pl/basic.py delete mode 100644 src/nimbus_inference/pp/__init__.py delete mode 100644 src/nimbus_inference/pp/basic.py delete mode 100644 src/nimbus_inference/tl/__init__.py delete mode 100644 src/nimbus_inference/tl/basic.py diff --git a/src/nimbus_inference/pl/__init__.py b/src/nimbus_inference/pl/__init__.py deleted file mode 100644 index c2315dd..0000000 --- a/src/nimbus_inference/pl/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .basic import BasicClass, basic_plot diff --git a/src/nimbus_inference/pl/basic.py b/src/nimbus_inference/pl/basic.py deleted file mode 100644 index ed390ef..0000000 --- a/src/nimbus_inference/pl/basic.py +++ /dev/null @@ -1,63 +0,0 @@ -from anndata import AnnData - - -def basic_plot(adata: AnnData) -> int: - """Generate a basic plot for an AnnData object. - - Parameters - ---------- - adata - The AnnData object to preprocess. - - Returns - ------- - Some integer value. - """ - print("Import matplotlib and implement a plotting function here.") - return 0 - - -class BasicClass: - """A basic class. - - Parameters - ---------- - adata - The AnnData object to preprocess. - """ - - my_attribute: str = "Some attribute." - my_other_attribute: int = 0 - - def __init__(self, adata: AnnData): - print("Implement a class here.") - - def my_method(self, param: int) -> int: - """A basic method. - - Parameters - ---------- - param - A parameter. - - Returns - ------- - Some integer value. - """ - print("Implement a method here.") - return 0 - - def my_other_method(self, param: str) -> str: - """Another basic method. - - Parameters - ---------- - param - A parameter. - - Returns - ------- - Some integer value. - """ - print("Implement a method here.") - return "" diff --git a/src/nimbus_inference/pp/__init__.py b/src/nimbus_inference/pp/__init__.py deleted file mode 100644 index 5e7e293..0000000 --- a/src/nimbus_inference/pp/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .basic import basic_preproc diff --git a/src/nimbus_inference/pp/basic.py b/src/nimbus_inference/pp/basic.py deleted file mode 100644 index 5db1ec0..0000000 --- a/src/nimbus_inference/pp/basic.py +++ /dev/null @@ -1,17 +0,0 @@ -from anndata import AnnData - - -def basic_preproc(adata: AnnData) -> int: - """Run a basic preprocessing on the AnnData object. - - Parameters - ---------- - adata - The AnnData object to preprocess. - - Returns - ------- - Some integer value. - """ - print("Implement a preprocessing function here.") - return 0 diff --git a/src/nimbus_inference/tl/__init__.py b/src/nimbus_inference/tl/__init__.py deleted file mode 100644 index 95a32cd..0000000 --- a/src/nimbus_inference/tl/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .basic import basic_tool diff --git a/src/nimbus_inference/tl/basic.py b/src/nimbus_inference/tl/basic.py deleted file mode 100644 index d215ade..0000000 --- a/src/nimbus_inference/tl/basic.py +++ /dev/null @@ -1,17 +0,0 @@ -from anndata import AnnData - - -def basic_tool(adata: AnnData) -> int: - """Run a tool on the AnnData object. - - Parameters - ---------- - adata - The AnnData object to preprocess. - - Returns - ------- - Some integer value. - """ - print("Implement a tool to run on the AnnData object.") - return 0 From 5d4a8b014e27d9d9fdced299f240198c47097ace Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 19 Feb 2024 18:34:54 +0100 Subject: [PATCH 08/31] Got rid of references to repo init files --- src/nimbus_inference/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/nimbus_inference/__init__.py b/src/nimbus_inference/__init__.py index ed5031e..3058dd4 100644 --- a/src/nimbus_inference/__init__.py +++ b/src/nimbus_inference/__init__.py @@ -1,7 +1,3 @@ from importlib.metadata import version -from . import pl, pp, tl - -__all__ = ["pl", "pp", "tl"] - __version__ = version("Nimbus-Inference") From 95aa27c84633bd223ec002d99abd0e94de9efe5d Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 19 Feb 2024 18:43:17 +0100 Subject: [PATCH 09/31] added more dependencies --- pyproject.toml | 4 ++++ src/nimbus_inference/nimbus.py | 6 ------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 796efd1..faf4bb5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,10 @@ dependencies = [ "numpy", "pandas", "joblib", + "pandas", + "json", + "pathlib", + "glob", # for debug logging (referenced from the issue template) "session-info", ] diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 203945a..7d77587 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -1,9 +1,3 @@ -# nimbus model class -# loads the image -# does inference on images, decides if tile & stitch or whole image inference -# preprocesses the predictions -# saves the output to the output folder - from alpineer import io_utils from skimage.util.shape import view_as_windows import nimbus_inference From 0154359cf782284979e187038bca0904354c6e12 Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 19 Feb 2024 18:45:53 +0100 Subject: [PATCH 10/31] Changed dependencies again.. --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index faf4bb5..ac5f1df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ dependencies = [ "pandas", "json", "pathlib", - "glob", # for debug logging (referenced from the issue template) "session-info", ] From 5878dd9b86c687b22c6c6fb442e8a39a9319796f Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 19 Feb 2024 18:49:16 +0100 Subject: [PATCH 11/31] Changed dependencies again.. --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ac5f1df..010ec4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ dependencies = [ "pandas", "joblib", "pandas", - "json", "pathlib", # for debug logging (referenced from the issue template) "session-info", From e07e8e6542b17bb2297dd2736a4b6f068fb0ab5f Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 19 Feb 2024 18:51:43 +0100 Subject: [PATCH 12/31] Changed dependencies again.. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 010ec4b..40bb2ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "joblib", "pandas", "pathlib", + "pyometiff", # for debug logging (referenced from the issue template) "session-info", ] From 0fe60ac33114c8443f4065cf91d86c1fc2225eba Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 20 Feb 2024 10:51:12 +0100 Subject: [PATCH 13/31] Added lfs pull to build.yaml --- .github/workflows/build.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3221efe..c199cfe 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -19,6 +19,7 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 + lfs: true - name: Build Wheels and Source Distribution run: pipx run build --wheel --sdist From de19af242fce254956809844f89908fef1cc86d8 Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 20 Feb 2024 11:03:49 +0100 Subject: [PATCH 14/31] Added lfs to ci and test workflows --- .github/workflows/ci.yaml | 1 + .github/workflows/test.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5bccd89..f7f7117 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -42,6 +42,7 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 + lfs: true - name: Download Coverage Artifact uses: actions/download-artifact@v4 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 59fdc4d..b2f16ac 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -21,6 +21,7 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + lfs: true - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 From e44bdac05936f5f6cb37cadcacea0e3c27b72c34 Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 20 Feb 2024 11:12:24 +0100 Subject: [PATCH 15/31] Added git lfs pull to all checkout workflows --- .github/workflows/build.yaml | 1 + .github/workflows/ci.yaml | 1 + .github/workflows/test.yaml | 1 + 3 files changed, 3 insertions(+) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index c199cfe..504e2ce 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -20,6 +20,7 @@ jobs: with: fetch-depth: 0 lfs: true + run: git lfs pull - name: Build Wheels and Source Distribution run: pipx run build --wheel --sdist diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f7f7117..03867d0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -43,6 +43,7 @@ jobs: with: fetch-depth: 0 lfs: true + run: git lfs pull - name: Download Coverage Artifact uses: actions/download-artifact@v4 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b2f16ac..37ad24f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,6 +22,7 @@ jobs: with: fetch-depth: 0 lfs: true + run: git lfs pull - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 From 22f31cc2f2ed1fefdb6a2a6b7d959d7f4142832c Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 20 Feb 2024 11:15:47 +0100 Subject: [PATCH 16/31] Changed git lfs pull, because it killed all workflows --- .github/workflows/build.yaml | 4 +++- .github/workflows/ci.yaml | 4 +++- .github/workflows/test.yaml | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 504e2ce..91a1f58 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -20,7 +20,9 @@ jobs: with: fetch-depth: 0 lfs: true - run: git lfs pull + + - name: LFS Pull + run: git lfs fetch --all - name: Build Wheels and Source Distribution run: pipx run build --wheel --sdist diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 03867d0..b4ef570 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -43,7 +43,9 @@ jobs: with: fetch-depth: 0 lfs: true - run: git lfs pull + + - name: LFS Pull + run: git lfs fetch --all - name: Download Coverage Artifact uses: actions/download-artifact@v4 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 37ad24f..682fbbb 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,7 +22,9 @@ jobs: with: fetch-depth: 0 lfs: true - run: git lfs pull + + - name: LFS Pull + run: git lfs fetch --all - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 From 7eb3b4b3afe1e1bf15f1d356983f42c05f406256 Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 20 Feb 2024 11:27:11 +0100 Subject: [PATCH 17/31] Set checkout lfs to false and added lfs fetch afterwards --- .github/workflows/build.yaml | 1 - .github/workflows/ci.yaml | 1 - .github/workflows/test.yaml | 1 - 3 files changed, 3 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 91a1f58..f5925d5 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -19,7 +19,6 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 - lfs: true - name: LFS Pull run: git lfs fetch --all diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b4ef570..af158bf 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -42,7 +42,6 @@ jobs: uses: actions/checkout@v4 with: fetch-depth: 0 - lfs: true - name: LFS Pull run: git lfs fetch --all diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 682fbbb..cbed93d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -21,7 +21,6 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 - lfs: true - name: LFS Pull run: git lfs fetch --all From 092af1f9a4da8590072301d51d06024bffbe4fb9 Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 20 Feb 2024 11:45:10 +0100 Subject: [PATCH 18/31] Swapped git lfs fetch --all to git lfs pull --- .github/workflows/build.yaml | 2 +- .github/workflows/ci.yaml | 2 +- .github/workflows/test.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index f5925d5..3354ebe 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -21,7 +21,7 @@ jobs: fetch-depth: 0 - name: LFS Pull - run: git lfs fetch --all + run: git lfs pull - name: Build Wheels and Source Distribution run: pipx run build --wheel --sdist diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index af158bf..ffaa3fa 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -44,7 +44,7 @@ jobs: fetch-depth: 0 - name: LFS Pull - run: git lfs fetch --all + run: git lfs pull - name: Download Coverage Artifact uses: actions/download-artifact@v4 diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index cbed93d..5ad010d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -23,7 +23,7 @@ jobs: fetch-depth: 0 - name: LFS Pull - run: git lfs fetch --all + run: git lfs pull - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 From 09a8874d5666b8b2873127308cfe79c499505f3f Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 20 Feb 2024 11:59:17 +0100 Subject: [PATCH 19/31] Added format to coveralls workflow --- .github/workflows/ci.yaml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ffaa3fa..035d9dd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -43,9 +43,6 @@ jobs: with: fetch-depth: 0 - - name: LFS Pull - run: git lfs pull - - name: Download Coverage Artifact uses: actions/download-artifact@v4 # if `name` is not specified, all artifacts are downloaded. @@ -54,3 +51,4 @@ jobs: uses: coverallsapp/github-action@v2 with: github-token: ${{ secrets.GITHUB_TOKEN }} + format: lcov From 05e3da3862931150d96d939b51449847b5b17ecf Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 20 Feb 2024 12:18:35 +0100 Subject: [PATCH 20/31] Changed pytest -cov argument --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 40bb2ba..7edb29c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ testpaths = ["tests"] xfail_strict = true addopts = [ "--import-mode=importlib", # allow using test files with same name - "--cov=nimbus_inference", "--cov-report=lcov", ] From 7f72d2f78ae38d67465b2669fe7ad9b48573b421 Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 20 Feb 2024 12:26:31 +0100 Subject: [PATCH 21/31] Reverted change in pyproject.toml --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 7edb29c..40bb2ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ testpaths = ["tests"] xfail_strict = true addopts = [ "--import-mode=importlib", # allow using test files with same name + "--cov=nimbus_inference", "--cov-report=lcov", ] From 560b6a77d0d44e5a68a21f00c459c6486e90010e Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 20 Feb 2024 15:48:20 +0100 Subject: [PATCH 22/31] Added notebook and fixed gpu inference --- pyproject.toml | 2 +- src/nimbus_inference/nimbus.py | 61 ++++--- src/nimbus_inference/utils.py | 16 +- templates/1_Nimbus_Predict.ipynb | 303 +++++++++++++++++++++++++++++++ tests/test_utils.py | 2 +- 5 files changed, 348 insertions(+), 36 deletions(-) create mode 100644 templates/1_Nimbus_Predict.ipynb diff --git a/pyproject.toml b/pyproject.toml index 40bb2ba..85708e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ urls.Documentation = "https://Nimbus-Inference.readthedocs.io/" urls.Source = "https://github.com/angelolab/Nimbus-Inference" urls.Home-page = "https://github.com/angelolab/Nimbus-Inference" dependencies = [ - "torch==2.2.0", + "torch==2.2.0+cu118", "torchvision==0.17.0", "alpineer", "scikit-image", diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 7d77587..4fceec9 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -16,6 +16,7 @@ import torch import json import os +import re def nimbus_preprocess(image, **kwargs): @@ -32,6 +33,8 @@ def nimbus_preprocess(image, **kwargs): normalize = kwargs.get("normalize", True) if normalize: marker = kwargs.get("marker", None) + if re.search(".tiff|.tiff|.png|.jpg|.jpeg", marker, re.IGNORECASE): + marker = marker.split(".")[0] normalization_dict = kwargs.get("normalization_dict", {}) if marker in normalization_dict.keys(): norm_factor = normalization_dict[marker] @@ -107,6 +110,7 @@ def __init__( self.suffix = suffix if self.output_dir != "": os.makedirs(self.output_dir, exist_ok=True) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def check_inputs(self): """check inputs for Nimbus model""" @@ -138,18 +142,14 @@ def initialize_model(self, padding="reflect"): padding (str): Padding mode for model, either "reflect" or "valid". """ model = UNet(num_classes=1, padding=padding) - # make sure path can be resolved on any OS and when importing from anywhere - self.checkpoint_path = os.path.normpath( - "src/nimbus_inference/assets/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt" + # relative path to weights + path = os.path.dirname(nimbus_inference.__file__) + path = Path(path).resolve() + self.checkpoint_path = os.path.join( + path, + "assets", + "resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt" ) - if not os.path.exists(self.checkpoint_path): - path = os.path.abspath(nimbus_inference.__file__) - path = Path(path).resolve() - self.checkpoint_path = os.path.join( - *path.parts[:-3], - "assets", - "resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt" - ) if not os.path.exists(self.checkpoint_path): self.checkpoint_path = os.path.abspath( *glob( @@ -171,14 +171,15 @@ def initialize_model(self, padding="reflect"): raise FileNotFoundError( "Could not find Nimbus weights at {ckpt_path}. \ Current path is {current_path} and directory contains {dir_c},\ - path to cell_clasification i{p}".format( + path to nimbus_inference {p}".format( ckpt_path=self.checkpoint_path, current_path=os.getcwd(), dir_c=os.listdir(os.getcwd()), - p=os.path.abspath(nimbus_inference.__file__), + p=nimbus_inference.__file__, ) ) - self.model = model + self.model = model.to(self.device) + def prepare_normalization_dict( self, quantile=0.999, n_subset=10, multiprocessing=False, overwrite=False, @@ -214,22 +215,27 @@ def predict_fovs(self): if not hasattr(self, "normalization_dict"): self.prepare_normalization_dict() # check if GPU is available - print("Available GPUs: ", torch.cuda.device_count()) + gpus = torch.cuda.device_count() + print("Available GPUs: ", gpus) print("Predictions will be saved in {}".format(self.output_dir)) print("Iterating through fovs will take a while...") if self.suffix.lower() in [".ome.tif", ".ome.tiff"]: self.cell_table = predict_ome_fovs( - self.fov_paths, self.output_dir, self, self.normalization_dict, - self.segmentation_naming_convention, self.include_channels, - self.save_predictions, self.half_resolution, batch_size=self.batch_size, - test_time_augmentation=self.test_time_aug, + nimbus=self, fov_paths=self.fov_paths, output_dir=self.output_dir, + normalization_dict=self.normalization_dict, + segmentation_naming_convention=self.segmentation_naming_convention, + include_channels=self.include_channels, save_predictions=self.save_predictions, + half_resolution=self.half_resolution, batch_size=self.batch_size, + test_time_augmentation=self.test_time_aug, suffix=self.suffix, ) elif self.suffix.lower() in [".tiff", ".tif", ".jpg", ".jpeg", ".png"]: self.cell_table = predict_fovs( - self.fov_paths, self.output_dir, self, self.normalization_dict, - self.segmentation_naming_convention, self.include_channels, - self.save_predictions, self.half_resolution, batch_size=self.batch_size, - test_time_augmentation=self.test_time_aug, + nimbus=self, fov_paths=self.fov_paths, output_dir=self.output_dir, + normalization_dict=self.normalization_dict, + segmentation_naming_convention=self.segmentation_naming_convention, + include_channels=self.include_channels, save_predictions=self.save_predictions, + half_resolution=self.half_resolution, batch_size=self.batch_size, + test_time_augmentation=self.test_time_aug, suffix=self.suffix, ) self.cell_table.to_csv(os.path.join(self.output_dir, "nimbus_cell_table.csv"), index=False) return self.cell_table @@ -244,17 +250,19 @@ def predict_segmentation(self, input_data, preprocess_kwargs): np.array: Predicted segmentation. """ input_data = nimbus_preprocess(input_data, **preprocess_kwargs) - if np.all(np.greater_equal(self.input_shape, input_data.shape[-2:])): + if np.all(np.greater(self.input_shape, input_data.shape[-2:])): if not hasattr(self, "model") or self.model.padding != "reflect": self.initialize_model(padding="reflect") with torch.no_grad(): if not isinstance(input_data, torch.Tensor): input_data = torch.tensor(input_data).float() + input_data = input_data.to(self.device) prediction = self.model(input_data) prediction = prediction.cpu().squeeze(0).numpy() else: if not hasattr(self, "model") or self.model.padding != "valid": self.initialize_model(padding="valid") + prediction = self._tile_and_stitch(input_data, self.batch_size) return prediction @@ -267,7 +275,7 @@ def _tile_and_stitch(self, input_data): np.array: Predicted segmentation. """ with torch.no_grad(): - output_shape = self.model(torch.rand(1, 2, *self.input_shape)).shape[-2:] + output_shape = self.model(torch.rand(1, 2, *self.input_shape).to(self.device)).shape[-2:] input_shape = input_data.shape # f^dl crop to have perfect shift equivariance inference self.crop_by = np.array(output_shape) % 2 ** 5 @@ -288,8 +296,7 @@ def _tile_and_stitch(self, input_data): prediction = [] for i in tqdm(range(0, len(tiled_input), self.batch_size)): batch = torch.from_numpy(tiled_input[i : i + self.batch_size]).float() - if torch.cuda.is_available(): - batch = batch.cuda() + batch = batch.to(self.device) with torch.no_grad(): pred = self.model(batch).cpu().numpy() # crop pred diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py index a609e2c..632d8bd 100644 --- a/src/nimbus_inference/utils.py +++ b/src/nimbus_inference/utils.py @@ -21,6 +21,7 @@ def calculate_normalization(channel_path, quantile): normalization_value (float): normalization value """ mplex_img = io.imread(channel_path) + mplex_img = mplex_img.astype(np.float32) normalization_value = np.quantile(mplex_img, quantile) chan = os.path.basename(channel_path).split(".")[0] return chan, normalization_value @@ -84,6 +85,7 @@ def prepare_input_data(mplex_img, instance_mask): Returns: input_data (np.array): input data for segmentation model """ + mplex_img = mplex_img.astype(np.float32) edge = find_boundaries(instance_mask, mode="inner").astype(np.uint8) binary_mask = np.logical_and(edge == 0, instance_mask > 0).astype(np.float32) input_data = np.stack([mplex_img, binary_mask], axis=0)[np.newaxis,...] # bhwc @@ -183,12 +185,13 @@ def predict_fovs( cell_table (pd.DataFrame): cell table with predicted confidence scores per fov and cell """ fov_dict_list = [] - for fov_path in tqdm(fov_paths): + for fov_path in fov_paths: + print(f"Predicting {fov_path}...") out_fov_path = os.path.join( os.path.normpath(output_dir), os.path.basename(fov_path) ) df_fov = pd.DataFrame() - for channel in os.listdir(fov_path): + for channel in tqdm(os.listdir(fov_path)): channel_path = os.path.join(fov_path, channel) channel_ = channel.split(".")[0] if not channel.endswith(suffix) or ( @@ -213,13 +216,12 @@ def predict_fovs( input_data, channel, nimbus, normalization_dict, batch_size=batch_size ) else: - prediction = nimbus._predict_segmentation( + prediction = nimbus.predict_segmentation( input_data, preprocess_kwargs={ "normalize": True, "marker": channel, "normalization_dict": normalization_dict }, - batch_size=batch_size ) prediction = np.squeeze(prediction) if half_resolution: @@ -233,8 +235,8 @@ def predict_fovs( os.makedirs(out_fov_path, exist_ok=True) pred_int = (prediction*255.0).astype(np.uint8) io.imsave( - os.path.join(out_fov_path, channel), pred_int, - photometric="minisblack", compression="zlib" + os.path.join(out_fov_path, channel), pred_int, check_contrast=False, + plugin="tifffile", photometric="minisblack", compression="zlib", ) fov_dict_list.append(df_fov) cell_table = pd.concat(fov_dict_list, ignore_index=True) @@ -248,7 +250,7 @@ def nimbus_preprocess(image, **kwargs): Returns: np.array: processed image array """ - output = np.copy(image) + output = np.copy(image.astype(np.float32)) if len(image.shape) != 4: raise ValueError("Image data must be 4D, got image of shape {}".format(image.shape)) diff --git a/templates/1_Nimbus_Predict.ipynb b/templates/1_Nimbus_Predict.ipynb new file mode 100644 index 0000000..5c5694d --- /dev/null +++ b/templates/1_Nimbus_Predict.ipynb @@ -0,0 +1,303 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "673d8e3a", + "metadata": {}, + "source": [ + "# Nimbus prediction notebook " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f920e689", + "metadata": {}, + "outputs": [], + "source": [ + "# import required packages\n", + "import warnings\n", + "warnings.simplefilter(\"ignore\")\n", + "import os\n", + "from IPython.display import display, HTML\n", + "display(HTML(\"\"))\n", + "from nimbus_inference.nimbus import Nimbus, prep_naming_convention\n", + "from alpineer import io_utils\n", + "from ark.utils import example_dataset\n", + "from nimbus_inference.viewer_widget import NimbusViewer" + ] + }, + { + "cell_type": "markdown", + "id": "e4642fe2", + "metadata": {}, + "source": [ + "## 0: Set root directory and download example dataset\n", + "Here we are using the example data located in `/data/example_dataset/input_data`. To modify this notebook to run using your own data, simply change `base_dir` to point to your own sub-directory within the data folder. Set `base_dir`, the path to all of your imaging data (i.e. multiplexed images and segmentation masks). Subdirectory `nimbus_output` will contain all of the data generated by this notebook. In the following, we expect this folder structure:\n", + "```\n", + "|-- base_dir\n", + "| |-- image_data\n", + "| | |-- fov_1\n", + "| | |-- fov_2\n", + "| |-- segmentation\n", + "| | |-- deepcell_output\n", + "| |-- nimbus_output\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "974f8dda", + "metadata": {}, + "outputs": [], + "source": [ + "# set up the base directory\n", + "base_dir = os.path.normpath(\"../data/example_dataset\")\n", + "base_dir = os.path.normpath(\"C:/Users/lorenz/Desktop/angelo_lab/data/example_dataset\")" + ] + }, + { + "cell_type": "markdown", + "id": "0ade450f", + "metadata": {}, + "source": [ + "If you would like to test Nimbus with an example dataset, run the cell below. It will download a dataset consisting of 10 FOVs with 22 channels. You may find more information about the example dataset in the [ark-analysis README](https://github.com/angelolab/ark-analysis/blob/bc6685050dfbef4607874fbbadebd4289251c173/README.md#example-dataset). If you want to use your own data, skip the cell below\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37733de5", + "metadata": {}, + "outputs": [], + "source": [ + "example_dataset.get_example_dataset(dataset=\"cluster_pixels\", save_dir = base_dir, overwrite_existing = False)" + ] + }, + { + "cell_type": "markdown", + "id": "9cd2ab6c", + "metadata": {}, + "source": [ + "## 1: set file paths and parameters\n", + "\n", + "### All data, images, files, etc. must be placed in the 'data' directory, and referenced via '../data/path_to_your_data'\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "292e4524", + "metadata": {}, + "outputs": [], + "source": [ + "# set up file paths\n", + "tiff_dir = os.path.join(base_dir, \"image_data\")\n", + "deepcell_output_dir = os.path.join(base_dir, \"segmentation\", \"deepcell_output\")\n", + "nimbus_output_dir = os.path.join(base_dir, \"nimbus_output\")\n", + "\n", + "# Create nimbus output directory\n", + "os.makedirs(nimbus_output_dir, exist_ok=True)\n", + "\n", + "# Check if paths exist\n", + "io_utils.validate_paths([base_dir, tiff_dir, deepcell_output_dir, nimbus_output_dir])" + ] + }, + { + "cell_type": "markdown", + "id": "ae89442a", + "metadata": {}, + "source": [ + "## 2: Set up input paths and the naming convention for the segmentation data\n", + "Store names of channels to exclude in the list below. Either predict all FOVs or specify manually the ones you want to apply Nimbus on." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65a319c9", + "metadata": {}, + "outputs": [], + "source": [ + "# define the channels to include\n", + "include_channels = [\n", + " \"CD3\", \"CD4\", \"CD8\", \"CD14\", \"CD20\", \"CD31\", \"CD45\", \"CD68\", \"CD163\", \"CK17\", \"Collagen1\",\n", + " \"ECAD\", \"Fibronectin\", \"GLUT1\", \"HLADR\", \"IDO\", \"Ki67\", \"PD1\", \"SMA\", \"Vim\"\n", + "]\n", + "\n", + "# either get all fovs in the folder...\n", + "fov_names = os.listdir(tiff_dir)\n", + "# ... or optionally, select a specific set of fovs manually\n", + "# fovs = [\"fov0\", \"fov1\"]\n", + "\n", + "# construct paths for fovs\n", + "fov_paths = [os.path.join(tiff_dir, fov_name) for fov_name in fov_names]" + ] + }, + { + "cell_type": "markdown", + "id": "8c85f682", + "metadata": {}, + "source": [ + "Define the naming convention for the segmentation data in function `segmentation_naming_convention`, that maps the `fov_name` to the path of the associated segmentation output. The below function `prep_deepcell_naming_convention` assumes that all segmentation outputs are stored in one folder, with the `fov_name` as the prefix and `_whole_cell.tiff` as the suffix, as shown below in the visualization of the folder structure. If this does not apply to your data, you have to define a function `segmentation_naming_convention` that takes an element from `fov_paths` and returns a valid path to the segmentation label map you want to use for that fov.\n", + "\n", + "```\n", + "|-- base_dir\n", + "| |-- image_data\n", + "| | |-- fov_1\n", + "| | |-- fov_2\n", + "| |-- segmentation\n", + "| | |-- deepcell_output\n", + "| | | |-- fov_1_whole_cell.tiff\n", + "| | | |-- fov_2_whole_cell.tiff\n", + "| |-- nimbus_output\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc8256e6", + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare segmentation naming convention that maps a fov_path to the according segmentation label map\n", + "segmentation_naming_convention = prep_naming_convention(deepcell_output_dir)\n", + "\n", + "# test segmentation_naming_convention\n", + "if os.path.exists(segmentation_naming_convention(fov_paths[0])):\n", + " print(\"Segmentation data exists for fov 0 and naming convention is correct\")\n", + "else:\n", + " print(\"Segmentation data does not exist for fov 0 or naming convention is incorrect\")" + ] + }, + { + "cell_type": "markdown", + "id": "839e5240", + "metadata": {}, + "source": [ + "## 3: Load model and initialize Nimbus application\n", + "The following code initializes the Nimbus application and loads the model checkpoint. The model was trained on a diverse set of tissues, protein markers, imaging platforms and cell types and doesn't need re-training. If you want to use the model on a machine without GPU, set `test_time_aug=False` to speed up inference. If you run it on a laptop GPU and run into out-of-memory errors, consider reducing the `batch_size` to 1 and the `input_shape` to `[512,512]`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7fd0a575", + "metadata": {}, + "outputs": [], + "source": [ + "nimbus = Nimbus(\n", + " fov_paths=fov_paths,\n", + " segmentation_naming_convention=segmentation_naming_convention,\n", + " output_dir=nimbus_output_dir,\n", + " include_channels=include_channels,\n", + " save_predictions=True,\n", + " batch_size=4,\n", + " test_time_aug=True,\n", + " input_shape=[1024,1024],\n", + " suffix=\".tiff\",\n", + ")\n", + "\n", + "# check if all inputs are valid\n", + "nimbus.check_inputs()" + ] + }, + { + "cell_type": "markdown", + "id": "bbce682e", + "metadata": {}, + "source": [ + "## 4: Prepare normalization dictionary \n", + "The next step is to iterate through all the fovs and calculate the 0.999 marker expression quantile for each marker individually. This is used for normalizing the marker expressions prior to predicting marker confidence scores with our model. You can set `n_subset` to estimate the quantiles on a small subset of the data and you can set `multiprocessing=True` to speed up computation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41b100e7", + "metadata": {}, + "outputs": [], + "source": [ + "nimbus.prepare_normalization_dict(\n", + " n_subset=50,\n", + " multiprocessing=True,\n", + " overwrite=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9e782794", + "metadata": {}, + "source": [ + "## 5: Make predictions with the model\n", + "Nimbus will iterate through your samples and store predictions and a file named `nimbus_cell_table.csv` that contains the mean-per-cell predicted marker confidence scores in the sub-directory called `nimbus_output`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "76225704", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "cell_table = nimbus.predict_fovs()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca222e0e", + "metadata": {}, + "outputs": [], + "source": [ + "cell_table" + ] + }, + { + "cell_type": "markdown", + "id": "fdef2ab9", + "metadata": {}, + "source": [ + "## 6: View multiplexed channels and Nimbus predictions side-by-side\n", + "Select an FOV and one marker image per channel to inspect the imaging data and associated Nimbus predictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f95e351", + "metadata": {}, + "outputs": [], + "source": [ + "viewer = NimbusViewer(input_dir=tiff_dir, output_dir=nimbus_output_dir)\n", + "viewer.display()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/test_utils.py b/tests/test_utils.py index 976aef2..7882bae 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -148,7 +148,7 @@ def test_prepare_input_data(): assert np.alltrue(np.unique(input_data[:,1]) == np.array([0, 1])) assert np.sum(input_data[:,1]) < np.sum(instance_mask) # check if mplex image is the same as before - assert np.alltrue(input_data[0, 0] == mplex_img) + assert np.alltrue(input_data[0, 0] == mplex_img.astype(np.float32)) def test_tt_aug(): From eb8dd853e619f683bd6f3fc9730f2f5c168e82f3 Mon Sep 17 00:00:00 2001 From: Lenz Date: Wed, 21 Feb 2024 11:11:44 +0100 Subject: [PATCH 23/31] Fixed image normalization --- src/nimbus_inference/nimbus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 4fceec9..77a1af8 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -45,9 +45,9 @@ def nimbus_preprocess(image, **kwargs): marker ) ) - norm_factor = np.quantile(output[..., 0], 0.999) + norm_factor = np.quantile(output[:,0,...], 0.999) # normalize only marker channel in chan 0 not binary mask in chan 1 - output[..., 0] /= norm_factor + output[:,0,...] /= norm_factor output = output.clip(0, 1) return output From 23e70a86943dbac0aace4b87f7bd349b1faf7593 Mon Sep 17 00:00:00 2001 From: Lenz Date: Thu, 22 Feb 2024 16:32:08 +0100 Subject: [PATCH 24/31] Deleted lfs checkpoint and added huggingface download --- pyproject.toml | 1 + ...ec_mskc_mskp_2_channel_halfres_512_bs32.pt | 3 -- src/nimbus_inference/nimbus.py | 36 +++++-------------- 3 files changed, 10 insertions(+), 30 deletions(-) delete mode 100644 src/nimbus_inference/assets/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt diff --git a/pyproject.toml b/pyproject.toml index 85708e6..3d1b609 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "pandas", "pathlib", "pyometiff", + "huggingface_hub", # for debug logging (referenced from the issue template) "session-info", ] diff --git a/src/nimbus_inference/assets/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt b/src/nimbus_inference/assets/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt deleted file mode 100644 index 6d98b5e..0000000 --- a/src/nimbus_inference/assets/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a9f607c19ecf3b1ff998b52b75a8290efefb315ba2808e7d314e6444c8fda885 -size 142545376 diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 77a1af8..6b5b520 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -7,6 +7,7 @@ predict_ome_fovs, nimbus_preprocess, ) +from huggingface_hub import hf_hub_download from nimbus_inference.unet import UNet from tqdm.autonotebook import tqdm from pathlib import Path @@ -151,36 +152,17 @@ def initialize_model(self, padding="reflect"): "resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt" ) if not os.path.exists(self.checkpoint_path): - self.checkpoint_path = os.path.abspath( - *glob( - "**/resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt" - ) - ) - - if not os.path.exists(self.checkpoint_path): - self.checkpoint_path = os.path.join( - os.getcwd(), - "assets", - "resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt", - ) - - if os.path.exists(self.checkpoint_path): - model.load_state_dict(torch.load(self.checkpoint_path)) - print("Loaded weights from {}".format(self.checkpoint_path)) - else: - raise FileNotFoundError( - "Could not find Nimbus weights at {ckpt_path}. \ - Current path is {current_path} and directory contains {dir_c},\ - path to nimbus_inference {p}".format( - ckpt_path=self.checkpoint_path, - current_path=os.getcwd(), - dir_c=os.listdir(os.getcwd()), - p=nimbus_inference.__file__, - ) + local_dir = os.path.join(path, "assets") + self.checkpoint_path = hf_hub_download( + repo_id="JLrumberger/Nimbus-Inference", + filename="resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt", + local_dir=local_dir, + local_dir_use_symlinks=False, ) + model.load_state_dict(torch.load(self.checkpoint_path)) + print("Loaded weights from {}".format(self.checkpoint_path)) self.model = model.to(self.device) - def prepare_normalization_dict( self, quantile=0.999, n_subset=10, multiprocessing=False, overwrite=False, ): From 8b70ccd8dafe03a4bd4a324a5549aa52aa760dd0 Mon Sep 17 00:00:00 2001 From: Lenz Date: Thu, 22 Feb 2024 16:48:56 +0100 Subject: [PATCH 25/31] Changed dependencies --- pyproject.toml | 3 ++- src/nimbus_inference/nimbus.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3d1b609..014c425 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ urls.Documentation = "https://Nimbus-Inference.readthedocs.io/" urls.Source = "https://github.com/angelolab/Nimbus-Inference" urls.Home-page = "https://github.com/angelolab/Nimbus-Inference" dependencies = [ - "torch==2.2.0+cu118", + "torch==2.2.0", "torchvision==0.17.0", "alpineer", "scikit-image", @@ -36,6 +36,7 @@ dependencies = [ "session-info", ] + [project.optional-dependencies] dev = ["pre-commit", "twine>=4.0.2"] doc = [ diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 6b5b520..80a13f3 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -153,6 +153,7 @@ def initialize_model(self, padding="reflect"): ) if not os.path.exists(self.checkpoint_path): local_dir = os.path.join(path, "assets") + print("Downloading weights from Hugging Face Hub...") self.checkpoint_path = hf_hub_download( repo_id="JLrumberger/Nimbus-Inference", filename="resUnet_baseline_hickey_tonic_dec_mskc_mskp_2_channel_halfres_512_bs32.pt", From d00a26835a1c27b04b8838d0dacdb9a54fe780fc Mon Sep 17 00:00:00 2001 From: Lenz Date: Thu, 22 Feb 2024 17:01:06 +0100 Subject: [PATCH 26/31] Got rid of lfs pull in workflows --- .github/workflows/build.yaml | 3 --- .github/workflows/test.yaml | 3 --- 2 files changed, 6 deletions(-) diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 3354ebe..3221efe 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -20,9 +20,6 @@ jobs: with: fetch-depth: 0 - - name: LFS Pull - run: git lfs pull - - name: Build Wheels and Source Distribution run: pipx run build --wheel --sdist diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 5ad010d..59fdc4d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -22,9 +22,6 @@ jobs: with: fetch-depth: 0 - - name: LFS Pull - run: git lfs pull - - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: From ee5c75ecfa36dac1f5afe7901a7fc03ea869ac96 Mon Sep 17 00:00:00 2001 From: Lenz Date: Fri, 23 Feb 2024 18:30:00 +0100 Subject: [PATCH 27/31] Fixed tile and stitch for batches --- src/nimbus_inference/nimbus.py | 28 +++++++++++++++------------- src/nimbus_inference/utils.py | 12 ++++++++---- tests/test_nimbus.py | 8 ++++---- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 80a13f3..012f277 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -49,7 +49,7 @@ def nimbus_preprocess(image, **kwargs): norm_factor = np.quantile(output[:,0,...], 0.999) # normalize only marker channel in chan 0 not binary mask in chan 1 output[:,0,...] /= norm_factor - output = output.clip(0, 1) + # output = output.clip(0, 1) return output @@ -245,8 +245,7 @@ def predict_segmentation(self, input_data, preprocess_kwargs): else: if not hasattr(self, "model") or self.model.padding != "valid": self.initialize_model(padding="valid") - - prediction = self._tile_and_stitch(input_data, self.batch_size) + prediction = self._tile_and_stitch(input_data) return prediction def _tile_and_stitch(self, input_data): @@ -271,9 +270,9 @@ def _tile_and_stitch(self, input_data): padding[0] - shape_diff[0], padding[1] - shape_diff[0], padding[2] - shape_diff[1], padding[3] - shape_diff[1], ] - h_t, h_w = tiled_input.shape[:2] + h_t, h_w, b, c = tiled_input.shape[:4] tiled_input = tiled_input.reshape( - h_t * h_w, input_shape[1], *self.input_shape + h_t * h_w * b, c, *self.input_shape ) # h_t,w_t,c,h,w -> h_t*w_t,c,h,w # predict tiles prediction = [] @@ -291,7 +290,7 @@ def _tile_and_stitch(self, input_data): ] prediction += [pred] prediction = np.concatenate(prediction) # h_t*w_t,c,h,w - prediction = prediction.reshape(h_t, h_w, *prediction.shape[1:]) # h_t,w_t,c,h,w + prediction = prediction.reshape(h_t, h_w, b, *prediction.shape[1:]) # h_t,w_t,b,c,h,w # stitch tiles prediction = self._stitch_tiles(prediction, padding) return prediction @@ -316,9 +315,11 @@ def _tile_input(self, image, tile_size, output_shape, pad_mode="reflect"): image = np.pad(image, ((0, 0), (0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1)), mode=pad_mode) b, c = image.shape[:2] # tile image - view = np.squeeze( - view_as_windows(image, [b, c] + list(tile_size), step=[b, c] + list(output_shape)) - ) # h_t,w_t,c,h,w + view = np.squeeze( + view_as_windows(image, [b, c] + list(tile_size), step=[b, c] + list(output_shape)), + axis=(0,1) + ) + # h_t,w_t,b,c,h,w padding = [pad_h0, pad_h1, pad_w0, pad_w1] return view, padding @@ -331,11 +332,12 @@ def _stitch_tiles(self, tiles, padding): np.array: Reconstructed image. """ # stitch tiles - h_t, w_t, c, h, w = tiles.shape - stitched = np.zeros((c, h_t * h, w_t * w)) + h_t, w_t, b, c, h, w = tiles.shape + stitched = np.zeros((b, c, h_t * h, w_t * w)) for i in range(h_t): for j in range(w_t): - stitched[:, i * h : (i + 1) * h, j * w : (j + 1) * w] = tiles[i, j] + for b_ in range(b): + stitched[b_, :, i * h : (i + 1) * h, j * w : (j + 1) * w] = tiles[i, j, b_] # remove padding - stitched = stitched[:, padding[0] : -padding[1], padding[2] : -padding[3]] + stitched = stitched[:, :, padding[0] : -padding[1], padding[2] : -padding[3]] return stitched diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py index 632d8bd..406b356 100644 --- a/src/nimbus_inference/utils.py +++ b/src/nimbus_inference/utils.py @@ -5,9 +5,11 @@ import random import numpy as np import pandas as pd -from skimage import io +import imageio as io +# from skimage import io from tqdm.autonotebook import tqdm from joblib import Parallel, delayed +from joblib.externals.loky import get_reusable_executor from skimage.segmentation import find_boundaries from skimage.measure import regionprops_table @@ -69,6 +71,8 @@ def prepare_normalization_dict( if channel not in normalization_dict: normalization_dict[channel] = [] normalization_dict[channel].append(normalization_value) + if n_jobs > 1: + get_reusable_executor().shutdown(wait=True) for channel in normalization_dict.keys(): normalization_dict[channel] = np.mean(normalization_dict[channel]) # save normalization dict @@ -234,9 +238,9 @@ def predict_fovs( if save_predictions: os.makedirs(out_fov_path, exist_ok=True) pred_int = (prediction*255.0).astype(np.uint8) - io.imsave( - os.path.join(out_fov_path, channel), pred_int, check_contrast=False, - plugin="tifffile", photometric="minisblack", compression="zlib", + io.imwrite( + os.path.join(out_fov_path, channel), pred_int, photometric="minisblack", + # compress=0, ) fov_dict_list.append(df_fov) cell_table = pd.concat(fov_dict_list, ignore_index=True) diff --git a/tests/test_nimbus.py b/tests/test_nimbus.py index 55dc465..a954e20 100644 --- a/tests/test_nimbus.py +++ b/tests/test_nimbus.py @@ -67,9 +67,9 @@ def test_tile_input(): nimbus = Nimbus(fov_paths=[""], segmentation_naming_convention="", output_dir="") nimbus.model = lambda x: x[..., 96:-96, 96:-96] tiled_input, padding = nimbus._tile_input(image, tile_size, output_shape) - assert tiled_input.shape == (3,3,2,512,512) + assert tiled_input.shape == (3,3,1,2,512,512) assert padding == [192, 192, 192, 192] - + def test_tile_and_stitch(): # tests _tile_and_stitch which chains _tile_input, model.forward and _stitch_tiles @@ -85,12 +85,12 @@ def test_tile_and_stitch(): nimbus.model = lambda x: x[..., s:-s, s:-s] out = nimbus._tile_and_stitch(image) assert np.all( - np.isclose(np.transpose(image[0], (1,2,0)), np.transpose(out, (1,2,0)), rtol=1e-4) + np.isclose(image, out, rtol=1e-4) ) # check if tile and stitch works with the real model nimbus.initialize_model(padding="valid") image = np.random.rand(1, 2, 768, 768) prediction = nimbus._tile_and_stitch(image) - assert prediction.shape == (1, 768, 768) + assert prediction.shape == (1, 1, 768, 768) assert prediction.max() <= 1 assert prediction.min() >= 0 From 4089704940497b2cefe7ef4c08a967b647ce081e Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 26 Feb 2024 13:06:56 +0100 Subject: [PATCH 28/31] Put model into eval mode and added viewer widget updates --- src/nimbus_inference/nimbus.py | 8 +- src/nimbus_inference/unet.py | 6 +- src/nimbus_inference/viewer_widget.py | 194 ++++++++++++++++++++++++++ tests/test_viewer_widget.py | 33 +++++ 4 files changed, 231 insertions(+), 10 deletions(-) create mode 100644 src/nimbus_inference/viewer_widget.py create mode 100644 tests/test_viewer_widget.py diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 012f277..90efcc6 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -162,7 +162,7 @@ def initialize_model(self, padding="reflect"): ) model.load_state_dict(torch.load(self.checkpoint_path)) print("Loaded weights from {}".format(self.checkpoint_path)) - self.model = model.to(self.device) + self.model = model.to(self.device).eval() def prepare_normalization_dict( self, quantile=0.999, n_subset=10, multiprocessing=False, overwrite=False, @@ -258,7 +258,6 @@ def _tile_and_stitch(self, input_data): """ with torch.no_grad(): output_shape = self.model(torch.rand(1, 2, *self.input_shape).to(self.device)).shape[-2:] - input_shape = input_data.shape # f^dl crop to have perfect shift equivariance inference self.crop_by = np.array(output_shape) % 2 ** 5 output_shape = output_shape - self.crop_by @@ -285,8 +284,8 @@ def _tile_and_stitch(self, input_data): if self.crop_by.any(): pred = pred[ ..., - self.crop_by[0] // 2 : -self.crop_by[0] // 2, - self.crop_by[1] // 2 : -self.crop_by[1] // 2, + self.crop_by[0]//2 : -self.crop_by[0]//2, + self.crop_by[1]//2 : -self.crop_by[1]//2, ] prediction += [pred] prediction = np.concatenate(prediction) # h_t*w_t,c,h,w @@ -305,7 +304,6 @@ def _tile_input(self, image, tile_size, output_shape, pad_mode="reflect"): Returns: list: List of tiled images. """ - overlap_px = (np.array(tile_size) - np.array(output_shape)) // 2 # pad image to be divisible by tile size pad_h0, pad_w0 = np.array(tile_size) - ( np.array(image.shape[-2:]) % np.array(output_shape) diff --git a/src/nimbus_inference/unet.py b/src/nimbus_inference/unet.py index 7682cb3..f050bed 100644 --- a/src/nimbus_inference/unet.py +++ b/src/nimbus_inference/unet.py @@ -108,8 +108,6 @@ def __init__( ) self.bn_2 = nn.BatchNorm2d(filters) - - def forward(self, x): skip = self.conv2d_0(x) x = self.padding_layer(skip) @@ -195,9 +193,7 @@ def forward(self, x, down_layer): class UNet(nn.Module): - def __init__(self, nx: int = 512, - ny: int = 512, - channels: int = 1, + def __init__(self, num_classes: int = 2, layer_depth: int = 5, filters_root: int = 64, diff --git a/src/nimbus_inference/viewer_widget.py b/src/nimbus_inference/viewer_widget.py new file mode 100644 index 0000000..eef7899 --- /dev/null +++ b/src/nimbus_inference/viewer_widget.py @@ -0,0 +1,194 @@ +import os +import ipywidgets as widgets +from IPython.display import display +from io import BytesIO +from skimage import io +from copy import copy +import numpy as np +from natsort import natsorted + + +class NimbusViewer(object): + def __init__(self, input_dir, output_dir, img_width='600px'): + """Viewer for Nimbus application. + Args: + input_dir (str): Path to directory containing individual channels of multiplexed images + output_dir (str): Path to directory containing output of Nimbus application. + img_width (str): Width of images in viewer. + """ + self.image_width = img_width + self.input_dir = input_dir + self.output_dir = output_dir + self.fov_names = [os.path.basename(p) for p in os.listdir(output_dir) if \ + os.path.isdir(os.path.join(output_dir, p))] + self.fov_names = natsorted(self.fov_names) + self.update_button = widgets.Button(description="Update Image") + self.update_button.on_click(self.update_button_click) + + self.fov_select = widgets.Select( + options=self.fov_names, + description='FOV:', + disabled=False + ) + self.fov_select.observe(self.select_fov, names='value') + + self.red_select = widgets.Select( + options=[], + description='Red:', + disabled=False + ) + self.green_select = widgets.Select( + options=[], + description='Green:', + disabled=False + ) + self.blue_select = widgets.Select( + options=[], + description='Blue:', + disabled=False + ) + self.input_image = widgets.Image() + self.output_image = widgets.Image() + + def select_fov(self, change): + """Selects fov to display. + Args: + change (dict): Change dictionary from ipywidgets. + """ + fov_path = os.path.join(self.output_dir, self.fov_select.value) + channels = [ + ch for ch in os.listdir(fov_path) if os.path.isfile(os.path.join(fov_path, ch)) + ] + self.red_select.options = natsorted(channels) + self.green_select.options = natsorted(channels) + self.blue_select.options = natsorted(channels) + + def create_composite_image(self, path_dict): + """Creates composite image from input paths. + Args: + path_dict (dict): Dictionary of paths to images. + Returns: + composite_image (np.array): Composite image. + """ + for k in ["red", "green", "blue"]: + if k not in path_dict.keys(): + path_dict[k] = None + output_image = [] + for k, p in path_dict.items(): + if p: + img = io.imread(p) + output_image.append(img) + else: + non_none = [p for p in path_dict.values() if p] + img = io.imread(non_none[0]) + output_image.append(img*0) + + composite_image = np.stack(output_image, axis=-1) + return composite_image + + def layout(self): + """Creates layout for viewer.""" + channel_selectors = widgets.VBox([ + self.red_select, + self.green_select, + self.blue_select + ]) + self.input_image.layout.width = self.image_width + self.output_image.layout.width = self.image_width + viewer_html = widgets.HTML("

Select files

") + input_html = widgets.HTML("

Input

") + output_html = widgets.HTML("

Nimbus Output

") + + layout = widgets.HBox([ + widgets.VBox([ + viewer_html, + self.fov_select, + channel_selectors, + self.update_button + ]), + widgets.VBox([ + input_html, + self.input_image + ]), + widgets.VBox([ + output_html, + self.output_image + ]) + ]) + display(layout) + + def search_for_similar(self, select_value): + """Searches for similar filename in input directory. + Args: + select_value (str): Filename to search for. + Returns: + similar_path (str): Path to similar filename. + """ + in_f_path = os.path.join(self.input_dir, self.fov_select.value) + # search for similar filename in in_f_path + in_f_files = [ + f for f in os.listdir(in_f_path) if os.path.isfile(os.path.join(in_f_path, f)) + ] + similar_path = None + for f in in_f_files: + if select_value.split(".")[0]+"." in f: + similar_path = os.path.join(self.input_dir, self.fov_select.value, f) + return similar_path + + def update_img(self, image_viewer, composite_image): + """Updates image in viewer by saving it as png and loading it with the viewer widget. + Args: + image_viewer (ipywidgets.Image): Image widget to update. + composite_image (np.array): Composite image to display. + """ + # Convert composite image to bytes and assign it to the output_image widget + with BytesIO() as output_buffer: + io.imsave(output_buffer, composite_image, format="png") + output_buffer.seek(0) + image_viewer.value = output_buffer.read() + + def update_composite(self): + """Updates composite image in viewer.""" + path_dict = { + "red": None, + "green": None, + "blue": None + } + in_path_dict = copy(path_dict) + if self.red_select.value: + path_dict["red"] = os.path.join( + self.output_dir, self.fov_select.value, self.red_select.value + ) + in_path_dict["red"] = self.search_for_similar(self.red_select.value) + if self.green_select.value: + path_dict["green"] = os.path.join( + self.output_dir, self.fov_select.value, self.green_select.value + ) + in_path_dict["green"] = self.search_for_similar(self.green_select.value) + if self.blue_select.value: + path_dict["blue"] = os.path.join( + self.output_dir, self.fov_select.value, self.blue_select.value + ) + in_path_dict["blue"] = self.search_for_similar(self.blue_select.value) + non_none = [p for p in path_dict.values() if p] + if not non_none: + return + composite_image = self.create_composite_image(path_dict) + in_composite_image = self.create_composite_image(in_path_dict) + in_composite_image = in_composite_image / np.quantile( + in_composite_image, 0.999, axis=(0,1) + ) + in_composite_image = np.clip(in_composite_image*255, 0, 255).astype(np.uint8) + # update image viewers + self.update_img(self.input_image, in_composite_image) + self.update_img(self.output_image, composite_image) + + def update_button_click(self, button): + """Updates composite image in viewer when update button is clicked.""" + self.update_composite() + + def display(self): + """Displays viewer.""" + self.select_fov(None) + self.layout() + self.update_composite() \ No newline at end of file diff --git a/tests/test_viewer_widget.py b/tests/test_viewer_widget.py new file mode 100644 index 0000000..163271c --- /dev/null +++ b/tests/test_viewer_widget.py @@ -0,0 +1,33 @@ +from nimbus_inference.viewer_widget import NimbusViewer +from tests.test_utils import prepare_ome_tif_data, prepare_tif_data +import numpy as np +import tempfile +import os + + +def test_NimbusViewer(): + with tempfile.TemporaryDirectory() as temp_dir: + _ = prepare_tif_data( + num_samples=2, temp_dir=temp_dir, selected_markers=["CD4", "CD11c", "CD56"] + ) + viewer_widget = NimbusViewer(temp_dir, temp_dir) + assert isinstance(viewer_widget, NimbusViewer) + + +def test_composite_image(): + with tempfile.TemporaryDirectory() as temp_dir: + _ = prepare_tif_data( + num_samples=2, temp_dir=temp_dir, selected_markers=["CD4", "CD11c", "CD56"] + ) + viewer_widget = NimbusViewer(temp_dir, temp_dir) + path_dict = { + "red": os.path.join(temp_dir, "fov_0", "CD4.tiff"), + "green": os.path.join(temp_dir, "fov_0", "CD11c.tiff"), + } + composite_image = viewer_widget.create_composite_image(path_dict) + assert isinstance(composite_image, np.ndarray) + assert composite_image.shape == (256, 256, 3) + + path_dict["blue"] = os.path.join(temp_dir, "fov_0", "CD56.tiff") + composite_image = viewer_widget.create_composite_image(path_dict) + assert composite_image.shape == (256, 256, 3) \ No newline at end of file From 9727988f17b228ce139aa6265036b63e9d1c1a17 Mon Sep 17 00:00:00 2001 From: Lenz Date: Mon, 26 Feb 2024 13:10:59 +0100 Subject: [PATCH 29/31] Added ipywidgets, natsort and ipython to dependencies --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 014c425..9642cc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,9 @@ dependencies = [ "huggingface_hub", # for debug logging (referenced from the issue template) "session-info", + "ipywidgets", + "natsort", + "ipython", ] From dabeb1d5819dc9dabfde4f9cc6022838606c9216 Mon Sep 17 00:00:00 2001 From: Lenz Date: Tue, 27 Feb 2024 13:26:41 +0100 Subject: [PATCH 30/31] Changed whole-image vs. tile&stitch inference control flow --- src/nimbus_inference/nimbus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index 90efcc6..56616b2 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -233,7 +233,7 @@ def predict_segmentation(self, input_data, preprocess_kwargs): np.array: Predicted segmentation. """ input_data = nimbus_preprocess(input_data, **preprocess_kwargs) - if np.all(np.greater(self.input_shape, input_data.shape[-2:])): + if np.all(np.greater_equal(self.input_shape, input_data.shape[-2:])): if not hasattr(self, "model") or self.model.padding != "reflect": self.initialize_model(padding="reflect") with torch.no_grad(): From f34237dae0327dcf4af50633d4663ee1bc52c6b3 Mon Sep 17 00:00:00 2001 From: Lenz Date: Fri, 1 Mar 2024 10:15:35 +0100 Subject: [PATCH 31/31] Normalization only on non-zero pixels --- src/nimbus_inference/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/nimbus_inference/utils.py b/src/nimbus_inference/utils.py index 406b356..7d74fea 100644 --- a/src/nimbus_inference/utils.py +++ b/src/nimbus_inference/utils.py @@ -24,7 +24,8 @@ def calculate_normalization(channel_path, quantile): """ mplex_img = io.imread(channel_path) mplex_img = mplex_img.astype(np.float32) - normalization_value = np.quantile(mplex_img, quantile) + foreground = mplex_img[mplex_img > 0] + normalization_value = np.quantile(foreground, quantile) chan = os.path.basename(channel_path).split(".")[0] return chan, normalization_value