Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed the memory bugs when the input vector for innerproduct has 3 dims (3D vector innerproduct) #5158

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 235 additions & 0 deletions src/layer/arm/innerproduct_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,241 @@ int InnerProduct_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Optio
return 0;
}

if (bottom_blob.dims == 3 && bottom_blob.c == num_input)
{
// 3d tensor input gemm
int w = bottom_blob.w;
int h = bottom_blob.h;
int c = bottom_blob.c; // num_input
size_t elemsize = bottom_blob.elemsize;
int elempack = bottom_blob.elempack;
ncnn::Mat bottom_blob_flattened = bottom_blob.reshape(c, w * h);

top_blob.create(num_output, w * h, elemsize, elempack, opt.blob_allocator);
if (top_blob.empty())
return -100;

int num_output_elempack = 1;
#if __ARM_NEON
if (opt.use_packing_layout)
{
num_output_elempack = num_output % 4 == 0 ? 4 : 1;
}
#endif

#pragma omp parallel for num_threads(opt.num_threads)
for (int j = 0; j < h; j++)
{
#if __ARM_NEON
if (elempack == 4 && num_output_elempack == 4)
{
float* outptr = top_blob.row(j);

for (int p = 0; p < num_output / num_output_elempack; p++)
{
const float* kptr = weight_data_tm.row(p);
const float* m = bottom_blob.row(j);

float32x4_t _sum0 = vdupq_n_f32(0.f);
float32x4_t _sum1 = vdupq_n_f32(0.f);
float32x4_t _sum2 = vdupq_n_f32(0.f);
float32x4_t _sum3 = vdupq_n_f32(0.f);

if (bias_term)
{
_sum0 = vdupq_n_f32(bias_data[p * 4 + 0]);
_sum1 = vdupq_n_f32(bias_data[p * 4 + 1]);
_sum2 = vdupq_n_f32(bias_data[p * 4 + 2]);
_sum3 = vdupq_n_f32(bias_data[p * 4 + 3]);
}

int i = 0;
for (; i < num_input; i++)
{
float32x4_t _val = vld1q_f32(m);
float32x4_t _w = vld1q_f32(kptr);
#if __aarch64__
_sum0 = vfmaq_laneq_f32(_sum0, _val, _w, 0);
_sum1 = vfmaq_laneq_f32(_sum1, _val, _w, 1);
_sum2 = vfmaq_laneq_f32(_sum2, _val, _w, 2);
_sum3 = vfmaq_laneq_f32(_sum3, _val, _w, 3);
#else
_sum0 = vmlaq_lane_f32(_sum0, _val, vget_low_f32(_w), 0);
_sum1 = vmlaq_lane_f32(_sum1, _val, vget_low_f32(_w), 1);
_sum2 = vmlaq_lane_f32(_sum2, _val, vget_high_f32(_w), 0);
_sum3 = vmlaq_lane_f32(_sum3, _val, vget_high_f32(_w), 1);
#endif
m += 4;
kptr += 4;
}

_sum0 = activation_ps(_sum0, activation_type, activation_params);
_sum1 = activation_ps(_sum1, activation_type, activation_params);
_sum2 = activation_ps(_sum2, activation_type, activation_params);
_sum3 = activation_ps(_sum3, activation_type, activation_params);

vst1q_f32(outptr, _sum0);
vst1q_f32(outptr + 4, _sum1);
vst1q_f32(outptr + 8, _sum2);
vst1q_f32(outptr + 12, _sum3);
outptr += 16;
}
}

if (elempack == 1 && num_output_elempack == 4)
{
float* outptr = top_blob.row(j);

for (int p = 0; p < num_output / num_output_elempack; p++)
{
const float* kptr = weight_data_tm.row(p);
const float* m = bottom_blob.row(j);

float32x4_t _sum0 = vdupq_n_f32(0.f);
float32x4_t _sum1 = vdupq_n_f32(0.f);
float32x4_t _sum2 = vdupq_n_f32(0.f);
float32x4_t _sum3 = vdupq_n_f32(0.f);

if (bias_term)
{
_sum0 = vld1q_f32((const float*)bias_data + p * 4);
}

int i = 0;
for (; i + 3 < num_input; i += 4)
{
float32x4_t _val = vld1q_f32(m);

float32x4_t _w0 = vld1q_f32(kptr);
float32x4_t _w1 = vld1q_f32(kptr + 4);
float32x4_t _w2 = vld1q_f32(kptr + 8);
float32x4_t _w3 = vld1q_f32(kptr + 12);

#if __aarch64__
_sum0 = vfmaq_laneq_f32(_sum0, _w0, _val, 0);
_sum1 = vfmaq_laneq_f32(_sum1, _w1, _val, 1);
_sum2 = vfmaq_laneq_f32(_sum2, _w2, _val, 2);
_sum3 = vfmaq_laneq_f32(_sum3, _w3, _val, 3);
#else
_sum0 = vmlaq_lane_f32(_sum0, _w0, vget_low_f32(_val), 0);
_sum1 = vmlaq_lane_f32(_sum1, _w1, vget_low_f32(_val), 1);
_sum2 = vmlaq_lane_f32(_sum2, _w2, vget_high_f32(_val), 0);
_sum3 = vmlaq_lane_f32(_sum3, _w3, vget_high_f32(_val), 1);
#endif

m += 4;
kptr += 16;
}
for (; i < num_input; i++)
{
float32x4_t _val = vld1q_dup_f32(m);
float32x4_t _k = vld1q_f32(kptr);
_sum0 = vmlaq_f32(_sum0, _val, _k);

m += 1;
kptr += 4;
}

_sum0 = vaddq_f32(_sum0, _sum1);
_sum2 = vaddq_f32(_sum2, _sum3);
_sum0 = vaddq_f32(_sum0, _sum2);

_sum0 = activation_ps(_sum0, activation_type, activation_params);

vst1q_f32(outptr, _sum0);
outptr += 4;
}
}

if (elempack == 4 && num_output_elempack == 1)
{
float* outptr = top_blob.row(j);

for (int p = 0; p < num_output; p++)
{
const float* kptr = (const float*)weight_data_tm + num_input * p;
const float* m = bottom_blob.row(j);

float32x4_t _sum = vdupq_n_f32(0.f);

if (bias_term)
{
_sum = vdupq_n_f32(bias_data[p]);
}

for (int i = 0; i < num_input; i++)
{
float32x4_t _val = vld1q_f32(m);
float32x4_t _k = vdupq_n_f32(kptr[0]);
_sum = vmlaq_f32(_sum, _val, _k);

m += 4;
kptr += 1;
}

_sum = activation_ps(_sum, activation_type, activation_params);

vst1q_f32(outptr, _sum);
outptr += 4;
}
}
#endif // __ARM_NEON

if (elempack == 1 && num_output_elempack == 1)
{
float* outptr = top_blob.row(j);

for (int p = 0; p < num_output; p++)
{
const float* kptr = (const float*)weight_data_tm + num_input * p;
const float* m = bottom_blob.row(j);

float sum = 0.f;

if (bias_term)
{
sum = bias_data[p];
}

int i = 0;
#if __ARM_NEON
float32x4_t _sum = vdupq_n_f32(0.f);
for (; i + 3 < num_input; i += 4)
{
float32x4_t _val = vld1q_f32(m);
float32x4_t _k = vld1q_f32(kptr);
_sum = vmlaq_f32(_sum, _val, _k);

m += 4;
kptr += 4;
}
#if __aarch64__
sum += vaddvq_f32(_sum);
#else
float32x2_t _ss = vadd_f32(vget_low_f32(_sum), vget_high_f32(_sum));
_ss = vpadd_f32(_ss, _ss);
sum += vget_lane_f32(_ss, 0);
#endif
#endif // __ARM_NEON
for (; i < num_input; i++)
{
sum += *m * *kptr;

m += 1;
kptr += 1;
}

sum = activation_ss(sum, activation_type, activation_params);

outptr[0] = sum;
outptr += 1;
}
}
}
top_blob = top_blob.reshape(w, h, num_output);
return 0;
}

// flatten
Mat bottom_blob_flattened = bottom_blob;
if (bottom_blob.dims != 1)
Expand Down
23 changes: 23 additions & 0 deletions src/layer/x86/innerproduct_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,29 @@ int InnerProduct_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Optio
return 0;
}

if (bottom_blob.dims == 3 && bottom_blob.c == num_input)
{
// 3d input vector
int w = bottom_blob.w;
int h = bottom_blob.h;
int c = bottom_blob.c; // num_input
size_t elemsize = bottom_blob.elemsize;
int elempack = bottom_blob.elempack;

ncnn::Mat bottom_blob_flattened = bottom_blob.reshape(c, w * h);
ncnn::Mat top_blob_flattened;

// Adjust the size of top_blob
top_blob.create(num_output, h * w, elemsize, elempack, opt.blob_allocator);
if (top_blob.empty())
return -100;

// Perform the matrix multiplication
innerproduct_gemm_sse(bottom_blob_flattened, top_blob, weight_data_tm, bias_data, activation_type, activation_params, opt);
top_blob = top_blob.reshape(w, h, num_output);
return 0;
}

// flatten
Mat bottom_blob_flattened = bottom_blob;
if (bottom_blob.dims != 1)
Expand Down