-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkeyword_spotter.py
68 lines (47 loc) · 1.86 KB
/
keyword_spotter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Sebastian Thomas (datascience at sebastianthomas dot de)
from __future__ import annotations
# type hints
from typing import ClassVar, Union, Optional, NoReturn, BinaryIO
# operation system
from os import PathLike
# machine learning
import tensorflow as tf
from tensorflow.keras.models import load_model
# custom modules
from common.constants import CATEGORIES, CLASSIFIER_PATH
from common.preprocessing import to_features
__all__ = ['KeywordSpotter', 'KeywordSpotterType']
class KeywordSpotterType:
"""Type of KeywordSpotter singleton."""
_instance: ClassVar[Optional[KeywordSpotterType]] = None
def __new__(cls) -> NoReturn:
"""Constructs new instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> NoReturn:
"""Initializes instance."""
try:
self._classifier = load_model(CLASSIFIER_PATH)
except OSError:
self._classifier = None
def update(self) -> NoReturn:
"""Updates instance by reloading the classifier field."""
self.__init__()
def predict(self, file: Union[str, PathLike, BinaryIO]) -> str:
"""Predicts keyword spoken in file.
Raises FileNotFoundError if no classifier is saved at
CLASSIFIER_PATH."""
try:
n_mfcc = self._classifier.input_shape[2]
except AttributeError:
try:
self._classifier = load_model(CLASSIFIER_PATH)
except OSError:
raise FileNotFoundError('No classifier could be found.')
else:
n_mfcc = self._classifier.input_shape[2]
features = to_features(file, n_mfcc=n_mfcc)
logits = self._classifier(tf.expand_dims(features, 0))[0]
return CATEGORIES[int(tf.argmax(logits))]
KeywordSpotter = KeywordSpotterType()