From 41b58c62cb2ed193d1bcd34ac3391e4bccbfcbd1 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Mon, 6 Jan 2025 06:48:54 +0000 Subject: [PATCH 1/2] add ctc --- dataset.py | 18 +++++++++++++- recognize.py | 3 +++ speech_llm.py | 68 +++++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 83 insertions(+), 6 deletions(-) diff --git a/dataset.py b/dataset.py index ebf4785..ccff8c1 100644 --- a/dataset.py +++ b/dataset.py @@ -1,5 +1,6 @@ # Copyright (c) 2025 Binbin Zhang(binbzha@qq.com) +import math import json from dataclasses import dataclass, field from typing import Dict @@ -54,6 +55,8 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: if sample_rate != 16000: audio = torchaudio.transforms.Resample(sample_rate, 16000)(audio) audio = audio[0] # get the first channel + # 10 frames per second after downsample + mel_len = math.ceil(float(audio.size(0)) / 16000 * 10) audio = whisper.pad_or_trim(audio) mel = whisper.log_mel_spectrogram(audio) ids_audio = [0] * int(mel.shape[1] / 10) # 10x downsample @@ -78,9 +81,22 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: target_ids = torch.tensor(tgt, dtype=torch.int) target_ids[target_ids == self.tokenizer.pad_token_id] = IGNORE_TOKEN_ID attention_mask = input_ids.ne(self.tokenizer.pad_token_id) - return { + + ctc_tokens = self.tokenizer(msg['txt'], + padding='max_length', + max_length=100, + truncation=True, + return_tensors='pt') + ctc_ids = ctc_tokens['input_ids'][0] + ctc_ids_len = ctc_tokens['attention_mask'].sum().item() + ret = { 'input_ids': input_ids, 'labels': target_ids, 'attention_mask': attention_mask, 'mel': mel, + 'mel_len': mel_len, } + if not self.inference: + ret['ctc_ids'] = ctc_ids + ret['ctc_ids_len'] = ctc_ids_len + return ret diff --git a/recognize.py b/recognize.py index 02a3aaf..5f251ef 100644 --- a/recognize.py +++ b/recognize.py @@ -51,6 +51,9 @@ def main(): generated_ids = model.generate(**item, eos_token_id=eos_token_id, decode_config=decode_args) + # generated_ids = model.decode_ctc(**item, + # eos_token_id=eos_token_id, + # decode_config=decode_args) text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) print(text) diff --git a/speech_llm.py b/speech_llm.py index 3aa2f7f..944f04d 100644 --- a/speech_llm.py +++ b/speech_llm.py @@ -21,6 +21,18 @@ class ModelArguments: projector_model_path: Optional[str] = field(default=None) +def ctc_reduce(hyp, blank_id: int = 0): + new_hyp = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != blank_id: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + class ProjectorCov1d(nn.Module): def __init__(self, config, encoder_dim, llm_dim): @@ -73,9 +85,15 @@ def __init__( self._keys_to_ignore_on_save.add('llm.' + k) for k in self.encoder.state_dict().keys(): self._keys_to_ignore_on_save.add('encoder.' + k) + self.ctc_linear = nn.Linear(config.hidden_size, config.vocab_size) + # Use bos_token_id as CTC blank id + self.ctc_loss = nn.CTCLoss(config.bos_token_id, + reduction='mean', + zero_infinity=True) + self.blank_id = config.bos_token_id def get_input_embedding(self, input_ids, mel): - # whisper, 30s, 2x downsample = 1500 + # whisper + projector, 10x downsample, there is 300 outputs of 30s. speech_size = 300 speech_emb = self.encoder.embed_audio(mel) # (b, n_mel, 1500) # projector, x 5x downsample = 300 @@ -83,7 +101,7 @@ def get_input_embedding(self, input_ids, mel): text_emb = self.llm.get_input_embeddings()(input_ids) inputs_embeds = torch.cat((speech_proj, text_emb[:, speech_size:, :]), dim=1) - return inputs_embeds + return inputs_embeds, speech_proj @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def forward( @@ -92,13 +110,24 @@ def forward( attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, mel: torch.LongTensor = None, + mel_len: torch.LongTensor = None, + ctc_ids: torch.LongTensor = None, + ctc_ids_len: torch.LongTensor = None, ): - inputs_embeds = self.get_input_embedding(input_ids, mel) - return self.llm( + inputs_embeds, speech_proj = self.get_input_embedding(input_ids, mel) + ctc_act = self.ctc_linear(speech_proj) + ctc_act = ctc_act.transpose(0, 1) + ctc_act = ctc_act.log_softmax(2) + with torch.cuda.amp.autocast(enabled=False): + closs = self.ctc_loss(ctc_act.float(), ctc_ids, mel_len, + ctc_ids_len) + out = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, ) + out.loss = 0.9 * out.loss + 0.1 * closs + return out @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def generate( @@ -107,10 +136,11 @@ def generate( attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, mel: torch.LongTensor = None, + mel_len: torch.LongTensor = None, eos_token_id=None, decode_config=None, ): - inputs_embeds = self.get_input_embedding(input_ids, mel) + inputs_embeds, _ = self.get_input_embedding(input_ids, mel) model_outputs = self.llm.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, @@ -122,12 +152,38 @@ def generate( ) return model_outputs + @torch.autocast(device_type="cuda", dtype=torch.bfloat16) + def decode_ctc( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + mel: torch.LongTensor = None, + mel_len: torch.LongTensor = None, + eos_token_id=None, + decode_config=None, + ): + _, speech_proj = self.get_input_embedding(input_ids, mel) + ctc_act = self.ctc_linear(speech_proj) # (B, T, D) + ctc_probs = ctc_act.log_softmax(2) + batch_size = ctc_probs.size(0) + results = [] + for i in range(batch_size): + top1 = ctc_probs[i][:mel_len[i], :].argmax(dim=1) + hyp = ctc_reduce(top1.tolist(), self.blank_id) + results.append(hyp) + return results + def enable_input_require_grads(self): self.llm.enable_input_require_grads() def freeze_encoder(self): freeze_model(self.encoder) + def copy_llm_embedding_weight(self): + embedding_weights = self.llm.get_input_embeddings().weight + self.ctc_linear.weight.copy_(embedding_weights) + def freeze_llm(self): freeze_model(self.llm) @@ -153,6 +209,8 @@ def init_model(model_args): total_params = sum(p.numel() for p in projector.parameters()) print('Projector total params: {:.2f}M'.format(total_params / 1024 / 1024)) model = SpeechLLM(config, llm_model, encoder, projector) + total_params = sum(p.numel() for p in model.ctc_linear.parameters()) + print('CTC total params: {:.2f}M'.format(total_params / 1024 / 1024)) if model_args.projector_model_path is not None: model.load_projector(model_args.projector_model_path) return model From c4e444a08ac91dd666f79cf23f7d07639f291072 Mon Sep 17 00:00:00 2001 From: Binbin Zhang Date: Tue, 7 Jan 2025 02:21:11 +0000 Subject: [PATCH 2/2] use tie embedding --- recognize.py | 14 ++++++++------ speech_llm.py | 15 ++++++--------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/recognize.py b/recognize.py index 5f251ef..71645f8 100644 --- a/recognize.py +++ b/recognize.py @@ -16,6 +16,7 @@ @dataclass class DecodeArguments: llm_type: str = 'qwen2' + decode_type: str = 'llm' max_new_tokens: int = 50 num_beams: int = 1 batch_size: int = 1 @@ -46,14 +47,15 @@ def main(): model, data_loader = accelerator.prepare(model, data_loader) model.eval() fid = open(decode_args.result_path, 'w', encoding='utf8') + if decode_args.decode_type == 'llm': + decode_func = model.generate + else: + decode_func = model.decode_ctc with torch.no_grad(): for item in tqdm(data_loader): - generated_ids = model.generate(**item, - eos_token_id=eos_token_id, - decode_config=decode_args) - # generated_ids = model.decode_ctc(**item, - # eos_token_id=eos_token_id, - # decode_config=decode_args) + generated_ids = decode_func(**item, + eos_token_id=eos_token_id, + decode_config=decode_args) text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) print(text) diff --git a/speech_llm.py b/speech_llm.py index 944f04d..1b5184b 100644 --- a/speech_llm.py +++ b/speech_llm.py @@ -85,7 +85,6 @@ def __init__( self._keys_to_ignore_on_save.add('llm.' + k) for k in self.encoder.state_dict().keys(): self._keys_to_ignore_on_save.add('encoder.' + k) - self.ctc_linear = nn.Linear(config.hidden_size, config.vocab_size) # Use bos_token_id as CTC blank id self.ctc_loss = nn.CTCLoss(config.bos_token_id, reduction='mean', @@ -115,7 +114,9 @@ def forward( ctc_ids_len: torch.LongTensor = None, ): inputs_embeds, speech_proj = self.get_input_embedding(input_ids, mel) - ctc_act = self.ctc_linear(speech_proj) + # Tie CTC linear transforme and input embedding weight + ctc_linear = self.llm.get_input_embeddings().weight + ctc_act = torch.matmul(speech_proj, ctc_linear.T) ctc_act = ctc_act.transpose(0, 1) ctc_act = ctc_act.log_softmax(2) with torch.cuda.amp.autocast(enabled=False): @@ -164,7 +165,9 @@ def decode_ctc( decode_config=None, ): _, speech_proj = self.get_input_embedding(input_ids, mel) - ctc_act = self.ctc_linear(speech_proj) # (B, T, D) + # Tie CTC linear transforme and input embedding weight + ctc_linear = self.llm.get_input_embeddings().weight + ctc_act = torch.matmul(speech_proj, ctc_linear.T) ctc_probs = ctc_act.log_softmax(2) batch_size = ctc_probs.size(0) results = [] @@ -180,10 +183,6 @@ def enable_input_require_grads(self): def freeze_encoder(self): freeze_model(self.encoder) - def copy_llm_embedding_weight(self): - embedding_weights = self.llm.get_input_embeddings().weight - self.ctc_linear.weight.copy_(embedding_weights) - def freeze_llm(self): freeze_model(self.llm) @@ -209,8 +208,6 @@ def init_model(model_args): total_params = sum(p.numel() for p in projector.parameters()) print('Projector total params: {:.2f}M'.format(total_params / 1024 / 1024)) model = SpeechLLM(config, llm_model, encoder, projector) - total_params = sum(p.numel() for p in model.ctc_linear.parameters()) - print('CTC total params: {:.2f}M'.format(total_params / 1024 / 1024)) if model_args.projector_model_path is not None: model.load_projector(model_args.projector_model_path) return model