Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved line detection #103

Merged
merged 8 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 83 additions & 16 deletions DECIMER_Segmentation_notebook.ipynb

Large diffs are not rendered by default.

153 changes: 131 additions & 22 deletions decimer_segmentation/complete_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu
from skimage.morphology import binary_erosion
from skimage.morphology import binary_erosion, binary_dilation
from typing import List, Tuple
from scipy.ndimage import label

Expand Down Expand Up @@ -98,8 +98,7 @@ def detect_horizontal_and_vertical_lines(
) -> np.ndarray:
"""
This function takes an image and returns a binary mask that labels the pixels that
are part of long horizontal or vertical lines. [Definition of long: 1/5 of the
width/height of the image].
are part of long horizontal or vertical lines.

Args:
image (np.ndarray): binarised image (np.array; type bool) as it is returned by
Expand Down Expand Up @@ -130,11 +129,86 @@ def detect_horizontal_and_vertical_lines(
return horizontal_mask + vertical_mask


def find_equidistant_points(
x1: int,
y1: int,
x2: int,
y2: int,
num_points: int = 5
) -> List[Tuple[int, int]]:
"""
Finds equidistant points between two points.

Args:
x1 (int): x coordinate of first point
y1 (int): y coordinate of first point
x2 (int): x coordinate of second point
y2 (int): y coordinate of second point
num_points (int, optional): Number of points to return. Defaults to 5.

Returns:
List[Tuple[int, int]]: Equidistant points on the given line
"""
points = []
for i in range(num_points + 1):
t = i / num_points
x = x1 * (1 - t) + x2 * t
y = y1 * (1 - t) + y2 * t
points.append((x, y))
return points


def detect_lines(
image: np.ndarray,
max_depiction_size: Tuple[int, int],
segmentation_mask: np.ndarray
) -> np.ndarray:
"""
This function takes an image and returns a binary mask that labels the pixels that
are part of lines that are not part of chemical structures (like arrays, tables).

Args:
image (np.ndarray): binarised image (np.array; type bool) as it is returned by
binary_erosion() in complete_structure_mask()
max_depiction_size (Tuple[int, int]): height, width; used for thresholds
segmentation_mask (np.ndarray): Indicates whether or not a pixel is part of a
chemical structure depiction (shape: (height, width))
Returns:
np.ndarray: Exclusion mask that contains indices of pixels that are part of
horizontal or vertical lines
"""
image = ~image * 255
image = image.astype("uint8")
# Detect lines using the Hough Transform
lines = cv2.HoughLinesP(image,
1,
np.pi / 180,
threshold=5,
minLineLength=int(max(max_depiction_size)/4),
maxLineGap=10)
# Generate exclusion mask based on detected lines
exclusion_mask = np.zeros_like(image)
if lines is None:
return exclusion_mask
for line in lines:
x1, y1, x2, y2 = line[0]
# Check if any of the lines is in a chemical structure depiction
points = find_equidistant_points(x1, y1, x2, y2, num_points=7)
points_in_structure = False
for x, y in points[1:-1]:
if segmentation_mask[int(y), int(x)]:
points_in_structure = True
break
if points_in_structure:
continue
cv2.line(exclusion_mask, (x1, y1), (x2, y2), 255, 2)
return exclusion_mask


def expand_masks(
image_array: np.array,
seed_pixels: List[Tuple[int, int]],
mask_array: np.array,
exclusion_mask: np.array,
) -> np.array:
"""
This function generates a mask array where the given masks have been
Expand All @@ -144,20 +218,16 @@ def expand_masks(
image_array (np.array): array that represents an image (float values)
seed_pixels (List[Tuple[int, int]]): [(x, y), ...]
mask_array (np.array): MRCNN output; shape: (y, x, mask_index)
exclusion_mask (np.array]: indicates whether or not a pixel is excluded from
expansion
contour_expansion (bool, optional): Indicates whether or not to expand
from contours. Defaults to False.

Returns:
np.array: Expanded masks
"""
image_with_exclusion = np.invert(image_array) * np.invert(exclusion_mask)
labeled_array, _ = label(image_with_exclusion)
image_array = np.invert(image_array)
labeled_array, _ = label(image_array)
mask_array = np.zeros_like(image_array)
for seed_pixel in seed_pixels:
x, y = seed_pixel
if mask_array[y, x] or exclusion_mask[y, x]:
if mask_array[y, x]:
continue
label_value = labeled_array[y, x]
if label_value > 0:
Expand All @@ -176,7 +246,7 @@ def expansion_coordination(
seed_pixels = get_seeds(image_array,
mask_array,
exclusion_mask)
mask_array = expand_masks(image_array, seed_pixels, mask_array, exclusion_mask)
mask_array = expand_masks(image_array, seed_pixels, mask_array)
return mask_array


Expand All @@ -200,7 +270,8 @@ def complete_structure_mask(
image_array (np.array): input image
mask_array (np.array): shape: y, x, n where n is the amount of masks
max_depiction_size (Tuple[int, int]): height, width
debug (bool, optional): More verbose if True. Defaults to False.
debug (bool, optional): You get visualisations in a Jupyter Notebook if True.
Defaults to False.

Returns:
np.array: expanded mask array
Expand All @@ -209,13 +280,12 @@ def complete_structure_mask(
if mask_array.size != 0:
# Binarization of input image
binarized_image_array = binarize_image(image_array, threshold=0.72)
if debug:
plot_it(binarized_image_array)
# Apply dilation with a resolution-dependent kernel to the image
blur_factor = (
int(image_array.shape[1] / 185) if image_array.shape[1] / 185 >= 2 else 2
)
if debug:
plot_it(binarized_image_array)
# Define kernel and apply
kernel = np.ones((blur_factor, blur_factor))
blurred_image_array = binary_erosion(binarized_image_array, footprint=kernel)
if debug:
Expand All @@ -224,22 +294,61 @@ def complete_structure_mask(
split_mask_arrays = np.array(
[mask_array[:, :, index] for index in range(mask_array.shape[2])]
)
exclusion_mask = detect_horizontal_and_vertical_lines(
# Detect horizontal and vertical lines
horizontal_vertical_lines = detect_horizontal_and_vertical_lines(
blurred_image_array, max_depiction_size
)
# Run expansion the expansion
image_repeat = itertools.repeat(blurred_image_array, mask_array.shape[2])
exclusion_mask_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2])

hough_lines = detect_lines(
binarized_image_array,
max_depiction_size,
segmentation_mask=np.any(mask_array, axis=2).astype(np.bool)
)
hough_lines = binary_dilation(hough_lines, footprint=kernel)
exclusion_mask = horizontal_vertical_lines + hough_lines
image_with_exclusion = np.invert(
np.invert(blurred_image_array) * np.invert(exclusion_mask)
)
if debug:
plot_it(horizontal_vertical_lines)
plot_it(hough_lines)
plot_it(exclusion_mask)
plot_it(image_with_exclusion)
# Run expansion
image_repeat = itertools.repeat(image_with_exclusion, mask_array.shape[2])
exclusion_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2])
# Faster with map function
expanded_split_mask_arrays = map(
expansion_coordination,
split_mask_arrays,
image_repeat,
exclusion_mask_repeat,
exclusion_repeat,
)
# Stack mask arrays to give the desired output format
# Filter duplicates and stack mask arrays to give the desired output format
expanded_split_mask_arrays = filter_duplicate_masks(expanded_split_mask_arrays)
mask_array = np.stack(expanded_split_mask_arrays, -1)
return mask_array
else:
print("No masks found.")
return mask_array


def filter_duplicate_masks(array_list: List[np.array]) -> List[np.array]:
"""
This function takes a list of arrays and returns a list of unique arrays.

Args:
array_list (List[np.array]): Masks

Returns:
List[np.array]: Unique masks
"""
seen = set()
unique_list = []
for arr in array_list:
# Convert the array to a hashable tuple
arr_tuple = tuple(arr.ravel())
if arr_tuple not in seen:
seen.add(arr_tuple)
unique_list.append(arr)
return unique_list
1 change: 1 addition & 0 deletions decimer_segmentation/decimer_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class InferenceConfig(moldetect.MolDetectConfig):
# Run detection on one image at a time
GPU_COUNT = 1
IMAGES_PER_GPU = 1
DETECTION_MIN_CONFIDENCE = 0.7


def segment_chemical_structures_from_file(
Expand Down
4 changes: 1 addition & 3 deletions tests/test_mask_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ def test_expand_masks():
test_image_array,
test_seed_pixels,
test_mask_array,
np.zeros(test_image_array.shape, dtype=bool),
)
expected_result.all() == actual_result.all()
# assert expected_result.all() == actual_result.all()
assert expected_result.all() == actual_result.all()


def test_expansion_coordination():
Expand Down
Loading