diff --git a/handyrl/worker.py b/handyrl/worker.py index 312b40bb..a76d875b 100755 --- a/handyrl/worker.py +++ b/handyrl/worker.py @@ -84,8 +84,8 @@ def run(self): send_recv(self.conn, ('result', result)) -def make_worker_args(args, n_ga, gaid, base_wid, wid, conn): - return args, conn, base_wid + wid * n_ga + gaid +def make_worker_args(args, base_wid, wid, conn): + return args, conn, base_wid + wid def open_worker(args, conn, wid): @@ -94,25 +94,20 @@ def open_worker(args, conn, wid): class Gather(QueueCommunicator): - def __init__(self, args, conn, gaid): - print('started gather %d' % gaid) + def __init__(self, args, conn, gather_id, base_worker_id, num_workers): + print('started gather %d' % gather_id) super().__init__() - self.gather_id = gaid + self.gather_id = gather_id self.server_conn = conn self.args_queue = deque([]) self.data_map = {'model': {}} self.result_send_map = {} self.result_send_cnt = 0 - n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] - - num_workers_per_gather = (n_pro // n_ga) + int(gaid < n_pro % n_ga) - base_wid = args['worker'].get('base_worker_id', 0) - worker_conns = open_multiprocessing_connections( - num_workers_per_gather, + num_workers, open_worker, - functools.partial(make_worker_args, args, n_ga, gaid, base_wid) + functools.partial(make_worker_args, args, base_worker_id) ) for conn in worker_conns: @@ -162,9 +157,25 @@ def run(self): self.result_send_cnt = 0 -def gather_loop(args, conn, gaid): +def gather_loop(args, conn, gather_id): + n_pro, n_ga = args['worker']['num_parallel'], args['worker']['num_gathers'] + n_pro_w = (n_pro // n_ga) + int(gather_id < n_pro % n_ga) + args['worker']['num_parallel_per_gather'] = n_pro_w + base_worker_id = 0 + + if conn is None: + # entry + conn = connect_socket_connection(args['worker']['server_address'], 9998) + conn.send(args['worker']) + args = conn.recv() + + if gather_id == 0: # call once at every machine + print(args) + prepare_env(args['env']) + base_worker_id = args['worker'].get('base_worker_id', 0) + try: - gather = Gather(args, conn, gaid) + gather = Gather(args, conn, gather_id, base_worker_id, n_pro_w) gather.run() finally: gather.shutdown() @@ -203,7 +214,7 @@ def worker_server(port): worker_args = conn.recv() print('accepted connection from %s!' % worker_args['address']) worker_args['base_worker_id'] = self.total_worker_count - self.total_worker_count += worker_args['num_parallel'] + self.total_worker_count += worker_args['num_parallel_per_gather'] args = copy.deepcopy(self.args) args['worker'] = worker_args conn.send(args) @@ -228,18 +239,8 @@ def run(self): process = [] try: for i in range(self.args['num_gathers']): - # entry - conn = connect_socket_connection(self.args['server_address'], 9998) - conn.send(self.args) - args = conn.recv() - - if i == 0: # call once at every machine - print(args) - prepare_env(args['env']) - - p = mp.Process(target=gather_loop, args=(args, conn, i)) + p = mp.Process(target=gather_loop, args=({'worker': self.args}, None, i)) p.start() - conn.close() process.append(p) while True: time.sleep(100)