Skip to content

Commit

Permalink
Remove layout templates from src (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
FMarno authored Feb 13, 2024
1 parent 324a88f commit 330f37b
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 312 deletions.
171 changes: 56 additions & 115 deletions src/portfft/committed_descriptor_impl.hpp

Large diffs are not rendered by default.

39 changes: 17 additions & 22 deletions src/portfft/common/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors,
* Device function responsible for calling the corresponding sub-implementation
*
* @tparam Scalar Scalar type
* @tparam LayoutIn Input layout
* @tparam LayoutOut Output layout
* @tparam SubgroupSize Subgroup size
* @param input input pointer
* @param output output pointer
Expand All @@ -134,7 +132,7 @@ PORTFFT_INLINE inline IdxGlobal get_outer_batch_offset(const IdxGlobal* factors,
* @param global_data global data
* @param kh kernel handler
*/
template <typename Scalar, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize>
template <typename Scalar, Idx SubgroupSize>
PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Scalar* input_imag, Scalar* output_imag,
const Scalar* implementation_twiddles, const Scalar* store_modifier_data,
Scalar* input_loc, Scalar* twiddles_loc, Scalar* store_modifier_loc,
Expand All @@ -156,16 +154,16 @@ PORTFFT_INLINE void dispatch_level(const Scalar* input, Scalar* output, const Sc
batch_size, global_data, kh, static_cast<const Scalar*>(nullptr),
store_modifier_data, static_cast<Scalar*>(nullptr), store_modifier_loc);
} else if (level == detail::level::SUBGROUP) {
subgroup_impl<SubgroupSize, LayoutIn, LayoutOut, Scalar>(
input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset,
output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data,
kh, static_cast<const Scalar*>(nullptr), store_modifier_data, static_cast<Scalar*>(nullptr),
store_modifier_loc);
subgroup_impl<SubgroupSize, Scalar>(input + outer_batch_offset, output + outer_batch_offset,
input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc,
twiddles_loc, batch_size, implementation_twiddles, global_data, kh,
static_cast<const Scalar*>(nullptr), store_modifier_data,
static_cast<Scalar*>(nullptr), store_modifier_loc);
} else if (level == detail::level::WORKGROUP) {
workgroup_impl<SubgroupSize, LayoutIn, LayoutOut, Scalar>(
input + outer_batch_offset, output + outer_batch_offset, input_imag + outer_batch_offset,
output_imag + outer_batch_offset, input_loc, twiddles_loc, batch_size, implementation_twiddles, global_data,
kh, static_cast<Scalar*>(nullptr), store_modifier_data);
workgroup_impl<SubgroupSize, Scalar>(input + outer_batch_offset, output + outer_batch_offset,
input_imag + outer_batch_offset, output_imag + outer_batch_offset, input_loc,
twiddles_loc, batch_size, implementation_twiddles, global_data, kh,
static_cast<Scalar*>(nullptr), store_modifier_data);
}
sycl::group_barrier(global_data.it.get_group());
}
Expand Down Expand Up @@ -277,8 +275,6 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
* Prepares the launch of fft compute at a particular level
* @tparam Scalar Scalar type
* @tparam Domain Domain of FFT
* @tparam LayoutIn Input layout
* @tparam LayoutOut output layout
* @tparam SubgroupSize subgroup size
* @tparam TIn input type
* @param kd_struct associated kernel data struct with the factor
Expand All @@ -304,8 +300,7 @@ sycl::event transpose_level(const typename committed_descriptor_impl<Scalar, Dom
* @param queue queue
* @return vector events, one for each batch in l2
*/
template <typename Scalar, domain Domain, detail::layout LayoutIn, detail::layout LayoutOut, Idx SubgroupSize,
typename TIn>
template <typename Scalar, domain Domain, Idx SubgroupSize, typename TIn>
std::vector<sycl::event> compute_level(
const typename committed_descriptor_impl<Scalar, Domain>::kernel_data_struct& kd_struct, const TIn& input,
Scalar* output, const TIn& input_imag, Scalar* output_imag, const Scalar* twiddles_ptr,
Expand Down Expand Up @@ -380,7 +375,7 @@ std::vector<sycl::event> compute_level(
#endif
PORTFFT_LOG_TRACE("Launching kernel for global implementation with global_size", global_range, "local_size",
local_range);
cgh.parallel_for<global_kernel<Scalar, Domain, Mem, LayoutIn, LayoutOut, SubgroupSize>>(
cgh.parallel_for<global_kernel<Scalar, Domain, Mem, SubgroupSize>>(
sycl::nd_range<1>(sycl::range<1>(static_cast<std::size_t>(global_range)),
sycl::range<1>(static_cast<std::size_t>(local_range))),
[=
Expand All @@ -394,11 +389,11 @@ std::vector<sycl::event> compute_level(
s, global_logging_config,
#endif
it};
dispatch_level<Scalar, LayoutIn, LayoutOut, SubgroupSize>(
&in_acc_or_usm[0] + input_batch_offset, offset_output, &in_imag_acc_or_usm[0] + input_batch_offset,
offset_output_imag, subimpl_twiddles, multipliers_between_factors, &loc_for_input[0],
&loc_for_twiddles[0], &loc_for_modifier[0], factors_triple, inner_batches, inclusive_scan, batch_size,
global_data, kh);
dispatch_level<Scalar, SubgroupSize>(&in_acc_or_usm[0] + input_batch_offset, offset_output,
&in_imag_acc_or_usm[0] + input_batch_offset, offset_output_imag,
subimpl_twiddles, multipliers_between_factors, &loc_for_input[0],
&loc_for_twiddles[0], &loc_for_modifier[0], factors_triple,
inner_batches, inclusive_scan, batch_size, global_data, kh);
});
}));
}
Expand Down
17 changes: 8 additions & 9 deletions src/portfft/common/workgroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ namespace detail {
/**
* Calculate all dfts in one dimension of the data stored in local memory.
*
* @tparam LayoutIn Input Layout
* @tparam SubgroupSize Size of the subgroup
* @tparam LocalT The type of the local view
* @tparam T Scalar type
Expand All @@ -73,7 +72,7 @@ namespace detail {
* @param stride_within_dft Stride between elements of each DFT - also the number of the DFTs in the inner dimension
* @param ndfts_in_outer_dimension Number of DFTs in outer dimension
* @param storage complex storage: interleaved or split
* @param layout_in Input Layout
* @param input_layout the layout of the input data of the transforms
* @param multiply_on_load Whether the input data is multiplied with some data array before fft computation.
* @param MultiplyOnStore Whether the input data is multiplied with some data array after fft computation.
* @param ApplyScaleFactor Whether or not the scale factor is applied
Expand All @@ -86,7 +85,7 @@ __attribute__((always_inline)) inline void dimension_dft(
LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor, Idx max_num_batches_in_local_mem,
Idx batch_num_in_local, const T* load_modifier_data, const T* store_modifier_data, IdxGlobal batch_num_in_kernel,
Idx dft_size, Idx stride_within_dft, Idx ndfts_in_outer_dimension, complex_storage storage,
detail::layout layout_in, detail::elementwise_multiply multiply_on_load,
detail::layout input_layout, detail::elementwise_multiply multiply_on_load,
detail::elementwise_multiply multiply_on_store, detail::apply_scale_factor apply_scale_factor,
detail::complex_conjugate conjugate_on_load, detail::complex_conjugate conjugate_on_store,
global_data_struct<1> global_data) {
Expand Down Expand Up @@ -149,7 +148,7 @@ __attribute__((always_inline)) inline void dimension_dft(
working = working && static_cast<Idx>(global_data.sg.get_local_linear_id()) < max_working_tid_in_sg;
}
if (working) {
if (layout_in == detail::layout::BATCH_INTERLEAVED) {
if (input_layout == detail::layout::BATCH_INTERLEAVED) {
global_data.log_message_global(__func__, "loading transposed data from local to private memory");
if (storage == complex_storage::INTERLEAVED_COMPLEX) {
detail::strided_view local_view{
Expand Down Expand Up @@ -249,7 +248,7 @@ __attribute__((always_inline)) inline void dimension_dft(
}
}
global_data.log_dump_private("data in registers after computation:", priv, 2 * fact_wi);
if (layout_in == detail::layout::BATCH_INTERLEAVED) {
if (input_layout == detail::layout::BATCH_INTERLEAVED) {
global_data.log_message_global(__func__, "storing transposed data from private to local memory");
if (storage == complex_storage::INTERLEAVED_COMPLEX) {
detail::strided_view local_view{
Expand Down Expand Up @@ -313,7 +312,7 @@ __attribute__((always_inline)) inline void dimension_dft(
* @param N Smaller factor of the Problem size
* @param M Larger factor of the problem size
* @param storage complex storage: interleaved or split
* @param layout_in Whether or not the input is transposed
* @param input_layout the layout of the input data of the transforms
* @param multiply_on_load Whether the input data is multiplied with some data array before fft computation.
* @param multiply_on_store Whether the input data is multiplied with some data array after fft computation.
* @param apply_scale_factor Whether or not the scale factor is applied
Expand All @@ -325,7 +324,7 @@ template <Idx SubgroupSize, typename LocalT, typename T>
PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T scaling_factor,
Idx max_num_batches_in_local_mem, Idx batch_num_in_local, IdxGlobal batch_num_in_kernel,
const T* load_modifier_data, const T* store_modifier_data, Idx fft_size, Idx N, Idx M,
complex_storage storage, detail::layout layout_in,
complex_storage storage, detail::layout input_layout,
detail::elementwise_multiply multiply_on_load,
detail::elementwise_multiply multiply_on_store,
detail::apply_scale_factor apply_scale_factor, detail::complex_conjugate conjugate_on_load,
Expand All @@ -336,14 +335,14 @@ PORTFFT_INLINE void wg_dft(LocalT loc, T* loc_twiddles, const T* wg_twiddles, T
// column-wise DFTs
detail::dimension_dft<SubgroupSize, LocalT, T>(
loc, loc_twiddles + (2 * M), nullptr, 1, max_num_batches_in_local_mem, batch_num_in_local, load_modifier_data,
store_modifier_data, batch_num_in_kernel, N, M, 1, storage, layout_in, multiply_on_load,
store_modifier_data, batch_num_in_kernel, N, M, 1, storage, input_layout, multiply_on_load,
detail::elementwise_multiply::NOT_APPLIED, detail::apply_scale_factor::NOT_APPLIED, conjugate_on_load,
detail::complex_conjugate::NOT_APPLIED, global_data);
sycl::group_barrier(global_data.it.get_group());
// row-wise DFTs, including twiddle multiplications and scaling
detail::dimension_dft<SubgroupSize, LocalT, T>(
loc, loc_twiddles, wg_twiddles, scaling_factor, max_num_batches_in_local_mem, batch_num_in_local,
load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, storage, layout_in,
load_modifier_data, store_modifier_data, batch_num_in_kernel, M, 1, N, storage, input_layout,
detail::elementwise_multiply::NOT_APPLIED, multiply_on_store, apply_scale_factor,
detail::complex_conjugate::NOT_APPLIED, conjugate_on_store, global_data);
global_data.log_message_global(__func__, "exited");
Expand Down
Loading

0 comments on commit 330f37b

Please sign in to comment.