Skip to content

Commit

Permalink
Generalize tqdm pb to not depend on std::cerr automatically since R b…
Browse files Browse the repository at this point in the history
…indings forbids dealing with std::cerr.
  • Loading branch information
JamesYang007 committed Jun 7, 2024
1 parent c7b7677 commit 870f8cf
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 26 deletions.
9 changes: 6 additions & 3 deletions adelie/src/include/adelie_core/solver/solver_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ bool kkt(
}

template <class StateType,
class PBType,
class PBAddSuffixType,
class UpdateLossNullType,
class UpdateInvarianceType,
Expand All @@ -231,7 +232,7 @@ template <class StateType,
class FitType>
inline void solve_core(
StateType&& state,
bool display,
PBType&& pb,
PBAddSuffixType pb_add_suffix_f,
UpdateLossNullType update_loss_null_f,
UpdateInvarianceType update_invariance_f,
Expand Down Expand Up @@ -320,8 +321,10 @@ inline void solve_core(
// All solutions to lambda > lambda_max are saved.

// initialize progress bar
auto pb = util::tq::trange(lmda_path.size());
pb.set_display(display);
pb.set_range(
util::tq::int_iterator<int>(0),
util::tq::int_iterator<int>(lmda_path.size())
);
auto pb_it = pb.begin();

// slice lambda_path up to lmda_max
Expand Down
7 changes: 4 additions & 3 deletions adelie/src/include/adelie_core/solver/solver_gaussian_cov.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,13 @@ auto fit(
}

template <class StateType,
class PBType,
class ExitCondType,
class UpdateCoefficientsType,
class CUIType=util::no_op>
inline void solve(
StateType&& state,
bool display,
PBType&& pb,
ExitCondType exit_cond_f,
UpdateCoefficientsType update_coefficients_f,
CUIType check_user_interrupt = CUIType()
Expand All @@ -247,7 +248,7 @@ inline void solve(
GaussianCovBufferPack<value_t, safe_bool_t> buffer_pack(p);

const auto pb_add_suffix_f = [&](const auto& state, auto& pb) {
if (display) cov::pb_add_suffix(state, pb);
cov::pb_add_suffix(state, pb);
};
const auto update_loss_null_f = [](const auto&) {};
const auto update_invariance_f = [&](
Expand Down Expand Up @@ -314,7 +315,7 @@ inline void solve(

solver::solve_core(
state,
display,
pb,
pb_add_suffix_f,
update_loss_null_f,
update_invariance_f,
Expand Down
12 changes: 7 additions & 5 deletions adelie/src/include/adelie_core/solver/solver_gaussian_naive.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,14 @@ auto fit(
}

template <class StateType,
class PBType,
class ExitCondType,
class UpdateCoefficientsType,
class TidyType,
class CUIType>
inline void solve(
StateType&& state,
bool display,
PBType&& pb,
ExitCondType exit_cond_f,
UpdateCoefficientsType update_coefficients_f,
TidyType tidy_f,
Expand All @@ -216,7 +217,7 @@ inline void solve(
GaussianNaiveBufferPack<value_t, safe_bool_t> buffer_pack(n);

const auto pb_add_suffix_f = [&](const auto& state, auto& pb) {
if (display) solver::pb_add_suffix(state, pb);
solver::pb_add_suffix(state, pb);
};
const auto update_loss_null_f = [](const auto&) {};
const auto update_invariance_f = [&](auto& state, const auto&, auto lmda) {
Expand Down Expand Up @@ -268,7 +269,7 @@ inline void solve(

solver::solve_core(
state,
display,
pb,
pb_add_suffix_f,
update_loss_null_f,
update_invariance_f,
Expand All @@ -280,20 +281,21 @@ inline void solve(
}

template <class StateType,
class PBType,
class ExitCondType,
class UpdateCoefficientsType,
class CUIType=util::no_op>
inline void solve(
StateType&& state,
bool display,
PBType&& pb,
ExitCondType exit_cond_f,
UpdateCoefficientsType update_coefficients_f,
CUIType check_user_interrupt = CUIType()
)
{
solve(
state,
display,
pb,
exit_cond_f,
update_coefficients_f,
[](){},
Expand Down
12 changes: 7 additions & 5 deletions adelie/src/include/adelie_core/solver/solver_glm_naive.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ auto fit(

template <class StateType,
class GlmType,
class PBType,
class ExitCondType,
class UpdateLossNullType,
class UpdateCoefficientsType,
Expand All @@ -402,7 +403,7 @@ template <class StateType,
inline void solve(
StateType&& state,
GlmType&& glm,
bool display,
PBType&& pb,
ExitCondType exit_cond_f,
UpdateLossNullType update_loss_null_f,
UpdateCoefficientsType update_coefficients_f,
Expand All @@ -419,7 +420,7 @@ inline void solve(
GlmNaiveBufferPack<value_t, safe_bool_t> buffer_pack(n, p);

const auto pb_add_suffix_f = [&](const auto& state, auto& pb) {
if (display) solver::pb_add_suffix(state, pb);
solver::pb_add_suffix(state, pb);
};
const auto update_loss_null_wrap_f = [&](auto& state) {
const auto setup_loss_null = state.setup_loss_null;
Expand Down Expand Up @@ -468,7 +469,7 @@ inline void solve(

solver::solve_core(
state,
display,
pb,
pb_add_suffix_f,
update_loss_null_wrap_f,
update_invariance_f,
Expand All @@ -481,13 +482,14 @@ inline void solve(

template <class StateType,
class GlmType,
class PBType,
class ExitCondType,
class UpdateCoefficientsType,
class CUIType=util::no_op>
inline void solve(
StateType&& state,
GlmType&& glm,
bool display,
PBType&& pb,
ExitCondType exit_cond_f,
UpdateCoefficientsType update_coefficients_f,
CUIType check_user_interrupt = CUIType()
Expand All @@ -496,7 +498,7 @@ inline void solve(
solve(
std::forward<StateType>(state),
std::forward<GlmType>(glm),
display,
std::forward<PBType>(pb),
exit_cond_f,
[](auto& state, auto& glm, auto& buffer_pack) {
update_loss_null(state, glm, buffer_pack);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ namespace multigaussian {
namespace naive {

template <class StateType,
class PBType,
class ExitCondType,
class UpdateCoefficientsType,
class CUIType=util::no_op>
inline void solve(
StateType&& state,
bool display,
PBType&& pb,
ExitCondType exit_cond_f,
UpdateCoefficientsType update_coefficients_f,
CUIType check_user_interrupt = CUIType()
Expand Down Expand Up @@ -47,7 +48,7 @@ inline void solve(

gaussian::naive::solve(
static_cast<state_gaussian_naive_t&>(state),
display,
pb,
exit_cond_f,
update_coefficients_f,
tidy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,14 @@ void update_loss_null(

template <class StateType,
class GlmType,
class PBType,
class ExitCondType,
class UpdateCoefficientsType,
class CUIType=util::no_op>
inline void solve(
StateType&& state,
GlmType&& glm,
bool display,
PBType&& pb,
ExitCondType exit_cond_f,
UpdateCoefficientsType update_coefficients_f,
CUIType check_user_interrupt = CUIType()
Expand Down Expand Up @@ -235,7 +236,7 @@ inline void solve(
glm::naive::solve(
static_cast<state_glm_naive_t&>(state),
glm_wrap,
display,
pb,
exit_cond_f,
[&](auto&, auto& glm, auto& buffer_pack) {
// ignore casted down state and use derived state
Expand Down
11 changes: 10 additions & 1 deletion adelie/src/include/adelie_core/util/tqdm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class progress_bar
double min_time_per_update_{1e-1}; // found experimentally
bool display_{true};

std::ostream* os_{&std::cerr};
std::ostream* os_ = nullptr;

index bar_size_{10};
index max_chars_{0};
Expand Down Expand Up @@ -329,6 +329,14 @@ class tqdm_for_lvalues
++iters_done_;
}

// NOTE: possibly invalidates other members.
// Safe usage is to call .begin(), .end() afterwards to reset everything.
void set_range(ForwardIter begin, EndIter end)
{
first_ = iterator(begin, this);
last_ = end;
num_iters_ = std::distance(begin, end);
}
void set_ostream(std::ostream& os) { bar_.set_ostream(os); }
void set_prefix(std::string s) { bar_.set_prefix(std::move(s)); }
void set_bar_size(int size) { bar_.set_bar_size(size); }
Expand Down Expand Up @@ -384,6 +392,7 @@ class tqdm_for_rvalues

void update() { return tqdm_.update(); }

void set_range(iterator begin, iterator end) { tqdm_.set_range(begin, end); }
void set_ostream(std::ostream& os) { tqdm_.set_ostream(os); }
void set_prefix(std::string s) { tqdm_.set_prefix(std::move(s)); }
void set_bar_size(int size) { tqdm_.set_bar_size(size); }
Expand Down
25 changes: 20 additions & 5 deletions adelie/src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,11 @@ py::dict solve_gaussian_cov(
const auto exit_cond_f = [&]() {
return exit_cond && exit_cond(state);
};
auto pb = ad::util::tq::trange(0);
pb.set_display(display_progress_bar);
pb.set_ostream(std::cerr);
ad::solver::gaussian::cov::solve(
state, display_progress_bar, exit_cond_f, u, c
state, pb, exit_cond_f, u, c
);
}
);
Expand All @@ -220,8 +223,11 @@ py::dict solve_gaussian_naive(
const auto exit_cond_f = [&]() {
return exit_cond && exit_cond(state);
};
auto pb = ad::util::tq::trange(0);
pb.set_display(display_progress_bar);
pb.set_ostream(std::cerr);
return ad::solver::gaussian::naive::solve(
state, display_progress_bar, exit_cond_f, u, c
state, pb, exit_cond_f, u, c
);
}
);
Expand All @@ -240,8 +246,11 @@ py::dict solve_multigaussian_naive(
const auto exit_cond_f = [&]() {
return exit_cond && exit_cond(state);
};
auto pb = ad::util::tq::trange(0);
pb.set_display(display_progress_bar);
pb.set_ostream(std::cerr);
ad::solver::multigaussian::naive::solve(
state, display_progress_bar, exit_cond_f, u, c
state, pb, exit_cond_f, u, c
);
}
);
Expand All @@ -265,8 +274,11 @@ py::dict solve_glm_naive(
const auto exit_cond_f = [&]() {
return exit_cond && exit_cond(state);
};
auto pb = ad::util::tq::trange(0);
pb.set_display(display_progress_bar);
pb.set_ostream(std::cerr);
ad::solver::glm::naive::solve(
state, glm, display_progress_bar, exit_cond_f, u, c
state, glm, pb, exit_cond_f, u, c
);
}
);
Expand All @@ -286,8 +298,11 @@ py::dict solve_multiglm_naive(
const auto exit_cond_f = [&]() {
return exit_cond && exit_cond(state);
};
auto pb = ad::util::tq::trange(0);
pb.set_display(display_progress_bar);
pb.set_ostream(std::cerr);
ad::solver::multiglm::naive::solve(
state, glm, display_progress_bar, exit_cond_f, u, c
state, glm, pb, exit_cond_f, u, c
);
}
);
Expand Down

0 comments on commit 870f8cf

Please sign in to comment.