diff --git a/fft/src/KokkosFFT_ROCM_plans.hpp b/fft/src/KokkosFFT_ROCM_plans.hpp index 25853b58..d21ffb63 100644 --- a/fft/src/KokkosFFT_ROCM_plans.hpp +++ b/fft/src/KokkosFFT_ROCM_plans.hpp @@ -5,84 +5,65 @@ #ifndef KOKKOSFFT_ROCM_PLANS_HPP #define KOKKOSFFT_ROCM_PLANS_HPP -<<<<<<< HEAD #include #include #include - ======= ->>>>>>> main #include "KokkosFFT_ROCM_types.hpp" #include "KokkosFFT_Extents.hpp" #include "KokkosFFT_traits.hpp" #include "KokkosFFT_asserts.hpp" #include "KokkosFFT_utils.hpp" - namespace KokkosFFT { - namespace Impl { +namespace KokkosFFT { +namespace Impl { - // batched transform, over ND Views - template , - std::nullptr_t> = nullptr> - auto create_plan(const ExecutionSpace& exec_space, - std::unique_ptr& plan, const InViewType& in, - const OutViewType& out, Direction direction, - axis_type axes, shape_type s, - bool is_inplace) { - static_assert( - KokkosFFT::Impl::are_operatable_views_v, - "create_plan: InViewType and OutViewType must have the same base " - "floating point type (float/double), the same layout " - "(LayoutLeft/LayoutRight), " - "and the same rank. ExecutionSpace must be accessible to the data in " - "InViewType and OutViewType."); +// batched transform, over ND Views +template , + std::nullptr_t> = nullptr> +auto create_plan(const ExecutionSpace& exec_space, + std::unique_ptr& plan, const InViewType& in, + const OutViewType& out, Direction direction, + axis_type axes, shape_type s, + bool is_inplace) { + static_assert( + KokkosFFT::Impl::are_operatable_views_v, + "create_plan: InViewType and OutViewType must have the same base " + "floating point type (float/double), the same layout " + "(LayoutLeft/LayoutRight), " + "and the same rank. ExecutionSpace must be accessible to the data in " + "InViewType and OutViewType."); - static_assert(InViewType::rank() >= fft_rank, - "KokkosFFT::create_plan: Rank of View must be larger than " - "Rank of FFT."); + static_assert(InViewType::rank() >= fft_rank, + "KokkosFFT::create_plan: Rank of View must be larger than " + "Rank of FFT."); - using in_value_type = typename InViewType::non_const_value_type; - using out_value_type = typename OutViewType::non_const_value_type; + using in_value_type = typename InViewType::non_const_value_type; + using out_value_type = typename OutViewType::non_const_value_type; - Kokkos::Profiling::ScopedRegion region( - "KokkosFFT::create_plan[TPL_rocfft]"); + Kokkos::Profiling::ScopedRegion region("KokkosFFT::create_plan[TPL_rocfft]"); - constexpr auto type = - KokkosFFT::Impl::transform_type::type(); - auto [in_extents, out_extents, fft_extents, howmany] = - KokkosFFT::Impl::get_extents(in, out, axes, s, is_inplace); + constexpr auto type = + KokkosFFT::Impl::transform_type::type(); + auto [in_extents, out_extents, fft_extents, howmany] = + KokkosFFT::Impl::get_extents(in, out, axes, s, is_inplace); - // Create a plan - plan = - std::make_unique(type, in_extents, out_extents, fft_extents, - howmany, direction, is_inplace); - plan->commit(exec_space); + // Create a plan + plan = std::make_unique(type, in_extents, out_extents, fft_extents, + howmany, direction, is_inplace); + plan->commit(exec_space); - // Calculate the total size of the FFT - int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, - std::multiplies<>()); + // Calculate the total size of the FFT + int fft_size = std::accumulate(fft_extents.begin(), fft_extents.end(), 1, + std::multiplies<>()); - return fft_size; - } + return fft_size; +} -<<<<<<< HEAD - template , - std::nullptr_t> = nullptr> - void destroy_plan_and_info(std::unique_ptr& plan, - InfoType& execution_info) { - Kokkos::Profiling::ScopedRegion region( - "KokkosFFT::destroy_plan[TPL_rocfft]"); - - rocfft_execution_info_destroy(execution_info); - rocfft_plan_destroy(*plan); - } -======= ->>>>>>> main - } // namespace Impl +} // namespace Impl } // namespace KokkosFFT #endif