Skip to content

Commit

Permalink
Merge pull request #2916 from boutproject/laplace-flag-check-methods
Browse files Browse the repository at this point in the history
Add getters for Laplacian flags
  • Loading branch information
bendudson authored Jun 26, 2024
2 parents ba8dd2c + 505f2b3 commit 71d7858
Show file tree
Hide file tree
Showing 20 changed files with 824 additions and 1,068 deletions.
38 changes: 29 additions & 9 deletions include/bout/invert_laplace.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ public:
virtual void setInnerBoundaryFlags(int f) { inner_boundary_flags = f; }
virtual void setOuterBoundaryFlags(int f) { outer_boundary_flags = f; }

virtual int getGlobalFlags() const { return global_flags; }
virtual int getInnerBoundaryFlags() const { return inner_boundary_flags; }
virtual int getOuterBoundaryFlags() const { return outer_boundary_flags; }

/// Does this solver use Field3D coefficients (true) or only their DC component (false)
virtual bool uses3DCoefs() const { return false; }

Expand Down Expand Up @@ -308,9 +312,23 @@ protected:
int extra_yguards_lower; ///< exclude some number of points at the lower boundary, useful for staggered grids or when boundary conditions make inversion redundant
int extra_yguards_upper; ///< exclude some number of points at the upper boundary, useful for staggered grids or when boundary conditions make inversion redundant

int global_flags; ///< Default flags
int inner_boundary_flags; ///< Flags to set inner boundary condition
int outer_boundary_flags; ///< Flags to set outer boundary condition
/// Return true if global/default \p flag is set
bool isGlobalFlagSet(int flag) const { return (global_flags & flag) != 0; }
/// Return true if \p flag is set for the inner boundary condition
bool isInnerBoundaryFlagSet(int flag) const {
return (inner_boundary_flags & flag) != 0;
}
/// Return true if \p flag is set for the outer boundary condition
bool isOuterBoundaryFlagSet(int flag) const {
return (outer_boundary_flags & flag) != 0;
}

/// Return true if \p flag is set for the inner boundary condition
/// and this is the first proc in X direction
bool isInnerBoundaryFlagSetOnFirstX(int flag) const;
/// Return true if \p flag is set for the outer boundary condition
/// and this the last proc in X direction
bool isOuterBoundaryFlagSetOnLastX(int flag) const;

void tridagCoefs(int jx, int jy, BoutReal kwave, dcomplex& a, dcomplex& b, dcomplex& c,
const Field2D* ccoef = nullptr, const Field2D* d = nullptr,
Expand All @@ -322,15 +340,13 @@ protected:
CELL_LOC loc = CELL_DEFAULT);

void tridagMatrix(dcomplex* avec, dcomplex* bvec, dcomplex* cvec, dcomplex* bk, int jy,
int kz, BoutReal kwave, int flags, int inner_boundary_flags,
int outer_boundary_flags, const Field2D* a, const Field2D* ccoef,
int kz, BoutReal kwave, const Field2D* a, const Field2D* ccoef,
const Field2D* d, bool includeguards = true, bool zperiodic = true) {
tridagMatrix(avec, bvec, cvec, bk, jy, kz, kwave, flags, inner_boundary_flags,
outer_boundary_flags, a, ccoef, ccoef, d, includeguards, zperiodic);
tridagMatrix(avec, bvec, cvec, bk, jy, kz, kwave, a, ccoef, ccoef, d, includeguards,
zperiodic);
}
void tridagMatrix(dcomplex* avec, dcomplex* bvec, dcomplex* cvec, dcomplex* bk, int jy,
int kz, BoutReal kwave, int flags, int inner_boundary_flags,
int outer_boundary_flags, const Field2D* a, const Field2D* c1coef,
int kz, BoutReal kwave, const Field2D* a, const Field2D* c1coef,
const Field2D* c2coef, const Field2D* d, bool includeguards = true,
bool zperiodic = true);
CELL_LOC location; ///< staggered grid location of this solver
Expand All @@ -339,6 +355,10 @@ protected:
/// localmesh->getCoordinates(location) once

private:
int global_flags; ///< Default flags
int inner_boundary_flags; ///< Flags to set inner boundary condition
int outer_boundary_flags; ///< Flags to set outer boundary condition

/// Singleton instance
static std::unique_ptr<Laplacian> instance;
/// Name for writing performance infomation; default taken from
Expand Down
65 changes: 32 additions & 33 deletions src/invert/laplace/impls/cyclic/cyclic_laplace.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@
*
*/

#include "cyclic_laplace.hxx"
#include "bout/build_config.hxx"
#include "bout/build_defines.hxx"

#if not BOUT_USE_METRIC_3D

#include "cyclic_laplace.hxx"
#include "bout/assert.hxx"
#include "bout/bout_types.hxx"
#include <bout/boutexception.hxx>
#include <bout/constants.hxx>
#include <bout/fft.hxx>
Expand All @@ -47,7 +49,7 @@
#include <bout/sys/timer.hxx>
#include <bout/utils.hxx>

#include "cyclic_laplace.hxx"
#include <vector>

LaplaceCyclic::LaplaceCyclic(Options* opt, const CELL_LOC loc, Mesh* mesh_in,
Solver* UNUSED(solver))
Expand Down Expand Up @@ -120,13 +122,13 @@ FieldPerp LaplaceCyclic::solve(const FieldPerp& rhs, const FieldPerp& x0) {

// If the flags to assign that only one guard cell should be used is set
int inbndry = localmesh->xstart, outbndry = localmesh->xstart;
if (((global_flags & INVERT_BOTH_BNDRY_ONE) != 0) || (localmesh->xstart < 2)) {
if (isGlobalFlagSet(INVERT_BOTH_BNDRY_ONE) || (localmesh->xstart < 2)) {
inbndry = outbndry = 1;
}
if ((inner_boundary_flags & INVERT_BNDRY_ONE) != 0) {
if (isInnerBoundaryFlagSet(INVERT_BNDRY_ONE)) {
inbndry = 1;
}
if ((outer_boundary_flags & INVERT_BNDRY_ONE) != 0) {
if (isOuterBoundaryFlagSet(INVERT_BNDRY_ONE)) {
outbndry = 1;
}

Expand All @@ -143,9 +145,9 @@ FieldPerp LaplaceCyclic::solve(const FieldPerp& rhs, const FieldPerp& x0) {
for (int ix = xs; ix <= xe; ix++) {
// Take DST in Z direction and put result in k1d

if (((ix < inbndry) && (inner_boundary_flags & INVERT_SET) && localmesh->firstX())
if (((ix < inbndry) && isInnerBoundaryFlagSetOnFirstX(INVERT_SET))
|| ((localmesh->LocalNx - ix - 1 < outbndry)
&& (outer_boundary_flags & INVERT_SET) && localmesh->lastX())) {
&& isOuterBoundaryFlagSetOnLastX(INVERT_SET))) {
// Use the values in x0 in the boundary
DST(x0[ix] + 1, localmesh->LocalNz - 2, std::begin(k1d));
} else {
Expand All @@ -169,8 +171,7 @@ FieldPerp LaplaceCyclic::solve(const FieldPerp& rhs, const FieldPerp& x0) {
tridagMatrix(&a(kz, 0), &b(kz, 0), &c(kz, 0), &bcmplx(kz, 0), jy,
kz, // wave number index
kwave, // kwave (inverse wave length)
global_flags, inner_boundary_flags, outer_boundary_flags, &Acoef,
&C1coef, &C2coef, &Dcoef,
&Acoef, &C1coef, &C2coef, &Dcoef,
false, // Don't include guard cells in arrays
false); // Z domain not periodic
}
Expand Down Expand Up @@ -218,9 +219,9 @@ FieldPerp LaplaceCyclic::solve(const FieldPerp& rhs, const FieldPerp& x0) {
for (int ix = xs; ix <= xe; ix++) {
// Take FFT in Z direction, apply shift, and put result in k1d

if (((ix < inbndry) && (inner_boundary_flags & INVERT_SET) && localmesh->firstX())
if (((ix < inbndry) && isInnerBoundaryFlagSetOnFirstX(INVERT_SET))
|| ((localmesh->LocalNx - ix - 1 < outbndry)
&& (outer_boundary_flags & INVERT_SET) && localmesh->lastX())) {
&& isOuterBoundaryFlagSetOnLastX(INVERT_SET))) {
// Use the values in x0 in the boundary
rfft(x0[ix], localmesh->LocalNz, std::begin(k1d));
} else {
Expand All @@ -241,8 +242,7 @@ FieldPerp LaplaceCyclic::solve(const FieldPerp& rhs, const FieldPerp& x0) {
tridagMatrix(&a(kz, 0), &b(kz, 0), &c(kz, 0), &bcmplx(kz, 0), jy,
kz, // True for the component constant (DC) in Z
kwave, // Z wave number
global_flags, inner_boundary_flags, outer_boundary_flags, &Acoef,
&C1coef, &C2coef, &Dcoef,
&Acoef, &C1coef, &C2coef, &Dcoef,
false); // Don't include guard cells in arrays
}
}
Expand Down Expand Up @@ -275,7 +275,7 @@ FieldPerp LaplaceCyclic::solve(const FieldPerp& rhs, const FieldPerp& x0) {
// ZFFT routine expects input of this length
auto k1d = Array<dcomplex>((localmesh->LocalNz) / 2 + 1);

const bool zero_DC = (global_flags & INVERT_ZERO_DC) != 0;
const bool zero_DC = isGlobalFlagSet(INVERT_ZERO_DC);

BOUT_OMP_PERF(for nowait)
for (int ix = xs; ix <= xe; ix++) {
Expand Down Expand Up @@ -316,13 +316,13 @@ Field3D LaplaceCyclic::solve(const Field3D& rhs, const Field3D& x0) {

// If the flags to assign that only one guard cell should be used is set
int inbndry = localmesh->xstart, outbndry = localmesh->xstart;
if (((global_flags & INVERT_BOTH_BNDRY_ONE) != 0) || (localmesh->xstart < 2)) {
if (isGlobalFlagSet(INVERT_BOTH_BNDRY_ONE) || (localmesh->xstart < 2)) {
inbndry = outbndry = 1;
}
if ((inner_boundary_flags & INVERT_BNDRY_ONE) != 0) {
if (isInnerBoundaryFlagSet(INVERT_BNDRY_ONE)) {
inbndry = 1;
}
if ((outer_boundary_flags & INVERT_BNDRY_ONE) != 0) {
if (isOuterBoundaryFlagSet(INVERT_BNDRY_ONE)) {
outbndry = 1;
}

Expand Down Expand Up @@ -350,6 +350,9 @@ Field3D LaplaceCyclic::solve(const Field3D& rhs, const Field3D& x0) {
const int nsys = nmode * ny; // Number of systems of equations to solve
const int nxny = nx * ny; // Number of points in X-Y

// This is just to silence static analysis
ASSERT0(ny > 0);

auto a3D = Matrix<dcomplex>(nsys, nx);
auto b3D = Matrix<dcomplex>(nsys, nx);
auto c3D = Matrix<dcomplex>(nsys, nx);
Expand All @@ -374,10 +377,9 @@ Field3D LaplaceCyclic::solve(const Field3D& rhs, const Field3D& x0) {

// Take DST in Z direction and put result in k1d

if (((ix < inbndry) && ((inner_boundary_flags & INVERT_SET) != 0)
&& localmesh->firstX())
if (((ix < inbndry) && isInnerBoundaryFlagSetOnFirstX(INVERT_SET))
|| ((localmesh->LocalNx - ix - 1 < outbndry)
&& ((outer_boundary_flags & INVERT_SET) != 0) && localmesh->lastX())) {
&& isOuterBoundaryFlagSetOnLastX(INVERT_SET))) {
// Use the values in x0 in the boundary
DST(x0(ix, iy) + 1, localmesh->LocalNz - 2, std::begin(k1d));
} else {
Expand Down Expand Up @@ -405,8 +407,7 @@ Field3D LaplaceCyclic::solve(const Field3D& rhs, const Field3D& x0) {
tridagMatrix(&a3D(ind, 0), &b3D(ind, 0), &c3D(ind, 0), &bcmplx3D(ind, 0), iy,
kz, // wave number index
kwave, // kwave (inverse wave length)
global_flags, inner_boundary_flags, outer_boundary_flags, &Acoef,
&C1coef, &C2coef, &Dcoef,
&Acoef, &C1coef, &C2coef, &Dcoef,
false, // Don't include guard cells in arrays
false); // Z domain not periodic
}
Expand Down Expand Up @@ -462,10 +463,9 @@ Field3D LaplaceCyclic::solve(const Field3D& rhs, const Field3D& x0) {

// Take FFT in Z direction, apply shift, and put result in k1d

if (((ix < inbndry) && ((inner_boundary_flags & INVERT_SET) != 0)
&& localmesh->firstX())
if (((ix < inbndry) && isInnerBoundaryFlagSetOnFirstX(INVERT_SET))
|| ((localmesh->LocalNx - ix - 1 < outbndry)
&& ((outer_boundary_flags & INVERT_SET) != 0) && localmesh->lastX())) {
&& isOuterBoundaryFlagSetOnLastX(INVERT_SET))) {
// Use the values in x0 in the boundary
rfft(x0(ix, iy), localmesh->LocalNz, std::begin(k1d));
} else {
Expand All @@ -490,8 +490,7 @@ Field3D LaplaceCyclic::solve(const Field3D& rhs, const Field3D& x0) {
tridagMatrix(&a3D(ind, 0), &b3D(ind, 0), &c3D(ind, 0), &bcmplx3D(ind, 0), iy,
kz, // True for the component constant (DC) in Z
kwave, // Z wave number
global_flags, inner_boundary_flags, outer_boundary_flags, &Acoef,
&C1coef, &C2coef, &Dcoef,
&Acoef, &C1coef, &C2coef, &Dcoef,
false); // Don't include guard cells in arrays
}
}
Expand All @@ -502,18 +501,18 @@ Field3D LaplaceCyclic::solve(const Field3D& rhs, const Field3D& x0) {

if (localmesh->periodicX) {
// Subtract X average of kz=0 mode
BoutReal local[ny + 1];
std::vector<BoutReal> local(ny + 1, 0.0);
for (int y = 0; y < ny; y++) {
local[y] = 0.0;
for (int ix = xs; ix <= xe; ix++) {
local[y] += xcmplx3D(y * nmode, ix - xs).real();
}
}
local[ny] = static_cast<BoutReal>(xe - xs + 1);

// Global reduce
BoutReal global[ny + 1];
MPI_Allreduce(local, global, ny + 1, MPI_DOUBLE, MPI_SUM, localmesh->getXcomm());
std::vector<BoutReal> global(ny + 1, 0.0);
MPI_Allreduce(local.data(), global.data(), ny + 1, MPI_DOUBLE, MPI_SUM,
localmesh->getXcomm());
// Subtract average from kz=0 modes
for (int y = 0; y < ny; y++) {
BoutReal avg = global[y] / global[ny];
Expand All @@ -530,7 +529,7 @@ Field3D LaplaceCyclic::solve(const Field3D& rhs, const Field3D& x0) {
auto k1d = Array<dcomplex>((localmesh->LocalNz) / 2
+ 1); // ZFFT routine expects input of this length

const bool zero_DC = (global_flags & INVERT_ZERO_DC) != 0;
const bool zero_DC = isGlobalFlagSet(INVERT_ZERO_DC);

BOUT_OMP_PERF(for nowait)
for (int ind = 0; ind < nxny; ++ind) { // Loop over X and Y
Expand Down
12 changes: 6 additions & 6 deletions src/invert/laplace/impls/hypre3d/hypre3d_laplace.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ LaplaceHypre3d::LaplaceHypre3d(Options* opt, const CELL_LOC loc, Mesh* mesh_in,

// Set up boundary conditions in operator
BOUT_FOR_SERIAL(i, indexer->getRegionInnerX()) {
if (inner_boundary_flags & INVERT_AC_GRAD) {
if (isInnerBoundaryFlagSet(INVERT_AC_GRAD)) {
// Neumann on inner X boundary
operator3D(i, i) = -1. / coords->dx[i] / sqrt(coords->g_11[i]);
operator3D(i, i.xp()) = 1. / coords->dx[i] / sqrt(coords->g_11[i]);
Expand All @@ -111,7 +111,7 @@ LaplaceHypre3d::LaplaceHypre3d(Options* opt, const CELL_LOC loc, Mesh* mesh_in,
}

BOUT_FOR_SERIAL(i, indexer->getRegionOuterX()) {
if (outer_boundary_flags & INVERT_AC_GRAD) {
if (isOuterBoundaryFlagSet(INVERT_AC_GRAD)) {
// Neumann on outer X boundary
operator3D(i, i) = 1. / coords->dx[i] / sqrt(coords->g_11[i]);
operator3D(i, i.xm()) = -1. / coords->dx[i] / sqrt(coords->g_11[i]);
Expand Down Expand Up @@ -180,19 +180,19 @@ Field3D LaplaceHypre3d::solve(const Field3D& b_in, const Field3D& x0) {
// Adjust vectors to represent boundary conditions and check that
// boundary cells are finite
BOUT_FOR_SERIAL(i, indexer->getRegionInnerX()) {
const BoutReal val = (inner_boundary_flags & INVERT_SET) ? x0[i] : 0.;
const BoutReal val = isInnerBoundaryFlagSet(INVERT_SET) ? x0[i] : 0.;
ASSERT1(std::isfinite(val));
if (!(inner_boundary_flags & INVERT_RHS)) {
if (!(isInnerBoundaryFlagSet(INVERT_RHS))) {
b[i] = val;
} else {
ASSERT1(std::isfinite(b[i]));
}
}

BOUT_FOR_SERIAL(i, indexer->getRegionOuterX()) {
const BoutReal val = (outer_boundary_flags & INVERT_SET) ? x0[i] : 0.;
const BoutReal val = (isOuterBoundaryFlagSet(INVERT_SET)) ? x0[i] : 0.;
ASSERT1(std::isfinite(val));
if (!(outer_boundary_flags & INVERT_RHS)) {
if (!(isOuterBoundaryFlagSet(INVERT_RHS))) {
b[i] = val;
} else {
ASSERT1(std::isfinite(b[i]));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,8 @@ FieldPerp LaplaceIPT::solve(const FieldPerp& b, const FieldPerp& x0) {
*/
auto bcmplx = Matrix<dcomplex>(nmode, ncx);

const bool invert_inner_boundary =
isInnerBoundaryFlagSet(INVERT_SET) and localmesh->firstX();
const bool invert_outer_boundary =
isOuterBoundaryFlagSet(INVERT_SET) and localmesh->lastX();
const bool invert_inner_boundary = isInnerBoundaryFlagSetOnFirstX(INVERT_SET);
const bool invert_outer_boundary = isOuterBoundaryFlagSetOnLastX(INVERT_SET);

BOUT_OMP_PERF(parallel for)
for (int ix = 0; ix < ncx; ix++) {
Expand Down Expand Up @@ -345,8 +343,7 @@ FieldPerp LaplaceIPT::solve(const FieldPerp& b, const FieldPerp& x0) {
kz,
// wave number (different from kz only if we are taking a part
// of the z-domain [and not from 0 to 2*pi])
kz * kwaveFactor, global_flags, inner_boundary_flags,
outer_boundary_flags, &A, &C, &D);
kz * kwaveFactor, &A, &C, &D);

// Patch up internal boundaries
if (not localmesh->lastX()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,14 +234,6 @@ private:

/// First and last interior points xstart, xend
int xs, xe;

bool isGlobalFlagSet(int flag) const { return (global_flags & flag) != 0; }
bool isInnerBoundaryFlagSet(int flag) const {
return (inner_boundary_flags & flag) != 0;
}
bool isOuterBoundaryFlagSet(int flag) const {
return (outer_boundary_flags & flag) != 0;
}
};

#endif // BOUT_USE_METRIC_3D
Expand Down
Loading

0 comments on commit 71d7858

Please sign in to comment.