diff --git a/src/datachain/__init__.py b/src/datachain/__init__.py index 98c7b8641..aff69941c 100644 --- a/src/datachain/__init__.py +++ b/src/datachain/__init__.py @@ -1,4 +1,4 @@ -from datachain.lib import func +from datachain.lib import func, models from datachain.lib.data_model import DataModel, DataType, is_chain_type from datachain.lib.dc import C, Column, DataChain, Sys from datachain.lib.file import ( @@ -38,5 +38,6 @@ "func", "is_chain_type", "metrics", + "models", "param", ] diff --git a/src/datachain/lib/file.py b/src/datachain/lib/file.py index 41cd6369f..caf19015a 100644 --- a/src/datachain/lib/file.py +++ b/src/datachain/lib/file.py @@ -20,9 +20,6 @@ from pyarrow.dataset import dataset from pydantic import Field, field_validator -if TYPE_CHECKING: - from typing_extensions import Self - from datachain.client.fileslice import FileSlice from datachain.lib.data_model import DataModel from datachain.lib.utils import DataChainError diff --git a/src/datachain/lib/models/__init__.py b/src/datachain/lib/models/__init__.py new file mode 100644 index 000000000..b4335237b --- /dev/null +++ b/src/datachain/lib/models/__init__.py @@ -0,0 +1,5 @@ +from . import yolo +from .bbox import BBox +from .pose import Pose, Pose3D + +__all__ = ["BBox", "Pose", "Pose3D", "yolo"] diff --git a/src/datachain/lib/models/bbox.py b/src/datachain/lib/models/bbox.py new file mode 100644 index 000000000..501ed6296 --- /dev/null +++ b/src/datachain/lib/models/bbox.py @@ -0,0 +1,45 @@ +from typing import Optional + +from pydantic import Field + +from datachain.lib.data_model import DataModel + + +class BBox(DataModel): + """ + A data model for representing bounding boxes. + + Attributes: + title (str): The title of the bounding box. + x1 (float): The x-coordinate of the top-left corner of the bounding box. + y1 (float): The y-coordinate of the top-left corner of the bounding box. + x2 (float): The x-coordinate of the bottom-right corner of the bounding box. + y2 (float): The y-coordinate of the bottom-right corner of the bounding box. + + The bounding box is defined by two points: + - (x1, y1): The top-left corner of the box. + - (x2, y2): The bottom-right corner of the box. + """ + + title: str = Field(default="") + x1: float = Field(default=0) + y1: float = Field(default=0) + x2: float = Field(default=0) + y2: float = Field(default=0) + + @staticmethod + def from_xywh(bbox: list[float], title: Optional[str] = None) -> "BBox": + """ + Converts a bounding box in (x, y, width, height) format + to a BBox data model instance. + + Args: + bbox (list[float]): A bounding box, represented as a list + of four floats [x, y, width, height]. + + Returns: + BBox2D: An instance of the BBox data model. + """ + assert len(bbox) == 4, f"Bounding box must have 4 elements, got f{len(bbox)}" + x, y, w, h = bbox + return BBox(title=title or "", x1=x, y1=y, x2=x + w, y2=y + h) diff --git a/src/datachain/lib/models/pose.py b/src/datachain/lib/models/pose.py new file mode 100644 index 000000000..5cb95a29b --- /dev/null +++ b/src/datachain/lib/models/pose.py @@ -0,0 +1,37 @@ +from pydantic import Field + +from datachain.lib.data_model import DataModel + + +class Pose(DataModel): + """ + A data model for representing pose keypoints. + + Attributes: + x (list[float]): The x-coordinates of the keypoints. + y (list[float]): The y-coordinates of the keypoints. + + The keypoints are represented as lists of x and y coordinates, where each index + corresponds to a specific body part. + """ + + x: list[float] = Field(default=None) + y: list[float] = Field(default=None) + + +class Pose3D(DataModel): + """ + A data model for representing 3D pose keypoints. + + Attributes: + x (list[float]): The x-coordinates of the keypoints. + y (list[float]): The y-coordinates of the keypoints. + visible (list[float]): The visibility of the keypoints. + + The keypoints are represented as lists of x, y, and visibility values, + where each index corresponds to a specific body part. + """ + + x: list[float] = Field(default=None) + y: list[float] = Field(default=None) + visible: list[float] = Field(default=None) diff --git a/src/datachain/lib/models/yolo.py b/src/datachain/lib/models/yolo.py new file mode 100644 index 000000000..4231240a6 --- /dev/null +++ b/src/datachain/lib/models/yolo.py @@ -0,0 +1,39 @@ +""" +This module contains the YOLO models. + +YOLO stands for "You Only Look Once", a family of object detection models that +are designed to be fast and accurate. The models are trained to detect objects +in images by dividing the image into a grid and predicting the bounding boxes +and class probabilities for each grid cell. + +More information about YOLO can be found here: +- https://pjreddie.com/darknet/yolo/ +- https://docs.ultralytics.com/ +""" + + +class PoseBodyPart: + """ + An enumeration of body parts for YOLO pose keypoints. + + More information about the body parts can be found here: + https://docs.ultralytics.com/tasks/pose/ + """ + + nose = 0 + left_eye = 1 + right_eye = 2 + left_ear = 3 + right_ear = 4 + left_shoulder = 5 + right_shoulder = 6 + left_elbow = 7 + right_elbow = 8 + left_wrist = 9 + right_wrist = 10 + left_hip = 11 + right_hip = 12 + left_knee = 13 + right_knee = 14 + left_ankle = 15 + right_ankle = 16 diff --git a/tests/unit/lib/test_models.py b/tests/unit/lib/test_models.py new file mode 100644 index 000000000..c3ea2b463 --- /dev/null +++ b/tests/unit/lib/test_models.py @@ -0,0 +1,50 @@ +from datachain.lib import models + + +def test_bbox(): + bbox = models.BBox(title="BBox", x1=0.5, y1=1.5, x2=2.5, y2=3.5) + assert bbox.model_dump() == { + "title": "BBox", + "x1": 0.5, + "y1": 1.5, + "x2": 2.5, + "y2": 3.5, + } + + +def test_bbox_from_xywh(): + bbox = models.BBox.from_xywh([0.5, 1.5, 2.5, 3.5]) + assert bbox.model_dump() == {"title": "", "x1": 0.5, "y1": 1.5, "x2": 3, "y2": 5} + + bbox = models.BBox.from_xywh([0.5, 1.5, 2.5, 3.5], title="BBox") + assert bbox.model_dump() == { + "title": "BBox", + "x1": 0.5, + "y1": 1.5, + "x2": 3, + "y2": 5, + } + + +def test_pose(): + x = [x * 0.5 for x in range(17)] + y = [y * 1.5 for y in range(17)] + pose = models.Pose(x=x, y=y) + assert pose.model_dump() == {"x": x, "y": y} + assert pose.x[models.yolo.PoseBodyPart.nose] == 0 + assert pose.x[models.yolo.PoseBodyPart.left_eye] == 0.5 + assert pose.x[models.yolo.PoseBodyPart.right_eye] == 1 + assert pose.x[models.yolo.PoseBodyPart.left_ear] == 1.5 + assert pose.x[models.yolo.PoseBodyPart.right_ear] == 2 + assert pose.x[models.yolo.PoseBodyPart.left_shoulder] == 2.5 + assert pose.x[models.yolo.PoseBodyPart.right_shoulder] == 3 + assert pose.x[models.yolo.PoseBodyPart.left_elbow] == 3.5 + assert pose.x[models.yolo.PoseBodyPart.right_elbow] == 4 + assert pose.x[models.yolo.PoseBodyPart.left_wrist] == 4.5 + assert pose.x[models.yolo.PoseBodyPart.right_wrist] == 5 + assert pose.x[models.yolo.PoseBodyPart.left_hip] == 5.5 + assert pose.x[models.yolo.PoseBodyPart.right_hip] == 6 + assert pose.x[models.yolo.PoseBodyPart.left_knee] == 6.5 + assert pose.x[models.yolo.PoseBodyPart.right_knee] == 7 + assert pose.x[models.yolo.PoseBodyPart.left_ankle] == 7.5 + assert pose.x[models.yolo.PoseBodyPart.right_ankle] == 8