Skip to content

Commit

Permalink
tests: make more resampling tests work
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 24, 2025
1 parent 74a4882 commit 060073e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
22 changes: 13 additions & 9 deletions R/ResamplingFcstCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,29 @@ ResamplingFcstCV = R6Class("ResamplingFcstCV",
setnames(tab, c("row_id", "order"))
setorderv(tab, "order")
train_end = tab[.N - horizon, row_id]
train_end = rev(seq.int(
train_end = seq.int(
from = train_end,
by = -pars$step_size,
length.out = pars$folds
))
)
if (!pars$fixed_window) {
train_ids = map(train_end, function(x) ids[1L]:x)
} else {
train_ids = map(train_end, function(x) (x - window_size + 1L):x)
}
test_ids = map(train_ids, function(x) (x[length(x)] + 1L):(x[length(x)] + horizon))
test_ids = map(train_ids, function(x) {
n = length(x)
(x[n] + 1L):(x[n] + horizon)
})
} else {
setnames(tab, "..row_id", "row_id")
setorderv(tab, c(key_cols, order_cols))
ids = tab[, {
train_end = rev(seq.int(
train_end = seq.int(
from = .N - horizon,
by = -pars$step_size,
length.out = pars$folds
))
)
if (pars$fixed_window) {
train_ids = map(train_end, function(x) .SD[(x - window_size + 1L):x, row_id])
} else {
Expand All @@ -137,9 +140,10 @@ ResamplingFcstCV = R6Class("ResamplingFcstCV",
})
list(train_ids = train_ids, test_ids = test_ids)
}, by = key_cols][, .(train_ids, test_ids)]

train_ids = ids$train_ids
test_ids = ids$test_ids
}
list(train = ids$train_ids, test = ids$test_ids)
list(train = train_ids, test = test_ids)
},

.sample_ids = function(ids, ...) {
Expand All @@ -149,11 +153,11 @@ ResamplingFcstCV = R6Class("ResamplingFcstCV",

ids = sort(ids)
train_end = ids[ids <= (max(ids) - horizon) & ids >= window_size]
train_end = rev(seq.int(
train_end = seq.int(
from = train_end[length(train_end)],
by = -pars$step_size,
length.out = pars$folds
))
)
if (pars$fixed_window) {
train_ids = map(train_end, function(x) (x - window_size + 1L):x)
} else {
Expand Down
10 changes: 5 additions & 5 deletions tests/testthat/test_ResamplingFcstCV.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
test_that("forecast_cv basic properties", {
task = tsk("penguins")
task = tsk("airpassengers")
resampling = rsmp("forecast_cv",
folds = 10L, horizon = 3L, window_size = 5L, fixed_window = FALSE
)
expect_resampling(resampling, task)
expect_resampling(resampling, task, strata = FALSE)
resampling$instantiate(task)
expect_resampling(resampling, task)
expect_resampling(resampling, task, strata = FALSE)
expect_identical(resampling$iters, 10L)
expect_equal(intersect(resampling$test_set(1L), resampling$train_set(1L)), integer())
expect_false(resampling$duplicated_ids)
Expand Down Expand Up @@ -33,7 +33,7 @@ test_that("forecast_cv works", {
})

test_that("forecast_cv fixed vs. expanding window", {
task = tsk("penguins")
task = tsk("airpassengers")
task$filter(1:30)

# fixed window
Expand All @@ -56,7 +56,7 @@ test_that("forecast_cv fixed vs. expanding window", {
})

test_that("forecast_cv with various parameter combinations", {
task = tsk("penguins")
task = tsk("airpassengers")
task$filter(1:30)

# small window, large step size
Expand Down

0 comments on commit 060073e

Please sign in to comment.