-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add in-place LU decomposition option for linear solve (#698)
* add in-place MOZART LU decomp * add in-place linear solver * allow in-place linear solve for Backward Euler
- Loading branch information
1 parent
19425fb
commit 706a34a
Showing
16 changed files
with
1,463 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// Copyright (C) 2023-2024 National Center for Atmospheric Research | ||
// SPDX-License-Identifier: Apache-2.0 | ||
#pragma once | ||
|
||
#include <micm/profiler/instrumentation.hpp> | ||
#include <micm/solver/lu_decomposition.hpp> | ||
#include <micm/util/matrix.hpp> | ||
#include <micm/util/sparse_matrix.hpp> | ||
|
||
#include <cmath> | ||
#include <functional> | ||
|
||
namespace micm | ||
{ | ||
/// @brief A general-use block-diagonal sparse-matrix linear solver | ||
/// | ||
/// The sparsity pattern of each block in the block diagonal matrix is the same. | ||
/// The L and U matrices are decomposed in-place over the original A matrix. | ||
template<class SparseMatrixPolicy, class LuDecompositionPolicy = LuDecompositionInPlace> | ||
class LinearSolverInPlace | ||
{ | ||
protected: | ||
// Parameters needed to calculate L (U x) = b | ||
// | ||
// The calculation is split into calculation of L y = b where y = U x: | ||
// | ||
// y_1 = b_1 / L_11 | ||
// y_i = 1 / L_ii * [ b_i - sum( j = 1...i-1 ){ L_ij * y_j } ] i = 2...N | ||
// | ||
// ... and then U x = y: | ||
// | ||
// x_N = y_N / U_NN | ||
// x_i = 1 / U_ii * [ y_i - sum( j = i+1...N ){ U_ij * x_j } ] i = N-1...1 | ||
|
||
// Number of non-zero elements (excluding the diagonal) for each row in L | ||
std::vector<std::size_t> nLij_; | ||
// Indices of non-zero combinations of L_ij and y_j | ||
std::vector<std::pair<std::size_t, std::size_t>> Lij_yj_; | ||
// Number of non-zero elements (exluding the diagonal) and the index of the diagonal | ||
// element for each row in U (in reverse order) | ||
std::vector<std::pair<std::size_t, std::size_t>> nUij_Uii_; | ||
// Indices of non-zero combinations of U_ij and x_j | ||
std::vector<std::pair<std::size_t, std::size_t>> Uij_xj_; | ||
|
||
LuDecompositionPolicy lu_decomp_; | ||
|
||
public: | ||
/// @brief default constructor | ||
LinearSolverInPlace(){}; | ||
|
||
LinearSolverInPlace(const LinearSolverInPlace&) = delete; | ||
LinearSolverInPlace& operator=(const LinearSolverInPlace&) = delete; | ||
LinearSolverInPlace(LinearSolverInPlace&&) = default; | ||
LinearSolverInPlace& operator=(LinearSolverInPlace&&) = default; | ||
|
||
/// @brief Constructs a linear solver for the sparsity structure of the given matrix | ||
/// @param matrix Sparse matrix | ||
/// @param initial_value Initial value for matrix elements | ||
LinearSolverInPlace(const SparseMatrixPolicy& matrix, typename SparseMatrixPolicy::value_type initial_value); | ||
|
||
/// @brief Constructs a linear solver for the sparsity structure of the given matrix | ||
/// @param matrix Sparse matrix | ||
/// @param initial_value Initial value for matrix elements | ||
/// @param create_lu_decomp Function to create an LU Decomposition object that adheres to LuDecompositionPolicy | ||
LinearSolverInPlace( | ||
const SparseMatrixPolicy& matrix, | ||
typename SparseMatrixPolicy::value_type initial_value, | ||
const std::function<LuDecompositionPolicy(const SparseMatrixPolicy&)> create_lu_decomp); | ||
|
||
virtual ~LinearSolverInPlace() = default; | ||
|
||
/// @brief Decompose the matrix into upper and lower triangular matrices (matrix will be overwritten) | ||
/// @param matrix Matrix to decompose in-place into lower and upper triangular matrices | ||
void Factor(SparseMatrixPolicy& matrix) const; | ||
|
||
/// @brief Solve for x in Ax = b. x should be a copy of b and after Solve finishes x will contain the result | ||
/// @param x The solution vector | ||
/// @param LU The LU decomposition of the matrix as a square sparse matrix | ||
template<class MatrixPolicy> | ||
requires(!VectorizableDense<MatrixPolicy> || !VectorizableSparse<SparseMatrixPolicy>) void Solve( | ||
MatrixPolicy& x, | ||
const SparseMatrixPolicy& lu_matrix) const; | ||
template<class MatrixPolicy> | ||
requires(VectorizableDense<MatrixPolicy>&& VectorizableSparse<SparseMatrixPolicy>) void Solve( | ||
MatrixPolicy& x, | ||
const SparseMatrixPolicy& lu_matrix) const; | ||
}; | ||
|
||
} // namespace micm | ||
|
||
#include "linear_solver_in_place.inl" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
// Copyright (C) 2023-2024 National Center for Atmospheric Research | ||
// SPDX-License-Identifier: Apache-2.0 | ||
namespace micm | ||
{ | ||
template<class SparseMatrixPolicy, class LuDecompositionPolicy> | ||
inline LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>::LinearSolverInPlace( | ||
const SparseMatrixPolicy& matrix, | ||
typename SparseMatrixPolicy::value_type initial_value) | ||
: LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>( | ||
matrix, | ||
initial_value, | ||
[](const SparseMatrixPolicy& m) -> LuDecompositionPolicy | ||
{ return LuDecompositionPolicy::template Create<SparseMatrixPolicy>(m); }) | ||
{ | ||
} | ||
|
||
template<class SparseMatrixPolicy, class LuDecompositionPolicy> | ||
inline LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>::LinearSolverInPlace( | ||
const SparseMatrixPolicy& matrix, | ||
typename SparseMatrixPolicy::value_type initial_value, | ||
const std::function<LuDecompositionPolicy(const SparseMatrixPolicy&)> create_lu_decomp) | ||
: nLij_(), | ||
Lij_yj_(), | ||
nUij_Uii_(), | ||
Uij_xj_(), | ||
lu_decomp_(create_lu_decomp(matrix)) | ||
{ | ||
MICM_PROFILE_FUNCTION(); | ||
|
||
auto lu = lu_decomp_.template GetLUMatrix<SparseMatrixPolicy>(matrix, initial_value); | ||
for (std::size_t i = 0; i < lu.NumRows(); ++i) | ||
{ | ||
std::size_t nLij = 0; | ||
for (std::size_t j = 0; j < i; ++j) | ||
{ | ||
if (lu.IsZero(i, j)) | ||
continue; | ||
Lij_yj_.push_back(std::make_pair(lu.VectorIndex(0, i, j), j)); | ||
++nLij; | ||
} | ||
nLij_.push_back(nLij); | ||
} | ||
for (std::size_t i = lu.NumRows() - 1; i != static_cast<std::size_t>(-1); --i) | ||
{ | ||
std::size_t nUij = 0; | ||
for (std::size_t j = i + 1; j < lu.NumColumns(); ++j) | ||
{ | ||
if (lu.IsZero(i, j)) | ||
continue; | ||
Uij_xj_.push_back(std::make_pair(lu.VectorIndex(0, i, j), j)); | ||
++nUij; | ||
} | ||
// There must always be a non-zero element on the diagonal | ||
nUij_Uii_.push_back(std::make_pair(nUij, lu.VectorIndex(0, i, i))); | ||
} | ||
}; | ||
|
||
template<class SparseMatrixPolicy, class LuDecompositionPolicy> | ||
inline void LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>::Factor( | ||
SparseMatrixPolicy& matrix) const | ||
{ | ||
MICM_PROFILE_FUNCTION(); | ||
|
||
lu_decomp_.template Decompose<SparseMatrixPolicy>(matrix); | ||
} | ||
|
||
template<class SparseMatrixPolicy, class LuDecompositionPolicy> | ||
template<class MatrixPolicy> | ||
requires( | ||
!VectorizableDense<MatrixPolicy> || | ||
!VectorizableSparse<SparseMatrixPolicy>) inline void LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>:: | ||
Solve(MatrixPolicy& x, const SparseMatrixPolicy& lu_matrix) const | ||
{ | ||
MICM_PROFILE_FUNCTION(); | ||
|
||
for (std::size_t i_cell = 0; i_cell < x.NumRows(); ++i_cell) | ||
{ | ||
auto x_cell = x[i_cell]; | ||
const std::size_t grid_offset = i_cell * lu_matrix.FlatBlockSize(); | ||
auto& y_cell = x_cell; // Alias x for consistency with equations, but to reuse memory | ||
|
||
// Forward Substitution | ||
{ | ||
auto y_elem = y_cell.begin(); | ||
auto Lij_yj = Lij_yj_.begin(); | ||
for (auto& nLij : nLij_) | ||
{ | ||
for (std::size_t i = 0; i < nLij; ++i) | ||
{ | ||
*y_elem -= lu_matrix.AsVector()[grid_offset + (*Lij_yj).first] * y_cell[(*Lij_yj).second]; | ||
++Lij_yj; | ||
} | ||
++y_elem; | ||
} | ||
} | ||
|
||
// Backward Substitution | ||
{ | ||
auto x_elem = std::next(x_cell.end(), -1); | ||
auto Uij_xj = Uij_xj_.begin(); | ||
for (auto& nUij_Uii : nUij_Uii_) | ||
{ | ||
// x_elem starts out as y_elem from the previous loop | ||
for (std::size_t i = 0; i < nUij_Uii.first; ++i) | ||
{ | ||
*x_elem -= lu_matrix.AsVector()[grid_offset + (*Uij_xj).first] * x_cell[(*Uij_xj).second]; | ||
++Uij_xj; | ||
} | ||
|
||
*(x_elem) /= lu_matrix.AsVector()[grid_offset + nUij_Uii.second]; | ||
// don't iterate before the beginning of the vector | ||
if (x_elem != x_cell.begin()) | ||
{ | ||
--x_elem; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
template<class SparseMatrixPolicy, class LuDecompositionPolicy> | ||
template<class MatrixPolicy> | ||
requires(VectorizableDense<MatrixPolicy>&& | ||
VectorizableSparse<SparseMatrixPolicy>) inline void LinearSolverInPlace<SparseMatrixPolicy, LuDecompositionPolicy>:: | ||
Solve(MatrixPolicy& x, const SparseMatrixPolicy& lu_matrix) const | ||
{ | ||
MICM_PROFILE_FUNCTION(); | ||
constexpr std::size_t n_cells = MatrixPolicy::GroupVectorSize(); | ||
// Loop over groups of blocks | ||
for (std::size_t i_group = 0; i_group < x.NumberOfGroups(); ++i_group) | ||
{ | ||
auto x_group = std::next(x.AsVector().begin(), i_group * x.GroupSize()); | ||
auto LU_group = | ||
std::next(lu_matrix.AsVector().begin(), i_group * lu_matrix.GroupSize()); | ||
// Forward Substitution | ||
{ | ||
auto y_elem = x_group; | ||
auto Lij_yj = Lij_yj_.begin(); | ||
for (auto& nLij : nLij_) | ||
{ | ||
for (std::size_t i = 0; i < nLij; ++i) | ||
{ | ||
const std::size_t Lij_yj_first = (*Lij_yj).first; | ||
const std::size_t Lij_yj_second_times_n_cells = (*Lij_yj).second * n_cells; | ||
for (std::size_t i_cell = 0; i_cell < n_cells; ++i_cell) | ||
y_elem[i_cell] -= LU_group[Lij_yj_first + i_cell] * x_group[Lij_yj_second_times_n_cells + i_cell]; | ||
++Lij_yj; | ||
} | ||
y_elem += n_cells; | ||
} | ||
} | ||
|
||
// Backward Substitution | ||
{ | ||
auto x_elem = std::next(x_group, x.GroupSize() - n_cells); | ||
auto Uij_xj = Uij_xj_.begin(); | ||
for (auto& nUij_Uii : nUij_Uii_) | ||
{ | ||
// x_elem starts out as y_elem from the previous loop | ||
for (std::size_t i = 0; i < nUij_Uii.first; ++i) | ||
{ | ||
const std::size_t Uij_xj_first = (*Uij_xj).first; | ||
const std::size_t Uij_xj_second_times_n_cells = (*Uij_xj).second * n_cells; | ||
for (std::size_t i_cell = 0; i_cell < n_cells; ++i_cell) | ||
x_elem[i_cell] -= LU_group[Uij_xj_first + i_cell] * x_group[Uij_xj_second_times_n_cells + i_cell]; | ||
++Uij_xj; | ||
} | ||
const std::size_t nUij_Uii_second = nUij_Uii.second; | ||
for (std::size_t i_cell = 0; i_cell < n_cells; ++i_cell) | ||
x_elem[i_cell] /= LU_group[nUij_Uii_second + i_cell]; | ||
|
||
// don't iterate before the beginning of the vector | ||
const std::size_t x_elem_distance = std::distance(x.AsVector().begin(), x_elem); | ||
x_elem -= std::min(n_cells, x_elem_distance); | ||
} | ||
} | ||
} | ||
} | ||
} // namespace micm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.