From 3437933b399aaaec22baef0259e62b7cc3249bb0 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Mon, 31 Jul 2023 19:01:07 +0000 Subject: [PATCH] use torch_dtype=fp16 when specified --- qlora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qlora.py b/qlora.py index 23e675ee..d085e5a1 100644 --- a/qlora.py +++ b/qlora.py @@ -324,7 +324,7 @@ def get_accelerate_model(args, checkpoint_dir): bnb_4bit_use_double_quant=args.double_quant, bnb_4bit_quant_type=args.quant_type, ), - torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)), + torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)), trust_remote_code=args.trust_remote_code, use_auth_token=args.use_auth_token ) @@ -341,7 +341,7 @@ def get_accelerate_model(args, checkpoint_dir): setattr(model, 'model_parallel', True) setattr(model, 'is_parallelizable', True) - model.config.torch_dtype=(torch.float32 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) + model.config.torch_dtype=(torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) # Tokenizer tokenizer = AutoTokenizer.from_pretrained(