Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "add speech understanding" #15

Merged
merged 1 commit into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 0 additions & 56 deletions compute_similarity.py

This file was deleted.

1 change: 0 additions & 1 deletion dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
ids_audio = [0] * self.config.max_speech_token_size
tgt_audio = [IGNORE_TOKEN_ID] * len(ids_audio)
chat = [{"role": "user", "content": "Transcribe the speech"}]
# chat = [{"role": "user", "content": "Translate the audio to English"}]
if self.inference:
kwargs = {'add_generation_prompt': True}
else:
Expand Down
108 changes: 17 additions & 91 deletions speech_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import transformers
from transformers import AutoModelForCausalLM, PreTrainedModel
import wenet
import whisper

from utils import ctc_reduce, ctc_peak_time, lists_to_tensor


@dataclass
class ModelArguments:
Expand All @@ -38,8 +35,6 @@ class ModelArguments:
frames_per_second: int = 100
# CTC related, if ctc_weight > 0, CTC loss is applied in training.
ctc_weight: Optional[float] = field(default=0.0)
ctc_reduce: Optional[bool] = field(default=False)
reduced_speech_token_per_second: int = 6

@property
def ds_rate(self):
Expand All @@ -59,6 +54,18 @@ def max_mel_size(self):
return self.max_speech_seconds * self.frames_per_second


def ctc_reduce(hyp, blank_id: int = 0):
new_hyp = []
cur = 0
while cur < len(hyp):
if hyp[cur] != blank_id:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp


class ProjectorCov1d(nn.Module):

def __init__(self, config, encoder_dim, llm_dim):
Expand Down Expand Up @@ -136,29 +143,6 @@ def get_speech_embeddings(self, mel, mel_len):
speech_proj = F.pad(speech_proj, (0, 0, 0, pad_size), value=0.0)
return speech_proj

def select_speech_embeddings(
self,
ctc_prob,
ctc_ids,
prob_len,
ctc_ids_len,
speech_emb,
):
""" Select speech embeddings by force align
"""
out_emb = torch.zeros_like(speech_emb)
batch_size = ctc_prob.size(0)
for i in range(batch_size):
# The current forced_align only supports batch_size==1.
alignment, _ = torchaudio.functional.forced_align(
ctc_prob[i, :prob_len[i], :].contiguous().unsqueeze(0),
ctc_ids[i, :ctc_ids_len[i]].contiguous().unsqueeze(0),
blank=self.blank_id)
peak_t = ctc_peak_time(alignment[0].tolist(), self.blank_id)
assert ctc_ids_len[i].item() == len(peak_t)
out_emb[i, :ctc_ids_len[i], :] = speech_emb[i, peak_t, :]
return out_emb

@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def forward(
self,
Expand All @@ -173,8 +157,12 @@ def forward(
max_speech_size = self.model_args.max_speech_token_size
text_emb = self.llm.get_input_embeddings()(input_ids)
speech_emb = self.get_speech_embeddings(mel, mel_len)
inputs_embeds = torch.cat(
(speech_emb, text_emb[:, max_speech_size:, :]), dim=1)
out = self.llm(inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels)
ctc_weight = self.model_args.ctc_weight
# Optional, add CTC loss
if ctc_weight > 0:
# Tie CTC linear transforme and input embedding weight
ctc_linear = self.llm.get_input_embeddings().weight
Expand All @@ -185,17 +173,6 @@ def forward(
with torch.cuda.amp.autocast(enabled=False):
closs = self.ctc_loss(ctc_prob.float(), ctc_ids, prob_len,
ctc_ids_len)
# Optional, reduce sequence, rewrite speech embeddings
if self.model_args.ctc_reduce:
speech_emb = self.select_speech_embeddings(
ctc_prob.transpose(0, 1), ctc_ids, prob_len, ctc_ids_len,
speech_emb)
inputs_embeds = torch.cat(
(speech_emb, text_emb[:, max_speech_size:, :]), dim=1)
out = self.llm(inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels)
if ctc_weight > 0:
out.loss = (1 - ctc_weight) * out.loss + ctc_weight * closs
return out

Expand All @@ -213,25 +190,6 @@ def generate(
max_speech_size = self.model_args.max_speech_token_size
text_emb = self.llm.get_input_embeddings()(input_ids)
speech_emb = self.get_speech_embeddings(mel, mel_len)
# Optional, appliy ctc_reduce
if self.model_args.ctc_reduce:
ctc_linear = self.llm.get_input_embeddings().weight
ctc_act = torch.matmul(speech_emb, ctc_linear.T)
ctc_probs = ctc_act.log_softmax(2)
prob_len = torch.ceil(mel_len / self.model_args.ds_rate).long()
batch_size = ctc_probs.size(0)
results = []
for i in range(batch_size):
top1 = ctc_probs[i][:prob_len[i], :].argmax(dim=1)
hyp = ctc_reduce(top1.tolist(), self.blank_id)
results.append(hyp)
results_len = torch.tensor([len(l) for l in results],
dtype=torch.long,
device=mel.device)
results_ids = lists_to_tensor(results, device=mel.device)
speech_emb = self.select_speech_embeddings(ctc_probs, results_ids,
prob_len, results_len,
speech_emb)
inputs_embeds = torch.cat(
(speech_emb, text_emb[:, max_speech_size:, :]), dim=1)
model_outputs = self.llm.generate(
Expand Down Expand Up @@ -270,38 +228,6 @@ def decode_ctc(
results.append(hyp)
return results

@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def compute_similarity(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
mel: torch.LongTensor = None,
mel_len: torch.LongTensor = None,
ctc_ids: torch.LongTensor = None,
ctc_ids_len: torch.LongTensor = None,
):
""" Compute text and speech embedding similarity
"""
max_speech_size = self.model_args.max_speech_token_size
text_emb = self.llm.get_input_embeddings()(ctc_ids)
speech_emb = self.get_speech_embeddings(mel, mel_len)
ctc_linear = self.llm.get_input_embeddings().weight
ctc_act = torch.matmul(speech_emb, ctc_linear.T)
ctc_probs = ctc_act.log_softmax(2)
prob_len = torch.ceil(mel_len / self.model_args.ds_rate).long()
speech_emb = self.select_speech_embeddings(ctc_probs, ctc_ids, prob_len,
ctc_ids_len, speech_emb)
batch_size = ctc_ids.size(0)
results = []
for i in range(batch_size):
end = ctc_ids_len[i]
s = F.cosine_similarity(text_emb[i, :end, :],
speech_emb[i, :end, :],
dim=1)
results.append(s)
return results

def enable_input_require_grads(self):
self.llm.enable_input_require_grads()

Expand Down
33 changes: 0 additions & 33 deletions utils.py

This file was deleted.