Skip to content

Commit

Permalink
sdpa: refactor internal primitive into sdpa_test_iface
Browse files Browse the repository at this point in the history
  • Loading branch information
umar456 committed Jan 29, 2025
1 parent 2625459 commit 21e4251
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 215 deletions.
56 changes: 0 additions & 56 deletions src/common/sdpa_internal.h

This file was deleted.

154 changes: 0 additions & 154 deletions src/common/sdpa_internal.hpp

This file was deleted.

7 changes: 3 additions & 4 deletions src/common/sdpa.cpp → src/common/sdpa_test_iface.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
* Copyright 2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -14,9 +14,8 @@
* limitations under the License.
*******************************************************************************/

#include "c_types_map.hpp"
#include "common/c_types_map.hpp"
#include "common/primitive_desc_iface.hpp"
#include "common/sdpa_internal.hpp"
#include "common/sdpa_pd.hpp"
#include "common/sdpa_types.hpp"
#include "common/sdpa_utils.hpp"
Expand All @@ -25,7 +24,7 @@
using dnnl::impl::status_t;
using namespace dnnl::impl;

dnnl_status_t dnnl_sdpa_primitive_desc_create(
dnnl_status_t DNNL_API sdpa_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine,
const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc,
const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc,
Expand Down
91 changes: 91 additions & 0 deletions tests/gtests/internals/sdpa_internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*******************************************************************************
* Copyright 2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#ifndef COMMON_SDPA_INTERNAL_HPP
#define COMMON_SDPA_INTERNAL_HPP

#include "dnnl.hpp"

/// Creates a primitive descriptor for a scaled dot product attention primitive
///
/// @param primitive_desc Output primitive descriptor.
/// @param engine Engine to use.
/// @param query_desc Query memory descriptor (tensor Q)
/// @param key_desc Key memory descriptor (tensor K)
/// @param value_desc Value memory descriptor (tensor V)
/// @param dst_desc Destination memory descriptor.
/// @param attn_mask_desc Attention mask memory descriptor.
/// @param attr Primitive attributes (can be NULL).
/// @param kq_attr Attribute for the Key/Query matmul operation(can be NULL).
/// @param vs_attr Attribute for the Value/Score matmul operation(can be NULL).
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.

dnnl_status_t DNNL_API sdpa_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine,
const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc,
const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc,
const_dnnl_memory_desc_t mask_desc, dnnl_data_type_t scale_dt,
bool invert_scale, dnnl_dim_t kv_head_number, bool causal_mask,
const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr,
const_dnnl_primitive_attr_t vs_attr);

namespace dnnl {
namespace impl {

/// Scaled Dot Product Attention (sdpa) internal primitive.
/// Implementing internally for more flexible validation
struct sdpa : public dnnl::primitive {
/// Primitive descriptor for a sdpa primitive.
struct primitive_desc : public dnnl::primitive_desc {
/// Default constructor. Produces an empty object.
primitive_desc() = default;

primitive_desc(const engine &aengine, const memory::desc &query_desc,
const memory::desc &key_desc, const memory::desc &value_desc,
const memory::desc *attn_mask_desc, memory::data_type scale_dt,
const memory::desc &output_desc, bool invert_scale,
memory::dim kv_head_number, bool causal_mask,
const primitive_attr &attr = default_attr(),
const primitive_attr &kq_attr = default_attr(),
const primitive_attr &vs_attr = default_attr()) {

dnnl_primitive_desc_t pd = nullptr;
dnnl_status_t status = sdpa_primitive_desc_create(&pd,
aengine.get(), query_desc.get(), key_desc.get(),
value_desc.get(), output_desc.get(),
optional_arg(attn_mask_desc), (dnnl_data_type_t)scale_dt,
invert_scale, kv_head_number, causal_mask, attr.get(),
kq_attr.get(), vs_attr.get());

dnnl::error::wrap_c_api(status,
"could not create a primitive descriptor for a sdpa "
"primitive");
reset(pd);
}
};

/// Default constructor. Produces an empty object.
sdpa() = default;

/// Constructs a sdpa primitive.
/// @param pd Primitive descriptor for a sdpa primitive.
sdpa(const primitive_desc &pd) : primitive(pd) {}
};
} // namespace impl
} // namespace dnnl

#endif
2 changes: 1 addition & 1 deletion tests/gtests/internals/test_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <dnnl_test_common.hpp>
#include <gtest/gtest.h>

#include <common/sdpa_internal.hpp>
#include "sdpa_internal.hpp"

#include <oneapi/dnnl/dnnl.hpp>

Expand Down

0 comments on commit 21e4251

Please sign in to comment.