Skip to content

Commit

Permalink
Add working state
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesYang007 committed Oct 17, 2023
1 parent 3e9e73a commit e4368f0
Show file tree
Hide file tree
Showing 16 changed files with 527 additions and 198 deletions.
6 changes: 5 additions & 1 deletion adelie/src/include/adelie_core/matrix/matrix_naive_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ class MatrixNaiveDense: public MatrixNaiveBase<typename DenseType::Scalar>
Eigen::Ref<colmat_t> out
) const override
{
out = _mat.middleCols(j, q);
dmmeq(
out,
_mat.middleCols(j, q),
_n_threads
);
}

int rows() const override
Expand Down
50 changes: 50 additions & 0 deletions adelie/src/include/adelie_core/matrix/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,56 @@ void dvsubi(
}
}

template <class X1Type, class X2Type>
ADELIE_CORE_STRONG_INLINE
void dmvsubi(
X1Type& x1,
const X2Type& x2,
size_t n_threads
)
{
assert(n_threads > 0);
const size_t n = x1.rows();
const int n_blocks = std::min(n_threads, n);
const int block_size = n / n_blocks;
const int remainder = n % n_blocks;

#pragma omp parallel for schedule(static) num_threads(n_blocks)
for (int t = 0; t < n_blocks; ++t) {
const auto begin = (
std::min<int>(t, remainder) * (block_size + 1)
+ std::max<int>(t-remainder, 0) * block_size
);
const auto size = block_size + (t < remainder);
x1.middleRows(begin, size).rowwise() -= x2;
}
}

template <class X1Type, class X2Type>
ADELIE_CORE_STRONG_INLINE
void dmmeq(
X1Type& x1,
const X2Type& x2,
size_t n_threads
)
{
assert(n_threads > 0);
const size_t n = x1.rows();
const int n_blocks = std::min(n_threads, n);
const int block_size = n / n_blocks;
const int remainder = n % n_blocks;

#pragma omp parallel for schedule(static) num_threads(n_blocks)
for (int t = 0; t < n_blocks; ++t) {
const auto begin = (
std::min<int>(t, remainder) * (block_size + 1)
+ std::max<int>(t-remainder, 0) * block_size
);
const auto size = block_size + (t < remainder);
x1.middleRows(begin, size) = x2.middleRows(begin, size);
}
}

template <class X1Type, class X2Type>
ADELIE_CORE_STRONG_INLINE
auto ddot(
Expand Down
64 changes: 47 additions & 17 deletions adelie/src/include/adelie_core/solver/solve_basil_naive.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,7 @@ void screen_strong(
return strong_hashset.find(i) != strong_hashset.end();
};

/* update strong_set */
// Use either the fixed-increment rule or strong rule to increase strong set.
if (strong_rule == state::strong_rule_type::_default) {
const auto strong_rule_lmda = (2 * lmda_next - lmda) * alpha;
// TODO: PARALLELIZE!
for (int i = 0; i < abs_grad.size(); ++i) {
if (is_strong(i)) continue;
if (abs_grad[i] > strong_rule_lmda * penalty[i]) {
strong_set.push_back(i);
}
}
} else if (strong_rule == state::strong_rule_type::_fixed_greedy) {
const auto do_fixed_greedy = [&]() {
size_t size_capped = std::min(delta_strong_size, new_safe_size);

strong_set.insert(strong_set.end(), size_capped, -1);
Expand All @@ -182,10 +171,32 @@ void screen_strong(
size_capped,
std::next(strong_set.begin(), old_strong_set_size)
);
};

/* update strong_set */
// Use either the fixed-increment rule or strong rule to increase strong set.
if (strong_rule == state::strong_rule_type::_default) {
const auto strong_rule_lmda = (2 * lmda_next - lmda) * alpha;

// TODO: PARALLELIZE!
for (int i = 0; i < abs_grad.size(); ++i) {
if (is_strong(i)) continue;
if (abs_grad[i] > strong_rule_lmda * penalty[i]) {
strong_set.push_back(i);
}
}

// If no new strong variables were added, need a fall-back.
// Use fixed-greedy method.
if (strong_set.size() == old_strong_set_size) {
do_fixed_greedy();
}
} else if (strong_rule == state::strong_rule_type::_fixed_greedy) {
do_fixed_greedy();
} else if (strong_rule == state::strong_rule_type::_safe) {
for (int i = 0; i < edpp_safe_set.size(); ++i) {
if (is_strong(edpp_safe_set[i])) continue;
strong_set.push_back(i);
strong_set.push_back(edpp_safe_set[i]);
}
}

Expand Down Expand Up @@ -243,6 +254,7 @@ auto fit(

auto& X = *state.X;
const auto y_mean = state.y_mean;
const auto y_var = state.y_var;
const auto& groups = state.groups;
const auto& group_sizes = state.group_sizes;
const auto alpha = state.alpha;
Expand All @@ -257,6 +269,9 @@ auto fit(
const auto intercept = state.intercept;
const auto max_iters = state.max_iters;
const auto tol = state.tol;
const auto rsq_tol = state.rsq_tol;
const auto rsq_slope_tol = state.rsq_slope_tol;
const auto rsq_curv_tol = state.rsq_curv_tol;
const auto newton_tol = state.newton_tol;
const auto newton_max_iters = state.newton_max_iters;
const auto n_threads = state.n_threads;
Expand All @@ -269,6 +284,7 @@ auto fit(
state_pin_naive_t state_pin_naive(
X,
y_mean,
y_var,
groups,
group_sizes,
alpha,
Expand All @@ -281,7 +297,8 @@ auto fit(
Eigen::Map<const vec_value_t>(strong_X_means.data(), strong_X_means.size()),
strong_transforms,
lmda_path,
intercept, max_iters, tol, 0, 0, newton_tol, newton_max_iters, n_threads,
intercept, max_iters, tol, rsq_tol, rsq_slope_tol, rsq_curv_tol,
newton_tol, newton_max_iters, n_threads,
rsq,
Eigen::Map<vec_value_t>(resid.data(), resid.size()),
resid_sum,
Expand Down Expand Up @@ -390,8 +407,8 @@ inline void solve_basil(
using state_t = std::decay_t<StateType>;
using value_t = typename state_t::value_t;
using vec_value_t = typename state_t::vec_value_t;
using vec_index_t = typename state_t::vec_index_t;
using vec_safe_bool_t = typename state_t::vec_safe_bool_t;
using sw_t = util::Stopwatch;

const auto& X_means = state.X_means;
const auto alpha = state.alpha;
Expand All @@ -401,6 +418,7 @@ inline void solve_basil(
const auto delta_lmda_path_size = state.delta_lmda_path_size;
const auto early_exit = state.early_exit;
const auto max_strong_size = state.max_strong_size;
const auto rsq_tol = state.rsq_tol;
const auto rsq_slope_tol = state.rsq_slope_tol;
const auto rsq_curv_tol = state.rsq_curv_tol;
const auto setup_edpp = state.setup_edpp;
Expand All @@ -424,6 +442,10 @@ inline void solve_basil(
auto& resid_prev_valid = state.resid_prev_valid;
auto& strong_beta_prev_valid = state.strong_beta_prev_valid;
auto& strong_is_active_prev_valid = state.strong_is_active_prev_valid;
auto& benchmark_screen = state.benchmark_screen;
auto& benchmark_fit = state.benchmark_fit;
auto& benchmark_kkt = state.benchmark_kkt;
auto& benchmark_invariance = state.benchmark_invariance;

const auto p = grad.size();

Expand Down Expand Up @@ -588,6 +610,7 @@ inline void solve_basil(
// We must go through BASIL iterations to solve each lambda.
vec_value_t lmda_batch;
std::vector<vec_value_t> grads;
sw_t sw;

while (1)
{
Expand All @@ -596,9 +619,8 @@ inline void solve_basil(
const auto rsq_u = rsqs[rsqs.size()-1];
const auto rsq_m = rsqs[rsqs.size()-2];
const auto rsq_l = rsqs[rsqs.size()-3];
// TODO: generalize 0.99
if (check_early_stop_rsq(rsq_l, rsq_m, rsq_u, rsq_slope_tol, rsq_curv_tol) ||
(rsqs.back() >= 0.99)) break;
(rsqs.back() >= rsq_tol)) break;
}

// check if any lambdas left to fit
Expand All @@ -613,10 +635,12 @@ inline void solve_basil(
// ====================================================================================
// Screening step
// ====================================================================================
sw.start();
naive::screen(
state,
lmda_batch[0]
);
benchmark_screen.push_back(sw.elapsed());

try {
// ====================================================================================
Expand All @@ -625,27 +649,32 @@ inline void solve_basil(
// Save all current valid quantities that will be modified in-place by fit.
// This is needed for the invariance step in case no valid solutions are found.
save_prev_valid();
sw.start();
auto&& state_pin_naive = naive::fit(
state,
lmda_batch,
update_coefficients_f,
check_user_interrupt
);
benchmark_fit.push_back(sw.elapsed());

// ====================================================================================
// KKT step
// ====================================================================================
grads.resize(lmda_batch.size());
for (auto& g : grads) g.resize(p);
sw.start();
const auto n_valid_solutions = naive::kkt(
state,
state_pin_naive,
grads
);
benchmark_kkt.push_back(sw.elapsed());

// ====================================================================================
// Invariance step
// ====================================================================================
sw.start();
lmda_path_idx += n_valid_solutions;
// If no valid solutions found, restore to the previous valid state,
// so that we can start over after screening more variables.
Expand Down Expand Up @@ -675,6 +704,7 @@ inline void solve_basil(
state_pin_naive,
n_valid_solutions
);
benchmark_invariance.push_back(sw.elapsed());
} catch (const std::exception& e) {
load_prev_valid();
throw util::propagator_error(e.what());
Expand Down
64 changes: 31 additions & 33 deletions adelie/src/include/adelie_core/solver/solve_pin_cov.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,24 +244,23 @@ void solve_pin_active(
ab_diff_view_curr = sb;
}

time_active_cd.push_back(0);
{
sw_t stopwatch(time_active_cd.back());
while (1) {
check_user_interrupt(iters);
++iters;
value_t convg_measure;
coordinate_descent(
state,
active_g1.data(), active_g1.data() + active_g1.size(),
active_g2.data(), active_g2.data() + active_g2.size(),
lmda_idx, convg_measure, buffer1, buffer2, buffer3, buffer4,
update_coefficients_f
);
if (convg_measure < tol) break;
if (iters >= max_iters) throw util::max_cds_error(lmda_idx);
}
sw_t stopwatch;
stopwatch.start();
while (1) {
check_user_interrupt(iters);
++iters;
value_t convg_measure;
coordinate_descent(
state,
active_g1.data(), active_g1.data() + active_g1.size(),
active_g2.data(), active_g2.data() + active_g2.size(),
lmda_idx, convg_measure, buffer1, buffer2, buffer3, buffer4,
update_coefficients_f
);
if (convg_measure < tol) break;
if (iters >= max_iters) throw util::max_cds_error(lmda_idx);
}
time_active_cd.push_back(stopwatch.elapsed());

// compute new active beta - old active beta
for (size_t i = 0; i < active_set.size(); ++i) {
Expand Down Expand Up @@ -422,22 +421,21 @@ inline void solve_pin(
++iters;
value_t convg_measure;
const auto old_active_size = active_set.size();
time_strong_cd.push_back(0);
{
sw_t stopwatch(time_strong_cd.back());
coordinate_descent(
state,
strong_g1.data(), strong_g1.data() + strong_g1.size(),
strong_g2.data(), strong_g2.data() + strong_g2.size(),
l, convg_measure,
buffer_pack.buffer1,
buffer_pack.buffer2,
buffer_pack.buffer3,
buffer_pack.buffer4,
update_coefficients_f,
add_active_set
);
}
sw_t stopwatch;
stopwatch.start();
coordinate_descent(
state,
strong_g1.data(), strong_g1.data() + strong_g1.size(),
strong_g2.data(), strong_g2.data() + strong_g2.size(),
l, convg_measure,
buffer_pack.buffer1,
buffer_pack.buffer2,
buffer_pack.buffer3,
buffer_pack.buffer4,
update_coefficients_f,
add_active_set
);
time_strong_cd.push_back(stopwatch.elapsed());
const bool new_active_added = (old_active_size < active_set.size());

if (new_active_added) {
Expand Down
Loading

0 comments on commit e4368f0

Please sign in to comment.