Skip to content

Commit

Permalink
add custom callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Francesco Rizzi committed Jun 17, 2024
1 parent 7f4c931 commit 89710f1
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 35 deletions.
165 changes: 131 additions & 34 deletions include/pressio/rom/impl/lspg_unsteady_reconstructor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,27 @@ void write_vector_to_binary(const std::string filename,
}

#ifdef PRESSIO_ENABLE_TPL_TRILINOS
template<class Rt, class Jt>
void write_res_jac_to_binary(const std::string & prepend, std::size_t i, const Rt & R, const Jt & JPhi)
struct DefaultWriter
{
const auto myRank = R.getMap()->getComm()->getRank();
const std::string finalPart = "rank_" + std::to_string(myRank) + "_step_" + std::to_string(i) + ".bin";

auto r_view = R.getLocalViewHost(Tpetra::Access::ReadOnly);
auto r_stdv = _rank1_view_to_stdvector(r_view);
const std::string R_f = prepend + "residual_" + finalPart;
write_vector_to_binary(R_f, r_stdv.data(), r_stdv.size());

auto jphi_view = JPhi.getLocalViewHost(Tpetra::Access::ReadOnly);
auto jphi_stdv = _rank2_view_to_stdvector(jphi_view);
std::string Jphi_f = prepend + "jacobian_action_" + finalPart;
write_vector_to_binary(Jphi_f, jphi_stdv.data(), jphi_stdv.size());
}
std::string prependToFileOut_;

template<class Rt, class Jt>
void operator()(std::size_t i, const Rt & R, const Jt & JPhi) const
{
const auto myRank = R.getMap()->getComm()->getRank();
const std::string finalPart = "rank_" + std::to_string(myRank) + "_step_" + std::to_string(i) + ".bin";

auto r_view = R.getLocalViewHost(Tpetra::Access::ReadOnly);
auto r_stdv = _rank1_view_to_stdvector(r_view);
const std::string R_f = prependToFileOut_ + "residual_" + finalPart;
write_vector_to_binary(R_f, r_stdv.data(), r_stdv.size());

auto jphi_view = JPhi.getLocalViewHost(Tpetra::Access::ReadOnly);
auto jphi_stdv = _rank2_view_to_stdvector(jphi_view);
std::string Jphi_f = prependToFileOut_ + "jacobian_action_" + finalPart;
write_vector_to_binary(Jphi_f, jphi_stdv.data(), jphi_stdv.size());
}
};

template<class MapType>
void write_map_to_file(MapType const & map)
Expand All @@ -143,6 +148,26 @@ void write_map_to_file(MapType const & map)
map.describe(*out, Teuchos::EVerbosityLevel::VERB_EXTREME);
outMapFile.close();
}

template<class T, typename ResidualType, typename TrialSubspaceType, typename enable = void>
struct is_writer : std::false_type {};

template<typename T, typename ResidualType, typename TrialSubspaceType>
struct is_writer<
T, ResidualType, TrialSubspaceType,
std::enable_if_t<
std::is_void<
decltype(
std::declval<T const>()(
std::size_t{},
std::declval<ResidualType const>(),
std::declval<typename TrialSubspaceType::basis_matrix_type const>()
)
)
>::value
>
> : std::true_type{};

#endif

template <class TrialSubspaceType>
Expand All @@ -158,6 +183,9 @@ class LspgReconstructor{
: trialSubspace_(trialSubspace){}

#ifdef PRESSIO_ENABLE_TPL_TRILINOS
//
// constrained for fully discrete system
//
template <
std::size_t n,
class FomSystemType,
Expand All @@ -167,13 +195,89 @@ class LspgReconstructor{
RealValuedFullyDiscreteSystemWithJacobianAction<
FomSystemType, n, typename _TrialSubspaceType::basis_matrix_type>::value
>
execute(const FomSystemType & fomSystem,
const std::string & filename,
execute(
const FomSystemType & fomSystem,
const std::string & romStateFilename,
std::optional<std::string> filenamePrepend = {}) const
{
static_assert(n==2,
"lspg reconstructor for a fully discrete system currently supports TotalNumberOfDesiredStates==2");

DefaultWriter writer{ filenamePrepend.value_or("") };
execute_for_fully_discrete_time_impl(fomSystem, romStateFilename, writer);
}

template <
std::size_t n,
class FomSystemType,
class CustomWriterType,
class _TrialSubspaceType = TrialSubspaceType
>
std::enable_if_t<
RealValuedFullyDiscreteSystemWithJacobianAction<
FomSystemType, n, typename _TrialSubspaceType::basis_matrix_type>::value
&& is_writer<CustomWriterType, typename FomSystemType::discrete_residual_type, _TrialSubspaceType>::value
>
execute(
const FomSystemType & fomSystem,
const std::string & romStateFilename,
const CustomWriterType & writer) const
{
static_assert(n==2,
"lspg reconstructor for a fully discrete system currently supports TotalNumberOfDesiredStates==2");

execute_for_fully_discrete_time_impl(fomSystem, romStateFilename, writer);
}

//
// constrained for semi-discrete system
//
template <
class FomSystemType,
class _TrialSubspaceType = TrialSubspaceType
>
std::enable_if_t<
RealValuedSemiDiscreteFomWithJacobianAction<
FomSystemType, typename _TrialSubspaceType::basis_matrix_type
>::value
>
execute(
const FomSystemType & fomSystem,
std::string const & romStateFilename,
::pressio::ode::StepScheme schemeName,
std::optional<std::string> filenamePrepend = {}) const
{
DefaultWriter writer{ filenamePrepend.value_or("") };
execute_for_semi_discrete_time_impl(fomSystem, romStateFilename, schemeName, writer);
}

template <
class FomSystemType,
class CustomWriterType,
class _TrialSubspaceType = TrialSubspaceType
>
std::enable_if_t<
RealValuedSemiDiscreteFomWithJacobianAction<
FomSystemType, typename _TrialSubspaceType::basis_matrix_type>::value
&& is_writer<CustomWriterType, typename FomSystemType::rhs_type, _TrialSubspaceType>::value
>
execute(
const FomSystemType & fomSystem,
std::string const & romStateFilename,
::pressio::ode::StepScheme schemeName,
const CustomWriterType & writer) const
{
execute_for_semi_discrete_time_impl(fomSystem, romStateFilename, schemeName, writer);
}


private:
template <class FomSystemType, class WriterType>
void execute_for_fully_discrete_time_impl(
const FomSystemType & fomSystem,
const std::string & romStateFilename,
const WriterType & writer) const
{
const auto & trialSub = trialSubspace_.get();
const auto & phi = trialSub.basisOfTranslatedSpace();

Expand All @@ -189,7 +293,7 @@ class LspgReconstructor{

// 3. read states
const std::size_t numModes = trialSubspace_.get().dimension();
auto [times, reducedStates] = read_rom_states_and_times_from_ascii<rom_state_type>(filename, numModes);
auto [times, reducedStates] = read_rom_states_and_times_from_ascii<rom_state_type>(romStateFilename, numModes);

trialSub.mapFromReducedState(reducedStates[0], state_n);
for (std::size_t i = 1; i < times.size(); i++){
Expand All @@ -200,24 +304,17 @@ class LspgReconstructor{
trialSub.mapFromReducedState(reducedStates[i], state_np1);
fomSystem.discreteTimeResidualAndJacobianAction(i, t_np1, dt, R,
phi, &JTimesPhi, state_np1, state_n);
write_res_jac_to_binary(filenamePrepend.value_or(""), i, R, JTimesPhi);
writer(i, R, JTimesPhi);
pressio::ops::deep_copy(state_n, state_np1);
}
}

template <
class FomSystemType,
class _TrialSubspaceType = TrialSubspaceType
>
std::enable_if_t<
RealValuedSemiDiscreteFomWithJacobianAction<
FomSystemType, typename _TrialSubspaceType::basis_matrix_type
>::value
>
execute(const FomSystemType & fomSystem,
std::string const & filename,
::pressio::ode::StepScheme schemeName,
std::optional<std::string> filenamePrepend = {}) const
template <class FomSystemType, class WriterType>
void execute_for_semi_discrete_time_impl(
const FomSystemType & fomSystem,
const std::string & romStateFilename,
::pressio::ode::StepScheme schemeName,
const WriterType & writer) const
{
assert(schemeName == pressio::ode::StepScheme::BDF1);

Expand All @@ -238,7 +335,7 @@ class LspgReconstructor{

// 3. read states
const std::size_t numModes = trialSubspace_.get().dimension();
auto [times, reducedStates] = read_rom_states_and_times_from_ascii<rom_state_type>(filename, numModes);
auto [times, reducedStates] = read_rom_states_and_times_from_ascii<rom_state_type>(romStateFilename, numModes);

auto & state_n = fomStencilStates(::pressio::ode::n());
const auto one = ::pressio::utils::Constants<scalar_type>::one();
Expand All @@ -256,7 +353,7 @@ class LspgReconstructor{
const auto factor = dt*::pressio::ode::constants::bdf1<scalar_type>::c_f_;
::pressio::ops::update(JTimesPhi, factor, phi, one);

write_res_jac_to_binary(filenamePrepend.value_or(""), i, R, JTimesPhi);
writer(i, R, JTimesPhi);
pressio::ops::deep_copy(state_n, state_np1);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class MyFom
for (int i=0; i<5; ++i){
ASSERT_DOUBLE_EQ( R_view(i,0), goldVal );
}

}
};

Expand All @@ -96,3 +95,38 @@ TEST_F(fixture_t, lspg_residual_jacaction_reconstructor)
auto o = lspg::create_reconstructor(space);
o.execute<2>(app, romStatesStr, "main1_");
}


struct MyCustomWriter{
void operator()(
std::size_t i,
const typename fixture_t::vec_t & R,
const typename fixture_t::mvec_t & JPhi
) const
{
ASSERT_TRUE(i >= 1 && i<6);
}
};

TEST_F(fixture_t, lspg_residual_jacaction_reconstructor_custom_writer)
{
using namespace pressio::rom;

auto phi = ::pressio::ops::clone(*myMv_);
for (int i=0; i<numVecs_; ++i){
auto col = pressio::column(phi, i);
pressio::ops::fill(col, i);
}

vec_t shift(contigMap_);
pressio::ops::fill(shift, 0);
using reduced_state_type = Eigen::VectorXd;
auto space = create_trial_column_subspace<reduced_state_type>(std::move(phi), shift, false);

const std::string romStatesStr = "./lspg_residual_jacaction_reconstructor/rom_states.txt";
MyFom app(*this);
auto o = lspg::create_reconstructor(space);
o.execute<2>(app, romStatesStr, MyCustomWriter{});
}


Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,35 @@ TEST_F(fixture_t, lspg_residual_jacaction_reconstructor_bdf1)
auto o = rom::lspg::create_reconstructor(space);
o.execute(app, romStatesStr, ode::StepScheme::BDF1, "main2_");
}

struct MyCustomWriter{
void operator()(
std::size_t i,
const typename fixture_t::vec_t & R,
const typename fixture_t::mvec_t & JPhi
) const
{
ASSERT_TRUE(i >= 1 && i<6);
}
};

TEST_F(fixture_t, lspg_residual_jacaction_reconstructor_bdf1_custom_writer)
{
using namespace pressio;

auto phi = ops::clone(*myMv_);
for (int i=0; i<numVecs_; ++i){
auto col = pressio::column(phi, i);
ops::fill(col, i);
}

vec_t shift(contigMap_);
ops::fill(shift, 0);
using reduced_state_type = Eigen::VectorXd;
auto space = rom::create_trial_column_subspace<reduced_state_type>(std::move(phi), shift, false);

const std::string romStatesStr = "./lspg_residual_jacaction_reconstructor/rom_states.txt";
MyFom app(*this);
auto o = rom::lspg::create_reconstructor(space);
o.execute(app, romStatesStr, ode::StepScheme::BDF1, MyCustomWriter{});
}

0 comments on commit 89710f1

Please sign in to comment.