Skip to content

Commit

Permalink
Make FFT safe for slabs (#4268)
Browse files Browse the repository at this point in the history
Support FFT on domains that have one cell in some dimensions.

It also supports Poisson solves on slab domains. However, for
FFT::PoissonHybrid that treats the z-direction in a special way, the
z-direction must have more than one cell.
  • Loading branch information
WeiqunZhang authored Dec 16, 2024
1 parent bdb4be3 commit b3f6738
Show file tree
Hide file tree
Showing 7 changed files with 926 additions and 155 deletions.
2 changes: 2 additions & 0 deletions Src/Base/AMReX_Periodicity.H
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public:
//! Cell-centered domain Box "infinitely" long in non-periodic directions.
[[nodiscard]] Box Domain () const noexcept;

[[nodiscard]] IntVect const& intVect () const { return period; }

[[nodiscard]] std::vector<IntVect> shiftIntVect (IntVect const& nghost = IntVect(0)) const;

static const Periodicity& NonPeriodic () noexcept;
Expand Down
231 changes: 231 additions & 0 deletions Src/FFT/AMReX_FFT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,235 @@ void hip_execute (rocfft_plan plan, void **in, void **out)
}
#endif

SubHelper::SubHelper (Box const& domain)
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(domain);
#elif (AMREX_SPACEDIM == 2)
if (domain.length(0) == 1) {
m_case = case_1n;
}
#else
if (domain.length(0) == 1 && domain.length(1) == 1) {
m_case = case_11n;
} else if (domain.length(0) == 1 && domain.length(2) == 1) {
m_case = case_1n1;
} else if (domain.length(0) == 1) {
m_case = case_1nn;
} else if (domain.length(1) == 1) {
m_case = case_n1n;
}
#endif
}

Box SubHelper::make_box (Box const& box) const
{
return Box(make_iv(box.smallEnd()), make_iv(box.bigEnd()), box.ixType());
}

Periodicity SubHelper::make_periodicity (Periodicity const& period) const
{
return Periodicity(make_iv(period.intVect()));
}

bool SubHelper::ghost_safe (IntVect const& ng) const
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(ng,this);
return true;
#elif (AMREX_SPACEDIM == 2)
if (m_case == case_1n) {
return (ng[0] == 0);
} else {
return true;
}
#else
if (m_case == case_11n) {
return (ng[0] == 0) && (ng[1] == 0);
} else if (m_case == case_1n1) {
return (ng[0] == 0);
} else if (m_case == case_1nn) {
return (ng[0] == 0);
} else if (m_case == case_n1n) {
return (ng[1] == 0);
} else {
return true;
}
#endif
}

IntVect SubHelper::make_iv (IntVect const& iv) const
{
return this->make_array(iv);
}

IntVect SubHelper::make_safe_ghost (IntVect const& ng) const
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(this);
return ng;
#elif (AMREX_SPACEDIM == 2)
if (m_case == case_1n) {
return IntVect{0,ng[1]};
} else {
return ng;
}
#else
if (m_case == case_11n) {
return IntVect{0,0,ng[2]};
} else if (m_case == case_1n1) {
return IntVect{0,ng[1],ng[2]};
} else if (m_case == case_1nn) {
return IntVect{0,ng[1],ng[2]};
} else if (m_case == case_n1n) {
return IntVect{ng[0],0,ng[2]};
} else {
return ng;
}
#endif
}

BoxArray SubHelper::inverse_boxarray (BoxArray const& ba) const
{ // sub domain order -> original domain order
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(this);
return ba;
#elif (AMREX_SPACEDIM == 2)
AMREX_ALWAYS_ASSERT(m_case == case_1n);
BoxList bl = ba.boxList();
// sub domain order: y, x
for (auto& b : bl) {
auto const& lo = b.smallEnd();
auto const& hi = b.bigEnd();
b.setSmall(IntVect(lo[1],lo[0]));
b.setBig (IntVect(hi[1],hi[0]));
}
return BoxArray(std::move(bl));
#else
BoxList bl = ba.boxList();
if (m_case == case_11n) {
// sub domain order: z, x, y
for (auto& b : bl) {
auto const& lo = b.smallEnd();
auto const& hi = b.bigEnd();
b.setSmall(IntVect(lo[1],lo[2],lo[0]));
b.setBig (IntVect(hi[1],hi[2],hi[0]));
}
} else if (m_case == case_1n1) {
// sub domain order: y, x, z
for (auto& b : bl) {
auto const& lo = b.smallEnd();
auto const& hi = b.bigEnd();
b.setSmall(IntVect(lo[1],lo[0],lo[2]));
b.setBig (IntVect(hi[1],hi[0],hi[2]));
}
} else if (m_case == case_1nn) {
// sub domain order: y, z, x
for (auto& b : bl) {
auto const& lo = b.smallEnd();
auto const& hi = b.bigEnd();
b.setSmall(IntVect(lo[2],lo[0],lo[1]));
b.setBig (IntVect(hi[2],hi[0],hi[1]));
}
} else if (m_case == case_n1n) {
// sub domain order: x, z, y
for (auto& b : bl) {
auto const& lo = b.smallEnd();
auto const& hi = b.bigEnd();
b.setSmall(IntVect(lo[0],lo[2],lo[1]));
b.setBig (IntVect(hi[0],hi[2],hi[1]));
}
} else {
amrex::Abort("SubHelper::inverse_boxarray: how did this happen?");
}
return BoxArray(std::move(bl));
#endif
}

IntVect SubHelper::inverse_order (IntVect const& order) const
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(this);
return order;
#elif (AMREX_SPACEDIM == 2)
amrex::ignore_unused(this);
return IntVect(order[1],order[0]);
#else
auto translate = [&] (int index) -> int
{
int r = index;
if (m_case == case_11n) {
// sub domain order: z, x, y
if (index == 0) {
r = 2;
} else if (index == 1) {
r = 0;
} else {
r = 1;
}
} else if (m_case == case_1n1) {
// sub domain order: y, x, z
if (index == 0) {
r = 1;
} else if (index == 1) {
r = 0;
} else {
r = 2;
}
} else if (m_case == case_1nn) {
// sub domain order: y, z, x
if (index == 0) {
r = 1;
} else if (index == 1) {
r = 2;
} else {
r = 0;
}
} else if (m_case == case_n1n) {
// sub domain order: x, z, y
if (index == 0) {
r = 0;
} else if (index == 1) {
r = 2;
} else {
r = 1;
}
}
return r;
};

IntVect iv;
for (int idim = 0; idim < AMREX_SPACEDIM; ++idim) {
iv[idim] = translate(order[idim]);
}
return iv;
#endif
}

GpuArray<int,3> SubHelper::xyz_order () const
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(this);
return GpuArray<int,3>{0,1,2};
#elif (AMREX_SPACEDIM == 2)
if (m_case == case_1n) {
return GpuArray<int,3>{1,0,2};
} else {
return GpuArray<int,3>{0,1,2};
}
#else
if (m_case == case_11n) {
return GpuArray<int,3>{1,2,0};
} else if (m_case == case_1n1) {
return GpuArray<int,3>{1,0,2};
} else if (m_case == case_1nn) {
return GpuArray<int,3>{2,0,1};
} else if (m_case == case_n1n) {
return GpuArray<int,3>{0,2,1};
} else {
return GpuArray<int,3>{0,1,2};
}
#endif
}

}
79 changes: 79 additions & 0 deletions Src/FFT/AMReX_FFT_Helper.H
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
#include <AMReX_DataAllocator.H>
#include <AMReX_DistributionMapping.H>
#include <AMReX_Enum.H>
#include <AMReX_FabArray.H>
#include <AMReX_Gpu.H>
#include <AMReX_GpuComplex.H>
#include <AMReX_Math.H>
#include <AMReX_Periodicity.H>

#if defined(AMREX_USE_CUDA)
# include <cufft.h>
Expand Down Expand Up @@ -1447,6 +1449,83 @@ struct RotateBwd
}
};

namespace detail
{
struct SubHelper
{
explicit SubHelper (Box const& domain);

[[nodiscard]] Box make_box (Box const& box) const;

[[nodiscard]] Periodicity make_periodicity (Periodicity const& period) const;

[[nodiscard]] bool ghost_safe (IntVect const& ng) const;

// This rearranges the order.
[[nodiscard]] IntVect make_iv (IntVect const& iv) const;

// This keeps the order, but zero out the values in the hidden dimension.
[[nodiscard]] IntVect make_safe_ghost (IntVect const& ng) const;

[[nodiscard]] BoxArray inverse_boxarray (BoxArray const& ba) const;

[[nodiscard]] IntVect inverse_order (IntVect const& order) const;

template <typename T>
[[nodiscard]] T make_array (T const& a) const
{
#if (AMREX_SPACEDIM == 1)
amrex::ignore_unused(this);
return a;
#elif (AMREX_SPACEDIM == 2)
if (m_case == case_1n) {
return T{a[1],a[0]};
} else {
return a;
}
#else
if (m_case == case_11n) {
return T{a[2],a[0],a[1]};
} else if (m_case == case_1n1) {
return T{a[1],a[0],a[2]};
} else if (m_case == case_1nn) {
return T{a[1],a[2],a[0]};
} else if (m_case == case_n1n) {
return T{a[0],a[2],a[1]};
} else {
return a;
}
#endif
}

[[nodiscard]] GpuArray<int,3> xyz_order () const;

template <typename FA>
FA make_alias_mf (FA const& mf)
{
BoxList bl = mf.boxArray().boxList();
for (auto& b : bl) {
b = make_box(b);
}
auto const& ng = make_iv(mf.nGrowVect());
FA submf(BoxArray(std::move(bl)), mf.DistributionMap(), 1, ng, MFInfo{}.SetAlloc(false));
using FAB = typename FA::fab_type;
for (MFIter mfi(submf, MFItInfo().DisableDeviceSync()); mfi.isValid(); ++mfi) {
submf.setFab(mfi, FAB(mfi.fabbox(), 1, mf[mfi].dataPtr()));
}
return submf;
}

#if (AMREX_SPACEDIM == 2)
enum Case { case_1n, case_other };
int m_case = case_other;
#elif (AMREX_SPACEDIM == 3)
enum Case { case_11n, case_1n1, case_1nn, case_n1n, case_other };
int m_case = case_other;
#endif
};
}

}

#endif
Loading

0 comments on commit b3f6738

Please sign in to comment.