Skip to content

Commit

Permalink
benchdnn: rnn: fix skip_unimplemented conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
densamoilov committed Jan 28, 2025
1 parent 73118a6 commit 00ba394
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions tests/benchdnn/rnn/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,9 +789,10 @@ void skip_unimplemented_prb(const prb_t *prb_, res_t *res) {
return;
}
#endif
// cpu backward only supports `any` or `abx` layouts for weights
if (prb.prop == dnnl_backward && prb.tag[1] != tag::abx
&& prb.tag[1] != tag::any) {
const auto wei_tag
= normalize_tag(prb.tag[1], prb.ndims(WEIGHTS_LAYER));
// cpu backward only supports `any` layout for weights.
if (prb.prop == dnnl_backward && wei_tag != tag::any) {
res->state = SKIPPED;
res->reason = skip_reason::case_not_supported;
return;
Expand Down Expand Up @@ -835,13 +836,24 @@ void skip_unimplemented_prb(const prb_t *prb_, res_t *res) {
res->reason = skip_reason::case_not_supported;
return;
}
if (is_cpu()
&& (prb.tag[0] != tag::abx || prb.tag[1] != tag::any
|| prb.tag[2] != tag::abx)) {
res->state = SKIPPED;
res->reason = skip_reason::case_not_supported;
return;

if (is_cpu()) {
const auto src_tag
= normalize_tag(prb.tag[0], prb.ndims(SRC_LAYER));
const auto wei_tag
= normalize_tag(prb.tag[1], prb.ndims(WEIGHTS_LAYER));
const auto dst_tag
= normalize_tag(prb.tag[2], prb.ndims(DST_LAYER));

const bool tags_not_ok = src_tag != "abc" || wei_tag != tag::any
|| dst_tag != "abc";
if (tags_not_ok) {
res->state = SKIPPED;
res->reason = skip_reason::case_not_supported;
return;
}
}

if (is_gpu() && prb.tag[1] != tag::any) {
res->state = SKIPPED;
res->reason = skip_reason::case_not_supported;
Expand Down

0 comments on commit 00ba394

Please sign in to comment.