diff --git a/README.md b/README.md
index 7132a031..ad85f529 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,19 @@ The full API is described in the documentation page [https://hyperion-ml.readthe
## Installation Instructions
+### If you use the CLSP grid, simply run below in the root of your cloned repo (as of Aug 15 2022):
+```
+./prepare_egs_paths.sh
+# Then, type /home/janto/usr/local/anaconda3 when "Introduce path to your conda base installation (e.g.:/usr/local/anaconda3):" is prompted
+# type /home/jcho/.conda/envs/hyp_persephone_jj when "Introduce name/prefix_path for your conda environment (e.g.:hyperion)" is prompted
+
+# You may see the two lines below but it seems okay to ignore:
+# Hyperion is not installed in env
+# Adding hyperion directory to the PYTHONPATH variable in the recipes.
+
+# Also, with this, you can skip "Prerequistes to run the recipes" below
+```
+
### Prerequisites
We use anaconda or miniconda, though you should be able to make it work in other python distributions
diff --git a/egs/voxceleb/dinossl.v1/README.md b/egs/voxceleb/dinossl.v1/README.md
new file mode 100644
index 00000000..5b5b93e5
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/README.md
@@ -0,0 +1,205 @@
+# VoxCeleb V1.1
+
+Recipe for the VoxCeleb Speaker Verification Task
+
+## Differences w.r.t VoxCeleb V1 recipe
+
+In recipe version V1:
+ - We compute speech augmentations and acoustic features offline and dump them to disk.
+ - Augmentation is performed using Kaldi scripts and wav-reverbate tool
+ - Babble noise is created on-the-fly when computing features by mixing 3-7 single speaker files.
+
+In this recipe:
+ - We compute speech augmentations and acoustic features are computed always on-the-fly,
+ we don't dump any features to disk.
+ - Augmentation is performed using Hyperin SpeechAugment class.
+ - The behavior of this class is controlled
+ by the the configuration file `conf/reverb_noise_aug.yml`,
+ which mimics the proportions of noise and RIR types, and SNRs used in the V1 or the recipe.
+ - Babble noise is created offline by mixing 3-10 single speaker files.
+
+
+## Citing
+
+## Training Data
+
+ - x-Vector network is trained on Voxceleb2 dev + test with augmentations
+ - MUSAN noise
+ - RIR reverberation
+
+## Test data
+
+ - Test data is VoxCeleb 1
+ - We evaluate 6 conditions:
+ - VoxCeleb-O (Original): Original Voxceleb test set with 40 speakers
+ - Voxceleb-O-cleaned: VoxCeleb-O cleaned-up of some errors
+ - VoxCeleb-E (Entire): List using all utterances of VoxCeleb1
+ - Voxceleb-E-cleaned: VoxCeleb-E cleaned-up of some errors
+ - VoxCeleb-H (Hard): List of hard trials between all utterances of VoxCeleb1, same gender and nationality trials.
+ - Voxceleb-H-cleaned: VoxCeleb-H cleaned-up of some errors
+
+## Usage
+
+ - Run the run_0*.sh scripts in sequence
+ - By default it will use Light ResNet (16 base channels)
+ - For better performance use full ResNet (64 base channels) using `config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh` file as
+```bash
+run_011_train_xvector.sh --config-file config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh
+run_030_extract_xvectors.sh --config-file config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh --use-gpu true
+run_040_eval_be.sh --config-file config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh
+```
+
+ - To train with mixed precision training use config file `config_fbank80_stmn_lresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh`
+
+## Recipe Steps:
+
+ - `run_001_prepare_data.sh`
+ - Data preparation script to generate Kaldi style data directories for
+ - VoxCeleb2 train+test
+ - VoxCeleb1 O/E/H eval sets
+
+ - `run_002_compute_evad.sh`
+ - Computes Energy VAD for all datasets
+
+ - `run_003_prepare_noises_rirs.sh`
+ - Prepares MUSAN noises, music to be used by SpeechAugment class.
+ - Creates Babble noise from MUSAN speech to be used by SpeechAugment class.
+ - Prepares RIRs by compacting then into HDF5 files, to be used by SpeechAugment class.
+
+ - `run_010_prepare_xvec_train_data.sh`
+ - Transforms all the audios that we are going to use to train the x-vector into a common format, e.g., .flac.
+ - Removes silence from the audios
+ - Removes utterances shorter than 4secs and speakers with less than 8 utterances.
+ - Creates training and validation lists for x-vector training
+
+ - `run_011_train_xvector.sh`
+ - Trains the x-vector network
+
+ - `run_030_extract_xvectors.sh`
+ - Extracts x-vectors for VoxCeleb2 or VoxCeleb2+augmentation for PLDA training
+ - Exctracts x-vectors for VoxCeleb1 test sets
+
+ - `run_040_eval_be.sh`
+ - Trains PLDA and evals PLDA and cosine scoring back-ends
+
+
+## Results
+
+### VoxCeleb 1 Original-Clean trial list
+
+| Config | Model Type | Model Details | Back-end | EER(%) | MinDCF(p=0.05) | MinDCF(p=0.01) |
+| ------ | ---------- | ------------- | -------- | :----: | :------------: | :------------: |
+| config_fbank80_stmn_lresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | LResNet34 | ArcFace s=30/m=0.3 | PLDA | 2.00 | 0.129 | 0.216 |
+| | | | Cosine | 2.04 | 0.138 | 0.210 |
+| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet34 | ArcFace s=30/m=0.3 | PLDA | 1.35 | 0.091 | 0.159 |
+| | | | Cosine | 1.22 | 0.082 | 0.129 |
+| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp_swa.v1.sh | ResNet34 | + SWA | Cosine | 1.19 | 0.074 | 0.124 |
+| config_fbank80_stmn_resnet50_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet50 | ArcFace s=30/m=0.3 | PLDA | 1.30 | 0.090 | 0.160 |
+| | | | Cosine | 1.44 | 0.100 | 0.173 |
+| config_fbank80_stmn_tseresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-ResNet34 | ArcFace s=30/m=0.3 | PLDA | 1.23 | 0.091 | 0.143 |
+| | | | Cosine | 1.17 | 0.081 | 0.110 |
+| config_fbank80_stmn_effnetb4_v2_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b4 v2 | EfficientNet-b4 with strides=1122121
ArcFace s=30/m=0.3 | 1.37 | 0.104 | 0.179 |
+| | | | Cosine | 1.31 | 0.080 | 0.139 |
+| config_fbank80_stmn_effnetb7_v2_eina_hln_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b7 v2 | EfficientNet-b7 with strides=1122121
Instance-Norm with affine transform in Encoder
Layer-Norm in head
ArcFace s=30/m=0.3 | 1.29 | 0.088 | 0.129 |
+| | | | Cosine | 1.23 | 0.083 | 0.136 |
+| config_fbank80_stmn_res2net34w16s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=16x4 | ArcFace s=30/m=0.3 | PLDA | 1.20 | 0.095 | 0.156 |
+| | | | Cosine | 1.29 | 0.089 | 0.146 |
+| config_fbank80_stmn_res2net34w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 1.20 | 0.084 | 0.136 |
+| | | | Cosine | 1.18 | 0.078 | 0.115 |
+| config_fbank80_stmn_res2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 1.11 | 0.084 | 0.145 |
+| | | | Cosine | 1.12 | 0.073 | 0.131 |
+| config_fbank80_stmn_seres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | SE-Res2Net50 | se-r=16
ArcFace s=30/m=0.3 | PLDA | 1.53 | 0.104 | 0.189 |
+| | | | Cosine | 1.31 | 0.084 | 0.132 |
+| config_fbank80_stmn_tseres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-Res2Net50 | se-r=256
ArcFace s=30/m=0.3 | PLDA | 0.98 | 0.066 | 0.116 |
+| | | | Cosine | 1.12 | 0.071 | 0.103 |
+| config_fbank80_stmn_res2net50w13s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=13x8 | ArcFace s=30/m=0.3 | PLDA | 1.05 | 0.077 | 0.123 |
+| | | | Cosine | 0.96 | 0.065 | 0.110 |
+| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x8 | ArcFace s=30/m=0.3 | PLDA | 1.04 | 0.071 | 0.118 |
+| | | | Cosine | 0.93 | 0.067 | 0.108 |
+| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp.v1_swa.sh | Res2Net50 width=26x8 | + SWA | PLDA | 0.90 | 0.067 | 0.118 |
+| | | | Cosine | 0.85 | 0.060 | 0.094 |
+| config_fbank80_stmn_spinenet49s_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49S | ArcFace s=30/m=0.3 | PLDA | 1.44 | 0.102 | 0.169 |
+| | | | Cosine | 1.29 | 0.084 | 0.140 |
+| config_fbank80_stmn_spinenet49_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49 | ArcFace s=30/m=0.3 | Cosine | 1.12 | 0.071 | 0.116 |
+| config_fbank80_stmn_spine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 1.05 | 0.074 | 0.116 |
+| config_fbank80_stmn_tsespine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 1.09 | 0.081 | 0.150 |
+
+
+### VoxCeleb 1 Entire-Clean trial list
+
+| Config | Model Type | Model Details | Back-end | EER(%) | MinDCF(p=0.05) | MinDCF(p=0.01) |
+| ------ | ---------- | ------------- | -------- | :----: | :------------: | :------------: |
+| config_fbank80_stmn_lresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | LResNet34 | ArcFace s=30/m=0.3 | PLDA | 1.86 | 0.124 | 0.210 |
+| | | | Cosine | 1.93 | 0.122 | 0.201 |
+| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet34 | ArcFace s=30/m=0.3 | PLDA | 1.43 | 0.091 | 0.159 |
+| | | | Cosine | 1.24 | 0.080 | 0.136 |
+| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp_swa.v1.sh | ResNet34 | + SWA | Cosine | 1.19 | 0.077 | 0.132 |
+| config_fbank80_stmn_resnet50_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet50 | ArcFace s=30/m=0.3 | PLDA | 1.27 | 0.084 | 0.150 |
+| | | | Cosine | 1.30 | 0.082 | 0.150 |
+| config_fbank80_stmn_tseresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-ResNet34 | ArcFace s=30/m=0.3 | PLDA | 1.30 | 0.083 | 0.146 |
+| | | | Cosine | 1.09 | 0.071 | 0.124 |
+| config_fbank80_stmn_effnetb4_v2_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b4 v2 | EfficientNet-b4 with strides=1122121
ArcFace s=30/m=0.3 | 1.45 | 0.097 | 0.165 |
+| | | | Cosine | 1.15 | 0.076 | 0.132 |
+| config_fbank80_stmn_effnetb7_v2_eina_hln_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b7 v2 | EfficientNet-b7 with strides=1122121
Instance-Norm with affine transform in Encoder
Layer-Norm in head
ArcFace s=30/m=0.3 | 1.47 | 0.094 | 0.165 |
+| | | | Cosine | 1.27 | 0.082 | 0.148 |
+| config_fbank80_stmn_res2net34w16s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=16x4 | ArcFace s=30/m=0.3 | PLDA | 1.31 | 0.086 | 0.149 |
+| | | | Cosine | 1.22 | 0.079 | 0.134 |
+| config_fbank80_stmn_res2net34w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 1.27 | 0.082 | 0.145 |
+| | | | Cosine | 1.16 | 0.074 | 0.130 |
+| config_fbank80_stmn_res2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 1.23 | 0.077 | 0.136 |
+| | | | Cosine | 1.11 | 0.071 | 0.125 |
+| config_fbank80_stmn_seres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | SE-Res2Net50 | se-r=16
ArcFace s=30/m=0.3 | PLDA | 1.46 | 0.097 | 0.173 |
+| | | | Cosine | 1.24 | 0.080 | 0.140 |
+| config_fbank80_stmn_tseres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-Res2Net50 | se-r=256
ArcFace s=30/m=0.3 | PLDA | 1.11 | 0.071 | 0.127 |
+| | | | Cosine | 1.05 | 0.067 | 0.117 |
+| config_fbank80_stmn_res2net50w13s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=13x8 | ArcFace s=30/m=0.3 | PLDA | 1.23 | 0.078 | 0.134 |
+| | | | Cosine | 1.05 | 0.069 | 0.121 |
+| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x8 | ArcFace s=30/m=0.3 | PLDA | 1.18 | 0.075 | 0.131 |
+| | | | Cosine | 0.98 | 0.063 | 0.110 |
+| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp_swa.v1.sh | Res2Net50 width=26x8 | + SWA | PLDA | 1.17 | 0.072 | 0.123 |
+| | | | Cosine | 0.94 | 0.061 | 0.107 |
+| config_fbank80_stmn_spinenet49s_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49S | ArcFace s=30/m=0.3 | PLDA | 1.56 | 0.095 | 0.166 |
+| | | | Cosine | 1.27 | 0.079 | 0.142 |
+| config_fbank80_stmn_spinenet49_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49 | ArcFace s=30/m=0.3 | Cosine | 1.19 | 0.077 | 0.137 |
+| config_fbank80_stmn_spine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 1.12 | 0.073 | 0.129 |
+| config_fbank80_stmn_tsespine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | TSE-Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 1.05 | 0.068 | 0.120 |
+
+
+### VoxCeleb 1 Hard-Clean trial list
+
+| Config | Model Type | Model Details | Back-end | EER(%) | MinDCF(p=0.05) | MinDCF(p=0.01) |
+| ------ | ---------- | ------------- | -------- | :----: | :------------: | :------------: |
+| config_fbank80_stmn_lresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | LResNet34 | ArcFace s=30/m=0.3 | PLDA | 3.29 | 0.195 | 0.318 |
+| | | | Cosine | 3.27 | 0.188 | 0.303 |
+| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet34 | ArcFace s=30/m=0.3 | PLDA | 2.66 | 0.160 | 0.258 |
+| | | | Cosine | 2.32 | 0.139 | 0.232 |
+| config_fbank80_stmn_resnet34_arcs30m0.3_adam_lr0.05_amp_swa.v1.sh | ResNet34 | + SWA | Cosine | 2.19 | 0.133 | 0.215 |
+| config_fbank80_stmn_resnet50_arcs30m0.3_adam_lr0.05_amp.v1.sh | ResNet50 | ArcFace s=30/m=0.3 | PLDA | 2.33 | 0.139 | 0.227 |
+| | | | Cosine | 2.33 | 0.142 | 0.235 |
+| config_fbank80_stmn_tseresnet34_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-ResNet34 | ArcFace s=30/m=0.3 | PLDA | 2.46 | 0.142 | 0.237 |
+| | | | Cosine | 2.14 | 0.126 | 0.203 |
+| config_fbank80_stmn_effnetb4_v2_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b4 v2 | EfficientNet-b4 with strides=1122121
ArcFace s=30/m=0.3 | 2.57 | 0.153 | 0.255 |
+| | | | Cosine | 2.11 | 0.127 | 0.205 |
+| config_fbank80_stmn_effnetb7_v2_eina_hln_arcs30m0.3_adam_lr0.01_amp.v1.sh | EfficientNet-b7 v2 | EfficientNet-b7 with strides=1122121
Instance-Norm with affine transform in Encoder
Layer-Norm in head
ArcFace s=30/m=0.3 | 2.64 | 0.157 | 0.244 |
+| | | | Cosine | 2.33 | 0.141 | 0.232 |
+| config_fbank80_stmn_res2net34w16s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=16x4 | ArcFace s=30/m=0.3 | PLDA | 2.42 | 0.144 | 0.245 |
+| | | | Cosine | 2.26 | 0.133 | 0.224
+| config_fbank80_stmn_res2net34w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net34 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 2.39 | 0.141 | 0.235 |
+| | | | Cosine | 2.17 | 0.128 | 0.215
+| config_fbank80_stmn_res2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x4 | ArcFace s=30/m=0.3 | PLDA | 2.28 | 0.131 | 0.225 |
+| | | | Cosine | 2.11 | 0.124 | 0.204 |
+| config_fbank80_stmn_seres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | SE-Res2Net50 | se-r=16
ArcFace s=30/m=0.3 | PLDA | 2.77 | 0.172 | 0.271 |
+| | | | Cosine | 2.45 | 0.141 | 0.225 |
+| config_fbank80_stmn_tseres2net50w26s4_arcs30m0.3_adam_lr0.05_amp.v1.sh | Time-SE-Res2Net50 | se-r=256
ArcFace s=30/m=0.3 | PLDA | 2.07 | 0.124 | 0.201 |
+| | | | Cosine | 1.95 | 0.113 | 0.181 |
+| config_fbank80_stmn_res2net50w13s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=13x8 | ArcFace s=30/m=0.3 | PLDA | 2.34 | 0.136 | 0.230 |
+| | | | Cosine | 1.99 | 0.119 | 0.196 |
+| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp.v1.sh | Res2Net50 width=26x8 | ArcFace s=30/m=0.3 | PLDA | 2.18 | 0.127 | 0.211 |
+| | | | Cosine | 1.89 | 0.112 | 0.184 |
+| config_fbank80_stmn_res2net50w26s8_arcs30m0.3_adam_lr0.05_amp.v1_swa.sh | Res2Net50 width=26x8 | + SWA | PLDA | 2.14 | 0.125 | 0.209 |
+| | | | Cosine | 1.84 | 0.110 | 0.186 |
+| config_fbank80_stmn_spinenet49s_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49S | ArcFace s=30/m=0.3 | PLDA | 2.78 | 0.156 | 0.252 |
+| | | | Cosine | 2.26 | 0.134 | 0.214 |
+| config_fbank80_stmn_spinenet49_arcs30m0.3_adam_lr0.05_amp.v1.sh | SpineNet49 | ArcFace s=30/m=0.3 | Cosine | 2.24 | 0.134 | 0.221 |
+| config_fbank80_stmn_spine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 2.20 | 0.132 | 0.219 |
+| config_fbank80_stmn_tsespine2net49_arcs30m0.3_adam_lr0.05_amp.v1.sh | Spine2Net49 | ArcFace s=30/m=0.3 | Cosine | 2.02 | 0.123 | 0.203 |
diff --git a/egs/voxceleb/dinossl.v1/cmd.sh b/egs/voxceleb/dinossl.v1/cmd.sh
new file mode 100755
index 00000000..040f458b
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/cmd.sh
@@ -0,0 +1,28 @@
+# you can change cmd.sh depending on what type of queue you are using.
+# If you have no queueing system and want to run on a local machine, you
+# can change all instances 'queue.pl' to run.pl (but be careful and run
+# commands one by one: most recipes will exhaust the memory on your
+# machine). queue.pl works with GridEngine (qsub). slurm.pl works
+# with slurm. Different queues are configured differently, with different
+# queue names and different ways of specifying things like memory;
+# to account for these differences you can create and edit the file
+# conf/queue.conf to match your queue's configuration. Search for
+# conf/queue.conf in http://kaldi-asr.org/doc/queue.html for more information,
+# or search for the string 'default_config' in utils/queue.pl or utils/slurm.pl.
+
+if [ "$(hostname -d)" == "cm.gemini" ];then
+ #export train_cmd="queue.pl --config conf/coe_gpu_short.conf --mem 4G"
+ export train_cmd="queue.pl --config conf/coe_gpu_long.conf --mem 4G"
+ export cuda_cmd="queue.pl --config conf/coe_gpu_long.conf --mem 20G"
+ #export cuda_cmd="queue.pl --config conf/coe_gpu_v100.conf --mem 20G"
+ export cuda_cmd="queue.pl --config conf/coe_gpu_rtx.conf --mem 40G"
+ export cuda_eval_cmd="queue.pl --config conf/coe_gpu_short.conf --mem 4G"
+ # export cuda_eval_cmd="queue.pl --config conf/coe_gpu_long.conf --mem 4G"
+else
+ export train_cmd="queue.pl --mem 4G -l hostname=\"[bc][01]*\" -V"
+ export cuda_cmd="queue.pl --mem 20G -l hostname=\"c[01]*\" -V"
+ export cuda_eval_cmd="$train_cmd"
+fi
+
+
+
diff --git a/egs/voxceleb/dinossl.v1/conf/clsp.conf b/egs/voxceleb/dinossl.v1/conf/clsp.conf
new file mode 100644
index 00000000..4ed38246
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/clsp.conf
@@ -0,0 +1,11 @@
+
+# Default configuration
+command qsub -v PATH -cwd -S /bin/bash -j y -l arch=*64* -V
+option mem=* -l mem_free=$0,ram_free=$0
+option mem=0 # Do not add anything to qsub_opts
+option num_threads=* -pe smp $0
+option num_threads=1 # Do not add anything to qsub_opts
+option max_jobs_run=* -tc $0
+default gpu=0
+option gpu=0 -l 'hostname=b[1]*|c0[123456789]*|c1[134679]*|c2[1357]*'
+option gpu=* -l 'hostname=c0[123456789]*|c1[1345679]*|c2[12357]*,gpu=$0'
diff --git a/egs/voxceleb/dinossl.v1/conf/coe_gpu_bigmem.conf b/egs/voxceleb/dinossl.v1/conf/coe_gpu_bigmem.conf
new file mode 100644
index 00000000..a7a2ce40
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/coe_gpu_bigmem.conf
@@ -0,0 +1,11 @@
+
+# Default configuration
+command qsub -v PATH -cwd -S /bin/bash -j y -sync y -l arch=*64* -V
+option mem=* -l mem_free=$0
+option mem=0 # Do not add anything to qsub_opts
+option num_threads=* -l num_proc=$0
+option num_threads=1 # Do not add anything to qsub_opts
+option max_jobs_run=* -tc $0
+default gpu=0
+option gpu=0 -q all.q -l h_rt=100:00:00 -l hostname=r[2-7]*
+option gpu=* -l gpu=$0,h_rt=500:00:00 -q gpu.q -l hostname=r[237]n[01][0123456789]*
diff --git a/egs/voxceleb/dinossl.v1/conf/coe_gpu_long.conf b/egs/voxceleb/dinossl.v1/conf/coe_gpu_long.conf
new file mode 100644
index 00000000..b31c167c
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/coe_gpu_long.conf
@@ -0,0 +1,13 @@
+
+# Default configuration
+command qsub -v PATH -cwd -S /bin/bash -j y -sync y -l arch=*64* -V
+option mem=* -l mem_free=$0
+option mem=0 # Do not add anything to qsub_opts
+option num_threads=* -l num_proc=$0
+option num_threads=1 # Do not add anything to qsub_opts
+option max_jobs_run=* -tc $0
+default gpu=0
+option gpu=0 -q all.q -l h_rt=100:00:00 -l hostname=r[1-9]*
+option gpu=* -l gpu=$0,h_rt=500:00:00 -q gpu.q -l hostname=r[1-9]*
+
+
diff --git a/egs/voxceleb/dinossl.v1/conf/coe_gpu_rtx.conf b/egs/voxceleb/dinossl.v1/conf/coe_gpu_rtx.conf
new file mode 100644
index 00000000..ba6d9e56
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/coe_gpu_rtx.conf
@@ -0,0 +1,11 @@
+
+# Default configuration
+command qsub -v PATH -cwd -S /bin/bash -j y -sync y -l arch=*64* -V
+option mem=* -l mem_free=$0
+option mem=0 # Do not add anything to qsub_opts
+option num_threads=* -l num_proc=$0
+option num_threads=1 # Do not add anything to qsub_opts
+option max_jobs_run=* -tc $0
+default gpu=0
+option gpu=0 -q all.q -l h_rt=100:00:00
+option gpu=* -l gpu=$0,h_rt=500:00:00 -q gpu.q@@rtx
diff --git a/egs/voxceleb/dinossl.v1/conf/coe_gpu_short.conf b/egs/voxceleb/dinossl.v1/conf/coe_gpu_short.conf
new file mode 100644
index 00000000..81de5cb7
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/coe_gpu_short.conf
@@ -0,0 +1,11 @@
+
+# Default configuration
+command qsub -v PATH -cwd -S /bin/bash -j y -sync y -l arch=*64* -V
+option mem=* -l mem_free=$0
+option mem=0 # Do not add anything to qsub_opts
+option num_threads=* -l num_proc=$0
+option num_threads=1 # Do not add anything to qsub_opts
+option max_jobs_run=* -tc $0
+default gpu=0
+option gpu=0 -q all.q -l h_rt=100:00:00 -l hostname=r[1-9]*
+option gpu=* -l gpu=$0,h_rt=00:59:00 -q gpu_short.q -l hostname=r[17]*
diff --git a/egs/voxceleb/dinossl.v1/conf/coe_gpu_v100.conf b/egs/voxceleb/dinossl.v1/conf/coe_gpu_v100.conf
new file mode 100644
index 00000000..69326b82
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/coe_gpu_v100.conf
@@ -0,0 +1,11 @@
+
+# Default configuration
+command qsub -v PATH -cwd -S /bin/bash -j y -sync y -l arch=*64* -V
+option mem=* -l mem_free=$0
+option mem=0 # Do not add anything to qsub_opts
+option num_threads=* -l num_proc=$0
+option num_threads=1 # Do not add anything to qsub_opts
+option max_jobs_run=* -tc $0
+default gpu=0
+option gpu=0 -q all.q -l h_rt=100:00:00
+option gpu=* -l gpu=$0,h_rt=500:00:00 -q gpu.q@@v100
diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/fbank80_stmn_16k.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/fbank80_stmn_16k.yaml
new file mode 100644
index 00000000..f4091f5d
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/fbank80_stmn_16k.yaml
@@ -0,0 +1,12 @@
+audio_feats:
+ audio_feat: logfb
+ sample_frequency: 16000
+ frame_length: 25
+ low_freq: 20
+ high_freq: 7600
+ num_filters: 80
+ snip_edges: false
+ use_energy: false
+mvn:
+ context: 150
+ norm_var: false
diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/lrsched_cos_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/lrsched_cos_default.yaml
new file mode 100644
index 00000000..6f1009c0
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/lrsched_cos_default.yaml
@@ -0,0 +1,7 @@
+lrsch_type: dinossl
+dinossl_lr: 0.005 # For now, the learning rate is linearly scaled with the batch size. What's specified here is for the batch size of 256. (i.e., if ngpu * batch_size_per_gpu = 512, the lr becomes 0.005 * 2)
+dinossl_min_lr: 1e-6
+dinossl_warmup_epochs: 10
+dinossl_weight_decay: 1e-4
+dinossl_weight_decay_end: 1e-4
+dinossl_momentum_teacher: 0.996
diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/optim_adam_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/optim_adam_default.yaml
new file mode 100644
index 00000000..8362a9b5
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/optim_adam_default.yaml
@@ -0,0 +1,6 @@
+opt_type: adam
+amsgrad: true
+beta1: 0.9
+beta2: 0.95
+weight_decay: 1.0e-05
+dinossl_style: true # (TODO: check) I think the above arguments are NOT used. If not, delete them all
diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/resnet34.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/resnet34.yaml
new file mode 100644
index 00000000..4042ec33
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/resnet34.yaml
@@ -0,0 +1,8 @@
+resnet_type: resnet34
+in_channels: 1
+in_feats: 80
+in_kernel_size: 3
+in_stride: 1
+no_maxpool: true
+embed_dim: 256
+dropout_rate: 0.0
diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_data_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_data_default.yaml
new file mode 100644
index 00000000..6a763f47
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_data_default.yaml
@@ -0,0 +1,10 @@
+dataset:
+ max_chunk_length: 2.0
+ min_chunk_length: 2.0
+ aug_cfg: conf/reverb_noise_aug.yaml
+sampler:
+ batch_size: 48 # 52:64:4 OOM w/ lresnet34 in this version of hyperion (It wasn't in the old version).
+ iters_per_epoch: 1
+data_loader:
+ num_workers: 8
+
diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_resnet34_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_resnet34_xvec_default.yaml
new file mode 100644
index 00000000..1d387790
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/train_resnet34_xvec_default.yaml
@@ -0,0 +1,7 @@
+data:
+ train: train_data_default.yaml
+ val: val_data_default.yaml
+feats: fbank80_stmn_16k.yaml
+model: resnet34.yaml
+trainer: trainer_default.yaml
+
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/trainer_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/trainer_default.yaml
new file mode 100644
index 00000000..279b2829
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/trainer_default.yaml
@@ -0,0 +1,6 @@
+optim: optim_adam_default.yaml
+lrsched: lrsched_cos_default.yaml
+use_amp: true
+log_interval: 100
+epochs: 70
+eff_batch_size: 128 # For now, this won't be used in dinossl
diff --git a/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/val_data_default.yaml b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/val_data_default.yaml
new file mode 100644
index 00000000..127af82c
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/dinossl_tuning/val_data_default.yaml
@@ -0,0 +1,10 @@
+dataset:
+ max_chunk_length: 2.0
+ min_chunk_length: 2.0
+ aug_cfg: conf/reverb_noise_aug.yaml
+sampler:
+ batch_size: 32
+ iters_per_epoch: 1
+data_loader:
+ num_workers: 0 #8
+
diff --git a/egs/voxceleb/dinossl.v1/conf/ecapatdnn_small.yaml b/egs/voxceleb/dinossl.v1/conf/ecapatdnn_small.yaml
new file mode 100644
index 00000000..fd386500
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/ecapatdnn_small.yaml
@@ -0,0 +1,34 @@
+resnet_enc:
+ in_feats: 80
+ in_conv_channels: 512
+ in_kernel_size: 5
+ in_stride: 1
+ resb_type: seres2bn
+ resb_repeats:
+ - 1
+ - 1
+ - 1
+ resb_channels:
+ - 512
+ resb_kernel_sizes:
+ - 3
+ resb_dilations:
+ - 2
+ - 3
+ - 4
+ resb_strides:
+ - 1
+ res2net_width_factor: 1
+ res2net_scale: 8
+ se_r: 4
+ multilayer: true
+ multilayer_concat: true
+ endpoint_channels: 1536
+pool_net:
+ pool_type: ch-wise-att-mean+stddev
+ inner_feats: 128
+embed_dim: 256
+cos_scale: 30.0
+margin: 0.3
+margin_warmup_epochs: 20.0
+dropout_rate: 0.0
diff --git a/egs/voxceleb/dinossl.v1/conf/efficientnet_b4.yaml b/egs/voxceleb/dinossl.v1/conf/efficientnet_b4.yaml
new file mode 100644
index 00000000..f87c1e02
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/efficientnet_b4.yaml
@@ -0,0 +1,20 @@
+effnet_type: efficientnet-b4
+in_feats: 80
+in_channels: 1
+in_kernel_size: 3
+in_stride: 1
+se_r: 4
+fix_stem_head: true
+mbconv_strides:
+- 1
+- 1
+- 2
+- 2
+- 1
+- 2
+- 1
+embed_dim: 256
+cos_scale: 30.0
+margin: 0.3
+margin_warmup_epochs: 20.0
+dropout_rate: 0.0
diff --git a/egs/voxceleb/dinossl.v1/conf/efficientnet_b7.yaml b/egs/voxceleb/dinossl.v1/conf/efficientnet_b7.yaml
new file mode 100644
index 00000000..bae5c7cb
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/efficientnet_b7.yaml
@@ -0,0 +1,22 @@
+effnet_type: efficientnet-b7
+in_feats: 80
+in_channels: 1
+in_kernel_size: 3
+in_stride: 1
+se_r: 4
+fix_stem_head: true
+mbconv_strides:
+- 1
+- 1
+- 2
+- 2
+- 1
+- 2
+- 1
+embed_dim: 256
+cos_scale: 30.0
+margin: 0.3
+margin_warmup_epochs: 20.0
+dropout_rate: 0.0
+norm_layer: instance-norm-affine
+head_norm_layer: layer-norm
diff --git a/egs/voxceleb/dinossl.v1/conf/fbank64_8k.yaml b/egs/voxceleb/dinossl.v1/conf/fbank64_8k.yaml
new file mode 100644
index 00000000..a77eb899
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/fbank64_8k.yaml
@@ -0,0 +1,7 @@
+sample_frequency: 8000
+frame_length: 25
+low_freq: 20
+high_freq: 3700
+num_filters: 64
+snip_edges: false
+use_energy: false
diff --git a/egs/voxceleb/dinossl.v1/conf/fbank64_stmn_8k.yaml b/egs/voxceleb/dinossl.v1/conf/fbank64_stmn_8k.yaml
new file mode 100644
index 00000000..dfd0d3e5
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/fbank64_stmn_8k.yaml
@@ -0,0 +1,12 @@
+audio_feats:
+ audio_feat: logfb
+ sample_frequency: 8000
+ frame_length: 25
+ low_freq: 20
+ high_freq: 3700
+ num_filters: 64
+ snip_edges: false
+ use_energy: false
+mvn:
+ context: 150
+ norm_var: false
diff --git a/egs/voxceleb/dinossl.v1/conf/fbank80_16k.yaml b/egs/voxceleb/dinossl.v1/conf/fbank80_16k.yaml
new file mode 100644
index 00000000..88bae69e
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/fbank80_16k.yaml
@@ -0,0 +1,7 @@
+sample_frequency: 16000
+frame_length: 25
+low_freq: 20
+high_freq: 7600
+num_filters: 80
+snip_edges: false
+use_energy: false
diff --git a/egs/voxceleb/dinossl.v1/conf/fbank80_stmn_16k.yaml b/egs/voxceleb/dinossl.v1/conf/fbank80_stmn_16k.yaml
new file mode 100644
index 00000000..f4091f5d
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/fbank80_stmn_16k.yaml
@@ -0,0 +1,12 @@
+audio_feats:
+ audio_feat: logfb
+ sample_frequency: 16000
+ frame_length: 25
+ low_freq: 20
+ high_freq: 7600
+ num_filters: 80
+ snip_edges: false
+ use_energy: false
+mvn:
+ context: 150
+ norm_var: false
diff --git a/egs/voxceleb/dinossl.v1/conf/lrsched_exp_default.yaml b/egs/voxceleb/dinossl.v1/conf/lrsched_exp_default.yaml
new file mode 100644
index 00000000..fe08b704
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/lrsched_exp_default.yaml
@@ -0,0 +1,7 @@
+lrsch_type: exp_lr
+decay_rate: 0.5
+decay_steps: 8000
+hold_steps: 40000
+min_lr: 1.0e-05
+update_lr_on_opt_step: true
+warmup_steps: 1000
diff --git a/egs/voxceleb/dinossl.v1/conf/noise_aug.yaml b/egs/voxceleb/dinossl.v1/conf/noise_aug.yaml
new file mode 100644
index 00000000..7e575faf
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/noise_aug.yaml
@@ -0,0 +1,19 @@
+noise_aug:
+ noise_prob: 0.7
+ noise_types:
+ noise:
+ weight: 1
+ noise_path: data/musan_noise_proc_audio/wav.scp
+ min_snr: 0
+ max_snr: 18
+ music:
+ weight: 1
+ noise_path: data/musan_music_proc_audio/wav.scp
+ min_snr: 3
+ max_snr: 18
+ babble:
+ weight: 1
+ noise_path: data/musan_speech_babble/wav.scp
+ min_snr: 3
+ max_snr: 18
+
diff --git a/egs/voxceleb/dinossl.v1/conf/online_pitch.conf b/egs/voxceleb/dinossl.v1/conf/online_pitch.conf
new file mode 100644
index 00000000..926bcfca
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/online_pitch.conf
@@ -0,0 +1 @@
+--sample-frequency=8000
diff --git a/egs/voxceleb/dinossl.v1/conf/optim_adam_default.yaml b/egs/voxceleb/dinossl.v1/conf/optim_adam_default.yaml
new file mode 100644
index 00000000..b6620069
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/optim_adam_default.yaml
@@ -0,0 +1,6 @@
+opt_type: adam
+lr: 0.05
+amsgrad: true
+beta1: 0.9
+beta2: 0.95
+weight_decay: 1.0e-05
diff --git a/egs/voxceleb/dinossl.v1/conf/res2net50.yaml b/egs/voxceleb/dinossl.v1/conf/res2net50.yaml
new file mode 100644
index 00000000..48067a3d
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/res2net50.yaml
@@ -0,0 +1,13 @@
+resnet_type: res2net50
+in_channels: 1
+in_feats: 80
+in_kernel_size: 3
+in_stride: 1
+no_maxpool: true
+res2net_width_factor: 3.25
+res2net_scale: 8
+embed_dim: 256
+cos_scale: 30.0
+margin: 0.3
+margin_warmup_epochs: 20.0
+dropout_rate: 0.0
diff --git a/egs/voxceleb/dinossl.v1/conf/resnet34.yaml b/egs/voxceleb/dinossl.v1/conf/resnet34.yaml
new file mode 100644
index 00000000..98695823
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/resnet34.yaml
@@ -0,0 +1,11 @@
+resnet_type: resnet34
+in_channels: 1
+in_feats: 80
+in_kernel_size: 3
+in_stride: 1
+no_maxpool: true
+embed_dim: 256
+cos_scale: 30.0
+margin: 0.3
+margin_warmup_epochs: 20.0
+dropout_rate: 0.0
diff --git a/egs/voxceleb/dinossl.v1/conf/reverb_noise_aug.yaml b/egs/voxceleb/dinossl.v1/conf/reverb_noise_aug.yaml
new file mode 100644
index 00000000..4fdf8068
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/reverb_noise_aug.yaml
@@ -0,0 +1,35 @@
+reverb_aug:
+ reverb_prob: 0.45
+ max_reverb_context: 0.5
+ rir_types:
+ smallroom:
+ weight: 1
+ rir_path: scp:data/rirs_smallroom/rirs.scp
+ rir_norm: max
+ mediumroom:
+ weight: 1
+ rir_path: scp:data/rirs_mediumroom/rirs.scp
+ rir_norm: max
+ realroom:
+ weight: 1
+ rir_path: scp:data/rirs_real/rirs.scp
+ rir_norm: max
+noise_aug:
+ noise_prob: 0.7
+ noise_types:
+ noise:
+ weight: 1
+ noise_path: data/musan_noise_proc_audio/wav.scp
+ min_snr: 0
+ max_snr: 18
+ music:
+ weight: 1
+ noise_path: data/musan_music_proc_audio/wav.scp
+ min_snr: 3
+ max_snr: 18
+ babble:
+ weight: 1
+ noise_path: data/musan_speech_babble/wav.scp
+ min_snr: 3
+ max_snr: 18
+
diff --git a/egs/voxceleb/dinossl.v1/conf/spinenet49.yaml b/egs/voxceleb/dinossl.v1/conf/spinenet49.yaml
new file mode 100644
index 00000000..66b8d517
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/spinenet49.yaml
@@ -0,0 +1,11 @@
+spinenet_type: spinenet49
+in_channels: 1
+in_feats: 80
+in_kernel_size: 3
+in_stride: 1
+no_maxpool: true
+embed_dim: 256
+cos_scale: 30.0
+margin: 0.3
+margin_warmup_epochs: 20.0
+dropout_rate: 0.0
diff --git a/egs/voxceleb/dinossl.v1/conf/train_data_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_data_default.yaml
new file mode 100644
index 00000000..451ffa35
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/train_data_default.yaml
@@ -0,0 +1,10 @@
+dataset:
+ max_chunk_length: 4.0
+ min_chunk_length: 4.0
+ aug_cfg: conf/reverb_noise_aug.yaml
+sampler:
+ batch_size: 32
+ iters_per_epoch: 6
+data_loader:
+ num_workers: 8
+
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/conf/train_ecapatdnn_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_ecapatdnn_xvec_default.yaml
new file mode 100644
index 00000000..46298946
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/train_ecapatdnn_xvec_default.yaml
@@ -0,0 +1,7 @@
+data:
+ train: train_data_default.yaml
+ val: val_data_default.yaml
+feats: fbank80_stmn_16k.yaml
+model: ecapatdnn_small.yaml
+trainer: trainer_default.yaml
+
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/conf/train_effnetb4_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_effnetb4_xvec_default.yaml
new file mode 100644
index 00000000..1bc74de6
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/train_effnetb4_xvec_default.yaml
@@ -0,0 +1,7 @@
+data:
+ train: train_data_default.yaml
+ val: val_data_default.yaml
+feats: fbank80_stmn_16k.yaml
+model: efficientnet_b4.yaml
+trainer: trainer_default.yaml
+
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/conf/train_res2net50_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_res2net50_xvec_default.yaml
new file mode 100644
index 00000000..1d387790
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/train_res2net50_xvec_default.yaml
@@ -0,0 +1,7 @@
+data:
+ train: train_data_default.yaml
+ val: val_data_default.yaml
+feats: fbank80_stmn_16k.yaml
+model: resnet34.yaml
+trainer: trainer_default.yaml
+
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/conf/train_resnet34_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_resnet34_xvec_default.yaml
new file mode 100644
index 00000000..1d387790
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/train_resnet34_xvec_default.yaml
@@ -0,0 +1,7 @@
+data:
+ train: train_data_default.yaml
+ val: val_data_default.yaml
+feats: fbank80_stmn_16k.yaml
+model: resnet34.yaml
+trainer: trainer_default.yaml
+
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/conf/train_spinenet49_xvec_default.yaml b/egs/voxceleb/dinossl.v1/conf/train_spinenet49_xvec_default.yaml
new file mode 100644
index 00000000..07167987
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/train_spinenet49_xvec_default.yaml
@@ -0,0 +1,7 @@
+data:
+ train: train_data_default.yaml
+ val: val_data_default.yaml
+feats: fbank80_stmn_16k.yaml
+model: spinenet49.yaml
+trainer: trainer_default.yaml
+
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/conf/trainer_default.yaml b/egs/voxceleb/dinossl.v1/conf/trainer_default.yaml
new file mode 100644
index 00000000..86dcc2e4
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/trainer_default.yaml
@@ -0,0 +1,6 @@
+optim: optim_adam_default.yaml
+lrsched: lrsched_exp_default.yaml
+use_amp: true
+log_interval: 1000
+epochs: 70
+eff_batch_size: 512
diff --git a/egs/voxceleb/dinossl.v1/conf/trainer_swa_default.yaml b/egs/voxceleb/dinossl.v1/conf/trainer_swa_default.yaml
new file mode 100644
index 00000000..0cafad01
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/trainer_swa_default.yaml
@@ -0,0 +1,9 @@
+optim: optim_adam_default.yaml
+lrsched: lrsched_exp_default.yaml
+use_amp: true
+log_interval: 1000
+epochs: 80
+eff_batch_size: 512
+swa_start: 60
+swa_lr: 1e-3
+swa_anneal_epochs: 5
diff --git a/egs/voxceleb/dinossl.v1/conf/vad_16k.yaml b/egs/voxceleb/dinossl.v1/conf/vad_16k.yaml
new file mode 100644
index 00000000..5fb0111c
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/vad_16k.yaml
@@ -0,0 +1,8 @@
+sample_frequency: 16000
+frame_shift: 10
+frame_length: 25
+snip_edges: false
+vad_energy_threshold: 5.5
+vad_energy_mean_scale: 0.5
+vad_proportion_threshold: 0.12
+vad_frames_context: 2
diff --git a/egs/voxceleb/dinossl.v1/conf/vad_8k.yaml b/egs/voxceleb/dinossl.v1/conf/vad_8k.yaml
new file mode 100644
index 00000000..7592c9d1
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/vad_8k.yaml
@@ -0,0 +1,8 @@
+sample_frequency: 8000
+frame_shift: 10
+frame_length: 25
+snip_edges: false
+vad_energy_threshold: 5.5
+vad_energy_mean_scale: 0.5
+vad_proportion_threshold: 0.12
+vad_frames_context: 2
diff --git a/egs/voxceleb/dinossl.v1/conf/val_data_default.yaml b/egs/voxceleb/dinossl.v1/conf/val_data_default.yaml
new file mode 100644
index 00000000..451ffa35
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/conf/val_data_default.yaml
@@ -0,0 +1,10 @@
+dataset:
+ max_chunk_length: 4.0
+ min_chunk_length: 4.0
+ aug_cfg: conf/reverb_noise_aug.yaml
+sampler:
+ batch_size: 32
+ iters_per_epoch: 6
+data_loader:
+ num_workers: 8
+
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/datapath.sh b/egs/voxceleb/dinossl.v1/datapath.sh
new file mode 100644
index 00000000..a7d277f4
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/datapath.sh
@@ -0,0 +1,22 @@
+# Copyright
+# 2018 Johns Hopkins University (Author: Jesus Villalba)
+#
+# Paths to the databases used in the experiment
+
+
+if [ "$(hostname --domain)" == "clsp.jhu.edu" ];then
+ voxceleb1_root=/export/corpora5/VoxCeleb1_v1 #voxceleb1 v1
+ # voxceleb1_root=/export/corpora5/VoxCeleb1_v2 #voxceleb1 v2
+ voxceleb2_root=/export/corpora5/VoxCeleb2
+ musan_root=/export/corpora5/JHU/musan
+elif [ "$(hostname --domain)" == "cm.gemini" ];then
+ # voxceleb1_root=/expscratch/dsnyder/VoxCeleb1 #voxceleb1 v1
+ voxceleb1_root=/exp/jvillalba/corpora/voxceleb1 #voxceleb1 v2
+ voxceleb2_root=/expscratch/dgromero/corpora-open/vox2
+ musan_root=/expscratch/dgromero/corpora-open/musan
+else
+ echo "Put your database paths here"
+ exit 1
+fi
+
+
diff --git a/egs/voxceleb/dinossl.v1/default_config.sh b/egs/voxceleb/dinossl.v1/default_config.sh
new file mode 120000
index 00000000..a4876a2e
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/default_config.sh
@@ -0,0 +1 @@
+global_conf/dinossl_tuning/config_fbank80_stmn_lresnet34_e256_do0_b96_amp.dinossl.v1.sh
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/global_conf/dinossl_tuning/config_fbank80_stmn_lresnet34_e256_do0_b96_amp.dinossl.v1.sh b/egs/voxceleb/dinossl.v1/global_conf/dinossl_tuning/config_fbank80_stmn_lresnet34_e256_do0_b96_amp.dinossl.v1.sh
new file mode 100644
index 00000000..6c86e285
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/global_conf/dinossl_tuning/config_fbank80_stmn_lresnet34_e256_do0_b96_amp.dinossl.v1.sh
@@ -0,0 +1,63 @@
+# LResNet34 x-vector with mixed precision training
+# Some variables defined here are just used to define nnet_name, which is better to be fixed
+
+nnet_name_tag="" # to manage file names for expdir. For example, utilize this when running multiple exps for hyp. para. tuning
+
+# acoustic features
+feat_config=conf/fbank80_stmn_16k.yaml
+feat_type=fbank80_stmn # just to define nnet_name. When changing this, fix the part in ${xvec_train_base_cfg} too to be applied in the actual training setup
+
+#vad
+vad_config=conf/vad_16k.yaml
+
+# x-vector training
+nnet_data=voxceleb2_train
+
+# x-vector cfg
+
+nnet_type=resnet
+
+resnet_type=lresnet34
+batch_size_1gpu=48 # 52:64:4 OOM
+ngpu=2
+eff_batch_size=`expr $batch_size_1gpu \* $ngpu` # In dinossl, eff_batch_size is the same as ngpu * batch_size_1gpu since grad_acc_steps is always 1 in dinossl for now. Thus, when eff_batch_size changes, instead of changing grad_acc_steps w/ a fixed lr, lr (base_value in cosine_scheduler, to be exact) is adjusted linearly proportional to eff_batch_size where the base value is 0.005 as as a default w/ eff_batch_size=256. For example, if eff_batch_size=128, the base value is 0.0025 in dinossl. eff_batch_size is calculated in python scripts but one here is to compose nnet_name below. # just to define nnet_name. When changing this, fix the part in ${xvec_train_base_cfg} too to be applied in the actual training setup
+dropout=0 # just to define nnet_name. When changing this, fix the part in ${xvec_train_base_cfg} too to be applied in the actual training setup
+embed_dim=256 # just to define nnet_name. When changing this, fix the part in ${xvec_train_base_cfg} too to be applied in the actual training setup
+
+xvec_train_base_cfg=conf/dinossl_tuning/train_resnet34_xvec_default.yaml
+xvec_train_args="--data.train.sampler.batch-size $batch_size_1gpu --model.resnet-type $resnet_type"
+
+# dinossl related (in addition to ones defined in xvec_train_base_cfg): dataset/dataloader, model/loss
+## dino-head
+dinossl_out_dim=65536
+dinossl_use_bn_in_head=false
+dinossl_norm_last_layer=true
+## data-augmentation
+dinossl_local_crops_number=4
+## teacher temperature
+dinossl_warmup_teacher_temp=0.04
+dinossl_teacher_temp=0.04
+dinossl_warmup_teacher_temp_epochs=0
+## chunk sampling related
+dinossl_chunk_len_mult=2 # a factor by which long chunk length increases from short chunk length. The short chunk length is determined randomly between min_chunk and max_chunk set above
+
+nnet_name=${feat_type}_${resnet_type}_e${embed_dim}_do${dropout}_b${eff_batch_size}_amp.dinossl.v1
+if [[ -n ${nnet_name_tag} ]]; then
+ nnet_name=${nnet_name}_${nnet_name_tag}
+fi
+
+
+nnet_dir=exp/xvector_nnets/$nnet_name
+nnet=$nnet_dir/model_ep0070.pth
+
+
+# back-end
+state_dict_key=model_teacher_state_dict
+plda_num_augs=0
+plda_data=voxceleb1_test_train
+plda_type=splda
+lda_dim=200
+plda_y_dim=150
+plda_z_dim=200
+
+
diff --git a/egs/voxceleb/dinossl.v1/hyp_utils b/egs/voxceleb/dinossl.v1/hyp_utils
new file mode 120000
index 00000000..f6d1eb7a
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/hyp_utils
@@ -0,0 +1 @@
+../../../hyp_utils
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/list_run.sh b/egs/voxceleb/dinossl.v1/list_run.sh
new file mode 100644
index 00000000..8fcdf759
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/list_run.sh
@@ -0,0 +1,11 @@
+# Why this is directory is created: To run DINO-based SSL for utterance-level embedding learning
+
+# How this directory is created:
+## 1. Run below in the parent directory:
+## cp -r v1.1 dinossl.v1
+## 2. STOPGAP: To train a model in the CLSP grid W/O data prep. The data prep. part (before run_011*) will be updated after updating training-related codes first.
+## ln -s /export/c01/jcho/hyperion_DINO/egs/voxceleb/v2/dataa data (for dinossl)
+
+# (WIP) training script
+## (GOOD) bash run_511_train_xvector.sh --ngpu 1
+## (TODO: for multiple gpus)
diff --git a/egs/voxceleb/dinossl.v1/local b/egs/voxceleb/dinossl.v1/local
new file mode 120000
index 00000000..740b697d
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/local
@@ -0,0 +1 @@
+../v1/local/
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/path.sh b/egs/voxceleb/dinossl.v1/path.sh
new file mode 100755
index 00000000..6994fdab
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/path.sh
@@ -0,0 +1,5 @@
+
+export HYP_ROOT=$(readlink -f `pwd -P`/../../..)
+export TOOLS_ROOT=$HYP_ROOT/tools
+
+. $TOOLS_ROOT/path.sh
diff --git a/egs/voxceleb/dinossl.v1/run_001_prepare_data.sh b/egs/voxceleb/dinossl.v1/run_001_prepare_data.sh
new file mode 100755
index 00000000..5ff5a9d6
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/run_001_prepare_data.sh
@@ -0,0 +1,31 @@
+#!/bin/bash
+# Copyright
+# 2018 Johns Hopkins University (Author: Jesus Villalba)
+# Apache 2.0.
+#
+. ./cmd.sh
+. ./path.sh
+set -e
+
+stage=1
+config_file=default_config.sh
+
+. parse_options.sh || exit 1;
+. datapath.sh
+
+
+if [ $stage -le 1 ];then
+ # Prepare the VoxCeleb2 dataset for training.
+ local/make_voxceleb2.pl $voxceleb2_root dev 16 data/voxceleb2_train
+fi
+
+if [ $stage -le 2 ];then
+ # prepare voxceleb1 for test
+ # This script is for the old version of the dataset and NOT processing LANG,
+ # GENDER, NAT
+ local/make_voxceleb1_oldversion_oeh.pl $voxceleb1_root data
+ # This script is for the old version of the dataset:
+ # local/make_voxceleb1_oeh.pl $voxceleb1_root data
+ # Use this for the newer version of voxceleb1:
+ # local/make_voxceleb1_v2_oeh.pl $voxceleb1_root data
+fi
diff --git a/egs/voxceleb/dinossl.v1/run_002_compute_evad.sh b/egs/voxceleb/dinossl.v1/run_002_compute_evad.sh
new file mode 100755
index 00000000..44153358
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/run_002_compute_evad.sh
@@ -0,0 +1,57 @@
+#!/bin/bash
+# Copyright
+# 2018 Johns Hopkins University (Author: Jesus Villalba)
+# Apache 2.0.
+#
+. ./cmd.sh
+. ./path.sh
+set -e
+nodes=fs01
+storage_name=$(date +'%m_%d_%H_%M')
+vaddir=`pwd`/exp/vad_e
+vad_config=conf/vad_16k.yaml
+
+stage=1
+config_file=default_config.sh
+
+. parse_options.sh || exit 1;
+. $config_file
+
+
+if [ $stage -le 1 ]; then
+ # Prepare to distribute data over multiple machines
+ if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $vaddir/storage ]; then
+ dir_name=$USER/hyp-data/voxceleb/dinossl.v1/$storage_name/vad/storage
+ if [ "$nodes" == "b0" ];then
+ utils/create_split_dir.pl \
+ utils/create_split_dir.pl \
+ /export/b{04,05,06,07}/$dir_name $vaddir/storage
+ elif [ "$nodes" == "b1" ];then
+ utils/create_split_dir.pl \
+ /export/b{14,15,16,17}/$dir_name $vaddir/storage
+ elif [ "$nodes" == "c0" ];then
+ utils/create_split_dir.pl \
+ /export/c{06,07,08,09}/$dir_name $vaddir/storage
+ elif [ "$nodes" == "fs01" ];then
+ utils/create_split_dir.pl \
+ /export/fs01/$dir_name $vaddir/storage
+ else
+ echo "we don't distribute data between multiple machines"
+ fi
+ fi
+fi
+
+#Train datasets
+if [ $stage -le 2 ];then
+ for name in voxceleb2_train voxceleb1_test
+ do
+ num_spk=$(wc -l data/$name/spk2utt | awk '{ print $1}')
+ nj=$(($num_spk < 40 ? $num_spk:40))
+ hyp_utils/feats/make_evad.sh --write-utt2num-frames true \
+ --vad-config $vad_config --nj $nj --cmd "$train_cmd" \
+ data/${name} exp/make_vad/$name $vaddir
+ utils/fix_data_dir.sh data/${name}
+ done
+fi
+
+
diff --git a/egs/voxceleb/dinossl.v1/run_003_prepare_noises_rirs.sh b/egs/voxceleb/dinossl.v1/run_003_prepare_noises_rirs.sh
new file mode 100755
index 00000000..4297f7fb
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/run_003_prepare_noises_rirs.sh
@@ -0,0 +1,65 @@
+#!/bin/bash
+# Copyright
+# 2020 Johns Hopkins University (Author: Jesus Villalba)
+# Apache 2.0.
+#
+. ./cmd.sh
+. ./path.sh
+set -e
+
+stage=1
+. parse_options.sh || exit 1;
+. datapath.sh
+
+# We prepare the noise files and RIR for online speech augmentation
+
+if [ $stage -le 1 ]; then
+
+ # Prepare the MUSAN corpus, which consists of music, speech, and noise
+ # suitable for augmentation.
+ local/make_musan.sh $musan_root 16 data
+
+ for name in musan_noise musan_music
+ do
+ steps_xvec/preprocess_audios_for_nnet_train.sh --nj 10 --cmd "$train_cmd" \
+ --storage_name voxceleb-dinossl.v1-$(date +'%m_%d_%H_%M') \
+ data/${name} data/${name}_proc_audio exp/${name}_proc_audio
+ utils/fix_data_dir.sh data/${name}_proc_audio
+ done
+
+fi
+
+if [ $stage -le 2 ]; then
+
+ # Create Babble noise from MUSAN speech files
+ for name in musan_speech
+ do
+ steps_xvec/make_babble_noise_for_nnet_train.sh --cmd "$train_cmd" \
+ --storage_name voxceleb-dinossl.v1-$(date +'%m_%d_%H_%M') \
+ data/${name} data/${name}_babble exp/${name}_babble
+ # utils/fix_data_dir.sh data/${name}_babble
+ done
+fi
+
+if [ $stage -le 3 ]; then
+ if [ ! -d "RIRS_NOISES" ]; then
+ if [ -d ../../sre19-cmn2/v1/RIRS_NOISES ];then
+ ln -s ../../sre19-cmn2/v1/RIRS_NOISES
+ else
+ # Download the package that includes the real RIRs, simulated RIRs, isotropic noises and point-source noises
+ wget --no-check-certificate http://www.openslr.org/resources/28/rirs_noises.zip
+ unzip rirs_noises.zip
+ fi
+ fi
+ local/make_rirs_data.sh RIRS_NOISES/simulated_rirs/smallroom 16 data/rirs_smallroom
+ local/make_rirs_data.sh RIRS_NOISES/simulated_rirs/mediumroom 16 data/rirs_mediumroom
+ local/make_rirs_data.sh RIRS_NOISES/real_rirs_isotropic_noises 16 data/rirs_real
+ for rirs in rirs_smallroom rirs_mediumroom rirs_real
+ do
+ #pack all rirs in h5 files
+ steps_xvec/pack_rirs_for_nnet_train.sh data/$rirs data/$rirs exp/rirs/$rirs
+ done
+
+fi
+
+
diff --git a/egs/voxceleb/dinossl.v1/run_010_prepare_xvec_train_data.sh b/egs/voxceleb/dinossl.v1/run_010_prepare_xvec_train_data.sh
new file mode 100755
index 00000000..fa643852
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/run_010_prepare_xvec_train_data.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+# Copyright
+# 2020 Johns Hopkins University (Author: Jesus Villalba)
+# Apache 2.0.
+#
+. ./cmd.sh
+. ./path.sh
+set -e
+
+stage=1
+config_file=default_config.sh
+
+. parse_options.sh || exit 1;
+. $config_file
+
+if [ $stage -le 2 ]; then
+ # This script preprocess audio for x-vector training
+ steps_xvec/preprocess_audios_for_nnet_train.sh --nj 40 --cmd "$train_cmd" \
+ --storage_name voxceleb-dinossl.v1-$(date +'%m_%d_%H_%M') --use-bin-vad true \
+ data/${nnet_data} data/${nnet_data}_proc_audio_no_sil exp/${nnet_data}_proc_audio_no_sil
+ hyp_utils/kaldi/utils/fix_data_dir.sh data/${nnet_data}_proc_audio_no_sil
+
+fi
+
+if [ $stage -le 3 ]; then
+ # Now, we remove files with less than 4s. This removes ~ 6.4% of the
+ # number of the original samples for voxceleb2_train.
+ hyp_utils/remove_short_audios.sh --min-len 4 data/${nnet_data}_proc_audio_no_sil
+fi
+
+if [ $stage -le 4 ]; then
+ # Prepare train and validation lists for x-vectors. JJ: This might use
+ # speaker labels but validation list won't be used in self-supervised
+ # learning. (In the future, it may be used w/o labels for better validation)
+ local/make_train_lists_sup_embed_with_augm.sh \
+ data/${nnet_data}_proc_audio_no_sil \
+ data/${nnet_data}_proc_audio_no_sil/lists_xvec
+fi
+
+exit
diff --git a/egs/voxceleb/dinossl.v1/run_011_train_xvector.sh b/egs/voxceleb/dinossl.v1/run_011_train_xvector.sh
new file mode 100755
index 00000000..17d50722
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/run_011_train_xvector.sh
@@ -0,0 +1,59 @@
+#!/bin/bash
+# Copyright
+# 2019 Johns Hopkins University (Author: Jesus Villalba)
+# Apache 2.0.
+#
+. ./cmd.sh
+. ./path.sh
+set -e
+
+stage=1
+ngpu=4
+config_file=default_config.sh
+interactive=false
+num_workers=""
+use_tb=false
+use_wandb=false
+
+. parse_options.sh || exit 1;
+. $config_file
+. datapath.sh
+
+list_dir=data/${nnet_data}_proc_audio_no_sil
+
+#add extra args from the command line arguments
+if [ -n "$num_workers" ];then
+ extra_args="--data.train.data_loader.num-workers $num_workers"
+fi
+if [ "$use_tb" == "true" ];then
+ extra_args="$extra_args --trainer.use-tensorboard"
+fi
+if [ "$use_wandb" == "true" ];then
+ extra_args="$extra_args --trainer.use-wandb --trainer.wandb.project voxceleb-v1.1 --trainer.wandb.name $nnet_name.$(date -Iminutes)"
+fi
+
+if [ "$interactive" == "true" ];then
+ export cuda_cmd=run.pl
+fi
+
+# Network Training
+if [ $stage -le 1 ]; then
+
+
+ mkdir -p $nnet_dir/log
+ $cuda_cmd \
+ --gpu $ngpu $nnet_dir/log/train.log \
+ hyp_utils/conda_env.sh --conda-env $HYP_ENV --num-gpus $ngpu \
+ train_xvector_from_wav.py $nnet_type --cfg $xvec_train_base_cfg $xvec_train_args $extra_args \
+ --data.train.dataset.audio-file $list_dir/wav.scp \
+ --data.train.dataset.time-durs-file $list_dir/utt2dur \
+ --data.train.dataset.key-file $list_dir/lists_xvec/train.scp \
+ --data.train.dataset.class-file $list_dir/lists_xvec/class2int \
+ --data.val.dataset.audio-file $list_dir/wav.scp \
+ --data.val.dataset.time-durs-file $list_dir/utt2dur \
+ --data.val.dataset.key-file $list_dir/lists_xvec/val.scp \
+ --trainer.exp-path $nnet_dir $args \
+ --num-gpus $ngpu \
+
+fi
+
diff --git a/egs/voxceleb/dinossl.v1/run_030_extract_xvectors.sh b/egs/voxceleb/dinossl.v1/run_030_extract_xvectors.sh
new file mode 100755
index 00000000..3abf2ff6
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/run_030_extract_xvectors.sh
@@ -0,0 +1,61 @@
+#!/bin/bash
+# Copyright
+# 2020 Johns Hopkins University (Author: Jesus Villalba)
+# Apache 2.0.
+#
+. ./cmd.sh
+. ./path.sh
+set -e
+
+stage=1
+config_file=default_config.sh
+use_gpu=false
+xvec_chunk_length=12800
+. parse_options.sh || exit 1;
+. $config_file
+
+if [ "$use_gpu" == "true" ];then
+ xvec_args="--use-gpu true --chunk-length $xvec_chunk_length"
+ xvec_cmd="$cuda_eval_cmd --mem 4G"
+else
+ xvec_cmd="$train_cmd --mem 12G"
+fi
+
+xvector_dir=exp/xvectors/$nnet_name
+
+if [ $stage -le 1 ]; then
+ # Extract xvectors for training LDA/PLDA
+ for name in voxceleb2cat_train
+ do
+ if [ $plda_num_augs -eq 0 ]; then
+ steps_xvec/extract_xvectors_from_wav.sh --cmd "$xvec_cmd" --nj 100 ${xvec_args} \
+ --random-utt-length true --min-utt-length 400 --max-utt-length 14000 \
+ --feat-config $feat_config \
+ $nnet data/${name} \
+ $xvector_dir/${name}
+ else
+ steps_xvec/extract_xvectors_from_wav.sh --cmd "$xvec_cmd" --nj 300 ${xvec_args} \
+ --random-utt-length true --min-utt-length 400 --max-utt-length 14000 \
+ --feat-config $feat_config --aug-config $plda_aug_config --num-augs $plda_num_augs \
+ $nnet data/${name} \
+ $xvector_dir/${name}_augx${plda_num_augs} \
+ data/${name}_augx${plda_num_augs}
+ fi
+ done
+fi
+
+
+if [ $stage -le 2 ]; then
+ # Extracts x-vectors for evaluation
+ for name in voxceleb1_test
+ do
+ num_spk=$(wc -l data/$name/spk2utt | awk '{ print $1}')
+ nj=$(($num_spk < 100 ? $num_spk:100))
+ steps_xvec/extract_xvectors_from_wav.sh --cmd "$xvec_cmd --mem 6G" --nj $nj ${xvec_args} \
+ --feat-config $feat_config \
+ $nnet data/$name \
+ $xvector_dir/$name
+ done
+fi
+
+exit
diff --git a/egs/voxceleb/dinossl.v1/run_040_eval_be.sh b/egs/voxceleb/dinossl.v1/run_040_eval_be.sh
new file mode 100755
index 00000000..cd168180
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/run_040_eval_be.sh
@@ -0,0 +1,125 @@
+#!/bin/bash
+# Copyright 2018 Johns Hopkins University (Author: Jesus Villalba)
+#
+# Apache 2.0.
+#
+. ./cmd.sh
+. ./path.sh
+set -e
+
+stage=1
+config_file=default_config.sh
+
+. parse_options.sh || exit 1;
+. $config_file
+. datapath.sh
+
+plda_label=${plda_type}y${plda_y_dim}_v1
+be_name=lda${lda_dim}_${plda_label}_${plda_data}
+
+xvector_dir=exp/xvectors/$nnet_name
+be_dir=exp/be/$nnet_name/$be_name
+score_dir=exp/scores/$nnet_name/${be_name}
+score_plda_dir=$score_dir/plda
+score_cosine_dir=exp/scores/$nnet_name/cosine
+
+if [ $stage -le 1 ]; then
+
+ echo "Train PLDA on Voxceleb2"
+ steps_be/train_be_v1.sh --cmd "$train_cmd" \
+ --lda_dim $lda_dim \
+ --plda_type $plda_type \
+ --y_dim $plda_y_dim --z_dim $plda_z_dim \
+ $xvector_dir/$plda_data/xvector.scp \
+ data/$plda_data \
+ $be_dir &
+
+
+ wait
+
+fi
+
+
+if [ $stage -le 2 ];then
+
+ echo "Eval Voxceleb 1 with LDA+CentWhiten+LNorm+PLDA"
+ steps_be/eval_be_v1.sh --cmd "$train_cmd" --plda_type $plda_type \
+ data/voxceleb1_test/trials \
+ data/voxceleb1_test/utt2model \
+ $xvector_dir/voxceleb1_test/xvector.scp \
+ $be_dir/lda_lnorm.h5 \
+ $be_dir/plda.h5 \
+ $score_plda_dir/voxceleb1_scores
+
+ $train_cmd --mem 10G --num-threads 6 $score_plda_dir/log/score_voxceleb1.log \
+ local/score_voxceleb1.sh data/voxceleb1_test $score_plda_dir
+
+ for f in $(ls $score_plda_dir/*_results);
+ do
+ echo $f
+ cat $f
+ echo ""
+ done
+
+fi
+
+
+score_plda_dir=$score_cosine_dir
+
+if [ $stage -le 3 ];then
+
+ echo "Eval Voxceleb 1 with Cosine scoring"
+ steps_be/eval_be_cos.sh --cmd "$train_cmd" \
+ data/voxceleb1_test/trials \
+ data/voxceleb1_test/utt2model \
+ $xvector_dir/voxceleb1_test/xvector.scp \
+ $score_plda_dir/voxceleb1_scores
+
+ $train_cmd --mem 10G --num-threads 6 $score_plda_dir/log/score_voxceleb1.log \
+ local/score_voxceleb1.sh data/voxceleb1_test $score_plda_dir
+
+ for f in $(ls $score_plda_dir/*_results);
+ do
+ echo $f
+ cat $f
+ echo ""
+ done
+
+fi
+
+be_dir=exp/be/$nnet_name/cw
+score_plda_dir=$score_dir/cw_cosine
+
+if [ $stage -le 4 ]; then
+ echo "Train centering+whitening on Voxceleb2"
+ steps_be/train_be_v2.sh --cmd "$train_cmd" \
+ $xvector_dir/$plda_data/xvector.scp \
+ data/$plda_data \
+ $be_dir
+fi
+
+
+if [ $stage -le 5 ];then
+
+ echo "Eval Voxceleb 1 with CentWhiten + Cosine scoring"
+ steps_be/eval_be_v2.sh --cmd "$train_cmd" \
+ data/voxceleb1_test/trials \
+ data/voxceleb1_test/utt2model \
+ $xvector_dir/voxceleb1_test/xvector.scp \
+ $be_dir/cw.h5 \
+ $score_plda_dir/voxceleb1_scores
+
+ $train_cmd --mem 10G --num-threads 6 $score_plda_dir/log/score_voxceleb1.log \
+ local/score_voxceleb1.sh data/voxceleb1_test $score_plda_dir
+
+ for f in $(ls $score_plda_dir/*_results);
+ do
+ echo $f
+ cat $f
+ echo ""
+ done
+
+fi
+
+exit
+
diff --git a/egs/voxceleb/dinossl.v1/run_511_train_xvector.sh b/egs/voxceleb/dinossl.v1/run_511_train_xvector.sh
new file mode 100755
index 00000000..19a37a75
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/run_511_train_xvector.sh
@@ -0,0 +1,85 @@
+#!/bin/bash
+# Copyright
+# 2019 Johns Hopkins University (Author: Jesus Villalba)
+# Apache 2.0.
+#
+. ./cmd.sh
+. ./path.sh
+set -e
+
+stage=1
+ngpu=1
+config_file=default_config.sh
+interactive=false
+num_workers=""
+use_tb=false
+use_wandb=false
+
+. parse_options.sh || exit 1;
+. $config_file
+. datapath.sh
+
+list_dir=data/${nnet_data}_proc_audio_no_sil
+
+#add extra args from the command line arguments
+if [ -n "$num_workers" ];then
+ extra_args="--data.train.data_loader.num-workers $num_workers"
+fi
+if [ "$use_tb" == "true" ];then
+ extra_args="$extra_args --trainer.use-tensorboard"
+fi
+if [ "$use_wandb" == "true" ];then
+ extra_args="$extra_args --trainer.use-wandb --trainer.wandb.project voxceleb-dinossl.v1 --trainer.wandb.name $nnet_name.$(date -Iminutes)"
+fi
+
+if [ "$interactive" == "true" ];then
+ export cuda_cmd=run.pl
+fi
+
+# Network Training
+if [ $stage -le 1 ]; then
+ # dino arguments
+ dinossl_args="--dinossl true "
+ for arg in dinossl_nlayers dinossl_out_dim dinossl_use_bn_in_head dinossl_norm_last_layer \
+ dinossl_local_crops_number dinossl_warmup_teacher_temp dinossl_teacher_temp \
+ dinossl_warmup_teacher_temp_epochs dinossl_chunk_len_mult dinossl_reduce_overlap_prob; do
+ if [ ! -z ${!arg} ]; then
+ dinossl_args+="--${arg} ${!arg} " # ${!arg} return a value in the var, "${arg}"
+ fi
+ done
+ echo "Dino arguments: ${dinossl_args}"
+
+ # Edit train.scp and class2int files to ignore class balancing in batching (to
+ # simulate a unsupervised scenario). Simply make class_idx == utt_idx
+ # train.utt2utt.scp
+ if [ ! -s ${list_dir}/lists_xvec/train.utt2utt.scp ]; then
+ awk '{print $1" "$1}' ${list_dir}/lists_xvec/train.scp > ${list_dir}/lists_xvec/train.utt2utt.scp
+ fi
+ # (This block can be ignored) val.utt2utt.scp although it is not used in the end
+ if [ ! -s ${list_dir}/lists_xvec/val.utt2utt.scp ]; then
+ awk '{print $1" "$1}' ${list_dir}/lists_xvec/val.scp > ${list_dir}/lists_xvec/val.utt2utt.scp
+ fi
+ # utt2int
+ if [ ! -s ${list_dir}/lists_xvec/utt2int ]; then
+ cat <(awk '{print $1}' ${list_dir}/lists_xvec/train.scp) <(awk '{print $1}' ${list_dir}/lists_xvec/val.scp) > ${list_dir}/lists_xvec/utt2int
+ fi
+
+
+
+ mkdir -p $nnet_dir/log
+ $cuda_cmd \
+ --gpu $ngpu $nnet_dir/log/train.log \
+ hyp_utils/conda_env.sh --conda-env $HYP_ENV --num-gpus $ngpu \
+ train_xvector_from_wav_dinossl.py $nnet_type --cfg $xvec_train_base_cfg $xvec_train_args $extra_args ${dinossl_args} \
+ --data.train.dataset.audio-file $list_dir/wav.scp \
+ --data.train.dataset.time-durs-file $list_dir/utt2dur \
+ --data.train.dataset.key-file $list_dir/lists_xvec/train.utt2utt.scp \
+ --data.train.dataset.class-file $list_dir/lists_xvec/utt2int \
+ --data.val.dataset.audio-file $list_dir/wav.scp \
+ --data.val.dataset.time-durs-file $list_dir/utt2dur \
+ --data.val.dataset.key-file $list_dir/lists_xvec/val.utt2utt.scp \
+ --trainer.exp-path $nnet_dir $args \
+ --num-gpus $ngpu \
+
+fi
+
diff --git a/egs/voxceleb/dinossl.v1/steps b/egs/voxceleb/dinossl.v1/steps
new file mode 120000
index 00000000..aede39fe
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/steps
@@ -0,0 +1 @@
+hyp_utils/kaldi/steps
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/steps_be b/egs/voxceleb/dinossl.v1/steps_be
new file mode 120000
index 00000000..b2098c2a
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/steps_be
@@ -0,0 +1 @@
+../v1/steps_be
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/steps_fe b/egs/voxceleb/dinossl.v1/steps_fe
new file mode 120000
index 00000000..73ccc1eb
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/steps_fe
@@ -0,0 +1 @@
+hyp_utils/kaldi/vad
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/steps_pyfe b/egs/voxceleb/dinossl.v1/steps_pyfe
new file mode 120000
index 00000000..7b9d122a
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/steps_pyfe
@@ -0,0 +1 @@
+hyp_utils/feats
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/steps_xvec b/egs/voxceleb/dinossl.v1/steps_xvec
new file mode 120000
index 00000000..af66a94d
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/steps_xvec
@@ -0,0 +1 @@
+hyp_utils/xvectors
\ No newline at end of file
diff --git a/egs/voxceleb/dinossl.v1/utils b/egs/voxceleb/dinossl.v1/utils
new file mode 120000
index 00000000..3d590a1d
--- /dev/null
+++ b/egs/voxceleb/dinossl.v1/utils
@@ -0,0 +1 @@
+hyp_utils/kaldi/utils
\ No newline at end of file
diff --git a/egs/voxceleb/v1/local/make_voxceleb1_oldversion_oeh.pl b/egs/voxceleb/v1/local/make_voxceleb1_oldversion_oeh.pl
new file mode 100755
index 00000000..f6eb8f35
--- /dev/null
+++ b/egs/voxceleb/v1/local/make_voxceleb1_oldversion_oeh.pl
@@ -0,0 +1,130 @@
+#!/usr/bin/perl
+# Note: this is an old version script in the commit id 1eb12f2ed01801a50c3f7ba014809bf7c7212f28. NOT process LANG, GENDER, NAT
+# Copyright 2018 Ewald Enzinger
+# 2018 David Snyder
+# 2020 Jesus Villalba
+#
+# Usage: make_voxceleb1.pl /export/voxceleb1 data/
+# Create trial lists for Voxceleb1 original, Entire (E) and hard (H),
+# with cleaned and non-cleaned versions
+
+if (@ARGV != 2) {
+ print STDERR "Usage: $0 \n";
+ print STDERR "e.g. $0 /export/voxceleb1 data/\n";
+ exit(1);
+}
+
+($data_base, $out_dir) = @ARGV;
+my $out_dir = "$out_dir/voxceleb1_test";
+
+if (system("mkdir -p $out_dir") != 0) {
+ die "Error making directory $out_dir";
+}
+
+my $url_base="http://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta";
+my @trials_basename = ("very_test.txt", "very_test2.txt", "list_test_hard.txt", "list_test_hard2.txt", "list_test_all.txt", "list_test_all2.txt");
+my @trials_url = ("$url_base/veri_test.txt", "$url_base/veri_test2.txt", "$url_base/list_test_hard.txt", "$url_base/list_test_hard2.txt", "$url_base/list_test_all.txt", "$url_base/list_test_all2.txt");
+my @trials = ("trials_o", "trials_o_clean", "trials_h", "trials_h_clean", "trials_e", "trials_e_clean");
+
+open(META_IN, "<", "$data_base/vox1_meta.csv") or die "Could not open the meta data file $data_base/vox1_meta.csv";
+my %id2spkr = ();
+while () {
+ chomp;
+ my ($vox_id, $spkr_id, $gender, $nation, $set) = split;
+ $id2spkr{$vox_id} = $spkr_id;
+
+}
+close(META_IN) or die;
+
+#download trials from voxceleb web page
+for($i = 0; $i <= $#trials; $i++) {
+
+ my $file_i = "$out_dir/$trials_basename[$i]";
+ my $url_i = $trials_url[$i];
+ my $trial_i = "$out_dir/$trials[$i]";
+ if (! -e $file_i) {
+ system("wget -O $file_i $url_i");
+ }
+ #mapping from new speaker ids and file-names to old ones
+ open(TRIAL_IN, "<", "$file_i") or die "Could not open the verification trials file $file_i";
+ open(TRIAL_OUT, ">", "$trial_i") or die "Could not open the output file $trial_i";
+ while () {
+ chomp;
+ my ($tar_or_non, $path1, $path2) = split;
+
+ # Create entry for left-hand side of trial
+ my ($vox_id, $rec_id, $segment) = split('/', $path1);
+ $segment =~ s/\.wav$//;
+ my $spkr_id = $id2spkr{$vox_id};
+ my $utt_id1 = "$spkr_id-$rec_id-00$segment";
+
+ # Create entry for right-hand side of trial
+ my ($vox_id, $rec_id, $segment) = split('/', $path2);
+ $segment =~ s/\.wav$//;
+ my $spkr_id = $id2spkr{$vox_id};
+ my $utt_id2 = "$spkr_id-$rec_id-00$segment";
+
+ my $target = "nontarget";
+ if ($tar_or_non eq "1") {
+ $target = "target";
+ }
+ print TRIAL_OUT "$utt_id1 $utt_id2 $target\n";
+ }
+
+ close(TRIAL_IN) or die;
+ close(TRIAL_OUT) or die;
+
+}
+
+
+opendir my $dh, "$data_base/voxceleb1_wav" or die "Cannot open directory: $!";
+my @spkr_dirs = grep {-d "$data_base/voxceleb1_wav/$_" && ! /^\.{1,2}$/} readdir($dh);
+closedir $dh;
+
+open(SPKR_TEST, ">", "$out_dir/utt2spk") or die "Could not open the output file $out_dir/utt2spk";
+open(WAV_TEST, ">", "$out_dir/wav.scp") or die "Could not open the output file $out_dir/wav.scp";
+
+foreach (@spkr_dirs) {
+ my $spkr_id = $_;
+ my $new_spkr_id = $spkr_id;
+ # If we're using a newer version of VoxCeleb1, we need to "deanonymize"
+ # the speaker labels.
+ if (exists $id2spkr{$spkr_id}) {
+ $new_spkr_id = $id2spkr{$spkr_id};
+ }
+ opendir my $dh, "$data_base/voxceleb1_wav/$spkr_id/" or die "Cannot open directory: $!";
+ my @files = map{s/\.[^.]+$//;$_}grep {/\.wav$/} readdir($dh);
+ closedir $dh;
+ foreach (@files) {
+ my $filename = $_;
+ my $rec_id = substr($filename, 0, 11);
+ my $segment = substr($filename, 12, 7);
+ my $wav = "$data_base/voxceleb1_wav/$spkr_id/$filename.wav";
+ my $utt_id = "$new_spkr_id-$rec_id-$segment";
+ print WAV_TEST "$utt_id", " $wav", "\n";
+ print SPKR_TEST "$utt_id", " $new_spkr_id", "\n";
+ }
+}
+
+close(SPKR_TEST) or die;
+close(WAV_TEST) or die;
+
+if (system(
+ "cat $out_dir/trials_* | sort -u > $out_dir/trials") != 0) {
+ die "Error creating trials file in directory $out_dir";
+}
+
+if (system(
+ "awk '{ print \$1,\$1 }' $out_dir/trials | sort -u > $out_dir/utt2model") != 0) {
+ die "Error creating utt2model file in directory $out_dir";
+}
+
+if (system(
+ "utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
+ die "Error creating spk2utt file in directory $out_dir";
+}
+system("env LC_COLLATE=C utils/fix_data_dir.sh $out_dir");
+if (system("env LC_COLLATE=C utils/validate_data_dir.sh --no-text --no-feats $out_dir") != 0) {
+ die "Error validating directory $out_dir";
+}
+
diff --git a/egs/voxceleb/v1/local/make_voxceleb2.pl b/egs/voxceleb/v1/local/make_voxceleb2.pl
index e0ebeb0f..b2bd0d71 100755
--- a/egs/voxceleb/v1/local/make_voxceleb2.pl
+++ b/egs/voxceleb/v1/local/make_voxceleb2.pl
@@ -1,5 +1,6 @@
#!/usr/bin/perl
-#
+# Note: Compared to local/make_voxceleb2cat.pl 1) This does NOT concatenate same speaker recording turns in the conversation
+# 2) skippied the part to get LANG, GENDER metadata
# Copyright 2018 Johns Hopkins University (Jesus Villalba)
# Copyright 2018 Ewald Enzinger
#
@@ -32,59 +33,20 @@
$dataset_path = "$data_base/$dataset"
}
+opendir my $dh, "$dataset_path" or die "Cannot open directory: $!";
+my @spkr_dirs = grep {-d "$dataset_path/$_" && ! /^\.{1,2}$/} readdir($dh);
+closedir $dh;
if (system("mkdir -p $out_dir") != 0) {
die "Error making directory $out_dir";
}
-
-my $meta_url = "https://www.openslr.org/resources/49/vox2_meta.csv";
-my $meta_path = "$data_base/vox2_meta.csv";
-if (! -e "$meta_path") {
- $meta_path = "$out_dir/vox2_meta.csv";
- system("wget -O $meta_path $meta_url");
-}
-open(META_IN, "<", "$meta_path") or die "Could not open the output file $meta_path";
-my %spkr2gender = ();
-while () {
- chomp;
- my ($spkr, $vox_id, $vgg_id, $gender, $set) = split;
- $spkr2gender{$vox_id} = $gender;
-}
-close(META_IN) or die;
-
-my $lid_url = "https://www.robots.ox.ac.uk/~vgg/data/voxceleb/data_workshop_2021/lang_vox2_final.csv";
-my $lid_path = "$data_base/lang_vox2_final.csv";
-if (! -e "$lid_path") {
- $lid_path = "$out_dir/lang_vox2_final.csv";
- system("wget -O $lid_path $lid_url");
-}
-open(LID_IN, "<", "$lid_path") or die "Could not open the output file $lid_path";
-my %utt2lang = ();
-while () {
- chomp;
- my ($utt_id, $lang, $score) = split ',';
- $utt_id =~ s@/@-@g;
- $utt_id =~ s@\.wav$@@;
- $utt2lang{$utt_id} = $lang;
-}
-close(LID_IN) or die;
-
-
open(SPKR, ">", "$out_dir/utt2spk") or die "Could not open the output file $out_dir/utt2spk";
open(WAV, ">", "$out_dir/wav.scp") or die "Could not open the output file $out_dir/wav.scp";
-open(LANG, ">", "$out_dir/utt2lang") or die "Could not open the output file $out_dir/utt2lang";
-open(GENDER, ">", "$out_dir/spk2gender") or die "Could not open the output file $out_dir/spk2gender";
-
-opendir my $dh, "$dataset_path" or die "Cannot open directory: $!";
-my @spkr_dirs = grep {-d "$dataset_path/$_" && ! /^\.{1,2}$/} readdir($dh);
-closedir $dh;
foreach (@spkr_dirs) {
my $spkr_id = $_;
- print GENDER "$spkr_id $spkr2gender{$spkr_id}\n";
-
opendir my $dh, "$dataset_path/$spkr_id/" or die "Cannot open directory: $!";
my @rec_dirs = grep {-d "$dataset_path/$spkr_id/$_" && ! /^\.{1,2}$/} readdir($dh);
closedir $dh;
@@ -105,19 +67,11 @@
my $utt_id = "$spkr_id-$rec_id-$name";
print WAV "$utt_id", " $wav", "\n";
print SPKR "$utt_id", " $spkr_id", "\n";
- if (exists $utt2lang{$utt_id}) {
- print LANG "$utt_id", " $utt2lang{$utt_id}", "\n";
- }
- else {
- print LANG "$utt_id N/A\n";
- }
}
}
}
close(SPKR) or die;
close(WAV) or die;
-close(LANG) or die;
-close(GENDER) or die;
if (system(
"utils/utt2spk_to_spk2utt.pl $out_dir/utt2spk >$out_dir/spk2utt") != 0) {
diff --git a/hyp_utils/feats/make_evad.sh b/hyp_utils/feats/make_evad.sh
index 373fc4a6..377c6e7b 100755
--- a/hyp_utils/feats/make_evad.sh
+++ b/hyp_utils/feats/make_evad.sh
@@ -84,11 +84,42 @@ if [ -f $data/segments ]; then
opt_args="${opt_args} --segments $data/segments"
fi
+set +e
$cmd JOB=1:$nj $logdir/make_vad_${name}.JOB.log \
hyp_utils/conda_env.sh \
compute_energy_vad.py --cfg $vad_config $opt_args \
--input $scp --output ark,scp:$vaddir/vad_$name.JOB.ark,$vaddir/vad_$name.JOB.scp \
- --part-idx JOB --num-parts $nj || exit 1
+ --part-idx JOB --num-parts $nj
+set -e
+
+# rerun not successful jobs
+for tmp in {1..3};do
+ pids=""
+
+ for((i=1;i<=$nj;i++))
+ do
+ status=$(tail -n 1 $logdir/make_vad_${name}.$i.log | \
+ awk '/status 0/ { print 0}
+ !/status 0/ { print 1}')
+ if [ $status -eq 1 ];then
+ echo "JOB $i failed, resubmitting"
+ sleep 10
+ opt_args=`echo ${opt_args} | sed -e "s/utt2num_frames.JOB/utt2num_frames.$i/g"`
+ $cmd $logdir/make_vad_${name}.$i.log \
+ hyp_utils/conda_env.sh \
+ compute_energy_vad.py --cfg $vad_config $opt_args \
+ --input $scp --output ark,scp:$vaddir/vad_$name.$i.ark,$vaddir/vad_$name.$i.scp \
+ --part-idx $i --num-parts $nj &
+ opt_args=`echo ${opt_args} | sed -e "s/utt2num_frames.$i/utt2num_frames.JOB/g"`
+ pids="$pids $!"
+ fi
+ done
+
+ for pid in $pids;do
+ wait $pid
+ done
+done
+wait
# concatenate the .scp files together.
for n in $(seq $nj); do
diff --git a/hyp_utils/kaldi/utils/fix_data_dir.sh b/hyp_utils/kaldi/utils/fix_data_dir.sh
index bb18e07b..ed080eee 100755
--- a/hyp_utils/kaldi/utils/fix_data_dir.sh
+++ b/hyp_utils/kaldi/utils/fix_data_dir.sh
@@ -117,7 +117,7 @@ function filter_speakers {
${kaldi_utils}/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
- for s in cmvn.scp spk2gender spk2nation; do
+ for s in cmvn.scp spk2gender; do
f=$data/$s
if [ -f $f ]; then
filter_file $f $tmpdir/speakers
@@ -127,7 +127,7 @@ function filter_speakers {
filter_file $tmpdir/speakers $data/spk2utt
${kaldi_utils}/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
- for s in cmvn.scp spk2gender spk2nation $spk_extra_files; do
+ for s in cmvn.scp spk2gender $spk_extra_files; do
f=$data/$s
if [ -f $f ]; then
filter_file $tmpdir/speakers $f
diff --git a/hyperion/bin/torch-extract-xvectors-from-wav.py b/hyperion/bin/torch-extract-xvectors-from-wav.py
new file mode 100644
index 00000000..90969722
--- /dev/null
+++ b/hyperion/bin/torch-extract-xvectors-from-wav.py
@@ -0,0 +1,342 @@
+#!/usr/bin/env python
+"""
+ Copyright 2019 Jesus Villalba (Johns Hopkins University)
+ Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+"""
+
+import sys
+import os
+from jsonargparse import (
+ ArgumentParser,
+ ActionConfigFile,
+ ActionParser,
+ namespace_to_dict,
+)
+import time
+import logging
+
+import numpy as np
+import pandas as pd
+
+import torch
+
+from hyperion.hyp_defs import config_logger, float_cpu, set_float_cpu
+from hyperion.utils import Utt2Info
+from hyperion.io import DataWriterFactory as DWF
+from hyperion.io import SequentialAudioReader as AR
+from hyperion.io import VADReaderFactory as VRF
+from hyperion.augment import SpeechAugment
+
+from hyperion.torch.utils import open_device
+from hyperion.torch.utils import dinossl
+from hyperion.torch.narchs import AudioFeatsMVN as AF
+from hyperion.torch import TorchModelLoader as TML
+
+
+def init_device(use_gpu):
+ set_float_cpu("float32")
+ num_gpus = 1 if use_gpu else 0
+ logging.info("initializing devices num_gpus={}".format(num_gpus))
+ device = open_device(num_gpus=num_gpus)
+ return device
+
+
+def init_feats(device, **kwargs):
+ feat_args = AF.filter_args(**kwargs["feats"])
+ logging.info("feat args={}".format(feat_args))
+ logging.info("initializing feature extractor")
+ feat_extractor = AF(trans=False, **feat_args)
+ logging.info("feat-extractor={}".format(feat_extractor))
+ feat_extractor.eval()
+ feat_extractor.to(device)
+ return feat_extractor
+
+
+def load_model(model_path, device, state_dict_key='model_state_dict', dinossl_kwargs=None):
+ logging.info("loading model {}".format(model_path))
+ model = TML.load(model_path, state_dict_key=state_dict_key ,dinossl_kwargs=dinossl_kwargs)
+ logging.info("xvector-model={}".format(model))
+ model.to(device)
+ model.eval()
+ return model
+
+
+def augment(key0, x0, augmenter, aug_df, aug_id):
+ if augmenter is None:
+ x = x0
+ key = key0
+ else:
+ x, aug_info = augmenter(x0)
+ key = "%s-aug-%02d" % (key0, aug_id)
+ aug_df_row = {
+ "key_aug": key,
+ "key_orig": key0,
+ "noise_type": aug_info["noise"]["noise_type"],
+ "snr": aug_info["noise"]["snr"],
+ "rir_type": aug_info["reverb"]["rir_type"],
+ "srr": aug_info["reverb"]["srr"],
+ "sdr": aug_info["sdr"],
+ }
+
+ aug_df.append(pd.DataFrame(aug_df_row, index=[0]))
+
+ return key, x
+
+
+def select_random_chunk(key, x, min_utt_length, max_utt_length, rng):
+ utt_length = rng.randint(low=min_utt_length, high=max_utt_length + 1)
+ if utt_length < x.shape[1]:
+ first_frame = rng.randint(low=0, high=x.shape[1] - utt_length)
+ x = x[:, first_frame : first_frame + utt_length]
+ logging.info(
+ "extract-random-utt %s of length=%d first-frame=%d"
+ % (key, x.shape[1], first_frame)
+ )
+ return x
+
+
+def extract_xvectors(
+ input_spec,
+ output_spec,
+ vad_spec,
+ write_num_frames_spec,
+ scp_sep,
+ vad_path_prefix,
+ model_path,
+ chunk_length,
+ embed_layer,
+ random_utt_length,
+ min_utt_length,
+ max_utt_length,
+ aug_cfg,
+ num_augs,
+ aug_info_path,
+ use_gpu,
+ **kwargs
+):
+
+ rng = np.random.RandomState(seed=1123581321 + kwargs["part_idx"])
+ device = init_device(use_gpu)
+ feat_extractor = init_feats(device, **kwargs)
+ if kwargs['dinossl']:
+ dinossl_kwargs={k:kwargs[k] for k in kwargs if 'dinossl' in k}
+ model = load_model(model_path, device, state_dict_key=kwargs['state_dict_key'], dinossl_kwargs=dinossl_kwargs)
+ else:
+ model = load_model(model_path, device, state_dict_key=kwargs['state_dict_key'])
+
+ if write_num_frames_spec is not None:
+ keys = []
+ info = []
+
+ if aug_cfg is not None:
+ augmenter = SpeechAugment.create(aug_cfg, rng=rng)
+ aug_df = []
+ else:
+ augmenter = None
+ aug_df = None
+ num_augs = 1
+
+ ar_args = AR.filter_args(**kwargs)
+ logging.info("opening output stream: %s" % (output_spec))
+ with DWF.create(output_spec, scp_sep=scp_sep) as writer:
+
+ logging.info(
+ "opening input stream: {} with args={}".format(input_spec, ar_args)
+ )
+ with AR(input_spec, **ar_args) as reader:
+
+ if vad_spec is not None:
+ logging.info("opening VAD stream: %s" % (vad_spec))
+ v_reader = VRF.create(
+ vad_spec, path_prefix=vad_path_prefix, scp_sep=scp_sep
+ )
+
+ while not reader.eof():
+ t1 = time.time()
+ key, x0, fs = reader.read(1)
+ if len(key) == 0:
+ break
+
+ x0 = x0[0]
+ key0 = key[0]
+ t2 = time.time()
+
+ logging.info("processing utt %s" % (key0))
+ for aug_id in range(num_augs):
+ t3 = time.time()
+ key, x = augment(key0, x0, augmenter, aug_df, aug_id)
+ t4 = time.time()
+ with torch.no_grad():
+ x = torch.tensor(
+ x[None, :], dtype=torch.get_default_dtype()
+ ).to(device)
+
+ x = feat_extractor(x)
+ t5 = time.time()
+ tot_frames = x.shape[1]
+ if vad_spec is not None:
+ vad = v_reader.read(key0, num_frames=tot_frames)[0]
+ vad = torch.tensor(vad, dtype=torch.bool).to(device)
+ x = x[:, vad]
+
+ logging.info(
+ "utt %s detected %d/%d (%.2f %%) speech frames"
+ % (
+ key,
+ x.shape[1],
+ tot_frames,
+ x.shape[1] / tot_frames * 100,
+ )
+ )
+
+ if random_utt_length:
+ x = select_random_chunk(
+ key, x, min_utt_length, max_utt_length, rng
+ )
+
+ t6 = time.time()
+ if x.shape[1] == 0: # JJ: EXP - this case is not taken care of for the dinossl case
+ y = np.zeros((model.embed_dim,), dtype=float_cpu())
+ else:
+ x = x.transpose(1, 2).contiguous()
+ y = (
+ model.extract_embed(
+ x,
+ chunk_length=chunk_length,
+ embed_layer=embed_layer,
+ )
+ .cpu()
+ .numpy()[0]
+ )
+
+ t7 = time.time()
+ writer.write([key], [y])
+ if write_num_frames_spec is not None:
+ keys.append(key)
+ info.append(str(x.shape[1]))
+
+ t8 = time.time()
+ read_time = t2 - t1
+ tot_time = read_time + t8 - t3
+ logging.info(
+ (
+ "utt %s total-time=%.3f read-time=%.3f "
+ "aug-time=%.3f feat-time=%.3f "
+ "vad-time=%.3f embed-time=%.3f write-time=%.3f "
+ "rt-factor=%.2f"
+ )
+ % (
+ key,
+ tot_time,
+ read_time,
+ t4 - t3,
+ t5 - t4,
+ t6 - t5,
+ t7 - t6,
+ t8 - t7,
+ x0.shape[0] / fs[0] / tot_time,
+ )
+ )
+
+ if write_num_frames_spec is not None:
+ logging.info("writing num-frames to %s" % (write_num_frames_spec))
+ u2nf = Utt2Info.create(keys, info)
+ u2nf.save(write_num_frames_spec)
+
+ if aug_info_path is not None:
+ aug_df = pd.concat(aug_df, ignore_index=True)
+ aug_df.to_csv(aug_info_path, index=False, na_rep="n/a")
+
+
+if __name__ == "__main__":
+
+ parser = ArgumentParser(
+ description=(
+ "Extracts x-vectors from waveform computing " "acoustic features on the fly"
+ )
+ )
+
+ parser.add_argument("--cfg", action=ActionConfigFile)
+ parser.add_argument("--dinossl_cfg", action=ActionConfigFile)
+ parser.add_argument("--input", dest="input_spec", required=True)
+ parser.add_argument("--vad", dest="vad_spec", default=None)
+ parser.add_argument(
+ "--write-num-frames", dest="write_num_frames_spec", default=None
+ )
+ parser.add_argument("--scp-sep", default=" ", help=("scp file field separator"))
+ parser.add_argument(
+ "--vad-path-prefix", default=None, help=("scp file_path prefix for vad")
+ )
+
+ AR.add_class_args(parser)
+
+ parser.add_argument("--aug-cfg", default=None)
+ parser.add_argument("--aug-info-path", default=None)
+ parser.add_argument(
+ "--num-augs", default=1, type=int, help="number of augmentations per utterance"
+ )
+
+ AF.add_class_args(parser, prefix="feats")
+
+ parser.add_argument("--model-path", required=True)
+ parser.add_argument(
+ "--chunk-length",
+ type=int,
+ default=0,
+ help=(
+ "number of frames used in each forward pass "
+ "of the x-vector encoder,"
+ "if 0 the full utterance is used"
+ ),
+ )
+ parser.add_argument(
+ "--embed-layer",
+ type=int,
+ default=None,
+ help=(
+ "classifier layer to get the embedding from, "
+ "if None, it uses layer set in training phase"
+ ),
+ )
+
+ parser.add_argument(
+ "--random-utt-length",
+ default=False,
+ action="store_true",
+ help="calculates x-vector from a random chunk",
+ )
+ parser.add_argument(
+ "--min-utt-length",
+ type=int,
+ default=500,
+ help=("minimum utterance length when using random utt length"),
+ )
+ parser.add_argument(
+ "--max-utt-length",
+ type=int,
+ default=12000,
+ help=("maximum utterance length when using random utt length"),
+ )
+
+ parser.add_argument("--output", dest="output_spec", required=True)
+ parser.add_argument(
+ "--use-gpu", default=False, action="store_true", help="extract xvectors in gpu"
+ )
+ parser.add_argument(
+ "-v", "--verbose", dest="verbose", default=1, choices=[0, 1, 2, 3], type=int
+ )
+ # dinossl related
+ parser.add_argument('--state_dict_key', type=str, default='model_state_dict',
+ choices=['model_state_dict','model_teacher_state_dict'],
+ help=('key for state_dict of a pre-trained model. Currently model_teacher_state_dict is only possible for dinossl'))
+ parser.add_argument('--dinossl_xvec_loc', type=str, default='f',
+ choices=['f', 'dinohead_mlp','dinohead_l2norm','dinohead_linear'],
+ help=('Where to extract x-vectors from the dinossl model. The naming follows Figure 9 in the DINO paper'))
+ dinossl.add_dinossl_args(parser)
+
+ args = parser.parse_args()
+ config_logger(args.verbose)
+ del args.verbose
+ logging.debug(args)
+
+ extract_xvectors(**namespace_to_dict(args))
\ No newline at end of file
diff --git a/hyperion/bin/train_xvector_from_wav.py b/hyperion/bin/train_xvector_from_wav.py
index 0e074977..89a281af 100755
--- a/hyperion/bin/train_xvector_from_wav.py
+++ b/hyperion/bin/train_xvector_from_wav.py
@@ -213,5 +213,5 @@ def make_parser(xvec_class):
args_sc.xvec_class = xvec_dict[xvec_type]
# torch docs recommend using forkserver
- multiprocessing.set_start_method("forkserver")
+ multiprocessing.set_start_method("forkserver",force=True)
train_xvec(gpu_id, args_sc)
diff --git a/hyperion/bin/train_xvector_from_wav_dinossl.py b/hyperion/bin/train_xvector_from_wav_dinossl.py
new file mode 100755
index 00000000..f89496ad
--- /dev/null
+++ b/hyperion/bin/train_xvector_from_wav_dinossl.py
@@ -0,0 +1,284 @@
+#!/usr/bin/env python
+"""
+ Copyright 2020 Johns Hopkins University (Author: Jesus Villalba)
+ Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+"""
+import sys
+import os
+from pathlib import Path
+from jsonargparse import (
+ ArgumentParser,
+ ActionConfigFile,
+ ActionParser,
+ namespace_to_dict,
+)
+import time
+import logging
+import multiprocessing
+
+import torch
+
+from hyperion.hyp_defs import config_logger, set_float_cpu
+from hyperion.torch.utils import ddp
+from hyperion.torch.utils import dinossl
+from hyperion.torch.trainers import DINOSSLXVectorTrainerFromWav as Trainer
+from hyperion.torch.data import AudioDataset as AD
+from hyperion.torch.data import ClassWeightedSeqSampler as Sampler
+from hyperion.torch.narchs import AudioFeatsMVN as AF
+from hyperion.torch.models import ResNetXVector as RXVec
+from hyperion.torch.models import ResNet1dXVector as R1dXVec
+from hyperion.torch.models import EfficientNetXVector as EXVec
+from hyperion.torch.models import TDNNXVector as TDXVec
+from hyperion.torch.models import TransformerXVectorV1 as TFXVec
+from hyperion.torch.models import SpineNetXVector as SpineXVec
+
+xvec_dict = {
+ "resnet": RXVec,
+ "resnet1d": R1dXVec,
+ "efficientnet": EXVec,
+ "tdnn": TDXVec,
+ "transformer": TFXVec,
+ "spinenet": SpineXVec,
+}
+
+
+def init_data(partition, rank, num_gpus, **kwargs):
+ if kwargs["dinossl"]:
+ dinossl_args = dinossl.filter_args(**kwargs)
+ kwargs = kwargs["data"][partition]
+ ad_args = AD.filter_args(**kwargs["dataset"])
+ sampler_args = Sampler.filter_args(**kwargs["sampler"])
+ if rank == 0:
+ logging.info("{} audio dataset args={}".format(partition, ad_args))
+ logging.info("{} sampler args={}".format(partition, sampler_args))
+ logging.info("init %s dataset", partition)
+
+ ad_args["is_val"] = partition == "val"
+ if dinossl_args["dinossl"]:
+ dataset = AD(**ad_args,dinossl_chunk_len_mult=dinossl_args["dinossl_chunk_len_mult"], dinossl_n_chunks=dinossl_args["dinossl_local_crops_number"] + 2, dinossl_reduce_overlap_prob=dinossl_args["dinossl_reduce_overlap_prob"])
+ else:
+ dataset = AD(**ad_args)
+
+ if rank == 0:
+ logging.info("init %s samplers", partition)
+
+ sampler = Sampler(dataset, **sampler_args)
+
+ if rank == 0:
+ logging.info("init %s dataloader", partition)
+
+ num_workers = kwargs["data_loader"]["num_workers"]
+ num_workers_per_gpu = int((num_workers + num_gpus - 1) / num_gpus)
+ largs = (
+ {"num_workers": num_workers_per_gpu, "pin_memory": True} if num_gpus > 0 else {}
+ )
+ data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler, **largs)
+ return data_loader
+
+
+# def init_data(
+# audio_path,
+# train_list,
+# val_list,
+# train_aug_cfg,
+# val_aug_cfg,
+# num_workers,
+# num_gpus,
+# rank,
+# **kwargs
+# ):
+
+# ad_args = AD.filter_args(**kwargs)
+# sampler_args = Sampler.filter_args(**kwargs)
+# if rank == 0:
+# logging.info("audio dataset args={}".format(ad_args))
+# logging.info("sampler args={}".format(sampler_args))
+# logging.info("init datasets")
+
+# train_data = AD(audio_path, train_list, aug_cfg=train_aug_cfg, **ad_args)
+# val_data = AD(audio_path, val_list, aug_cfg=val_aug_cfg, is_val=True, **ad_args)
+
+# if rank == 0:
+# logging.info("init samplers")
+# train_sampler = Sampler(train_data, **sampler_args)
+# val_sampler = Sampler(val_data, **sampler_args)
+
+# num_workers_per_gpu = int((num_workers + num_gpus - 1) / num_gpus)
+# largs = (
+# {"num_workers": num_workers_per_gpu, "pin_memory": True} if num_gpus > 0 else {}
+# )
+
+# train_loader = torch.utils.data.DataLoader(
+# train_data, batch_sampler=train_sampler, **largs
+# )
+
+# test_loader = torch.utils.data.DataLoader(
+# val_data, batch_sampler=val_sampler, **largs
+# )
+
+# return train_loader, test_loader
+
+
+def init_feats(rank, **kwargs):
+ feat_args = AF.filter_args(**kwargs["feats"])
+ if rank == 0:
+ logging.info("feat args={}".format(feat_args))
+ logging.info("initializing feature extractor")
+ feat_extractor = AF(trans=True, **feat_args)
+ if rank == 0:
+ logging.info("feat-extractor={}".format(feat_extractor))
+ return feat_extractor
+
+
+def init_xvector(num_classes, rank, xvec_class, **kwargs):
+ xvec_args = xvec_class.filter_args(**kwargs["model"])
+ if rank == 0:
+ logging.info("xvector network args={}".format(xvec_args))
+ xvec_args["num_classes"] = num_classes
+ model = xvec_class(**xvec_args)
+ if rank == 0:
+ logging.info("x-vector-model={}".format(model))
+ return model
+
+
+def train_xvec(gpu_id, args):
+
+ config_logger(args.verbose)
+ del args.verbose
+ logging.debug(args)
+
+ kwargs = namespace_to_dict(args)
+ torch.manual_seed(args.seed)
+ set_float_cpu("float32")
+
+ ddp_args = ddp.filter_ddp_args(**kwargs)
+ device, rank, world_size = ddp.ddp_init(gpu_id, **ddp_args)
+ kwargs["rank"] = rank
+
+ train_loader = init_data(partition="train", **kwargs)
+ val_loader = init_data(partition="val", **kwargs) if not kwargs["dinossl"] else None
+ feat_extractor = init_feats(**kwargs)
+ model = init_xvector(train_loader.dataset.num_classes, **kwargs)
+ loss = None
+ if kwargs["dinossl"]:
+ dinossl_args = dinossl.filter_args(**kwargs)
+ model, loss = dinossl.init_dino(model, dinossl_args, rank = rank)
+
+ trn_args = Trainer.filter_args(**kwargs["trainer"])
+ trn_args["niter_per_ep"] = len(train_loader) # will be used for DINO-related scheduling
+ trn_args["batch_size"] = kwargs["data"]["train"]["sampler"]["batch_size"] * kwargs["num_gpus"]
+ if rank == 0:
+ logging.info("trainer args={}".format(trn_args))
+ metrics = {}
+ trainer = Trainer(
+ model,
+ feat_extractor,
+ device=device,
+ metrics=metrics,
+ ddp=world_size > 1,
+ loss=loss,
+ **trn_args,
+ )
+ trainer.load_last_checkpoint()
+ trainer.fit(train_loader, val_loader)
+
+ ddp.ddp_cleanup()
+
+
+def make_parser(xvec_class):
+ parser = ArgumentParser()
+
+ parser.add_argument("--cfg", action=ActionConfigFile)
+
+ train_parser = ArgumentParser(prog="")
+ # parser.add_argument("--audio-path", required=True)
+ # parser.add_argument("--train-list", required=True)
+ # parser.add_argument("--val-list", required=True)
+
+ AD.add_class_args(train_parser, prefix="dataset", skip={})
+ Sampler.add_class_args(train_parser, prefix="sampler")
+ # parser.add_argument("--train-aug-cfg", default=None)
+ # parser.add_argument("--val-aug-cfg", default=None)
+ train_parser.add_argument(
+ "--data_loader.num-workers",
+ type=int,
+ default=5,
+ help="num_workers of data loader",
+ )
+
+ val_parser = ArgumentParser(prog="")
+ AD.add_class_args(val_parser, prefix="dataset", skip={})
+ Sampler.add_class_args(val_parser, prefix="sampler")
+ val_parser.add_argument(
+ "--data_loader.num-workers",
+ type=int,
+ default=5,
+ help="num_workers of data loader",
+ )
+ data_parser = ArgumentParser(prog="")
+ data_parser.add_argument("--train", action=ActionParser(parser=train_parser))
+ data_parser.add_argument("--val", action=ActionParser(parser=val_parser))
+ parser.add_argument("--data", action=ActionParser(parser=data_parser))
+ parser.link_arguments(
+ "data.train.dataset.class_file", "data.val.dataset.class_file"
+ )
+ parser.link_arguments(
+ "data.train.data_loader.num_workers", "data.val.data_loader.num_workers"
+ )
+ parser.link_arguments(
+ "data.train.sampler.batch_size", "data.val.sampler.batch_size"
+ )
+
+ AF.add_class_args(parser, prefix="feats")
+ xvec_class.add_class_args(parser, prefix="model")
+ Trainer.add_class_args(
+ parser, prefix="trainer", train_modes=xvec_class.valid_train_modes()
+ )
+ dinossl.add_dinossl_args(parser)
+ ddp.add_ddp_args(parser)
+ parser.add_argument("--seed", type=int, default=1123581321, help="random seed")
+ # parser.add_argument(
+ # "--resume",
+ # action="store_true",
+ # default=False,
+ # help="resume training from checkpoint",
+ # )
+ parser.add_argument(
+ "-v", "--verbose", dest="verbose", default=1, choices=[0, 1, 2, 3], type=int
+ )
+
+ return parser
+
+
+if __name__ == "__main__":
+
+ parser = ArgumentParser(description="Train XVector from audio files")
+
+ parser.add_argument("--cfg", action=ActionConfigFile)
+
+ subcommands = parser.add_subcommands()
+
+ for k, v in xvec_dict.items():
+ parser_k = make_parser(v)
+ subcommands.add_subcommand(k, parser_k)
+
+ args = parser.parse_args()
+ try:
+ gpu_id = int(os.environ["LOCAL_RANK"])
+ except:
+ gpu_id = 0
+
+ xvec_type = args.subcommand
+ args_sc = vars(args)[xvec_type]
+
+ if gpu_id == 0:
+ try:
+ config_file = Path(args_sc.trainer.exp_path) / "config.yaml"
+ parser.save(args, str(config_file), format="yaml", overwrite=True)
+ except:
+ pass
+
+ args_sc.xvec_class = xvec_dict[xvec_type]
+ # torch docs recommend using forkserver
+ multiprocessing.set_start_method("forkserver",force=True)
+ train_xvec(gpu_id, args_sc)
diff --git a/hyperion/torch/data/audio_dataset.py b/hyperion/torch/data/audio_dataset.py
index 439c00ba..cb02c971 100644
--- a/hyperion/torch/data/audio_dataset.py
+++ b/hyperion/torch/data/audio_dataset.py
@@ -468,6 +468,9 @@ def __init__(
target_sample_freq=None,
wav_scale=2 ** 15 - 1,
is_val=False,
+ dinossl_chunk_len_mult=None,
+ dinossl_n_chunks=None,
+ dinossl_reduce_overlap_prob=0
):
super().__init__()
@@ -514,6 +517,19 @@ def __init__(
)
self.return_orig = return_orig
+ # start dino-stuff persephone_dinossl
+ # dinossl related
+ # self.dinossl_chunk_len_mult = dinossl_chunk_len_mult
+ # self.dinossl_n_chunks = dinossl_n_chunks
+ # self.dinossl_reduce_overlap_prob = dinossl_reduce_overlap_prob
+
+ # self._prepare_class_info(class_file)
+
+ # if max_chunk_length is None:
+ # max_chunk_length = min_chunk_length
+ # self._min_chunk_length = min_chunk_length
+ # self._max_chunk_length = max_chunk_length
+ # end dinostuff =======
self.num_augs = num_augs
self._create_augmenters(aug_cfgs)
@@ -589,6 +605,84 @@ def min_seq_length(self):
def max_seq_length(self):
return np.max(self.seq_lengths)
+ # start dino stuff <<<<<<< persephone_dinossl
+ # def _prune_short_seqs(self, min_length):
+ # if self.rank == 0:
+ # logging.info("pruning short seqs")
+ # keep_idx = self.seq_lengths >= min_length
+ # self.u2c = self.u2c.filter_index(keep_idx)
+ # self._seq_lengths = self.seq_lengths[keep_idx]
+ # if self.rank == 0:
+ # logging.info(
+ # "pruned seqs with min_length < %f,"
+ # "keep %d/%d seqs" % (min_length, self.num_seqs, len(keep_idx))
+ # )
+
+ # def _prepare_class_info(self, class_file):
+ # class_weights = None
+ # if class_file is None:
+ # classes, class_idx = np.unique(self.u2c.info, return_inverse=True)
+ # class2idx = {k: i for i, k in enumerate(classes)}
+ # else:
+ # if self.rank == 0:
+ # logging.info("reading class-file %s" % (class_file))
+ # class_info = pd.read_csv(class_file, header=None, sep=" ")
+ # class2idx = {str(k): i for i, k in enumerate(class_info[0])}
+ # class_idx = np.array([class2idx[k] for k in self.u2c.info], dtype=int)
+ # if class_info.shape[1] == 2:
+ # class_weights = np.array(class_info[1]).astype(
+ # floatstr_torch(), copy=False
+ # )
+ #
+ # self.num_classes = len(class2idx)
+
+ # class2utt_idx = {}
+ # class2num_utt = np.zeros((self.num_classes,), dtype=int)
+
+ # for k in range(self.num_classes):
+ # idx = (class_idx == k).nonzero()[0]
+ # class2utt_idx[k] = idx
+ # class2num_utt[k] = len(idx)
+ # if class2num_utt[k] == 0:
+ # if (not self.is_val) and (self.dinossl_chunk_len_mult is None):
+ # logging.warning("class %d doesn't have any samples" % (k))
+ # if class_weights is None:
+ # class_weights = np.ones((self.num_classes,), dtype=floatstr_torch())
+ # class_weights[k] = 0
+
+ # count_empty = np.sum(class2num_utt == 0)
+ # if count_empty > 0:
+ # logging.warning("%d classes have 0 samples" % (count_empty))
+
+ # self.utt_idx2class = class_idx
+ # self.class2utt_idx = class2utt_idx
+ # self.class2num_utt = class2num_utt
+ # if class_weights is not None:
+ # class_weights /= np.sum(class_weights)
+ # class_weights = torch.Tensor(class_weights)
+ # self.class_weights = class_weights
+
+ # if self.short_seq_exist:
+ # # if there are seq shorter than max_chunk_lenght we need some extra variables
+ # # we will need class_weights to put to 0 classes that have all utts shorter than the batch chunk length
+ # if self.class_weights is None:
+ # self.class_weights = torch.ones((self.num_classes,))
+
+ # # we need the max length of the utterances of each class
+ # class2max_length = torch.zeros((self.num_classes,), dtype=torch.float)
+ # for c in range(self.num_classes):
+ # if class2num_utt[c] > 0:
+ # class2max_length[c] = np.max(
+ # self.seq_lengths[self.class2utt_idx[c]]
+ # )
+ #
+ # self.class2max_length = class2max_length
+
+
+ #def _seq_shorter_than_max_length_exists(self, max_length):
+ # return np.any(self.seq_lengths < max_length)
+ #end dinostuff
+
@property
def num_classes(self):
return {k: t.num_classes for k, t in self.class_info.items()}
@@ -602,6 +696,12 @@ def _parse_segment_item(self, segment):
f"chunk duration ({duration})"
)
else:
+ # start dino stuff persephone_dinossl
+ # if self.dinossl_n_chunks == None:
+ # return self._get_random_chunk(index)
+ # else: # multi-chunks for dinossl
+ # return self._get_random_chunks(index)
+ # end dino stuff
seg_id, start, duration = segment, 0, 0
if "start" in self.seg_set:
@@ -695,14 +795,177 @@ def __getitem__(self, segment):
r.append(x_orig)
else:
+ # start dinostuff persephone_dinossl
+ # chunk_length = self.max_chunk_length
+
+ #key = self.u2c.key[index]
+
+ #full_seq_length = self.seq_lengths[index]
+ #assert (
+ # chunk_length <= full_seq_length
+ #), "chunk_length(%d) <= full_seq_length(%d)" % (chunk_length, full_seq_length)
+
+ #time_offset = torch.rand(size=(1,)).item() * (full_seq_length - chunk_length)
+ #reverb_context = min(self.reverb_context, time_offset)
+ #time_offset -= reverb_context
+ #read_chunk_length = chunk_length + reverb_context
+
+ # logging.info('get-random-chunk {} {} {} {} {}'.format(index, key, time_offset, chunk_length, full_seq_length ))
+ #x, fs = self.r.read([key], time_offset=time_offset, time_durs=read_chunk_length)
+
+ #x = x[0]
+ #fs = fs[0]
+
+ #x_clean = x
+ #if self.augmenter is not None:
+ # chunk_length_samples = int(chunk_length * fs)
+ # end_idx = len(x)
+ # reverb_context_samples = end_idx - chunk_length_samples
+ # assert reverb_context_samples >= 0, (
+ # "key={} time-offset={}, read-chunk={} "
+ # "read-x-samples={}, chunk_samples={}, reverb_context_samples={}"
+ # ).format(
+ # key,
+ # time_offset,
+ # read_chunk_length,
+ # end_idx,
+ # chunk_length_samples,
+ # reverb_context_samples,
+ # )
+ # # end_idx = reverb_context_samples + chunk_length_samples
+ # x, aug_info = self.augmenter(x)
+ # x = x[reverb_context_samples:end_idx]
+ # if self.return_clean_aug_pair:
+ # x_clean = x_clean[reverb_context_samples:end_idx]
+ # x_clean = x_clean.astype(floatstr_torch(), copy=False)
+ # # x_clean = x_clean[reverb_context_samples:]
+ # # logging.info('augmentation x-clean={}, x={}, aug_info={}'.format(
+ # # x_clean.shape, x.shape, aug_info))
+ ## if len(x) != 64000:
+ ## logging.info('x!=4s, {} {} {} {} {} {} {} {}'.format(len(x),reverb_context, reverb_context_samples, chunk_length, chunk_length_samples, end_idx, fs, read_chunk_length))
+
+ ## if len(x) != 64000:
+ ## logging.info('x!=4s-2, {} {} {} {}'.format(len(x), chunk_length, fs, read_chunk_length))
+
+ #if self.transpose_input:
+ # x = x[None, :]
+ # if self.return_clean_aug_pair:
+ # x_clean = x_clean[None, :]
+
+ #x = x.astype(floatstr_torch(), copy=False)
+ #if self.return_clean_aug_pair:
+ # r = x, x_clean
+ #else:
+ # r = (x,)
+ # end dinostuff
r = [x]
+
# adds the segment labels
seg_info = self._get_segment_info(seg_id)
r.extend(seg_info)
return (*r,)
+ def _get_random_chunks(self, index):
+
+ if len(index) == 2:
+ index, chunk_length = index
+ else:
+ chunk_length = self.max_chunk_length
+ key = self.u2c.key[index]
+
+ full_seq_length = self.seq_lengths[index]
+ assert chunk_length <= full_seq_length, 'chunk_length(%d) <= full_seq_length(%d)' % (
+ chunk_length, full_seq_length)
+
+ chunk_length_list = []
+ # 2 long chunks
+ if chunk_length * self.dinossl_chunk_len_mult > full_seq_length:
+ chunk_length_list.extend([full_seq_length]*2)
+ else:
+ chunk_length_list.extend([chunk_length * self.dinossl_chunk_len_mult]*2)
+ # self.n_chunks - 2 short chunks
+ chunk_length_list.extend([chunk_length]*(self.dinossl_n_chunks-2))
+
+ r_list = [] # this is for dino's multiple augmentations (more than once) of a given sample
+
+ # to reduce overlap between 2 long chunks
+ reduce_overlap = (self.dinossl_reduce_overlap_prob > torch.rand(size=(1,)))
+ if reduce_overlap:
+ long_chunk_proc_cnt = 0
+ tmp = torch.rand(size=(5,))*(full_seq_length - chunk_length_list[0])
+ time_offset_long_chunks = [torch.min(tmp), torch.max(tmp)]
+
+ for chunk_length in chunk_length_list: # full_seq_length, self.reverb_context are fixed within this for loop
+ if reduce_overlap and (long_chunk_proc_cnt < 2):
+ time_offset = time_offset_long_chunks[long_chunk_proc_cnt]
+ long_chunk_proc_cnt += 1
+ else:
+ time_offset = torch.rand(size=(1,)).item()*(full_seq_length-chunk_length)
+ reverb_context = min(self.reverb_context, time_offset)
+ time_offset -= reverb_context
+ read_chunk_length = chunk_length + reverb_context
+
+ #logging.info('get-random-chunk {} {} {} {} {}'.format(index, key, time_offset, chunk_length, full_seq_length ))
+ x, fs = self.r.read([key], time_offset=time_offset,
+ time_durs=read_chunk_length)
+
+ x = x[0]
+ fs = fs[0]
+
+ x_clean = x
+ if self.augmenter is not None:
+ chunk_length_samples = int(chunk_length * fs)
+ end_idx = len(x)
+ reverb_context_samples = end_idx - chunk_length_samples
+ assert reverb_context_samples >= 0, (
+ ('key={} time-offset={}, read-chunk={} '
+ 'read-x-samples={}, chunk_samples={}, reverb_context_samples={}').format(
+ key, time_offset, read_chunk_length,
+ end_idx, chunk_length_samples, reverb_context_samples))
+ # end_idx = reverb_context_samples + chunk_length_samples
+ x, aug_info = self.augmenter(x)
+ x = x[reverb_context_samples:end_idx]
+ if self.return_clean_aug_pair:
+ x_clean = x_clean[reverb_context_samples:end_idx]
+ x_clean = x_clean.astype(floatstr_torch(), copy=False)
+ #x_clean = x_clean[reverb_context_samples:]
+ #logging.info('augmentation x-clean={}, x={}, aug_info={}'.format(
+ # x_clean.shape, x.shape, aug_info))
+ # if len(x) != 64000:
+ # logging.info('x!=4s, {} {} {} {} {} {} {} {}'.format(len(x),reverb_context, reverb_context_samples, chunk_length, chunk_length_samples, end_idx, fs, read_chunk_length))
+
+ # if len(x) != 64000:
+ # logging.info('x!=4s-2, {} {} {} {}'.format(len(x), chunk_length, fs, read_chunk_length))
+
+ if self.transpose_input:
+ x = x[None,:]
+ if self.return_clean_aug_pair:
+ x_clean = x_clean[None,:]
+
+ x = x.astype(floatstr_torch(), copy=False)
+ if self.return_clean_aug_pair:
+ r = x, x_clean
+ else:
+ r = (x,)
+ r_list.append(*r)
+
+ if len(r_list) == 1: del r_list
+
+ if not self.return_class:
+ try:
+ return r_list
+ except:
+ return r
+
+ class_idx = self.utt_idx2class[index]
+ try:
+ r = r_list, class_idx
+ except:
+ r = *r, class_idx
+
+ return r
@staticmethod
def filter_args(**kwargs):
diff --git a/hyperion/torch/lr_schedulers/factory.py b/hyperion/torch/lr_schedulers/factory.py
index 3fef6e93..24e85551 100644
--- a/hyperion/torch/lr_schedulers/factory.py
+++ b/hyperion/torch/lr_schedulers/factory.py
@@ -39,6 +39,7 @@ def create(
d_model=None,
lr_factor=1,
update_lr_on_opt_step=False,
+ **kwargs
):
"""Creates a learning rate scheduler object.
@@ -195,6 +196,12 @@ def filter_args(**kwargs):
"lr_factor",
"d_model",
"update_lr_on_opt_step",
+ "dinossl_lr",
+ "dinossl_min_lr",
+ "dinossl_warmup_epochs",
+ "dinossl_weight_decay",
+ "dinossl_weight_decay_end",
+ "dinossl_momentum_teacher"
)
return dict((k, kwargs[k]) for k in valid_args if k in kwargs)
@@ -211,6 +218,7 @@ def add_class_args(parser, prefix=None):
default="none",
choices=[
"none",
+ "dinossl",
"exp_lr",
"invpow_lr",
"cos_lr",
@@ -340,6 +348,29 @@ def add_class_args(parser, prefix=None):
action="store_true",
help=("Update lr based on batch number instead of epoch number"),
)
+ # dinossl related - start
+ parser.add_argument("--dinossl_lr", default=0.005, type=float,
+ help=("""Learning rate at the end of linear warmup (highest LR used during training).
+ The learning rate is linearly scaled with the batch size, and specified here for a
+ reference batch size of 256."""))
+ parser.add_argument("--dinossl_min_lr" ,
+ default=1e-6, type=float,
+ help=("Target LR at the end of optimization. We use a cosine LR schedule with linear warmup."))
+ parser.add_argument("--dinossl_warmup_epochs" ,
+ default=10, type=int,
+ help=("Number of epochs for the linear learning-rate warm up."))
+ parser.add_argument("--dinossl_weight_decay" ,
+ default=0.04, type=float,
+ help=("Initial value of the weight decay. With ViT, a smaller value at the beginning of training works well."))
+ parser.add_argument("--dinossl_weight_decay_end" ,
+ default=0.4, type=float,
+ help=("""Final value of the weight decay. We use a cosine schedule for WD and using a larger decay by
+ the end of training improves performance for ViTs."""))
+ parser.add_argument("--dinossl_momentum_teacher" ,
+ default=0, type=float,
+ help=("""Base EMA parameter for teacher update. The value is increased to 1 during training with cosine schedule.
+ We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256."""))
+ # dinossl related - end
if prefix is not None:
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
diff --git a/hyperion/torch/models/xvectors/resnet_xvector.py b/hyperion/torch/models/xvectors/resnet_xvector.py
index fe88ff57..ecc64773 100644
--- a/hyperion/torch/models/xvectors/resnet_xvector.py
+++ b/hyperion/torch/models/xvectors/resnet_xvector.py
@@ -187,6 +187,10 @@ def load(cls, file_path=None, cfg=None, state_dict=None):
model = cls(**cfg)
if state_dict is not None:
+ # (hacky coding. may change later for neater codes)
+ is_classif_net_out = len([ True for k in state_dict.keys() if k[:18] == 'classif_net.output']) # to check this is to know if the training was dinossl or not
+ if not is_classif_net_out:
+ model.classif_net.output = nn.Identity()
model.load_state_dict(state_dict)
return model
diff --git a/hyperion/torch/models/xvectors/xvector.py b/hyperion/torch/models/xvectors/xvector.py
index 15f0ce86..535cb835 100644
--- a/hyperion/torch/models/xvectors/xvector.py
+++ b/hyperion/torch/models/xvectors/xvector.py
@@ -222,7 +222,7 @@ def _pre_enc(self, x):
return x
def _post_enc(self, x, in_lengths=None, max_in_length=None):
- if self.encoder_net.out_dim() == 4:
+ if self.encoder_net.out_dim() == 4 and (not isinstance(self.classif_net,torch.nn.modules.linear.Linear)):
x = x.view(x.size(0), -1, x.size(-1))
if self.proj is not None:
@@ -286,7 +286,10 @@ class logits tensor with shape=(batch, num_classes).
x = self.encoder_net(x)
x, x_lengths = self._post_enc(x, x_lengths, max_in_length)
p = self.pool_net(x, x_lengths=x_lengths)
- y = self.classif_net(p, y)
+ if isinstance(self.classif_net.output,nn.modules.linear.Identity): # for dino
+ y = self.classif_net(p)
+ else:
+ y = self.classif_net(p, y)
return y
def forward_hid_feats(
@@ -572,6 +575,7 @@ def rebuild_output_layer(
intertop_margin=0.0,
num_subcenters=2,
):
+
if (
(self.num_classes is not None and self.num_classes != num_classes)
or (self.loss_type != loss_type)
diff --git a/hyperion/torch/narchs/classif_head.py b/hyperion/torch/narchs/classif_head.py
index adfeceb3..9416d022 100644
--- a/hyperion/torch/narchs/classif_head.py
+++ b/hyperion/torch/narchs/classif_head.py
@@ -272,7 +272,7 @@ def forward(self, x, y=None):
for l in range(self.num_embed_layers):
x = self.fc_blocks[l](x)
- if self.loss_type == "softmax":
+ if self.loss_type == "softmax" or isinstance(self.output,nn.modules.linear.Identity):
y = self.output(x)
else:
y = self.output(x, y)
diff --git a/hyperion/torch/optim/factory.py b/hyperion/torch/optim/factory.py
index ab350098..a7ebec1f 100644
--- a/hyperion/torch/optim/factory.py
+++ b/hyperion/torch/optim/factory.py
@@ -141,7 +141,7 @@ def create(
if base_opt is None:
raise Exception("unknown optimizer %s" % opt_type)
- if oss:
+ if oss: # (JJ: this (oss=True) is NOT touched for dinossl_style param filtering so with dinossl_style, the behavior is not yet confirmed)
from fairscale.optim.oss import OSS
logging.info("Optimizer uses OSS")
@@ -171,6 +171,7 @@ def filter_args(**kwargs):
"init_acc_val",
"max_iter",
"oss",
+ "dinossl_style"
)
return filter_args(valid_args, kwargs)
@@ -320,6 +321,10 @@ def add_class_args(parser, prefix=None):
"--max-iter", default=20, type=int, help=("max iterations in LBGS")
)
+ parser.add_argument(
+ '--dinossl_style', default=False, type=bool,
+ help=('per-parameter updates following FB dino repo to NOT regularize biases nor Norm parameters'))
+
if prefix is not None:
outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
# help='optimizer options')
diff --git a/hyperion/torch/torch_model_loader.py b/hyperion/torch/torch_model_loader.py
index c173cd50..c52c32b5 100644
--- a/hyperion/torch/torch_model_loader.py
+++ b/hyperion/torch/torch_model_loader.py
@@ -7,6 +7,7 @@
import re
import torch
+from hyperion.torch.utils import dinossl
from .narchs import *
from .models import *
@@ -34,7 +35,13 @@ def _fix_compatibility(class_obj, cfg):
return cfg
@staticmethod
- def load(file_path, extra_objs={}, map_location=None):
+ def load(file_path, extra_objs={}, map_location=None, state_dict_key='model_state_dict', dinossl_kwargs=None):
+ """
+ Args:
+ state_dict_key (str): key for state_dict of a pre-trained model. Currently either
+ 'model_state_dict' or 'model_teacher_state_dict' (possible option in dinossl)
+ dinossl_kwargs (dict): DINOHead related arguments to reconstruct the DINOHead module as it was in the traiing + location info for xvector extraction.
+ """
if map_location is None:
map_location = torch.device("cpu")
@@ -50,7 +57,8 @@ def load(file_path, extra_objs={}, map_location=None):
else:
raise Exception("unknown object with class_name=%s" % (class_name))
- state_dict = model_data["model_state_dict"]
+ state_dict = model_data[state_dict_key]
+ logging.info('Using state_dict_key: {} of the pre-trained model'.format(state_dict_key))
if "n_averaged" in state_dict:
del state_dict["n_averaged"]
@@ -58,10 +66,18 @@ def load(file_path, extra_objs={}, map_location=None):
cfg = TorchModelLoader._fix_compatibility(class_obj, cfg)
p = re.compile("^module\.")
+ q = re.compile('^backbone\.') # for dinossl
num_tries = 3
for tries in range(num_tries):
try:
- return class_obj.load(cfg=cfg, state_dict=state_dict)
+ model = class_obj.load(cfg=cfg, state_dict=state_dict)
+ if (dinossl_kwargs is not None) and (dinossl_kwargs['dinossl_xvec_loc'] != 'f'): # no need when dinossl_kwargs['dinossl_xvec_loc'] == 'f' since it does not requires DINOHead
+ embed_dim = state_dict_head['mlp.0.weight'].shape[1]
+ model = dinossl.MultiCropWrapper(model, dinossl.DINOHead(embed_dim, dinossl_kwargs['dinossl_out_dim'], use_bn=dinossl_kwargs['dinossl_use_bn_in_head'],
+ norm_last_layer=dinossl_kwargs['dinossl_norm_last_layer'], nlayers=dinossl_kwargs['dinossl_nlayers']))
+ model.head.load_state_dict(state_dict_head) # putting this into this "try:" block assumes the pre-trained model is always trained with multi-gpus.
+ model.dinossl_xvec_loc = dinossl_kwargs['dinossl_xvec_loc']
+ return model
except RuntimeError as err:
# remove module prefix when is trained with dataparallel
if tries == num_tries - 1:
@@ -69,3 +85,8 @@ def load(file_path, extra_objs={}, map_location=None):
raise err
# remove module prefix when is trained with dataparallel
state_dict = ODict((p.sub("", k), v) for k, v in state_dict.items())
+ # below three are for dinossl
+ state_dict = ODict((q.sub('',k), v) for k,v in state_dict.items())
+ state_dict_head = ODict((k[5:], v) for k,v in state_dict.items() if (k[:4] == 'head'))
+ state_dict = ODict((k, v) for k,v in state_dict.items() if not (k[:4] == 'head'))
+
diff --git a/hyperion/torch/trainers/__init__.py b/hyperion/torch/trainers/__init__.py
index 8fef7df5..f4461b39 100644
--- a/hyperion/torch/trainers/__init__.py
+++ b/hyperion/torch/trainers/__init__.py
@@ -4,14 +4,17 @@
"""
from .torch_trainer import TorchTrainer
+from .torch_trainer_dinossl import DINOSSLTorchTrainer
from .xvector_trainer import XVectorTrainer
from .xvector_trainer_deep_feat_reg import XVectorTrainerDeepFeatReg
from .xvector_adv_trainer import XVectorAdvTrainer
+from .xvector_trainer_dinossl import DINOSSLXVectorTrainer
from .xvector_trainer_from_wav import XVectorTrainerFromWav
from .xvector_trainer_deep_feat_reg_from_wav import XVectorTrainerDeepFeatRegFromWav
from .xvector_adv_trainer_from_wav import XVectorAdvTrainerFromWav
+from .xvector_trainer_from_wav_dinossl import DINOSSLXVectorTrainerFromWav
from .vae_trainer import VAETrainer
from .dvae_trainer import DVAETrainer
diff --git a/hyperion/torch/trainers/torch_trainer.py b/hyperion/torch/trainers/torch_trainer.py
index 5f573904..94caf1d0 100644
--- a/hyperion/torch/trainers/torch_trainer.py
+++ b/hyperion/torch/trainers/torch_trainer.py
@@ -3,6 +3,7 @@
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
+from hyperion.torch.utils import dinossl
import os
import math
import contextlib
@@ -413,7 +414,13 @@ def _make_optimizer(self, optim, model, oss=False):
opt_args["oss"] = oss
if self.rank == 0:
logging.info("optimizer args={}".format(opt_args))
- optimizer = OF.create(model.parameters(), **opt_args)
+ if opt_args['dinossl_style']: # dinossl_style means per-parameter updates following FB dino repo to NOT regularize biases nor Norm parameters
+ params_groups = dinossl.get_params_groups(model)
+ del opt_args['dinossl_style']
+ optimizer = OF.create(params_groups, **opt_args)
+ else:
+ del opt_args['dinossl_style']
+ optimizer = OF.create(model.parameters(), **opt_args)
return optimizer
def _make_lr_sched(self, lr_sched, optim):
diff --git a/hyperion/torch/trainers/torch_trainer_dinossl.py b/hyperion/torch/trainers/torch_trainer_dinossl.py
new file mode 100644
index 00000000..d1e5dc4b
--- /dev/null
+++ b/hyperion/torch/trainers/torch_trainer_dinossl.py
@@ -0,0 +1,810 @@
+"""
+ Copyright 2022 Johns Hopkins University (Author: Jaejin Cho) - changes in var names from original dino repo to this: student --> model & teacher --> model_teacher
+ Copyright 2019 Johns Hopkins University (Author: Jesus Villalba)
+ Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+"""
+
+from hyperion.torch.utils import dinossl
+import os
+import math
+import contextlib
+from collections import OrderedDict as ODict
+from enum import Enum
+from jsonargparse import ArgumentParser, ActionParser
+import logging
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.cuda.amp as amp
+from torch.optim.swa_utils import AveragedModel, SWALR
+import torch.distributed as dist
+
+from fairscale.optim.grad_scaler import ShardedGradScaler
+
+from ..utils import MetricAcc, TorchDDP, FairShardedDDP, FairFullyShardedDDP
+from ..loggers import LoggerList, CSVLogger, ProgLogger, TensorBoardLogger, WAndBLogger
+from ..optim import OptimizerFactory as OF
+from ..lr_schedulers import LRSchedulerFactory as LRSF
+from ..lr_schedulers import LRScheduler as LRS
+
+
+class DDPType(str, Enum):
+ DDP = "ddp"
+ OSS_DDP = "oss_ddp"
+ OSS_SHARDED_DDP = "oss_sharded_ddp"
+ FULLY_SHARDED_DDP = "fully_sharded_ddp"
+
+
+ddp_choices = [o.value for o in DDPType]
+
+
+class DINOSSLTorchTrainer(object):
+ """Base Trainer class to train basic neural network models
+
+ Attributes:
+ model: model object.
+ loss: nn.Module loss class
+ optim: pytorch optimizer object or optimizer options dict
+ epochs: max. number of epochs
+ exp_path: experiment output path
+ cur_epoch: current epoch
+ grad_acc_steps: gradient accumulation steps to simulate larger batch size.
+ device: cpu/gpu device
+ metrics: extra metrics to compute besides cxe.
+ lrsched: learning rate scheduler object
+ loggers: LoggerList object, loggers write training progress to std. output and file.
+ ddp: if True use distributed data parallel training
+ ddp_type: type of distributed data parallel in (ddp, oss_ddp, oss_shared_ddp)
+ train_mode: training mode in ['full', 'frozen']
+ use_amp: uses mixed precision training.
+ log_interval: number of optim. steps between log outputs
+ use_tensorboard: use tensorboard logger
+ use_wandb: use wandb logger
+ wandb: wandb dictionary of options
+ grad_clip: norm to clip gradients, if 0 there is no clipping
+ grad_clip_norm: norm type to clip gradients
+ swa_start: epoch to start doing swa
+ swa_lr: SWA learning rate
+ swa_anneal_epochs: SWA learning rate anneal epochs
+ cpu_offload: CPU offload of gradients when using fully sharded ddp
+ """
+
+ def __init__(
+ self,
+ model,
+ loss,
+ optim={},
+ epochs=100,
+ exp_path="./train",
+ cur_epoch=0,
+ grad_acc_steps=1,
+ eff_batch_size=None,
+ device=None,
+ metrics=None,
+ lrsched=None,
+ loggers=None,
+ ddp=False,
+ ddp_type="ddp",
+ train_mode="full",
+ use_amp=False,
+ log_interval=10,
+ use_tensorboard=False,
+ use_wandb=False,
+ wandb={},
+ grad_clip=0,
+ grad_clip_norm=2,
+ swa_start=0,
+ swa_lr=1e-3,
+ swa_anneal_epochs=10,
+ cpu_offload=False,
+ niter_per_ep=0,
+ batch_size=0
+ ):
+
+ self.model = model[0]
+ self.model_teacher = model[1]
+ self.loss = loss
+ self.epochs = epochs
+ self.niter_per_ep = niter_per_ep
+ self.batch_size = batch_size
+ self.cur_epoch = cur_epoch
+ self.grad_acc_steps = grad_acc_steps
+ self.eff_batch_size = eff_batch_size
+ self.exp_path = Path(exp_path)
+
+ if loggers is None:
+ self.loggers = self._default_loggers(
+ log_interval, use_tensorboard, use_wandb, wandb
+ )
+ elif isinstance(loggers, list):
+ self.loggers = LoggerList(loggers)
+ else:
+ self.loggers = loggers
+
+ self.metrics = metrics
+ self.device = device
+ self.train_mode = train_mode
+ self.use_amp = use_amp
+ self.grad_clip = grad_clip
+ self.grad_clip_norm = grad_clip_norm
+ self.swa_start = swa_start
+ self.do_swa = swa_start > 0
+ self.swa_lr = swa_lr
+ self.swa_anneal_epochs = swa_anneal_epochs
+ self.amp_args = {}
+
+ self.set_train_mode()
+
+ if device is not None:
+ self.model.to(device)
+ self.model_teacher.to(device)
+ if loss is not None:
+ self.loss.to(device)
+
+ self.ddp = ddp
+ self.ddp_type = ddp_type
+ self.rank = 0
+ self.world_size = 1
+ if ddp: # (JJ: EXP - for now, I will only use self.ddp_type = 'ddp', i.e., DDPType.DDP)
+ self.rank = dist.get_rank()
+ self.world_size = dist.get_world_size()
+ if ddp_type == DDPType.DDP or ddp_type == DDPType.OSS_DDP:
+ if dinossl.has_batchnorms(self.model):
+ self.model, self.model_teacher, self.model_teacher_without_ddp = self.convert_sync_batchnorm(device)
+ else: # (JJ: EXP - when we use ViT-like model (in DINO) that does not have batchnorms: This is not tested yet in hyperion)
+ self.model_teacher_without_ddp = self.model_teacher
+ if self.rank == 0:
+ logging.info(
+ "training in multiple gpus with distributed-data-parallel"
+ )
+ oss = False if ddp_type == DDPType.DDP else True
+ self.optimizer = self._make_optimizer(optim, self.model, oss=oss)
+ self.model = TorchDDP(
+ self.model,
+ device_ids=[device],
+ output_device=device,
+ )
+ elif ddp_type == DDPType.OSS_SHARDED_DDP:
+ if dinossl.has_batchnorms(self.model):
+ self.model, self.model_teacher, self.model_teacher_without_ddp = self.convert_sync_batchnorm(device)
+ else:
+ self.model_teacher_without_ddp = self.model_teacher
+ if self.rank == 0:
+ logging.info(
+ "training in multiple gpus with fair sharded-distributed-data-parallel"
+ )
+ self.optimizer = self._make_optimizer(optim, self.model, oss=True)
+ self.model = FairShardedDDP(self.model, self.optimizer)
+ else:
+ if self.rank == 0:
+ logging.info(
+ "training in multiple gpus with fair fully-sharded-distributed-data-parallel"
+ )
+ # syncbathcnorm is not supported here, it raises exception
+ self.model = FairFullyShardedDDP(
+ self.model,
+ mixed_precision=self.use_amp,
+ move_params_to_cpu=cpu_offload,
+ )
+ self.optimizer = self._make_optimizer(optim, self.model, oss=False)
+
+ else:
+ self.model_teacher_without_ddp = self.model_teacher
+ self.optimizer = self._make_optimizer(optim, self.model)
+
+ # NO backpropagation through model_teacher, which instead is updated by momentum.
+ # Do the step here after ddp applied. Otherwise, an assertion error raises (DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.)
+ for p in self.model_teacher.parameters():
+ p.requires_grad = False
+
+ # make the learning rate scheduler or schedules
+ if lrsched['lrsch_type'] == 'dinossl':
+ self.lr_scheduler = None
+ self.lr_schedule, self.wd_schedule, self.momentum_schedule = self._make_schedules(lrsched)
+ else:
+ self.lr_scheduler = self._make_lr_sched(lrsched, self.optimizer)
+
+ if self.use_amp:
+ if ddp and ddp_type != DDPType.DDP:
+ if self.rank == 0:
+ logging.info(
+ "using automatic mixed precision training with sharded-grad-scaler"
+ )
+ self.grad_scaler = ShardedGradScaler()
+ else:
+ if self.rank == 0:
+ logging.info(
+ "using automatic mixed precision training with grad-scaler"
+ )
+ self.grad_scaler = amp.GradScaler()
+ self.amp_autocast = amp.autocast
+ else:
+ self.amp_autocast = contextlib.nullcontext
+
+ self.in_swa = False
+ if self.do_swa:
+ if self.rank == 0:
+ logging.info("init SWA model")
+ self.swa_model = AveragedModel(self.model)
+ self.swa_scheduler = SWALR(
+ self.optimizer, swa_lr=self.swa_lr, anneal_epochs=self.swa_anneal_epochs
+ )
+
+ def fit(self, train_data, val_data=None):
+ """Training function, it performs the training and validation epochs
+
+ Args:
+ train_data: PyTorch data loader for the training loop
+ val_data: PyTorch data loader for the validation loop
+ """
+ self.exp_path.mkdir(parents=True, exist_ok=True)
+ #self._compute_grad_acc_steps(train_data) # Do not apply grad_acc_steps with the current dinossl (i.e., set this to 1 as a default). Instead, apply linear scaling of the lr as is done in original dino repo.
+
+ if self.do_swa and self.cur_epoch >= self.swa_start:
+ raise NotImplementedError("Using swa is not implemented for dinossl yet")
+ self.in_swa = True
+
+ val_logs = {}
+ self.loggers.on_train_begin(epochs=self.epochs)
+ for epoch in range(self.cur_epoch, self.epochs):
+
+ self.loggers.on_epoch_begin(epoch, batches=len(train_data))
+ if self.lr_scheduler is not None:
+ # this is needed by cosine scheduler
+ epoch_updates = int(len(train_data) / self.grad_acc_steps)
+ self.lr_scheduler.on_epoch_begin(epoch, epoch_updates=epoch_updates)
+
+ logs = self.train_epoch(train_data)
+ if val_data is not None:
+ val_logs = self.validation_epoch(val_data)
+ logs.update(val_logs)
+ else:
+ logging.info("NO validation phase ...")
+
+ self.cur_epoch += 1
+
+ self.loggers.on_epoch_end(logs)
+ if self.do_swa and self.cur_epoch >= self.swa_start:
+ self.in_swa = True
+ self.swa_model.update_parameters(self.model)
+ self.swa_scheduler.step()
+ else:
+ if self.lr_scheduler is not None:
+ self.lr_scheduler.on_epoch_end(logs)
+
+ self.save_checkpoint(logs)
+
+ if self.in_swa:
+ self.loggers.on_epoch_begin(self.cur_epoch, batches=len(train_data))
+ self.model = self.swa_model.module
+ logs = self.bn_update_epoch(train_data)
+
+ if val_data is not None:
+ val_logs = self.validation_epoch(val_data)
+ logs.update(val_logs)
+
+ self.cur_epoch += 1
+ self.loggers.on_epoch_end(logs)
+ self.save_swa_model(logs)
+
+ def set_train_mode(self):
+ # self.model.train_mode = self.train_mode
+ self.model.set_train_mode(self.train_mode)
+
+ def train_epoch(self, data_loader):
+ """Training epoch loop
+
+ Args:
+ data_loader: PyTorch data loader return input/output pairs
+ """
+ metric_acc = MetricAcc(device=self.device)
+ batch_metrics = ODict()
+ self.model.train()
+ for batch, (data, target) in enumerate(data_loader):
+ self.loggers.on_batch_begin(batch)
+ if batch % self.grad_acc_steps == 0:
+ self.optimizer.zero_grad()
+
+ data, target = data.to(self.device), target.to(self.device)
+ batch_size = data.shape[0]
+ with self.amp_autocast():
+ output = self.model(data)
+ loss = self.loss(output, target).mean() / self.grad_acc_steps
+
+ if self.use_amp:
+ self.grad_scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ if (batch + 1) % self.grad_acc_steps == 0:
+ if self.lr_scheduler is not None and not self.in_swa:
+ self.lr_scheduler.on_opt_step()
+ self.update_model()
+
+ self._reduce_metric(loss)
+ batch_metrics["loss"] = loss.item() * self.grad_acc_steps
+ for k, metric in self.metrics.items():
+ batch_metrics[k] = metric(output, target)
+
+ metric_acc.update(batch_metrics, batch_size)
+ logs = metric_acc.metrics
+ logs["lr"] = self._get_lr()
+ self.loggers.on_batch_end(logs=logs, batch_size=batch_size)
+ # total_batches += 1
+
+ logs = metric_acc.metrics
+ logs = ODict(("train_" + k, v) for k, v in logs.items())
+ logs["lr"] = self._get_lr()
+ return logs
+
+ def validation_epoch(self, data_loader, swa_update_bn=False):
+ """Validation epoch loop
+
+ Args:
+ data_loader: PyTorch data loader return input/output pairs.
+ sw_update_bn: wheter or not, update batch-norm layers in SWA.
+ """
+
+ metric_acc = MetricAcc(self.device)
+ batch_metrics = ODict()
+ with torch.no_grad():
+ if swa_update_bn:
+ log_tag = "train_"
+ self.train()
+ else:
+ log_tag = "val_"
+ self.model.eval()
+
+ for batch, (data, target) in enumerate(data_loader):
+ data, target = data.to(self.device), target.to(self.device)
+ batch_size = data.shape[0]
+
+ with self.amp_autocast():
+ output = self.model(data)
+ loss = self.loss(output, target)
+
+ batch_metrics["loss"] = loss.mean().item()
+ for k, metric in self.metrics.items():
+ batch_metrics[k] = metric(output, target)
+
+ metric_acc.update(batch_metrics, batch_size)
+
+ logs = metric_acc.metrics
+ logs = ODict((log_tag + k, v) for k, v in logs.items())
+ return logs
+
+ def bn_update_epoch(self, data_loader):
+ logs = self.validation_epoch(data_loader, swa_update_bn=True)
+ logs["lr"] = self._get_lr()
+ return logs
+
+ def _clip_grad_norm(self, model, optim, grad_clip, grad_clip_norm):
+ if self.ddp:
+ if self.ddp_type == DDPType.DDP:
+ nn.utils.clip_grad_norm_(
+ model.parameters(), grad_clip, norm_type=grad_clip_norm
+ )
+ return
+ if self.ddp_type == DDPType.FULLY_SHARDED_DDP:
+ # we have to use the member function in FullyShardedDDP class
+ model.clip_grad_norm_(grad_clip, norm_type=grad_clip_norm)
+ return
+ else:
+ # not sure about this but it looks like
+ # we have to use the member function in the OSS optimizer wrapper
+ optim.clip_grad_norm(grad_clip, norm_type=grad_clip_norm)
+
+ # if no DDP clip normally
+ nn.utils.clip_grad_norm_(
+ model.parameters(), grad_clip, norm_type=grad_clip_norm
+ )
+
+ def update_model(self):
+ """Updates the model and does gradding clipping."""
+ if self.use_amp:
+ if self.grad_clip > 0:
+ self.grad_scaler.unscale_(self.optimizer)
+ self._clip_grad_norm(
+ self.model, self.optimizer, self.grad_clip, self.grad_clip_norm
+ )
+
+ self.grad_scaler.step(self.optimizer)
+ self.grad_scaler.update()
+ else:
+ if self.grad_clip > 0:
+ self._clip_grad_norm(
+ self.model, self.optimizer, self.grad_clip, self.grad_clip_norm
+ )
+
+ self.optimizer.step()
+
+ def _make_optimizer(self, optim, model, oss=False):
+ """Makes an optimizer object."""
+ if isinstance(optim, torch.optim.Optimizer):
+ return optim
+
+ assert isinstance(optim, dict)
+ opt_args = OF.filter_args(**optim)
+ opt_args["oss"] = oss
+ if self.rank == 0:
+ logging.info("optimizer args={}".format(opt_args))
+ if opt_args['dinossl_style']: # dinossl_style means per-parameter updates following FB dino repo to NOT regularize biases nor Norm parameters
+ params_groups = dinossl.get_params_groups(model)
+ del opt_args['dinossl_style']
+ optimizer = OF.create(params_groups, **opt_args)
+ else:
+ del opt_args['dinossl_style']
+ optimizer = OF.create(model.parameters(), **opt_args)
+ return optimizer
+
+
+ def _make_lr_sched(self, lr_sched, optim):
+ """Makes a Learning Rate scheduler object."""
+ if lr_sched is None or isinstance(lr_sched, LRS):
+ return lr_sched
+
+ assert isinstance(lr_sched, dict)
+ args = LRSF.filter_args(**lr_sched)
+ if self.rank == 0:
+ logging.info("lr scheduler args={}".format(args))
+ lr_sched = LRSF.create(optim, **args)
+ return lr_sched
+
+ def _make_schedules(self, lrsched):
+ assert (self.niter_per_ep != 0) and (self.batch_size != 0)
+ lr_schedule = dinossl.cosine_scheduler(
+ lrsched['dinossl_lr'] * self.batch_size / 256., # linear scaling rule (JJ: TODO - 256. might need to change)
+ lrsched['dinossl_min_lr'],
+ self.epochs, self.niter_per_ep,
+ warmup_epochs=lrsched['dinossl_warmup_epochs'],
+ )
+ wd_schedule = dinossl.cosine_scheduler(
+ lrsched['dinossl_weight_decay'],
+ lrsched['dinossl_weight_decay_end'],
+ self.epochs, self.niter_per_ep,
+ )
+ # momentum parameter is getting increased to 1. during training with a cosine schedule
+ momentum_schedule = dinossl.cosine_scheduler(lrsched['dinossl_momentum_teacher'], 1,
+ self.epochs, self.niter_per_ep)
+
+ return lr_schedule, wd_schedule, momentum_schedule
+
+ def _default_loggers(self, log_interval, use_tensorboard, use_wandb, wandb):
+ """Creates the default data loaders"""
+ prog_log = ProgLogger(interval=log_interval)
+ csv_log = CSVLogger(self.exp_path / "train.log", append=True)
+ loggers = [prog_log, csv_log]
+ if use_tensorboard:
+ loggers.append(
+ TensorBoardLogger(self.exp_path / "tb", interval=log_interval)
+ )
+ if use_wandb:
+ loggers.append(
+ WAndBLogger(
+ **wandb, path=self.exp_path / "wandb", interval=log_interval
+ )
+ )
+ return LoggerList(loggers)
+
+ def _get_lr(self):
+ """Returns the current learning rate to show in the loggers"""
+ for param_group in self.optimizer.param_groups:
+ return param_group["lr"]
+
+ def _compute_grad_acc_steps(self, data_loader):
+ if self.eff_batch_size is None:
+ return
+
+ if data_loader.batch_sampler is not None:
+ try:
+ batch_size = data_loader.batch_sampler.avg_batch_size
+ except:
+ logging.warn(
+ "batch sampler doesn't have avg_batch_size property, "
+ "we cannot estimate grad_acc_steps, using grad_acc_steps=%d",
+ self.grad_acc_steps,
+ )
+ return
+
+ self.grad_acc_steps = int(
+ math.ceil(self.eff_batch_size / batch_size / self.world_size)
+ )
+ logging.info(
+ "Setting grad_acc_steps=%d for "
+ "eff_batch_size=%d, avg_batch_size=%d, world_size=%d",
+ self.grad_acc_steps,
+ self.eff_batch_size,
+ batch_size,
+ self.world_size,
+ )
+ return
+
+ logging.warn(
+ "We cannot determine the batch_size, "
+ "we cannot estimate grad_acc_steps, using grad_acc_steps=%d",
+ self.grad_acc_steps,
+ )
+
+ def checkpoint(self, logs=None):
+ """Creates a checkpoint of the training, to save and posterior recovery
+
+ Args:
+ logs: logs containing the current value of the metrics.
+ """
+ checkpoint = {
+ "epoch": self.cur_epoch,
+ "rng_state": torch.get_rng_state(),
+ "model_cfg": self.model.backbone.conf,
+ "model_state_dict": self.model.state_dict(),
+ 'model_teacher_state_dict': self.model_teacher.state_dict(),
+ "optimizer_state_dict": self.optimizer.state_dict(), # (JJ: EXP - "it" used for schedules could be found by EITHER "it = len(data_loader) * self.cur_epoch + batch" OR "lr" in the checkpoint['logs'] below)
+ "loss_state_dict": self.loss.state_dict()
+ if self.loss is not None
+ else None,
+ }
+ if self.lr_scheduler is not None:
+ checkpoint["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict()
+
+ if logs is not None:
+ checkpoint["logs"] = logs
+
+ if self.in_swa:
+ checkpoint["swa_model_state_dict"] = self.swa_model.state_dict()
+ checkpoint["swa_scheduler_state_dict"] = self.swa_scheduler.state_dict()
+
+ return checkpoint
+
+ def convert_sync_batchnorm(self, device):
+ self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
+ self.model_teacher = nn.SyncBatchNorm.convert_sync_batchnorm(self.model_teacher)
+ # we need DDP wrapper to have synchro batch norms working...
+ if self.ddp_type == DDPType.DDP:
+ self.model_teacher = nn.parallel.DistributedDataParallel(self.model_teacher, device_ids=[device], output_device=device) # (JJ: TODO - for now, simply follow what Jesus did)
+ else:
+ raise NotImplementedError('Need implementation for other DDPType except DDPType.DDP')
+ self.model_teacher_without_ddp = self.model_teacher.module
+
+ return self.model, self.model_teacher, self.model_teacher_without_ddp
+
+ def save_checkpoint(self, logs=None):
+ """Saves a checkpoint of the training status
+
+ Args:
+ logs: logs containing the current value of the metrics.
+ """
+ if self.ddp and (
+ self.ddp_type == DDPType.OSS_DDP or self.ddp_type == DDPType.OSS_SHARDED_DDP
+ ):
+ # Not sure what this does, just copying from the example in
+ # https://github.com/facebookresearch/fairscale/blob/master/benchmarks/oss.py
+ # Check the checkpointing in the case of the OSS optimizer
+ # Memory usage could spill over from there
+ # optimizer = cast(OSS, optimizer)
+ self.optimizer.consolidate_state_dict()
+
+ if self.rank != 0:
+ return
+ checkpoint = self.checkpoint(logs)
+ file_path = "%s/model_ep%04d.pth" % (self.exp_path, self.cur_epoch)
+
+ torch.save(checkpoint, file_path)
+
+ def save_swa_model(self, logs=None):
+ """Saves a checkpoint of the training status
+
+ Args:
+ logs: logs containing the current value of the metrics.
+ """
+ if self.rank != 0:
+ return
+
+ checkpoint = self.checkpoint(logs)
+ checkpoint["model_state_dict"] = checkpoint["swa_model_state_dict"]
+ del checkpoint["swa_model_state_dict"]
+ file_path = "%s/swa_model_ep%04d.pth" % (self.exp_path, self.cur_epoch)
+
+ torch.save(checkpoint, file_path)
+
+ def load_checkpoint(self, file_path):
+ """Loads a training checkpoint from file.
+
+ Args:
+ file_path: checkpoint file path
+ """
+ checkpoint = torch.load(file_path, map_location=torch.device("cpu"))
+ rng_state = checkpoint["rng_state"]
+ torch.set_rng_state(rng_state)
+ if self.rank > 0:
+ # this will make sure that each process produces different data
+ # when using ddp
+ dummy = torch.rand(1000 * self.rank)
+ del dummy
+
+ self.cur_epoch = checkpoint["epoch"]
+ try:
+ self.model.load_state_dict(checkpoint["model_state_dict"])
+ except:
+ self.model.module.load_state_dict(checkpoint["model_state_dict"])
+ try:
+ self.model_teacher.load_state_dict(checkpoint['model_teacher_state_dict'])
+ except:
+ self.model_teacher.module.load_state_dict(checkpoint['model_teacher_state_dict'])
+ self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+ if self.loss is not None:
+ self.loss.load_state_dict(checkpoint["loss_state_dict"])
+ if self.lr_scheduler is not None:
+ self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
+
+ # if self.use_amp:
+ # amp.load_state_dict(checkpoint['amp'])
+ if self.do_swa:
+ if "swa_model_state_dict" in checkpoint:
+ self.swa_model.load_state_dict(checkpoint["swa_model_state_dict"])
+ self.swa_scheduler.load_state_dict(
+ checkpoint["swa_scheduler_state_dict"]
+ )
+ else:
+ self.swa_scheduler = SWALR(
+ self.optimizer,
+ swa_lr=self.swa_lr,
+ anneal_epochs=self.swa_anneal_epochs,
+ )
+
+ logs = None
+ if "logs" in checkpoint:
+ logs = checkpoint["logs"]
+
+ del checkpoint
+ # this was added before to try to release as much GPU memory as possible
+ # Recently has started to cause CUDA not available devices error
+ # Commenting for now.
+ # if self.device is not None:
+ # torch.cuda.empty_cache()
+
+ return logs
+
+ def load_last_checkpoint(self):
+ """Loads the last training checkpoint in the experiment dir."""
+ for epoch in range(self.epochs, 0, -1):
+ file_path = "%s/model_ep%04d.pth" % (self.exp_path, epoch)
+ if os.path.isfile(file_path):
+ return self.load_checkpoint(file_path)
+
+ return None
+
+ @staticmethod
+ def filter_args(**kwargs):
+ valid_args = (
+ "grad_acc_steps",
+ "eff_batch_size",
+ "epochs",
+ "log_interval",
+ "use_amp",
+ "ddp_type",
+ "grad_clip",
+ "grad_clip_norm",
+ "swa_start",
+ "swa_lr",
+ "swa_anneal_epochs",
+ "exp_path",
+ "optim",
+ "lrsched",
+ "cpu_offload",
+ "use_tensorboard",
+ "use_wandb",
+ "wandb",
+ "train_mode",
+ )
+ args = dict((k, kwargs[k]) for k in valid_args if k in kwargs)
+ return args
+
+ @staticmethod
+ def add_class_args(parser, prefix=None, train_modes=None, skip=[]):
+ if prefix is not None:
+ outer_parser = parser
+ parser = ArgumentParser(prog="")
+
+ if "optim" not in skip:
+ OF.add_class_args(parser, prefix="optim")
+
+ if "lrsched" not in skip:
+ LRSF.add_class_args(parser, prefix="lrsched")
+
+ parser.add_argument(
+ "--grad-acc-steps",
+ type=int,
+ default=1,
+ help="gradient accumulation batches before weigth update",
+ )
+ parser.add_argument(
+ "--eff-batch-size",
+ type=int,
+ default=None,
+ help="effective total batch size, if given, it overrides grad_acc_steps",
+ )
+ parser.add_argument("--epochs", type=int, default=200, help="number of epochs")
+ if train_modes is not None:
+ parser.add_argument(
+ "--train-mode",
+ default="full",
+ choices=train_modes,
+ help=f"Available train modes for the model in {train_modes}",
+ )
+ parser.add_argument(
+ "--log-interval",
+ type=int,
+ default=10,
+ help="how many batches to wait before logging training status",
+ )
+ parser.add_argument(
+ "--use-tensorboard",
+ action="store_true",
+ default=False,
+ help="use tensorboard logger",
+ )
+ parser.add_argument(
+ "--use-wandb", action="store_true", default=False, help="use wandb logger"
+ )
+ parser.add_argument("--wandb.project", default=None, help="wandb project name")
+ parser.add_argument("--wandb.group", default=None, help="wandb group name")
+ parser.add_argument("--wandb.name", default=None, help="wandb display name")
+ # parser.add_argument(
+ # '--wandb.path', default=None,
+ # help='wandb directory')
+ parser.add_argument(
+ "--wandb.mode",
+ default="online",
+ choices=["online", "offline"],
+ help="wandb mode (online, offline)",
+ )
+
+ parser.add_argument(
+ "--ddp-type",
+ default="ddp",
+ choices=ddp_choices,
+ help="DDP type in {}".format(ddp_choices),
+ )
+ parser.add_argument(
+ "--use-amp",
+ action="store_true",
+ default=False,
+ help="use mixed precision training",
+ )
+ parser.add_argument(
+ "--cpu-offload",
+ action="store_true",
+ default=False,
+ help="CPU offload of gradients when using fully_sharded_ddp",
+ )
+ parser.add_argument(
+ "--grad-clip", type=float, default=0, help="gradient clipping norm value"
+ )
+ parser.add_argument(
+ "--grad-clip-norm",
+ default=2,
+ choices=["inf", 1, 2],
+ help="gradient clipping norm type",
+ )
+ parser.add_argument(
+ "--swa-start",
+ type=int,
+ default=0,
+ help="start epoch for SWA, if 0 it does not use SWA",
+ )
+ parser.add_argument(
+ "--swa-lr", type=float, default=1e-3, help="learning rate for SWA phase"
+ )
+ parser.add_argument(
+ "--swa-anneal-epochs",
+ type=int,
+ default=10,
+ help="SWA learning rate anneal epochs",
+ )
+
+ parser.add_argument("--exp-path", help="experiment path")
+
+ if prefix is not None:
+ outer_parser.add_argument("--" + prefix, action=ActionParser(parser=parser))
+
+ add_argparse_args = add_class_args
diff --git a/hyperion/torch/trainers/xvector_trainer_dinossl.py b/hyperion/torch/trainers/xvector_trainer_dinossl.py
new file mode 100644
index 00000000..02cee3f7
--- /dev/null
+++ b/hyperion/torch/trainers/xvector_trainer_dinossl.py
@@ -0,0 +1,163 @@
+"""
+ Copyright 2019 Johns Hopkins University (Author: Jesus Villalba)
+ Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+"""
+import os
+from collections import OrderedDict as ODict
+
+import logging
+
+import torch
+import torch.nn as nn
+
+from ..utils import MetricAcc
+from .torch_trainer_dinossl import DINOSSLTorchTrainer
+from torch.distributed.elastic.multiprocessing.errors import record
+
+
+class DINOSSLXVectorTrainer(DINOSSLTorchTrainer):
+ """Trainer to train x-vector style models.
+
+ Attributes:
+ model: x-Vector model object.
+ optim: pytorch optimizer object or options dict
+ epochs: max. number of epochs
+ exp_path: experiment output path
+ cur_epoch: current epoch
+ grad_acc_steps: gradient accumulation steps to simulate larger batch size.
+ device: cpu/gpu device
+ metrics: extra metrics to compute besides cxe.
+ lrsched: learning rate scheduler object or options dict
+ loggers: LoggerList object, loggers write training progress to std. output and file.
+ If None, it uses default loggers.
+ ddp: if True use distributed data parallel training
+ ddp_type: type of distributed data parallel in (ddp, oss_ddp, oss_shared_ddp)
+ loss: if None, it uses cross-entropy
+ train_mode: training mode in ['train', 'ft-full', 'ft-last-layer']
+ use_amp: uses mixed precision training.
+ log_interval: number of optim. steps between log outputs
+ use_tensorboard: use tensorboard logger
+ use_wandb: use wandb logger
+ wandb: wandb dictionary of options
+ grad_clip: norm to clip gradients, if 0 there is no clipping
+ grad_clip_norm: norm type to clip gradients
+ swa_start: epoch to start doing swa
+ swa_lr: SWA learning rate
+ swa_anneal_epochs: SWA learning rate anneal epochs
+ cpu_offload: CPU offload of gradients when using fully sharded ddp
+ """
+
+ def __init__(
+ self,
+ model,
+ optim={},
+ epochs=100,
+ exp_path="./train",
+ cur_epoch=0,
+ grad_acc_steps=1,
+ eff_batch_size=None,
+ device=None,
+ metrics=None,
+ lrsched=None,
+ loggers=None,
+ ddp=False,
+ ddp_type="ddp",
+ loss=None,
+ train_mode="full",
+ use_amp=False,
+ log_interval=10,
+ use_tensorboard=False,
+ use_wandb=False,
+ wandb={},
+ grad_clip=0,
+ grad_clip_norm=2,
+ swa_start=0,
+ swa_lr=1e-3,
+ swa_anneal_epochs=10,
+ cpu_offload=False,
+ niter_per_ep=0,
+ batch_size=0
+ ):
+
+ if loss is None:
+ loss = nn.CrossEntropyLoss()
+ super().__init__(
+ model,
+ loss,
+ optim,
+ epochs,
+ exp_path,
+ cur_epoch=cur_epoch,
+ grad_acc_steps=grad_acc_steps,
+ eff_batch_size=eff_batch_size,
+ device=device,
+ metrics=metrics,
+ lrsched=lrsched,
+ loggers=loggers,
+ ddp=ddp,
+ ddp_type=ddp_type,
+ train_mode=train_mode,
+ use_amp=use_amp,
+ log_interval=log_interval,
+ use_tensorboard=use_tensorboard,
+ use_wandb=use_wandb,
+ wandb=wandb,
+ grad_clip=grad_clip,
+ grad_clip_norm=grad_clip_norm,
+ swa_start=swa_start,
+ swa_lr=swa_lr,
+ swa_anneal_epochs=swa_anneal_epochs,
+ cpu_offload=cpu_offload,
+ niter_per_ep=niter_per_ep,
+ batch_size=batch_size
+ )
+
+ @record
+ def train_epoch(self, data_loader):
+ """Training epoch loop
+
+ Args:
+ data_loader: pytorch data loader returning features and class labels.
+ """
+
+ self.model.update_loss_margin(self.cur_epoch)
+
+ metric_acc = MetricAcc(device=self.device)
+ batch_metrics = ODict()
+ self.model.train()
+ for batch, (data, target) in enumerate(data_loader):
+ self.loggers.on_batch_begin(batch)
+
+ if batch % self.grad_acc_steps == 0:
+ self.optimizer.zero_grad()
+
+ data, target = data.to(self.device), target.to(self.device)
+ batch_size = data.shape[0]
+
+ with self.amp_autocast():
+ output = self.model(data, y=target)
+ loss = self.loss(output, target).mean() / self.grad_acc_steps
+
+ if self.use_amp:
+ self.grad_scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ if (batch + 1) % self.grad_acc_steps == 0:
+ if self.lr_scheduler is not None and not self.in_swa:
+ self.lr_scheduler.on_opt_step()
+ self.update_model()
+
+ batch_metrics["loss"] = loss.item() * self.grad_acc_steps
+ for k, metric in self.metrics.items():
+ batch_metrics[k] = metric(output, target)
+
+ metric_acc.update(batch_metrics, batch_size)
+ logs = metric_acc.metrics
+ logs["lr"] = self._get_lr()
+ self.loggers.on_batch_end(logs=logs, batch_size=batch_size)
+
+ logs = metric_acc.metrics
+ logs = ODict(("train_" + k, v) for k, v in logs.items())
+ logs["lr"] = self._get_lr()
+ return logs
diff --git a/hyperion/torch/trainers/xvector_trainer_from_wav_dinossl.py b/hyperion/torch/trainers/xvector_trainer_from_wav_dinossl.py
new file mode 100644
index 00000000..a8afb572
--- /dev/null
+++ b/hyperion/torch/trainers/xvector_trainer_from_wav_dinossl.py
@@ -0,0 +1,235 @@
+"""
+ Copyright 2019 Johns Hopkins University (Author: Jesus Villalba)
+ Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+"""
+import os
+import sys
+from collections import OrderedDict as ODict
+
+import logging
+import copy
+import math
+
+import torch
+import torch.nn as nn
+
+from ..utils import MetricAcc, TorchDDP
+from ..utils import cancel_gradients_last_layer
+from .xvector_trainer_dinossl import DINOSSLXVectorTrainer
+
+
+class DINOSSLXVectorTrainerFromWav(DINOSSLXVectorTrainer):
+ """Trainer to train x-vector style models.
+
+ Attributes:
+ model: x-Vector model object.
+ feat_extractor: feature extractor nn.Module
+ optim: pytorch optimizer object or options dict
+ epochs: max. number of epochs
+ exp_path: experiment output path
+ cur_epoch: current epoch
+ grad_acc_steps: gradient accumulation steps to simulate larger batch size.
+ device: cpu/gpu device
+ metrics: extra metrics to compute besides cxe.
+ lrsched: learning rate scheduler object or options dict.
+ loggers: LoggerList object, loggers write training progress to std. output and file.
+ ddp: if True use distributed data parallel training
+ ddp_type: type of distributed data parallel in (ddp, oss_ddp, oss_shared_ddp)
+ loss: if None, it uses cross-entropy
+ train_mode: training mode in ['train', 'ft-full', 'ft-last-layer']
+ use_amp: uses mixed precision training.
+ log_interval: number of optim. steps between log outputs
+ use_tensorboard: use tensorboard logger
+ use_wandb: use wandb logger
+ wandb: wandb dictionary of options
+ grad_clip: norm to clip gradients, if 0 there is no clipping
+ grad_clip_norm: norm type to clip gradients
+ swa_start: epoch to start doing swa
+ swa_lr: SWA learning rate
+ swa_anneal_epochs: SWA learning rate anneal epochs
+ cpu_offload: CPU offload of gradients when using fully sharded ddp
+ """
+
+ def __init__(
+ self,
+ model,
+ feat_extractor,
+ optim={},
+ epochs=100,
+ exp_path="./train",
+ cur_epoch=0,
+ grad_acc_steps=1,
+ eff_batch_size=None,
+ device=None,
+ metrics=None,
+ lrsched=None,
+ loggers=None,
+ ddp=False,
+ ddp_type="ddp",
+ loss=None,
+ train_mode="full",
+ use_amp=False,
+ log_interval=10,
+ use_tensorboard=False,
+ use_wandb=False,
+ wandb={},
+ grad_clip=0,
+ grad_clip_norm=2,
+ swa_start=0,
+ swa_lr=1e-3,
+ swa_anneal_epochs=10,
+ cpu_offload=False,
+ niter_per_ep=0,
+ batch_size=0
+ ):
+
+ super().__init__(
+ model,
+ optim,
+ epochs,
+ exp_path,
+ cur_epoch=cur_epoch,
+ grad_acc_steps=grad_acc_steps,
+ eff_batch_size=eff_batch_size,
+ device=device,
+ metrics=metrics,
+ lrsched=lrsched,
+ loggers=loggers,
+ ddp=ddp,
+ ddp_type=ddp_type,
+ loss=loss,
+ train_mode=train_mode,
+ use_amp=use_amp,
+ log_interval=log_interval,
+ use_tensorboard=use_tensorboard,
+ use_wandb=use_wandb,
+ wandb=wandb,
+ grad_clip=grad_clip,
+ grad_clip_norm=grad_clip_norm,
+ swa_start=swa_start,
+ swa_lr=swa_lr,
+ swa_anneal_epochs=swa_anneal_epochs,
+ cpu_offload=cpu_offload,
+ niter_per_ep=niter_per_ep,
+ batch_size=batch_size
+ )
+
+ self.feat_extractor = feat_extractor
+ if device is not None:
+ self.feat_extractor.to(device)
+
+ # if ddp:
+ # self.feat_extractor = TorchDDP(self.feat_extractor)
+
+ def train_epoch(self, data_loader):
+ """Training epoch loop
+
+ Args:
+ data_loader: pytorch data loader returning features and class labels.
+ """
+
+ metric_acc = MetricAcc(device=self.device)
+ batch_metrics = ODict()
+ self.feat_extractor.train()
+ self.model.train()
+ for batch, (data, _) in enumerate(data_loader):
+ self.loggers.on_batch_begin(batch)
+ if batch % self.grad_acc_steps == 0:
+ self.optimizer.zero_grad()
+
+ data = [i.to(self.device, non_blocking=True) for i in data]
+ batch_size = data[0].shape[0]
+ with torch.no_grad():
+ feats = []
+ for i in data:
+ feats.append(self.feat_extractor(i))
+
+ with self.amp_autocast():
+ output = self.model(feats)
+ output_teacher = self.model_teacher(feats[:2]) # 2 (currently this number is fixed) global crops
+ loss = self.loss(output, output_teacher, self.cur_epoch)/self.grad_acc_steps
+
+ if not math.isfinite(loss.item()):
+ logging.warning('Loss is {}, stopping training'.format(loss.item()))
+ sys.exit(1)
+
+ if self.use_amp:
+ self.grad_scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ freeze_last_layer=1
+ cancel_gradients_last_layer(self.cur_epoch, self.model, freeze_last_layer)
+
+ if (batch + 1) % self.grad_acc_steps == 0:
+ if self.lr_scheduler is not None and not self.in_swa:
+ self.lr_scheduler.on_opt_step()
+ # update learning rate and weight decay rate
+ it = len(data_loader) * self.cur_epoch + int(batch/self.grad_acc_steps) # it: global batch index, batch: local batch index in the current epoch
+ for i, param_group in enumerate(self.optimizer.param_groups):
+ param_group["lr"] = self.lr_schedule[it]
+ if i == 0: # only the first group is regularized
+ param_group["weight_decay"] = self.wd_schedule[it]
+ self.update_model()
+
+ # EMA update for the teacher
+ with torch.no_grad():
+ m = self.momentum_schedule[it] # momentum parameter
+ if hasattr(self.model,'module'): # train with ddp
+ for param_q, param_k in zip(self.model.module.parameters(), self.model_teacher_without_ddp.parameters()):
+ param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
+ else: # train with a single gpu (w/o ddp), which I (JJ) used in debugging
+ for param_q, param_k in zip(self.model.parameters(), self.model_teacher_without_ddp.parameters()):
+ param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
+ # (JJ: TODO - for now, I skipped the logging related parts (per iter and epoch) after the above from dino)
+
+ batch_metrics["loss"] = loss.item() * self.grad_acc_steps
+ for k, metric in self.metrics.items():
+ batch_metrics[k] = metric(output)
+
+ metric_acc.update(batch_metrics, batch_size)
+ logs = metric_acc.metrics
+ logs["lr"] = self._get_lr() # (JJ: TODO - this may need to change later (NOT now) if lrs are applied differerently over parameter groups)
+ self.loggers.on_batch_end(logs=logs, batch_size=batch_size)
+
+ logs = metric_acc.metrics
+ logs = ODict(("train_" + k, v) for k, v in logs.items())
+ logs["lr"] = self._get_lr()
+ return logs
+
+ def validation_epoch(self, data_loader, swa_update_bn=False):
+ """Validation epoch loop
+
+ Args:
+ data_loader: PyTorch data loader return input/output pairs.
+ sw_update_bn: wheter or not, update batch-norm layers in SWA.
+ """
+ metric_acc = MetricAcc(device=self.device)
+ batch_metrics = ODict()
+ self.feat_extractor.eval()
+ with torch.no_grad():
+ if swa_update_bn:
+ log_tag = "train_"
+ self.model.train()
+ else:
+ log_tag = "val_"
+ self.model.eval()
+
+ for batch, (data, target) in enumerate(data_loader):
+ data, target = data.to(self.device), target.to(self.device)
+ batch_size = data.shape[0]
+
+ feats = self.feat_extractor(data)
+ with self.amp_autocast():
+ output = self.model(feats)
+ loss = self.loss(output, target)
+
+ batch_metrics["loss"] = loss.mean().item()
+ for k, metric in self.metrics.items():
+ batch_metrics[k] = metric(output, target)
+
+ metric_acc.update(batch_metrics, batch_size)
+
+ logs = metric_acc.metrics
+ logs = ODict((log_tag + k, v) for k, v in logs.items())
+ return logs
diff --git a/hyperion/torch/utils/__init__.py b/hyperion/torch/utils/__init__.py
index 3a4692dc..79691bf8 100644
--- a/hyperion/torch/utils/__init__.py
+++ b/hyperion/torch/utils/__init__.py
@@ -11,3 +11,4 @@
from .vad_utils import remove_silence
from .data_parallel import TorchDataParallel
from .ddp import TorchDDP, FairShardedDDP, FairFullyShardedDDP
+from .dinossl import MultiCropWrapper, DINOHead, has_batchnorms, cancel_gradients_last_layer, add_dinossl_args, filter_args, get_params_groups, trunc_normal_, _no_grad_trunc_normal_
diff --git a/hyperion/torch/utils/dinossl.py b/hyperion/torch/utils/dinossl.py
new file mode 100644
index 00000000..c2b32aa3
--- /dev/null
+++ b/hyperion/torch/utils/dinossl.py
@@ -0,0 +1,581 @@
+"""
+ JJ: I copied and edited from utils.py script in facebook dino repo (some parts are from vision_transformer.py, main_dino.py in the same repo)
+ Copyright 2021 Johns Hopkins University (Author: Jaejin Cho)
+ Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+"""
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Misc functions.
+
+Mostly copy-paste from torchvision references or other public repos like DETR:
+https://github.com/facebookresearch/detr/blob/master/util/misc.py
+"""
+import copy
+import logging
+import os
+import sys
+import time
+import math
+import random
+import datetime
+import subprocess
+from collections import defaultdict, deque
+
+
+import numpy as np
+import torch
+from torch import nn
+import torch.distributed as dist
+from PIL import ImageFilter, ImageOps
+
+def add_dinossl_args(parser):
+ parser.add_argument('--dinossl', default=False, type=bool, # I can simply type=bool with jsonargparse module. Refer to the "Boolean arguments" section in https://jsonargparse.readthedocs.io/en/stable/
+ help='whether to run DINO self-supervised training')
+ parser.add_argument('--dinossl_nlayers', default=3, type=int, help="""number of layers in MLP
+ except the last optional weight normalized layer""")
+ parser.add_argument('--dinossl_out_dim', default=65536, type=int, help="""Dimensionality of
+ the DINO head output. For complex and large datasets large values (like 65k) work well.""")
+ parser.add_argument('--dinossl_use_bn_in_head', default=False, type=bool,
+ help="Whether to use batch normalizations in projection head (Default: False)")
+ parser.add_argument('--dinossl_norm_last_layer', default=True, type=bool,
+ help="""Whether or not to weight normalize the last layer of the DINO head.
+ Not normalizing leads to better performance but can make the training unstable.
+ In our experiments, we typically set this paramater to False with deit_small and True with vit_base.""")
+ parser.add_argument('--dinossl_local_crops_number', type=int, default=2, help="""Number of small
+ local views to generate. Set this parameter to 0 to disable multi-crop training.
+ When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """)
+ # Teacher temperature parameters
+ parser.add_argument('--dinossl_warmup_teacher_temp', default=0.04, type=float,
+ help="""Initial value for the teacher temperature: 0.04 works well in most cases.
+ Try decreasing it if the training loss does not decrease.""")
+ parser.add_argument('--dinossl_teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup)
+ of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
+ starting with the default value of 0.04 and increase this slightly if needed.""")
+ parser.add_argument('--dinossl_warmup_teacher_temp_epochs', default=0, type=int,
+ help='Number of warmup epochs for the teacher temperature (Default: 30).')
+ # Audio chunking related
+ parser.add_argument('--dinossl_chunk_len_mult', type=int, default=2,
+ help=('value to multiply chunk_length with for long chunks'))
+ parser.add_argument('--dinossl_reduce_overlap_prob', type=float, default=0,
+ help=('probability of applying a function to reduce an overlap between two long chunks'))
+ # epochs
+ parser.add_argument('--epochs', type=int, default=70,
+ help=('training epochs'))
+
+def filter_args(**kwargs):
+ valid_args = ('dinossl', 'dinossl_nlayers', 'dinossl_out_dim', 'dinossl_use_bn_in_head',
+ 'dinossl_norm_last_layer','dinossl_local_crops_number',
+ 'dinossl_warmup_teacher_temp','dinossl_teacher_temp', 'dinossl_warmup_teacher_temp_epochs',
+ 'dinossl_chunk_len_mult', 'dinossl_reduce_overlap_prob', 'epochs')
+ args = dict((k, kwargs[k])
+ for k in valid_args if k in kwargs)
+
+ return args
+
+def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
+ if os.path.isfile(pretrained_weights):
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
+ if checkpoint_key is not None and checkpoint_key in state_dict:
+ print(f"Take key {checkpoint_key} in provided checkpoint dict")
+ state_dict = state_dict[checkpoint_key]
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
+ msg = model.load_state_dict(state_dict, strict=False)
+ print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
+ else:
+ print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
+ url = None
+ if model_name == "deit_small" and patch_size == 16:
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
+ elif model_name == "deit_small" and patch_size == 8:
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
+ elif model_name == "vit_base" and patch_size == 16:
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
+ elif model_name == "vit_base" and patch_size == 8:
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
+ if url is not None:
+ print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
+ model.load_state_dict(state_dict, strict=True)
+ else:
+ print("There is no reference weights available for this model => We use random weights.")
+
+
+def clip_gradients(model, clip):
+ norms = []
+ for name, p in model.named_parameters():
+ if p.grad is not None:
+ param_norm = p.grad.data.norm(2)
+ norms.append(param_norm.item())
+ clip_coef = clip / (param_norm + 1e-6)
+ if clip_coef < 1:
+ p.grad.data.mul_(clip_coef)
+ return norms
+
+
+def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
+ if epoch >= freeze_last_layer:
+ return
+ for n, p in model.named_parameters():
+ if "last_layer" in n:
+ p.grad = None
+
+
+def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
+ """
+ Re-start from checkpoint
+ """
+ if not os.path.isfile(ckp_path):
+ return
+ print("Found checkpoint at {}".format(ckp_path))
+
+ # open checkpoint file
+ checkpoint = torch.load(ckp_path, map_location="cpu")
+
+ # key is what to look for in the checkpoint file
+ # value is the object to load
+ # example: {'state_dict': model}
+ for key, value in kwargs.items():
+ if key in checkpoint and value is not None:
+ try:
+ msg = value.load_state_dict(checkpoint[key], strict=False)
+ print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
+ except TypeError:
+ try:
+ msg = value.load_state_dict(checkpoint[key])
+ print("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
+ except ValueError:
+ print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
+ else:
+ print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
+
+ # re load variable important for the run
+ if run_variables is not None:
+ for var_name in run_variables:
+ if var_name in checkpoint:
+ run_variables[var_name] = checkpoint[var_name]
+
+
+def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
+ warmup_schedule = np.array([])
+ warmup_iters = warmup_epochs * niter_per_ep
+ if warmup_epochs > 0:
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
+
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
+
+ schedule = np.concatenate((warmup_schedule, schedule))
+ assert len(schedule) == epochs * niter_per_ep
+ return schedule
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def init_distributed_mode(args):
+ # launched with torch.distributed.launch
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ # launched with submitit on a slurm cluster
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ # launched naively with `python main_dino.py`
+ # we manually add MASTER_ADDR and MASTER_PORT to env variables
+ elif torch.cuda.is_available():
+ print('Will run the code on one GPU.')
+ args.rank, args.gpu, args.world_size = 0, 0, 1
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
+ os.environ['MASTER_PORT'] = '29500'
+ else:
+ print('Does not support training without GPU.')
+ sys.exit(1)
+
+ dist.init_process_group(
+ backend="nccl",
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+
+ torch.cuda.set_device(args.gpu)
+ print('| distributed init (rank {}): {}'.format(
+ args.rank, args.dist_url), flush=True)
+ dist.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ maxk = max(topk)
+ batch_size = target.size(0)
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
+ return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+class LARS(torch.optim.Optimizer):
+ """
+ Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
+ """
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
+ weight_decay_filter=None, lars_adaptation_filter=None):
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
+ eta=eta, weight_decay_filter=weight_decay_filter,
+ lars_adaptation_filter=lars_adaptation_filter)
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self):
+ for g in self.param_groups:
+ for p in g['params']:
+ dp = p.grad
+
+ if dp is None:
+ continue
+
+ if p.ndim != 1:
+ dp = dp.add(p, alpha=g['weight_decay'])
+
+ if p.ndim != 1:
+ param_norm = torch.norm(p)
+ update_norm = torch.norm(dp)
+ one = torch.ones_like(param_norm)
+ q = torch.where(param_norm > 0.,
+ torch.where(update_norm > 0,
+ (g['eta'] * param_norm / update_norm), one), one)
+ dp = dp.mul(q)
+
+ param_state = self.state[p]
+ if 'mu' not in param_state:
+ param_state['mu'] = torch.zeros_like(p)
+ mu = param_state['mu']
+ mu.mul_(g['momentum']).add_(dp)
+
+ p.add_(mu, alpha=-g['lr'])
+
+
+class MultiCropWrapper(nn.Module):
+ """
+ Perform forward pass separately on each resolution input.
+ The inputs corresponding to a single resolution are clubbed and single
+ forward is run on the same resolution inputs. Hence we do several
+ forward passes = number of different resolutions used. We then
+ concatenate all the output features and run the head forward on these
+ concatenated features.
+ (JJ: EXP - Q: In spkid, I think we do zero-padding at least to make the seq len same. Can we skip this or at least zero-padding separately for local and global view groups? A: It seems with adaptive pooling, input seq. len. can vary)
+ """
+ def __init__(self, backbone, head):
+ super(MultiCropWrapper, self).__init__()
+ # disable layers dedicated to ImageNet labels classification
+ #backbone.fc, backbone.head = nn.Identity(), nn.Identity()
+ self.backbone = backbone
+ self.head = head
+ self._train_mode = "full"
+
+ def forward(self, x):
+ # convert to list
+ if not isinstance(x, list):
+ x = [x]
+ idx_crops = torch.cumsum(torch.unique_consecutive(
+ torch.tensor([inp.shape[-1] for inp in x]),
+ return_counts=True,
+ )[1], 0)
+ start_idx = 0
+ for end_idx in idx_crops:
+ _out = self.backbone(torch.cat(x[start_idx: end_idx]))
+ if start_idx == 0:
+ output = _out
+ else:
+ output = torch.cat((output, _out))
+ start_idx = end_idx
+ # Run the head forward on the concatenated features.
+ return self.head(output)
+
+ def extract_embed(self, x, chunk_length=0, embed_layer=None, detach_chunks=False):
+ # - This method is NOT used when extracting xvectors from f (right before the DINOHead)
+ # - self.dinossl_xvec_loc is registered (in TorchModelLoader) for xvector extraction NOT during training.
+ y = self.backbone.extract_embed(x, chunk_length=chunk_length,
+ embed_layer=embed_layer)
+ if self.dinossl_xvec_loc == "dinohead_mlp":
+ y = self.head.mlp(y)
+ elif self.dinossl_xvec_loc == "dinohead_l2norm":
+ y = self.head.mlp(y)
+ y = nn.functional.normalize(y, dim=-1, p=2)
+ elif self.dinossl_xvec_loc == "dinohead_linear":
+ y = self.head(y)
+ else:
+ NotImplementedError
+ return y
+
+ def freeze(self):
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def unfreeze(self):
+ for param in self.parameters():
+ param.requires_grad = True
+
+ def set_train_mode(self, mode):
+ if mode == self._train_mode:
+ return
+
+ if mode == "full": # different from xvec train, in dinossl, only this is used (thus implemented)
+ self.unfreeze()
+ elif mode == "frozen":
+ raise NotImplementedError
+ # self.freeze()
+ elif mode == "ft-embed-affine":
+ raise NotImplementedError
+ # self.unfreeze()
+ # self.freeze_preembed_layers()
+ else:
+ raise ValueError(f"invalid train_mode={mode}")
+
+ self._train_mode = mode
+
+def get_params_groups(model):
+ regularized = []
+ not_regularized = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue
+ # we do not regularize biases nor Norm parameters
+ if name.endswith(".bias") or len(param.shape) == 1:
+ not_regularized.append(param)
+ else:
+ regularized.append(param)
+ return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
+
+
+def has_batchnorms(model):
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
+ for name, module in model.named_modules():
+ if isinstance(module, bn_types):
+ return True
+ return False
+
+# Parts copied and edited from vision_transformer.py in facebook dino repo - start
+class DINOHead(nn.Module):
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ if nlayers == 1:
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
+ self.mlp = nn.Sequential(*layers)
+ self.apply(self._init_weights)
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+ if norm_last_layer:
+ self.last_layer.weight_g.requires_grad = False
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ x = nn.functional.normalize(x, dim=-1, p=2)
+ x = self.last_layer(x)
+ return x
+# Parts copied and edited from vision_transformer.py in facebook dino repo - end
+
+# Parts copied and edited from main_dino.py in facebook dino repo - start
+class DINOLoss(nn.Module):
+ def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
+ warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
+ center_momentum=0.9):
+ super().__init__()
+ self.student_temp = student_temp
+ self.center_momentum = center_momentum
+ self.ncrops = ncrops
+ self.register_buffer("center", torch.zeros(1, out_dim))
+ # we apply a warm up for the teacher temperature because
+ # a too high temperature makes the training instable at the beginning
+ self.teacher_temp_schedule = np.concatenate((
+ np.linspace(warmup_teacher_temp,
+ teacher_temp, warmup_teacher_temp_epochs),
+ np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
+ ))
+
+ def forward(self, student_output, teacher_output, epoch):
+ """
+ Cross-entropy between softmax outputs of the teacher and student networks.
+ """
+ student_out = student_output / self.student_temp
+ student_out = student_out.chunk(self.ncrops)
+
+ # teacher centering and sharpening
+ temp = self.teacher_temp_schedule[epoch]
+ teacher_out = nn.functional.softmax((teacher_output - self.center) / temp, dim=-1)
+ teacher_out = teacher_out.detach().chunk(2)
+
+ total_loss = 0
+ n_loss_terms = 0
+ for iq, q in enumerate(teacher_out):
+ for v in range(len(student_out)):
+ if v == iq:
+ # we skip cases where student and teacher operate on the same view
+ continue
+ loss = torch.sum(-q * nn.functional.log_softmax(student_out[v], dim=-1), dim=-1)
+ total_loss += loss.mean()
+ n_loss_terms += 1
+ total_loss /= n_loss_terms
+ self.update_center(teacher_output)
+ return total_loss
+
+ @torch.no_grad()
+ def update_center(self, teacher_output):
+ """
+ Update center used for teacher output.
+ """
+ batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
+ world_size = get_world_size()
+ if world_size == 1: # 1-gpu w/o ddp
+ batch_center = batch_center / len(teacher_output)
+ else: # ddp
+ dist.all_reduce(batch_center)
+ batch_center = batch_center / (len(teacher_output) * world_size)
+
+ # ema update
+ self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
+# Parts copied and edited from main_dino.py in facebook dino repo - end
+
+def init_dino(model, dinossl_args, rank = 0):
+ model.conf = model.get_config() # The classif_net part will NOT be correct in the dinssl case but keep around this conf for compatibility with the current code-base
+ if rank == 0:
+ logging.info('dinossl args={}'.format(dinossl_args))
+
+ embed_dim = model.conf['embed_dim']
+ model.classif_net.output = nn.Identity()
+
+ # model
+ model_teacher = copy.deepcopy(model)
+ model = MultiCropWrapper(model, DINOHead(embed_dim, dinossl_args['dinossl_out_dim'], use_bn=dinossl_args['dinossl_use_bn_in_head'],
+ norm_last_layer=dinossl_args['dinossl_norm_last_layer'], nlayers=dinossl_args['dinossl_nlayers'])) # multi-crop wrapper handles forward with inputs of different chunk lengths
+ model_teacher = MultiCropWrapper(model_teacher, DINOHead(embed_dim, dinossl_args['dinossl_out_dim'], use_bn=dinossl_args['dinossl_use_bn_in_head'],
+ nlayers=dinossl_args['dinossl_nlayers']))
+ # teacher and student start with the same weights. "requires_grad = False" happens in torch_trainer_dinossl.py
+ model_teacher.load_state_dict(model.state_dict())
+ model = [model, model_teacher]
+ # loss
+ loss = DINOLoss(dinossl_args['dinossl_out_dim'],
+ dinossl_args['dinossl_local_crops_number'] + 2, # total number of crops = 2 global crops + local_crops_number
+ dinossl_args['dinossl_warmup_teacher_temp'],
+ dinossl_args['dinossl_teacher_temp'],
+ dinossl_args['dinossl_warmup_teacher_temp_epochs'],
+ dinossl_args['epochs'],) # to(device) will happen in torch_trainer_dinossl.py
+
+ if rank == 0:
+ logging.info('dinossl-model={}'.format(model))
+ return model, loss
\ No newline at end of file