-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathremote_trainer_with_communicator.py
122 lines (92 loc) · 4.16 KB
/
remote_trainer_with_communicator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import pickle
import uuid
from copy import deepcopy
from typing import List
import ray
from ray.rllib.agents import with_common_config
from ray.rllib.agents.ppo.ppo import DEFAULT_CONFIG as config_ppo
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy
from ray.rllib.agents.trainer_template import build_trainer
from ray.rllib.evaluation.worker_set import WorkerSet
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.rollout_ops import ParallelRollouts, ConcatBatches, StandardizeFields, SelectExperiences
from ray.rllib.policy.sample_batch import SampleBatch, DEFAULT_POLICY_ID, MultiAgentBatch
from ap_rllib.helpers import filter_pickleable, dict_get_any_value, save_gym_space, unlink_ignore_error
from frankenstein.remote_communicator import RemoteHTTPPickleCommunicator
from gym_compete_rllib.load_gym_compete_policy import nets_to_weights, load_weights_from_vars
def rllib_samples_to_dict(samples):
"""Convert rllib MultiAgentBatch to a dict."""
samples = samples.policy_batches
samples = {x: dict(y) for x, y in samples.items()}
return samples
def train_external(policies, samples, config):
"""Train using a TCP stable_baselines server, return info."""
infos = {}
answer_paths = {}
data_paths = {}
# doing nothing for make_video.py
if config['lr'] == 0:
return {}
samples_dict = rllib_samples_to_dict(samples)
# only training policies with data
to_train = set(policies)
to_train = to_train.intersection(samples_dict.keys())
config_orig = deepcopy(config)
config = filter_pickleable(config_orig)
# config to send
p = dict_get_any_value(config_orig['multiagent']['policies'])
print(config_orig['multiagent']['policies'])
obs_space, act_space = p[1], p[2]
config['_observation_space'] = save_gym_space(obs_space)
config['_action_space'] = save_gym_space(act_space)
communicator = RemoteHTTPPickleCommunicator(config['http_remote_port'])
# requesting to train all policies
for policy in to_train:
# only training the requested policies
if policy not in config['multiagent']['policies_to_train']:
continue
# identifier for this run
run_uid = config['run_uid']
# identifier for the run+policy
run_policy_uid = f"{run_uid}_policy_{policy}"
# unique step information
iteration = str(uuid.uuid1())
# identifier for run+policy_current step
run_policy_step_uid = f"{run_uid}_policy_{policy}_step{iteration}"
# data to pickle
data_policy = {'rollouts': samples_dict[policy],
'weights': nets_to_weights(policies[policy].model._nets),
'config': config}
# paths for data/answer
tmp_dir = os.getcwd() # config['tmp_dir']
data_path = os.path.join(tmp_dir, run_policy_step_uid + '.pkl')
answer_path = os.path.join(tmp_dir, run_policy_step_uid + '.answer.pkl')
data_paths[policy] = data_path
answer_paths[policy] = answer_path
# saving pickle data
pickle.dump(data_policy, open(data_path, 'wb'))
# connecting to the RPC server
communicator.submit_job(client_id=run_policy_uid, data_path=data_path,
answer_path=answer_path, data=data_policy)
# obtaining policies
for policy in to_train:
answer_path = answer_paths[policy]
data_path = data_paths[policy]
weights_info = communicator.get_result(answer_path)
# checking correctness
if not (weights_info[0] is True):
raise Exception(weights_info[1])
weights = weights_info[1]['weights']
info = weights_info[1]['info']
def load_weights(model, weights):
"""Load weights into a model."""
load_weights_from_vars(weights, model._nets['value'], model._nets['policy'])
# loading weights into the model
load_weights(policies[policy].model, weights)
# removing pickle files to save space
unlink_ignore_error(data_path)
unlink_ignore_error(answer_path)
# obtaining info
infos[policy] = dict(info)
return infos