diff --git a/examples/09_derivative/09_derivative.cpp b/examples/09_derivative/09_derivative.cpp index d9da9104..8c40835a 100644 --- a/examples/09_derivative/09_derivative.cpp +++ b/examples/09_derivative/09_derivative.cpp @@ -105,7 +105,7 @@ void analytical_solution(RealView1DType& x, RealView1DType& y, auto h_x = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), x); auto h_y = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), y); - // Initialize field + // Compute the analytical solution on host const int nx = dudxy.extent(2), ny = dudxy.extent(1), nz = dudxy.extent(0); auto h_dudxy = Kokkos::create_mirror_view(dudxy); for (int iz = 0; iz < nz; iz++) { @@ -134,12 +134,6 @@ void compute_derivative(const int nx, const int ny, const int nz, using ComplexView2D = View2D>; using ComplexView3D = View3D>; - // KokkosFFT plan types - using ForwardPlanType = - KokkosFFT::Plan; - using BackwardPlanType = - KokkosFFT::Plan; - // Declare grids RealView1D x("x", nx), y("y", ny); ComplexView2D ikx("ikx", ny, nx / 2 + 1), iky("iky", ny, nx / 2 + 1); @@ -164,10 +158,11 @@ void compute_derivative(const int nx, const int ny, const int nz, point2D_type{{ny, nx / 2 + 1}}, tile2D_type{{TILE0, TILE1}}); - ForwardPlanType r2c_plan(exec, u, u_hat, KokkosFFT::Direction::forward, + // kokkos-fft plans + KokkosFFT::Plan r2c_plan(exec, u, u_hat, KokkosFFT::Direction::forward, + KokkosFFT::axis_type<2>({-2, -1})); + KokkosFFT::Plan c2r_plan(exec, u_hat, u, KokkosFFT::Direction::backward, KokkosFFT::axis_type<2>({-2, -1})); - BackwardPlanType c2r_plan(exec, u_hat, u, KokkosFFT::Direction::backward, - KokkosFFT::axis_type<2>({-2, -1})); // Start computation Kokkos::Timer timer;