Skip to content

Commit

Permalink
Add in-place LU decomposition option for linear solve (#698)
Browse files Browse the repository at this point in the history
* add in-place MOZART LU decomp

* add in-place linear solver

* allow in-place linear solve for Backward Euler
  • Loading branch information
mattldawson authored Jan 9, 2025
1 parent 19425fb commit 706a34a
Show file tree
Hide file tree
Showing 16 changed files with 1,463 additions and 14 deletions.
23 changes: 19 additions & 4 deletions include/micm/solver/backward_euler.inl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ namespace micm
// if the last attempt to reduce the timestep fails,
// accept the current H but do not update the Yn vector

using MatrixPolicy = decltype(state.variables_);
using DenseMatrixPolicy = decltype(state.variables_);
using SparseMatrixPolicy = decltype(state.jacobian_);

SolverResult result;

Expand All @@ -67,7 +68,7 @@ namespace micm
std::size_t n_convergence_failures = 0;

auto derived_class_temporary_variables =
static_cast<BackwardEulerTemporaryVariables<MatrixPolicy>*>(state.temporary_variables_.get());
static_cast<BackwardEulerTemporaryVariables<DenseMatrixPolicy>*>(state.temporary_variables_.get());
auto& Yn = derived_class_temporary_variables->Yn_;
auto& Yn1 = state.variables_; // Yn1 will hold the new solution at the end of the solve
auto& forcing = derived_class_temporary_variables->forcing_;
Expand Down Expand Up @@ -110,7 +111,14 @@ namespace micm
// (y_{n+1} - y_n) / H = f(t_{n+1}, y_{n+1})

// try to find the root by factoring and solving the linear system
linear_solver_.Factor(state.jacobian_, state.lower_matrix_, state.upper_matrix_);
if constexpr (LinearSolverInPlaceConcept<LinearSolverPolicy, DenseMatrixPolicy, SparseMatrixPolicy>)
{
linear_solver_.Factor(state.jacobian_);
}
else
{
linear_solver_.Factor(state.jacobian_, state.lower_matrix_, state.upper_matrix_);
}
result.stats_.decompositions_++;

// forcing_blk in camchem
Expand All @@ -120,7 +128,14 @@ namespace micm

// the result of the linear solver will be stored in forcing
// this represents the change in the solution
linear_solver_.Solve(forcing, state.lower_matrix_, state.upper_matrix_);
if constexpr (LinearSolverInPlaceConcept<LinearSolverPolicy, DenseMatrixPolicy, SparseMatrixPolicy>)
{
linear_solver_.Solve(forcing, state.jacobian_);
}
else
{
linear_solver_.Solve(forcing, state.lower_matrix_, state.upper_matrix_);
}
result.stats_.solves_++;

// solution_blk in camchem
Expand Down
19 changes: 19 additions & 0 deletions include/micm/solver/linear_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,34 @@

#include <micm/profiler/instrumentation.hpp>
#include <micm/solver/lu_decomposition.hpp>
#include <micm/solver/linear_solver_in_place.hpp>
#include <micm/util/matrix.hpp>
#include <micm/util/sparse_matrix.hpp>
#include <micm/util/sparse_matrix_vector_ordering.hpp>

#include <cmath>
#include <functional>

namespace micm
{

/// @brief Concept for in-place linear solver algorithms
template<class T, class DenseMatrixPolicy, class SparseMatrixPolicy>
concept LinearSolverInPlaceConcept = requires(T t)
{
{ t.Factor(std::declval<SparseMatrixPolicy&>()) };
{ t.Solve(std::declval<DenseMatrixPolicy&>(), SparseMatrixPolicy{}) };
};
static_assert(LinearSolverInPlaceConcept<LinearSolverInPlace<StandardSparseMatrix>, StandardDenseMatrix, StandardSparseMatrix>, "LinearSolverInPlace does not meet the LinearSolverInPlaceConcept requirements");
static_assert(LinearSolverInPlaceConcept<
LinearSolverInPlace<
SparseMatrix<double, SparseMatrixVectorOrderingCompressedSparseRow<1>>,
LuDecompositionMozartInPlace
>,
VectorMatrix<double, 1>,
SparseMatrix<double, SparseMatrixVectorOrderingCompressedSparseRow<1>>
>, "LinearSolverInPlace for vector matrices does not meet the LinearSolverInPlaceConcept requirements");

/// @brief Reorders a set of state variables using Diagonal Markowitz algorithm
/// @param matrix Original matrix non-zero elements
/// @result Reordered mapping vector (reordered[i] = original[map[i]])
Expand Down
91 changes: 91 additions & 0 deletions include/micm/solver/linear_solver_in_place.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (C) 2023-2024 National Center for Atmospheric Research
// SPDX-License-Identifier: Apache-2.0
#pragma once

#include <micm/profiler/instrumentation.hpp>
#include <micm/solver/lu_decomposition.hpp>
#include <micm/util/matrix.hpp>
#include <micm/util/sparse_matrix.hpp>

#include <cmath>
#include <functional>

namespace micm
{
/// @brief A general-use block-diagonal sparse-matrix linear solver
///
/// The sparsity pattern of each block in the block diagonal matrix is the same.
/// The L and U matrices are decomposed in-place over the original A matrix.
template<class SparseMatrixPolicy, class LuDecompositionPolicy = LuDecompositionInPlace>
class LinearSolverInPlace
{
protected:
// Parameters needed to calculate L (U x) = b
//
// The calculation is split into calculation of L y = b where y = U x:
//
// y_1 = b_1 / L_11
// y_i = 1 / L_ii * [ b_i - sum( j = 1...i-1 ){ L_ij * y_j } ] i = 2...N
//
// ... and then U x = y:
//
// x_N = y_N / U_NN
// x_i = 1 / U_ii * [ y_i - sum( j = i+1...N ){ U_ij * x_j } ] i = N-1...1

// Number of non-zero elements (excluding the diagonal) for each row in L
std::vector<std::size_t> nLij_;
// Indices of non-zero combinations of L_ij and y_j
std::vector<std::pair<std::size_t, std::size_t>> Lij_yj_;
// Number of non-zero elements (exluding the diagonal) and the index of the diagonal
// element for each row in U (in reverse order)
std::vector<std::pair<std::size_t, std::size_t>> nUij_Uii_;
// Indices of non-zero combinations of U_ij and x_j
std::vector<std::pair<std::size_t, std::size_t>> Uij_xj_;

LuDecompositionPolicy lu_decomp_;

public:
/// @brief default constructor
LinearSolverInPlace(){};

LinearSolverInPlace(const LinearSolverInPlace&) = delete;
LinearSolverInPlace& operator=(const LinearSolverInPlace&) = delete;
LinearSolverInPlace(LinearSolverInPlace&&) = default;
LinearSolverInPlace& operator=(LinearSolverInPlace&&) = default;

/// @brief Constructs a linear solver for the sparsity structure of the given matrix
/// @param matrix Sparse matrix
/// @param initial_value Initial value for matrix elements
LinearSolverInPlace(const SparseMatrixPolicy& matrix, typename SparseMatrixPolicy::value_type initial_value);

/// @brief Constructs a linear solver for the sparsity structure of the given matrix
/// @param matrix Sparse matrix
/// @param initial_value Initial value for matrix elements
/// @param create_lu_decomp Function to create an LU Decomposition object that adheres to LuDecompositionPolicy
LinearSolverInPlace(
const SparseMatrixPolicy& matrix,
typename SparseMatrixPolicy::value_type initial_value,
const std::function<LuDecompositionPolicy(const SparseMatrixPolicy&)> create_lu_decomp);

virtual ~LinearSolverInPlace() = default;

/// @brief Decompose the matrix into upper and lower triangular matrices (matrix will be overwritten)
/// @param matrix Matrix to decompose in-place into lower and upper triangular matrices
void Factor(SparseMatrixPolicy& matrix) const;

/// @brief Solve for x in Ax = b. x should be a copy of b and after Solve finishes x will contain the result
/// @param x The solution vector
/// @param LU The LU decomposition of the matrix as a square sparse matrix
template<class MatrixPolicy>
requires(!VectorizableDense<MatrixPolicy> || !VectorizableSparse<SparseMatrixPolicy>) void Solve(
MatrixPolicy& x,
const SparseMatrixPolicy& lu_matrix) const;
template<class MatrixPolicy>
requires(VectorizableDense<MatrixPolicy>&& VectorizableSparse<SparseMatrixPolicy>) void Solve(
MatrixPolicy& x,
const SparseMatrixPolicy& lu_matrix) const;
};

} // namespace micm

#include "linear_solver_in_place.inl"
179 changes: 179 additions & 0 deletions include/micm/solver/linear_solver_in_place.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
// Copyright (C) 2023-2024 National Center for Atmospheric Research
// SPDX-License-Identifier: Apache-2.0
namespace micm
{
template<class SparseMatrixPolicy, class LuDecompositionPolicy>
inline LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>::LinearSolverInPlace(
const SparseMatrixPolicy& matrix,
typename SparseMatrixPolicy::value_type initial_value)
: LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>(
matrix,
initial_value,
[](const SparseMatrixPolicy& m) -> LuDecompositionPolicy
{ return LuDecompositionPolicy::template Create<SparseMatrixPolicy>(m); })
{
}

template<class SparseMatrixPolicy, class LuDecompositionPolicy>
inline LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>::LinearSolverInPlace(
const SparseMatrixPolicy& matrix,
typename SparseMatrixPolicy::value_type initial_value,
const std::function<LuDecompositionPolicy(const SparseMatrixPolicy&)> create_lu_decomp)
: nLij_(),
Lij_yj_(),
nUij_Uii_(),
Uij_xj_(),
lu_decomp_(create_lu_decomp(matrix))
{
MICM_PROFILE_FUNCTION();

auto lu = lu_decomp_.template GetLUMatrix<SparseMatrixPolicy>(matrix, initial_value);
for (std::size_t i = 0; i < lu.NumRows(); ++i)
{
std::size_t nLij = 0;
for (std::size_t j = 0; j < i; ++j)
{
if (lu.IsZero(i, j))
continue;
Lij_yj_.push_back(std::make_pair(lu.VectorIndex(0, i, j), j));
++nLij;
}
nLij_.push_back(nLij);
}
for (std::size_t i = lu.NumRows() - 1; i != static_cast<std::size_t>(-1); --i)
{
std::size_t nUij = 0;
for (std::size_t j = i + 1; j < lu.NumColumns(); ++j)
{
if (lu.IsZero(i, j))
continue;
Uij_xj_.push_back(std::make_pair(lu.VectorIndex(0, i, j), j));
++nUij;
}
// There must always be a non-zero element on the diagonal
nUij_Uii_.push_back(std::make_pair(nUij, lu.VectorIndex(0, i, i)));
}
};

template<class SparseMatrixPolicy, class LuDecompositionPolicy>
inline void LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>::Factor(
SparseMatrixPolicy& matrix) const
{
MICM_PROFILE_FUNCTION();

lu_decomp_.template Decompose<SparseMatrixPolicy>(matrix);
}

template<class SparseMatrixPolicy, class LuDecompositionPolicy>
template<class MatrixPolicy>
requires(
!VectorizableDense<MatrixPolicy> ||
!VectorizableSparse<SparseMatrixPolicy>) inline void LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>::
Solve(MatrixPolicy& x, const SparseMatrixPolicy& lu_matrix) const
{
MICM_PROFILE_FUNCTION();

for (std::size_t i_cell = 0; i_cell < x.NumRows(); ++i_cell)
{
auto x_cell = x[i_cell];
const std::size_t grid_offset = i_cell * lu_matrix.FlatBlockSize();
auto& y_cell = x_cell; // Alias x for consistency with equations, but to reuse memory

// Forward Substitution
{
auto y_elem = y_cell.begin();
auto Lij_yj = Lij_yj_.begin();
for (auto& nLij : nLij_)
{
for (std::size_t i = 0; i < nLij; ++i)
{
*y_elem -= lu_matrix.AsVector()[grid_offset + (*Lij_yj).first] * y_cell[(*Lij_yj).second];
++Lij_yj;
}
++y_elem;
}
}

// Backward Substitution
{
auto x_elem = std::next(x_cell.end(), -1);
auto Uij_xj = Uij_xj_.begin();
for (auto& nUij_Uii : nUij_Uii_)
{
// x_elem starts out as y_elem from the previous loop
for (std::size_t i = 0; i < nUij_Uii.first; ++i)
{
*x_elem -= lu_matrix.AsVector()[grid_offset + (*Uij_xj).first] * x_cell[(*Uij_xj).second];
++Uij_xj;
}

*(x_elem) /= lu_matrix.AsVector()[grid_offset + nUij_Uii.second];
// don't iterate before the beginning of the vector
if (x_elem != x_cell.begin())
{
--x_elem;
}
}
}
}
}

template<class SparseMatrixPolicy, class LuDecompositionPolicy>
template<class MatrixPolicy>
requires(VectorizableDense<MatrixPolicy>&&
VectorizableSparse<SparseMatrixPolicy>) inline void LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>::
Solve(MatrixPolicy& x, const SparseMatrixPolicy& lu_matrix) const
{
MICM_PROFILE_FUNCTION();
constexpr std::size_t n_cells = MatrixPolicy::GroupVectorSize();
// Loop over groups of blocks
for (std::size_t i_group = 0; i_group < x.NumberOfGroups(); ++i_group)
{
auto x_group = std::next(x.AsVector().begin(), i_group * x.GroupSize());
auto LU_group =
std::next(lu_matrix.AsVector().begin(), i_group * lu_matrix.GroupSize());
// Forward Substitution
{
auto y_elem = x_group;
auto Lij_yj = Lij_yj_.begin();
for (auto& nLij : nLij_)
{
for (std::size_t i = 0; i < nLij; ++i)
{
const std::size_t Lij_yj_first = (*Lij_yj).first;
const std::size_t Lij_yj_second_times_n_cells = (*Lij_yj).second * n_cells;
for (std::size_t i_cell = 0; i_cell < n_cells; ++i_cell)
y_elem[i_cell] -= LU_group[Lij_yj_first + i_cell] * x_group[Lij_yj_second_times_n_cells + i_cell];
++Lij_yj;
}
y_elem += n_cells;
}
}

// Backward Substitution
{
auto x_elem = std::next(x_group, x.GroupSize() - n_cells);
auto Uij_xj = Uij_xj_.begin();
for (auto& nUij_Uii : nUij_Uii_)
{
// x_elem starts out as y_elem from the previous loop
for (std::size_t i = 0; i < nUij_Uii.first; ++i)
{
const std::size_t Uij_xj_first = (*Uij_xj).first;
const std::size_t Uij_xj_second_times_n_cells = (*Uij_xj).second * n_cells;
for (std::size_t i_cell = 0; i_cell < n_cells; ++i_cell)
x_elem[i_cell] -= LU_group[Uij_xj_first + i_cell] * x_group[Uij_xj_second_times_n_cells + i_cell];
++Uij_xj;
}
const std::size_t nUij_Uii_second = nUij_Uii.second;
for (std::size_t i_cell = 0; i_cell < n_cells; ++i_cell)
x_elem[i_cell] /= LU_group[nUij_Uii_second + i_cell];

// don't iterate before the beginning of the vector
const std::size_t x_elem_distance = std::distance(x.AsVector().begin(), x_elem);
x_elem -= std::min(n_cells, x_elem_distance);
}
}
}
}
} // namespace micm
17 changes: 17 additions & 0 deletions include/micm/solver/lu_decomposition.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,26 @@

#include "lu_decomposition_doolittle.hpp"
#include "lu_decomposition_mozart.hpp"
#include "lu_decomposition_mozart_in_place.hpp"
#include <micm/util/sparse_matrix.hpp>
#include <micm/util/sparse_matrix_vector_ordering.hpp>

namespace micm
{

/// @brief Concept for in-place LU decomposition algorithms
template<class T, class SparseMatrixPolicy>
concept LuDecompositionInPlaceConcept = requires(T t)
{
{ t.GetLUMatrix(SparseMatrixPolicy{}, 0.0) };
{ t.Decompose(std::declval<SparseMatrixPolicy&>()) };
};
static_assert(LuDecompositionInPlaceConcept<LuDecompositionMozartInPlace, StandardSparseMatrix>, "LuDecompositionMozartInPlace does not meet the LuDecompositionInPlaceConcept requirements");
static_assert(LuDecompositionInPlaceConcept<LuDecompositionMozartInPlace, SparseMatrix<double, SparseMatrixVectorOrderingCompressedSparseRow<1>>>, "LuDecompositionMozartInPlace for vector matrices does not meet the LuDecompositionInPlaceConcept requirements");

/// @brief Alias for the default LU decomposition algorithm
using LuDecomposition = LuDecompositionDoolittle;

/// @brief Alias for the default in-place LU decomposition algorithm
using LuDecompositionInPlace = LuDecompositionMozartInPlace;
} // namespace micm
Loading

0 comments on commit 706a34a

Please sign in to comment.