Skip to content

Commit

Permalink
Minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jovoni committed Aug 21, 2024
1 parent f97ab15 commit e5e199d
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions R/fit_task0.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
return(ssr)
}


fit <- function(n_segments) {
# Define the function to minimize
min_function <- fit_with_breaks_opt
Expand All @@ -103,7 +104,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
control = DEoptim::DEoptim.control(
VTR = 0,
NP = 100,
itermax = 1000,
itermax = n_trials,
reltol = 1e-3,
CR = 0.7,
strategy = 2,
Expand Down Expand Up @@ -204,6 +205,7 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
fits <- list()
plots <- list()
proposed_breakpoints$idx <- c(1:nrow(proposed_breakpoints)) + 1
j <- 2
loos <- lapply(0:nrow(proposed_breakpoints), function(j) {
print(j)
if (j == 0) {
Expand Down Expand Up @@ -236,9 +238,8 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
means <- lapply(1:ncol(repetitions), function(j) {mean(repetitions[,j])}) %>% unlist()
sds <- lapply(1:ncol(repetitions), function(j) {sd(repetitions[,j])}) %>% unlist()

plots[[j+1]] <<- dplyr::tibble(x =x, means = means, sds = sds) %>%
ggplot2::ggplot(mapping = ggplot2::aes(x=.data$x, y=.data$means, ymin=.data$means-.data$sds, ymax=.data$means+.data$sds)) +
ggplot2::geom_pointrange() +
plots[[j+1]] <<- ggplot2::ggplot() +
ggplot2::geom_pointrange(dplyr::tibble(x=x, means=means, sds = sds), mapping = ggplot2::aes(x=.data$x, y=.data$means, ymin=.data$means-.data$sds, ymax=.data$means+.data$sds)) +
ggplot2::geom_point(dplyr::tibble(x=x, y=y), mapping=ggplot2::aes(x=.data$x, y=.data$y), col="red") +
ggplot2::ggtitle(max(f$lp()))

Expand All @@ -251,38 +252,32 @@ find_breakpoints = function(d, avg_points_per_window, max_breakpoints, norm, n_t
# bic
})

# loo::loo_compare(loos)
# loos %>% unlist()

#best_j <- loos %>% unlist() %>% which.min()
#
if (length(loos) == 1) {
message("Zero models with breakpoints has been found")
return(list(best_bp=NULL, best_fit=NULL))
}

suppressWarnings(loo_comp <- loo::loo_compare(loos))
loo_comp[,1] <- round(loo_comp[,1],1)
if (sum(loo_comp[,1] == max(loo_comp[,1])) == 1) {
best_j <- min(as.numeric(stringr::str_replace(rownames(loo_comp)[1], pattern = "model", replacement = "")))
} else {
loo_comp <- loo_comp[loo_comp[,1] == max(loo_comp[,1]),]
best_j <- min(as.numeric(stringr::str_replace(rownames(loo_comp), pattern = "model", replacement = "")))
}

loo_comp <- loo_comp %>% as_tibble() %>%
dplyr::mutate(model = rownames(loo_comp)) %>%
dplyr::mutate(j = as.numeric(stringr::str_replace(rownames(loo_comp), pattern = "model", replacement = "")) - 1) %>%
dplyr::mutate(convergence = TRUE)

loo_comp$convergence <- lapply(loo_comp$j, function(j) {
if (j == 0) return(TRUE)
all(loos[[j]]$diagnostics$pareto_k <= .7)
}) %>% unlist()

loo_comp <- loo_comp %>% dplyr::filter(convergence) %>% dplyr::filter(elpd_diff == max(elpd_diff))
best_j <- min(loo_comp$j)

best_fit <- fits[[best_j]]
# if (best_j == 1) {
# best_bp = NULL
# } else {
# best_bp <- best_fit$draws(variables = 'b', format = 'matrix') %>%
# dplyr::as_tibble() %>%
# dplyr::summarise_all(stats::median) %>%
# as.numeric()
# }
best_fit <- biPOD:::convert_mcmc_fit_to_biPOD(best_fit)

best_bp <- proposed_breakpoints %>%
dplyr::filter(idx == best_j) %>%
dplyr::filter(idx == best_j + 1) %>%
dplyr::pull(best_bp) %>% unlist()

if (norm) {
Expand Down

0 comments on commit e5e199d

Please sign in to comment.