Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 11, 2024
1 parent 6d36e22 commit 3597ca4
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 16 deletions.
11 changes: 7 additions & 4 deletions R/CallbackEvaluation.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ CallbackEvaluation= R6Class("CallbackEvaluation",
#' Evaluation callbacks are called at different stages of the resampling process.
#' Each stage is called once per resampling iteration.
#' The stages are prefixed with `on_*`.
#' The text in brackets indicates what happens between the stages and which accesses to the [ContextEvaluation] `ctx` are typical for the stage.
#' The text in brackets indicates what happens between the stages and which accesses to the [ContextEvaluation] (`ctx`) are typical for the stage.
#'
#' ```
#' Start Resampling Iteration on Worker
Expand All @@ -64,9 +64,12 @@ CallbackEvaluation= R6Class("CallbackEvaluation",
#' @details
#' When implementing a callback, each function must have two arguments named `callback` and `context`.
#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself.
#' Evaluation callbacks access [ContextEvaluation].
#' Data can be stored in the [ResampleResult] and [BenchmarkResult] objects via `context$data_extra`.
#' Alternatively results can be stored in the learner state via `context$learner$state`.
#'
#' @section Parallelization:
#' Be careful when modifying `ctx$learner`, `ctx$task`, or `ctx$resampling` because callbacks can behave differently when parallelizing the resampling process.
#' When running the resampling process sequentially, the modifications are carried over to the next iteration.
#' When parallelizing the resampling process, modifying the [ContextEvaluation] will not be synchronized between workers.
#' This also applies to the `$state` of the callback.
#'
#' @param id (`character(1)`)\cr
#' Identifier for the new instance.
Expand Down
23 changes: 18 additions & 5 deletions R/ContextEvaluation.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ ContextEvaluation = R6Class("ContextEvaluation",
#' The data is available on stage `on_evaluation_end`.
pdatas = NULL,

#' @field data_extra (list())\cr
#' Data saved in the [ResampleResult] or [BenchmarkResult].
#' Use this field to save results.
data_extra = NULL,

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
Expand All @@ -59,5 +54,23 @@ ContextEvaluation = R6Class("ContextEvaluation",

super$initialize(id = "evaluate", label = "Evaluation")
}
),

active = list(

#' @field data_extra (list())\cr
#' Data saved in the [ResampleResult] or [BenchmarkResult].
#' Use this field to save results.
#' Must be a `list()`.
data_extra = function(rhs) {
if (missing(rhs)) {
return(private$.data_extra)
}
private$.data_extra = assert_list(rhs)
}
),

private = list(
.data_extra = NULL
)
)
2 changes: 1 addition & 1 deletion R/ResultData.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ ResultData = R6Class("ResultData",
# add extra data to fact table
if (!is.null(data_extra)) {
assert_list(data_extra, len = nrow(data))
set(data, j = "data_extra", value = data_extra)
set(data, j = "data_extra", value = list(data_extra))
}

if (!store_backends) {
Expand Down
10 changes: 8 additions & 2 deletions man/ContextEvaluation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 9 additions & 4 deletions man/callback_evaluation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

22 changes: 22 additions & 0 deletions tests/testthat/test_CallbackEvaluation.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,27 @@ test_that("writing to data_extra works", {
})
})

test_that("data_extra is a list column", {
task = tsk("pima")
learner = lrn("classif.rpart")
resampling = rsmp("holdout")

callback = callback_evaluation("test",
on_evaluation_end = function(callback, context) {
context$data_extra$test = 1
}
)

rr = resample(task, learner, resampling, callbacks = callback)
expect_list(as.data.table(rr, data_extra = TRUE)$data_extra, len = 1)
expect_list(as.data.table(rr, data_extra = TRUE)$data_extra[[1]], len = 1)

resampling = rsmp("cv", folds = 3)
rr = resample(task, learner, resampling, callbacks = callback)
expect_list(as.data.table(rr, data_extra = TRUE)$data_extra, len = 3)
expect_list(as.data.table(rr, data_extra = TRUE)$data_extra[[1]], len = 1)
})

test_that("data_extra is null", {
task = tsk("pima")
learner = lrn("classif.rpart")
Expand Down Expand Up @@ -161,3 +182,4 @@ test_that("data_extra is null", {
expect_data_table(tab)
expect_names(names(tab), disjunct.from = "data_extra")
})

0 comments on commit 3597ca4

Please sign in to comment.