Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter OOB points while training #2061

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def linkcode_resolve(domain, info):
# These paths are either relative to html_static_path
# or fully qualified paths (eg. https://...)
html_css_files = [
'css/tabs.css',
"css/tabs.css",
]

# Custom sidebar templates, must be a dictionary that maps document names
Expand Down
23 changes: 17 additions & 6 deletions sleap/nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,23 @@ def find_instance_crop_size(
# Calculate crop size
min_crop_size_no_pad = min_crop_size - padding
max_length = 0.0
for inst in labels.user_instances:
pts = inst.points_array
pts *= input_scaling
max_length = np.maximum(max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0]))
max_length = np.maximum(max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1]))
max_length = np.maximum(max_length, min_crop_size_no_pad)
for lf in labels:
for inst in lf.user_instances:
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
pts = inst.points_array

pts[pts < 0] = np.NaN
height, width = lf.image.shape[:2]
pts[:, 0][pts[:, 0] > width - 1] = np.NaN
pts[:, 1][pts[:, 1] > height - 1] = np.NaN

pts *= input_scaling
max_length = np.maximum(
max_length, np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
)
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved
max_length = np.maximum(
max_length, np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
)
max_length = np.maximum(max_length, min_crop_size_no_pad)

max_length += float(padding)
crop_size = np.math.ceil(max_length / float(maximum_stride)) * maximum_stride
Expand Down
19 changes: 19 additions & 0 deletions sleap/nn/data/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import attr
from typing import Text, Optional, List, Sequence, Union, Tuple
import sleap
from sleap.instance import Instance


@attr.s(auto_attribs=True)
Expand Down Expand Up @@ -198,6 +199,24 @@ def py_fetch_lf(ind):
raw_image = lf.image
raw_image_size = np.array(raw_image.shape).astype("int32")

height, width = raw_image_size[:2]

# Filter OOB points
instances = []
for instance in lf.instances:
pts = instance.numpy()
# negative coords
pts[pts < 0] = np.NaN

# coordinates outside img frame
pts[:, 0][pts[:, 0] > width - 1] = np.NaN
pts[:, 1][pts[:, 1] > height - 1] = np.NaN

instances.append(
Instance.from_numpy(pts, instance.skeleton, instance.track)
)
lf.instances = instances

if self.user_instances_only:
insts = lf.user_instances
else:
Expand Down
18 changes: 18 additions & 0 deletions tests/nn/data/test_instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@
from sleap.nn.config import InstanceCroppingConfig


def test_find_instance_crop_size(min_labels):
labels = min_labels.copy()
assert len(labels.labeled_frames[0].instances) == 2

crop_size = instance_cropping.find_instance_crop_size(labels)
assert crop_size == 74

assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes

labels[0].instances[1][0] = (390, 187.9) # exceeds img height
crop_size = instance_cropping.find_instance_crop_size(labels)
assert crop_size == 60

labels[0].instances[1][0] = (-100, 187.9) # exceeds img height
crop_size = instance_cropping.find_instance_crop_size(labels)
assert crop_size == 60
gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved


def test_normalize_bboxes():
bbox = tf.convert_to_tensor([[0, 0, 3, 3]], tf.float32)
norm_bbox = instance_cropping.normalize_bboxes(bbox, 9, 9)
Expand Down
30 changes: 30 additions & 0 deletions tests/nn/data/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,36 @@ def test_labels_reader_no_visible_points(min_labels):
assert len(labels_reader) == 0


def test_labels_filter_oob_points(min_labels):
# There should be two instances in the labels dataset
labels = min_labels.copy()
assert len(labels.labeled_frames[0].instances) == 2

assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes

labels[0].instances[0][0] = (390, 100) # exceeds img height

labels_reader = providers.LabelsReader.from_user_instances(labels)
examples = list(iter(labels_reader.make_dataset()))
assert len(examples) == 1

assert all(np.isnan(examples[0]["instances"][0][0]))

# test with negative keypoints
labels = min_labels.copy()
assert len(labels.labeled_frames[0].instances) == 2

assert labels[0].instances[0].numpy().shape[0] == 2 # 2 nodes

labels[0].instances[0][1] = (-10, 100)

eberrigan marked this conversation as resolved.
Show resolved Hide resolved
labels_reader = providers.LabelsReader.from_user_instances(labels)
examples = list(iter(labels_reader.make_dataset()))
assert len(examples) == 1

assert all(np.isnan(examples[0]["instances"][0][1]))


def test_labels_reader_subset(min_labels):
labels = sleap.Labels([min_labels[0], min_labels[0], min_labels[0]])
assert len(labels) == 3
Expand Down
8 changes: 6 additions & 2 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def test_topdown_model(test_pipeline):
assert tuple(out["instance_peak_vals"].shape) == (8, 2, 2)
assert tuple(out["n_valid"].shape) == (8,)

assert (out["n_valid"] == [1, 1, 1, 2, 2, 2, 2, 2]).all()
assert (out["n_valid"] == [1, 1, 1, 2, 2, 2, 1, 1]).all()


def test_inference_layer():
Expand Down Expand Up @@ -2039,7 +2039,11 @@ def test_movenet_predictor(min_dance_labels, movenet_video):
[labels_pr[0][0].numpy(), labels_pr[1][0].numpy()], axis=0
)

np.testing.assert_allclose(points_gt, points_pr, atol=0.75)
assert_allclose(
points_gt[~np.isnan(points_gt).any(axis=1)],
points_pr[~np.isnan(points_gt).any(axis=1)],
atol=0.75,
)


@pytest.mark.parametrize(
Expand Down
38 changes: 38 additions & 0 deletions tests/nn/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,44 @@ def test_train_topdown(training_labels, cfg):
assert tuple(trainer.keras_model.outputs[0].shape) == (None, 96, 96, 2)


def test_train_topdown_with_oob_pts(min_labels, cfg):
# pt exceeding img dim
labels = min_labels
labels.append(
sleap.LabeledFrame(
video=labels.videos[0], frame_idx=1, instances=labels[0].instances
)
)
labels[0].instances[1][0] = (390, 187.9) # crop size=60

cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig(
sigma=1.5, output_stride=1, offset_refinement=False
)
trainer = TopdownConfmapsModelTrainer.from_config(cfg, training_labels=labels)
trainer.setup()
trainer.train()
assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead"
assert tuple(trainer.keras_model.outputs[0].shape) == (None, 80, 80, 2)

# negative pts
labels = min_labels
labels.append(
sleap.LabeledFrame(
video=labels.videos[0], frame_idx=1, instances=labels[0].instances
)
)
labels[0].instances[1][0] = (-100, 187.9) # crop size=60

cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig(
sigma=1.5, output_stride=1, offset_refinement=False
)
trainer = TopdownConfmapsModelTrainer.from_config(cfg, training_labels=labels)
trainer.setup()
trainer.train()
assert trainer.keras_model.output_names[0] == "CenteredInstanceConfmapsHead"
assert tuple(trainer.keras_model.outputs[0].shape) == (None, 80, 80, 2)

gitttt-1234 marked this conversation as resolved.
Show resolved Hide resolved

def test_train_topdown_with_offset(training_labels, cfg):
cfg.model.heads.centered_instance = CenteredInstanceConfmapsHeadConfig(
sigma=1.5, output_stride=1, offset_refinement=True
Expand Down
Loading