Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gmres for linear solver and matrix-free newton krylov solver #704

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/*
//@HEADER
// ************************************************************************
//
// solvers_linear_eigen_iterative_matrix_free_impl.hpp
// Pressio
// Copyright 2019
// National Technology & Engineering Solutions of Sandia, LLC (NTESS)
//
// Under the terms of Contract DE-NA0003525 with NTESS, the
// U.S. Government retains certain rights in this software.
//
// Pressio is licensed under BSD-3-Clause terms of use:
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
// IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact Francesco Rizzi ([email protected])
//
// ************************************************************************
//@HEADER
*/

#ifndef PRESSIO_SOLVERS_LINEAR_IMPL_SOLVERS_LINEAR_EIGEN_ITERATIVE_MATRIX_FREE_IMPL_HPP_
#define PRESSIO_SOLVERS_LINEAR_IMPL_SOLVERS_LINEAR_EIGEN_ITERATIVE_MATRIX_FREE_IMPL_HPP_

#include "solvers_linear_iterative_base.hpp"
#include <Eigen/Core>
#include <Eigen/Dense>

namespace pressio { namespace linearsolvers{

template<typename UserDefinedLinearOperatorType>
class OperatorWrapper;

}}

namespace Eigen {
namespace internal {
template<typename UserDefinedLinearOperatorType>
struct traits< pressio::linearsolvers::OperatorWrapper<UserDefinedLinearOperatorType> >
: public Eigen::internal::traits<
Eigen::Matrix<typename UserDefinedLinearOperatorType::scalar_type,-1,-1>
>
{};
}
}

namespace pressio { namespace linearsolvers{

template<typename UserDefinedLinearOperatorType>
class OperatorWrapper :
public Eigen::EigenBase<
OperatorWrapper<UserDefinedLinearOperatorType>
>
{
public:
using Scalar = typename UserDefinedLinearOperatorType::scalar_type;
using RealScalar = Scalar;
using StorageIndex = int;
enum {
ColsAtCompileTime = Eigen::Dynamic,
MaxColsAtCompileTime = Eigen::Dynamic,
};

OperatorWrapper() = default;

OperatorWrapper(UserDefinedLinearOperatorType const & valueIn)
: m_userOperator(&valueIn) {}


int rows() const { return m_userOperator->rows(); }
int cols() const { return m_userOperator->cols(); }

template<typename Rhs>
Eigen::Product<OperatorWrapper<UserDefinedLinearOperatorType>, Rhs, Eigen::AliasFreeProduct>
operator*(const Eigen::MatrixBase<Rhs>& x) const{
using r_t = Eigen::Product<
OperatorWrapper<UserDefinedLinearOperatorType>, Rhs, Eigen::AliasFreeProduct
>;
return r_t(*this, x.derived());
}

void replace(const UserDefinedLinearOperatorType & opIn) {
m_userOperator = &opIn;
}

template<class OperandT, class ResultT>
void applyAndAddTo(OperandT const & operand, ResultT & out) const {
// compute: out += operator * operand
m_userOperator->applyAndAddTo(operand, out);
}

private:
UserDefinedLinearOperatorType const *m_userOperator = nullptr;
};

}} // end namespace pressio::linearsolvers

namespace Eigen {
namespace internal {

template<typename Rhs, typename UserDefinedOpT>
struct generic_product_impl<
pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>,
Rhs, DenseShape, DenseShape, GemvProduct
>
: generic_product_impl_base<
pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>, Rhs,
generic_product_impl<pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>, Rhs>
>
{
using Scalar = typename Product<
pressio::linearsolvers::OperatorWrapper<UserDefinedOpT>,
Rhs
>::Scalar;

template<typename Dest>
static void scaleAndAddTo(
Dest& dst,
const pressio::linearsolvers::OperatorWrapper<UserDefinedOpT> & lhs,
const Rhs& rhs,
const Scalar& alpha)
{
// This method should implement "dst += alpha * lhs * rhs" inplace,
// however, for iterative solvers, alpha is always equal to 1,
// so let's not bother about it.
assert(alpha==Scalar(1) && "scaling is not implemented");
EIGEN_ONLY_USED_FOR_DEBUG(alpha);

lhs.applyAndAddTo(rhs, dst);
}
};
}
}

namespace pressio { namespace linearsolvers{ namespace impl{

template<typename TagType, typename UserDefinedLinearOperatorType>
class EigenIterativeMatrixFree
: public IterativeBase<
EigenIterativeMatrixFree<TagType, UserDefinedLinearOperatorType>
>
{

public:
using this_type = EigenIterative<TagType, UserDefinedLinearOperatorType>;
using scalar_type = typename UserDefinedLinearOperatorType::scalar_type;
using solver_traits = ::pressio::linearsolvers::Traits<TagType>;
using op_wrapper_t = OperatorWrapper<UserDefinedLinearOperatorType>;
using native_solver_type = typename solver_traits::template eigen_solver_type<op_wrapper_t>;
using base_iterative_type = IterativeBase<this_type>;
using iteration_type = typename base_iterative_type::iteration_type;

static_assert(solver_traits::eigen_enabled == true,
"the native solver must be from Eigen to use in EigenIterativeMatrixFree");
static_assert(solver_traits::direct == false,
"The native eigen solver must be iterative to use in EigenIterativeMatrixFree");

public:
EigenIterativeMatrixFree() = default;

iteration_type numIterationsExecuted() const{
return mysolver_.iterations();
}

scalar_type finalError() const{
return mysolver_.error();
}

void resetLinearSystem(const UserDefinedLinearOperatorType& A)
{
mysolver_.setMaxIterations(this->maxIters_);
m_wrapper.replace(A);
mysolver_.compute(m_wrapper);
}

template <typename T>
void solve(const T& b, T & y){
mysolver_.setMaxIterations(this->maxIters_);
y = mysolver_.solve(b);
}

template <typename T>
void solve(const UserDefinedLinearOperatorType & A, const T& b, T & y){
this->resetLinearSystem(A);
this->solve(b, y);
}

private:
friend base_iterative_type;
native_solver_type mysolver_ = {};
op_wrapper_t m_wrapper;
};

}}} // end namespace pressio::solvers::iterarive::impl
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#ifdef PRESSIO_ENABLE_TPL_EIGEN
#include "solvers_linear_eigen_direct_impl.hpp"
#include "solvers_linear_eigen_iterative_impl.hpp"
#include "solvers_linear_eigen_iterative_matrix_free_impl.hpp"
#endif
#ifdef PRESSIO_ENABLE_TPL_KOKKOS
#include "solvers_linear_kokkos_direct_geqrf_impl.hpp"
Expand All @@ -69,6 +70,16 @@ struct Selector{
};

#ifdef PRESSIO_ENABLE_TPL_EIGEN
template<typename UserDefinedLinearOperatorType>
struct Selector<
iterative::GMRES, UserDefinedLinearOperatorType, void
>
{
using tag_t = iterative::GMRES;
using solver_traits = ::pressio::linearsolvers::Traits<tag_t>;
using type = EigenIterativeMatrixFree<tag_t, UserDefinedLinearOperatorType>;
};

template<typename TagType, typename MatrixType>
struct Selector<
TagType, MatrixType,
Expand Down
18 changes: 18 additions & 0 deletions include/pressio/solvers_linear/impl/solvers_linear_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
#include <Eigen/Sparse>
#include <Eigen/SparseQR>
#include <Eigen/OrderingMethods>
#include <unsupported/Eigen/IterativeSolvers>
#endif

namespace pressio{ namespace linearsolvers{
Expand All @@ -73,6 +74,23 @@ struct Traits {
#endif
};

template <>
struct Traits<::pressio::linearsolvers::iterative::GMRES>
{
static constexpr bool direct = false;
static constexpr bool iterative = true;

#ifdef PRESSIO_ENABLE_TPL_EIGEN
template <
typename MatrixOrOperatorT,
typename PrecT = Eigen::IdentityPreconditioner
>
using eigen_solver_type = Eigen::GMRES<MatrixOrOperatorT, PrecT>;

static constexpr bool eigen_enabled = true;
#endif
};

template <>
struct Traits<::pressio::linearsolvers::iterative::CG>
{
Expand Down
1 change: 1 addition & 0 deletions include/pressio/solvers_linear/solvers_linear_tags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ namespace iterative{
struct CG {};
struct LSCG {};
struct Bicgstab {};
struct GMRES{};
}

namespace direct{
Expand Down
17 changes: 17 additions & 0 deletions include/pressio/solvers_nonlinear/impl/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ void compute_residual(RegistryType & reg,
system.residual(state, r);
}

#ifdef PRESSIO_ENABLE_CXX20
template<class RegistryType, class SystemType>
requires NonlinearSystem<SystemType>
#else
template<
class RegistryType, class SystemType,
std::enable_if_t< NonlinearSystem<SystemType>::value, int> = 0
>
#endif
void compute_residual(RegistryType & reg,
const SystemType & system)
{
const auto & state = reg.template get<StateTag>();
auto & r = reg.template get<ResidualTag>();
system.residual(state, r);
}

template<class RegistryType, class SystemType>
void compute_residual_and_jacobian(RegistryType & reg,
const SystemType & system)
Expand Down
1 change: 1 addition & 0 deletions include/pressio/solvers_nonlinear/impl/internal_tags.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct QTransposeResidualTag{};


struct NewtonTag{};
struct MatrixFreeNewtonTag{};
struct GaussNewtonNormalEqTag{};
struct WeightedGaussNewtonNormalEqTag{};
struct LevenbergMarquardtNormalEqTag{};
Expand Down
38 changes: 38 additions & 0 deletions include/pressio/solvers_nonlinear/impl/registries.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,44 @@ class RegistryNewton
GETMETHOD(6)
};

template<class SystemType, class LinearSolverTag>
class RegistryMatrixFreeNewtonKrylov
{
using state_t = typename SystemType::state_type;
using r_t = typename SystemType::residual_type;

using Tag1 = nonlinearsolvers::CorrectionTag;
using Tag2 = nonlinearsolvers::InitialGuessTag;
using Tag3 = nonlinearsolvers::ResidualTag;
using Tag4 = nonlinearsolvers::impl::SystemTag;

state_t d1_;
state_t d2_;
r_t d3_;
SystemType const * d4_;

public:
using linear_solver_tag = LinearSolverTag;

RegistryMatrixFreeNewtonKrylov(const SystemType & system)
: d1_(system.createState()),
d2_(system.createState()),
d3_(system.createResidual()),
d4_(&system){}

template<class TagToFind>
static constexpr bool contains(){
return (mpl::variadic::find_if_binary_pred_t<TagToFind, std::is_same,
Tag1, Tag2, Tag3, Tag4>::value) < 4;
}

GETMETHOD(1)
GETMETHOD(2)
GETMETHOD(3)
GETMETHOD(4)
};


template<class SystemType, class InnSolverType>
class RegistryGaussNewtonNormalEqs
{
Expand Down
Loading