Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(3e -> 3a) Add InstanceGroup class #1618

Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
cc4f1db
Add method to get single instance permutations
roomrys Nov 3, 2023
7fd89ec
Clean-up
roomrys Nov 3, 2023
b9af5c3
Add method and (failing) test to get instance grouping
Nov 15, 2023
d89adff
Rename functions and add a per instance error function
Nov 15, 2023
8268a7c
Append a dummy instance for missing instances
Nov 20, 2023
2fbaf1f
Merge branch 'liezl/add-method-to-get-single-instance-permutations' o…
Nov 20, 2023
f7fd938
Update tests to accept a dummy instance
Nov 21, 2023
b863c6b
Correct 'permutations' to 'products'
Nov 21, 2023
8b73638
Merge branch 'liezl/add-method-to-get-single-instance-permutations' o…
Nov 21, 2023
c5001ea
Rename last few permutations to products
Nov 21, 2023
bd0f32d
Adapt instance grouping from single to multi instance
Nov 21, 2023
7ee80b0
Add function to get error per view
Nov 27, 2023
70bdc67
Add initial InstanceGroup class
roomrys Nov 30, 2023
ad5fb9e
Lint
roomrys Nov 30, 2023
4037122
Lint
roomrys Nov 30, 2023
89550a1
Few extra tests for `InstanceGroup`
roomrys Nov 30, 2023
51f7fd3
Replace track-based triangulation with hypothesis-based
roomrys Dec 1, 2023
ab3b86d
Uncomment some commented out code (for testing)
roomrys Dec 1, 2023
509e592
Merge branch 'liezl/add-method-for-multi-instance-products' of https:…
roomrys Dec 1, 2023
5b40e9e
Fix typehinting, add comments
roomrys Dec 1, 2023
2d1e23f
Merge branch 'liezl/add-method-for-multi-instance-products' of https:…
roomrys Dec 1, 2023
c66154a
Remember instance grouping after testing hypotheses
roomrys Dec 2, 2023
d5d6a43
Fix failing tests
roomrys Dec 2, 2023
c33ccc0
Merge branch 'liezl/add-method-to-match-instances-across-views' of ht…
roomrys Dec 6, 2023
8ff1cc5
Typehinting
roomrys Dec 6, 2023
892863b
Merge branch 'liezl/add-method-to-get-single-instance-permutations' o…
roomrys Dec 6, 2023
1b21838
Typehinting
roomrys Dec 6, 2023
134a83a
Merge branch 'liezl/add-method-to-test-instance-grouping' of https://…
roomrys Dec 6, 2023
a3aca14
Typehinting
roomrys Dec 6, 2023
2e926d5
Merge branch 'liezl/add-method-for-multi-instance-products' of https:…
roomrys Dec 6, 2023
ae665ed
Typehinting
roomrys Dec 6, 2023
5964e5d
Use reconsumable iterator for reprojected coords
roomrys Dec 6, 2023
aac86d5
Catch race condition early
roomrys Dec 6, 2023
83aa558
Only triangulate user instances, add fixture, update tests
roomrys Dec 6, 2023
2feb9f0
Lint
roomrys Dec 6, 2023
38e37f2
Normalize instance reprojection errors
Dec 7, 2023
4a50c81
Merge branch 'liezl/add-method-to-match-instances-across-views' of ht…
roomrys Mar 20, 2024
9a2fb9a
Add `locked`, `_dummy_instance`, `numpy`, and `update_points`
roomrys Mar 21, 2024
8e0ceac
Allow `PredictedPoint`s to be updated as well
roomrys Mar 21, 2024
0072f7a
Add tests for new attributes and methods
roomrys Mar 21, 2024
01bbdad
Lint
roomrys Mar 21, 2024
17e734c
Add methods to create, add, replace, and remove instances
Mar 22, 2024
95071c1
Use PredictedInstance for new/dummy instances
Apr 11, 2024
0c057c1
(3f -> 3e) Add `FrameGroup` class (#1665)
roomrys Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class which inherits from `AppCommand` (or a more specialized class such as
import traceback
from enum import Enum
from glob import glob
from itertools import permutations, product
from pathlib import Path, PurePath
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Type, Union, cast

Expand Down Expand Up @@ -3750,6 +3751,210 @@ def get_all_views_at_frame(

return views

@staticmethod
def get_instance_grouping(
instances: Dict[int, Dict[Camcorder, List[Instance]]],
reprojection_error_per_frame: Dict[int, float],
) -> Dict[int, Dict[Camcorder, List[Instance]]]:
"""Get instance grouping for triangulation."""

frame_with_min_error = min(
reprojection_error_per_frame, key=reprojection_error_per_frame.get
)

best_instances = instances[frame_with_min_error]
best_instances_correct_format = {frame_with_min_error: best_instances}

return best_instances_correct_format

@staticmethod
def _calculate_reprojection_error(
session: RecordingSession,
instances: Dict[int, Dict[Camcorder, List[Instance]]],
per_instance: bool = False,
per_view: bool = False,
) -> Union[
Dict[int, float], Dict[int, Dict[Camcorder, List[Tuple[Instance, float]]]]
]:
"""Calculate reprojection error per frame or per instance.

Args:
session: The `RecordingSession` containing the `Camcorder`s.
instances: Dict with frame identifier keys (not the frame index) and values
of another inner dict with `Camcorder` keys and `List[Instance]` values.
per_instance: If True, then return a dict with frame identifier keys and
values of another inner dict with `Camcorder` keys and
`List[Tuple[Instance, float]]` values.
per_view: If True, then return a dict with frame identifier keys and values
of another inner dict with `Camcorder` keys and
`Tuple[Tuple[str, str], float]` values. If per_instance is True, then that takes precendence.

Returns:
Dict with frame identifier keys (not the frame index) and values of another
inner dict with `Camcorder` keys and `List[Tuple[Instance, float]]` values
if per_instance is True, otherwise a dict with frame identifier keys and
values of reprojection error for the frame.
"""

reprojection_error_per_frame = {}

# Triangulate and reproject instance coordinates.
instances_and_coords: Dict[
int, Dict[Camcorder, Iterator[Tuple[Instance, np.ndarray]]]
] = TriangulateSession.calculate_reprojected_points(
session=session, instances=instances
)
for frame_id, instances_in_frame in instances_and_coords.items():
frame_error = {} if per_instance or per_view else 0
for cam, instances_in_view in instances_in_frame.items():
# Compare instance coordinates here
instance_ids = []
view_error = [] if per_instance else 0
for inst, inst_coords in instances_in_view:
node_errors = np.nan_to_num(inst.numpy() - inst_coords)
instance_error = np.linalg.norm(node_errors)

if per_instance:
view_error.append((inst, instance_error))
else:
view_error += instance_error

inst_id = inst.track if inst.track is not None else "None"
instance_ids.append(inst_id)

if per_instance:
frame_error[cam] = view_error
elif per_view:
frame_error[cam] = (tuple(instance_ids), view_error)
else:
frame_error += view_error

reprojection_error_per_frame[frame_id] = frame_error

return reprojection_error_per_frame

@staticmethod
def calculate_error_per_instance(
session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]]
) -> Dict[int, float]:
"""Calculate reprojection error per instance."""

reprojection_error_per_instance = (
TriangulateSession._calculate_reprojection_error(
session=session, instances=instances, per_instance=True
)
)

return reprojection_error_per_instance

@staticmethod
def calculate_error_per_view(
session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]]
) -> Dict[int, float]:
"""Calculate reprojection error per instance."""

reprojection_error_per_view = TriangulateSession._calculate_reprojection_error(
session=session, instances=instances, per_view=True
)

return reprojection_error_per_view

@staticmethod
def calculate_error_per_frame(
session: RecordingSession, instances: Dict[int, Dict[Camcorder, List[Instance]]]
) -> Dict[int, float]:
"""Calculate reprojection error per frame."""

reprojection_error_per_frame = TriangulateSession._calculate_reprojection_error(
session=session, instances=instances, per_instance=False
)

return reprojection_error_per_frame

@staticmethod
def get_products_of_instances(
selected_instance: Instance,
session: RecordingSession,
frame_idx: int,
cams_to_include: Optional[List[Camcorder]] = None,
) -> Dict[int, Dict[Camcorder, List[Instance]]]:
"""Get all (single-instance) possible products of instances across views.

Args:
selected_instance: The `Instance` to add to permutations of instances in
views other than that of the `selected_instance`.
session: The `RecordingSession` containing the `Camcorder`s.
frame_idx: Frame index to get instances from (0-indexed).
cams_to_include: List of `Camcorder`s to include. Default is all.
require_multiple_views: If True, then raise and error if one or less views
or instances are found.

Raises:
ValueError if one or less views or instances are found.

Returns:
Dict with frame identifier keys (not the frame index) and values of another
inner dict with `Camcorder` keys and `List[Instance]` values. Each
`List[Instance]` is of length 1.
"""

cam_selected = session.get_camera(selected_instance.video)
cam_selected = cast(Camcorder, cam_selected) # Could be None if not in session

# Get all instances accross views at this frame index, then remove selected
instances: Dict[
Camcorder, List[Instance]
] = TriangulateSession.get_instances_across_views(
session=session,
frame_idx=frame_idx,
cams_to_include=cams_to_include,
track=-1, # Get all instances regardless of track.
require_multiple_views=True,
)

# Find max number of instances in other views
max_num_instances = max([len(instances) for instances in instances.values()])

# Create a dummy instance of all nan values
dummy_instance = Instance.from_numpy(
np.full(
shape=(len(selected_instance.skeleton.nodes), 2),
fill_value=np.nan,
),
skeleton=selected_instance.skeleton,
)

# Get permutations of instances from other views
instances_permutations: Dict[Camcorder, Iterator[Tuple]] = {}
for cam, instances_in_view in instances.items():
# Append a dummy instance to all lists of instances if less than the max length
num_missing = 1
num_instances = len(instances_in_view)
if num_instances < max_num_instances:
num_missing = max_num_instances - num_instances

# Extend the list first
instances_in_view.extend([dummy_instance] * num_missing)

# Permute instances into all possible orderings w/in a view
instances_permutations[cam] = permutations(instances_in_view)

# Get products of instances from other views into all possible groupings
# Ordering of dict_values is preserved in Python 3.7+
products_of_instances: Iterator[Iterator[Tuple]] = product(
*instances_permutations.values()
)

# Reorganize products by cam and add selected instance to each permutation
instances_hypotheses: Dict[int, Dict[Camcorder, List[Instance]]] = {}
for frame_id, prod in enumerate(products_of_instances):
instances_hypotheses[frame_id] = {
cam: [*inst] for cam, inst in zip(instances.keys(), prod)
}

# Expect "max # instances in view" ** "# views" frames (a.k.a. hypotheses)
return instances_hypotheses

@staticmethod
def get_instances_matrices(
instances: Dict[int, Dict[Camcorder, List[Instance]]],
Expand Down Expand Up @@ -3804,7 +4009,9 @@ def get_instances_matrices(
inst_coords_frames.append(
inst_coords_views
) # len=frame_idx, List[M x T x N x 2]

inst_coords = np.stack(inst_coords_frames, axis=1) # M x F x T x N x 2
cams_ordered = cast(List[Camcorder], cams_ordered) # Could be None if no frames

return inst_coords, cams_ordered

Expand Down
146 changes: 144 additions & 2 deletions sleap/io/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import logging
import tempfile
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast

import cattr
import numpy as np
import toml
from aniposelib.cameras import Camera, CameraGroup, FisheyeCamera
from attrs import define, field
from attrs.validators import deep_iterable, instance_of
from sleap_anipose import reproject, triangulate

# from sleap.io.dataset import Labels # TODO(LM): Circular import, implement Observer
from sleap.io.video import Video
Expand Down Expand Up @@ -755,3 +754,146 @@ def make_cattr(videos_list: List[Video]):
RecordingSession, lambda x: x.to_session_dict(video_to_idx)
)
return sessions_cattr


@define
class InstanceGroup:
"""Defines a group of instances across the same frame index.

Args:
camera_cluster: `CameraCluster` object.
instances: List of `Instance` objects.

"""

frame_idx: int = field(validator=instance_of(int))
camera_cluster: Optional[CameraCluster] = None
_instance_by_camcorder: Dict[Camcorder, "Instance"] = field(factory=dict)
_camcorder_by_instance: Dict["Instance", Camcorder] = field(factory=dict)

def __attrs_post_init__(self):
"""Initialize `InstanceGroup` object."""

for cam, instance in self._instance_by_camcorder.items():
self._camcorder_by_instance[instance] = cam

@property
def instances(self) -> List["Instance"]:
"""List of `Instance` objects."""
return list(self._instance_by_camcorder.values())

def get_instance(self, cam: Camcorder) -> Optional["Instance"]:
"""Retrieve `Instance` linked to `Camcorder`.

Args:
camcorder: `Camcorder` object.

Returns:
If `Camcorder` in `self.camera_cluster`, then `Instance` object if found, else
`None` if `Camcorder` has no linked `Instance`.
"""

if cam not in self._instance_by_camcorder:
logger.warning(
f"Camcorder {cam.name} is not linked to a video in this "
f"RecordingSession."
)
return None

return self._instance_by_camcorder[cam]

def get_cam(self, instance: "Instance") -> Optional[Camcorder]:
"""Retrieve `Camcorder` linked to `Instance`.

Args:
instance: `Instance` object.

Returns:
`Camcorder` object if found, else `None`.
"""

if instance not in self._camcorder_by_instance:
logger.warning(
f"{instance} is not in this InstanceGroup's Instances: \n\t{self.instances}."
)
return None

return self._camcorder_by_instance[instance]

def __getitem__(
self, idx_or_key: Union[int, Camcorder, "Instance"]
) -> Union[Camcorder, "Instance"]:
"""Grab a `Camcorder` of `Instance` from the `InstanceGroup`."""

# Try to find in `self.camera_cluster.cameras`
if isinstance(idx_or_key, int):
try:
return self.instances[idx_or_key]
except IndexError:
pass

# Return a `Instance` if `idx_or_key` is a `Camcorder``
if isinstance(idx_or_key, Camcorder):
return self.get_instance(idx_or_key)

else:
# isinstance(idx_or_key, "Instance"):
try:
return self.get_cam(idx_or_key)
except:
pass

raise KeyError(
f"Key {idx_or_key} not found in {self.__class__.__name__} or "
"associated metadata."
)

def __len__(self):
return len(self.instances)

def __repr__(self):
return f"{self.__class__.__name__}(frame_idx={self.frame_idx}, instances={len(self)}, camera_cluster={self.camera_cluster})"

@classmethod
def from_dict(cls, d: dict) -> "InstanceGroup":
"""Creates an `InstanceGroup` object from a dictionary.

Args:
d: Dictionary with `Camcorder` keys and `Instance` values.

Returns:
`InstanceGroup` object.
"""

frame_idx = None
for cam, instance in d.copy().items():
camera_cluster = cam.camera_cluster

# Remove dummy instances (determined by not having a frame index)
if instance.frame_idx is None:
d.pop(cam)
# Grab the frame index from non-dummy instances
elif frame_idx is None:
frame_idx = instance.frame_idx
# Ensure all instances have the same frame index
else:
try:
assert frame_idx == instance.frame_idx
except AssertionError:
logger.warning(
f"Cannot create `InstanceGroup`: Frame index {frame_idx} "
f"does not match instance frame index {instance.frame_idx}."
)

if len(d) == 0:
logger.warning("Cannot create `InstanceGroup`: No real instances found.")

frame_idx = cast(
int, frame_idx
) # Could be None if no real instances in dictionary
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of cast to convert frame_idx to an integer may not be safe if frame_idx can be None. Consider adding a check to ensure frame_idx is not None before casting.


return cls(
frame_idx=frame_idx,
camera_cluster=camera_cluster,
instance_by_camcorder=d,
)
Loading
Loading