Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xe: Add internal tests for SDPA #2549

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/common/sdpa_test_iface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*******************************************************************************
* Copyright 2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "common/c_types_map.hpp"
#include "common/primitive_desc_iface.hpp"
#include "common/sdpa_pd.hpp"
#include "common/sdpa_types.hpp"
#include "common/sdpa_utils.hpp"
#include "opdesc.hpp"

using dnnl::impl::status_t;
using namespace dnnl::impl;

dnnl_status_t DNNL_API sdpa_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine,
const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc,
const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc,
const_dnnl_memory_desc_t mask_desc, dnnl_data_type_t scale_dt,
bool invert_scale, dnnl_dim_t kv_head_number, bool causal_mask,
const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr,
const_dnnl_primitive_attr_t vs_attr) {
if (auto err = sdpa_attr_check(query_desc, key_desc, value_desc, engine,
attr, kq_attr, vs_attr)) {
return err;
}
dnnl::impl::sdpa_desc_t sdpa_desc = dnnl::impl::create_sdpa_desc(query_desc,
key_desc, value_desc, dst_desc, mask_desc,
(dnnl::impl::data_type_t)scale_dt, invert_scale, kv_head_number,
causal_mask, kq_attr, vs_attr);
return dnnl::impl::primitive_desc_create(primitive_desc_iface, engine,
(const dnnl::impl::op_desc_t *)&sdpa_desc, nullptr, attr);
}
36 changes: 20 additions & 16 deletions src/gpu/intel/ocl/micro_sdpa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,16 @@ struct micro_sdpa_t : public gpu_primitive_t {
"kq scales mask(%d) must equal kq zero point(%d) "
"mask",
kq_scales_mask, kq_zp_mask);
VDISPATCH_SDPA(utils::one_of(kq_scales_mask, 0, 1, 3, 11, 15),
"unsupported mask for kq matmul(%d). must be 0, 1, 3, 11, "
"or 15",
kq_scales_mask);
VDISPATCH_SDPA(utils::one_of(kq_zp_mask, 0, 1, 3, 11, 15),
"unsupported mask for kq matmul(%d). must be 0, 1, 3, 11, "
"or 15",
kq_zp_mask);
if (!desc()->kq_scales.has_default_values())
VDISPATCH_SDPA(utils::one_of(kq_scales_mask, 0, 1, 3, 11, 15),
"unsupported mask for kq matmul(%d). must be 0, 1, 3, "
"11, or 15",
kq_scales_mask);
if (!desc()->kq_zero_points.has_default_values())
VDISPATCH_SDPA(utils::one_of(kq_zp_mask, 0, 1, 3, 11, 15),
"unsupported mask for kq matmul(%d). must be 0, 1, 3, "
"11, or 15",
kq_zp_mask);

/// NOTE: Limitation of microkernels
if (utils::one_of(
Expand All @@ -127,14 +129,16 @@ struct micro_sdpa_t : public gpu_primitive_t {
"vs scales mask(%d) must equal vs zero point(%d) "
"mask",
vs_scales_mask, vs_zp_mask);
VDISPATCH_SDPA(utils::one_of(vs_scales_mask, 0, 1, 3, 7, 15),
"unsupported mask for vs matmul(%d). must be 0, 1, 3, 7, "
"or 15",
vs_scales_mask);
VDISPATCH_SDPA(utils::one_of(vs_zp_mask, 0, 1, 3, 7, 15),
"unsupported mask for vs matmul(%d). must be 0, 1, 3, 7, "
"or 15",
vs_zp_mask);
if (!desc()->vs_zero_points.has_default_values())
VDISPATCH_SDPA(utils::one_of(vs_scales_mask, 0, 1, 3, 7, 15),
"unsupported mask for vs matmul(%d). must be 0, 1, 3, "
"7, or 15",
vs_scales_mask);
if (!desc()->vs_zero_points.has_default_values())
VDISPATCH_SDPA(utils::one_of(vs_zp_mask, 0, 1, 3, 7, 15),
"unsupported mask for vs matmul(%d). must be 0, 1, 3, "
"7, or 15",
vs_zp_mask);

/// NOTE: Limitation of microkernels
if (utils::one_of(
Expand Down
91 changes: 91 additions & 0 deletions tests/gtests/internals/sdpa_internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*******************************************************************************
* Copyright 2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef COMMON_SDPA_INTERNAL_HPP
#define COMMON_SDPA_INTERNAL_HPP

#include "dnnl.hpp"

/// Creates a primitive descriptor for a scaled dot product attention primitive
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param query_desc Query memory descriptor (tensor Q)
/// @param key_desc Key memory descriptor (tensor K)
/// @param value_desc Value memory descriptor (tensor V)
/// @param dst_desc Destination memory descriptor.
/// @param attn_mask_desc Attention mask memory descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @param kq_attr Attribute for the Key/Query matmul operation(can be NULL).
/// @param vs_attr Attribute for the Value/Score matmul operation(can be NULL).
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.

dnnl_status_t DNNL_API sdpa_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine,
const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc,
const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc,
const_dnnl_memory_desc_t mask_desc, dnnl_data_type_t scale_dt,
bool invert_scale, dnnl_dim_t kv_head_number, bool causal_mask,
const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr,
const_dnnl_primitive_attr_t vs_attr);

namespace dnnl {
namespace impl {

/// Scaled Dot Product Attention (sdpa) internal primitive.
/// Implementing internally for more flexible validation
struct sdpa : public dnnl::primitive {
/// Primitive descriptor for a sdpa primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;

primitive_desc(const engine &aengine, const memory::desc &query_desc,
const memory::desc &key_desc, const memory::desc &value_desc,
const memory::desc *attn_mask_desc, memory::data_type scale_dt,
const memory::desc &output_desc, bool invert_scale,
memory::dim kv_head_number, bool causal_mask,
const primitive_attr &attr = default_attr(),
const primitive_attr &kq_attr = default_attr(),
const primitive_attr &vs_attr = default_attr()) {

dnnl_primitive_desc_t pd = nullptr;
dnnl_status_t status = sdpa_primitive_desc_create(&pd,
aengine.get(), query_desc.get(), key_desc.get(),
value_desc.get(), output_desc.get(),
optional_arg(attn_mask_desc), (dnnl_data_type_t)scale_dt,
invert_scale, kv_head_number, causal_mask, attr.get(),
kq_attr.get(), vs_attr.get());

dnnl::error::wrap_c_api(status,
"could not create a primitive descriptor for a sdpa "
"primitive");
reset(pd);
}
};

/// Default constructor. Produces an empty object.
sdpa() = default;

/// Constructs a sdpa primitive.
/// @param pd Primitive descriptor for a sdpa primitive.
sdpa(const primitive_desc &pd) : primitive(pd) {}
};
} // namespace impl
} // namespace dnnl

#endif
Loading
Loading