Skip to content

Commit

Permalink
Cache the TextFeaturizer instance for infer speed improvement. (#1260)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamBear authored Jan 4, 2022
1 parent 50752f8 commit 36c9eaa
Showing 1 changed file with 10 additions and 15 deletions.
25 changes: 10 additions & 15 deletions paddlespeech/cli/asr/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,25 +174,25 @@ def _init_from_path(self,
self.config.collator.mean_std_filepath = os.path.join(
res_path, self.config.collator.cmvn_path)
self.collate_fn_test = SpeechCollator.from_config(self.config)
text_feature = TextFeaturizer(
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.collate_fn_test.feature_size
self.config.model.output_dim = text_feature.vocab_size
self.config.model.output_dim = self.text_feature.vocab_size
elif "conformer" in model_type or "transformer" in model_type or "wenetspeech" in model_type:
self.config.collator.vocab_filepath = os.path.join(
res_path, self.config.collator.vocab_filepath)
self.config.collator.augmentation_config = os.path.join(
res_path, self.config.collator.augmentation_config)
self.config.collator.spm_model_prefix = os.path.join(
res_path, self.config.collator.spm_model_prefix)
text_feature = TextFeaturizer(
self.text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)
self.config.model.input_dim = self.config.collator.feat_dim
self.config.model.output_dim = text_feature.vocab_size
self.config.model.output_dim = self.text_feature.vocab_size

else:
raise Exception("wrong type")
Expand All @@ -211,6 +211,7 @@ def _init_from_path(self,
model_dict = paddle.load(self.ckpt_path)
self.model.set_state_dict(model_dict)


def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
"""
Input preprocess and return paddle.Tensor stored in self.input.
Expand All @@ -228,7 +229,7 @@ def preprocess(self, model_type: str, input: Union[str, os.PathLike]):
audio = paddle.to_tensor(audio, dtype='float32')
audio_len = paddle.to_tensor(audio_len)
audio = paddle.unsqueeze(audio, axis=0)
vocab_list = collate_fn_test.vocab_list
# vocab_list = collate_fn_test.vocab_list
self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}")
Expand Down Expand Up @@ -274,10 +275,7 @@ def preprocess(self, model_type: str, input: Union[str, os.PathLike]):

audio_len = paddle.to_tensor(audio.shape[0])
audio = paddle.to_tensor(audio, dtype='float32').unsqueeze(axis=0)
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)

self._inputs["audio"] = audio
self._inputs["audio_len"] = audio_len
logger.info(f"audio feat shape: {audio.shape}")
Expand All @@ -290,18 +288,15 @@ def infer(self, model_type: str):
"""
Model inference and result stored in self.output.
"""
text_feature = TextFeaturizer(
unit_type=self.config.collator.unit_type,
vocab=self.config.collator.vocab_filepath,
spm_model_prefix=self.config.collator.spm_model_prefix)

cfg = self.config.decoding
audio = self._inputs["audio"]
audio_len = self._inputs["audio_len"]
if "ds2_online" in model_type or "ds2_offline" in model_type:
result_transcripts = self.model.decode(
audio,
audio_len,
text_feature.vocab_list,
self.text_feature.vocab_list,
decoding_method=cfg.decoding_method,
lang_model_path=cfg.lang_model_path,
beam_alpha=cfg.alpha,
Expand All @@ -316,7 +311,7 @@ def infer(self, model_type: str):
result_transcripts = self.model.decode(
audio,
audio_len,
text_feature=text_feature,
text_feature=self.text_feature,
decoding_method=cfg.decoding_method,
beam_size=cfg.beam_size,
ctc_weight=cfg.ctc_weight,
Expand Down

0 comments on commit 36c9eaa

Please sign in to comment.