From 910295021fbd009f108904d34a798d1a9c5fa2e4 Mon Sep 17 00:00:00 2001 From: Tim Kempchen <80586406+TKempchen@users.noreply.github.com> Date: Thu, 27 Jun 2024 10:18:03 +0200 Subject: [PATCH] adding updates as discussed yesterday --- notebooks/2_preprocessing.ipynb | 10 - src/spacec/_shared/segmentation.py | 68 +++++- src/spacec/helperfunctions/_general.py | 13 + src/spacec/plotting/_general.py | 54 ++++- src/spacec/plotting/_segmentation.py | 17 +- src/spacec/tools/_general.py | 172 ++++++++++++- src/spacec/tools/_segmentation.py | 318 ++++++++++++++++++++----- 7 files changed, 568 insertions(+), 84 deletions(-) diff --git a/notebooks/2_preprocessing.ipynb b/notebooks/2_preprocessing.ipynb index d7e0fba..7ab4306 100755 --- a/notebooks/2_preprocessing.ipynb +++ b/notebooks/2_preprocessing.ipynb @@ -14,16 +14,6 @@ "## Set up environment" ] }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, { "cell_type": "code", "execution_count": 2, diff --git a/src/spacec/_shared/segmentation.py b/src/spacec/_shared/segmentation.py index 9d28ed8..4eb1db7 100644 --- a/src/spacec/_shared/segmentation.py +++ b/src/spacec/_shared/segmentation.py @@ -5,6 +5,23 @@ def create_multichannel_tiff(input_dir, output_dir, output_filename): + """ + Create a multi-channel TIFF image from individual TIFF files. + + Parameters + ---------- + input_dir : str + Directory containing the input TIFF files. + output_dir : str + Directory to save the output TIFF file. + output_filename : str + Name of the output TIFF file. + + Returns + ------- + list of str + List of channel names. + """ # Get a list of all TIFF files in the input directory tiff_files = [f for f in os.listdir(input_dir) if f.endswith((".tiff", ".tif"))] @@ -35,18 +52,36 @@ def create_multichannel_tiff(input_dir, output_dir, output_filename): # combine multiple channels in one image and add as new image to image_dict with the name segmentation_channel def combine_channels(image_dict, channel_list, new_channel_name): + """ + Combine multiple channels into a single channel. + + Parameters + ---------- + image_dict : dict + Dictionary with channel names as keys and images as values. + channel_list : list of str + List of channel names to combine. + new_channel_name : str + Name of the new channel. + + Returns + ------- + dict + Updated dictionary with the new channel added. + """ + # Determine bit depth of input images + bit_depth = image_dict[channel_list[0]].dtype + # Create empty image new_image = np.zeros( - (image_dict[channel_list[0]].shape[0], image_dict[channel_list[0]].shape[1]) + (image_dict[channel_list[0]].shape[0], image_dict[channel_list[0]].shape[1]), + dtype=bit_depth ) # Add channels to image as maximum projection for channel in channel_list: new_image = np.maximum(new_image, image_dict[channel]) - # generate greyscale image - new_image = np.uint8(new_image) - # Add image to image_dict image_dict[new_channel_name] = new_image @@ -62,6 +97,31 @@ def format_CODEX( stack=True, input_format="Multichannel", ): + """ + Format images based on the input format. + + Parameters + ---------- + image : ndarray or str + Input image or directory containing images. + channel_names : list of str, optional + List of channel names. + number_cycles : int, optional + Number of cycles in the CODEX format. + images_per_cycle : int, optional + Number of images per cycle in the CODEX format. + stack : bool, default=True + If True, stack the images in the list. + input_format : str, default="Multichannel" + Format of the input images. Options are "CODEX", "Multichannel", and "Channels". + + Returns + ------- + dict + Dictionary with channel names as keys and images as values. If `stack` is True and `input_format` is "CODEX", + also returns a stacked image as a numpy array. + """ + if input_format == "CODEX": total_images = number_cycles * images_per_cycle image_list = [None] * total_images # pre-allocated list diff --git a/src/spacec/helperfunctions/_general.py b/src/spacec/helperfunctions/_general.py index 3c0321f..de7d9a4 100644 --- a/src/spacec/helperfunctions/_general.py +++ b/src/spacec/helperfunctions/_general.py @@ -1053,6 +1053,18 @@ def is_dark(color): def check_for_gpu(): + """ + Check if a GPU is available for use by TensorFlow and PyTorch. + + This function checks if a GPU is available for use by TensorFlow and PyTorch. + It prints a message indicating whether a GPU is available for each library, + and returns a boolean indicating whether a GPU is available for PyTorch. + + Returns + ------- + bool + True if a GPU is available for PyTorch, False otherwise. + """ if tf.config.list_physical_devices("GPU"): print("GPU is available to Tensorflow") else: @@ -1061,3 +1073,4 @@ def check_for_gpu(): use_GPU = use_gpu() yn = ["GPU is not available to Pytorch", "GPU is available to Pytorch"] print(f"{yn[use_GPU]}") + return use_GPU diff --git a/src/spacec/plotting/_general.py b/src/spacec/plotting/_general.py index a1c1132..7030504 100644 --- a/src/spacec/plotting/_general.py +++ b/src/spacec/plotting/_general.py @@ -3816,6 +3816,34 @@ def cn_map( output_dir="./", rand_seed=1, ): + """ + Generates a CNMap plot using the provided data and parameters. + + Parameters + ---------- + adata : anndata.AnnData + Annotated data matrix. + cnmap_dict : dict + Dictionary containing graph, tops, e0, e1, and simp_freqs. + cn_col : str + Column name in adata to be used for color coding. + palette : dict, optional + Color palette to use for the plot. If None, a random color palette is generated. + figsize : tuple, optional + Size of the figure. Defaults to (40, 20). + savefig : bool, optional + Whether to save the figure or not. Defaults to False. + output_fname : str, optional + The filename for the saved figure. Required if savefig is True. Defaults to "". + output_dir : str, optional + The directory where the figure will be saved. Defaults to "./". + rand_seed : int, optional + Seed for random number generator. Defaults to 1. + + Returns + ------- + None + """ graph = cnmap_dict["g"] tops = cnmap_dict["tops"] e0 = cnmap_dict["e0"] @@ -3854,6 +3882,28 @@ def cn_map( c=col, zorder=-1, ) + # Dummy scatter plots for legend + freqs = simp_freqs * 10000 + max_size = max(freqs) + sizes = [round(max_size)/4, round(max_size)/2, round(max_size)] # Replace with the sizes you want in the legend + labels = [str(round(max_size/100)/4) + '%', str(round(max_size/100)/2) + '%', str(round(max_size/100)) + '%'] # Replace with the labels you want in the legend + + # Add legend + legend_elements = [plt.Line2D([0], [0], marker='o', color='w', label=label, + markerfacecolor='black', markersize=size**0.5) + for size, label in zip(sizes, labels)] + + # Add first legend + legend1 = plt.legend(handles=legend_elements, + loc='lower right', + title='Total frequency', + title_fontsize = 30, + fontsize=30, + handlelength=6, + handletextpad=1, + bbox_to_anchor=(0.0, -0.15, 1.0, 0.102)) + + if n in tops: plt.text( pos[n][0], @@ -3903,10 +3953,11 @@ def cn_map( ] # Add legend to bottom of plot + plt.gca().add_artist(legend1) plt.legend( handles=legend_patches, bbox_to_anchor=(0.0, -0.15, 1.0, 0.102), - loc="lower center", + loc="lower left", ncol=3, borderaxespad=0.0, fontsize=35, @@ -3920,6 +3971,7 @@ def cn_map( plt.show() + def coordinates_on_image( df, overlay_data, diff --git a/src/spacec/plotting/_segmentation.py b/src/spacec/plotting/_segmentation.py index 5c96a39..479d3d4 100644 --- a/src/spacec/plotting/_segmentation.py +++ b/src/spacec/plotting/_segmentation.py @@ -13,7 +13,9 @@ def segmentation_ch( file_name, # image for segmentation channel_file, # all channels used for staining - output_dir, # + output_dir, + savefig=False, # new + output_fname="", # new extra_seg_ch_list=None, # channels used for membrane segmentation nuclei_channel="DAPI", input_format="Multichannel", # CODEX or Phenocycler --> This depends on the machine you are using and the resulting file format (see documentation above) @@ -68,7 +70,18 @@ def segmentation_ch( ax[1].imshow(image_dict["segmentation_channel"]) ax[0].set_title("nuclei") ax[1].set_title("membrane") - plt.show() + + # save or plot figure + if savefig: + plt.savefig( + output_dir + output_fname + ".pdf", + format="pdf", + dpi=300, + transparent=True, + bbox_inches="tight", + ) + else: + plt.show() def show_masks( diff --git a/src/spacec/tools/_general.py b/src/spacec/tools/_general.py index 9e42ba3..c5b52ca 100644 --- a/src/spacec/tools/_general.py +++ b/src/spacec/tools/_general.py @@ -2403,6 +2403,31 @@ def plot_selected_neighbors_with_shapes( plot=True, identification_column="community", ): + """ + Plot points and identify points within a specified radius of selected points, + highlighting points from different clusters with distinct shapes. + + Parameters + ---------- + full_df : pandas.DataFrame + The full dataset containing all points. + selected_df : pandas.DataFrame + The dataset containing selected points to be analyzed. + target_df : pandas.DataFrame + The dataset from which neighbors are to be identified. + radius : float + The radius within which to identify neighboring points. + plot : bool, optional + Whether to plot the results. The default is True. + identification_column : str, optional + The column name used to identify different clusters. The default is "community". + + Returns + ------- + pandas.DataFrame + A DataFrame containing points within the specified radius from the selected points, + excluding points from the same cluster as the selected point. + """ # Get unique clusters from the full DataFrame unique_clusters = full_df[identification_column].unique() @@ -2491,7 +2516,6 @@ def plot_selected_neighbors_with_shapes( return all_in_circle_diff_cluster - def process_cluster(args): ( df, @@ -2515,7 +2539,45 @@ def process_cluster(args): points[:, :2], length_threshold=concave_hull_length_threshold, ) + """ + Processes a single cluster by computing its concave hull, identifying nearest neighbors of the hull points, + and plotting selected neighbors within a specified radius, highlighting points from different clusters. + + Parameters + ---------- + args : tuple + A tuple containing the following parameters: + df : pandas.DataFrame + The DataFrame containing the data points. + cluster : int or str + The cluster identifier to process. + cluster_column : str + The column name in `df` that contains cluster identifiers. + x_column : str + The column name in `df` that contains the x-coordinates. + y_column : str + The column name in `df` that contains the y-coordinates. + concave_hull_length_threshold : float + The length threshold for the concave hull computation. + edge_neighbours : int + The number of neighbors to consider for each edge point of the hull. + full_df : pandas.DataFrame + The full DataFrame containing all points, used for neighbor identification. + radius : float + The radius within which to identify neighboring points. + plot : bool + Whether to plot the results. + identification_column : str + The column name used to identify different clusters. + Returns + ------- + tuple + A tuple containing two pandas.DataFrames: + - The first DataFrame contains points within the specified radius from the hull points, + excluding points from the same cluster. + - The second DataFrame contains the nearest neighbors of the hull points. + """ # Get hull points from the DataFrame hull_points = pd.DataFrame(points[idxes], columns=["x", "y"]) @@ -2540,7 +2602,6 @@ def process_cluster(args): return prox_points, hull_nearest_neighbors - def identify_points_in_proximity( df, full_df, @@ -2553,6 +2614,41 @@ def identify_points_in_proximity( plot=True, concave_hull_length_threshold=50, ): + """ + Identify points within a specified proximity and generate outlines based on clusters. + + Parameters + ---------- + df : DataFrame + The DataFrame containing the data points to be analyzed. + full_df : DataFrame + The full DataFrame from which `df` is derived, used for additional context or data not present in `df`. + identification_column : str + The name of the column in `df` used to uniquely identify each point. + cluster_column : str, optional + The name of the column in `df` that contains cluster identifiers. Default is "cluster". + x_column : str, optional + The name of the column in `df` that contains the x coordinates of each point. Default is "x". + y_column : str, optional + The name of the column in `df` that contains the y coordinates of each point. Default is "y". + radius : int, optional + The radius within which to identify points in proximity. Default is 200. + edge_neighbours : int, optional + The minimum number of neighbors a point must have to be considered an edge point. Default is 3. + plot : bool, optional + Whether to plot the results. Default is True. + concave_hull_length_threshold : int, optional + The threshold length used when generating the concave hull. Points with a greater length will not be included. Default is 50. + + Returns + ------- + tuple of DataFrame + A tuple containing two DataFrames. The first DataFrame contains the result of the proximity analysis, including the identified points and their respective cluster identifiers. The second DataFrame contains the outlines of the clusters identified in the analysis. + + Notes + ----- + This function utilizes multiprocessing to parallelize the analysis of clusters, aiming to improve performance on systems with multiple CPU cores. + """ num_processes = max( 1, os.cpu_count() - 2 ) # Use all available CPUs minus 2, but at least 1 @@ -3253,7 +3349,49 @@ def tm_viewer( include_masks=True, open_viewer=True, add_UMAP=True, -): + ): + """ + Prepares and visualizes tissue microscopy data using TissUUmaps. + + Parameters + ---------- + adata : AnnData + Annotated data matrix from scanpy, with observations (cells) as rows and variables (genes) as columns. + images_pickle_path : str + Path to the pickle file containing segmented images and masks. + directory : str + Directory path where the images and csv will be cashed. + region_column : str, optional + Column name in `adata.obs` that specifies the region, by default "unique_region". + region : str, optional + Specific region to process, by default "" processes all regions. + xSelector : str, optional + Column name to use as the x-coordinate for visualization, by default "x". + ySelector : str, optional + Column name to use as the y-coordinate for visualization, by default "y". + color_by : str, optional + Column name to color cells by in the visualization, by default "celltype_fine". + keep_list : list of str, optional + List of column names to keep from `adata.obs`, by default None keeps [region_column, xSelector, ySelector, color_by]. + include_masks : bool, optional + Whether to include cell masks in the visualization, by default True. + open_viewer : bool, optional + Whether to automatically open the TissUUmaps viewer, by default True. + add_UMAP : bool, optional + Whether to add UMAP coordinates to the output, by default True. + + Returns + ------- + tuple + A tuple containing two elements: + - A list of paths to the saved image files. + - A list of paths to the saved CSV files containing the data for visualization. + + Notes + ----- + This function requires the `scanpy` library for handling `adata`, and the `TissUUmaps` library for visualization. + It assumes that the `images_pickle_path` file contains a dictionary with keys "image_dict" and "masks" pointing to the respective data structures. + """ segmented_matrix = adata.obs with open(images_pickle_path, "rb") as f: @@ -3267,6 +3405,10 @@ def tm_viewer( print("Preparing TissUUmaps input...") + if directory == "": + print("No directory specified. Using current working directory.") + directory = os.getcwd() + cache_dir = pathlib.Path(directory) / region cache_dir.mkdir(parents=True, exist_ok=True) @@ -3511,6 +3653,30 @@ def anndata_to_CPU( def install_stellar(CUDA=12): + """ + Installs PyTorch and PyTorch Geometric along with their dependencies for a specific CUDA version. + + This function automates the installation of PyTorch, PyTorch Geometric, and related libraries + (torch_scatter, torch_sparse, torch_cluster, torch_spline_conv) optimized for the specified CUDA version. + It uses pip for installation, ensuring compatibility with the CUDA version provided. + + Parameters + ---------- + CUDA : float, optional + The CUDA version for which the libraries should be installed. Supported values are 12 and 11.8. + Defaults to 12. + + Raises + ------ + subprocess.CalledProcessError + If any of the pip installation commands fail, this error is raised, indicating the failure of the installation process. + + Notes + ----- + - The function checks the CUDA version and selects the appropriate PyTorch and PyTorch Geometric wheels for installation. + - It is assumed that pip and pip3 are available in the system's PATH. + - The function prints a message if an unsupported CUDA version is provided, directing the user to the official installation guides. + """ if CUDA == 12: subprocess.run(["pip3", "install", "torch"], check=True) subprocess.run(["pip", "install", "torch_geometric"], check=True) diff --git a/src/spacec/tools/_segmentation.py b/src/spacec/tools/_segmentation.py index c1b8cf3..4a39b68 100644 --- a/src/spacec/tools/_segmentation.py +++ b/src/spacec/tools/_segmentation.py @@ -14,6 +14,7 @@ from deepcell.utils.plot_utils import create_rgb_image, make_outline_overlay from skimage.measure import regionprops_table from tensorflow.keras.models import load_model +import tensorflow as tf from tqdm import tqdm from .._shared.segmentation import ( @@ -43,6 +44,7 @@ def cell_segmentation( model_path="./models", resize_factor=1, custom_model=False, + differentiate_nucleus_cytoplasm=False, # experimental feature! ): """ Perform cell segmentation on an image. @@ -82,11 +84,21 @@ def cell_segmentation( Whether to save the segmentation mask as a PNG file. Default is False. model_path : str, optional The path to the model. Default is './models'. + differentiate_nucleus_cytoplasm : bool, optional + Whether to differentiate between nucleus and cytoplasm. Default is False. Returns ------- dict A dictionary containing the original image ('img'), the segmentation masks ('masks'), and the image dictionary ('image_dict'). """ + if use_gpu == True: + gpus = tf.config.list_physical_devices('GPU') + + if gpus: + print("GPU(s) available") + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu,True) + print("Create image channels!") # check input format @@ -139,72 +151,242 @@ def cell_segmentation( ) # Replace the original image with the resized image in the dictionary segmentation_image_dict[channel] = resized_img - if seg_method == "mesmer": - print("Segmenting with Mesmer!") - if membrane_channel_list is None: - print( - "Mesmer expects two-channel images as input, where the first channel must be a nuclear channel (e.g. DAPI) and the second channel must be a membrane or cytoplasmic channel (e.g. E-Cadherin)." - ) - sys.exit("Please provide any membrane or cytoplasm channel!") + + if differentiate_nucleus_cytoplasm == True: + if membrane_channel_list == None: + print("Provide membrane channel for differentiation between nucleus and cytoplasm") + return else: - masks = mesmer_segmentation( - nuclei_image=segmentation_image_dict[nuclei_channel], - membrane_image=segmentation_image_dict["segmentation_channel"], - plot_predictions=plot_predictions, # plot segmentation results - compartment=compartment, - model_path=model_path, - ) # segment whole cells or nuclei only - else: - print("Segmenting with Cellpose!") - if membrane_channel_list is None: - masks, flows, styles, input_image, rgb_channels = cellpose_segmentation( - image_dict=segmentation_image_dict, - output_dir=output_dir, - membrane_channel=None, - cytoplasm_channel=cytoplasm_channel_list, - nucleus_channel=nuclei_channel, - use_gpu=use_gpu, - model=model, - custom_model=custom_model, - diameter=diameter, - save_mask_as_png=save_mask_as_png, + if seg_method == "mesmer": + print("Segmenting with Mesmer!") + + masks_nuclei = mesmer_segmentation( + nuclei_image=segmentation_image_dict[nuclei_channel], + membrane_image=None, + plot_predictions=plot_predictions, # plot segmentation results + compartment="nuclear", + model_path=model_path, + ) # segment whole cells or nuclei only + + masks_whole_cell = mesmer_segmentation( + nuclei_image=segmentation_image_dict[nuclei_channel], + membrane_image=segmentation_image_dict["segmentation_channel"], + plot_predictions=plot_predictions, # plot segmentation results + compartment=compartment, + model_path=model_path, + ) # segment whole cells or nuclei only + else: + print("Segmenting with Cellpose!") + + masks_nuclei, flows, styles, input_image, rgb_channels = cellpose_segmentation( + image_dict=segmentation_image_dict, + output_dir=output_dir, + membrane_channel=None, + cytoplasm_channel=cytoplasm_channel_list, + nucleus_channel=nuclei_channel, + use_gpu=use_gpu, + model=model, + custom_model=custom_model, + diameter=diameter, + save_mask_as_png=save_mask_as_png, + ) + + masks_whole_cell, flows, styles, input_image, rgb_channels = cellpose_segmentation( + image_dict=segmentation_image_dict, + output_dir=output_dir, + membrane_channel="segmentation_channel", + cytoplasm_channel=cytoplasm_channel_list, + nucleus_channel=nuclei_channel, + use_gpu=use_gpu, + model=model, + custom_model=custom_model, + diameter=diameter, + save_mask_as_png=save_mask_as_png, + ) + + + # Remove single-dimensional entries from the shape of segmentation_masks + masks_whole_cell = masks_whole_cell.squeeze() + # Get the original dimensions of any one of the images + original_height, original_width = image_dict[ + nuclei_channel + ].shape # or any other channel + # Resize the masks back to the original size + masks_whole_cell = cv2.resize( + masks_whole_cell, (original_width, original_height), interpolation=cv2.INTER_NEAREST ) - else: - masks, flows, styles, input_image, rgb_channels = cellpose_segmentation( - image_dict=segmentation_image_dict, - output_dir=output_dir, - membrane_channel="segmentation_channel", - cytoplasm_channel=cytoplasm_channel_list, - nucleus_channel=nuclei_channel, - use_gpu=use_gpu, - model=model, - custom_model=custom_model, - diameter=diameter, - save_mask_as_png=save_mask_as_png, + + # Remove single-dimensional entries from the shape of segmentation_masks + masks_nuclei = masks_nuclei.squeeze() + # Get the original dimensions of any one of the images + original_height, original_width = image_dict[ + nuclei_channel + ].shape # or any other channel + # Resize the masks back to the original size + masks_nuclei = cv2.resize( + masks_nuclei, (original_width, original_height), interpolation=cv2.INTER_NEAREST ) - # Remove single-dimensional entries from the shape of segmentation_masks - masks = masks.squeeze() - # Get the original dimensions of any one of the images - original_height, original_width = image_dict[ - nuclei_channel - ].shape # or any other channel - # Resize the masks back to the original size - masks = cv2.resize( - masks, (original_width, original_height), interpolation=cv2.INTER_NEAREST - ) - print("Quantifying features after segmentation!") - extract_features( - image_dict=image_dict, # image dictionary - segmentation_masks=masks, # segmentation masks generated by cellpose - channels_to_quantify=channel_names, # list of channels to quantify (here: all channels) - output_file=pathlib.Path(output_dir) - / ( - output_fname + "_" + seg_method + "_result.csv" - ), # output path to store results as csv - size_cutoff=size_cutoff, - ) # size cutoff for segmentation masks (default = 0) - print("Done!") - return {"img": img, "masks": masks, "image_dict": image_dict} + + # Create binary masks + binary_masks_nuclei = masks_nuclei > 0 + binary_masks_whole_cell = masks_whole_cell > 0 + + # Subtract the binary nuclei mask from the binary whole cell mask + binary_masks_cytoplasm = binary_masks_whole_cell & ~binary_masks_nuclei + + # Now, if you want to get a labeled mask for the cytoplasm, you can use a function like `label` from `scipy.ndimage` + from scipy.ndimage import label + masks_cytoplasm, num_labels = label(binary_masks_cytoplasm) + + print("Quantifying features after segmentation!") + print("Quantifying features nuclei") + nuc= extract_features( + image_dict=image_dict, # image dictionary + segmentation_masks=masks_nuclei, # segmentation masks generated by cellpose + channels_to_quantify=channel_names, # list of channels to quantify (here: all channels) + output_file=pathlib.Path(output_dir) + / ( + output_fname + "_" + seg_method + "_nuclei_result.csv" + ), # output path to store results as csv + size_cutoff=size_cutoff, + ) # size cutoff for segmentation masks (default = 0) + print("Quantifying features cytoplasm") + cyto= extract_features( + image_dict=image_dict, # image dictionary + segmentation_masks=masks_cytoplasm, # segmentation masks generated by cellpose + channels_to_quantify=channel_names, # list of channels to quantify (here: all channels) + output_file=pathlib.Path(output_dir) + / ( + output_fname + "_" + seg_method + "_cytoplasm_result.csv" + ), # output path to store results as csv + size_cutoff=size_cutoff, + ) # size cutoff for segmentation masks (default = 0) + + print("Quantifying features whole cell") + whole= extract_features( + image_dict=image_dict, # image dictionary + segmentation_masks=masks_whole_cell, # segmentation masks generated by cellpose + channels_to_quantify=channel_names, # list of channels to quantify (here: all channels) + output_file=pathlib.Path(output_dir) + / ( + output_fname + "_" + seg_method + "_whole_cell_result.csv" + ), # output path to store results as csv + size_cutoff=size_cutoff, + ) # size cutoff for segmentation masks (default = 0) + print("Done!") + + + # remove + out = ["x", + "y", + "eccentricity", + "perimeter", + "convex_area", + "area", + "axis_major_length", + "axis_minor_length", + "label",] + + # keep metadata + whole_meta = whole[out] + + # remove from nuc + nuc = nuc.drop(out, axis=1) + # add whole metadata to cyto + nuc_save = pd.concat([nuc, whole_meta], axis=1) + nuc_save.to_csv(output_dir + output_fname + "_" + seg_method + "_nuclei_intensities_result.csv") + # remove from cyto + cyto = cyto.drop(out, axis=1) + # add whole metadata to cyto + cyto_save = pd.concat([cyto, whole_meta], axis=1) + cyto_save.to_csv(output_dir + output_fname + "_" + seg_method + "_cytoplasm_intensities_result.csv") + + + whole.to_csv(output_dir + output_fname + "_" + seg_method + "_whole_cell_intensities_result.csv") + whole = whole.drop(out, axis=1) + + # add identifier to each column name + nuc.columns = [str(col) + '_nuc' for col in nuc.columns] + cyto.columns = [str(col) + '_cyto' for col in cyto.columns] + whole.columns = [str(col) + '_whole' for col in whole.columns] + + # combine the dataframes and save as csv + result = pd.concat([nuc, cyto, whole, whole_meta], axis=1) + result = result.loc[:, ~result.columns.str.contains('Unnamed: 0')] + result.to_csv(output_dir + output_fname + "_" + seg_method + "_segmentation_results_combined.csv") + + return {"img": img, "masks": masks_whole_cell, "image_dict": image_dict, "masks_cytoplasm": masks_cytoplasm, "masks_nuclei": masks_nuclei,} + + else: + if seg_method == "mesmer": + print("Segmenting with Mesmer!") + if membrane_channel_list is None: + masks = mesmer_segmentation( + nuclei_image=segmentation_image_dict[nuclei_channel], + membrane_image=None, + plot_predictions=plot_predictions, # plot segmentation results + compartment="nuclear", + model_path=model_path, + ) # segment whole cells or nuclei only + else: + masks = mesmer_segmentation( + nuclei_image=segmentation_image_dict[nuclei_channel], + membrane_image=segmentation_image_dict["segmentation_channel"], + plot_predictions=plot_predictions, # plot segmentation results + compartment=compartment, + model_path=model_path, + ) # segment whole cells or nuclei only + else: + print("Segmenting with Cellpose!") + if membrane_channel_list is None: + masks, flows, styles, input_image, rgb_channels = cellpose_segmentation( + image_dict=segmentation_image_dict, + output_dir=output_dir, + membrane_channel=None, + cytoplasm_channel=cytoplasm_channel_list, + nucleus_channel=nuclei_channel, + use_gpu=use_gpu, + model=model, + custom_model=custom_model, + diameter=diameter, + save_mask_as_png=save_mask_as_png, + ) + else: + masks, flows, styles, input_image, rgb_channels = cellpose_segmentation( + image_dict=segmentation_image_dict, + output_dir=output_dir, + membrane_channel="segmentation_channel", + cytoplasm_channel=cytoplasm_channel_list, + nucleus_channel=nuclei_channel, + use_gpu=use_gpu, + model=model, + custom_model=custom_model, + diameter=diameter, + save_mask_as_png=save_mask_as_png, + ) + # Remove single-dimensional entries from the shape of segmentation_masks + masks = masks.squeeze() + # Get the original dimensions of any one of the images + original_height, original_width = image_dict[ + nuclei_channel + ].shape # or any other channel + # Resize the masks back to the original size + masks = cv2.resize( + masks, (original_width, original_height), interpolation=cv2.INTER_NEAREST + ) + print("Quantifying features after segmentation!") + extract_features( + image_dict=image_dict, # image dictionary + segmentation_masks=masks, # segmentation masks generated by cellpose + channels_to_quantify=channel_names, # list of channels to quantify (here: all channels) + output_file=pathlib.Path(output_dir) + / ( + output_fname + "_" + seg_method + "_result.csv" + ), # output path to store results as csv + size_cutoff=size_cutoff, + ) # size cutoff for segmentation masks (default = 0) + print("Done!") + return {"img": img, "masks": masks, "image_dict": image_dict} def extract_features( @@ -291,6 +473,8 @@ def extract_features( # Export to CSV markers.to_csv(output_file) + + return markers def cellpose_segmentation( @@ -622,7 +806,13 @@ def mesmer_segmentation( # Create a combined image stack # Assumes nuclei_image and membrane_image are numpy arrays of the same shape - combined_image = np.stack([nuclei_image, membrane_image], axis=-1) + if membrane_image is None: + # generate empty membrane image + print("No membrane image provided. Nuclear segmentation only.") + membrane_image = np.zeros_like(nuclei_image) + combined_image = np.stack([nuclei_image, membrane_image], axis=-1) + else: + combined_image = np.stack([nuclei_image, membrane_image], axis=-1) # Add an extra dimension to make it compatible with Mesmer's input requirements # Changes shape from (height, width, channels) to (1, height, width, channels)