diff --git a/src/portfft/committed_descriptor_impl.hpp b/src/portfft/committed_descriptor_impl.hpp index 1366597c..d985e816 100644 --- a/src/portfft/committed_descriptor_impl.hpp +++ b/src/portfft/committed_descriptor_impl.hpp @@ -46,8 +46,7 @@ namespace detail { template class committed_descriptor_impl; -template +template std::vector compute_level( const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, @@ -64,14 +63,13 @@ sycl::event transpose_level(const typename committed_descriptor_impl +template class workitem_kernel; -template +template class subgroup_kernel; -template +template class workgroup_kernel; -template +template class global_kernel; template class transpose_kernel; @@ -85,8 +83,7 @@ class transpose_kernel; template class committed_descriptor_impl { friend struct descriptor; - template + template friend std::vector detail::compute_level( const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, Scalar1* output, const TIn& input_imag, Scalar1* output_imag, const Scalar1* twiddles_ptr, @@ -184,38 +181,17 @@ class committed_descriptor_impl { } } - template + template auto dispatch(detail::level level, Args&&... args) { switch (level) { case detail::level::WORKITEM: - return Impl::template inner::execute(*this, args...); + return Impl::template inner::execute(*this, args...); case detail::level::SUBGROUP: - return Impl::template inner::execute(*this, args...); + return Impl::template inner::execute(*this, args...); case detail::level::WORKGROUP: - return Impl::template inner::execute(*this, args...); + return Impl::template inner::execute(*this, args...); case detail::level::GLOBAL: - return Impl::template inner::execute(*this, args...); - default: - // This should be unreachable - throw unsupported_configuration("Unimplemented"); - } - } - - template - auto dispatch(detail::level level, Args&&... args) { - switch (level) { - case detail::level::WORKITEM: - return Impl::template inner::execute(*this, - args...); - case detail::level::SUBGROUP: - return Impl::template inner::execute(*this, - args...); - case detail::level::WORKGROUP: - return Impl::template inner::execute( - *this, args...); - case detail::level::GLOBAL: - return Impl::template inner::execute(*this, - args...); + return Impl::template inner::execute(*this, args...); default: // This should be unreachable throw unsupported_configuration("Unimplemented"); @@ -271,10 +247,11 @@ class committed_descriptor_impl { Idx factor_sg_m = detail::factorize_sg(m, SubgroupSize); Idx factor_wi_m = m / factor_sg_m; Idx temp_num_sgs_in_wg; - std::size_t local_memory_usage = num_scalars_in_local_mem( - detail::level::WORKGROUP, static_cast(fft_size), SubgroupSize, - {factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg) * - sizeof(Scalar); + std::size_t local_memory_usage = + num_scalars_in_local_mem(detail::level::WORKGROUP, static_cast(fft_size), SubgroupSize, + {factor_sg_n, factor_wi_n, factor_sg_m, factor_wi_m}, temp_num_sgs_in_wg, + layout::PACKED) * + sizeof(Scalar); // Checks for PACKED layout only at the moment, as the other layout will not be supported // by the global implementation. For such sizes, only PACKED layout will be supported if (detail::fits_in_wi(factor_wi_n) && detail::fits_in_wi(factor_wi_m) && @@ -307,20 +284,13 @@ class committed_descriptor_impl { IdxGlobal factor_sg = detail::factorize_sg(factor_size, SubgroupSize); IdxGlobal factor_wi = factor_size / factor_sg; if (detail::can_cast_safely(factor_sg) && detail::can_cast_safely(factor_wi)) { - if (batch_interleaved_layout) { - return (2 * - num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, - {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) * - sizeof(Scalar) + - 2 * static_cast(factor_size) * sizeof(Scalar)) < - static_cast(local_memory_size); - } - return (num_scalars_in_local_mem( - detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, - {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg) * - sizeof(Scalar) + - 2 * static_cast(factor_size) * sizeof(Scalar)) < + std::size_t input_scalars = + num_scalars_in_local_mem(detail::level::SUBGROUP, static_cast(factor_size), SubgroupSize, + {static_cast(factor_sg), static_cast(factor_wi)}, temp_num_sgs_in_wg, + batch_interleaved_layout ? layout::BATCH_INTERLEAVED : layout::PACKED); + std::size_t store_modifiers = batch_interleaved_layout ? input_scalars : 0; + std::size_t twiddle_scalars = 2 * static_cast(factor_size); + return (sizeof(Scalar) * (input_scalars + store_modifiers + twiddle_scalars)) < static_cast(local_memory_size); } return false; @@ -416,10 +386,10 @@ class committed_descriptor_impl { */ struct num_scalars_in_local_mem_struct { // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class - template + template struct inner { static std::size_t execute(committed_descriptor_impl& desc, std::size_t length, Idx used_sg_size, - const std::vector& factors, Idx& num_sgs_per_wg); + const std::vector& factors, Idx& num_sgs_per_wg, layout input_layout); }; }; @@ -432,13 +402,14 @@ class committed_descriptor_impl { * @param used_sg_size subgroup size the kernel will use * @param factors factorization of the FFT size the kernel will use * @param[out] num_sgs_per_wg number of subgroups in a workgroup + * @param input_layout the layout of the input data of the transforms * @return the number of scalars */ - template std::size_t num_scalars_in_local_mem(detail::level level, std::size_t length, Idx used_sg_size, - const std::vector& factors, Idx& num_sgs_per_wg) { + const std::vector& factors, Idx& num_sgs_per_wg, layout input_layout) { PORTFFT_LOG_FUNCTION_ENTRY(); - return dispatch(level, length, used_sg_size, factors, num_sgs_per_wg); + return dispatch(level, length, used_sg_size, factors, num_sgs_per_wg, + input_layout); } /** @@ -953,9 +924,9 @@ class committed_descriptor_impl { std::size_t outer_size = total_size / params.lengths.back(); PORTFFT_LOG_TRACE("Dispatching the kernel for the last dimension"); - sycl::event previous_event = dispatch_kernel_1d( - in, out, in_imag, out_imag, dependencies, params.number_of_transforms * outer_size, input_layout, output_layout, - input_offset, output_offset, dimensions.back(), compute_direction); + sycl::event previous_event = + dispatch_kernel_1d(in, out, in_imag, out_imag, dependencies, params.number_of_transforms * outer_size, + input_layout, input_offset, output_offset, dimensions.back(), compute_direction); if (n_dimensions == 1) { return previous_event; } @@ -971,8 +942,8 @@ class committed_descriptor_impl { for (std::size_t j = 0; j < params.number_of_transforms * outer_size; j++) { sycl::event e = dispatch_kernel_1d( out, out, out_imag, out_imag, previous_events, inner_size, layout::BATCH_INTERLEAVED, - layout::BATCH_INTERLEAVED, output_offset + j * stride_between_kernels, - output_offset + j * stride_between_kernels, dimensions[i], compute_direction); + output_offset + j * stride_between_kernels, output_offset + j * stride_between_kernels, dimensions[i], + compute_direction); next_events.push_back(e); } inner_size *= params.lengths[i]; @@ -998,7 +969,6 @@ class committed_descriptor_impl { * @param dependencies events that must complete before the computation * @param n_transforms number of FT transforms to do in one call * @param input_layout the layout of the input data of the transforms - * @param output_layout the layout of the output data of the transforms * @param input_offset offset into input allocation where the data for FFTs start * @param output_offset offset into output allocation where the data for FFTs start * @param dimension_data data for the dimension this call will work on @@ -1008,13 +978,12 @@ class committed_descriptor_impl { template sycl::event dispatch_kernel_1d(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, std::size_t n_transforms, - layout input_layout, layout output_layout, std::size_t input_offset, - std::size_t output_offset, dimension_struct& dimension_data, - direction compute_direction) { + layout input_layout, std::size_t input_offset, std::size_t output_offset, + dimension_struct& dimension_data, direction compute_direction) { PORTFFT_LOG_FUNCTION_ENTRY(); return dispatch_kernel_1d_helper( - in, out, in_imag, out_imag, dependencies, n_transforms, input_layout, output_layout, input_offset, - output_offset, dimension_data, compute_direction); + in, out, in_imag, out_imag, dependencies, n_transforms, input_layout, input_offset, output_offset, + dimension_data, compute_direction); } /** @@ -1035,7 +1004,6 @@ class committed_descriptor_impl { * @param dependencies events that must complete before the computation * @param n_transforms number of FT transforms to do in one call * @param input_layout the layout of the input data of the transforms - * @param output_layout the layout of the output data of the transforms * @param input_offset offset into input allocation where the data for FFTs start * @param output_offset offset into output allocation where the data for FFTs start * @param dimension_data data for the dimension this call will work on @@ -1045,21 +1013,18 @@ class committed_descriptor_impl { template sycl::event dispatch_kernel_1d_helper(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, std::size_t n_transforms, - layout input_layout, layout output_layout, std::size_t input_offset, - std::size_t output_offset, dimension_struct& dimension_data, - direction compute_direction) { + layout input_layout, std::size_t input_offset, std::size_t output_offset, + dimension_struct& dimension_data, direction compute_direction) { PORTFFT_LOG_FUNCTION_ENTRY(); if (SubgroupSize == dimension_data.used_sg_size) { const bool input_batch_interleaved = input_layout == layout::BATCH_INTERLEAVED; - const bool output_batch_interleaved = output_layout == layout::BATCH_INTERLEAVED; for (kernel_data_struct kernel_data : dimension_data.forward_kernels) { - std::size_t minimum_local_mem_required; if (input_batch_interleaved) { - minimum_local_mem_required = num_scalars_in_local_mem( - kernel_data.level, kernel_data.length, SubgroupSize, kernel_data.factors, - kernel_data.num_sgs_per_wg) * - sizeof(Scalar); + std::size_t minimum_local_mem_required = + num_scalars_in_local_mem(kernel_data.level, kernel_data.length, SubgroupSize, kernel_data.factors, + kernel_data.num_sgs_per_wg, layout::BATCH_INTERLEAVED) * + sizeof(Scalar); PORTFFT_LOG_TRACE("Local mem required:", minimum_local_mem_required, "B. Available: ", local_memory_size, "B."); if (static_cast(minimum_local_mem_required) > local_memory_size) { @@ -1070,49 +1035,26 @@ class committed_descriptor_impl { } } - // UNPACKED is also being dispatched as PACKED, but kernels that support UNPACKED don't use the layout template - // parameter. - if (!input_batch_interleaved && !output_batch_interleaved) { - return run_kernel( - in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, - compute_direction); - } - if (input_batch_interleaved && !output_batch_interleaved && in != out) { - return run_kernel( - in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, - compute_direction); - } - if (!input_batch_interleaved && output_batch_interleaved && in != out) { - return run_kernel( - in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, - compute_direction); - } - if (input_batch_interleaved && output_batch_interleaved) { - return run_kernel( - in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, output_offset, dimension_data, - compute_direction); - } - throw internal_error("None of the run_kernel functions match the description."); + return run_kernel(in, out, in_imag, out_imag, dependencies, n_transforms, input_offset, + output_offset, dimension_data, compute_direction, input_layout); } if constexpr (sizeof...(OtherSGSizes) == 0) { throw invalid_configuration("None of the compiled subgroup sizes are supported by the device!"); } else { - return dispatch_kernel_1d_helper( - in, out, in_imag, out_imag, dependencies, n_transforms, input_layout, output_layout, input_offset, - output_offset, dimension_data, compute_direction); + return dispatch_kernel_1d_helper(in, out, in_imag, out_imag, dependencies, + n_transforms, input_layout, input_offset, + output_offset, dimension_data, compute_direction); } } /** * Struct for dispatching `run_kernel()` call. * - * @tparam LayoutIn Input Layout - * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup * @tparam TIn Type of the input USM pointer or buffer * @tparam TOut Type of the output USM pointer or buffer */ - template + template struct run_kernel_struct { // Dummy parameter is needed as only partial specializations are allowed without specializing the containing class template @@ -1120,15 +1062,13 @@ class committed_descriptor_impl { static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, std::size_t n_transforms, std::size_t forward_offset, std::size_t backward_offset, - dimension_struct& dimension_data, direction compute_direction); + dimension_struct& dimension_data, direction compute_direction, layout input_layout); }; }; /** * Common interface to run the kernel called by compute_forward and compute_backward * - * @tparam LayoutIn Input Layout - * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup * @tparam TIn Type of the input USM pointer or buffer * @tparam TOut Type of the output USM pointer or buffer @@ -1146,13 +1086,14 @@ class committed_descriptor_impl { * @param output_offset offset into output allocation where the data for FFTs start * @param dimension_data data for the dimension this call will work on * @param compute_direction direction of fft, forward / backward + * @param input_layout the layout of the input data of the transforms * @return sycl::event */ - template + template sycl::event run_kernel(const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, std::size_t n_transforms, std::size_t input_offset, std::size_t output_offset, dimension_struct& dimension_data, - direction compute_direction) { + direction compute_direction, layout input_layout) { PORTFFT_LOG_FUNCTION_ENTRY(); // mixing const and non-const inputs leads to hard-to-debug linking errors, as both use the same kernel name, but // are called from different template instantiations. @@ -1166,11 +1107,11 @@ class committed_descriptor_impl { using TInReinterpret = decltype(detail::reinterpret(in)); using TOutReinterpret = decltype(detail::reinterpret(out)); std::size_t vec_multiplier = params.complex_storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1; - return dispatch>( + return dispatch>( dimension_data.level, detail::reinterpret(in), detail::reinterpret(out), detail::reinterpret(in_imag), detail::reinterpret(out_imag), dependencies, static_cast(n_transforms), static_cast(vec_multiplier * input_offset), - static_cast(vec_multiplier * output_offset), dimension_data, compute_direction); + static_cast(vec_multiplier * output_offset), dimension_data, compute_direction, input_layout); } }; diff --git a/src/portfft/common/global.hpp b/src/portfft/common/global.hpp index dc0b96ce..727d0bed 100644 --- a/src/portfft/common/global.hpp +++ b/src/portfft/common/global.hpp @@ -115,8 +115,6 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors, * Device function responsible for calling the corresponding sub-implementation * * @tparam Scalar Scalar type - * @tparam LayoutIn Input layout - * @tparam LayoutOut Output layout * @tparam SubgroupSize Subgroup size * @param input input pointer * @param output output pointer @@ -134,7 +132,7 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors, * @param global_data global data * @param kh kernel handler */ -template +template PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Scalar* input_imag, Scalar* output_imag, const Scalar* implementation_twiddles, const Scalar* store_modifier_data, Scalar* input_loc, Scalar* twiddles_loc, Scalar* store_modifier_loc, @@ -156,16 +154,16 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc batch_size, global_data, kh, static_cast(nullptr), store_modifier_data, static_cast(nullptr), store_modifier_loc); } else if (level == detail::level::SUBGROUP) { - subgroup_impl( - input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset, - output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data, - kh, static_cast(nullptr), store_modifier_data, static_cast(nullptr), - store_modifier_loc); + subgroup_impl(input + outer_batch_offset, output + outer_batch_offset, + input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc, + twiddles_loc, batch_size, implementation_twiddles, global_data, kh, + static_cast(nullptr), store_modifier_data, + static_cast(nullptr), store_modifier_loc); } else if (level == detail::level::WORKGROUP) { - workgroup_impl( - input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset, - output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data, - kh, static_cast(nullptr), store_modifier_data); + workgroup_impl(input + outer_batch_offset, output + outer_batch_offset, + input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc, + twiddles_loc, batch_size, implementation_twiddles, global_data, kh, + static_cast(nullptr), store_modifier_data); } sycl::group_barrier(global_data.it.get_group()); } @@ -277,8 +275,6 @@ sycl::event transpose_level(const typename committed_descriptor_impl +template std::vector compute_level( const typename committed_descriptor_impl::kernel_data_struct& kd_struct, const TIn& input, Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr, @@ -380,7 +375,7 @@ std::vector compute_level( #endif PORTFFT_LOG_TRACE("Launching kernel for global implementation with global_size", global_range, "local_size", local_range); - cgh.parallel_for>( + cgh.parallel_for>( sycl::nd_range<1>(sycl::range<1>(static_cast(global_range)), sycl::range<1>(static_cast(local_range))), [= @@ -394,11 +389,11 @@ std::vector compute_level( s, global_logging_config, #endif it}; - dispatch_level( - &in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset, - offset_output_imag, subimpl_twiddles, multipliers_between_factors, &loc_for_input[0], - &loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, inner_batches, inclusive_scan, batch_size, - global_data, kh); + dispatch_level(&in_acc_or_usm[0] + input_batch_offset, offset_output, + &in_imag_acc_or_usm[0] + input_batch_offset, offset_output_imag, + subimpl_twiddles, multipliers_between_factors, &loc_for_input[0], + &loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, + inner_batches, inclusive_scan, batch_size, global_data, kh); }); })); } diff --git a/src/portfft/common/workgroup.hpp b/src/portfft/common/workgroup.hpp index 612029f0..25038527 100644 --- a/src/portfft/common/workgroup.hpp +++ b/src/portfft/common/workgroup.hpp @@ -56,7 +56,6 @@ namespace detail { /** * Calculate all dfts in one dimension of the data stored in local memory. * - * @tparam LayoutIn Input Layout * @tparam SubgroupSize Size of the subgroup * @tparam LocalT The type of the local view * @tparam T Scalar type @@ -73,7 +72,7 @@ namespace detail { * @param stride_within_dft Stride between elements of each DFT - also the number of the DFTs in the inner dimension * @param ndfts_in_outer_dimension Number of DFTs in outer dimension * @param storage complex storage: interleaved or split - * @param layout_in Input Layout + * @param input_layout the layout of the input data of the transforms * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation. * @param MultiplyOnStore Whether the input data is multiplied with some data array after fft computation. * @param ApplyScaleFactor Whether or not the scale factor is applied @@ -86,7 +85,7 @@ __attribute__((always_inline)) inline void dimension_dft( LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor, Idx max_num_batches_in_local_mem, Idx batch_num_in_local, const T* load_modifier_data, const T* store_modifier_data, IdxGlobal batch_num_in_kernel, Idx dft_size, Idx stride_within_dft, Idx ndfts_in_outer_dimension, complex_storage storage, - detail::layout layout_in, detail::elementwise_multiply multiply_on_load, + detail::layout input_layout, detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, detail::apply_scale_factor apply_scale_factor, detail::complex_conjugate conjugate_on_load, detail::complex_conjugate conjugate_on_store, global_data_struct<1> global_data) { @@ -149,7 +148,7 @@ __attribute__((always_inline)) inline void dimension_dft( working = working && static_cast(global_data.sg.get_local_linear_id()) < max_working_tid_in_sg; } if (working) { - if (layout_in == detail::layout::BATCH_INTERLEAVED) { + if (input_layout == detail::layout::BATCH_INTERLEAVED) { global_data.log_message_global(__func__, "loading transposed data from local to private memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { detail::strided_view local_view{ @@ -249,7 +248,7 @@ __attribute__((always_inline)) inline void dimension_dft( } } global_data.log_dump_private("data in registers after computation:", priv, 2 * fact_wi); - if (layout_in == detail::layout::BATCH_INTERLEAVED) { + if (input_layout == detail::layout::BATCH_INTERLEAVED) { global_data.log_message_global(__func__, "storing transposed data from private to local memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { detail::strided_view local_view{ @@ -313,7 +312,7 @@ __attribute__((always_inline)) inline void dimension_dft( * @param N Smaller factor of the Problem size * @param M Larger factor of the problem size * @param storage complex storage: interleaved or split - * @param layout_in Whether or not the input is transposed + * @param input_layout the layout of the input data of the transforms * @param multiply_on_load Whether the input data is multiplied with some data array before fft computation. * @param multiply_on_store Whether the input data is multiplied with some data array after fft computation. * @param apply_scale_factor Whether or not the scale factor is applied @@ -325,7 +324,7 @@ template PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor, Idx max_num_batches_in_local_mem, Idx batch_num_in_local, IdxGlobal batch_num_in_kernel, const T* load_modifier_data, const T* store_modifier_data, Idx fft_size, Idx N, Idx M, - complex_storage storage, detail::layout layout_in, + complex_storage storage, detail::layout input_layout, detail::elementwise_multiply multiply_on_load, detail::elementwise_multiply multiply_on_store, detail::apply_scale_factor apply_scale_factor, detail::complex_conjugate conjugate_on_load, @@ -336,14 +335,14 @@ PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T // column-wise DFTs detail::dimension_dft( loc, loc_twiddles + (2 * M), nullptr, 1, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data, - store_modifier_data, batch_num_in_kernel, N, M, 1, storage, layout_in, multiply_on_load, + store_modifier_data, batch_num_in_kernel, N, M, 1, storage, input_layout, multiply_on_load, detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, conjugate_on_load, detail::complex_conjugate::NOT_APPLIED, global_data); sycl::group_barrier(global_data.it.get_group()); // row-wise DFTs, including twiddle multiplications and scaling detail::dimension_dft( loc, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, batch_num_in_local, - load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, storage, layout_in, + load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, storage, input_layout, detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale_factor, detail::complex_conjugate::NOT_APPLIED, conjugate_on_store, global_data); global_data.log_message_global(__func__, "exited"); diff --git a/src/portfft/dispatcher/global_dispatcher.hpp b/src/portfft/dispatcher/global_dispatcher.hpp index 5dd6f56a..8b67e55a 100644 --- a/src/portfft/dispatcher/global_dispatcher.hpp +++ b/src/portfft/dispatcher/global_dispatcher.hpp @@ -216,9 +216,10 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn if (counter < kernels.size() - 1) { kernel_data.local_mem_required = static_cast(1); } else { - kernel_data.local_mem_required = desc.num_scalars_in_local_mem( + kernel_data.local_mem_required = desc.num_scalars_in_local_mem( detail::level::WORKITEM, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factors_idx_global.at(counter))}, num_sgs_in_wg); + kernel_data.used_sg_size, {static_cast(factors_idx_global.at(counter))}, num_sgs_in_wg, + layout::PACKED); } auto [global_range, local_range] = detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::WORKITEM, @@ -231,13 +232,15 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn IdxGlobal factor_sg = detail::factorize_sg(factors_idx_global.at(counter), kernel_data.used_sg_size); IdxGlobal factor_wi = factors_idx_global.at(counter) / factor_sg; if (counter < kernels.size() - 1) { - kernel_data.local_mem_required = desc.num_scalars_in_local_mem( + kernel_data.local_mem_required = desc.num_scalars_in_local_mem( detail::level::SUBGROUP, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factor_wi), static_cast(factor_sg)}, num_sgs_in_wg); + kernel_data.used_sg_size, {static_cast(factor_wi), static_cast(factor_sg)}, num_sgs_in_wg, + layout::BATCH_INTERLEAVED); } else { - kernel_data.local_mem_required = desc.num_scalars_in_local_mem( + kernel_data.local_mem_required = desc.num_scalars_in_local_mem( detail::level::SUBGROUP, static_cast(factors_idx_global.at(counter)), - kernel_data.used_sg_size, {static_cast(factor_wi), static_cast(factor_sg)}, num_sgs_in_wg); + kernel_data.used_sg_size, {static_cast(factor_wi), static_cast(factor_sg)}, num_sgs_in_wg, + layout::PACKED); } auto [global_range, local_range] = detail::get_launch_params(factors_idx_global.at(counter), sub_batches.at(counter), detail::level::SUBGROUP, @@ -278,11 +281,10 @@ struct committed_descriptor_impl::set_spec_constants_struct::inn }; template -template -struct committed_descriptor_impl::num_scalars_in_local_mem_struct::inner { +template +struct committed_descriptor_impl::num_scalars_in_local_mem_struct::inner { static std::size_t execute(committed_descriptor_impl& /*desc*/, std::size_t /*length*/, Idx /*used_sg_size*/, - const std::vector& /*factors*/, Idx& /*num_sgs_per_wg*/) { + const std::vector& /*factors*/, Idx& /*num_sgs_per_wg*/, layout /*input_layout*/) { PORTFFT_LOG_FUNCTION_ENTRY(); // No work required as all work done in calculate_twiddles; return 0; @@ -290,14 +292,14 @@ struct committed_descriptor_impl::num_scalars_in_local_mem_struc }; template -template +template template -struct committed_descriptor_impl::run_kernel_struct::run_kernel_struct::inner { static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data, - direction compute_direction) { + direction compute_direction, layout /*input_layout*/) { PORTFFT_LOG_FUNCTION_ENTRY(); complex_storage storage = desc.params.complex_storage; const IdxGlobal vec_size = storage == complex_storage::INTERLEAVED_COMPLEX ? 2 : 1; @@ -327,8 +329,7 @@ struct committed_descriptor_impl::run_kernel_struct( + l2_events = detail::compute_level( kernel0, in, desc.scratch_ptr_1.get(), in_imag, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, vec_size * static_cast(i) * committed_size + input_offset, committed_size, @@ -344,25 +345,16 @@ struct committed_descriptor_impl::run_kernel_struct(factor_num) == dimension_data.num_factors - 1) { PORTFFT_LOG_TRACE("This is the last kernel"); - l2_events = detail::compute_level( - current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(), - desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, - factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, - static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), - static_cast(factor_num), dimension_data.num_factors, storage, l2_events, desc.queue); - } else { - l2_events = detail::compute_level( - current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(), - desc.scratch_ptr_1.get() + imag_offset, desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, - factors_and_scan, intermediate_twiddles_offset, impl_twiddle_offset, 0, committed_size, - static_cast(max_batches_in_l2), static_cast(num_batches), static_cast(i), - static_cast(factor_num), dimension_data.num_factors, storage, l2_events, desc.queue); - intermediate_twiddles_offset += 2 * current_kernel.batch_size * static_cast(current_kernel.length); - impl_twiddle_offset += - detail::increment_twiddle_offset(current_kernel.level, static_cast(current_kernel.length)); } + l2_events = detail::compute_level( + current_kernel, desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get(), desc.scratch_ptr_1.get() + imag_offset, + desc.scratch_ptr_1.get() + imag_offset, twiddles_ptr, factors_and_scan, intermediate_twiddles_offset, + impl_twiddle_offset, 0, committed_size, static_cast(max_batches_in_l2), + static_cast(num_batches), static_cast(i), static_cast(factor_num), + dimension_data.num_factors, storage, l2_events, desc.queue); + intermediate_twiddles_offset += 2 * current_kernel.batch_size * static_cast(current_kernel.length); + impl_twiddle_offset += + detail::increment_twiddle_offset(current_kernel.level, static_cast(current_kernel.length)); detail::dump_device(desc.queue, "after factor:", desc.scratch_ptr_1.get(), desc.params.number_of_transforms * dimension_data.length * 2, l2_events); } diff --git a/src/portfft/dispatcher/subgroup_dispatcher.hpp b/src/portfft/dispatcher/subgroup_dispatcher.hpp index 700408cf..aebc8629 100644 --- a/src/portfft/dispatcher/subgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/subgroup_dispatcher.hpp @@ -60,8 +60,6 @@ IdxGlobal get_global_size_subgroup(IdxGlobal n_transforms, Idx factor_sg, Idx su /** * Implementation of FFT for sizes that can be done by a subgroup. * - * @tparam LayoutIn Input Layout - * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup * @tparam T type of the scalar used for computations * @param input pointer to global memory containing input data. If complex storage (from @@ -84,22 +82,33 @@ IdxGlobal get_global_size_subgroup(IdxGlobal n_transforms, Idx factor_sg, Idx su * @param loc_load_modifier Pointer to load modifier data in local memory * @param loc_store_modifier Pointer to store modifier data in local memory */ -template +template PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag, T* output_imag, T* loc, T* loc_twiddles, IdxGlobal n_transforms, const T* twiddles, global_data_struct<1> global_data, sycl::kernel_handler& kh, const T* load_modifier_data = nullptr, const T* store_modifier_data = nullptr, T* loc_load_modifier = nullptr, T* loc_store_modifier = nullptr) { - complex_storage storage = kh.get_specialization_constant(); - detail::elementwise_multiply multiply_on_load = kh.get_specialization_constant(); - detail::elementwise_multiply multiply_on_store = kh.get_specialization_constant(); - detail::apply_scale_factor apply_scale_factor = kh.get_specialization_constant(); - detail::complex_conjugate conjugate_on_load = kh.get_specialization_constant(); - detail::complex_conjugate conjugate_on_store = kh.get_specialization_constant(); - T scaling_factor = kh.get_specialization_constant()>(); + const complex_storage storage = kh.get_specialization_constant(); + const detail::elementwise_multiply multiply_on_load = + kh.get_specialization_constant(); + const detail::elementwise_multiply multiply_on_store = + kh.get_specialization_constant(); + const detail::apply_scale_factor apply_scale_factor = + kh.get_specialization_constant(); + const detail::complex_conjugate conjugate_on_load = + kh.get_specialization_constant(); + const detail::complex_conjugate conjugate_on_store = + kh.get_specialization_constant(); + const T scaling_factor = kh.get_specialization_constant()>(); const Idx factor_wi = kh.get_specialization_constant(); const Idx factor_sg = kh.get_specialization_constant(); + const IdxGlobal input_distance = kh.get_specialization_constant(); + const IdxGlobal output_distance = kh.get_specialization_constant(); + + const bool input_batch_interleaved = input_distance == 1; + const bool output_batch_interleaved = output_distance == 1; + global_data.log_message_global(__func__, "entered", "FactorWI", factor_wi, "FactorSG", factor_sg, "n_transforms", n_transforms); const Idx n_reals_per_wi = 2 * factor_wi; @@ -136,7 +145,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag IdxGlobal id_of_fft_in_kernel; IdxGlobal n_ffts_in_kernel; - if (LayoutIn == detail::layout::BATCH_INTERLEAVED) { + if (input_batch_interleaved) { id_of_fft_in_kernel = static_cast(global_data.it.get_group(0) * global_data.it.get_local_range(0)) / 2; n_ffts_in_kernel = static_cast(global_data.it.get_group_range(0)) * local_size / 2; } else { @@ -159,7 +168,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag bool working = subgroup_local_id < max_wis_working && i < n_transforms; Idx n_ffts_worked_on_by_sg = sycl::min(static_cast(n_transforms - i) + id_of_fft_in_sg, n_ffts_per_sg); - if (LayoutIn == detail::layout::BATCH_INTERLEAVED) { + if (input_batch_interleaved) { /** * Codepath taken if the input is transposed * The number of batches that are loaded, is equal to half of the workgroup size. @@ -296,7 +305,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (working_inner) { global_data.log_dump_private("data in registers after scaling:", priv, n_reals_per_wi); } - if (SubgroupSize == factor_sg && LayoutOut == detail::layout::PACKED) { + if (SubgroupSize == factor_sg && !output_batch_interleaved) { if (working_inner) { global_data.log_message_global( __func__, "storing transposed data from private to global memory (SubgroupSize == FactorSG)"); @@ -326,7 +335,7 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (working_inner) { global_data.log_message_global(__func__, "storing transposed data from private to local memory (SubgroupSize != " - "FactorSG or LayoutOut == detail::layout::BATCH_INTERLEAVED)"); + "FactorSG or batch interleaved output layout)"); // Store back to local memory only if (storage == complex_storage::INTERLEAVED_COMPLEX) { detail::strided_view strided_local_view{loc_view, std::array{factor_sg, max_num_batches_local_mem}, @@ -346,13 +355,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag } } sycl::group_barrier(global_data.it.get_group()); - if (SubgroupSize != factor_sg || LayoutOut == detail::layout::BATCH_INTERLEAVED) { + if (SubgroupSize != factor_sg || output_batch_interleaved) { global_data.log_dump_local("computed data in local memory:", loc_view, n_reals_per_wi * factor_sg); // store back all loaded batches at once. - if (LayoutOut == detail::layout::PACKED) { + if (!output_batch_interleaved) { global_data.log_message_global(__func__, "storing transposed data from local to global memory (SubgroupSize != " - "FactorSG) with LayoutOut = detail::layout::PACKED"); + "FactorSG) with packed output layout"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { detail::md_view local_md_view2{loc_view, std::array{2 * max_num_batches_local_mem, 1, 2}}; detail::md_view output_view{output, std::array{2, 1, 2 * fft_size}, i * n_reals_per_fft}; @@ -369,9 +378,8 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag std::array{fft_size, num_batches_in_local_mem}); } } else { - global_data.log_message_global(__func__, - "storing transposed data from local memory to global memory with LayoutOut == " - "detail::layout::BATCH_INTERLEAVED"); + global_data.log_message_global( + __func__, "storing transposed data from local memory to global memory with batch interleaved layout"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { detail::md_view local_md_view2{loc_view, std::array{2 * max_num_batches_local_mem, 1}}; detail::md_view output_view{output, std::array{2 * n_transforms, static_cast(1)}, 2 * i}; @@ -490,13 +498,13 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag if (working) { global_data.log_dump_private("data in registers after scaling:", priv, n_reals_per_wi); } - if (factor_sg == SubgroupSize && LayoutOut == detail::layout::PACKED) { + if (factor_sg == SubgroupSize && !output_batch_interleaved) { // in this case we get fully coalesced memory access even without going through local memory // TODO we may want to tune maximal `FactorSG` for which we use direct stores. if (working) { global_data.log_message_global(__func__, "storing transposed data from private to global memory (FactorSG == " - "SubgroupSize) and LayoutOut == detail::level::PACKED"); + "SubgroupSize) and packed layout"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { detail::strided_view output_view{output, static_cast(factor_sg), i * static_cast(n_reals_per_sg) + @@ -518,10 +526,9 @@ PORTFFT_INLINE void subgroup_impl(const T* input, T* output, const T* input_imag copy_wi(global_data, priv_imag_view, output_imag_view, factor_wi); } } - } else if (LayoutOut == detail::layout::BATCH_INTERLEAVED) { + } else if (output_batch_interleaved) { if (working) { - global_data.log_message_global( - __func__, "Storing data from private to Global with LayoutOut == detail::level::BATCH_INTERLEAVED"); + global_data.log_message_global(__func__, "Storing data from private to Global with batch interleaved layout"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { detail::strided_view output_view{output, std::array{static_cast(factor_sg), n_transforms}, std::array{static_cast(2 * id_of_wi_in_fft), 2 * i}}; @@ -613,14 +620,14 @@ struct committed_descriptor_impl::calculate_twiddles_struct::inn }; template -template +template template -struct committed_descriptor_impl::run_kernel_struct::run_kernel_struct::inner { static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data, - direction compute_direction) { + direction compute_direction, layout input_layout) { PORTFFT_LOG_FUNCTION_ENTRY(); constexpr detail::memory Mem = std::is_pointer_v ? detail::memory::USM : detail::memory::BUFFER; auto& kernel_data = compute_direction == direction::FORWARD ? dimension_data.forward_kernels.at(0) @@ -628,8 +635,9 @@ struct committed_descriptor_impl::run_kernel_struct::execute( - desc, kernel_data.length, kernel_data.used_sg_size, kernel_data.factors, kernel_data.num_sgs_per_wg); + num_scalars_in_local_mem_struct::template inner::execute( + desc, kernel_data.length, kernel_data.used_sg_size, kernel_data.factors, kernel_data.num_sgs_per_wg, + input_layout); std::size_t global_size = static_cast(detail::get_global_size_subgroup( n_transforms, factor_sg, SubgroupSize, kernel_data.num_sgs_per_wg, desc.n_compute_units)); std::size_t twiddle_elements = 2 * kernel_data.length; @@ -648,7 +656,7 @@ struct committed_descriptor_impl::run_kernel_struct>( + cgh.parallel_for>( sycl::nd_range<1>{{global_size}, {static_cast(SubgroupSize * kernel_data.num_sgs_per_wg)}}, [= #ifdef PORTFFT_KERNEL_LOG @@ -662,10 +670,10 @@ struct committed_descriptor_impl::run_kernel_struct( - &in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, - &in_imag_acc_or_usm[0] + input_offset, &out_imag_acc_or_usm[0] + output_offset, &loc[0], - &loc_twiddles[0], n_transforms, twiddles, global_data, kh); + detail::subgroup_impl(&in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, + &in_imag_acc_or_usm[0] + input_offset, + &out_imag_acc_or_usm[0] + output_offset, &loc[0], &loc_twiddles[0], + n_transforms, twiddles, global_data, kh); global_data.log_message_global("Exiting subgroup kernel"); }); }); @@ -687,15 +695,15 @@ struct committed_descriptor_impl::set_spec_constants_struct::inn }; template -template +template struct committed_descriptor_impl::num_scalars_in_local_mem_struct::inner { + Dummy> { static std::size_t execute(committed_descriptor_impl& desc, std::size_t length, Idx used_sg_size, - const std::vector& factors, Idx& num_sgs_per_wg) { + const std::vector& factors, Idx& num_sgs_per_wg, layout input_layout) { PORTFFT_LOG_FUNCTION_ENTRY(); Idx dft_length = static_cast(length); Idx twiddle_bytes = 2 * dft_length * static_cast(sizeof(Scalar)); - if constexpr (LayoutIn == detail::layout::BATCH_INTERLEAVED) { + if (input_layout == detail::layout::BATCH_INTERLEAVED) { Idx padded_fft_bytes = detail::pad_local(2 * dft_length, Idx(1)) * static_cast(sizeof(Scalar)); Idx max_batches_in_local_mem = (desc.local_memory_size - twiddle_bytes) / padded_fft_bytes; Idx batches_per_sg = used_sg_size / 2; @@ -704,15 +712,15 @@ struct committed_descriptor_impl::num_scalars_in_local_mem_struc num_sgs_per_wg = num_sgs_required; Idx num_batches_in_local_mem = used_sg_size * num_sgs_per_wg / 2; return static_cast(detail::pad_local(2 * dft_length * num_batches_in_local_mem, 1)); - } else { - Idx factor_sg = factors[1]; - Idx n_ffts_per_sg = used_sg_size / factor_sg; - Idx num_scalars_per_sg = detail::pad_local(2 * dft_length * n_ffts_per_sg, 1); - Idx max_n_sgs = (desc.local_memory_size - twiddle_bytes) / static_cast(sizeof(Scalar)) / num_scalars_per_sg; - num_sgs_per_wg = std::min(Idx(PORTFFT_SGS_IN_WG), std::max(Idx(1), max_n_sgs)); - Idx res = num_scalars_per_sg * num_sgs_per_wg; - return static_cast(res); } + + Idx factor_sg = factors[1]; + Idx n_ffts_per_sg = used_sg_size / factor_sg; + Idx num_scalars_per_sg = detail::pad_local(2 * dft_length * n_ffts_per_sg, 1); + Idx max_n_sgs = (desc.local_memory_size - twiddle_bytes) / static_cast(sizeof(Scalar)) / num_scalars_per_sg; + num_sgs_per_wg = std::min(Idx(PORTFFT_SGS_IN_WG), std::max(Idx(1), max_n_sgs)); + Idx res = num_scalars_per_sg * num_sgs_per_wg; + return static_cast(res); } }; diff --git a/src/portfft/dispatcher/workgroup_dispatcher.hpp b/src/portfft/dispatcher/workgroup_dispatcher.hpp index bca3ca6b..dbbca454 100644 --- a/src/portfft/dispatcher/workgroup_dispatcher.hpp +++ b/src/portfft/dispatcher/workgroup_dispatcher.hpp @@ -37,37 +37,32 @@ namespace detail { * Calculates the number of batches that will be loaded into local memory at any one time for the work-group * implementation. * - * @tparam LayoutIn The input data layout + * @param is_batch_interleaved is the input data layout batch interleaved * @param workgroup_size The size of the work-group. Must be divisible by 2. */ -template -PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup(Idx workgroup_size) noexcept { - if constexpr (LayoutIn == detail::layout::BATCH_INTERLEAVED) { - return workgroup_size / 2; - } else { - return 1; - } +PORTFFT_INLINE constexpr Idx get_num_batches_in_local_mem_workgroup(bool is_batch_interleaved, + Idx workgroup_size) noexcept { + return is_batch_interleaved ? workgroup_size / 2 : 1; } /** * Calculates the global size needed for given problem. * * @tparam T type of the scalar used for computations - * @tparam LayoutIn The input data layout * @param n_transforms number of transforms * @param subgroup_size size of subgroup used by the compute kernel * @param num_sgs_per_wg number of subgroups in a workgroup * @param n_compute_units number of compute units on target device * @return Number of elements of size T that need to fit into local memory */ -template -IdxGlobal get_global_size_workgroup(IdxGlobal n_transforms, Idx subgroup_size, Idx num_sgs_per_wg, - Idx n_compute_units) { +template +IdxGlobal get_global_size_workgroup(IdxGlobal n_transforms, Idx subgroup_size, Idx num_sgs_per_wg, Idx n_compute_units, + layout input_layout) { PORTFFT_LOG_FUNCTION_ENTRY(); Idx maximum_n_sgs = 8 * n_compute_units * 64; Idx maximum_n_wgs = maximum_n_sgs / num_sgs_per_wg; Idx wg_size = subgroup_size * num_sgs_per_wg; - Idx dfts_per_wg = get_num_batches_in_local_mem_workgroup(wg_size); + Idx dfts_per_wg = get_num_batches_in_local_mem_workgroup(input_layout == layout::BATCH_INTERLEAVED, wg_size); return static_cast(wg_size) * sycl::min(static_cast(maximum_n_wgs), divide_ceil(n_transforms, static_cast(dfts_per_wg))); @@ -76,8 +71,6 @@ IdxGlobal get_global_size_workgroup(IdxGlobal n_transforms, Idx subgroup_size, I /** * Implementation of FFT for sizes that can be done by a workgroup. * - * @tparam LayoutIn Input Layout - * @tparam LayoutOut Output Layout * @tparam SubgroupSize size of the subgroup * @tparam T Scalar type * @@ -98,7 +91,7 @@ IdxGlobal get_global_size_workgroup(IdxGlobal n_transforms, Idx subgroup_size, I * @param load_modifier_data Pointer to the load modifier data in global Memory * @param store_modifier_data Pointer to the store modifier data in global Memory */ -template +template PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_imag, T* output_imag, T* loc, T* loc_twiddles, IdxGlobal n_transforms, const T* twiddles, global_data_struct<1> global_data, sycl::kernel_handler& kh, @@ -112,6 +105,11 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima T scaling_factor = kh.get_specialization_constant()>(); const Idx fft_size = kh.get_specialization_constant(); + const IdxGlobal input_distance = kh.get_specialization_constant(); + const IdxGlobal output_distance = kh.get_specialization_constant(); + + const bool input_batch_interleaved = input_distance == 1; + const bool output_batch_interleaved = output_distance == 1; global_data.log_message_global(__func__, "entered", "fft_size", fft_size, "n_transforms", n_transforms); Idx num_workgroups = static_cast(global_data.it.get_group_range(0)); @@ -128,8 +126,8 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima global2local(global_data, twiddles, loc_twiddles, 2 * (factor_m + factor_n)); global_data.log_dump_local("twiddles loaded to local memory:", loc_twiddles, 2 * (factor_m + factor_n)); - Idx max_num_batches_in_local_mem = - get_num_batches_in_local_mem_workgroup(static_cast(global_data.it.get_local_range(0))); + Idx max_num_batches_in_local_mem = get_num_batches_in_local_mem_workgroup( + input_batch_interleaved, static_cast(global_data.it.get_local_range(0))); IdxGlobal first_batch_start = static_cast(wg_id) * static_cast(max_num_batches_in_local_mem); IdxGlobal num_batches_in_kernel = @@ -139,7 +137,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima for (IdxGlobal batch_start_idx = first_batch_start; batch_start_idx < n_transforms; batch_start_idx += num_batches_in_kernel) { IdxGlobal offset = static_cast(vec_size * fft_size) * batch_start_idx; - if (LayoutIn == detail::layout::BATCH_INTERLEAVED) { + if (input_batch_interleaved) { /** * In the transposed case, the data is laid out in the local memory column-wise, viewing it as a FFT_Size x * WG_SIZE / 2 matrix, Each column contains either the real or the complex component of the batch. Loads WG_SIZE @@ -168,11 +166,11 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima for (Idx sub_batch = 0; sub_batch < num_batches_in_local_mem; sub_batch++) { wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, sub_batch, batch_start_idx, load_modifier_data, store_modifier_data, fft_size, factor_n, - factor_m, storage, LayoutIn, multiply_on_load, multiply_on_store, apply_scale_factor, - conjugate_on_load, conjugate_on_store, global_data); + factor_m, storage, layout::BATCH_INTERLEAVED, multiply_on_load, multiply_on_store, + apply_scale_factor, conjugate_on_load, conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); } - if constexpr (LayoutOut == detail::layout::PACKED) { + if (!output_batch_interleaved) { global_data.log_message_global(__func__, "storing data from local to global memory (with 2 transposes)"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { detail::md_view loc_md_view2{loc_view, std::array{2, 1, 2 * max_num_batches_in_local_mem, @@ -193,7 +191,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima copy_group(global_data, loc_imag_view, output_imag_view, std::array{num_batches_in_local_mem, factor_m, factor_n}); } - } else { // LayoutOut == detail::layout::BATCH_INTERLEAVED + } else { // batch interleaved layout out if (storage == complex_storage::INTERLEAVED_COMPLEX) { detail::md_view loc_md_view2{ loc_view, std::array{2 * max_num_batches_in_local_mem, 2 * max_num_batches_in_local_mem * factor_m, 1}}; @@ -220,7 +218,7 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima } } sycl::group_barrier(global_data.it.get_group()); - } else { // LayoutIn == detail::layout::PACKED + } else { // packed input layout global_data.log_message_global(__func__, "loading non-transposed data from global to local memory"); if (storage == complex_storage::INTERLEAVED_COMPLEX) { global2local(global_data, input, loc_view, 2 * fft_size, offset); @@ -232,12 +230,12 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima sycl::group_barrier(global_data.it.get_group()); wg_dft(loc_view, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, 0, batch_start_idx, load_modifier_data, store_modifier_data, fft_size, factor_n, factor_m, - storage, LayoutIn, multiply_on_load, multiply_on_store, apply_scale_factor, + storage, layout::PACKED, multiply_on_load, multiply_on_store, apply_scale_factor, conjugate_on_load, conjugate_on_store, global_data); sycl::group_barrier(global_data.it.get_group()); global_data.log_message_global(__func__, "storing non-transposed data from local to global memory"); // transposition for WG CT - if (LayoutOut == detail::layout::PACKED) { + if (!output_batch_interleaved) { if (storage == complex_storage::INTERLEAVED_COMPLEX) { detail::md_view local_md_view2{loc_view, std::array{1, 2, 2 * factor_m}}; detail::md_view output_view{output, std::array{1, 2 * factor_n, 2}, offset}; @@ -274,31 +272,27 @@ PORTFFT_INLINE void workgroup_impl(const T* input, T* output, const T* input_ima } template -template +template template -struct committed_descriptor_impl::run_kernel_struct::run_kernel_struct::inner { static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data, - direction compute_direction) { + direction compute_direction, layout input_layout) { PORTFFT_LOG_FUNCTION_ENTRY(); auto& kernel_data = compute_direction == direction::FORWARD ? dimension_data.forward_kernels.at(0) : dimension_data.backward_kernels.at(0); - Idx num_batches_in_local_mem = [=]() { - if constexpr (LayoutIn == detail::layout::BATCH_INTERLEAVED) { - return kernel_data.used_sg_size * PORTFFT_SGS_IN_WG / 2; - } else { - return 1; - } - }(); + Idx num_batches_in_local_mem = + input_layout == layout::BATCH_INTERLEAVED ? kernel_data.used_sg_size * PORTFFT_SGS_IN_WG / 2 : 1; constexpr detail::memory Mem = std::is_pointer_v ? detail::memory::USM : detail::memory::BUFFER; Scalar* twiddles = kernel_data.twiddles_forward.get(); std::size_t local_elements = - num_scalars_in_local_mem_struct::template inner::execute( - desc, kernel_data.length, kernel_data.used_sg_size, kernel_data.factors, kernel_data.num_sgs_per_wg); - std::size_t global_size = static_cast(detail::get_global_size_workgroup( - n_transforms, SubgroupSize, kernel_data.num_sgs_per_wg, desc.n_compute_units)); + num_scalars_in_local_mem_struct::template inner::execute( + desc, kernel_data.length, kernel_data.used_sg_size, kernel_data.factors, kernel_data.num_sgs_per_wg, + input_layout); + std::size_t global_size = static_cast(detail::get_global_size_workgroup( + n_transforms, SubgroupSize, kernel_data.num_sgs_per_wg, desc.n_compute_units, input_layout)); const Idx bank_lines_per_pad = bank_lines_per_pad_wg(2 * static_cast(sizeof(Scalar)) * kernel_data.factors[2] * kernel_data.factors[3]); std::size_t sg_twiddles_offset = static_cast( @@ -316,7 +310,7 @@ struct committed_descriptor_impl::run_kernel_struct>( + cgh.parallel_for>( sycl::nd_range<1>{{global_size}, {static_cast(SubgroupSize * PORTFFT_SGS_IN_WG)}}, [= #ifdef PORTFFT_KERNEL_LOG @@ -330,10 +324,10 @@ struct committed_descriptor_impl::run_kernel_struct( - &in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, - &in_imag_acc_or_usm[0] + input_offset, &out_imag_acc_or_usm[0] + output_offset, &loc[0], - &loc[0] + sg_twiddles_offset, n_transforms, twiddles, global_data, kh); + detail::workgroup_impl(&in_acc_or_usm[0] + input_offset, &out_acc_or_usm[0] + output_offset, + &in_imag_acc_or_usm[0] + input_offset, + &out_imag_acc_or_usm[0] + output_offset, &loc[0], + &loc[0] + sg_twiddles_offset, n_transforms, twiddles, global_data, kh); global_data.log_message_global("Exiting workgroup kernel"); }); }); @@ -353,17 +347,17 @@ struct committed_descriptor_impl::set_spec_constants_struct::inn }; template -template +template struct committed_descriptor_impl::num_scalars_in_local_mem_struct::inner { + Dummy> { static std::size_t execute(committed_descriptor_impl& /*desc*/, std::size_t length, Idx used_sg_size, - const std::vector& factors, Idx& /*num_sgs_per_wg*/) { + const std::vector& factors, Idx& /*num_sgs_per_wg*/, layout input_layout) { PORTFFT_LOG_FUNCTION_ENTRY(); std::size_t n = static_cast(factors[0]) * static_cast(factors[1]); std::size_t m = static_cast(factors[2]) * static_cast(factors[3]); // working memory + twiddles for subgroup impl for the two sizes - Idx num_batches_in_local_mem = - detail::get_num_batches_in_local_mem_workgroup(used_sg_size * PORTFFT_SGS_IN_WG); + Idx num_batches_in_local_mem = detail::get_num_batches_in_local_mem_workgroup( + input_layout == layout::BATCH_INTERLEAVED, used_sg_size * PORTFFT_SGS_IN_WG); return detail::pad_local(static_cast(2 * num_batches_in_local_mem) * length, bank_lines_per_pad_wg(2 * static_cast(sizeof(Scalar)) * m)) + 2 * (m + n); diff --git a/src/portfft/dispatcher/workitem_dispatcher.hpp b/src/portfft/dispatcher/workitem_dispatcher.hpp index 9ab4145f..28b6962b 100644 --- a/src/portfft/dispatcher/workitem_dispatcher.hpp +++ b/src/portfft/dispatcher/workitem_dispatcher.hpp @@ -350,21 +350,22 @@ PORTFFT_INLINE void workitem_impl(const T* input, T* output, const T* input_imag } template -template +template template -struct committed_descriptor_impl::run_kernel_struct::run_kernel_struct::inner { static sycl::event execute(committed_descriptor_impl& desc, const TIn& in, TOut& out, const TIn& in_imag, TOut& out_imag, const std::vector& dependencies, IdxGlobal n_transforms, IdxGlobal input_offset, IdxGlobal output_offset, dimension_struct& dimension_data, - direction compute_direction) { + direction compute_direction, layout input_layout) { PORTFFT_LOG_FUNCTION_ENTRY(); constexpr detail::memory Mem = std::is_pointer_v ? detail::memory::USM : detail::memory::BUFFER; auto& kernel_data = compute_direction == direction::FORWARD ? dimension_data.forward_kernels.at(0) : dimension_data.backward_kernels.at(0); std::size_t local_elements = - num_scalars_in_local_mem_struct::template inner::execute( - desc, kernel_data.length, kernel_data.used_sg_size, kernel_data.factors, kernel_data.num_sgs_per_wg); + num_scalars_in_local_mem_struct::template inner::execute( + desc, kernel_data.length, kernel_data.used_sg_size, kernel_data.factors, kernel_data.num_sgs_per_wg, + input_layout); std::size_t global_size = static_cast(detail::get_global_size_workitem( n_transforms, SubgroupSize, kernel_data.num_sgs_per_wg, desc.n_compute_units)); @@ -381,7 +382,7 @@ struct committed_descriptor_impl::run_kernel_struct>( + cgh.parallel_for>( sycl::nd_range<1>{{global_size}, {static_cast(SubgroupSize * kernel_data.num_sgs_per_wg)}}, [= #ifdef PORTFFT_KERNEL_LOG @@ -418,11 +419,11 @@ struct committed_descriptor_impl::set_spec_constants_struct::inn }; template -template +template struct committed_descriptor_impl::num_scalars_in_local_mem_struct::inner { + Dummy> { static std::size_t execute(committed_descriptor_impl& desc, std::size_t length, Idx used_sg_size, - const std::vector& /*factors*/, Idx& num_sgs_per_wg) { + const std::vector& /*factors*/, Idx& num_sgs_per_wg, layout /*input_layout*/) { PORTFFT_LOG_FUNCTION_ENTRY(); Idx num_scalars_per_sg = detail::pad_local(2 * static_cast(length) * used_sg_size, 1); Idx max_n_sgs = desc.local_memory_size / static_cast(sizeof(Scalar)) / num_scalars_per_sg; diff --git a/src/portfft/utils.hpp b/src/portfft/utils.hpp index bfe518ab..db837e3e 100644 --- a/src/portfft/utils.hpp +++ b/src/portfft/utils.hpp @@ -43,35 +43,23 @@ class transpose_kernel; * @tparam SubgroupSize size of the subgroup * @return vector of kernel ids */ -template