Skip to content

Commit

Permalink
xe: ocl: gemm: fixup postop normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Simonsays095 committed Jan 10, 2025
1 parent 69bd2f9 commit df5f802
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions src/gpu/intel/ocl/gemm_matmul.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2020-2024 Intel Corporation
* Copyright 2020-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -88,13 +88,13 @@ struct gemm_matmul_t : public gpu_primitive_t {
CHECK(map_gemm_zp(DNNL_ARG_DST, DNNL_ARG_C));
}

auto maybe_reshape = [&](dims_t &orig_a_dims, dims_t &orig_b_dims,
dims_t &orig_c_dims,
dims_t &orig_bias_dims,
const int orig_dims) {
auto maybe_reshape
= [&](dims_t &orig_a_dims, dims_t &orig_b_dims,
dims_t &orig_c_dims, dims_t &orig_bias_dims,
const int orig_dims) -> status_t {
int batch_b_dims = 1;
for (int i = b_md->ndims; i > 2; i--) {
batch_b_dims *= b_md->dims[b_md->ndims - i];
for (int i = 0; i < b_md->ndims - 2; i++) {
batch_b_dims *= b_md->dims[i];
}
for (int i = 0; i < orig_dims; i++) {
orig_a_dims[i] = a_md->dims[i];
Expand Down Expand Up @@ -161,7 +161,7 @@ struct gemm_matmul_t : public gpu_primitive_t {
for (int i = 0; i < attr()->post_ops_.len(); i++) {
auto &po = post_ops.entry_[i];
if (po.is_binary()) {
auto &po_desc = po.binary.src1_desc;
const auto &po_desc = po.binary.src1_desc;
auto a_dim = po_desc.dims[po_desc.ndims
- reshape_size];
for (int i = po_desc.ndims; i > reshape_size; i--) {
Expand All @@ -187,9 +187,11 @@ struct gemm_matmul_t : public gpu_primitive_t {
? po_desc.dims[po_desc.ndims - 1]
: 1;
}
CHECK(memory_desc_reshape(
po_desc, po_desc, reshape_size, po_dims));
tmp_post_ops.entry_[i].binary.src1_desc = po_desc;
memory_desc_t tmp_po_desc;
CHECK(memory_desc_reshape(tmp_po_desc, po_desc,
reshape_size, po_dims));
tmp_post_ops.entry_[i].binary.src1_desc
= tmp_po_desc;
} else if (po.is_prelu()) {
auto mask = po.prelu.mask;
int new_mask = 0;
Expand Down

0 comments on commit df5f802

Please sign in to comment.