-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.R
124 lines (118 loc) · 3.85 KB
/
utils.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
library(extrafont)
library(HDInterval)
library(rstudioapi)
source("src/utils-caching.R")
hdi2 <- function(x, ...) c(get_interval2(x, ...), median = median(x))
data_for_obj <- function(obj, obj_name, ..., ci_lty = "dashed",
include_draws = FALSE, n_draws = 1000) {
t_incr <-
if (grepl("hat", obj_name)) jags_data$x.hat.raw else jags_data$x.pred.raw
to_tbl <- function(x) {
as_tibble(t(x)) %>%
mutate(rel_year = unique(t_incr)) %>%
left_join(d %>% distinct(rel_year, cal_year) %>% arrange(rel_year),
by = "rel_year"
) %>%
filter(!is.na(cal_year))
}
out_raw <- if (grepl("strat", obj_name)) {
post_stat <- apply(concat_chains(obj[[obj_name]], axis = 4), 1:2, hdi2, ...)
apply(post_stat, 3, to_tbl) %>%
bind_rows(.id = "stratum_index") %>%
mutate(stratum_index = as.integer(stratum_index))
} else if (grepl("park", obj_name)) {
post_stat <- apply(concat_chains(obj[[obj_name]], axis = 3), 1, hdi2, ...)
to_tbl(post_stat)
}
# browser()
draws <- NULL
if (include_draws) {
if (grepl("strat", obj_name)) {
stop("Argument include_draws supported only for park-level objects...")
}
n_draws_total <- dim(obj[[obj_name]])["iteration"] * dim(obj[[obj_name]])["chain"]
which_samples <- sample(n_draws_total, n_draws)
draws_raw <- to_tbl(t(concat_chains(obj[[obj_name]], axis = 3))[which_samples, ])
draws <- draws_raw %>%
pivot_longer(-any_of(c("rel_year", "cal_year")),
names_to = "iteration", values_to = "val"
)
}
out <- out_raw %>%
pivot_longer(all_of(c("lower", "median", "upper")),
names_to = "stat", values_to = "val"
) %>%
mutate(
lty = ifelse(stat == "median", "solid", ci_lty),
lwd = ifelse(stat == "median", 0.9, 0.5),
obj = obj_name
)
attr(out, "draws") <- draws
if (grepl("strat", obj_name)) {
out %>%
left_join(d %>% distinct(stratum_id, stratum_index))
} else if (grepl("park", obj_name)) {
out
}
}
coef_stats <- function(x, denom = 1, digits = NULL, transform = TRUE, ...) {
if (transform) {
z <- exp(as.vector(x) / denom)
} else {
z <- as.vector(x) / denom
}
out <- c(median = median(z), hdi(z, ...))
if (!is.null(digits)) {
return(round(out, digits))
}
out
}
unscale <- function(z, to, from = c(0, 1)) {
(z - from[1]) / (diff(from)) * diff(to) + to[1]
}
load_font <- function() {
font_import(prompt = FALSE, pattern = "lmroman*")
loadfonts()
}
this_file <- function() {
args <- commandArgs(trailingOnly = FALSE)
match <- grep("--file=", args)
if (length(match) > 0) {
normalizePath(sub("--file=", "", args[match])) # Rscript
} else {
normalizePath(sys.frames()[[1]]$ofile) # 'source'd via R console
}
}
get_ex_path <- function() {
ex_path <- ifelse(.Platform$GUI == "RStudio",
dirname(callFun("getActiveDocumentContext")$path),
dirname(this_file())
)
sapply(
list.files(file.path(ex_path, "../../src"), ".*.R", full.names = TRUE),
source
)
ex_path
}
get_output_path <- function(ex_path, base_dir = "output", dirs = NULL) { # c('03-inference', '99-misc')
# browser()
output_path <-
file.path(base_dir, stringr::str_match(ex_path, "example-\\d+.*")[, 1])
# file.path('output', stringr::str_match(ex_path, 'example-\\d+/.*')[, 1])
message(sprintf("writing complete results to %s", output_path))
sapply(do.call(file.path, list(c(output_path, dirs))),
dir.create,
showWarnings = FALSE, recursive = TRUE
)
if (dir.exists(file.path(ex_path, "00-input"))) {
file.copy(file.path(ex_path, "00-input"), output_path, recursive = TRUE)
}
output_path
}
get_asset_path <- function(output_dir, filename, ...) {
output_dir <- paste(c(output_dir, ...), collapse = "/")
if (!dir.exists(output_dir)) {
dir.create(output_dir, showWarnings = FALSE, recursive = TRUE)
}
file.path(output_dir, filename)
}