forked from ZiYang-xie/PyCAPTCHA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlauncher.py
41 lines (34 loc) · 1.23 KB
/
launcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from model.model import captcha_model, model_conv, model_resnet
from data.datamodule import captcha_dm
import pytorch_lightning as pl
import torch.optim as optim
import torch
from utils.config_util import configGetter
from utils.arg_parsers import train_arg_parser
cfg = configGetter('SOLVER')
lr = cfg['LR']
batch_size = cfg['BATCH_SIZE']
epoch = cfg['EPOCH']
def main(arg):
pl.seed_everything(42)
m = model_resnet()
model = captcha_model(
model=m, lr=lr)
dm = captcha_dm(batch_size=batch_size)
tb_logger = pl.loggers.TensorBoardLogger(
args.log_dir, name=args.exp_name, version=2, default_hp_metric=False)
trainer = pl.Trainer(deterministic=True,
gpus=args.gpus,
auto_select_gpus=True,
precision=32,
logger=tb_logger,
fast_dev_run=False,
max_epochs=epoch,
log_every_n_steps=50,
stochastic_weight_avg=True
)
trainer.fit(model, datamodule=dm)
torch.save(model.state_dict(), args.save_path)
if __name__ == "__main__":
args = train_arg_parser()
main(args)