Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workgroup strided transforms #143

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions src/portfft/committed_descriptor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

#include "common/exceptions.hpp"
#include "common/subgroup.hpp"
#include "common/workgroup.hpp"
#include "defines.hpp"
#include "enums.hpp"
#include "specialization_constant.hpp"
Expand Down Expand Up @@ -234,18 +235,8 @@ class committed_descriptor_impl {
PORTFFT_LOG_TRACE("Prepared subgroup impl with factor_wi:", factor_wi, "and factor_sg:", factor_sg);
return {detail::level::SUBGROUP, {{detail::level::SUBGROUP, ids, factors}}};
}
IdxGlobal n_idx_global = detail::factorize(fft_size);
if (detail::can_cast_safely<IdxGlobal, Idx>(n_idx_global) &&
detail::can_cast_safely<IdxGlobal, Idx>(fft_size / n_idx_global)) {
if (n_idx_global == 1) {
throw unsupported_configuration("FFT size ", fft_size, " : Large Prime sized FFT currently is unsupported");
}
Idx n = static_cast<Idx>(n_idx_global);
Idx m = static_cast<Idx>(fft_size / n_idx_global);
Idx factor_sg_n = detail::factorize_sg(n, SubgroupSize);
Idx factor_wi_n = n / factor_sg_n;
Idx factor_sg_m = detail::factorize_sg(m, SubgroupSize);
Idx factor_wi_m = m / factor_sg_m;
if (auto wg_factorization = detail::factorize_for_wg<Scalar>(fft_size, SubgroupSize); wg_factorization) {
hjabird marked this conversation as resolved.
Show resolved Hide resolved
auto [factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m] = wg_factorization.value();
Idx temp_num_sgs_in_wg;
std::size_t local_memory_usage =
num_scalars_in_local_mem(detail::level::WORKGROUP, static_cast<std::size_t>(fft_size), SubgroupSize,
Expand All @@ -254,8 +245,7 @@ class committed_descriptor_impl {
sizeof(Scalar);
// Checks for PACKED layout only at the moment, as the other layout will not be supported
// by the global implementation. For such sizes, only PACKED layout will be supported
if (detail::fits_in_wi<Scalar>(factor_wi_n) && detail::fits_in_wi<Scalar>(factor_wi_m) &&
(local_memory_usage <= static_cast<std::size_t>(local_memory_size))) {
if (local_memory_usage <= static_cast<std::size_t>(local_memory_size)) {
factors.push_back(factor_wi_n);
factors.push_back(factor_sg_n);
factors.push_back(factor_wi_m);
Expand Down
47 changes: 47 additions & 0 deletions src/portfft/common/workgroup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@
#ifndef PORTFFT_COMMON_WORKGROUP_HPP
#define PORTFFT_COMMON_WORKGROUP_HPP

#include <optional>

#include "helpers.hpp"
#include "logging.hpp"
#include "memory_views.hpp"
#include "portfft/defines.hpp"
#include "portfft/enums.hpp"
#include "portfft/traits.hpp"
#include "portfft/utils.hpp"
#include "subgroup.hpp"
#include "transfers.hpp"

namespace portfft {

Expand All @@ -53,6 +58,48 @@ constexpr T bank_lines_per_pad_wg(T row_size) {
}

namespace detail {

// struct for the result of factorize_for_wg
struct wg_factorization {
Idx factor_wi_n;
Idx factor_sg_n;
Idx factor_wi_m;
Idx factor_sg_m;
};

/**
*
* Calculate a valid factorization for workgroup dfts, assuming there is sufficient local memory.
*
* @tparam Scalar scalar type of the transform data
* @param fft_size the number of elements in the transforms
* @param subgroup_size the size of subgroup used for the transform
*
* @return a factorization for workgroup dft or null if the size won't work with the implemenation of workgroup dfts.
*/
FMarno marked this conversation as resolved.
Show resolved Hide resolved
template <typename Scalar>
inline std::optional<wg_factorization> factorize_for_wg(IdxGlobal fft_size, Idx subgroup_size) {
hjabird marked this conversation as resolved.
Show resolved Hide resolved
IdxGlobal n_idx_global = detail::factorize(fft_size);
hjabird marked this conversation as resolved.
Show resolved Hide resolved
if (n_idx_global == 1) {
throw unsupported_configuration("FFT size ", fft_size, " : Large Prime sized FFT currently is unsupported");
}
hjabird marked this conversation as resolved.
Show resolved Hide resolved
IdxGlobal m_idx_global = fft_size / n_idx_global;
if (detail::can_cast_safely<IdxGlobal, Idx>(n_idx_global) && detail::can_cast_safely<IdxGlobal, Idx>(m_idx_global)) {
Idx n = static_cast<Idx>(n_idx_global);
Idx m = static_cast<Idx>(m_idx_global);
Idx factor_sg_n = detail::factorize_sg(n, subgroup_size);
Idx factor_wi_n = n / factor_sg_n;
Idx factor_sg_m = detail::factorize_sg(m, subgroup_size);
Idx factor_wi_m = m / factor_sg_m;

if (fits_in_wi<Scalar>(factor_wi_n) && fits_in_wi<Scalar>(factor_wi_m)) {
return wg_factorization{factor_wi_n, factor_sg_n, factor_wi_m, factor_sg_m};
}
}

return std::nullopt;
}

/**
* Calculate all dfts in one dimension of the data stored in local memory.
*
Expand Down
4 changes: 2 additions & 2 deletions src/portfft/descriptor_validation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <string_view>

#include "common/exceptions.hpp"
#include "common/subgroup.hpp"
#include "common/workgroup.hpp"
#include "enums.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -68,7 +68,7 @@ inline void validate_layout(const std::vector<std::size_t>& lengths, portfft::de
bool fits_subgroup = false;
for (auto sg_size : {PORTFFT_SUBGROUP_SIZES}) {
fits_subgroup =
fits_subgroup || portfft::detail::fits_in_sg<Scalar>(static_cast<IdxGlobal>(lengths.back()), sg_size);
fits_subgroup || portfft::detail::factorize_for_wg<Scalar>(static_cast<IdxGlobal>(lengths.back()), sg_size);
if (fits_subgroup) {
break;
}
Expand Down
Loading
Loading