Skip to content

Commit

Permalink
Merge pull request #2932 from boutproject/cvode-linear-solve
Browse files Browse the repository at this point in the history
CVODE solver: Pass linear flag to rhs()
  • Loading branch information
bendudson authored Jul 29, 2024
2 parents 20b47f4 + f009116 commit 78a8149
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 12 deletions.
2 changes: 2 additions & 0 deletions include/bout/solver.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ protected:
bool has_constraints{false};
/// Has init been called yet?
bool initialised{false};
/// If calling user RHS for the first time
bool first_rhs_call{true};

/// Current simulation time
BoutReal simtime{0.0};
Expand Down
47 changes: 38 additions & 9 deletions src/solver/impls/cvode/cvode.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
*
*
**************************************************************************
* Copyright 2010 B.D.Dudson, S.Farley, M.V.Umansky, X.Q.Xu
* Copyright 2010-2024 BOUT++ contributors
*
* Contact: Ben Dudson, [email protected]
* Contact: Ben Dudson, [email protected]
*
* This file is part of BOUT++.
*
Expand Down Expand Up @@ -59,7 +59,8 @@ BOUT_ENUM_CLASS(positivity_constraint, none, positive, non_negative, negative,

// NOLINTBEGIN(readability-identifier-length)
namespace {
int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data);
int cvode_linear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data);
int cvode_nonlinear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data);
int cvode_bbd_rhs(sunindextype Nlocal, BoutReal t, N_Vector u, N_Vector du,
void* user_data);

Expand Down Expand Up @@ -216,9 +217,16 @@ int CvodeSolver::init() {
throw BoutException("CVodeSetUserData failed\n");
}

if (CVodeInit(cvode_mem, cvode_rhs, simtime, uvec) != CV_SUCCESS) {
#if SUNDIALS_VERSION_MAJOR >= 6
// Set the default RHS to linear, then pass nonlinear rhs to NL solver
if (CVodeInit(cvode_mem, cvode_linear_rhs, simtime, uvec) != CV_SUCCESS) {
throw BoutException("CVodeInit failed\n");
}
#else
if (CVodeInit(cvode_mem, cvode_nonlinear_rhs, simtime, uvec) != CV_SUCCESS) {
throw BoutException("CVodeInit failed\n");
}
#endif

if (max_order > 0) {
if (CVodeSetMaxOrd(cvode_mem, max_order) != CV_SUCCESS) {
Expand Down Expand Up @@ -385,6 +393,11 @@ int CvodeSolver::init() {
}
}

#if SUNDIALS_VERSION_MAJOR >= 6
// Set the RHS function to be used in the nonlinear solver
CVodeSetNlsRhsFn(cvode_mem, cvode_nonlinear_rhs);
#endif

// Set internal tolerance factors
if (CVodeSetNonlinConvCoef(cvode_mem, cvode_nonlinear_convergence_coef) != CV_SUCCESS) {
throw BoutException("CVodeSetNonlinConvCoef failed\n");
Expand Down Expand Up @@ -573,7 +586,7 @@ BoutReal CvodeSolver::run(BoutReal tout) {
* RHS function du = F(t, u)
**************************************************************************/

void CvodeSolver::rhs(BoutReal t, BoutReal* udata, BoutReal* dudata) {
void CvodeSolver::rhs(BoutReal t, BoutReal* udata, BoutReal* dudata, bool linear) {
TRACE("Running RHS: CvodeSolver::res({})", t);

// Load state from udata
Expand All @@ -584,7 +597,7 @@ void CvodeSolver::rhs(BoutReal t, BoutReal* udata, BoutReal* dudata) {
CVodeGetLastStep(cvode_mem, &hcur);

// Call RHS function
run_rhs(t);
run_rhs(t, linear);

// Save derivatives to dudata
save_derivs(dudata);
Expand Down Expand Up @@ -655,7 +668,23 @@ void CvodeSolver::jac(BoutReal t, BoutReal* ydata, BoutReal* vdata, BoutReal* Jv

// NOLINTBEGIN(readability-identifier-length)
namespace {
int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) {
int cvode_linear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) {

BoutReal* udata = N_VGetArrayPointer(u);
BoutReal* dudata = N_VGetArrayPointer(du);

auto* s = static_cast<CvodeSolver*>(user_data);

// Calculate RHS function
try {
s->rhs(t, udata, dudata, true);
} catch (BoutRhsFail& error) {
return 1;
}
return 0;
}

int cvode_nonlinear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) {

BoutReal* udata = N_VGetArrayPointer(u);
BoutReal* dudata = N_VGetArrayPointer(du);
Expand All @@ -664,7 +693,7 @@ int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) {

// Calculate RHS function
try {
s->rhs(t, udata, dudata);
s->rhs(t, udata, dudata, false);
} catch (BoutRhsFail& error) {
return 1;
}
Expand All @@ -674,7 +703,7 @@ int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) {
/// RHS function for BBD preconditioner
int cvode_bbd_rhs(sunindextype UNUSED(Nlocal), BoutReal t, N_Vector u, N_Vector du,
void* user_data) {
return cvode_rhs(t, u, du, user_data);
return cvode_linear_rhs(t, u, du, user_data);
}

/// Preconditioner function
Expand Down
6 changes: 3 additions & 3 deletions src/solver/impls/cvode/cvode.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ public:

void resetInternalFields() override;

// These functions used internally (but need to be public)
void rhs(BoutReal t, BoutReal* udata, BoutReal* dudata);
// These functions are used internally (but need to be public)
void rhs(BoutReal t, BoutReal* udata, BoutReal* dudata, bool linear);
void pre(BoutReal t, BoutReal gamma, BoutReal delta, BoutReal* udata, BoutReal* rvec,
BoutReal* zvec);
void jac(BoutReal t, BoutReal* ydata, BoutReal* vdata, BoutReal* Jvdata);
Expand Down Expand Up @@ -138,7 +138,7 @@ private:
int nonlin_fails{0};
int stab_lims{0};

bool cvode_initialised = false;
bool cvode_initialised{false};

void set_vector_option_values(BoutReal* option_data, std::vector<BoutReal>& f2dtols,
std::vector<BoutReal>& f3dtols);
Expand Down
6 changes: 6 additions & 0 deletions src/solver/solver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,12 @@ int Solver::run_rhs(BoutReal t, bool linear) {

Timer timer("rhs");

if (first_rhs_call) {
// Ensure that nonlinear terms are calculated on first call
linear = false;
first_rhs_call = false;
}

if (model->splitOperator()) {
// Run both parts

Expand Down

0 comments on commit 78a8149

Please sign in to comment.