-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_solve_manipulate.py
48 lines (38 loc) · 1.69 KB
/
test_solve_manipulate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import argparse
import os
import time
import yaml, json
from helper_functions import solve_manipulate, generate_subtasks, load_prior, init_dirs, NumpyArrayEncoder, distort_traj
import warnings
warnings.simplefilter("ignore")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--gui", action="store_true")
parser.add_argument("--name", required=True, help="Name of task")
parser.add_argument("--dt", type=float, default=1e-4, help="Simulation timestep")
args = parser.parse_args()
return args
args = parse_args()
cfg = yaml.load(open("./config.yaml", "r"), Loader=yaml.FullLoader)
# Add objects to sim
args.objects = ["025_mug"]
args.object_positions = np.array([[0.4, -0.3, 0.65, 0.0, 0.0, 1.0, 0.0]])
object_of_interest = args.objects[0]
compressed_traj, clean_traj = load_prior(args, cfg, squish_mu=15e-3)
print(compressed_traj)
subtasks = generate_subtasks(compressed_traj)
print(subtasks)
subtask_name = "0_manipulate"
subtask = subtasks[subtask_name]
# subtask["trajectory"] = distort_traj(subtask["trajectory"], indices=subtask["explore_indices"], mean=0.0, var=0.05)
timestr = time.strftime('_%Y%m%d_%H%M%S')
savefolder = os.path.basename(__file__).replace(".py", "") + timestr
# save config used
savepath = os.path.join(cfg['save_data']['ROLLOUTS'], savefolder)
init_dirs([savepath])
yaml.safe_dump(cfg, open(os.path.join(savepath, "cfg.yaml"), "w"))
json.dump(vars(args), open(os.path.join(savepath, "args.json"), "w"), cls=NumpyArrayEncoder, indent=2)
approach_traj = np.array(subtasks["0_approach"]["trajectory"])[:-1,:]
print(approach_traj)
solve_manipulate(cfg, args, subtask, subtask_name, object_of_interest, savefolder, approach_traj)