-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdenoiser.py
75 lines (53 loc) · 2.11 KB
/
denoiser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
from librosa import stft, istft
from resampy.core import resample
stft_hop_length = 420
win_length = n_fft = 4 * stft_hop_length
def _stft(x):
s = stft(x, window='hann', win_length=win_length, n_fft=n_fft, hop_length=stft_hop_length,
center=True, pad_mode='reflect')
s = s[..., :-1]
mag = np.abs(s)
phi = np.angle(s)
cos = np.cos(phi)
sin = np.sin(phi)
return mag, cos, sin
def _istft(mag: np.array, cos: np.array, sin: np.array):
real = mag * cos
imag = mag * sin
s = real + imag * 1.0j
s = np.pad(s, ((0, 0), (0, 0), (0, 1)), mode='edge')
x = istft(s, window='hann', win_length=win_length, hop_length=stft_hop_length, n_fft=n_fft)
return x
def model(onnx_session, wav: np.array) -> np.array:
padded_wav = np.pad(wav, ((0,0), (0, 441)))
mag, cos, sin = _stft(padded_wav) # (b nfft/2 t)
ort_inputs = {
"mag": mag,
"cos": cos,
"sin": sin,
}
sep_mag, sep_cos, sep_sin = onnx_session.run(None, ort_inputs)
o = _istft(sep_mag, sep_cos, sep_sin)
o = o[:wav.shape[-1]]
return o
def run(onnx_session, wav: np.array, sample_rate: int, batch_process_chunks = False) -> np.array:
assert wav.ndim == 1, 'Input should be 1D (mono) wav'
if sample_rate != 44_100:
wav = resample(wav, sample_rate, 44_100, filter='kaiser_best', parallel=True)
chunk_length = int(sample_rate * 30)
#overlap_length = int(sr * overlap_seconds)
hop_length = chunk_length # - overlap_length
num_chunks = 1 + (wav.shape[-1] - 1) // hop_length
n_pad = (num_chunks - wav.shape[-1] % num_chunks) % num_chunks
wav = np.pad(wav, (0, n_pad))
chunks = np.reshape(wav, (num_chunks, -1))
abs_max = np.clip(np.max(np.abs(chunks), axis = -1, keepdims = True), a_min=1e-7, a_max=None)
chunks /= abs_max
if batch_process_chunks:
res_chunks = model(onnx_session, chunks)
else:
res_chunks = np.array([model(onnx_session, c[None]) for c in chunks]).squeeze(axis=1)
res_chunks *= abs_max
res = np.reshape(res_chunks, (-1))
return res[:wav.shape[-1]], 44_100