diff --git a/sleap/io/format/deeplabcut.py b/sleap/io/format/deeplabcut.py index 5892dba1a..6a7dbf820 100644 --- a/sleap/io/format/deeplabcut.py +++ b/sleap/io/format/deeplabcut.py @@ -19,7 +19,8 @@ import numpy as np import pandas as pd -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Tuple +from pathlib import Path from sleap import Labels, Video, Skeleton from sleap.instance import Instance, LabeledFrame, Point, Track @@ -80,20 +81,71 @@ def read( ) @classmethod - def make_video_for_image_list(cls, image_dir, filenames) -> Video: - """Creates a Video object from frame images.""" + def make_video_for_image_list( + cls, image_dir, filenames + ) -> Tuple[List[Video], List[int], List[int]]: + """Creates a Video object from frame images. + + Args: + image_dir: Directory where images are stored. + filenames: List of image filenames. + + Returns: + Tuple containing: + - List of Video objects created from the images. + - List of video indices for each image. + - List of frame indices for each image. + """ # the image filenames in the csv may not match where the user has them # so we'll change the directory to match where the user has the csv def fix_img_path(img_dir, img_filename): + img_filename = (Path(img_dir) / Path(img_filename).name).as_posix() img_filename = img_filename.replace("\\", "/") - img_filename = os.path.basename(img_filename) - img_filename = os.path.join(img_dir, img_filename) return img_filename + def get_shape(filename): + import cv2 + + img = cv2.imread(filename) + return img.shape[:2] + + # Fix image paths to match the CSV directory. filenames = list(map(lambda f: fix_img_path(image_dir, f), filenames)) - return Video.from_image_filenames(filenames) + try: + # Group by shape. + shapes = list(map(get_shape, filenames)) + imgs_by_shape = {} + for filename, shape in zip(filenames, shapes): + if shape not in imgs_by_shape: + imgs_by_shape[shape] = [] + imgs_by_shape[shape].append(filename) + + # Create videos for each shape group. + videos = [] + inds_by_img = {} + for video_ind, (shape, img_fns) in enumerate(imgs_by_shape.items()): + videos.append( + Video.from_image_filenames(img_fns, height=shape[0], width=shape[1]) + ) + for fidx, img_fn in enumerate(img_fns): + inds_by_img[img_fn] = (video_ind, fidx) + + # Return videos and indices in the input ordering. + video_inds = [] + frame_inds = [] + for filename in filenames: + video_ind, frame_ind = inds_by_img[filename] + video_inds.append(video_ind) + frame_inds.append(frame_ind) + except: + # If we couldn't group by shape, create a single video for all images. + videos = [Video.from_image_filenames(filenames)] + video_inds = [0] * len(filenames) + frame_inds = list(range(len(filenames))) + + return videos, video_inds, frame_inds @classmethod def read_frames( @@ -147,23 +199,21 @@ def read_frames( # Old format has filenames in a single column. img_files = data.iloc[:, 0] - if full_video: - video = full_video - index_frames_by_original_index = True - else: - # Create the Video object + if not full_video: + # Create the Video objects grouped by shape img_dir = os.path.dirname(filename) - video = cls.make_video_for_image_list(img_dir, img_files) - - # The frames in the video we created will be indexed from 0 to N - # rather than having their index from the original source video. - index_frames_by_original_index = False + videos, video_inds, frame_inds = cls.make_video_for_image_list( + img_dir, img_files + ) lfs = [] for i in range(len(data)): - # Figure out frame index to use. - if index_frames_by_original_index: + # Figure out the video and frame index to use. + if full_video: + # Use the input provided one. + video = full_video + # Extract "0123" from "path/img0123.png" as original frame index. frame_idx_match = re.search("(?<=img)(\\d+)(?=\\.png)", img_files[i]) @@ -174,7 +224,9 @@ def read_frames( f"Unable to determine frame index for image {img_files[i]}" ) else: - frame_idx = i + # Get from pregrouped list. + video = videos[video_inds[i]] + frame_idx = frame_inds[i] instances = [] if is_multianimal: