diff --git a/decimer_segmentation/complete_structure.py b/decimer_segmentation/complete_structure.py index ef5bceaf..4a593fd8 100644 --- a/decimer_segmentation/complete_structure.py +++ b/decimer_segmentation/complete_structure.py @@ -286,7 +286,11 @@ def get_mask_center(mask_array: np.array) -> Tuple[int, int]: return None, None -def get_seeds(image_array: np.array, mask_array: np.array) -> List[Tuple[int, int]]: +def get_seeds( + image_array: np.array, + mask_array: np.array, + exclusion_mask: np.array, +) -> List[Tuple[int, int]]: """ This function takes an array that represents an image and a mask. It returns a list of tuples with indices of seeds in the structure @@ -295,6 +299,7 @@ def get_seeds(image_array: np.array, mask_array: np.array) -> List[Tuple[int, in Args: image_array (np.array): Image mask_array (np.array): Mask + exclusion_mask (np.array): Exclusion mask Returns: List[Tuple[int, int]]: [(x,y), (x,y), ...] @@ -312,32 +317,36 @@ def get_seeds(image_array: np.array, mask_array: np.array) -> List[Tuple[int, in if not mask_array[y_center, x_center + n]: up = False if not image_array[y_center, x_center + n]: - seed_pixels.append((x_center + n, y_center)) - up = False + if not exclusion_mask[y_center, x_center + n]: + seed_pixels.append((x_center + n, y_center)) + up = False # Check for seeds below center if down: if x_center - n >= 0: if not mask_array[y_center, x_center - n]: down = False if not image_array[y_center, x_center - n]: - seed_pixels.append((x_center - n, y_center)) - down = False + if not exclusion_mask[y_center, x_center - n]: + seed_pixels.append((x_center - n, y_center)) + down = False # Check for seeds left from center if left: if y_center + n < image_array.shape[0]: if not mask_array[y_center + n, x_center]: left = False if not image_array[y_center + n, x_center]: - seed_pixels.append((x_center, y_center + n)) - left = False + if not exclusion_mask[y_center + n, x_center]: + seed_pixels.append((x_center, y_center + n)) + left = False # Check for seeds right from center if right: if y_center - n >= 0: if not mask_array[y_center - n, x_center]: right = False if not image_array[y_center - n, x_center]: - seed_pixels.append((x_center, y_center - n)) - right = False + if not exclusion_mask[y_center - n, x_center]: + seed_pixels.append((x_center, y_center - n)) + right = False return seed_pixels @@ -450,7 +459,9 @@ def expansion_coordination( the mask expansion. It returns the expanded mask. The purpose of this function is wrapping up the expansion procedure in a map function. """ - seed_pixels = get_seeds(image_array, mask_array) + seed_pixels = get_seeds(image_array, + mask_array, + exclusion_mask) if seed_pixels != []: mask_array = expand_masks(image_array, seed_pixels, mask_array, exclusion_mask) else: @@ -498,7 +509,7 @@ def complete_structure_mask( if mask_array.size != 0: # Binarization of input image - binarized_image_array = binarize_image(image_array, threshold=0.85) + binarized_image_array = binarize_image(image_array, threshold=0.72) # Apply dilation with a resolution-dependent kernel to the image blur_factor = ( int(image_array.shape[1] / 185) if image_array.shape[1] / 185 >= 2 else 2 diff --git a/decimer_segmentation/decimer_segmentation.py b/decimer_segmentation/decimer_segmentation.py index 8be2fa78..51fa79d1 100644 --- a/decimer_segmentation/decimer_segmentation.py +++ b/decimer_segmentation/decimer_segmentation.py @@ -114,6 +114,10 @@ def segment_chemical_structures( if len(segments) > 0: segments, bboxes = sort_segments_bboxes(segments, bboxes) + segments = [segment for segment in segments + if segment.shape[0] > 0 + if segment.shape[1] > 0] + return segments @@ -235,6 +239,7 @@ def get_expanded_masks(image: np.array) -> np.array: image_array=image, mask_array=masks, max_depiction_size=size, + debug=False ) return expanded_masks diff --git a/tests/test_mask_expansion.py b/tests/test_mask_expansion.py index c008eeb4..a7c2719a 100644 --- a/tests/test_mask_expansion.py +++ b/tests/test_mask_expansion.py @@ -117,10 +117,11 @@ def test_get_mask_center(): def test_get_seeds(): - test_image_array = np.array([(3, 2)]) - test_mask_array = np.array([(9, 5, 9)]) + test_image_array = np.array([[0, 1, 0],[1, 0, 1],[0, 1, 0]]) + test_mask_array = np.ones(test_image_array.shape) + exclusion_mask = np.zeros(test_image_array.shape) expected_result = [] - actual_result = get_seeds(test_image_array, test_mask_array) + actual_result = get_seeds(test_image_array, test_mask_array, exclusion_mask) for index in range(len(expected_result)): assert expected_result[index] == actual_result[index]