diff --git a/src/common/memory_tracking.hpp b/src/common/memory_tracking.hpp index 1eec393ad2f..b3f18bdca01 100644 --- a/src/common/memory_tracking.hpp +++ b/src/common/memory_tracking.hpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright 2018-2024 Intel Corporation -* Copyright 2024 Arm Ltd. and affiliates +* Copyright 2024-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -205,8 +205,6 @@ enum { key_conv_ncsp_matmul_dst, key_conv_ncsp_diff_sp_sum, key_conv_padded_bias, - key_conv_permuted_inputs, - key_conv_permuted_outputs, key_conv_permuted_weights, key_conv_rtus_space, key_conv_store_wsp, @@ -317,11 +315,9 @@ enum { key_softmax_interim_store, key_sum_reduction, key_sum_srcs_cvt, - key_wino_transformed_weights, key_wino_U, key_wino_V, key_wino_M, - key_wino_workspace, // These two keys should always be the last ones, // even though they are not in alphabetical order key_nested, diff --git a/src/cpu/aarch64/acl_convolution_utils.cpp b/src/cpu/aarch64/acl_convolution_utils.cpp index 15437746069..fdaa4aa63f7 100644 --- a/src/cpu/aarch64/acl_convolution_utils.cpp +++ b/src/cpu/aarch64/acl_convolution_utils.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates +* Copyright 2020-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. *******************************************************************************/ -#include "acl_convolution_utils.hpp" +#include "cpu/aarch64/acl_convolution_utils.hpp" #include "common/convolution_pd.hpp" #include "common/utils.hpp" #include "oneapi/dnnl/dnnl.h" @@ -283,6 +283,56 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, return status::success; } + +status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr) { + + // Under these conditions, fallback to faster GEMM-based convolution + // unless the user explicitly specifies Winograd algorithm + // clang-format off + if (one_of(true, src_md.dims[2] > 112, // ih + src_md.dims[3] > 112, // iw + src_md.dims[1] < 64, // ic + dst_md.dims[1] < 64, // oc + dnnl_get_max_threads() > 28) + && cd.alg_kind == alg_kind::convolution_auto) { + return status::unimplemented; + } + // clang-format on + + // General Compute Library checks, memory tags are also set there + acp.alg_winograd = true; + CHECK(acl_init_conf(acp, src_md, weights_md, dst_md, bias_md, cd, attr)); + + const bool shape_ok + // only unit strides allowed + = (acp.padstride_info.stride() == std::pair {1, 1}) + // Note: Compute Library supports arbitrary padding for wino kernels + // but we only allow small padding to be consistent with oneDNN + && (acp.padstride_info.pad().first <= 1) // padding left/right + && (acp.padstride_info.pad().second <= 1) // padding top/bottom + // only non-dilated convolutions allowed + && (acp.dilation_info == arm_compute::Size2D(1, 1)); + + ACL_CHECK_SUPPORT(!shape_ok, "shape not supported by winograd kernels"); + + // clang-format off + // Validate convolution manually to check for return status + ACL_CHECK_VALID(arm_compute::NEWinogradConvolutionLayer::validate( + &acp.src_tensor_info, + &acp.wei_tensor_info, + acp.with_bias ? &acp.bia_tensor_info : nullptr, + &acp.dst_tensor_info, + acp.padstride_info, + acp.act_info, + true)); // enable_fast_math flag in ACL Winograd + // clang-format on + + return status::success; +} + } // namespace acl_convolution_utils } // namespace aarch64 diff --git a/src/cpu/aarch64/acl_convolution_utils.hpp b/src/cpu/aarch64/acl_convolution_utils.hpp index 37a3d6c3d98..c438cf9574b 100644 --- a/src/cpu/aarch64/acl_convolution_utils.hpp +++ b/src/cpu/aarch64/acl_convolution_utils.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates +* Copyright 2020-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,10 @@ namespace aarch64 { template struct acl_obj_t { + arm_compute::Tensor src_tensor; + arm_compute::Tensor wei_tensor; + arm_compute::Tensor bia_tensor; + arm_compute::Tensor dst_tensor; ConvOp conv; arm_compute::experimental::MemoryRequirements aux_mem_req; }; @@ -64,6 +68,11 @@ status_t acl_init_conf(acl_conv_conf_t &acp, memory_desc_t &src_md, memory_desc_t &bias_md, const convolution_desc_t &cd, const primitive_attr_t &attr); +status_t init_conf_wino(acl_conv_conf_t &acp, memory_desc_t &src_md, + memory_desc_t &weights_md, memory_desc_t &dst_md, + memory_desc_t &bias_md, const convolution_desc_t &cd, + const primitive_attr_t &attr); + } // namespace acl_convolution_utils // Keys are anonymous with local linkage. So deduce the type automagically. @@ -175,6 +184,53 @@ status_t execute_forward_conv_acl(const exec_ctx_t &ctx, return status::success; } +template +status_t execute_forward_conv_acl( + const exec_ctx_t &ctx, conv_obj_t &acl_conv_obj, const conv_pd_t *pd) { + bool with_bias = pd->acp_.with_bias; + bool use_dst_acc_for_sum = pd->acp_.use_dst_acc_for_sum; + + auto src_base = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); + auto wei_base = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); + + // import_memory() and free() methods do not allocate/free any additional + // memory, only acquire/release pointers. + acl_conv_obj.src_tensor.allocator()->import_memory( + const_cast(src_base)); + acl_conv_obj.wei_tensor.allocator()->import_memory( + const_cast(wei_base)); + + const auto scratchpad = ctx.get_scratchpad_grantor(); + + // If we have an unfused sum post op, put the result in a scratchpad tensor. + // Result will be summed to the dst during acl_post_ops.execute + auto dst_base = use_dst_acc_for_sum + ? scratchpad.get(memory_tracking::names::key_generic_acc) + : CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); + acl_conv_obj.dst_tensor.allocator()->import_memory(dst_base); + + if (with_bias) { + auto bia_base = CTX_IN_MEM(const bia_data_t *, DNNL_ARG_BIAS); + acl_conv_obj.bia_tensor.allocator()->import_memory( + const_cast(bia_base)); + } + + acl_conv_obj.conv.run(); + + acl_conv_obj.src_tensor.allocator()->free(); + acl_conv_obj.wei_tensor.allocator()->free(); + if (with_bias) { acl_conv_obj.bia_tensor.allocator()->free(); } + + void *dst = acl_conv_obj.dst_tensor.buffer(); + pd->post_ops.execute(ctx, dst); + + acl_conv_obj.dst_tensor.allocator()->free(); + + return status::success; +} + } // namespace aarch64 } // namespace cpu } // namespace impl diff --git a/src/cpu/aarch64/acl_gemm_convolution.cpp b/src/cpu/aarch64/acl_gemm_convolution.cpp index 5934fd24102..ed01f832e2b 100644 --- a/src/cpu/aarch64/acl_gemm_convolution.cpp +++ b/src/cpu/aarch64/acl_gemm_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates +* Copyright 2020-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -89,30 +89,13 @@ template status_t acl_gemm_convolution_fwd_t::init( engine_t *engine) { - // commented due to hot fix solution for stateless API which should be replaced soon. - // auto acp_ = pd()->acp_; - // acl_obj_->conv.configure(&acp_.src_tensor_info, &acp_.wei_tensor_info, - // acp_.with_bias ? &acp_.bia_tensor_info : nullptr, - // &acp_.dst_tensor_info, acp_.padstride_info, acp_.weights_info, - // acp_.dilation_info, acp_.act_info, acp_.fast_math); - // acl_obj_->aux_mem_req = acl_obj_->conv.workspace(); - return status::success; -} - -template -std::unique_ptr::Op>> -acl_gemm_convolution_fwd_t::reinitialize_acl_obj() const { auto acp_ = pd()->acp_; - std::unique_ptr> acl_obj = std::make_unique>(); - acl_obj->conv.configure(&acp_.src_tensor_info, &acp_.wei_tensor_info, + acl_obj_->conv.configure(&acp_.src_tensor_info, &acp_.wei_tensor_info, acp_.with_bias ? &acp_.bia_tensor_info : nullptr, &acp_.dst_tensor_info, acp_.padstride_info, acp_.weights_info, acp_.dilation_info, acp_.act_info, acp_.fast_math); - acl_obj->aux_mem_req = acl_obj->conv.workspace(); - return acl_obj; + acl_obj_->aux_mem_req = acl_obj_->conv.workspace(); + return status::success; } template ::execute_forward( const exec_ctx_t &ctx) const { - // Temporary hotfix: We're using a local acl_obj instance in this method - // instead of the class member acl_obj_. This hotfix is to bypass persistent aux mem requirements but is not the ideal solution. - // It should be refactored or removed in the future when a more permanent fix is implemented. - auto acl_obj = reinitialize_acl_obj(); - return execute_forward_conv_acl, pd_t, src_data_t, wei_data_t, - dst_data_t, bia_data_t>(ctx, acl_obj.get(), pd(), gemm_conv_keys); + dst_data_t, bia_data_t>(ctx, acl_obj_.get(), pd(), gemm_conv_keys); } using namespace data_type; diff --git a/src/cpu/aarch64/acl_gemm_convolution.hpp b/src/cpu/aarch64/acl_gemm_convolution.hpp index 6b40f0efff4..f92606374e6 100644 --- a/src/cpu/aarch64/acl_gemm_convolution.hpp +++ b/src/cpu/aarch64/acl_gemm_convolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates +* Copyright 2020-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,10 +47,8 @@ struct acl_gemm_convolution_fwd_t : public primitive_t { acl_post_ops_t post_ops; }; - // hot fix solution for stateless API which should be replaced soon. - // acl_gemm_convolution_fwd_t(const pd_t *apd) - // : primitive_t(apd), acl_obj_(std::make_unique>()) {} - acl_gemm_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} + acl_gemm_convolution_fwd_t(const pd_t *apd) + : primitive_t(apd), acl_obj_(std::make_unique>()) {} status_t init(engine_t *engine) override; @@ -65,15 +63,8 @@ struct acl_gemm_convolution_fwd_t : public primitive_t { private: status_t execute_forward(const exec_ctx_t &ctx) const; - - // hot fix solution for stateless API which should be replaced soon. - std::unique_ptr> reinitialize_acl_obj() const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - - // commented due to hot fix solution for stateless API which should be replaced soon. - // std::unique_ptr> acl_obj_; - + std::unique_ptr> acl_obj_; }; // acl_gemm_convolution_fwd_t } // namespace aarch64 diff --git a/src/cpu/aarch64/acl_winograd_convolution.cpp b/src/cpu/aarch64/acl_winograd_convolution.cpp index ebdb99f50ed..f23be5bd17d 100644 --- a/src/cpu/aarch64/acl_winograd_convolution.cpp +++ b/src/cpu/aarch64/acl_winograd_convolution.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates +* Copyright 2020-2023, 2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,144 +14,30 @@ * limitations under the License. *******************************************************************************/ -#include "acl_winograd_convolution.hpp" -#include "common/memory_tracking.hpp" -#include "common/utils.hpp" +#include "cpu/aarch64/acl_winograd_convolution.hpp" namespace dnnl { namespace impl { namespace cpu { namespace aarch64 { - -namespace { using data_t = prec_traits::type; -// Keys are anonymous. So deduce the type automagically. -using conv_key_t = decltype(memory_tracking::names::key_gemm_tmp_buffer); - -// Map: [slot , key] -const std::map wino_conv_keys - = {{0, conv_key_t::key_gemm_asm_tmp_buffer}, - {1, conv_key_t::key_gemm_pretranspose_b}, - {2, conv_key_t::key_gemm_pretranspose}, - {3, conv_key_t::key_gemm_interleaved_lhs}, - {4, conv_key_t::key_gemm_pretransposed_rhs}, - {5, conv_key_t::key_gemm_transposed_1xwrhs}, - {6, conv_key_t::key_gemm_tmp_buffer}, - {7, conv_key_t::key_conv_permuted_outputs}, - {8, conv_key_t::key_conv_permuted_inputs}, - {9, conv_key_t::key_wino_workspace}, - {10, conv_key_t::key_wino_transformed_weights}, - {11, conv_key_t::key_conv_permuted_weights}}; -} // namespace - -status_t acl_wino_convolution_fwd_t::pd_t::init(engine_t *engine) { - using namespace data_type; - const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) - && attr()->has_default_values( - primitive_attr_t::skip_mask_t::post_ops, f16); - const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) - && attr()->has_default_values( - primitive_attr_t::skip_mask_t::post_ops, f32); - bool ok = is_fwd() - && utils::one_of(desc()->alg_kind, alg_kind::convolution_auto, - alg_kind::convolution_winograd) - && utils::one_of(true, is_fp16_ok, is_fp32_ok) - && !has_zero_dim_memory(); - - ok = ok && DNNL_CPU_THREADING_RUNTIME != DNNL_RUNTIME_THREADPOOL; - if (!ok) return status::unimplemented; - - CHECK(init_conf()); - - set_default_alg_kind(alg_kind::convolution_winograd); - - Op conv; - conv.configure(&acp_.src_tensor_info, &acp_.wei_tensor_info, - acp_.with_bias ? &acp_.bia_tensor_info : nullptr, - &acp_.dst_tensor_info, acp_.padstride_info, acp_.act_info, - true); // to support 5x5, 7x7 filter shapes in addition to 3x3 - - auto scratchpad = scratchpad_registry().registrar(); - const auto aux_mem = conv.workspace(); - return init_scratchpad(conv, scratchpad, wino_conv_keys, engine, post_ops, - attr_.post_ops_, acp_.act_info, acp_.use_dst_acc_for_sum, dst_md_); -} - -status_t acl_wino_convolution_fwd_t::init(engine_t *engine) { - // commented due to hot fix solution for stateless API which should be replaced soon. - // auto acp = pd()->acp_; - // acl_obj_->conv.configure(&acp.src_tensor_info, &acp.wei_tensor_info, - // acp.with_bias ? &acp.bia_tensor_info : nullptr, - // &acp.dst_tensor_info, acp.padstride_info, acp.act_info, - // true); // to support 5x5, 7x7 filter shapes in addition to 3x3 - - // acl_obj_->aux_mem_req = acl_obj_->conv.workspace(); - return status::success; -} - -status_t acl_wino_convolution_fwd_t::pd_t::init_conf() { - - // Under these conditions, fallback to faster GEMM-based convolution - // unless the user explicitly specifies Winograd algorithm - if (utils::one_of(true, src_md_.dims[2] > 112, // ih - src_md_.dims[3] > 112, // iw - src_md_.dims[1] < 64, // ic - dst_md_.dims[1]<64, // oc - dnnl_get_max_threads()> 28) - && desc()->alg_kind == alg_kind::convolution_auto) { - return status::unimplemented; - } - - // General Compute Library checks, memory tags are also set there - acp_.alg_winograd = true; - CHECK(acl_convolution_utils::acl_init_conf( - acp_, src_md_, weights_md_, dst_md_, bias_md_, *desc(), *attr())); - - const bool shape_ok - // only unit strides allowed - = (acp_.padstride_info.stride() - == std::pair {1, 1}) - // Note: Compute Library supports arbitrary padding for wino kernels - // but we only allow small padding to be consistent with oneDNN - && (acp_.padstride_info.pad().first <= 1) // padding left/right - && (acp_.padstride_info.pad().second <= 1) // padding top/bottom - // only non-dilated convolutions allowed - && (acp_.dilation_info == arm_compute::Size2D(1, 1)); - - ACL_CHECK_SUPPORT(!shape_ok, "shape not supported by winograd kernels"); - - // Validate convolution manually to check for return status - ACL_CHECK_VALID(Op::validate(&acp_.src_tensor_info, &acp_.wei_tensor_info, - acp_.with_bias ? &acp_.bia_tensor_info : nullptr, - &acp_.dst_tensor_info, acp_.padstride_info, acp_.act_info, - true)); // enable_fast_math flag in ACL Winograd - - return status::success; -} - -std::unique_ptr> -acl_wino_convolution_fwd_t::reinitialize_acl_obj() const { - auto acp = pd()->acp_; - std::unique_ptr> acl_obj = std::make_unique>(); - acl_obj->conv.configure(&acp.src_tensor_info, &acp.wei_tensor_info, - acp.with_bias ? &acp.bia_tensor_info : nullptr, - &acp.dst_tensor_info, acp.padstride_info, acp.act_info, - true); // to support 5x5, 7x7 filter shapes in addition to 3x3 - - acl_obj->aux_mem_req = acl_obj->conv.workspace(); - return acl_obj; -} - status_t acl_wino_convolution_fwd_t::execute_forward( const exec_ctx_t &ctx) const { - // Temporary hotfix: We're using a local acl_obj instance in this method - // instead of the class member acl_obj_. This hotfix is to bypass persistent aux mem requirements but is not the ideal solution. - // It should be refactored or removed in the future when a more permanent fix is implemented. - const auto acl_obj = reinitialize_acl_obj(); - return execute_forward_conv_acl, pd_t, data_t>( - ctx, acl_obj.get(), pd(), wino_conv_keys); + // Lock here is needed because resource_mapper does not support + // concurrent multithreaded access. + std::lock_guard _lock {this->mtx}; + // Retrieve primitive resource and configured Compute Library objects + auto *acl_resource + = ctx.get_resource_mapper()->get(this); + acl_obj_t &acl_wino_obj + = acl_resource->get_acl_obj(); + + return execute_forward_conv_acl< + acl_obj_t, pd_t, data_t>( + ctx, acl_wino_obj, pd()); } + } // namespace aarch64 } // namespace cpu } // namespace impl diff --git a/src/cpu/aarch64/acl_winograd_convolution.hpp b/src/cpu/aarch64/acl_winograd_convolution.hpp index cfee93d3ffe..155517cdab2 100644 --- a/src/cpu/aarch64/acl_winograd_convolution.hpp +++ b/src/cpu/aarch64/acl_winograd_convolution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2020-2024 Arm Ltd. and affiliates +* Copyright 2020-2025 Arm Ltd. and affiliates * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,52 +19,129 @@ #include "cpu/cpu_convolution_pd.hpp" -#include "acl_convolution_utils.hpp" -#include "arm_compute/runtime/experimental/operators/CpuWinogradConv2d.h" +#include "cpu/aarch64/acl_convolution_utils.hpp" namespace dnnl { namespace impl { namespace cpu { namespace aarch64 { -struct acl_wino_convolution_fwd_t : public primitive_t { - using Op = arm_compute::experimental::op::CpuWinogradConv2d; +struct acl_wino_resource_t : public resource_t { + acl_wino_resource_t() + : acl_wino_obj_(utils::make_unique< + acl_obj_t>()) {} + + status_t configure(const acl_conv_conf_t &acp) { + if (!acl_wino_obj_) return status::out_of_memory; + + // Init Compute Library tensors based on info from descriptor + acl_wino_obj_->src_tensor.allocator()->init(acp.src_tensor_info); + acl_wino_obj_->wei_tensor.allocator()->init(acp.wei_tensor_info); + acl_wino_obj_->dst_tensor.allocator()->init(acp.dst_tensor_info); + acl_wino_obj_->bia_tensor.allocator()->init(acp.bia_tensor_info); + + // clang-format off + acl_wino_obj_->conv.configure( + &acl_wino_obj_->src_tensor, + &acl_wino_obj_->wei_tensor, + acp.with_bias ? &acl_wino_obj_->bia_tensor : nullptr, + &acl_wino_obj_->dst_tensor, + acp.padstride_info, + acp.act_info, + true); // to support 5x5, 7x7 filter shapes in addition to 3x3 + // clang-format on + + return status::success; + } + + acl_obj_t &get_acl_obj() const { + return *acl_wino_obj_; + } + + DNNL_DISALLOW_COPY_AND_ASSIGN(acl_wino_resource_t); +private: + std::unique_ptr> + acl_wino_obj_; +}; // acl_wino_resource_t + +struct acl_wino_convolution_fwd_t : public primitive_t { struct pd_t : public cpu_convolution_fwd_pd_t { using cpu_convolution_fwd_pd_t::cpu_convolution_fwd_pd_t; DECLARE_COMMON_PD_T( "wino:acl", acl_wino_convolution_fwd_t, USE_GLOBAL_SCRATCHPAD); - status_t init(engine_t *engine); + status_t init(engine_t *engine) { + using namespace data_type; + const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) + && attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, f16); + const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) + && attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, f32); + bool ok = is_fwd() + && utils::one_of(desc()->alg_kind, + alg_kind::convolution_auto, + alg_kind::convolution_winograd) + && utils::one_of(true, is_fp16_ok, is_fp32_ok) + && !has_zero_dim_memory(); + + ok = ok && DNNL_CPU_THREADING_RUNTIME != DNNL_RUNTIME_THREADPOOL; + if (!ok) return status::unimplemented; + + CHECK(acl_convolution_utils::init_conf_wino(acp_, src_md_, + weights_md_, dst_md_, bias_md_, *desc(), *attr())); + + set_default_alg_kind(alg_kind::convolution_winograd); + + CHECK(post_ops.init( + engine, attr_.post_ops_, dst_md_, acp_.act_info)); + acp_.use_dst_acc_for_sum = post_ops.has_sum(); + + if (acp_.use_dst_acc_for_sum) { + const memory_desc_wrapper dst_d(&dst_md_); + auto scratchpad = scratchpad_registry().registrar(); + scratchpad.book(memory_tracking::names::key_generic_acc, + dst_d.nelems(), dst_d.data_type_size()); + } + + return status::success; + } acl_conv_conf_t acp_ = utils::zero(); acl_post_ops_t post_ops; - - private: - status_t init_conf(); }; - // hot fix solution for stateless API which should be replaced soon. - // acl_wino_convolution_fwd_t(const pd_t *apd) - // : primitive_t(apd), acl_obj_(std::make_unique>()) {} acl_wino_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} - status_t init(engine_t *engine) override; + status_t create_resource( + engine_t *engine, resource_mapper_t &mapper) const override { + if (mapper.has_resource(this)) return status::success; + + auto r = utils::make_unique(); + if (!r) return status::out_of_memory; + + // Configure the resource based on information from primitive descriptor + CHECK(r->configure(pd()->acp_)); + mapper.add(this, std::move(r)); + + return status::success; + } + + ~acl_wino_convolution_fwd_t() override = default; + + using data_t = typename prec_traits::type; status_t execute(const exec_ctx_t &ctx) const override { return execute_forward(ctx); } private: + // To guard the const execute_forward(), the mutex must be 'mutable' + mutable std::mutex mtx; status_t execute_forward(const exec_ctx_t &ctx) const; - - // hot fix solution for stateless API which should be replaced soon. - std::unique_ptr> reinitialize_acl_obj() const; - const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - // commented due to hot fix solution for stateless API which should be replaced soon. - // std::unique_ptr> acl_obj_; }; // acl_wino_convolution_fwd_t } // namespace aarch64