Skip to content

Commit

Permalink
refactor: some renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 23, 2025
1 parent 497e34d commit 5982488
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 70 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ Collate:
'MeasureDirectional.R'
'PipeOpFcstLag.R'
'PipeOpTargetTrafo.R'
'ResamplingForecastCV.R'
'ResamplingForecastHoldout.R'
'ResamplingFcstCV.R'
'ResamplingFcstHoldout.R'
'TaskFcst.R'
'TaskFcstAirpassengers.R'
'TaskFcstElectricty.R'
Expand Down
4 changes: 2 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ export(LearnerFcstBats)
export(LearnerFcstEts)
export(LearnerFcstTbats)
export(PipeOpFcstLag)
export(ResamplingForecastCV)
export(ResamplingForecastHoldout)
export(ResamplingFcstCV)
export(ResamplingFcstHoldout)
export(TaskFcst)
export(as_task_fcst)
import(R6)
Expand Down
4 changes: 2 additions & 2 deletions R/ResamplingForecastCV.R → R/ResamplingFcstCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#'
#' # Internal storage:
#' cv$instance # list
ResamplingForecastCV = R6Class("ResamplingForecastCV",
ResamplingFcstCV = R6Class("ResamplingFcstCV",
inherit = Resampling,
public = list(
#' @description
Expand Down Expand Up @@ -160,4 +160,4 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV",
)

#' @include zzz.R
register_resampling("forecast_cv", ResamplingForecastCV)
register_resampling("forecast_cv", ResamplingFcstCV)
60 changes: 30 additions & 30 deletions R/ResamplingForecastHoldout.R → R/ResamplingFcstHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
#'
#' # Internal storage:
#' holdout$instance # simple list
ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
ResamplingFcstHoldout = R6Class("ResamplingFcstHoldout",
inherit = Resampling,
public = list(
#' @description
Expand Down Expand Up @@ -67,7 +67,7 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
),

private = list(
.sample_old = function(ids, ...) {
.sample = function(ids, task, ...) {
if ("ordered" %nin% task$properties) {
stopf(
"Resampling '%s' requires an ordered task, but Task '%s' has no order.",
Expand All @@ -78,7 +78,7 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
pars = self$param_set$get_values()
ratio = pars$ratio
n = pars$n
n_obs = length(ids)
n_obs = task$nrow

has_ratio = !is.null(ratio)
if (!xor(!has_ratio, is.null(n))) {
Expand All @@ -92,12 +92,30 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
nr = max(n_obs + n, 0L)
}

ids = sort(ids)
ii = ids[1:nr]
list(train = ii, test = ids[(nr + 1L):n_obs])
order_cols = task$col_roles$order
key_cols = task$col_roles$key
has_key_cols = length(key_cols) > 0L
tab = task$backend$data(rows = ids, cols = c(task$backend$primary_key, order_cols, key_cols))
if (has_key_cols) {
setnames(tab, c("row_id", "order", "key"))
setorderv(tab, c("key", "order"))
n_groups = length(unique(tab$key))
nr = if (has_ratio) nr %/% n_groups else nr
list(
train = tab[, .SD[1:nr], by = key][, row_id],
test = tab[, .SD[(nr + 1L):.N], by = key][, row_id]
)
} else {
setnames(tab, c("row_id", "order"))
setorderv(tab, c("order"))
list(
train = tab[1:nr, row_id],
test = tab[(nr + 1L):.N, row_id]
)
}
},

.sample = function(ids, task, ...) {
.sample_ids = function(ids, ...) {
if ("ordered" %nin% task$properties) {
stopf(
"Resampling '%s' requires an ordered task, but Task '%s' has no order.",
Expand All @@ -108,7 +126,7 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
pars = self$param_set$get_values()
ratio = pars$ratio
n = pars$n
n_obs = task$nrow
n_obs = length(ids)

has_ratio = !is.null(ratio)
if (!xor(!has_ratio, is.null(n))) {
Expand All @@ -122,27 +140,9 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
nr = max(n_obs + n, 0L)
}

order_cols = task$col_roles$order
key_cols = task$col_roles$key
has_key_cols = length(key_cols) > 0L
tab = task$backend$data(rows = ids, cols = c(task$backend$primary_key, order_cols, key_cols))
if (has_key_cols) {
setnames(tab, c("row_id", "order", "key"))
setorderv(tab, c("key", "order"))
n_groups = length(unique(tab$key))
nr = if (has_ratio) nr %/% n_groups else nr
list(
train = tab[, .SD[1:nr], by = key][, row_id],
test = tab[, .SD[(nr + 1L):.N], by = key][, row_id]
)
} else {
setnames(tab, c("row_id", "order"))
setorderv(tab, c("order"))
list(
train = tab[1:nr, row_id],
test = tab[(nr + 1L):.N, row_id]
)
}
ids = sort(ids)
ii = ids[1:nr]
list(train = ii, test = ids[(nr + 1L):n_obs])
},

.get_train = function(i) {
Expand All @@ -160,4 +160,4 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
)

#' @include zzz.R
register_resampling("forecast_holdout", ResamplingForecastHoldout)
register_resampling("forecast_holdout", ResamplingFcstHoldout)
4 changes: 1 addition & 3 deletions R/TaskFcstAirpassengers.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
NULL

load_task_airpassengers = function(id = "airpassengers") {
if (!requireNamespace("tsbox", quietly = TRUE)) {
stopf("Package 'tsbox' is required to load the 'AirPassengers' dataset.")
}
require_namespaces("tsbox")
dt = tsbox::ts_dt(load_dataset("AirPassengers", "datasets"))
setnames(dt, c("date", "passengers"))
b = as_data_backend(dt)
Expand Down
4 changes: 1 addition & 3 deletions R/TaskFcstElectricty.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
NULL

load_task_electricty = function(id = "electricity") {
if (!requireNamespace("tsibbledata", quietly = TRUE)) {
stopf("Package 'tsibbledata' is required to load the 'vic_elec' dataset.")
}
require_namespaces("tsibbledata")
dt = as.data.table(load_dataset("vic_elec", "tsibbledata"))
setnames(dt, tolower)
demand = temperature = holiday = NULL
Expand Down
4 changes: 1 addition & 3 deletions R/TaskFcstLivestock.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
NULL

load_task_livestock = function(id = "livestock") {
if (!requireNamespace("tsibbledata", quietly = TRUE)) {
stopf("Package 'tsibbledata' is required to load the 'aus_livestock' dataset.")
}
require_namespaces(c("tsibbledata", "tsibble"))
dt = as.data.table(load_dataset("aus_livestock", "tsibbledata"))
setnames(dt, tolower)
dt[, month := as.Date(month)]
Expand Down
22 changes: 11 additions & 11 deletions man/mlr_resamplings_forecast_cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 11 additions & 11 deletions man/mlr_resamplings_forecast_holdout.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
test_that("forecast_holdout basic properties", {
skip("currently require datetime column, i.e. don't sort based on ids")
task = tsk("penguins")
resampling = rsmp("forecast_holdout", ratio = 0.7)
task = tsk("airpassengers")
resampling = rsmp("forecast_holdout", ratio = 0.8)
expect_resampling(resampling, task)
resampling$instantiate(task)
expect_resampling(resampling, task)
Expand Down

0 comments on commit 5982488

Please sign in to comment.