Skip to content

Commit

Permalink
allow in-place linear solve for Backward Euler
Browse files Browse the repository at this point in the history
  • Loading branch information
mattldawson committed Dec 24, 2024
1 parent 4067531 commit b529224
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 12 deletions.
23 changes: 19 additions & 4 deletions include/micm/solver/backward_euler.inl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ namespace micm
// if the last attempt to reduce the timestep fails,
// accept the current H but do not update the Yn vector

using MatrixPolicy = decltype(state.variables_);
using DenseMatrixPolicy = decltype(state.variables_);
using SparseMatrixPolicy = decltype(state.jacobian_);

SolverResult result;

Expand All @@ -67,7 +68,7 @@ namespace micm
std::size_t n_convergence_failures = 0;

auto derived_class_temporary_variables =
static_cast<BackwardEulerTemporaryVariables<MatrixPolicy>*>(state.temporary_variables_.get());
static_cast<BackwardEulerTemporaryVariables<DenseMatrixPolicy>*>(state.temporary_variables_.get());
auto& Yn = derived_class_temporary_variables->Yn_;
auto& Yn1 = state.variables_; // Yn1 will hold the new solution at the end of the solve
auto& forcing = derived_class_temporary_variables->forcing_;
Expand Down Expand Up @@ -110,7 +111,14 @@ namespace micm
// (y_{n+1} - y_n) / H = f(t_{n+1}, y_{n+1})

// try to find the root by factoring and solving the linear system
linear_solver_.Factor(state.jacobian_, state.lower_matrix_, state.upper_matrix_);
if constexpr (LinearSolverInPlaceConcept<LinearSolverPolicy, DenseMatrixPolicy, SparseMatrixPolicy>)
{
linear_solver_.Factor(state.jacobian_);
}
else
{
linear_solver_.Factor(state.jacobian_, state.lower_matrix_, state.upper_matrix_);
}
result.stats_.decompositions_++;

// forcing_blk in camchem
Expand All @@ -120,7 +128,14 @@ namespace micm

// the result of the linear solver will be stored in forcing
// this represents the change in the solution
linear_solver_.Solve(forcing, state.lower_matrix_, state.upper_matrix_);
if constexpr (LinearSolverInPlaceConcept<LinearSolverPolicy, DenseMatrixPolicy, SparseMatrixPolicy>)
{
linear_solver_.Solve(forcing, state.jacobian_);
}
else
{
linear_solver_.Solve(forcing, state.lower_matrix_, state.upper_matrix_);
}
result.stats_.solves_++;

// solution_blk in camchem
Expand Down
18 changes: 18 additions & 0 deletions include/micm/solver/linear_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,31 @@
#include <micm/solver/linear_solver_in_place.hpp>
#include <micm/util/matrix.hpp>
#include <micm/util/sparse_matrix.hpp>
#include <micm/util/sparse_matrix_vector_ordering.hpp>

#include <cmath>
#include <functional>

namespace micm
{

/// @brief Concept for in-place linear solver algorithms
template<class T, class DenseMatrixPolicy, class SparseMatrixPolicy>
concept LinearSolverInPlaceConcept = requires(T t)
{
{ t.Factor(std::declval<SparseMatrixPolicy&>()) };
{ t.Solve(std::declval<DenseMatrixPolicy&>(), SparseMatrixPolicy{}) };
};
static_assert(LinearSolverInPlaceConcept<LinearSolverInPlace<StandardSparseMatrix>, StandardDenseMatrix, StandardSparseMatrix>, "LinearSolverInPlace does not meet the LinearSolverInPlaceConcept requirements");
static_assert(LinearSolverInPlaceConcept<
LinearSolverInPlace<
SparseMatrix<double, SparseMatrixVectorOrderingCompressedSparseRow<1>>,
LuDecompositionMozartInPlace
>,
VectorMatrix<double, 1>,
SparseMatrix<double, SparseMatrixVectorOrderingCompressedSparseRow<1>>
>, "LinearSolverInPlace for vector matrices does not meet the LinearSolverInPlaceConcept requirements");

/// @brief Reorders a set of state variables using Diagonal Markowitz algorithm
/// @param matrix Original matrix non-zero elements
/// @result Reordered mapping vector (reordered[i] = original[map[i]])
Expand Down
1 change: 0 additions & 1 deletion include/micm/solver/linear_solver_in_place.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

namespace micm
{

/// @brief A general-use block-diagonal sparse-matrix linear solver
///
/// The sparsity pattern of each block in the block diagonal matrix is the same.
Expand Down
13 changes: 13 additions & 0 deletions include/micm/solver/lu_decomposition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,22 @@
#include "lu_decomposition_doolittle.hpp"
#include "lu_decomposition_mozart.hpp"
#include "lu_decomposition_mozart_in_place.hpp"
#include <micm/util/sparse_matrix.hpp>
#include <micm/util/sparse_matrix_vector_ordering.hpp>

namespace micm
{

/// @brief Concept for in-place LU decomposition algorithms
template<class T, class SparseMatrixPolicy>
concept LuDecompositionInPlaceConcept = requires(T t)
{
{ t.GetLUMatrix(SparseMatrixPolicy{}, 0.0) };
{ t.Decompose(std::declval<SparseMatrixPolicy&>()) };
};
static_assert(LuDecompositionInPlaceConcept<LuDecompositionMozartInPlace, StandardSparseMatrix>, "LuDecompositionMozartInPlace does not meet the LuDecompositionInPlaceConcept requirements");
static_assert(LuDecompositionInPlaceConcept<LuDecompositionMozartInPlace, SparseMatrix<double, SparseMatrixVectorOrderingCompressedSparseRow<1>>>, "LuDecompositionMozartInPlace for vector matrices does not meet the LuDecompositionInPlaceConcept requirements");

/// @brief Alias for the default LU decomposition algorithm
using LuDecomposition = LuDecompositionDoolittle;

Expand Down
20 changes: 14 additions & 6 deletions include/micm/solver/state.inl
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,20 @@ namespace micm

jacobian_ = BuildJacobian<SparseMatrixPolicy>(
parameters.nonzero_jacobian_elements_, parameters.number_of_grid_cells_, state_size_);

auto lu = LuDecompositionPolicy::template GetLUMatrices<SparseMatrixPolicy, LMatrixPolicy, UMatrixPolicy>(jacobian_, 0);
auto lower_matrix = std::move(lu.first);
auto upper_matrix = std::move(lu.second);
lower_matrix_ = lower_matrix;
upper_matrix_ = upper_matrix;

if constexpr (LuDecompositionInPlaceConcept<LuDecompositionPolicy, SparseMatrixPolicy>)
{
auto lu = LuDecompositionPolicy::template GetLUMatrix<SparseMatrixPolicy>(jacobian_, 0);
jacobian_ = std::move(lu);
}
else
{
auto lu = LuDecompositionPolicy::template GetLUMatrices<SparseMatrixPolicy, LMatrixPolicy, UMatrixPolicy>(jacobian_, 0);
auto lower_matrix = std::move(lu.first);
auto upper_matrix = std::move(lu.second);
lower_matrix_ = lower_matrix;
upper_matrix_ = upper_matrix;
}
}

template<class DenseMatrixPolicy, class SparseMatrixPolicy, class LuDecompositionPolicy, class LMatrixPolicy, class UMatrixPolicy>
Expand Down
32 changes: 31 additions & 1 deletion test/integration/test_analytical_backward_euler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,25 @@ template<std::size_t L>
using VectorStateTypeMozartCSC =
micm::State<micm::VectorMatrix<double, L>, micm::SparseMatrix<double, micm::SparseMatrixVectorOrderingCompressedSparseColumn<L>>, micm::LuDecompositionMozart>;

template<std::size_t L>
using VectorBackwardEulerMozartInPlace = micm::SolverBuilder<
micm::BackwardEulerSolverParameters,
micm::VectorMatrix<double, L>,
micm::SparseMatrix<double, micm::SparseMatrixVectorOrdering<L>>,
micm::ProcessSet,
micm::LuDecompositionMozartInPlace,
micm::LinearSolverInPlace<micm::SparseMatrix<double, micm::SparseMatrixVectorOrdering<L>>, micm::LuDecompositionMozartInPlace>,
micm::State<
micm::VectorMatrix<double, L>,
micm::SparseMatrix<double, micm::SparseMatrixVectorOrdering<L>>,
micm::LuDecompositionMozartInPlace>>;

template<std::size_t L>
using VectorStateTypeMozartInPlace = micm::State<
micm::VectorMatrix<double, L>,
micm::SparseMatrix<double, micm::SparseMatrixVectorOrdering<L>>,
micm::LuDecompositionMozartInPlace>;

auto backward_euler = micm::CpuSolverBuilder<micm::BackwardEulerSolverParameters>(micm::BackwardEulerSolverParameters());
auto backard_euler_vector_1 = VectorBackwardEuler<1>(micm::BackwardEulerSolverParameters());
auto backard_euler_vector_2 = VectorBackwardEuler<2>(micm::BackwardEulerSolverParameters());
Expand All @@ -87,7 +106,10 @@ auto backward_euler_vector_mozart_csc_1 = VectorBackwardEulerMozartCSC<1>(micm::
auto backward_euler_vector_mozart_csc_2 = VectorBackwardEulerMozartCSC<2>(micm::BackwardEulerSolverParameters());
auto backward_euler_vector_mozart_csc_3 = VectorBackwardEulerMozartCSC<3>(micm::BackwardEulerSolverParameters());
auto backward_euler_vector_mozart_csc_4 = VectorBackwardEulerMozartCSC<4>(micm::BackwardEulerSolverParameters());

auto backward_euler_vector_mozart_in_place_1 = VectorBackwardEulerMozartInPlace<1>(micm::BackwardEulerSolverParameters());
auto backward_euler_vector_mozart_in_place_2 = VectorBackwardEulerMozartInPlace<2>(micm::BackwardEulerSolverParameters());
auto backward_euler_vector_mozart_in_place_3 = VectorBackwardEulerMozartInPlace<3>(micm::BackwardEulerSolverParameters());
auto backward_euler_vector_mozart_in_place_4 = VectorBackwardEulerMozartInPlace<4>(micm::BackwardEulerSolverParameters());

TEST(AnalyticalExamples, Troe)
{
Expand Down Expand Up @@ -116,6 +138,14 @@ TEST(AnalyticalExamples, Troe)
test_analytical_troe<VectorBackwardEulerMozartCSC<2>, VectorStateTypeMozartCSC<2>>(backward_euler_vector_mozart_csc_2, 1e-6);
test_analytical_troe<VectorBackwardEulerMozartCSC<3>, VectorStateTypeMozartCSC<3>>(backward_euler_vector_mozart_csc_3, 1e-6);
test_analytical_troe<VectorBackwardEulerMozartCSC<4>, VectorStateTypeMozartCSC<4>>(backward_euler_vector_mozart_csc_4, 1e-6);
test_analytical_troe<VectorBackwardEulerMozartInPlace<1>, VectorStateTypeMozartInPlace<1>>(
backward_euler_vector_mozart_in_place_1, 1e-6);
test_analytical_troe<VectorBackwardEulerMozartInPlace<2>, VectorStateTypeMozartInPlace<2>>(
backward_euler_vector_mozart_in_place_2, 1e-6);
test_analytical_troe<VectorBackwardEulerMozartInPlace<3>, VectorStateTypeMozartInPlace<3>>(
backward_euler_vector_mozart_in_place_3, 1e-6);
test_analytical_troe<VectorBackwardEulerMozartInPlace<4>, VectorStateTypeMozartInPlace<4>>(
backward_euler_vector_mozart_in_place_4, 1e-6);
}

TEST(AnalyticalExamples, TroeSuperStiffButAnalytical)
Expand Down

0 comments on commit b529224

Please sign in to comment.