Skip to content

Commit

Permalink
Change datatype to fp16 to run on all GPUs (#13)
Browse files Browse the repository at this point in the history
* Add comment regarding float16 and bfloat16 data-type compatibility with GPUs
  • Loading branch information
shub-kris authored Feb 19, 2024
1 parent 8991654 commit 1963e1b
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions examples/vertex-ai/notebooks/gemma-finetuning-clm-lora-sft.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@
" load_in_4bit=True, # quantize the model to 4-bits when you load it\n",
" bnb_4bit_quant_type=\"nf4\", #use a special 4-bit data type for weights initialized from a normal distribution\n",
" bnb_4bit_use_double_quant=True, #use a nested quantization scheme to quantize the already quantized weights\n",
" bnb_4bit_compute_dtype=torch.bfloat16, #for faster computation\n",
" bnb_4bit_compute_dtype=torch.bfloat16, #conversion from bfloat16 to float16 may lead to overflow (and opposite may lead to loss of precision)\n",
" # Use float16 when running on a GPU(T4, V100) where bfloat16 is not supported\n",
")"
]
},
Expand Down Expand Up @@ -178,7 +179,7 @@
"source": [
"## Load the tokenizer\n",
"# tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"tokenizer = GemmaTokenizer.from_pretrained(model_id.replace(\"gemma\", \"golden-gate\")) # Until the tokenizer is available in the model hub"
"tokenizer = GemmaTokenizer.from_pretrained(model_id)"
]
},
{
Expand All @@ -198,7 +199,7 @@
},
"outputs": [],
"source": [
"from peft import get_peft_model, LoraConfig, TaskType\n",
"from peft import LoraConfig\n",
"\n",
"peft_config = LoraConfig(\n",
" task_type=\"Causal_LM\", \n",
Expand Down Expand Up @@ -264,7 +265,8 @@
" group_by_length=True,\n",
" eval_steps=20, # Evaluate every n steps during training\n",
" evaluation_strategy=\"steps\",\n",
" bf16=True\n",
" bf16=True #conversion from bfloat16 to float16 may lead to overflow (and opposite may lead to loss of precision)\n",
" # Use float16 when running on a GPU(T4, V100) where bfloat16 is not supported\n",
")\n",
"\n",
"trainer = SFTTrainer(\n",
Expand Down

0 comments on commit 1963e1b

Please sign in to comment.