-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
22 lines (17 loc) · 835 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from parallelAgent import MADDPG
from environment import Environment
import constants as C
import torch.multiprocessing
#I/O for number of layers goes here...
#layers = input("Enter hidden layers: ")
#layers = int(layers)
if __name__ == '__main__':
torch.multiprocessing.set_start_method('spawn')
torch.multiprocessing.freeze_support()
layers = 3
env = Environment(layers)
C.NUM_AGENTS = layers
#C.MAX_ACTION = [C.ACTION_SPACE] * C.NUM_AGENTS
#call multiAgent here
controller = MADDPG(env, num_agents=layers, alpha=C.ALPHA, beta=C.BETA, tau=C.TAU, input_dims=[C.NUM_AGENTS], n_actions=C.N_ACTIONS, hd1_dims = C.H1_DIMS, hd2_dims = C.H2_DIMS, mem_size = C.BUF_LEN, gamma = C.GAMMA, batch_size = C.BATCH_SIZE)
controller.run_parallel_episodes(max_episodes=C.MAX_EPISODES, max_steps=C.MAX_STEPS)