-
Notifications
You must be signed in to change notification settings - Fork 115
/
Copy pathpretrain.py
66 lines (49 loc) · 2.3 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import os
from os.path import join as join_path
import torch
import multiprocessing
import sys
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--fairseq_path", default=None, type=str,
required=True, help="Path to installed fairseq library")
parser.add_argument("--audio_path", default=None, type=str,
required=True, help="Path to unlabeled audio")
parser.add_argument("--init_model", default=None, required=False,
type=str,help="Path to English pretrain wav2vec model")
parser.add_argument("--batch_size", default=1200000, required=False,
type=int,help="Batch size, try to decrease this number if any CUDA memory problems occur")
args = parser.parse_args()
#Prepare manifest file
MANIFEST_PATH = join_path(args.fairseq_path, 'examples/wav2vec/wav2vec_manifest.py')
temp_dir = os.path.abspath('./temp')
if not os.path.exists(temp_dir):
os.makedirs(temp_dir)
cmd = 'python3 ' + MANIFEST_PATH + ' ' + args.audio_path + ' --dest ' + temp_dir + ' --ext wav --valid-percent 0.05'
os.system(cmd)
#Pretrain the model
NUM_GPU = torch.cuda.device_count()
NUM_CPU = multiprocessing.cpu_count()
if NUM_GPU == 0:
print("pytorch cannot find any GPUs !")
sys.exit(0)
cmd = ["fairseq-hydra-train"]
cmd.append("task.data=" + str(temp_dir))
cmd.append("distributed_training.distributed_world_size=" + str(NUM_GPU))
cmd.append("+optimization.update_freq='[" + str(int(64/NUM_GPU)) + "]'")
if args.init_model != None:
cmd.append("checkpoint.restore_file=" + os.path.abspath(args.init_model))
cmd.append("checkpoint.reset_optimizer=True")
cmd.append("checkpoint.reset_lr_scheduler=True")
cmd.append("checkpoint.reset_dataloader=True")
cmd.append("checkpoint.reset_meters=True")
#cmd.append("optimization.max_update=2000000")
cmd.append("dataset.num_workers=" + str(NUM_CPU))
cmd.append("dataset.max_tokens=" + str(args.batch_size))
cmd.append("--config-dir config/pretraining")
cmd.append("--config-name wav2vec2_base_librispeech")
cmd = ' '.join(cmd)
print(cmd)
os.system(cmd)
main()