diff --git a/include/bout/solver.hxx b/include/bout/solver.hxx index 896ce62965..47fef7ce73 100644 --- a/include/bout/solver.hxx +++ b/include/bout/solver.hxx @@ -429,6 +429,8 @@ protected: bool has_constraints{false}; /// Has init been called yet? bool initialised{false}; + /// If calling user RHS for the first time + bool first_rhs_call{true}; /// Current simulation time BoutReal simtime{0.0}; diff --git a/src/solver/impls/cvode/cvode.cxx b/src/solver/impls/cvode/cvode.cxx index 22f7f154f7..1fca765687 100644 --- a/src/solver/impls/cvode/cvode.cxx +++ b/src/solver/impls/cvode/cvode.cxx @@ -3,9 +3,9 @@ * * ************************************************************************** - * Copyright 2010 B.D.Dudson, S.Farley, M.V.Umansky, X.Q.Xu + * Copyright 2010-2024 BOUT++ contributors * - * Contact: Ben Dudson, bd512@york.ac.uk + * Contact: Ben Dudson, dudson2@llnl.gov * * This file is part of BOUT++. * @@ -59,7 +59,8 @@ BOUT_ENUM_CLASS(positivity_constraint, none, positive, non_negative, negative, // NOLINTBEGIN(readability-identifier-length) namespace { -int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data); +int cvode_linear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data); +int cvode_nonlinear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data); int cvode_bbd_rhs(sunindextype Nlocal, BoutReal t, N_Vector u, N_Vector du, void* user_data); @@ -216,9 +217,16 @@ int CvodeSolver::init() { throw BoutException("CVodeSetUserData failed\n"); } - if (CVodeInit(cvode_mem, cvode_rhs, simtime, uvec) != CV_SUCCESS) { +#if SUNDIALS_VERSION_MAJOR >= 6 + // Set the default RHS to linear, then pass nonlinear rhs to NL solver + if (CVodeInit(cvode_mem, cvode_linear_rhs, simtime, uvec) != CV_SUCCESS) { throw BoutException("CVodeInit failed\n"); } +#else + if (CVodeInit(cvode_mem, cvode_nonlinear_rhs, simtime, uvec) != CV_SUCCESS) { + throw BoutException("CVodeInit failed\n"); + } +#endif if (max_order > 0) { if (CVodeSetMaxOrd(cvode_mem, max_order) != CV_SUCCESS) { @@ -385,6 +393,11 @@ int CvodeSolver::init() { } } +#if SUNDIALS_VERSION_MAJOR >= 6 + // Set the RHS function to be used in the nonlinear solver + CVodeSetNlsRhsFn(cvode_mem, cvode_nonlinear_rhs); +#endif + // Set internal tolerance factors if (CVodeSetNonlinConvCoef(cvode_mem, cvode_nonlinear_convergence_coef) != CV_SUCCESS) { throw BoutException("CVodeSetNonlinConvCoef failed\n"); @@ -573,7 +586,7 @@ BoutReal CvodeSolver::run(BoutReal tout) { * RHS function du = F(t, u) **************************************************************************/ -void CvodeSolver::rhs(BoutReal t, BoutReal* udata, BoutReal* dudata) { +void CvodeSolver::rhs(BoutReal t, BoutReal* udata, BoutReal* dudata, bool linear) { TRACE("Running RHS: CvodeSolver::res({})", t); // Load state from udata @@ -584,7 +597,7 @@ void CvodeSolver::rhs(BoutReal t, BoutReal* udata, BoutReal* dudata) { CVodeGetLastStep(cvode_mem, &hcur); // Call RHS function - run_rhs(t); + run_rhs(t, linear); // Save derivatives to dudata save_derivs(dudata); @@ -655,7 +668,23 @@ void CvodeSolver::jac(BoutReal t, BoutReal* ydata, BoutReal* vdata, BoutReal* Jv // NOLINTBEGIN(readability-identifier-length) namespace { -int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) { +int cvode_linear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) { + + BoutReal* udata = N_VGetArrayPointer(u); + BoutReal* dudata = N_VGetArrayPointer(du); + + auto* s = static_cast(user_data); + + // Calculate RHS function + try { + s->rhs(t, udata, dudata, true); + } catch (BoutRhsFail& error) { + return 1; + } + return 0; +} + +int cvode_nonlinear_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) { BoutReal* udata = N_VGetArrayPointer(u); BoutReal* dudata = N_VGetArrayPointer(du); @@ -664,7 +693,7 @@ int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) { // Calculate RHS function try { - s->rhs(t, udata, dudata); + s->rhs(t, udata, dudata, false); } catch (BoutRhsFail& error) { return 1; } @@ -674,7 +703,7 @@ int cvode_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) { /// RHS function for BBD preconditioner int cvode_bbd_rhs(sunindextype UNUSED(Nlocal), BoutReal t, N_Vector u, N_Vector du, void* user_data) { - return cvode_rhs(t, u, du, user_data); + return cvode_linear_rhs(t, u, du, user_data); } /// Preconditioner function diff --git a/src/solver/impls/cvode/cvode.hxx b/src/solver/impls/cvode/cvode.hxx index 89c3a613a8..d44fcf2335 100644 --- a/src/solver/impls/cvode/cvode.hxx +++ b/src/solver/impls/cvode/cvode.hxx @@ -68,8 +68,8 @@ public: void resetInternalFields() override; - // These functions used internally (but need to be public) - void rhs(BoutReal t, BoutReal* udata, BoutReal* dudata); + // These functions are used internally (but need to be public) + void rhs(BoutReal t, BoutReal* udata, BoutReal* dudata, bool linear); void pre(BoutReal t, BoutReal gamma, BoutReal delta, BoutReal* udata, BoutReal* rvec, BoutReal* zvec); void jac(BoutReal t, BoutReal* ydata, BoutReal* vdata, BoutReal* Jvdata); @@ -138,7 +138,7 @@ private: int nonlin_fails{0}; int stab_lims{0}; - bool cvode_initialised = false; + bool cvode_initialised{false}; void set_vector_option_values(BoutReal* option_data, std::vector& f2dtols, std::vector& f3dtols); diff --git a/src/solver/solver.cxx b/src/solver/solver.cxx index 1b7ec1fd74..8a75ff43a4 100644 --- a/src/solver/solver.cxx +++ b/src/solver/solver.cxx @@ -1364,6 +1364,12 @@ int Solver::run_rhs(BoutReal t, bool linear) { Timer timer("rhs"); + if (first_rhs_call) { + // Ensure that nonlinear terms are calculated on first call + linear = false; + first_rhs_call = false; + } + if (model->splitOperator()) { // Run both parts