diff --git a/fft/src/KokkosFFT_Plans.hpp b/fft/src/KokkosFFT_Plans.hpp index 55dd9333..0b9bbbfc 100644 --- a/fft/src/KokkosFFT_Plans.hpp +++ b/fft/src/KokkosFFT_Plans.hpp @@ -355,6 +355,22 @@ class Plan { auto const direction = KokkosFFT::Impl::direction_type(m_direction); KokkosFFT::Impl::exec_plan(*m_plan, idata, odata, direction, m_info); + + if constexpr (KokkosFFT::Impl::is_complex_v && + KokkosFFT::Impl::is_real_v) { + if (m_is_inplace) { + // For the in-place Complex to Real transform, the output must be + // reshaped to fit the original size (in.size() * 2) for correct + // normalization + using UnmanagedOutViewType = + Kokkos::View>; + UnmanagedOutViewType out_tmp(out.data(), in.size() * 2); + KokkosFFT::Impl::normalize(m_exec_space, out_tmp, m_direction, norm, + m_fft_size); + return; + } + } KokkosFFT::Impl::normalize(m_exec_space, out, m_direction, norm, m_fft_size); } diff --git a/fft/unit_test/Test_Transform.cpp b/fft/unit_test/Test_Transform.cpp index a10d3e3e..cedcf79f 100644 --- a/fft/unit_test/Test_Transform.cpp +++ b/fft/unit_test/Test_Transform.cpp @@ -1695,18 +1695,20 @@ void test_fft2_2dfft_2dview_inplace([[maybe_unused]] T atol = 1.0e-12) { // Unmanged views for in-place transforms RealUView2DType xr(reinterpret_cast(xr_hat.data()), n0, n1), inv_xr_hat(reinterpret_cast(xr_hat.data()), n0, n1); - RealUView2DType xr_unpadded(reinterpret_cast(xr_hat.data()), n0, - (n1 / 2 + 1) * 2); + RealUView2DType xr_padded(reinterpret_cast(xr_hat.data()), n0, + (n1 / 2 + 1) * 2), + inv_xr_hat_padded(reinterpret_cast(xr_hat.data()), n0, + (n1 / 2 + 1) * 2); - // Initialize xr_hat through xr_unpadded - auto sub_xr_unpadded = Kokkos::subview(xr_unpadded, Kokkos::ALL(), - Kokkos::pair(0, n1)); + // Initialize xr_hat through xr_padded + auto sub_xr_padded = + Kokkos::subview(xr_padded, Kokkos::ALL(), Kokkos::pair(0, n1)); const Kokkos::complex z(1.0, 1.0); Kokkos::Random_XorShift64_Pool<> random_pool(12345); Kokkos::fill_random(xr_ref, random_pool, 1.0); Kokkos::fill_random(x, random_pool, z); - Kokkos::deep_copy(sub_xr_unpadded, xr_ref); + Kokkos::deep_copy(sub_xr_padded, xr_ref); Kokkos::deep_copy(x_ref, x); using axes_type = KokkosFFT::axis_type<2>; @@ -1753,8 +1755,6 @@ void test_fft2_2dfft_2dview_inplace([[maybe_unused]] T atol = 1.0e-12) { KokkosFFT::irfft2(execution_space(), xr_hat_ref, inv_xr_hat_ref, KokkosFFT::Normalization::backward, axes); - RealUView2DType inv_xr_hat_padded(reinterpret_cast(xr_hat.data()), n0, - (n1 / 2 + 1) * 2); auto sub_inv_xr_hat_padded = Kokkos::subview( inv_xr_hat_padded, Kokkos::ALL(), Kokkos::pair(0, n1)); Kokkos::deep_copy(inv_xr_hat_unpadded, sub_inv_xr_hat_padded); @@ -3535,4 +3535,4 @@ TYPED_TEST(FFTND, 3DFFT_batched_8DView) { float_type atol = std::is_same_v ? 1.0e-5 : 1.0e-10; test_fftn_3dfft_8dview(atol); -} \ No newline at end of file +}