From 870f8cf5e3b56c0e6c293415cd14431e60ef5a50 Mon Sep 17 00:00:00 2001 From: James Yang Date: Fri, 7 Jun 2024 09:24:56 -0700 Subject: [PATCH] Generalize tqdm pb to not depend on std::cerr automatically since R bindings forbids dealing with std::cerr. --- .../adelie_core/solver/solver_base.hpp | 9 ++++--- .../solver/solver_gaussian_cov.hpp | 7 +++--- .../solver/solver_gaussian_naive.hpp | 12 +++++---- .../adelie_core/solver/solver_glm_naive.hpp | 12 +++++---- .../solver/solver_multigaussian_naive.hpp | 5 ++-- .../solver/solver_multiglm_naive.hpp | 5 ++-- adelie/src/include/adelie_core/util/tqdm.hpp | 11 +++++++- adelie/src/solver.cpp | 25 +++++++++++++++---- 8 files changed, 60 insertions(+), 26 deletions(-) diff --git a/adelie/src/include/adelie_core/solver/solver_base.hpp b/adelie/src/include/adelie_core/solver/solver_base.hpp index 4fa439bb..941dafbe 100644 --- a/adelie/src/include/adelie_core/solver/solver_base.hpp +++ b/adelie/src/include/adelie_core/solver/solver_base.hpp @@ -222,6 +222,7 @@ bool kkt( } template inline void solve_core( StateType&& state, - bool display, + PBType&& pb, PBAddSuffixType pb_add_suffix_f, UpdateLossNullType update_loss_null_f, UpdateInvarianceType update_invariance_f, @@ -320,8 +321,10 @@ inline void solve_core( // All solutions to lambda > lambda_max are saved. // initialize progress bar - auto pb = util::tq::trange(lmda_path.size()); - pb.set_display(display); + pb.set_range( + util::tq::int_iterator(0), + util::tq::int_iterator(lmda_path.size()) + ); auto pb_it = pb.begin(); // slice lambda_path up to lmda_max diff --git a/adelie/src/include/adelie_core/solver/solver_gaussian_cov.hpp b/adelie/src/include/adelie_core/solver/solver_gaussian_cov.hpp index c40ed54f..d4b6cd62 100644 --- a/adelie/src/include/adelie_core/solver/solver_gaussian_cov.hpp +++ b/adelie/src/include/adelie_core/solver/solver_gaussian_cov.hpp @@ -226,12 +226,13 @@ auto fit( } template inline void solve( StateType&& state, - bool display, + PBType&& pb, ExitCondType exit_cond_f, UpdateCoefficientsType update_coefficients_f, CUIType check_user_interrupt = CUIType() @@ -247,7 +248,7 @@ inline void solve( GaussianCovBufferPack buffer_pack(p); const auto pb_add_suffix_f = [&](const auto& state, auto& pb) { - if (display) cov::pb_add_suffix(state, pb); + cov::pb_add_suffix(state, pb); }; const auto update_loss_null_f = [](const auto&) {}; const auto update_invariance_f = [&]( @@ -314,7 +315,7 @@ inline void solve( solver::solve_core( state, - display, + pb, pb_add_suffix_f, update_loss_null_f, update_invariance_f, diff --git a/adelie/src/include/adelie_core/solver/solver_gaussian_naive.hpp b/adelie/src/include/adelie_core/solver/solver_gaussian_naive.hpp index 2d54bc4a..cfd26845 100644 --- a/adelie/src/include/adelie_core/solver/solver_gaussian_naive.hpp +++ b/adelie/src/include/adelie_core/solver/solver_gaussian_naive.hpp @@ -195,13 +195,14 @@ auto fit( } template inline void solve( StateType&& state, - bool display, + PBType&& pb, ExitCondType exit_cond_f, UpdateCoefficientsType update_coefficients_f, TidyType tidy_f, @@ -216,7 +217,7 @@ inline void solve( GaussianNaiveBufferPack buffer_pack(n); const auto pb_add_suffix_f = [&](const auto& state, auto& pb) { - if (display) solver::pb_add_suffix(state, pb); + solver::pb_add_suffix(state, pb); }; const auto update_loss_null_f = [](const auto&) {}; const auto update_invariance_f = [&](auto& state, const auto&, auto lmda) { @@ -268,7 +269,7 @@ inline void solve( solver::solve_core( state, - display, + pb, pb_add_suffix_f, update_loss_null_f, update_invariance_f, @@ -280,12 +281,13 @@ inline void solve( } template inline void solve( StateType&& state, - bool display, + PBType&& pb, ExitCondType exit_cond_f, UpdateCoefficientsType update_coefficients_f, CUIType check_user_interrupt = CUIType() @@ -293,7 +295,7 @@ inline void solve( { solve( state, - display, + pb, exit_cond_f, update_coefficients_f, [](){}, diff --git a/adelie/src/include/adelie_core/solver/solver_glm_naive.hpp b/adelie/src/include/adelie_core/solver/solver_glm_naive.hpp index 2b2f9283..e48a6acb 100644 --- a/adelie/src/include/adelie_core/solver/solver_glm_naive.hpp +++ b/adelie/src/include/adelie_core/solver/solver_glm_naive.hpp @@ -394,6 +394,7 @@ auto fit( template buffer_pack(n, p); const auto pb_add_suffix_f = [&](const auto& state, auto& pb) { - if (display) solver::pb_add_suffix(state, pb); + solver::pb_add_suffix(state, pb); }; const auto update_loss_null_wrap_f = [&](auto& state) { const auto setup_loss_null = state.setup_loss_null; @@ -468,7 +469,7 @@ inline void solve( solver::solve_core( state, - display, + pb, pb_add_suffix_f, update_loss_null_wrap_f, update_invariance_f, @@ -481,13 +482,14 @@ inline void solve( template inline void solve( StateType&& state, GlmType&& glm, - bool display, + PBType&& pb, ExitCondType exit_cond_f, UpdateCoefficientsType update_coefficients_f, CUIType check_user_interrupt = CUIType() @@ -496,7 +498,7 @@ inline void solve( solve( std::forward(state), std::forward(glm), - display, + std::forward(pb), exit_cond_f, [](auto& state, auto& glm, auto& buffer_pack) { update_loss_null(state, glm, buffer_pack); diff --git a/adelie/src/include/adelie_core/solver/solver_multigaussian_naive.hpp b/adelie/src/include/adelie_core/solver/solver_multigaussian_naive.hpp index a880fcb7..52f74396 100644 --- a/adelie/src/include/adelie_core/solver/solver_multigaussian_naive.hpp +++ b/adelie/src/include/adelie_core/solver/solver_multigaussian_naive.hpp @@ -11,12 +11,13 @@ namespace multigaussian { namespace naive { template inline void solve( StateType&& state, - bool display, + PBType&& pb, ExitCondType exit_cond_f, UpdateCoefficientsType update_coefficients_f, CUIType check_user_interrupt = CUIType() @@ -47,7 +48,7 @@ inline void solve( gaussian::naive::solve( static_cast(state), - display, + pb, exit_cond_f, update_coefficients_f, tidy, diff --git a/adelie/src/include/adelie_core/solver/solver_multiglm_naive.hpp b/adelie/src/include/adelie_core/solver/solver_multiglm_naive.hpp index c924b8cf..4c5a8acc 100644 --- a/adelie/src/include/adelie_core/solver/solver_multiglm_naive.hpp +++ b/adelie/src/include/adelie_core/solver/solver_multiglm_naive.hpp @@ -194,13 +194,14 @@ void update_loss_null( template inline void solve( StateType&& state, GlmType&& glm, - bool display, + PBType&& pb, ExitCondType exit_cond_f, UpdateCoefficientsType update_coefficients_f, CUIType check_user_interrupt = CUIType() @@ -235,7 +236,7 @@ inline void solve( glm::naive::solve( static_cast(state), glm_wrap, - display, + pb, exit_cond_f, [&](auto&, auto& glm, auto& buffer_pack) { // ignore casted down state and use derived state diff --git a/adelie/src/include/adelie_core/util/tqdm.hpp b/adelie/src/include/adelie_core/util/tqdm.hpp index 51183bbe..2aa73f2a 100644 --- a/adelie/src/include/adelie_core/util/tqdm.hpp +++ b/adelie/src/include/adelie_core/util/tqdm.hpp @@ -218,7 +218,7 @@ class progress_bar double min_time_per_update_{1e-1}; // found experimentally bool display_{true}; - std::ostream* os_{&std::cerr}; + std::ostream* os_ = nullptr; index bar_size_{10}; index max_chars_{0}; @@ -329,6 +329,14 @@ class tqdm_for_lvalues ++iters_done_; } + // NOTE: possibly invalidates other members. + // Safe usage is to call .begin(), .end() afterwards to reset everything. + void set_range(ForwardIter begin, EndIter end) + { + first_ = iterator(begin, this); + last_ = end; + num_iters_ = std::distance(begin, end); + } void set_ostream(std::ostream& os) { bar_.set_ostream(os); } void set_prefix(std::string s) { bar_.set_prefix(std::move(s)); } void set_bar_size(int size) { bar_.set_bar_size(size); } @@ -384,6 +392,7 @@ class tqdm_for_rvalues void update() { return tqdm_.update(); } + void set_range(iterator begin, iterator end) { tqdm_.set_range(begin, end); } void set_ostream(std::ostream& os) { tqdm_.set_ostream(os); } void set_prefix(std::string s) { tqdm_.set_prefix(std::move(s)); } void set_bar_size(int size) { tqdm_.set_bar_size(size); } diff --git a/adelie/src/solver.cpp b/adelie/src/solver.cpp index 1ab91021..439b7e90 100644 --- a/adelie/src/solver.cpp +++ b/adelie/src/solver.cpp @@ -200,8 +200,11 @@ py::dict solve_gaussian_cov( const auto exit_cond_f = [&]() { return exit_cond && exit_cond(state); }; + auto pb = ad::util::tq::trange(0); + pb.set_display(display_progress_bar); + pb.set_ostream(std::cerr); ad::solver::gaussian::cov::solve( - state, display_progress_bar, exit_cond_f, u, c + state, pb, exit_cond_f, u, c ); } ); @@ -220,8 +223,11 @@ py::dict solve_gaussian_naive( const auto exit_cond_f = [&]() { return exit_cond && exit_cond(state); }; + auto pb = ad::util::tq::trange(0); + pb.set_display(display_progress_bar); + pb.set_ostream(std::cerr); return ad::solver::gaussian::naive::solve( - state, display_progress_bar, exit_cond_f, u, c + state, pb, exit_cond_f, u, c ); } ); @@ -240,8 +246,11 @@ py::dict solve_multigaussian_naive( const auto exit_cond_f = [&]() { return exit_cond && exit_cond(state); }; + auto pb = ad::util::tq::trange(0); + pb.set_display(display_progress_bar); + pb.set_ostream(std::cerr); ad::solver::multigaussian::naive::solve( - state, display_progress_bar, exit_cond_f, u, c + state, pb, exit_cond_f, u, c ); } ); @@ -265,8 +274,11 @@ py::dict solve_glm_naive( const auto exit_cond_f = [&]() { return exit_cond && exit_cond(state); }; + auto pb = ad::util::tq::trange(0); + pb.set_display(display_progress_bar); + pb.set_ostream(std::cerr); ad::solver::glm::naive::solve( - state, glm, display_progress_bar, exit_cond_f, u, c + state, glm, pb, exit_cond_f, u, c ); } ); @@ -286,8 +298,11 @@ py::dict solve_multiglm_naive( const auto exit_cond_f = [&]() { return exit_cond && exit_cond(state); }; + auto pb = ad::util::tq::trange(0); + pb.set_display(display_progress_bar); + pb.set_ostream(std::cerr); ad::solver::multiglm::naive::solve( - state, glm, display_progress_bar, exit_cond_f, u, c + state, glm, pb, exit_cond_f, u, c ); } );