diff --git a/generate.py b/generate.py index 404d483a..6814a51f 100644 --- a/generate.py +++ b/generate.py @@ -2,8 +2,9 @@ from argparse import ArgumentParser import os import sys - import torch + import numpy + import random script_dir = os.path.dirname(os.path.abspath(__file__)) code_model_dir = os.path.abspath(os.path.join(script_dir, 'model')) @@ -23,6 +24,8 @@ parser.add_argument('--gen_len', type=int, help='Length of generation') parser.add_argument('--temp', type=float, help='Generation temperature') parser.add_argument('--topk', type=int, help='Top-k sampling') + parser.add_argument('--seed', type=int, + help='Seed for the random number generators, allowing reproducibility') parser.set_defaults( model_dir=None, @@ -37,6 +40,13 @@ args = parser.parse_args() + if(args.seed is not None): + # https://pytorch.org/docs/stable/notes/randomness.html + random.seed(args.seed) + numpy.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.use_deterministic_algorithms(True) + model_fp = os.path.join(args.model_dir, 'model.pt') vocab_fp = os.path.join(args.model_dir, 'vocab.txt') if not os.path.isdir(args.out_dir):