-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2932 from boutproject/cvode-linear-solve
CVODE solver: Pass linear flag to rhs()
- Loading branch information
Showing
4 changed files
with
49 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,9 +3,9 @@ | |
* | ||
* | ||
************************************************************************** | ||
* Copyright 2010 B.D.Dudson, S.Farley, M.V.Umansky, X.Q.Xu | ||
* Copyright 2010-2024 BOUT++ contributors | ||
* | ||
* Contact: Ben Dudson, [email protected] | ||
* Contact: Ben Dudson, [email protected] | ||
* | ||
* This file is part of BOUT++. | ||
* | ||
|
@@ -59,7 +59,8 @@ BOUT_ENUM_CLASS(positivity_constraint, none, positive, non_negative, negative, | |
|
||
// NOLINTBEGIN(readability-identifier-length) | ||
namespace { | ||
int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data); | ||
int cvode_linear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data); | ||
int cvode_nonlinear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data); | ||
int cvode_bbd_rhs(sunindextype Nlocal, BoutReal t, N_Vector u, N_Vector du, | ||
void* user_data); | ||
|
||
|
@@ -216,9 +217,16 @@ int CvodeSolver::init() { | |
throw BoutException("CVodeSetUserData failed\n"); | ||
} | ||
|
||
if (CVodeInit(cvode_mem, cvode_rhs, simtime, uvec) != CV_SUCCESS) { | ||
#if SUNDIALS_VERSION_MAJOR >= 6 | ||
// Set the default RHS to linear, then pass nonlinear rhs to NL solver | ||
if (CVodeInit(cvode_mem, cvode_linear_rhs, simtime, uvec) != CV_SUCCESS) { | ||
throw BoutException("CVodeInit failed\n"); | ||
} | ||
#else | ||
if (CVodeInit(cvode_mem, cvode_nonlinear_rhs, simtime, uvec) != CV_SUCCESS) { | ||
throw BoutException("CVodeInit failed\n"); | ||
} | ||
#endif | ||
|
||
if (max_order > 0) { | ||
if (CVodeSetMaxOrd(cvode_mem, max_order) != CV_SUCCESS) { | ||
|
@@ -385,6 +393,11 @@ int CvodeSolver::init() { | |
} | ||
} | ||
|
||
#if SUNDIALS_VERSION_MAJOR >= 6 | ||
// Set the RHS function to be used in the nonlinear solver | ||
CVodeSetNlsRhsFn(cvode_mem, cvode_nonlinear_rhs); | ||
#endif | ||
|
||
// Set internal tolerance factors | ||
if (CVodeSetNonlinConvCoef(cvode_mem, cvode_nonlinear_convergence_coef) != CV_SUCCESS) { | ||
throw BoutException("CVodeSetNonlinConvCoef failed\n"); | ||
|
@@ -573,7 +586,7 @@ BoutReal CvodeSolver::run(BoutReal tout) { | |
* RHS function du = F(t, u) | ||
**************************************************************************/ | ||
|
||
void CvodeSolver::rhs(BoutReal t, BoutReal* udata, BoutReal* dudata) { | ||
void CvodeSolver::rhs(BoutReal t, BoutReal* udata, BoutReal* dudata, bool linear) { | ||
TRACE("Running RHS: CvodeSolver::res({})", t); | ||
|
||
// Load state from udata | ||
|
@@ -584,7 +597,7 @@ void CvodeSolver::rhs(BoutReal t, BoutReal* udata, BoutReal* dudata) { | |
CVodeGetLastStep(cvode_mem, &hcur); | ||
|
||
// Call RHS function | ||
run_rhs(t); | ||
run_rhs(t, linear); | ||
|
||
// Save derivatives to dudata | ||
save_derivs(dudata); | ||
|
@@ -655,7 +668,23 @@ void CvodeSolver::jac(BoutReal t, BoutReal* ydata, BoutReal* vdata, BoutReal* Jv | |
|
||
// NOLINTBEGIN(readability-identifier-length) | ||
namespace { | ||
int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) { | ||
int cvode_linear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) { | ||
|
||
BoutReal* udata = N_VGetArrayPointer(u); | ||
BoutReal* dudata = N_VGetArrayPointer(du); | ||
|
||
auto* s = static_cast<CvodeSolver*>(user_data); | ||
|
||
// Calculate RHS function | ||
try { | ||
s->rhs(t, udata, dudata, true); | ||
} catch (BoutRhsFail& error) { | ||
return 1; | ||
} | ||
return 0; | ||
} | ||
|
||
int cvode_nonlinear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) { | ||
|
||
BoutReal* udata = N_VGetArrayPointer(u); | ||
BoutReal* dudata = N_VGetArrayPointer(du); | ||
|
@@ -664,7 +693,7 @@ int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) { | |
|
||
// Calculate RHS function | ||
try { | ||
s->rhs(t, udata, dudata); | ||
s->rhs(t, udata, dudata, false); | ||
} catch (BoutRhsFail& error) { | ||
return 1; | ||
} | ||
|
@@ -674,7 +703,7 @@ int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) { | |
/// RHS function for BBD preconditioner | ||
int cvode_bbd_rhs(sunindextype UNUSED(Nlocal), BoutReal t, N_Vector u, N_Vector du, | ||
void* user_data) { | ||
return cvode_rhs(t, u, du, user_data); | ||
return cvode_linear_rhs(t, u, du, user_data); | ||
} | ||
|
||
/// Preconditioner function | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters