Skip to content

Commit

Permalink
[fix] hymba template more consistent with jinja
Browse files Browse the repository at this point in the history
  • Loading branch information
YizhenJia committed Dec 9, 2024
1 parent 0d898cd commit c31799f
Showing 1 changed file with 47 additions and 17 deletions.
64 changes: 47 additions & 17 deletions src/lmflow/utils/conversation_template/hymba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from transformers import PreTrainedTokenizer


# NOTE: 'contexts' are not used in sft
# {{'<extra_id_0>System'}}
# {% for message in messages %}
# {% if message['role'] == 'system' %}
# {{'\n' + message['content'].strip()}}
# {% if tools or contexts %}
# {% if tools %}
# {{'\n'}}
# {% endif %}
# {% endif %}
Expand All @@ -21,14 +21,6 @@
# {{ '\n<tool> ' + tool|tojson + ' </tool>' }}
# {% endfor %}
# {% endif %}
# {% if contexts %}
# {% if tools %}
# {{'\n'}}
# {% endif %}
# {% for context in contexts %}
# {{ '\n<context> ' + context.strip() + ' </context>' }}
# {% endfor %}
# {% endif %}
# {{'\n\n'}}
# {% for message in messages %}
# {% if message['role'] == 'user' %}
Expand All @@ -45,18 +37,56 @@


class HymbaConversationTemplate(ConversationTemplateForTool):
def _encode(
def encode_conversation(
self,
tokenizer: PreTrainedTokenizer,
messages: List[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
tools: Optional[List[str]] = None,
**kwargs
) -> Sequence[Tuple[List[int], List[int]]]:
if tools:
system = system + "<tool> " + tools + " </tool>\n"
return super()._encode(tokenizer=tokenizer,messages=messages, system=system, tools='')

r'''
Messages here should be guaranteed to be in pairs, with the first message being the user message and the second message being the system message.
Data example:
```json
{
"conversation_id": 2,
"system": "sysinfo1",
"tools": ["tool_1_desc"],
"messages": [
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello!"
}
]
}
```
'''
assert isinstance(messages, list), "Messages must be a list."

tools_out = ''
if tools is not None:
for tool in tools:
tools_out += "\n<tool> " + tool + " </tool>"

if system is None:
system = ""
else:
if system.replace(" ",""): # has actual content
if not self.system_formatter:
raise ValueError("Your dataset contains system message but no system formatter is provided. "
"Consider either providing a system formatter or removing system prompt from your dataset.")
system = '\n' + system
else:
system = ""
encoded_pairs = self._encode(tokenizer, messages, system, tools_out, **kwargs)
encoded_pairs = self.post_process_pairs(encoded_pairs=encoded_pairs, tokenizer=tokenizer)

return encoded_pairs


HYMBA_TEMPLATE = HymbaConversationTemplate(
Expand All @@ -83,7 +113,7 @@ def _encode(
),
system_formatter=StringFormatter(
template=[
TemplateComponent(type='string', content='<extra_id_0>System\n{{content}}\n')
TemplateComponent(type='string', content='<extra_id_0>System{{content}}\n\n')
]
),
separator=TemplateComponent(type='token_id', content=13),
Expand Down

0 comments on commit c31799f

Please sign in to comment.