From c31799f56b6793a462c1acde638ba1a2390b214f Mon Sep 17 00:00:00 2001 From: YizhenJia Date: Mon, 9 Dec 2024 10:05:46 +0800 Subject: [PATCH] [fix] hymba template more consistent with jinja --- .../utils/conversation_template/hymba.py | 64 ++++++++++++++----- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/src/lmflow/utils/conversation_template/hymba.py b/src/lmflow/utils/conversation_template/hymba.py index fdd01a9a..81e3ec81 100644 --- a/src/lmflow/utils/conversation_template/hymba.py +++ b/src/lmflow/utils/conversation_template/hymba.py @@ -6,12 +6,12 @@ from transformers import PreTrainedTokenizer - +# NOTE: 'contexts' are not used in sft # {{'System'}} # {% for message in messages %} # {% if message['role'] == 'system' %} # {{'\n' + message['content'].strip()}} -# {% if tools or contexts %} +# {% if tools %} # {{'\n'}} # {% endif %} # {% endif %} @@ -21,14 +21,6 @@ # {{ '\n ' + tool|tojson + ' ' }} # {% endfor %} # {% endif %} -# {% if contexts %} -# {% if tools %} -# {{'\n'}} -# {% endif %} -# {% for context in contexts %} -# {{ '\n ' + context.strip() + ' ' }} -# {% endfor %} -# {% endif %} # {{'\n\n'}} # {% for message in messages %} # {% if message['role'] == 'user' %} @@ -45,18 +37,56 @@ class HymbaConversationTemplate(ConversationTemplateForTool): - def _encode( + def encode_conversation( self, tokenizer: PreTrainedTokenizer, messages: List[Dict[str, str]], system: Optional[str] = None, - tools: Optional[str] = None, + tools: Optional[List[str]] = None, **kwargs ) -> Sequence[Tuple[List[int], List[int]]]: - if tools: - system = system + " " + tools + " \n" - return super()._encode(tokenizer=tokenizer,messages=messages, system=system, tools='') - + r''' + Messages here should be guaranteed to be in pairs, with the first message being the user message and the second message being the system message. + Data example: + ```json + { + "conversation_id": 2, + "system": "sysinfo1", + "tools": ["tool_1_desc"], + "messages": [ + { + "role": "user", + "content": "hi" + }, + { + "role": "assistant", + "content": "Hello!" + } + ] + } + ``` + ''' + assert isinstance(messages, list), "Messages must be a list." + + tools_out = '' + if tools is not None: + for tool in tools: + tools_out += "\n " + tool + " " + + if system is None: + system = "" + else: + if system.replace(" ",""): # has actual content + if not self.system_formatter: + raise ValueError("Your dataset contains system message but no system formatter is provided. " + "Consider either providing a system formatter or removing system prompt from your dataset.") + system = '\n' + system + else: + system = "" + encoded_pairs = self._encode(tokenizer, messages, system, tools_out, **kwargs) + encoded_pairs = self.post_process_pairs(encoded_pairs=encoded_pairs, tokenizer=tokenizer) + + return encoded_pairs HYMBA_TEMPLATE = HymbaConversationTemplate( @@ -83,7 +113,7 @@ def _encode( ), system_formatter=StringFormatter( template=[ - TemplateComponent(type='string', content='System\n{{content}}\n') + TemplateComponent(type='string', content='System{{content}}\n\n') ] ), separator=TemplateComponent(type='token_id', content=13),