diff --git a/qlora.py b/qlora.py index 45cbe889..027a3378 100644 --- a/qlora.py +++ b/qlora.py @@ -675,7 +675,15 @@ def train(): set_seed(args.seed) data_module = make_data_module(tokenizer=tokenizer, args=args) - + + # When using distributed training, the value of the flag find_unused_parameters passed to + # DistributedDataParallel. Will default to False if gradient checkpointing is used, True otherwise. + if os.environ.get('LOCAL_RANK') is not None: + if training_args.gradient_checkpointing: + training_args.ddp_find_unused_parameters = False + else: + training_args.ddp_find_unused_parameters = True + trainer = Seq2SeqTrainer( model=model, tokenizer=tokenizer,