From 25596d25116d3fd523f1ac5e32e44cb5e8295a9e Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Fri, 31 May 2024 10:30:55 -0700 Subject: [PATCH] common: sdpa: enable ONEDNN_ENABLE_PRIMITIVE selection --- cmake/configuring_primitive_list.cmake | 2 +- cmake/options.cmake | 4 ++-- doc/build/build_options.md | 10 +++++----- include/oneapi/dnnl/dnnl_config.h.in | 1 + src/common/impl_registration.hpp | 7 +++++++ src/gpu/gpu_sdpa_list.cpp | 4 ++-- 6 files changed, 18 insertions(+), 10 deletions(-) diff --git a/cmake/configuring_primitive_list.cmake b/cmake/configuring_primitive_list.cmake index 7cdf065e691..3524f171070 100644 --- a/cmake/configuring_primitive_list.cmake +++ b/cmake/configuring_primitive_list.cmake @@ -32,7 +32,7 @@ else() foreach(impl ${DNNL_ENABLE_PRIMITIVE}) string(TOUPPER ${impl} uimpl) if(NOT "${uimpl}" MATCHES - "^(BATCH_NORMALIZATION|BINARY|CONCAT|CONVOLUTION|DECONVOLUTION|ELTWISE|INNER_PRODUCT|LAYER_NORMALIZATION|LRN|MATMUL|POOLING|PRELU|REDUCTION|REORDER|RESAMPLING|RNN|SHUFFLE|SOFTMAX|SUM)$") + "^(BATCH_NORMALIZATION|BINARY|CONCAT|CONVOLUTION|DECONVOLUTION|ELTWISE|INNER_PRODUCT|LAYER_NORMALIZATION|LRN|MATMUL|POOLING|PRELU|REDUCTION|REORDER|RESAMPLING|RNN|SDPA|SHUFFLE|SOFTMAX|SUM)$") message(FATAL_ERROR "Unsupported primitive: ${uimpl}") endif() set(BUILD_${uimpl} TRUE) diff --git a/cmake/options.cmake b/cmake/options.cmake index 8c314507cc9..4e22f782052 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -124,8 +124,8 @@ set(DNNL_ENABLE_PRIMITIVE "ALL" CACHE STRING - . Includes only the selected primitive to be enabled. Possible values are: BATCH_NORMALIZATION, BINARY, CONCAT, CONVOLUTION, DECONVOLUTION, ELTWISE, INNER_PRODUCT, LAYER_NORMALIZATION, LRN, MATMUL, - POOLING, PRELU, REDUCTION, REORDER, RESAMPLING, RNN, SHUFFLE, SOFTMAX, - SUM. + POOLING, PRELU, REDUCTION, REORDER, RESAMPLING, RNN, SDPA, SHUFFLE, + SOFTMAX, SUM. - ;;... Includes only selected primitives to be enabled at build time. This is treated as CMake string, thus, semicolon is a mandatory delimiter between names. This is the way to specify several diff --git a/doc/build/build_options.md b/doc/build/build_options.md index 7632809c16d..3ffe850d7d0 100644 --- a/doc/build/build_options.md +++ b/doc/build/build_options.md @@ -89,11 +89,11 @@ This option supports several values: `ALL` (the default) which enables all primitives implementations or a set of `BATCH_NORMALIZATION`, `BINARY`, `CONCAT`, `CONVOLUTION`, `DECONVOLUTION`, `ELTWISE`, `INNER_PRODUCT`, `LAYER_NORMALIZATION`, `LRN`, `MATMUL`, `POOLING`, `PRELU`, `REDUCTION`, -`REORDER`, `RESAMPLING`, `RNN`, `SHUFFLE`, `SOFTMAX`, `SUM`. When a set is used, -only those selected primitives implementations will be available. Attempting to -use other primitive implementations will end up returning an unimplemented -status when creating primitive descriptor. In order to specify a set, a -CMake-style string should be used, with semicolon delimiters, as in this +`REORDER`, `RESAMPLING`, `RNN`, `SDPA`, `SHUFFLE`, `SOFTMAX`, `SUM`. When a set +is used, only those selected primitives implementations will be available. +Attempting to use other primitive implementations will end up returning an +unimplemented status when creating primitive descriptor. In order to specify a +set, a CMake-style string should be used, with semicolon delimiters, as in this example: ``` -DONEDNN_ENABLE_PRIMITIVE=CONVOLUTION;MATMUL;REORDER diff --git a/include/oneapi/dnnl/dnnl_config.h.in b/include/oneapi/dnnl/dnnl_config.h.in index 18fa5cffdec..2cf99534aca 100644 --- a/include/oneapi/dnnl/dnnl_config.h.in +++ b/include/oneapi/dnnl/dnnl_config.h.in @@ -192,6 +192,7 @@ #cmakedefine01 BUILD_REORDER #cmakedefine01 BUILD_RESAMPLING #cmakedefine01 BUILD_RNN +#cmakedefine01 BUILD_SDPA #cmakedefine01 BUILD_SHUFFLE #cmakedefine01 BUILD_SOFTMAX #cmakedefine01 BUILD_SUM diff --git a/src/common/impl_registration.hpp b/src/common/impl_registration.hpp index ac270224087..3ab993cb353 100644 --- a/src/common/impl_registration.hpp +++ b/src/common/impl_registration.hpp @@ -173,6 +173,13 @@ {} #endif +#if BUILD_PRIMITIVE_ALL || BUILD_SDPA +#define REG_SDPA_P(...) __VA_ARGS__ +#else +#define REG_SDPA_P(...) \ + {} +#endif + #if BUILD_PRIMITIVE_ALL || BUILD_SHUFFLE #define REG_SHUFFLE_P(...) __VA_ARGS__ #else diff --git a/src/gpu/gpu_sdpa_list.cpp b/src/gpu/gpu_sdpa_list.cpp index 8195fbb10d8..1c9230f95c1 100644 --- a/src/gpu/gpu_sdpa_list.cpp +++ b/src/gpu/gpu_sdpa_list.cpp @@ -28,11 +28,11 @@ namespace gpu { namespace { // clang-format off -constexpr impl_list_item_t impl_list[] = { +constexpr impl_list_item_t impl_list[] = REG_SDPA_P({ GPU_INSTANCE_INTEL(intel::ocl::micro_sdpa_t) GPU_INSTANCE_INTEL_DEVMODE(intel::ocl::ref_sdpa_t) nullptr, -}; +}); // clang-format on } // namespace