Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 22, 2024
1 parent 45f9161 commit f361aef
Show file tree
Hide file tree
Showing 5 changed files with 1,458 additions and 157 deletions.
41 changes: 40 additions & 1 deletion src/layer/arm/lstm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,45 @@ LSTM_arm::LSTM_arm()

int LSTM_arm::create_pipeline(const Option& opt)
{
#if NCNN_INT8
if (int8_scale_term)
{
const int num_directions = direction == 2 ? 2 : 1;
const int size = weight_data_size / num_directions / hidden_size / 4;

// TODO fuse weight de-scale into kernel
Mat weight_xc_data_fp32(size, hidden_size * 4, num_directions);
Mat weight_hc_data_fp32(num_output, hidden_size * 4, num_directions);
for (int d = 0; d < num_directions; d++)
{
for (int q = 0; q < hidden_size * 4; q++)
{
const signed char* weight_xc_ptr = weight_xc_data.channel(d).row<const signed char>(q);
const signed char* weight_hc_ptr = weight_hc_data.channel(d).row<const signed char>(q);

float* weight_xc_fp32_ptr = weight_xc_data_fp32.channel(d).row(q);
float* weight_hc_fp32_ptr = weight_hc_data_fp32.channel(d).row(q);

const float descale_xc = 1.f / weight_xc_data_int8_scales.row(d)[q];
const float descale_hc = 1.f / weight_hc_data_int8_scales.row(d)[q];

for (int i = 0; i < size; i++)
{
weight_xc_fp32_ptr[i] = weight_xc_ptr[i] * descale_xc;
}

for (int i = 0; i < num_output; i++)
{
weight_hc_fp32_ptr[i] = weight_hc_ptr[i] * descale_hc;
}
}
}

weight_xc_data = weight_xc_data_fp32;
weight_hc_data = weight_hc_data_fp32;
}
#endif // NCNN_INT8

#if NCNN_ARM82
if (support_fp16_storage && opt.use_fp16_storage)
{
Expand Down Expand Up @@ -308,7 +347,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
float* tmp_hidden_ptr = tmp_hidden_state;

int remain_hidden_size_start = 0;
#if __ARM_NEON
#if 0 //__ARM_NEON TODO test_lstm failed for precision loss
int nn_hidden_size = hidden_size >> 2;
remain_hidden_size_start = nn_hidden_size << 2;

Expand Down
290 changes: 241 additions & 49 deletions src/layer/lstm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,37 +72,6 @@ int LSTM::load_model(const ModelBin& mb)
{
weight_xc_data_int8_scales = mb.load(hidden_size * 4, num_directions, 1);
weight_hc_data_int8_scales = mb.load(hidden_size * 4, num_directions, 1);

// TODO fuse weight de-scale into kernel
Mat weight_xc_data_fp32(size, hidden_size * 4, num_directions);
Mat weight_hc_data_fp32(num_output, hidden_size * 4, num_directions);
for (int d = 0; d < num_directions; d++)
{
for (int q = 0; q < hidden_size * 4; q++)
{
const signed char* weight_xc_ptr = weight_xc_data.channel(d).row<const signed char>(q);
const signed char* weight_hc_ptr = weight_hc_data.channel(d).row<const signed char>(q);

float* weight_xc_fp32_ptr = weight_xc_data_fp32.channel(d).row(q);
float* weight_hc_fp32_ptr = weight_hc_data_fp32.channel(d).row(q);

const float descale_xc = 1.f / weight_xc_data_int8_scales.row(d)[q];
const float descale_hc = 1.f / weight_hc_data_int8_scales.row(d)[q];

for (int i = 0; i < size; i++)
{
weight_xc_fp32_ptr[i] = weight_xc_ptr[i] * descale_xc;
}

for (int i = 0; i < num_output; i++)
{
weight_hc_fp32_ptr[i] = weight_hc_ptr[i] * descale_hc;
}
}
}

weight_xc_data = weight_xc_data_fp32;
weight_hc_data = weight_hc_data_fp32;
}
#endif // NCNN_INT8

Expand Down Expand Up @@ -255,6 +224,163 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w
return 0;
}

#if NCNN_INT8
static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt)
{
int size = bottom_blob.w;
int T = bottom_blob.h;

int num_output = top_blob.w;
int hidden_size = cell_state.w;

// 4 x hidden_size
Mat gates(4, hidden_size, 4u, opt.workspace_allocator);
if (gates.empty())
return -100;

Mat tmp_hidden_state;
if (num_output != hidden_size)
{
tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator);
if (tmp_hidden_state.empty())
return -100;
}

// unroll
for (int t = 0; t < T; t++)
{
// clip hidden by continuation indicator
// h_cont_{t-1} = cont_t * h_{t-1}
// h_cont_{t-1} = h_{t-1} if cont_t == 1
// 0 otherwise
// calculate hidden
// gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c

int ti = reverse ? T - 1 - t : t;

const float* x = bottom_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < hidden_size; q++)
{
const float* bias_c_I = bias_c.row(0);
const float* bias_c_F = bias_c.row(1);
const float* bias_c_O = bias_c.row(2);
const float* bias_c_G = bias_c.row(3);

float* gates_data = gates.row(q);

// gate I F O G
const signed char* weight_xc_int8_I = weight_xc_int8.row<const signed char>(hidden_size * 0 + q);
const signed char* weight_xc_int8_F = weight_xc_int8.row<const signed char>(hidden_size * 1 + q);
const signed char* weight_xc_int8_O = weight_xc_int8.row<const signed char>(hidden_size * 2 + q);
const signed char* weight_xc_int8_G = weight_xc_int8.row<const signed char>(hidden_size * 3 + q);

const signed char* weight_hc_int8_I = weight_hc_int8.row<const signed char>(hidden_size * 0 + q);
const signed char* weight_hc_int8_F = weight_hc_int8.row<const signed char>(hidden_size * 1 + q);
const signed char* weight_hc_int8_O = weight_hc_int8.row<const signed char>(hidden_size * 2 + q);
const signed char* weight_hc_int8_G = weight_hc_int8.row<const signed char>(hidden_size * 3 + q);

const float descale_xc_I = 1.f / weight_xc_int8_scales[hidden_size * 0 + q];
const float descale_xc_F = 1.f / weight_xc_int8_scales[hidden_size * 1 + q];
const float descale_xc_O = 1.f / weight_xc_int8_scales[hidden_size * 2 + q];
const float descale_xc_G = 1.f / weight_xc_int8_scales[hidden_size * 3 + q];
const float descale_hc_I = 1.f / weight_hc_int8_scales[hidden_size * 0 + q];
const float descale_hc_F = 1.f / weight_hc_int8_scales[hidden_size * 1 + q];
const float descale_hc_O = 1.f / weight_hc_int8_scales[hidden_size * 2 + q];
const float descale_hc_G = 1.f / weight_hc_int8_scales[hidden_size * 3 + q];

float I = bias_c_I[q];
float F = bias_c_F[q];
float O = bias_c_O[q];
float G = bias_c_G[q];

for (int i = 0; i < size; i++)
{
float xi = x[i];

I += weight_xc_int8_I[i] * descale_xc_I * xi;
F += weight_xc_int8_F[i] * descale_xc_F * xi;
O += weight_xc_int8_O[i] * descale_xc_O * xi;
G += weight_xc_int8_G[i] * descale_xc_G * xi;
}

for (int i = 0; i < num_output; i++)
{
float h_cont = hidden_state[i];

I += weight_hc_int8_I[i] * descale_hc_I * h_cont;
F += weight_hc_int8_F[i] * descale_hc_F * h_cont;
O += weight_hc_int8_O[i] * descale_hc_O * h_cont;
G += weight_hc_int8_G[i] * descale_hc_G * h_cont;
}

gates_data[0] = I;
gates_data[1] = F;
gates_data[2] = O;
gates_data[3] = G;
}

// lstm unit
// sigmoid(I)
// sigmoid(F)
// sigmoid(O)
// tanh(G)
// c_t := f_t .* c_{t-1} + i_t .* g_t
// h_t := o_t .* tanh[c_t]
float* output_data = top_blob.row(ti);
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < hidden_size; q++)
{
const float* gates_data = gates.row(q);

float I = gates_data[0];
float F = gates_data[1];
float O = gates_data[2];
float G = gates_data[3];

I = 1.f / (1.f + expf(-I));
F = 1.f / (1.f + expf(-F));
O = 1.f / (1.f + expf(-O));
G = tanhf(G);

float cell2 = F * cell_state[q] + I * G;
float H = O * tanhf(cell2);
cell_state[q] = cell2;

if (num_output == hidden_size)
{
hidden_state[q] = H;
output_data[q] = H;
}
else
{
tmp_hidden_state[q] = H;
}
}

if (num_output != hidden_size)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < num_output; q++)
{
const float* hr = weight_hr.row(q);

float H = 0;
for (int i = 0; i < hidden_size; i++)
{
H += tmp_hidden_state[i] * hr[i];
}

hidden_state[q] = H;
output_data[q] = H;
}
}
}

return 0;
}
#endif // NCNN_INT8

int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
int T = bottom_blob.h;
Expand All @@ -279,9 +405,20 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
#if NCNN_INT8
if (int8_scale_term)
{
int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
else
#endif
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
}

if (direction == 2)
Expand All @@ -294,16 +431,38 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
if (top_blob_reverse.empty())
return -100;

int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret0 != 0)
return ret0;
#if NCNN_INT8
if (int8_scale_term)
{
int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
else
#endif
{
int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}

hidden.fill(0.0f);
cell.fill(0.0f);

int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret1 != 0)
return ret1;
#if NCNN_INT8
if (int8_scale_term)
{
int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret != 0)
return ret;
}
else
#endif
{
int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt);
if (ret != 0)
return ret;
}

// concat w
for (int i = 0; i < T; i++)
Expand Down Expand Up @@ -355,9 +514,20 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
// Uni directional
if (direction == 0 || direction == 1)
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
#if NCNN_INT8
if (int8_scale_term)
{
int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
else
#endif
{
int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt);
if (ret != 0)
return ret;
}
}

if (direction == 2)
Expand All @@ -372,15 +542,37 @@ int LSTM::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl

Mat hidden0 = hidden.row_range(0, 1);
Mat cell0 = cell.row_range(0, 1);
int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret0 != 0)
return ret0;
#if NCNN_INT8
if (int8_scale_term)
{
int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret != 0)
return ret;
}
else
#endif
{
int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt);
if (ret != 0)
return ret;
}

Mat hidden1 = hidden.row_range(1, 1);
Mat cell1 = cell.row_range(1, 1);
int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret1 != 0)
return ret1;
#if NCNN_INT8
if (int8_scale_term)
{
int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret != 0)
return ret;
}
else
#endif
{
int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt);
if (ret != 0)
return ret;
}

// concat w
for (int i = 0; i < T; i++)
Expand Down
Loading

0 comments on commit f361aef

Please sign in to comment.