From d16b23ff95a02b2fb1f88f5b6faa4867fa0748ff Mon Sep 17 00:00:00 2001 From: Zilong-Li Date: Wed, 7 Feb 2024 12:54:31 +0100 Subject: [PATCH] fix post(z,g) for B>1 --- src/common.hpp | 28 ++++++++++++++++++++++++++-- src/fastphase.cpp | 12 +++++++----- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/common.hpp b/src/common.hpp index 613c84f..76cb3f7 100644 --- a/src/common.hpp +++ b/src/common.hpp @@ -348,11 +348,35 @@ inline MyArr2D get_emission_by_gl(const MyArr2D & gli, const MyArr2D & P, double } } } - emit = emit.colwise() / emit.rowwise().maxCoeff(); // normalize + // emit = emit.colwise() / emit.rowwise().maxCoeff(); // normalize emit = (emit < minEmission).select(minEmission, emit); return emit.transpose(); } +inline MyArr1D get_emission_by_site(const MyArr1D & gli, const MyArr1D & P, double minEmission = 1e-10) +{ + const int C = P.size(); + MyArr1D emit = MyArr1D::Zero(C * C); + int z1, z2, z12, g1, g2; + for(z1 = 0; z1 < C; z1++) + { + for(z2 = 0; z2 < C; z2++) + { + z12 = z1 * C + z2; + for(g1 = 0; g1 <= 1; g1++) + { + for(g2 = 0; g2 <= 1; g2++) + { + emit(z12) += + gli(g1 + g2) * (g1 * P(z1) + (1 - g1) * (1 - P(z1))) * (g2 * P(z2) + (1 - g2) * (1 - P(z2))); + } + } + } + } + emit = (emit < minEmission).select(minEmission, emit); + return emit; +} + /* ** @param gli genotype likelihoods of current individual i, (M, 3) ** @param P cluster-specific allele frequence (M, C) @@ -392,7 +416,7 @@ inline MyArr2D get_emission_by_grid(const MyArr2D & gli, } } // apply bounding - emit.col(g) /= emit.col(g).maxCoeff(); + emit.col(g) /= emit.col(g).sum(); emit.col(g) = (emit.col(g) < minEmission).select(minEmission, emit.col(g)); } return emit; diff --git a/src/fastphase.cpp b/src/fastphase.cpp index 560f3ea..2e468ad 100644 --- a/src/fastphase.cpp +++ b/src/fastphase.cpp @@ -183,27 +183,29 @@ double FastPhaseK2::hmmIterWithJumps(const MyFloat1D & GL, const int ic, const i start = grid_chunk[ic]; nsize = nGrids; } - MyArr2D emit = get_emission_by_grid(gli, P.middleRows(pos_chunk[ic], S), collapse.segment(pos_chunk[ic], S)); + MyArr2D emit_grid = get_emission_by_grid(gli, P.middleRows(pos_chunk[ic], S), collapse.segment(pos_chunk[ic], S)); const auto [alpha, beta, cs] = - forward_backwards_diploid(emit, R.middleCols(start, nsize), PI.middleCols(start, nsize)); + forward_backwards_diploid(emit_grid, R.middleCols(start, nsize), PI.middleCols(start, nsize)); if(!((1 - ((alpha * beta).colwise().sum())).abs() < 1e-9).all()) cao.error((alpha * beta).colwise().sum(), "\ngamma sum is not 1.0!\n"); // now get posterios MyArr2D ind_post_zg1(C, S), ind_post_zg2(C, S), ind_post_zj(C, nGrids), gammaC(C, nGrids); - MyArr1D gamma_div_emit(CC), beta_mult_emit(CC); + MyArr1D gamma_div_emit(CC), beta_mult_emit(CC), igamma(CC); MyArr1D alphatmp(C); int z1, m, s, g{0}, gg{0}; const auto se = find_grid_start_end(collapse.segment(pos_chunk[ic], S)); for(g = 0; g < nGrids; g++) { gg = g + grid_chunk[ic]; - gamma_div_emit = (alpha.col(g) * beta.col(g)) / emit.col(g); // C2 gammaC.col(g) = (alpha.col(g) * beta.col(g)).reshaped(C, C).colwise().sum(); + igamma = alpha.col(g) * beta.col(g); + gamma_div_emit = igamma / emit_grid.col(g); // C2 for(z1 = 0; z1 < C; z1++) { for(s = se[g][0]; s <= se[g][1]; s++) { m = s + pos_chunk[ic]; + if(B > 1) gamma_div_emit = igamma / get_emission_by_site(gli.row(s), P.row(m)); ind_post_zg1(z1, s) = (gamma_div_emit(Eigen::seqN(z1, C, C)) * (1 - P(m, z1)) * (gli(s, 0) * (1 - P.row(m)) + gli(s, 1) * P.row(m)).transpose()) .sum(); @@ -217,7 +219,7 @@ double FastPhaseK2::hmmIterWithJumps(const MyFloat1D & GL, const int ic, const i } if(g == 0) continue; alphatmp += PI.col(gg) * R(2, gg) * 1.0; // inner alpha.col(s-1).sum == 1 - beta_mult_emit = emit.col(g) * beta.col(g); // C2 + beta_mult_emit = emit_grid.col(g) * beta.col(g); // C2 for(z1 = 0; z1 < C; z1++) ind_post_zj(z1, g) = cs(g) * (PI(z1, gg) * alphatmp * beta_mult_emit(Eigen::seqN(z1, C, C))).sum(); }