Skip to content

Commit

Permalink
trying to use llama-2 template
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasonqi146 committed Oct 27, 2023
1 parent 9a01411 commit b76fc0f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 32 deletions.
9 changes: 9 additions & 0 deletions llm_ft/fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,15 @@ def get_conv_template(name: str) -> Conversation:
)
)

register_conv_template(
Conversation(
name="sotopia-llama-2",
roles=("Agent1", "Agent2"),
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
)
)

if __name__ == "__main__":
print("Vicuna template:")
Expand Down
53 changes: 25 additions & 28 deletions llm_ft/fastchat/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from transformers.trainer_pt_utils import LabelSmoother

from fastchat.conversation import SeparatorStyle
# from fastchat.model.model_adapter import get_conversation_template
from fastchat.conversation import get_conv_template
from fastchat.model.model_adapter import get_conversation_template

IGNORE_TOKEN_ID = LabelSmoother.ignore_index

Expand All @@ -51,7 +51,7 @@ class DataArguments:
)
lazy_preprocess: bool = False
shuffle: bool = True
drop_long_seq: bool = False
template: str|None = None


@dataclass
Expand Down Expand Up @@ -88,9 +88,10 @@ def trainer_save_model_safe(trainer: transformers.Trainer):
def preprocess(
sources,
tokenizer: transformers.PreTrainedTokenizer,
drop_long_seq: bool = False,
template: str | None
) -> Dict:
conv = get_conv_template("sotopia")
print("Template: ", template)
conv = get_conv_template(template) if template else get_conversation_template("vicuna")
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# Apply prompt templates
Expand All @@ -107,14 +108,6 @@ def preprocess(
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())

if drop_long_seq:
new_conversation = []
for temp_conv in conversations:
token_len = tokenizer(temp_conv, return_tensors="pt", padding=False, truncation=False).input_ids.size()[1]
if token_len <= tokenizer.model_max_length: new_conversation.append(temp_conv)
conversation = new_conversation
print(f"Dropping conversations longer than {tokenizer.model_max_length}; Now have {len(conversation)} conversations")

# Tokenize conversations
input_ids = tokenizer(
conversations,
Expand All @@ -124,11 +117,11 @@ def preprocess(
truncation=True,
).input_ids
targets = input_ids.clone()

assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO


print("Conv Sep Style", conv.sep_style)
sep = conv.sep + conv.roles[1] + ": " if conv.sep_style == SeparatorStyle.ADD_COLON_TWO else conv.sep

# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())

Expand All @@ -145,19 +138,24 @@ def preprocess(
break
parts[0] += sep
# "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 if conv.sep_style == SeparatorStyle.ADD_COLON_TWO else len(tokenizer(parts[0]).input_ids) - 1

# Ignore the user instructions
target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len

target[cur_len:] = IGNORE_TOKEN_ID

if conv.sep_style == SeparatorStyle.ADD_COLON_TWO:
target[cur_len:] = IGNORE_TOKEN_ID
else:
target[cur_len+1:] = IGNORE_TOKEN_ID

if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
rank0_print(tokenizer.decode(z))


if not conv.sep_style == SeparatorStyle.ADD_COLON_TWO: cur_len += 2

if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
Expand All @@ -175,12 +173,12 @@ def preprocess(
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""

def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, drop_long_seq: bool = False):
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, template: None):
super(SupervisedDataset, self).__init__()

rank0_print("Formatting inputs...")
sources = [example["conversations"] for example in raw_data]
data_dict = preprocess(sources, tokenizer, drop_long_seq=drop_long_seq)
data_dict = preprocess(sources, tokenizer, template)

self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
Expand All @@ -200,15 +198,15 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""

def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, drop_long_seq: bool = False):
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, template: None):
super(LazySupervisedDataset, self).__init__()
self.tokenizer = tokenizer
self.drop_long_seq = drop_long_seq

rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.raw_data = raw_data
self.cached_data_dict = {}
self.template = template

def __len__(self):
return len(self.raw_data)
Expand All @@ -217,7 +215,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if i in self.cached_data_dict:
return self.cached_data_dict[i]

ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer)
ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.template)
ret = dict(
input_ids=ret["input_ids"][0],
labels=ret["labels"][0],
Expand All @@ -240,17 +238,16 @@ def make_supervised_data_module(
train_json = json.load(open(data_args.data_path, "r"))
if data_args.shuffle: random.shuffle(train_json)

train_dataset = dataset_cls(train_json, tokenizer=tokenizer, drop_long_seq = data_args.drop_long_seq)
train_dataset = dataset_cls(train_json, tokenizer=tokenizer, template=data_args.template)

if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
if data_args.shuffle: random.shuffle(train_json)

eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, drop_long_seq = data_args.drop_long_seq)
eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer, template=data_args.template)
else:
eval_dataset = None


return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)


Expand Down
2 changes: 0 additions & 2 deletions llm_ft/fastchat/train/train_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,6 @@ def train():
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()

print(data_args)

device_map = None
world_size = int(os.environ.get("WORLD_SIZE", 1))
Expand Down
4 changes: 2 additions & 2 deletions llm_ft/llama2-13b_qlora_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ deepspeed --num_gpus=1 fastchat/train/train_lora.py \
--data_path ./data/fastchat-ft-gpt4-gpt4-easy-2-side-partial.json \
--shuffle True \
--bf16 True \
--output_dir ./checkpoint-shuffle-drop-long \
--output_dir ./checkpoint-shuffle \
--num_train_epochs 20 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
Expand All @@ -25,7 +25,7 @@ deepspeed --num_gpus=1 fastchat/train/train_lora.py \
--hf_access_token "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" \
--tf32 True \
--flash_attn True \
--drop_long_seq True \
--template "llama-2"

# Possible other options
# --flash_attn True \
Expand Down

0 comments on commit b76fc0f

Please sign in to comment.