Skip to content

Commit

Permalink
chore: move generate_newdata() helper into package
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 24, 2025
1 parent b1657c8 commit 60dfebe
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 106 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export(ResamplingFcstCV)
export(ResamplingFcstHoldout)
export(TaskFcst)
export(as_task_fcst)
export(generate_newdata)
import(R6)
import(checkmate)
import(data.table)
Expand Down
21 changes: 21 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,24 @@ as.ts.TaskFcst = function(x, ...) { # nolint
)
stats::ts(x$truth(), freq = freq)
}

#' @export
generate_newdata = function(task, n = 1L) {
assert_count(n)
order_cols = task$col_roles$order
max_index = max(task$data(cols = order_cols)[[1L]])

unit = switch(task$frequency,
daily = "day",
weekly = "week",
monthly = "month",
quarterly = "quarter",
yearly = "quarterly"
)
unit = sprintf("1 %s", unit)
index = seq(max_index, length.out = n + 1L, by = unit)
index = index[2:length(index)]

newdata = data.frame(index = index, target = rep(NA_real_, n), check.names = FALSE)
set_names(newdata, c(order_cols, task$target_names))
}
45 changes: 4 additions & 41 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -57,54 +57,17 @@ still in flux and may change.

### Example: forecasting with forecast learner

First lets create a helper function to generate new data for forecasting tasks.

```{r}
library(mlr3forecast)
generate_newdata = function(task, n = 1L, resolution = "day") {
assert_count(n)
assert_string(resolution)
assert_choice(
resolution, c("second", "minute", "hour", "day", "week", "month", "quarter", "year")
)
order_cols = task$col_roles$order
max_index = max(task$data(cols = order_cols)[[1L]])
unit = switch(resolution,
second = "sec",
minute = "min",
hour = ,
day = ,
week = ,
month = ,
quarter = ,
year = identity(resolution),
stopf("Invalid resolution")
)
unit = sprintf("1 %s", unit)
index = seq(max_index, length.out = n + 1L, by = unit)
index = index[2:length(index)]
newdata = data.frame(index = index, target = rep(NA_real_, n), check.names = FALSE)
setNames(newdata, c(order_cols, task$target_names))
}
task = tsk("airpassengers")
newdata = generate_newdata(task, 12L, "month")
newdata
```

Currently, we support native forecasting learners from the forecast package.
In the future, we plan to support more forecasting learners.

```{r}
library(mlr3forecast)
task = tsk("airpassengers")
learner = lrn("fcst.auto_arima")$train(task)
prediction = learner$predict(task, 140:144)
prediction$score(msr("regr.rmse"))
newdata = generate_newdata(task, 12L, "month")
newdata = generate_newdata(task, 12L)
learner$predict_newdata(newdata, task)
# works with quantile response
Expand Down Expand Up @@ -318,7 +281,7 @@ glrn = as_learner(graph %>>% flrn)$train(task)
prediction = glrn$predict(task, 142:144)
prediction$score(msr("regr.rmse"))
newdata = generate_newdata(task, 12L, "month")
newdata = generate_newdata(task, 12L)
glrn$predict_newdata(newdata, task)
```

Expand Down
79 changes: 14 additions & 65 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,64 +43,13 @@ of the latter is still in flux and may change.

### Example: forecasting with forecast learner

First lets create a helper function to generate new data for forecasting
tasks.
Currently, we support native forecasting learners from the forecast
package. In the future, we plan to support more forecasting learners.

``` r
library(mlr3forecast)
#> Loading required package: mlr3

generate_newdata = function(task, n = 1L, resolution = "day") {
assert_count(n)
assert_string(resolution)
assert_choice(
resolution, c("second", "minute", "hour", "day", "week", "month", "quarter", "year")
)

order_cols = task$col_roles$order
max_index = max(task$data(cols = order_cols)[[1L]])

unit = switch(resolution,
second = "sec",
minute = "min",
hour = ,
day = ,
week = ,
month = ,
quarter = ,
year = identity(resolution),
stopf("Invalid resolution")
)
unit = sprintf("1 %s", unit)
index = seq(max_index, length.out = n + 1L, by = unit)
index = index[2:length(index)]

newdata = data.frame(index = index, target = rep(NA_real_, n), check.names = FALSE)
setNames(newdata, c(order_cols, task$target_names))
}

task = tsk("airpassengers")
newdata = generate_newdata(task, 12L, "month")
newdata
#> date passengers
#> 1 1961-01-01 NA
#> 2 1961-02-01 NA
#> 3 1961-03-01 NA
#> 4 1961-04-01 NA
#> 5 1961-05-01 NA
#> 6 1961-06-01 NA
#> 7 1961-07-01 NA
#> 8 1961-08-01 NA
#> 9 1961-09-01 NA
#> 10 1961-10-01 NA
#> 11 1961-11-01 NA
#> 12 1961-12-01 NA
```

Currently, we support native forecasting learners from the forecast
package. In the future, we plan to support more forecasting learners.

``` r
task = tsk("airpassengers")
learner = lrn("fcst.auto_arima")$train(task)
#> Registered S3 method overwritten by 'quantmod':
Expand All @@ -110,7 +59,7 @@ prediction = learner$predict(task, 140:144)
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 13.85493
newdata = generate_newdata(task, 12L, "month")
newdata = generate_newdata(task, 12L)
learner$predict_newdata(newdata, task)
#> <PredictionRegr> for 12 observations:
#> row_ids truth response
Expand Down Expand Up @@ -154,32 +103,32 @@ prediction = flrn$predict_newdata(newdata, task)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 NA 433.4103
#> 2 NA 435.9582
#> 3 NA 456.0257
#> 1 NA 435.1630
#> 2 NA 436.5908
#> 3 NA 456.2188
prediction = flrn$predict(task, 142:144)
prediction
#> <PredictionRegr> for 3 observations:
#> row_ids truth response
#> 1 461 457.2672
#> 2 390 413.4132
#> 3 432 430.3775
#> 1 461 459.4903
#> 2 390 414.8749
#> 3 432 433.6170
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 13.72037
#> 14.41823

flrn = ForecastLearner$new(lrn("regr.ranger"), 1:12)
resampling = rsmp("forecast_holdout", ratio = 0.9)
rr = resample(task, flrn, resampling)
rr$aggregate(msr("regr.rmse"))
#> regr.rmse
#> 47.75969
#> 48.34306

resampling = rsmp("forecast_cv")
rr = resample(task, flrn, resampling)
rr$aggregate(msr("regr.rmse"))
#> regr.rmse
#> 25.95714
#> 26.41356
```

Or with some feature engineering using mlr3pipelines:
Expand All @@ -200,7 +149,7 @@ glrn = as_learner(graph %>>% flrn)$train(task)
prediction = glrn$predict(task, 142:144)
prediction$score(msr("regr.rmse"))
#> regr.rmse
#> 16.00005
#> 15.49621
```

### Example: Forecasting electricity demand
Expand Down Expand Up @@ -361,7 +310,7 @@ glrn = as_learner(graph %>>% flrn)$train(task)
prediction = glrn$predict(task, 142:144)
prediction$score(msr("regr.rmse"))

newdata = generate_newdata(task, 12L, "month")
newdata = generate_newdata(task, 12L)
glrn$predict_newdata(newdata, task)
```

Expand Down

0 comments on commit 60dfebe

Please sign in to comment.