diff --git a/src/nimbus_inference/nimbus.py b/src/nimbus_inference/nimbus.py index a394bff..ae5464c 100644 --- a/src/nimbus_inference/nimbus.py +++ b/src/nimbus_inference/nimbus.py @@ -18,19 +18,30 @@ import re -def prep_naming_convention(deepcell_output_dir): +def prep_naming_convention(deepcell_output_dir, approx=False): """Prepares the naming convention for the segmentation data produced with the DeepCell library. Args: deepcell_output_dir (str): path to directory where segmentation data is saved + approx (bool): whether to use an approximate naming convention Returns: function: function that returns the path to the segmentation data for a given fov """ - def segmentation_naming_convention(fov_path): """Prepares the path to the segmentation data for a given fov + Args: + fov_path (str): path to fov + Returns: + str: paths to segmentation fovs + """ + fov_name = os.path.basename(fov_path).replace(".ome.tiff", "") + return os.path.join(deepcell_output_dir, fov_name + "_whole_cell.tiff") + + def segmentation_naming_convention_approx(fov_path): + """Prepares the path to the segmentation data for a given fov + Args: fov_path (str): path to fov Returns: @@ -48,7 +59,11 @@ def segmentation_naming_convention(fov_path): if len(fnames) > 1: raise ValueError(f"Multiple segmentation data found for fov {fov_name}") return fnames[0] - return segmentation_naming_convention + + if approx: + return segmentation_naming_convention_approx + else: + return segmentation_naming_convention class Nimbus(nn.Module):