Skip to content

Commit

Permalink
check++
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jun 5, 2024
1 parent 9c7b801 commit bf79f48
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 8 deletions.
6 changes: 6 additions & 0 deletions src/layer/x86/crop_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
opt_pack1.blob_allocator = opt.workspace_allocator;

convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1);
if (bottom_blob_unpacked.empty())
return -100;
}

return Crop::forward(bottom_blob_unpacked, top_blob, opt);
Expand Down Expand Up @@ -980,6 +982,8 @@ int Crop_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
opt_pack1.blob_allocator = opt.workspace_allocator;

convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1);
if (bottom_blob_unpacked.empty())
return -100;
}

Mat reference_blob_unpacked = reference_blob;
Expand All @@ -989,6 +993,8 @@ int Crop_x86::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& to
opt_pack1.blob_allocator = opt.workspace_allocator;

convert_packing(reference_blob, reference_blob_unpacked, 1, opt_pack1);
if (reference_blob_unpacked.empty())
return -100;
}

std::vector<Mat> bottom_blobs_unpacked(2);
Expand Down
48 changes: 40 additions & 8 deletions src/layer/x86/multiheadattention_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,11 @@ int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::v
const Mat& attn_mask_blob = attn_mask ? bottom_blobs[bottom_blobs.size() - 1] : Mat();

Mat attn_mask_blob_unpacked;
if (attn_mask_blob.elempack != 1)
if (attn_mask && attn_mask_blob.elempack != 1)
{
convert_packing(attn_mask_blob, attn_mask_blob_unpacked, 1, opt);
if (attn_mask_blob_unpacked.empty())
return -100;
}
else
{
Expand All @@ -287,12 +289,21 @@ int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::v
const int dst_seqlen = k_blob.h * k_blob.elempack;

Mat q_affine;
q_gemm->forward(q_blob, q_affine, opt);
int retq = q_gemm->forward(q_blob, q_affine, opt);
if (retq != 0)
return retq;

Mat k_affine;
k_gemm->forward(k_blob, k_affine, opt);
int retk = k_gemm->forward(k_blob, k_affine, opt);
if (retk != 0)
return retk;

Mat qk_cross(dst_seqlen, src_seqlen * num_heads, 4u, opt.blob_allocator);
if (qk_cross.empty())
return -100;

std::vector<int> retqks;
retqks.resize(num_heads);
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < num_heads; i++)
{
Expand All @@ -308,18 +319,32 @@ int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::v
qk_top_blobs[0] = qk_cross.row_range(i * src_seqlen, src_seqlen);
Option opt1 = opt;
opt1.num_threads = 1;
qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1);
retqks[i] = qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1);
}
for (int i = 0; i < num_heads; i++)
{
if (retqks[i] != 0)
return retqks[i];
}

q_affine.release();
k_affine.release();

qk_softmax->forward_inplace(qk_cross, opt);
int retqk = qk_softmax->forward_inplace(qk_cross, opt);
if (retqk != 0)
return retqk;

Mat v_affine;
v_gemm->forward(v_blob, v_affine, opt);
int retv = v_gemm->forward(v_blob, v_affine, opt);
if (retv != 0)
return retv;

Mat qkv_cross(src_seqlen, embed_dim_per_head * num_heads, 4u, opt.blob_allocator);
if (qkv_cross.empty())
return -100;

std::vector<int> retqkvs;
retqkvs.resize(num_heads);
#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < num_heads; i++)
{
Expand All @@ -330,12 +355,19 @@ int MultiHeadAttention_x86::forward(const std::vector<Mat>& bottom_blobs, std::v
qkv_top_blobs[0] = qkv_cross.row_range(i * embed_dim_per_head, embed_dim_per_head);
Option opt1 = opt;
opt1.num_threads = 1;
qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1);
retqkvs[i] = qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1);
}
for (int i = 0; i < num_heads; i++)
{
if (retqkvs[i] != 0)
return retqkvs[i];
}

v_affine.release();

o_gemm->forward(qkv_cross, top_blobs[0], opt);
int reto = o_gemm->forward(qkv_cross, top_blobs[0], opt);
if (reto != 0)
return reto;

return 0;
}
Expand Down
4 changes: 4 additions & 0 deletions src/layer/x86/shufflechannel_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,17 @@ int ShuffleChannel_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Opt

Mat bottom_blob_unpacked;
convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack);
if (bottom_blob_unpacked.empty())
return -100;

Mat top_blob_unpacked;
int ret = ShuffleChannel::forward(bottom_blob_unpacked, top_blob_unpacked, opt_pack);
if (ret != 0)
return ret;

convert_packing(top_blob_unpacked, top_blob, elempack, opt);
if (top_blob.empty())
return -100;

return 0;
}
Expand Down

0 comments on commit bf79f48

Please sign in to comment.