Skip to content

Commit

Permalink
xe: ocl: fix gemm_with_po int accumulation type
Browse files Browse the repository at this point in the history
  • Loading branch information
rjoursler committed Jan 13, 2025
1 parent ddb4334 commit 797359f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src,
const uint b_scale_dim = (NDIMS == 2) ? d1 : (NDIMS == 3) ? d2 : d3;
float b_scale = 1;
if (B_SCALES) load(&b_scale, b_scales + scale_stride * b_scale_dim);
acc *= a_scale * b_scale;
if (A_SCALES || B_SCALES) acc *= a_scale * b_scale;

if (bias) {
ACC_DATA_T b;
Expand Down
22 changes: 16 additions & 6 deletions src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,31 @@ status_t gemm_with_post_ops_t::pd_t::init_kernel_ctx(
kernel_ctx, memory_desc_info_t::create(dst_md(0)), "DST", false);

int ndims = src_info.ndims;
bool is_int8 = src_md(1)->data_type == data_type::s8;
kernel_ctx.set_data_type(c_type);
//here SRC is output tensor of gemm call
def_data_type(kernel_ctx, is_int8 ? data_type::f32 : desc_.acc_type, "ACC");

kernel_ctx.define_int("NDIMS", ndims);
CHECK(def_attr_info(
kernel_ctx, attr_info_, attr()->post_ops_, *gemm_pd_->dst_md()));
const auto &attr_scales = attr()->scales_;
const bool with_src_scales
= !attr_scales.get(DNNL_ARG_SRC).has_default_values();
const bool with_wei_scales
= !attr_scales.get(DNNL_ARG_WEIGHTS).has_default_values();
const bool with_dst_scales
= !attr_scales.get(DNNL_ARG_DST).has_default_values();
auto is_int_type = [](data_type_t t) {
return utils::one_of(t, data_type::s8, data_type::u8, data_type::s32);
};
data_type_t acc_type = desc_.acc_type;
if (desc_.acc_type == data_type::s32) {
if (with_src_scales || with_wei_scales
|| !is_int_type(bias_info.data_type)
|| !is_int_type(dst_md(0)->data_type)) {
acc_type = data_type::f32;
}
}
def_data_type(kernel_ctx, acc_type, "ACC");

kernel_ctx.define_int("NDIMS", ndims);
CHECK(def_attr_info(
kernel_ctx, attr_info_, attr()->post_ops_, *gemm_pd_->dst_md()));
kernel_ctx.define_int("A_SCALES", with_src_scales);
kernel_ctx.define_int("B_SCALES", with_wei_scales);
kernel_ctx.define_int("C_SCALES", with_dst_scales);
Expand Down

0 comments on commit 797359f

Please sign in to comment.