Skip to content

Commit

Permalink
Fix task chain for Det -> Cls / Seg (#4105)
Browse files Browse the repository at this point in the history
* fix linter

* return recipe back

* added roi extraction for multi cllass classification datasett

* fix linter

* add same logic to semantic seg

* added test for OTXDataset

* add clip and raise an error when coordinates are invalid.

* rewrite value error
  • Loading branch information
kprokofi authored Nov 8, 2024
1 parent 88ab4b8 commit 844fc2e
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 35 deletions.
2 changes: 1 addition & 1 deletion src/otx/algo/utils/xai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _get_image_data_name(
subset = datamodule.subsets[subset_name]
item = subset.dm_subset[img_id]
img = item.media_as(Image)
img_data, _ = subset._get_img_data_and_shape(img) # noqa: SLF001
img_data, _, _ = subset._get_img_data_and_shape(img) # noqa: SLF001
image_save_name = "".join([char if char.isalnum() else "_" for char in item.id])
return img_data, image_save_name

Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_item_impl(
datumaro_item = self.dm_subset[index]
img = datumaro_item.media_as(Image)
# returns image in RGB format if self.image_color_channel is RGB
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label = self._get_label(datumaro_item)

Expand Down
58 changes: 48 additions & 10 deletions src/otx/core/data/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from abc import abstractmethod
from collections.abc import Iterable
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Generic, Iterator, List, Union
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterator, List, Union

import cv2
import numpy as np
Expand Down Expand Up @@ -92,6 +92,7 @@ def __init__(
self.image_color_channel = image_color_channel
self.stack_images = stack_images
self.to_tv_image = to_tv_image

if self.dm_subset.categories():
self.label_info = LabelInfo.from_dm_label_groups(self.dm_subset.categories()[AnnotationType.label])
else:
Expand Down Expand Up @@ -141,11 +142,31 @@ def __getitem__(self, index: int) -> T_OTXDataEntity:
msg = f"Reach the maximum refetch number ({self.max_refetch})"
raise RuntimeError(msg)

def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, int]]:
def _get_img_data_and_shape(
self,
img: Image,
roi: dict[str, Any] | None = None,
) -> tuple[np.ndarray, tuple[int, int], dict[str, Any] | None]:
"""Get image data and shape.
This method is used to get image data and shape from Datumaro image object.
If ROI is provided, the image data is extracted from the ROI.
Args:
img (Image): Image object from Datumaro.
roi (dict[str, Any] | None, Optional): Region of interest.
Represented by dict with coordinates and some meta information.
Returns:
The image data, shape, and ROI meta information
"""
key = img.path if isinstance(img, ImageFromFile) else id(img)
roi_meta = None

if (img_data := self.mem_cache_handler.get(key=key)[0]) is not None:
return img_data, img_data.shape[:2]
# check if the image is already in the cache
img_data, roi_meta = self.mem_cache_handler.get(key=key)
if img_data is not None:
return img_data, img_data.shape[:2], roi_meta

with image_decode_context():
img_data = (
Expand All @@ -158,11 +179,28 @@ def _get_img_data_and_shape(self, img: Image) -> tuple[np.ndarray, tuple[int, in
msg = "Cannot get image data"
raise RuntimeError(msg)

img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8))
if roi:
# extract ROI from image
shape = roi["shape"]
h, w = img_data.shape[:2]
x1, y1, x2, y2 = (
int(np.clip(np.trunc(shape["x1"] * w), 0, w)),
int(np.clip(np.trunc(shape["y1"] * h), 0, h)),
int(np.clip(np.ceil(shape["x2"] * w), 0, w)),
int(np.clip(np.ceil(shape["y2"] * h), 0, h)),
)
if (x2 - x1) * (y2 - y1) <= 0:
msg = f"ROI has zero or negative area. ROI coordinates: {x1}, {y1}, {x2}, {y2}"
raise ValueError(msg)

img_data = img_data[y1:y2, x1:x2]
roi_meta = {"x1": x1, "y1": y1, "x2": x2, "y2": y2, "orig_image_shape": (h, w)}

img_data = self._cache_img(key=key, img_data=img_data.astype(np.uint8), meta=roi_meta)

return img_data, img_data.shape[:2]
return img_data, img_data.shape[:2], roi_meta

def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
def _cache_img(self, key: str | int, img_data: np.ndarray, meta: dict[str, Any] | None = None) -> np.ndarray:
"""Cache an image after resizing.
If there is available space in the memory pool, the input image is cached.
Expand All @@ -182,14 +220,14 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
return img_data

if self.mem_cache_img_max_size is None:
self.mem_cache_handler.put(key=key, data=img_data, meta=None)
self.mem_cache_handler.put(key=key, data=img_data, meta=meta)
return img_data

height, width = img_data.shape[:2]
max_height, max_width = self.mem_cache_img_max_size

if height <= max_height and width <= max_width:
self.mem_cache_handler.put(key=key, data=img_data, meta=None)
self.mem_cache_handler.put(key=key, data=img_data, meta=meta)
return img_data

# Preserve the image size ratio and fit to max_height or max_width
Expand All @@ -206,7 +244,7 @@ def _cache_img(self, key: str | int, img_data: np.ndarray) -> np.ndarray:
self.mem_cache_handler.put(
key=key,
data=resized_img,
meta=None,
meta=meta,
)
return resized_img

Expand Down
28 changes: 14 additions & 14 deletions src/otx/core/data/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@ class OTXMulticlassClsDataset(OTXDataset[MulticlassClsDataEntity]):
def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
img_data, img_shape = self._get_img_data_and_shape(img)
roi = item.attributes.get("roi", None)
img_data, img_shape, _ = self._get_img_data_and_shape(img, roi)
if roi:
# extract labels from ROI
labels_ids = [
label["label"]["_id"] for label in roi["labels"] if label["label"]["domain"] == "CLASSIFICATION"
]
label_anns = [self.label_info.label_names.index(label_id) for label_id in labels_ids]
else:
# extract labels from annotations
label_anns = [ann.label for ann in item.annotations if isinstance(ann, Label)]

label_anns = []
for ann in item.annotations:
if isinstance(ann, Label):
label_anns.append(ann)
else:
# If the annotation is not Label, it should be converted to Label.
# For Chained Task: Detection (Bbox) -> Classification (Label)
label = Label(label=ann.label)
if label not in label_anns:
label_anns.append(label)
if len(label_anns) > 1:
msg = f"Multi-class Classification can't use the multi-label, currently len(labels) = {len(label_anns)}"
raise ValueError(msg)
Expand All @@ -56,7 +56,7 @@ def _get_item_impl(self, index: int) -> MulticlassClsDataEntity | None:
ori_shape=img_shape,
image_color_channel=self.image_color_channel,
),
labels=torch.as_tensor([ann.label for ann in label_anns]),
labels=torch.as_tensor(label_anns),
)

return self._apply_transforms(entity)
Expand All @@ -78,7 +78,7 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label_anns = []
for ann in item.annotations:
Expand Down Expand Up @@ -195,7 +195,7 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

label_anns = []
for ann in item.annotations:
Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _get_item_impl(self, index: int) -> DetDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]

Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_item_impl(self, index: int) -> InstanceSegDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = []
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], []

Expand Down
2 changes: 1 addition & 1 deletion src/otx/core/data/dataset/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _get_item_impl(self, index: int) -> KeypointDetDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = [] # This should be assigned form item
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]
bboxes = (
Expand Down
9 changes: 7 additions & 2 deletions src/otx/core/data/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,14 @@ def _get_item_impl(self, index: int) -> SegDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(Image)
ignored_labels: list[int] = []
img_data, img_shape = self._get_img_data_and_shape(img)
roi = item.attributes.get("roi", None)
img_data, img_shape, roi_meta = self._get_img_data_and_shape(img, roi)
if item.annotations:
extracted_mask = _extract_class_mask(item=item, img_shape=img_shape, ignore_index=self.ignore_index)
ori_shape = roi_meta["orig_image_shape"] if roi_meta else img_shape
extracted_mask = _extract_class_mask(item=item, img_shape=ori_shape, ignore_index=self.ignore_index)
if roi_meta:
extracted_mask = extracted_mask[roi_meta["y1"] : roi_meta["y2"], roi_meta["x1"] : roi_meta["x2"]]

masks = tv_tensors.Mask(extracted_mask[None])
else:
# semi-supervised learning, unlabeled dataset
Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/dataset/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def _get_item_impl(self, index: int) -> TileDetDataEntity: # type: ignore[overr
"""
item = self.dm_subset[index]
img = item.media_as(Image)
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

bbox_anns = [ann for ann in item.annotations if isinstance(ann, Bbox)]

Expand Down Expand Up @@ -461,7 +461,7 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
"""
item = self.dm_subset[index]
img = item.media_as(Image)
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], []

Expand Down
4 changes: 2 additions & 2 deletions src/otx/core/data/dataset/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
def _get_item_impl(self, index: int) -> VisualPromptingDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(dmImage)
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

gt_bboxes, gt_points = [], []
gt_masks = defaultdict(list)
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
def _get_item_impl(self, index: int) -> ZeroShotVisualPromptingDataEntity | None:
item = self.dm_subset[index]
img = item.media_as(dmImage)
img_data, img_shape = self._get_img_data_and_shape(img)
img_data, img_shape, _ = self._get_img_data_and_shape(img)

gt_prompts: list[tvBoundingBoxes | Points] = []
gt_masks: list[tvMask] = []
Expand Down
104 changes: 104 additions & 0 deletions tests/unit/core/data/dataset/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from unittest import mock

import numpy as np
import pytest
from datumaro.components.media import Image
from otx.core.data.dataset.base import OTXDataset


class TestOTXDataset:
@pytest.fixture()
def mock_image(self) -> Image:
img = mock.Mock(spec=Image)
img.data = np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8)
img.path = "test_path"
return img

@pytest.fixture()
def mock_mem_cache_handler(self):
mem_cache_handler = mock.MagicMock()
mem_cache_handler.frozen = False
return mem_cache_handler

@pytest.fixture()
def otx_dataset(self, mock_mem_cache_handler):
class MockOTXDataset(OTXDataset):
def _get_item_impl(self, idx: int) -> None:
return None

@property
def collate_fn(self) -> None:
return None

dm_subset = mock.Mock()
dm_subset.categories = mock.MagicMock()
dm_subset.categories.return_value = None

return MockOTXDataset(
dm_subset=dm_subset,
transforms=None,
mem_cache_handler=mock_mem_cache_handler,
mem_cache_img_max_size=None,
)

def test_get_img_data_and_shape_no_cache(self, otx_dataset, mock_image, mock_mem_cache_handler):
mock_mem_cache_handler.get.return_value = (None, None)
img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image)
assert img_data.shape == (10, 10, 3)
assert img_shape == (10, 10)
assert roi_meta is None

def test_get_img_data_and_shape_with_cache(self, otx_dataset, mock_image, mock_mem_cache_handler):
mock_mem_cache_handler.get.return_value = (np.random.randint(0, 256, (10, 10, 3), dtype=np.uint8), None)
img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image)
assert img_data.shape == (10, 10, 3)
assert img_shape == (10, 10)
assert roi_meta is None

def test_get_img_data_and_shape_with_roi(self, otx_dataset, mock_image, mock_mem_cache_handler):
roi = {"shape": {"x1": 0.1, "y1": 0.1, "x2": 0.9, "y2": 0.9}}
mock_mem_cache_handler.get.return_value = (None, None)
img_data, img_shape, roi_meta = otx_dataset._get_img_data_and_shape(mock_image, roi)
assert img_data.shape == (8, 8, 3)
assert img_shape == (8, 8)
assert roi_meta == {"x1": 1, "y1": 1, "x2": 9, "y2": 9, "orig_image_shape": (10, 10)}

def test_cache_img_no_resize(self, otx_dataset):
img_data = np.random.randint(0, 256, (50, 50, 3), dtype=np.uint8)
key = "test_key"

cached_img = otx_dataset._cache_img(key, img_data)

assert np.array_equal(cached_img, img_data)
otx_dataset.mem_cache_handler.put.assert_called_once_with(key=key, data=img_data, meta=None)

def test_cache_img_with_resize(self, otx_dataset, mock_mem_cache_handler):
otx_dataset.mem_cache_img_max_size = (100, 100)
img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8)
key = "test_key"

cached_img = otx_dataset._cache_img(key, img_data)

assert cached_img.shape == (100, 100, 3)
mock_mem_cache_handler.put.assert_called_once()
assert mock_mem_cache_handler.put.call_args[1]["data"].shape == (100, 100, 3)

def test_cache_img_no_max_size(self, otx_dataset, mock_mem_cache_handler):
otx_dataset.mem_cache_img_max_size = None
img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8)
key = "test_key"

cached_img = otx_dataset._cache_img(key, img_data)

assert np.array_equal(cached_img, img_data)
mock_mem_cache_handler.put.assert_called_once_with(key=key, data=img_data, meta=None)

def test_cache_img_frozen_handler(self, otx_dataset, mock_mem_cache_handler):
mock_mem_cache_handler.frozen = True
img_data = np.random.randint(0, 256, (200, 200, 3), dtype=np.uint8)
key = "test_key"

cached_img = otx_dataset._cache_img(key, img_data)

assert np.array_equal(cached_img, img_data)
mock_mem_cache_handler.put.assert_not_called()

0 comments on commit 844fc2e

Please sign in to comment.