Skip to content

Commit

Permalink
[feature] add conversation template for hymba
Browse files Browse the repository at this point in the history
  • Loading branch information
YizhenJia committed Dec 7, 2024
1 parent f6590b4 commit eadd9de
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 47 deletions.
2 changes: 2 additions & 0 deletions src/lmflow/utils/conversation_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .chatml import CHATML_TEMPLATE
from .deepseek import DEEPSEEK_TEMPLATE
from .gemma import GEMMA_TEMPLATE
from .hymba import HYMBA_TEMPLATE
from .internlm import INTERNLM2_TEMPLATE
from .llama import LLAMA2_TEMPLATE, LLAMA3_TEMPLATE, LLAMA3_TEMPLATE_FOR_TOOL
from .phi import PHI3_TEMPLATE
Expand All @@ -22,6 +23,7 @@
'empty': EMPTY_TEMPLATE,
'empty_no_special_tokens': EMPTY_NO_SPECIAL_TOKENS_TEMPLATE,
'gemma': GEMMA_TEMPLATE,
'hymba': HYMBA_TEMPLATE,
'internlm2': INTERNLM2_TEMPLATE,
'llama2': LLAMA2_TEMPLATE,
'llama3': LLAMA3_TEMPLATE,
Expand Down
81 changes: 34 additions & 47 deletions src/lmflow/utils/conversation_template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,13 @@ def format(self, **kwargs) -> list:
class ConversationTemplate:
user_formatter: Formatter
assistant_formatter: Formatter
function_formatter: Optional[Formatter] = None,
observation_formatter: Optional[Formatter] = None,
function_formatter: Optional[Formatter] = None
observation_formatter: Optional[Formatter] = None
system_formatter: Optional[Formatter] = None
force_system: bool = False
tools_formatter: Optional[Formatter] = None
separator: Optional[TemplateComponent] = None
remove_last_sep: bool = False
special_starter: Optional[TemplateComponent] = None
special_stopper: Optional[TemplateComponent] = None
template_name: Optional[str] = None
Expand All @@ -181,7 +183,6 @@ def encode_conversation(
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[List[str]] = None,
remove_last_sep: bool = False,
**kwargs
) -> Sequence[Tuple[List[int], List[int]]]:
r'''
Expand Down Expand Up @@ -219,27 +220,7 @@ def encode_conversation(
system = None

encoded_pairs = self._encode(tokenizer, messages, system, tools, **kwargs)

if self.separator and remove_last_sep:
# For models that require a separator between messages,
# user can include the seperator at the end of each template
# and specify the separator. Auto formatting will remove the
# last separator once user specifies this option.
encoded_pairs = self.remove_last_separator(encoded_pairs, tokenizer)

if self.special_starter:
# For models that has ONLY ONE bos token at the beginning of
# a conversation session (not a conversation pair), user can
# specify a special starter to add that starter to the very
# beginning of the conversation session.
# eg:
# llama-2: <s> and </s> at every pair of conversation
# v.s.
# llama-3: <|begin_of_text|> only at the beginning of a session
encoded_pairs = self.add_special_starter(encoded_pairs, tokenizer)

if self.special_stopper:
encoded_pairs = self.add_special_stopper(encoded_pairs, tokenizer)
encoded_pairs = self.post_process_pairs(encoded_pairs=encoded_pairs, tokenizer=tokenizer)

return encoded_pairs

Expand All @@ -256,7 +237,10 @@ def _encode(

res_all = []

system_formatted = self.system_formatter.format(content=system) if system else []
if system:
system_formatted = self.system_formatter.format(content=system)
else:
system_formatted = self.system_formatter.format(content='') if self.force_system else []
system_encoded = self._encode_template(system_formatted, tokenizer)

for i in range(0, len(messages), 2):
Expand Down Expand Up @@ -317,6 +301,30 @@ def _encode_template(
raise NotImplementedError(f"Component type {component.type} is not supported yet.")
return encoded_ids

def post_process_pairs(self, encoded_pairs, tokenizer):
if self.separator and self.remove_last_sep:
# For models that require a separator between messages,
# user can include the seperator at the end of each template
# and specify the separator. Auto formatting will remove the
# last separator once user specifies this option.
encoded_pairs = self.remove_last_separator(encoded_pairs, tokenizer)

if self.special_starter:
# For models that has ONLY ONE bos token at the beginning of
# a conversation session (not a conversation pair), user can
# specify a special starter to add that starter to the very
# beginning of the conversation session.
# eg:
# llama-2: <s> and </s> at every pair of conversation
# v.s.
# llama-3: <|begin_of_text|> only at the beginning of a session
encoded_pairs = self.add_special_starter(encoded_pairs, tokenizer)

if self.special_stopper:
encoded_pairs = self.add_special_stopper(encoded_pairs, tokenizer)

return encoded_pairs

def remove_last_separator(
self,
encoded_pairs: Sequence[Tuple[List[int], List[int]]],
Expand Down Expand Up @@ -404,7 +412,6 @@ def encode_conversation(
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[List[str]] = None,
remove_last_sep: bool = False,
**kwargs
) -> Sequence[Tuple[List[int], List[int]]]:
r'''
Expand Down Expand Up @@ -446,27 +453,7 @@ def encode_conversation(
else:
system = ""
encoded_pairs = self._encode(tokenizer, messages, system, tools, **kwargs)

if self.separator and remove_last_sep:
# For models that require a separator between messages,
# user can include the seperator at the end of each template
# and specify the separator. Auto formatting will remove the
# last separator once user specifies this option.
encoded_pairs = self.remove_last_separator(encoded_pairs, tokenizer)

if self.special_starter:
# For models that has ONLY ONE bos token at the beginning of
# a conversation session (not a conversation pair), user can
# specify a special starter to add that starter to the very
# beginning of the conversation session.
# eg:
# llama-2: <s> and </s> at every pair of conversation
# v.s.
# llama-3: <|begin_of_text|> only at the beginning of a session
encoded_pairs = self.add_special_starter(encoded_pairs, tokenizer)

if self.special_stopper:
encoded_pairs = self.add_special_stopper(encoded_pairs, tokenizer)
encoded_pairs = self.post_process_pairs(encoded_pairs=encoded_pairs, tokenizer=tokenizer)

return encoded_pairs

Expand Down
91 changes: 91 additions & 0 deletions src/lmflow/utils/conversation_template/hymba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 Statistics and Machine Learning Research Group. All rights reserved.
from .base import StringFormatter, TemplateComponent, ConversationTemplateForTool
from typing import Dict, Set, Sequence, Literal, Union, List, Optional, Tuple

from transformers import PreTrainedTokenizer


# {{'<extra_id_0>System'}}
# {% for message in messages %}
# {% if message['role'] == 'system' %}
# {{'\n' + message['content'].strip()}}
# {% if tools or contexts %}
# {{'\n'}}
# {% endif %}
# {% endif %}
# {% endfor %}
# {% if tools %}
# {% for tool in tools %}
# {{ '\n<tool> ' + tool|tojson + ' </tool>' }}
# {% endfor %}
# {% endif %}
# {% if contexts %}
# {% if tools %}
# {{'\n'}}
# {% endif %}
# {% for context in contexts %}
# {{ '\n<context> ' + context.strip() + ' </context>' }}
# {% endfor %}
# {% endif %}
# {{'\n\n'}}
# {% for message in messages %}
# {% if message['role'] == 'user' %}
# {{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}
# {% elif message['role'] == 'assistant' %}
# {{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}
# {% elif message['role'] == 'tool' %}
# {{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}
# {% endif %}
# {% endfor %}
# {%- if add_generation_prompt %}
# {{'<extra_id_1>Assistant\n'}}
# {%- endif %}


class HymbaConversationTemplate(ConversationTemplateForTool):
def _encode(
self,
tokenizer: PreTrainedTokenizer,
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**kwargs
) -> Sequence[Tuple[List[int], List[int]]]:
system_and_tool = system + "<tool> " + tools + " </tool>"
return super()._encode(tokenizer=tokenizer,messages=messages, system=system_and_tool, tools='')



HYMBA_TEMPLATE = HymbaConversationTemplate(
template_name='hymba',
user_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_1>User\n{{content}}\n')
]
),
assistant_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_1>Assistant\n{{content}}\n')
]
),
function_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_1>Assistant\n{{content}}\n')
]
),
observation_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_1>Tool\n{{content}}\n')
]
),
system_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_0>System\n{{content}}\n\n')
]
),
separator=TemplateComponent(type='string', content='\n'),
remove_last_sep=True,
special_stopper=TemplateComponent(type='token', content='eos_token')
)

0 comments on commit eadd9de

Please sign in to comment.