forked from YUCHEN005/DPSL-ASR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathst_trans.py
183 lines (165 loc) · 5.72 KB
/
st_trans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python3
# encoding: utf-8
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""End-to-end speech translation model decoding script."""
import logging
import os
import random
import sys
import configargparse
import numpy as np
# NOTE: you need this func to generate our sphinx doc
def get_parser():
"""Get default arguments."""
parser = configargparse.ArgumentParser(
description="Translate text from speech using a speech translation "
"model on one CPU or GPU",
config_file_parser_class=configargparse.YAMLConfigFileParser,
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
)
# general configuration
parser.add("--config", is_config_file=True, help="Config file path")
parser.add(
"--config2",
is_config_file=True,
help="Second config file path that overwrites the settings in `--config`",
)
parser.add(
"--config3",
is_config_file=True,
help="Third config file path that overwrites "
"the settings in `--config` and `--config2`",
)
parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs")
parser.add_argument(
"--dtype",
choices=("float16", "float32", "float64"),
default="float32",
help="Float precision (only available in --api v2)",
)
parser.add_argument(
"--backend",
type=str,
default="chainer",
choices=["chainer", "pytorch"],
help="Backend library",
)
parser.add_argument("--debugmode", type=int, default=1, help="Debugmode")
parser.add_argument("--seed", type=int, default=1, help="Random seed")
parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option")
parser.add_argument(
"--batchsize",
type=int,
default=1,
help="Batch size for beam search (0: means no batch processing)",
)
parser.add_argument(
"--preprocess-conf",
type=str,
default=None,
help="The configuration file for the pre-processing",
)
parser.add_argument(
"--api",
default="v1",
choices=["v1", "v2"],
help="Beam search APIs "
"v1: Default API. "
"It only supports the ASRInterface.recognize method and DefaultRNNLM. "
"v2: Experimental API. "
"It supports any models that implements ScorerInterface.",
)
# task related
parser.add_argument(
"--trans-json", type=str, help="Filename of translation data (json)"
)
parser.add_argument(
"--result-label",
type=str,
required=True,
help="Filename of result label data (json)",
)
# model (parameter) related
parser.add_argument(
"--model", type=str, required=True, help="Model file parameters to read"
)
# search related
parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
parser.add_argument("--beam-size", type=int, default=1, help="Beam size")
parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty")
parser.add_argument(
"--maxlenratio",
type=float,
default=0.0,
help="""Input length ratio to obtain max output length.
If maxlenratio=0.0 (default), it uses a end-detect function
to automatically find maximum hypothesis lengths""",
)
parser.add_argument(
"--minlenratio",
type=float,
default=0.0,
help="Input length ratio to obtain min output length",
)
# multilingual related
parser.add_argument(
"--tgt-lang",
default=False,
type=str,
help="target language ID (e.g., <en>, <de>, and <fr> etc.)",
)
return parser
def main(args):
"""Run the main decoding function."""
parser = get_parser()
args = parser.parse_args(args)
# logging info
if args.verbose == 1:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
elif args.verbose == 2:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
logging.warning("Skip DEBUG/INFO messages")
# check CUDA_VISIBLE_DEVICES
if args.ngpu > 0:
cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
if cvd is None:
logging.warning("CUDA_VISIBLE_DEVICES is not set.")
elif args.ngpu != len(cvd.split(",")):
logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
sys.exit(1)
# TODO(mn5k): support of multiple GPUs
if args.ngpu > 1:
logging.error("The program only supports ngpu=1.")
sys.exit(1)
# display PYTHONPATH
logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
# seed setting
random.seed(args.seed)
np.random.seed(args.seed)
logging.info("set random seed = %d" % args.seed)
# trans
logging.info("backend = " + args.backend)
if args.backend == "pytorch":
# Experimental API that supports custom LMs
from espnet.st.pytorch_backend.st import trans
if args.dtype != "float32":
raise NotImplementedError(
f"`--dtype {args.dtype}` is only available with `--api v2`"
)
trans(args)
else:
raise ValueError("Only pytorch are supported.")
if __name__ == "__main__":
main(sys.argv[1:])