Skip to content

Commit

Permalink
added motion dataset, #4
Browse files Browse the repository at this point in the history
  • Loading branch information
belledon committed Dec 21, 2022
1 parent d8d273e commit c3467d6
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 195 deletions.
3 changes: 2 additions & 1 deletion cusanus/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .field import FieldDataset, write_ffcv, load_ffcv
from .geometry import SphericalGeometryDataset, MeshGeometryDataset
from .geometry import write_ffcv
from .motion import KinematicsFieldDataset
64 changes: 20 additions & 44 deletions cusanus/datasets/geometry.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,20 @@
import torch
import numpy as np
from copy import deepcopy
from torch.utils.data import Dataset
from ffcv.writer import DatasetWriter
from ffcv.fields import NDArrayField
from ffcv.fields.decoders import NDArrayDecoder
from ffcv.loader import Loader, OrderOption
from ffcv.transforms import (Convert, NormalizeImage, ToTensor,
ToDevice)

import trimesh
import numpy as np
from abc import ABC

from cusanus.pytypes import *
from cusanus.datasets import FieldDataset
from cusanus.utils import grids_along_depth
from cusanus.utils.meshes import sheered_rect_prism, ramp_mesh

_pipe = [NDArrayDecoder(),
ToTensor(),
Convert(torch.float32)]
_pipelines = {'qs': _pipe, 'ys': _pipe}

def write_ffcv(d:Dataset, k:int, path:str):
qshape = (k, 3)
yshape = (k, 1)
fields = {
'qs': NDArrayField(dtype = np.dtype('float32'),
shape = qshape),
'ys': NDArrayField(dtype = np.dtype('float32'),
shape = yshape),
}
writer = DatasetWriter(path, fields)
writer.from_indexed_dataset(d)

def load_ffcv(p:str, device, **kwargs):
ps = {}
for k in ['qs', 'ys']:
ps[k] = deepcopy(_pipe)
if not device is None:
ps[k].append(ToDevice(device))
return Loader(p, pipelines = ps, order = OrderOption(3),
**kwargs)

class OccupancyFieldDataset(Dataset):
ffcv_pipelines = _pipelines

def write_ffcv(self, path:str):
write_ffcv(self,self.k_queries,path)

class OccupancyFieldDataset(FieldDataset, ABC):
@property
def qsize(self):
return 3
@property
def ysize(self):
return 1

class SphericalGeometryDataset(OccupancyFieldDataset):

Expand All @@ -55,11 +23,15 @@ def __init__(self, n_shapes:int = 1000, k_queries:int = 100,
r_min:float = 0.1, r_max:float = 0.8,
sigma:float=3.0) -> None:
self.n_shapes = n_shapes
self.k_queries = k_queries
self._k_queries = k_queries
self.r_min = r_min
self.r_max = r_max
self.sigma = sigma

@property
def k_queries(self):
return self._k_queries

def __len__(self):
return self.n_shapes

Expand Down Expand Up @@ -88,14 +60,18 @@ def __init__(self, n_shapes:int = 1000, k_queries:int = 100,
obs_extents:List[float] = [1.5,1.5,2.5],
ramp_extents:List[float]= [4.0, 1.5, .1]) -> None:
self.n_shapes = n_shapes
self.k_queries = k_queries
self._k_queries = k_queries
self.delta_y = delta_y
self.delta_size = delta_size
self.qsigma = qsigma
self.axis = np.array([0., 1., 0.])
self.obs_extents = obs_extents
self.ramp_extents = ramp_extents

@property
def k_queries(self):
return self._k_queries

def __len__(self):
return self.n_shapes

Expand Down
2 changes: 2 additions & 0 deletions scripts/configs/motion_field_dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
train:
n_scenes: 10000
74 changes: 0 additions & 74 deletions scripts/write_mesh_dataset.py

This file was deleted.

31 changes: 31 additions & 0 deletions scripts/write_motion_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python

import os
import yaml
import argparse
import torch

from cusanus.datasets import write_ffcv, KinematicsFieldDataset

name = 'motion_field'

def main():
parser = argparse.ArgumentParser(
description = 'Generates occupancy field dataset via ffcv',
formatter_class = argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument('--num_workers', type = int,
help = 'Number of write workers',
default = -1)
args = parser.parse_args()


with open(f"/project/scripts/configs/{name}_dataset.yaml", 'r') as file:
config = yaml.safe_load(file)

d = KinematicsFieldDataset(**config['train'])
dpath = f"/spaths/datasets/{name}_train_dataset.beton"
d.write_ffcv(dpath, num_workers = args.num_workers)

if __name__ == '__main__':
main()
76 changes: 0 additions & 76 deletions scripts/write_spherical_geo_dataset.py

This file was deleted.

0 comments on commit c3467d6

Please sign in to comment.