From cef0da6720f78473864b52057683b9b1e5bcd2a6 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 7 Oct 2024 20:08:23 -0700 Subject: [PATCH 01/22] and note where `target_instance_count` is initialized --- sleap/nn/tracking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 558aa9309..510df1275 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -574,7 +574,7 @@ class Tracker(BaseTracker): max_tracking: bool = False # To enable maximum tracking. cleaner: Optional[Callable] = None # TODO: deprecate - target_instance_count: int = 0 + target_instance_count: int = 0 # TODO: deprecate pre_cull_function: Optional[Callable] = None post_connect_single_breaks: bool = False robust_best_instance: float = 1.0 From dea73692cc77b79e99c7caf5884f5dfd18e61d6f Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 7 Oct 2024 20:09:03 -0700 Subject: [PATCH 02/22] `target_instance_count` is not available in the GUI but `max_tracks` is --- sleap/nn/tracking.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 510df1275..a74ae4138 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -824,8 +824,13 @@ def final_pass(self, frames: List[LabeledFrame]): # "tracking." # ) self.cleaner.run(frames) - elif self.target_instance_count and self.post_connect_single_breaks: + elif (self.target_instance_count or self.max_tracks) and self.post_connect_single_breaks: + if not self.target_instance_count: + # If target_instance_count is not set, use max_tracks instead + # target_instance_count not available in the GUI + self.target_instance_count = self.max_tracks connect_single_track_breaks(frames, self.target_instance_count) + print("Connecting single track breaks.") def get_name(self): tracker_name = self.candidate_maker.__class__.__name__ From 43c0f0aad85e2e122c8c79fd759cb9057f50ba1e Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 7 Oct 2024 20:09:31 -0700 Subject: [PATCH 03/22] add note where `target_instance_count` is initialized --- sleap/nn/tracking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index a74ae4138..666f79990 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -855,7 +855,7 @@ def make_tracker_by_name( of_max_levels: int = 3, save_shifted_instances: bool = False, # Pre-tracking options to cull instances - target_instance_count: int = 0, + target_instance_count: int = 0, # TODO: deprecate target_instance_count pre_cull_to_target: bool = False, pre_cull_iou_threshold: Optional[float] = None, # Post-tracking options to connect broken tracks From 60a40abf46fdc07d6208d0d52b3eaece895a5741 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 7 Oct 2024 20:10:06 -0700 Subject: [PATCH 04/22] add note since neither `target_instance_count` nor `pre_cull_to_target` are options in the GUI --- sleap/nn/tracking.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 666f79990..44e3f0347 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -926,6 +926,7 @@ def make_tracker_by_name( pre_cull_function = None if target_instance_count and pre_cull_to_target: + # Right now this is not accessible from the GUI def pre_cull_function(inst_list): cull_frame_instances( From 06ab653695ecb623f4fa03c84be964110edecd1f Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 7 Oct 2024 20:11:04 -0700 Subject: [PATCH 05/22] accept either max_tracks or target_instance_count for compatibility with both CLI and GUI --- sleap/nn/tracking.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 44e3f0347..f9af3171d 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -946,11 +946,17 @@ def pre_cull_function(inst_list): pre_cull_function=pre_cull_function, max_tracking=max_tracking, max_tracks=max_tracks, - target_instance_count=target_instance_count, + target_instance_count=target_instance_count, # TODO: deprecate target_instance_count post_connect_single_breaks=post_connect_single_breaks, ) - if target_instance_count and kf_init_frame_count: + # Kalman filter requires deprecated target_instance_count + if (max_tracks or target_instance_count) and kf_init_frame_count: + if not target_instance_count: + # If target_instance_count is not set, use max_tracks instead + # target_instance_count not available in the GUI + target_instance_count = max_tracks + kalman_obj = KalmanTracker.make_tracker( init_tracker=tracker_obj, init_frame_count=kf_init_frame_count, @@ -960,8 +966,8 @@ def pre_cull_function(inst_list): ) return kalman_obj - elif kf_init_frame_count and not target_instance_count: - raise ValueError("Kalman filter requires target instance count.") + elif kf_init_frame_count and not (max_tracks or target_instance_count): + raise ValueError("Kalman filter requires max tracks or target instance count.") else: return tracker_obj From 4c70d275eeddab71319a4afb917ad96c23cfafd8 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 7 Oct 2024 20:12:31 -0700 Subject: [PATCH 06/22] TypeError: track() got an unexpected keyword argument 'img_hw' since `init_tracker` has `img_hw` --- sleap/nn/tracking.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index f9af3171d..90080620e 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1398,6 +1398,7 @@ def track( untracked_instances: List[InstanceType], img: Optional[np.ndarray] = None, t: int = None, + **kwargs, ) -> List[InstanceType]: """Tracks individual frame, using Kalman filters if possible.""" From 474d0d6a8af621ba1cc568cb2bd148a0892acc83 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 7 Oct 2024 20:12:57 -0700 Subject: [PATCH 07/22] useful print statements --- sleap/nn/tracking.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 90080620e..0feff4ef2 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1381,6 +1381,8 @@ def cull_function(inst_list): if init_tracker.pre_cull_function is None: init_tracker.pre_cull_function = cull_function + print(f"Using {init_tracker.get_name()} to track {init_frame_count} frames for Kalman filters.") + return cls( init_tracker=init_tracker, kalman_tracker=kalman_tracker, @@ -1433,7 +1435,7 @@ def track( # Initialize the Kalman filters self.kalman_tracker.init_filters(self.init_set.instances) - # print(f"Kalman filters initialized (frame {t})") + print(f"Kalman filters initialized (frame {t})") # Clear the data used to init filters, so that if the filters # stop tracking and we need to re-init, we won't re-use the From 1eafd5f7388f5938423b175b19ab2e2b571754c0 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 7 Oct 2024 20:54:07 -0700 Subject: [PATCH 08/22] black --- sleap/nn/tracking.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 0feff4ef2..83f994627 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -574,7 +574,7 @@ class Tracker(BaseTracker): max_tracking: bool = False # To enable maximum tracking. cleaner: Optional[Callable] = None # TODO: deprecate - target_instance_count: int = 0 # TODO: deprecate + target_instance_count: int = 0 # TODO: deprecate pre_cull_function: Optional[Callable] = None post_connect_single_breaks: bool = False robust_best_instance: float = 1.0 @@ -824,7 +824,9 @@ def final_pass(self, frames: List[LabeledFrame]): # "tracking." # ) self.cleaner.run(frames) - elif (self.target_instance_count or self.max_tracks) and self.post_connect_single_breaks: + elif ( + self.target_instance_count or self.max_tracks + ) and self.post_connect_single_breaks: if not self.target_instance_count: # If target_instance_count is not set, use max_tracks instead # target_instance_count not available in the GUI @@ -855,7 +857,7 @@ def make_tracker_by_name( of_max_levels: int = 3, save_shifted_instances: bool = False, # Pre-tracking options to cull instances - target_instance_count: int = 0, # TODO: deprecate target_instance_count + target_instance_count: int = 0, # TODO: deprecate target_instance_count pre_cull_to_target: bool = False, pre_cull_iou_threshold: Optional[float] = None, # Post-tracking options to connect broken tracks @@ -946,7 +948,7 @@ def pre_cull_function(inst_list): pre_cull_function=pre_cull_function, max_tracking=max_tracking, max_tracks=max_tracks, - target_instance_count=target_instance_count, # TODO: deprecate target_instance_count + target_instance_count=target_instance_count, # TODO: deprecate target_instance_count post_connect_single_breaks=post_connect_single_breaks, ) @@ -967,7 +969,9 @@ def pre_cull_function(inst_list): return kalman_obj elif kf_init_frame_count and not (max_tracks or target_instance_count): - raise ValueError("Kalman filter requires max tracks or target instance count.") + raise ValueError( + "Kalman filter requires max tracks or target instance count." + ) else: return tracker_obj @@ -1381,7 +1385,9 @@ def cull_function(inst_list): if init_tracker.pre_cull_function is None: init_tracker.pre_cull_function = cull_function - print(f"Using {init_tracker.get_name()} to track {init_frame_count} frames for Kalman filters.") + print( + f"Using {init_tracker.get_name()} to track {init_frame_count} frames for Kalman filters." + ) return cls( init_tracker=init_tracker, From 1d6ed7c3a6f0a46219cc76d8ffffe8b27e2aa3f2 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 16 Dec 2024 13:01:03 -0800 Subject: [PATCH 09/22] np.bool is deprecated --- sleap/nn/tracker/kalman.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/nn/tracker/kalman.py b/sleap/nn/tracker/kalman.py index 2b0343927..774a4634e 100644 --- a/sleap/nn/tracker/kalman.py +++ b/sleap/nn/tracker/kalman.py @@ -608,7 +608,7 @@ def remove_second_bests_from_cost_matrix( cost matrix with invalid matches set to specified invalid value. """ - valid_match_mask = np.full_like(cost_matrix, True, dtype=np.bool) + valid_match_mask = np.full_like(cost_matrix, True, dtype=bool) rows, columns = cost_matrix.shape From a3e64a93468c0e2f4f5c12de1acd30b29f36f3f5 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Tue, 17 Dec 2024 10:26:32 -0800 Subject: [PATCH 10/22] debug --- sleap/nn/inference.py | 12 +++++++++--- sleap_tracking_debug.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 sleap_tracking_debug.py diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 3f01a1c3c..c27382e52 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -1129,9 +1129,11 @@ def export_model( info["predicted_tensors"] = tensors full_model = tf.function( - lambda x: sleap.nn.data.utils.unrag_example(model(x), numpy=False) - if unrag_outputs - else model(x) + lambda x: ( + sleap.nn.data.utils.unrag_example(model(x), numpy=False) + if unrag_outputs + else model(x) + ) ) full_model = full_model.get_concrete_function( @@ -5717,3 +5719,7 @@ def main(args: Optional[list] = None): "To retrack on predictions, must specify tracker. " "Use \"sleap-track --tracking.tracker ...' to specify tracker to use." ) + + +if __name__ == "__main__": + main() diff --git a/sleap_tracking_debug.py b/sleap_tracking_debug.py new file mode 100644 index 000000000..83661cbc8 --- /dev/null +++ b/sleap_tracking_debug.py @@ -0,0 +1,34 @@ +import sleap + +PREDICTIONS_FILE = ( + "/Users/elizabethberrigan/repos/sleap/tests/data/tracks/clip.2node.slp" +) + +# Load predictions +labels = sleap.load_file(PREDICTIONS_FILE) + +# Here I'm removing the tracks so we just have instances without any tracking applied. +for instance in labels.instances(): + instance.track = None +labels.tracks = [] + +tracker = sleap.nn.tracking.Tracker.make_tracker_by_name( + tracker="flow", + track_window=5, + # Matching options + similarity="instance", + match="hungarian", + max_tracking=True, + max_tracks=2, + kf_node_indices=[0, 1], + kf_init_frame_count=10, +) + +tracked_lfs = [] +for lf in labels: + lf.instances = tracker.track(lf.instances, img=lf.image) + tracked_lfs.append(lf) +tracked_labels = sleap.Labels(tracked_lfs) +tracked_labels.save( + "/Users/elizabethberrigan/repos/sleap/tests/data/tracks/clip.2node.tracked.slp" +) From 24a442df2973654aa293338ff28179a152478089 Mon Sep 17 00:00:00 2001 From: eberrigan Date: Tue, 17 Dec 2024 11:01:16 -0800 Subject: [PATCH 11/22] add params for testing kalman filter --- tests/nn/test_tracking_integration.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 625302fd0..b730d9cc7 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -146,6 +146,11 @@ def main(f, dir): 0.25, ) + kalman_params = dict( + kf_node_indices=[0, 1], + kf_init_frame_count=10, + ) + def make_tracker( tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0 ): From 96a08cc4175ff7ecd7080a658e64b43ac97c868e Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Tue, 17 Dec 2024 11:09:47 -0800 Subject: [PATCH 12/22] remove params because this function isn't used --- tests/nn/test_tracking_integration.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index b730d9cc7..625302fd0 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -146,11 +146,6 @@ def main(f, dir): 0.25, ) - kalman_params = dict( - kf_node_indices=[0, 1], - kf_init_frame_count=10, - ) - def make_tracker( tracker_name, matcher_name, sim_name, max_tracks, max_tracking=False, scale=0 ): From bd364cd0e5ca0f9a8c1980741ea5510ff95ae373 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Tue, 17 Dec 2024 12:41:28 -0800 Subject: [PATCH 13/22] debugging --- sleap/nn/tracking.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 83f994627..9cd091730 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -287,10 +287,14 @@ def flow_shift_instances( This function relies on the Lucas-Kanade method for optical flow estimation. """ + print(f"Image type before converting: {type(ref_img)}") + # Convert to uint8 for cv2.calcOpticalFlowPyrLK ref_img = ensure_int(ref_img) new_img = ensure_int(new_img) + print(f"Image type after converting: {type(ref_img)}") + # Convert tensors to ndarays if hasattr(ref_img, "numpy"): ref_img = ref_img.numpy() @@ -1081,9 +1085,9 @@ def get_by_name_factory_options(cls): option = dict(name="of_window_size", default=21) option["type"] = int - option[ - "help" - ] = "For optical-flow: Optical flow window size to consider at each pyramid " + option["help"] = ( + "For optical-flow: Optical flow window size to consider at each pyramid " + ) "scale level" options.append(option) @@ -1110,9 +1114,9 @@ def int_list_func(s): option = dict(name="kf_init_frame_count", default="0") option["type"] = int - option[ - "help" - ] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." + option["help"] = ( + "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." + ) options.append(option) def float_list_func(s): From 8722813fcb9a0b729993a89a111c87af9f5e02fd Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Tue, 17 Dec 2024 12:41:43 -0800 Subject: [PATCH 14/22] test kalman filter tracking --- tests/nn/test_tracking_integration.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 625302fd0..98e4e3d35 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -2,13 +2,35 @@ import operator import os import time - +import pytest import sleap from sleap.nn.inference import main as inference_cli import sleap.nn.tracker.components from sleap.io.dataset import Labels, LabeledFrame +@pytest.mark.parametrize( + "tracker_name", ["simple", "simplemaxtracks", "flow", "flowmaxtracks"] +) +def test_kalman_tracker(tmpdir, centered_pair_predictions_slp_path, tracker_name): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + "--frames 200-300 " + "--tracking.similarity instance " + "--tracking.match hungarian " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file(f"{tmpdir}/{tracker_name}.slp") + assert len(labels.tracks) == 2 + + def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): cli = ( "--tracking.tracker simple " From 0fcce5c7f646df24dca0c363b1572e402075717e Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Wed, 18 Dec 2024 12:48:20 -0800 Subject: [PATCH 15/22] add documentation --- docs/guides/cli.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 134461c60..3f4c37ab2 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -230,7 +230,7 @@ optional arguments: --tracking.kf_node_indices TRACKING.KF_NODE_INDICES For Kalman filter: Indices of nodes to track. (default: ) --tracking.kf_init_frame_count TRACKING.KF_INIT_FRAME_COUNT - For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0) + For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used. (default: 0) Kalman filters require TRACKING.KF_NODE_INDICES, TRACKING.MAX_TRACKING and TRACKING.MAX_TRACKS or TRACKING.TARGET_INSTANCE_COUNT, TRACKING.TRACKER to be simple or simplemaxtracks, and TRACKING.SIMILARITY to not be normalized_instance. ``` #### Examples: From 7322b3c7a9c34129b4de57303046992fb18ddf80 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Wed, 18 Dec 2024 12:48:44 -0800 Subject: [PATCH 16/22] kalman filter needs node indices, simple tracking and similarity anything besides normalized --- sleap/nn/tracking.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 9cd091730..0799a056c 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -958,6 +958,23 @@ def pre_cull_function(inst_list): # Kalman filter requires deprecated target_instance_count if (max_tracks or target_instance_count) and kf_init_frame_count: + if not kf_node_indices: + raise ValueError( + "Kalman filter requires node indices for instance tracking." + ) + + if tracker == "flow" or tracker == "flowmaxtracks": + # Tracking with Kalman filter requires initial tracker object to be simple + raise ValueError( + "Kalman filter requires simple tracker for initial tracking." + ) + + if similarity == "normalized_instance": + # Kalman filter doesnot support normalized_instance_similarity + raise ValueError( + "Kalman filter does not support normalized_instance_similarity." + ) + if not target_instance_count: # If target_instance_count is not set, use max_tracks instead # target_instance_count not available in the GUI From 266d8ae6d452e00ad374c29814fb20e07b29a183 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Wed, 18 Dec 2024 12:49:04 -0800 Subject: [PATCH 17/22] add tests for every combination related to kalman args --- tests/nn/test_tracking_integration.py | 182 +++++++++++++++++++++++--- 1 file changed, 166 insertions(+), 16 deletions(-) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 98e4e3d35..21d34380f 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -9,26 +9,176 @@ from sleap.io.dataset import Labels, LabeledFrame +similarity_args = [ + "instance", + "normalized_instance", + "object_keypoint", + "centroid", + "iou", +] +match_args = ["hungarian", "greedy"] + + @pytest.mark.parametrize( "tracker_name", ["simple", "simplemaxtracks", "flow", "flowmaxtracks"] ) -def test_kalman_tracker(tmpdir, centered_pair_predictions_slp_path, tracker_name): - cli = ( - f"--tracking.tracker {tracker_name} " - "--tracking.max_tracking 1 --tracking.max_tracks 2 " - "--frames 200-300 " - "--tracking.similarity instance " - "--tracking.match hungarian " - "--tracking.track_window 5 " - "--tracking.kf_init_frame_count 10 " - "--tracking.kf_node_indices 0,1 " - f"-o {tmpdir}/{tracker_name}.slp " - f"{centered_pair_predictions_slp_path}" - ) - inference_cli(cli.split(" ")) +@pytest.mark.parametrize("similarity", similarity_args) +@pytest.mark.parametrize("match", match_args) +def test_kalman_tracker( + tmpdir, centered_pair_predictions_slp_path, tracker_name, similarity, match +): + + if tracker_name == "flow" or tracker_name == "flowmaxtracks": + # Expecting ValueError for "flow" or "flowmaxtracks" due to Kalman filter requiring a simple tracker + with pytest.raises( + ValueError, + match="Kalman filter requires simple tracker for initial tracking.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + else: + # For simple or simplemaxtracks, continue with other tests + # Check for ValueError when similarity is "normalized_instance" + if similarity == "normalized_instance": + with pytest.raises( + ValueError, + match="Kalman filter does not support normalized_instance_similarity.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + return + + # Check for ValueError when kf_node_indices is None which is the default + with pytest.raises( + ValueError, + match="Kalman filter requires node indices for instance tracking.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + # Test for missing max_tracks and target_instance_count with kf_init_frame_count + with pytest.raises( + ValueError, + match="Kalman filter requires max tracks or target instance count.", + ): + cli = ( + f"--tracking.tracker {tracker_name} " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + f"-o {tmpdir}/{tracker_name}.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + # Test with target_instance_count and without max_tracks + cli = ( + f"--tracking.tracker {tracker_name} " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + f"-o {tmpdir}/{tracker_name}_target_instance_count.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file(f"{tmpdir}/{tracker_name}_target_instance_count.slp") + assert len(labels.tracks) == 2 + + # Test with target_instance_count and with max_tracks + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) - labels = sleap.load_file(f"{tmpdir}/{tracker_name}.slp") - assert len(labels.tracks) == 2 + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count.slp" + ) + assert len(labels.tracks) == 2 + + # Test with "--tracking.pre_cull_iou_threshold", "0.8" + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + "--tracking.pre_cull_iou_threshold 0.8 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_iou.slp" + ) + assert len(labels.tracks) == 2 + + # Test with "--tracking.pre_cull_to_target", "1" + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + "--tracking.pre_cull_to_target 1 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_to_target.slp" + ) + assert len(labels.tracks) == 2 def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): From 00cbdaf6e04605f70a846d32887404bcb0c471fb Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Wed, 18 Dec 2024 13:42:31 -0800 Subject: [PATCH 18/22] add example to documentation --- docs/guides/cli.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/guides/cli.md b/docs/guides/cli.md index 3f4c37ab2..339c5405b 100644 --- a/docs/guides/cli.md +++ b/docs/guides/cli.md @@ -285,6 +285,12 @@ sleap-track --gpu 1 ... sleap-track -m "models/my_model" --frames 1000-2000 "input_video.mp4" ``` +**9. Use Kalman tracker (not recommended since flow is preferred):** + +```none +sleap-track -m "models/my_model" --tracking.similarity instance --tracking.tracker simplemaxtracks --tracking.max_tracking 1 --tracking.max_tracks 4 --tracking.kf_init_frame_count 10 --tracking.kf_node_indices 0,1 -o "output_predictions.slp" "input_video.mp4" +``` + ## Dataset files (sleap-convert)= From 6482f13eb13f28a9c16d06449bea036da82aacd1 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Wed, 18 Dec 2024 13:43:41 -0800 Subject: [PATCH 19/22] delete debug scripts --- sleap_tracking_debug.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 sleap_tracking_debug.py diff --git a/sleap_tracking_debug.py b/sleap_tracking_debug.py deleted file mode 100644 index 83661cbc8..000000000 --- a/sleap_tracking_debug.py +++ /dev/null @@ -1,34 +0,0 @@ -import sleap - -PREDICTIONS_FILE = ( - "/Users/elizabethberrigan/repos/sleap/tests/data/tracks/clip.2node.slp" -) - -# Load predictions -labels = sleap.load_file(PREDICTIONS_FILE) - -# Here I'm removing the tracks so we just have instances without any tracking applied. -for instance in labels.instances(): - instance.track = None -labels.tracks = [] - -tracker = sleap.nn.tracking.Tracker.make_tracker_by_name( - tracker="flow", - track_window=5, - # Matching options - similarity="instance", - match="hungarian", - max_tracking=True, - max_tracks=2, - kf_node_indices=[0, 1], - kf_init_frame_count=10, -) - -tracked_lfs = [] -for lf in labels: - lf.instances = tracker.track(lf.instances, img=lf.image) - tracked_lfs.append(lf) -tracked_labels = sleap.Labels(tracked_lfs) -tracked_labels.save( - "/Users/elizabethberrigan/repos/sleap/tests/data/tracks/clip.2node.tracked.slp" -) From cbf5ca83a4fd751b6be533221e88e7d4e23b3150 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Wed, 18 Dec 2024 13:44:30 -0800 Subject: [PATCH 20/22] delete print statements --- sleap/nn/tracking.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index 0799a056c..d10327222 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -287,14 +287,10 @@ def flow_shift_instances( This function relies on the Lucas-Kanade method for optical flow estimation. """ - print(f"Image type before converting: {type(ref_img)}") - # Convert to uint8 for cv2.calcOpticalFlowPyrLK ref_img = ensure_int(ref_img) new_img = ensure_int(new_img) - print(f"Image type after converting: {type(ref_img)}") - # Convert tensors to ndarays if hasattr(ref_img, "numpy"): ref_img = ref_img.numpy() From c29921b89eea334f930d5ccda828a123fddeeac7 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Wed, 18 Dec 2024 13:48:38 -0800 Subject: [PATCH 21/22] black --- sleap/nn/tracking.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sleap/nn/tracking.py b/sleap/nn/tracking.py index d10327222..231b004f5 100644 --- a/sleap/nn/tracking.py +++ b/sleap/nn/tracking.py @@ -1098,9 +1098,9 @@ def get_by_name_factory_options(cls): option = dict(name="of_window_size", default=21) option["type"] = int - option["help"] = ( - "For optical-flow: Optical flow window size to consider at each pyramid " - ) + option[ + "help" + ] = "For optical-flow: Optical flow window size to consider at each pyramid " "scale level" options.append(option) @@ -1127,9 +1127,9 @@ def int_list_func(s): option = dict(name="kf_init_frame_count", default="0") option["type"] = int - option["help"] = ( - "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." - ) + option[ + "help" + ] = "For Kalman filter: Number of frames to track with other tracker. 0 means no Kalman filters will be used." options.append(option) def float_list_func(s): From fa5a82bdbba4b4750ffdea58c211413a35f5b39f Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Wed, 18 Dec 2024 14:50:22 -0800 Subject: [PATCH 22/22] add test for connect single breaks --- tests/nn/test_tracking_integration.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/nn/test_tracking_integration.py b/tests/nn/test_tracking_integration.py index 21d34380f..4a601ac00 100644 --- a/tests/nn/test_tracking_integration.py +++ b/tests/nn/test_tracking_integration.py @@ -180,6 +180,26 @@ def test_kalman_tracker( ) assert len(labels.tracks) == 2 + # Test with 'tracking.post_connect_single_breaks': 0 + cli = ( + f"--tracking.tracker {tracker_name} " + "--tracking.max_tracking 1 --tracking.max_tracks 2 " + f"--tracking.similarity {similarity} " + f"--tracking.match {match} " + "--tracking.track_window 5 " + "--tracking.kf_init_frame_count 10 " + "--tracking.kf_node_indices 0,1 " + "--tracking.target_instance_count 2 " + "--tracking.post_connect_single_breaks 0 " + f"-o {tmpdir}/{tracker_name}_max_tracks_target_instance_count_single_breaks.slp " + f"{centered_pair_predictions_slp_path}" + ) + inference_cli(cli.split(" ")) + labels = sleap.load_file( + f"{tmpdir}/{tracker_name}_max_tracks_target_instance_count_single_breaks.slp" + ) + assert len(labels.tracks) == 2 + def test_simple_tracker(tmpdir, centered_pair_predictions_slp_path): cli = (