diff --git a/.DS_Store b/.DS_Store index 6ef3173..4abeee4 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/R/fit_task0.R b/R/fit_task0.R index 2289c46..9d24544 100644 --- a/R/fit_task0.R +++ b/R/fit_task0.R @@ -54,7 +54,7 @@ fit_breakpoints = function( x$metadata$breakpoints <- best_bp if (!(is.null(best_bp))) { - x$counts$group <- biPOD:::bp_to_groups(x$counts, x$metadata$breakpoints) + x$counts$group <- bp_to_groups(x$counts, x$metadata$breakpoints) } cli::cli_alert_info("Median of the inferred breakpoints have been succesfully stored.") @@ -129,7 +129,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t fit_with_breaks_opt <- function(breaks) { # Ensure necessary attributes are initialized - if(!all(((biPOD:::bp_to_groups(dplyr::tibble(time=x, count=y), break_points = breaks) %>% table()) >= avg_points_per_window))) return(Inf) + if(!all(((bp_to_groups(dplyr::tibble(time=x, count=y), break_points = breaks) %>% table()) >= avg_points_per_window))) return(Inf) A <- assemble_regression_matrix(c(min(x), breaks, max(x))) # Try to solve the regression problem @@ -206,7 +206,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t dplyr::distinct() for (j in 1:nrow(proposed_breakpoints)) { - if (!all((biPOD:::bp_to_groups(dplyr::tibble(time=x, count=y), unlist(proposed_breakpoints[j,]$best_bp)) %>% table()) >= avg_points_per_window)) { + if (!all((bp_to_groups(dplyr::tibble(time=x, count=y), unlist(proposed_breakpoints[j,]$best_bp)) %>% table()) >= avg_points_per_window)) { proposed_breakpoints[j,]$convergence <- F } } @@ -220,7 +220,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t return(list(best_bp=NULL, best_fit=NULL)) } - m <- biPOD:::get_model("fit_breakpoints") + m <- get_model("fit_breakpoints") cli::cli_alert_info("Breakpoints optimization...") fits <- list() plots <- list() @@ -297,7 +297,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t suppressWarnings(loo_comp <- loo::loo_compare(criterion)) loo_comp[,1] <- round(loo_comp[,1],1) - loo_comp <- loo_comp %>% as_tibble() %>% + loo_comp <- loo_comp %>% dplyr::as_tibble() %>% dplyr::mutate(model = rownames(loo_comp)) %>% dplyr::mutate(j = as.numeric(stringr::str_replace(rownames(loo_comp), pattern = "model", replacement = ""))) %>% dplyr::mutate(convergence = TRUE) @@ -308,12 +308,12 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t }) %>% unlist() #loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::filter(elpd_diff == max(elpd_diff)) - loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::arrange(-as.numeric(elpd_diff), -as.numeric(se_diff)) + loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::arrange(-as.numeric(.data$elpd_diff), -as.numeric(.data$se_diff)) best_js <- loo_comp$j } else if (model_selection %in% c("AIC", "BIC")) { comp <- dplyr::tibble(value = unlist(criterion), n_breakpoints = c(0, proposed_breakpoints$n_breakpoints)) %>% - dplyr::mutate(idx = row_number()) + dplyr::mutate(idx = dplyr::row_number()) best_js <- comp %>% dplyr::arrange(.data$value) %>% @@ -343,7 +343,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t } # Extra fit - m <- biPOD:::get_model("piecewise_changepoints") + m <- get_model("piecewise_changepoints") if (is.null(best_bp) | length(best_bp) == 0) { bp = array(0, dim = c(0)) @@ -366,7 +366,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t ) ) - final_fit <- biPOD:::convert_mcmc_fit_to_biPOD(f) + final_fit <- convert_mcmc_fit_to_biPOD(f) if (norm) { x <- d$time @@ -376,7 +376,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t if (length(bp)) { final_bp <- f$draws(variables = 'b', format = 'matrix') %>% dplyr::as_tibble() %>% - dplyr::summarise_all(median) %>% + dplyr::summarise_all(stats::median) %>% as.numeric() if (norm) { @@ -396,7 +396,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t } }) %>% unlist() - all_correct <- all((biPOD:::bp_to_groups(dplyr::tibble(time=x, count=y), unlist(final_bp)) %>% table()) >= avg_points_per_window) & mean(unlist(final_fit$rhat)) <= 1.1 + all_correct <- all((bp_to_groups(dplyr::tibble(time=x, count=y), unlist(final_bp)) %>% table()) >= avg_points_per_window) & mean(unlist(final_fit$rhat)) <= 1.1 j_idx <- j_idx + 1 } diff --git a/R/fit_task3.R b/R/fit_task3.R index efe11be..e047a12 100644 --- a/R/fit_task3.R +++ b/R/fit_task3.R @@ -44,9 +44,9 @@ fit_two_pop_model <- function( # Add results to bipod object x$two_pop_fit_elbo <- res$elbo_data - x$two_pop_fit <- convert_mcmc_fit_to_biPOD(res$fit) + x$metadata$chains = chains - if (sampling_type == "mcmc") { + if (sampling_type == "MCMC sampling") { x$metadata$status <- diagnose_mcmc_fit(res$fit) } else { x$metadata$status <- diagnose_variational_fit(res$fit, res$elbo_data) @@ -56,6 +56,124 @@ fit_two_pop_model <- function( x$metadata$sampling <- res$fit_info$sampling x$metadata$factor_size <- res$fit_info$factor_size + # Produce plots #### + ## Produce evo plot #### + best_fit <- res$fit + draws <- best_fit$draws(format = "df", variables = "mu") + + mu <- draws %>% + dplyr::as_tibble() %>% + dplyr::select(!c(.data$.chain, .data$.iteration, .data$.draw)) %>% + #select(starts_with("mu")) %>% + apply(2, stats::quantile, c(00.05, 0.5, 0.95)) %>% + t() %>% + data.frame(x = x$counts$time) %>% + tidyr::gather(pct, y, -x) + mu$y <- mu$y * factor_size + + evo_plot <- ggplot2::ggplot() + + ggplot2::geom_point(ggplot2::aes(x=.data$time, y=.data$count), data = x$counts, size = 1) + + ggplot2::geom_line(ggplot2::aes(x=.data$x, y=.data$y, linetype = .data$pct), data = mu, color = 'darkgreen') + + ggplot2::scale_linetype_manual(values = c(2,1,2)) + + ggplot2::guides(linetype = "none") + + ggplot2::theme_bw() + + ## Produce evo separated plot #### + draws <- best_fit$draws(format = "df", variables = "ns") + ns <- draws %>% + dplyr::as_tibble() %>% + dplyr::select(!c(.data$.chain, .data$.iteration, .data$.draw)) %>% + apply(2, stats::quantile, c(00.05, 0.5, 0.95)) %>% + t() %>% + data.frame(x = x$counts$time) %>% + tidyr::gather(pct, y, -x) %>% + dplyr::mutate(y = .data$y * factor_size) + + draws <- best_fit$draws(format = "df", variables = "nr") + nr <- draws %>% + dplyr::as_tibble() %>% + dplyr::select(!c(.data$.chain, .data$.iteration, .data$.draw)) %>% + apply(2, stats::quantile, c(00.05, 0.5, 0.95)) %>% + t() %>% + data.frame(x = x$counts$time) %>% + tidyr::gather(pct, y, -x) %>% + dplyr::mutate(y = .data$y * factor_size) + + evo_plot_separated <- ggplot2::ggplot() + + ggplot2::geom_point(ggplot2::aes(x=.data$time, y=.data$count), data = x$counts, size = 1) + + ggplot2::geom_line(ggplot2::aes(x=.data$x, y=.data$y, linetype = .data$pct), data = nr, color = 'steelblue3') + + ggplot2::geom_line(ggplot2::aes(x=.data$x, y=.data$y, linetype = .data$pct), data = ns, color = 'indianred') + + ggplot2::scale_linetype_manual(values = c(2,1,2)) + + ggplot2::guides(linetype = "none") + + ggplot2::theme_bw() + + ## Produce times plot #### + variables <- best_fit$summary()$variable + + times <- dplyr::tibble(x = best_fit$draws(format = "df", variables = "t0_r")[,1] %>% unlist() %>% as.numeric(), par = 't0_r') + t0_r_draws <- times$x + if ("t_end" %in% variables) { + times <- dplyr::bind_rows( + times, + dplyr::tibble(x = best_fit$draws(format = "df", variables = "t_end")[,1] %>% unlist() %>% as.numeric(), par = 't_end') + ) + } + + times_plot <- times %>% + dplyr::group_by(.data$par) %>% + dplyr::mutate(q_low = stats::quantile(x, .05), q_high = stats::quantile(x, .95)) %>% + dplyr::filter(x >= .data$q_low, x <= .data$q_high) %>% + ggplot2::ggplot(mapping = ggplot2::aes(x=.data$x, col=.data$par)) + + ggplot2::geom_density() + + ggplot2::labs(x = "Time", y="Density") + + ggplot2::geom_vline(xintercept = x$counts$time[1], linetype = "dashed") + + ggplot2::theme_bw() + + ## Produce plot of F_R #### + nr_first_obs <- best_fit$draws(format = "df", variables = "nr")[,1] %>% unlist() %>% as.numeric() + ns_first_obs <- best_fit$draws(format = "df", variables = "ns")[,1] %>% unlist() %>% as.numeric() + + nr_first_obs <- nr_first_obs[t0_r_draws <= 0] + ns_first_obs <- ns_first_obs[t0_r_draws <= 0] + + f_r = nr_first_obs / (nr_first_obs + ns_first_obs) + fr_plot <- dplyr::tibble(x = f_r) %>% + ggplot2::ggplot(mapping = ggplot2::aes(x=x)) + + ggplot2::geom_density() + + ggplot2::labs(x = "Resistant population fraction", y="Density") + + ggplot2::lims(x = c(0,1)) + + ggplot2::theme_bw() + + ## Rates plot #### + rho_r_draws <- best_fit$draws(format = "df", variables = "rho_r")[,1] %>% unlist() %>% as.numeric() + rates <- dplyr::tibble(x=best_fit$draws(format = "df", variables = "rho_r")[,1] %>% unlist() %>% as.numeric(), par = "rho_r") + if ("rho_s" %in% variables) { + rates <- dplyr::bind_rows( + rates, + dplyr::tibble(x=-(best_fit$draws(format = "df", variables = "rho_s")[,1] %>% unlist() %>% as.numeric()), par = "rho_s") + ) + } + + rates_plot <- rates %>% + dplyr::group_by(.data$par) %>% + dplyr::mutate(q_low = stats::quantile(x, .05), q_high = stats::quantile(x, .95)) %>% + dplyr::filter(x >= .data$q_low, x <= .data$q_high) %>% + ggplot2::ggplot(mapping = ggplot2::aes(x=.data$x, col=.data$par)) + + ggplot2::geom_density() + + ggplot2::geom_vline(xintercept = 0, linetype = "dashed") + + ggplot2::labs(x = "Growth rate", y="Density") + + ggplot2::theme_bw() + + x$two_pop_plots <- list( + evo_plot = evo_plot, + evo_plot_separated = evo_plot_separated, + rates_plot = rates_plot, + times_plot = times_plot, + fr_plot = fr_plot + ) + + x$two_pop_fit <- convert_mcmc_fit_to_biPOD(res$fit) + #x$two_pop_fit <- res$fit return(x) } @@ -73,22 +191,38 @@ prep_data_two_pop_fit <- function(x, factor_size) { fit_two_pop_data <- function(x, factor_size, variational, chains, iter, cores) { input_data <- prep_data_two_pop_fit(x = x, factor_size = factor_size) - model <- get_model(model_name = "two_pop") + #model <- get_model(model_name = "two_pop") + + m1 <- get_model("two_pop_single") + m2 <- get_model("two_pop_both") + tmp <- utils::capture.output(f1 <- m1$sample(input_data, parallel_chains = chains, iter_warmup = iter, iter_sampling = iter, chains = chains)) + tmp <- utils::capture.output(f2 <- m2$sample(input_data, parallel_chains = chains, iter_warmup = iter, iter_sampling = iter, chains = chains)) + + fits <- list(f1, f2) + + loo1 <- f1$loo() + loo2 <- f2$loo() + + loos <- loo::loo_compare(list(loo1, loo2)) + fit_model <- fits[[as.numeric(stringr::str_replace(rownames(loos)[1], "model", ""))]] + + sampling <- "mcmc" + elbo_d <- NULL # Fit with either MCMC or Variational - if (variational) { - sampling <- "variational" - res <- variational_fit(model = model, data = input_data, iter = iter) - fit_model <- res$fit_model - elbo_d <- res$elbo_d - } else { - sampling <- "mcmc" - tmp <- utils::capture.output( - suppressMessages( - fit_model <- model$sample(data = input_data, chains = chains, parallel_chains = cores, iter_warmup = iter, iter_sampling = iter, refresh = iter) - ) - ) - } + # if (variational) { + # sampling <- "variational" + # res <- variational_fit(model = model, data = input_data, iter = iter) + # fit_model <- res$fit_model + # elbo_d <- res$elbo_d + # } else { + # sampling <- "mcmc" + # tmp <- utils::capture.output( + # suppressMessages( + # fit_model <- model$sample(data = input_data, chains = chains, parallel_chains = cores, iter_warmup = iter, iter_sampling = iter, refresh = iter) + # ) + # ) + # } elbo_data <- c() if (variational) elbo_data <- elbo_d diff --git a/R/getter.R b/R/getter.R index 03375ec..931c823 100644 --- a/R/getter.R +++ b/R/getter.R @@ -9,7 +9,9 @@ get_model <- function(model_name) { "two_pop" = "two_population.stan", "piecewise_changepoints" = "piecewise_linear_regression.stan", "pw_lin_fixed_b" = "pw_linear_b_fixed.stan", - "fit_breakpoints" = "fit_breakpoints.stan" + "fit_breakpoints" = "fit_breakpoints.stan", + "two_pop_both" = 'two_pop_both_v2.stan', + "two_pop_single" = 'two_pop_single.stan' ) if (!(model_name) %in% names(all_paths)) stop("model_name not recognized") diff --git a/R/plot_task3.R b/R/plot_task3.R index f15a808..b7510fa 100644 --- a/R/plot_task3.R +++ b/R/plot_task3.R @@ -29,7 +29,7 @@ plot_two_pop_fit <- function( #if (add_posteriors & !(split_process)) stop("Posteriors can be only added if 'split_process' = TRUE") alpha <- 1 - CI - fitted_data <- biPOD:::get_data_for_two_pop_plot(x, alpha = alpha) + fitted_data <- get_data_for_two_pop_plot(x, alpha = alpha) fitted_data <- dplyr::bind_rows(fitted_data, fitted_data %>% @@ -38,14 +38,14 @@ plot_two_pop_fit <- function( dplyr::summarise_all(sum) %>% dplyr::mutate(group = 'total')) - times <- x$counts$time + times <- x$counts$time p <- ggplot2::ggplot() + ggplot2::geom_point(x$counts, mapping = ggplot2::aes(x = .data$time, y = .data$count)) + # original points ggplot2::geom_line(fitted_data %>% dplyr::filter(.data$group == "total", .data$x >= min(times), .data$x <= max(times)), mapping = ggplot2::aes(x = .data$x, y = .data$y), col = "black") + - biPOD:::my_ggplot_theme() + my_ggplot_theme() if (split_process) { p <- p + @@ -58,56 +58,56 @@ plot_two_pop_fit <- function( plots <- list(p = p) if (t_posteriors) { - t_data <- biPOD:::get_parameters(x$two_pop_fit, par_list = c('t0_r', "t_end")) %>% - dplyr::mutate(parameter = if_else(parameter == "t0_r", "t_r", 't_e')) - + t_data <- get_parameters(x$two_pop_fit, par_list = c('t0_r', "t_end")) %>% + dplyr::mutate(parameter = dplyr::if_else(.data$parameter == "t0_r", "t_r", 't_e')) time_limits <- c(min(t_data$value, min(fitted_data$x)), max(t_data$value, max(fitted_data$x))) - p <- p + lims(x = time_limits) + p <- p + ggplot2::lims(x = time_limits) t_plot <- t_data%>% - ggplot(mapping = aes(x=value, fill=parameter)) + - geom_histogram(alpha = .7, binwidth = 0.005) + + ggplot2::ggplot(mapping = ggplot2::aes(x=.data$value, fill=.data$parameter)) + + ggplot2::geom_histogram(alpha = .7, binwidth = 0.005) + #scale_color_manual(values = c("t_e"=sensitive_color, "t_r"=resistant_color)) + - scale_fill_manual(values = c("t_e"=sensitive_color, "t_r"=resistant_color)) + - biPOD:::my_ggplot_theme() + - labs(fill = "", col="", x="time") + - lims(x = time_limits) + ggplot2::scale_fill_manual(values = c("t_e"=sensitive_color, "t_r"=resistant_color)) + + my_ggplot_theme() + + ggplot2::labs(fill = "", col="", x="time") + + ggplot2::lims(x = time_limits) plots = list(p = p, t_plot=t_plot) } if (r_posteriors) { - d <- biPOD:::get_parameters(x$two_pop_fit, par_list = c('rho_r', "rho_s")) %>% - dplyr::mutate(value = ifelse(parameter == "rho_s", -value, value)) + d <- get_parameters(x$two_pop_fit, par_list = c('rho_r', "rho_s")) %>% + dplyr::mutate(value = ifelse(.data$parameter == "rho_s", -.data$value, .data$value)) rho_plot <- d %>% - ggplot(mapping = aes(x=value,fill=parameter)) + + ggplot2::ggplot(mapping = ggplot2::aes(x=.data$value,fill=.data$parameter)) + #geom_density(alpha = .3) + - geom_histogram(alpha = .7, binwidth = 0.005) + + ggplot2::geom_histogram(alpha = .7, binwidth = 0.005) + #scale_color_manual(values = c("rho_s"=sensitive_color, "rho_r"=resistant_color)) + - scale_fill_manual(values = c("rho_s"=sensitive_color, "rho_r"=resistant_color)) + - biPOD:::my_ggplot_theme() + - labs(fill = "", col="", x="Growth rate") + - geom_vline(xintercept = 0, linetype='dashed') + ggplot2::scale_fill_manual(values = c("rho_s"=sensitive_color, "rho_r"=resistant_color)) + + my_ggplot_theme() + + ggplot2::labs(fill = "", col="", x="Growth rate") + + ggplot2::geom_vline(xintercept = 0, linetype='dashed') plots$rho_plot = rho_plot } if (f_posteriors) { - d <- biPOD:::get_parameters(x$two_pop_fit, par_list = c('f_s')) %>% - dplyr::mutate(value = 1 - value, parameter = "f_r") + ns <- get_parameters(x$two_pop_fit, par_list = c('ns[1]')) %>% dplyr::pull(.data$value) + nr <- get_parameters(x$two_pop_fit, par_list = c('nr[1]')) %>% dplyr::pull(.data$value) + d <- dplyr::tibble(value = nr / (nr + ns), parameter = "f_r") f_plot <- d %>% - ggplot(mapping = aes(x=value,fill=parameter)) + + ggplot2::ggplot(mapping = ggplot2::aes(x=.data$value,fill=.data$parameter)) + #geom_density(alpha = .3) + - geom_histogram(alpha = .7, binwidth = 0.01) + + ggplot2::geom_histogram(alpha = .7, binwidth = 0.01) + #scale_color_manual(values = c("rho_s"=sensitive_color, "rho_r"=resistant_color)) + - scale_fill_manual(values = c("f_r"=resistant_color)) + - biPOD:::my_ggplot_theme() + - labs(fill = "", col="", x="Resistant population fraction") + - lims(x = c(-0.05,1.05)) + ggplot2::scale_fill_manual(values = c("f_r"=resistant_color)) + + my_ggplot_theme() + + ggplot2::labs(fill = "", col="", x="Resistant population fraction") + + ggplot2::lims(x = c(-0.05,1.05)) plots$f_plot = f_plot } @@ -129,7 +129,7 @@ get_data_for_two_pop_plot <- function(x, alpha) { rho_quantiles <- rho_samples %>% dplyr::group_by(.data$parameter) %>% - dplyr::summarise(low = stats::quantile(.data$value, alpha / 2), mid = stats::quantile(.data$value, .5), high = stats::quantile(.data$value, 1 - alpha / 2)) + dplyr::summarise(low = stats::quantile(.data$value, alpha / 2, na.rm=TRUE), mid = stats::quantile(.data$value, .5, na.rm=TRUE), high = stats::quantile(.data$value, 1 - alpha / 2, na.rm=TRUE)) rho_s_quantiles <- rho_quantiles %>% dplyr::filter(.data$parameter == "rho_s") rho_r_quantiles <- rho_quantiles %>% dplyr::filter(.data$parameter == "rho_r") @@ -146,15 +146,15 @@ get_data_for_two_pop_plot <- function(x, alpha) { xs <- seq(min_t, max_t, length = 2000) - median_ns <- get_parameter(fit, "ns") %>% + median_ns <- get_parameter(fit, "ns[1]") %>% dplyr::pull(.data$value) %>% stats::median() func <- two_pops_evo - ylow <- lapply(xs, two_pops_evo, ns = median_ns, t0_r = median_t0_r, rho_s = -rho_s_quantiles$low, rho_r = rho_r_quantiles$low) - ymid <- lapply(xs, two_pops_evo, ns = median_ns, t0_r = median_t0_r, rho_s = -rho_s_quantiles$mid, rho_r = rho_r_quantiles$mid) - yhigh <- lapply(xs, two_pops_evo, ns = median_ns, t0_r = median_t0_r, rho_s = -rho_s_quantiles$high, rho_r = rho_r_quantiles$high) + ylow <- lapply(xs, func, ns = median_ns, t0_r = median_t0_r, rho_s = -rho_s_quantiles$low, rho_r = rho_r_quantiles$low) + ymid <- lapply(xs, func, ns = median_ns, t0_r = median_t0_r, rho_s = -rho_s_quantiles$mid, rho_r = rho_r_quantiles$mid) + yhigh <- lapply(xs, func, ns = median_ns, t0_r = median_t0_r, rho_s = -rho_s_quantiles$high, rho_r = rho_r_quantiles$high) ylow_r <- lapply(ylow, function(y) { y$r_pop diff --git a/R/print.bipod.R b/R/print.bipod.R index 722f44d..e9c9cee 100644 --- a/R/print.bipod.R +++ b/R/print.bipod.R @@ -103,7 +103,7 @@ print.bipod = function(x, ...) { ) ) - x$two_pop_fit + #x$two_pop_fit if (x$metadata$factor_size != 1) { cli::cli_alert_info(paste0(" Scale factor : ", crayon::blue(x$metadata$factor_size), ". Instant of birth might be incorrect.")) @@ -113,7 +113,9 @@ print.bipod = function(x, ...) { cat("\n") cli::cli_alert_info(" Inferred parameters") - par_tibble <- lapply(x$two_pop_fit$parameters, function(p) { + par_list <- c("lp__", "rho_r", "rho_s", "t0_r", "t_end") + + par_tibble <- lapply(par_list, function(p) { draws <- x$two_pop_fit$draws[,grepl(p, colnames(x$two_pop_fit$draws), fixed = T)] %>% as.vector() %>% unlist() dplyr::tibble(Parameter = p, Mean = mean(draws), Sd = stats::sd(draws), p05 = stats::quantile(draws, .05), p50 = stats::quantile(draws, .50), p95 = stats::quantile(draws, .95)) }) %>% do.call(dplyr::bind_rows, .) diff --git a/R/utils.R b/R/utils.R index 393523b..0245443 100644 --- a/R/utils.R +++ b/R/utils.R @@ -80,13 +80,14 @@ two_pops_evo <- function(t, ns, t0_r, rho_s, rho_r) { diagnose_mcmc_fit <- function(fit) { all_pars <- colnames(fit$draws(format = "matrix")) - n_chains <- ncol(fit$draws("lp__")) status <- "PASS" for (par in all_pars) { rhat <- posterior::rhat(fit$draws(par)) - if (rhat > 1.01) { + if (is.na(rhat)) { + status <- "PASS" + } else if (rhat > 1.01) { status <- "FAIL" break() } diff --git a/inst/.DS_Store b/inst/.DS_Store new file mode 100644 index 0000000..9315ccb Binary files /dev/null and b/inst/.DS_Store differ diff --git a/inst/cmdstan/two_pop_both_v2.stan b/inst/cmdstan/two_pop_both_v2.stan new file mode 100644 index 0000000..5960406 --- /dev/null +++ b/inst/cmdstan/two_pop_both_v2.stan @@ -0,0 +1,70 @@ +data { + int S; // Number of steps + array[S] int N; // observations + array[S] real T; // observations +} + +parameters { + real rho_r; // Parameter rho_r (rate for recoverN) + real rho_s; // Parameter rho_s (rate for signal decaN) + real t0_r; // Parameter t_r (time shift) + real t_end; + // real f_s; +} + +model { + vector[S] mu; // Expected values for N given x + vector[S] ns; + vector[S] nr; + + // Priors + // n0 ~ normal(N[1], 0.1 * N[1]); + rho_r ~ normal(0, 1); // Prior for rho_r + rho_s ~ normal(0, 1); // Prior for rho_s + t0_r ~ normal(T[1], 1); // Prior for t_r + t_end ~ normal(T[S], 1); + + for (i in 1:S) { + if (T[i] >= t_end) { + ns[i] = 1e-6; + } else { + ns[i] = exp(-rho_s * (T[i] - t_end)); + } + + if (T[i] >= t0_r) { + nr[i] = exp(rho_r * (T[i] - t0_r)); + } else { + nr[i] = 1e-6; + } + + mu[i] = nr[i] + ns[i]; + } + + // Likelihood (assuming normallN distributed noise) + N ~ poisson(mu); +} + +generated quantities { + vector[S] log_lik; // Log-likelihood for each observation + vector[S] ns; + vector[S] nr; + vector[S] mu; // Expected values for N given x + + for (i in 1:S) { + if (T[i] >= t_end) { + ns[i] = 1e-6; + } else { + ns[i] = exp(-rho_s * (T[i] - t_end)); + } + + if (T[i] >= t0_r) { + nr[i] = exp(rho_r * (T[i] - t0_r)); + } else { + nr[i] = 1e-6; + } + mu[i] = nr[i] + ns[i]; + log_lik[i] = poisson_lpmf(N[i] | mu[i]); // Log-likelihood calculation + } +} + + diff --git a/inst/cmdstan/two_pop_single.stan b/inst/cmdstan/two_pop_single.stan new file mode 100644 index 0000000..cf8c4a3 --- /dev/null +++ b/inst/cmdstan/two_pop_single.stan @@ -0,0 +1,44 @@ +data { + int S; // Number of steps + array[S] int N; // observations + array[S] real T; // observations +} + +parameters { + real rho_r; // Parameter rho_r (rate for recover) + real t0_r; // Parameter t_r (time shift) +} + +model { + vector[S] mu; // Expected values for N given x + vector[S] ns; + vector[S] nr; + + // Define the expected value based on the given equation + for (i in 1:S) { + nr[i] = exp(rho_r * (T[i] - t0_r)); + ns[i] = 0.0; + mu[i] = nr[i] + ns[i]; + } + + // Priors + rho_r ~ normal(0, 10); // Prior for rho_r + t0_r ~ normal(0, 10); // Prior for t_r + + // Likelihood (assuming normallN distributed noise) + N ~ poisson(mu); +} + +generated quantities { + vector[S] log_lik; // Log-likelihood for each observation + vector[S] mu; // Expected values for N given x + vector[S] ns; + vector[S] nr; + + for (i in 1:S) { + nr[i] = exp(rho_r * (T[i] - t0_r)); + ns[i] = 0.0; + mu[i] = nr[i] + ns[i]; + log_lik[i] = poisson_lpmf(N[i] | mu[i]); // Log-likelihood calculation + } +} diff --git a/vignettes/a4_task3.Rmd b/vignettes/a4_task3.Rmd index 101f074..2fa73d8 100644 --- a/vignettes/a4_task3.Rmd +++ b/vignettes/a4_task3.Rmd @@ -63,11 +63,11 @@ x <- biPOD::fit_two_pop_model(x, variational = F, factor_size = 1) and the fit can be visualized as a single process ```{r} -biPOD::plot_two_pop_fit(x) +biPOD::plot_two_pop_fit(x, split_process = F, f_posteriors = F, t_posteriors = F, r_posteriors = F) ``` or splitting it depending on the two different populations ```{r} -biPOD::plot_two_pop_fit(x, split_process = T) +biPOD::plot_two_pop_fit(x, split_process = T, f_posteriors = F, t_posteriors = F, r_posteriors = F) ```