Skip to content

Commit

Permalink
Better parallel training and RNG fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Feb 22, 2023
1 parent 688e04f commit 84ecf19
Show file tree
Hide file tree
Showing 18 changed files with 114 additions and 79 deletions.
2 changes: 1 addition & 1 deletion dreamerv3/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ENV GCS_READ_REQUEST_TIMEOUT_SECS=300
ENV GCS_WRITE_REQUEST_TIMEOUT_SECS=600

# Embodied
RUN pip3 install numpy cloudpickle ruamel.yaml rich
RUN pip3 install numpy cloudpickle ruamel.yaml rich zmq msgpack
COPY . /embodied
RUN chown -R 1000:root /embodied && chmod -R 775 /embodied

Expand Down
5 changes: 0 additions & 5 deletions dreamerv3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,6 @@ def report(self, data):
report.update({f'expl_{k}': v for k, v in mets.items()})
return report

def dataset(self, generator):
return embodied.Prefetch(
sources=[generator] * self.config.batch_size,
workers=self.config.data_loaders, prefetch=4)

def preprocess(self, obs):
obs = obs.copy()
for key, value in obs.items():
Expand Down
17 changes: 13 additions & 4 deletions dreamerv3/configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@ defaults:
replay: uniform
replay_size: 1e6
replay_online: False
jax: {platform: gpu, jit: True, precision: float16, prealloc: True, debug_nans: False, logical_cpus: 0, debug: False, policy_devices: [0], train_devices: [0]}
eval_dir: ''
filter: '.*'

# Loop
jax:
platform: gpu
jit: True
precision: float16
prealloc: True
debug_nans: False
logical_cpus: 0
debug: False
policy_devices: [0]
train_devices: [0]
metrics_every: 10

run:
script: train
steps: 1e10
Expand All @@ -31,12 +41,11 @@ defaults:
log_keys_mean: '(log_entropy)'
log_keys_max: '^$'
from_checkpoint: ''
sync_every: 200
sync_every: 10
# actor_addr: 'tcp://127.0.0.1:5551'
actor_addr: 'ipc:///tmp/5551'
actor_batch: 32

# Envs
envs: {amount: 4, parallel: process, length: 0, reset: True, restart: True, discretize: 0, checks: False}
wrapper: {length: 0, reset: True, discretize: 0, checks: False}
env:
Expand Down
2 changes: 1 addition & 1 deletion dreamerv3/embodied/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .parallel import Parallel
from .timer import Timer
from .worker import Worker
from .prefetch import Prefetch
from .batcher import Batcher
from .metrics import Metrics
from .uuid import uuid

Expand Down
3 changes: 1 addition & 2 deletions dreamerv3/embodied/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ def __init__(self, obs_space, act_space, step, config):
pass

def dataset(self, generator_fn):
# TODO: Go from iterable to iterable instead.
raise NotImplementedError(
'dataset(generator_fn) -> iterable')
'dataset(generator_fn) -> generator_fn')

def policy(self, obs, state=None, mode='train'):
raise NotImplementedError(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import queue as queuelib
import sys
import threading
import time
import traceback

import numpy as np


class Prefetch:
"""Implements zip() with multi-threaded prefetching. The sources are expected
to yield dicts of Numpy arrays and the iterator will return dicts of batched
Numpy arrays."""
class Batcher:

def __init__(self, sources, workers=0, prefetch=4):
self.workers = workers
def __init__(
self, sources, workers=0, postprocess=None,
prefetch_source=4, prefetch_batch=2):
self._workers = workers
self._postprocess = postprocess
if workers:
# Round-robin assign sources to workers.
self._running = True
self._threads = []
self._queues = []
assignments = [([], []) for _ in range(workers)]
for index, source in enumerate(sources):
queue = queuelib.Queue(prefetch)
queue = queuelib.Queue(prefetch_source)
self._queues.append(queue)
assignments[index % workers][0].append(source)
assignments[index % workers][1].append(queue)
Expand All @@ -29,7 +30,7 @@ def __init__(self, sources, workers=0, prefetch=4):
target=self._creator, args=args, daemon=True)
creator.start()
self._threads.append(creator)
self._batches = queuelib.Queue(prefetch)
self._batches = queuelib.Queue(prefetch_batch)
batcher = threading.Thread(
target=self._batcher, args=(self._queues, self._batches),
daemon=True)
Expand All @@ -40,7 +41,7 @@ def __init__(self, sources, workers=0, prefetch=4):
self._once = False

def close(self):
if self.workers:
if self._workers:
self._running = False
for thread in self._threads:
thread.close()
Expand All @@ -53,8 +54,11 @@ def __iter__(self):
self._once = True
return self

def __call__(self):
return self.__iter__()

def __next__(self):
if self.workers:
if self._workers:
batch = self._batches.get()
else:
elems = [next(x) for x in self._iterators]
Expand All @@ -67,8 +71,14 @@ def _creator(self, sources, outputs):
try:
iterators = [source() for source in sources]
while self._running:
waiting = True
for iterator, queue in zip(iterators, outputs):
if queue.full():
continue
queue.put(next(iterator))
waiting = False
if waiting:
time.sleep(0.001)
except Exception as e:
e.stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
outputs[0].put(e)
Expand All @@ -82,7 +92,9 @@ def _batcher(self, sources, output):
if isinstance(elem, Exception):
raise elem
batch = {k: np.stack([x[k] for x in elems], 0) for k in elems[0]}
output.put(batch)
if self._postprocess:
batch = self._postprocess(batch)
output.put(batch) # Will wait here if the queue is full.
except Exception as e:
e.stacktrace = ''.join(traceback.format_exception(*sys.exc_info()))
output.put(e)
Expand Down
10 changes: 9 additions & 1 deletion dreamerv3/embodied/core/distr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import ctypes
import multiprocessing
import sys
import threading
import time
Expand Down Expand Up @@ -161,12 +160,17 @@ def terminate(self):
class Process:

lock = None
initializers = []

def __init__(self, fn, *args, name=None):
import multiprocessing
import cloudpickle
mp = multiprocessing.get_context('spawn')
if Process.lock is None:
Process.lock = mp.Lock()
name = name or fn.__name__
initializers = cloudpickle.dumps(self.initializers)
args = (initializers,) + args
self._process = mp.Process(
target=self._wrapper, args=(Process.lock, fn, *args),
name=name)
Expand All @@ -188,6 +192,10 @@ def terminate(self):

def _wrapper(self, lock, fn, *args):
try:
import cloudpickle
initializers, *args = args
for initializer in cloudpickle.loads(initializers):
initializer()
fn(*args)
except Exception:
with lock:
Expand Down
1 change: 0 additions & 1 deletion dreamerv3/embodied/core/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,6 @@ def __init__(self, run_name=None, resume_id=None, config=None, prefix=None):
self._setup(run_name, resume_id, config)

def __call__(self, summaries):
timestamp = datetime.datetime.now().timestamp()
bystep = collections.defaultdict(dict)
for step, name, value in summaries:
if len(value.shape) == 0 and self._pattern.search(name):
Expand Down
6 changes: 4 additions & 2 deletions dreamerv3/embodied/run/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def learner(step, agent, replay, logger, timer, args):
should_log = embodied.when.Clock(args.log_every)
should_save = embodied.when.Clock(args.save_every)
should_sync = embodied.when.Every(args.sync_every)
updates = embodied.Counter()

checkpoint = embodied.Checkpoint(logdir / 'checkpoint.ckpt')
checkpoint.step = step
Expand All @@ -93,16 +94,17 @@ def learner(step, agent, replay, logger, timer, args):
checkpoint.load(args.from_checkpoint)
checkpoint.load_or_save()

dataset = iter(agent.dataset(replay.dataset))
dataset = agent.dataset(replay.dataset)
state = None
stats = dict(last_time=time.time(), last_step=int(step), batch_entries=0)
while True:
batch = next(dataset)
outs, state, mets = agent.train(batch, state)
metrics.add(mets)
updates.increment()
stats['batch_entries'] += batch['is_first'].size

if should_sync(step):
if should_sync(updates):
agent.sync()

if should_log():
Expand Down
6 changes: 4 additions & 2 deletions dreamerv3/embodied/run/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def train(agent, env, replay, logger, args):
should_save = embodied.when.Clock(args.save_every)
should_sync = embodied.when.Every(args.sync_every)
step = logger.step
updates = embodied.Counter()
metrics = embodied.Metrics()
print('Observation space:', embodied.format(env.obs_space), sep='\n')
print('Action space:', embodied.format(env.act_space), sep='\n')
Expand Down Expand Up @@ -65,7 +66,7 @@ def per_episode(ep):
logger.add(metrics.result())
logger.write()

dataset = iter(agent.dataset(replay.dataset))
dataset = agent.dataset(replay.dataset)
state = [None] # To be writable from train step function below.
batch = [None]
def train_step(tran, worker):
Expand All @@ -76,7 +77,8 @@ def train_step(tran, worker):
metrics.add(mets, prefix='train')
if 'priority' in outs:
replay.prioritize(outs['key'], outs['priority'])
if should_sync(step):
updates.increment()
if should_sync(updates):
agent.sync()
if should_log(step):
agg = metrics.result()
Expand Down
8 changes: 5 additions & 3 deletions dreamerv3/embodied/run/train_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def train_eval(
should_eval = embodied.when.Every(args.eval_every, args.eval_initial)
should_sync = embodied.when.Every(args.sync_every)
step = logger.step
updates = embodied.Counter()
metrics = embodied.Metrics()
print('Observation space:', embodied.format(train_env.obs_space), sep='\n')
print('Action space:', embodied.format(train_env.act_space), sep='\n')
Expand Down Expand Up @@ -70,8 +71,8 @@ def per_episode(ep, mode):
logger.add(metrics.result())
logger.write()

dataset_train = iter(agent.dataset(train_replay.dataset))
dataset_eval = iter(agent.dataset(eval_replay.dataset))
dataset_train = agent.dataset(train_replay.dataset)
dataset_eval = agent.dataset(eval_replay.dataset)
state = [None] # To be writable from train step function below.
batch = [None]
def train_step(tran, worker):
Expand All @@ -82,7 +83,8 @@ def train_step(tran, worker):
metrics.add(mets, prefix='train')
if 'priority' in outs:
train_replay.prioritize(outs['key'], outs['priority'])
if should_sync(step):
updates.inc()
if should_sync(updates):
agent.sync()
if should_log(step):
logger.add(metrics.result())
Expand Down
8 changes: 5 additions & 3 deletions dreamerv3/embodied/run/train_holdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def train_holdout(agent, env, train_replay, eval_replay, logger, args):
should_save = embodied.when.Clock(args.save_every)
should_sync = embodied.when.Every(args.sync_every)
step = logger.step
updates = embodied.Counter()
metrics = embodied.Metrics()
print('Observation space:', embodied.format(env.obs_space), sep='\n')
print('Action space:', embodied.format(env.act_space), sep='\n')
Expand Down Expand Up @@ -70,8 +71,8 @@ def per_episode(ep):
logger.add(metrics.result())
logger.write()

dataset_train = iter(agent.dataset(train_replay.dataset))
dataset_eval = iter(agent.dataset(eval_replay.dataset))
dataset_train = agent.dataset(train_replay.dataset)
dataset_eval = agent.dataset(eval_replay.dataset)
state = [None] # To be writable from train step function below.
batch = [None]
def train_step(tran, worker):
Expand All @@ -82,7 +83,8 @@ def train_step(tran, worker):
metrics.add(mets, prefix='train')
if 'priority' in outs:
train_replay.prioritize(outs['key'], outs['priority'])
if should_sync(step):
updates.increment()
if should_sync(updates):
agent.sync()
if should_log(step):
logger.add(metrics.result())
Expand Down
6 changes: 4 additions & 2 deletions dreamerv3/embodied/run/train_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def train_save(agent, env, replay, logger, args):
should_save = embodied.when.Clock(args.save_every)
should_sync = embodied.when.Every(args.sync_every)
step = logger.step
updates = embodied.Counter()
metrics = embodied.Metrics()
print('Observation space:', embodied.format(env.obs_space), sep='\n')
print('Action space:', embodied.format(env.act_space), sep='\n')
Expand Down Expand Up @@ -84,7 +85,7 @@ def save(ep):
logger.add(metrics.result())
logger.write()

dataset = iter(agent.dataset(replay.dataset))
dataset = agent.dataset(replay.dataset)
state = [None] # To be writable from train step function below.
batch = [None]
def train_step(tran, worker):
Expand All @@ -95,7 +96,8 @@ def train_step(tran, worker):
metrics.add(mets, prefix='train')
if 'priority' in outs:
replay.prioritize(outs['key'], outs['priority'])
if should_sync(step):
updates.increment()
if should_sync(updates):
agent.sync()
if should_log(step):
agg = metrics.result()
Expand Down
Loading

0 comments on commit 84ecf19

Please sign in to comment.