Skip to content

Commit

Permalink
Merge pull request #103 from Kohulan/improved_line_detection
Browse files Browse the repository at this point in the history
Improved line detection
  • Loading branch information
OBrink authored Dec 13, 2023
2 parents 2e0b78f + 411d663 commit c990944
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 41 deletions.
99 changes: 83 additions & 16 deletions DECIMER_Segmentation_notebook.ipynb

Large diffs are not rendered by default.

153 changes: 131 additions & 22 deletions decimer_segmentation/complete_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu
from skimage.morphology import binary_erosion
from skimage.morphology import binary_erosion, binary_dilation
from typing import List, Tuple
from scipy.ndimage import label

Expand Down Expand Up @@ -98,8 +98,7 @@ def detect_horizontal_and_vertical_lines(
) -> np.ndarray:
"""
This function takes an image and returns a binary mask that labels the pixels that
are part of long horizontal or vertical lines. [Definition of long: 1/5 of the
width/height of the image].
are part of long horizontal or vertical lines.
Args:
image (np.ndarray): binarised image (np.array; type bool) as it is returned by
Expand Down Expand Up @@ -130,11 +129,86 @@ def detect_horizontal_and_vertical_lines(
return horizontal_mask + vertical_mask


def find_equidistant_points(
x1: int,
y1: int,
x2: int,
y2: int,
num_points: int = 5
) -> List[Tuple[int, int]]:
"""
Finds equidistant points between two points.
Args:
x1 (int): x coordinate of first point
y1 (int): y coordinate of first point
x2 (int): x coordinate of second point
y2 (int): y coordinate of second point
num_points (int, optional): Number of points to return. Defaults to 5.
Returns:
List[Tuple[int, int]]: Equidistant points on the given line
"""
points = []
for i in range(num_points + 1):
t = i / num_points
x = x1 * (1 - t) + x2 * t
y = y1 * (1 - t) + y2 * t
points.append((x, y))
return points


def detect_lines(
image: np.ndarray,
max_depiction_size: Tuple[int, int],
segmentation_mask: np.ndarray
) -> np.ndarray:
"""
This function takes an image and returns a binary mask that labels the pixels that
are part of lines that are not part of chemical structures (like arrays, tables).
Args:
image (np.ndarray): binarised image (np.array; type bool) as it is returned by
binary_erosion() in complete_structure_mask()
max_depiction_size (Tuple[int, int]): height, width; used for thresholds
segmentation_mask (np.ndarray): Indicates whether or not a pixel is part of a
chemical structure depiction (shape: (height, width))
Returns:
np.ndarray: Exclusion mask that contains indices of pixels that are part of
horizontal or vertical lines
"""
image = ~image * 255
image = image.astype("uint8")
# Detect lines using the Hough Transform
lines = cv2.HoughLinesP(image,
1,
np.pi / 180,
threshold=5,
minLineLength=int(max(max_depiction_size)/4),
maxLineGap=10)
# Generate exclusion mask based on detected lines
exclusion_mask = np.zeros_like(image)
if lines is None:
return exclusion_mask
for line in lines:
x1, y1, x2, y2 = line[0]
# Check if any of the lines is in a chemical structure depiction
points = find_equidistant_points(x1, y1, x2, y2, num_points=7)
points_in_structure = False
for x, y in points[1:-1]:
if segmentation_mask[int(y), int(x)]:
points_in_structure = True
break
if points_in_structure:
continue
cv2.line(exclusion_mask, (x1, y1), (x2, y2), 255, 2)
return exclusion_mask


def expand_masks(
image_array: np.array,
seed_pixels: List[Tuple[int, int]],
mask_array: np.array,
exclusion_mask: np.array,
) -> np.array:
"""
This function generates a mask array where the given masks have been
Expand All @@ -144,20 +218,16 @@ def expand_masks(
image_array (np.array): array that represents an image (float values)
seed_pixels (List[Tuple[int, int]]): [(x, y), ...]
mask_array (np.array): MRCNN output; shape: (y, x, mask_index)
exclusion_mask (np.array]: indicates whether or not a pixel is excluded from
expansion
contour_expansion (bool, optional): Indicates whether or not to expand
from contours. Defaults to False.
Returns:
np.array: Expanded masks
"""
image_with_exclusion = np.invert(image_array) * np.invert(exclusion_mask)
labeled_array, _ = label(image_with_exclusion)
image_array = np.invert(image_array)
labeled_array, _ = label(image_array)
mask_array = np.zeros_like(image_array)
for seed_pixel in seed_pixels:
x, y = seed_pixel
if mask_array[y, x] or exclusion_mask[y, x]:
if mask_array[y, x]:
continue
label_value = labeled_array[y, x]
if label_value > 0:
Expand All @@ -176,7 +246,7 @@ def expansion_coordination(
seed_pixels = get_seeds(image_array,
mask_array,
exclusion_mask)
mask_array = expand_masks(image_array, seed_pixels, mask_array, exclusion_mask)
mask_array = expand_masks(image_array, seed_pixels, mask_array)
return mask_array


Expand All @@ -200,7 +270,8 @@ def complete_structure_mask(
image_array (np.array): input image
mask_array (np.array): shape: y, x, n where n is the amount of masks
max_depiction_size (Tuple[int, int]): height, width
debug (bool, optional): More verbose if True. Defaults to False.
debug (bool, optional): You get visualisations in a Jupyter Notebook if True.
Defaults to False.
Returns:
np.array: expanded mask array
Expand All @@ -209,13 +280,12 @@ def complete_structure_mask(
if mask_array.size != 0:
# Binarization of input image
binarized_image_array = binarize_image(image_array, threshold=0.72)
if debug:
plot_it(binarized_image_array)
# Apply dilation with a resolution-dependent kernel to the image
blur_factor = (
int(image_array.shape[1] / 185) if image_array.shape[1] / 185 >= 2 else 2
)
if debug:
plot_it(binarized_image_array)
# Define kernel and apply
kernel = np.ones((blur_factor, blur_factor))
blurred_image_array = binary_erosion(binarized_image_array, footprint=kernel)
if debug:
Expand All @@ -224,22 +294,61 @@ def complete_structure_mask(
split_mask_arrays = np.array(
[mask_array[:, :, index] for index in range(mask_array.shape[2])]
)
exclusion_mask = detect_horizontal_and_vertical_lines(
# Detect horizontal and vertical lines
horizontal_vertical_lines = detect_horizontal_and_vertical_lines(
blurred_image_array, max_depiction_size
)
# Run expansion the expansion
image_repeat = itertools.repeat(blurred_image_array, mask_array.shape[2])
exclusion_mask_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2])

hough_lines = detect_lines(
binarized_image_array,
max_depiction_size,
segmentation_mask=np.any(mask_array, axis=2).astype(np.bool)
)
hough_lines = binary_dilation(hough_lines, footprint=kernel)
exclusion_mask = horizontal_vertical_lines + hough_lines
image_with_exclusion = np.invert(
np.invert(blurred_image_array) * np.invert(exclusion_mask)
)
if debug:
plot_it(horizontal_vertical_lines)
plot_it(hough_lines)
plot_it(exclusion_mask)
plot_it(image_with_exclusion)
# Run expansion
image_repeat = itertools.repeat(image_with_exclusion, mask_array.shape[2])
exclusion_repeat = itertools.repeat(exclusion_mask, mask_array.shape[2])
# Faster with map function
expanded_split_mask_arrays = map(
expansion_coordination,
split_mask_arrays,
image_repeat,
exclusion_mask_repeat,
exclusion_repeat,
)
# Stack mask arrays to give the desired output format
# Filter duplicates and stack mask arrays to give the desired output format
expanded_split_mask_arrays = filter_duplicate_masks(expanded_split_mask_arrays)
mask_array = np.stack(expanded_split_mask_arrays, -1)
return mask_array
else:
print("No masks found.")
return mask_array


def filter_duplicate_masks(array_list: List[np.array]) -> List[np.array]:
"""
This function takes a list of arrays and returns a list of unique arrays.
Args:
array_list (List[np.array]): Masks
Returns:
List[np.array]: Unique masks
"""
seen = set()
unique_list = []
for arr in array_list:
# Convert the array to a hashable tuple
arr_tuple = tuple(arr.ravel())
if arr_tuple not in seen:
seen.add(arr_tuple)
unique_list.append(arr)
return unique_list
1 change: 1 addition & 0 deletions decimer_segmentation/decimer_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class InferenceConfig(moldetect.MolDetectConfig):
# Run detection on one image at a time
GPU_COUNT = 1
IMAGES_PER_GPU = 1
DETECTION_MIN_CONFIDENCE = 0.7


def segment_chemical_structures_from_file(
Expand Down
4 changes: 1 addition & 3 deletions tests/test_mask_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@ def test_expand_masks():
test_image_array,
test_seed_pixels,
test_mask_array,
np.zeros(test_image_array.shape, dtype=bool),
)
expected_result.all() == actual_result.all()
# assert expected_result.all() == actual_result.all()
assert expected_result.all() == actual_result.all()


def test_expansion_coordination():
Expand Down

0 comments on commit c990944

Please sign in to comment.