Skip to content

Commit

Permalink
add training and decode instruction (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 authored Jan 24, 2025
1 parent cc791ca commit b186105
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
14 changes: 12 additions & 2 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,21 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
mel = mel[:, :self.config.max_mel_size]
ids_audio = [0] * self.config.max_speech_token_size
tgt_audio = [IGNORE_TOKEN_ID] * len(ids_audio)
chat = [{"role": "user", "content": "Transcribe the speech"}]
if 'instruction' in msg:
instruction = msg['instruction']
elif self.inference and self.config.decode_instruction != '':
instruction = self.config.decode_instruction
else:
instruction = 'Transcribe the speech'
chat = [{"role": "user", "content": instruction}]
# `content`: the anwser acorrding to the audio and instruction
# `txt`: the transcription of the audio
# If there is no content, the default `content` is the same as `txt`.
content = msg['content'] if 'content' in msg else msg['txt']
if self.inference:
kwargs = {'add_generation_prompt': True}
else:
chat.append({"role": "assistant", "content": msg['txt']})
chat.append({"role": "assistant", "content": content})
kwargs = {
'padding': 'max_length',
'max_length': self.config.model_max_length -
Expand Down
2 changes: 2 additions & 0 deletions speech_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class ModelArguments:
frames_per_second: int = 100
# CTC related, if ctc_weight > 0, CTC loss is applied in training.
ctc_weight: Optional[float] = field(default=0.0)
# For decode
decode_instruction: Optional[str] = field(default="")

@property
def ds_rate(self):
Expand Down

0 comments on commit b186105

Please sign in to comment.