Skip to content

Commit

Permalink
Make execute a method of Plan class
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Oct 11, 2024
1 parent fae1e28 commit 19fd10d
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 219 deletions.
8 changes: 4 additions & 4 deletions examples/06_1DFFT_reuse_plans/06_1DFFT_reuse_plans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ int main(int argc, char* argv[]) {
int axis = -1;
KokkosFFT::Plan fft_plan(exec, xc2c, xc2c_hat,
KokkosFFT::Direction::forward, axis);
KokkosFFT::execute(fft_plan, xc2c, xc2c_hat);
fft_plan.execute(xc2c, xc2c_hat);

KokkosFFT::Plan ifft_plan(exec, xc2c_hat, xc2c_inv,
KokkosFFT::Direction::backward, axis);
KokkosFFT::execute(ifft_plan, xc2c_hat, xc2c_inv);
ifft_plan.execute(xc2c_hat, xc2c_inv);

// 1D R2C FFT
View1D<double> xr2c("xr2c", n0);
Expand All @@ -42,7 +42,7 @@ int main(int argc, char* argv[]) {

KokkosFFT::Plan rfft_plan(exec, xr2c, xr2c_hat,
KokkosFFT::Direction::forward, axis);
KokkosFFT::execute(rfft_plan, xr2c, xr2c_hat);
rfft_plan.execute(xr2c, xr2c_hat);

// 1D C2R FFT
View1D<Kokkos::complex<double> > xc2r("xc2r_hat", n0 / 2 + 1);
Expand All @@ -51,7 +51,7 @@ int main(int argc, char* argv[]) {

KokkosFFT::Plan irfft_plan(exec, xc2r, xc2r_hat,
KokkosFFT::Direction::backward, axis);
KokkosFFT::execute(irfft_plan, xc2r, xc2r_hat);
irfft_plan.execute(xc2r, xc2r_hat);
exec.fence();
}
Kokkos::finalize();
Expand Down
110 changes: 86 additions & 24 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,45 @@
#include "KokkosFFT_default_types.hpp"
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_transpose.hpp"
#include "KokkosFFT_normalization.hpp"
#include "KokkosFFT_padding.hpp"
#include "KokkosFFT_utils.hpp"

#if defined(KOKKOS_ENABLE_CUDA)
#include "KokkosFFT_Cuda_plans.hpp"
#include "KokkosFFT_Cuda_transform.hpp"
#ifdef ENABLE_HOST_AND_DEVICE
#include "KokkosFFT_Host_plans.hpp"
#include "KokkosFFT_Host_transform.hpp"
#endif
#elif defined(KOKKOS_ENABLE_HIP)
#if defined(KOKKOSFFT_ENABLE_TPL_ROCFFT)
#include "KokkosFFT_ROCM_plans.hpp"
#include "KokkosFFT_ROCM_transform.hpp"
#else
#include "KokkosFFT_HIP_plans.hpp"
#include "KokkosFFT_HIP_transform.hpp"
#endif
#ifdef ENABLE_HOST_AND_DEVICE
#include "KokkosFFT_Host_plans.hpp"
#include "KokkosFFT_Host_transform.hpp"
#endif
#elif defined(KOKKOS_ENABLE_SYCL)
#include "KokkosFFT_SYCL_plans.hpp"
#include "KokkosFFT_SYCL_transform.hpp"
#ifdef ENABLE_HOST_AND_DEVICE
#include "KokkosFFT_Host_plans.hpp"
#include "KokkosFFT_Host_transform.hpp"
#endif
#elif defined(KOKKOS_ENABLE_OPENMP)
#include "KokkosFFT_Host_plans.hpp"
#include "KokkosFFT_Host_transform.hpp"
#elif defined(KOKKOS_ENABLE_THREADS)
#include "KokkosFFT_Host_plans.hpp"
#include "KokkosFFT_Host_transform.hpp"
#else
#include "KokkosFFT_Host_plans.hpp"
#include "KokkosFFT_Host_transform.hpp"
#endif

namespace KokkosFFT {
Expand Down Expand Up @@ -135,12 +146,6 @@ class Plan {
extents_type m_in_extents, m_out_extents;
///@}

//! @{
//! Internal buffers used for transpose
nonConstInViewType m_in_T;
nonConstOutViewType m_out_T;
//! @}

//! Internal work buffer (for rocfft)
BufferViewType m_buffer;

Expand Down Expand Up @@ -270,6 +275,81 @@ class Plan {
Plan& operator=(Plan&&) = delete;
Plan(Plan&&) = delete;

/// \brief Execute FFT on input and output Views with normalization
///
/// \param in [in] Input data
/// \param out [out] Ouput data
template <typename InViewType2, typename OutViewType2>
void execute(
const InViewType2& in, const OutViewType2& out,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward) {
static_assert(
KokkosFFT::Impl::are_operatable_views_v<execSpace, InViewType,
OutViewType>,
"Plan::execute: InViewType and OutViewType must have the same base "
"floating point "
"type (float/double), the same layout (LayoutLeft/LayoutRight), and "
"the "
"same rank. ExecutionSpace must be accessible to the data in "
"InViewType "
"and OutViewType.");

// sanity check that the plan is consistent with the input/output views
good(in, out);

using ManagableInViewType =
typename KokkosFFT::Impl::manageable_view_type<InViewType2>::type;
using ManagableOutViewType =
typename KokkosFFT::Impl::manageable_view_type<OutViewType2>::type;
ManagableInViewType in_s;
InViewType2 in_tmp;
if (m_is_crop_or_pad_needed) {
KokkosFFT::Impl::crop_or_pad(m_exec_space, in, in_s, m_shape);
in_tmp = in_s;
} else {
in_tmp = in;
}

if (m_is_transpose_needed) {
using LayoutType = typename ManagableInViewType::array_layout;
ManagableInViewType const in_T(
"in_T",
KokkosFFT::Impl::create_layout<LayoutType>(
KokkosFFT::Impl::compute_transpose_extents(in_tmp, m_map)));
ManagableOutViewType const out_T(
"out_T", KokkosFFT::Impl::create_layout<LayoutType>(
KokkosFFT::Impl::compute_transpose_extents(out, m_map)));

KokkosFFT::Impl::transpose(m_exec_space, in_tmp, in_T, m_map);
KokkosFFT::Impl::transpose(m_exec_space, out, out_T, m_map);

execute_fft(in_T, out_T, norm);

KokkosFFT::Impl::transpose(m_exec_space, out_T, out, m_map_inv);
} else {
execute_fft(in_tmp, out, norm);
}
}

private:
template <typename InViewType2, typename OutViewType2>
void execute_fft(const InViewType2& in, OutViewType2& out,
KokkosFFT::Normalization norm) {
using in_value_type = typename InViewType2::non_const_value_type;
using out_value_type = typename OutViewType2::non_const_value_type;

auto* idata = reinterpret_cast<typename KokkosFFT::Impl::fft_data_type<
execSpace, in_value_type>::type*>(in.data());
auto* odata = reinterpret_cast<typename KokkosFFT::Impl::fft_data_type<
execSpace, out_value_type>::type*>(out.data());

auto const direction =
KokkosFFT::Impl::direction_type<execSpace>(m_direction);
KokkosFFT::Impl::exec_plan(*m_plan, idata, odata, direction, m_info);
KokkosFFT::Impl::normalize(m_exec_space, out, m_direction, norm,
m_fft_size);
}

/// \brief Sanity check of the plan used to call FFT interface with
/// pre-defined FFT plan. This raises an error if there is an
/// incosistency between FFT function and plan
Expand Down Expand Up @@ -298,24 +378,6 @@ class Plan {
out_extents != m_out_extents,
"extents of output View for plan and execution are not identical.");
}

/// \brief Return the execution space
execSpace const& exec_space() const noexcept { return m_exec_space; }

/// \brief Return the FFT plan
fft_plan_type& plan() const { return *m_plan; }

/// \brief Return the FFT info
fft_info_type const& info() const { return m_info; }

/// \brief Return the FFT size
fft_size_type fft_size() const { return m_fft_size; }
KokkosFFT::Direction direction() const { return m_direction; }
bool is_transpose_needed() const { return m_is_transpose_needed; }
bool is_crop_or_pad_needed() const { return m_is_crop_or_pad_needed; }
extents_type shape() const { return m_shape; }
map_type map() const { return m_map; }
map_type map_inv() const { return m_map_inv; }
};
} // namespace KokkosFFT

Expand Down
125 changes: 7 additions & 118 deletions fft/src/KokkosFFT_Transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,123 +6,12 @@
#define KOKKOSFFT_TRANSFORM_HPP

#include <Kokkos_Core.hpp>
#include "KokkosFFT_default_types.hpp"
#include "KokkosFFT_traits.hpp"
#include "KokkosFFT_utils.hpp"
#include "KokkosFFT_normalization.hpp"
#include "KokkosFFT_transpose.hpp"
#include "KokkosFFT_padding.hpp"
#include "KokkosFFT_utils.hpp"
#include "KokkosFFT_Plans.hpp"

#if defined(KOKKOS_ENABLE_CUDA)
#include "KokkosFFT_Cuda_transform.hpp"
#ifdef ENABLE_HOST_AND_DEVICE
#include "KokkosFFT_Host_transform.hpp"
#endif
#elif defined(KOKKOS_ENABLE_HIP)
#if defined(KOKKOSFFT_ENABLE_TPL_ROCFFT)
#include "KokkosFFT_ROCM_transform.hpp"
#else
#include "KokkosFFT_HIP_transform.hpp"
#endif
#ifdef ENABLE_HOST_AND_DEVICE
#include "KokkosFFT_Host_transform.hpp"
#endif
#elif defined(KOKKOS_ENABLE_SYCL)
#include "KokkosFFT_SYCL_transform.hpp"
#ifdef ENABLE_HOST_AND_DEVICE
#include "KokkosFFT_Host_transform.hpp"
#endif
#elif defined(KOKKOS_ENABLE_OPENMP)
#include "KokkosFFT_Host_transform.hpp"
#elif defined(KOKKOS_ENABLE_THREADS)
#include "KokkosFFT_Host_transform.hpp"
#else
#include "KokkosFFT_Host_transform.hpp"
#endif

#include <type_traits>

// General Transform Interface
namespace KokkosFFT {
namespace Impl {

template <typename PlanType, typename InViewType, typename OutViewType>
void exec_impl(
const PlanType& plan, const InViewType& in, OutViewType& out,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward) {
using in_value_type = typename InViewType::non_const_value_type;
using out_value_type = typename OutViewType::non_const_value_type;
using ExecutionSpace = typename PlanType::execSpace;

auto* idata = reinterpret_cast<typename KokkosFFT::Impl::fft_data_type<
ExecutionSpace, in_value_type>::type*>(in.data());
auto* odata = reinterpret_cast<typename KokkosFFT::Impl::fft_data_type<
ExecutionSpace, out_value_type>::type*>(out.data());

auto const exec_space = plan.exec_space();
auto const direction = direction_type<ExecutionSpace>(plan.direction());
KokkosFFT::Impl::exec_plan(plan.plan(), idata, odata, direction, plan.info());
KokkosFFT::Impl::normalize(exec_space, out, plan.direction(), norm,
plan.fft_size());
}

} // namespace Impl
} // namespace KokkosFFT

namespace KokkosFFT {
template <typename PlanType, typename InViewType, typename OutViewType>
void execute(
const PlanType& plan, const InViewType& in, OutViewType& out,
KokkosFFT::Normalization norm = KokkosFFT::Normalization::backward) {
using ExecutionSpace = typename PlanType::execSpace;
static_assert(
KokkosFFT::Impl::are_operatable_views_v<ExecutionSpace, InViewType,
OutViewType>,
"execute: InViewType and OutViewType must have the same base "
"floating point "
"type (float/double), the same layout (LayoutLeft/LayoutRight), and the "
"same rank. ExecutionSpace must be accessible to the data in InViewType "
"and OutViewType.");

plan.template good<InViewType, OutViewType>(in, out);

const auto exec_space = plan.exec_space();
using ManagableInViewType =
typename KokkosFFT::Impl::manageable_view_type<InViewType>::type;
using ManagableOutViewType =
typename KokkosFFT::Impl::manageable_view_type<OutViewType>::type;
ManagableInViewType _in_s;
InViewType _in;
if (plan.is_crop_or_pad_needed()) {
auto new_shape = plan.shape();
KokkosFFT::Impl::crop_or_pad(exec_space, in, _in_s, new_shape);
_in = _in_s;
} else {
_in = in;
}

if (plan.is_transpose_needed()) {
using LayoutType = typename ManagableInViewType::array_layout;
ManagableInViewType const in_T(
"in_T",
create_layout<LayoutType>(compute_transpose_extents(_in, plan.map())));
ManagableOutViewType const out_T(
"out_T",
create_layout<LayoutType>(compute_transpose_extents(out, plan.map())));

KokkosFFT::Impl::transpose(exec_space, _in, in_T, plan.map());
KokkosFFT::Impl::transpose(exec_space, out, out_T, plan.map());

KokkosFFT::Impl::exec_impl(plan, in_T, out_T, norm);

KokkosFFT::Impl::transpose(exec_space, out_T, out, plan.map_inv());

} else {
KokkosFFT::Impl::exec_impl(plan, _in, out, norm);
}
}

/// \brief One dimensional FFT in forward direction
///
/// \param exec_space [in] Kokkos execution space
Expand Down Expand Up @@ -150,7 +39,7 @@ void fft(const ExecutionSpace& exec_space, const InViewType& in,
"axes are invalid for in/out views");
KokkosFFT::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward, axis,
n);
KokkosFFT::execute(plan, in, out, norm);
plan.execute(in, out, norm);
}

/// \brief One dimensional FFT in backward direction
Expand Down Expand Up @@ -180,7 +69,7 @@ void ifft(const ExecutionSpace& exec_space, const InViewType& in,
"axes are invalid for in/out views");
KokkosFFT::Plan plan(exec_space, in, out, KokkosFFT::Direction::backward,
axis, n);
KokkosFFT::execute(plan, in, out, norm);
plan.execute(in, out, norm);
}

/// \brief One dimensional FFT for real input
Expand Down Expand Up @@ -368,7 +257,7 @@ void fft2(const ExecutionSpace& exec_space, const InViewType& in,
"axes are invalid for in/out views");
KokkosFFT::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward, axes,
s);
KokkosFFT::execute(plan, in, out, norm);
plan.execute(in, out, norm);
}

/// \brief Two dimensional FFT in backward direction
Expand Down Expand Up @@ -398,7 +287,7 @@ void ifft2(const ExecutionSpace& exec_space, const InViewType& in,
"axes are invalid for in/out views");
KokkosFFT::Plan plan(exec_space, in, out, KokkosFFT::Direction::backward,
axes, s);
KokkosFFT::execute(plan, in, out, norm);
plan.execute(in, out, norm);
}

/// \brief Two dimensional FFT for real input
Expand Down Expand Up @@ -513,7 +402,7 @@ void fftn(
"axes are invalid for in/out views");
KokkosFFT::Plan plan(exec_space, in, out, KokkosFFT::Direction::forward, axes,
s);
KokkosFFT::execute(plan, in, out, norm);
plan.execute(in, out, norm);
}

/// \brief N-dimensional FFT in backward direction with a given plan
Expand Down Expand Up @@ -556,7 +445,7 @@ void ifftn(
"axes are invalid for in/out views");
KokkosFFT::Plan plan(exec_space, in, out, KokkosFFT::Direction::backward,
axes, s);
KokkosFFT::execute(plan, in, out, norm);
plan.execute(in, out, norm);
}

/// \brief N-dimensional FFT for real input
Expand Down
Loading

0 comments on commit 19fd10d

Please sign in to comment.