Skip to content

Commit

Permalink
fix: normalization for C2R
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Oct 18, 2024
1 parent de41209 commit 48e8072
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
16 changes: 16 additions & 0 deletions fft/src/KokkosFFT_Plans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,22 @@ class Plan {
auto const direction =
KokkosFFT::Impl::direction_type<execSpace>(m_direction);
KokkosFFT::Impl::exec_plan(*m_plan, idata, odata, direction, m_info);

if constexpr (KokkosFFT::Impl::is_complex_v<in_value_type> &&
KokkosFFT::Impl::is_real_v<out_value_type>) {
if (m_is_inplace) {
// For the in-place Complex to Real transform, the output must be
// reshaped to fit the original size (in.size() * 2) for correct
// normalization
using UnmanagedOutViewType =
Kokkos::View<out_value_type*, execSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
UnmanagedOutViewType out_tmp(out.data(), in.size() * 2);
KokkosFFT::Impl::normalize(m_exec_space, out_tmp, m_direction, norm,
m_fft_size);
return;
}
}
KokkosFFT::Impl::normalize(m_exec_space, out, m_direction, norm,
m_fft_size);
}
Expand Down
18 changes: 9 additions & 9 deletions fft/unit_test/Test_Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1695,18 +1695,20 @@ void test_fft2_2dfft_2dview_inplace([[maybe_unused]] T atol = 1.0e-12) {
// Unmanged views for in-place transforms
RealUView2DType xr(reinterpret_cast<T*>(xr_hat.data()), n0, n1),
inv_xr_hat(reinterpret_cast<T*>(xr_hat.data()), n0, n1);
RealUView2DType xr_unpadded(reinterpret_cast<T*>(xr_hat.data()), n0,
(n1 / 2 + 1) * 2);
RealUView2DType xr_padded(reinterpret_cast<T*>(xr_hat.data()), n0,
(n1 / 2 + 1) * 2),
inv_xr_hat_padded(reinterpret_cast<T*>(xr_hat.data()), n0,
(n1 / 2 + 1) * 2);

// Initialize xr_hat through xr_unpadded
auto sub_xr_unpadded = Kokkos::subview(xr_unpadded, Kokkos::ALL(),
Kokkos::pair<int, int>(0, n1));
// Initialize xr_hat through xr_padded
auto sub_xr_padded =
Kokkos::subview(xr_padded, Kokkos::ALL(), Kokkos::pair<int, int>(0, n1));

const Kokkos::complex<T> z(1.0, 1.0);
Kokkos::Random_XorShift64_Pool<> random_pool(12345);
Kokkos::fill_random(xr_ref, random_pool, 1.0);
Kokkos::fill_random(x, random_pool, z);
Kokkos::deep_copy(sub_xr_unpadded, xr_ref);
Kokkos::deep_copy(sub_xr_padded, xr_ref);
Kokkos::deep_copy(x_ref, x);

using axes_type = KokkosFFT::axis_type<2>;
Expand Down Expand Up @@ -1753,8 +1755,6 @@ void test_fft2_2dfft_2dview_inplace([[maybe_unused]] T atol = 1.0e-12) {
KokkosFFT::irfft2(execution_space(), xr_hat_ref, inv_xr_hat_ref,
KokkosFFT::Normalization::backward, axes);

RealUView2DType inv_xr_hat_padded(reinterpret_cast<T*>(xr_hat.data()), n0,
(n1 / 2 + 1) * 2);
auto sub_inv_xr_hat_padded = Kokkos::subview(
inv_xr_hat_padded, Kokkos::ALL(), Kokkos::pair<int, int>(0, n1));
Kokkos::deep_copy(inv_xr_hat_unpadded, sub_inv_xr_hat_padded);
Expand Down Expand Up @@ -3535,4 +3535,4 @@ TYPED_TEST(FFTND, 3DFFT_batched_8DView) {

float_type atol = std::is_same_v<float_type, float> ? 1.0e-5 : 1.0e-10;
test_fftn_3dfft_8dview<float_type, layout_type>(atol);
}
}

0 comments on commit 48e8072

Please sign in to comment.