Skip to content

Commit

Permalink
add functionality to complete an incomplete segmentation filtering wi…
Browse files Browse the repository at this point in the history
…thout needing to recalculate values for completed tiles
  • Loading branch information
sophiamaedler committed Jan 10, 2024
1 parent cc64062 commit 5cbfa7a
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 4 deletions.
135 changes: 131 additions & 4 deletions src/sparcscore/pipeline/filter_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from multiprocessing import Pool
import shutil
import pandas as pd
from collections import defaultdict

import traceback

Expand Down Expand Up @@ -168,6 +169,26 @@ def initialize_tile_list(self, tileing_plan, input_path):

return _tile_list

def initialize_tile_list_incomplete(self, tileing_plan, incomplete_indexes, input_path):
_tile_list = []

self.input_path = input_path

for i, window in zip(incomplete_indexes, tileing_plan):
local_tile_directory = os.path.join(self.tile_directory, str(i))
current_tile = self.method(
self.config,
local_tile_directory,
project_location = self.project_location,
debug=self.debug,
overwrite=self.overwrite,
intermediate_output=self.intermediate_output,
)
current_tile.initialize_as_tile(i, window, self.input_path, zarr_status = False)
_tile_list.append(current_tile)

return _tile_list

def calculate_tileing_plan(self, mask_size):
#save tileing plan to file
tileing_plan_path = f"{self.directory}/tileing_plan.csv"
Expand Down Expand Up @@ -270,10 +291,31 @@ def resolve_tileing(self, tileing_plan):

#perform sanity check that no cytosol_id is listed twice
filtered_classes_combined = {int(k): int(v) for k, v in (s.split(":") for s in filtered_classes_combined)}

#only perform this is if this check fails (otherwise more computationally expensive)
if len(filtered_classes_combined.values()) != len(set(filtered_classes_combined.values())):
print(pd.Series(filtered_classes_combined.values()).value_counts())
print(filtered_classes_combined)
sys.exit("Duplicate values found. Some issues with filtering. Please contact the developers.")

#remove all entries where a cytosol_id is assigned to several nuclei
#check to ensure that only one nucleus_id is assigned to each cytosol_id
cytosol_count = defaultdict(int)

# Count the occurrences of each cytosol value
for cytosol in filtered_classes_combined.values():
cytosol_count[cytosol] += 1

# Find cytosol values assigned to more than one nucleus and remove from dictionary
multi_nucleated_nulceus_ids = []

for nucleus, cytosol in filtered_classes_combined.items():
if cytosol_count[cytosol] > 1:
multi_nucleated_nulceus_ids.append(nucleus)

#remove entries from dictionary
# this needs to be put into a seperate loop because otherwise the dictionary size changes during loop and this throws an error
for nucleus in multi_nucleated_nulceus_ids:
del filtered_classes_combined[nucleus]

self.log(f"Found several nuclei {len(multi_nucleated_nulceus_ids)} assigned to the same cytosol. All of these entries were removed.")

# save newly generated class list to file
filtered_path = os.path.join(self.directory, self.DEFAULT_FILTER_FILE)
Expand Down Expand Up @@ -343,4 +385,89 @@ def process(self, input_path):
self.resolve_tileing(tileing_plan)

#make sure to cleanup temp directories
self.log("=== finished filtering === ")
self.log("=== finished filtering === ")

def complete_segmentation(self, input_path):

self.tile_directory = os.path.join(self.directory, self.DEFAULT_TILES_FOLDER)

if not os.path.isdir(self.tile_directory):
sys.exit("No tile Directory found for the given project. Can not complete a segmentation filter which has not started. Please rerun the segmentation filter method.")

#check to see which tiles are incomplete
tile_directories = os.listdir(self.shard_directory)
incomplete_indexes=[]

for tile in tile_directories:
if not os.path.isfile(f"{self.tile_directory}/{tile}/filtered_classes.csv"):
incomplete_indexes.append(int(tile))
self.log(f"Tile with ID {tile} not completed.")

# calculate tileing plan
with h5py.File(input_path, "r") as hf:
self.mask_size = hf["labels"].shape[1:]

if self.config["tile_size"] >= np.prod(self.mask_size):
target_size = self.config["tile_size"]
self.log(f"target size {target_size} is equal or larger to input mask {np.prod(self.mask_size)}. Tileing will not be used.")

tileing_plan = [
(slice(0, self.mask_size[0]), slice(0, self.mask_size[1]))
]

else:
target_size = self.config["tile_size"]
self.log(f"target size {target_size} is smaller than input mask {np.prod(self.mask_size)}. Tileing will be used.")

#read tileing plan from file
with open(f"{self.directory}/tileing_plan.csv", "r") as f:
tileing_plan = [eval(line) for line in f.readlines()]

self.log(f"Tileing plan read from file {self.directory}/tileing_plan.csv")

#check to make sure that calculated sharding plan matches to existing sharding results
if len(tileing_plan) != len(tile_directories):
sys.exit("Calculated a different number of tiles than found tile directories. This indicates a mismatch between the current loaded config file and the config file used to generate the exisiting partial segmentation. Please rerun the complete segmentation to ensure accurate results.")

#select only those tiles that did not complete successfully for further processing
tileing_plan_complete = tileing_plan

if len(incomplete_indexes) == 0:
if os.path.isfile(f"{self.directory}/filtered_classes.csv"):
self.log("Segmentation filtering already done.")
else:
self.log("Segmentation filtering on individual tiles already completed. Unifying results of individual tiles.")
self.resolve_sharding(tileing_plan_complete)

#make sure to cleanup temp directories
self.log("=== finished filtering === ")
else:
tileing_plan = [tile for i, tile in enumerate(tileing_plan) if i in incomplete_indexes]
self.log(f"Adjusted tileing plan to only proceed with the {len(incomplete_indexes)} incomplete tiles.")

tile_list = self.initialize_tile_list_incomplete(tileing_plan, incomplete_indexes)

self.log(
f"tileing plan with {len(tileing_plan)} elements generated, tileing with {self.config['threads']} threads begins"
)

with Pool(processes=self.config['threads']) as pool:
results = list(
tqdm(
pool.imap(self.method.call_as_tile, tile_list),
total=len(tile_list),
)
)
pool.close()
pool.join()
print("All Filtering Steps are done.", flush=True)

#free up memory
del tile_list
gc.collect()

self.log("Finished tiled filtering.")
self.resolve_tileing(tileing_plan)

#make sure to cleanup temp directories
self.log("=== finished filtering === ")
12 changes: 12 additions & 0 deletions src/sparcscore/pipeline/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,18 @@ def filter_segmentation(self, *args, **kwargs):
input_segmentation = self.segmentation_f.get_output()
self.segmentation_filtering_f(input_segmentation, *args, **kwargs)

def complete_filter_segmentation(self, *args, **kwargs):

"""complete an aborted or failed segmentation filtering run.
"""
self.log("completing incomplete segmentation filtering")

if self.segmentation_filtering_f is None:
raise ValueError("No filtering method for refining segmentation masks defined.")

input_segmentation = self.segmentation_f.get_output()
self.segmentation_filtering_f.complete_filter_segmentation(input_segmentation, *args, **kwargs)

def extract(self, *args, **kwargs):
"""
Extract single cells with the defined extraction method.
Expand Down

0 comments on commit 5cbfa7a

Please sign in to comment.