From 8b6a6003b667fc6b00ece4675853aff94a12ca88 Mon Sep 17 00:00:00 2001 From: Mddct Date: Thu, 19 Dec 2024 23:01:40 +0800 Subject: [PATCH 1/4] [v2-tokenizer] support cosyvoice-tokenizer-v2 --- s3tokenizer/__init__.py | 5 ++ s3tokenizer/cli.py | 14 ++-- s3tokenizer/model.py | 139 +++++++++++++++++++++++++++++++++++----- s3tokenizer/utils.py | 44 +++++++++---- 4 files changed, 169 insertions(+), 33 deletions(-) diff --git a/s3tokenizer/__init__.py b/s3tokenizer/__init__.py index facc6aa..15e0f3f 100644 --- a/s3tokenizer/__init__.py +++ b/s3tokenizer/__init__.py @@ -39,6 +39,9 @@ "speech_tokenizer_v1_25hz": "https://www.modelscope.cn/models/iic/CosyVoice-300M-25Hz/" "resolve/master/speech_tokenizer_v1.onnx", + "speech_tokenizer_v2_25hz": + "https://www.modelscope.cn/models/iic/CosyVoice2-0.5B/" + "resolve/master/speech_tokenizer_v2.onnx", } _SHA256S = { @@ -46,6 +49,8 @@ "23b5a723ed9143aebfd9ffda14ac4c21231f31c35ef837b6a13bb9e5488abb1e", "speech_tokenizer_v1_25hz": "56285ddd4a83e883ee0cb9f8d69c1089b53a94b1f78ff7e4a0224a27eb4cb486", + "speech_tokenizer_v2_25hz": + "d43342aa12163a80bf07bffb94c9de2e120a8df2f9917cd2f642e7f4219c6f71", } diff --git a/s3tokenizer/cli.py b/s3tokenizer/cli.py index 31c2bb9..da0c81d 100644 --- a/s3tokenizer/cli.py +++ b/s3tokenizer/cli.py @@ -91,12 +91,14 @@ def init_distributed(): def get_args(): parser = argparse.ArgumentParser(description='extract speech code') - parser.add_argument( - '--model', - required=True, - type=str, - choices=["speech_tokenizer_v1", "speech_tokenizer_v1_25hz"], - help='model version') + parser.add_argument('--model', + required=True, + type=str, + choices=[ + "speech_tokenizer_v1", "speech_tokenizer_v1_25hz", + "speech_tokenizer_v2_25hz" + ], + help='model version') parser.add_argument('--wav_scp', required=True, type=str, diff --git a/s3tokenizer/model.py b/s3tokenizer/model.py index 73b77be..002f68d 100644 --- a/s3tokenizer/model.py +++ b/s3tokenizer/model.py @@ -38,6 +38,43 @@ class ModelDimensions: n_codebook_size: int = 4096 +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [ + d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) + ] + return freqs_cis.view(*shape) + + +def round_ste(z: Tensor) -> Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + class LayerNorm(nn.LayerNorm): def forward(self, x: Tensor) -> Tensor: @@ -87,23 +124,30 @@ def forward( self, x: Tensor, mask: Optional[Tensor] = None, + freqs_cis: Optional[Tensor] = None, ): q = self.query(x) k = self.key(x) v = self.value(x) - wv, qk = self.qkv_attention(q, k, v, mask) + wv, qk = self.qkv_attention(q, k, v, mask, freqs_cis) return self.out(wv), qk def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, - mask: Optional[Tensor] = None): + mask: Optional[Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None): _, T, D = q.shape scale = (D // self.n_head)**-0.25 - q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale - k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + q = q.view(*q.shape[:2], self.n_head, -1) + k = k.view(*k.shape[:2], self.n_head, -1) + if freqs_cis is not None: + q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) + + q = q.permute(0, 2, 1, 3) * scale + k = k.permute(0, 2, 1, 3) * scale v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) qk = q @ k # (B, n_head, T, T) @@ -148,6 +192,7 @@ def __init__( n_head: int, n_layer: int, stride: int, + rope: bool = False, ): super().__init__() self.stride = stride @@ -161,7 +206,13 @@ def __init__( kernel_size=3, stride=2, padding=1) - self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) + self.rope = True + if not rope: + self.register_buffer("positional_embedding", + sinusoids(n_ctx, n_state)) + else: + freqs_cis = precompute_freqs_cis(n_mels, 1024 * 2) + self.register_buffer("positional_embedding", freqs_cis) self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]) @@ -184,10 +235,15 @@ def forward(self, x: Tensor, x_len: Tensor) -> Tuple[Tensor, Tensor]: mask = mask[:, :, (_T + 1) % 2::2] # (B, 1, T // 4) mask = mask_to_bias(mask, x.dtype) - x = (x + self.positional_embedding[:x.shape[1], :]).to(x.dtype) + if not self.rope: + x = (x + self.positional_embedding[:x.shape[1], :]).to(x.dtype) for block in self.blocks: - x = block(x, mask.unsqueeze(1)) + if self.rope: + x = block(x, mask.unsqueeze(1), + self.positional_embedding[:x.size(1)]) + else: + x = block(x, mask.unsqueeze(1)) x_len = (x_len + 1) // 2 if self.stride == 2: @@ -247,6 +303,40 @@ def decode(self, embed_ind: Tensor) -> Tensor: return quantize +class FSQCodebook(nn.Module): + + def __init__(self, dim: int, level: int = 3): + super().__init__() + self.project_down = torch.nn.Linear(dim, 8) + self.level = level + self.embed = None + + @torch.inference_mode() + def preprocess(self, x: Tensor) -> Tensor: + x = rearrange(x, "... d -> (...) d") + return x + + @torch.inference_mode() + def encode(self, x: Tensor) -> Tensor: + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + + h = x.tanh() + h = (self.level * h).round() # range [-k, k] + base = 2 * self.level + 1 + powers = torch.pow(base, torch.arange(shape[-1], device=x.device)) + mu = torch.sum(h * powers.unsqueeze(0), dim=-1) + ind = mu.reshape(shape[0], shape[1]) + return ind + + @torch.inference_mode() + def decode(self, embed_ind: Tensor) -> Tensor: + raise NotImplementedError( + 'There is no official up project component provided') + + class VectorQuantization(nn.Module): """Vector quantization implementation (inference-only). Args: @@ -254,10 +344,16 @@ class VectorQuantization(nn.Module): codebook_size (int): Codebook size """ - def __init__(self, dim: int, codebook_size: int): + def __init__(self, dim: int, codebook_size: int, quantize_type='vq'): super().__init__() - self._codebook = EuclideanCodebook(dim=dim, - codebook_size=codebook_size) + if quantize_type == 'vq': + self._codebook = EuclideanCodebook(dim=dim, + codebook_size=codebook_size) + else: + assert quantize_type == 'fsq' + assert 3**8 == codebook_size + self._codebook = FSQCodebook(dim=dim, level=3) + self.quantize_type = quantize_type self.codebook_size = codebook_size @property @@ -266,9 +362,13 @@ def codebook(self): @torch.inference_mode() def encode(self, x: Tensor) -> Tensor: - x = F.normalize(x, p=2, dim=-1) - embed_in = self._codebook.encode(x) - return embed_in + if self.quantize_type == 'vq': + x = F.normalize(x, p=2, dim=-1) + embed_in = self._codebook.encode(x) + return embed_in + else: + assert self.quantize_type == 'fsq' + return self._codebook.encode(x) @torch.inference_mode() def decode(self, embed_ind: Tensor) -> Tensor: @@ -285,6 +385,10 @@ class S3Tokenizer(nn.Module): def __init__(self, name: str, dims: ModelDimensions = ModelDimensions()): super().__init__() + if 'v1' not in name: + assert 'v2' in name + # TODO(Mddct): make it configureable + dims.n_codebook_size = 3**8 self.dims = dims self.encoder = AudioEncoder( self.dims.n_mels, @@ -292,10 +396,13 @@ def __init__(self, name: str, dims: ModelDimensions = ModelDimensions()): self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer, - 2 if name == "speech_tokenizer_v1_25hz" else 1, + 2 if name + in ["speech_tokenizer_v1_25hz", "speech_tokenizer_v2_25hz"] else 1, ) - self.quantizer = VectorQuantization(self.dims.n_audio_state, - self.dims.n_codebook_size) + self.quantizer = VectorQuantization( + self.dims.n_audio_state, + self.dims.n_codebook_size, + quantize_type='vq' if 'v1' in name else 'fsq') def forward(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]: return self.quantize(mel, mel_len) diff --git a/s3tokenizer/utils.py b/s3tokenizer/utils.py index b1e96fb..489bc81 100644 --- a/s3tokenizer/utils.py +++ b/s3tokenizer/utils.py @@ -43,9 +43,17 @@ def _rename_weights(weights_dict: dict): """ new_weight_dict = {} for k in weights_dict.keys(): - if "quantizer" in k: # vq + if "quantizer" in k: # vq or fsq if k == "/quantizer/rq/model/layers.0/_codebook/Pow_1": new_weight_dict["quantizer._codebook.embed"] = weights_dict[k] + elif k == 'quantizer.project_in.bias': + new_weight_dict[ + "quantizer._codebook.project_down.bias"] = weights_dict[k] + elif k == 'onnx::MatMul_2536': + new_weight_dict[ + "quantizer._codebook.project_down.weight"] = weights_dict[ + k] + elif "positional_embedding" in k: # positional emb new_weight_dict[k] = weights_dict[k] elif "conv" in k: # 1/2 or 1/4 subsample @@ -90,17 +98,29 @@ def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False): for input_name in node.input: if input_name in initializer_map: initializer = initializer_map[input_name] - if (input_name == "onnx::Conv_1519" - or input_name == "encoders.conv1.weight"): + if input_name in [ + "onnx::Conv_1519", + "encoders.conv1.weight", + "onnx::Conv_2216", + ]: # v1_50hz, v1_25hz, v2_25hz weight_name = "encoder.conv1.weight" - elif (input_name == "onnx::Conv_1520" - or input_name == "encoders.conv1.bias"): + elif input_name in [ + "onnx::Conv_1520", + "encoders.conv1.bias", + "onnx::Conv_2217", + ]: # v1_50hz, v1_25hz, v2_25hz weight_name = "encoder.conv1.bias" - elif (input_name == "onnx::Conv_1521" - or input_name == "encoders.conv2.weight"): + elif input_name in [ + "onnx::Conv_1521", + "encoders.conv2.weight", + "onnx::Conv_2218", + ]: weight_name = "encoder.conv2.weight" - elif (input_name == "onnx::Conv_1522" - or input_name == "encoders.conv2.bias"): + elif input_name in [ + "onnx::Conv_1522", + "encoders.conv2.bias", + "onnx::Conv_2219", + ]: weight_name = "encoder.conv2.bias" elif input_name == "encoders.positional_embedding": weight_name = "encoder.positional_embedding" @@ -109,11 +129,13 @@ def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False): weight_array = onnx.numpy_helper.to_array(initializer).copy() weight_array.flags.writeable = True weight_tensor = torch.from_numpy(weight_array) - if len(weight_tensor.shape - ) > 2 or weight_name == "encoder.positional_embedding": + if len(weight_tensor.shape) > 2 or weight_name in [ + "encoder.positional_embedding" + ]: weights_dict[weight_name] = weight_tensor else: weights_dict[weight_name] = weight_tensor.t() + new_weights_dict = _rename_weights(weights_dict) if verbose: for k, v in new_weights_dict.items(): From 80c805b5f660c7be2cc423553e7f4ec0abbf4945 Mon Sep 17 00:00:00 2001 From: Mddct Date: Fri, 20 Dec 2024 12:40:07 +0800 Subject: [PATCH 2/4] fix fuse ln in v2 --- s3tokenizer/model.py | 6 ++--- s3tokenizer/utils.py | 58 +++++++++++++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/s3tokenizer/model.py b/s3tokenizer/model.py index 002f68d..19a3ee7 100644 --- a/s3tokenizer/model.py +++ b/s3tokenizer/model.py @@ -147,7 +147,7 @@ def qkv_attention(self, q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) q = q.permute(0, 2, 1, 3) * scale - k = k.permute(0, 2, 1, 3) * scale + k = k.permute(0, 2, 3, 1) * scale v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) qk = q @ k # (B, n_head, T, T) @@ -206,7 +206,7 @@ def __init__( kernel_size=3, stride=2, padding=1) - self.rope = True + self.rope = rope if not rope: self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) @@ -398,7 +398,7 @@ def __init__(self, name: str, dims: ModelDimensions = ModelDimensions()): self.dims.n_audio_layer, 2 if name in ["speech_tokenizer_v1_25hz", "speech_tokenizer_v2_25hz"] else 1, - ) + rope=True if 'v2' in name else False) self.quantizer = VectorQuantization( self.dims.n_audio_state, self.dims.n_codebook_size, diff --git a/s3tokenizer/utils.py b/s3tokenizer/utils.py index 489bc81..7b94736 100644 --- a/s3tokenizer/utils.py +++ b/s3tokenizer/utils.py @@ -46,14 +46,8 @@ def _rename_weights(weights_dict: dict): if "quantizer" in k: # vq or fsq if k == "/quantizer/rq/model/layers.0/_codebook/Pow_1": new_weight_dict["quantizer._codebook.embed"] = weights_dict[k] - elif k == 'quantizer.project_in.bias': - new_weight_dict[ - "quantizer._codebook.project_down.bias"] = weights_dict[k] - elif k == 'onnx::MatMul_2536': - new_weight_dict[ - "quantizer._codebook.project_down.weight"] = weights_dict[ - k] - + elif 'project_down' in k: # v2 + new_weight_dict[k] = weights_dict[k] elif "positional_embedding" in k: # positional emb new_weight_dict[k] = weights_dict[k] elif "conv" in k: # 1/2 or 1/4 subsample @@ -64,6 +58,7 @@ def _rename_weights(weights_dict: dict): 'MatMul', 'weight').replace('Add_1', 'bias').replace( 'Mul', 'weight').replace('Add', 'bias').replace('mlp.mlp', 'mlp')) + new_weight_dict[f"encoder.{new_k}"] = weights_dict[k] return new_weight_dict @@ -97,6 +92,7 @@ def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False): for node in onnx_model.graph.node: for input_name in node.input: if input_name in initializer_map: + ln_bias_name, ln_weight_name = None, None # for v2 ln initializer = initializer_map[input_name] if input_name in [ "onnx::Conv_1519", @@ -124,17 +120,45 @@ def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False): weight_name = "encoder.conv2.bias" elif input_name == "encoders.positional_embedding": weight_name = "encoder.positional_embedding" + elif input_name == 'quantizer.project_in.bias': + weight_name = "quantizer._codebook.project_down.bias" + elif input_name == 'onnx::MatMul_2536': + weight_name = "quantizer._codebook.project_down.weight" else: - weight_name = node.name - weight_array = onnx.numpy_helper.to_array(initializer).copy() - weight_array.flags.writeable = True - weight_tensor = torch.from_numpy(weight_array) - if len(weight_tensor.shape) > 2 or weight_name in [ - "encoder.positional_embedding" - ]: - weights_dict[weight_name] = weight_tensor + if node.op_type == 'LayerNormalization': # in input_name: + ln_name = node.name.replace('/LayerNormalization', '') + ln_weight_name = ln_name + '.weight' + ln_bias_name = ln_name + '.bias' + else: + weight_name = node.name + if ln_weight_name is not None and ln_bias_name is not None: + ln_inputs = node.input + scale_name = ln_inputs[1] # Scale 的输入名称 + bias_name = ln_inputs[2] # B 的输入名称 + scale = onnx.numpy_helper.to_array( + initializer_map[scale_name]).copy( + ) if scale_name in initializer_map else None + bias = onnx.numpy_helper.to_array( + initializer_map[bias_name]).copy( + ) if bias_name in initializer_map else None + scale.flags.writeable = True + bias.flags.writeable = True + weight_tensor = torch.from_numpy(scale) + bias_tensor = torch.from_numpy(bias) + + weights_dict[ln_bias_name] = bias_tensor + weights_dict[ln_weight_name] = weight_tensor else: - weights_dict[weight_name] = weight_tensor.t() + weight_array = onnx.numpy_helper.to_array( + initializer).copy() + weight_array.flags.writeable = True + weight_tensor = torch.from_numpy(weight_array) + if len(weight_tensor.shape) > 2 or weight_name in [ + "encoder.positional_embedding" + ]: + weights_dict[weight_name] = weight_tensor + else: + weights_dict[weight_name] = weight_tensor.t() new_weights_dict = _rename_weights(weights_dict) if verbose: From 84e647ae56123dffde392cea68785c8c675d6206 Mon Sep 17 00:00:00 2001 From: Mddct Date: Fri, 20 Dec 2024 16:24:00 +0800 Subject: [PATCH 3/4] v2 works --- s3tokenizer/__init__.py | 8 +- s3tokenizer/model.py | 141 ++------------- s3tokenizer/model_v2.py | 381 ++++++++++++++++++++++++++++++++++++++++ s3tokenizer/utils.py | 9 +- test/test_onnx.py | 4 +- 5 files changed, 412 insertions(+), 131 deletions(-) create mode 100644 s3tokenizer/model_v2.py diff --git a/s3tokenizer/__init__.py b/s3tokenizer/__init__.py index 15e0f3f..9b4cb87 100644 --- a/s3tokenizer/__init__.py +++ b/s3tokenizer/__init__.py @@ -24,6 +24,8 @@ from tqdm import tqdm +from s3tokenizer.model_v2 import S3TokenizerV2 + from .model import S3Tokenizer from .utils import (load_audio, log_mel_spectrogram, make_non_pad_mask, mask_to_bias, onnx2torch, padding) @@ -142,8 +144,10 @@ def load_model( else: raise RuntimeError( f"Model {name} not found; available models = {available_models()}") - - model = S3Tokenizer(name) + if 'v2' in name: + model = S3TokenizerV2(name) + else: + model = S3Tokenizer(name) model.init_from_onnx(checkpoint_file) return model diff --git a/s3tokenizer/model.py b/s3tokenizer/model.py index 19a3ee7..73b77be 100644 --- a/s3tokenizer/model.py +++ b/s3tokenizer/model.py @@ -38,43 +38,6 @@ class ModelDimensions: n_codebook_size: int = 4096 -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): - ndim = x.ndim - assert 0 <= 1 < ndim - assert freqs_cis.shape == (x.shape[1], x.shape[-1]) - shape = [ - d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) - ] - return freqs_cis.view(*shape) - - -def round_ste(z: Tensor) -> Tensor: - """Round with straight through gradients.""" - zhat = z.round() - return z + (zhat - z).detach() - - -def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cis: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) - xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) - - -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - - class LayerNorm(nn.LayerNorm): def forward(self, x: Tensor) -> Tensor: @@ -124,30 +87,23 @@ def forward( self, x: Tensor, mask: Optional[Tensor] = None, - freqs_cis: Optional[Tensor] = None, ): q = self.query(x) k = self.key(x) v = self.value(x) - wv, qk = self.qkv_attention(q, k, v, mask, freqs_cis) + wv, qk = self.qkv_attention(q, k, v, mask) return self.out(wv), qk def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, - mask: Optional[Tensor] = None, - freqs_cis: Optional[torch.Tensor] = None): + mask: Optional[Tensor] = None): _, T, D = q.shape scale = (D // self.n_head)**-0.25 - q = q.view(*q.shape[:2], self.n_head, -1) - k = k.view(*k.shape[:2], self.n_head, -1) - if freqs_cis is not None: - q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) - - q = q.permute(0, 2, 1, 3) * scale - k = k.permute(0, 2, 3, 1) * scale + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) qk = q @ k # (B, n_head, T, T) @@ -192,7 +148,6 @@ def __init__( n_head: int, n_layer: int, stride: int, - rope: bool = False, ): super().__init__() self.stride = stride @@ -206,13 +161,7 @@ def __init__( kernel_size=3, stride=2, padding=1) - self.rope = rope - if not rope: - self.register_buffer("positional_embedding", - sinusoids(n_ctx, n_state)) - else: - freqs_cis = precompute_freqs_cis(n_mels, 1024 * 2) - self.register_buffer("positional_embedding", freqs_cis) + self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]) @@ -235,15 +184,10 @@ def forward(self, x: Tensor, x_len: Tensor) -> Tuple[Tensor, Tensor]: mask = mask[:, :, (_T + 1) % 2::2] # (B, 1, T // 4) mask = mask_to_bias(mask, x.dtype) - if not self.rope: - x = (x + self.positional_embedding[:x.shape[1], :]).to(x.dtype) + x = (x + self.positional_embedding[:x.shape[1], :]).to(x.dtype) for block in self.blocks: - if self.rope: - x = block(x, mask.unsqueeze(1), - self.positional_embedding[:x.size(1)]) - else: - x = block(x, mask.unsqueeze(1)) + x = block(x, mask.unsqueeze(1)) x_len = (x_len + 1) // 2 if self.stride == 2: @@ -303,40 +247,6 @@ def decode(self, embed_ind: Tensor) -> Tensor: return quantize -class FSQCodebook(nn.Module): - - def __init__(self, dim: int, level: int = 3): - super().__init__() - self.project_down = torch.nn.Linear(dim, 8) - self.level = level - self.embed = None - - @torch.inference_mode() - def preprocess(self, x: Tensor) -> Tensor: - x = rearrange(x, "... d -> (...) d") - return x - - @torch.inference_mode() - def encode(self, x: Tensor) -> Tensor: - shape = x.shape - # pre-process - x = self.preprocess(x) - # quantize - - h = x.tanh() - h = (self.level * h).round() # range [-k, k] - base = 2 * self.level + 1 - powers = torch.pow(base, torch.arange(shape[-1], device=x.device)) - mu = torch.sum(h * powers.unsqueeze(0), dim=-1) - ind = mu.reshape(shape[0], shape[1]) - return ind - - @torch.inference_mode() - def decode(self, embed_ind: Tensor) -> Tensor: - raise NotImplementedError( - 'There is no official up project component provided') - - class VectorQuantization(nn.Module): """Vector quantization implementation (inference-only). Args: @@ -344,16 +254,10 @@ class VectorQuantization(nn.Module): codebook_size (int): Codebook size """ - def __init__(self, dim: int, codebook_size: int, quantize_type='vq'): + def __init__(self, dim: int, codebook_size: int): super().__init__() - if quantize_type == 'vq': - self._codebook = EuclideanCodebook(dim=dim, - codebook_size=codebook_size) - else: - assert quantize_type == 'fsq' - assert 3**8 == codebook_size - self._codebook = FSQCodebook(dim=dim, level=3) - self.quantize_type = quantize_type + self._codebook = EuclideanCodebook(dim=dim, + codebook_size=codebook_size) self.codebook_size = codebook_size @property @@ -362,13 +266,9 @@ def codebook(self): @torch.inference_mode() def encode(self, x: Tensor) -> Tensor: - if self.quantize_type == 'vq': - x = F.normalize(x, p=2, dim=-1) - embed_in = self._codebook.encode(x) - return embed_in - else: - assert self.quantize_type == 'fsq' - return self._codebook.encode(x) + x = F.normalize(x, p=2, dim=-1) + embed_in = self._codebook.encode(x) + return embed_in @torch.inference_mode() def decode(self, embed_ind: Tensor) -> Tensor: @@ -385,10 +285,6 @@ class S3Tokenizer(nn.Module): def __init__(self, name: str, dims: ModelDimensions = ModelDimensions()): super().__init__() - if 'v1' not in name: - assert 'v2' in name - # TODO(Mddct): make it configureable - dims.n_codebook_size = 3**8 self.dims = dims self.encoder = AudioEncoder( self.dims.n_mels, @@ -396,13 +292,10 @@ def __init__(self, name: str, dims: ModelDimensions = ModelDimensions()): self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer, - 2 if name - in ["speech_tokenizer_v1_25hz", "speech_tokenizer_v2_25hz"] else 1, - rope=True if 'v2' in name else False) - self.quantizer = VectorQuantization( - self.dims.n_audio_state, - self.dims.n_codebook_size, - quantize_type='vq' if 'v1' in name else 'fsq') + 2 if name == "speech_tokenizer_v1_25hz" else 1, + ) + self.quantizer = VectorQuantization(self.dims.n_audio_state, + self.dims.n_codebook_size) def forward(self, mel: Tensor, mel_len: Tensor) -> Tuple[Tensor, Tensor]: return self.quantize(mel, mel_len) diff --git a/s3tokenizer/model_v2.py b/s3tokenizer/model_v2.py new file mode 100644 index 0000000..7928849 --- /dev/null +++ b/s3tokenizer/model_v2.py @@ -0,0 +1,381 @@ +# Copyright (c) (Mddct: Dinghao Zhou) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from einops import rearrange + +from s3tokenizer.model import Conv1d, LayerNorm, Linear, MultiHeadAttention +from s3tokenizer.utils import make_non_pad_mask, mask_to_bias, onnx2torch + + +@dataclass +class ModelDimensions: + n_mels: int = 128 + n_audio_ctx: int = 1500 + n_audio_state: int = 1280 + n_audio_head: int = 20 + n_audio_layer: int = 6 + n_codebook_size: int = 3**8 + + +def precompute_freqs_cis(dim: int, + end: int, + theta: float = 10000.0, + scaling=None): + freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + if scaling is not None: + t = t * scaling + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + + return torch.cat((freqs_cis, freqs_cis), dim=-1) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + real = torch.view_as_real(freqs_cis) + cos, sin = real[:, :, 0], real[:, :, 1] + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + D = xq.shape[-1] + half_l, half_r = xq[:, :, :, :D // 2], xq[:, :, :, D // 2:] + xq_r = torch.cat((-half_r, half_l), dim=-1) + + D = xk.shape[-1] + + half_l, half_r = xk[:, :, :, :D // 2], xk[:, :, :, D // 2:] + xk_r = torch.cat((-half_r, half_l), dim=-1) + + return xq * cos + xq_r * sin, xk * cos + xk_r * sin + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [ + d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) + ] + return freqs_cis.view(*shape) + + +class FSQCodebook(torch.nn.Module): + + def __init__(self, dim: int, level: int = 3): + super().__init__() + self.project_down = torch.nn.Linear(dim, 8) + self.level = level + self.embed = None + + @torch.inference_mode() + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + x = rearrange(x, "... d -> (...) d") + return x + + @torch.inference_mode() + def encode(self, x: torch.Tensor) -> torch.Tensor: + x_shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + + h = self.project_down(x) + h = h.tanh() + h = h * 0.9990000128746033 + h = h.round() + 1 + # h = ((self.level - 1) * h).round() # range [-k, k] + powers = torch.pow(self.level, + torch.arange(2**self.level, device=x.device)) + mu = torch.sum(h * powers.unsqueeze(0), dim=-1) + ind = mu.reshape(x_shape[0], x_shape[1]) + return ind + + @torch.inference_mode() + def decode(self, embed_ind: torch.Tensor) -> torch.Tensor: + raise NotImplementedError( + 'There is no official up project component provided') + + +class FSQVectorQuantization(torch.nn.Module): + """Vector quantization implementation (inference-only). + Args: + dim (int): Dimension + codebook_size (int): Codebook size + """ + + def __init__( + self, + dim: int, + codebook_size: int, + ): + super().__init__() + assert 3**8 == codebook_size + self._codebook = FSQCodebook(dim=dim, level=3) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + @torch.inference_mode() + def encode(self, x: torch.Tensor) -> torch.Tensor: + return self._codebook.encode(x) + + @torch.inference_mode() + def decode(self, embed_ind: torch.Tensor) -> torch.Tensor: + quantize = self._codebook.decode(embed_ind) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + +class FSMNMultiHeadAttention(MultiHeadAttention): + + def __init__( + self, + n_state: int, + n_head: int, + kernel_size: int = 31, + ): + super().__init__(n_state, n_head) + + self.fsmn_block = torch.nn.Conv1d(n_state, + n_state, + kernel_size, + stride=1, + padding=0, + groups=n_state, + bias=False) + self.left_padding = (kernel_size - 1) // 2 + self.right_padding = kernel_size - 1 - self.left_padding + self.pad_fn = torch.nn.ConstantPad1d( + (self.left_padding, self.right_padding), 0.0) + + def forward_fsmn(self, + inputs: torch.Tensor, + mask: Optional[torch.Tensor] = None): + b, t, _, _ = inputs.size() + inputs = inputs.view(b, t, -1) + if mask is not None and mask.size(2) > 0: # time2 > 0 + inputs = inputs * mask + x = inputs.transpose(1, 2) + x = self.pad_fn(x) + x = self.fsmn_block(x) + x = x.transpose(1, 2) + x += inputs + return x * mask + + def qkv_attention(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor] = None, + mask_pad: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None): + _, _, D = q.shape + scale = (D // self.n_head)**-0.25 + q = q.view(*q.shape[:2], self.n_head, -1) + k = k.view(*k.shape[:2], self.n_head, -1) + v = v.view(*v.shape[:2], self.n_head, -1) + + if freqs_cis is not None: + q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) + + fsm_memory = self.forward_fsmn(v, mask_pad) + + q = q.permute(0, 2, 1, 3) * scale + k = k.permute(0, 2, 3, 1) * scale + v = v.permute(0, 2, 1, 3) + + qk = q @ k # (B, n_head, T, T) + if mask is not None: + qk = qk + mask + qk = qk.float() + w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, + 3).flatten(start_dim=2), qk.detach(), fsm_memory + + def forward(self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + mask_pad: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None): + + q = self.query(x) + k = self.key(x) + v = self.value(x) + + wv, qk, fsm_memory = self.qkv_attention(q, k, v, mask, mask_pad, + freqs_cis) + return self.out(wv) + fsm_memory, qk + + +class ResidualAttentionBlock(torch.nn.Module): + + def __init__( + self, + n_state: int, + n_head: int, + kernel_size: int = 31, + ): + super().__init__() + + self.attn = FSMNMultiHeadAttention(n_state, n_head, kernel_size) + self.attn_ln = LayerNorm(n_state, eps=1e-6) + + n_mlp = n_state * 4 + + self.mlp = torch.nn.Sequential(Linear(n_state, n_mlp), torch.nn.GELU(), + Linear(n_mlp, n_state)) + self.mlp_ln = LayerNorm(n_state) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + mask_pad: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ): + x = x + self.attn( + self.attn_ln(x), mask=mask, mask_pad=mask_pad, + freqs_cis=freqs_cis)[0] + + x = x + self.mlp(self.mlp_ln(x)) + return x + + +class AudioEncoderV2(torch.nn.Module): + + def __init__( + self, + n_mels: int, + n_state: int, + n_head: int, + n_layer: int, + stride: int, + ): + super().__init__() + self.stride = stride + + self.conv1 = Conv1d(n_mels, + n_state, + kernel_size=3, + stride=stride, + padding=1) + self.conv2 = Conv1d(n_state, + n_state, + kernel_size=3, + stride=2, + padding=1) + self.freqs_cis = precompute_freqs_cis(64, 1024 * 2) + self.blocks = torch.nn.ModuleList( + [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]) + + def forward(self, x: torch.Tensor, + x_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + x : torch.Tensor, shape = (batch_size, n_mels, T) + the mel spectrogram of the audio + x_len: torch.Tensor, shape = (batch_size,) + length of each audio in x + """ + T = x.size(-1) + x = torch.nn.functional.gelu(self.conv1(x)) + x = torch.nn.functional.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) # (B, T // 2, n_state) + freqs_cis = self.freqs_cis.to(x.device) + mask = make_non_pad_mask(x_len, T).unsqueeze(1) # (B, 1, T) + mask = mask[:, :, (T + 1) % 2::2] # (B, 1, T // 2) + mask_pad = None + if self.stride == 2: + _T = mask.size(-1) + mask = mask[:, :, (_T + 1) % 2::2] # (B, 1, T // 4) + mask_pad = mask.transpose(1, 2) + mask = mask_to_bias(mask, x.dtype) + + tmp = torch.view_as_real(freqs_cis) + cos, sin = tmp[:, :, 0], tmp[:, :, 1] + + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + + for block in self.blocks: + x = block(x, mask.unsqueeze(1), mask_pad, freqs_cis[:x.size(1)]) + + x_len = (x_len + 1) // 2 + if self.stride == 2: + x_len = (x_len + 1) // 2 + return x, x_len + + +class S3TokenizerV2(torch.nn.Module): + """S3 tokenizer v2 implementation (inference-only). + Args: + dims (ModelDimensions): Dimension + """ + + def __init__(self, name: str, dims: ModelDimensions = ModelDimensions()): + super().__init__() + if 'v1' not in name: + assert 'v2' in name + # TODO(Mddct): make it configureable + dims.n_codebook_size = 3**8 + self.dims = dims + self.encoder = AudioEncoderV2( + self.dims.n_mels, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + 2, + ) + self.quantizer = FSQVectorQuantization( + self.dims.n_audio_state, + self.dims.n_codebook_size, + ) + + def forward(self, mel: torch.Tensor, + mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return self.quantize(mel, mel_len) + + @torch.inference_mode() + def quantize(self, mel: torch.Tensor, + mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + hidden, code_len = self.encoder(mel, mel_len) + code = self.quantizer.encode(hidden) + return code, code_len + + @property + def device(self): + return next(self.parameters()).device + + def init_from_onnx(self, onnx_path: str): + ckpt = onnx2torch(onnx_path, None, False) + self.load_state_dict(ckpt, strict=True) + + def init_from_pt(self, ckpt_path: str): + ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True) + self.load_state_dict(ckpt, strict=True) + + def freeze(self): + for _, param in self.named_parameters(): + param.requires_grad = False diff --git a/s3tokenizer/utils.py b/s3tokenizer/utils.py index 7b94736..99520da 100644 --- a/s3tokenizer/utils.py +++ b/s3tokenizer/utils.py @@ -56,8 +56,9 @@ def _rename_weights(weights_dict: dict): assert "blocks" in k new_k = (k[1:].replace('/', '.').replace( 'MatMul', 'weight').replace('Add_1', 'bias').replace( - 'Mul', 'weight').replace('Add', - 'bias').replace('mlp.mlp', 'mlp')) + 'Mul', 'weight').replace('Add', 'bias').replace( + 'mlp.mlp', 'mlp')).replace('fsmn_block.Conv', + 'fsmn_block.weight') new_weight_dict[f"encoder.{new_k}"] = weights_dict[k] return new_weight_dict @@ -133,8 +134,8 @@ def onnx2torch(onnx_path: str, torch_path: str = None, verbose: bool = False): weight_name = node.name if ln_weight_name is not None and ln_bias_name is not None: ln_inputs = node.input - scale_name = ln_inputs[1] # Scale 的输入名称 - bias_name = ln_inputs[2] # B 的输入名称 + scale_name = ln_inputs[1] + bias_name = ln_inputs[2] scale = onnx.numpy_helper.to_array( initializer_map[scale_name]).copy( ) if scale_name in initializer_map else None diff --git a/test/test_onnx.py b/test/test_onnx.py index bd7cb0b..6e64eed 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -51,6 +51,8 @@ ort_session.get_inputs()[1].name: np.array([mels_lens[i].item()], dtype=np.int32) })[0] + if name == 'speech_tokenizer_v2_25hz': + speech_token = np.expand_dims(speech_token, 0) speech_token = torch.tensor(speech_token[0, 0, :]) print(f"wav[{i}]") print(speech_token) @@ -58,6 +60,6 @@ f"all equal: {torch.equal(speech_token, codes[i, :codes_lens[i].item()].cpu())}" # noqa ) miss_num = torch.sum( - (speech_token == codes[i, :codes_lens[i].item()].cpu()) is False) + ~(speech_token == codes[i, :codes_lens[i].item()].cpu())) total = speech_token.numel() print(f"miss rate: {miss_num * 100.0 / total}%") From 1f72c8fe70e7535a6bfb5cda7440dd29c2c38a19 Mon Sep 17 00:00:00 2001 From: Mddct Date: Sun, 22 Dec 2024 03:06:52 +0800 Subject: [PATCH 4/4] update readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5e4e4b9..4395aa4 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ pip install s3tokenizer ```py import s3tokenizer -tokenizer = s3tokenizer.load_model("speech_tokenizer_v1").cuda() # or "speech_tokenizer_v1_25hz" +tokenizer = s3tokenizer.load_model("speech_tokenizer_v1").cuda() # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz" mels = [] wav_paths = ["s3tokenizer/assets/BAC009S0764W0121.wav", "s3tokenizer/assets/BAC009S0764W0122.wav"] @@ -48,7 +48,7 @@ s3tokenizer --wav_scp xxx.scp \ --device "cpu" \ --output_dir "./" \ --batch_size 32 \ - --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz" + --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz" ``` @@ -66,7 +66,7 @@ torchrun --nproc_per_node=8 --nnodes=1 \ --device "cuda" \ --output_dir "./" \ --batch_size 32 \ - --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz" + --model "speech_tokenizer_v1" # or "speech_tokenizer_v1_25hz speech_tokenizer_v2_25hz" ```