Skip to content

Commit

Permalink
Added model selection to third task
Browse files Browse the repository at this point in the history
  • Loading branch information
jovoni committed Sep 18, 2024
1 parent 0bc1c03 commit 5995095
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 68 deletions.
Binary file modified .DS_Store
Binary file not shown.
22 changes: 11 additions & 11 deletions R/fit_task0.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fit_breakpoints = function(
x$metadata$breakpoints <- best_bp

if (!(is.null(best_bp))) {
x$counts$group <- biPOD:::bp_to_groups(x$counts, x$metadata$breakpoints)
x$counts$group <- bp_to_groups(x$counts, x$metadata$breakpoints)
}
cli::cli_alert_info("Median of the inferred breakpoints have been succesfully stored.")

Expand Down Expand Up @@ -129,7 +129,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t

fit_with_breaks_opt <- function(breaks) {
# Ensure necessary attributes are initialized
if(!all(((biPOD:::bp_to_groups(dplyr::tibble(time=x, count=y), break_points = breaks) %>% table()) >= avg_points_per_window))) return(Inf)
if(!all(((bp_to_groups(dplyr::tibble(time=x, count=y), break_points = breaks) %>% table()) >= avg_points_per_window))) return(Inf)
A <- assemble_regression_matrix(c(min(x), breaks, max(x)))

# Try to solve the regression problem
Expand Down Expand Up @@ -206,7 +206,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
dplyr::distinct()

for (j in 1:nrow(proposed_breakpoints)) {
if (!all((biPOD:::bp_to_groups(dplyr::tibble(time=x, count=y), unlist(proposed_breakpoints[j,]$best_bp)) %>% table()) >= avg_points_per_window)) {
if (!all((bp_to_groups(dplyr::tibble(time=x, count=y), unlist(proposed_breakpoints[j,]$best_bp)) %>% table()) >= avg_points_per_window)) {
proposed_breakpoints[j,]$convergence <- F
}
}
Expand All @@ -220,7 +220,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
return(list(best_bp=NULL, best_fit=NULL))
}

m <- biPOD:::get_model("fit_breakpoints")
m <- get_model("fit_breakpoints")
cli::cli_alert_info("Breakpoints optimization...")
fits <- list()
plots <- list()
Expand Down Expand Up @@ -297,7 +297,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
suppressWarnings(loo_comp <- loo::loo_compare(criterion))
loo_comp[,1] <- round(loo_comp[,1],1)

loo_comp <- loo_comp %>% as_tibble() %>%
loo_comp <- loo_comp %>% dplyr::as_tibble() %>%
dplyr::mutate(model = rownames(loo_comp)) %>%
dplyr::mutate(j = as.numeric(stringr::str_replace(rownames(loo_comp), pattern = "model", replacement = ""))) %>%
dplyr::mutate(convergence = TRUE)
Expand All @@ -308,12 +308,12 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
}) %>% unlist()

#loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::filter(elpd_diff == max(elpd_diff))
loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::arrange(-as.numeric(elpd_diff), -as.numeric(se_diff))
loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::arrange(-as.numeric(.data$elpd_diff), -as.numeric(.data$se_diff))
best_js <- loo_comp$j
} else if (model_selection %in% c("AIC", "BIC")) {

comp <- dplyr::tibble(value = unlist(criterion), n_breakpoints = c(0, proposed_breakpoints$n_breakpoints)) %>%
dplyr::mutate(idx = row_number())
dplyr::mutate(idx = dplyr::row_number())

best_js <- comp %>%
dplyr::arrange(.data$value) %>%
Expand Down Expand Up @@ -343,7 +343,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
}

# Extra fit
m <- biPOD:::get_model("piecewise_changepoints")
m <- get_model("piecewise_changepoints")

if (is.null(best_bp) | length(best_bp) == 0) {
bp = array(0, dim = c(0))
Expand All @@ -366,7 +366,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
)
)

final_fit <- biPOD:::convert_mcmc_fit_to_biPOD(f)
final_fit <- convert_mcmc_fit_to_biPOD(f)

if (norm) {
x <- d$time
Expand All @@ -376,7 +376,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
if (length(bp)) {
final_bp <- f$draws(variables = 'b', format = 'matrix') %>%
dplyr::as_tibble() %>%
dplyr::summarise_all(median) %>%
dplyr::summarise_all(stats::median) %>%
as.numeric()

if (norm) {
Expand All @@ -396,7 +396,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
}
}) %>% unlist()

all_correct <- all((biPOD:::bp_to_groups(dplyr::tibble(time=x, count=y), unlist(final_bp)) %>% table()) >= avg_points_per_window) & mean(unlist(final_fit$rhat)) <= 1.1
all_correct <- all((bp_to_groups(dplyr::tibble(time=x, count=y), unlist(final_bp)) %>% table()) >= avg_points_per_window) & mean(unlist(final_fit$rhat)) <= 1.1
j_idx <- j_idx + 1
}

Expand Down
166 changes: 150 additions & 16 deletions R/fit_task3.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ fit_two_pop_model <- function(

# Add results to bipod object
x$two_pop_fit_elbo <- res$elbo_data
x$two_pop_fit <- convert_mcmc_fit_to_biPOD(res$fit)
x$metadata$chains = chains

if (sampling_type == "mcmc") {
if (sampling_type == "MCMC sampling") {
x$metadata$status <- diagnose_mcmc_fit(res$fit)
} else {
x$metadata$status <- diagnose_variational_fit(res$fit, res$elbo_data)
Expand All @@ -56,6 +56,124 @@ fit_two_pop_model <- function(
x$metadata$sampling <- res$fit_info$sampling
x$metadata$factor_size <- res$fit_info$factor_size

# Produce plots ####
## Produce evo plot ####
best_fit <- res$fit
draws <- best_fit$draws(format = "df", variables = "mu")

mu <- draws %>%
dplyr::as_tibble() %>%
dplyr::select(!c(.data$.chain, .data$.iteration, .data$.draw)) %>%
#select(starts_with("mu")) %>%
apply(2, stats::quantile, c(00.05, 0.5, 0.95)) %>%
t() %>%
data.frame(x = x$counts$time) %>%
tidyr::gather(pct, y, -x)
mu$y <- mu$y * factor_size

evo_plot <- ggplot2::ggplot() +
ggplot2::geom_point(ggplot2::aes(x=.data$time, y=.data$count), data = x$counts, size = 1) +
ggplot2::geom_line(ggplot2::aes(x=.data$x, y=.data$y, linetype = .data$pct), data = mu, color = 'darkgreen') +
ggplot2::scale_linetype_manual(values = c(2,1,2)) +
ggplot2::guides(linetype = "none") +
ggplot2::theme_bw()

## Produce evo separated plot ####
draws <- best_fit$draws(format = "df", variables = "ns")
ns <- draws %>%
dplyr::as_tibble() %>%
dplyr::select(!c(.data$.chain, .data$.iteration, .data$.draw)) %>%
apply(2, stats::quantile, c(00.05, 0.5, 0.95)) %>%
t() %>%
data.frame(x = x$counts$time) %>%
tidyr::gather(pct, y, -x) %>%
dplyr::mutate(y = .data$y * factor_size)

draws <- best_fit$draws(format = "df", variables = "nr")
nr <- draws %>%
dplyr::as_tibble() %>%
dplyr::select(!c(.data$.chain, .data$.iteration, .data$.draw)) %>%
apply(2, stats::quantile, c(00.05, 0.5, 0.95)) %>%
t() %>%
data.frame(x = x$counts$time) %>%
tidyr::gather(pct, y, -x) %>%
dplyr::mutate(y = .data$y * factor_size)

evo_plot_separated <- ggplot2::ggplot() +
ggplot2::geom_point(ggplot2::aes(x=.data$time, y=.data$count), data = x$counts, size = 1) +
ggplot2::geom_line(ggplot2::aes(x=.data$x, y=.data$y, linetype = .data$pct), data = nr, color = 'steelblue3') +
ggplot2::geom_line(ggplot2::aes(x=.data$x, y=.data$y, linetype = .data$pct), data = ns, color = 'indianred') +
ggplot2::scale_linetype_manual(values = c(2,1,2)) +
ggplot2::guides(linetype = "none") +
ggplot2::theme_bw()

## Produce times plot ####
variables <- best_fit$summary()$variable

times <- dplyr::tibble(x = best_fit$draws(format = "df", variables = "t0_r")[,1] %>% unlist() %>% as.numeric(), par = 't0_r')
t0_r_draws <- times$x
if ("t_end" %in% variables) {
times <- dplyr::bind_rows(
times,
dplyr::tibble(x = best_fit$draws(format = "df", variables = "t_end")[,1] %>% unlist() %>% as.numeric(), par = 't_end')
)
}

times_plot <- times %>%
dplyr::group_by(.data$par) %>%
dplyr::mutate(q_low = stats::quantile(x, .05), q_high = stats::quantile(x, .95)) %>%
dplyr::filter(x >= .data$q_low, x <= .data$q_high) %>%
ggplot2::ggplot(mapping = ggplot2::aes(x=.data$x, col=.data$par)) +
ggplot2::geom_density() +
ggplot2::labs(x = "Time", y="Density") +
ggplot2::geom_vline(xintercept = x$counts$time[1], linetype = "dashed") +
ggplot2::theme_bw()

## Produce plot of F_R ####
nr_first_obs <- best_fit$draws(format = "df", variables = "nr")[,1] %>% unlist() %>% as.numeric()
ns_first_obs <- best_fit$draws(format = "df", variables = "ns")[,1] %>% unlist() %>% as.numeric()

nr_first_obs <- nr_first_obs[t0_r_draws <= 0]
ns_first_obs <- ns_first_obs[t0_r_draws <= 0]

f_r = nr_first_obs / (nr_first_obs + ns_first_obs)
fr_plot <- dplyr::tibble(x = f_r) %>%
ggplot2::ggplot(mapping = ggplot2::aes(x=x)) +
ggplot2::geom_density() +
ggplot2::labs(x = "Resistant population fraction", y="Density") +
ggplot2::lims(x = c(0,1)) +
ggplot2::theme_bw()

## Rates plot ####
rho_r_draws <- best_fit$draws(format = "df", variables = "rho_r")[,1] %>% unlist() %>% as.numeric()
rates <- dplyr::tibble(x=best_fit$draws(format = "df", variables = "rho_r")[,1] %>% unlist() %>% as.numeric(), par = "rho_r")
if ("rho_s" %in% variables) {
rates <- dplyr::bind_rows(
rates,
dplyr::tibble(x=-(best_fit$draws(format = "df", variables = "rho_s")[,1] %>% unlist() %>% as.numeric()), par = "rho_s")
)
}

rates_plot <- rates %>%
dplyr::group_by(.data$par) %>%
dplyr::mutate(q_low = stats::quantile(x, .05), q_high = stats::quantile(x, .95)) %>%
dplyr::filter(x >= .data$q_low, x <= .data$q_high) %>%
ggplot2::ggplot(mapping = ggplot2::aes(x=.data$x, col=.data$par)) +
ggplot2::geom_density() +
ggplot2::geom_vline(xintercept = 0, linetype = "dashed") +
ggplot2::labs(x = "Growth rate", y="Density") +
ggplot2::theme_bw()

x$two_pop_plots <- list(
evo_plot = evo_plot,
evo_plot_separated = evo_plot_separated,
rates_plot = rates_plot,
times_plot = times_plot,
fr_plot = fr_plot
)

x$two_pop_fit <- convert_mcmc_fit_to_biPOD(res$fit)
#x$two_pop_fit <- res$fit
return(x)
}

Expand All @@ -73,22 +191,38 @@ prep_data_two_pop_fit <- function(x, factor_size) {

fit_two_pop_data <- function(x, factor_size, variational, chains, iter, cores) {
input_data <- prep_data_two_pop_fit(x = x, factor_size = factor_size)
model <- get_model(model_name = "two_pop")
#model <- get_model(model_name = "two_pop")

m1 <- get_model("two_pop_single")
m2 <- get_model("two_pop_both")

tmp <- utils::capture.output(f1 <- m1$sample(input_data, parallel_chains = chains, iter_warmup = iter, iter_sampling = iter, chains = chains))
tmp <- utils::capture.output(f2 <- m2$sample(input_data, parallel_chains = chains, iter_warmup = iter, iter_sampling = iter, chains = chains))

fits <- list(f1, f2)

loo1 <- f1$loo()
loo2 <- f2$loo()

loos <- loo::loo_compare(list(loo1, loo2))
fit_model <- fits[[as.numeric(stringr::str_replace(rownames(loos)[1], "model", ""))]]

sampling <- "mcmc"
elbo_d <- NULL
# Fit with either MCMC or Variational
if (variational) {
sampling <- "variational"
res <- variational_fit(model = model, data = input_data, iter = iter)
fit_model <- res$fit_model
elbo_d <- res$elbo_d
} else {
sampling <- "mcmc"
tmp <- utils::capture.output(
suppressMessages(
fit_model <- model$sample(data = input_data, chains = chains, parallel_chains = cores, iter_warmup = iter, iter_sampling = iter, refresh = iter)
)
)
}
# if (variational) {
# sampling <- "variational"
# res <- variational_fit(model = model, data = input_data, iter = iter)
# fit_model <- res$fit_model
# elbo_d <- res$elbo_d
# } else {
# sampling <- "mcmc"
# tmp <- utils::capture.output(
# suppressMessages(
# fit_model <- model$sample(data = input_data, chains = chains, parallel_chains = cores, iter_warmup = iter, iter_sampling = iter, refresh = iter)
# )
# )
# }

elbo_data <- c()
if (variational) elbo_data <- elbo_d
Expand Down
4 changes: 3 additions & 1 deletion R/getter.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ get_model <- function(model_name) {
"two_pop" = "two_population.stan",
"piecewise_changepoints" = "piecewise_linear_regression.stan",
"pw_lin_fixed_b" = "pw_linear_b_fixed.stan",
"fit_breakpoints" = "fit_breakpoints.stan"
"fit_breakpoints" = "fit_breakpoints.stan",
"two_pop_both" = 'two_pop_both_v2.stan',
"two_pop_single" = 'two_pop_single.stan'
)

if (!(model_name) %in% names(all_paths)) stop("model_name not recognized")
Expand Down
Loading

0 comments on commit 5995095

Please sign in to comment.