diff --git a/adelie/src/include/adelie_core/matrix/matrix_naive_dense.hpp b/adelie/src/include/adelie_core/matrix/matrix_naive_dense.hpp index 85df1d25..d58c4d9e 100644 --- a/adelie/src/include/adelie_core/matrix/matrix_naive_dense.hpp +++ b/adelie/src/include/adelie_core/matrix/matrix_naive_dense.hpp @@ -84,7 +84,11 @@ class MatrixNaiveDense: public MatrixNaiveBase Eigen::Ref out ) const override { - out = _mat.middleCols(j, q); + dmmeq( + out, + _mat.middleCols(j, q), + _n_threads + ); } int rows() const override diff --git a/adelie/src/include/adelie_core/matrix/utils.hpp b/adelie/src/include/adelie_core/matrix/utils.hpp index a3e517b3..98648200 100644 --- a/adelie/src/include/adelie_core/matrix/utils.hpp +++ b/adelie/src/include/adelie_core/matrix/utils.hpp @@ -30,6 +30,56 @@ void dvsubi( } } +template +ADELIE_CORE_STRONG_INLINE +void dmvsubi( + X1Type& x1, + const X2Type& x2, + size_t n_threads +) +{ + assert(n_threads > 0); + const size_t n = x1.rows(); + const int n_blocks = std::min(n_threads, n); + const int block_size = n / n_blocks; + const int remainder = n % n_blocks; + + #pragma omp parallel for schedule(static) num_threads(n_blocks) + for (int t = 0; t < n_blocks; ++t) { + const auto begin = ( + std::min(t, remainder) * (block_size + 1) + + std::max(t-remainder, 0) * block_size + ); + const auto size = block_size + (t < remainder); + x1.middleRows(begin, size).rowwise() -= x2; + } +} + +template +ADELIE_CORE_STRONG_INLINE +void dmmeq( + X1Type& x1, + const X2Type& x2, + size_t n_threads +) +{ + assert(n_threads > 0); + const size_t n = x1.rows(); + const int n_blocks = std::min(n_threads, n); + const int block_size = n / n_blocks; + const int remainder = n % n_blocks; + + #pragma omp parallel for schedule(static) num_threads(n_blocks) + for (int t = 0; t < n_blocks; ++t) { + const auto begin = ( + std::min(t, remainder) * (block_size + 1) + + std::max(t-remainder, 0) * block_size + ); + const auto size = block_size + (t < remainder); + x1.middleRows(begin, size) = x2.middleRows(begin, size); + } +} + template ADELIE_CORE_STRONG_INLINE auto ddot( diff --git a/adelie/src/include/adelie_core/solver/solve_basil_naive.hpp b/adelie/src/include/adelie_core/solver/solve_basil_naive.hpp index f40ea362..5d12d61e 100644 --- a/adelie/src/include/adelie_core/solver/solve_basil_naive.hpp +++ b/adelie/src/include/adelie_core/solver/solve_basil_naive.hpp @@ -154,18 +154,7 @@ void screen_strong( return strong_hashset.find(i) != strong_hashset.end(); }; - /* update strong_set */ - // Use either the fixed-increment rule or strong rule to increase strong set. - if (strong_rule == state::strong_rule_type::_default) { - const auto strong_rule_lmda = (2 * lmda_next - lmda) * alpha; - // TODO: PARALLELIZE! - for (int i = 0; i < abs_grad.size(); ++i) { - if (is_strong(i)) continue; - if (abs_grad[i] > strong_rule_lmda * penalty[i]) { - strong_set.push_back(i); - } - } - } else if (strong_rule == state::strong_rule_type::_fixed_greedy) { + const auto do_fixed_greedy = [&]() { size_t size_capped = std::min(delta_strong_size, new_safe_size); strong_set.insert(strong_set.end(), size_capped, -1); @@ -182,10 +171,32 @@ void screen_strong( size_capped, std::next(strong_set.begin(), old_strong_set_size) ); + }; + + /* update strong_set */ + // Use either the fixed-increment rule or strong rule to increase strong set. + if (strong_rule == state::strong_rule_type::_default) { + const auto strong_rule_lmda = (2 * lmda_next - lmda) * alpha; + + // TODO: PARALLELIZE! + for (int i = 0; i < abs_grad.size(); ++i) { + if (is_strong(i)) continue; + if (abs_grad[i] > strong_rule_lmda * penalty[i]) { + strong_set.push_back(i); + } + } + + // If no new strong variables were added, need a fall-back. + // Use fixed-greedy method. + if (strong_set.size() == old_strong_set_size) { + do_fixed_greedy(); + } + } else if (strong_rule == state::strong_rule_type::_fixed_greedy) { + do_fixed_greedy(); } else if (strong_rule == state::strong_rule_type::_safe) { for (int i = 0; i < edpp_safe_set.size(); ++i) { if (is_strong(edpp_safe_set[i])) continue; - strong_set.push_back(i); + strong_set.push_back(edpp_safe_set[i]); } } @@ -243,6 +254,7 @@ auto fit( auto& X = *state.X; const auto y_mean = state.y_mean; + const auto y_var = state.y_var; const auto& groups = state.groups; const auto& group_sizes = state.group_sizes; const auto alpha = state.alpha; @@ -257,6 +269,9 @@ auto fit( const auto intercept = state.intercept; const auto max_iters = state.max_iters; const auto tol = state.tol; + const auto rsq_tol = state.rsq_tol; + const auto rsq_slope_tol = state.rsq_slope_tol; + const auto rsq_curv_tol = state.rsq_curv_tol; const auto newton_tol = state.newton_tol; const auto newton_max_iters = state.newton_max_iters; const auto n_threads = state.n_threads; @@ -269,6 +284,7 @@ auto fit( state_pin_naive_t state_pin_naive( X, y_mean, + y_var, groups, group_sizes, alpha, @@ -281,7 +297,8 @@ auto fit( Eigen::Map(strong_X_means.data(), strong_X_means.size()), strong_transforms, lmda_path, - intercept, max_iters, tol, 0, 0, newton_tol, newton_max_iters, n_threads, + intercept, max_iters, tol, rsq_tol, rsq_slope_tol, rsq_curv_tol, + newton_tol, newton_max_iters, n_threads, rsq, Eigen::Map(resid.data(), resid.size()), resid_sum, @@ -390,8 +407,8 @@ inline void solve_basil( using state_t = std::decay_t; using value_t = typename state_t::value_t; using vec_value_t = typename state_t::vec_value_t; - using vec_index_t = typename state_t::vec_index_t; using vec_safe_bool_t = typename state_t::vec_safe_bool_t; + using sw_t = util::Stopwatch; const auto& X_means = state.X_means; const auto alpha = state.alpha; @@ -401,6 +418,7 @@ inline void solve_basil( const auto delta_lmda_path_size = state.delta_lmda_path_size; const auto early_exit = state.early_exit; const auto max_strong_size = state.max_strong_size; + const auto rsq_tol = state.rsq_tol; const auto rsq_slope_tol = state.rsq_slope_tol; const auto rsq_curv_tol = state.rsq_curv_tol; const auto setup_edpp = state.setup_edpp; @@ -424,6 +442,10 @@ inline void solve_basil( auto& resid_prev_valid = state.resid_prev_valid; auto& strong_beta_prev_valid = state.strong_beta_prev_valid; auto& strong_is_active_prev_valid = state.strong_is_active_prev_valid; + auto& benchmark_screen = state.benchmark_screen; + auto& benchmark_fit = state.benchmark_fit; + auto& benchmark_kkt = state.benchmark_kkt; + auto& benchmark_invariance = state.benchmark_invariance; const auto p = grad.size(); @@ -588,6 +610,7 @@ inline void solve_basil( // We must go through BASIL iterations to solve each lambda. vec_value_t lmda_batch; std::vector grads; + sw_t sw; while (1) { @@ -596,9 +619,8 @@ inline void solve_basil( const auto rsq_u = rsqs[rsqs.size()-1]; const auto rsq_m = rsqs[rsqs.size()-2]; const auto rsq_l = rsqs[rsqs.size()-3]; - // TODO: generalize 0.99 if (check_early_stop_rsq(rsq_l, rsq_m, rsq_u, rsq_slope_tol, rsq_curv_tol) || - (rsqs.back() >= 0.99)) break; + (rsqs.back() >= rsq_tol)) break; } // check if any lambdas left to fit @@ -613,10 +635,12 @@ inline void solve_basil( // ==================================================================================== // Screening step // ==================================================================================== + sw.start(); naive::screen( state, lmda_batch[0] ); + benchmark_screen.push_back(sw.elapsed()); try { // ==================================================================================== @@ -625,27 +649,32 @@ inline void solve_basil( // Save all current valid quantities that will be modified in-place by fit. // This is needed for the invariance step in case no valid solutions are found. save_prev_valid(); + sw.start(); auto&& state_pin_naive = naive::fit( state, lmda_batch, update_coefficients_f, check_user_interrupt ); + benchmark_fit.push_back(sw.elapsed()); // ==================================================================================== // KKT step // ==================================================================================== grads.resize(lmda_batch.size()); for (auto& g : grads) g.resize(p); + sw.start(); const auto n_valid_solutions = naive::kkt( state, state_pin_naive, grads ); + benchmark_kkt.push_back(sw.elapsed()); // ==================================================================================== // Invariance step // ==================================================================================== + sw.start(); lmda_path_idx += n_valid_solutions; // If no valid solutions found, restore to the previous valid state, // so that we can start over after screening more variables. @@ -675,6 +704,7 @@ inline void solve_basil( state_pin_naive, n_valid_solutions ); + benchmark_invariance.push_back(sw.elapsed()); } catch (const std::exception& e) { load_prev_valid(); throw util::propagator_error(e.what()); diff --git a/adelie/src/include/adelie_core/solver/solve_pin_cov.hpp b/adelie/src/include/adelie_core/solver/solve_pin_cov.hpp index 2cb67d74..07ee369e 100644 --- a/adelie/src/include/adelie_core/solver/solve_pin_cov.hpp +++ b/adelie/src/include/adelie_core/solver/solve_pin_cov.hpp @@ -244,24 +244,23 @@ void solve_pin_active( ab_diff_view_curr = sb; } - time_active_cd.push_back(0); - { - sw_t stopwatch(time_active_cd.back()); - while (1) { - check_user_interrupt(iters); - ++iters; - value_t convg_measure; - coordinate_descent( - state, - active_g1.data(), active_g1.data() + active_g1.size(), - active_g2.data(), active_g2.data() + active_g2.size(), - lmda_idx, convg_measure, buffer1, buffer2, buffer3, buffer4, - update_coefficients_f - ); - if (convg_measure < tol) break; - if (iters >= max_iters) throw util::max_cds_error(lmda_idx); - } + sw_t stopwatch; + stopwatch.start(); + while (1) { + check_user_interrupt(iters); + ++iters; + value_t convg_measure; + coordinate_descent( + state, + active_g1.data(), active_g1.data() + active_g1.size(), + active_g2.data(), active_g2.data() + active_g2.size(), + lmda_idx, convg_measure, buffer1, buffer2, buffer3, buffer4, + update_coefficients_f + ); + if (convg_measure < tol) break; + if (iters >= max_iters) throw util::max_cds_error(lmda_idx); } + time_active_cd.push_back(stopwatch.elapsed()); // compute new active beta - old active beta for (size_t i = 0; i < active_set.size(); ++i) { @@ -422,22 +421,21 @@ inline void solve_pin( ++iters; value_t convg_measure; const auto old_active_size = active_set.size(); - time_strong_cd.push_back(0); - { - sw_t stopwatch(time_strong_cd.back()); - coordinate_descent( - state, - strong_g1.data(), strong_g1.data() + strong_g1.size(), - strong_g2.data(), strong_g2.data() + strong_g2.size(), - l, convg_measure, - buffer_pack.buffer1, - buffer_pack.buffer2, - buffer_pack.buffer3, - buffer_pack.buffer4, - update_coefficients_f, - add_active_set - ); - } + sw_t stopwatch; + stopwatch.start(); + coordinate_descent( + state, + strong_g1.data(), strong_g1.data() + strong_g1.size(), + strong_g2.data(), strong_g2.data() + strong_g2.size(), + l, convg_measure, + buffer_pack.buffer1, + buffer_pack.buffer2, + buffer_pack.buffer3, + buffer_pack.buffer4, + update_coefficients_f, + add_active_set + ); + time_strong_cd.push_back(stopwatch.elapsed()); const bool new_active_added = (old_active_size < active_set.size()); if (new_active_added) { diff --git a/adelie/src/include/adelie_core/solver/solve_pin_naive.hpp b/adelie/src/include/adelie_core/solver/solve_pin_naive.hpp index 62e9109c..51122511 100644 --- a/adelie/src/include/adelie_core/solver/solve_pin_naive.hpp +++ b/adelie/src/include/adelie_core/solver/solve_pin_naive.hpp @@ -188,24 +188,23 @@ void solve_pin_active( auto& iters = state.iters; auto& time_active_cd = state.time_active_cd; - time_active_cd.push_back(0); - { - sw_t stopwatch(time_active_cd.back()); - while (1) { - check_user_interrupt(iters); - ++iters; - value_t convg_measure; - coordinate_descent( - state, - active_g1.data(), active_g1.data() + active_g1.size(), - active_g2.data(), active_g2.data() + active_g2.size(), - lmda_idx, convg_measure, buffer1, buffer2, buffer3, buffer4_n, - update_coefficients_f - ); - if (convg_measure < tol) break; - if (iters >= max_iters) throw util::max_cds_error(lmda_idx); - } + sw_t stopwatch; + stopwatch.start(); + while (1) { + check_user_interrupt(iters); + ++iters; + value_t convg_measure; + coordinate_descent( + state, + active_g1.data(), active_g1.data() + active_g1.size(), + active_g2.data(), active_g2.data() + active_g2.size(), + lmda_idx, convg_measure, buffer1, buffer2, buffer3, buffer4_n, + update_coefficients_f + ); + if (convg_measure < tol) break; + if (iters >= max_iters) throw util::max_cds_error(lmda_idx); } + time_active_cd.push_back(stopwatch.elapsed()); } template = rsq_tol * y_var) break; // early stop if R^2 criterion is fulfilled. - if (check_early_stop_rsq(rsqs[l-2], rsqs[l-1], rsqs[l], rsq_slope_tol, rsq_curv_tol)) break; + if ((l >= 2) && + check_early_stop_rsq( + rsqs[l-2], + rsqs[l-1], + rsqs[l], + rsq_slope_tol, + rsq_curv_tol + )) break; } } diff --git a/adelie/src/include/adelie_core/state/state_basil_base.hpp b/adelie/src/include/adelie_core/state/state_basil_base.hpp index b4b9e57a..08b9f14f 100644 --- a/adelie/src/include/adelie_core/state/state_basil_base.hpp +++ b/adelie/src/include/adelie_core/state/state_basil_base.hpp @@ -153,6 +153,7 @@ struct StateBasilBase // convergence configs const size_t max_iters; const value_t tol; + const value_t rsq_tol; const value_t rsq_slope_tol; const value_t rsq_curv_tol; const value_t newton_tol; @@ -168,6 +169,8 @@ struct StateBasilBase /* dynamic states */ value_t lmda_max; vec_value_t lmda_path; + + // invariants uset_index_t strong_hashset; dyn_vec_index_t strong_set; dyn_vec_index_t strong_g1; @@ -180,13 +183,18 @@ struct StateBasilBase value_t lmda; vec_value_t grad; vec_value_t abs_grad; + + // final results dyn_vec_sp_vec_t betas; dyn_vec_value_t intercepts; dyn_vec_value_t rsqs; dyn_vec_value_t lmdas; - /* diagnostics */ - // TODO: fill + // diagnostics + std::vector benchmark_screen; + std::vector benchmark_fit; + std::vector benchmark_kkt; + std::vector benchmark_invariance; virtual ~StateBasilBase() =default; @@ -205,6 +213,7 @@ struct StateBasilBase const std::string& strong_rule, size_t max_iters, value_t tol, + value_t rsq_tol, value_t rsq_slope_tol, value_t rsq_curv_tol, value_t newton_tol, @@ -233,6 +242,7 @@ struct StateBasilBase strong_rule(convert_strong_rule(strong_rule)), max_iters(max_iters), tol(tol), + rsq_tol(rsq_tol), rsq_slope_tol(rsq_slope_tol), rsq_curv_tol(rsq_curv_tol), newton_tol(newton_tol), @@ -284,6 +294,10 @@ struct StateBasilBase intercepts.reserve(n_lmdas); rsqs.reserve(n_lmdas); lmdas.reserve(n_lmdas); + benchmark_fit.reserve(n_lmdas); + benchmark_kkt.reserve(n_lmdas); + benchmark_screen.reserve(n_lmdas); + benchmark_invariance.reserve(n_lmdas); } }; diff --git a/adelie/src/include/adelie_core/state/state_basil_naive.hpp b/adelie/src/include/adelie_core/state/state_basil_naive.hpp index 294daf78..0cfc13fb 100644 --- a/adelie/src/include/adelie_core/state/state_basil_naive.hpp +++ b/adelie/src/include/adelie_core/state/state_basil_naive.hpp @@ -1,7 +1,7 @@ #pragma once #include #include -#include +#include #include #include @@ -50,7 +50,9 @@ void update_strong_derived_naive( strong_transforms.resize(new_strong_size); strong_vars.resize(new_strong_value_size, 0); - util::colmat_type Xi; // buffer + // buffers + util::colmat_type Xi; + util::colmat_type XiTXi; for (size_t i = old_strong_size; i < new_strong_size; ++i) { const auto g = groups[strong_set[i]]; @@ -70,23 +72,24 @@ void update_strong_derived_naive( // if intercept, must center first if (intercept) { - // TODO: PARALLELIZE!! - Xi.rowwise() -= Xi_means.matrix(); + auto Xia = Xi.array(); + matrix::dmvsubi(Xia, Xi_means, n_threads); } // transform data - Eigen::BDCSVD> solver( - Xi, - Eigen::ComputeFullV - ); - const auto& D = solver.singularValues(); + Eigen::setNbThreads(n_threads); + XiTXi.noalias() = Xi.transpose() * Xi; + Eigen::setNbThreads(0); + + Eigen::SelfAdjointEigenSolver> solver(XiTXi); /* update strong_transforms */ - strong_transforms[i] = std::move(solver.matrixV()); + strong_transforms[i] = std::move(solver.eigenvectors()); /* update strong_vars */ + const auto& D = solver.eigenvalues(); Eigen::Map svars(strong_vars.data() + sb, gs); - svars.head(D.size()) = D.array().square(); + svars.head(D.size()) = D.array(); } } @@ -234,6 +237,7 @@ struct StateBasilNaive : StateBasilBase< const std::string& strong_rule, size_t max_iters, value_t tol, + value_t rsq_tol, value_t rsq_slope_tol, value_t rsq_curv_tol, value_t newton_tol, @@ -253,7 +257,7 @@ struct StateBasilNaive : StateBasilBase< base_t( groups, group_sizes, alpha, penalty, lmda_path, lmda_max, min_ratio, lmda_path_size, delta_lmda_path_size, delta_strong_size, max_strong_size, strong_rule, - max_iters, tol, rsq_slope_tol, rsq_curv_tol, + max_iters, tol, rsq_tol, rsq_slope_tol, rsq_curv_tol, newton_tol, newton_max_iters, early_exit, setup_lmda_max, setup_lmda_path, intercept, n_threads, strong_set, strong_beta, strong_is_active, rsq, lmda, grad ), diff --git a/adelie/src/include/adelie_core/state/state_pin_naive.hpp b/adelie/src/include/adelie_core/state/state_pin_naive.hpp index 9a092625..61b9ca51 100644 --- a/adelie/src/include/adelie_core/state/state_pin_naive.hpp +++ b/adelie/src/include/adelie_core/state/state_pin_naive.hpp @@ -40,8 +40,11 @@ struct StatePinNaive : StatePinBase< /* Static states */ const value_t y_mean; + const value_t y_var; const map_cvec_value_t strong_X_means; + const value_t rsq_tol; + /* Dynamic states */ matrix_t* X; map_vec_value_t resid; @@ -55,6 +58,7 @@ struct StatePinNaive : StatePinBase< explicit StatePinNaive( matrix_t& X, value_t y_mean, + value_t y_var, const Eigen::Ref& groups, const Eigen::Ref& group_sizes, value_t alpha, @@ -70,6 +74,7 @@ struct StatePinNaive : StatePinBase< bool intercept, size_t max_iters, value_t tol, + value_t rsq_tol, value_t rsq_slope_tol, value_t rsq_curv_tol, value_t newton_tol, @@ -88,7 +93,9 @@ struct StatePinNaive : StatePinBase< rsq, strong_beta, strong_is_active ), y_mean(y_mean), + y_var(y_var), strong_X_means(strong_X_means.data(), strong_X_means.size()), + rsq_tol(rsq_tol), X(&X), resid(resid.data(), resid.size()), resid_sum(resid_sum), diff --git a/adelie/src/include/adelie_core/util/stopwatch.hpp b/adelie/src/include/adelie_core/util/stopwatch.hpp index 6aaa696d..5a27ebaf 100644 --- a/adelie/src/include/adelie_core/util/stopwatch.hpp +++ b/adelie/src/include/adelie_core/util/stopwatch.hpp @@ -8,33 +8,21 @@ class Stopwatch { using sw_clock_t = std::chrono::steady_clock; using tpt_t = std::chrono::time_point; - double& store_; double elapsed_; tpt_t start_; public: - Stopwatch(double& store) - : store_(store) - { - start(); - } - - ~Stopwatch() - { - elapsed(); - store_ = elapsed_; - } - void start() { start_ = sw_clock_t::now(); } - void elapsed() + double elapsed() { const auto end = sw_clock_t::now(); const auto dur = (end - start_); elapsed_ = std::chrono::duration_cast(dur).count() * 1e-9; + return elapsed_; } }; diff --git a/adelie/src/include/adelie_core/util/type_traits.hpp b/adelie/src/include/adelie_core/util/type_traits.hpp deleted file mode 100644 index 5051225e..00000000 --- a/adelie/src/include/adelie_core/util/type_traits.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once -#include - -namespace adelie_core { -namespace util { - -template -struct is_dense -{ - using po_t = typename std::decay_t::PlainObject; - static constexpr bool value = - std::is_base_of, po_t>::value; -}; - -} // namespace util -} // namespace adelie_core diff --git a/adelie/src/state.cpp b/adelie/src/state.cpp index 8ec0fcda..b4907a42 100644 --- a/adelie/src/state.cpp +++ b/adelie/src/state.cpp @@ -309,6 +309,7 @@ void state_pin_naive(py::module_& m, const char* name) .def(py::init< matrix_t&, value_t, + value_t, const Eigen::Ref&, const Eigen::Ref&, value_t, @@ -327,6 +328,7 @@ void state_pin_naive(py::module_& m, const char* name) value_t, value_t, value_t, + value_t, size_t, size_t, value_t, @@ -337,6 +339,7 @@ void state_pin_naive(py::module_& m, const char* name) >(), py::arg("X"), py::arg("y_mean"), + py::arg("y_var"), py::arg("groups").noconvert(), py::arg("group_sizes").noconvert(), py::arg("alpha"), @@ -352,6 +355,7 @@ void state_pin_naive(py::module_& m, const char* name) py::arg("intercept"), py::arg("max_iters"), py::arg("tol"), + py::arg("rsq_tol"), py::arg("rsq_slope_tol"), py::arg("rsq_curv_tol"), py::arg("newton_tol"), @@ -367,6 +371,12 @@ void state_pin_naive(py::module_& m, const char* name) .def_readonly("y_mean", &state_t::y_mean, R"delimiter( Mean of :math:`y`. )delimiter") + .def_readonly("y_var", &state_t::y_var, R"delimiter( + :math:`\ell_2` norm squared of :math:`y_c`. + )delimiter") + .def_readonly("rsq_tol", &state_t::rsq_tol, R"delimiter( + Early stopping rule check on :math:`R^2`. + )delimiter") .def_readonly("strong_X_means", &state_t::strong_X_means, R"delimiter( Column means of :math:`X` for strong groups. )delimiter") @@ -522,6 +532,7 @@ void state_basil_base(py::module_& m, const char* name) value_t, value_t, value_t, + value_t, size_t, bool, bool, @@ -549,6 +560,7 @@ void state_basil_base(py::module_& m, const char* name) py::arg("strong_rule"), py::arg("max_iters"), py::arg("tol"), + py::arg("rsq_tol"), py::arg("rsq_slope_tol"), py::arg("rsq_curv_tol"), py::arg("newton_tol"), @@ -623,6 +635,9 @@ void state_basil_base(py::module_& m, const char* name) .def_readonly("tol", &state_t::tol, R"delimiter( Convergence tolerance. )delimiter") + .def_readonly("rsq_tol", &state_t::rsq_tol, R"delimiter( + Early stopping rule check on :math:`R^2`. + )delimiter") .def_readonly("rsq_slope_tol", &state_t::rsq_slope_tol, R"delimiter( Early stopping rule check on slope of :math:`R^2`. )delimiter") @@ -770,6 +785,38 @@ void state_basil_base(py::module_& m, const char* name) }, R"delimiter( ``intercepts[i]`` is the intercept solution corresponding to ``lmdas[i]``. )delimiter") + .def_property_readonly("benchmark_screen", [](const state_t& s) { + return Eigen::Map>( + s.benchmark_screen.data(), + s.benchmark_screen.size() + ); + }, R"delimiter( + Screen time for a given BASIL iteration. + )delimiter") + .def_property_readonly("benchmark_fit", [](const state_t& s) { + return Eigen::Map>( + s.benchmark_fit.data(), + s.benchmark_fit.size() + ); + }, R"delimiter( + Fit time for a given BASIL iteration. + )delimiter") + .def_property_readonly("benchmark_kkt", [](const state_t& s) { + return Eigen::Map>( + s.benchmark_kkt.data(), + s.benchmark_kkt.size() + ); + }, R"delimiter( + KKT time for a given BASIL iteration. + )delimiter") + .def_property_readonly("benchmark_invariance", [](const state_t& s) { + return Eigen::Map>( + s.benchmark_invariance.data(), + s.benchmark_invariance.size() + ); + }, R"delimiter( + Invariance time for a given BASIL iteration. + )delimiter") ; } @@ -821,6 +868,7 @@ void state_basil_naive(py::module_& m, const char* name) value_t, value_t, value_t, + value_t, size_t, bool, bool, @@ -858,6 +906,7 @@ void state_basil_naive(py::module_& m, const char* name) py::arg("strong_rule"), py::arg("max_iters"), py::arg("tol"), + py::arg("rsq_tol"), py::arg("rsq_slope_tol"), py::arg("rsq_curv_tol"), py::arg("newton_tol"), diff --git a/adelie/state.py b/adelie/state.py index 543bc17f..1628de22 100644 --- a/adelie/state.py +++ b/adelie/state.py @@ -534,6 +534,7 @@ def default_init( *, X: matrix.base | matrix.MatrixNaiveBase64 | matrix.MatrixNaiveBase32, y_mean: float, + y_var: float, groups: np.ndarray, group_sizes: np.ndarray, alpha: float, @@ -547,6 +548,7 @@ def default_init( intercept: bool, max_iters: int, tol: float, + rsq_tol: float, rsq_slope_tol: float, rsq_curv_tol: float, newton_tol: float, @@ -619,6 +621,7 @@ def default_init( self, X=X, y_mean=y_mean, + y_var=y_var, groups=self._groups, group_sizes=self._group_sizes, alpha=alpha, @@ -634,6 +637,7 @@ def default_init( intercept=intercept, max_iters=max_iters, tol=tol, + rsq_tol=rsq_tol, rsq_slope_tol=rsq_slope_tol, rsq_curv_tol=rsq_curv_tol, newton_tol=newton_tol, @@ -715,6 +719,7 @@ def pin_naive( *, X: matrix.base | matrix.MatrixNaiveBase64 | matrix.MatrixNaiveBase32, y_mean: float, + y_var: float, groups: np.ndarray, group_sizes: np.ndarray, alpha: float, @@ -727,9 +732,10 @@ def pin_naive( strong_is_active: np.ndarray, intercept: bool =True, max_iters: int =int(1e5), - tol: float =1e-16, - rsq_slope_tol: float =1e-2, - rsq_curv_tol: float =1e-2, + tol: float =1e-12, + rsq_tol: float =0.9, + rsq_slope_tol: float =1e-3, + rsq_curv_tol: float =1e-3, newton_tol: float =1e-12, newton_max_iters: int =1000, n_threads: int =1, @@ -743,6 +749,8 @@ def pin_naive( It is typically one of the matrices defined in ``adelie.matrix`` sub-module. y_mean : float Mean of :math:`y`. + y_var : float + :math:`\\ell_2` norm squared of :math:`y_c`. groups : (G,) np.ndarray List of starting indices to each group where `G` is the number of groups. ``groups[i]`` is the starting index of the ``i`` th group. @@ -783,13 +791,16 @@ def pin_naive( Default is ``int(1e5)``. tol : float, optional Convergence tolerance. - Default is ``1e-16``. + Default is ``1e-12``. + rsq_tol : float, optional + Early stopping rule check on :math:`R^2`. + Default is ``0.9``. rsq_slope_tol : float, optional Early stopping rule check on slope of :math:`R^2`. - Default is ``1e-2``. + Default is ``1e-3``. rsq_curv_tol : float, optional Early stopping rule check on curvature of :math:`R^2`. - Default is ``1e-2``. + Default is ``1e-3``. newton_tol : float, optional Convergence tolerance for the BCD update. Default is ``1e-12``. @@ -831,6 +842,7 @@ def pin_naive( return dispatcher[dtype]( X=X, y_mean=y_mean, + y_var=y_var, groups=groups, group_sizes=group_sizes, alpha=alpha, @@ -844,6 +856,7 @@ def pin_naive( intercept=intercept, max_iters=max_iters, tol=tol, + rsq_tol=rsq_tol, rsq_slope_tol=rsq_slope_tol, rsq_curv_tol=rsq_curv_tol, newton_tol=newton_tol, @@ -1023,9 +1036,9 @@ def pin_cov( strong_grad: np.ndarray, strong_is_active: np.ndarray, max_iters: int =int(1e5), - tol: float =1e-16, - rsq_slope_tol: float =1e-2, - rsq_curv_tol: float =1e-2, + tol: float =1e-12, + rsq_slope_tol: float =1e-3, + rsq_curv_tol: float =1e-3, newton_tol: float =1e-12, newton_max_iters: int =1000, n_threads: int =1, @@ -1081,13 +1094,13 @@ def pin_cov( Default is ``int(1e5)``. tol : float, optional Convergence tolerance. - Default is ``1e-16``. + Default is ``1e-12``. rsq_slope_tol : float, optional Early stopping rule check on slope of :math:`R^2`. - Default is ``1e-2``. + Default is ``1e-3``. rsq_curv_tol : float, optional Early stopping rule check on curvature of :math:`R^2`. - Default is ``1e-2``. + Default is ``1e-3``. newton_tol : float, optional Convergence tolerance for the BCD update. Default is ``1e-12``. @@ -1182,6 +1195,7 @@ def default_init( strong_rule: str, max_iters: int, tol: float, + rsq_tol: float, rsq_slope_tol: float, rsq_curv_tol: float, newton_tol: float, @@ -1252,6 +1266,7 @@ def default_init( strong_rule=strong_rule, max_iters=max_iters, tol=tol, + rsq_tol=rsq_tol, rsq_slope_tol=rsq_slope_tol, rsq_curv_tol=rsq_curv_tol, newton_tol=newton_tol, @@ -1602,6 +1617,47 @@ def check( method, logger, ) + def update_path(self, path): + return basil_naive( + X=self.X, + X_means=self.X_means, + X_group_norms=self.X_group_norms, + y_mean=self.y_mean, + y_var=self.y_var, + resid=self.resid, + groups=self.groups, + group_sizes=self.group_sizes, + alpha=self.alpha, + penalty=self.penalty, + strong_set=self.strong_set, + strong_beta=self.strong_beta, + strong_is_active=self.strong_is_active, + rsq=self.rsq, + lmda=self.lmda, + grad=self.grad, + lmda_path=path, + lmda_max=None if self.lmda_max == -1 else self.lmda_max, + edpp_safe_set=self.edpp_safe_set, + edpp_v1_0=None if len(self.edpp_v1_0) == 0 else self.edpp_v1_0, + edpp_resid_0=None if len(self.edpp_resid_0) == 0 else self.edpp_resid_0, + max_iters=self.max_iters, + tol=self.tol, + rsq_tol=self.rsq_tol, + rsq_slope_tol=self.rsq_slope_tol, + rsq_curv_tol=self.rsq_curv_tol, + newton_tol=self.newton_tol, + newton_max_iters=self.newton_max_iters, + n_threads=self.n_threads, + early_exit=self.early_exit, + intercept=self.intercept, + strong_rule=self.strong_rule, + min_ratio=self.min_ratio, + lmda_path_size=self.lmda_path_size, + delta_lmda_path_size=self.delta_lmda_path_size, + delta_strong_size=self.delta_strong_size, + max_strong_size=self.max_strong_size, + ) + class basil_naive_64(basil_naive_base, core.state.StateBasilNaive64): """State class for basil, naive method using 64-bit floating point.""" @@ -1669,9 +1725,10 @@ def basil_naive( edpp_v1_0: np.ndarray =None, edpp_resid_0: np.ndarray =None, max_iters: int =int(1e5), - tol: float =1e-16, - rsq_slope_tol: float =1e-2, - rsq_curv_tol: float =1e-2, + tol: float =1e-12, + rsq_tol: float =0.9, + rsq_slope_tol: float =1e-3, + rsq_curv_tol: float =1e-3, newton_tol: float =1e-12, newton_max_iters: int =1000, n_threads: int =1, @@ -1769,13 +1826,16 @@ def basil_naive( Default is ``int(1e5)``. tol : float, optional Convergence tolerance. - Default is ``1e-16``. + Default is ``1e-12``. + rsq_tol : float, optional + Early stopping rule check on :math:`R^2`. + Default is ``0.9``. rsq_slope_tol : float, optional Early stopping rule check on slope of :math:`R^2`. - Default is ``1e-2``. + Default is ``1e-3``. rsq_curv_tol : float, optional Early stopping rule check on curvature of :math:`R^2`. - Default is ``1e-2``. + Default is ``1e-3``. newton_tol : float, optional Convergence tolerance for the BCD update. Default is ``1e-12``. @@ -1887,6 +1947,7 @@ def basil_naive( strong_rule=strong_rule, max_iters=max_iters, tol=tol, + rsq_tol=rsq_tol, rsq_slope_tol=rsq_slope_tol, rsq_curv_tol=rsq_curv_tol, newton_tol=newton_tol, diff --git a/benchmark/bench_omp.cpp b/benchmark/bench_omp.cpp index 778b88a2..d3e53db5 100644 --- a/benchmark/bench_omp.cpp +++ b/benchmark/bench_omp.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include #include @@ -221,6 +223,40 @@ static void BM_btmul_par_c(benchmark::State& state) { } } +static void BM_svd(benchmark::State& state) +{ + const auto n = state.range(0); + const auto p = state.range(1); + ad::util::colmat_type X(n, p); + srand(0); + X.setRandom(); + + for (auto _ : state) { + Eigen::BDCSVD> solver( + X, + Eigen::ComputeFullV + ); + benchmark::DoNotOptimize(solver); + } +} + +static void BM_eigh(benchmark::State& state) +{ + const auto n = state.range(0); + const auto p = state.range(1); + ad::util::colmat_type X(n, p); + ad::util::colmat_type XTX(p, p); + srand(0); + X.setRandom(); + + Eigen::setNbThreads(8); + for (auto _ : state) { + XTX.noalias() = X.transpose() * X; + Eigen::SelfAdjointEigenSolver> solver(XTX); + benchmark::DoNotOptimize(solver); + } +} + BENCHMARK(BM_cmul_seq) -> Args({10000000}) ; @@ -266,4 +302,17 @@ BENCHMARK(BM_btmul_par_c) -> Args({1000000, 8, 2}) -> Args({1000000, 8, 4}) -> Args({1000000, 8, 8}) + ; + +BENCHMARK(BM_svd) + -> Args({100, 8}) + -> Args({1000, 8}) + -> Args({10000, 8}) + -> Args({100000, 8}) + ; +BENCHMARK(BM_eigh) + -> Args({100, 8}) + -> Args({1000, 8}) + -> Args({10000, 8}) + -> Args({100000, 8}) ; \ No newline at end of file diff --git a/research/test.ipynb b/research/test.ipynb index 4fda1b85..94bb9700 100644 --- a/research/test.ipynb +++ b/research/test.ipynb @@ -17,7 +17,6 @@ "outputs": [], "source": [ "import adelie as ad\n", - "import cvxpy as cp\n", "import numpy as np" ] }, @@ -52,6 +51,7 @@ "\n", " # generate raw data\n", " X = np.random.normal(0, 1, (n, p))\n", + " X = np.asfortranarray(X)\n", " beta = np.random.normal(0, 1, p)\n", " beta[np.random.choice(p, int(sparsity * p), replace=False)] = 0\n", " y = X @ beta + np.random.normal(0, 1, n)\n", @@ -97,15 +97,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ - "n = 10000\n", - "p = 10000\n", - "G = 1000\n", + "n = 1000\n", + "p = 1000000\n", + "G = 100000\n", "intercept = True\n", - "n_threads = 12\n", "\n", "test_data = create_test_data_basil(\n", " n, p, G, \n", @@ -115,31 +114,38 @@ ")\n", "test_data[\"penalty\"] = np.sqrt(test_data[\"group_sizes\"])\n", "X, y = test_data[\"X\"], test_data[\"y\"]\n", - "Xc = X - np.mean(X, axis=0)[None] * intercept\n", - "yc = y - np.mean(y) * intercept\n", - "\n", "test_data.pop(\"y\")\n", + "Xc = X - np.mean(X, axis=0)[None] * intercept\n", + "yc = y - np.mean(y) * intercept" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "n_threads = 8\n", "test_data[\"X\"] = ad.matrix.naive_dense(X, n_threads=n_threads)\n", "state = ad.state.basil_naive(\n", " **test_data,\n", - " delta_lmda_path_size=3,\n", - " rsq_slope_tol=1e-4,\n", - " rsq_curv_tol=1e-2,\n", + " max_strong_size=30000,\n", + " strong_rule=\"default\",\n", " n_threads=n_threads,\n", ")" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 73, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 10min 46s, sys: 821 ms, total: 10min 47s\n", - "Wall time: 54.7 s\n" + "CPU times: user 2min 44s, sys: 953 ms, total: 2min 45s\n", + "Wall time: 20.7 s\n" ] } ], @@ -150,64 +156,142 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([0.00000000e+00, 6.77120108e-04, 1.29408683e-03, 2.34312370e-03,\n", - " 3.73901847e-03, 5.14436296e-03, 6.65568234e-03, 8.58942724e-03,\n", - " 1.08114868e-02, 1.41029691e-02, 1.81845374e-02, 2.37259484e-02,\n", - " 3.30676467e-02, 4.36765023e-02, 5.63350007e-02, 6.96027444e-02,\n", - " 8.25781543e-02, 9.73405573e-02, 1.12673532e-01, 1.29154472e-01,\n", - " 1.49425054e-01, 1.72598117e-01, 1.96746550e-01, 2.24700485e-01,\n", - " 2.54688709e-01, 2.85351669e-01, 3.15517812e-01, 3.45869186e-01,\n", - " 3.76058096e-01, 4.06840833e-01, 4.37689371e-01, 4.67536139e-01,\n", - " 4.97114794e-01, 5.26530956e-01, 5.55346825e-01, 5.82553502e-01,\n", - " 6.08636739e-01, 6.33697154e-01, 6.57537341e-01, 6.79934929e-01,\n", - " 7.01094975e-01, 7.21371204e-01, 7.40568500e-01, 7.58654216e-01,\n", - " 7.75509717e-01, 7.91182109e-01, 8.05761361e-01, 8.19358666e-01,\n", - " 8.32079111e-01, 8.43957444e-01, 8.55052580e-01, 8.65378932e-01,\n", - " 8.74977709e-01, 8.83876982e-01, 8.92147675e-01, 8.99838638e-01,\n", - " 9.06978203e-01, 9.13642970e-01, 9.19883062e-01, 9.25699791e-01,\n", - " 9.31112860e-01, 9.36135949e-01, 9.40787265e-01, 9.45100544e-01,\n", - " 9.49091164e-01, 9.52799136e-01, 9.56235346e-01, 9.59435725e-01,\n", - " 9.62405844e-01, 9.65159597e-01, 9.67712617e-01, 9.70074728e-01,\n", - " 9.72262587e-01, 9.74290836e-01, 9.76169216e-01, 9.77906919e-01,\n", - " 9.79515257e-01, 9.81007900e-01, 9.82395143e-01, 9.83676780e-01,\n", - " 9.84861088e-01, 9.85955245e-01, 9.86966672e-01, 9.87901996e-01,\n", - " 9.88766472e-01, 9.89565700e-01, 9.90305513e-01, 9.90992033e-01,\n", - " 9.91626366e-01])" + "6.047451920000001" ] }, - "execution_count": 9, + "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "next_state.rsqs" + "np.sum(next_state.benchmark_fit)" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(6.047451920000001, 5.5086994350000005, 0.005143656, 8.708577768000001)" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(\n", + " np.sum(next_state.benchmark_fit),\n", + " np.sum(next_state.benchmark_screen),\n", + " np.sum(next_state.benchmark_invariance),\n", + " np.sum(next_state.benchmark_kkt),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.32461099, 0.50084198, 0.45982327, 0.46059465, 0.5058221 ,\n", + " 0.440893 , 0.441547 , 0.44808808, 0.47491148, 0.60637055,\n", + " 0.56224193, 0.66115996, 0.70434021, 0.63886951, 0.80449459,\n", + " 0.67396849])" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "next_state.benchmark_kkt" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(1000, 568, 509)" + "(1183, 3003)" ] }, - "execution_count": 14, + "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "len(next_state.edpp_safe_set), len(next_state.strong_set), np.sum(next_state.strong_is_active)" + "(\n", + " len(next_state.strong_set),\n", + " len(next_state.edpp_safe_set),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<42x1000000 sparse matrix of type ''\n", + "\twith 18460 stored elements in Compressed Sparse Row format>" + ] + }, + "execution_count": 81, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "next_state.betas" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0. , 0.00303226, 0.00713723, 0.0120045 , 0.0201986 ,\n", + " 0.02999345, 0.04113448, 0.05565416, 0.07271441, 0.09116506,\n", + " 0.11233844, 0.13734477, 0.16723087, 0.20191022, 0.23896808,\n", + " 0.27788078, 0.31810521, 0.35797503, 0.3969806 , 0.43511581,\n", + " 0.47229851, 0.50771219, 0.54172233, 0.5739653 , 0.60440091,\n", + " 0.63319388, 0.66026084, 0.68597565, 0.71027314, 0.73277567,\n", + " 0.75365057, 0.77308266, 0.79115993, 0.80789267, 0.82339423,\n", + " 0.83775291, 0.85099826, 0.86331652, 0.87466545, 0.88508539,\n", + " 0.89464705, 0.90342317])" + ] + }, + "execution_count": 82, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "next_state.rsqs" ] } ], diff --git a/tests/test_solver.py b/tests/test_solver.py index 17517892..6968ef36 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -7,10 +7,6 @@ import cvxpy as cp import numpy as np -# ======================================================================== -# TEST helpers -# ======================================================================== - # ======================================================================== # TEST solve_pin # ======================================================================== @@ -190,6 +186,7 @@ def _test(n, p, G, S, intercept=True, alpha=1, sparsity=0.95, seed=0): ) y_mean = np.mean(y) resid = y - intercept * y_mean + y_var = np.sum(resid ** 2) Xs = [ ad.matrix.naive_dense(X, n_threads=2) ] @@ -197,6 +194,7 @@ def _test(n, p, G, S, intercept=True, alpha=1, sparsity=0.95, seed=0): state = ad.state.pin_naive( X=Xpy, y_mean=y_mean, + y_var=y_var, groups=groups, group_sizes=group_sizes, alpha=alpha, @@ -213,6 +211,7 @@ def _test(n, p, G, S, intercept=True, alpha=1, sparsity=0.95, seed=0): state = ad.state.pin_naive( X=Xpy, y_mean=y_mean, + y_var=y_var, groups=groups, group_sizes=group_sizes, alpha=alpha, diff --git a/tests/test_state.py b/tests/test_state.py index 716b15cf..55df5360 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -22,12 +22,14 @@ def test_state_pin_naive(): rsq = 0.0 resid = np.random.normal(0, 1, n) y_mean = 0 + y_var = 1 strong_beta = np.zeros(p) strong_is_active = np.zeros(strong_set.shape[0], dtype=bool) state = mod.pin_naive( X=X, y_mean=y_mean, + y_var=y_var, groups=groups, group_sizes=group_sizes, alpha=alpha,