Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the real world stuff #44

Merged
merged 1 commit into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/benchmark/real_world_mug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
name: real_world_mug
dataset_root: /home/beisner/code/multi_project/cam_ready_trainingdata
49 changes: 49 additions & 0 deletions configs/dataset/real_world_mug/mug_place.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Dataset Settings
train_dset:
demo_dset:
dataset_type: real_world_mug
dataset_root: ${benchmark.dataset_root}
dataset_indices:
- 0
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
start_anchor: True
min_num_cameras: 4
max_num_cameras: 4
num_points: 1024
action_class: ${task.action_class}
anchor_class: ${task.anchor_class}

val_dset:
demo_dset:
dataset_type: ${...train_dset.demo_dset.dataset_type}
dataset_root: ${benchmark.dataset_root}
dataset_indices:
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
start_anchor: ${...train_dset.demo_dset.start_anchor}
min_num_cameras: ${...train_dset.demo_dset.min_num_cameras}
max_num_cameras: ${...train_dset.demo_dset.max_num_cameras}
num_points: ${...train_dset.demo_dset.num_points}
action_class: ${task.action_class}
anchor_class: ${task.anchor_class}

test_dset:
demo_dset:
dataset_indices: null
num_demo: null
6 changes: 6 additions & 0 deletions taxpose/datasets/point_cloud_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def make_dataset(
import taxpose.datasets.ndf as ndf

return ndf.NDFPointCloudDataset(cast(ndf.NDFPointCloudDatasetConfig, cfg))
elif cfg.dataset_type == "real_world_mug":
import taxpose.datasets.real_world_mug as real_world_mug

return real_world_mug.RealWorldMugPointCloudDataset(
cast(real_world_mug.RealWorldMugPointCloudDatasetConfig, cfg)
)
else:
raise NotImplementedError(f"Unknown dataset type: {cfg.dataset_type}")

Expand Down
103 changes: 103 additions & 0 deletions taxpose/datasets/real_world_mug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import functools
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import ClassVar, List, Optional

import numpy as np
import torch
from torch.utils.data import Dataset

from taxpose.datasets.base import PlacementPointCloudData


@dataclass
class RealWorldMugPointCloudDatasetConfig:
dataset_type: ClassVar[str] = "real_world_mug"
dataset_root: Path
dataset_indices: Optional[List[int]] = None
start_anchor: bool = False
min_num_cameras: int = 4
max_num_cameras: int = 4
num_points: int = 1024


class RealWorldMugPointCloudDataset(Dataset[PlacementPointCloudData]):
def __init__(self, cfg: RealWorldMugPointCloudDatasetConfig):
# Check that all the files in the dataset are there.

assert cfg.dataset_indices is not None and len(cfg.dataset_indices) > 0

self.filenames = [
Path(cfg.dataset_root) / f"{idx}.pkl" for idx in cfg.dataset_indices
]

# Make sure each file exists.
for filename in self.filenames:
assert filename.exists()

self.start_anchor = cfg.start_anchor
self.min_num_cameras = cfg.min_num_cameras
self.max_num_cameras = cfg.max_num_cameras
self.num_points = cfg.num_points

# @staticmethod
@functools.cache
def load_data(self, filename, start_anchor=False):
with open(filename, "rb") as f:
sensor_data = pickle.load(f)

points_action_np = sensor_data["end_action"]
if start_anchor:
points_anchor_np = sensor_data["start_anchor"]
else:
points_anchor_np = sensor_data["end_anchor"]

points_action_cls = sensor_data["end_action_cam_label"]
if start_anchor:
points_anchor_cls = sensor_data["start_anchor_cam_label"]
else:
points_anchor_cls = sensor_data["end_anchor_cam_label"]

if self.min_num_cameras < 4:
num_cameras = np.random.randint(
low=self.min_num_cameras, high=self.max_num_cameras + 1
)
sampled_camera_idxs = np.random.choice(4, num_cameras, replace=False)
action_valid_idxs = np.isin(points_action_cls, sampled_camera_idxs)
points_action_np = points_action_np[action_valid_idxs]
anchor_valid_idxs = np.isin(points_anchor_cls, sampled_camera_idxs)
points_anchor_np = points_anchor_np[anchor_valid_idxs]

points_action_mean_np = points_action_np.mean(axis=0)
points_action_np = points_action_np - points_action_mean_np
points_anchor_np = points_anchor_np - points_action_mean_np

points_action = torch.from_numpy(points_action_np).float().unsqueeze(0)
points_anchor = torch.from_numpy(points_anchor_np).float().unsqueeze(0)

if points_action.shape[1] < self.num_points:
n = self.num_points // points_action.shape[1] + 1
points_action = torch.cat([points_action] * n, dim=1)
if points_anchor.shape[1] < self.num_points:
n = self.num_points // points_anchor.shape[1] + 1
points_anchor = torch.cat([points_action] * n, dim=1)

return {
"points_action": points_action.numpy(),
"points_anchor": points_anchor.numpy(),
"action_symmetry_features": np.ones_like(points_action[:, :, :1]),
"anchor_symmetry_features": np.ones_like(points_anchor[:, :, :1]),
"action_symmetry_rgb": np.zeros_like(
points_action[:, :, :3], dtype=np.uint8
),
"anchor_symmetry_rgb": np.zeros_like(
points_anchor[:, :, :3], dtype=np.uint8
),
}

def __len__(self):
return len(self.filenames)

def __getitem__(self, idx):
return self.load_data(self.filenames[idx], self.start_anchor)
Loading