Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LSTM INT8 for AVX2 #2575

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,7 @@ struct memory : public handle<dnnl_memory_t> {
AB8a2b = dnnl_AB8a2b,
abDc16d = dnnl_abDc16d,
abDc32d = dnnl_abDc32d,
abDC16d4c = dnnl_abDC16d4c,
abDC32d4c = dnnl_abDC32d4c,
abCd32c = dnnl_abCd32c,
abdEc16e = dnnl_abdEc16e,
Expand Down Expand Up @@ -1959,6 +1960,7 @@ struct memory : public handle<dnnl_memory_t> {

ldOi16o = abDc16d,
ldOi32o = abDc32d,
ldOI16o4i = abDC16d4c,
ldOI32o4i = abDC32d4c,
ldgOi16o = abdEc16e,
ldgOI16o4i = abdEC16e4c,
Expand Down
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ typedef enum {
dnnl_dabc,
dnnl_Ab32a,
dnnl_abdEC16e4c,
dnnl_abDC16d4c,

/// Just a sentinel, not real memory format tag. Must be changed after new
/// format tag is added.
Expand Down Expand Up @@ -1174,6 +1175,7 @@ typedef enum {
/// 5D LSTM projection tensor
dnnl_ldOi16o = dnnl_abDc16d,
dnnl_ldOi32o = dnnl_abDc32d,
dnnl_ldOI16o4i = dnnl_abDC16d4c,
dnnl_ldOI32o4i = dnnl_abDC32d4c,
dnnl_ldIo32i = dnnl_abCd32c,
/// 6D RNN weights tensor
Expand Down
2 changes: 2 additions & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ const format_tag_t AB32a32b8a2b = dnnl_AB32a32b8a2b;
const format_tag_t AB8a2b = dnnl_AB8a2b;
const format_tag_t abDc16d = dnnl_abDc16d;
const format_tag_t abDc32d = dnnl_abDc32d;
const format_tag_t abDC16d4c = dnnl_abDC16d4c;
const format_tag_t abDC32d4c = dnnl_abDC32d4c;
const format_tag_t abCd4c = dnnl_abCd4c;
const format_tag_t abCde4c = dnnl_abCde4c;
Expand Down Expand Up @@ -1459,6 +1460,7 @@ const format_tag_t gOIhw4o8i2o = dnnl_gOIhw4o8i2o;
const format_tag_t gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o;
const format_tag_t ldOi16o = dnnl_ldOi16o;
const format_tag_t ldOi32o = dnnl_ldOi32o;
const format_tag_t ldOI16o4i = dnnl_ldOI16o4i;
const format_tag_t ldOI32o4i = dnnl_ldOI32o4i;
const format_tag_t ldIo32i = dnnl_ldIo32i;
const format_tag_t ldgOi16o = dnnl_ldgOi16o;
Expand Down
2 changes: 2 additions & 0 deletions src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_AB32a32b8a2b) return "AB32a32b8a2b";
if (v == dnnl_AB8a2b) return "AB8a2b";
if (v == dnnl_abDc32d) return "abDc32d";
if (v == dnnl_abDC16d4c) return "abDC16d4c";
if (v == dnnl_abDC32d4c) return "abDC32d4c";
if (v == dnnl_abdEc32e) return "abdEc32e";
if (v == dnnl_abdEC16e4c) return "abdEC16e4c";
Expand Down Expand Up @@ -996,6 +997,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_ldgo) return "ldgo";
if (v == dnnl_ldOi16o) return "ldOi16o";
if (v == dnnl_ldOi32o) return "ldOi32o";
if (v == dnnl_ldOI16o4i) return "ldOI16o4i";
if (v == dnnl_ldOI32o4i) return "ldOI32o4i";
if (v == dnnl_ldIo32i) return "ldIo32i";
if (v == dnnl_ldgOi16o) return "ldgOi16o";
Expand Down
1 change: 1 addition & 0 deletions src/common/memory_desc_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ status_t memory_desc_wrapper::compute_blocking(
C(AB8a2b, {0, 1}, {8, 2}, {0, 1});
C(abDc16d, {0, 1, 3, 2}, {16}, {3});
C(abDc32d, {0, 1, 3, 2}, {32}, {3});
C(abDC16d4c, {0, 1, 3, 2}, {16, 4}, {3, 2});
C(abDC32d4c, {0, 1, 3, 2}, {32, 4}, {3, 2});
C(abCd4c, {0, 1, 2, 3}, {4}, {2});
C(abCde4c, {0, 1, 2, 3, 4}, {4}, {2});
Expand Down
2 changes: 2 additions & 0 deletions src/common/tag_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ enum class inner_blk_t {
_24b4c,
_24c2b,
_24c4b,
_16d4c,
_32d4c,
_32e2c,
_32e4c,
Expand Down Expand Up @@ -816,6 +817,7 @@ DECL_TRAITS(aBCde4c8b2c, _BC, _4c8b2c, 5);
DECL_TRAITS(aBCdef4c8b2c, _BC, _4c8b2c, 6);
DECL_TRAITS(abDc16d, _D, _16d, 4);
DECL_TRAITS(abDc32d, _D, _32d, 4);
DECL_TRAITS(abDC16d4c, _CD, _16d4c, 4);
DECL_TRAITS(abDC32d4c, _CD, _32d4c, 4);
DECL_TRAITS(abCd32c, _C, _32c, 4);
DECL_TRAITS(abCde32c, _C, _32c, 5);
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/rnn/rnn_reorders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {

itag = id.matches_one_of_tag(ldigo, ldio);
otag = od.matches_one_of_tag(
ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i);
ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i, ldOI16o4i);
if (itag != format_tag::undef && otag != format_tag::undef) {
_pd->itag_ = itag;
_pd->otag_ = otag;
Expand Down
7 changes: 4 additions & 3 deletions src/cpu/rnn/rnn_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ bool rnn_utils::is_ldgoi_blocked(const memory_desc_wrapper &mdw) {
}

bool rnn_utils::is_ldio_blocked(const memory_desc_wrapper &mdw) {
format_tag_t md_format_tag = mdw.matches_one_of_tag(
format_tag::ldOi32o, format_tag::ldOI32o4i, format_tag::ldOi16o);
format_tag_t md_format_tag = mdw.matches_one_of_tag(format_tag::ldOi32o,
format_tag::ldOI32o4i, ldOI16o4i, format_tag::ldOi16o);
return md_format_tag != format_tag::undef;
}

Expand Down Expand Up @@ -286,7 +286,8 @@ status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,

if (weights_type == weights_type_t::projection) {
if (rnn.is_int8_conf())
tag = format_tag::ldOI32o4i;
tag = utils::map(n_block, format_tag::undef, 32,
format_tag::ldOI32o4i, 16, format_tag::ldOI16o4i);
else
tag = utils::map(n_block, format_tag::undef, 32,
format_tag::ldOi32o, 16, format_tag::ldOi16o);
Expand Down
Loading