Skip to content

Commit

Permalink
tests: gtests: graph: add mqa decompose tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xiang1guo authored and TaoLv committed Apr 30, 2024
1 parent 2a67ffe commit 64f6f9c
Show file tree
Hide file tree
Showing 3 changed files with 433 additions and 1 deletion.
1 change: 1 addition & 0 deletions tests/gtests/graph/unit/backend/dnnl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ FILE(GLOB DNNL_OP_EXECUTION_TEST_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/test_large_partition.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_layer_norm.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_mqa_decomp.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_pool.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_prelu.cpp
${CMAKE_CURRENT_SOURCE_DIR}/test_quantize.cpp
Expand Down
16 changes: 15 additions & 1 deletion tests/gtests/graph/unit/backend/dnnl/test_large_partition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@
namespace graph = dnnl::impl::graph;
namespace utils = dnnl::graph::tests::unit::utils;

static inline void custom_setenv(
const char *name, const char *value, int overwrite) {
#ifdef _WIN32
SetEnvironmentVariable(name, value);
#else
::setenv(name, value, overwrite);
#endif
}

static void fill_data(
std::vector<float> &buffer, dnnl::impl::data_type_t dtype) {
if (dtype == dnnl::impl::data_type::u8) {
Expand Down Expand Up @@ -905,9 +914,14 @@ TEST(test_large_partition_execute, F32JaxMqa) {
outputs.emplace_back(&lt);
}

// Enable large partition test
custom_setenv("_ONEDNN_ENABLE_SDP_DECOMP", "0", 1);
graph::compiled_partition_t cp(p);
ASSERT_EQ(p.compile(&cp, inputs, outputs, eng), graph::status::success);

// Set back to avoid affecting other tests
custom_setenv("_ONEDNN_ENABLE_SDP_DECOMP", "1", 1);

std::vector<test_tensor> inputs_ts, outputs_ts;

for (auto &lt : inputs) {
Expand Down Expand Up @@ -1323,4 +1337,4 @@ TEST(test_large_partition_execute, Int8Bf16GptMha_CPU) {
test_tensor::to_graph_tensor(outputs_ts)),
graph::status::success);
strm->wait();
}
}
Loading

0 comments on commit 64f6f9c

Please sign in to comment.