Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onlinedpo error when use deepspeed zero3 #2532

Open
5 of 9 tasks
yiyepiaoling0715 opened this issue Dec 30, 2024 · 2 comments
Open
5 of 9 tasks

onlinedpo error when use deepspeed zero3 #2532

yiyepiaoling0715 opened this issue Dec 30, 2024 · 2 comments
Labels
🐛 bug Something isn't working 🚀 deepspeed Related to deepspeed ⏳ needs more info Additional information or clarification is required to proceed 🏋 Online DPO Related to Online DPO

Comments

@yiyepiaoling0715
Copy link

System Info

`
transformers 4.47.0
triton 3.0.0
trl 0.12.1
trove-classifiers 2024.10.21.16
truststore 0.8.0
typer 0.14.0
types-dataclasses 0.6.6
typing_extensions 4.12.2
typing-inspect 0.9.0
tzdata 2024.2
tzlocal 5.2
ujson 5.10.0
urllib3 2.2.2
utils 1.0.2
uvicorn 0.32.1
uvloop 0.21.0
virtualenv 20.28.0
vllm 0.6.3
vllm-flash-attn 2.6.1

trl env
`Copy-paste the following information when reporting an issue:

  • Platform: Linux-5.4.143-2-velinux1-amd64-x86_64-with-glibc2.35
  • Python version: 3.11.9
  • PyTorch version: 2.4.0
  • CUDA device(s): NVIDIA A100-SXM4-80GB, NVIDIA A100-SXM4-80GB, NVIDIA A100-SXM4-80GB, NVIDIA A100-SXM4-80GB, NVIDIA A100-SXM4-80GB, NVIDIA A100-SXM4-80GB, NVIDIA A100-SXM4-80GB, NVIDIA A100-SXM4-80GB
  • Transformers version: 4.47.0
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • Datasets version: 3.1.0
  • HF Hub version: 0.26.3
  • TRL version: 0.12.1
  • bitsandbytes version: 0.45.0
  • DeepSpeed version: 0.16.1
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: 0.0.2
  • OpenAI version: 1.57.0
  • PEFT version: 0.13.2`

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction


class UnifiedDPODataset(Dataset):
    """
    统一的DPO数据集
    """
    def __init__(self, file, tokenizer, max_seq_length, max_prompt_length, template,
                 maximum_es_score,minimum_es_score,bool_training:bool):
        self.tokenizer = tokenizer
        self.template_name = template.template_name
        #==None
        self.system_format = template.system_format
        self.user_format = template.user_format
        self.assistant_format = template.assistant_format
        self.system = template.system

        self.max_seq_length = max_seq_length
        self.max_prompt_length = max_prompt_length
        logger.info('Loading data: {}'.format(file))
        with open(file, 'r', encoding='utf-8') as f:
            raw_data_list = f.readlines()
            #根据key=es_score过滤数据
            for check_data in raw_data_list:
                try:
                    json.loads(check_data)
                except json.decoder.JSONDecodeError as e:
                    print(f'JSONDecodeError={e.args},check_data={check_data}')
            # data_list = [json.loads(data) for data in raw_data_list 
            #     if float(json.loads(data)['es_score']) >= minimum_es_score 
            #     and float(json.loads(data)['es_score']) <= maximum_es_score]
            if bool_training:
                data_list=[]
                for data_str_iter in raw_data_list:
                    data_json_iter=json.loads(data_str_iter)
                    if isinstance(data_json_iter['es_score'],dict):
                        es_score=max([float(elem) for elem in list(data_json_iter['es_score'].values())])
                    else:
                        es_score=float(data_json_iter['es_score'])
                    if es_score >= minimum_es_score and es_score <= maximum_es_score:
                        data_list.append(data_json_iter)
            else:
                data_list=[json.loads(data_str_iter) for data_str_iter in raw_data_list]
        logger.info(f"Use template {self.template_name} for training,bool_training={bool_training},There are {len(data_list)} data in dataset,原始数据量={len(raw_data_list)}")
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def build_prompt_input_ids(self, system, history):
        """
        chatglm2: [gMASK]sop [Round 1]\n\n问:{input1}\n\n答:{target1}</s>[Round 2]\n\n问:{input2}\n\n答:{target2}</s>...
        chatglm3: [gMASK]sop <|system|>xxx<|user|>xxx<|assistant|>xxx<eos>
        others: {system_format}{user_format}{assistant_format}{user_format}{assistant_format}...
        """
        # chatglm模型具有特殊的起始token
        if self.template_name in ['chatglm2', 'chatglm3']:
            prompt_input_ids = self.tokenizer.get_prefix_tokens()
        else:
            prompt_input_ids = []
        prompt=''
        # collect system information
        if self.system_format is not None:
            system = system if system is not None else self.system
            # system信息不为空
            if system is not None:
                if self.template_name == 'chatglm3':
                    prompt_input_ids += [self.tokenizer.get_command(f"<|system|>")] + self.tokenizer.encode(system, add_special_tokens=False)
                else:
                    system_text = self.system_format.format(content=system)
                    prompt_input_ids += self.tokenizer.encode(system_text, add_special_tokens=False)
                prompt+=system_text
        # collect history
        ##将 user/assist 的 multi-turn  prompt/input_ids 拼接
        for i, conv in enumerate(history):
            role = conv['role'].strip()
            content = conv['content'].strip()

            assert role != 'system', 'there should not be more than one system information'
            text_iter=''
            if role == 'user':
                if self.template_name == 'chatglm2':
                    human = self.user_format.format(content=content, idx=i//2 + 1)
                    input_ids = self.tokenizer.encode(human, add_special_tokens=False)
                elif self.template_name == 'chatglm3':
                    input_ids = [self.tokenizer.get_command(f"<|user|>")] + \
                                self.tokenizer.encode(content, add_special_tokens=False) + \
                                [self.tokenizer.get_command(f"<|assistant|>")]
                else:
                    human = self.user_format.format(content=content, stop_token=self.tokenizer.eos_token)
                    input_ids = self.tokenizer.encode(human, add_special_tokens=False)
                text_iter=human
            elif role == 'assistant':
                if self.template_name in ['chatglm2', 'chatglm3']:
                    input_ids = self.tokenizer.encode(content, add_special_tokens=False) + [self.tokenizer.eos_token_id]
                else:
                    assistant = self.assistant_format.format(content=content, stop_token=self.tokenizer.eos_token)
                    input_ids = self.tokenizer.encode(assistant, add_special_tokens=False)
                text_iter=assistant
            else:
                raise Exception('role error')
            prompt_input_ids += input_ids
            prompt += text_iter

        return prompt_input_ids,prompt

    def __getitem__(self, index):
        data = self.data_list[index]
        # data = json.loads(data)
        chosen = data['chosen']
        rejected = data['rejected']
        assert len(chosen) == len(rejected)

        # 判断第0个是否为system
        if chosen[0]['role'] == 'system':
            system = chosen[0]['content'].strip()
            history = chosen[1:-1]  # 对话上文
            chosen, rejected = chosen[-1], rejected[-1]
        else:
            # user/assist ,单轮 history为空
            system = None
            history = chosen[:-1]  # 对话上文
            ##chosen/rejected 最后一轮,assist的回复
            chosen, rejected = chosen[-1], rejected[-1]

        # build prompt 
        #构建 system, history 部分
        prompt_input_ids,prompt = self.build_prompt_input_ids(system, history)

        # build response
        if self.template_name in ['chatglm2', 'chatglm3']:
            chosen_input_ids = self.tokenizer.encode(chosen['content'], add_special_tokens=False) + [self.tokenizer.eos_token_id]
            rejected_input_ids = self.tokenizer.encode(rejected['content'], add_special_tokens=False) + [self.tokenizer.eos_token_id]
        else:
            #chosen content 对应的prompt
            chosen = self.assistant_format.format(content=chosen['content'], stop_token=self.tokenizer.eos_token)
            #rejected content 对应的prompt
            rejected = self.assistant_format.format(content=rejected['content'], stop_token=self.tokenizer.eos_token)

            chosen_input_ids = self.tokenizer.encode(chosen, add_special_tokens=False)
            rejected_input_ids = self.tokenizer.encode(rejected, add_special_tokens=False)

        # truncate by max_seq_length
        ##todo 需要在生成语料时候对最长的声场加上截断,过滤筛选,防止过长
        longer_response_length = max(len(chosen_input_ids), len(rejected_input_ids))
        # if combined sequence is too long, truncate the prompt
        if len(prompt_input_ids) + longer_response_length > self.max_seq_length:
            #取 static的 max_prompt_length  和  max_seq_length - longer_response_length 的最大值
            max_prompt_length = max(self.max_prompt_length, self.max_seq_length - longer_response_length)
            #截断
            prompt_input_ids = prompt_input_ids[-max_prompt_length:]
        # if that's still too long, truncate the response
        ##?? 什么情况still too long?
        if len(prompt_input_ids) + longer_response_length > self.max_seq_length:
            chosen_input_ids = chosen_input_ids[: self.max_seq_length - len(prompt_input_ids)]
            rejected_input_ids = rejected_input_ids[: self.max_seq_length - len(prompt_input_ids)]
        chosen_content_of_assist_len=len(chosen_input_ids)
        reject_content_of_assist_len=len(rejected_input_ids)
        chosen_labels = [-100] * len(prompt_input_ids) + chosen_input_ids
        chosen_input_ids = prompt_input_ids + chosen_input_ids
        rejected_labels = [-100] * len(prompt_input_ids) + rejected_input_ids
        rejected_input_ids = prompt_input_ids + rejected_input_ids
        assert len(chosen_labels) == len(chosen_input_ids)
        assert len(rejected_labels) == len(rejected_input_ids)
        if np.random.random()<0.01:
            info_msg=f'longer_response_length={longer_response_length},prompt_input_ids len={len(prompt_input_ids)}'+ \
            f'chosen_答案长度={chosen_content_of_assist_len},reject答案长度={reject_content_of_assist_len}'+\
            f'拼接prompt后chosen长度={len(chosen_input_ids)},拼接prompt后reject长度={len(rejected_input_ids)}'+\
            f'prompt_input_ids={prompt_input_ids},\n chosen_input_ids={chosen_input_ids},\n rejected_input_ids={rejected_input_ids},\n'+\
            f'chosen_labels={chosen_labels},\n rejected_labels={rejected_labels}'
            print(info_msg) 
        inputs = dict(
            prompt_input_ids=prompt_input_ids,
            prompt_attention_mask=[1]*len(prompt_input_ids),
            chosen_input_ids=chosen_input_ids,
            chosen_attention_mask=[1]*len(chosen_input_ids),
            chosen_labels=chosen_labels,
            rejected_input_ids=rejected_input_ids,
            rejected_attention_mask=[1]*len(rejected_input_ids),
            rejected_labels=rejected_labels,
            prompt=prompt
        )
        return inputs

    # 为了适配DPOTrainer的接口
    def map(self, func, **kwargs):
        return self
    # 为了适配DPOTrainer的接口
    def map(self, func, **kwargs):
        return self
    def select(self,index_list):
        select_data_lsit=[]
        for index in index_list:
            data_iter=self.data_list[index]
            select_data_lsit.append(data_iter)
        return select_data_lsit
        
class UnifiedOnlineDPODataset(UnifiedDPODataset):
    def __init__(self, file, tokenizer, max_seq_length,template,
                 maximum_es_score,minimum_es_score,bool_training:bool):
        max_prompt_length=max_seq_length
        super(UnifiedOnlineDPODataset, self).__init__(file=file, tokenizer=tokenizer, max_seq_length=max_seq_length, 
                max_prompt_length=max_prompt_length, template=template,maximum_es_score=maximum_es_score,minimum_es_score=minimum_es_score,
                bool_training=bool_training)
    def __getitem__(self, index):
        data = self.data_list[index]
        # build prompt 
        #构建 system, history 部分
         # 判断第0个是否为system
        # chosen = data['chosen']
        # if chosen[0]['role'] == 'system':
        #     system = chosen[0]['content'].strip()
        #     history = chosen[1:-1]  # 对话上文
        #     chosen = chosen[-1]
        # else:
        #     # user/assist ,单轮 history为空
        #     system = None
        #     history = chosen[:-1]  # 对话上文
        #     ##chosen/rejected 最后一轮,assist的回复
        #     chosen = chosen[-1]
        # prompt_input_ids,prompt = self.build_prompt_input_ids(system, history)       

        prompt=data['prompt']
        groundtruth=data['groundtruth']
        # self.build_prompt_input_ids(system, history)       
        # prompt_input_ids = self.tokenizer.encode(prompt, add_special_tokens=False) + [self.tokenizer.eos_token_id]
        prompt_input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        ##todo assert fim_end
        assert groundtruth.endswith(TC.DS_EOS_TOKEN)
        assert not prompt.endswith(TC.DS_EOS_TOKEN)
        # system = None

        # truncate by max_seq_length
        ##todo 需要在生成语料时候对最长的声场加上截断,过滤筛选,防止过长
        # if combined sequence is too long, truncate the prompt
        if len(prompt_input_ids) > self.max_prompt_length:
            #截断
            prompt_input_ids = prompt_input_ids[-self.max_prompt_length:]        
            decoded_prompt=self.tokenizer.decode(prompt_input_ids,skip_special_tokens=False)
            double_decoded_prompt_ids=self.tokenizer.encode(decoded_prompt,add_special_tokens=False)
            #for check 
            try:
                zipo_decode_tuple_list=list(zip(prompt[-self.max_prompt_length:][::-1],decoded_prompt[-self.max_prompt_length:][::-1]))[::-1]
                zipo_decode_id_tuple_list=list(zip(prompt_input_ids[::-1],double_decoded_prompt_ids[::-1]))[::-1]
                assert decoded_prompt[-self.max_prompt_length:]==prompt[-self.max_prompt_length:]
            except AssertionError as e:
                # print(f'decoded_prompt[-self.max_prompt_length:]=\n{decoded_prompt[-self.max_prompt_length:]},\nprompt[-self.max_prompt_length:]={prompt[-self.max_prompt_length:]},')        
                print(f'decoded_prompt[-self.max_prompt_length:]={decoded_prompt[-self.max_prompt_length:]},\n'+\
                    f'prompt[-self.max_prompt_length:]={prompt[-self.max_prompt_length:]},\n'+\
                    f'zipo_decode_tuple_list={zipo_decode_tuple_list},\nzipo_decode_id_tuple_list={zipo_decode_id_tuple_list}')
            prompt=decoded_prompt
            # prompt=self.tokenizer.convert_ids_to_tokens(prompt_input_ids)
        inputs = dict(
            # prompt_input_ids=prompt_input_ids,
            prompt_input_ids=prompt_input_ids,
            prompt_attention_mask=[1]*len(prompt_input_ids),
            prompt=prompt,
            groundtruth=groundtruth
        )
        return inputs



from typing import Any, Dict, List
import torch
from loguru import logger
import copy
import numpy as np


def general_pad_and_truncate(direction,inputs:Dict,params:Dict):
    batch_max_len=params['batch_max_len']
    # max_seq_length=params['max_seq_length']
    pad_token_id=params['pad_token_id']
    logger.info(f'general_pad_and_truncate params={params}')
    key_endswith_input_ids=[k for k,v in inputs.items() if k.endswith('input_ids')]
    input_ids=inputs[key_endswith_input_ids[0]]
    assert len(key_endswith_input_ids)==1
    padding_len = batch_max_len - len(input_ids)
    try:
        assert padding_len>=0
    except AssertionError as e:
        err_msg=f'AssertionError,需要截断,padding_len={padding_len},batch_max_len={batch_max_len},len(input_ids)={len(input_ids)}'
        # logger.info()
        raise ValueError(err_msg)
    new_inpus={}
    if direction=='left':
        for k,v in inputs.items():
            if k.endswith('input_ids'):
                # logger.info(f'pad_token_id={pad_token_id},padding_len={padding_len},[pad_token_id]*padding_len={[pad_token_id]*padding_len},v={v}')
                if isinstance(v,torch.Tensor):
                    v=v.tolist()
                v=[pad_token_id]*padding_len+v
            elif k.endswith('mask'):
                if isinstance(v,torch.Tensor):
                    v=v.tolist()
                v=[0]*padding_len+v
            else:
                pass
                # err_msg=f'error,not endswith input_ids/mask,直接赋值k={k},v={v}'
                # # raise ValueError()
                # # new_inpus[k]=v
                # logger.info(err_msg)
            new_inpus[k]=v
    elif direction=='right':
        for k,v in inputs.items():
            if k.endswith('input_ids'):
                v = v + [pad_token_id] * padding_len
            elif k.endswith('mask'):
                v = v + [0] * padding_len
                # target_mask = target_mask + [0] * padding_len
            else:
                raise ValueError('error')
            new_inpus[k]=v
    else:
        raise ValueError('error')   
    # truncate
    for k,v in new_inpus.items():
        if k.endswith('input_ids') or k.endswith('mask'):
            new_inpus[k]=v[:batch_max_len]
    # input_ids = input_ids[:max_seq_length]
    # attention_mask = attention_mask[:max_seq_length]
    # target_mask = target_mask[:max_seq_length]
    ##modify 
    # assert input_ids[0]!=pad_token_id
    return new_inpus
def pad_and_truncate(direction,inputs:Dict,params:Dict):
    batch_max_len=params['batch_max_len']
    max_seq_length=params['max_seq_length']
    pad_token_id=params['pad_token_id']
    input_ids=inputs['input_ids']
    attention_mask=inputs['attention_mask']
    target_mask=inputs['target_mask']
    padding_len = batch_max_len - len(input_ids)
    ##存在   batch_max_len = min(max(lengths), self.max_seq_length),存在 max_seq_length 更小的场景,导致 batch_max_len<len(input_ids)
    ## padding_len为空,则不生效
    try:
        assert padding_len>=0
    except AssertionError as e:
        err_msg=f'AssertionError,需要截断,padding_len={padding_len},batch_max_len={batch_max_len},len(input_ids)={len(input_ids)}'
        # logger.info()
        raise ValueError(err_msg)
        
    if direction=='left':
        input_ids=[pad_token_id]*padding_len+input_ids
        attention_mask=[0]*padding_len+attention_mask
        target_mask=[0]*padding_len+target_mask

        input_ids=input_ids[-max_seq_length:]
        attention_mask=attention_mask[-max_seq_length:]        
        target_mask=target_mask[-max_seq_length:]
        assert input_ids[-2]!=pad_token_id,print(f'assert 最后一个pos是 pad_token_id={pad_token_id},input_ids =\n{input_ids}')
    elif direction=='right':
        input_ids = input_ids + [pad_token_id] * padding_len
        attention_mask = attention_mask + [0] * padding_len
        target_mask = target_mask + [0] * padding_len
        # truncate
        input_ids = input_ids[:max_seq_length]
        attention_mask = attention_mask[:max_seq_length]
        target_mask = target_mask[:max_seq_length]
        assert input_ids[0]!=pad_token_id
    else:
        raise ValueError('error')
    assert len(input_ids)==len(attention_mask)==len(target_mask)
    return {
        'input_ids':input_ids,
        'attention_mask':attention_mask,
        'target_mask':target_mask
    }
class SFTDataCollator(object):
    def __init__(self, tokenizer, max_seq_length,pad_direction='right',max_len_strategy='batch_max_len'):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.pad_token_id = tokenizer.pad_token_id
        self.pad_direction=pad_direction
        self.max_len_strategy=max_len_strategy

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        """

        """
        # 找出batch中的最大长度
        lengths = [len(x['input_ids']) for x in batch if x['input_ids'] is not None]
        # 取出batch中的最大长度,如果超过max_seq_length,则取max_seq_length
        if self.max_len_strategy=='batch_max_len':
            batch_max_len = min(max(lengths), self.max_seq_length)
        elif self.max_len_strategy=='max_seq_length':
            batch_max_len=self.max_seq_length
        else:
            raise ValueError(f'max_len_strategy={self.max_len_strategy} not support')
        # batch_max_len = self.max_seq_length
        params={
                'batch_max_len':batch_max_len,
                'max_seq_length':self.max_seq_length,
                'pad_token_id':self.pad_token_id
                }
        input_ids_batch, attention_mask_batch, target_mask_batch = [], [], []
        # truncate and padding
        for x in batch:
            input_ids = x['input_ids']
            attention_mask = x['attention_mask']
            target_mask = x['target_mask']
            if input_ids is None:
                logger.info('some input_ids is None')
                continue
            # padding_len = batch_max_len - len(input_ids)
            x_of_pad_and_truncate=pad_and_truncate(direction=self.pad_direction,inputs=x,params=params)
            input_ids = x_of_pad_and_truncate['input_ids']
            attention_mask = x_of_pad_and_truncate['attention_mask']
            target_mask = x_of_pad_and_truncate['target_mask']

            input_ids_batch.append(input_ids)
            attention_mask_batch.append(attention_mask)
            target_mask_batch.append(target_mask)

        # 将list转换为tensor,得到最终的的模型输入
        input_ids_batch = torch.tensor(input_ids_batch, dtype=torch.long)
        attention_mask_batch = torch.tensor(attention_mask_batch, dtype=torch.long)
        target_mask_batch = torch.tensor(target_mask_batch, dtype=torch.long)
        ##重要, 将 target_mask==0的部分置为-100, 这样在计算loss的时候,就不会计算这部分的loss =>fim_xx 都==-100
        labels = torch.where(target_mask_batch == 1, input_ids_batch, -100)
        for check_input_ids_iter in input_ids_batch:
            assert len(check_input_ids_iter)<=self.max_seq_length
        inputs = {
            'input_ids': input_ids_batch,
            'attention_mask': attention_mask_batch,
            'labels': labels
        }
        return inputs


def input_ids_mask_unneed_ids(input_ids,unneed_ids_tuple_list,ignore_index=-100,tokenizer=None):
    """
        诸如 #Path 这种unndeed
    """
    # start_pos_list=[tuple_iter[0] for tuple_iter in unneed_ids_tuple_list]
    # end_pos_list=[tuple_iter[1] for tuple_iter in unneed_ids_tuple_list]

    new_input_ids=copy.deepcopy(input_ids)
    bool_mask_counter=0
    for ids_tuple_iter in unneed_ids_tuple_list:
        start_id=ids_tuple_iter[0]
        end_id=ids_tuple_iter[1]
        start_index=-10000
        end_index  =-10000
        for index,input_id_iter in enumerate(new_input_ids):
            if input_id_iter==start_id:
                start_index=index
            elif input_id_iter==end_id:
                end_index=index
                try:
                    assert start_index!=-10000
                    new_input_ids[start_index:end_index+1]=[ignore_index]*(end_index+1-start_index)
                    start_index=-10000
                    end_index=-10000
                    bool_mask_counter=1
                except AssertionError as e:
                    check_start_id_list=[(check_index,check_input_id) for check_index,check_input_id in enumerate(new_input_ids) if check_input_id==start_id]
                    check_end_id_list=[(check_index,check_input_id) for check_index,check_input_id in enumerate(new_input_ids) if check_input_id==end_id]
                    logger.info(f'AssertionError input_ids_mask_unneed_ids AssertionError \ncur_index={index},check_start_id_list={check_start_id_list},check_end_id_list={check_end_id_list}')
                    try:
                        new_text=tokenizer.decode(new_input_ids)
                        logger.info(f'AssertionError new_text={new_text}')
                    except OverflowError as e:
                        logger.info(f'AssertionError OverflowError new_input_ids={new_input_ids}')
                    # if tokenizer is not None:

            else:
                pass
    return new_input_ids,bool_mask_counter


class PretrainCollator(object):

    def __init__(self, tokenizer, max_seq_length,args):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.pad_token_id = tokenizer.pad_token_id
        #add
        self.args=args
        self.unneed_ids_tuple_list=[(tokenizer.path_start_token_id,tokenizer.path_end_token_id)]
        self.bool_mask_counter=0
        self.bool_mask_counter_all=0
    def mask_input_ids_of_path(self,input_ids,ignore_index=-100):
        # new_input_ids=input_ids_mask_unneed_ids(input_ids=input_ids,unneed_ids_tuple_list=self.unneed_ids_tuple_list)
        # target_mask=[ignore_id if input_id==-1 else 1 for input_id in new_input_ids]
        return input_ids_mask_unneed_ids(input_ids=input_ids,unneed_ids_tuple_list=self.unneed_ids_tuple_list,ignore_index=ignore_index,tokenizer=self.tokenizer)

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        # for x in batch:
            # print('PretrainCollator x.keys()=',x.keys())
        ignore_index=-100

        batch = [x['input_ids'] for x in batch if x['input_ids'] is not None]
        # 找出batch中的最大长度
        lengths = [len(x) for x in batch]
        # 取出batch中的最大长度,如果超过max_seq_length,则取max_seq_length
        batch_max_len = min(max(lengths), self.max_seq_length)
        
        padding_len_list=[]
        ignore_index_len_tuple_list=[]
        # batch_max_len = self.max_seq_length
        
        input_ids_batch, attention_mask_batch, labels_batch = [], [], []
        #只对pad部分做了mask,对labels部分做了ignore_index
        for x in batch:
            input_ids = x
            attention_mask = [1] * len(input_ids)
            padding_len = batch_max_len - len(input_ids)
            #for check
            padding_len_list.append(padding_len)
            # padding
            labels = input_ids + [ignore_index] * padding_len
            input_ids = input_ids + [self.pad_token_id] * padding_len
            attention_mask = attention_mask + [0] * padding_len
            # truncate
            input_ids = input_ids[:self.max_seq_length]
            labels = labels[:self.max_seq_length]
            attention_mask = attention_mask[:self.max_seq_length]

            if self.args.bool_target_mask:
                ignore_index_num_before_mask_unneed=labels.count(ignore_index)
                labels_of_mask_unneed,bool_mask_counter_iter=self.mask_input_ids_of_path(labels,ignore_index=ignore_index)
                labels=labels_of_mask_unneed
                ignore_index_num_after_mask_unneed=labels.count(ignore_index)
                ignore_index_len_tuple_list.append((padding_len,ignore_index_num_before_mask_unneed,ignore_index_num_after_mask_unneed))
                self.bool_mask_counter+=bool_mask_counter_iter
                self.bool_mask_counter_all+=1
            input_ids_batch.append(input_ids)
            labels_batch.append(labels)
            attention_mask_batch.append(attention_mask)
        
        if self.args.bool_target_mask and np.random.rand() < 0.01:
            #for check
            print(f'PretrainCollator lengths={lengths} max(lengths)={max(lengths)},min(lengths)={min(lengths)}')
            check_batch_index=np.random.choice(len(input_ids_batch))
            check_padding_len,check_ignore_index_num_before_mask_unneed,check_ignore_index_num_after_mask_unneed=ignore_index_len_tuple_list[check_batch_index]
            print_string=f'padding_len={check_padding_len},ignore_index_num_before_mask_unneed={check_ignore_index_num_before_mask_unneed},ignore_index_num_after_mask_unneed={check_ignore_index_num_after_mask_unneed}'
            print(print_string)
            prfix_100=input_ids_batch[check_batch_index][:100]
            random_mid_start=np.random.randint(100,batch_max_len-200)
            mid_random_100=input_ids_batch[check_batch_index][random_mid_start:random_mid_start+100]
            suffix_100=input_ids_batch[check_batch_index][-100:]
            prefix_decode_text=self.tokenizer.decode(prfix_100)
            mid_suffix_rand_decode_text=self.tokenizer.decode(mid_random_100)
            suffix_decode_text=self.tokenizer.decode(suffix_100)
            print(f'PretrainCollator prfix_100={prfix_100},suffix_100={suffix_100},mid_random_100={mid_random_100}')
            print(f'prefix_decode_text={prefix_decode_text},suffix_decode_text={suffix_decode_text},mid_suffix_rand_decode_text={mid_suffix_rand_decode_text}')
            logger.info(f'bool_mask_counter_all={self.bool_mask_counter_all},bool_mask_counter={self.bool_mask_counter}')
        # 将list转换为tensor,得到最终的的模型输入
        input_ids_batch = torch.tensor(input_ids_batch, dtype=torch.long)
        #此处没有针对label做shift
        labels_batch = torch.tensor(labels_batch, dtype=torch.long)
        attention_mask_batch = torch.tensor(attention_mask_batch, dtype=torch.long)

        inputs = {
            'input_ids': input_ids_batch,
            'attention_mask': attention_mask_batch,
            'labels': labels_batch
        }

        return inputs



class OnlineDPOCollator(object):
    def __init__(self, tokenizer, max_seq_length,pad_direction='left',max_len_strategy='max_seq_length'):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.pad_token_id = tokenizer.pad_token_id
        self.pad_direction=pad_direction
        self.max_len_strategy=max_len_strategy

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        """

        """
        # 找出batch中的最大长度
        lengths = [len(x['prompt_input_ids']) for x in batch if x['prompt_input_ids'] is not None]
        # 取出batch中的最大长度,如果超过max_seq_length,则取max_seq_length
        if self.max_len_strategy=='batch_max_len':
            batch_max_len = min(max(lengths), self.max_seq_length)
        elif self.max_len_strategy=='max_seq_length':
            batch_max_len=self.max_seq_length
        else:
            raise ValueError(f'max_len_strategy={self.max_len_strategy} not support')
        # batch_max_len = self.max_seq_length
        params={
                'batch_max_len':batch_max_len,
                # 'max_seq_length':self.max_seq_length,
                'pad_token_id':self.pad_token_id
                }
        prompt_input_ids_batch, prompt_attention_mask_batch = [], []
        prompt_list,groundtruth_list=[],[]
        # truncate and padding
        for x in batch:
            prompt_input_ids = x['prompt_input_ids']
            prompt_attention_mask = x['prompt_attention_mask']
            # target_mask = x['target_mask']
            if prompt_input_ids is None:
                logger.info(f'some prompt_input_ids is None,x={x}')
                continue
            ##todo check
            if 'prompt' not in x:
                logger.info(f'some prompt is None,x={x}')
                continue
            prompt=x['prompt']
            groundtruth=x['groundtruth']
            prompt_list.append(prompt)
            groundtruth_list.append(groundtruth)
            # padding_len = batch_max_len - len(input_ids)
            # x_of_pad_and_truncate=pad_and_truncate(direction=self.pad_direction,inputs=x,params=params)
            x_of_pad_and_truncate=general_pad_and_truncate(direction='left',inputs=x,params=params)
            prompt_input_ids = x_of_pad_and_truncate['prompt_input_ids']
            prompt_attention_mask = x_of_pad_and_truncate['prompt_attention_mask']
            # target_mask = x_of_pad_and_truncate['prompt_target_mask']

            prompt_input_ids_batch.append(prompt_input_ids)
            prompt_attention_mask_batch.append(prompt_attention_mask)
            # target_mask_batch.append(target_mask)

        # 将list转换为tensor,得到最终的的模型输入
        prompt_input_ids_batch = torch.tensor(prompt_input_ids_batch, dtype=torch.long)
        prompt_attention_mask_batch = torch.tensor(prompt_attention_mask_batch, dtype=torch.long)
        # target_mask_batch = torch.tensor(target_mask_batch, dtype=torch.long)
        ##重要, 将 target_mask==0的部分置为-100, 这样在计算loss的时候,就不会计算这部分的loss =>fim_xx 都==-100
        # labels = torch.where(target_mask_batch == 1, input_ids_batch, -100)
        for check_input_ids_iter in prompt_input_ids_batch:
            assert len(check_input_ids_iter)<=self.max_seq_length
        
        inputs = {
            'prompt_input_ids': prompt_input_ids_batch,
            'prompt_attention_mask': prompt_attention_mask_batch,
            'prompt':prompt_list,
            'groundtruth':groundtruth_list,
        }
        # logger.info(f'inputs.keys()={inputs.keys()},prompt_input_ids_batch.shape={prompt_input_ids_batch.shape},prompt_attention_mask={prompt_attention_mask.shape},prompt_list-len={len(prompt_list)},groundtruth_list-len={len(groundtruth_list)}')
        for check_k,check_v in inputs.items():
            check_v_shape= check_v.shape if isinstance(check_v,torch.Tensor) else len(check_v)
            logger.info(f'OnlineDPOCollator check_k={check_k},check_v={check_v_shape}')
        return inputs



def calc_edit_sim_per_sample(pred,gt):
    """
        计算两个字符串之间的编辑距离相似度, a∩b/a∪b
    """
    return fuzz.ratio(pred, gt)

# class RandomPairwiseJudge(BasePairwiseJudge):
#     """
#     Random pairwise judge, for testing purposes.
#     """

#     def judge(self, prompts, completions, shuffle_order=True):
#         return [random.randint(0, len(completion) - 1) for completion in completions]



class CodeRulePairRMJudge(BasePairwiseJudge):
    def __init__(self,tokenizer,judge_text_type):
        self.tokenizer=tokenizer
        self.judge_text_type=judge_text_type
        # self.tokenizer=tokenizer
    # def judge(self,prompts, completions,groundtruths,return_scores=False, disable_tqdm=True):
    #     """
    #         completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]]
    #     """
    #     # return np.array(list([range(len(completions))]))
    #     rank_index=[[0,1],[1,0]]
    #     return np.array(rank_index)

    # def rank(self,prompts, completions,groundtruths,return_scores=False, disable_tqdm=True):
    #     """
    #         completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]]
    #     """
    #     # return np.array(list([range(len(completions))]))
    #     # rank_index=[[0,1],[1,0]]
    #     # return np.array(rank_index)

    def judge(self,prompts, completions:List[Tuple],groundtruths,return_scores=False, disable_tqdm=True):
        """
            completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]]
        """
        es_list=[]
        rank_list=[]
        index=0
        for completion_tuple,groundtruth,prompt in zip(completions,groundtruths,prompts):
            id_str=''
            if self.judge_text_type=='text':
                es_r0=calc_edit_sim_per_sample(pred=completion_tuple[0],gt=groundtruth)
                es_r1=calc_edit_sim_per_sample(pred=completion_tuple[1],gt=groundtruth)
                rank=[int(es_r0>es_r1),int(es_r1>es_r0)]
            elif self.judge_text_type=='id':
                completion_ids_0=self.tokenizer.encode(completion_tuple[0], add_special_tokens=False)
                completion_ids_1=self.tokenizer.encode(completion_tuple[1], add_special_tokens=False)
                groundtruth_ids=self.tokenizer.encode(groundtruth, add_special_tokens=False)
                # es=calc_edit_sim_per_sample(pred=completion_ids,gt=groundtruth_ids)
                es_r0=calc_edit_sim_per_sample(pred=completion_ids_0,gt=groundtruth_ids)
                es_r1=calc_edit_sim_per_sample(pred=completion_ids_1,gt=groundtruth_ids)
                # es_r0=calc_edit_sim_per_sample_by_edit_dist(pred=completion_ids_0,gt=groundtruth_ids)
                # es_r1=calc_edit_sim_per_sample_by_edit_dist(pred=completion_ids_1,gt=groundtruth_ids)
                rank=[int(es_r0>es_r1),int(es_r1>es_r0)]
                rank_list.append(rank)
                id_str=f'completion_ids_0={completion_ids_0},\ncompletion_ids_1={completion_ids_1},\ngroundtruth_ids={groundtruth_ids},\n'
            else:
                raise ValueError(f'judge_text_type={self.judge_text_type}')
            es_list.append([es_r0,es_r1])
            index+=1
            info_msg=f'CodeRulePairRMJudge.judge index={index}, es_r0={es_r0},es_r1={es_r1},rank={rank},\ncompletion[0]={completion_tuple[0]},\ncompletion[1]={completion_tuple[1]},\ngroundtruth={groundtruth},\n'+id_str+f'\nprompt={prompt},\n'
            logger.info(info_msg)
        rank_index=torch.amax(torch.tensor(rank_list),dim=-1).tolist()
        logger.info(f'rank_index={rank_index}',main_process_only=True)
        return rank_index
        # return np.array(es_list)


        # return np.array(list([range(len(completions))]))
        # rank_index=[[0,1],[1,0]]
        # return np.array(rank_index)


def init_components(args, training_args):
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    # if accelerator.is_main_process:
    #   judge = PairRMJudge()
    # accelerator.wait_for_everyone()
    # judge = PairRMJudge()
    # if accelerator.is_main_process:
    #   judge=RandomPairwiseJudge()
    # accelerator.wait_for_everyone()
    # judge=RandomPairwiseJudge()
    judge=CodeRulePairRMJudge(tokenizer,judge_text_type=args.judge_text_type)
    """
    >>> translation_generation_config = GenerationConfig(
         num_beams=4,
         early_stopping=True,
         decoder_start_token_id=0,
         eos_token_id=model.config.eos_token_id,
         pad_token=model.config.pad_token_id,       
    """
    template=template_dict['deepseek-coder-base']
    print(f'tokenizer={tokenizer},pad_token_id={tokenizer.pad_token_id},eos_token_id={tokenizer.eos_token_id},bos_token_id={tokenizer.bos_token_id}')
    # train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
    train_dataset = UnifiedOnlineDPODataset(file=args.train_file, tokenizer=tokenizer, 
                max_seq_length=args.max_seq_length, template=template,
                maximum_es_score=args.maximum_es_score,minimum_es_score=args.minimum_es_score,bool_training=True)
    eval_dataset = UnifiedOnlineDPODataset(file=args.eval_file, tokenizer=tokenizer, 
                max_seq_length=args.max_seq_length, template=template,
                maximum_es_score=args.maximum_es_score,minimum_es_score=args.minimum_es_score,bool_training=False)
    torch_dtype = torch.float16 if training_args.fp16 else torch.bfloat16
    model_kwargs = dict(
        trust_remote_code=True,
        # attn_implementation=attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False,
    )
    # trainer.ref_model = Accelerator().prepare(dpo_trainer.ref_model) # add this
    # ref_model=AutoCausalLM.from_pretrained(model_name_or_path).to('cuda:7')
    # trainer.ref_model=ref_model
    if False:
        bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
        )

        model = AutoModelForCausalLM.from_pretrained(
            # script_args.model_name,        # "meta-llama/Llama-2-7b-hf"
            pretrained_model_name_or_path=args.model_name_or_path,
            quantization_config=bnb_config,
            # device_map="auto",
            trust_remote_code=True,
            torch_dtype=torch_dtype
            # use_auth_token=True,
        )
        model.config.use_cache = False
        #     "lora_rank": 128,
        #     "lora_alpha": 16,
        #     "lora_dropout": 0.05,
        # add LoRA layers on top of the quantized base model
        peft_config = LoraConfig(
            r=128,
            lora_alpha=16,
            lora_dropout=0.05,
            target_modules=["q_proj", "v_proj"],
            bias="none",
            task_type="CAUSAL_LM",
        )
    else:
        peft_config=None
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,**model_kwargs)
    ref_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,**model_kwargs)
    
    data_collator = OnlineDPOCollator(tokenizer=tokenizer, max_seq_length=args.max_seq_length,pad_direction='left',max_len_strategy='max_seq_length')  

    trainer = OnlineDPOTrainer(
        model=model,ref_model=ref_model,judge=judge, args=training_args,custom_args=args,processing_class=tokenizer, 
        train_dataset=train_dataset,eval_dataset=eval_dataset,
        peft_config=peft_config,
        data_collator=data_collator

    )
    completion_callback=LogCompletionsCallback(trainer=trainer,num_prompts=20)
    trainer.add_callback(completion_callback)
    return trainer

def main():
    cuda_msg=f'torch.cuda.get_device_name()={torch.cuda.get_device_name()},torch.cuda.current_device()={torch.cuda.current_device()}'
    logger.info(cuda_msg)
    args, training_args = setup_everything()
    trainer=init_components(args, training_args)
    # dpo_trainer.train()
    train_result=trainer.train()
    metrics = train_result.metrics
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    if training_args.do_eval:
        metrics = trainer.evaluate()
        try:
            perplexity = math.exp(metrics["eval_loss"])
        except OverflowError:
            perplexity = float("inf")
        metrics["perplexity"] = perplexity
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)
        if trainer.is_world_process_zero():
            logger.debug(f"Eval metrics: {metrics}")

    final_save_path = os.path.join(training_args.output_dir, "final_checkpoint")
    os.makedirs(final_save_path, exist_ok=True)
    ##?? 没看到 is_world_process_zero 逻辑?
    trainer.save_model(final_save_path)  # Saves the tokenizer too  

outputs:

Traceback (most recent call last):
  File "example.py", line 42, in <module>
    ...

the final script as follows

. base.sh
export TRAIN_MODE=full
export TASK_TYPE=online_dpo
export DEEPSPEED=3
export DEEPSPEED_CONFIG_FILE=../train_args/ds_z3_config_new.json
if [ "${BOOL_DEBUG}" = "true" ];then
    SHUJUJUAN_MODEL_DIR=$LPAI_INPUT_DATA_0/data/multitask_best_rl_debug
elif [ "${BOOL_DEBUG}" = "false" ];then
    SHUJUJUAN_MODEL_DIR=$LPAI_INPUT_DATA_0/data/multitask_best_rl
else
    echo "BOOL_DEBUG is not set"
    exit 1
fi

SHUJUJUAN_MODEL_DIR=$LPAI_INPUT_DATA_0/data/multitask_best_rl
DATE_STR="${data-$(date +%Y%m%d%H%M)}"
export DEPARTMENT=full
export SHUJUJUAN_BEST_OUTPUT_DIR=$SHUJUJUAN_MODEL_DIR/${DATE_STR}/best_model/${DEPARTMENT} 
echo "SHUJUJUAN_BEST_OUTPUT_DIR=${SHUJUJUAN_BEST_OUTPUT_DIR}"


# export ZERO_STAGE2_CONFIG_JSONPATH=${WORKDIR}/train_args/ds_z3_config_new.json
# DS_CONFIG_FILE=../train_args/ds_z3_config_new.yaml


readarray -t elements < <(get_train_args_file  $TRAIN_MODE $TASK_TYPE)
train_args_file=${elements[0]}
echo "train_args_file=${train_args_file}"
readarray -t elements1 < <(get_ds_config_file  $DEEPSPEED $TASK_TYPE)
export DS_CONFIG_FILE=${elements1[0]}
export LOW_CPU_MEM_USAGE=${elements1[1]}
echo "DS_CONFIG_FILE=${DS_CONFIG_FILE}"


if [ "${DDP}" = "accelerate" ];then
    accelerate launch \
        --config_file ${DS_CONFIG_FILE} \
        --num_machines ${NODE_NUM} \
        --num_processes ${WORLD_SIZE} \
        --main_process_ip ${MASTER_ADDR} \
        --main_process_port ${MASTER_PORT} \
        ../train_onlinedpo.py \
        --train_args_file $train_args_file
elif [ "${DDP}" = "torchrun" ];then
    torchrun --nproc_per_node=$GPU_NUM --nnode=$NODE_NUM --node_rank=$RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
        ../train_onlinedpo.py \
        --train_args_file $train_args_file
fi


ENV as follows

{
  "MODEL_NAME_OR_PATH": "xxx",
  "TRAIN_FILE": "xxx",
  "EVAL_FILE": "xxx",
  "PER_DEVICE_TRAIN_BATCH_SIZE": "16",
  "MAX_SEQ_LENGTH": "4096",
  "MAX_PROMPT_LENGTH": "4096",
  "LEARNING_RATE": "5e-7",
  "GRADIENT_ACCUMULATION_STEPS": "1",
  "RPO_ALPHA": "0.5",
  "DEEPSPEED": "2",
  "MAXIMUM_ES_SCORE": "80",
  "MINIMUM_ES_SCORE": "10",
  "BOOL_DEBUG": "true",
  "MAX_GRAD_NORM": "1.0",
  "LOSS_TYPE": "ipo_norm",
  "BETA": "0.9",
  "MAX_NEW_TOKENS": "128",
  "JUDGE_TEXT_TYPE": "id",
  "NUM_TRAIN_EPOCHS": "1",
  "DDP": "torchrun",
  "DOLA_LAYERS": "none",
  "REPETITION_PENALTY": "1.0"
}

Expected behavior

error info

| [rank4]: Traceback (most recent call last): |  

  | app | task-multitask-rl-dev-26d7c8b2
  | container | main
  | filename | /var/log/pods/sc-ep_task-multitask-rl-dev-26d7c8b2-master-0_0dfb4ee9-631c-4e77-8070-46d72626f885/main/0.log

  |   | 2024-12-30 10:53:44.559 | [rank4]: Traceback (most recent call last): |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/lpai-running/code/firefly-zyy-dev/339ecc/shells/../train_onlinedpo.py", line 251, in <module> |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     main() |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/lpai-running/code/firefly-zyy-dev/339ecc/shells/../train_onlinedpo.py", line 195, in main |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     train_result=trainer.train() |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:                  ^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2164, in train |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     return inner_training_loop( |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:            ^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/transformers/trainer.py", line 2522, in _inner_training_loop |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch) |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/lpai-running/code/firefly-zyy-dev/339ecc/models/online_dpo_trainer.py", line 480, in training_step |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     output = unwrapped_model.generate( |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:              ^^^^^^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     return func(*args, **kwargs) |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:            ^^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 2252, in generate |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     result = self._sample( |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:              ^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 3254, in _sample |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     outputs = model_forward(**model_inputs, return_dict=True) |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     return self._call_impl(*args, **kwargs) |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1603, in _call_impl |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     result = forward_call(*args, **kwargs) |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1163, in forward |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     outputs = self.model( |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:               ^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     return self._call_impl(*args, **kwargs) |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     return forward_call(*args, **kwargs) |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 883, in forward |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     causal_mask = self._update_causal_mask( |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 993, in _update_causal_mask |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1060, in _prepare_4d_causal_attention_mask_with_cache_position |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: RuntimeError: The size of tensor a (4137) must match the size of tensor b (4138) at non-singleton dimension 0 |  
  | Fieldsapptask-multitask-rl-dev-26d7c8b2containermainfilename/var/log/pods/sc-ep_task-multitask-rl-dev-26d7c8b2-master-0_0dfb4ee9-631c-4e77-8070-46d72626f885/main/0.logjobsc-ep/task-multitask-rl-dev-26d7c8b2namespacesc-epnode_name10.48.7.142podtask-multitask-rl-dev-26d7c8b2-master-0streamstderr | Fields |   | app | task-multitask-rl-dev-26d7c8b2 |   | container | main |   | filename | /var/log/pods/sc-ep_task-multitask-rl-dev-26d7c8b2-master-0_0dfb4ee9-631c-4e77-8070-46d72626f885/main/0.log |   | job | sc-ep/task-multitask-rl-dev-26d7c8b2 |   | namespace | sc-ep |   | node_name | 10.48.7.142 |   | pod | task-multitask-rl-dev-26d7c8b2-master-0 |   | stream | stderr
Fields
  | app | task-multitask-rl-dev-26d7c8b2
  | container | main
  | filename | /var/log/pods/sc-ep_task-multitask-rl-dev-26d7c8b2-master-0_0dfb4ee9-631c-4e77-8070-46d72626f885/main/0.log
  | job | sc-ep/task-multitask-rl-dev-26d7c8b2
  | namespace | sc-ep
  | node_name | 10.48.7.142
  | pod | task-multitask-rl-dev-26d7c8b2-master-0
  | stream | stderr
  |   | 2024-12-30 10:53:44.559 | [rank4]: Exception raised from infer_size_impl at /opt/conda/conda-bld/pytorch_1720538435607/work/aten/src/ATen/ExpandUtils.cpp:31 (most recent call first): |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: C++ CapturedTraceback: |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #4 std::_Function_handler<std::shared_ptr<c10::LazyValue<std::string> const> (), c10::SetStackTraceFetcher(std::function<std::string ()>)::{lambda()#1}>::_M_invoke(std::_Any_data const&) from Logging.cpp:0 |  
 
  |   | 2024-12-30 10:53:44.559 | [rank4]: #40 do_call_core from /usr/local/src/conda/python-3.11.9/Python/ceval.c:7349 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #41 _PyEval_EvalFrame from /usr/local/src/conda/python-3.11.9/Include/internal/pycore_ceval.h:73 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #42 method_vectorcall from /usr/local/src/conda/python-3.11.9/Objects/classobject.c:59 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #43 _PyVectorcall_Call from /usr/local/src/conda/python-3.11.9/Objects/call.c:257 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #44 do_call_core from /usr/local/src/conda/python-3.11.9/Python/ceval.c:7349 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #45 _PyEval_EvalFrame from /usr/local/src/conda/python-3.11.9/Include/internal/pycore_ceval.h:73 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #46 method_vectorcall from /usr/local/src/conda/python-3.11.9/Objects/classobject.c:59 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #47 _PyVectorcall_Call from /usr/local/src/conda/python-3.11.9/Objects/call.c:257 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #48 do_call_core from /usr/local/src/conda/python-3.11.9/Python/ceval.c:7349 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #49 _PyEval_EvalFrame from /usr/local/src/conda/python-3.11.9/Include/internal/pycore_ceval.h:73 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #50 method_vectorcall from /usr/local/src/conda/python-3.11.9/Objects/classobject.c:59 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #51 _PyVectorcall_Call from /usr/local/src/conda/python-3.11.9/Objects/call.c:257 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #52 do_call_core from /usr/local/src/conda/python-3.11.9/Python/ceval.c:7349 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #53 _PyEval_EvalFrame from /usr/local/src/conda/python-3.11.9/Include/internal/pycore_ceval.h:73 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #54 _PyVectorcall_Call from /usr/local/src/conda/python-3.11.9/Objects/call.c:257 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #55 do_call_core from /usr/local/src/conda/python-3.11.9/Python/ceval.c:7349 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #56 _PyEval_EvalFrame from /usr/local/src/conda/python-3.11.9/Include/internal/pycore_ceval.h:73 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #57 method_vectorcall from /usr/local/src/conda/python-3.11.9/Objects/classobject.c:59 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #58 _PyVectorcall_Call from /usr/local/src/conda/python-3.11.9/Objects/call.c:257 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #59 partial_call from /usr/local/src/conda/python-3.11.9/Modules/_functoolsmodule.c:324 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #60 _PyObject_MakeTpCall from /usr/local/src/conda/python-3.11.9/Objects/call.c:214 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #61 _PyObject_VectorcallTstate from /usr/local/src/conda/python-3.11.9/Include/internal/pycore_call.h:92 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #62 _PyEval_EvalFrameDefault from /usr/local/src/conda/python-3.11.9/Python/ceval.c:4769 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #63 _PyEval_EvalFrame from /usr/local/src/conda/python-3.11.9/Include/internal/pycore_ceval.h:73 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #64 PyEval_EvalCode from /usr/local/src/conda/python-3.11.9/Python/ceval.c:1148 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #65 run_eval_code_obj from /usr/local/src/conda/python-3.11.9/Python/pythonrun.c:1741 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #66 run_mod from /usr/local/src/conda/python-3.11.9/Python/pythonrun.c:1762 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #67 pyrun_file from /usr/local/src/conda/python-3.11.9/Python/pythonrun.c:1657 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #68 _PyRun_SimpleFileObject from /usr/local/src/conda/python-3.11.9/Python/pythonrun.c:440 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #69 _PyRun_AnyFileObject from /usr/local/src/conda/python-3.11.9/Python/pythonrun.c:79 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #70 pymain_run_file_obj from /usr/local/src/conda/python-3.11.9/Modules/main.c:360 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #71 Py_BytesMain from /usr/local/src/conda/python-3.11.9/Modules/main.c:734 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #72 __libc_start_call_main from ./csu/../sysdeps/nptl/libc_start_call_main.h:58 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #73 __libc_start_main_impl from ./csu/../csu/libc-start.c:392 |  
  |   | 2024-12-30 10:53:44.559 | [rank4]: #74 _start from ??:0 |  
  |   | 2024-12-30 10:53:44.559 |   |  

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
@yiyepiaoling0715
Copy link
Author

image

@August-murr August-murr added 🐛 bug Something isn't working 🚀 deepspeed Related to deepspeed 🏋 Online DPO Related to Online DPO labels Dec 30, 2024
@qgallouedec
Copy link
Member

qgallouedec commented Dec 30, 2024

  • Any code provided is minimal, complete, and reproducible (more on MREs)

can you please minimise your code? It seems like the error occurs at generation; what the input of the model here?:

  |   | 2024-12-30 10:53:44.559 | [rank4]:   File "/opt/conda/lib/python3.11/site-packages/transformers/generation/utils.py", line 3254, in _sample |  
  |   | 2024-12-30 10:53:44.559 | [rank4]:     outputs = model_forward(**model_inputs, return_dict=True) |  

Can you reproduce the error without all the training logic?

@qgallouedec qgallouedec added the ⏳ needs more info Additional information or clarification is required to proceed label Jan 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🚀 deepspeed Related to deepspeed ⏳ needs more info Additional information or clarification is required to proceed 🏋 Online DPO Related to Online DPO
Projects
None yet
Development

No branches or pull requests

3 participants