-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_biencoder.py
31 lines (23 loc) · 1.1 KB
/
train_biencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from argparse import ArgumentParser, Namespace
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.trainer import Trainer, seed_everything
from soseki.biencoder.modeling import BiencoderLightningModule
def main(args: Namespace) -> None:
seed_everything(args.random_seed, workers=True)
model = BiencoderLightningModule(args)
checkpoint_callback = ModelCheckpoint(
filename="best", monitor="val_avg_rank", save_last=True, save_top_k=1, mode="min"
)
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = Trainer.from_argparse_args(
args, default_root_dir=args.output_dir, callbacks=[checkpoint_callback, lr_monitor], replace_sampler_ddp=False
)
trainer.fit(model)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--output_dir", type=str, required=True)
parser.add_argument("--random_seed", type=int, default=1)
parser = BiencoderLightningModule.add_model_specific_args(parser)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
main(args)