diff --git a/include/oneapi/dnnl/dnnl.hpp b/include/oneapi/dnnl/dnnl.hpp index aaaf3c95cc7..f20bc045599 100644 --- a/include/oneapi/dnnl/dnnl.hpp +++ b/include/oneapi/dnnl/dnnl.hpp @@ -1463,6 +1463,7 @@ struct memory : public handle { AB8a2b = dnnl_AB8a2b, abDc16d = dnnl_abDc16d, abDc32d = dnnl_abDc32d, + abDC16d4c = dnnl_abDC16d4c, abDC32d4c = dnnl_abDC32d4c, abCd32c = dnnl_abCd32c, abdEc16e = dnnl_abdEc16e, @@ -1959,6 +1960,7 @@ struct memory : public handle { ldOi16o = abDc16d, ldOi32o = abDc32d, + ldOI16o4i = abDC16d4c, ldOI32o4i = abDC32d4c, ldgOi16o = abdEc16e, ldgOI16o4i = abdEC16e4c, diff --git a/include/oneapi/dnnl/dnnl_types.h b/include/oneapi/dnnl/dnnl_types.h index c8364f18010..40f4e269ec5 100644 --- a/include/oneapi/dnnl/dnnl_types.h +++ b/include/oneapi/dnnl/dnnl_types.h @@ -1039,6 +1039,7 @@ typedef enum { dnnl_dabc, dnnl_Ab32a, dnnl_abdEC16e4c, + dnnl_abDC16d4c, /// Just a sentinel, not real memory format tag. Must be changed after new /// format tag is added. @@ -1174,6 +1175,7 @@ typedef enum { /// 5D LSTM projection tensor dnnl_ldOi16o = dnnl_abDc16d, dnnl_ldOi32o = dnnl_abDc32d, + dnnl_ldOI16o4i = dnnl_abDC16d4c, dnnl_ldOI32o4i = dnnl_abDC32d4c, dnnl_ldIo32i = dnnl_abCd32c, /// 6D RNN weights tensor diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp index f3ccc494c8d..13e25f9c563 100644 --- a/src/common/c_types_map.hpp +++ b/src/common/c_types_map.hpp @@ -693,6 +693,7 @@ const format_tag_t AB32a32b8a2b = dnnl_AB32a32b8a2b; const format_tag_t AB8a2b = dnnl_AB8a2b; const format_tag_t abDc16d = dnnl_abDc16d; const format_tag_t abDc32d = dnnl_abDc32d; +const format_tag_t abDC16d4c = dnnl_abDC16d4c; const format_tag_t abDC32d4c = dnnl_abDC32d4c; const format_tag_t abCd4c = dnnl_abCd4c; const format_tag_t abCde4c = dnnl_abCde4c; @@ -1459,6 +1460,7 @@ const format_tag_t gOIhw4o8i2o = dnnl_gOIhw4o8i2o; const format_tag_t gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o; const format_tag_t ldOi16o = dnnl_ldOi16o; const format_tag_t ldOi32o = dnnl_ldOi32o; +const format_tag_t ldOI16o4i = dnnl_ldOI16o4i; const format_tag_t ldOI32o4i = dnnl_ldOI32o4i; const format_tag_t ldIo32i = dnnl_ldIo32i; const format_tag_t ldgOi16o = dnnl_ldgOi16o; diff --git a/src/common/dnnl_debug_autogenerated.cpp b/src/common/dnnl_debug_autogenerated.cpp index 0f2ba6bd440..8de18f70cf1 100644 --- a/src/common/dnnl_debug_autogenerated.cpp +++ b/src/common/dnnl_debug_autogenerated.cpp @@ -371,6 +371,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_AB32a32b8a2b) return "AB32a32b8a2b"; if (v == dnnl_AB8a2b) return "AB8a2b"; if (v == dnnl_abDc32d) return "abDc32d"; + if (v == dnnl_abDC16d4c) return "abDC16d4c"; if (v == dnnl_abDC32d4c) return "abDC32d4c"; if (v == dnnl_abdEc32e) return "abdEc32e"; if (v == dnnl_abdEC16e4c) return "abdEC16e4c"; @@ -996,6 +997,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) { if (v == dnnl_ldgo) return "ldgo"; if (v == dnnl_ldOi16o) return "ldOi16o"; if (v == dnnl_ldOi32o) return "ldOi32o"; + if (v == dnnl_ldOI16o4i) return "ldOI16o4i"; if (v == dnnl_ldOI32o4i) return "ldOI32o4i"; if (v == dnnl_ldIo32i) return "ldIo32i"; if (v == dnnl_ldgOi16o) return "ldgOi16o"; diff --git a/src/common/memory_desc_wrapper.cpp b/src/common/memory_desc_wrapper.cpp index bcd40b89474..e1a8c6f1d98 100644 --- a/src/common/memory_desc_wrapper.cpp +++ b/src/common/memory_desc_wrapper.cpp @@ -612,6 +612,7 @@ status_t memory_desc_wrapper::compute_blocking( C(AB8a2b, {0, 1}, {8, 2}, {0, 1}); C(abDc16d, {0, 1, 3, 2}, {16}, {3}); C(abDc32d, {0, 1, 3, 2}, {32}, {3}); + C(abDC16d4c, {0, 1, 3, 2}, {16, 4}, {3, 2}); C(abDC32d4c, {0, 1, 3, 2}, {32, 4}, {3, 2}); C(abCd4c, {0, 1, 2, 3}, {4}, {2}); C(abCde4c, {0, 1, 2, 3, 4}, {4}, {2}); diff --git a/src/common/tag_traits.hpp b/src/common/tag_traits.hpp index f9857c1a1bf..28f1b64a776 100644 --- a/src/common/tag_traits.hpp +++ b/src/common/tag_traits.hpp @@ -121,6 +121,7 @@ enum class inner_blk_t { _24b4c, _24c2b, _24c4b, + _16d4c, _32d4c, _32e2c, _32e4c, @@ -816,6 +817,7 @@ DECL_TRAITS(aBCde4c8b2c, _BC, _4c8b2c, 5); DECL_TRAITS(aBCdef4c8b2c, _BC, _4c8b2c, 6); DECL_TRAITS(abDc16d, _D, _16d, 4); DECL_TRAITS(abDc32d, _D, _32d, 4); +DECL_TRAITS(abDC16d4c, _CD, _16d4c, 4); DECL_TRAITS(abDC32d4c, _CD, _32d4c, 4); DECL_TRAITS(abCd32c, _C, _32c, 4); DECL_TRAITS(abCde32c, _C, _32c, 5); diff --git a/src/cpu/rnn/rnn_reorders.hpp b/src/cpu/rnn/rnn_reorders.hpp index 78858f85c49..63ab1a997d1 100644 --- a/src/cpu/rnn/rnn_reorders.hpp +++ b/src/cpu/rnn/rnn_reorders.hpp @@ -803,7 +803,7 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t { itag = id.matches_one_of_tag(ldigo, ldio); otag = od.matches_one_of_tag( - ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i); + ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i, ldOI16o4i); if (itag != format_tag::undef && otag != format_tag::undef) { _pd->itag_ = itag; _pd->otag_ = otag; diff --git a/src/cpu/rnn/rnn_utils.cpp b/src/cpu/rnn/rnn_utils.cpp index 2f364d9ce98..a4a51608ad1 100644 --- a/src/cpu/rnn/rnn_utils.cpp +++ b/src/cpu/rnn/rnn_utils.cpp @@ -88,8 +88,8 @@ bool rnn_utils::is_ldgoi_blocked(const memory_desc_wrapper &mdw) { } bool rnn_utils::is_ldio_blocked(const memory_desc_wrapper &mdw) { - format_tag_t md_format_tag = mdw.matches_one_of_tag( - format_tag::ldOi32o, format_tag::ldOI32o4i, format_tag::ldOi16o); + format_tag_t md_format_tag = mdw.matches_one_of_tag(format_tag::ldOi32o, + format_tag::ldOI32o4i, ldOI16o4i, format_tag::ldOi16o); return md_format_tag != format_tag::undef; } @@ -286,7 +286,8 @@ status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn, if (weights_type == weights_type_t::projection) { if (rnn.is_int8_conf()) - tag = format_tag::ldOI32o4i; + tag = utils::map(n_block, format_tag::undef, 32, + format_tag::ldOI32o4i, 16, format_tag::ldOI16o4i); else tag = utils::map(n_block, format_tag::undef, 32, format_tag::ldOi32o, 16, format_tag::ldOi16o);