diff --git a/PostProcessSegmentationMask.py b/PostProcessSegmentationMask.py index 015521d..fa1c8cd 100644 --- a/PostProcessSegmentationMask.py +++ b/PostProcessSegmentationMask.py @@ -3,52 +3,10 @@ import cv2 import numpy as np -import skimage.measure -from skimage.morphology import remove_small_objects import scipy.ndimage as ndi -# from numba import jit - - -def remove_small_objects_from_image(image, min_size=100): - image_copy = image.copy() - image_copy[image > 0] = 1 - image_copy = image_copy.astype(np.bool) - removed_red_channel = remove_small_objects(image_copy, min_size=min_size).astype(np.uint8) - image[removed_red_channel == 0] = 0 - return image - - -def remove_background_noise(mask, mask_boundary): - labeled = skimage.measure.label(mask, background=0) - padding = 5 - for i in range(1, len(np.unique(labeled))): - component = np.zeros_like(mask) - component[labeled == i] = mask[labeled == i] - component_bound = np.zeros_like(mask_boundary) - component_bound[max(0, min(np.nonzero(component)[0]) - padding): min(mask_boundary.shape[1], max(np.nonzero(component)[0]) + padding), - max(0, min(np.nonzero(component)[1]) - padding): min(mask_boundary.shape[1], max(np.nonzero(component)[1]) + padding)] \ - = mask_boundary[max(0, min(np.nonzero(component)[0]) - padding): min(mask_boundary.shape[1], max(np.nonzero(component)[0]) + padding), - max(0, min(np.nonzero(component)[1]) - padding): min(mask_boundary.shape[1], max(np.nonzero(component)[1]) + padding)] - if len(np.nonzero(component_bound)[0]) < len(np.nonzero(component)[0]) / 3: - mask[labeled == i] = 0 - return mask - - -def remove_cell_noise(mask1, mask2): - labeled = skimage.measure.label(mask1, background=0) - padding = 2 - for i in range(1, len(np.unique(labeled))): - component = np.zeros_like(mask1) - component[labeled == i] = mask1[labeled == i] - component_bound = np.zeros_like(mask2) - component_bound[max(0, min(np.nonzero(component)[0]) - padding): min(mask2.shape[1], max(np.nonzero(component)[0]) + padding), - max(0, min(np.nonzero(component)[1]) - padding): min(mask2.shape[1], max(np.nonzero(component)[1]) + padding)] \ - = mask2[max(0, min(np.nonzero(component)[0]) - padding): min(mask2.shape[1], max(np.nonzero(component)[0]) + padding), - max(0, min(np.nonzero(component)[1]) - padding): min(mask2.shape[1], max(np.nonzero(component)[1]) + padding)] - if len(np.nonzero(component_bound)[0]) > len(np.nonzero(component)[0]) / 3: - mask1[labeled == i] = 0 - mask2[labeled == i] = 255 - return mask1, mask2 + +from deepliif.postprocessing import overlay, refine, remove_cell_noise, remove_background_noise, \ + remove_small_objects_from_image def align_seg_on_image(input_image, input_mask, output_image, thresh=100, noise_objects_size=100): @@ -58,9 +16,9 @@ def align_seg_on_image(input_image, input_mask, output_image, thresh=100, noise_ final_mask = orig_image.copy() processed_mask = np.zeros_like(orig_image) - red = seg_image[:,:,0] - blue = seg_image[:,:,2] - boundary = seg_image[:,:,1] + red = seg_image[:, :, 0] + blue = seg_image[:, :, 2] + boundary = seg_image[:, :, 1] boundary[boundary < thresh] = 0 @@ -98,7 +56,6 @@ def align_seg_on_image(input_image, input_mask, output_image, thresh=100, noise_ cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) cv2.drawContours(final_mask, contours, -1, (0, 0, 255), 2) - processed_mask[positive_cells > 0] = (0, 0, 255) processed_mask[negative_cells > 0] = (255, 0, 0) @@ -118,59 +75,11 @@ def align_seg_on_image2(input_image, input_mask, output_image, thresh=100, noise seg_image = cv2.cvtColor(cv2.imread(input_mask), cv2.COLOR_BGR2RGB) orig_image = cv2.cvtColor(cv2.imread(input_image), cv2.COLOR_BGR2RGB) - final_mask = orig_image.copy() - processed_mask = np.zeros_like(orig_image) - - red = seg_image[:,:,0] - blue = seg_image[:,:,2] - boundary = seg_image[:,:,1] - - boundary[boundary < thresh] = 0 - - positive_cells = np.zeros((seg_image.shape[0], seg_image.shape[1]), dtype=np.uint8) - negative_cells = np.zeros((seg_image.shape[0], seg_image.shape[1]), dtype=np.uint8) - - positive_cells[red > thresh] = 255 - positive_cells[boundary > thresh] = 0 - negative_cells[blue > thresh] = 255 - negative_cells[boundary > thresh] = 0 - - negative_cells[red >= blue] = 0 - positive_cells[blue > red] = 0 - - negative_cells = remove_small_objects_from_image(negative_cells, noise_objects_size) - negative_cells = ndi.binary_fill_holes(negative_cells).astype(np.uint8) - - positive_cells = remove_small_objects_from_image(positive_cells, noise_objects_size) - positive_cells = ndi.binary_fill_holes(positive_cells).astype(np.uint8) - - positive_cells = cv2.morphologyEx(positive_cells, cv2.MORPH_DILATE, kernel=np.ones((2, 2))) - negative_cells = cv2.morphologyEx(negative_cells, cv2.MORPH_DILATE, kernel=np.ones((2, 2))) - - boundary = cv2.morphologyEx(boundary, cv2.MORPH_ERODE, kernel=np.ones((2, 2))) - - contours, hierarchy = cv2.findContours(positive_cells, - cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) - cv2.drawContours(final_mask, contours, -1, (255, 0, 0), 2) + overlaid_mask = overlay(orig_image, seg_image, thresh, noise_objects_size) + cv2.imwrite(output_image, overlaid_mask) - contours, hierarchy = cv2.findContours(negative_cells, - cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) - cv2.drawContours(final_mask, contours, -1, (0, 0, 255), 2) - - - processed_mask[positive_cells > 0] = (0, 0, 255) - processed_mask[negative_cells > 0] = (255, 0, 0) - - contours, hierarchy = cv2.findContours(positive_cells, - cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) - cv2.drawContours(processed_mask, contours, -1, (0, 255, 0), 2) - - contours, hierarchy = cv2.findContours(negative_cells, - cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) - cv2.drawContours(processed_mask, contours, -1, (0, 255, 0), 2) - - cv2.imwrite(output_image, cv2.cvtColor(final_mask, cv2.COLOR_BGR2RGB)) - cv2.imwrite(output_image.replace('Overlaid', 'Refined'), processed_mask) + refined_mask = refine(orig_image, seg_image, thresh, noise_objects_size) + cv2.imwrite(output_image.replace('Overlaid', 'Refined'), refined_mask) def post_process_segmentation_mask(input_dir, seg_thresh=100, noise_object_size=100): @@ -178,14 +87,14 @@ def post_process_segmentation_mask(input_dir, seg_thresh=100, noise_object_size= for img in images: if '_fake_B_5.png' in img: align_seg_on_image2(os.path.join(input_dir, img.replace('_fake_B_5', '_real_A')), - os.path.join(input_dir, img), - os.path.join(input_dir, img.replace('_fake_B_5', '_Seg_Overlaid_')), - thresh=seg_thresh, noise_objects_size=noise_object_size) + os.path.join(input_dir, img), + os.path.join(input_dir, img.replace('_fake_B_5', '_Seg_Overlaid_')), + thresh=seg_thresh, noise_objects_size=noise_object_size) elif '_Seg.png' in img: align_seg_on_image2(os.path.join(input_dir, img.replace('_Seg', '')), - os.path.join(input_dir, img), - os.path.join(input_dir, img.replace('_Seg', '_SegOverlaid')), - thresh=seg_thresh, noise_objects_size=noise_object_size) + os.path.join(input_dir, img), + os.path.join(input_dir, img.replace('_Seg', '_SegOverlaid')), + thresh=seg_thresh, noise_objects_size=noise_object_size) if __name__ == '__main__': diff --git a/deepliif/postprocessing.py b/deepliif/postprocessing.py index 96638c6..abc90d6 100644 --- a/deepliif/postprocessing.py +++ b/deepliif/postprocessing.py @@ -1,4 +1,9 @@ +import cv2 from PIL import Image +import skimage.measure +from skimage.morphology import remove_small_objects +import numpy as np +import scipy.ndimage as ndi from deepliif.util import util @@ -21,4 +26,110 @@ def stitch(tiles, overlap_size): new_im.paste(img, (t.j * tile_size, t.i * tile_size)) - return new_im \ No newline at end of file + return new_im + + +def remove_small_objects_from_image(img, min_size=100): + image_copy = img.copy() + image_copy[img > 0] = 1 + image_copy = image_copy.astype(bool) + removed_red_channel = remove_small_objects(image_copy, min_size=min_size).astype(np.uint8) + img[removed_red_channel == 0] = 0 + + return img + + +def remove_background_noise(mask, mask_boundary): + labeled = skimage.measure.label(mask, background=0) + padding = 5 + for i in range(1, len(np.unique(labeled))): + component = np.zeros_like(mask) + component[labeled == i] = mask[labeled == i] + component_bound = np.zeros_like(mask_boundary) + component_bound[max(0, min(np.nonzero(component)[0]) - padding): min(mask_boundary.shape[1], + max(np.nonzero(component)[0]) + padding), + max(0, min(np.nonzero(component)[1]) - padding): min(mask_boundary.shape[1], + max(np.nonzero(component)[1]) + padding)] \ + = mask_boundary[max(0, min(np.nonzero(component)[0]) - padding): min(mask_boundary.shape[1], max( + np.nonzero(component)[0]) + padding), + max(0, min(np.nonzero(component)[1]) - padding): min(mask_boundary.shape[1], + max(np.nonzero(component)[1]) + padding)] + if len(np.nonzero(component_bound)[0]) < len(np.nonzero(component)[0]) / 3: + mask[labeled == i] = 0 + return mask + + +def remove_cell_noise(mask1, mask2): + labeled = skimage.measure.label(mask1, background=0) + padding = 2 + for i in range(1, len(np.unique(labeled))): + component = np.zeros_like(mask1) + component[labeled == i] = mask1[labeled == i] + component_bound = np.zeros_like(mask2) + component_bound[ + max(0, min(np.nonzero(component)[0]) - padding): min(mask2.shape[1], max(np.nonzero(component)[0]) + padding), + max(0, min(np.nonzero(component)[1]) - padding): min(mask2.shape[1], max(np.nonzero(component)[1]) + padding)] \ + = mask2[max(0, min(np.nonzero(component)[0]) - padding): min(mask2.shape[1], + max(np.nonzero(component)[0]) + padding), + max(0, min(np.nonzero(component)[1]) - padding): min(mask2.shape[1], + max(np.nonzero(component)[1]) + padding)] + if len(np.nonzero(component_bound)[0]) > len(np.nonzero(component)[0]) / 3: + mask1[labeled == i] = 0 + mask2[labeled == i] = 255 + return mask1, mask2 + + +def positive_negative_masks(mask, thresh=100, noise_objects_size=20): + positive_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8) + negative_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.uint8) + + red = mask[:, :, 0] + blue = mask[:, :, 2] + boundary = mask[:, :, 1] + + positive_mask[red > thresh] = 255 + positive_mask[boundary > thresh] = 0 + positive_mask[blue > red] = 0 + + negative_mask[blue > thresh] = 255 + negative_mask[boundary > thresh] = 0 + negative_mask[red >= blue] = 0 + + def inner(img): + img = remove_small_objects_from_image(img, noise_objects_size) + img = ndi.binary_fill_holes(img).astype(np.uint8) + + return cv2.morphologyEx(img, cv2.MORPH_DILATE, kernel=np.ones((2, 2))) + + return inner(positive_mask), inner(negative_mask) + + +def refine(img, seg_img, thresh=100, noise_objects_size=20): + positive_mask, negative_mask = positive_negative_masks(seg_img, thresh, noise_objects_size) + + refined_mask = np.zeros_like(img) + + refined_mask[positive_mask > 0] = (0, 0, 255) + refined_mask[negative_mask > 0] = (255, 0, 0) + + contours, _ = cv2.findContours(positive_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(refined_mask, contours, -1, (0, 255, 0), 2) + + contours, _ = cv2.findContours(negative_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(refined_mask, contours, -1, (0, 255, 0), 2) + + return refined_mask + + +def overlay(img, seg_img, thresh=100, noise_objects_size=20): + positive_mask, negative_mask = positive_negative_masks(seg_img, thresh, noise_objects_size) + + overlaid_mask = img.copy() + + contours, _ = cv2.findContours(positive_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(overlaid_mask, contours, -1, (0, 0, 255), 2) + + contours, _ = cv2.findContours(negative_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) + cv2.drawContours(overlaid_mask, contours, -1, (255, 0, 0), 2) + + return overlaid_mask diff --git a/deepliif/preprocessing.py b/deepliif/preprocessing.py index 6760f96..bd1ffb6 100644 --- a/deepliif/preprocessing.py +++ b/deepliif/preprocessing.py @@ -27,7 +27,6 @@ def output_size(img, tile_size): return round(img.width / tile_size) * tile_size, round(img.height / tile_size) * tile_size -@util.timeit def generate_tiles(img, tile_size, overlap_size): img = img.resize(output_size(img, tile_size)) @@ -47,7 +46,6 @@ def generate_tiles(img, tile_size, overlap_size): )).resize((tile_size, tile_size))) -@util.timeit def transform(img): return default_collate([transforms.Compose([ transforms.Lambda(lambda i: __make_power_2(i, base=4, method=Image.BICUBIC)),