Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] inference feature extraction same as at the training time #397

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

taras-sereda
Copy link

This PR fixes discrepancy in feature extraction logic at training and inference time.

I observed that extracted embedding have different numeric values, when extracted following training and inference pipelines, cosine similarity between extracted embeddings was at the level of 0.9, when it should be close to 1.0 .

With suggested fix, extracted embeddings have numerically close values, respective cosine similarities are 1.0. Though vector numerical values are still non-exact, and I didn't find yet, why is it so.

I took compute_fbank logic from processors that are used at training time, and made sure that the same hyperparameters for filter banks computation are used at the inference.

This code snippet validates correctness of extracted embeddings.

This issue emerged with fine-tuned ReDimNet model, that also required to pass correct number of Mel features, which is 72 for ReDimNet. Thereafter, I suggest storing dataset params as part of class attributes.

import torchaudio
import wespeaker
from torchmetrics.functional.pairwise import pairwise_cosine_similarity
import wespeaker.dataset.processor as processor
import yaml
from pathlib import Path


AUDIO_DIR = Path("/path/to/audio/dir")


model_dir = "/path/to/model/dir"
model = wespeaker.load_model_local(model_dir)

config_path = Path(model_dir) / "config.yaml"
with open(config_path, "r") as fin:
    configs = yaml.load(fin, Loader=yaml.FullLoader)

wavs = []
embs_1 = []
for audio_path in list(AUDIO_DIR.glob("*.wav")):
    audio_path = str(audio_path)
    wav, sample_rate = torchaudio.load(audio_path)
    wavs.append(
        {"sample_rate": sample_rate, "wav": wav, "key": None, "label": None}
    )
    emb = model.extract_embedding(audio_path)
    embs_1.append(emb)

feats = []
for fbank in processor.compute_fbank(wavs, **configs["dataset_args"]["fbank_args"]):
    feats.append(fbank["feat"].unsqueeze(0))

embs_2 = []
with torch.no_grad():
    for f in feats:
        _, emb = model.model(f)
        embs_2.append(emb)

for e1, e2 in zip(embs_1, embs_2):
    print(
        e1.mean(),
        e2[0].mean(),
        torch.allclose(e1, e2, atol=1e-3),
        pairwise_cosine_similarity(e1.unsqueeze(0), e2),
    )

Sample output:

tensor(0.0049) tensor(0.0048) False tensor([[1.0000]])
tensor(0.0185) tensor(0.0185) False tensor([[1.0000]])
tensor(0.0280) tensor(0.0279) False tensor([[1.0000]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant