diff --git a/README.md b/README.md index 5e4e4b9..4395aa4 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ pip install s3tokenizer ```py import s3tokenizer -tokenizer = s3tokenizer.load_model("speech_tokenizer_v1").cuda() # or "speech_tokenizer_v1_25hz" +tokenizer = s3tokenizer.load_model("speech_tokenizer_v1").cuda() # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz" mels = [] wav_paths = ["s3tokenizer/assets/BAC009S0764W0121.wav", "s3tokenizer/assets/BAC009S0764W0122.wav"] @@ -48,7 +48,7 @@ s3tokenizer --wav_scp xxx.scp \ --device "cpu" \ --output_dir "./" \ --batch_size 32 \ - --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz" + --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz" ``` @@ -66,7 +66,7 @@ torchrun --nproc_per_node=8 --nnodes=1 \ --device "cuda" \ --output_dir "./" \ --batch_size 32 \ - --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz" + --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz" ``` diff --git a/s3tokenizer/__init__.py b/s3tokenizer/__init__.py index facc6aa..9b4cb87 100644 --- a/s3tokenizer/__init__.py +++ b/s3tokenizer/__init__.py @@ -24,6 +24,8 @@ from tqdm import tqdm +from s3tokenizer.model_v2 import S3TokenizerV2 + from .model import S3Tokenizer from .utils import (load_audio, log_mel_spectrogram, make_non_pad_mask, mask_to_bias, onnx2torch, padding) @@ -39,6 +41,9 @@ "speech_tokenizer_v1_25hz": "https://www.modelscope.cn/models/iic/CosyVoice-300M-25Hz/" "resolve/master/speech_tokenizer_v1.onnx", + "speech_tokenizer_v2_25hz": + "https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/" + "resolve/master/speech_tokenizer_v2.onnx", } _SHA256S = { @@ -46,6 +51,8 @@ "23b5a723ed9143aebfd9ffda14ac4c21231f31c35ef837b6a13bb9e5488abb1e", "speech_tokenizer_v1_25hz": "56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486", + "speech_tokenizer_v2_25hz": + "d43342aa12163a80bf07bffb94c9de2e120a8df2f9917cd2f642e7f4219c6f71", } @@ -137,8 +144,10 @@ def load_model( else: raise RuntimeError( f"Model {name} not found; available models = {available_models()}") - - model = S3Tokenizer(name) + if 'v2' in name: + model = S3TokenizerV2(name) + else: + model = S3Tokenizer(name) model.init_from_onnx(checkpoint_file) return model diff --git a/s3tokenizer/cli.py b/s3tokenizer/cli.py index 31c2bb9..da0c81d 100644 --- a/s3tokenizer/cli.py +++ b/s3tokenizer/cli.py @@ -91,12 +91,14 @@ def init_distributed(): def get_args(): parser = argparse.ArgumentParser(description='extract speech code') - parser.add_argument( - '--model', - required=True, - type=str, - choices=["speech_tokenizer_v1", "speech_tokenizer_v1_25hz"], - help='model version') + parser.add_argument('--model', + required=True, + type=str, + choices=[ + "speech_tokenizer_v1", "speech_tokenizer_v1_25hz", + "speech_tokenizer_v2_25hz" + ], + help='model version') parser.add_argument('--wav_scp', required=True, type=str, diff --git a/s3tokenizer/model_v2.py b/s3tokenizer/model_v2.py new file mode 100644 index 0000000..7928849 --- /dev/null +++ b/s3tokenizer/model_v2.py @@ -0,0 +1,381 @@ +# Copyright (c) (Mddct: Dinghao Zhou) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from einops import rearrange + +from s3tokenizer.model import Conv1d, LayerNorm, Linear, MultiHeadAttention +from s3tokenizer.utils import make_non_pad_mask, mask_to_bias, onnx2torch + + +@dataclass +class ModelDimensions: + n_mels: int = 128 + n_audio_ctx: int = 1500 + n_audio_state: int = 1280 + n_audio_head: int = 20 + n_audio_layer: int = 6 + n_codebook_size: int = 3**8 + + +def precompute_freqs_cis(dim: int, + end: int, + theta: float = 10000.0, + scaling=None): + freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + if scaling is not None: + t = t * scaling + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + + return torch.cat((freqs_cis, freqs_cis), dim=-1) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + real = torch.view_as_real(freqs_cis) + cos, sin = real[:, :, 0], real[:, :, 1] + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + D = xq.shape[-1] + half_l, half_r = xq[:, :, :, :D // 2], xq[:, :, :, D // 2:] + xq_r = torch.cat((-half_r, half_l), dim=-1) + + D = xk.shape[-1] + + half_l, half_r = xk[:, :, :, :D // 2], xk[:, :, :, D // 2:] + xk_r = torch.cat((-half_r, half_l), dim=-1) + + return xq * cos + xq_r * sin, xk * cos + xk_r * sin + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [ + d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) + ] + return freqs_cis.view(*shape) + + +class FSQCodebook(torch.nn.Module): + + def __init__(self, dim: int, level: int = 3): + super().__init__() + self.project_down = torch.nn.Linear(dim, 8) + self.level = level + self.embed = None + + @torch.inference_mode() + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + x = rearrange(x, "... d -> (...) d") + return x + + @torch.inference_mode() + def encode(self, x: torch.Tensor) -> torch.Tensor: + x_shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + + h = self.project_down(x) + h = h.tanh() + h = h * 0.9990000128746033 + h = h.round() + 1 + # h = ((self.level - 1) * h).round() # range [-k, k] + powers = torch.pow(self.level, + torch.arange(2**self.level, device=x.device)) + mu = torch.sum(h * powers.unsqueeze(0), dim=-1) + ind = mu.reshape(x_shape[0], x_shape[1]) + return ind + + @torch.inference_mode() + def decode(self, embed_ind: torch.Tensor) -> torch.Tensor: + raise NotImplementedError( + 'There is no official up project component provided') + + +class FSQVectorQuantization(torch.nn.Module): + """Vector quantization implementation (inference-only). + Args: + dim (int): Dimension + codebook_size (int): Codebook size + """ + + def __init__( + self, + dim: int, + codebook_size: int, + ): + super().__init__() + assert 3**8 == codebook_size + self._codebook = FSQCodebook(dim=dim, level=3) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + @torch.inference_mode() + def encode(self, x: torch.Tensor) -> torch.Tensor: + return self._codebook.encode(x) + + @torch.inference_mode() + def decode(self, embed_ind: torch.Tensor) -> torch.Tensor: + quantize = self._codebook.decode(embed_ind) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + +class FSMNMultiHeadAttention(MultiHeadAttention): + + def __init__( + self, + n_state: int, + n_head: int, + kernel_size: int = 31, + ): + super().__init__(n_state, n_head) + + self.fsmn_block = torch.nn.Conv1d(n_state, + n_state, + kernel_size, + stride=1, + padding=0, + groups=n_state, + bias=False) + self.left_padding = (kernel_size - 1) // 2 + self.right_padding = kernel_size - 1 - self.left_padding + self.pad_fn = torch.nn.ConstantPad1d( + (self.left_padding, self.right_padding), 0.0) + + def forward_fsmn(self, + inputs: torch.Tensor, + mask: Optional[torch.Tensor] = None): + b, t, _, _ = inputs.size() + inputs = inputs.view(b, t, -1) + if mask is not None and mask.size(2) > 0: # time2 > 0 + inputs = inputs * mask + x = inputs.transpose(1, 2) + x = self.pad_fn(x) + x = self.fsmn_block(x) + x = x.transpose(1, 2) + x += inputs + return x * mask + + def qkv_attention(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor] = None, + mask_pad: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None): + _, _, D = q.shape + scale = (D // self.n_head)**-0.25 + q = q.view(*q.shape[:2], self.n_head, -1) + k = k.view(*k.shape[:2], self.n_head, -1) + v = v.view(*v.shape[:2], self.n_head, -1) + + if freqs_cis is not None: + q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) + + fsm_memory = self.forward_fsmn(v, mask_pad) + + q = q.permute(0, 2, 1, 3) * scale + k = k.permute(0, 2, 3, 1) * scale + v = v.permute(0, 2, 1, 3) + + qk = q @ k # (B, n_head, T, T) + if mask is not None: + qk = qk + mask + qk = qk.float() + w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, + 3).flatten(start_dim=2), qk.detach(), fsm_memory + + def forward(self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + mask_pad: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None): + + q = self.query(x) + k = self.key(x) + v = self.value(x) + + wv, qk, fsm_memory = self.qkv_attention(q, k, v, mask, mask_pad, + freqs_cis) + return self.out(wv) + fsm_memory, qk + + +class ResidualAttentionBlock(torch.nn.Module): + + def __init__( + self, + n_state: int, + n_head: int, + kernel_size: int = 31, + ): + super().__init__() + + self.attn = FSMNMultiHeadAttention(n_state, n_head, kernel_size) + self.attn_ln = LayerNorm(n_state, eps=1e-6) + + n_mlp = n_state * 4 + + self.mlp = torch.nn.Sequential(Linear(n_state, n_mlp), torch.nn.GELU(), + Linear(n_mlp, n_state)) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + mask_pad: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ): + x = x + self.attn( + self.attn_ln(x), mask=mask, mask_pad=mask_pad, + freqs_cis=freqs_cis)[0] + + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class AudioEncoderV2(torch.nn.Module): + + def __init__( + self, + n_mels: int, + n_state: int, + n_head: int, + n_layer: int, + stride: int, + ): + super().__init__() + self.stride = stride + + self.conv1 = Conv1d(n_mels, + n_state, + kernel_size=3, + stride=stride, + padding=1) + self.conv2 = Conv1d(n_state, + n_state, + kernel_size=3, + stride=2, + padding=1) + self.freqs_cis = precompute_freqs_cis(64, 1024 * 2) + self.blocks = torch.nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]) + + def forward(self, x: torch.Tensor, + x_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + x : torch.Tensor, shape = (batch_size, n_mels, T) + the mel spectrogram of the audio + x_len: torch.Tensor, shape = (batch_size,) + length of each audio in x + """ + T = x.size(-1) + x = torch.nn.functional.gelu(self.conv1(x)) + x = torch.nn.functional.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) # (B, T // 2, n_state) + freqs_cis = self.freqs_cis.to(x.device) + mask = make_non_pad_mask(x_len, T).unsqueeze(1) # (B, 1, T) + mask = mask[:, :, (T + 1) % 2::2] # (B, 1, T // 2) + mask_pad = None + if self.stride == 2: + _T = mask.size(-1) + mask = mask[:, :, (_T + 1) % 2::2] # (B, 1, T // 4) + mask_pad = mask.transpose(1, 2) + mask = mask_to_bias(mask, x.dtype) + + tmp = torch.view_as_real(freqs_cis) + cos, sin = tmp[:, :, 0], tmp[:, :, 1] + + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + for block in self.blocks: + x = block(x, mask.unsqueeze(1), mask_pad, freqs_cis[:x.size(1)]) + + x_len = (x_len + 1) // 2 + if self.stride == 2: + x_len = (x_len + 1) // 2 + return x, x_len + + +class S3TokenizerV2(torch.nn.Module): + """S3 tokenizer v2 implementation (inference-only). + Args: + dims (ModelDimensions): Dimension + """ + + def __init__(self, name: str, dims: ModelDimensions = ModelDimensions()): + super().__init__() + if 'v1' not in name: + assert 'v2' in name + # TODO(Mddct): make it configureable + dims.n_codebook_size = 3**8 + self.dims = dims + self.encoder = AudioEncoderV2( + self.dims.n_mels, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + 2, + ) + self.quantizer = FSQVectorQuantization( + self.dims.n_audio_state, + self.dims.n_codebook_size, + ) + + def forward(self, mel: torch.Tensor, + mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self.quantize(mel, mel_len) + + @torch.inference_mode() + def quantize(self, mel: torch.Tensor, + mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + hidden, code_len = self.encoder(mel, mel_len) + code = self.quantizer.encode(hidden) + return code, code_len + + @property + def device(self): + return next(self.parameters()).device + + def init_from_onnx(self, onnx_path: str): + ckpt = onnx2torch(onnx_path, None, False) + self.load_state_dict(ckpt, strict=True) + + def init_from_pt(self, ckpt_path: str): + ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True) + self.load_state_dict(ckpt, strict=True) + + def freeze(self): + for _, param in self.named_parameters(): + param.requires_grad = False diff --git a/s3tokenizer/utils.py b/s3tokenizer/utils.py index b1e96fb..99520da 100644 --- a/s3tokenizer/utils.py +++ b/s3tokenizer/utils.py @@ -43,9 +43,11 @@ def _rename_weights(weights_dict: dict): """ new_weight_dict = {} for k in weights_dict.keys(): - if "quantizer" in k: # vq + if "quantizer" in k: # vq or fsq if k == "/quantizer/rq/model/layers.0/_codebook/Pow_1": new_weight_dict["quantizer._codebook.embed"] = weights_dict[k] + elif 'project_down' in k: # v2 + new_weight_dict[k] = weights_dict[k] elif "positional_embedding" in k: # positional emb new_weight_dict[k] = weights_dict[k] elif "conv" in k: # 1/2 or 1/4 subsample @@ -54,8 +56,10 @@ def _rename_weights(weights_dict: dict): assert "blocks" in k new_k = (k[1:].replace('/', '.').replace( 'MatMul', 'weight').replace('Add_1', 'bias').replace( - 'Mul', 'weight').replace('Add', - 'bias').replace('mlp.mlp', 'mlp')) + 'Mul', 'weight').replace('Add', 'bias').replace( + 'mlp.mlp', 'mlp')).replace('fsmn_block.Conv', + 'fsmn_block.weight') + new_weight_dict[f"encoder.{new_k}"] = weights_dict[k] return new_weight_dict @@ -89,31 +93,74 @@ def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False): for node in onnx_model.graph.node: for input_name in node.input: if input_name in initializer_map: + ln_bias_name, ln_weight_name = None, None # for v2 ln initializer = initializer_map[input_name] - if (input_name == "onnx::Conv_1519" - or input_name == "encoders.conv1.weight"): + if input_name in [ + "onnx::Conv_1519", + "encoders.conv1.weight", + "onnx::Conv_2216", + ]: # v1_50hz, v1_25hz, v2_25hz weight_name = "encoder.conv1.weight" - elif (input_name == "onnx::Conv_1520" - or input_name == "encoders.conv1.bias"): + elif input_name in [ + "onnx::Conv_1520", + "encoders.conv1.bias", + "onnx::Conv_2217", + ]: # v1_50hz, v1_25hz, v2_25hz weight_name = "encoder.conv1.bias" - elif (input_name == "onnx::Conv_1521" - or input_name == "encoders.conv2.weight"): + elif input_name in [ + "onnx::Conv_1521", + "encoders.conv2.weight", + "onnx::Conv_2218", + ]: weight_name = "encoder.conv2.weight" - elif (input_name == "onnx::Conv_1522" - or input_name == "encoders.conv2.bias"): + elif input_name in [ + "onnx::Conv_1522", + "encoders.conv2.bias", + "onnx::Conv_2219", + ]: weight_name = "encoder.conv2.bias" elif input_name == "encoders.positional_embedding": weight_name = "encoder.positional_embedding" + elif input_name == 'quantizer.project_in.bias': + weight_name = "quantizer._codebook.project_down.bias" + elif input_name == 'onnx::MatMul_2536': + weight_name = "quantizer._codebook.project_down.weight" else: - weight_name = node.name - weight_array = onnx.numpy_helper.to_array(initializer).copy() - weight_array.flags.writeable = True - weight_tensor = torch.from_numpy(weight_array) - if len(weight_tensor.shape - ) > 2 or weight_name == "encoder.positional_embedding": - weights_dict[weight_name] = weight_tensor + if node.op_type == 'LayerNormalization': # in input_name: + ln_name = node.name.replace('/LayerNormalization', '') + ln_weight_name = ln_name + '.weight' + ln_bias_name = ln_name + '.bias' + else: + weight_name = node.name + if ln_weight_name is not None and ln_bias_name is not None: + ln_inputs = node.input + scale_name = ln_inputs[1] + bias_name = ln_inputs[2] + scale = onnx.numpy_helper.to_array( + initializer_map[scale_name]).copy( + ) if scale_name in initializer_map else None + bias = onnx.numpy_helper.to_array( + initializer_map[bias_name]).copy( + ) if bias_name in initializer_map else None + scale.flags.writeable = True + bias.flags.writeable = True + weight_tensor = torch.from_numpy(scale) + bias_tensor = torch.from_numpy(bias) + + weights_dict[ln_bias_name] = bias_tensor + weights_dict[ln_weight_name] = weight_tensor else: - weights_dict[weight_name] = weight_tensor.t() + weight_array = onnx.numpy_helper.to_array( + initializer).copy() + weight_array.flags.writeable = True + weight_tensor = torch.from_numpy(weight_array) + if len(weight_tensor.shape) > 2 or weight_name in [ + "encoder.positional_embedding" + ]: + weights_dict[weight_name] = weight_tensor + else: + weights_dict[weight_name] = weight_tensor.t() + new_weights_dict = _rename_weights(weights_dict) if verbose: for k, v in new_weights_dict.items(): diff --git a/test/test_onnx.py b/test/test_onnx.py index bd7cb0b..6e64eed 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -51,6 +51,8 @@ ort_session.get_inputs()[1].name: np.array([mels_lens[i].item()], dtype=np.int32) })[0] + if name == 'speech_tokenizer_v2_25hz': + speech_token = np.expand_dims(speech_token, 0) speech_token = torch.tensor(speech_token[0, 0, :]) print(f"wav[{i}]") print(speech_token) @@ -58,6 +60,6 @@ f"all equal: {torch.equal(speech_token, codes[i, :codes_lens[i].item()].cpu())}" # noqa ) miss_num = torch.sum( - (speech_token == codes[i, :codes_lens[i].item()].cpu()) is False) + ~(speech_token == codes[i, :codes_lens[i].item()].cpu())) total = speech_token.numel() print(f"miss rate: {miss_num * 100.0 / total}%")