From 70a18551ea79a7a0beda4f4dd43a4418d4ea2668 Mon Sep 17 00:00:00 2001 From: Ben Eisner Date: Thu, 9 May 2024 19:13:23 -0400 Subject: [PATCH] add the real world stuff --- configs/benchmark/real_world_mug.yaml | 2 + configs/dataset/real_world_mug/mug_place.yaml | 49 +++++++++ taxpose/datasets/point_cloud_dataset.py | 6 + taxpose/datasets/real_world_mug.py | 103 ++++++++++++++++++ 4 files changed, 160 insertions(+) create mode 100644 configs/benchmark/real_world_mug.yaml create mode 100644 configs/dataset/real_world_mug/mug_place.yaml create mode 100644 taxpose/datasets/real_world_mug.py diff --git a/configs/benchmark/real_world_mug.yaml b/configs/benchmark/real_world_mug.yaml new file mode 100644 index 0000000..2de90d4 --- /dev/null +++ b/configs/benchmark/real_world_mug.yaml @@ -0,0 +1,2 @@ +name: real_world_mug +dataset_root: /home/beisner/code/multi_project/cam_ready_trainingdata diff --git a/configs/dataset/real_world_mug/mug_place.yaml b/configs/dataset/real_world_mug/mug_place.yaml new file mode 100644 index 0000000..0a9c37a --- /dev/null +++ b/configs/dataset/real_world_mug/mug_place.yaml @@ -0,0 +1,49 @@ +# Dataset Settings +train_dset: + demo_dset: + dataset_type: real_world_mug + dataset_root: ${benchmark.dataset_root} + dataset_indices: + - 0 + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 10 + start_anchor: True + min_num_cameras: 4 + max_num_cameras: 4 + num_points: 1024 + action_class: ${task.action_class} + anchor_class: ${task.anchor_class} + +val_dset: + demo_dset: + dataset_type: ${...train_dset.demo_dset.dataset_type} + dataset_root: ${benchmark.dataset_root} + dataset_indices: + - 11 + - 12 + - 13 + - 14 + - 15 + - 16 + - 17 + - 18 + - 19 + start_anchor: ${...train_dset.demo_dset.start_anchor} + min_num_cameras: ${...train_dset.demo_dset.min_num_cameras} + max_num_cameras: ${...train_dset.demo_dset.max_num_cameras} + num_points: ${...train_dset.demo_dset.num_points} + action_class: ${task.action_class} + anchor_class: ${task.anchor_class} + +test_dset: + demo_dset: + dataset_indices: null + num_demo: null diff --git a/taxpose/datasets/point_cloud_dataset.py b/taxpose/datasets/point_cloud_dataset.py index 5b4aef2..555162b 100644 --- a/taxpose/datasets/point_cloud_dataset.py +++ b/taxpose/datasets/point_cloud_dataset.py @@ -35,6 +35,12 @@ def make_dataset( import taxpose.datasets.ndf as ndf return ndf.NDFPointCloudDataset(cast(ndf.NDFPointCloudDatasetConfig, cfg)) + elif cfg.dataset_type == "real_world_mug": + import taxpose.datasets.real_world_mug as real_world_mug + + return real_world_mug.RealWorldMugPointCloudDataset( + cast(real_world_mug.RealWorldMugPointCloudDatasetConfig, cfg) + ) else: raise NotImplementedError(f"Unknown dataset type: {cfg.dataset_type}") diff --git a/taxpose/datasets/real_world_mug.py b/taxpose/datasets/real_world_mug.py new file mode 100644 index 0000000..5b98db1 --- /dev/null +++ b/taxpose/datasets/real_world_mug.py @@ -0,0 +1,103 @@ +import functools +import pickle +from dataclasses import dataclass +from pathlib import Path +from typing import ClassVar, List, Optional + +import numpy as np +import torch +from torch.utils.data import Dataset + +from taxpose.datasets.base import PlacementPointCloudData + + +@dataclass +class RealWorldMugPointCloudDatasetConfig: + dataset_type: ClassVar[str] = "real_world_mug" + dataset_root: Path + dataset_indices: Optional[List[int]] = None + start_anchor: bool = False + min_num_cameras: int = 4 + max_num_cameras: int = 4 + num_points: int = 1024 + + +class RealWorldMugPointCloudDataset(Dataset[PlacementPointCloudData]): + def __init__(self, cfg: RealWorldMugPointCloudDatasetConfig): + # Check that all the files in the dataset are there. + + assert cfg.dataset_indices is not None and len(cfg.dataset_indices) > 0 + + self.filenames = [ + Path(cfg.dataset_root) / f"{idx}.pkl" for idx in cfg.dataset_indices + ] + + # Make sure each file exists. + for filename in self.filenames: + assert filename.exists() + + self.start_anchor = cfg.start_anchor + self.min_num_cameras = cfg.min_num_cameras + self.max_num_cameras = cfg.max_num_cameras + self.num_points = cfg.num_points + + # @staticmethod + @functools.cache + def load_data(self, filename, start_anchor=False): + with open(filename, "rb") as f: + sensor_data = pickle.load(f) + + points_action_np = sensor_data["end_action"] + if start_anchor: + points_anchor_np = sensor_data["start_anchor"] + else: + points_anchor_np = sensor_data["end_anchor"] + + points_action_cls = sensor_data["end_action_cam_label"] + if start_anchor: + points_anchor_cls = sensor_data["start_anchor_cam_label"] + else: + points_anchor_cls = sensor_data["end_anchor_cam_label"] + + if self.min_num_cameras < 4: + num_cameras = np.random.randint( + low=self.min_num_cameras, high=self.max_num_cameras + 1 + ) + sampled_camera_idxs = np.random.choice(4, num_cameras, replace=False) + action_valid_idxs = np.isin(points_action_cls, sampled_camera_idxs) + points_action_np = points_action_np[action_valid_idxs] + anchor_valid_idxs = np.isin(points_anchor_cls, sampled_camera_idxs) + points_anchor_np = points_anchor_np[anchor_valid_idxs] + + points_action_mean_np = points_action_np.mean(axis=0) + points_action_np = points_action_np - points_action_mean_np + points_anchor_np = points_anchor_np - points_action_mean_np + + points_action = torch.from_numpy(points_action_np).float().unsqueeze(0) + points_anchor = torch.from_numpy(points_anchor_np).float().unsqueeze(0) + + if points_action.shape[1] < self.num_points: + n = self.num_points // points_action.shape[1] + 1 + points_action = torch.cat([points_action] * n, dim=1) + if points_anchor.shape[1] < self.num_points: + n = self.num_points // points_anchor.shape[1] + 1 + points_anchor = torch.cat([points_action] * n, dim=1) + + return { + "points_action": points_action.numpy(), + "points_anchor": points_anchor.numpy(), + "action_symmetry_features": np.ones_like(points_action[:, :, :1]), + "anchor_symmetry_features": np.ones_like(points_anchor[:, :, :1]), + "action_symmetry_rgb": np.zeros_like( + points_action[:, :, :3], dtype=np.uint8 + ), + "anchor_symmetry_rgb": np.zeros_like( + points_anchor[:, :, :3], dtype=np.uint8 + ), + } + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + return self.load_data(self.filenames[idx], self.start_anchor)