diff --git a/src/admixture.cpp b/src/admixture.cpp index d8e2ec9..80d1de7 100644 --- a/src/admixture.cpp +++ b/src/admixture.cpp @@ -245,7 +245,7 @@ int run_admix_main(Options & opts) if(!opts.noaccel) { MyArr2D F0, Q0, F1, Q1, F2, Q2, Ft, Qt; - const int istep{4}; + const double istep{4}; double alpha{std::numeric_limits::lowest()}, qdiff, ldiff, stepMax{4}, alphaMax{1280}; double prevlike{std::numeric_limits::lowest()}, logcheck{0}, loglike{0}; for(int it = 0; SIG_COND && (it < opts.nadmix / 4); it++) @@ -282,9 +282,12 @@ int run_admix_main(Options & opts) opts.ltol); break; } - // save for later comparison - Ft = admixer.F; - Qt = admixer.Q; + if(!opts.force) + { + // save for later comparison + Ft = admixer.F; + Qt = admixer.Q; + } // accel iteration with steplen alpha = ((F1 - F0).square().sum() + (Q1 - Q0).square().sum()) / ((admixer.F - 2 * F1 + F0).square().sum() + (admixer.Q - 2 * Q1 + Q0).square().sum()); @@ -307,29 +310,32 @@ int run_admix_main(Options & opts) for(auto && ll : llike) loglike += ll.get(); llike.clear(); // clear future and renew admixer.updateIteration(); - // save current pars - F2 = admixer.F; - Q2 = admixer.Q; - // check if normal third iter is better - admixer.Q = Qt; - admixer.F = Ft; - admixer.initIteration(); - for(int i = 0; i < genome->nsamples; i++) - llike.emplace_back(poolit.enqueue(&Admixture::runOptimalWithBigAss, &admixer, i, std::ref(genome))); - logcheck = 0; - for(auto && ll : llike) logcheck += ll.get(); - llike.clear(); // clear future and renew - admixer.updateIteration(); - if(logcheck - loglike > 0.1) - { - stepMax = istep; - cao.warn(tim.date(), "reset stepMax to 4, normal EM yields better likelihoods than the accelerated EM.", - logcheck, " -", loglike, " > 0.1"); - } - else - { - admixer.Q = Q2; - admixer.F = F2; + if(!opts.force) + { // save current pars + F2 = admixer.F; + Q2 = admixer.Q; + // check if normal third iter is better + admixer.Q = Qt; + admixer.F = Ft; + admixer.initIteration(); + for(int i = 0; i < genome->nsamples; i++) + llike.emplace_back(poolit.enqueue(&Admixture::runOptimalWithBigAss, &admixer, i, std::ref(genome))); + logcheck = 0; + for(auto && ll : llike) logcheck += ll.get(); + llike.clear(); // clear future and renew + admixer.updateIteration(); + if(logcheck - loglike > 0.1) + { + stepMax = istep; + cao.warn(tim.date(), + "reset stepMax to 4, normal EM yields better likelihoods than the accelerated EM.", + logcheck, " -", loglike, " > 0.1"); + } + else + { + admixer.Q = Q2; + admixer.F = F2; + } } } } diff --git a/src/common.hpp b/src/common.hpp index 76cb3f7..c9749be 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -103,7 +103,7 @@ struct Options double ftol{1e-6}; // threshold for F double qtol{1e-6}; // threshold for Q bool noaccel{0}, noscreen{0}, single_chunk{0}, debug{0}, collapse{0}; - bool nQ{0}, nP{0}, nF{0}, nR{0}, aQ{0}, oVCF{0}, eHap{0}, oF{0}, cF{0}; + bool nQ{0}, nP{0}, nF{0}, nR{0}, aQ{0}, oVCF{0}, eHap{0}, oF{0}, cF{0}, force{0}; std::string out, in_beagle, in_vcf, in_bin, in_impute, in_joint; std::string samples{""}, region{""}, in_plink{""}, in_qfile{""}, in_pfile{""}, in_rfile{""}; std::string opts_in_effect{"Options in effect:\n "}; @@ -274,7 +274,7 @@ inline Int2D split_pos_into_grid(const Int1D & pos, const Bool1D & collapse) inline Int1D calc_grid_distance(const Int1D & pos, const Bool1D & collapse) { - assert(pos.size() == collapse.size()); + assert((int)pos.size() == (int)collapse.size()); // B = 1 if((collapse == true).count() == 0) return calc_position_distance(pos); // B > 1, split pos into grids diff --git a/src/main.cpp b/src/main.cpp index 3b0f6a1..28dc050 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -211,6 +211,9 @@ int main(int argc, char * argv[]) .help("seed for reproducibility") .default_value(999) .scan<'i', int>(); + cmd_admix.add_argument("-f", "--force-accept") + .help("always accept the acceleration solution") + .flag(); cmd_admix.add_argument("-F", "--constrain-F") .help("apply constraint on F so that it is not smaller than cluster frequency in fastphase model") .flag(); @@ -312,6 +315,7 @@ int main(int argc, char * argv[]) opts.nthreads = cmd_admix.get("--threads"); opts.nadmix = cmd_admix.get("--iterations"); opts.cF = cmd_admix.get("--constrain-F"); + opts.force = cmd_admix.get("--force-accept"); if(opts.in_bin.empty() || cmd_admix.get("--help")) throw std::runtime_error(cmd_admix.help().str()); run_admix_main(opts); }