diff --git a/README.md b/README.md index 7132a031..ad85f529 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,19 @@ The full API is described in the documentation page [https://hyperion-ml.readthe ## Installation Instructions +### If you use the CLSP grid, simply run below in the root of your cloned repo (as of Aug 15 2022): +``` +./prepare_egs_paths.sh +# Then, type /home/janto/usr/local/anaconda3 when "Introduce path to your conda base installation (e.g.:/usr/local/anaconda3):" is prompted +# type /home/jcho/.conda/envs/hyp_persephone_jj when "Introduce name/prefix_path for your conda environment (e.g.:hyperion)" is prompted + +# You may see the two lines below but it seems okay to ignore: +# Hyperion is not installed in env +# Adding hyperion directory to the PYTHONPATH variable in the recipes. + +# Also, with this, you can skip "Prerequistes to run the recipes" below +``` + ### Prerequisites We use anaconda or miniconda, though you should be able to make it work in other python distributions diff --git a/egs/voxceleb/dinossl.v1/README.md b/egs/voxceleb/dinossl.v1/README.md new file mode 100644 index 00000000..5b5b93e5 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/README.md @@ -0,0 +1,205 @@ +# VoxCeleb V1.1 + +Recipe for the VoxCeleb Speaker Verification Task + +## Differences w.r.t VoxCeleb V1 recipe + +In recipe version V1: + - We compute speech augmentations and acoustic features offline and dump them to disk. + - Augmentation is performed using Kaldi scripts and wav-reverbate tool + - Babble noise is created on-the-fly when computing features by mixing 3-7 single speaker files. + +In this recipe: + - We compute speech augmentations and acoustic features are computed always on-the-fly, + we don't dump any features to disk. + - Augmentation is performed using Hyperin SpeechAugment class. + - The behavior of this class is controlled + by the the configuration file `conf/reverb_noise_aug.yml`, + which mimics the proportions of noise and RIR types, and SNRs used in the V1 or the recipe. + - Babble noise is created offline by mixing 3-10 single speaker files. + + +## Citing + +## Training Data + + - x-Vector network is trained on Voxceleb2 dev + test with augmentations + - MUSAN noise + - RIR reverberation + +## Test data + + - Test data is VoxCeleb 1 + - We evaluate 6 conditions: + - VoxCeleb-O (Original): Original Voxceleb test set with 40 speakers + - Voxceleb-O-cleaned: VoxCeleb-O cleaned-up of some errors + - VoxCeleb-E (Entire): List using all utterances of VoxCeleb1 + - Voxceleb-E-cleaned: VoxCeleb-E cleaned-up of some errors + - VoxCeleb-H (Hard): List of hard trials between all utterances of VoxCeleb1, same gender and nationality trials. + - Voxceleb-H-cleaned: VoxCeleb-H cleaned-up of some errors + +## Usage + + - Run the run_0*.sh scripts in sequence + - By default it will use Light ResNet (16 base channels) + - For better performance use full ResNet (64 base channels) using `config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh` file as +```bash +run_011_train_xvector.sh --config-file config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh +run_030_extract_xvectors.sh --config-file config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh --use-gpu true +run_040_eval_be.sh --config-file config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh +``` + + - To train with mixed precision training use config file `config_fbank80_stmn_lresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh` + +## Recipe Steps: + + - `run_001_prepare_data.sh` + - Data preparation script to generate Kaldi style data directories for + - VoxCeleb2 train+test + - VoxCeleb1 O/E/H eval sets + + - `run_002_compute_evad.sh` + - Computes Energy VAD for all datasets + + - `run_003_prepare_noises_rirs.sh` + - Prepares MUSAN noises, music to be used by SpeechAugment class. + - Creates Babble noise from MUSAN speech to be used by SpeechAugment class. + - Prepares RIRs by compacting then into HDF5 files, to be used by SpeechAugment class. + + - `run_010_prepare_xvec_train_data.sh` + - Transforms all the audios that we are going to use to train the x-vector into a common format, e.g., .flac. + - Removes silence from the audios + - Removes utterances shorter than 4secs and speakers with less than 8 utterances. + - Creates training and validation lists for x-vector training + + - `run_011_train_xvector.sh` + - Trains the x-vector network + + - `run_030_extract_xvectors.sh` + - Extracts x-vectors for VoxCeleb2 or VoxCeleb2+augmentation for PLDA training + - Exctracts x-vectors for VoxCeleb1 test sets + + - `run_040_eval_be.sh` + - Trains PLDA and evals PLDA and cosine scoring back-ends + + +## Results + +### VoxCeleb 1 Original-Clean trial list + +| Config | Model Type | Model Details | Back-end | EER(%) | MinDCF(p=0.05) | MinDCF(p=0.01) | +| ------ | ---------- | ------------- | -------- | :----: | :------------: | :------------: | +| config_fbank80_stmn_lresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | LResNet34 | ArcFace s=30/m=0.3 | PLDA | 2.00 | 0.129 | 0.216 | +| | | | Cosine | 2.04 | 0.138 | 0.210 | +| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet34 | ArcFace s=30/m=0.3 | PLDA | 1.35 | 0.091 | 0.159 | +| | | | Cosine | 1.22 | 0.082 | 0.129 | +| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp_swa.v1.sh | ResNet34 | + SWA | Cosine | 1.19 | 0.074 | 0.124 | +| config_fbank80_stmn_resnet50_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet50 | ArcFace s=30/m=0.3 | PLDA | 1.30 | 0.090 | 0.160 | +| | | | Cosine | 1.44 | 0.100 | 0.173 | +| config_fbank80_stmn_tseresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-ResNet34 | ArcFace s=30/m=0.3 | PLDA | 1.23 | 0.091 | 0.143 | +| | | | Cosine | 1.17 | 0.081 | 0.110 | +| config_fbank80_stmn_effnetb4_v2_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b4 v2 | EfficientNet-b4 with strides=1122121
ArcFace s=30/m=0.3 | 1.37 | 0.104 | 0.179 | +| | | | Cosine | 1.31 | 0.080 | 0.139 | +| config_fbank80_stmn_effnetb7_v2_eina_hln_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b7 v2 | EfficientNet-b7 with strides=1122121
Instance-Norm with affine transform in Encoder
Layer-Norm in head
ArcFace s=30/m=0.3 | 1.29 | 0.088 | 0.129 | +| | | | Cosine | 1.23 | 0.083 | 0.136 | +| config_fbank80_stmn_res2net34w16s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=16x4 | ArcFace s=30/m=0.3 | PLDA | 1.20 | 0.095 | 0.156 | +| | | | Cosine | 1.29 | 0.089 | 0.146 | +| config_fbank80_stmn_res2net34w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 1.20 | 0.084 | 0.136 | +| | | | Cosine | 1.18 | 0.078 | 0.115 | +| config_fbank80_stmn_res2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 1.11 | 0.084 | 0.145 | +| | | | Cosine | 1.12 | 0.073 | 0.131 | +| config_fbank80_stmn_seres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | SE-Res2Net50 | se-r=16
ArcFace s=30/m=0.3 | PLDA | 1.53 | 0.104 | 0.189 | +| | | | Cosine | 1.31 | 0.084 | 0.132 | +| config_fbank80_stmn_tseres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-Res2Net50 | se-r=256
ArcFace s=30/m=0.3 | PLDA | 0.98 | 0.066 | 0.116 | +| | | | Cosine | 1.12 | 0.071 | 0.103 | +| config_fbank80_stmn_res2net50w13s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=13x8 | ArcFace s=30/m=0.3 | PLDA | 1.05 | 0.077 | 0.123 | +| | | | Cosine | 0.96 | 0.065 | 0.110 | +| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x8 | ArcFace s=30/m=0.3 | PLDA | 1.04 | 0.071 | 0.118 | +| | | | Cosine | 0.93 | 0.067 | 0.108 | +| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp.v1_swa.sh | Res2Net50 width=26x8 | + SWA | PLDA | 0.90 | 0.067 | 0.118 | +| | | | Cosine | 0.85 | 0.060 | 0.094 | +| config_fbank80_stmn_spinenet49s_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49S | ArcFace s=30/m=0.3 | PLDA | 1.44 | 0.102 | 0.169 | +| | | | Cosine | 1.29 | 0.084 | 0.140 | +| config_fbank80_stmn_spinenet49_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49 | ArcFace s=30/m=0.3 | Cosine | 1.12 | 0.071 | 0.116 | +| config_fbank80_stmn_spine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 1.05 | 0.074 | 0.116 | +| config_fbank80_stmn_tsespine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 1.09 | 0.081 | 0.150 | + + +### VoxCeleb 1 Entire-Clean trial list + +| Config | Model Type | Model Details | Back-end | EER(%) | MinDCF(p=0.05) | MinDCF(p=0.01) | +| ------ | ---------- | ------------- | -------- | :----: | :------------: | :------------: | +| config_fbank80_stmn_lresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | LResNet34 | ArcFace s=30/m=0.3 | PLDA | 1.86 | 0.124 | 0.210 | +| | | | Cosine | 1.93 | 0.122 | 0.201 | +| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet34 | ArcFace s=30/m=0.3 | PLDA | 1.43 | 0.091 | 0.159 | +| | | | Cosine | 1.24 | 0.080 | 0.136 | +| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp_swa.v1.sh | ResNet34 | + SWA | Cosine | 1.19 | 0.077 | 0.132 | +| config_fbank80_stmn_resnet50_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet50 | ArcFace s=30/m=0.3 | PLDA | 1.27 | 0.084 | 0.150 | +| | | | Cosine | 1.30 | 0.082 | 0.150 | +| config_fbank80_stmn_tseresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-ResNet34 | ArcFace s=30/m=0.3 | PLDA | 1.30 | 0.083 | 0.146 | +| | | | Cosine | 1.09 | 0.071 | 0.124 | +| config_fbank80_stmn_effnetb4_v2_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b4 v2 | EfficientNet-b4 with strides=1122121
ArcFace s=30/m=0.3 | 1.45 | 0.097 | 0.165 | +| | | | Cosine | 1.15 | 0.076 | 0.132 | +| config_fbank80_stmn_effnetb7_v2_eina_hln_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b7 v2 | EfficientNet-b7 with strides=1122121
Instance-Norm with affine transform in Encoder
Layer-Norm in head
ArcFace s=30/m=0.3 | 1.47 | 0.094 | 0.165 | +| | | | Cosine | 1.27 | 0.082 | 0.148 | +| config_fbank80_stmn_res2net34w16s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=16x4 | ArcFace s=30/m=0.3 | PLDA | 1.31 | 0.086 | 0.149 | +| | | | Cosine | 1.22 | 0.079 | 0.134 | +| config_fbank80_stmn_res2net34w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 1.27 | 0.082 | 0.145 | +| | | | Cosine | 1.16 | 0.074 | 0.130 | +| config_fbank80_stmn_res2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 1.23 | 0.077 | 0.136 | +| | | | Cosine | 1.11 | 0.071 | 0.125 | +| config_fbank80_stmn_seres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | SE-Res2Net50 | se-r=16
ArcFace s=30/m=0.3 | PLDA | 1.46 | 0.097 | 0.173 | +| | | | Cosine | 1.24 | 0.080 | 0.140 | +| config_fbank80_stmn_tseres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-Res2Net50 | se-r=256
ArcFace s=30/m=0.3 | PLDA | 1.11 | 0.071 | 0.127 | +| | | | Cosine | 1.05 | 0.067 | 0.117 | +| config_fbank80_stmn_res2net50w13s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=13x8 | ArcFace s=30/m=0.3 | PLDA | 1.23 | 0.078 | 0.134 | +| | | | Cosine | 1.05 | 0.069 | 0.121 | +| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x8 | ArcFace s=30/m=0.3 | PLDA | 1.18 | 0.075 | 0.131 | +| | | | Cosine | 0.98 | 0.063 | 0.110 | +| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp_swa.v1.sh | Res2Net50 width=26x8 | + SWA | PLDA | 1.17 | 0.072 | 0.123 | +| | | | Cosine | 0.94 | 0.061 | 0.107 | +| config_fbank80_stmn_spinenet49s_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49S | ArcFace s=30/m=0.3 | PLDA | 1.56 | 0.095 | 0.166 | +| | | | Cosine | 1.27 | 0.079 | 0.142 | +| config_fbank80_stmn_spinenet49_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49 | ArcFace s=30/m=0.3 | Cosine | 1.19 | 0.077 | 0.137 | +| config_fbank80_stmn_spine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 1.12 | 0.073 | 0.129 | +| config_fbank80_stmn_tsespine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | TSE-Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 1.05 | 0.068 | 0.120 | + + +### VoxCeleb 1 Hard-Clean trial list + +| Config | Model Type | Model Details | Back-end | EER(%) | MinDCF(p=0.05) | MinDCF(p=0.01) | +| ------ | ---------- | ------------- | -------- | :----: | :------------: | :------------: | +| config_fbank80_stmn_lresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | LResNet34 | ArcFace s=30/m=0.3 | PLDA | 3.29 | 0.195 | 0.318 | +| | | | Cosine | 3.27 | 0.188 | 0.303 | +| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet34 | ArcFace s=30/m=0.3 | PLDA | 2.66 | 0.160 | 0.258 | +| | | | Cosine | 2.32 | 0.139 | 0.232 | +| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp_swa.v1.sh | ResNet34 | + SWA | Cosine | 2.19 | 0.133 | 0.215 | +| config_fbank80_stmn_resnet50_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet50 | ArcFace s=30/m=0.3 | PLDA | 2.33 | 0.139 | 0.227 | +| | | | Cosine | 2.33 | 0.142 | 0.235 | +| config_fbank80_stmn_tseresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-ResNet34 | ArcFace s=30/m=0.3 | PLDA | 2.46 | 0.142 | 0.237 | +| | | | Cosine | 2.14 | 0.126 | 0.203 | +| config_fbank80_stmn_effnetb4_v2_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b4 v2 | EfficientNet-b4 with strides=1122121
ArcFace s=30/m=0.3 | 2.57 | 0.153 | 0.255 | +| | | | Cosine | 2.11 | 0.127 | 0.205 | +| config_fbank80_stmn_effnetb7_v2_eina_hln_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b7 v2 | EfficientNet-b7 with strides=1122121
Instance-Norm with affine transform in Encoder
Layer-Norm in head
ArcFace s=30/m=0.3 | 2.64 | 0.157 | 0.244 | +| | | | Cosine | 2.33 | 0.141 | 0.232 | +| config_fbank80_stmn_res2net34w16s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=16x4 | ArcFace s=30/m=0.3 | PLDA | 2.42 | 0.144 | 0.245 | +| | | | Cosine | 2.26 | 0.133 | 0.224 +| config_fbank80_stmn_res2net34w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 2.39 | 0.141 | 0.235 | +| | | | Cosine | 2.17 | 0.128 | 0.215 +| config_fbank80_stmn_res2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 2.28 | 0.131 | 0.225 | +| | | | Cosine | 2.11 | 0.124 | 0.204 | +| config_fbank80_stmn_seres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | SE-Res2Net50 | se-r=16
ArcFace s=30/m=0.3 | PLDA | 2.77 | 0.172 | 0.271 | +| | | | Cosine | 2.45 | 0.141 | 0.225 | +| config_fbank80_stmn_tseres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-Res2Net50 | se-r=256
ArcFace s=30/m=0.3 | PLDA | 2.07 | 0.124 | 0.201 | +| | | | Cosine | 1.95 | 0.113 | 0.181 | +| config_fbank80_stmn_res2net50w13s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=13x8 | ArcFace s=30/m=0.3 | PLDA | 2.34 | 0.136 | 0.230 | +| | | | Cosine | 1.99 | 0.119 | 0.196 | +| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x8 | ArcFace s=30/m=0.3 | PLDA | 2.18 | 0.127 | 0.211 | +| | | | Cosine | 1.89 | 0.112 | 0.184 | +| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp.v1_swa.sh | Res2Net50 width=26x8 | + SWA | PLDA | 2.14 | 0.125 | 0.209 | +| | | | Cosine | 1.84 | 0.110 | 0.186 | +| config_fbank80_stmn_spinenet49s_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49S | ArcFace s=30/m=0.3 | PLDA | 2.78 | 0.156 | 0.252 | +| | | | Cosine | 2.26 | 0.134 | 0.214 | +| config_fbank80_stmn_spinenet49_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49 | ArcFace s=30/m=0.3 | Cosine | 2.24 | 0.134 | 0.221 | +| config_fbank80_stmn_spine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 2.20 | 0.132 | 0.219 | +| config_fbank80_stmn_tsespine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 2.02 | 0.123 | 0.203 | diff --git a/egs/voxceleb/dinossl.v1/cmd.sh b/egs/voxceleb/dinossl.v1/cmd.sh new file mode 100755 index 00000000..040f458b --- /dev/null +++ b/egs/voxceleb/dinossl.v1/cmd.sh @@ -0,0 +1,28 @@ +# you can change cmd.sh depending on what type of queue you are using. +# If you have no queueing system and want to run on a local machine, you +# can change all instances 'queue.pl' to run.pl (but be careful and run +# commands one by one: most recipes will exhaust the memory on your +# machine). queue.pl works with GridEngine (qsub). slurm.pl works +# with slurm. Different queues are configured differently, with different +# queue names and different ways of specifying things like memory; +# to account for these differences you can create and edit the file +# conf/queue.conf to match your queue's configuration. Search for +# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information, +# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl. + +if [ "$(hostname -d)" == "cm.gemini" ];then + #export train_cmd="queue.pl --config conf/coe_gpu_short.conf --mem 4G" + export train_cmd="queue.pl --config conf/coe_gpu_long.conf --mem 4G" + export cuda_cmd="queue.pl --config conf/coe_gpu_long.conf --mem 20G" + #export cuda_cmd="queue.pl --config conf/coe_gpu_v100.conf --mem 20G" + export cuda_cmd="queue.pl --config conf/coe_gpu_rtx.conf --mem 40G" + export cuda_eval_cmd="queue.pl --config conf/coe_gpu_short.conf --mem 4G" + # export cuda_eval_cmd="queue.pl --config conf/coe_gpu_long.conf --mem 4G" +else + export train_cmd="queue.pl --mem 4G -l hostname=\"[bc][01]*\" -V" + export cuda_cmd="queue.pl --mem 20G -l hostname=\"c[01]*\" -V" + export cuda_eval_cmd="$train_cmd" +fi + + + diff --git a/egs/voxceleb/dinossl.v1/conf/clsp.conf b/egs/voxceleb/dinossl.v1/conf/clsp.conf new file mode 100644 index 00000000..4ed38246 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/clsp.conf @@ -0,0 +1,11 @@ + +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* -V +option mem=* -l mem_free=$0,ram_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -pe smp $0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 -l 'hostname=b[1]*|c0[123456789]*|c1[134679]*|c2[1357]*' +option gpu=* -l 'hostname=c0[123456789]*|c1[1345679]*|c2[12357]*,gpu=$0' diff --git a/egs/voxceleb/dinossl.v1/conf/coe_gpu_bigmem.conf b/egs/voxceleb/dinossl.v1/conf/coe_gpu_bigmem.conf new file mode 100644 index 00000000..a7a2ce40 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/coe_gpu_bigmem.conf @@ -0,0 +1,11 @@ + +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -sync y -l arch=*64* -V +option mem=* -l mem_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -l num_proc=$0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 -q all.q -l h_rt=100:00:00 -l hostname=r[2-7]* +option gpu=* -l gpu=$0,h_rt=500:00:00 -q gpu.q -l hostname=r[237]n[01][0123456789]* diff --git a/egs/voxceleb/dinossl.v1/conf/coe_gpu_long.conf b/egs/voxceleb/dinossl.v1/conf/coe_gpu_long.conf new file mode 100644 index 00000000..b31c167c --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/coe_gpu_long.conf @@ -0,0 +1,13 @@ + +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -sync y -l arch=*64* -V +option mem=* -l mem_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -l num_proc=$0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 -q all.q -l h_rt=100:00:00 -l hostname=r[1-9]* +option gpu=* -l gpu=$0,h_rt=500:00:00 -q gpu.q -l hostname=r[1-9]* + + diff --git a/egs/voxceleb/dinossl.v1/conf/coe_gpu_rtx.conf b/egs/voxceleb/dinossl.v1/conf/coe_gpu_rtx.conf new file mode 100644 index 00000000..ba6d9e56 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/coe_gpu_rtx.conf @@ -0,0 +1,11 @@ + +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -sync y -l arch=*64* -V +option mem=* -l mem_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -l num_proc=$0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 -q all.q -l h_rt=100:00:00 +option gpu=* -l gpu=$0,h_rt=500:00:00 -q gpu.q@@rtx diff --git a/egs/voxceleb/dinossl.v1/conf/coe_gpu_short.conf b/egs/voxceleb/dinossl.v1/conf/coe_gpu_short.conf new file mode 100644 index 00000000..81de5cb7 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/coe_gpu_short.conf @@ -0,0 +1,11 @@ + +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -sync y -l arch=*64* -V +option mem=* -l mem_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -l num_proc=$0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 -q all.q -l h_rt=100:00:00 -l hostname=r[1-9]* +option gpu=* -l gpu=$0,h_rt=00:59:00 -q gpu_short.q -l hostname=r[17]* diff --git a/egs/voxceleb/dinossl.v1/conf/coe_gpu_v100.conf b/egs/voxceleb/dinossl.v1/conf/coe_gpu_v100.conf new file mode 100644 index 00000000..69326b82 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/coe_gpu_v100.conf @@ -0,0 +1,11 @@ + +# Default configuration +command qsub -v PATH -cwd -S /bin/bash -j y -sync y -l arch=*64* -V +option mem=* -l mem_free=$0 +option mem=0 # Do not add anything to qsub_opts +option num_threads=* -l num_proc=$0 +option num_threads=1 # Do not add anything to qsub_opts +option max_jobs_run=* -tc $0 +default gpu=0 +option gpu=0 -q all.q -l h_rt=100:00:00 +option gpu=* -l gpu=$0,h_rt=500:00:00 -q gpu.q@@v100 diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/fbank80_stmn_16k.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/fbank80_stmn_16k.yaml new file mode 100644 index 00000000..f4091f5d --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/fbank80_stmn_16k.yaml @@ -0,0 +1,12 @@ +audio_feats: + audio_feat: logfb + sample_frequency: 16000 + frame_length: 25 + low_freq: 20 + high_freq: 7600 + num_filters: 80 + snip_edges: false + use_energy: false +mvn: + context: 150 + norm_var: false diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/lrsched_cos_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/lrsched_cos_default.yaml new file mode 100644 index 00000000..6f1009c0 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/lrsched_cos_default.yaml @@ -0,0 +1,7 @@ +lrsch_type: dinossl +dinossl_lr: 0.005 # For now, the learning rate is linearly scaled with the batch size. What's specified here is for the batch size of 256. (i.e., if ngpu * batch_size_per_gpu = 512, the lr becomes 0.005 * 2) +dinossl_min_lr: 1e-6 +dinossl_warmup_epochs: 10 +dinossl_weight_decay: 1e-4 +dinossl_weight_decay_end: 1e-4 +dinossl_momentum_teacher: 0.996 diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/optim_adam_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/optim_adam_default.yaml new file mode 100644 index 00000000..8362a9b5 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/optim_adam_default.yaml @@ -0,0 +1,6 @@ +opt_type: adam +amsgrad: true +beta1: 0.9 +beta2: 0.95 +weight_decay: 1.0e-05 +dinossl_style: true # (TODO: check) I think the above arguments are NOT used. If not, delete them all diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/resnet34.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/resnet34.yaml new file mode 100644 index 00000000..4042ec33 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/resnet34.yaml @@ -0,0 +1,8 @@ +resnet_type: resnet34 +in_channels: 1 +in_feats: 80 +in_kernel_size: 3 +in_stride: 1 +no_maxpool: true +embed_dim: 256 +dropout_rate: 0.0 diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_data_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_data_default.yaml new file mode 100644 index 00000000..6a763f47 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_data_default.yaml @@ -0,0 +1,10 @@ +dataset: + max_chunk_length: 2.0 + min_chunk_length: 2.0 + aug_cfg: conf/reverb_noise_aug.yaml +sampler: + batch_size: 48 # 52:64:4 OOM w/ lresnet34 in this version of hyperion (It wasn't in the old version). + iters_per_epoch: 1 +data_loader: + num_workers: 8 + diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_resnet34_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_resnet34_xvec_default.yaml new file mode 100644 index 00000000..1d387790 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_resnet34_xvec_default.yaml @@ -0,0 +1,7 @@ +data: + train: train_data_default.yaml + val: val_data_default.yaml +feats: fbank80_stmn_16k.yaml +model: resnet34.yaml +trainer: trainer_default.yaml + \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/trainer_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/trainer_default.yaml new file mode 100644 index 00000000..279b2829 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/trainer_default.yaml @@ -0,0 +1,6 @@ +optim: optim_adam_default.yaml +lrsched: lrsched_cos_default.yaml +use_amp: true +log_interval: 100 +epochs: 70 +eff_batch_size: 128 # For now, this won't be used in dinossl diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/val_data_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/val_data_default.yaml new file mode 100644 index 00000000..127af82c --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/val_data_default.yaml @@ -0,0 +1,10 @@ +dataset: + max_chunk_length: 2.0 + min_chunk_length: 2.0 + aug_cfg: conf/reverb_noise_aug.yaml +sampler: + batch_size: 32 + iters_per_epoch: 1 +data_loader: + num_workers: 0 #8 + diff --git a/egs/voxceleb/dinossl.v1/conf/ecapatdnn_small.yaml b/egs/voxceleb/dinossl.v1/conf/ecapatdnn_small.yaml new file mode 100644 index 00000000..fd386500 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/ecapatdnn_small.yaml @@ -0,0 +1,34 @@ +resnet_enc: + in_feats: 80 + in_conv_channels: 512 + in_kernel_size: 5 + in_stride: 1 + resb_type: seres2bn + resb_repeats: + - 1 + - 1 + - 1 + resb_channels: + - 512 + resb_kernel_sizes: + - 3 + resb_dilations: + - 2 + - 3 + - 4 + resb_strides: + - 1 + res2net_width_factor: 1 + res2net_scale: 8 + se_r: 4 + multilayer: true + multilayer_concat: true + endpoint_channels: 1536 +pool_net: + pool_type: ch-wise-att-mean+stddev + inner_feats: 128 +embed_dim: 256 +cos_scale: 30.0 +margin: 0.3 +margin_warmup_epochs: 20.0 +dropout_rate: 0.0 diff --git a/egs/voxceleb/dinossl.v1/conf/efficientnet_b4.yaml b/egs/voxceleb/dinossl.v1/conf/efficientnet_b4.yaml new file mode 100644 index 00000000..f87c1e02 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/efficientnet_b4.yaml @@ -0,0 +1,20 @@ +effnet_type: efficientnet-b4 +in_feats: 80 +in_channels: 1 +in_kernel_size: 3 +in_stride: 1 +se_r: 4 +fix_stem_head: true +mbconv_strides: +- 1 +- 1 +- 2 +- 2 +- 1 +- 2 +- 1 +embed_dim: 256 +cos_scale: 30.0 +margin: 0.3 +margin_warmup_epochs: 20.0 +dropout_rate: 0.0 diff --git a/egs/voxceleb/dinossl.v1/conf/efficientnet_b7.yaml b/egs/voxceleb/dinossl.v1/conf/efficientnet_b7.yaml new file mode 100644 index 00000000..bae5c7cb --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/efficientnet_b7.yaml @@ -0,0 +1,22 @@ +effnet_type: efficientnet-b7 +in_feats: 80 +in_channels: 1 +in_kernel_size: 3 +in_stride: 1 +se_r: 4 +fix_stem_head: true +mbconv_strides: +- 1 +- 1 +- 2 +- 2 +- 1 +- 2 +- 1 +embed_dim: 256 +cos_scale: 30.0 +margin: 0.3 +margin_warmup_epochs: 20.0 +dropout_rate: 0.0 +norm_layer: instance-norm-affine +head_norm_layer: layer-norm diff --git a/egs/voxceleb/dinossl.v1/conf/fbank64_8k.yaml b/egs/voxceleb/dinossl.v1/conf/fbank64_8k.yaml new file mode 100644 index 00000000..a77eb899 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/fbank64_8k.yaml @@ -0,0 +1,7 @@ +sample_frequency: 8000 +frame_length: 25 +low_freq: 20 +high_freq: 3700 +num_filters: 64 +snip_edges: false +use_energy: false diff --git a/egs/voxceleb/dinossl.v1/conf/fbank64_stmn_8k.yaml b/egs/voxceleb/dinossl.v1/conf/fbank64_stmn_8k.yaml new file mode 100644 index 00000000..dfd0d3e5 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/fbank64_stmn_8k.yaml @@ -0,0 +1,12 @@ +audio_feats: + audio_feat: logfb + sample_frequency: 8000 + frame_length: 25 + low_freq: 20 + high_freq: 3700 + num_filters: 64 + snip_edges: false + use_energy: false +mvn: + context: 150 + norm_var: false diff --git a/egs/voxceleb/dinossl.v1/conf/fbank80_16k.yaml b/egs/voxceleb/dinossl.v1/conf/fbank80_16k.yaml new file mode 100644 index 00000000..88bae69e --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/fbank80_16k.yaml @@ -0,0 +1,7 @@ +sample_frequency: 16000 +frame_length: 25 +low_freq: 20 +high_freq: 7600 +num_filters: 80 +snip_edges: false +use_energy: false diff --git a/egs/voxceleb/dinossl.v1/conf/fbank80_stmn_16k.yaml b/egs/voxceleb/dinossl.v1/conf/fbank80_stmn_16k.yaml new file mode 100644 index 00000000..f4091f5d --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/fbank80_stmn_16k.yaml @@ -0,0 +1,12 @@ +audio_feats: + audio_feat: logfb + sample_frequency: 16000 + frame_length: 25 + low_freq: 20 + high_freq: 7600 + num_filters: 80 + snip_edges: false + use_energy: false +mvn: + context: 150 + norm_var: false diff --git a/egs/voxceleb/dinossl.v1/conf/lrsched_exp_default.yaml b/egs/voxceleb/dinossl.v1/conf/lrsched_exp_default.yaml new file mode 100644 index 00000000..fe08b704 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/lrsched_exp_default.yaml @@ -0,0 +1,7 @@ +lrsch_type: exp_lr +decay_rate: 0.5 +decay_steps: 8000 +hold_steps: 40000 +min_lr: 1.0e-05 +update_lr_on_opt_step: true +warmup_steps: 1000 diff --git a/egs/voxceleb/dinossl.v1/conf/noise_aug.yaml b/egs/voxceleb/dinossl.v1/conf/noise_aug.yaml new file mode 100644 index 00000000..7e575faf --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/noise_aug.yaml @@ -0,0 +1,19 @@ +noise_aug: + noise_prob: 0.7 + noise_types: + noise: + weight: 1 + noise_path: data/musan_noise_proc_audio/wav.scp + min_snr: 0 + max_snr: 18 + music: + weight: 1 + noise_path: data/musan_music_proc_audio/wav.scp + min_snr: 3 + max_snr: 18 + babble: + weight: 1 + noise_path: data/musan_speech_babble/wav.scp + min_snr: 3 + max_snr: 18 + diff --git a/egs/voxceleb/dinossl.v1/conf/online_pitch.conf b/egs/voxceleb/dinossl.v1/conf/online_pitch.conf new file mode 100644 index 00000000..926bcfca --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/online_pitch.conf @@ -0,0 +1 @@ +--sample-frequency=8000 diff --git a/egs/voxceleb/dinossl.v1/conf/optim_adam_default.yaml b/egs/voxceleb/dinossl.v1/conf/optim_adam_default.yaml new file mode 100644 index 00000000..b6620069 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/optim_adam_default.yaml @@ -0,0 +1,6 @@ +opt_type: adam +lr: 0.05 +amsgrad: true +beta1: 0.9 +beta2: 0.95 +weight_decay: 1.0e-05 diff --git a/egs/voxceleb/dinossl.v1/conf/res2net50.yaml b/egs/voxceleb/dinossl.v1/conf/res2net50.yaml new file mode 100644 index 00000000..48067a3d --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/res2net50.yaml @@ -0,0 +1,13 @@ +resnet_type: res2net50 +in_channels: 1 +in_feats: 80 +in_kernel_size: 3 +in_stride: 1 +no_maxpool: true +res2net_width_factor: 3.25 +res2net_scale: 8 +embed_dim: 256 +cos_scale: 30.0 +margin: 0.3 +margin_warmup_epochs: 20.0 +dropout_rate: 0.0 diff --git a/egs/voxceleb/dinossl.v1/conf/resnet34.yaml b/egs/voxceleb/dinossl.v1/conf/resnet34.yaml new file mode 100644 index 00000000..98695823 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/resnet34.yaml @@ -0,0 +1,11 @@ +resnet_type: resnet34 +in_channels: 1 +in_feats: 80 +in_kernel_size: 3 +in_stride: 1 +no_maxpool: true +embed_dim: 256 +cos_scale: 30.0 +margin: 0.3 +margin_warmup_epochs: 20.0 +dropout_rate: 0.0 diff --git a/egs/voxceleb/dinossl.v1/conf/reverb_noise_aug.yaml b/egs/voxceleb/dinossl.v1/conf/reverb_noise_aug.yaml new file mode 100644 index 00000000..4fdf8068 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/reverb_noise_aug.yaml @@ -0,0 +1,35 @@ +reverb_aug: + reverb_prob: 0.45 + max_reverb_context: 0.5 + rir_types: + smallroom: + weight: 1 + rir_path: scp:data/rirs_smallroom/rirs.scp + rir_norm: max + mediumroom: + weight: 1 + rir_path: scp:data/rirs_mediumroom/rirs.scp + rir_norm: max + realroom: + weight: 1 + rir_path: scp:data/rirs_real/rirs.scp + rir_norm: max +noise_aug: + noise_prob: 0.7 + noise_types: + noise: + weight: 1 + noise_path: data/musan_noise_proc_audio/wav.scp + min_snr: 0 + max_snr: 18 + music: + weight: 1 + noise_path: data/musan_music_proc_audio/wav.scp + min_snr: 3 + max_snr: 18 + babble: + weight: 1 + noise_path: data/musan_speech_babble/wav.scp + min_snr: 3 + max_snr: 18 + diff --git a/egs/voxceleb/dinossl.v1/conf/spinenet49.yaml b/egs/voxceleb/dinossl.v1/conf/spinenet49.yaml new file mode 100644 index 00000000..66b8d517 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/spinenet49.yaml @@ -0,0 +1,11 @@ +spinenet_type: spinenet49 +in_channels: 1 +in_feats: 80 +in_kernel_size: 3 +in_stride: 1 +no_maxpool: true +embed_dim: 256 +cos_scale: 30.0 +margin: 0.3 +margin_warmup_epochs: 20.0 +dropout_rate: 0.0 diff --git a/egs/voxceleb/dinossl.v1/conf/train_data_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_data_default.yaml new file mode 100644 index 00000000..451ffa35 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/train_data_default.yaml @@ -0,0 +1,10 @@ +dataset: + max_chunk_length: 4.0 + min_chunk_length: 4.0 + aug_cfg: conf/reverb_noise_aug.yaml +sampler: + batch_size: 32 + iters_per_epoch: 6 +data_loader: + num_workers: 8 + \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/conf/train_ecapatdnn_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_ecapatdnn_xvec_default.yaml new file mode 100644 index 00000000..46298946 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/train_ecapatdnn_xvec_default.yaml @@ -0,0 +1,7 @@ +data: + train: train_data_default.yaml + val: val_data_default.yaml +feats: fbank80_stmn_16k.yaml +model: ecapatdnn_small.yaml +trainer: trainer_default.yaml + \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/conf/train_effnetb4_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_effnetb4_xvec_default.yaml new file mode 100644 index 00000000..1bc74de6 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/train_effnetb4_xvec_default.yaml @@ -0,0 +1,7 @@ +data: + train: train_data_default.yaml + val: val_data_default.yaml +feats: fbank80_stmn_16k.yaml +model: efficientnet_b4.yaml +trainer: trainer_default.yaml + \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/conf/train_res2net50_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_res2net50_xvec_default.yaml new file mode 100644 index 00000000..1d387790 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/train_res2net50_xvec_default.yaml @@ -0,0 +1,7 @@ +data: + train: train_data_default.yaml + val: val_data_default.yaml +feats: fbank80_stmn_16k.yaml +model: resnet34.yaml +trainer: trainer_default.yaml + \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/conf/train_resnet34_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_resnet34_xvec_default.yaml new file mode 100644 index 00000000..1d387790 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/train_resnet34_xvec_default.yaml @@ -0,0 +1,7 @@ +data: + train: train_data_default.yaml + val: val_data_default.yaml +feats: fbank80_stmn_16k.yaml +model: resnet34.yaml +trainer: trainer_default.yaml + \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/conf/train_spinenet49_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_spinenet49_xvec_default.yaml new file mode 100644 index 00000000..07167987 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/train_spinenet49_xvec_default.yaml @@ -0,0 +1,7 @@ +data: + train: train_data_default.yaml + val: val_data_default.yaml +feats: fbank80_stmn_16k.yaml +model: spinenet49.yaml +trainer: trainer_default.yaml + \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/conf/trainer_default.yaml b/egs/voxceleb/dinossl.v1/conf/trainer_default.yaml new file mode 100644 index 00000000..86dcc2e4 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/trainer_default.yaml @@ -0,0 +1,6 @@ +optim: optim_adam_default.yaml +lrsched: lrsched_exp_default.yaml +use_amp: true +log_interval: 1000 +epochs: 70 +eff_batch_size: 512 diff --git a/egs/voxceleb/dinossl.v1/conf/trainer_swa_default.yaml b/egs/voxceleb/dinossl.v1/conf/trainer_swa_default.yaml new file mode 100644 index 00000000..0cafad01 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/trainer_swa_default.yaml @@ -0,0 +1,9 @@ +optim: optim_adam_default.yaml +lrsched: lrsched_exp_default.yaml +use_amp: true +log_interval: 1000 +epochs: 80 +eff_batch_size: 512 +swa_start: 60 +swa_lr: 1e-3 +swa_anneal_epochs: 5 diff --git a/egs/voxceleb/dinossl.v1/conf/vad_16k.yaml b/egs/voxceleb/dinossl.v1/conf/vad_16k.yaml new file mode 100644 index 00000000..5fb0111c --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/vad_16k.yaml @@ -0,0 +1,8 @@ +sample_frequency: 16000 +frame_shift: 10 +frame_length: 25 +snip_edges: false +vad_energy_threshold: 5.5 +vad_energy_mean_scale: 0.5 +vad_proportion_threshold: 0.12 +vad_frames_context: 2 diff --git a/egs/voxceleb/dinossl.v1/conf/vad_8k.yaml b/egs/voxceleb/dinossl.v1/conf/vad_8k.yaml new file mode 100644 index 00000000..7592c9d1 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/vad_8k.yaml @@ -0,0 +1,8 @@ +sample_frequency: 8000 +frame_shift: 10 +frame_length: 25 +snip_edges: false +vad_energy_threshold: 5.5 +vad_energy_mean_scale: 0.5 +vad_proportion_threshold: 0.12 +vad_frames_context: 2 diff --git a/egs/voxceleb/dinossl.v1/conf/val_data_default.yaml b/egs/voxceleb/dinossl.v1/conf/val_data_default.yaml new file mode 100644 index 00000000..451ffa35 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/conf/val_data_default.yaml @@ -0,0 +1,10 @@ +dataset: + max_chunk_length: 4.0 + min_chunk_length: 4.0 + aug_cfg: conf/reverb_noise_aug.yaml +sampler: + batch_size: 32 + iters_per_epoch: 6 +data_loader: + num_workers: 8 + \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/datapath.sh b/egs/voxceleb/dinossl.v1/datapath.sh new file mode 100644 index 00000000..a7d277f4 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/datapath.sh @@ -0,0 +1,22 @@ +# Copyright +# 2018 Johns Hopkins University (Author: Jesus Villalba) +# +# Paths to the databases used in the experiment + + +if [ "$(hostname --domain)" == "clsp.jhu.edu" ];then + voxceleb1_root=/export/corpora5/VoxCeleb1_v1 #voxceleb1 v1 + # voxceleb1_root=/export/corpora5/VoxCeleb1_v2 #voxceleb1 v2 + voxceleb2_root=/export/corpora5/VoxCeleb2 + musan_root=/export/corpora5/JHU/musan +elif [ "$(hostname --domain)" == "cm.gemini" ];then + # voxceleb1_root=/expscratch/dsnyder/VoxCeleb1 #voxceleb1 v1 + voxceleb1_root=/exp/jvillalba/corpora/voxceleb1 #voxceleb1 v2 + voxceleb2_root=/expscratch/dgromero/corpora-open/vox2 + musan_root=/expscratch/dgromero/corpora-open/musan +else + echo "Put your database paths here" + exit 1 +fi + + diff --git a/egs/voxceleb/dinossl.v1/default_config.sh b/egs/voxceleb/dinossl.v1/default_config.sh new file mode 120000 index 00000000..a4876a2e --- /dev/null +++ b/egs/voxceleb/dinossl.v1/default_config.sh @@ -0,0 +1 @@ +global_conf/dinossl_tuning/config_fbank80_stmn_lresnet34_e256_do0_b96_amp.dinossl.v1.sh \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/global_conf/dinossl_tuning/config_fbank80_stmn_lresnet34_e256_do0_b96_amp.dinossl.v1.sh b/egs/voxceleb/dinossl.v1/global_conf/dinossl_tuning/config_fbank80_stmn_lresnet34_e256_do0_b96_amp.dinossl.v1.sh new file mode 100644 index 00000000..6c86e285 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/global_conf/dinossl_tuning/config_fbank80_stmn_lresnet34_e256_do0_b96_amp.dinossl.v1.sh @@ -0,0 +1,63 @@ +# LResNet34 x-vector with mixed precision training +# Some variables defined here are just used to define nnet_name, which is better to be fixed + +nnet_name_tag="" # to manage file names for expdir. For example, utilize this when running multiple exps for hyp. para. tuning + +# acoustic features +feat_config=conf/fbank80_stmn_16k.yaml +feat_type=fbank80_stmn # just to define nnet_name. When changing this, fix the part in ${xvec_train_base_cfg} too to be applied in the actual training setup + +#vad +vad_config=conf/vad_16k.yaml + +# x-vector training +nnet_data=voxceleb2_train + +# x-vector cfg + +nnet_type=resnet + +resnet_type=lresnet34 +batch_size_1gpu=48 # 52:64:4 OOM +ngpu=2 +eff_batch_size=`expr $batch_size_1gpu \* $ngpu` # In dinossl, eff_batch_size is the same as ngpu * batch_size_1gpu since grad_acc_steps is always 1 in dinossl for now. Thus, when eff_batch_size changes, instead of changing grad_acc_steps w/ a fixed lr, lr (base_value in cosine_scheduler, to be exact) is adjusted linearly proportional to eff_batch_size where the base value is 0.005 as as a default w/ eff_batch_size=256. For example, if eff_batch_size=128, the base value is 0.0025 in dinossl. eff_batch_size is calculated in python scripts but one here is to compose nnet_name below. # just to define nnet_name. When changing this, fix the part in ${xvec_train_base_cfg} too to be applied in the actual training setup +dropout=0 # just to define nnet_name. When changing this, fix the part in ${xvec_train_base_cfg} too to be applied in the actual training setup +embed_dim=256 # just to define nnet_name. When changing this, fix the part in ${xvec_train_base_cfg} too to be applied in the actual training setup + +xvec_train_base_cfg=conf/dinossl_tuning/train_resnet34_xvec_default.yaml +xvec_train_args="--data.train.sampler.batch-size $batch_size_1gpu --model.resnet-type $resnet_type" + +# dinossl related (in addition to ones defined in xvec_train_base_cfg): dataset/dataloader, model/loss +## dino-head +dinossl_out_dim=65536 +dinossl_use_bn_in_head=false +dinossl_norm_last_layer=true +## data-augmentation +dinossl_local_crops_number=4 +## teacher temperature +dinossl_warmup_teacher_temp=0.04 +dinossl_teacher_temp=0.04 +dinossl_warmup_teacher_temp_epochs=0 +## chunk sampling related +dinossl_chunk_len_mult=2 # a factor by which long chunk length increases from short chunk length. The short chunk length is determined randomly between min_chunk and max_chunk set above + +nnet_name=${feat_type}_${resnet_type}_e${embed_dim}_do${dropout}_b${eff_batch_size}_amp.dinossl.v1 +if [[ -n ${nnet_name_tag} ]]; then + nnet_name=${nnet_name}_${nnet_name_tag} +fi + + +nnet_dir=exp/xvector_nnets/$nnet_name +nnet=$nnet_dir/model_ep0070.pth + + +# back-end +state_dict_key=model_teacher_state_dict +plda_num_augs=0 +plda_data=voxceleb1_test_train +plda_type=splda +lda_dim=200 +plda_y_dim=150 +plda_z_dim=200 + + diff --git a/egs/voxceleb/dinossl.v1/hyp_utils b/egs/voxceleb/dinossl.v1/hyp_utils new file mode 120000 index 00000000..f6d1eb7a --- /dev/null +++ b/egs/voxceleb/dinossl.v1/hyp_utils @@ -0,0 +1 @@ +../../../hyp_utils \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/list_run.sh b/egs/voxceleb/dinossl.v1/list_run.sh new file mode 100644 index 00000000..8fcdf759 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/list_run.sh @@ -0,0 +1,11 @@ +# Why this is directory is created: To run DINO-based SSL for utterance-level embedding learning + +# How this directory is created: +## 1. Run below in the parent directory: +## cp -r v1.1 dinossl.v1 +## 2. STOPGAP: To train a model in the CLSP grid W/O data prep. The data prep. part (before run_011*) will be updated after updating training-related codes first. +## ln -s /export/c01/jcho/hyperion_DINO/egs/voxceleb/v2/dataa data (for dinossl) + +# (WIP) training script +## (GOOD) bash run_511_train_xvector.sh --ngpu 1 +## (TODO: for multiple gpus) diff --git a/egs/voxceleb/dinossl.v1/local b/egs/voxceleb/dinossl.v1/local new file mode 120000 index 00000000..740b697d --- /dev/null +++ b/egs/voxceleb/dinossl.v1/local @@ -0,0 +1 @@ +../v1/local/ \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/path.sh b/egs/voxceleb/dinossl.v1/path.sh new file mode 100755 index 00000000..6994fdab --- /dev/null +++ b/egs/voxceleb/dinossl.v1/path.sh @@ -0,0 +1,5 @@ + +export HYP_ROOT=$(readlink -f `pwd -P`/../../..) +export TOOLS_ROOT=$HYP_ROOT/tools + +. $TOOLS_ROOT/path.sh diff --git a/egs/voxceleb/dinossl.v1/run_001_prepare_data.sh b/egs/voxceleb/dinossl.v1/run_001_prepare_data.sh new file mode 100755 index 00000000..5ff5a9d6 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/run_001_prepare_data.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright +# 2018 Johns Hopkins University (Author: Jesus Villalba) +# Apache 2.0. +# +. ./cmd.sh +. ./path.sh +set -e + +stage=1 +config_file=default_config.sh + +. parse_options.sh || exit 1; +. datapath.sh + + +if [ $stage -le 1 ];then + # Prepare the VoxCeleb2 dataset for training. + local/make_voxceleb2.pl $voxceleb2_root dev 16 data/voxceleb2_train +fi + +if [ $stage -le 2 ];then + # prepare voxceleb1 for test + # This script is for the old version of the dataset and NOT processing LANG, + # GENDER, NAT + local/make_voxceleb1_oldversion_oeh.pl $voxceleb1_root data + # This script is for the old version of the dataset: + # local/make_voxceleb1_oeh.pl $voxceleb1_root data + # Use this for the newer version of voxceleb1: + # local/make_voxceleb1_v2_oeh.pl $voxceleb1_root data +fi diff --git a/egs/voxceleb/dinossl.v1/run_002_compute_evad.sh b/egs/voxceleb/dinossl.v1/run_002_compute_evad.sh new file mode 100755 index 00000000..44153358 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/run_002_compute_evad.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright +# 2018 Johns Hopkins University (Author: Jesus Villalba) +# Apache 2.0. +# +. ./cmd.sh +. ./path.sh +set -e +nodes=fs01 +storage_name=$(date +'%m_%d_%H_%M') +vaddir=`pwd`/exp/vad_e +vad_config=conf/vad_16k.yaml + +stage=1 +config_file=default_config.sh + +. parse_options.sh || exit 1; +. $config_file + + +if [ $stage -le 1 ]; then + # Prepare to distribute data over multiple machines + if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $vaddir/storage ]; then + dir_name=$USER/hyp-data/voxceleb/dinossl.v1/$storage_name/vad/storage + if [ "$nodes" == "b0" ];then + utils/create_split_dir.pl \ + utils/create_split_dir.pl \ + /export/b{04,05,06,07}/$dir_name $vaddir/storage + elif [ "$nodes" == "b1" ];then + utils/create_split_dir.pl \ + /export/b{14,15,16,17}/$dir_name $vaddir/storage + elif [ "$nodes" == "c0" ];then + utils/create_split_dir.pl \ + /export/c{06,07,08,09}/$dir_name $vaddir/storage + elif [ "$nodes" == "fs01" ];then + utils/create_split_dir.pl \ + /export/fs01/$dir_name $vaddir/storage + else + echo "we don't distribute data between multiple machines" + fi + fi +fi + +#Train datasets +if [ $stage -le 2 ];then + for name in voxceleb2_train voxceleb1_test + do + num_spk=$(wc -l data/$name/spk2utt | awk '{ print $1}') + nj=$(($num_spk < 40 ? $num_spk:40)) + hyp_utils/feats/make_evad.sh --write-utt2num-frames true \ + --vad-config $vad_config --nj $nj --cmd "$train_cmd" \ + data/${name} exp/make_vad/$name $vaddir + utils/fix_data_dir.sh data/${name} + done +fi + + diff --git a/egs/voxceleb/dinossl.v1/run_003_prepare_noises_rirs.sh b/egs/voxceleb/dinossl.v1/run_003_prepare_noises_rirs.sh new file mode 100755 index 00000000..4297f7fb --- /dev/null +++ b/egs/voxceleb/dinossl.v1/run_003_prepare_noises_rirs.sh @@ -0,0 +1,65 @@ +#!/bin/bash +# Copyright +# 2020 Johns Hopkins University (Author: Jesus Villalba) +# Apache 2.0. +# +. ./cmd.sh +. ./path.sh +set -e + +stage=1 +. parse_options.sh || exit 1; +. datapath.sh + +# We prepare the noise files and RIR for online speech augmentation + +if [ $stage -le 1 ]; then + + # Prepare the MUSAN corpus, which consists of music, speech, and noise + # suitable for augmentation. + local/make_musan.sh $musan_root 16 data + + for name in musan_noise musan_music + do + steps_xvec/preprocess_audios_for_nnet_train.sh --nj 10 --cmd "$train_cmd" \ + --storage_name voxceleb-dinossl.v1-$(date +'%m_%d_%H_%M') \ + data/${name} data/${name}_proc_audio exp/${name}_proc_audio + utils/fix_data_dir.sh data/${name}_proc_audio + done + +fi + +if [ $stage -le 2 ]; then + + # Create Babble noise from MUSAN speech files + for name in musan_speech + do + steps_xvec/make_babble_noise_for_nnet_train.sh --cmd "$train_cmd" \ + --storage_name voxceleb-dinossl.v1-$(date +'%m_%d_%H_%M') \ + data/${name} data/${name}_babble exp/${name}_babble + # utils/fix_data_dir.sh data/${name}_babble + done +fi + +if [ $stage -le 3 ]; then + if [ ! -d "RIRS_NOISES" ]; then + if [ -d ../../sre19-cmn2/v1/RIRS_NOISES ];then + ln -s ../../sre19-cmn2/v1/RIRS_NOISES + else + # Download the package that includes the real RIRs, simulated RIRs, isotropic noises and point-source noises + wget --no-check-certificate http://www.openslr.org/resources/28/rirs_noises.zip + unzip rirs_noises.zip + fi + fi + local/make_rirs_data.sh RIRS_NOISES/simulated_rirs/smallroom 16 data/rirs_smallroom + local/make_rirs_data.sh RIRS_NOISES/simulated_rirs/mediumroom 16 data/rirs_mediumroom + local/make_rirs_data.sh RIRS_NOISES/real_rirs_isotropic_noises 16 data/rirs_real + for rirs in rirs_smallroom rirs_mediumroom rirs_real + do + #pack all rirs in h5 files + steps_xvec/pack_rirs_for_nnet_train.sh data/$rirs data/$rirs exp/rirs/$rirs + done + +fi + + diff --git a/egs/voxceleb/dinossl.v1/run_010_prepare_xvec_train_data.sh b/egs/voxceleb/dinossl.v1/run_010_prepare_xvec_train_data.sh new file mode 100755 index 00000000..fa643852 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/run_010_prepare_xvec_train_data.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# Copyright +# 2020 Johns Hopkins University (Author: Jesus Villalba) +# Apache 2.0. +# +. ./cmd.sh +. ./path.sh +set -e + +stage=1 +config_file=default_config.sh + +. parse_options.sh || exit 1; +. $config_file + +if [ $stage -le 2 ]; then + # This script preprocess audio for x-vector training + steps_xvec/preprocess_audios_for_nnet_train.sh --nj 40 --cmd "$train_cmd" \ + --storage_name voxceleb-dinossl.v1-$(date +'%m_%d_%H_%M') --use-bin-vad true \ + data/${nnet_data} data/${nnet_data}_proc_audio_no_sil exp/${nnet_data}_proc_audio_no_sil + hyp_utils/kaldi/utils/fix_data_dir.sh data/${nnet_data}_proc_audio_no_sil + +fi + +if [ $stage -le 3 ]; then + # Now, we remove files with less than 4s. This removes ~ 6.4% of the + # number of the original samples for voxceleb2_train. + hyp_utils/remove_short_audios.sh --min-len 4 data/${nnet_data}_proc_audio_no_sil +fi + +if [ $stage -le 4 ]; then + # Prepare train and validation lists for x-vectors. JJ: This might use + # speaker labels but validation list won't be used in self-supervised + # learning. (In the future, it may be used w/o labels for better validation) + local/make_train_lists_sup_embed_with_augm.sh \ + data/${nnet_data}_proc_audio_no_sil \ + data/${nnet_data}_proc_audio_no_sil/lists_xvec +fi + +exit diff --git a/egs/voxceleb/dinossl.v1/run_011_train_xvector.sh b/egs/voxceleb/dinossl.v1/run_011_train_xvector.sh new file mode 100755 index 00000000..17d50722 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/run_011_train_xvector.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# Copyright +# 2019 Johns Hopkins University (Author: Jesus Villalba) +# Apache 2.0. +# +. ./cmd.sh +. ./path.sh +set -e + +stage=1 +ngpu=4 +config_file=default_config.sh +interactive=false +num_workers="" +use_tb=false +use_wandb=false + +. parse_options.sh || exit 1; +. $config_file +. datapath.sh + +list_dir=data/${nnet_data}_proc_audio_no_sil + +#add extra args from the command line arguments +if [ -n "$num_workers" ];then + extra_args="--data.train.data_loader.num-workers $num_workers" +fi +if [ "$use_tb" == "true" ];then + extra_args="$extra_args --trainer.use-tensorboard" +fi +if [ "$use_wandb" == "true" ];then + extra_args="$extra_args --trainer.use-wandb --trainer.wandb.project voxceleb-v1.1 --trainer.wandb.name $nnet_name.$(date -Iminutes)" +fi + +if [ "$interactive" == "true" ];then + export cuda_cmd=run.pl +fi + +# Network Training +if [ $stage -le 1 ]; then + + + mkdir -p $nnet_dir/log + $cuda_cmd \ + --gpu $ngpu $nnet_dir/log/train.log \ + hyp_utils/conda_env.sh --conda-env $HYP_ENV --num-gpus $ngpu \ + train_xvector_from_wav.py $nnet_type --cfg $xvec_train_base_cfg $xvec_train_args $extra_args \ + --data.train.dataset.audio-file $list_dir/wav.scp \ + --data.train.dataset.time-durs-file $list_dir/utt2dur \ + --data.train.dataset.key-file $list_dir/lists_xvec/train.scp \ + --data.train.dataset.class-file $list_dir/lists_xvec/class2int \ + --data.val.dataset.audio-file $list_dir/wav.scp \ + --data.val.dataset.time-durs-file $list_dir/utt2dur \ + --data.val.dataset.key-file $list_dir/lists_xvec/val.scp \ + --trainer.exp-path $nnet_dir $args \ + --num-gpus $ngpu \ + +fi + diff --git a/egs/voxceleb/dinossl.v1/run_030_extract_xvectors.sh b/egs/voxceleb/dinossl.v1/run_030_extract_xvectors.sh new file mode 100755 index 00000000..3abf2ff6 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/run_030_extract_xvectors.sh @@ -0,0 +1,61 @@ +#!/bin/bash +# Copyright +# 2020 Johns Hopkins University (Author: Jesus Villalba) +# Apache 2.0. +# +. ./cmd.sh +. ./path.sh +set -e + +stage=1 +config_file=default_config.sh +use_gpu=false +xvec_chunk_length=12800 +. parse_options.sh || exit 1; +. $config_file + +if [ "$use_gpu" == "true" ];then + xvec_args="--use-gpu true --chunk-length $xvec_chunk_length" + xvec_cmd="$cuda_eval_cmd --mem 4G" +else + xvec_cmd="$train_cmd --mem 12G" +fi + +xvector_dir=exp/xvectors/$nnet_name + +if [ $stage -le 1 ]; then + # Extract xvectors for training LDA/PLDA + for name in voxceleb2cat_train + do + if [ $plda_num_augs -eq 0 ]; then + steps_xvec/extract_xvectors_from_wav.sh --cmd "$xvec_cmd" --nj 100 ${xvec_args} \ + --random-utt-length true --min-utt-length 400 --max-utt-length 14000 \ + --feat-config $feat_config \ + $nnet data/${name} \ + $xvector_dir/${name} + else + steps_xvec/extract_xvectors_from_wav.sh --cmd "$xvec_cmd" --nj 300 ${xvec_args} \ + --random-utt-length true --min-utt-length 400 --max-utt-length 14000 \ + --feat-config $feat_config --aug-config $plda_aug_config --num-augs $plda_num_augs \ + $nnet data/${name} \ + $xvector_dir/${name}_augx${plda_num_augs} \ + data/${name}_augx${plda_num_augs} + fi + done +fi + + +if [ $stage -le 2 ]; then + # Extracts x-vectors for evaluation + for name in voxceleb1_test + do + num_spk=$(wc -l data/$name/spk2utt | awk '{ print $1}') + nj=$(($num_spk < 100 ? $num_spk:100)) + steps_xvec/extract_xvectors_from_wav.sh --cmd "$xvec_cmd --mem 6G" --nj $nj ${xvec_args} \ + --feat-config $feat_config \ + $nnet data/$name \ + $xvector_dir/$name + done +fi + +exit diff --git a/egs/voxceleb/dinossl.v1/run_040_eval_be.sh b/egs/voxceleb/dinossl.v1/run_040_eval_be.sh new file mode 100755 index 00000000..cd168180 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/run_040_eval_be.sh @@ -0,0 +1,125 @@ +#!/bin/bash +# Copyright 2018 Johns Hopkins University (Author: Jesus Villalba) +# +# Apache 2.0. +# +. ./cmd.sh +. ./path.sh +set -e + +stage=1 +config_file=default_config.sh + +. parse_options.sh || exit 1; +. $config_file +. datapath.sh + +plda_label=${plda_type}y${plda_y_dim}_v1 +be_name=lda${lda_dim}_${plda_label}_${plda_data} + +xvector_dir=exp/xvectors/$nnet_name +be_dir=exp/be/$nnet_name/$be_name +score_dir=exp/scores/$nnet_name/${be_name} +score_plda_dir=$score_dir/plda +score_cosine_dir=exp/scores/$nnet_name/cosine + +if [ $stage -le 1 ]; then + + echo "Train PLDA on Voxceleb2" + steps_be/train_be_v1.sh --cmd "$train_cmd" \ + --lda_dim $lda_dim \ + --plda_type $plda_type \ + --y_dim $plda_y_dim --z_dim $plda_z_dim \ + $xvector_dir/$plda_data/xvector.scp \ + data/$plda_data \ + $be_dir & + + + wait + +fi + + +if [ $stage -le 2 ];then + + echo "Eval Voxceleb 1 with LDA+CentWhiten+LNorm+PLDA" + steps_be/eval_be_v1.sh --cmd "$train_cmd" --plda_type $plda_type \ + data/voxceleb1_test/trials \ + data/voxceleb1_test/utt2model \ + $xvector_dir/voxceleb1_test/xvector.scp \ + $be_dir/lda_lnorm.h5 \ + $be_dir/plda.h5 \ + $score_plda_dir/voxceleb1_scores + + $train_cmd --mem 10G --num-threads 6 $score_plda_dir/log/score_voxceleb1.log \ + local/score_voxceleb1.sh data/voxceleb1_test $score_plda_dir + + for f in $(ls $score_plda_dir/*_results); + do + echo $f + cat $f + echo "" + done + +fi + + +score_plda_dir=$score_cosine_dir + +if [ $stage -le 3 ];then + + echo "Eval Voxceleb 1 with Cosine scoring" + steps_be/eval_be_cos.sh --cmd "$train_cmd" \ + data/voxceleb1_test/trials \ + data/voxceleb1_test/utt2model \ + $xvector_dir/voxceleb1_test/xvector.scp \ + $score_plda_dir/voxceleb1_scores + + $train_cmd --mem 10G --num-threads 6 $score_plda_dir/log/score_voxceleb1.log \ + local/score_voxceleb1.sh data/voxceleb1_test $score_plda_dir + + for f in $(ls $score_plda_dir/*_results); + do + echo $f + cat $f + echo "" + done + +fi + +be_dir=exp/be/$nnet_name/cw +score_plda_dir=$score_dir/cw_cosine + +if [ $stage -le 4 ]; then + echo "Train centering+whitening on Voxceleb2" + steps_be/train_be_v2.sh --cmd "$train_cmd" \ + $xvector_dir/$plda_data/xvector.scp \ + data/$plda_data \ + $be_dir +fi + + +if [ $stage -le 5 ];then + + echo "Eval Voxceleb 1 with CentWhiten + Cosine scoring" + steps_be/eval_be_v2.sh --cmd "$train_cmd" \ + data/voxceleb1_test/trials \ + data/voxceleb1_test/utt2model \ + $xvector_dir/voxceleb1_test/xvector.scp \ + $be_dir/cw.h5 \ + $score_plda_dir/voxceleb1_scores + + $train_cmd --mem 10G --num-threads 6 $score_plda_dir/log/score_voxceleb1.log \ + local/score_voxceleb1.sh data/voxceleb1_test $score_plda_dir + + for f in $(ls $score_plda_dir/*_results); + do + echo $f + cat $f + echo "" + done + +fi + +exit + diff --git a/egs/voxceleb/dinossl.v1/run_511_train_xvector.sh b/egs/voxceleb/dinossl.v1/run_511_train_xvector.sh new file mode 100755 index 00000000..19a37a75 --- /dev/null +++ b/egs/voxceleb/dinossl.v1/run_511_train_xvector.sh @@ -0,0 +1,85 @@ +#!/bin/bash +# Copyright +# 2019 Johns Hopkins University (Author: Jesus Villalba) +# Apache 2.0. +# +. ./cmd.sh +. ./path.sh +set -e + +stage=1 +ngpu=1 +config_file=default_config.sh +interactive=false +num_workers="" +use_tb=false +use_wandb=false + +. parse_options.sh || exit 1; +. $config_file +. datapath.sh + +list_dir=data/${nnet_data}_proc_audio_no_sil + +#add extra args from the command line arguments +if [ -n "$num_workers" ];then + extra_args="--data.train.data_loader.num-workers $num_workers" +fi +if [ "$use_tb" == "true" ];then + extra_args="$extra_args --trainer.use-tensorboard" +fi +if [ "$use_wandb" == "true" ];then + extra_args="$extra_args --trainer.use-wandb --trainer.wandb.project voxceleb-dinossl.v1 --trainer.wandb.name $nnet_name.$(date -Iminutes)" +fi + +if [ "$interactive" == "true" ];then + export cuda_cmd=run.pl +fi + +# Network Training +if [ $stage -le 1 ]; then + # dino arguments + dinossl_args="--dinossl true " + for arg in dinossl_nlayers dinossl_out_dim dinossl_use_bn_in_head dinossl_norm_last_layer \ + dinossl_local_crops_number dinossl_warmup_teacher_temp dinossl_teacher_temp \ + dinossl_warmup_teacher_temp_epochs dinossl_chunk_len_mult dinossl_reduce_overlap_prob; do + if [ ! -z ${!arg} ]; then + dinossl_args+="--${arg} ${!arg} " # ${!arg} return a value in the var, "${arg}" + fi + done + echo "Dino arguments: ${dinossl_args}" + + # Edit train.scp and class2int files to ignore class balancing in batching (to + # simulate a unsupervised scenario). Simply make class_idx == utt_idx + # train.utt2utt.scp + if [ ! -s ${list_dir}/lists_xvec/train.utt2utt.scp ]; then + awk '{print $1" "$1}' ${list_dir}/lists_xvec/train.scp > ${list_dir}/lists_xvec/train.utt2utt.scp + fi + # (This block can be ignored) val.utt2utt.scp although it is not used in the end + if [ ! -s ${list_dir}/lists_xvec/val.utt2utt.scp ]; then + awk '{print $1" "$1}' ${list_dir}/lists_xvec/val.scp > ${list_dir}/lists_xvec/val.utt2utt.scp + fi + # utt2int + if [ ! -s ${list_dir}/lists_xvec/utt2int ]; then + cat <(awk '{print $1}' ${list_dir}/lists_xvec/train.scp) <(awk '{print $1}' ${list_dir}/lists_xvec/val.scp) > ${list_dir}/lists_xvec/utt2int + fi + + + + mkdir -p $nnet_dir/log + $cuda_cmd \ + --gpu $ngpu $nnet_dir/log/train.log \ + hyp_utils/conda_env.sh --conda-env $HYP_ENV --num-gpus $ngpu \ + train_xvector_from_wav_dinossl.py $nnet_type --cfg $xvec_train_base_cfg $xvec_train_args $extra_args ${dinossl_args} \ + --data.train.dataset.audio-file $list_dir/wav.scp \ + --data.train.dataset.time-durs-file $list_dir/utt2dur \ + --data.train.dataset.key-file $list_dir/lists_xvec/train.utt2utt.scp \ + --data.train.dataset.class-file $list_dir/lists_xvec/utt2int \ + --data.val.dataset.audio-file $list_dir/wav.scp \ + --data.val.dataset.time-durs-file $list_dir/utt2dur \ + --data.val.dataset.key-file $list_dir/lists_xvec/val.utt2utt.scp \ + --trainer.exp-path $nnet_dir $args \ + --num-gpus $ngpu \ + +fi + diff --git a/egs/voxceleb/dinossl.v1/steps b/egs/voxceleb/dinossl.v1/steps new file mode 120000 index 00000000..aede39fe --- /dev/null +++ b/egs/voxceleb/dinossl.v1/steps @@ -0,0 +1 @@ +hyp_utils/kaldi/steps \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/steps_be b/egs/voxceleb/dinossl.v1/steps_be new file mode 120000 index 00000000..b2098c2a --- /dev/null +++ b/egs/voxceleb/dinossl.v1/steps_be @@ -0,0 +1 @@ +../v1/steps_be \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/steps_fe b/egs/voxceleb/dinossl.v1/steps_fe new file mode 120000 index 00000000..73ccc1eb --- /dev/null +++ b/egs/voxceleb/dinossl.v1/steps_fe @@ -0,0 +1 @@ +hyp_utils/kaldi/vad \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/steps_pyfe b/egs/voxceleb/dinossl.v1/steps_pyfe new file mode 120000 index 00000000..7b9d122a --- /dev/null +++ b/egs/voxceleb/dinossl.v1/steps_pyfe @@ -0,0 +1 @@ +hyp_utils/feats \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/steps_xvec b/egs/voxceleb/dinossl.v1/steps_xvec new file mode 120000 index 00000000..af66a94d --- /dev/null +++ b/egs/voxceleb/dinossl.v1/steps_xvec @@ -0,0 +1 @@ +hyp_utils/xvectors \ No newline at end of file diff --git a/egs/voxceleb/dinossl.v1/utils b/egs/voxceleb/dinossl.v1/utils new file mode 120000 index 00000000..3d590a1d --- /dev/null +++ b/egs/voxceleb/dinossl.v1/utils @@ -0,0 +1 @@ +hyp_utils/kaldi/utils \ No newline at end of file diff --git a/egs/voxceleb/v1/local/make_voxceleb1_oldversion_oeh.pl b/egs/voxceleb/v1/local/make_voxceleb1_oldversion_oeh.pl new file mode 100755 index 00000000..f6eb8f35 --- /dev/null +++ b/egs/voxceleb/v1/local/make_voxceleb1_oldversion_oeh.pl @@ -0,0 +1,130 @@ +#!/usr/bin/perl +# Note: this is an old version script in the commit id 1eb12f2ed01801a50c3f7ba014809bf7c7212f28. NOT process LANG, GENDER, NAT +# Copyright 2018 Ewald Enzinger +# 2018 David Snyder +# 2020 Jesus Villalba +# +# Usage: make_voxceleb1.pl /export/voxceleb1 data/ +# Create trial lists for Voxceleb1 original, Entire (E) and hard (H), +# with cleaned and non-cleaned versions + +if (@ARGV != 2) { + print STDERR "Usage: $0 \n"; + print STDERR "e.g. $0 /export/voxceleb1 data/\n"; + exit(1); +} + +($data_base, $out_dir) = @ARGV; +my $out_dir = "$out_dir/voxceleb1_test"; + +if (system("mkdir -p $out_dir") != 0) { + die "Error making directory $out_dir"; +} + +my $url_base="http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta"; +my @trials_basename = ("very_test.txt", "very_test2.txt", "list_test_hard.txt", "list_test_hard2.txt", "list_test_all.txt", "list_test_all2.txt"); +my @trials_url = ("$url_base/veri_test.txt", "$url_base/veri_test2.txt", "$url_base/list_test_hard.txt", "$url_base/list_test_hard2.txt", "$url_base/list_test_all.txt", "$url_base/list_test_all2.txt"); +my @trials = ("trials_o", "trials_o_clean", "trials_h", "trials_h_clean", "trials_e", "trials_e_clean"); + +open(META_IN, "<", "$data_base/vox1_meta.csv") or die "Could not open the meta data file $data_base/vox1_meta.csv"; +my %id2spkr = (); +while () { + chomp; + my ($vox_id, $spkr_id, $gender, $nation, $set) = split; + $id2spkr{$vox_id} = $spkr_id; + +} +close(META_IN) or die; + +#download trials from voxceleb web page +for($i = 0; $i <= $#trials; $i++) { + + my $file_i = "$out_dir/$trials_basename[$i]"; + my $url_i = $trials_url[$i]; + my $trial_i = "$out_dir/$trials[$i]"; + if (! -e $file_i) { + system("wget -O $file_i $url_i"); + } + #mapping from new speaker ids and file-names to old ones + open(TRIAL_IN, "<", "$file_i") or die "Could not open the verification trials file $file_i"; + open(TRIAL_OUT, ">", "$trial_i") or die "Could not open the output file $trial_i"; + while () { + chomp; + my ($tar_or_non, $path1, $path2) = split; + + # Create entry for left-hand side of trial + my ($vox_id, $rec_id, $segment) = split('/', $path1); + $segment =~ s/\.wav$//; + my $spkr_id = $id2spkr{$vox_id}; + my $utt_id1 = "$spkr_id-$rec_id-00$segment"; + + # Create entry for right-hand side of trial + my ($vox_id, $rec_id, $segment) = split('/', $path2); + $segment =~ s/\.wav$//; + my $spkr_id = $id2spkr{$vox_id}; + my $utt_id2 = "$spkr_id-$rec_id-00$segment"; + + my $target = "nontarget"; + if ($tar_or_non eq "1") { + $target = "target"; + } + print TRIAL_OUT "$utt_id1 $utt_id2 $target\n"; + } + + close(TRIAL_IN) or die; + close(TRIAL_OUT) or die; + +} + + +opendir my $dh, "$data_base/voxceleb1_wav" or die "Cannot open directory: $!"; +my @spkr_dirs = grep {-d "$data_base/voxceleb1_wav/$_" && ! /^\.{1,2}$/} readdir($dh); +closedir $dh; + +open(SPKR_TEST, ">", "$out_dir/utt2spk") or die "Could not open the output file $out_dir/utt2spk"; +open(WAV_TEST, ">", "$out_dir/wav.scp") or die "Could not open the output file $out_dir/wav.scp"; + +foreach (@spkr_dirs) { + my $spkr_id = $_; + my $new_spkr_id = $spkr_id; + # If we're using a newer version of VoxCeleb1, we need to "deanonymize" + # the speaker labels. + if (exists $id2spkr{$spkr_id}) { + $new_spkr_id = $id2spkr{$spkr_id}; + } + opendir my $dh, "$data_base/voxceleb1_wav/$spkr_id/" or die "Cannot open directory: $!"; + my @files = map{s/\.[^.]+$//;$_}grep {/\.wav$/} readdir($dh); + closedir $dh; + foreach (@files) { + my $filename = $_; + my $rec_id = substr($filename, 0, 11); + my $segment = substr($filename, 12, 7); + my $wav = "$data_base/voxceleb1_wav/$spkr_id/$filename.wav"; + my $utt_id = "$new_spkr_id-$rec_id-$segment"; + print WAV_TEST "$utt_id", " $wav", "\n"; + print SPKR_TEST "$utt_id", " $new_spkr_id", "\n"; + } +} + +close(SPKR_TEST) or die; +close(WAV_TEST) or die; + +if (system( + "cat $out_dir/trials_* | sort -u > $out_dir/trials") != 0) { + die "Error creating trials file in directory $out_dir"; +} + +if (system( + "awk '{ print \$1,\$1 }' $out_dir/trials | sort -u > $out_dir/utt2model") != 0) { + die "Error creating utt2model file in directory $out_dir"; +} + +if (system( + "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { + die "Error creating spk2utt file in directory $out_dir"; +} +system("env LC_COLLATE=C utils/fix_data_dir.sh $out_dir"); +if (system("env LC_COLLATE=C utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) { + die "Error validating directory $out_dir"; +} + diff --git a/egs/voxceleb/v1/local/make_voxceleb2.pl b/egs/voxceleb/v1/local/make_voxceleb2.pl index e0ebeb0f..b2bd0d71 100755 --- a/egs/voxceleb/v1/local/make_voxceleb2.pl +++ b/egs/voxceleb/v1/local/make_voxceleb2.pl @@ -1,5 +1,6 @@ #!/usr/bin/perl -# +# Note: Compared to local/make_voxceleb2cat.pl 1) This does NOT concatenate same speaker recording turns in the conversation +# 2) skippied the part to get LANG, GENDER metadata # Copyright 2018 Johns Hopkins University (Jesus Villalba) # Copyright 2018 Ewald Enzinger # @@ -32,59 +33,20 @@ $dataset_path = "$data_base/$dataset" } +opendir my $dh, "$dataset_path" or die "Cannot open directory: $!"; +my @spkr_dirs = grep {-d "$dataset_path/$_" && ! /^\.{1,2}$/} readdir($dh); +closedir $dh; if (system("mkdir -p $out_dir") != 0) { die "Error making directory $out_dir"; } - -my $meta_url = "https://www.openslr.org/resources/49/vox2_meta.csv"; -my $meta_path = "$data_base/vox2_meta.csv"; -if (! -e "$meta_path") { - $meta_path = "$out_dir/vox2_meta.csv"; - system("wget -O $meta_path $meta_url"); -} -open(META_IN, "<", "$meta_path") or die "Could not open the output file $meta_path"; -my %spkr2gender = (); -while () { - chomp; - my ($spkr, $vox_id, $vgg_id, $gender, $set) = split; - $spkr2gender{$vox_id} = $gender; -} -close(META_IN) or die; - -my $lid_url = "https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data_workshop_2021/lang_vox2_final.csv"; -my $lid_path = "$data_base/lang_vox2_final.csv"; -if (! -e "$lid_path") { - $lid_path = "$out_dir/lang_vox2_final.csv"; - system("wget -O $lid_path $lid_url"); -} -open(LID_IN, "<", "$lid_path") or die "Could not open the output file $lid_path"; -my %utt2lang = (); -while () { - chomp; - my ($utt_id, $lang, $score) = split ','; - $utt_id =~ s@/@-@g; - $utt_id =~ s@\.wav$@@; - $utt2lang{$utt_id} = $lang; -} -close(LID_IN) or die; - - open(SPKR, ">", "$out_dir/utt2spk") or die "Could not open the output file $out_dir/utt2spk"; open(WAV, ">", "$out_dir/wav.scp") or die "Could not open the output file $out_dir/wav.scp"; -open(LANG, ">", "$out_dir/utt2lang") or die "Could not open the output file $out_dir/utt2lang"; -open(GENDER, ">", "$out_dir/spk2gender") or die "Could not open the output file $out_dir/spk2gender"; - -opendir my $dh, "$dataset_path" or die "Cannot open directory: $!"; -my @spkr_dirs = grep {-d "$dataset_path/$_" && ! /^\.{1,2}$/} readdir($dh); -closedir $dh; foreach (@spkr_dirs) { my $spkr_id = $_; - print GENDER "$spkr_id $spkr2gender{$spkr_id}\n"; - opendir my $dh, "$dataset_path/$spkr_id/" or die "Cannot open directory: $!"; my @rec_dirs = grep {-d "$dataset_path/$spkr_id/$_" && ! /^\.{1,2}$/} readdir($dh); closedir $dh; @@ -105,19 +67,11 @@ my $utt_id = "$spkr_id-$rec_id-$name"; print WAV "$utt_id", " $wav", "\n"; print SPKR "$utt_id", " $spkr_id", "\n"; - if (exists $utt2lang{$utt_id}) { - print LANG "$utt_id", " $utt2lang{$utt_id}", "\n"; - } - else { - print LANG "$utt_id N/A\n"; - } } } } close(SPKR) or die; close(WAV) or die; -close(LANG) or die; -close(GENDER) or die; if (system( "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) { diff --git a/hyp_utils/feats/make_evad.sh b/hyp_utils/feats/make_evad.sh index 373fc4a6..377c6e7b 100755 --- a/hyp_utils/feats/make_evad.sh +++ b/hyp_utils/feats/make_evad.sh @@ -84,11 +84,42 @@ if [ -f $data/segments ]; then opt_args="${opt_args} --segments $data/segments" fi +set +e $cmd JOB=1:$nj $logdir/make_vad_${name}.JOB.log \ hyp_utils/conda_env.sh \ compute_energy_vad.py --cfg $vad_config $opt_args \ --input $scp --output ark,scp:$vaddir/vad_$name.JOB.ark,$vaddir/vad_$name.JOB.scp \ - --part-idx JOB --num-parts $nj || exit 1 + --part-idx JOB --num-parts $nj +set -e + +# rerun not successful jobs +for tmp in {1..3};do + pids="" + + for((i=1;i<=$nj;i++)) + do + status=$(tail -n 1 $logdir/make_vad_${name}.$i.log | \ + awk '/status 0/ { print 0} + !/status 0/ { print 1}') + if [ $status -eq 1 ];then + echo "JOB $i failed, resubmitting" + sleep 10 + opt_args=`echo ${opt_args} | sed -e "s/utt2num_frames.JOB/utt2num_frames.$i/g"` + $cmd $logdir/make_vad_${name}.$i.log \ + hyp_utils/conda_env.sh \ + compute_energy_vad.py --cfg $vad_config $opt_args \ + --input $scp --output ark,scp:$vaddir/vad_$name.$i.ark,$vaddir/vad_$name.$i.scp \ + --part-idx $i --num-parts $nj & + opt_args=`echo ${opt_args} | sed -e "s/utt2num_frames.$i/utt2num_frames.JOB/g"` + pids="$pids $!" + fi + done + + for pid in $pids;do + wait $pid + done +done +wait # concatenate the .scp files together. for n in $(seq $nj); do diff --git a/hyp_utils/kaldi/utils/fix_data_dir.sh b/hyp_utils/kaldi/utils/fix_data_dir.sh index bb18e07b..ed080eee 100755 --- a/hyp_utils/kaldi/utils/fix_data_dir.sh +++ b/hyp_utils/kaldi/utils/fix_data_dir.sh @@ -117,7 +117,7 @@ function filter_speakers { ${kaldi_utils}/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers - for s in cmvn.scp spk2gender spk2nation; do + for s in cmvn.scp spk2gender; do f=$data/$s if [ -f $f ]; then filter_file $f $tmpdir/speakers @@ -127,7 +127,7 @@ function filter_speakers { filter_file $tmpdir/speakers $data/spk2utt ${kaldi_utils}/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk - for s in cmvn.scp spk2gender spk2nation $spk_extra_files; do + for s in cmvn.scp spk2gender $spk_extra_files; do f=$data/$s if [ -f $f ]; then filter_file $tmpdir/speakers $f diff --git a/hyperion/bin/torch-extract-xvectors-from-wav.py b/hyperion/bin/torch-extract-xvectors-from-wav.py new file mode 100644 index 00000000..90969722 --- /dev/null +++ b/hyperion/bin/torch-extract-xvectors-from-wav.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python +""" + Copyright 2019 Jesus Villalba (Johns Hopkins University) + Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +""" + +import sys +import os +from jsonargparse import ( + ArgumentParser, + ActionConfigFile, + ActionParser, + namespace_to_dict, +) +import time +import logging + +import numpy as np +import pandas as pd + +import torch + +from hyperion.hyp_defs import config_logger, float_cpu, set_float_cpu +from hyperion.utils import Utt2Info +from hyperion.io import DataWriterFactory as DWF +from hyperion.io import SequentialAudioReader as AR +from hyperion.io import VADReaderFactory as VRF +from hyperion.augment import SpeechAugment + +from hyperion.torch.utils import open_device +from hyperion.torch.utils import dinossl +from hyperion.torch.narchs import AudioFeatsMVN as AF +from hyperion.torch import TorchModelLoader as TML + + +def init_device(use_gpu): + set_float_cpu("float32") + num_gpus = 1 if use_gpu else 0 + logging.info("initializing devices num_gpus={}".format(num_gpus)) + device = open_device(num_gpus=num_gpus) + return device + + +def init_feats(device, **kwargs): + feat_args = AF.filter_args(**kwargs["feats"]) + logging.info("feat args={}".format(feat_args)) + logging.info("initializing feature extractor") + feat_extractor = AF(trans=False, **feat_args) + logging.info("feat-extractor={}".format(feat_extractor)) + feat_extractor.eval() + feat_extractor.to(device) + return feat_extractor + + +def load_model(model_path, device, state_dict_key='model_state_dict', dinossl_kwargs=None): + logging.info("loading model {}".format(model_path)) + model = TML.load(model_path, state_dict_key=state_dict_key ,dinossl_kwargs=dinossl_kwargs) + logging.info("xvector-model={}".format(model)) + model.to(device) + model.eval() + return model + + +def augment(key0, x0, augmenter, aug_df, aug_id): + if augmenter is None: + x = x0 + key = key0 + else: + x, aug_info = augmenter(x0) + key = "%s-aug-%02d" % (key0, aug_id) + aug_df_row = { + "key_aug": key, + "key_orig": key0, + "noise_type": aug_info["noise"]["noise_type"], + "snr": aug_info["noise"]["snr"], + "rir_type": aug_info["reverb"]["rir_type"], + "srr": aug_info["reverb"]["srr"], + "sdr": aug_info["sdr"], + } + + aug_df.append(pd.DataFrame(aug_df_row, index=[0])) + + return key, x + + +def select_random_chunk(key, x, min_utt_length, max_utt_length, rng): + utt_length = rng.randint(low=min_utt_length, high=max_utt_length + 1) + if utt_length < x.shape[1]: + first_frame = rng.randint(low=0, high=x.shape[1] - utt_length) + x = x[:, first_frame : first_frame + utt_length] + logging.info( + "extract-random-utt %s of length=%d first-frame=%d" + % (key, x.shape[1], first_frame) + ) + return x + + +def extract_xvectors( + input_spec, + output_spec, + vad_spec, + write_num_frames_spec, + scp_sep, + vad_path_prefix, + model_path, + chunk_length, + embed_layer, + random_utt_length, + min_utt_length, + max_utt_length, + aug_cfg, + num_augs, + aug_info_path, + use_gpu, + **kwargs +): + + rng = np.random.RandomState(seed=1123581321 + kwargs["part_idx"]) + device = init_device(use_gpu) + feat_extractor = init_feats(device, **kwargs) + if kwargs['dinossl']: + dinossl_kwargs={k:kwargs[k] for k in kwargs if 'dinossl' in k} + model = load_model(model_path, device, state_dict_key=kwargs['state_dict_key'], dinossl_kwargs=dinossl_kwargs) + else: + model = load_model(model_path, device, state_dict_key=kwargs['state_dict_key']) + + if write_num_frames_spec is not None: + keys = [] + info = [] + + if aug_cfg is not None: + augmenter = SpeechAugment.create(aug_cfg, rng=rng) + aug_df = [] + else: + augmenter = None + aug_df = None + num_augs = 1 + + ar_args = AR.filter_args(**kwargs) + logging.info("opening output stream: %s" % (output_spec)) + with DWF.create(output_spec, scp_sep=scp_sep) as writer: + + logging.info( + "opening input stream: {} with args={}".format(input_spec, ar_args) + ) + with AR(input_spec, **ar_args) as reader: + + if vad_spec is not None: + logging.info("opening VAD stream: %s" % (vad_spec)) + v_reader = VRF.create( + vad_spec, path_prefix=vad_path_prefix, scp_sep=scp_sep + ) + + while not reader.eof(): + t1 = time.time() + key, x0, fs = reader.read(1) + if len(key) == 0: + break + + x0 = x0[0] + key0 = key[0] + t2 = time.time() + + logging.info("processing utt %s" % (key0)) + for aug_id in range(num_augs): + t3 = time.time() + key, x = augment(key0, x0, augmenter, aug_df, aug_id) + t4 = time.time() + with torch.no_grad(): + x = torch.tensor( + x[None, :], dtype=torch.get_default_dtype() + ).to(device) + + x = feat_extractor(x) + t5 = time.time() + tot_frames = x.shape[1] + if vad_spec is not None: + vad = v_reader.read(key0, num_frames=tot_frames)[0] + vad = torch.tensor(vad, dtype=torch.bool).to(device) + x = x[:, vad] + + logging.info( + "utt %s detected %d/%d (%.2f %%) speech frames" + % ( + key, + x.shape[1], + tot_frames, + x.shape[1] / tot_frames * 100, + ) + ) + + if random_utt_length: + x = select_random_chunk( + key, x, min_utt_length, max_utt_length, rng + ) + + t6 = time.time() + if x.shape[1] == 0: # JJ: EXP - this case is not taken care of for the dinossl case + y = np.zeros((model.embed_dim,), dtype=float_cpu()) + else: + x = x.transpose(1, 2).contiguous() + y = ( + model.extract_embed( + x, + chunk_length=chunk_length, + embed_layer=embed_layer, + ) + .cpu() + .numpy()[0] + ) + + t7 = time.time() + writer.write([key], [y]) + if write_num_frames_spec is not None: + keys.append(key) + info.append(str(x.shape[1])) + + t8 = time.time() + read_time = t2 - t1 + tot_time = read_time + t8 - t3 + logging.info( + ( + "utt %s total-time=%.3f read-time=%.3f " + "aug-time=%.3f feat-time=%.3f " + "vad-time=%.3f embed-time=%.3f write-time=%.3f " + "rt-factor=%.2f" + ) + % ( + key, + tot_time, + read_time, + t4 - t3, + t5 - t4, + t6 - t5, + t7 - t6, + t8 - t7, + x0.shape[0] / fs[0] / tot_time, + ) + ) + + if write_num_frames_spec is not None: + logging.info("writing num-frames to %s" % (write_num_frames_spec)) + u2nf = Utt2Info.create(keys, info) + u2nf.save(write_num_frames_spec) + + if aug_info_path is not None: + aug_df = pd.concat(aug_df, ignore_index=True) + aug_df.to_csv(aug_info_path, index=False, na_rep="n/a") + + +if __name__ == "__main__": + + parser = ArgumentParser( + description=( + "Extracts x-vectors from waveform computing " "acoustic features on the fly" + ) + ) + + parser.add_argument("--cfg", action=ActionConfigFile) + parser.add_argument("--dinossl_cfg", action=ActionConfigFile) + parser.add_argument("--input", dest="input_spec", required=True) + parser.add_argument("--vad", dest="vad_spec", default=None) + parser.add_argument( + "--write-num-frames", dest="write_num_frames_spec", default=None + ) + parser.add_argument("--scp-sep", default=" ", help=("scp file field separator")) + parser.add_argument( + "--vad-path-prefix", default=None, help=("scp file_path prefix for vad") + ) + + AR.add_class_args(parser) + + parser.add_argument("--aug-cfg", default=None) + parser.add_argument("--aug-info-path", default=None) + parser.add_argument( + "--num-augs", default=1, type=int, help="number of augmentations per utterance" + ) + + AF.add_class_args(parser, prefix="feats") + + parser.add_argument("--model-path", required=True) + parser.add_argument( + "--chunk-length", + type=int, + default=0, + help=( + "number of frames used in each forward pass " + "of the x-vector encoder," + "if 0 the full utterance is used" + ), + ) + parser.add_argument( + "--embed-layer", + type=int, + default=None, + help=( + "classifier layer to get the embedding from, " + "if None, it uses layer set in training phase" + ), + ) + + parser.add_argument( + "--random-utt-length", + default=False, + action="store_true", + help="calculates x-vector from a random chunk", + ) + parser.add_argument( + "--min-utt-length", + type=int, + default=500, + help=("minimum utterance length when using random utt length"), + ) + parser.add_argument( + "--max-utt-length", + type=int, + default=12000, + help=("maximum utterance length when using random utt length"), + ) + + parser.add_argument("--output", dest="output_spec", required=True) + parser.add_argument( + "--use-gpu", default=False, action="store_true", help="extract xvectors in gpu" + ) + parser.add_argument( + "-v", "--verbose", dest="verbose", default=1, choices=[0, 1, 2, 3], type=int + ) + # dinossl related + parser.add_argument('--state_dict_key', type=str, default='model_state_dict', + choices=['model_state_dict','model_teacher_state_dict'], + help=('key for state_dict of a pre-trained model. Currently model_teacher_state_dict is only possible for dinossl')) + parser.add_argument('--dinossl_xvec_loc', type=str, default='f', + choices=['f', 'dinohead_mlp','dinohead_l2norm','dinohead_linear'], + help=('Where to extract x-vectors from the dinossl model. The naming follows Figure 9 in the DINO paper')) + dinossl.add_dinossl_args(parser) + + args = parser.parse_args() + config_logger(args.verbose) + del args.verbose + logging.debug(args) + + extract_xvectors(**namespace_to_dict(args)) \ No newline at end of file diff --git a/hyperion/bin/train_xvector_from_wav.py b/hyperion/bin/train_xvector_from_wav.py index 0e074977..89a281af 100755 --- a/hyperion/bin/train_xvector_from_wav.py +++ b/hyperion/bin/train_xvector_from_wav.py @@ -213,5 +213,5 @@ def make_parser(xvec_class): args_sc.xvec_class = xvec_dict[xvec_type] # torch docs recommend using forkserver - multiprocessing.set_start_method("forkserver") + multiprocessing.set_start_method("forkserver",force=True) train_xvec(gpu_id, args_sc) diff --git a/hyperion/bin/train_xvector_from_wav_dinossl.py b/hyperion/bin/train_xvector_from_wav_dinossl.py new file mode 100755 index 00000000..f89496ad --- /dev/null +++ b/hyperion/bin/train_xvector_from_wav_dinossl.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python +""" + Copyright 2020 Johns Hopkins University (Author: Jesus Villalba) + Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +""" +import sys +import os +from pathlib import Path +from jsonargparse import ( + ArgumentParser, + ActionConfigFile, + ActionParser, + namespace_to_dict, +) +import time +import logging +import multiprocessing + +import torch + +from hyperion.hyp_defs import config_logger, set_float_cpu +from hyperion.torch.utils import ddp +from hyperion.torch.utils import dinossl +from hyperion.torch.trainers import DINOSSLXVectorTrainerFromWav as Trainer +from hyperion.torch.data import AudioDataset as AD +from hyperion.torch.data import ClassWeightedSeqSampler as Sampler +from hyperion.torch.narchs import AudioFeatsMVN as AF +from hyperion.torch.models import ResNetXVector as RXVec +from hyperion.torch.models import ResNet1dXVector as R1dXVec +from hyperion.torch.models import EfficientNetXVector as EXVec +from hyperion.torch.models import TDNNXVector as TDXVec +from hyperion.torch.models import TransformerXVectorV1 as TFXVec +from hyperion.torch.models import SpineNetXVector as SpineXVec + +xvec_dict = { + "resnet": RXVec, + "resnet1d": R1dXVec, + "efficientnet": EXVec, + "tdnn": TDXVec, + "transformer": TFXVec, + "spinenet": SpineXVec, +} + + +def init_data(partition, rank, num_gpus, **kwargs): + if kwargs["dinossl"]: + dinossl_args = dinossl.filter_args(**kwargs) + kwargs = kwargs["data"][partition] + ad_args = AD.filter_args(**kwargs["dataset"]) + sampler_args = Sampler.filter_args(**kwargs["sampler"]) + if rank == 0: + logging.info("{} audio dataset args={}".format(partition, ad_args)) + logging.info("{} sampler args={}".format(partition, sampler_args)) + logging.info("init %s dataset", partition) + + ad_args["is_val"] = partition == "val" + if dinossl_args["dinossl"]: + dataset = AD(**ad_args,dinossl_chunk_len_mult=dinossl_args["dinossl_chunk_len_mult"], dinossl_n_chunks=dinossl_args["dinossl_local_crops_number"] + 2, dinossl_reduce_overlap_prob=dinossl_args["dinossl_reduce_overlap_prob"]) + else: + dataset = AD(**ad_args) + + if rank == 0: + logging.info("init %s samplers", partition) + + sampler = Sampler(dataset, **sampler_args) + + if rank == 0: + logging.info("init %s dataloader", partition) + + num_workers = kwargs["data_loader"]["num_workers"] + num_workers_per_gpu = int((num_workers + num_gpus - 1) / num_gpus) + largs = ( + {"num_workers": num_workers_per_gpu, "pin_memory": True} if num_gpus > 0 else {} + ) + data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler, **largs) + return data_loader + + +# def init_data( +# audio_path, +# train_list, +# val_list, +# train_aug_cfg, +# val_aug_cfg, +# num_workers, +# num_gpus, +# rank, +# **kwargs +# ): + +# ad_args = AD.filter_args(**kwargs) +# sampler_args = Sampler.filter_args(**kwargs) +# if rank == 0: +# logging.info("audio dataset args={}".format(ad_args)) +# logging.info("sampler args={}".format(sampler_args)) +# logging.info("init datasets") + +# train_data = AD(audio_path, train_list, aug_cfg=train_aug_cfg, **ad_args) +# val_data = AD(audio_path, val_list, aug_cfg=val_aug_cfg, is_val=True, **ad_args) + +# if rank == 0: +# logging.info("init samplers") +# train_sampler = Sampler(train_data, **sampler_args) +# val_sampler = Sampler(val_data, **sampler_args) + +# num_workers_per_gpu = int((num_workers + num_gpus - 1) / num_gpus) +# largs = ( +# {"num_workers": num_workers_per_gpu, "pin_memory": True} if num_gpus > 0 else {} +# ) + +# train_loader = torch.utils.data.DataLoader( +# train_data, batch_sampler=train_sampler, **largs +# ) + +# test_loader = torch.utils.data.DataLoader( +# val_data, batch_sampler=val_sampler, **largs +# ) + +# return train_loader, test_loader + + +def init_feats(rank, **kwargs): + feat_args = AF.filter_args(**kwargs["feats"]) + if rank == 0: + logging.info("feat args={}".format(feat_args)) + logging.info("initializing feature extractor") + feat_extractor = AF(trans=True, **feat_args) + if rank == 0: + logging.info("feat-extractor={}".format(feat_extractor)) + return feat_extractor + + +def init_xvector(num_classes, rank, xvec_class, **kwargs): + xvec_args = xvec_class.filter_args(**kwargs["model"]) + if rank == 0: + logging.info("xvector network args={}".format(xvec_args)) + xvec_args["num_classes"] = num_classes + model = xvec_class(**xvec_args) + if rank == 0: + logging.info("x-vector-model={}".format(model)) + return model + + +def train_xvec(gpu_id, args): + + config_logger(args.verbose) + del args.verbose + logging.debug(args) + + kwargs = namespace_to_dict(args) + torch.manual_seed(args.seed) + set_float_cpu("float32") + + ddp_args = ddp.filter_ddp_args(**kwargs) + device, rank, world_size = ddp.ddp_init(gpu_id, **ddp_args) + kwargs["rank"] = rank + + train_loader = init_data(partition="train", **kwargs) + val_loader = init_data(partition="val", **kwargs) if not kwargs["dinossl"] else None + feat_extractor = init_feats(**kwargs) + model = init_xvector(train_loader.dataset.num_classes, **kwargs) + loss = None + if kwargs["dinossl"]: + dinossl_args = dinossl.filter_args(**kwargs) + model, loss = dinossl.init_dino(model, dinossl_args, rank = rank) + + trn_args = Trainer.filter_args(**kwargs["trainer"]) + trn_args["niter_per_ep"] = len(train_loader) # will be used for DINO-related scheduling + trn_args["batch_size"] = kwargs["data"]["train"]["sampler"]["batch_size"] * kwargs["num_gpus"] + if rank == 0: + logging.info("trainer args={}".format(trn_args)) + metrics = {} + trainer = Trainer( + model, + feat_extractor, + device=device, + metrics=metrics, + ddp=world_size > 1, + loss=loss, + **trn_args, + ) + trainer.load_last_checkpoint() + trainer.fit(train_loader, val_loader) + + ddp.ddp_cleanup() + + +def make_parser(xvec_class): + parser = ArgumentParser() + + parser.add_argument("--cfg", action=ActionConfigFile) + + train_parser = ArgumentParser(prog="") + # parser.add_argument("--audio-path", required=True) + # parser.add_argument("--train-list", required=True) + # parser.add_argument("--val-list", required=True) + + AD.add_class_args(train_parser, prefix="dataset", skip={}) + Sampler.add_class_args(train_parser, prefix="sampler") + # parser.add_argument("--train-aug-cfg", default=None) + # parser.add_argument("--val-aug-cfg", default=None) + train_parser.add_argument( + "--data_loader.num-workers", + type=int, + default=5, + help="num_workers of data loader", + ) + + val_parser = ArgumentParser(prog="") + AD.add_class_args(val_parser, prefix="dataset", skip={}) + Sampler.add_class_args(val_parser, prefix="sampler") + val_parser.add_argument( + "--data_loader.num-workers", + type=int, + default=5, + help="num_workers of data loader", + ) + data_parser = ArgumentParser(prog="") + data_parser.add_argument("--train", action=ActionParser(parser=train_parser)) + data_parser.add_argument("--val", action=ActionParser(parser=val_parser)) + parser.add_argument("--data", action=ActionParser(parser=data_parser)) + parser.link_arguments( + "data.train.dataset.class_file", "data.val.dataset.class_file" + ) + parser.link_arguments( + "data.train.data_loader.num_workers", "data.val.data_loader.num_workers" + ) + parser.link_arguments( + "data.train.sampler.batch_size", "data.val.sampler.batch_size" + ) + + AF.add_class_args(parser, prefix="feats") + xvec_class.add_class_args(parser, prefix="model") + Trainer.add_class_args( + parser, prefix="trainer", train_modes=xvec_class.valid_train_modes() + ) + dinossl.add_dinossl_args(parser) + ddp.add_ddp_args(parser) + parser.add_argument("--seed", type=int, default=1123581321, help="random seed") + # parser.add_argument( + # "--resume", + # action="store_true", + # default=False, + # help="resume training from checkpoint", + # ) + parser.add_argument( + "-v", "--verbose", dest="verbose", default=1, choices=[0, 1, 2, 3], type=int + ) + + return parser + + +if __name__ == "__main__": + + parser = ArgumentParser(description="Train XVector from audio files") + + parser.add_argument("--cfg", action=ActionConfigFile) + + subcommands = parser.add_subcommands() + + for k, v in xvec_dict.items(): + parser_k = make_parser(v) + subcommands.add_subcommand(k, parser_k) + + args = parser.parse_args() + try: + gpu_id = int(os.environ["LOCAL_RANK"]) + except: + gpu_id = 0 + + xvec_type = args.subcommand + args_sc = vars(args)[xvec_type] + + if gpu_id == 0: + try: + config_file = Path(args_sc.trainer.exp_path) / "config.yaml" + parser.save(args, str(config_file), format="yaml", overwrite=True) + except: + pass + + args_sc.xvec_class = xvec_dict[xvec_type] + # torch docs recommend using forkserver + multiprocessing.set_start_method("forkserver",force=True) + train_xvec(gpu_id, args_sc) diff --git a/hyperion/torch/data/audio_dataset.py b/hyperion/torch/data/audio_dataset.py index 439c00ba..cb02c971 100644 --- a/hyperion/torch/data/audio_dataset.py +++ b/hyperion/torch/data/audio_dataset.py @@ -468,6 +468,9 @@ def __init__( target_sample_freq=None, wav_scale=2 ** 15 - 1, is_val=False, + dinossl_chunk_len_mult=None, + dinossl_n_chunks=None, + dinossl_reduce_overlap_prob=0 ): super().__init__() @@ -514,6 +517,19 @@ def __init__( ) self.return_orig = return_orig + # start dino-stuff persephone_dinossl + # dinossl related + # self.dinossl_chunk_len_mult = dinossl_chunk_len_mult + # self.dinossl_n_chunks = dinossl_n_chunks + # self.dinossl_reduce_overlap_prob = dinossl_reduce_overlap_prob + + # self._prepare_class_info(class_file) + + # if max_chunk_length is None: + # max_chunk_length = min_chunk_length + # self._min_chunk_length = min_chunk_length + # self._max_chunk_length = max_chunk_length + # end dinostuff ======= self.num_augs = num_augs self._create_augmenters(aug_cfgs) @@ -589,6 +605,84 @@ def min_seq_length(self): def max_seq_length(self): return np.max(self.seq_lengths) + # start dino stuff <<<<<<< persephone_dinossl + # def _prune_short_seqs(self, min_length): + # if self.rank == 0: + # logging.info("pruning short seqs") + # keep_idx = self.seq_lengths >= min_length + # self.u2c = self.u2c.filter_index(keep_idx) + # self._seq_lengths = self.seq_lengths[keep_idx] + # if self.rank == 0: + # logging.info( + # "pruned seqs with min_length < %f," + # "keep %d/%d seqs" % (min_length, self.num_seqs, len(keep_idx)) + # ) + + # def _prepare_class_info(self, class_file): + # class_weights = None + # if class_file is None: + # classes, class_idx = np.unique(self.u2c.info, return_inverse=True) + # class2idx = {k: i for i, k in enumerate(classes)} + # else: + # if self.rank == 0: + # logging.info("reading class-file %s" % (class_file)) + # class_info = pd.read_csv(class_file, header=None, sep=" ") + # class2idx = {str(k): i for i, k in enumerate(class_info[0])} + # class_idx = np.array([class2idx[k] for k in self.u2c.info], dtype=int) + # if class_info.shape[1] == 2: + # class_weights = np.array(class_info[1]).astype( + # floatstr_torch(), copy=False + # ) + # + # self.num_classes = len(class2idx) + + # class2utt_idx = {} + # class2num_utt = np.zeros((self.num_classes,), dtype=int) + + # for k in range(self.num_classes): + # idx = (class_idx == k).nonzero()[0] + # class2utt_idx[k] = idx + # class2num_utt[k] = len(idx) + # if class2num_utt[k] == 0: + # if (not self.is_val) and (self.dinossl_chunk_len_mult is None): + # logging.warning("class %d doesn't have any samples" % (k)) + # if class_weights is None: + # class_weights = np.ones((self.num_classes,), dtype=floatstr_torch()) + # class_weights[k] = 0 + + # count_empty = np.sum(class2num_utt == 0) + # if count_empty > 0: + # logging.warning("%d classes have 0 samples" % (count_empty)) + + # self.utt_idx2class = class_idx + # self.class2utt_idx = class2utt_idx + # self.class2num_utt = class2num_utt + # if class_weights is not None: + # class_weights /= np.sum(class_weights) + # class_weights = torch.Tensor(class_weights) + # self.class_weights = class_weights + + # if self.short_seq_exist: + # # if there are seq shorter than max_chunk_lenght we need some extra variables + # # we will need class_weights to put to 0 classes that have all utts shorter than the batch chunk length + # if self.class_weights is None: + # self.class_weights = torch.ones((self.num_classes,)) + + # # we need the max length of the utterances of each class + # class2max_length = torch.zeros((self.num_classes,), dtype=torch.float) + # for c in range(self.num_classes): + # if class2num_utt[c] > 0: + # class2max_length[c] = np.max( + # self.seq_lengths[self.class2utt_idx[c]] + # ) + # + # self.class2max_length = class2max_length + + + #def _seq_shorter_than_max_length_exists(self, max_length): + # return np.any(self.seq_lengths < max_length) + #end dinostuff + @property def num_classes(self): return {k: t.num_classes for k, t in self.class_info.items()} @@ -602,6 +696,12 @@ def _parse_segment_item(self, segment): f"chunk duration ({duration})" ) else: + # start dino stuff persephone_dinossl + # if self.dinossl_n_chunks == None: + # return self._get_random_chunk(index) + # else: # multi-chunks for dinossl + # return self._get_random_chunks(index) + # end dino stuff seg_id, start, duration = segment, 0, 0 if "start" in self.seg_set: @@ -695,14 +795,177 @@ def __getitem__(self, segment): r.append(x_orig) else: + # start dinostuff persephone_dinossl + # chunk_length = self.max_chunk_length + + #key = self.u2c.key[index] + + #full_seq_length = self.seq_lengths[index] + #assert ( + # chunk_length <= full_seq_length + #), "chunk_length(%d) <= full_seq_length(%d)" % (chunk_length, full_seq_length) + + #time_offset = torch.rand(size=(1,)).item() * (full_seq_length - chunk_length) + #reverb_context = min(self.reverb_context, time_offset) + #time_offset -= reverb_context + #read_chunk_length = chunk_length + reverb_context + + # logging.info('get-random-chunk {} {} {} {} {}'.format(index, key, time_offset, chunk_length, full_seq_length )) + #x, fs = self.r.read([key], time_offset=time_offset, time_durs=read_chunk_length) + + #x = x[0] + #fs = fs[0] + + #x_clean = x + #if self.augmenter is not None: + # chunk_length_samples = int(chunk_length * fs) + # end_idx = len(x) + # reverb_context_samples = end_idx - chunk_length_samples + # assert reverb_context_samples >= 0, ( + # "key={} time-offset={}, read-chunk={} " + # "read-x-samples={}, chunk_samples={}, reverb_context_samples={}" + # ).format( + # key, + # time_offset, + # read_chunk_length, + # end_idx, + # chunk_length_samples, + # reverb_context_samples, + # ) + # # end_idx = reverb_context_samples + chunk_length_samples + # x, aug_info = self.augmenter(x) + # x = x[reverb_context_samples:end_idx] + # if self.return_clean_aug_pair: + # x_clean = x_clean[reverb_context_samples:end_idx] + # x_clean = x_clean.astype(floatstr_torch(), copy=False) + # # x_clean = x_clean[reverb_context_samples:] + # # logging.info('augmentation x-clean={}, x={}, aug_info={}'.format( + # # x_clean.shape, x.shape, aug_info)) + ## if len(x) != 64000: + ## logging.info('x!=4s, {} {} {} {} {} {} {} {}'.format(len(x),reverb_context, reverb_context_samples, chunk_length, chunk_length_samples, end_idx, fs, read_chunk_length)) + + ## if len(x) != 64000: + ## logging.info('x!=4s-2, {} {} {} {}'.format(len(x), chunk_length, fs, read_chunk_length)) + + #if self.transpose_input: + # x = x[None, :] + # if self.return_clean_aug_pair: + # x_clean = x_clean[None, :] + + #x = x.astype(floatstr_torch(), copy=False) + #if self.return_clean_aug_pair: + # r = x, x_clean + #else: + # r = (x,) + # end dinostuff r = [x] + # adds the segment labels seg_info = self._get_segment_info(seg_id) r.extend(seg_info) return (*r,) + def _get_random_chunks(self, index): + + if len(index) == 2: + index, chunk_length = index + else: + chunk_length = self.max_chunk_length + key = self.u2c.key[index] + + full_seq_length = self.seq_lengths[index] + assert chunk_length <= full_seq_length, 'chunk_length(%d) <= full_seq_length(%d)' % ( + chunk_length, full_seq_length) + + chunk_length_list = [] + # 2 long chunks + if chunk_length * self.dinossl_chunk_len_mult > full_seq_length: + chunk_length_list.extend([full_seq_length]*2) + else: + chunk_length_list.extend([chunk_length * self.dinossl_chunk_len_mult]*2) + # self.n_chunks - 2 short chunks + chunk_length_list.extend([chunk_length]*(self.dinossl_n_chunks-2)) + + r_list = [] # this is for dino's multiple augmentations (more than once) of a given sample + + # to reduce overlap between 2 long chunks + reduce_overlap = (self.dinossl_reduce_overlap_prob > torch.rand(size=(1,))) + if reduce_overlap: + long_chunk_proc_cnt = 0 + tmp = torch.rand(size=(5,))*(full_seq_length - chunk_length_list[0]) + time_offset_long_chunks = [torch.min(tmp), torch.max(tmp)] + + for chunk_length in chunk_length_list: # full_seq_length, self.reverb_context are fixed within this for loop + if reduce_overlap and (long_chunk_proc_cnt < 2): + time_offset = time_offset_long_chunks[long_chunk_proc_cnt] + long_chunk_proc_cnt += 1 + else: + time_offset = torch.rand(size=(1,)).item()*(full_seq_length-chunk_length) + reverb_context = min(self.reverb_context, time_offset) + time_offset -= reverb_context + read_chunk_length = chunk_length + reverb_context + + #logging.info('get-random-chunk {} {} {} {} {}'.format(index, key, time_offset, chunk_length, full_seq_length )) + x, fs = self.r.read([key], time_offset=time_offset, + time_durs=read_chunk_length) + + x = x[0] + fs = fs[0] + + x_clean = x + if self.augmenter is not None: + chunk_length_samples = int(chunk_length * fs) + end_idx = len(x) + reverb_context_samples = end_idx - chunk_length_samples + assert reverb_context_samples >= 0, ( + ('key={} time-offset={}, read-chunk={} ' + 'read-x-samples={}, chunk_samples={}, reverb_context_samples={}').format( + key, time_offset, read_chunk_length, + end_idx, chunk_length_samples, reverb_context_samples)) + # end_idx = reverb_context_samples + chunk_length_samples + x, aug_info = self.augmenter(x) + x = x[reverb_context_samples:end_idx] + if self.return_clean_aug_pair: + x_clean = x_clean[reverb_context_samples:end_idx] + x_clean = x_clean.astype(floatstr_torch(), copy=False) + #x_clean = x_clean[reverb_context_samples:] + #logging.info('augmentation x-clean={}, x={}, aug_info={}'.format( + # x_clean.shape, x.shape, aug_info)) + # if len(x) != 64000: + # logging.info('x!=4s, {} {} {} {} {} {} {} {}'.format(len(x),reverb_context, reverb_context_samples, chunk_length, chunk_length_samples, end_idx, fs, read_chunk_length)) + + # if len(x) != 64000: + # logging.info('x!=4s-2, {} {} {} {}'.format(len(x), chunk_length, fs, read_chunk_length)) + + if self.transpose_input: + x = x[None,:] + if self.return_clean_aug_pair: + x_clean = x_clean[None,:] + + x = x.astype(floatstr_torch(), copy=False) + if self.return_clean_aug_pair: + r = x, x_clean + else: + r = (x,) + r_list.append(*r) + + if len(r_list) == 1: del r_list + + if not self.return_class: + try: + return r_list + except: + return r + + class_idx = self.utt_idx2class[index] + try: + r = r_list, class_idx + except: + r = *r, class_idx + + return r @staticmethod def filter_args(**kwargs): diff --git a/hyperion/torch/lr_schedulers/factory.py b/hyperion/torch/lr_schedulers/factory.py index 3fef6e93..24e85551 100644 --- a/hyperion/torch/lr_schedulers/factory.py +++ b/hyperion/torch/lr_schedulers/factory.py @@ -39,6 +39,7 @@ def create( d_model=None, lr_factor=1, update_lr_on_opt_step=False, + **kwargs ): """Creates a learning rate scheduler object. @@ -195,6 +196,12 @@ def filter_args(**kwargs): "lr_factor", "d_model", "update_lr_on_opt_step", + "dinossl_lr", + "dinossl_min_lr", + "dinossl_warmup_epochs", + "dinossl_weight_decay", + "dinossl_weight_decay_end", + "dinossl_momentum_teacher" ) return dict((k, kwargs[k]) for k in valid_args if k in kwargs) @@ -211,6 +218,7 @@ def add_class_args(parser, prefix=None): default="none", choices=[ "none", + "dinossl", "exp_lr", "invpow_lr", "cos_lr", @@ -340,6 +348,29 @@ def add_class_args(parser, prefix=None): action="store_true", help=("Update lr based on batch number instead of epoch number"), ) + # dinossl related - start + parser.add_argument("--dinossl_lr", default=0.005, type=float, + help=("""Learning rate at the end of linear warmup (highest LR used during training). + The learning rate is linearly scaled with the batch size, and specified here for a + reference batch size of 256.""")) + parser.add_argument("--dinossl_min_lr" , + default=1e-6, type=float, + help=("Target LR at the end of optimization. We use a cosine LR schedule with linear warmup.")) + parser.add_argument("--dinossl_warmup_epochs" , + default=10, type=int, + help=("Number of epochs for the linear learning-rate warm up.")) + parser.add_argument("--dinossl_weight_decay" , + default=0.04, type=float, + help=("Initial value of the weight decay. With ViT, a smaller value at the beginning of training works well.")) + parser.add_argument("--dinossl_weight_decay_end" , + default=0.4, type=float, + help=("""Final value of the weight decay. We use a cosine schedule for WD and using a larger decay by + the end of training improves performance for ViTs.""")) + parser.add_argument("--dinossl_momentum_teacher" , + default=0, type=float, + help=("""Base EMA parameter for teacher update. The value is increased to 1 during training with cosine schedule. + We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""")) + # dinossl related - end if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser)) diff --git a/hyperion/torch/models/xvectors/resnet_xvector.py b/hyperion/torch/models/xvectors/resnet_xvector.py index fe88ff57..ecc64773 100644 --- a/hyperion/torch/models/xvectors/resnet_xvector.py +++ b/hyperion/torch/models/xvectors/resnet_xvector.py @@ -187,6 +187,10 @@ def load(cls, file_path=None, cfg=None, state_dict=None): model = cls(**cfg) if state_dict is not None: + # (hacky coding. may change later for neater codes) + is_classif_net_out = len([ True for k in state_dict.keys() if k[:18] == 'classif_net.output']) # to check this is to know if the training was dinossl or not + if not is_classif_net_out: + model.classif_net.output = nn.Identity() model.load_state_dict(state_dict) return model diff --git a/hyperion/torch/models/xvectors/xvector.py b/hyperion/torch/models/xvectors/xvector.py index 15f0ce86..535cb835 100644 --- a/hyperion/torch/models/xvectors/xvector.py +++ b/hyperion/torch/models/xvectors/xvector.py @@ -222,7 +222,7 @@ def _pre_enc(self, x): return x def _post_enc(self, x, in_lengths=None, max_in_length=None): - if self.encoder_net.out_dim() == 4: + if self.encoder_net.out_dim() == 4 and (not isinstance(self.classif_net,torch.nn.modules.linear.Linear)): x = x.view(x.size(0), -1, x.size(-1)) if self.proj is not None: @@ -286,7 +286,10 @@ class logits tensor with shape=(batch, num_classes). x = self.encoder_net(x) x, x_lengths = self._post_enc(x, x_lengths, max_in_length) p = self.pool_net(x, x_lengths=x_lengths) - y = self.classif_net(p, y) + if isinstance(self.classif_net.output,nn.modules.linear.Identity): # for dino + y = self.classif_net(p) + else: + y = self.classif_net(p, y) return y def forward_hid_feats( @@ -572,6 +575,7 @@ def rebuild_output_layer( intertop_margin=0.0, num_subcenters=2, ): + if ( (self.num_classes is not None and self.num_classes != num_classes) or (self.loss_type != loss_type) diff --git a/hyperion/torch/narchs/classif_head.py b/hyperion/torch/narchs/classif_head.py index adfeceb3..9416d022 100644 --- a/hyperion/torch/narchs/classif_head.py +++ b/hyperion/torch/narchs/classif_head.py @@ -272,7 +272,7 @@ def forward(self, x, y=None): for l in range(self.num_embed_layers): x = self.fc_blocks[l](x) - if self.loss_type == "softmax": + if self.loss_type == "softmax" or isinstance(self.output,nn.modules.linear.Identity): y = self.output(x) else: y = self.output(x, y) diff --git a/hyperion/torch/optim/factory.py b/hyperion/torch/optim/factory.py index ab350098..a7ebec1f 100644 --- a/hyperion/torch/optim/factory.py +++ b/hyperion/torch/optim/factory.py @@ -141,7 +141,7 @@ def create( if base_opt is None: raise Exception("unknown optimizer %s" % opt_type) - if oss: + if oss: # (JJ: this (oss=True) is NOT touched for dinossl_style param filtering so with dinossl_style, the behavior is not yet confirmed) from fairscale.optim.oss import OSS logging.info("Optimizer uses OSS") @@ -171,6 +171,7 @@ def filter_args(**kwargs): "init_acc_val", "max_iter", "oss", + "dinossl_style" ) return filter_args(valid_args, kwargs) @@ -320,6 +321,10 @@ def add_class_args(parser, prefix=None): "--max-iter", default=20, type=int, help=("max iterations in LBGS") ) + parser.add_argument( + '--dinossl_style', default=False, type=bool, + help=('per-parameter updates following FB dino repo to NOT regularize biases nor Norm parameters')) + if prefix is not None: outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser)) # help='optimizer options') diff --git a/hyperion/torch/torch_model_loader.py b/hyperion/torch/torch_model_loader.py index c173cd50..c52c32b5 100644 --- a/hyperion/torch/torch_model_loader.py +++ b/hyperion/torch/torch_model_loader.py @@ -7,6 +7,7 @@ import re import torch +from hyperion.torch.utils import dinossl from .narchs import * from .models import * @@ -34,7 +35,13 @@ def _fix_compatibility(class_obj, cfg): return cfg @staticmethod - def load(file_path, extra_objs={}, map_location=None): + def load(file_path, extra_objs={}, map_location=None, state_dict_key='model_state_dict', dinossl_kwargs=None): + """ + Args: + state_dict_key (str): key for state_dict of a pre-trained model. Currently either + 'model_state_dict' or 'model_teacher_state_dict' (possible option in dinossl) + dinossl_kwargs (dict): DINOHead related arguments to reconstruct the DINOHead module as it was in the traiing + location info for xvector extraction. + """ if map_location is None: map_location = torch.device("cpu") @@ -50,7 +57,8 @@ def load(file_path, extra_objs={}, map_location=None): else: raise Exception("unknown object with class_name=%s" % (class_name)) - state_dict = model_data["model_state_dict"] + state_dict = model_data[state_dict_key] + logging.info('Using state_dict_key: {} of the pre-trained model'.format(state_dict_key)) if "n_averaged" in state_dict: del state_dict["n_averaged"] @@ -58,10 +66,18 @@ def load(file_path, extra_objs={}, map_location=None): cfg = TorchModelLoader._fix_compatibility(class_obj, cfg) p = re.compile("^module\.") + q = re.compile('^backbone\.') # for dinossl num_tries = 3 for tries in range(num_tries): try: - return class_obj.load(cfg=cfg, state_dict=state_dict) + model = class_obj.load(cfg=cfg, state_dict=state_dict) + if (dinossl_kwargs is not None) and (dinossl_kwargs['dinossl_xvec_loc'] != 'f'): # no need when dinossl_kwargs['dinossl_xvec_loc'] == 'f' since it does not requires DINOHead + embed_dim = state_dict_head['mlp.0.weight'].shape[1] + model = dinossl.MultiCropWrapper(model, dinossl.DINOHead(embed_dim, dinossl_kwargs['dinossl_out_dim'], use_bn=dinossl_kwargs['dinossl_use_bn_in_head'], + norm_last_layer=dinossl_kwargs['dinossl_norm_last_layer'], nlayers=dinossl_kwargs['dinossl_nlayers'])) + model.head.load_state_dict(state_dict_head) # putting this into this "try:" block assumes the pre-trained model is always trained with multi-gpus. + model.dinossl_xvec_loc = dinossl_kwargs['dinossl_xvec_loc'] + return model except RuntimeError as err: # remove module prefix when is trained with dataparallel if tries == num_tries - 1: @@ -69,3 +85,8 @@ def load(file_path, extra_objs={}, map_location=None): raise err # remove module prefix when is trained with dataparallel state_dict = ODict((p.sub("", k), v) for k, v in state_dict.items()) + # below three are for dinossl + state_dict = ODict((q.sub('',k), v) for k,v in state_dict.items()) + state_dict_head = ODict((k[5:], v) for k,v in state_dict.items() if (k[:4] == 'head')) + state_dict = ODict((k, v) for k,v in state_dict.items() if not (k[:4] == 'head')) + diff --git a/hyperion/torch/trainers/__init__.py b/hyperion/torch/trainers/__init__.py index 8fef7df5..f4461b39 100644 --- a/hyperion/torch/trainers/__init__.py +++ b/hyperion/torch/trainers/__init__.py @@ -4,14 +4,17 @@ """ from .torch_trainer import TorchTrainer +from .torch_trainer_dinossl import DINOSSLTorchTrainer from .xvector_trainer import XVectorTrainer from .xvector_trainer_deep_feat_reg import XVectorTrainerDeepFeatReg from .xvector_adv_trainer import XVectorAdvTrainer +from .xvector_trainer_dinossl import DINOSSLXVectorTrainer from .xvector_trainer_from_wav import XVectorTrainerFromWav from .xvector_trainer_deep_feat_reg_from_wav import XVectorTrainerDeepFeatRegFromWav from .xvector_adv_trainer_from_wav import XVectorAdvTrainerFromWav +from .xvector_trainer_from_wav_dinossl import DINOSSLXVectorTrainerFromWav from .vae_trainer import VAETrainer from .dvae_trainer import DVAETrainer diff --git a/hyperion/torch/trainers/torch_trainer.py b/hyperion/torch/trainers/torch_trainer.py index 5f573904..94caf1d0 100644 --- a/hyperion/torch/trainers/torch_trainer.py +++ b/hyperion/torch/trainers/torch_trainer.py @@ -3,6 +3,7 @@ Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """ +from hyperion.torch.utils import dinossl import os import math import contextlib @@ -413,7 +414,13 @@ def _make_optimizer(self, optim, model, oss=False): opt_args["oss"] = oss if self.rank == 0: logging.info("optimizer args={}".format(opt_args)) - optimizer = OF.create(model.parameters(), **opt_args) + if opt_args['dinossl_style']: # dinossl_style means per-parameter updates following FB dino repo to NOT regularize biases nor Norm parameters + params_groups = dinossl.get_params_groups(model) + del opt_args['dinossl_style'] + optimizer = OF.create(params_groups, **opt_args) + else: + del opt_args['dinossl_style'] + optimizer = OF.create(model.parameters(), **opt_args) return optimizer def _make_lr_sched(self, lr_sched, optim): diff --git a/hyperion/torch/trainers/torch_trainer_dinossl.py b/hyperion/torch/trainers/torch_trainer_dinossl.py new file mode 100644 index 00000000..d1e5dc4b --- /dev/null +++ b/hyperion/torch/trainers/torch_trainer_dinossl.py @@ -0,0 +1,810 @@ +""" + Copyright 2022 Johns Hopkins University (Author: Jaejin Cho) - changes in var names from original dino repo to this: student --> model & teacher --> model_teacher + Copyright 2019 Johns Hopkins University (Author: Jesus Villalba) + Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +""" + +from hyperion.torch.utils import dinossl +import os +import math +import contextlib +from collections import OrderedDict as ODict +from enum import Enum +from jsonargparse import ArgumentParser, ActionParser +import logging +from pathlib import Path + +import torch +import torch.nn as nn +import torch.cuda.amp as amp +from torch.optim.swa_utils import AveragedModel, SWALR +import torch.distributed as dist + +from fairscale.optim.grad_scaler import ShardedGradScaler + +from ..utils import MetricAcc, TorchDDP, FairShardedDDP, FairFullyShardedDDP +from ..loggers import LoggerList, CSVLogger, ProgLogger, TensorBoardLogger, WAndBLogger +from ..optim import OptimizerFactory as OF +from ..lr_schedulers import LRSchedulerFactory as LRSF +from ..lr_schedulers import LRScheduler as LRS + + +class DDPType(str, Enum): + DDP = "ddp" + OSS_DDP = "oss_ddp" + OSS_SHARDED_DDP = "oss_sharded_ddp" + FULLY_SHARDED_DDP = "fully_sharded_ddp" + + +ddp_choices = [o.value for o in DDPType] + + +class DINOSSLTorchTrainer(object): + """Base Trainer class to train basic neural network models + + Attributes: + model: model object. + loss: nn.Module loss class + optim: pytorch optimizer object or optimizer options dict + epochs: max. number of epochs + exp_path: experiment output path + cur_epoch: current epoch + grad_acc_steps: gradient accumulation steps to simulate larger batch size. + device: cpu/gpu device + metrics: extra metrics to compute besides cxe. + lrsched: learning rate scheduler object + loggers: LoggerList object, loggers write training progress to std. output and file. + ddp: if True use distributed data parallel training + ddp_type: type of distributed data parallel in (ddp, oss_ddp, oss_shared_ddp) + train_mode: training mode in ['full', 'frozen'] + use_amp: uses mixed precision training. + log_interval: number of optim. steps between log outputs + use_tensorboard: use tensorboard logger + use_wandb: use wandb logger + wandb: wandb dictionary of options + grad_clip: norm to clip gradients, if 0 there is no clipping + grad_clip_norm: norm type to clip gradients + swa_start: epoch to start doing swa + swa_lr: SWA learning rate + swa_anneal_epochs: SWA learning rate anneal epochs + cpu_offload: CPU offload of gradients when using fully sharded ddp + """ + + def __init__( + self, + model, + loss, + optim={}, + epochs=100, + exp_path="./train", + cur_epoch=0, + grad_acc_steps=1, + eff_batch_size=None, + device=None, + metrics=None, + lrsched=None, + loggers=None, + ddp=False, + ddp_type="ddp", + train_mode="full", + use_amp=False, + log_interval=10, + use_tensorboard=False, + use_wandb=False, + wandb={}, + grad_clip=0, + grad_clip_norm=2, + swa_start=0, + swa_lr=1e-3, + swa_anneal_epochs=10, + cpu_offload=False, + niter_per_ep=0, + batch_size=0 + ): + + self.model = model[0] + self.model_teacher = model[1] + self.loss = loss + self.epochs = epochs + self.niter_per_ep = niter_per_ep + self.batch_size = batch_size + self.cur_epoch = cur_epoch + self.grad_acc_steps = grad_acc_steps + self.eff_batch_size = eff_batch_size + self.exp_path = Path(exp_path) + + if loggers is None: + self.loggers = self._default_loggers( + log_interval, use_tensorboard, use_wandb, wandb + ) + elif isinstance(loggers, list): + self.loggers = LoggerList(loggers) + else: + self.loggers = loggers + + self.metrics = metrics + self.device = device + self.train_mode = train_mode + self.use_amp = use_amp + self.grad_clip = grad_clip + self.grad_clip_norm = grad_clip_norm + self.swa_start = swa_start + self.do_swa = swa_start > 0 + self.swa_lr = swa_lr + self.swa_anneal_epochs = swa_anneal_epochs + self.amp_args = {} + + self.set_train_mode() + + if device is not None: + self.model.to(device) + self.model_teacher.to(device) + if loss is not None: + self.loss.to(device) + + self.ddp = ddp + self.ddp_type = ddp_type + self.rank = 0 + self.world_size = 1 + if ddp: # (JJ: EXP - for now, I will only use self.ddp_type = 'ddp', i.e., DDPType.DDP) + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + if ddp_type == DDPType.DDP or ddp_type == DDPType.OSS_DDP: + if dinossl.has_batchnorms(self.model): + self.model, self.model_teacher, self.model_teacher_without_ddp = self.convert_sync_batchnorm(device) + else: # (JJ: EXP - when we use ViT-like model (in DINO) that does not have batchnorms: This is not tested yet in hyperion) + self.model_teacher_without_ddp = self.model_teacher + if self.rank == 0: + logging.info( + "training in multiple gpus with distributed-data-parallel" + ) + oss = False if ddp_type == DDPType.DDP else True + self.optimizer = self._make_optimizer(optim, self.model, oss=oss) + self.model = TorchDDP( + self.model, + device_ids=[device], + output_device=device, + ) + elif ddp_type == DDPType.OSS_SHARDED_DDP: + if dinossl.has_batchnorms(self.model): + self.model, self.model_teacher, self.model_teacher_without_ddp = self.convert_sync_batchnorm(device) + else: + self.model_teacher_without_ddp = self.model_teacher + if self.rank == 0: + logging.info( + "training in multiple gpus with fair sharded-distributed-data-parallel" + ) + self.optimizer = self._make_optimizer(optim, self.model, oss=True) + self.model = FairShardedDDP(self.model, self.optimizer) + else: + if self.rank == 0: + logging.info( + "training in multiple gpus with fair fully-sharded-distributed-data-parallel" + ) + # syncbathcnorm is not supported here, it raises exception + self.model = FairFullyShardedDDP( + self.model, + mixed_precision=self.use_amp, + move_params_to_cpu=cpu_offload, + ) + self.optimizer = self._make_optimizer(optim, self.model, oss=False) + + else: + self.model_teacher_without_ddp = self.model_teacher + self.optimizer = self._make_optimizer(optim, self.model) + + # NO backpropagation through model_teacher, which instead is updated by momentum. + # Do the step here after ddp applied. Otherwise, an assertion error raises (DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.) + for p in self.model_teacher.parameters(): + p.requires_grad = False + + # make the learning rate scheduler or schedules + if lrsched['lrsch_type'] == 'dinossl': + self.lr_scheduler = None + self.lr_schedule, self.wd_schedule, self.momentum_schedule = self._make_schedules(lrsched) + else: + self.lr_scheduler = self._make_lr_sched(lrsched, self.optimizer) + + if self.use_amp: + if ddp and ddp_type != DDPType.DDP: + if self.rank == 0: + logging.info( + "using automatic mixed precision training with sharded-grad-scaler" + ) + self.grad_scaler = ShardedGradScaler() + else: + if self.rank == 0: + logging.info( + "using automatic mixed precision training with grad-scaler" + ) + self.grad_scaler = amp.GradScaler() + self.amp_autocast = amp.autocast + else: + self.amp_autocast = contextlib.nullcontext + + self.in_swa = False + if self.do_swa: + if self.rank == 0: + logging.info("init SWA model") + self.swa_model = AveragedModel(self.model) + self.swa_scheduler = SWALR( + self.optimizer, swa_lr=self.swa_lr, anneal_epochs=self.swa_anneal_epochs + ) + + def fit(self, train_data, val_data=None): + """Training function, it performs the training and validation epochs + + Args: + train_data: PyTorch data loader for the training loop + val_data: PyTorch data loader for the validation loop + """ + self.exp_path.mkdir(parents=True, exist_ok=True) + #self._compute_grad_acc_steps(train_data) # Do not apply grad_acc_steps with the current dinossl (i.e., set this to 1 as a default). Instead, apply linear scaling of the lr as is done in original dino repo. + + if self.do_swa and self.cur_epoch >= self.swa_start: + raise NotImplementedError("Using swa is not implemented for dinossl yet") + self.in_swa = True + + val_logs = {} + self.loggers.on_train_begin(epochs=self.epochs) + for epoch in range(self.cur_epoch, self.epochs): + + self.loggers.on_epoch_begin(epoch, batches=len(train_data)) + if self.lr_scheduler is not None: + # this is needed by cosine scheduler + epoch_updates = int(len(train_data) / self.grad_acc_steps) + self.lr_scheduler.on_epoch_begin(epoch, epoch_updates=epoch_updates) + + logs = self.train_epoch(train_data) + if val_data is not None: + val_logs = self.validation_epoch(val_data) + logs.update(val_logs) + else: + logging.info("NO validation phase ...") + + self.cur_epoch += 1 + + self.loggers.on_epoch_end(logs) + if self.do_swa and self.cur_epoch >= self.swa_start: + self.in_swa = True + self.swa_model.update_parameters(self.model) + self.swa_scheduler.step() + else: + if self.lr_scheduler is not None: + self.lr_scheduler.on_epoch_end(logs) + + self.save_checkpoint(logs) + + if self.in_swa: + self.loggers.on_epoch_begin(self.cur_epoch, batches=len(train_data)) + self.model = self.swa_model.module + logs = self.bn_update_epoch(train_data) + + if val_data is not None: + val_logs = self.validation_epoch(val_data) + logs.update(val_logs) + + self.cur_epoch += 1 + self.loggers.on_epoch_end(logs) + self.save_swa_model(logs) + + def set_train_mode(self): + # self.model.train_mode = self.train_mode + self.model.set_train_mode(self.train_mode) + + def train_epoch(self, data_loader): + """Training epoch loop + + Args: + data_loader: PyTorch data loader return input/output pairs + """ + metric_acc = MetricAcc(device=self.device) + batch_metrics = ODict() + self.model.train() + for batch, (data, target) in enumerate(data_loader): + self.loggers.on_batch_begin(batch) + if batch % self.grad_acc_steps == 0: + self.optimizer.zero_grad() + + data, target = data.to(self.device), target.to(self.device) + batch_size = data.shape[0] + with self.amp_autocast(): + output = self.model(data) + loss = self.loss(output, target).mean() / self.grad_acc_steps + + if self.use_amp: + self.grad_scaler.scale(loss).backward() + else: + loss.backward() + + if (batch + 1) % self.grad_acc_steps == 0: + if self.lr_scheduler is not None and not self.in_swa: + self.lr_scheduler.on_opt_step() + self.update_model() + + self._reduce_metric(loss) + batch_metrics["loss"] = loss.item() * self.grad_acc_steps + for k, metric in self.metrics.items(): + batch_metrics[k] = metric(output, target) + + metric_acc.update(batch_metrics, batch_size) + logs = metric_acc.metrics + logs["lr"] = self._get_lr() + self.loggers.on_batch_end(logs=logs, batch_size=batch_size) + # total_batches += 1 + + logs = metric_acc.metrics + logs = ODict(("train_" + k, v) for k, v in logs.items()) + logs["lr"] = self._get_lr() + return logs + + def validation_epoch(self, data_loader, swa_update_bn=False): + """Validation epoch loop + + Args: + data_loader: PyTorch data loader return input/output pairs. + sw_update_bn: wheter or not, update batch-norm layers in SWA. + """ + + metric_acc = MetricAcc(self.device) + batch_metrics = ODict() + with torch.no_grad(): + if swa_update_bn: + log_tag = "train_" + self.train() + else: + log_tag = "val_" + self.model.eval() + + for batch, (data, target) in enumerate(data_loader): + data, target = data.to(self.device), target.to(self.device) + batch_size = data.shape[0] + + with self.amp_autocast(): + output = self.model(data) + loss = self.loss(output, target) + + batch_metrics["loss"] = loss.mean().item() + for k, metric in self.metrics.items(): + batch_metrics[k] = metric(output, target) + + metric_acc.update(batch_metrics, batch_size) + + logs = metric_acc.metrics + logs = ODict((log_tag + k, v) for k, v in logs.items()) + return logs + + def bn_update_epoch(self, data_loader): + logs = self.validation_epoch(data_loader, swa_update_bn=True) + logs["lr"] = self._get_lr() + return logs + + def _clip_grad_norm(self, model, optim, grad_clip, grad_clip_norm): + if self.ddp: + if self.ddp_type == DDPType.DDP: + nn.utils.clip_grad_norm_( + model.parameters(), grad_clip, norm_type=grad_clip_norm + ) + return + if self.ddp_type == DDPType.FULLY_SHARDED_DDP: + # we have to use the member function in FullyShardedDDP class + model.clip_grad_norm_(grad_clip, norm_type=grad_clip_norm) + return + else: + # not sure about this but it looks like + # we have to use the member function in the OSS optimizer wrapper + optim.clip_grad_norm(grad_clip, norm_type=grad_clip_norm) + + # if no DDP clip normally + nn.utils.clip_grad_norm_( + model.parameters(), grad_clip, norm_type=grad_clip_norm + ) + + def update_model(self): + """Updates the model and does gradding clipping.""" + if self.use_amp: + if self.grad_clip > 0: + self.grad_scaler.unscale_(self.optimizer) + self._clip_grad_norm( + self.model, self.optimizer, self.grad_clip, self.grad_clip_norm + ) + + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + if self.grad_clip > 0: + self._clip_grad_norm( + self.model, self.optimizer, self.grad_clip, self.grad_clip_norm + ) + + self.optimizer.step() + + def _make_optimizer(self, optim, model, oss=False): + """Makes an optimizer object.""" + if isinstance(optim, torch.optim.Optimizer): + return optim + + assert isinstance(optim, dict) + opt_args = OF.filter_args(**optim) + opt_args["oss"] = oss + if self.rank == 0: + logging.info("optimizer args={}".format(opt_args)) + if opt_args['dinossl_style']: # dinossl_style means per-parameter updates following FB dino repo to NOT regularize biases nor Norm parameters + params_groups = dinossl.get_params_groups(model) + del opt_args['dinossl_style'] + optimizer = OF.create(params_groups, **opt_args) + else: + del opt_args['dinossl_style'] + optimizer = OF.create(model.parameters(), **opt_args) + return optimizer + + + def _make_lr_sched(self, lr_sched, optim): + """Makes a Learning Rate scheduler object.""" + if lr_sched is None or isinstance(lr_sched, LRS): + return lr_sched + + assert isinstance(lr_sched, dict) + args = LRSF.filter_args(**lr_sched) + if self.rank == 0: + logging.info("lr scheduler args={}".format(args)) + lr_sched = LRSF.create(optim, **args) + return lr_sched + + def _make_schedules(self, lrsched): + assert (self.niter_per_ep != 0) and (self.batch_size != 0) + lr_schedule = dinossl.cosine_scheduler( + lrsched['dinossl_lr'] * self.batch_size / 256., # linear scaling rule (JJ: TODO - 256. might need to change) + lrsched['dinossl_min_lr'], + self.epochs, self.niter_per_ep, + warmup_epochs=lrsched['dinossl_warmup_epochs'], + ) + wd_schedule = dinossl.cosine_scheduler( + lrsched['dinossl_weight_decay'], + lrsched['dinossl_weight_decay_end'], + self.epochs, self.niter_per_ep, + ) + # momentum parameter is getting increased to 1. during training with a cosine schedule + momentum_schedule = dinossl.cosine_scheduler(lrsched['dinossl_momentum_teacher'], 1, + self.epochs, self.niter_per_ep) + + return lr_schedule, wd_schedule, momentum_schedule + + def _default_loggers(self, log_interval, use_tensorboard, use_wandb, wandb): + """Creates the default data loaders""" + prog_log = ProgLogger(interval=log_interval) + csv_log = CSVLogger(self.exp_path / "train.log", append=True) + loggers = [prog_log, csv_log] + if use_tensorboard: + loggers.append( + TensorBoardLogger(self.exp_path / "tb", interval=log_interval) + ) + if use_wandb: + loggers.append( + WAndBLogger( + **wandb, path=self.exp_path / "wandb", interval=log_interval + ) + ) + return LoggerList(loggers) + + def _get_lr(self): + """Returns the current learning rate to show in the loggers""" + for param_group in self.optimizer.param_groups: + return param_group["lr"] + + def _compute_grad_acc_steps(self, data_loader): + if self.eff_batch_size is None: + return + + if data_loader.batch_sampler is not None: + try: + batch_size = data_loader.batch_sampler.avg_batch_size + except: + logging.warn( + "batch sampler doesn't have avg_batch_size property, " + "we cannot estimate grad_acc_steps, using grad_acc_steps=%d", + self.grad_acc_steps, + ) + return + + self.grad_acc_steps = int( + math.ceil(self.eff_batch_size / batch_size / self.world_size) + ) + logging.info( + "Setting grad_acc_steps=%d for " + "eff_batch_size=%d, avg_batch_size=%d, world_size=%d", + self.grad_acc_steps, + self.eff_batch_size, + batch_size, + self.world_size, + ) + return + + logging.warn( + "We cannot determine the batch_size, " + "we cannot estimate grad_acc_steps, using grad_acc_steps=%d", + self.grad_acc_steps, + ) + + def checkpoint(self, logs=None): + """Creates a checkpoint of the training, to save and posterior recovery + + Args: + logs: logs containing the current value of the metrics. + """ + checkpoint = { + "epoch": self.cur_epoch, + "rng_state": torch.get_rng_state(), + "model_cfg": self.model.backbone.conf, + "model_state_dict": self.model.state_dict(), + 'model_teacher_state_dict': self.model_teacher.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), # (JJ: EXP - "it" used for schedules could be found by EITHER "it = len(data_loader) * self.cur_epoch + batch" OR "lr" in the checkpoint['logs'] below) + "loss_state_dict": self.loss.state_dict() + if self.loss is not None + else None, + } + if self.lr_scheduler is not None: + checkpoint["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict() + + if logs is not None: + checkpoint["logs"] = logs + + if self.in_swa: + checkpoint["swa_model_state_dict"] = self.swa_model.state_dict() + checkpoint["swa_scheduler_state_dict"] = self.swa_scheduler.state_dict() + + return checkpoint + + def convert_sync_batchnorm(self, device): + self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) + self.model_teacher = nn.SyncBatchNorm.convert_sync_batchnorm(self.model_teacher) + # we need DDP wrapper to have synchro batch norms working... + if self.ddp_type == DDPType.DDP: + self.model_teacher = nn.parallel.DistributedDataParallel(self.model_teacher, device_ids=[device], output_device=device) # (JJ: TODO - for now, simply follow what Jesus did) + else: + raise NotImplementedError('Need implementation for other DDPType except DDPType.DDP') + self.model_teacher_without_ddp = self.model_teacher.module + + return self.model, self.model_teacher, self.model_teacher_without_ddp + + def save_checkpoint(self, logs=None): + """Saves a checkpoint of the training status + + Args: + logs: logs containing the current value of the metrics. + """ + if self.ddp and ( + self.ddp_type == DDPType.OSS_DDP or self.ddp_type == DDPType.OSS_SHARDED_DDP + ): + # Not sure what this does, just copying from the example in + # https://github.com/facebookresearch/fairscale/blob/master/benchmarks/oss.py + # Check the checkpointing in the case of the OSS optimizer + # Memory usage could spill over from there + # optimizer = cast(OSS, optimizer) + self.optimizer.consolidate_state_dict() + + if self.rank != 0: + return + checkpoint = self.checkpoint(logs) + file_path = "%s/model_ep%04d.pth" % (self.exp_path, self.cur_epoch) + + torch.save(checkpoint, file_path) + + def save_swa_model(self, logs=None): + """Saves a checkpoint of the training status + + Args: + logs: logs containing the current value of the metrics. + """ + if self.rank != 0: + return + + checkpoint = self.checkpoint(logs) + checkpoint["model_state_dict"] = checkpoint["swa_model_state_dict"] + del checkpoint["swa_model_state_dict"] + file_path = "%s/swa_model_ep%04d.pth" % (self.exp_path, self.cur_epoch) + + torch.save(checkpoint, file_path) + + def load_checkpoint(self, file_path): + """Loads a training checkpoint from file. + + Args: + file_path: checkpoint file path + """ + checkpoint = torch.load(file_path, map_location=torch.device("cpu")) + rng_state = checkpoint["rng_state"] + torch.set_rng_state(rng_state) + if self.rank > 0: + # this will make sure that each process produces different data + # when using ddp + dummy = torch.rand(1000 * self.rank) + del dummy + + self.cur_epoch = checkpoint["epoch"] + try: + self.model.load_state_dict(checkpoint["model_state_dict"]) + except: + self.model.module.load_state_dict(checkpoint["model_state_dict"]) + try: + self.model_teacher.load_state_dict(checkpoint['model_teacher_state_dict']) + except: + self.model_teacher.module.load_state_dict(checkpoint['model_teacher_state_dict']) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + if self.loss is not None: + self.loss.load_state_dict(checkpoint["loss_state_dict"]) + if self.lr_scheduler is not None: + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) + + # if self.use_amp: + # amp.load_state_dict(checkpoint['amp']) + if self.do_swa: + if "swa_model_state_dict" in checkpoint: + self.swa_model.load_state_dict(checkpoint["swa_model_state_dict"]) + self.swa_scheduler.load_state_dict( + checkpoint["swa_scheduler_state_dict"] + ) + else: + self.swa_scheduler = SWALR( + self.optimizer, + swa_lr=self.swa_lr, + anneal_epochs=self.swa_anneal_epochs, + ) + + logs = None + if "logs" in checkpoint: + logs = checkpoint["logs"] + + del checkpoint + # this was added before to try to release as much GPU memory as possible + # Recently has started to cause CUDA not available devices error + # Commenting for now. + # if self.device is not None: + # torch.cuda.empty_cache() + + return logs + + def load_last_checkpoint(self): + """Loads the last training checkpoint in the experiment dir.""" + for epoch in range(self.epochs, 0, -1): + file_path = "%s/model_ep%04d.pth" % (self.exp_path, epoch) + if os.path.isfile(file_path): + return self.load_checkpoint(file_path) + + return None + + @staticmethod + def filter_args(**kwargs): + valid_args = ( + "grad_acc_steps", + "eff_batch_size", + "epochs", + "log_interval", + "use_amp", + "ddp_type", + "grad_clip", + "grad_clip_norm", + "swa_start", + "swa_lr", + "swa_anneal_epochs", + "exp_path", + "optim", + "lrsched", + "cpu_offload", + "use_tensorboard", + "use_wandb", + "wandb", + "train_mode", + ) + args = dict((k, kwargs[k]) for k in valid_args if k in kwargs) + return args + + @staticmethod + def add_class_args(parser, prefix=None, train_modes=None, skip=[]): + if prefix is not None: + outer_parser = parser + parser = ArgumentParser(prog="") + + if "optim" not in skip: + OF.add_class_args(parser, prefix="optim") + + if "lrsched" not in skip: + LRSF.add_class_args(parser, prefix="lrsched") + + parser.add_argument( + "--grad-acc-steps", + type=int, + default=1, + help="gradient accumulation batches before weigth update", + ) + parser.add_argument( + "--eff-batch-size", + type=int, + default=None, + help="effective total batch size, if given, it overrides grad_acc_steps", + ) + parser.add_argument("--epochs", type=int, default=200, help="number of epochs") + if train_modes is not None: + parser.add_argument( + "--train-mode", + default="full", + choices=train_modes, + help=f"Available train modes for the model in {train_modes}", + ) + parser.add_argument( + "--log-interval", + type=int, + default=10, + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--use-tensorboard", + action="store_true", + default=False, + help="use tensorboard logger", + ) + parser.add_argument( + "--use-wandb", action="store_true", default=False, help="use wandb logger" + ) + parser.add_argument("--wandb.project", default=None, help="wandb project name") + parser.add_argument("--wandb.group", default=None, help="wandb group name") + parser.add_argument("--wandb.name", default=None, help="wandb display name") + # parser.add_argument( + # '--wandb.path', default=None, + # help='wandb directory') + parser.add_argument( + "--wandb.mode", + default="online", + choices=["online", "offline"], + help="wandb mode (online, offline)", + ) + + parser.add_argument( + "--ddp-type", + default="ddp", + choices=ddp_choices, + help="DDP type in {}".format(ddp_choices), + ) + parser.add_argument( + "--use-amp", + action="store_true", + default=False, + help="use mixed precision training", + ) + parser.add_argument( + "--cpu-offload", + action="store_true", + default=False, + help="CPU offload of gradients when using fully_sharded_ddp", + ) + parser.add_argument( + "--grad-clip", type=float, default=0, help="gradient clipping norm value" + ) + parser.add_argument( + "--grad-clip-norm", + default=2, + choices=["inf", 1, 2], + help="gradient clipping norm type", + ) + parser.add_argument( + "--swa-start", + type=int, + default=0, + help="start epoch for SWA, if 0 it does not use SWA", + ) + parser.add_argument( + "--swa-lr", type=float, default=1e-3, help="learning rate for SWA phase" + ) + parser.add_argument( + "--swa-anneal-epochs", + type=int, + default=10, + help="SWA learning rate anneal epochs", + ) + + parser.add_argument("--exp-path", help="experiment path") + + if prefix is not None: + outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser)) + + add_argparse_args = add_class_args diff --git a/hyperion/torch/trainers/xvector_trainer_dinossl.py b/hyperion/torch/trainers/xvector_trainer_dinossl.py new file mode 100644 index 00000000..02cee3f7 --- /dev/null +++ b/hyperion/torch/trainers/xvector_trainer_dinossl.py @@ -0,0 +1,163 @@ +""" + Copyright 2019 Johns Hopkins University (Author: Jesus Villalba) + Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +""" +import os +from collections import OrderedDict as ODict + +import logging + +import torch +import torch.nn as nn + +from ..utils import MetricAcc +from .torch_trainer_dinossl import DINOSSLTorchTrainer +from torch.distributed.elastic.multiprocessing.errors import record + + +class DINOSSLXVectorTrainer(DINOSSLTorchTrainer): + """Trainer to train x-vector style models. + + Attributes: + model: x-Vector model object. + optim: pytorch optimizer object or options dict + epochs: max. number of epochs + exp_path: experiment output path + cur_epoch: current epoch + grad_acc_steps: gradient accumulation steps to simulate larger batch size. + device: cpu/gpu device + metrics: extra metrics to compute besides cxe. + lrsched: learning rate scheduler object or options dict + loggers: LoggerList object, loggers write training progress to std. output and file. + If None, it uses default loggers. + ddp: if True use distributed data parallel training + ddp_type: type of distributed data parallel in (ddp, oss_ddp, oss_shared_ddp) + loss: if None, it uses cross-entropy + train_mode: training mode in ['train', 'ft-full', 'ft-last-layer'] + use_amp: uses mixed precision training. + log_interval: number of optim. steps between log outputs + use_tensorboard: use tensorboard logger + use_wandb: use wandb logger + wandb: wandb dictionary of options + grad_clip: norm to clip gradients, if 0 there is no clipping + grad_clip_norm: norm type to clip gradients + swa_start: epoch to start doing swa + swa_lr: SWA learning rate + swa_anneal_epochs: SWA learning rate anneal epochs + cpu_offload: CPU offload of gradients when using fully sharded ddp + """ + + def __init__( + self, + model, + optim={}, + epochs=100, + exp_path="./train", + cur_epoch=0, + grad_acc_steps=1, + eff_batch_size=None, + device=None, + metrics=None, + lrsched=None, + loggers=None, + ddp=False, + ddp_type="ddp", + loss=None, + train_mode="full", + use_amp=False, + log_interval=10, + use_tensorboard=False, + use_wandb=False, + wandb={}, + grad_clip=0, + grad_clip_norm=2, + swa_start=0, + swa_lr=1e-3, + swa_anneal_epochs=10, + cpu_offload=False, + niter_per_ep=0, + batch_size=0 + ): + + if loss is None: + loss = nn.CrossEntropyLoss() + super().__init__( + model, + loss, + optim, + epochs, + exp_path, + cur_epoch=cur_epoch, + grad_acc_steps=grad_acc_steps, + eff_batch_size=eff_batch_size, + device=device, + metrics=metrics, + lrsched=lrsched, + loggers=loggers, + ddp=ddp, + ddp_type=ddp_type, + train_mode=train_mode, + use_amp=use_amp, + log_interval=log_interval, + use_tensorboard=use_tensorboard, + use_wandb=use_wandb, + wandb=wandb, + grad_clip=grad_clip, + grad_clip_norm=grad_clip_norm, + swa_start=swa_start, + swa_lr=swa_lr, + swa_anneal_epochs=swa_anneal_epochs, + cpu_offload=cpu_offload, + niter_per_ep=niter_per_ep, + batch_size=batch_size + ) + + @record + def train_epoch(self, data_loader): + """Training epoch loop + + Args: + data_loader: pytorch data loader returning features and class labels. + """ + + self.model.update_loss_margin(self.cur_epoch) + + metric_acc = MetricAcc(device=self.device) + batch_metrics = ODict() + self.model.train() + for batch, (data, target) in enumerate(data_loader): + self.loggers.on_batch_begin(batch) + + if batch % self.grad_acc_steps == 0: + self.optimizer.zero_grad() + + data, target = data.to(self.device), target.to(self.device) + batch_size = data.shape[0] + + with self.amp_autocast(): + output = self.model(data, y=target) + loss = self.loss(output, target).mean() / self.grad_acc_steps + + if self.use_amp: + self.grad_scaler.scale(loss).backward() + else: + loss.backward() + + if (batch + 1) % self.grad_acc_steps == 0: + if self.lr_scheduler is not None and not self.in_swa: + self.lr_scheduler.on_opt_step() + self.update_model() + + batch_metrics["loss"] = loss.item() * self.grad_acc_steps + for k, metric in self.metrics.items(): + batch_metrics[k] = metric(output, target) + + metric_acc.update(batch_metrics, batch_size) + logs = metric_acc.metrics + logs["lr"] = self._get_lr() + self.loggers.on_batch_end(logs=logs, batch_size=batch_size) + + logs = metric_acc.metrics + logs = ODict(("train_" + k, v) for k, v in logs.items()) + logs["lr"] = self._get_lr() + return logs diff --git a/hyperion/torch/trainers/xvector_trainer_from_wav_dinossl.py b/hyperion/torch/trainers/xvector_trainer_from_wav_dinossl.py new file mode 100644 index 00000000..a8afb572 --- /dev/null +++ b/hyperion/torch/trainers/xvector_trainer_from_wav_dinossl.py @@ -0,0 +1,235 @@ +""" + Copyright 2019 Johns Hopkins University (Author: Jesus Villalba) + Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +""" +import os +import sys +from collections import OrderedDict as ODict + +import logging +import copy +import math + +import torch +import torch.nn as nn + +from ..utils import MetricAcc, TorchDDP +from ..utils import cancel_gradients_last_layer +from .xvector_trainer_dinossl import DINOSSLXVectorTrainer + + +class DINOSSLXVectorTrainerFromWav(DINOSSLXVectorTrainer): + """Trainer to train x-vector style models. + + Attributes: + model: x-Vector model object. + feat_extractor: feature extractor nn.Module + optim: pytorch optimizer object or options dict + epochs: max. number of epochs + exp_path: experiment output path + cur_epoch: current epoch + grad_acc_steps: gradient accumulation steps to simulate larger batch size. + device: cpu/gpu device + metrics: extra metrics to compute besides cxe. + lrsched: learning rate scheduler object or options dict. + loggers: LoggerList object, loggers write training progress to std. output and file. + ddp: if True use distributed data parallel training + ddp_type: type of distributed data parallel in (ddp, oss_ddp, oss_shared_ddp) + loss: if None, it uses cross-entropy + train_mode: training mode in ['train', 'ft-full', 'ft-last-layer'] + use_amp: uses mixed precision training. + log_interval: number of optim. steps between log outputs + use_tensorboard: use tensorboard logger + use_wandb: use wandb logger + wandb: wandb dictionary of options + grad_clip: norm to clip gradients, if 0 there is no clipping + grad_clip_norm: norm type to clip gradients + swa_start: epoch to start doing swa + swa_lr: SWA learning rate + swa_anneal_epochs: SWA learning rate anneal epochs + cpu_offload: CPU offload of gradients when using fully sharded ddp + """ + + def __init__( + self, + model, + feat_extractor, + optim={}, + epochs=100, + exp_path="./train", + cur_epoch=0, + grad_acc_steps=1, + eff_batch_size=None, + device=None, + metrics=None, + lrsched=None, + loggers=None, + ddp=False, + ddp_type="ddp", + loss=None, + train_mode="full", + use_amp=False, + log_interval=10, + use_tensorboard=False, + use_wandb=False, + wandb={}, + grad_clip=0, + grad_clip_norm=2, + swa_start=0, + swa_lr=1e-3, + swa_anneal_epochs=10, + cpu_offload=False, + niter_per_ep=0, + batch_size=0 + ): + + super().__init__( + model, + optim, + epochs, + exp_path, + cur_epoch=cur_epoch, + grad_acc_steps=grad_acc_steps, + eff_batch_size=eff_batch_size, + device=device, + metrics=metrics, + lrsched=lrsched, + loggers=loggers, + ddp=ddp, + ddp_type=ddp_type, + loss=loss, + train_mode=train_mode, + use_amp=use_amp, + log_interval=log_interval, + use_tensorboard=use_tensorboard, + use_wandb=use_wandb, + wandb=wandb, + grad_clip=grad_clip, + grad_clip_norm=grad_clip_norm, + swa_start=swa_start, + swa_lr=swa_lr, + swa_anneal_epochs=swa_anneal_epochs, + cpu_offload=cpu_offload, + niter_per_ep=niter_per_ep, + batch_size=batch_size + ) + + self.feat_extractor = feat_extractor + if device is not None: + self.feat_extractor.to(device) + + # if ddp: + # self.feat_extractor = TorchDDP(self.feat_extractor) + + def train_epoch(self, data_loader): + """Training epoch loop + + Args: + data_loader: pytorch data loader returning features and class labels. + """ + + metric_acc = MetricAcc(device=self.device) + batch_metrics = ODict() + self.feat_extractor.train() + self.model.train() + for batch, (data, _) in enumerate(data_loader): + self.loggers.on_batch_begin(batch) + if batch % self.grad_acc_steps == 0: + self.optimizer.zero_grad() + + data = [i.to(self.device, non_blocking=True) for i in data] + batch_size = data[0].shape[0] + with torch.no_grad(): + feats = [] + for i in data: + feats.append(self.feat_extractor(i)) + + with self.amp_autocast(): + output = self.model(feats) + output_teacher = self.model_teacher(feats[:2]) # 2 (currently this number is fixed) global crops + loss = self.loss(output, output_teacher, self.cur_epoch)/self.grad_acc_steps + + if not math.isfinite(loss.item()): + logging.warning('Loss is {}, stopping training'.format(loss.item())) + sys.exit(1) + + if self.use_amp: + self.grad_scaler.scale(loss).backward() + else: + loss.backward() + + freeze_last_layer=1 + cancel_gradients_last_layer(self.cur_epoch, self.model, freeze_last_layer) + + if (batch + 1) % self.grad_acc_steps == 0: + if self.lr_scheduler is not None and not self.in_swa: + self.lr_scheduler.on_opt_step() + # update learning rate and weight decay rate + it = len(data_loader) * self.cur_epoch + int(batch/self.grad_acc_steps) # it: global batch index, batch: local batch index in the current epoch + for i, param_group in enumerate(self.optimizer.param_groups): + param_group["lr"] = self.lr_schedule[it] + if i == 0: # only the first group is regularized + param_group["weight_decay"] = self.wd_schedule[it] + self.update_model() + + # EMA update for the teacher + with torch.no_grad(): + m = self.momentum_schedule[it] # momentum parameter + if hasattr(self.model,'module'): # train with ddp + for param_q, param_k in zip(self.model.module.parameters(), self.model_teacher_without_ddp.parameters()): + param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) + else: # train with a single gpu (w/o ddp), which I (JJ) used in debugging + for param_q, param_k in zip(self.model.parameters(), self.model_teacher_without_ddp.parameters()): + param_k.data.mul_(m).add_((1 - m) * param_q.detach().data) + # (JJ: TODO - for now, I skipped the logging related parts (per iter and epoch) after the above from dino) + + batch_metrics["loss"] = loss.item() * self.grad_acc_steps + for k, metric in self.metrics.items(): + batch_metrics[k] = metric(output) + + metric_acc.update(batch_metrics, batch_size) + logs = metric_acc.metrics + logs["lr"] = self._get_lr() # (JJ: TODO - this may need to change later (NOT now) if lrs are applied differerently over parameter groups) + self.loggers.on_batch_end(logs=logs, batch_size=batch_size) + + logs = metric_acc.metrics + logs = ODict(("train_" + k, v) for k, v in logs.items()) + logs["lr"] = self._get_lr() + return logs + + def validation_epoch(self, data_loader, swa_update_bn=False): + """Validation epoch loop + + Args: + data_loader: PyTorch data loader return input/output pairs. + sw_update_bn: wheter or not, update batch-norm layers in SWA. + """ + metric_acc = MetricAcc(device=self.device) + batch_metrics = ODict() + self.feat_extractor.eval() + with torch.no_grad(): + if swa_update_bn: + log_tag = "train_" + self.model.train() + else: + log_tag = "val_" + self.model.eval() + + for batch, (data, target) in enumerate(data_loader): + data, target = data.to(self.device), target.to(self.device) + batch_size = data.shape[0] + + feats = self.feat_extractor(data) + with self.amp_autocast(): + output = self.model(feats) + loss = self.loss(output, target) + + batch_metrics["loss"] = loss.mean().item() + for k, metric in self.metrics.items(): + batch_metrics[k] = metric(output, target) + + metric_acc.update(batch_metrics, batch_size) + + logs = metric_acc.metrics + logs = ODict((log_tag + k, v) for k, v in logs.items()) + return logs diff --git a/hyperion/torch/utils/__init__.py b/hyperion/torch/utils/__init__.py index 3a4692dc..79691bf8 100644 --- a/hyperion/torch/utils/__init__.py +++ b/hyperion/torch/utils/__init__.py @@ -11,3 +11,4 @@ from .vad_utils import remove_silence from .data_parallel import TorchDataParallel from .ddp import TorchDDP, FairShardedDDP, FairFullyShardedDDP +from .dinossl import MultiCropWrapper, DINOHead, has_batchnorms, cancel_gradients_last_layer, add_dinossl_args, filter_args, get_params_groups, trunc_normal_, _no_grad_trunc_normal_ diff --git a/hyperion/torch/utils/dinossl.py b/hyperion/torch/utils/dinossl.py new file mode 100644 index 00000000..c2b32aa3 --- /dev/null +++ b/hyperion/torch/utils/dinossl.py @@ -0,0 +1,581 @@ +""" + JJ: I copied and edited from utils.py script in facebook dino repo (some parts are from vision_transformer.py, main_dino.py in the same repo) + Copyright 2021 Johns Hopkins University (Author: Jaejin Cho) + Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +""" +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Misc functions. + +Mostly copy-paste from torchvision references or other public repos like DETR: +https://github.com/facebookresearch/detr/blob/master/util/misc.py +""" +import copy +import logging +import os +import sys +import time +import math +import random +import datetime +import subprocess +from collections import defaultdict, deque + + +import numpy as np +import torch +from torch import nn +import torch.distributed as dist +from PIL import ImageFilter, ImageOps + +def add_dinossl_args(parser): + parser.add_argument('--dinossl', default=False, type=bool, # I can simply type=bool with jsonargparse module. Refer to the "Boolean arguments" section in https://jsonargparse.readthedocs.io/en/stable/ + help='whether to run DINO self-supervised training') + parser.add_argument('--dinossl_nlayers', default=3, type=int, help="""number of layers in MLP + except the last optional weight normalized layer""") + parser.add_argument('--dinossl_out_dim', default=65536, type=int, help="""Dimensionality of + the DINO head output. For complex and large datasets large values (like 65k) work well.""") + parser.add_argument('--dinossl_use_bn_in_head', default=False, type=bool, + help="Whether to use batch normalizations in projection head (Default: False)") + parser.add_argument('--dinossl_norm_last_layer', default=True, type=bool, + help="""Whether or not to weight normalize the last layer of the DINO head. + Not normalizing leads to better performance but can make the training unstable. + In our experiments, we typically set this paramater to False with deit_small and True with vit_base.""") + parser.add_argument('--dinossl_local_crops_number', type=int, default=2, help="""Number of small + local views to generate. Set this parameter to 0 to disable multi-crop training. + When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """) + # Teacher temperature parameters + parser.add_argument('--dinossl_warmup_teacher_temp', default=0.04, type=float, + help="""Initial value for the teacher temperature: 0.04 works well in most cases. + Try decreasing it if the training loss does not decrease.""") + parser.add_argument('--dinossl_teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup) + of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend + starting with the default value of 0.04 and increase this slightly if needed.""") + parser.add_argument('--dinossl_warmup_teacher_temp_epochs', default=0, type=int, + help='Number of warmup epochs for the teacher temperature (Default: 30).') + # Audio chunking related + parser.add_argument('--dinossl_chunk_len_mult', type=int, default=2, + help=('value to multiply chunk_length with for long chunks')) + parser.add_argument('--dinossl_reduce_overlap_prob', type=float, default=0, + help=('probability of applying a function to reduce an overlap between two long chunks')) + # epochs + parser.add_argument('--epochs', type=int, default=70, + help=('training epochs')) + +def filter_args(**kwargs): + valid_args = ('dinossl', 'dinossl_nlayers', 'dinossl_out_dim', 'dinossl_use_bn_in_head', + 'dinossl_norm_last_layer','dinossl_local_crops_number', + 'dinossl_warmup_teacher_temp','dinossl_teacher_temp', 'dinossl_warmup_teacher_temp_epochs', + 'dinossl_chunk_len_mult', 'dinossl_reduce_overlap_prob', 'epochs') + args = dict((k, kwargs[k]) + for k in valid_args if k in kwargs) + + return args + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size): + if os.path.isfile(pretrained_weights): + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + print(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg)) + else: + print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.") + url = None + if model_name == "deit_small" and patch_size == 16: + url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" + elif model_name == "deit_small" and patch_size == 8: + url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth" + elif model_name == "vit_base" and patch_size == 16: + url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" + elif model_name == "vit_base" and patch_size == 8: + url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" + if url is not None: + print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.") + state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url) + model.load_state_dict(state_dict, strict=True) + else: + print("There is no reference weights available for this model => We use random weights.") + + +def clip_gradients(model, clip): + norms = [] + for name, p in model.named_parameters(): + if p.grad is not None: + param_norm = p.grad.data.norm(2) + norms.append(param_norm.item()) + clip_coef = clip / (param_norm + 1e-6) + if clip_coef < 1: + p.grad.data.mul_(clip_coef) + return norms + + +def cancel_gradients_last_layer(epoch, model, freeze_last_layer): + if epoch >= freeze_last_layer: + return + for n, p in model.named_parameters(): + if "last_layer" in n: + p.grad = None + + +def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs): + """ + Re-start from checkpoint + """ + if not os.path.isfile(ckp_path): + return + print("Found checkpoint at {}".format(ckp_path)) + + # open checkpoint file + checkpoint = torch.load(ckp_path, map_location="cpu") + + # key is what to look for in the checkpoint file + # value is the object to load + # example: {'state_dict': model} + for key, value in kwargs.items(): + if key in checkpoint and value is not None: + try: + msg = value.load_state_dict(checkpoint[key], strict=False) + print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg)) + except TypeError: + try: + msg = value.load_state_dict(checkpoint[key]) + print("=> loaded {} from checkpoint '{}'".format(key, ckp_path)) + except ValueError: + print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path)) + else: + print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path)) + + # re load variable important for the run + if run_variables is not None: + for var_name in run_variables: + if var_name in checkpoint: + run_variables[var_name] = checkpoint[var_name] + + +def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): + warmup_schedule = np.array([]) + warmup_iters = warmup_epochs * niter_per_ep + if warmup_epochs > 0: + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(epochs * niter_per_ep - warmup_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + + schedule = np.concatenate((warmup_schedule, schedule)) + assert len(schedule) == epochs * niter_per_ep + return schedule + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def init_distributed_mode(args): + # launched with torch.distributed.launch + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + # launched with submitit on a slurm cluster + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + # launched naively with `python main_dino.py` + # we manually add MASTER_ADDR and MASTER_PORT to env variables + elif torch.cuda.is_available(): + print('Will run the code on one GPU.') + args.rank, args.gpu, args.world_size = 0, 0, 1 + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29500' + else: + print('Does not support training without GPU.') + sys.exit(1) + + dist.init_process_group( + backend="nccl", + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + + torch.cuda.set_device(args.gpu) + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + dist.barrier() + setup_for_distributed(args.rank == 0) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class LARS(torch.optim.Optimizer): + """ + Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py + """ + def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001, + weight_decay_filter=None, lars_adaptation_filter=None): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, + eta=eta, weight_decay_filter=weight_decay_filter, + lars_adaptation_filter=lars_adaptation_filter) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g['params']: + dp = p.grad + + if dp is None: + continue + + if p.ndim != 1: + dp = dp.add(p, alpha=g['weight_decay']) + + if p.ndim != 1: + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where(param_norm > 0., + torch.where(update_norm > 0, + (g['eta'] * param_norm / update_norm), one), one) + dp = dp.mul(q) + + param_state = self.state[p] + if 'mu' not in param_state: + param_state['mu'] = torch.zeros_like(p) + mu = param_state['mu'] + mu.mul_(g['momentum']).add_(dp) + + p.add_(mu, alpha=-g['lr']) + + +class MultiCropWrapper(nn.Module): + """ + Perform forward pass separately on each resolution input. + The inputs corresponding to a single resolution are clubbed and single + forward is run on the same resolution inputs. Hence we do several + forward passes = number of different resolutions used. We then + concatenate all the output features and run the head forward on these + concatenated features. + (JJ: EXP - Q: In spkid, I think we do zero-padding at least to make the seq len same. Can we skip this or at least zero-padding separately for local and global view groups? A: It seems with adaptive pooling, input seq. len. can vary) + """ + def __init__(self, backbone, head): + super(MultiCropWrapper, self).__init__() + # disable layers dedicated to ImageNet labels classification + #backbone.fc, backbone.head = nn.Identity(), nn.Identity() + self.backbone = backbone + self.head = head + self._train_mode = "full" + + def forward(self, x): + # convert to list + if not isinstance(x, list): + x = [x] + idx_crops = torch.cumsum(torch.unique_consecutive( + torch.tensor([inp.shape[-1] for inp in x]), + return_counts=True, + )[1], 0) + start_idx = 0 + for end_idx in idx_crops: + _out = self.backbone(torch.cat(x[start_idx: end_idx])) + if start_idx == 0: + output = _out + else: + output = torch.cat((output, _out)) + start_idx = end_idx + # Run the head forward on the concatenated features. + return self.head(output) + + def extract_embed(self, x, chunk_length=0, embed_layer=None, detach_chunks=False): + # - This method is NOT used when extracting xvectors from f (right before the DINOHead) + # - self.dinossl_xvec_loc is registered (in TorchModelLoader) for xvector extraction NOT during training. + y = self.backbone.extract_embed(x, chunk_length=chunk_length, + embed_layer=embed_layer) + if self.dinossl_xvec_loc == "dinohead_mlp": + y = self.head.mlp(y) + elif self.dinossl_xvec_loc == "dinohead_l2norm": + y = self.head.mlp(y) + y = nn.functional.normalize(y, dim=-1, p=2) + elif self.dinossl_xvec_loc == "dinohead_linear": + y = self.head(y) + else: + NotImplementedError + return y + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def unfreeze(self): + for param in self.parameters(): + param.requires_grad = True + + def set_train_mode(self, mode): + if mode == self._train_mode: + return + + if mode == "full": # different from xvec train, in dinossl, only this is used (thus implemented) + self.unfreeze() + elif mode == "frozen": + raise NotImplementedError + # self.freeze() + elif mode == "ft-embed-affine": + raise NotImplementedError + # self.unfreeze() + # self.freeze_preembed_layers() + else: + raise ValueError(f"invalid train_mode={mode}") + + self._train_mode = mode + +def get_params_groups(model): + regularized = [] + not_regularized = [] + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + # we do not regularize biases nor Norm parameters + if name.endswith(".bias") or len(param.shape) == 1: + not_regularized.append(param) + else: + regularized.append(param) + return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False + +# Parts copied and edited from vision_transformer.py in facebook dino repo - start +class DINOHead(nn.Module): + def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256): + super().__init__() + nlayers = max(nlayers, 1) + if nlayers == 1: + self.mlp = nn.Linear(in_dim, bottleneck_dim) + else: + layers = [nn.Linear(in_dim, hidden_dim)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim)) + self.mlp = nn.Sequential(*layers) + self.apply(self._init_weights) + self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + if norm_last_layer: + self.last_layer.weight_g.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + x = nn.functional.normalize(x, dim=-1, p=2) + x = self.last_layer(x) + return x +# Parts copied and edited from vision_transformer.py in facebook dino repo - end + +# Parts copied and edited from main_dino.py in facebook dino repo - start +class DINOLoss(nn.Module): + def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp, + warmup_teacher_temp_epochs, nepochs, student_temp=0.1, + center_momentum=0.9): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.ncrops = ncrops + self.register_buffer("center", torch.zeros(1, out_dim)) + # we apply a warm up for the teacher temperature because + # a too high temperature makes the training instable at the beginning + self.teacher_temp_schedule = np.concatenate(( + np.linspace(warmup_teacher_temp, + teacher_temp, warmup_teacher_temp_epochs), + np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp + )) + + def forward(self, student_output, teacher_output, epoch): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + """ + student_out = student_output / self.student_temp + student_out = student_out.chunk(self.ncrops) + + # teacher centering and sharpening + temp = self.teacher_temp_schedule[epoch] + teacher_out = nn.functional.softmax((teacher_output - self.center) / temp, dim=-1) + teacher_out = teacher_out.detach().chunk(2) + + total_loss = 0 + n_loss_terms = 0 + for iq, q in enumerate(teacher_out): + for v in range(len(student_out)): + if v == iq: + # we skip cases where student and teacher operate on the same view + continue + loss = torch.sum(-q * nn.functional.log_softmax(student_out[v], dim=-1), dim=-1) + total_loss += loss.mean() + n_loss_terms += 1 + total_loss /= n_loss_terms + self.update_center(teacher_output) + return total_loss + + @torch.no_grad() + def update_center(self, teacher_output): + """ + Update center used for teacher output. + """ + batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + world_size = get_world_size() + if world_size == 1: # 1-gpu w/o ddp + batch_center = batch_center / len(teacher_output) + else: # ddp + dist.all_reduce(batch_center) + batch_center = batch_center / (len(teacher_output) * world_size) + + # ema update + self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) +# Parts copied and edited from main_dino.py in facebook dino repo - end + +def init_dino(model, dinossl_args, rank = 0): + model.conf = model.get_config() # The classif_net part will NOT be correct in the dinssl case but keep around this conf for compatibility with the current code-base + if rank == 0: + logging.info('dinossl args={}'.format(dinossl_args)) + + embed_dim = model.conf['embed_dim'] + model.classif_net.output = nn.Identity() + + # model + model_teacher = copy.deepcopy(model) + model = MultiCropWrapper(model, DINOHead(embed_dim, dinossl_args['dinossl_out_dim'], use_bn=dinossl_args['dinossl_use_bn_in_head'], + norm_last_layer=dinossl_args['dinossl_norm_last_layer'], nlayers=dinossl_args['dinossl_nlayers'])) # multi-crop wrapper handles forward with inputs of different chunk lengths + model_teacher = MultiCropWrapper(model_teacher, DINOHead(embed_dim, dinossl_args['dinossl_out_dim'], use_bn=dinossl_args['dinossl_use_bn_in_head'], + nlayers=dinossl_args['dinossl_nlayers'])) + # teacher and student start with the same weights. "requires_grad = False" happens in torch_trainer_dinossl.py + model_teacher.load_state_dict(model.state_dict()) + model = [model, model_teacher] + # loss + loss = DINOLoss(dinossl_args['dinossl_out_dim'], + dinossl_args['dinossl_local_crops_number'] + 2, # total number of crops = 2 global crops + local_crops_number + dinossl_args['dinossl_warmup_teacher_temp'], + dinossl_args['dinossl_teacher_temp'], + dinossl_args['dinossl_warmup_teacher_temp_epochs'], + dinossl_args['epochs'],) # to(device) will happen in torch_trainer_dinossl.py + + if rank == 0: + logging.info('dinossl-model={}'.format(model)) + return model, loss \ No newline at end of file