-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathlaunch_script.py
52 lines (41 loc) · 2.18 KB
/
launch_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/usr/bin/env python
import argparse
import multiprocessing as mp
import torch
from config import Config
from train.game_generator import GameGenerator
from train.policy_improver import PolicyImprover
from train.self_challenge import Champion
from train.train import save_trained, load_model
parser = argparse.ArgumentParser(description='Launcher for distributed Chess trainer')
parser.add_argument('--batch-size', type=int, default=25, help='input batch size for training (default: 25)')
parser.add_argument('--epochs', type=int, default=1, help='number of epochs to train (default: 1)')
parser.add_argument('--lr', type=float, default=0.0002, help='learning rate (default: 0.0002)')
parser.add_argument('--championship-rounds', type=int, default=10,
help='Number of rounds in the championship. Default=10')
parser.add_argument('--checkpoint-path', type=str, default=None, help='Path for checkpointing')
parser.add_argument('--data-path', type=str, default='./data', help='Path to data')
parser.add_argument('--workers', type=int, help='Number of workers used for generating games', default=Config.default_workers)
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--no-cuda', action='store_true', default=True, help='Disables GPU use')
parser.add_argument('--pretrain', action='store_true', default=True, help='Pretrain value function')
args = parser.parse_args()
args.cuda = True if not args.no_cuda and torch.cuda.is_available() else False
torch.manual_seed(args.seed)
def main():
print("Launching Deep Pepper...")
print("Running {} workers on {} cores".format(args.workers, mp.cpu_count()))
pool = mp.Pool(args.workers)
print("Created processing pool of size {}...".format(args.workers))
model, i = load_model()
champion = Champion(model)
generator = GameGenerator(champion, pool, args.batch_size, args.workers)
improver = PolicyImprover(champion, args.championship_rounds)
while True:
games = generator.generate_games()
improver.improve_policy(games, pool)
i += 1
save_trained(model, i)
print("Saving model {}".format(i))
if __name__ == '__main__':
main()