Skip to content

Commit

Permalink
adjust tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dominik737 committed Jan 28, 2025
1 parent 4c25b88 commit af3c493
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions tests/unittests/test_nodes/test_detection_config_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,19 @@
from depthai_nodes.nodes.utils.detection_config_generator import generate_script_content


def test_rvc3_unsupported():
@pytest.fixture
def resize_width():
return 256


@pytest.fixture
def resize_height():
return 256


def test_rvc3_unsupported(resize_width, resize_height):
with pytest.raises(ValueError, match="Unsupported"):
generate_script_content("rvc3")
generate_script_content("rvc3", resize_width, resize_height)


class ImageManipConfigV2(dai.ImageManipConfigV2):
Expand Down Expand Up @@ -159,8 +169,15 @@ def node_input_detections(node) -> List[dai.ImgDetections]:


@pytest.mark.parametrize("platform", ["rvc2", "rvc4"])
def test_passthrough(node, node_input_detections, node_input_frames, platform):
script = generate_script_content(platform)
def test_passthrough(
node,
node_input_detections,
node_input_frames,
platform,
resize_width,
resize_height,
):
script = generate_script_content(platform, resize_width, resize_height)
expected_frames = []
for frame, detections in zip(node_input_frames, node_input_detections):
for _ in detections.detections:
Expand All @@ -176,15 +193,23 @@ def test_passthrough(node, node_input_detections, node_input_frames, platform):

@pytest.mark.parametrize(("platform", "labels"), [("rvc2", [1]), ("rvc4", [1, 2])])
def test_label_validation(
node, node_input_detections, node_input_frames, platform, labels
node,
node_input_detections,
node_input_frames,
platform,
labels,
resize_width,
resize_height,
):
expected_frames: List[Frame] = []
for detections, frame in zip(node_input_detections, node_input_frames):
for detection in detections.detections:
if detection.label not in labels:
continue
expected_frames.append(frame)
script = generate_script_content(platform, valid_labels=labels)
script = generate_script_content(
platform, resize_width, resize_height, valid_labels=labels
)
try:
run_script(node, script)
except Warning:
Expand Down Expand Up @@ -216,7 +241,7 @@ def test_rvc4_output_size(node, resize):


@pytest.mark.parametrize("padding", [0, 0.1, 0.2, -0.1, -0.2])
def test_rvc2_crop(node, node_input_detections, padding):
def test_rvc2_crop(node, node_input_detections, padding, resize_width, resize_height):
expected_rects: List[dai.ImageManipConfig.CropRect] = []
for input_dets in node_input_detections:
for detection in input_dets.detections:
Expand All @@ -226,7 +251,9 @@ def test_rvc2_crop(node, node_input_detections, padding):
rect.ymin = max(detection.ymin - padding, 0)
rect.ymax = min(detection.ymax + padding, 1)
expected_rects.append(rect)
script = generate_script_content("rvc2", padding=padding)
script = generate_script_content(
"rvc2", resize_width, resize_height, padding=padding
)
try:
run_script(node, script)
except Warning:
Expand All @@ -243,7 +270,7 @@ def test_rvc2_crop(node, node_input_detections, padding):


@pytest.mark.parametrize("padding", [0, 0.1, 0.2, -0.1, -0.2])
def test_rvc4_crop(node, node_input_detections, padding):
def test_rvc4_crop(node, node_input_detections, padding, resize_width, resize_height):
ANGLE = 0
expected_rects: List[dai.RotatedRect] = []
for input_dets in node_input_detections:
Expand All @@ -256,7 +283,9 @@ def test_rvc4_crop(node, node_input_detections, padding):
rect.size.width = detection.xmax - detection.xmin + rect_padding
rect.size.height = detection.ymax - detection.ymin + rect_padding
expected_rects.append(rect)
script = generate_script_content("rvc4", padding=padding)
script = generate_script_content(
"rvc4", resize_width, resize_height, padding=padding
)
try:
run_script(node, script)
except Warning:
Expand Down

0 comments on commit af3c493

Please sign in to comment.