-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathWaveRNNVocoder.cpp
executable file
·80 lines (57 loc) · 2.13 KB
/
WaveRNNVocoder.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#define EIGEN_USE_MKL_ALL
#include <pybind11/pybind11.h>
#include <pybind11/eigen.h>
#include <Eigen/Dense>
#include <Eigen/Core>
#include <stdio.h>
#include "net_impl.h"
namespace py = pybind11;
typedef Matrixf MatrixPy;
typedef MatrixPy::Scalar Scalar;
constexpr bool rowMajor = MatrixPy::Flags & Eigen::RowMajorBit;
class Vocoder {
Model model;
bool isLoaded;
public:
Vocoder() { isLoaded = false; }
void loadWeights( const std::string& fileName ){
mkl_set_num_threads(2);
FILE* fd = fopen(fileName.c_str(), "rb");
if( not fd ){
throw std::runtime_error("Cannot open file.");
}
model.loadNext(fd);
isLoaded = true;
}
Vectorf melToWav( Eigen::Ref<const MatrixPy> mels ){
if( not isLoaded ){
throw std::runtime_error("Model hasn't been loaded. Call loadWeights first.");
}
return model.apply(mels);
}
};
PYBIND11_MODULE(WaveRNNVocoder, m){
m.doc() = "WaveRNN Vocoder";
py::class_<MatrixPy>( m, "Matrix", py::buffer_protocol() )
.def("__init__", [](MatrixPy &m, py::buffer b) {
typedef Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic> Strides;
/* Request a buffer descriptor from Python */
py::buffer_info info = b.request();
/* Some sanity checks ... */
if (info.format != py::format_descriptor<Scalar>::format())
throw std::runtime_error("Incompatible format: expected a float32 array!");
if (info.ndim != 2)
throw std::runtime_error("Incompatible buffer dimension!");
auto strides = Strides(
info.strides[rowMajor ? 0 : 1] / (py::ssize_t)sizeof(Scalar),
info.strides[rowMajor ? 1 : 0] / (py::ssize_t)sizeof(Scalar));
auto map = Eigen::Map<MatrixPy, 0, Strides>(
static_cast<Scalar *>(info.ptr), info.shape[0], info.shape[1], strides);
new (&m) MatrixPy(map);
});
py::class_<Vocoder>( m, "Vocoder")
.def(py::init())
.def("loadWeights", &Vocoder::loadWeights )
.def("melToWav", &Vocoder::melToWav )
;
}