diff --git a/dataset.py b/dataset.py index 682990f..dc790cf 100644 --- a/dataset.py +++ b/dataset.py @@ -78,11 +78,21 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: mel = mel[:, :self.config.max_mel_size] ids_audio = [0] * self.config.max_speech_token_size tgt_audio = [IGNORE_TOKEN_ID] * len(ids_audio) - chat = [{"role": "user", "content": "Transcribe the speech"}] + if 'instruction' in msg: + instruction = msg['instruction'] + elif self.inference and self.config.decode_instruction != '': + instruction = self.config.decode_instruction + else: + instruction = 'Transcribe the speech' + chat = [{"role": "user", "content": instruction}] + # `content`: the anwser acorrding to the audio and instruction + # `txt`: the transcription of the audio + # If there is no content, the default `content` is the same as `txt`. + content = msg['content'] if 'content' in msg else msg['txt'] if self.inference: kwargs = {'add_generation_prompt': True} else: - chat.append({"role": "assistant", "content": msg['txt']}) + chat.append({"role": "assistant", "content": content}) kwargs = { 'padding': 'max_length', 'max_length': self.config.model_max_length - diff --git a/speech_llm.py b/speech_llm.py index b3e3e05..88d6116 100644 --- a/speech_llm.py +++ b/speech_llm.py @@ -35,6 +35,8 @@ class ModelArguments: frames_per_second: int = 100 # CTC related, if ctc_weight > 0, CTC loss is applied in training. ctc_weight: Optional[float] = field(default=0.0) + # For decode + decode_instruction: Optional[str] = field(default="") @property def ds_rate(self):