Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter OOB points while training #2061

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def linkcode_resolve(domain, info):
# These paths are either relative to html_static_path
# or fully qualified paths (eg. https://...)
html_css_files = [
'css/tabs.css',
"css/tabs.css",
]

# Custom sidebar templates, must be a dictionary that maps document names
Expand Down
22 changes: 16 additions & 6 deletions sleap/nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, List, Text
import sleap
from sleap.nn.config import InstanceCroppingConfig
from sleap.nn.data.utils import filter_oob_points


def find_instance_crop_size(
Expand Down Expand Up @@ -42,12 +43,21 @@ def find_instance_crop_size(
# Calculate crop size
min_crop_size_no_pad = min_crop_size - padding
max_length = 0.0
for inst in labels.user_instances:
pts = inst.points_array
pts *= input_scaling
max_length = np.maximum(max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0]))
max_length = np.maximum(max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1]))
max_length = np.maximum(max_length, min_crop_size_no_pad)
for lf in labels:
for inst in lf:
if isinstance(inst, sleap.PredictedInstance):
continue

pts = filter_oob_points(inst.numpy(), lf.image.shape[:2])

pts *= input_scaling
max_length: float = np.nanmax(
[max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])]
)
max_length: float = np.nanmax(
[max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])]
)
max_length: float = np.nanmax([max_length, min_crop_size_no_pad])

max_length += float(padding)
crop_size = np.math.ceil(max_length / float(maximum_stride)) * maximum_stride
Expand Down
63 changes: 43 additions & 20 deletions sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import attr
from typing import Text, Optional, List, Sequence, Union, Tuple
import sleap
from sleap.instance import Instance
from sleap.nn.data.utils import filter_oob_points


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -202,30 +204,51 @@ def py_fetch_lf(ind):
insts = lf.user_instances
else:
insts = lf.instances
insts = [inst for inst in insts if len(inst) > 0]
if self.with_track_only:
insts = [inst for inst in insts if inst.track is not None]
n_instances = len(insts)
n_nodes = len(insts[0].skeleton) if n_instances > 0 else 0

instances = np.full((n_instances, n_nodes, 2), np.nan, dtype="float32")
for i, instance in enumerate(insts):
instances[i] = instance.numpy()

skeleton_inds = np.array(
[self.labels.skeletons.index(inst.skeleton) for inst in insts]
).astype("int32")
track_inds = np.array(
[
self.tracks.index(inst.track) if inst.track is not None else -1
for inst in insts
]
).astype("int32")

instances = []

for inst in insts:

# Filter OOB
pts = filter_oob_points(inst.numpy(), raw_image_size[:2])

instance = Instance.from_numpy(pts, inst.skeleton, inst.track)
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved

if len(instance) > 0:

if self.with_track_only:
if instance.track is not None:
instances.append(instance)

else:
instances.append(instance)

n_instances = len(instances)
n_nodes = len(instances[0].skeleton) if n_instances > 0 else 0

insts = np.full((n_instances, n_nodes, 2), np.nan, dtype="float32")
track_inds = []
skeleton_inds = []
for i, instance in enumerate(instances):

track_inds.append(
self.tracks.index(instance.track)
if instance.track is not None
else -1
)

skeleton_inds.append(self.labels.skeletons.index(instance.skeleton))

insts[i] = instance.numpy()

track_inds = np.array(track_inds).astype("int32")
skeleton_inds = np.array(skeleton_inds).astype("int32")
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved

n_tracks = np.array(len(self.tracks)).astype("int32")
return (
raw_image,
raw_image_size,
instances,
insts,
video_ind,
frame_ind,
skeleton_inds,
Expand Down
10 changes: 10 additions & 0 deletions sleap/nn/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
from typing import Any, List, Tuple, Dict, Text, Optional


def filter_oob_points(pts: np.ndarray, img_hw: tuple) -> np.ndarray:
"""Convert negative/ out-of-boundary pts to NaNs."""
pts[pts < 0] = np.NaN
height, width = img_hw
pts[:, 0][pts[:, 0] > width - 1] = np.NaN
pts[:, 1][pts[:, 1] > height - 1] = np.NaN

return pts


def ensure_list(x: Any) -> List[Any]:
"""Convert the input into a list if it is not already."""
if not isinstance(x, list):
Expand Down
18 changes: 18 additions & 0 deletions tests/nn/data/test_instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@
from sleap.nn.config import InstanceCroppingConfig


def test_find_instance_crop_size(min_labels):
labels = min_labels.copy()
assert len(labels.labeled_frames[0].instances) == 2

crop_size = instance_cropping.find_instance_crop_size(labels)
assert crop_size == 74

assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes

labels[0].instances[1][0] = (390, 187.9) # exceeds img height
crop_size = instance_cropping.find_instance_crop_size(labels)
assert crop_size == 60

labels[0].instances[1][0] = (-100, 187.9) # exceeds img height
crop_size = instance_cropping.find_instance_crop_size(labels)
assert crop_size == 60
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved


def test_normalize_bboxes():
bbox = tf.convert_to_tensor([[0, 0, 3, 3]], tf.float32)
norm_bbox = instance_cropping.normalize_bboxes(bbox, 9, 9)
Expand Down
21 changes: 21 additions & 0 deletions tests/nn/data/test_providers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
import tensorflow as tf
from sleap.nn.system import use_cpu_only

Expand Down Expand Up @@ -68,6 +69,26 @@ def test_labels_reader_no_visible_points(min_labels):
assert len(labels_reader) == 0


@pytest.mark.parametrize(
"oob_point,test_case",
[((390, 187.9), "exceeding_image_dim"), ((-100, 187.9), "negative_coordinates")],
)
def test_labels_filter_oob_points(min_labels, oob_point, test_case):
# There should be two instances in the labels dataset
labels = min_labels.copy()
assert len(labels.labeled_frames[0].instances) == 2

assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes

labels[0].instances[0][0] = oob_point

labels_reader = providers.LabelsReader.from_user_instances(labels)
examples = list(iter(labels_reader.make_dataset()))
assert len(examples) == 1

assert all(np.isnan(examples[0]["instances"][0][0]))


def test_labels_reader_subset(min_labels):
labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0]])
assert len(labels) == 3
Expand Down
8 changes: 6 additions & 2 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def test_topdown_model(test_pipeline):
assert tuple(out["instance_peak_vals"].shape) == (8, 2, 2)
assert tuple(out["n_valid"].shape) == (8,)

assert (out["n_valid"] == [1, 1, 1, 2, 2, 2, 2, 2]).all()
assert (out["n_valid"] == [1, 1, 1, 2, 2, 2, 1, 1]).all()


def test_inference_layer():
Expand Down Expand Up @@ -2039,7 +2039,11 @@ def test_movenet_predictor(min_dance_labels, movenet_video):
[labels_pr[0][0].numpy(), labels_pr[1][0].numpy()], axis=0
)

np.testing.assert_allclose(points_gt, points_pr, atol=0.75)
assert_allclose(
points_gt[~np.isnan(points_gt).any(axis=1)],
points_pr[~np.isnan(points_gt).any(axis=1)],
atol=0.75,
)


@pytest.mark.parametrize(
Expand Down
24 changes: 24 additions & 0 deletions tests/nn/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,30 @@ def test_train_topdown(training_labels, cfg):
assert tuple(trainer.keras_model.outputs[0].shape) == (None, 96, 96, 2)


@pytest.mark.parametrize(
"oob_point,test_case",
[((390, 187.9), "exceeding_image_dim"), ((-100, 187.9), "negative_coordinates")],
)
def test_train_topdown_with_oob_pts(min_labels, cfg, oob_point, test_case):
# pt exceeding img dim
labels = min_labels
labels.append(
sleap.LabeledFrame(
video=labels.videos[0], frame_idx=1, instances=labels[0].instances
)
)
labels[0].instances[1][0] = oob_point # crop size=60

cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig(
sigma=1.5, output_stride=1, offset_refinement=False
)
trainer = TopdownConfmapsModelTrainer.from_config(cfg, training_labels=labels)
trainer.setup()
trainer.train()
assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead"
assert tuple(trainer.keras_model.outputs[0].shape) == (None, 80, 80, 2)


def test_train_topdown_with_offset(training_labels, cfg):
cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig(
sigma=1.5, output_stride=1, offset_refinement=True
Expand Down
Loading