Skip to content

Commit

Permalink
[feature] add conversation template for hymba
Browse files Browse the repository at this point in the history
  • Loading branch information
YizhenJia committed Dec 8, 2024
1 parent eadd9de commit 9eaacaf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
9 changes: 7 additions & 2 deletions src/lmflow/utils/conversation_template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class ConversationTemplate:

def __post_init__(self):
if self.separator:
if self.separator.type not in ['string', 'token']:
if self.separator.type not in ['string', 'token', 'token_id']:
raise NotImplementedError(f"Component type {self.separator.type} cannot be used as a separator.")

if self.special_starter:
Expand Down Expand Up @@ -335,6 +335,8 @@ def remove_last_separator(
separator_ids = tokenizer.encode(self.separator.content, add_special_tokens=False)
elif self.separator.type == 'token':
separator_ids = self._ensure_id_list(tokenizer.convert_tokens_to_ids(self.separator.content))
elif self.separator.type == 'token_id':
separator_ids = self._ensure_id_list(self.separator.content)
else:
raise ValueError(f"Component type {self.separator.type} cannot be used as a separator.")

Expand Down Expand Up @@ -471,7 +473,10 @@ def _encode(
res_all = []
# Concatenate the system and tools strings
system = system + tools
system_formatted = self.system_formatter.format(content=system) if system else []
if system:
system_formatted = self.system_formatter.format(content=system)
else:
system_formatted = self.system_formatter.format(content='') if self.force_system else []
system_encoded = self._encode_template(system_formatted, tokenizer)
ls_for_save = []
for i in range(0, len(messages), 1):
Expand Down
14 changes: 8 additions & 6 deletions src/lmflow/utils/conversation_template/hymba.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ def _encode(
system: Optional[str] = None,
tools: Optional[str] = None,
**kwargs
) -> Sequence[Tuple[List[int], List[int]]]:
system_and_tool = system + "<tool> " + tools + " </tool>"
return super()._encode(tokenizer=tokenizer,messages=messages, system=system_and_tool, tools='')
) -> Sequence[Tuple[List[int], List[int]]]:
if tools:
system = system + "<tool> " + tools + " </tool>\n"
return super()._encode(tokenizer=tokenizer,messages=messages, system=system, tools='')



Expand Down Expand Up @@ -82,10 +83,11 @@ def _encode(
),
system_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_0>System\n{{content}}\n\n')
TemplateComponent(type='string', content='<extra_id_0>System\n{{content}}\n')
]
),
separator=TemplateComponent(type='string', content='\n'),
separator=TemplateComponent(type='token_id', content=13),
remove_last_sep=True,
special_stopper=TemplateComponent(type='token', content='eos_token')
special_stopper=TemplateComponent(type='token', content='eos_token'),
force_system=True
)

0 comments on commit 9eaacaf

Please sign in to comment.