From 5cbfa7aea53a5556c3dde3e21b8a9cf57c93fbd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophia=20M=C3=A4dler?= <15019107+sophiamaedler@users.noreply.github.com> Date: Wed, 10 Jan 2024 17:43:00 +0100 Subject: [PATCH] add functionality to complete an incomplete segmentation filtering without needing to recalculate values for completed tiles --- .../pipeline/filter_segmentation.py | 135 +++++++++++++++++- src/sparcscore/pipeline/project.py | 12 ++ 2 files changed, 143 insertions(+), 4 deletions(-) diff --git a/src/sparcscore/pipeline/filter_segmentation.py b/src/sparcscore/pipeline/filter_segmentation.py index 5d497d9..aba1e26 100644 --- a/src/sparcscore/pipeline/filter_segmentation.py +++ b/src/sparcscore/pipeline/filter_segmentation.py @@ -5,6 +5,7 @@ from multiprocessing import Pool import shutil import pandas as pd +from collections import defaultdict import traceback @@ -168,6 +169,26 @@ def initialize_tile_list(self, tileing_plan, input_path): return _tile_list + def initialize_tile_list_incomplete(self, tileing_plan, incomplete_indexes, input_path): + _tile_list = [] + + self.input_path = input_path + + for i, window in zip(incomplete_indexes, tileing_plan): + local_tile_directory = os.path.join(self.tile_directory, str(i)) + current_tile = self.method( + self.config, + local_tile_directory, + project_location = self.project_location, + debug=self.debug, + overwrite=self.overwrite, + intermediate_output=self.intermediate_output, + ) + current_tile.initialize_as_tile(i, window, self.input_path, zarr_status = False) + _tile_list.append(current_tile) + + return _tile_list + def calculate_tileing_plan(self, mask_size): #save tileing plan to file tileing_plan_path = f"{self.directory}/tileing_plan.csv" @@ -270,10 +291,31 @@ def resolve_tileing(self, tileing_plan): #perform sanity check that no cytosol_id is listed twice filtered_classes_combined = {int(k): int(v) for k, v in (s.split(":") for s in filtered_classes_combined)} + + #only perform this is if this check fails (otherwise more computationally expensive) if len(filtered_classes_combined.values()) != len(set(filtered_classes_combined.values())): - print(pd.Series(filtered_classes_combined.values()).value_counts()) - print(filtered_classes_combined) - sys.exit("Duplicate values found. Some issues with filtering. Please contact the developers.") + + #remove all entries where a cytosol_id is assigned to several nuclei + #check to ensure that only one nucleus_id is assigned to each cytosol_id + cytosol_count = defaultdict(int) + + # Count the occurrences of each cytosol value + for cytosol in filtered_classes_combined.values(): + cytosol_count[cytosol] += 1 + + # Find cytosol values assigned to more than one nucleus and remove from dictionary + multi_nucleated_nulceus_ids = [] + + for nucleus, cytosol in filtered_classes_combined.items(): + if cytosol_count[cytosol] > 1: + multi_nucleated_nulceus_ids.append(nucleus) + + #remove entries from dictionary + # this needs to be put into a seperate loop because otherwise the dictionary size changes during loop and this throws an error + for nucleus in multi_nucleated_nulceus_ids: + del filtered_classes_combined[nucleus] + + self.log(f"Found several nuclei {len(multi_nucleated_nulceus_ids)} assigned to the same cytosol. All of these entries were removed.") # save newly generated class list to file filtered_path = os.path.join(self.directory, self.DEFAULT_FILTER_FILE) @@ -343,4 +385,89 @@ def process(self, input_path): self.resolve_tileing(tileing_plan) #make sure to cleanup temp directories - self.log("=== finished filtering === ") \ No newline at end of file + self.log("=== finished filtering === ") + + def complete_segmentation(self, input_path): + + self.tile_directory = os.path.join(self.directory, self.DEFAULT_TILES_FOLDER) + + if not os.path.isdir(self.tile_directory): + sys.exit("No tile Directory found for the given project. Can not complete a segmentation filter which has not started. Please rerun the segmentation filter method.") + + #check to see which tiles are incomplete + tile_directories = os.listdir(self.shard_directory) + incomplete_indexes=[] + + for tile in tile_directories: + if not os.path.isfile(f"{self.tile_directory}/{tile}/filtered_classes.csv"): + incomplete_indexes.append(int(tile)) + self.log(f"Tile with ID {tile} not completed.") + + # calculate tileing plan + with h5py.File(input_path, "r") as hf: + self.mask_size = hf["labels"].shape[1:] + + if self.config["tile_size"] >= np.prod(self.mask_size): + target_size = self.config["tile_size"] + self.log(f"target size {target_size} is equal or larger to input mask {np.prod(self.mask_size)}. Tileing will not be used.") + + tileing_plan = [ + (slice(0, self.mask_size[0]), slice(0, self.mask_size[1])) + ] + + else: + target_size = self.config["tile_size"] + self.log(f"target size {target_size} is smaller than input mask {np.prod(self.mask_size)}. Tileing will be used.") + + #read tileing plan from file + with open(f"{self.directory}/tileing_plan.csv", "r") as f: + tileing_plan = [eval(line) for line in f.readlines()] + + self.log(f"Tileing plan read from file {self.directory}/tileing_plan.csv") + + #check to make sure that calculated sharding plan matches to existing sharding results + if len(tileing_plan) != len(tile_directories): + sys.exit("Calculated a different number of tiles than found tile directories. This indicates a mismatch between the current loaded config file and the config file used to generate the exisiting partial segmentation. Please rerun the complete segmentation to ensure accurate results.") + + #select only those tiles that did not complete successfully for further processing + tileing_plan_complete = tileing_plan + + if len(incomplete_indexes) == 0: + if os.path.isfile(f"{self.directory}/filtered_classes.csv"): + self.log("Segmentation filtering already done.") + else: + self.log("Segmentation filtering on individual tiles already completed. Unifying results of individual tiles.") + self.resolve_sharding(tileing_plan_complete) + + #make sure to cleanup temp directories + self.log("=== finished filtering === ") + else: + tileing_plan = [tile for i, tile in enumerate(tileing_plan) if i in incomplete_indexes] + self.log(f"Adjusted tileing plan to only proceed with the {len(incomplete_indexes)} incomplete tiles.") + + tile_list = self.initialize_tile_list_incomplete(tileing_plan, incomplete_indexes) + + self.log( + f"tileing plan with {len(tileing_plan)} elements generated, tileing with {self.config['threads']} threads begins" + ) + + with Pool(processes=self.config['threads']) as pool: + results = list( + tqdm( + pool.imap(self.method.call_as_tile, tile_list), + total=len(tile_list), + ) + ) + pool.close() + pool.join() + print("All Filtering Steps are done.", flush=True) + + #free up memory + del tile_list + gc.collect() + + self.log("Finished tiled filtering.") + self.resolve_tileing(tileing_plan) + + #make sure to cleanup temp directories + self.log("=== finished filtering === ") \ No newline at end of file diff --git a/src/sparcscore/pipeline/project.py b/src/sparcscore/pipeline/project.py index c49f354..bb0dfbf 100644 --- a/src/sparcscore/pipeline/project.py +++ b/src/sparcscore/pipeline/project.py @@ -629,6 +629,18 @@ def filter_segmentation(self, *args, **kwargs): input_segmentation = self.segmentation_f.get_output() self.segmentation_filtering_f(input_segmentation, *args, **kwargs) + def complete_filter_segmentation(self, *args, **kwargs): + + """complete an aborted or failed segmentation filtering run. + """ + self.log("completing incomplete segmentation filtering") + + if self.segmentation_filtering_f is None: + raise ValueError("No filtering method for refining segmentation masks defined.") + + input_segmentation = self.segmentation_f.get_output() + self.segmentation_filtering_f.complete_filter_segmentation(input_segmentation, *args, **kwargs) + def extract(self, *args, **kwargs): """ Extract single cells with the defined extraction method.