-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
172 lines (153 loc) · 6.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from transformers import HfArgumentParser, TrainingArguments, set_seed
from trl import SFTTrainer
from utils import create_and_prepare_model, create_datasets
# Define and parse arguments.
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
chat_template_format: Optional[str] = field(
default="none",
metadata={
"help": "chatml|zephyr|none. Pass `none` if the dataset is already formatted with the chat template."
},
)
lora_alpha: Optional[int] = field(default=16)
lora_dropout: Optional[float] = field(default=0.1)
lora_r: Optional[int] = field(default=64)
lora_target_modules: Optional[str] = field(
default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
)
use_nested_quant: Optional[bool] = field(
default=True,
metadata={"help": "Activate nested quantization for 4bit base models"},
)
bnb_4bit_compute_dtype: Optional[str] = field(
default="bfloat16",
metadata={"help": "Compute dtype for 4bit base models"},
)
bnb_4bit_quant_storage_dtype: Optional[str] = field(
default="bfloat16",
metadata={"help": "Quantization storage dtype for 4bit base models"},
)
bnb_4bit_quant_type: Optional[str] = field(
default="nf4",
metadata={"help": "Quantization type fp4 or nf4"},
)
use_flash_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enables Flash attention for training."},
)
use_peft_lora: Optional[bool] = field(
default=False,
metadata={"help": "Enables PEFT LoRA for training."},
)
use_8bit_quantization: Optional[bool] = field(
default=False,
metadata={"help": "Enables loading model in 8bit."},
)
use_4bit_quantization: Optional[bool] = field(
default=False,
metadata={"help": "Enables loading model in 4bit."},
)
use_reentrant: Optional[bool] = field(
default=False,
metadata={"help": "Gradient Checkpointing param. Refer the related docs"},
)
use_unsloth: Optional[bool] = field(
default=False,
metadata={"help": "Enables UnSloth for training."},
)
@dataclass
class DataTrainingArguments:
dataset_name: Optional[str] = field(
default="timdettmers/openassistant-guanaco",
metadata={"help": "The preference dataset to use."},
)
packing: Optional[bool] = field(
default=False,
metadata={"help": "Use packing dataset creating."},
)
dataset_text_field: str = field(default="text", metadata={"help": "Dataset field to use as input text."})
max_seq_length: Optional[int] = field(default=512)
append_concat_token: Optional[bool] = field(
default=False,
metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
)
add_special_tokens: Optional[bool] = field(
default=False,
metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
)
splits: Optional[str] = field(
default="train,test",
metadata={"help": "Comma separate list of the splits to use from the dataset."},
)
def main(model_args, data_args, training_args):
# Set seed for reproducibility
set_seed(training_args.seed)
# model
model, peft_config, tokenizer = create_and_prepare_model(model_args, data_args, training_args)
# gradient ckpt
model.config.use_cache = not training_args.gradient_checkpointing
training_args.gradient_checkpointing = training_args.gradient_checkpointing and not model_args.use_unsloth
if training_args.gradient_checkpointing:
training_args.gradient_checkpointing_kwargs = {"use_reentrant": model_args.use_reentrant}
# datasets
train_dataset, eval_dataset = create_datasets(
tokenizer,
data_args,
training_args,
apply_chat_template=model_args.chat_template_format != "none",
)
# trainer
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=data_args.packing,
dataset_kwargs={
"append_concat_token": data_args.append_concat_token,
"add_special_tokens": data_args.add_special_tokens,
},
dataset_text_field=data_args.dataset_text_field,
max_seq_length=data_args.max_seq_length,
)
trainer.accelerator.print(f"{trainer.model}")
trainer.model.print_trainable_parameters()
if model_args.use_peft_lora:
# handle PEFT+FSDP case
trainer.model.print_trainable_parameters()
if getattr(trainer.accelerator.state, "fsdp_plugin", None):
from peft.utils.other import fsdp_auto_wrap_policy
fsdp_plugin = trainer.accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)
# train
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
trainer.train(resume_from_checkpoint=checkpoint)
# saving final model
if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model()
if __name__ == "__main__":
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
main(model_args, data_args, training_args)