diff --git a/src/common/sdpa_internal.h b/src/common/sdpa_internal.h deleted file mode 100644 index 2a3c26128f7..00000000000 --- a/src/common/sdpa_internal.h +++ /dev/null @@ -1,56 +0,0 @@ -/******************************************************************************* -* Copyright 2024 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef COMMON_SDPA_INTERNAL_H -#define COMMON_SDPA_INTERNAL_H - -#include - -#define DNNL_ARG_QUERIES DNNL_ARG_SRC_0 -#define DNNL_ARG_KEYS DNNL_ARG_SRC_1 -#define DNNL_ARG_VALUES DNNL_ARG_SRC_2 -#define DNNL_ARG_ATTN_MASK DNNL_ARG_SHIFT - -/// @addtogroup dnnl_api_sdpa -/// @{ - -/// Creates a primitive descriptor for a scaled dot product attention primitive -/// -/// @param primitive_desc Output primitive descriptor. -/// @param engine Engine to use. -/// @param query_desc Query memory descriptor (tensor Q) -/// @param key_desc Key memory descriptor (tensor K) -/// @param value_desc Value memory descriptor (tensor V) -/// @param dst_desc Destination memory descriptor. -/// @param attn_mask_desc Attention mask memory descriptor. -/// @param attr Primitive attributes (can be NULL). -/// @param kq_attr Attribute for the Key/Query matmul operation(can be NULL). -/// @param vs_attr Attribute for the Value/Score matmul operation(can be NULL). -/// @returns #dnnl_success on success and a status describing the error -/// otherwise. - -dnnl_status_t DNNL_API dnnl_sdpa_primitive_desc_create( - dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine, - const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc, - const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, - const_dnnl_memory_desc_t mask_desc, dnnl_data_type_t scale_dt, - bool invert_scale, dnnl_dim_t kv_head_number, bool causal_mask, - const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, - const_dnnl_primitive_attr_t vs_attr); - -/// @} dnnl_api_sdpa - -#endif diff --git a/src/common/sdpa_internal.hpp b/src/common/sdpa_internal.hpp deleted file mode 100644 index c7f12231104..00000000000 --- a/src/common/sdpa_internal.hpp +++ /dev/null @@ -1,154 +0,0 @@ -/******************************************************************************* -* Copyright 2024 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#ifndef COMMON_SDPA_INTERNAL_HPP -#define COMMON_SDPA_INTERNAL_HPP - -#include "common/sdpa_internal.h" -#include "dnnl.hpp" - -namespace dnnl { -namespace impl { - -/// Scaled Dot Product Attention (sdpa) primitive. -struct sdpa : public dnnl::primitive { - /// Primitive descriptor for a sdpa primitive. - struct primitive_desc : public dnnl::primitive_desc { - /// Default constructor. Produces an empty object. - primitive_desc() = default; - - /// Constructs a primitive descriptor for a sdpa primitive - /// - /// @param aengine Engine to use. - /// @param query_desc Memory descriptor for query tensor. - /// @param key_desc Memory descriptor for key tensor. - /// @param value_desc Memory descriptor for value tensor. - /// @param output_desc Memory descriptor for output tensor. - /// @param attr Primitive attributes to use. Attributes are optional - /// and default to empty attributes. - /// @param kq_attr Primitive attributes to use. Attributes are optional - /// and default to empty attributes. - /// @param vs_attr Primitive attributes to use. Attributes are optional - /// and default to empty attributes. - primitive_desc(const engine &aengine, const memory::desc &query_desc, - const memory::desc &key_desc, const memory::desc &value_desc, - const memory::desc &output_desc, - const primitive_attr &attr = default_attr(), - const primitive_attr &kq_attr = default_attr(), - const primitive_attr &vs_attr = default_attr()) - : primitive_desc(aengine, query_desc, key_desc, value_desc, nullptr, - memory::data_type::undef, output_desc, false, 1, false, - attr, kq_attr, vs_attr) {} - - /// Constructs a primitive descriptor for a sdpa primitive - /// - /// @param aengine Engine to use. - /// @param query_desc Memory descriptor for query tensor. - /// @param key_desc Memory descriptor for key tensor. - /// @param value_desc Memory descriptor for value tensor. - /// @param output_desc Memory descriptor for output tensor. - /// @param attn_mask_desc Memory descriptor for attention mask tensor. - /// @param attr Primitive attributes to use. Attributes are optional - /// and default to empty attributes. - primitive_desc(const engine &aengine, const memory::desc &query_desc, - const memory::desc &key_desc, const memory::desc &value_desc, - const memory::desc &attn_mask_desc, - const memory::desc &output_desc, - const primitive_attr &attr = default_attr(), - const primitive_attr &kq_attr = default_attr(), - const primitive_attr &vs_attr = default_attr()) - : primitive_desc(aengine, query_desc, key_desc, value_desc, - &attn_mask_desc, memory::data_type::undef, output_desc, - false, 1, false, attr, kq_attr, vs_attr) {} - - /// Constructs a primitive descriptor for a sdpa primitive from a C - /// API primitive descriptor that must have a matching kind. - /// - /// @param pd C API primitive descriptor for a sdpa primitive. - primitive_desc(dnnl_primitive_desc_t pd) - : dnnl::primitive_desc(pd, dnnl::primitive::kind::undef) {} - - /// @copydoc dnnl::primitive_desc_base::src_desc()const - memory::desc query_desc() const { - return query_md((dnnl::query)query::src_md, 0); - } - - /// @copydoc dnnl::primitive_desc_base::weights_desc()const - memory::desc key_desc() const { - return query_md((dnnl::query)query::src_md, 1); - } - - /// @copydoc dnnl::primitive_desc_base::weights_desc()const - memory::desc value_desc() const { - return query_md((dnnl::query)query::src_md, 2); - } - - /// @copydoc dnnl::primitive_desc_base::weights_desc()const - memory::desc attn_mask_desc() const { - return query_md((dnnl::query)query::src_md, 3); - } - - /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const - memory::desc bias_desc() const { - return query_md((dnnl::query)query::weights_md, 1); - } - - ///// @copydoc dnnl::primitive_desc_base::dst_desc()const - memory::desc dst_desc() const { - return query_md((dnnl::query)query::dst_md, 0); - } - - primitive_desc(const engine &aengine, const memory::desc &query_desc, - const memory::desc &key_desc, const memory::desc &value_desc, - const memory::desc *attn_mask_desc, memory::data_type scale_dt, - const memory::desc &output_desc, bool invert_scale, - dnnl_dim_t kv_head_number, bool causal_mask, - const primitive_attr &attr, - const primitive_attr &kq_attr = default_attr(), - const primitive_attr &vs_attr = default_attr()) { - - dnnl_primitive_desc_t pd = nullptr; - dnnl_status_t status = dnnl_sdpa_primitive_desc_create(&pd, - aengine.get(), query_desc.get(), key_desc.get(), - value_desc.get(), output_desc.get(), - optional_arg(attn_mask_desc), (dnnl_data_type_t)scale_dt, - invert_scale, kv_head_number, causal_mask, attr.get(), - kq_attr.get(), vs_attr.get()); - - dnnl::error::wrap_c_api(status, - "could not create a primitive descriptor for a sdpa " - "primitive"); - reset(pd); - } - }; - - /// Default constructor. Produces an empty object. - sdpa() = default; - - /// Constructs a sdpa primitive. - /// @param pd Primitive descriptor for a sdpa primitive. - sdpa(const primitive_desc &pd) : primitive(pd) {} - - /// Constructs a sdpa primitive from a cache blob. - /// @param pd Primitive descriptor for a sdpa primitive. - /// @param cache_blob Cache blob. - sdpa(const primitive_desc &pd, const std::vector &cache_blob) - : primitive(pd, cache_blob) {} -}; -} // namespace impl -} // namespace dnnl - -#endif diff --git a/src/common/sdpa.cpp b/src/common/sdpa_test_iface.cpp similarity index 93% rename from src/common/sdpa.cpp rename to src/common/sdpa_test_iface.cpp index fcb3cee200c..ed7e4a60752 100644 --- a/src/common/sdpa.cpp +++ b/src/common/sdpa_test_iface.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2025 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,8 @@ * limitations under the License. *******************************************************************************/ -#include "c_types_map.hpp" +#include "common/c_types_map.hpp" #include "common/primitive_desc_iface.hpp" -#include "common/sdpa_internal.hpp" #include "common/sdpa_pd.hpp" #include "common/sdpa_types.hpp" #include "common/sdpa_utils.hpp" @@ -25,7 +24,7 @@ using dnnl::impl::status_t; using namespace dnnl::impl; -dnnl_status_t dnnl_sdpa_primitive_desc_create( +dnnl_status_t DNNL_API sdpa_primitive_desc_create( dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine, const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc, const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, diff --git a/tests/gtests/internals/sdpa_internal.hpp b/tests/gtests/internals/sdpa_internal.hpp new file mode 100644 index 00000000000..609db91e15c --- /dev/null +++ b/tests/gtests/internals/sdpa_internal.hpp @@ -0,0 +1,91 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef COMMON_SDPA_INTERNAL_HPP +#define COMMON_SDPA_INTERNAL_HPP + +#include "dnnl.hpp" + +/// Creates a primitive descriptor for a scaled dot product attention primitive +/// +/// @param primitive_desc Output primitive descriptor. +/// @param engine Engine to use. +/// @param query_desc Query memory descriptor (tensor Q) +/// @param key_desc Key memory descriptor (tensor K) +/// @param value_desc Value memory descriptor (tensor V) +/// @param dst_desc Destination memory descriptor. +/// @param attn_mask_desc Attention mask memory descriptor. +/// @param attr Primitive attributes (can be NULL). +/// @param kq_attr Attribute for the Key/Query matmul operation(can be NULL). +/// @param vs_attr Attribute for the Value/Score matmul operation(can be NULL). +/// @returns #dnnl_success on success and a status describing the error +/// otherwise. + +dnnl_status_t DNNL_API sdpa_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine, + const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc, + const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, + const_dnnl_memory_desc_t mask_desc, dnnl_data_type_t scale_dt, + bool invert_scale, dnnl_dim_t kv_head_number, bool causal_mask, + const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, + const_dnnl_primitive_attr_t vs_attr); + +namespace dnnl { +namespace impl { + +/// Scaled Dot Product Attention (sdpa) internal primitive. +/// Implementing internally for more flexible validation +struct sdpa : public dnnl::primitive { + /// Primitive descriptor for a sdpa primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + primitive_desc(const engine &aengine, const memory::desc &query_desc, + const memory::desc &key_desc, const memory::desc &value_desc, + const memory::desc *attn_mask_desc, memory::data_type scale_dt, + const memory::desc &output_desc, bool invert_scale, + memory::dim kv_head_number, bool causal_mask, + const primitive_attr &attr = default_attr(), + const primitive_attr &kq_attr = default_attr(), + const primitive_attr &vs_attr = default_attr()) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = sdpa_primitive_desc_create(&pd, + aengine.get(), query_desc.get(), key_desc.get(), + value_desc.get(), output_desc.get(), + optional_arg(attn_mask_desc), (dnnl_data_type_t)scale_dt, + invert_scale, kv_head_number, causal_mask, attr.get(), + kq_attr.get(), vs_attr.get()); + + dnnl::error::wrap_c_api(status, + "could not create a primitive descriptor for a sdpa " + "primitive"); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + sdpa() = default; + + /// Constructs a sdpa primitive. + /// @param pd Primitive descriptor for a sdpa primitive. + sdpa(const primitive_desc &pd) : primitive(pd) {} +}; +} // namespace impl +} // namespace dnnl + +#endif diff --git a/tests/gtests/internals/test_sdpa.cpp b/tests/gtests/internals/test_sdpa.cpp index 409bbdb498a..ff61c3a08c8 100644 --- a/tests/gtests/internals/test_sdpa.cpp +++ b/tests/gtests/internals/test_sdpa.cpp @@ -17,7 +17,7 @@ #include #include -#include +#include "sdpa_internal.hpp" #include