-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
75 lines (61 loc) · 2.45 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import torchaudio
import torchvision
import librosa
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from config import set_params
from kws.model import treasure_net
from kws.utils.transforms import SpectogramNormalize
from kws.utils.utils import exp_moving_average
def test():
# set parameters
params = set_params()
params['device'] = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
if params['verbose']:
print('Using device', params['device'])
# initialize model
model = treasure_net(params).to(params['device'])
if params['load_model']:
checkpoint = torch.load(params['model_checkpoint'])
model.load_state_dict(checkpoint['model_state_dict'])
# prepare test audio
waveform, sample_rate = torchaudio.load(params['example_audio'])
waveform = waveform[:1]
if sample_rate != params['sample_rate']:
waveform = waveform.squeeze(0).numpy()
waveform = librosa.core.resample(waveform, sample_rate, params['sample_rate'])
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = waveform.to(params['device'])
spectrogramer = torchvision.transforms.Compose([
torchaudio.transforms.MelSpectrogram(
sample_rate=params['sample_rate'],
n_mels=params['num_mels'],
).to(params['device']),
SpectogramNormalize(),
])
# calculate keyword probs
spec = spectrogramer(waveform).transpose(1, 2)
num_predicts = spec.shape[1] - params['time_steps']
keyword_probs = np.zeros((num_predicts, len(params['keywords'])))
hidden = None
for i in range(num_predicts):
with torch.no_grad():
logits, hidden = model(spec[:, i:i + params['time_steps']], hidden)
probs = torch.nn.functional.softmax(logits.detach(), dim=-1).cpu().numpy()
keyword_probs[i] = probs[:, 1:]
# plot results
plt.figure(figsize=(12, 5))
plt.rcParams.update({'font.size': 14})
seconds_steps = np.linspace(0, waveform.shape[1] / params['sample_rate'], num_predicts)
for i, keyword in enumerate(params['keywords']):
ema_probs = exp_moving_average(keyword_probs[:, i], alpha=params['ema_alpha'])
plt.plot(seconds_steps, ema_probs, label=keyword)
plt.grid()
plt.legend(title='keyword')
plt.xlabel('time (s)')
plt.ylabel('probability')
plt.savefig(params['example_fig'])
if __name__ == '__main__':
test()