-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_reader.py
31 lines (23 loc) · 1.07 KB
/
train_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from argparse import ArgumentParser, Namespace
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.trainer import Trainer, seed_everything
from soseki.reader.modeling import ReaderLightningModule
def main(args: Namespace) -> None:
seed_everything(args.random_seed, workers=True)
model = ReaderLightningModule(args)
checkpoint_callback = ModelCheckpoint(
filename="best", monitor="val_answer_accuracy", save_last=True, save_top_k=1, mode="max"
)
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = Trainer.from_argparse_args(
args, default_root_dir=args.output_dir, callbacks=[checkpoint_callback, lr_monitor]
)
trainer.fit(model)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--random_seed", type=int, default=1)
parser = ReaderLightningModule.add_model_specific_args(parser)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
main(args)