-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathreplay_episodes.py
66 lines (57 loc) · 2.05 KB
/
replay_episodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Author: Jimmy Wu
# Date: October 2024
import argparse
import time
from itertools import count
from pathlib import Path
import cv2 as cv
from constants import POLICY_CONTROL_PERIOD
from episode_storage import EpisodeReader
from mujoco_env import MujocoEnv
def replay_episode(env, episode_dir, show_images=False, execute_obs=False):
# Reset env
env.reset()
# Load episode data
reader = EpisodeReader(episode_dir)
print(f'Loaded episode from {episode_dir}')
start_time = time.time()
for step_idx, (obs, action) in enumerate(zip(reader.observations, reader.actions)):
# Enforce desired control freq
step_end_time = start_time + step_idx * POLICY_CONTROL_PERIOD
while time.time() < step_end_time:
time.sleep(0.0001)
# Show image observations
if show_images:
window_idx = 0
for k, v in obs.items():
if v.ndim == 3:
cv.imshow(k, cv.cvtColor(v, cv.COLOR_RGB2BGR))
cv.moveWindow(k, 640 * window_idx, 0)
window_idx += 1
cv.waitKey(1)
# Execute in action in env
if execute_obs:
env.step(obs)
else:
env.step(action)
def main(args):
# Create env
if args.sim:
env = MujocoEnv(render_images=False)
else:
from real_env import RealEnv
env = RealEnv()
try:
episode_dirs = sorted([child for child in Path(args.input_dir).iterdir() if child.is_dir()])
for episode_dir in episode_dirs:
replay_episode(env, episode_dir, show_images=args.show_images, execute_obs=args.execute_obs)
# input('Press <Enter> to continue...')
finally:
env.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input-dir', default='data/demos')
parser.add_argument('--sim', action='store_true')
parser.add_argument('--show-images', action='store_true')
parser.add_argument('--execute-obs', action='store_true')
main(parser.parse_args())