Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ctc tie #8

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2025 Binbin Zhang([email protected])

import math
import json
from dataclasses import dataclass, field
from typing import Dict
Expand Down Expand Up @@ -54,6 +55,8 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if sample_rate != 16000:
audio = torchaudio.transforms.Resample(sample_rate, 16000)(audio)
audio = audio[0] # get the first channel
# 10 frames per second after downsample
mel_len = math.ceil(float(audio.size(0)) / 16000 * 10)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio)
ids_audio = [0] * int(mel.shape[1] / 10) # 10x downsample
Expand All @@ -78,9 +81,22 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
target_ids = torch.tensor(tgt, dtype=torch.int)
target_ids[target_ids == self.tokenizer.pad_token_id] = IGNORE_TOKEN_ID
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
return {

ctc_tokens = self.tokenizer(msg['txt'],
padding='max_length',
max_length=100,
truncation=True,
return_tensors='pt')
ctc_ids = ctc_tokens['input_ids'][0]
ctc_ids_len = ctc_tokens['attention_mask'].sum().item()
ret = {
'input_ids': input_ids,
'labels': target_ids,
'attention_mask': attention_mask,
'mel': mel,
'mel_len': mel_len,
}
if not self.inference:
ret['ctc_ids'] = ctc_ids
ret['ctc_ids_len'] = ctc_ids_len
return ret
11 changes: 8 additions & 3 deletions recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@dataclass
class DecodeArguments:
llm_type: str = 'qwen2'
decode_type: str = 'llm'
max_new_tokens: int = 50
num_beams: int = 1
batch_size: int = 1
Expand Down Expand Up @@ -46,11 +47,15 @@ def main():
model, data_loader = accelerator.prepare(model, data_loader)
model.eval()
fid = open(decode_args.result_path, 'w', encoding='utf8')
if decode_args.decode_type == 'llm':
decode_func = model.generate
else:
decode_func = model.decode_ctc
with torch.no_grad():
for item in tqdm(data_loader):
generated_ids = model.generate(**item,
eos_token_id=eos_token_id,
decode_config=decode_args)
generated_ids = decode_func(**item,
eos_token_id=eos_token_id,
decode_config=decode_args)
text = tokenizer.batch_decode(generated_ids,
skip_special_tokens=True)
print(text)
Expand Down
65 changes: 60 additions & 5 deletions speech_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ class ModelArguments:
projector_model_path: Optional[str] = field(default=None)


def ctc_reduce(hyp, blank_id: int = 0):
new_hyp = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp


class ProjectorCov1d(nn.Module):

def __init__(self, config, encoder_dim, llm_dim):
Expand Down Expand Up @@ -73,17 +85,22 @@ def __init__(
self._keys_to_ignore_on_save.add('llm.' + k)
for k in self.encoder.state_dict().keys():
self._keys_to_ignore_on_save.add('encoder.' + k)
# Use bos_token_id as CTC blank id
self.ctc_loss = nn.CTCLoss(config.bos_token_id,
reduction='mean',
zero_infinity=True)
self.blank_id = config.bos_token_id

def get_input_embedding(self, input_ids, mel):
# whisper, 30s, 2x downsample = 1500
# whisper + projector, 10x downsample, there is 300 outputs of 30s.
speech_size = 300
speech_emb = self.encoder.embed_audio(mel) # (b, n_mel, 1500)
# projector, x 5x downsample = 300
speech_proj = self.projector(speech_emb) # (b, x, 300)
text_emb = self.llm.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat((speech_proj, text_emb[:, speech_size:, :]),
dim=1)
return inputs_embeds
return inputs_embeds, speech_proj

@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def forward(
Expand All @@ -92,13 +109,26 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
mel: torch.LongTensor = None,
mel_len: torch.LongTensor = None,
ctc_ids: torch.LongTensor = None,
ctc_ids_len: torch.LongTensor = None,
):
inputs_embeds = self.get_input_embedding(input_ids, mel)
return self.llm(
inputs_embeds, speech_proj = self.get_input_embedding(input_ids, mel)
# Tie CTC linear transforme and input embedding weight
ctc_linear = self.llm.get_input_embeddings().weight
ctc_act = torch.matmul(speech_proj, ctc_linear.T)
ctc_act = ctc_act.transpose(0, 1)
ctc_act = ctc_act.log_softmax(2)
with torch.cuda.amp.autocast(enabled=False):
closs = self.ctc_loss(ctc_act.float(), ctc_ids, mel_len,
ctc_ids_len)
out = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
)
out.loss = 0.9 * out.loss + 0.1 * closs
return out

@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def generate(
Expand All @@ -107,10 +137,11 @@ def generate(
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
mel: torch.LongTensor = None,
mel_len: torch.LongTensor = None,
eos_token_id=None,
decode_config=None,
):
inputs_embeds = self.get_input_embedding(input_ids, mel)
inputs_embeds, _ = self.get_input_embedding(input_ids, mel)
model_outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
Expand All @@ -122,6 +153,30 @@ def generate(
)
return model_outputs

@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def decode_ctc(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
mel: torch.LongTensor = None,
mel_len: torch.LongTensor = None,
eos_token_id=None,
decode_config=None,
):
_, speech_proj = self.get_input_embedding(input_ids, mel)
# Tie CTC linear transforme and input embedding weight
ctc_linear = self.llm.get_input_embeddings().weight
ctc_act = torch.matmul(speech_proj, ctc_linear.T)
ctc_probs = ctc_act.log_softmax(2)
batch_size = ctc_probs.size(0)
results = []
for i in range(batch_size):
top1 = ctc_probs[i][:mel_len[i], :].argmax(dim=1)
hyp = ctc_reduce(top1.tolist(), self.blank_id)
results.append(hyp)
return results

def enable_input_require_grads(self):
self.llm.enable_input_require_grads()

Expand Down