Skip to content

Commit

Permalink
x64: enable brgemm int8 rnn for avx2
Browse files Browse the repository at this point in the history
  • Loading branch information
densamoilov committed Jan 22, 2025
1 parent 1a28e50 commit 188ddeb
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 16 deletions.
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,7 @@ struct memory : public handle<dnnl_memory_t> {
abCd32c = dnnl_abCd32c,
abdEc16e = dnnl_abdEc16e,
abdEc32e = dnnl_abdEc32e,
abdEC16e4c = dnnl_abdEC16e4c,
abdEC32e2c = dnnl_abdEC32e2c,
abdEC32e4c = dnnl_abdEC32e4c,
abdCe16c = dnnl_abdCe16c,
Expand Down Expand Up @@ -1980,6 +1981,7 @@ struct memory : public handle<dnnl_memory_t> {
ldOi32o = abDc32d,
ldOI32o4i = abDC32d4c,
ldgOi16o = abdEc16e,
ldgOI16o4i = abdEC16e4c,
ldgOi32o = abdEc32e,
ldgOI32o2i = abdEC32e2c,
ldgOI32o4i = abdEC32e4c,
Expand Down
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,7 @@ typedef enum {
dnnl_BAcd8a8b,
dnnl_BAcde8a8b,
dnnl_aCBdef8b8c,
dnnl_abdEC16e4c,

/// Just a sentinel, not real memory format tag. Must be changed after new
/// format tag is added.
Expand Down Expand Up @@ -1184,6 +1185,7 @@ typedef enum {
dnnl_ldIo32i = dnnl_abCd32c,
/// 6D RNN weights tensor
dnnl_ldgOi16o = dnnl_abdEc16e,
dnnl_ldgOI16o4i = dnnl_abdEC16e4c,
dnnl_ldgOi32o = dnnl_abdEc32e,
dnnl_ldgOI32o2i = dnnl_abdEC32e2c,
dnnl_ldgOI32o4i = dnnl_abdEC32e4c,
Expand Down
2 changes: 2 additions & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ const format_tag_t abCde32c = dnnl_abCde32c;
const format_tag_t abCdef32c = dnnl_abCdef32c;
const format_tag_t abdEc16e = dnnl_abdEc16e;
const format_tag_t abdEc32e = dnnl_abdEc32e;
const format_tag_t abdEC16e4c = dnnl_abdEC16e4c;
const format_tag_t abdEC32e2c = dnnl_abdEC32e2c;
const format_tag_t abdEC32e4c = dnnl_abdEC32e4c;
const format_tag_t abdEC64e2c = dnnl_abdEC64e2c;
Expand Down Expand Up @@ -1485,6 +1486,7 @@ const format_tag_t ldOI32o4i = dnnl_ldOI32o4i;
const format_tag_t ldIo32i = dnnl_ldIo32i;
const format_tag_t ldgOi16o = dnnl_ldgOi16o;
const format_tag_t ldgOi32o = dnnl_ldgOi32o;
const format_tag_t ldgOI16o4i = dnnl_ldgOI16o4i;
const format_tag_t ldgOI32o2i = dnnl_ldgOI32o2i;
const format_tag_t ldgOI32o4i = dnnl_ldgOI32o4i;
const format_tag_t ldgOI64o2i = dnnl_ldgOI64o2i;
Expand Down
4 changes: 3 additions & 1 deletion src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2018-2025 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -376,6 +376,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_abDc32d) return "abDc32d";
if (v == dnnl_abDC32d4c) return "abDC32d4c";
if (v == dnnl_abdEc32e) return "abdEc32e";
if (v == dnnl_abdEC16e4c) return "abdEC16e4c";
if (v == dnnl_abdEC32e2c) return "abdEC32e2c";
if (v == dnnl_abdEC32e4c) return "abdEC32e4c";
if (v == dnnl_aBdefC16b4c) return "aBdefC16b4c";
Expand Down Expand Up @@ -1008,6 +1009,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_ldIo32i) return "ldIo32i";
if (v == dnnl_ldgOi16o) return "ldgOi16o";
if (v == dnnl_ldgOi32o) return "ldgOi32o";
if (v == dnnl_ldgOI16o4i) return "ldgOI16o4i";
if (v == dnnl_ldgOI32o2i) return "ldgOI32o2i";
if (v == dnnl_ldgOI32o4i) return "ldgOI32o4i";
if (v == dnnl_ldgOI64o2i) return "ldgOI64o2i";
Expand Down
3 changes: 2 additions & 1 deletion src/common/memory_desc_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2024 Intel Corporation
* Copyright 2016-2025 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -628,6 +628,7 @@ status_t memory_desc_wrapper::compute_blocking(
C(abCdef32c, {0, 1, 2, 3, 4, 5}, {32}, {2});
C(abdEc16e, {0, 1, 3, 4, 2}, {16}, {4});
C(abdEc32e, {0, 1, 3, 4, 2}, {32}, {4});
C(abdEC16e4c, {0, 1, 3, 4, 2}, {16, 4}, {4, 2});
C(abdEC32e2c, {0, 1, 3, 4, 2}, {32, 2}, {4, 2});
C(abdEC32e4c, {0, 1, 3, 4, 2}, {32, 4}, {4, 2});
C(abdEC64e2c, {0, 1, 3, 4, 2}, {64, 2}, {4, 2});
Expand Down
4 changes: 3 additions & 1 deletion src/common/tag_traits.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2018-2025 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -142,6 +142,7 @@ enum class inner_blk_t {
_16b4c,
_16c2b,
_16c4b,
_16e4c,
_24a2b,
_24a4b,
_24b2a,
Expand Down Expand Up @@ -860,6 +861,7 @@ DECL_TRAITS(abCde4c, _C, _4c, 5);
DECL_TRAITS(abCdef4c, _C, _4c, 6);
DECL_TRAITS(abdEc16e, _E, _16e, 5);
DECL_TRAITS(abdEc32e, _E, _32e, 5);
DECL_TRAITS(abdEC16e4c, _CE, _16e4c, 5);
DECL_TRAITS(abdEC32e2c, _CE, _32e2c, 5);
DECL_TRAITS(abdEC32e4c, _CE, _32e4c, 5);
DECL_TRAITS(abdEC64e2c, _CE, _64e2c, 5);
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/rnn/ref_rnn.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2024 Intel Corporation
* Copyright 2018-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -414,7 +414,7 @@ _ref_rnn_common_t<aprop, src_type, weights_type, acc_type>::pd_t::init_brgemm(
VDISPATCH_RNN(
!(rnn_.is_signed_int8_conf() && !is_superset(isa, avx512_core_amx)),
VERBOSE_ISA_DT_MISMATCH);
VDISPATCH_RNN(!(rnn_.is_int8_conf() && !is_superset(isa, avx512_core_vnni)),
VDISPATCH_RNN(!(rnn_.is_int8_conf() && !is_superset(isa, avx2)),
VERBOSE_ISA_DT_MISMATCH);
VDISPATCH_RNN(!(rnn_.is_f32_conf() && !is_superset(isa, avx2)),
VERBOSE_ISA_DT_MISMATCH);
Expand Down
9 changes: 4 additions & 5 deletions src/cpu/rnn/rnn_reorders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,8 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {
format_tag_t otag, itag;

itag = id.matches_one_of_tag(ldigo, ldio);
otag = od.matches_one_of_tag(ldgOI64o4i, ldgOI32o4i, ldOI32o4i);
otag = od.matches_one_of_tag(
ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i);
if (itag != format_tag::undef && otag != format_tag::undef) {
_pd->itag_ = itag;
_pd->otag_ = otag;
Expand Down Expand Up @@ -855,15 +856,13 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {
return status::success;
}

const auto &blocked_d = dst_d;
const auto &pdims = blocked_d.padded_dims();

const int o_block = pd()->otag_ == ldgOI64o4i ? 64 : 32;
const int o_block = dst_d.blocking_desc().inner_blks[0];
static constexpr int i_block = 4;

dim_t L, D, I, G, O;
init_dims(L, D, I, G, O, src_d);

const auto &pdims = dst_d.padded_dims();
const dim_t pI = pdims[2];
const dim_t pO = (src_d.ndims() == 5) ? pdims[4] : pdims[3];
const dim_t IB = pI / i_block;
Expand Down
9 changes: 5 additions & 4 deletions src/cpu/rnn/rnn_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2018-2023 Intel Corporation
* Copyright 2018-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -76,8 +76,8 @@ bool rnn_utils::is_ldoi(const memory_desc_wrapper &mdw) {
bool rnn_utils::is_ldigo_blocked(const memory_desc_wrapper &mdw) {
format_tag_t md_format_tag = mdw.matches_one_of_tag(format_tag::ldgOi32o,
format_tag::ldgOI32o2i, format_tag::ldgOI32o4i,
format_tag::ldgOI64o2i, format_tag::ldgOI64o4i,
format_tag::ldgOi16o);
format_tag::ldgOI16o4i, format_tag::ldgOI64o2i,
format_tag::ldgOI64o4i, format_tag::ldgOi16o);
return md_format_tag != format_tag::undef;
}

Expand Down Expand Up @@ -293,7 +293,8 @@ status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,
} else if (rnn.is_fwd) {
if (rnn.is_int8_conf())
tag = utils::map(n_block, format_tag::undef, 64,
format_tag::ldgOI64o4i, 32, ldgOI32o4i);
format_tag::ldgOI64o4i, 32, ldgOI32o4i, 16,
ldgOI16o4i);
else if (rnn.is_xf16_conf())
tag = utils::map(n_block, format_tag::undef, 64,
format_tag::ldgOI64o2i, 32, ldgOI32o2i);
Expand Down
5 changes: 3 additions & 2 deletions src/cpu/x64/rnn/rnn_brgemm_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2021-2024 Intel Corporation
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -85,7 +85,8 @@ x64::cpu_isa_t brgemm_calc_isa(

if (rnn.is_cell_dt_int8()) {
return utils::map(true, x64::isa_undef, mayiuse(avx512_core_vnni),
avx512_core_vnni, mayiuse(avx512_core), avx512_core);
avx512_core_vnni, mayiuse(avx512_core), avx512_core,
mayiuse(avx2), avx2);
} else if (rnn.is_cell_dt_bf16()) {
return x64::avx512_core_bf16;
} else if (rnn.is_cell_dt_f16()) {
Expand Down

0 comments on commit 188ddeb

Please sign in to comment.