From 655039eb4240fa4c58a9477f9f68e270f803e00b Mon Sep 17 00:00:00 2001 From: Hongji Wang Date: Mon, 19 Aug 2024 19:14:47 +0800 Subject: [PATCH] [feature] add a frontend module in wespeaker and support wavlm (#344) * [feature] add a frontend module in wespeaker and support wavlm * update .gitignore * update wavlm configs * update wespeaker/frontend/__init__.py * [fix] remove trailing whitespaces * [fix] fix lint errors * [fix] fix lint errors * [fix] fix lint errors * [fix] fix spelling mistakes * update run.sh * update wavlm configs and add run_wavlm.sh * update README.md --- .gitignore | 1 + README.md | 1 + examples/voxceleb/v2/README.md | 22 +++ examples/voxceleb/v2/conf/ecapa_tdnn.yaml | 1 + .../v2/conf/ecapa_tdnn_WavLM_frozen.yaml | 91 ++++++++++ .../v2/conf/ecapa_tdnn_WavLM_joint_ft.yaml | 92 ++++++++++ .../v2/conf/ecapa_tdnn_WavLM_joint_lmft.yaml | 92 ++++++++++ examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml | 1 + examples/voxceleb/v2/run.sh | 8 +- examples/voxceleb/v2/run_wavlm.sh | 157 ++++++++++++++++++ requirements.txt | 5 +- tools/extract_embedding.sh | 2 +- wespeaker/bin/extract.py | 37 ++++- wespeaker/bin/train.py | 35 +++- wespeaker/dataset/dataset.py | 26 ++- wespeaker/dataset/dataset_deprecated.py | 6 +- wespeaker/dataset/dataset_utils.py | 53 ++++++ .../dataset_utils_deprecated.py | 0 wespeaker/frontend/__init__.py | 18 ++ wespeaker/frontend/s3prl.py | 93 +++++++++++ wespeaker/utils/executor.py | 49 +++--- 21 files changed, 738 insertions(+), 52 deletions(-) create mode 100644 examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_frozen.yaml create mode 100644 examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_ft.yaml create mode 100644 examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_lmft.yaml create mode 100755 examples/voxceleb/v2/run_wavlm.sh create mode 100644 wespeaker/dataset/dataset_utils.py rename wespeaker/{utils => dataset}/dataset_utils_deprecated.py (100%) create mode 100644 wespeaker/frontend/__init__.py create mode 100644 wespeaker/frontend/s3prl.py diff --git a/.gitignore b/.gitignore index acdfae5f..f1181632 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,4 @@ tensorboard *.onnx external_tools pretrained_models +s3prl_hub diff --git a/README.md b/README.md index 6f707d21..de8659d3 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,7 @@ pre-commit install # for clean and tidy code ``` ## 🔥 News +* 2024.08.18: Support using ssl pre-trained models as the frontend. The [WavLM recipe](https://github.com/wenet-e2e/wespeaker/blob/master/examples/voxceleb/v2/run_wavlm.sh) is also provided, see [#344](https://github.com/wenet-e2e/wespeaker/pull/344). * 2024.05.15: Add support for [quality-aware score calibration](https://arxiv.org/pdf/2211.00815), see [#320](https://github.com/wenet-e2e/wespeaker/pull/320). * 2024.04.25: Add support for the gemini-dfresnet model, see [#291](https://github.com/wenet-e2e/wespeaker/pull/291). * 2024.04.23: Support MNN inference engine in runtime, see [#310](https://github.com/wenet-e2e/wespeaker/pull/310). diff --git a/examples/voxceleb/v2/README.md b/examples/voxceleb/v2/README.md index 72151e01..e4b2af31 100644 --- a/examples/voxceleb/v2/README.md +++ b/examples/voxceleb/v2/README.md @@ -70,3 +70,25 @@ The results on ResNet34 (large margin, no asnorm) are: |:--------------:|:------------:|:------------:|:------------:| | PLDA | 1.207 | 1.350 | 2.528 | + +## WavLM results + +* Pre-trained frontend: the [WavLM](https://arxiv.org/abs/2110.13900) Large model, multilayer features are used +* Speaker model: ECAPA_TDNN_GLOB_c512-ASTP-emb192 +* Training strategy: Frozen => Joint ft => Joint lmft + +```bash +bash run_wavlm.sh --stage 3 --stop_stage 9 +``` + +| Training strategy | AS-Norm | QMF | vox1-O-clean | vox1-E-clean | vox1-H-clean | +|:------------------|:-------:|:---:|:------------:|:------------:|:------------:| +| Frozen | × | × | 0.595 | 0.719 | 1.501 | +| | √ | × | 0.548 | 0.656 | 1.355 | +| | √ | √ | 0.489 | 0.619 | 1.224 | +| Frozen => Joint ft | × | × | 0.542 | 0.635 | 1.355 | +| | √ | × | 0.521 | 0.594 | 1.237 | +| | √ | √ | 0.494 | 0.576 | 1.205 | +| Frozen => Joint ft => Joint lmft | × | × | 0.521 | 0.626 | 1.344 | +| | √ | × | 0.495 | 0.588 | 1.247 | +| | √ | √ | **0.415** | **0.551** | **1.118** | diff --git a/examples/voxceleb/v2/conf/ecapa_tdnn.yaml b/examples/voxceleb/v2/conf/ecapa_tdnn.yaml index 0a64d1d3..4602918f 100644 --- a/examples/voxceleb/v2/conf/ecapa_tdnn.yaml +++ b/examples/voxceleb/v2/conf/ecapa_tdnn.yaml @@ -32,6 +32,7 @@ dataset_args: speed_perturb: True num_frms: 200 aug_prob: 0.6 # prob to add reverb & noise aug per sample + frontend: "fbank" # fbank, s3prl fbank_args: num_mel_bins: 80 frame_shift: 10 diff --git a/examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_frozen.yaml b/examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_frozen.yaml new file mode 100644 index 00000000..826b3d46 --- /dev/null +++ b/examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_frozen.yaml @@ -0,0 +1,91 @@ +### train configuraton + +exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-WavLM_Large_frozen-num_frms150-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch150 +gpus: "[0,1,2,3,4,5,6,7]" +num_avg: 10 +enable_amp: True # whether enable automatic mixed precision training + +seed: 42 +num_epochs: 150 +save_epoch_interval: 5 # save model every 5 epochs +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 256 + num_workers: 16 + pin_memory: False + prefetch_factor: 16 + drop_last: True + +dataset_args: + # the sample number which will be traversed within one epoch, if the value equals to 0, + # the utterance number in the dataset will be used as the sample_num_per_epoch. + sample_num_per_epoch: 0 + shuffle: True + shuffle_args: + shuffle_size: 2500 + filter: True + filter_args: + min_num_frames: 50 + max_num_frames: 400 + resample_rate: 16000 + speed_perturb: True + num_frms: 150 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + frontend: "s3prl" # fbank, s3prl + s3prl_args: + upstream_args: + name: "wavlm_large" + download_dir: ./s3prl_hub + multilayer_feature: True + layer: -1 + frozen: True + frame_shift: 20 + frame_length: 20 + cmvn: True + cmvn_args: + norm_mean: True + norm_var: False + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: ECAPA_TDNN_GLOB_c512 # ECAPA_TDNN_GLOB_c512, ECAPA_TDNN_GLOB_c1024 +model_init: null +model_args: + feat_dim: -1 # equals to the output_size of the frontend (will be initialized before training) + embed_dim: 192 + pooling_func: "ASTP" # the default pooling_func in ECAPA_TDNN is ASTP +projection_args: + project_type: "arc_margin_intertopk_subcenter" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter + scale: 32.0 + easy_margin: False + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.0 + final_margin: 0.2 + increase_start_epoch: 20 + fix_start_epoch: 40 + update_margin: True + increase_type: "exp" # exp, linear + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: SGD +optimizer_args: + momentum: 0.9 + nesterov: True + weight_decay: 0.0001 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 0.1 + final_lr: 0.00001 + warm_up_epoch: 6 + warm_from_zero: True diff --git a/examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_ft.yaml b/examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_ft.yaml new file mode 100644 index 00000000..69961309 --- /dev/null +++ b/examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_ft.yaml @@ -0,0 +1,92 @@ +### train configuraton + +exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-WavLM_Large_joint_ft-num_frms150-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch20 +gpus: "[0,1,2,3,4,5,6,7]" +num_avg: 3 +enable_amp: True # whether enable automatic mixed precision training +do_lm: False + +seed: 42 +num_epochs: 20 +save_epoch_interval: 1 # save model every epoch +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 64 + num_workers: 8 + pin_memory: False + prefetch_factor: 8 + drop_last: True + +dataset_args: + # the sample number which will be traversed within one epoch, if the value equals to 0, + # the utterance number in the dataset will be used as the sample_num_per_epoch. + sample_num_per_epoch: 0 + shuffle: True + shuffle_args: + shuffle_size: 2500 + filter: True + filter_args: + min_num_frames: 50 + max_num_frames: 400 + resample_rate: 16000 + speed_perturb: True + num_frms: 150 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + frontend: "s3prl" # fbank, s3prl + s3prl_args: + upstream_args: + name: "wavlm_large" + download_dir: ./s3prl_hub + multilayer_feature: True + layer: -1 + frozen: False + frame_shift: 20 + frame_length: 20 + cmvn: True + cmvn_args: + norm_mean: True + norm_var: False + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: ECAPA_TDNN_GLOB_c512 # ECAPA_TDNN_GLOB_c512, ECAPA_TDNN_GLOB_c1024 +model_init: null +model_args: + feat_dim: -1 # equals to the output_size of the frontend (will be initialized before training) + embed_dim: 192 + pooling_func: "ASTP" # the default pooling_func in ECAPA_TDNN is ASTP +projection_args: + project_type: "arc_margin_intertopk_subcenter" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter + scale: 32.0 + easy_margin: False + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.2 + final_margin: 0.2 + increase_start_epoch: 1 + fix_start_epoch: 1 + update_margin: True + increase_type: "exp" # exp, linear + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: SGD +optimizer_args: + momentum: 0.9 + nesterov: True + weight_decay: 0.0001 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 0.001 + final_lr: 0.00025 + warm_up_epoch: 1 + warm_from_zero: True diff --git a/examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_lmft.yaml b/examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_lmft.yaml new file mode 100644 index 00000000..8fd1b854 --- /dev/null +++ b/examples/voxceleb/v2/conf/ecapa_tdnn_WavLM_joint_lmft.yaml @@ -0,0 +1,92 @@ +### train configuraton + +exp_dir: exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-WavLM_Large_joint_lmft-num_frms300-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch10 +gpus: "[0,1,2,3,4,5,6,7]" +num_avg: 1 +enable_amp: True # whether enable automatic mixed precision training +do_lm: True + +seed: 42 +num_epochs: 20 +save_epoch_interval: 1 # save model every epoch +log_batch_interval: 100 # log every 100 batchs + +dataloader_args: + batch_size: 32 + num_workers: 8 + pin_memory: False + prefetch_factor: 8 + drop_last: True + +dataset_args: + # the sample number which will be traversed within one epoch, if the value equals to 0, + # the utterance number in the dataset will be used as the sample_num_per_epoch. + sample_num_per_epoch: 0 + shuffle: True + shuffle_args: + shuffle_size: 2500 + filter: True + filter_args: + min_num_frames: 50 + max_num_frames: 400 + resample_rate: 16000 + speed_perturb: True + num_frms: 300 + aug_prob: 0.6 # prob to add reverb & noise aug per sample + frontend: "s3prl" # fbank, s3prl + s3prl_args: + upstream_args: + name: "wavlm_large" + download_dir: ./s3prl_hub + multilayer_feature: True + layer: -1 + frozen: False + frame_shift: 20 + frame_length: 20 + cmvn: True + cmvn_args: + norm_mean: True + norm_var: False + spec_aug: False + spec_aug_args: + num_t_mask: 1 + num_f_mask: 1 + max_t: 10 + max_f: 8 + prob: 0.6 + +model: ECAPA_TDNN_GLOB_c512 # ECAPA_TDNN_GLOB_c512, ECAPA_TDNN_GLOB_c1024 +model_init: null +model_args: + feat_dim: -1 # equals to the output_size of the frontend (will be initialized before training) + embed_dim: 192 + pooling_func: "ASTP" # the default pooling_func in ECAPA_TDNN is ASTP +projection_args: + project_type: "arc_margin_intertopk_subcenter" # add_margin, arc_margin, sphere, softmax, arc_margin_intertopk_subcenter + scale: 32.0 + easy_margin: False + +margin_scheduler: MarginScheduler +margin_update: + initial_margin: 0.5 + final_margin: 0.5 + increase_start_epoch: 1 + fix_start_epoch: 1 + update_margin: True + increase_type: "exp" # exp, linear + +loss: CrossEntropyLoss +loss_args: {} + +optimizer: SGD +optimizer_args: + momentum: 0.9 + nesterov: True + weight_decay: 0.0001 + +scheduler: ExponentialDecrease +scheduler_args: + initial_lr: 0.0001 + final_lr: 0.000025 + warm_up_epoch: 1 + warm_from_zero: True diff --git a/examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml b/examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml index e8e243a8..04178cc9 100644 --- a/examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml +++ b/examples/voxceleb/v2/conf/ecapa_tdnn_lm.yaml @@ -38,6 +38,7 @@ dataset_args: speed_perturb: True num_frms: 600 aug_prob: 0.6 # prob to add reverb & noise aug per sample + frontend: "fbank" # fbank, s3prl fbank_args: num_mel_bins: 80 frame_shift: 10 diff --git a/examples/voxceleb/v2/run.sh b/examples/voxceleb/v2/run.sh index 734bb8d4..955755a8 100755 --- a/examples/voxceleb/v2/run.sh +++ b/examples/voxceleb/v2/run.sh @@ -11,8 +11,8 @@ stop_stage=-1 data=data data_type="shard" # shard/raw -config=conf/campplus.yaml -exp_dir=exp/CAMPPlus-TSTP-emb512-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150 +config=conf/resnet.yaml +exp_dir=exp/ResNet34-TSTP-emb256-fbank80-num_frms200-aug0.6-spTrue-saFalse-ArcMargin-SGD-epoch150 gpus="[0,1]" num_avg=10 checkpoint= @@ -22,7 +22,7 @@ score_norm_method="asnorm" # asnorm/snorm top_n=300 # setup for large margin fine-tuning -lm_config=conf/campplus_lm.yaml +lm_config=conf/resnet_lm.yaml . tools/parse_options.sh || exit 1 @@ -55,7 +55,7 @@ fi if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then echo "Start training ..." num_gpus=$(echo $gpus | awk -F ',' '{print NF}') - torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \ + torchrun --master_addr=localhost --master_port=29401 --nnodes=1 --nproc_per_node=$num_gpus \ wespeaker/bin/train.py --config $config \ --exp_dir ${exp_dir} \ --gpus $gpus \ diff --git a/examples/voxceleb/v2/run_wavlm.sh b/examples/voxceleb/v2/run_wavlm.sh new file mode 100755 index 00000000..9494ea16 --- /dev/null +++ b/examples/voxceleb/v2/run_wavlm.sh @@ -0,0 +1,157 @@ +#!/bin/bash + +# Copyright 2024 Hongji Wang (jijijiang77@gmail.com) + +. ./path.sh || exit 1 + +stage=-1 +stop_stage=-1 + +data=data +data_type="shard" # shard/raw + +config=conf/ecapa_tdnn_WavLM_frozen.yaml +exp_dir=exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-WavLM_large_frozen_num_frms150-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch150 +gpus="[0,1,2,3]" #,4,5,6,7]" +num_avg=10 +checkpoint= + +trials="vox1_O_cleaned.kaldi vox1_E_cleaned.kaldi vox1_H_cleaned.kaldi" +score_norm_method="asnorm" # asnorm/snorm +top_n=300 + +# setup for joint ft and lmft +joint_ft_config=conf/ecapa_tdnn_WavLM_joint_ft.yaml +joint_ft_exp_dir=exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-WavLM_Large_joint_ft-num_frms150-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch20 +joint_lmft_config=conf/ecapa_tdnn_WavLM_joint_lmft.yaml +joint_lmft_exp_dir=exp/ECAPA_TDNN_GLOB_c512-ASTP-emb192-WavLM_Large_joint_lmft-num_frms300-aug0.6-spTrue-saFalse-ArcMargin_intertopk_subcenter-SGD-epoch10 + +. tools/parse_options.sh || exit 1 + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "Prepare datasets ..." + ./local/prepare_data.sh --stage 2 --stop_stage 4 --data ${data} +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "Covert train and test data to ${data_type}..." + for dset in vox2_dev vox1; do + if [ $data_type == "shard" ]; then + python tools/make_shard_list.py --num_utts_per_shard 1000 \ + --num_threads 16 \ + --prefix shards \ + --shuffle \ + ${data}/$dset/wav.scp ${data}/$dset/utt2spk \ + ${data}/$dset/shards ${data}/$dset/shard.list + else + python tools/make_raw_list.py ${data}/$dset/wav.scp \ + ${data}/$dset/utt2spk ${data}/$dset/raw.list + fi + done + # Convert all musan data to LMDB + python tools/make_lmdb.py ${data}/musan/wav.scp ${data}/musan/lmdb + # Convert all rirs data to LMDB + python tools/make_lmdb.py ${data}/rirs/wav.scp ${data}/rirs/lmdb +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "Start training ..." + num_gpus=$(echo $gpus | awk -F ',' '{print NF}') + torchrun --master_addr=localhost --master_port=29401 --nnodes=1 --nproc_per_node=$num_gpus \ + wespeaker/bin/train.py --config $config \ + --exp_dir ${exp_dir} \ + --gpus $gpus \ + --num_avg ${num_avg} \ + --data_type "${data_type}" \ + --train_data ${data}/vox2_dev/${data_type}.list \ + --train_label ${data}/vox2_dev/utt2spk \ + --reverb_data ${data}/rirs/lmdb \ + --noise_data ${data}/musan/lmdb \ + ${checkpoint:+--checkpoint $checkpoint} +fi + +if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then + echo "Do model average ..." + avg_model=$exp_dir/models/avg_model.pt + false && python wespeaker/bin/average_model.py \ + --dst_model $avg_model \ + --src_path $exp_dir/models \ + --num ${num_avg} + + echo "Extract embeddings ..." + local/extract_vox.sh \ + --exp_dir $exp_dir --model_path $avg_model \ + --nj 8 --gpus $gpus --data_type $data_type --data ${data} +fi + +if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then + echo "Score ..." + local/score.sh \ + --stage 1 --stop-stage 2 \ + --data ${data} \ + --exp_dir $exp_dir \ + --trials "$trials" +fi + +if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then + echo "Score norm ..." + local/score_norm.sh \ + --stage 1 --stop-stage 3 \ + --score_norm_method $score_norm_method \ + --cohort_set vox2_dev \ + --top_n $top_n \ + --data ${data} \ + --exp_dir $exp_dir \ + --trials "$trials" +fi + +if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then + echo "Score calibration ..." + local/score_calibration.sh \ + --stage 1 --stop-stage 5 \ + --score_norm_method $score_norm_method \ + --calibration_trial "vox2_cali.kaldi" \ + --cohort_set vox2_dev \ + --top_n $top_n \ + --data ${data} \ + --exp_dir $exp_dir \ + --trials "$trials" +fi + +if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + echo "Joint fine-tuning ..." + mkdir -p ${joint_ft_exp_dir}/models + # Use the average frozen model to initialize the joint-ft training + cp ${exp_dir}/models/avg_model.pt ${joint_ft_exp_dir}/models/model_0.pt + bash run_wavlm.sh --stage 3 --stop_stage 7 \ + --data ${data} \ + --data_type ${data_type} \ + --config ${joint_ft_config} \ + --exp_dir ${joint_ft_exp_dir} \ + --gpus $gpus \ + --num_avg 3 \ + --checkpoint ${joint_ft_exp_dir}/models/model_0.pt \ + --trials "$trials" \ + --score_norm_method ${score_norm_method} \ + --top_n ${top_n} +fi + +if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then + echo "Joint LM fine-tuning ..." + [ ! -f ${joint_ft_exp_dir}/models/avg_model.pt ] && + echo "Please do joint fint-tuning first" && exit 1 + mkdir -p ${joint_lmft_exp_dir}/models + # Use the average joint_ft model to initialize the joint_lmft training + cp ${joint_ft_exp_dir}/models/avg_model.pt ${joint_lmft_exp_dir}/models/model_0.pt + bash run_wavlm.sh --stage 3 --stop_stage 7 \ + --data ${data} \ + --data_type ${data_type} \ + --config ${joint_lmft_config} \ + --exp_dir ${joint_lmft_exp_dir} \ + --gpus $gpus \ + --num_avg 1 \ + --checkpoint ${joint_lmft_exp_dir}/models/model_0.pt \ + --trials "$trials" \ + --score_norm_method ${score_norm_method} \ + --top_n ${top_n} +fi diff --git a/requirements.txt b/requirements.txt index 535dc650..7850da2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ scipy==1.10.0 tableprint==0.9.1 torchnet==0.0.4 tqdm==4.66.3 -scikit-learn==1.5.0 +scikit-learn matplotlib==3.5.1 flake8==3.8.2 flake8-bugbear @@ -19,7 +19,8 @@ pycodestyle==2.6.0 pyflakes==2.2.0 lmdb==1.3.0 onnxruntime -soundfile==0.10.3.post1 +soundfile pypeln==0.4.9 silero-vad pre-commit==3.5.0 +s3prl diff --git a/tools/extract_embedding.sh b/tools/extract_embedding.sh index a12704c5..1f3327d6 100755 --- a/tools/extract_embedding.sh +++ b/tools/extract_embedding.sh @@ -48,7 +48,7 @@ for suffix in $(seq 0 $(($nj - 1))); do suffix=$(printf '%03d' $suffix) data_list_subfile=${log_dir}/split_${suffix} embed_ark=${embed_dir}/xvector_${suffix}.ark - CUDA_VISIBLE_DEVICES=${gpus[$idx]} python wespeaker/bin/extract.py \ + CUDA_VISIBLE_DEVICES=${gpus[$idx]} python -u wespeaker/bin/extract.py \ --config ${exp_dir}/config.yaml \ --model_path ${model_path} \ --data_type ${data_type} \ diff --git a/wespeaker/bin/extract.py b/wespeaker/bin/extract.py index 1f07ef5d..dfa5fdac 100644 --- a/wespeaker/bin/extract.py +++ b/wespeaker/bin/extract.py @@ -23,6 +23,8 @@ from tqdm import tqdm from wespeaker.dataset.dataset import Dataset +from wespeaker.dataset.dataset_utils import apply_cmvn, spec_aug +from wespeaker.frontend import * from wespeaker.models.speaker_model import get_speaker_model from wespeaker.utils.checkpoint import load_checkpoint from wespeaker.utils.utils import parse_config_or_kwargs, validate_path @@ -41,18 +43,27 @@ def extract(config='conf/config.yaml', **kwargs): # auto-tuner to False torch.backends.cudnn.benchmark = False + test_conf = copy.deepcopy(configs['dataset_args']) + # model: frontend (optional) => speaker model model = get_speaker_model(configs['model'])(**configs['model_args']) + frontend_type = test_conf.get('frontend', 'fbank') + if frontend_type == 's3prl': + frontend_args = frontend_type + "_args" + print('Initializing frontend model (this could take some time) ...') + frontend = frontend_class_dict[frontend_type]( + **test_conf[frontend_args], sample_rate=test_conf['resample_rate']) + model.add_module("frontend", frontend) + print('Loading checkpoint ...') load_checkpoint(model, model_path) + print('Finished !!! Start extracting ...') device = torch.device("cuda") model.to(device).eval() # test_configs - test_conf = copy.deepcopy(configs['dataset_args']) + # test_conf = copy.deepcopy(configs['dataset_args']) test_conf['speed_perturb'] = False if 'fbank_args' in test_conf: test_conf['fbank_args']['dither'] = 0.0 - elif 'mfcc_args' in test_conf: - test_conf['mfcc_args']['dither'] = 0.0 test_conf['spec_aug'] = False test_conf['shuffle'] = False test_conf['aug_prob'] = configs.get('aug_prob', 0.0) @@ -81,8 +92,24 @@ def extract(config='conf/config.yaml', **kwargs): embed_scp) as writer: for _, batch in tqdm(enumerate(dataloader)): utts = batch['key'] - features = batch['feat'] - features = features.float().to(device) # (B,T,F) + if frontend_type == 'fbank': + features = batch['feat'] + features = features.float().to(device) # (B,T,F) + else: # 's3prl' + wavs = batch['wav'] # (B,1,W) + wavs = wavs.squeeze(1).float().to(device) # (B,W) + wavs_len = torch.LongTensor([wavs.shape[1]]).repeat( + wavs.shape[0]).to(device) # (B) + features, _ = model.frontend(wavs, wavs_len) + + # apply cmvn + if test_conf.get('cmvn', True): + features = apply_cmvn(features, + **test_conf.get('cmvn_args', {})) + # spec augmentation + if test_conf.get('spec_aug', False): + features = spec_aug(features, **test_conf['spec_aug_args']) + # Forward through model outputs = model(features) # embed or (embed_a, embed_b) embeds = outputs[-1] if isinstance(outputs, tuple) else outputs diff --git a/wespeaker/bin/train.py b/wespeaker/bin/train.py index 46041761..3c534353 100644 --- a/wespeaker/bin/train.py +++ b/wespeaker/bin/train.py @@ -26,6 +26,7 @@ import wespeaker.utils.schedulers as schedulers from wespeaker.dataset.dataset import Dataset +from wespeaker.frontend import * from wespeaker.models.projections import get_projection from wespeaker.models.speaker_model import get_speaker_model from wespeaker.utils.checkpoint import load_checkpoint, save_checkpoint @@ -104,12 +105,26 @@ def train(config='conf/config.yaml', **kwargs): logger.info("train dataloaders created") logger.info('epoch iteration number: {}'.format(epoch_iter)) - # model + # model: frontend (optional) => speaker model => projection layer logger.info("<== Model ==>") - model = get_speaker_model(configs['model'])(**configs['model_args']) - num_params = sum(param.numel() for param in model.parameters()) + # frontend: fbank or s3prl + frontend_type = configs['dataset_args'].get('frontend', 'fbank') + if frontend_type == 's3prl': + frontend_args = frontend_type + "_args" + frontend = frontend_class_dict[frontend_type]( + **configs['dataset_args'][frontend_args], + sample_rate=configs['dataset_args']['resample_rate']) + # speaker model + configs['model_args']['feat_dim'] = frontend.output_size() + model = get_speaker_model(configs['model'])(**configs['model_args']) + model.add_module("frontend", frontend) + else: # == 'fbank' + # speaker model + model = get_speaker_model(configs['model'])(**configs['model_args']) if rank == 0: + num_params = sum(param.numel() for param in model.parameters()) logger.info('speaker_model size: {}'.format(num_params)) + # For model_init, only frontend and speaker model are needed !!! if configs['model_init'] is not None: logger.info('Load initial model from {}'.format(configs['model_init'])) load_checkpoint(model, configs['model_init']) @@ -137,10 +152,13 @@ def train(config='conf/config.yaml', **kwargs): # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine # the code to satisfy the script export requirements - script_model = torch.jit.script(model) - script_model.save(os.path.join(model_dir, 'init.zip')) + if frontend_type == 'fbank': + script_model = torch.jit.script(model) + script_model.save(os.path.join(model_dir, 'init.zip')) # If specify checkpoint, load some info from checkpoint. + # For checkpoint, frontend, speaker model, and projection layer + # are all needed !!! if checkpoint is not None: load_checkpoint(model, checkpoint) start_epoch = int(re.findall(r"(?<=model_)\d*(?=.pt)", @@ -219,12 +237,11 @@ def train(config='conf/config.yaml', **kwargs): epoch, logger, scaler, - enable_amp=configs['enable_amp'], - log_batch_interval=configs['log_batch_interval'], - device=device) + device=device, + configs=configs) if rank == 0: - if epoch % configs['save_epoch_interval'] == 0 or epoch >= configs[ + if epoch % configs['save_epoch_interval'] == 0 or epoch > configs[ 'num_epochs'] - configs['num_avg']: save_checkpoint( model, os.path.join(model_dir, diff --git a/wespeaker/dataset/dataset.py b/wespeaker/dataset/dataset.py index d4cf8a9d..4c8f54e8 100644 --- a/wespeaker/dataset/dataset.py +++ b/wespeaker/dataset/dataset.py @@ -155,8 +155,12 @@ def Dataset(data_type, reverb_lmdb_file: reverb data source lmdb file noise_lmdb_file: noise data source lmdb file whole_utt: use whole utt or random chunk + repeat_dataset: True for training while False for testing """ assert data_type in ['shard', 'raw', 'feat'] + frontend_type = configs.get('frontend', 'fbank') + frontend_args = frontend_type + "_args" + lists = read_lists(data_list_file) shuffle = configs.get('shuffle', False) # Global shuffle @@ -174,7 +178,7 @@ def Dataset(data_type, filter_conf = configs.get('filter_args', {}) dataset = Processor(dataset, processor.filter, - frame_shift=configs['fbank_args'].get( + frame_shift=configs[frontend_args].get( 'frame_shift', 10), data_type=data_type, **filter_conf) @@ -205,8 +209,8 @@ def Dataset(data_type, if not whole_utt: # random chunk num_frms = configs.get('num_frms', 200) - frame_shift = configs['fbank_args'].get('frame_shift', 10) - frame_length = configs['fbank_args'].get('frame_length', 25) + frame_shift = configs[frontend_args].get('frame_shift', 10) + frame_length = configs[frontend_args].get('frame_length', 25) chunk_len = ((num_frms - 1) * frame_shift + frame_length) * resample_rate // 1000 dataset = Processor(dataset, processor.random_chunk, chunk_len, @@ -220,9 +224,17 @@ def Dataset(data_type, reverb_data, noise_data, resample_rate, aug_prob) # compute fbank - dataset = Processor(dataset, processor.compute_fbank, - **configs['fbank_args']) - + if frontend_type == 'fbank': + dataset = Processor(dataset, processor.compute_fbank, + **configs['fbank_args']) + + # !!!IMPORTANT NOTICE!!! + # To support different frontends (including ssl pretrained models), + # we have to move apply_cmvn and spec_aug out of the dataset pipeline + # which runs totally in cpus. + # These two modules are now used in wespeaker/utils/executor.py (train) + # and wespeaker/bin/extract.py (test), which runs in gpus. + ''' # apply cmvn dataset = Processor(dataset, processor.apply_cmvn) @@ -231,5 +243,5 @@ def Dataset(data_type, if spec_aug_flag: dataset = Processor(dataset, processor.spec_aug, **configs['spec_aug_args']) - + ''' return dataset diff --git a/wespeaker/dataset/dataset_deprecated.py b/wespeaker/dataset/dataset_deprecated.py index fda30ddc..9ccecdcf 100644 --- a/wespeaker/dataset/dataset_deprecated.py +++ b/wespeaker/dataset/dataset_deprecated.py @@ -23,9 +23,9 @@ import torchaudio.compliance.kaldi as kaldi from wespeaker.utils.file_utils import read_scp -from wespeaker.utils.dataset_utils_deprecated import (get_random_chunk, - speed_perturb, - spec_augmentation) +from wespeaker.dataset.dataset_utils_deprecated import (get_random_chunk, + speed_perturb, + spec_augmentation) class FeatList_LableDict_Dataset(Dataset): diff --git a/wespeaker/dataset/dataset_utils.py b/wespeaker/dataset/dataset_utils.py new file mode 100644 index 00000000..feda2feb --- /dev/null +++ b/wespeaker/dataset/dataset_utils.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024 Hongji Wang (jijijiang77@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import torch + + +def apply_cmvn(feats, norm_mean=True, norm_var=False): + # feats batch: (B,T,F) + if norm_mean: + feats = feats - torch.mean(feats, dim=1, keepdim=True) + if norm_var: + feats = feats / torch.sqrt(torch.var(feats, dim=1, keepdim=True) + 1e-7) + + return feats + + +def spec_aug(feats, num_t_mask=1, num_f_mask=1, max_t=10, max_f=8, prob=0.6): + # feats batch: (B,T,F) + # do spec_aug on all batch samples using a same group of params randomly + # TODO (hongji): do spec_aug on each sample separately + if random.random() < prob: + x = feats + assert isinstance(x, torch.Tensor) + # y = x.clone().detach() + y = x.detach() # inplace operation + _, max_frames, max_freq = y.shape + # time mask + for i in range(num_t_mask): + start = random.randint(0, max_frames - 1) + length = random.randint(1, max_t) + end = min(max_frames, start + length) + y[:, start:end, :] = 0 + # freq mask + for i in range(num_f_mask): + start = random.randint(0, max_freq - 1) + length = random.randint(1, max_f) + end = min(max_freq, start + length) + y[:, :, start:end] = 0 + feats = y + + return feats diff --git a/wespeaker/utils/dataset_utils_deprecated.py b/wespeaker/dataset/dataset_utils_deprecated.py similarity index 100% rename from wespeaker/utils/dataset_utils_deprecated.py rename to wespeaker/dataset/dataset_utils_deprecated.py diff --git a/wespeaker/frontend/__init__.py b/wespeaker/frontend/__init__.py new file mode 100644 index 00000000..9b9fd27b --- /dev/null +++ b/wespeaker/frontend/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2024 Hongji Wang (jijijiang77@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .s3prl import S3prlFrontend + +frontend_class_dict = {'s3prl' : S3prlFrontend} diff --git a/wespeaker/frontend/s3prl.py b/wespeaker/frontend/s3prl.py new file mode 100644 index 00000000..d96168d5 --- /dev/null +++ b/wespeaker/frontend/s3prl.py @@ -0,0 +1,93 @@ +# Copyright (c) 2024 Hongji Wang (jijijiang77@gmail.com) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import torch +import torch.nn as nn + +import s3prl +from s3prl.nn import Featurizer, S3PRLUpstream + + +class S3prlFrontend(nn.Module): + """Speech Pretrained Representation Frontend.""" + + def __init__(self, + upstream_args: dict, + download_dir: str = "./s3prl_hub", + multilayer_feature: bool = True, + layer: int = -1, + frozen: bool = False, + frame_shift: int = 20, + frame_length: int = 20, + sample_rate: int = 16000): + super().__init__() + + self.multilayer_feature = multilayer_feature + self.layer = layer + self.frozen = frozen + + if download_dir is not None: + s3prl.util.download.set_dir(download_dir) + + assert upstream_args.get("name", + None) in S3PRLUpstream.available_names() + self.upstream = S3PRLUpstream( + upstream_args.get("name"), + path_or_url=upstream_args.get("path_or_url", None), + normalize=upstream_args.get("normalize", False), + extra_conf=upstream_args.get("extra_conf", None), + ) + if getattr(self.upstream.upstream, "model", None): + if getattr(self.upstream.upstream.model, "feature_grad_mult", + None) is not None: + self.upstream.upstream.model.feature_grad_mult = 1.0 + self.upstream.eval() + + if layer != -1: + layer_selections = [layer] + assert not multilayer_feature,\ + "multilayer_feature must be False if layer is specified" + else: + layer_selections = None + self.featurizer = Featurizer(self.upstream, + layer_selections=layer_selections) + + assert self.featurizer.downsample_rate == sample_rate * frame_shift // 1000 + + if self.frozen: + for param in self.upstream.parameters(): + param.requires_grad_(False) + else: + for name, param in self.upstream.named_parameters(): + if "mask_emb" in name: + param.requires_grad_(False) + + def output_size(self): + return self.featurizer.output_size + + def forward(self, input: torch.Tensor, input_lengths: torch.LongTensor): + with torch.no_grad() if self.frozen else contextlib.nullcontext(): + feats, feats_lens = self.upstream(input, input_lengths) + if self.layer != -1: + layer = self.layer + feats, feats_lens = feats[layer], feats_lens[layer] + return feats, feats_lens + + if self.multilayer_feature: + feats, feats_lens = self.featurizer(feats, feats_lens) + else: + feats, feats_lens = self.featurizer(feats[-1:], feats_lens[-1:]) + + return feats, feats_lens diff --git a/wespeaker/utils/executor.py b/wespeaker/utils/executor.py index 93e868c3..c9477ca5 100644 --- a/wespeaker/utils/executor.py +++ b/wespeaker/utils/executor.py @@ -17,39 +17,46 @@ import torch import torchnet as tnt +from wespeaker.dataset.dataset_utils import apply_cmvn, spec_aug -def run_epoch(dataloader, - epoch_iter, - model, - criterion, - optimizer, - scheduler, - margin_scheduler, - epoch, - logger, - scaler, - enable_amp, - log_batch_interval=100, - device=torch.device('cuda')): +def run_epoch(dataloader, epoch_iter, model, criterion, optimizer, scheduler, + margin_scheduler, epoch, logger, scaler, device, configs): model.train() # By default use average pooling loss_meter = tnt.meter.AverageValueMeter() acc_meter = tnt.meter.ClassErrorMeter(accuracy=True) + frontend_type = configs['dataset_args'].get('frontend', 'fbank') for i, batch in enumerate(dataloader): - utts = batch['key'] - targets = batch['label'] - features = batch['feat'] - cur_iter = (epoch - 1) * epoch_iter + i scheduler.step(cur_iter) margin_scheduler.step(cur_iter) - features = features.float().to(device) # (B,T,F) - targets = targets.long().to(device) + utts = batch['key'] + targets = batch['label'] + targets = targets.long().to(device) # (B) + if frontend_type == 'fbank': + features = batch['feat'] # (B,T,F) + features = features.float().to(device) + else: # 's3prl' + wavs = batch['wav'] # (B,1,W) + wavs = wavs.squeeze(1).float().to(device) # (B,W) + wavs_len = torch.LongTensor([wavs.shape[1]]).repeat( + wavs.shape[0]).to(device) # (B) + with torch.cuda.amp.autocast(enabled=configs['enable_amp']): + features, _ = model.module.frontend(wavs, wavs_len) + + with torch.cuda.amp.autocast(enabled=configs['enable_amp']): + # apply cmvn + if configs['dataset_args'].get('cmvn', True): + features = apply_cmvn( + features, **configs['dataset_args'].get('cmvn_args', {})) + # spec augmentation + if configs['dataset_args'].get('spec_aug', False): + features = spec_aug(features, **configs['spec_aug_args']) - with torch.cuda.amp.autocast(enabled=enable_amp): + with torch.cuda.amp.autocast(enabled=configs['enable_amp']): outputs = model(features) # (embed_a,embed_b) in most cases embeds = outputs[-1] if isinstance(outputs, tuple) else outputs outputs = model.module.projection(embeds, targets) @@ -70,7 +77,7 @@ def run_epoch(dataloader, scaler.update() # log - if (i + 1) % log_batch_interval == 0: + if (i + 1) % configs['log_batch_interval'] == 0: logger.info( tp.row((epoch, i + 1, scheduler.get_lr(), margin_scheduler.get_margin()) +