Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jovoni committed Jun 28, 2024
1 parent 32cc753 commit baa8b88
Show file tree
Hide file tree
Showing 11 changed files with 349 additions and 16 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ Imports:
ggplot2,
ggpubr,
glue,
lhs,
loo,
magrittr,
patchwork,
posterior,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
S3method(print,bipod)
export(breakpoints_inference)
export(fit)
export(fit_breakpoints)
export(fit_two_pop_model)
export(init)
export(plot_bayes_factor)
Expand Down
224 changes: 224 additions & 0 deletions R/fit_task0.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
#' Fit growth model to bipod object
#'
#' @param x a bipod object
#' @param norm .
#' @param n_trials .
#' @param min_points .
#' @param available_breakpoints .
#' @param constrain_bp_on_x .
#'
#' @return the input bipod object with an added 'breakpoints_fit' slot containing the fitted model for the breakpoints
#' @export
fit_breakpoints <- function(
x,
norm=F,
n_trials=1000,
min_points=3,
available_breakpoints=c(1:5),
constrain_bp_on_x=F
) {
# Check input
if (!(inherits(x, "bipod"))) stop("Input must be a bipod object")
#if (!(factor_size > 0)) stop("factor_size must be positive")
d <- x$counts

res <- find_breakpoints_v3(
d,
norm=norm,
n_trials=n_trials,
min_points=min_points,
available_breakpoints=available_breakpoints,
constrain_bp_on_x=constrain_bp_on_x
)

best_bp <- res$best_bp
best_fit <- res$best_fit

# Store results
elbo_data <- c()
# if (variational) elbo_data <- elbo_d %>% stats::na.omit()
# fit <- fit_model

# Add results to bipod object
#x$breakpoints_elbo <- elbo_data
x$breakpoints_fit <- best_fit

# Write fit info
# x$metadata$sampling <- sampling
#x$metadata$factor_size <- factor_size
# x$metadata$prior_K <- input_data$prior_K

# Add median of breakpoints
# n_changepoints <- length(input_data$changing_times_prior)
# breakpoints_names <- lapply(1:n_changepoints, function(i) {
# paste0("changing_times[", i, "]")
# }) %>% unlist()

# if (best_res$J == 0) {
# median_breakpoints = NULL
# } else {
# median_breakpoints <- best_fit$draws(variables = 'b', format = 'matrix') %>%
# dplyr::as_tibble() %>%
# dplyr::summarise_all(stats::median) %>%
# as.numeric()
#
# median_breakpoints <- median_breakpoints + min(x$counts$time)
# }

x$metadata$breakpoints <- best_bp

if (!(is.null(best_bp))) {
x$counts$group <- bp_to_groups(x$counts, x$metadata$breakpoints)
}

if (!constrain_bp_on_x) {
cli::cli_alert_success("Breakpoints have been inferred. Inspect the results using the {.field plot_breakpoints_posterior} function.")
}
cli::cli_alert_info("Median of the inferred breakpoints have been succesfully stored.")

x
}

ind <- function(x, y) { return(as.numeric(x >= y)) }

# Function to calculate the expected mean
expected_mean <- function(x, q, s, b) {
G <- length(s)
res <- q + x * s[1]
for (g in 2:G) {
res <- res + (x - b[g-1]) * s[g] * ind(x, b[g-1])
}
return(res)
}

find_breakpoints_v3 <- function(d, norm=T, n_trials=1000, min_points=3, available_breakpoints=c(1:6), constrain_bp_on_x=F) {
x <- d$time
y <- log(d$count)

if (norm) {
x <- (x - mean(x)) / stats::sd(x)
y <- (y - mean(y)) / stats::sd(y)
}

available_breakpoints <- available_breakpoints[available_breakpoints != 0]

message("Initial proposals")
proposed_breakpoints <- lapply(available_breakpoints, function(n_breakpoints) {
if (constrain_bp_on_x) {
random_starts <- lapply(1:(n_trials * n_breakpoints), function(j) {sample(x, n_breakpoints, replace = F)}) %>% do.call("rbind", .)
} else {
random_starts <- lhs::randomLHS(n_trials * n_breakpoints, n_breakpoints)
random_starts <- random_starts * (max(x) - min(x)) + min(x)
}

res <- lapply(1:n_trials, function(j) {
bp <- sort(random_starts[j,])
n_per_window <- biPOD:::bp_to_groups(dplyr::tibble(time=x, count=y), bp) %>% table()
if (any(n_per_window < min_points) | length(n_per_window) != ncol(random_starts) + 1) {return(NULL)}

# build design matrix
n_params = n_breakpoints + 2
X = matrix(0, nrow = length(x), ncol = n_params)
X[,1] = 1
X[,2] = x
tmp <- lapply(1:ncol(random_starts), function(k) {
X[,k+2] <<- ifelse(x > bp[k], x - bp[k], 0)
})

params <- c(solve(t(X) %*% X) %*% t(X) %*% y)
ypred = expected_mean(x, params[1], params[2:length(params)], bp)
rmse = sqrt(mean((y - ypred)**2))
return(dplyr::tibble(j = j, rmse=rmse))
}) %>% do.call('bind_rows', .)

if (nrow(res) > 0) {
best <- res %>% dplyr::filter(.data$rmse == min(.data$rmse)) %>% dplyr::slice_head(n=1)
best_rmse <- best$rmse
best_j <- best$j
best_bp <- random_starts[best_j,]
dplyr::tibble(rmse = best_rmse, bp = list(best_bp), n_breakpoints=n_breakpoints)
} else {
return(NULL)
}
}) %>% do.call("bind_rows", .) %>% dplyr::distinct()

#tmp <- utils::capture.output(suppressMessages(m <- cmdstanr::cmdstan_model("piecewise_fixed_breakpoints.stan")))

if (constrain_bp_on_x == T) {
m <- biPOD:::get_model("pw_lin_fixed_b")
} else {
m <- biPOD:::get_model("piecewise_changepoints")
}

message("Proposals' optimization")
fits <- list()
j = 0
proposed_breakpoints$idx <- c(1:nrow(proposed_breakpoints)) + 1
loos <- lapply(0:nrow(proposed_breakpoints), function(j) {
if (j == 0) {
bp = array(0, dim = c(0))
} else {
bp <- sort(unlist(proposed_breakpoints[j,2]))
}

if (constrain_bp_on_x == T) {
input_data <- list(
S = length(x),
G = length(bp),
N = y,
T = x,
b = bp
)
} else {
input_data <- list(
S = length(x),
G = length(bp),
N = y,
T = x,
b_prior = bp,
sigma_changepoints = (max(x) - min(x)) / 10
)
}

tmp <- utils::capture.output(
suppressMessages(
f <- m$sample(input_data, parallel_chains = 4)
)
)

suppressWarnings(loo <- f$loo())
fits[[j+1]] <<- f
loo
})

if (length(loos) == 1) {
message("Zero models with breakpoints has been found")
return(NULL)
}

suppressWarnings(loo_comp <- loo::loo_compare(loos))
best_j <- as.numeric(stringr::str_replace(rownames(loo_comp)[1], pattern = "model", replacement = ""))

if (constrain_bp_on_x) {
if (best_j == 1) { return(NULL) }
best_bp <- proposed_breakpoints %>% dplyr::filter(idx == best_j) %>% pull(bp) %>% unlist() %>% sort()
best_fit <- NULL
} else {
best_fit <- fits[[best_j]]
if (best_j == 1) {
best_bp = NULL
} else {
best_bp <- best_fit$draws(variables = 'b', format = 'matrix') %>%
dplyr::as_tibble() %>%
dplyr::summarise_all(stats::median) %>%
as.numeric()
}
best_fit <- biPOD:::convert_mcmc_fit_to_biPOD(best_fit)
}

if (norm) {
x <- d$time
best_bp <- best_bp * stats::sd(x) + mean(x)
}
return(list(best_bp=best_bp, best_fit=best_fit))
}
20 changes: 12 additions & 8 deletions R/fit_task2.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ breakpoints_inference <- function(

# Get the model
model_name <- "piecewise_changepoints"
m <- get_model(model_name = model_name)
m <- biPOD:::get_model(model_name = model_name)

# Loop
res <- dplyr::tibble()
Expand All @@ -35,8 +35,8 @@ breakpoints_inference <- function(
G = J,
N = x$counts$count / factor_size,
T = x$counts$time - min(x$counts$time),
b_prior = array(find_equispaced_points(min(x$counts$time - min(x$counts$time)), max(x$counts$time - min(x$counts$time)), N = J), dim = c(J)),
sigma_changepoints = 1
b_prior = array(biPOD:::find_equispaced_points(min(x$counts$time - min(x$counts$time)), max(x$counts$time - min(x$counts$time)), N = J), dim = c(J)),
sigma_changepoints = max(x$counts$time) - min(x$counts$time)
)
} else {
input_data <- list(
Expand All @@ -45,7 +45,7 @@ breakpoints_inference <- function(
N = x$counts$count / factor_size,
T = x$counts$time - min(x$counts$time),
b_prior = array(0, dim = c(0)),
sigma_changepoints = 1
sigma_changepoints = max(x$counts$time) - min(x$counts$time)
)
}

Expand All @@ -57,32 +57,36 @@ breakpoints_inference <- function(

out <- tryCatch({
out <- utils::capture.output(suppressMessages(f_pf <- m$pathfinder(input_data)))

k <- ncol(f_pf$draws()) - 3 # number of parameters
n <- nrow(x$counts) # number of obs
lp <- f_pf$draws("lp__") %>% stats::median()
bic <- k * log(n) - 2 * lp
f_pf
}, error = function(cond) {
lp <- Inf
bic <- Inf
return(NA)
}, warning = function(warn) {
lp <- Inf
bic <- Inf
return(NA)
}
)

res <- dplyr::bind_rows(res, dplyr::tibble(J = J, iter = i, lp = lp))
res <- dplyr::bind_rows(res, dplyr::tibble(J = J, iter = i, lp = lp, bic = bic))
fits[[paste0(J, " _ ", i)]] <- out

cli::cli_progress_update()
}
}

res <- res %>%
dplyr::filter(lp != Inf) %>%
dplyr::filter(bic != Inf) %>%
stats::na.omit()

status <- 'FAIL'
while ((status == 'FAIL') & (nrow(res) > 0)) {
best_res <- res %>% dplyr::filter(lp == max(lp))
best_res <- res %>% dplyr::filter(bic == min(bic))

best_fit <- fits[[paste0(best_res$J, " _ ", best_res$iter)]]

Expand Down
3 changes: 2 additions & 1 deletion R/getter.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ get_model <- function(model_name) {
"infer_changepoints" = "infer_changepoints.stan",
"exp_log_mixture" = "exp_log_mixture.stan",
"two_pop" = "two_population.stan",
"piecewise_changepoints" = "piecewise_linear_regression.stan"
"piecewise_changepoints" = "piecewise_linear_regression.stan",
"pw_lin_fixed_b" = "pw_linear_b_fixed.stan"
)

if (!(model_name) %in% names(all_paths)) stop("model_name not recognized")
Expand Down
2 changes: 1 addition & 1 deletion R/plot_task2.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ plot_breakpoints_posterior <- function(x, with_histogram = F, alpha = .6, colors
}) %>% unlist()

samples <- get_parameters(x$breakpoints_fit, par_list = par_list)
samples$value <- samples$value + min(x$counts$time)
#samples$value <- samples$value + min(x$counts$time)

colors = rep('darkgray', length(x$metadata$breakpoints))

Expand Down
4 changes: 4 additions & 0 deletions biPOD.Rproj
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ BuildType: Package
PackageUseDevtools: Yes
PackageInstallArgs: --no-multiarch --with-keep.source
PackageRoxygenize: rd,collate,namespace

PythonType: virtualenv
PythonVersion: 3.10.11
PythonPath: ~/.virtualenvs/r-reticulate/bin/python
15 changes: 11 additions & 4 deletions inst/cmdstan/piecewise_linear_regression.stan
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

functions {
real ind(real x, real y) {
if (x >= y) {
Expand All @@ -22,7 +21,7 @@ data {
int<lower=1> S; // Number of steps
int<lower=0> G; // Number of windows

array[S] real<lower=0> N; // observations
array[S] real N; // observations
array[S] real T; // observations

array[G] real b_prior;
Expand All @@ -32,12 +31,12 @@ data {
parameters {
real q; // intercept
vector[G+1] s; // slopes
positive_ordered[G] b;
ordered[G] b;
real<lower=0> sigma;
}

model {
target += normal_lpdf(q | log(N[1]), N[1]);
target += normal_lpdf(q | 0, 1);
target += inv_gamma_lpdf(sigma | .001, .001);

for (g in 1:(G+1)) {
Expand All @@ -54,3 +53,11 @@ model {
target += normal_lpdf(N[i] | expected_mean(T[i], q, s, b), sigma);
}
}

generated quantities {
vector[S] log_lik;
for (i in 1:S) {
log_lik[i] = normal_lpdf(N[i] | expected_mean(T[i], q, s, b), sigma);
}
}

Loading

0 comments on commit baa8b88

Please sign in to comment.