Skip to content

Commit

Permalink
cpu: rnn: use correct format for wei_proj
Browse files Browse the repository at this point in the history
  • Loading branch information
densamoilov committed Jan 31, 2025
1 parent efe1c13 commit 43a2dbb
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,7 @@ struct memory : public handle<dnnl_memory_t> {
AB8a2b = dnnl_AB8a2b,
abDc16d = dnnl_abDc16d,
abDc32d = dnnl_abDc32d,
abDC16d4c = dnnl_abDC16d4c,
abDC32d4c = dnnl_abDC32d4c,
abCd32c = dnnl_abCd32c,
abdEc16e = dnnl_abdEc16e,
Expand Down Expand Up @@ -1959,6 +1960,7 @@ struct memory : public handle<dnnl_memory_t> {

ldOi16o = abDc16d,
ldOi32o = abDc32d,
ldOI16o4i = abDC16d4c,
ldOI32o4i = abDC32d4c,
ldgOi16o = abdEc16e,
ldgOI16o4i = abdEC16e4c,
Expand Down
2 changes: 2 additions & 0 deletions include/oneapi/dnnl/dnnl_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ typedef enum {
dnnl_dabc,
dnnl_Ab32a,
dnnl_abdEC16e4c,
dnnl_abDC16d4c,

/// Just a sentinel, not real memory format tag. Must be changed after new
/// format tag is added.
Expand Down Expand Up @@ -1174,6 +1175,7 @@ typedef enum {
/// 5D LSTM projection tensor
dnnl_ldOi16o = dnnl_abDc16d,
dnnl_ldOi32o = dnnl_abDc32d,
dnnl_ldOI16o4i = dnnl_abDC16d4c,
dnnl_ldOI32o4i = dnnl_abDC32d4c,
dnnl_ldIo32i = dnnl_abCd32c,
/// 6D RNN weights tensor
Expand Down
2 changes: 2 additions & 0 deletions src/common/c_types_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ const format_tag_t AB32a32b8a2b = dnnl_AB32a32b8a2b;
const format_tag_t AB8a2b = dnnl_AB8a2b;
const format_tag_t abDc16d = dnnl_abDc16d;
const format_tag_t abDc32d = dnnl_abDc32d;
const format_tag_t abDC16d4c = dnnl_abDC16d4c;
const format_tag_t abDC32d4c = dnnl_abDC32d4c;
const format_tag_t abCd4c = dnnl_abCd4c;
const format_tag_t abCde4c = dnnl_abCde4c;
Expand Down Expand Up @@ -1459,6 +1460,7 @@ const format_tag_t gOIhw4o8i2o = dnnl_gOIhw4o8i2o;
const format_tag_t gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o;
const format_tag_t ldOi16o = dnnl_ldOi16o;
const format_tag_t ldOi32o = dnnl_ldOi32o;
const format_tag_t ldOI16o4i = dnnl_ldOI16o4i;
const format_tag_t ldOI32o4i = dnnl_ldOI32o4i;
const format_tag_t ldIo32i = dnnl_ldIo32i;
const format_tag_t ldgOi16o = dnnl_ldgOi16o;
Expand Down
2 changes: 2 additions & 0 deletions src/common/dnnl_debug_autogenerated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_AB32a32b8a2b) return "AB32a32b8a2b";
if (v == dnnl_AB8a2b) return "AB8a2b";
if (v == dnnl_abDc32d) return "abDc32d";
if (v == dnnl_abDC16d4c) return "abDC16d4c";
if (v == dnnl_abDC32d4c) return "abDC32d4c";
if (v == dnnl_abdEc32e) return "abdEc32e";
if (v == dnnl_abdEC16e4c) return "abdEC16e4c";
Expand Down Expand Up @@ -996,6 +997,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_ldgo) return "ldgo";
if (v == dnnl_ldOi16o) return "ldOi16o";
if (v == dnnl_ldOi32o) return "ldOi32o";
if (v == dnnl_ldOI16o4i) return "ldOI16o4i";
if (v == dnnl_ldOI32o4i) return "ldOI32o4i";
if (v == dnnl_ldIo32i) return "ldIo32i";
if (v == dnnl_ldgOi16o) return "ldgOi16o";
Expand Down
1 change: 1 addition & 0 deletions src/common/memory_desc_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ status_t memory_desc_wrapper::compute_blocking(
C(AB8a2b, {0, 1}, {8, 2}, {0, 1});
C(abDc16d, {0, 1, 3, 2}, {16}, {3});
C(abDc32d, {0, 1, 3, 2}, {32}, {3});
C(abDC16d4c, {0, 1, 3, 2}, {16, 4}, {3, 2});
C(abDC32d4c, {0, 1, 3, 2}, {32, 4}, {3, 2});
C(abCd4c, {0, 1, 2, 3}, {4}, {2});
C(abCde4c, {0, 1, 2, 3, 4}, {4}, {2});
Expand Down
2 changes: 2 additions & 0 deletions src/common/tag_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ enum class inner_blk_t {
_24b4c,
_24c2b,
_24c4b,
_16d4c,
_32d4c,
_32e2c,
_32e4c,
Expand Down Expand Up @@ -816,6 +817,7 @@ DECL_TRAITS(aBCde4c8b2c, _BC, _4c8b2c, 5);
DECL_TRAITS(aBCdef4c8b2c, _BC, _4c8b2c, 6);
DECL_TRAITS(abDc16d, _D, _16d, 4);
DECL_TRAITS(abDc32d, _D, _32d, 4);
DECL_TRAITS(abDC16d4c, _CD, _16d4c, 4);
DECL_TRAITS(abDC32d4c, _CD, _32d4c, 4);
DECL_TRAITS(abCd32c, _C, _32c, 4);
DECL_TRAITS(abCde32c, _C, _32c, 5);
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/rnn/rnn_reorders.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {

itag = id.matches_one_of_tag(ldigo, ldio);
otag = od.matches_one_of_tag(
ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i);
ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i, ldOI16o4i);
if (itag != format_tag::undef && otag != format_tag::undef) {
_pd->itag_ = itag;
_pd->otag_ = otag;
Expand Down
7 changes: 4 additions & 3 deletions src/cpu/rnn/rnn_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ bool rnn_utils::is_ldgoi_blocked(const memory_desc_wrapper &mdw) {
}

bool rnn_utils::is_ldio_blocked(const memory_desc_wrapper &mdw) {
format_tag_t md_format_tag = mdw.matches_one_of_tag(
format_tag::ldOi32o, format_tag::ldOI32o4i, format_tag::ldOi16o);
format_tag_t md_format_tag = mdw.matches_one_of_tag(format_tag::ldOi32o,
format_tag::ldOI32o4i, ldOI16o4i, format_tag::ldOi16o);
return md_format_tag != format_tag::undef;
}

Expand Down Expand Up @@ -286,7 +286,8 @@ status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,

if (weights_type == weights_type_t::projection) {
if (rnn.is_int8_conf())
tag = format_tag::ldOI32o4i;
tag = utils::map(n_block, format_tag::undef, 32,
format_tag::ldOI32o4i, 16, format_tag::ldOI16o4i);
else
tag = utils::map(n_block, format_tag::undef, 32,
format_tag::ldOi32o, 16, format_tag::ldOi16o);
Expand Down

0 comments on commit 43a2dbb

Please sign in to comment.