From 5b460c22a5d423e2731e24f8a30ae6beb1245ec2 Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 20 Nov 2024 13:10:44 +0100 Subject: [PATCH 01/54] ... --- R/CallbackResample.R | 113 ++++++++++++++++++++++++++++++++ R/ContextResample.R | 17 +++++ R/resample.R | 22 ++++++- R/worker.R | 21 +++++- inst/testthat/helper_autotest.R | 1 + 5 files changed, 169 insertions(+), 5 deletions(-) create mode 100644 R/CallbackResample.R create mode 100644 R/ContextResample.R diff --git a/R/CallbackResample.R b/R/CallbackResample.R new file mode 100644 index 000000000..6ac7cc4bd --- /dev/null +++ b/R/CallbackResample.R @@ -0,0 +1,113 @@ +#' @title Create Batch Tuning Callback +#' +#' @description +#' Specialized [bbotk::CallbackBatch] for batch tuning. +#' Callbacks allow to customize the behavior of processes in mlr3tuning. +#' The [callback_batch_tuning()] function creates a [CallbackResample]. +#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. +#' For more information on tuning callbacks see [callback_batch_tuning()]. +#' +#' @export +CallbackResample= R6Class("CallbackResample", + inherit = Callback, + public = list( + + #' @field on_resample_before_result_data (`function()`)\cr + #' Stage called before the result data is created. + on_resample_before_result_data = NULL + ) +) + +#' @title Create Batch Tuning Callback +#' +#' @description +#' Function to create a [CallbackResample]. +#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. +#' +#' Tuning callbacks can be called from different stages of the tuning process. +#' The stages are prefixed with `on_*`. +#' +#' ``` +#' Start Tuning +#' - on_optimization_begin +#' Start Tuner Batch +#' - on_optimizer_before_eval +#' Start Evaluation +#' - on_eval_after_design +#' - on_eval_after_benchmark +#' - on_eval_before_archive +#' End Evaluation +#' - on_optimizer_after_eval +#' End Tuner Batch +#' - on_tuning_result_begin +#' - on_result_begin +#' - on_result_end +#' - on_optimization_end +#' End Tuning +#' ``` +#' +#' See also the section on parameters for more information on the stages. +#' A tuning callback works with [ContextBatchTuning]. +#' +#' @details +#' When implementing a callback, each function must have two arguments named `callback` and `context`. +#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself. +#' Tuning callbacks access [ContextBatchTuning]. +#' +#' @param id (`character(1)`)\cr +#' Identifier for the new instance. +#' @param label (`character(1)`)\cr +#' Label for the new instance. +#' @param man (`character(1)`)\cr +#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. +#' The referenced help package can be opened via method `$help()`. +#' +#' @param on_optimization_begin (`function()`)\cr +#' Stage called at the beginning of the optimization. +#' Called in `Optimizer$optimize()`. +#' The functions must have two arguments named `callback` and `context`. +#' +#' @export +#' @inherit CallbackResample examples +callback_resample = function( + id, + label = NA_character_, + man = NA_character_, + on_resample_before_result_data = NULL + ) { + stages = discard(set_names(list( + on_resample_before_result_data), + c( + "on_resample_before_result_data" + )), is.null) + + walk(stages, function(stage) assert_function(stage, args = c("callback", "context"))) + callback = CallbackResample$new(id, label, man) + iwalk(stages, function(stage, name) callback[[name]] = stage) + callback +} + +#' @title Assertions for Callbacks +#' +#' @description +#' Assertions for [CallbackResample] class. +#' +#' @param callback ([CallbackResample]). +#' @param null_ok (`logical(1)`)\cr +#' If `TRUE`, `NULL` is allowed. +#' +#' @return [CallbackResample | List of [CallbackResample]s. +#' @export +assert_resample_callback = function(callback, null_ok = FALSE) { + if (null_ok && is.null(callback)) return(invisible(NULL)) + browser() + assert_class(callback, "CallbackResample") + invisible(callback) +} + +#' @export +#' @param callbacks (list of [CallbackResample]). +#' @rdname assert_resample_callback +assert_resample_callbacks = function(callbacks) { + invisible(lapply(callbacks, assert_resample_callback)) +} diff --git a/R/ContextResample.R b/R/ContextResample.R new file mode 100644 index 000000000..2f86aa8dd --- /dev/null +++ b/R/ContextResample.R @@ -0,0 +1,17 @@ +#' @title Batch Tuning Context +#' +#' @description +#' A [CallbackResample] accesses and modifies data during the optimization via the `ContextBatchTuning`. +#' See the section on active bindings for a list of modifiable objects. +#' See [callback_batch_tuning()] for a list of stages that access `ContextBatchTuning`. +#' +#' @template param_inst_batch +#' @template param_tuner +#' +#' @export +ContextResample = R6Class("ContextResample", + inherit = Context, + public = list( + data = NULL + ) +) diff --git a/R/resample.R b/R/resample.R index cc1bb88f2..a6a6523cb 100644 --- a/R/resample.R +++ b/R/resample.R @@ -55,7 +55,21 @@ #' bmr1 = as_benchmark_result(rr) #' bmr2 = as_benchmark_result(rr_featureless) #' print(bmr1$combine(bmr2)) -resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) { +resample = function( + task, + learner, + resampling, + store_models = FALSE, + store_backends = TRUE, + encapsulate = NA_character_, + allow_hotstart = FALSE, + clone = c("task", "learner", "resampling"), + unmarshal = TRUE, + callbacks = NULL + ) { + callbacks = assert_resample_callbacks(as_callbacks(callbacks)) + context = ContextResample$new("resample") + assert_subset(clone, c("task", "learner", "resampling")) task = assert_task(as_task(task, clone = "task" %in% clone)) learner = assert_learner(as_learner(learner, clone = "learner" %in% clone, discard_state = TRUE)) @@ -118,7 +132,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal) ) - data = data.table( + context$data = data.table( task = list(task), learner = grid$learner, learner_state = map(res, "learner_state"), @@ -130,7 +144,9 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe learner_hash = map_chr(res, "learner_hash") ) - result_data = ResultData$new(data, store_backends = store_backends) + call_back("on_resample_before_result_data", callbacks, context) + + result_data = ResultData$new(context$data, store_backends = store_backends) # the worker already ensures that models are sent back in marshaled form if unmarshal = FALSE, so we don't have # to do anything in this case. This allows us to minimize the amount of marshaling in those situtions where diff --git a/R/worker.R b/R/worker.R index e52daa23f..b06ffd39d 100644 --- a/R/worker.R +++ b/R/worker.R @@ -1,4 +1,4 @@ -learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NULL, mode = "train") { +learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NULL, mode = "train", callback, context) { # This wrapper calls learner$.train, and additionally performs some basic # checks that the training was successful. # Exceptions here are possibly encapsulated, so that they get captured @@ -251,7 +251,20 @@ learner_predict = function(learner, task, row_ids = NULL) { } -workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE, unmarshal = TRUE) { +workhorse = function( + iteration, + task, + learner, + resampling, + param_values = NULL, + lgr_threshold, + store_models = FALSE, + pb = NULL, + mode = "train", + is_sequential = TRUE, + unmarshal = TRUE, + callback = NULL, + ) { if (!is.null(pb)) { pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) } @@ -332,6 +345,10 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, } pdatas = discard(pdatas, is.null) + if (!is.null(callback)) { + learner_state = c(learner_state, assert_list(callback(learner$model))) + } + # set the model slot after prediction so it can be sent back to the main process process_model_after_predict( learner = learner, store_models = store_models, is_sequential = is_sequential, model_copy = model_copy_or_null, diff --git a/inst/testthat/helper_autotest.R b/inst/testthat/helper_autotest.R index 377e7d152..921e51fd6 100644 --- a/inst/testthat/helper_autotest.R +++ b/inst/testthat/helper_autotest.R @@ -490,6 +490,7 @@ run_autotest = function(learner, N = 30L, exclude = NULL, predict_types = learne if (predict_type == "quantiles") { learner$quantiles = 0.5 + browser() } run = run_experiment(task, learner) From af87ce70390e22dbb303538ea6647c78ce453c31 Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 20 Nov 2024 13:53:51 +0100 Subject: [PATCH 02/54] ... --- inst/test.r | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 inst/test.r diff --git a/inst/test.r b/inst/test.r new file mode 100644 index 000000000..a4571198b --- /dev/null +++ b/inst/test.r @@ -0,0 +1,13 @@ +# Callback +callback = callback_resample( + id = "test", + on_resample_before_result_data = function(callback, context) { + print("on_resample_before_result_data") + } +) + +learner = lrn("classif.rpart") +task = tsk("iris") +resampling = rsmp("cv", folds = 3) + +resample(task, learner, resampling, callbacks = callback) From 6a97bf0b0379c50be0e38231e14f5a9949ff8634 Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 21 Nov 2024 13:11:56 +0100 Subject: [PATCH 03/54] ... --- DESCRIPTION | 2 + NAMESPACE | 5 + R/CallbackResample.R | 113 -------------------- R/CallbackWorkhorse.R | 93 ++++++++++++++++ R/{ContextResample.R => ContextWorkhorse.R} | 7 +- R/resample.R | 11 +- R/worker.R | 15 ++- man/CallbackWorkhorse.Rd | 46 ++++++++ man/ContextWorkhorse.Rd | 46 ++++++++ man/Task.Rd | 6 +- man/assert_workhorse_callback.Rd | 25 +++++ man/callback_workhorse.Rd | 41 +++++++ man/resample.Rd | 3 +- 13 files changed, 280 insertions(+), 133 deletions(-) delete mode 100644 R/CallbackResample.R create mode 100644 R/CallbackWorkhorse.R rename R/{ContextResample.R => ContextWorkhorse.R} (75%) create mode 100644 man/CallbackWorkhorse.Rd create mode 100644 man/ContextWorkhorse.Rd create mode 100644 man/assert_workhorse_callback.Rd create mode 100644 man/callback_workhorse.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 28236d721..e1742a859 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -79,6 +79,8 @@ RoxygenNote: 7.3.2 Collate: 'mlr_reflections.R' 'BenchmarkResult.R' + 'CallbackWorkhorse.R' + 'ContextWorkhorse.R' 'warn_deprecated.R' 'DataBackend.R' 'DataBackendCbind.R' diff --git a/NAMESPACE b/NAMESPACE index eadee06bf..54c5e2316 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -120,6 +120,8 @@ S3method(unmarshal_model,classif.debug_model_marshaled) S3method(unmarshal_model,default) S3method(unmarshal_model,learner_state_marshaled) export(BenchmarkResult) +export(CallbackWorkhorse) +export(ContextWorkhorse) export(DataBackend) export(DataBackendDataTable) export(DataBackendMatrix) @@ -214,9 +216,12 @@ export(assert_row_ids) export(assert_task) export(assert_tasks) export(assert_validate) +export(assert_workhorse_callback) +export(assert_workhorse_callbacks) export(auto_convert) export(benchmark) export(benchmark_grid) +export(callback_workhorse) export(check_prediction_data) export(col_info) export(convert_task) diff --git a/R/CallbackResample.R b/R/CallbackResample.R deleted file mode 100644 index 6ac7cc4bd..000000000 --- a/R/CallbackResample.R +++ /dev/null @@ -1,113 +0,0 @@ -#' @title Create Batch Tuning Callback -#' -#' @description -#' Specialized [bbotk::CallbackBatch] for batch tuning. -#' Callbacks allow to customize the behavior of processes in mlr3tuning. -#' The [callback_batch_tuning()] function creates a [CallbackResample]. -#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. -#' For more information on tuning callbacks see [callback_batch_tuning()]. -#' -#' @export -CallbackResample= R6Class("CallbackResample", - inherit = Callback, - public = list( - - #' @field on_resample_before_result_data (`function()`)\cr - #' Stage called before the result data is created. - on_resample_before_result_data = NULL - ) -) - -#' @title Create Batch Tuning Callback -#' -#' @description -#' Function to create a [CallbackResample]. -#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. -#' -#' Tuning callbacks can be called from different stages of the tuning process. -#' The stages are prefixed with `on_*`. -#' -#' ``` -#' Start Tuning -#' - on_optimization_begin -#' Start Tuner Batch -#' - on_optimizer_before_eval -#' Start Evaluation -#' - on_eval_after_design -#' - on_eval_after_benchmark -#' - on_eval_before_archive -#' End Evaluation -#' - on_optimizer_after_eval -#' End Tuner Batch -#' - on_tuning_result_begin -#' - on_result_begin -#' - on_result_end -#' - on_optimization_end -#' End Tuning -#' ``` -#' -#' See also the section on parameters for more information on the stages. -#' A tuning callback works with [ContextBatchTuning]. -#' -#' @details -#' When implementing a callback, each function must have two arguments named `callback` and `context`. -#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself. -#' Tuning callbacks access [ContextBatchTuning]. -#' -#' @param id (`character(1)`)\cr -#' Identifier for the new instance. -#' @param label (`character(1)`)\cr -#' Label for the new instance. -#' @param man (`character(1)`)\cr -#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. -#' The referenced help package can be opened via method `$help()`. -#' -#' @param on_optimization_begin (`function()`)\cr -#' Stage called at the beginning of the optimization. -#' Called in `Optimizer$optimize()`. -#' The functions must have two arguments named `callback` and `context`. -#' -#' @export -#' @inherit CallbackResample examples -callback_resample = function( - id, - label = NA_character_, - man = NA_character_, - on_resample_before_result_data = NULL - ) { - stages = discard(set_names(list( - on_resample_before_result_data), - c( - "on_resample_before_result_data" - )), is.null) - - walk(stages, function(stage) assert_function(stage, args = c("callback", "context"))) - callback = CallbackResample$new(id, label, man) - iwalk(stages, function(stage, name) callback[[name]] = stage) - callback -} - -#' @title Assertions for Callbacks -#' -#' @description -#' Assertions for [CallbackResample] class. -#' -#' @param callback ([CallbackResample]). -#' @param null_ok (`logical(1)`)\cr -#' If `TRUE`, `NULL` is allowed. -#' -#' @return [CallbackResample | List of [CallbackResample]s. -#' @export -assert_resample_callback = function(callback, null_ok = FALSE) { - if (null_ok && is.null(callback)) return(invisible(NULL)) - browser() - assert_class(callback, "CallbackResample") - invisible(callback) -} - -#' @export -#' @param callbacks (list of [CallbackResample]). -#' @rdname assert_resample_callback -assert_resample_callbacks = function(callbacks) { - invisible(lapply(callbacks, assert_resample_callback)) -} diff --git a/R/CallbackWorkhorse.R b/R/CallbackWorkhorse.R new file mode 100644 index 000000000..7c00c10cd --- /dev/null +++ b/R/CallbackWorkhorse.R @@ -0,0 +1,93 @@ +#' @title Create Workhorse Callback +#' +#' @description +#' Callbacks allow to customize the behavior of processes in mlr3. +#' +#' @export +CallbackWorkhorse= R6Class("CallbackWorkhorse", + inherit = Callback, + public = list( + + on_workhorse_before_train = NULL, + + on_workhorse_before_predict = NULL, + + on_workhorse_before_result = NULL + ) +) + +#' @title Create Workhorse Callback +#' +#' @description +#' Function to create a [CallbackWorkhorse]. +#' +#' ``` +#' Start Workhorse +#' - on_workhorse_before_train +#' - on_workhorse_before_predict +#' - on_workhorse_before_result +#' End Tuning +#' ``` +# +#' @details +#' When implementing a callback, each function must have two arguments named `callback` and `context`. +#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself. +#' Workhorse callbacks access [ContextWorkhorse]. +#' +#' @param id (`character(1)`)\cr +#' Identifier for the new instance. +#' @param label (`character(1)`)\cr +#' Label for the new instance. +#' @param man (`character(1)`)\cr +#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. +#' The referenced help package can be opened via method `$help()`. +#' +#' @export +#' @inherit CallbackWorkhorse examples +callback_workhorse = function( + id, + label = NA_character_, + man = NA_character_, + on_workhorse_before_train = NULL, + on_workhorse_before_predict = NULL, + on_workhorse_before_result = NULL + ) { + stages = discard(set_names(list( + on_workhorse_before_train, + on_workhorse_before_predict, + on_workhorse_before_result), + c( + "on_workhorse_before_train", + "on_workhorse_before_predict", + "on_workhorse_before_result" + )), is.null) + + walk(stages, function(stage) assert_function(stage, args = c("callback", "context"))) + callback = CallbackWorkhorse$new(id, label, man) + iwalk(stages, function(stage, name) callback[[name]] = stage) + callback +} + +#' @title Assertions for Callbacks +#' +#' @description +#' Assertions for [CallbackWorkhorse] class. +#' +#' @param callback ([CallbackWorkhorse]). +#' @param null_ok (`logical(1)`)\cr +#' If `TRUE`, `NULL` is allowed. +#' +#' @return [CallbackWorkhorse | List of [CallbackWorkhorse]s. +#' @export +assert_workhorse_callback = function(callback, null_ok = FALSE) { + if (null_ok && is.null(callback)) return(invisible(NULL)) + assert_class(callback, "CallbackWorkhorse") + invisible(callback) +} + +#' @export +#' @param callbacks (list of [CallbackWorkhorse]). +#' @rdname assert_workhorse_callback +assert_workhorse_callbacks = function(callbacks) { + invisible(lapply(callbacks, assert_workhorse_callback)) +} diff --git a/R/ContextResample.R b/R/ContextWorkhorse.R similarity index 75% rename from R/ContextResample.R rename to R/ContextWorkhorse.R index 2f86aa8dd..6ec1e9cbd 100644 --- a/R/ContextResample.R +++ b/R/ContextWorkhorse.R @@ -5,13 +5,10 @@ #' See the section on active bindings for a list of modifiable objects. #' See [callback_batch_tuning()] for a list of stages that access `ContextBatchTuning`. #' -#' @template param_inst_batch -#' @template param_tuner -#' #' @export -ContextResample = R6Class("ContextResample", +ContextWorkhorse = R6Class("ContextResample", inherit = Context, public = list( - data = NULL + env = NULL ) ) diff --git a/R/resample.R b/R/resample.R index a6a6523cb..6ecdae27b 100644 --- a/R/resample.R +++ b/R/resample.R @@ -67,8 +67,7 @@ resample = function( unmarshal = TRUE, callbacks = NULL ) { - callbacks = assert_resample_callbacks(as_callbacks(callbacks)) - context = ContextResample$new("resample") + callbacks = assert_workhorse_callbacks(as_callbacks(callbacks)) assert_subset(clone, c("task", "learner", "resampling")) task = assert_task(as_task(task, clone = "task" %in% clone)) @@ -129,10 +128,10 @@ resample = function( } res = future_map(n, workhorse, iteration = seq_len(n), learner = grid$learner, mode = grid$mode, - MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal) + MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, callbacks = callbacks) ) - context$data = data.table( + data = data.table( task = list(task), learner = grid$learner, learner_state = map(res, "learner_state"), @@ -144,9 +143,7 @@ resample = function( learner_hash = map_chr(res, "learner_hash") ) - call_back("on_resample_before_result_data", callbacks, context) - - result_data = ResultData$new(context$data, store_backends = store_backends) + result_data = ResultData$new(data, store_backends = store_backends) # the worker already ensures that models are sent back in marshaled form if unmarshal = FALSE, so we don't have # to do anything in this case. This allows us to minimize the amount of marshaling in those situtions where diff --git a/R/worker.R b/R/worker.R index b06ffd39d..9a2fa938c 100644 --- a/R/worker.R +++ b/R/worker.R @@ -263,8 +263,11 @@ workhorse = function( mode = "train", is_sequential = TRUE, unmarshal = TRUE, - callback = NULL, + callbacks = NULL ) { + context = ContextWorkhorse$new("workhorse") + context$env = environment() + if (!is.null(pb)) { pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) } @@ -319,6 +322,9 @@ workhorse = function( validate = get0("validate", learner) test_set = if (identical(validate, "test")) sets$test + + call_back("on_workhorse_before_train", callbacks, context) + train_result = learner_train(learner, task, sets[["train"]], test_set, mode = mode) learner = train_result$learner @@ -337,6 +343,9 @@ workhorse = function( pdatas = Map(function(set, row_ids, task) { lg$debug("Creating Prediction for predict set '%s'", set) + + call_back("on_workhorse_before_predict", callbacks, context) + learner_predict(learner, task, row_ids) }, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks) @@ -345,9 +354,7 @@ workhorse = function( } pdatas = discard(pdatas, is.null) - if (!is.null(callback)) { - learner_state = c(learner_state, assert_list(callback(learner$model))) - } + call_back("on_workhorse_before_result", callbacks, context) # set the model slot after prediction so it can be sent back to the main process process_model_after_predict( diff --git a/man/CallbackWorkhorse.Rd b/man/CallbackWorkhorse.Rd new file mode 100644 index 000000000..c07b2c9cf --- /dev/null +++ b/man/CallbackWorkhorse.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/CallbackWorkhorse.R +\name{CallbackWorkhorse} +\alias{CallbackWorkhorse} +\title{Create Workhorse Callback} +\description{ +Callbacks allow to customize the behavior of processes in mlr3. +} +\section{Super class}{ +\code{\link[mlr3misc:Callback]{mlr3misc::Callback}} -> \code{CallbackWorkhorse} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-CallbackWorkhorse-clone}{\code{CallbackWorkhorse$clone()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CallbackWorkhorse-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{CallbackWorkhorse$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/ContextWorkhorse.Rd b/man/ContextWorkhorse.Rd new file mode 100644 index 000000000..5b0503263 --- /dev/null +++ b/man/ContextWorkhorse.Rd @@ -0,0 +1,46 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ContextWorkhorse.R +\name{ContextWorkhorse} +\alias{ContextWorkhorse} +\title{Batch Tuning Context} +\description{ +A \link{CallbackResample} accesses and modifies data during the optimization via the \code{ContextBatchTuning}. +See the section on active bindings for a list of modifiable objects. +See \code{\link[=callback_batch_tuning]{callback_batch_tuning()}} for a list of stages that access \code{ContextBatchTuning}. +} +\section{Super class}{ +\code{\link[mlr3misc:Context]{mlr3misc::Context}} -> \code{ContextResample} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-ContextResample-clone}{\code{ContextWorkhorse$clone()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ContextResample-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ContextWorkhorse$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/Task.Rd b/man/Task.Rd index f45a41ddc..4e925c1ad 100644 --- a/man/Task.Rd +++ b/man/Task.Rd @@ -147,9 +147,6 @@ Required for \code{\link[=convert_task]{convert_task()}}.} \item{\code{mlr3_version}}{(\code{package_version})\cr Package version of \code{mlr3} used to create the task.} - -\item{\code{characteristics}}{(\code{list()})\cr -List of characteristics of the task, e.g. \code{list(n = 5, p = 7)}.} } \if{html}{\out{}} } @@ -303,6 +300,9 @@ Alternatively, you can provide a \code{\link[=data.frame]{data.frame()}} with th Hash (unique identifier) for all columns except the \code{primary_key}: A \code{character} vector, named by the columns that each element refers to.\cr Columns of different \code{\link{Task}}s or \code{\link{DataBackend}}s that have agreeing \code{col_hashes} always represent the same data, given that the same \code{row}s are selected. The reverse is not necessarily true: There can be columns with the same content that have different \code{col_hashes}.} + +\item{\code{characteristics}}{(\code{list()})\cr +List of characteristics of the task, e.g. \code{list(n = 5, p = 7)}.} } \if{html}{\out{}} } diff --git a/man/assert_workhorse_callback.Rd b/man/assert_workhorse_callback.Rd new file mode 100644 index 000000000..a5add6c9b --- /dev/null +++ b/man/assert_workhorse_callback.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/CallbackWorkhorse.R +\name{assert_workhorse_callback} +\alias{assert_workhorse_callback} +\alias{assert_workhorse_callbacks} +\title{Assertions for Callbacks} +\usage{ +assert_workhorse_callback(callback, null_ok = FALSE) + +assert_workhorse_callbacks(callbacks) +} +\arguments{ +\item{callback}{(\link{CallbackWorkhorse}).} + +\item{null_ok}{(\code{logical(1)})\cr +If \code{TRUE}, \code{NULL} is allowed.} + +\item{callbacks}{(list of \link{CallbackWorkhorse}).} +} +\value{ +[CallbackWorkhorse | List of \link{CallbackWorkhorse}s. +} +\description{ +Assertions for \link{CallbackWorkhorse} class. +} diff --git a/man/callback_workhorse.Rd b/man/callback_workhorse.Rd new file mode 100644 index 000000000..e258d0ab3 --- /dev/null +++ b/man/callback_workhorse.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/CallbackWorkhorse.R +\name{callback_workhorse} +\alias{callback_workhorse} +\title{Create Workhorse Callback} +\usage{ +callback_workhorse( + id, + label = NA_character_, + man = NA_character_, + on_workhorse_before_train = NULL, + on_workhorse_before_predict = NULL, + on_workhorse_before_result = NULL +) +} +\arguments{ +\item{id}{(\code{character(1)})\cr +Identifier for the new instance.} + +\item{label}{(\code{character(1)})\cr +Label for the new instance.} + +\item{man}{(\code{character(1)})\cr +String in the format \verb{[pkg]::[topic]} pointing to a manual page for this object. +The referenced help package can be opened via method \verb{$help()}.} +} +\description{ +Function to create a \link{CallbackWorkhorse}. + +\if{html}{\out{
}}\preformatted{Start Workhorse + - on_workhorse_before_train + - on_workhorse_before_predict + - on_workhorse_before_result +End Tuning +}\if{html}{\out{
}} +} +\details{ +When implementing a callback, each function must have two arguments named \code{callback} and \code{context}. +A callback can write data to the state (\verb{$state}), e.g. settings that affect the callback itself. +Workhorse callbacks access \link{ContextWorkhorse}. +} diff --git a/man/resample.Rd b/man/resample.Rd index b972108ef..00378067e 100644 --- a/man/resample.Rd +++ b/man/resample.Rd @@ -13,7 +13,8 @@ resample( encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), - unmarshal = TRUE + unmarshal = TRUE, + callbacks = NULL ) } \arguments{ From c6ae6bca28ed219c2f3f4e03154080b84d110646 Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 21 Nov 2024 13:13:14 +0100 Subject: [PATCH 04/54] ... --- inst/testthat/helper_autotest.R | 1 - 1 file changed, 1 deletion(-) diff --git a/inst/testthat/helper_autotest.R b/inst/testthat/helper_autotest.R index 921e51fd6..377e7d152 100644 --- a/inst/testthat/helper_autotest.R +++ b/inst/testthat/helper_autotest.R @@ -490,7 +490,6 @@ run_autotest = function(learner, N = 30L, exclude = NULL, predict_types = learne if (predict_type == "quantiles") { learner$quantiles = 0.5 - browser() } run = run_experiment(task, learner) From 211b9cfa48c39d8f77ee2be754eb1c74dd480c32 Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 26 Nov 2024 22:36:54 +0100 Subject: [PATCH 05/54] ... --- DESCRIPTION | 4 +- NAMESPACE | 10 +- R/CallbackEvaluation.R | 110 +++++++++++++++++ R/CallbackWorkhorse.R | 93 -------------- R/ContextEvaluation.R | 73 +++++++++++ R/ContextWorkhorse.R | 14 --- R/resample.R | 2 +- R/worker.R | 47 +++---- ...backWorkhorse.Rd => CallbackEvaluation.Rd} | 41 +++++-- man/ContextEvaluation.Rd | 110 +++++++++++++++++ man/ContextWorkhorse.Rd | 46 ------- man/assert_evaluation_callback.Rd | 25 ++++ man/assert_workhorse_callback.Rd | 25 ---- ...ck_workhorse.Rd => callback_evaluation.Rd} | 30 ++--- tests/testthat/test_CallbackEvaluation.R | 116 ++++++++++++++++++ 15 files changed, 514 insertions(+), 232 deletions(-) create mode 100644 R/CallbackEvaluation.R delete mode 100644 R/CallbackWorkhorse.R create mode 100644 R/ContextEvaluation.R delete mode 100644 R/ContextWorkhorse.R rename man/{CallbackWorkhorse.Rd => CallbackEvaluation.Rd} (56%) create mode 100644 man/ContextEvaluation.Rd delete mode 100644 man/ContextWorkhorse.Rd create mode 100644 man/assert_evaluation_callback.Rd delete mode 100644 man/assert_workhorse_callback.Rd rename man/{callback_workhorse.Rd => callback_evaluation.Rd} (56%) create mode 100644 tests/testthat/test_CallbackEvaluation.R diff --git a/DESCRIPTION b/DESCRIPTION index e1742a859..59cf400ef 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -79,8 +79,8 @@ RoxygenNote: 7.3.2 Collate: 'mlr_reflections.R' 'BenchmarkResult.R' - 'CallbackWorkhorse.R' - 'ContextWorkhorse.R' + 'CallbackEvaluation.R' + 'ContextEvaluation.R' 'warn_deprecated.R' 'DataBackend.R' 'DataBackendCbind.R' diff --git a/NAMESPACE b/NAMESPACE index 54c5e2316..20e3e3f8c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -120,8 +120,8 @@ S3method(unmarshal_model,classif.debug_model_marshaled) S3method(unmarshal_model,default) S3method(unmarshal_model,learner_state_marshaled) export(BenchmarkResult) -export(CallbackWorkhorse) -export(ContextWorkhorse) +export(CallbackEvaluation) +export(ContextEvaluation) export(DataBackend) export(DataBackendDataTable) export(DataBackendMatrix) @@ -202,6 +202,8 @@ export(as_tasks) export(as_tasks_unsupervised) export(assert_backend) export(assert_benchmark_result) +export(assert_evaluation_callback) +export(assert_evaluation_callbacks) export(assert_learnable) export(assert_learner) export(assert_learners) @@ -216,12 +218,10 @@ export(assert_row_ids) export(assert_task) export(assert_tasks) export(assert_validate) -export(assert_workhorse_callback) -export(assert_workhorse_callbacks) export(auto_convert) export(benchmark) export(benchmark_grid) -export(callback_workhorse) +export(callback_evaluation) export(check_prediction_data) export(col_info) export(convert_task) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R new file mode 100644 index 000000000..3b8834096 --- /dev/null +++ b/R/CallbackEvaluation.R @@ -0,0 +1,110 @@ +#' @title Create Evaluation Callback +#' +#' @description +#' Callbacks allow to customize the behavior of `resample()` and `benchmark()` in mlr3. +#' +#' @export +CallbackEvaluation= R6Class("CallbackEvaluation", + inherit = Callback, + public = list( + + #' @field on_evaluation_begin (`function()`)\cr + #' Stage called at the beginning of an evaluation. + #' Called in `workhorse()` (internal). + on_evaluation_begin = NULL, + + #' @field on_evaluation_before_train (`function()`)\cr + #' Stage called before training the learner. + #' Called in `workhorse()` (internal). + on_evaluation_before_train = NULL, + + #' @field on_evaluation_before_predict (`function()`)\cr + #' Stage called before predicting. + #' Called in `workhorse()` (internal). + on_evaluation_before_predict = NULL, + + #' @field on_evaluation_end (`function()`)\cr + #' Stage called at the end of an evaluation. + #' Called in `workhorse()` (internal). + on_evaluation_end = NULL + ) +) + +#' @title Create Workhorse Callback +#' +#' @description +#' Function to create a [CallbackEvaluation]. +#' +#' ``` +#' Start Evaluation on Worker +#' - on_evaluation_begin +#' - on_evaluation_before_train +#' - on_evaluation_before_predict +#' - on_evaluation_end +#' End Evaluation on Worker +#' ``` +# +#' @details +#' When implementing a callback, each function must have two arguments named `callback` and `context`. +#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself. +#' Evaluation callbacks access [ContextEvaluation]. +#' +#' @param id (`character(1)`)\cr +#' Identifier for the new instance. +#' @param label (`character(1)`)\cr +#' Label for the new instance. +#' @param man (`character(1)`)\cr +#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. +#' The referenced help package can be opened via method `$help()`. +#' +#' @export +callback_evaluation = function( + id, + label = NA_character_, + man = NA_character_, + on_evaluation_begin = NULL, + on_evaluation_before_train = NULL, + on_evaluation_before_predict = NULL, + on_evaluation_end = NULL + ) { + stages = discard(set_names(list( + on_evaluation_begin, + on_evaluation_before_train, + on_evaluation_before_predict, + on_evaluation_end ), + c( + "on_evaluation_begin", + "on_evaluation_before_train", + "on_evaluation_before_predict", + "on_evaluation_end" + )), is.null) + + walk(stages, function(stage) assert_function(stage, args = c("callback", "context"))) + callback = CallbackEvaluation$new(id, label, man) + iwalk(stages, function(stage, name) callback[[name]] = stage) + callback +} + +#' @title Assertions for Callbacks +#' +#' @description +#' Assertions for [CallbackEvaluation] class. +#' +#' @param callback ([CallbackEvaluation]). +#' @param null_ok (`logical(1)`)\cr +#' If `TRUE`, `NULL` is allowed. +#' +#' @return [CallbackEvaluation | List of [CallbackEvaluation]s. +#' @export +assert_evaluation_callback = function(callback, null_ok = FALSE) { + if (null_ok && is.null(callback)) return(invisible(NULL)) + assert_class(callback, "CallbackEvaluation") + invisible(callback) +} + +#' @export +#' @param callbacks (list of [CallbackEvaluation]). +#' @rdname assert_evaluation_callback +assert_evaluation_callbacks = function(callbacks) { + invisible(lapply(callbacks, assert_evaluation_callback)) +} diff --git a/R/CallbackWorkhorse.R b/R/CallbackWorkhorse.R deleted file mode 100644 index 7c00c10cd..000000000 --- a/R/CallbackWorkhorse.R +++ /dev/null @@ -1,93 +0,0 @@ -#' @title Create Workhorse Callback -#' -#' @description -#' Callbacks allow to customize the behavior of processes in mlr3. -#' -#' @export -CallbackWorkhorse= R6Class("CallbackWorkhorse", - inherit = Callback, - public = list( - - on_workhorse_before_train = NULL, - - on_workhorse_before_predict = NULL, - - on_workhorse_before_result = NULL - ) -) - -#' @title Create Workhorse Callback -#' -#' @description -#' Function to create a [CallbackWorkhorse]. -#' -#' ``` -#' Start Workhorse -#' - on_workhorse_before_train -#' - on_workhorse_before_predict -#' - on_workhorse_before_result -#' End Tuning -#' ``` -# -#' @details -#' When implementing a callback, each function must have two arguments named `callback` and `context`. -#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself. -#' Workhorse callbacks access [ContextWorkhorse]. -#' -#' @param id (`character(1)`)\cr -#' Identifier for the new instance. -#' @param label (`character(1)`)\cr -#' Label for the new instance. -#' @param man (`character(1)`)\cr -#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. -#' The referenced help package can be opened via method `$help()`. -#' -#' @export -#' @inherit CallbackWorkhorse examples -callback_workhorse = function( - id, - label = NA_character_, - man = NA_character_, - on_workhorse_before_train = NULL, - on_workhorse_before_predict = NULL, - on_workhorse_before_result = NULL - ) { - stages = discard(set_names(list( - on_workhorse_before_train, - on_workhorse_before_predict, - on_workhorse_before_result), - c( - "on_workhorse_before_train", - "on_workhorse_before_predict", - "on_workhorse_before_result" - )), is.null) - - walk(stages, function(stage) assert_function(stage, args = c("callback", "context"))) - callback = CallbackWorkhorse$new(id, label, man) - iwalk(stages, function(stage, name) callback[[name]] = stage) - callback -} - -#' @title Assertions for Callbacks -#' -#' @description -#' Assertions for [CallbackWorkhorse] class. -#' -#' @param callback ([CallbackWorkhorse]). -#' @param null_ok (`logical(1)`)\cr -#' If `TRUE`, `NULL` is allowed. -#' -#' @return [CallbackWorkhorse | List of [CallbackWorkhorse]s. -#' @export -assert_workhorse_callback = function(callback, null_ok = FALSE) { - if (null_ok && is.null(callback)) return(invisible(NULL)) - assert_class(callback, "CallbackWorkhorse") - invisible(callback) -} - -#' @export -#' @param callbacks (list of [CallbackWorkhorse]). -#' @rdname assert_workhorse_callback -assert_workhorse_callbacks = function(callbacks) { - invisible(lapply(callbacks, assert_workhorse_callback)) -} diff --git a/R/ContextEvaluation.R b/R/ContextEvaluation.R new file mode 100644 index 000000000..0aa7cc73e --- /dev/null +++ b/R/ContextEvaluation.R @@ -0,0 +1,73 @@ +#' @title Evaluation Context +#' +#' @description +#' A [CallbackEvaluation] accesses and modifies data during [resample()] and [benchmark()] via the `ContextEvaluation`. +#' See the section on fields for a list of modifiable objects. +#' See [callback_evaluation()] for a list of stages that access `ContextEvaluation`. +#' +#' @export +ContextEvaluation = R6Class("ContextEvaluation", + inherit = Context, + public = list( + + #' @field task ([Task])\cr + #' The task to be evaluated. + #' The task is unchanged during the evaluation. + task = NULL, + + #' @field learner ([Learner])\cr + #' The learner to be evaluated. + #' The learner contains the models after stage `on_evaluation_before_train`. + learner = NULL, + + #' @field resampling [Resampling]\cr + #' The resampling strategy to be used. + #' The resampling is unchanged during the evaluation. + resampling = NULL, + + #' @field param_values `list()`\cr + #' The parameter values to be used. + #' Is usually only set while tuning. + param_values = NULL, + + #' @field sets (`list()`)\cr + #' The train and test set. + #' The sets are available on stage `on_evaluation_before_train``. + sets = NULL, + + #' @field test_set (`integer()`)\cr + #' Validation test set. + #' The set is only available when using internal validation. + test_set = NULL, + + #' @field predict_sets (`list()`)\cr + #' The prediction sets stored in `learner$predict_sets`. + #' The sets are available on stage `on_evaluation_before_predict`. + predict_sets = NULL, + + #' @field pdatas (List of [PredictionData])\cr + #' The prediction data. + #' The data is available on stage `on_evaluation_end`. + pdatas = NULL, + + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + #' + #' @param task ([Task])\cr + #' The task to be evaluated. + #' @param learner ([Learner])\cr + #' The learner to be evaluated. + #' @param resampling ([Resampling])\cr + #' The resampling strategy to be used. + #' @param param_values (`list()`)\cr + #' The parameter values to be used. + initialize = function(task, learner, resampling, param_values) { + # no assertions to avoid overhead + self$task = task + self$learner = learner + self$resampling = resampling + + super$initialize(id = "evaluate", label = "Evaluation") + } + ) +) diff --git a/R/ContextWorkhorse.R b/R/ContextWorkhorse.R deleted file mode 100644 index 6ec1e9cbd..000000000 --- a/R/ContextWorkhorse.R +++ /dev/null @@ -1,14 +0,0 @@ -#' @title Batch Tuning Context -#' -#' @description -#' A [CallbackResample] accesses and modifies data during the optimization via the `ContextBatchTuning`. -#' See the section on active bindings for a list of modifiable objects. -#' See [callback_batch_tuning()] for a list of stages that access `ContextBatchTuning`. -#' -#' @export -ContextWorkhorse = R6Class("ContextResample", - inherit = Context, - public = list( - env = NULL - ) -) diff --git a/R/resample.R b/R/resample.R index 6ecdae27b..fd241d936 100644 --- a/R/resample.R +++ b/R/resample.R @@ -67,7 +67,7 @@ resample = function( unmarshal = TRUE, callbacks = NULL ) { - callbacks = assert_workhorse_callbacks(as_callbacks(callbacks)) + callbacks = assert_callbacks(as_callbacks(callbacks)) assert_subset(clone, c("task", "learner", "resampling")) task = assert_task(as_task(task, clone = "task" %in% clone)) diff --git a/R/worker.R b/R/worker.R index 9a2fa938c..5ae271412 100644 --- a/R/worker.R +++ b/R/worker.R @@ -265,8 +265,9 @@ workhorse = function( unmarshal = TRUE, callbacks = NULL ) { - context = ContextWorkhorse$new("workhorse") - context$env = environment() + ctx = ContextEvaluation$new(task, learner, resampling, param_values) + + call_back("on_evaluation_begin", callbacks, ctx) if (!is.null(pb)) { pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) @@ -306,13 +307,13 @@ workhorse = function( lg$info("%s learner '%s' on task '%s' (iter %i/%i)", if (mode == "train") "Applying" else "Hotstarting", learner$id, task$id, iteration, resampling$iters) - sets = list( + ctx$sets = list( train = resampling$train_set(iteration), test = resampling$test_set(iteration) ) # train model - learner = learner$clone() + ctx$learner = learner = learner$clone() if (length(param_values)) { learner$param_set$values = list() learner$param_set$set_values(.values = param_values) @@ -321,12 +322,12 @@ workhorse = function( validate = get0("validate", learner) - test_set = if (identical(validate, "test")) sets$test + ctx$test_set = if (identical(validate, "test")) ctx$sets$test - call_back("on_workhorse_before_train", callbacks, context) + call_back("on_evaluation_before_train", callbacks, ctx) - train_result = learner_train(learner, task, sets[["train"]], test_set, mode = mode) - learner = train_result$learner + train_result = learner_train(learner, task, ctx$sets[["train"]], ctx$test_set, mode = mode) + ctx$learner = learner = train_result$learner # process the model so it can be used for prediction (e.g. marshal for callr prediction), but also # keep a copy of the model in current form in case this is the format that we want to send back to the main process @@ -336,25 +337,23 @@ workhorse = function( ) # predict for each set - predict_sets = learner$predict_sets + ctx$predict_sets = learner$predict_sets # creates the tasks and row_ids for all selected predict sets - pred_data = prediction_tasks_and_sets(task, train_result, validate, sets, predict_sets) + pred_data = prediction_tasks_and_sets(task, train_result, validate, ctx$sets, ctx$predict_sets) + + call_back("on_evaluation_before_predict", callbacks, ctx) pdatas = Map(function(set, row_ids, task) { lg$debug("Creating Prediction for predict set '%s'", set) - call_back("on_workhorse_before_predict", callbacks, context) - learner_predict(learner, task, row_ids) - }, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks) + }, set = ctx$predict_sets, row_ids = pred_data$sets, task = pred_data$tasks) - if (!length(predict_sets)) { + if (!length(ctx$predict_sets)) { learner$state$predict_time = 0L } - pdatas = discard(pdatas, is.null) - - call_back("on_workhorse_before_result", callbacks, context) + ctx$pdatas = discard(pdatas, is.null) # set the model slot after prediction so it can be sent back to the main process process_model_after_predict( @@ -362,6 +361,13 @@ workhorse = function( unmarshal = unmarshal ) + call_back("on_evaluation_end", callbacks, ctx) + + if (!store_models) { + lg$debug("Erasing stored model for learner '%s'", learner$id) + learner$state$model = NULL + } + learner_state = set_class(learner$state, c("learner_state", "list")) list(learner_state = learner_state, prediction = pdatas, param_values = learner$param_set$values, learner_hash = learner_hash) @@ -438,13 +444,10 @@ process_model_before_predict = function(learner, store_models, is_sequential, un } process_model_after_predict = function(learner, store_models, is_sequential, unmarshal, model_copy) { - if (!store_models) { - lg$debug("Erasing stored model for learner '%s'", learner$id) - learner$state$model = NULL - } else if (!is.null(model_copy)) { + if (store_models && !is.null(model_copy)) { # we created a copy of the model to avoid additional marshaling cycles learner$model = model_copy - } else if (!is_sequential || !unmarshal) { + } else if (store_models && !is_sequential || !unmarshal) { # no copy was created, here we make sure that we return the model the way the user wants it learner$model = marshal_model(learner$model, inplace = TRUE) } diff --git a/man/CallbackWorkhorse.Rd b/man/CallbackEvaluation.Rd similarity index 56% rename from man/CallbackWorkhorse.Rd rename to man/CallbackEvaluation.Rd index c07b2c9cf..67dfd7124 100644 --- a/man/CallbackWorkhorse.Rd +++ b/man/CallbackEvaluation.Rd @@ -1,18 +1,39 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/CallbackWorkhorse.R -\name{CallbackWorkhorse} -\alias{CallbackWorkhorse} -\title{Create Workhorse Callback} +% Please edit documentation in R/CallbackEvaluation.R +\name{CallbackEvaluation} +\alias{CallbackEvaluation} +\title{Create Evaluation Callback} \description{ -Callbacks allow to customize the behavior of processes in mlr3. +Callbacks allow to customize the behavior of \code{resample()} and \code{benchmark()} in mlr3. } \section{Super class}{ -\code{\link[mlr3misc:Callback]{mlr3misc::Callback}} -> \code{CallbackWorkhorse} +\code{\link[mlr3misc:Callback]{mlr3misc::Callback}} -> \code{CallbackEvaluation} +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{on_evaluation_begin}}{(\verb{function()})\cr +Stage called at the beginning of an evaluation. +Called in \code{workhorse()} (internal).} + +\item{\code{on_evaluation_before_train}}{(\verb{function()})\cr +Stage called before training the learner. +Called in \code{workhorse()} (internal).} + +\item{\code{on_evaluation_before_predict}}{(\verb{function()})\cr +Stage called before predicting. +Called in \code{workhorse()} (internal).} + +\item{\code{on_evaluation_end}}{(\verb{function()})\cr +Stage called at the end of an evaluation. +Called in \code{workhorse()} (internal).} +} +\if{html}{\out{
}} } \section{Methods}{ \subsection{Public methods}{ \itemize{ -\item \href{#method-CallbackWorkhorse-clone}{\code{CallbackWorkhorse$clone()}} +\item \href{#method-CallbackEvaluation-clone}{\code{CallbackEvaluation$clone()}} } } \if{html}{\out{ @@ -27,12 +48,12 @@ Callbacks allow to customize the behavior of processes in mlr3. }} \if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-CallbackWorkhorse-clone}{}}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CallbackEvaluation-clone}{}}} \subsection{Method \code{clone()}}{ The objects of this class are cloneable with this method. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{CallbackWorkhorse$clone(deep = FALSE)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{CallbackEvaluation$clone(deep = FALSE)}\if{html}{\out{
}} } \subsection{Arguments}{ diff --git a/man/ContextEvaluation.Rd b/man/ContextEvaluation.Rd new file mode 100644 index 000000000..6f50e39b2 --- /dev/null +++ b/man/ContextEvaluation.Rd @@ -0,0 +1,110 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/ContextEvaluation.R +\name{ContextEvaluation} +\alias{ContextEvaluation} +\title{Evaluation Context} +\description{ +A \link{CallbackEvaluation} accesses and modifies data during \code{\link[=resample]{resample()}} and \code{\link[=benchmark]{benchmark()}} via the \code{ContextEvaluation}. +See the section on fields for a list of modifiable objects. +See \code{\link[=callback_evaluation]{callback_evaluation()}} for a list of stages that access \code{ContextEvaluation}. +} +\section{Super class}{ +\code{\link[mlr3misc:Context]{mlr3misc::Context}} -> \code{ContextEvaluation} +} +\section{Public fields}{ +\if{html}{\out{
}} +\describe{ +\item{\code{task}}{(\link{Task})\cr +The task to be evaluated. +The task is unchanged during the evaluation.} + +\item{\code{learner}}{(\link{Learner})\cr +The learner to be evaluated. +The learner contains the models after stage \code{on_evaluation_before_train}.} + +\item{\code{resampling}}{\link{Resampling}\cr +The resampling strategy to be used. +The resampling is unchanged during the evaluation.} + +\item{\code{param_values}}{\code{list()}\cr +The parameter values to be used. +Is usually only set while tuning.} + +\item{\code{sets}}{(\code{list()})\cr +The train and test set. +The sets are available on stage `on_evaluation_before_train``.} + +\item{\code{test_set}}{(\code{integer()})\cr +Validation test set. +The set is only available when using internal validation.} + +\item{\code{predict_sets}}{(\code{list()})\cr +The prediction sets stored in \code{learner$predict_sets}. +The sets are available on stage \code{on_evaluation_before_predict}.} + +\item{\code{pdatas}}{(List of \link{PredictionData})\cr +The prediction data. +The data is available on stage \code{on_evaluation_end}.} +} +\if{html}{\out{
}} +} +\section{Methods}{ +\subsection{Public methods}{ +\itemize{ +\item \href{#method-ContextEvaluation-new}{\code{ContextEvaluation$new()}} +\item \href{#method-ContextEvaluation-clone}{\code{ContextEvaluation$clone()}} +} +} +\if{html}{\out{ +
Inherited methods + +
+}} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ContextEvaluation-new}{}}} +\subsection{Method \code{new()}}{ +Creates a new instance of this \link[R6:R6Class]{R6} class. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ContextEvaluation$new(task, learner, resampling, param_values)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{task}}{(\link{Task})\cr +The task to be evaluated.} + +\item{\code{learner}}{(\link{Learner})\cr +The learner to be evaluated.} + +\item{\code{resampling}}{(\link{Resampling})\cr +The resampling strategy to be used.} + +\item{\code{param_values}}{(\code{list()})\cr +The parameter values to be used.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ContextEvaluation-clone}{}}} +\subsection{Method \code{clone()}}{ +The objects of this class are cloneable with this method. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ContextEvaluation$clone(deep = FALSE)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{deep}}{Whether to make a deep clone.} +} +\if{html}{\out{
}} +} +} +} diff --git a/man/ContextWorkhorse.Rd b/man/ContextWorkhorse.Rd deleted file mode 100644 index 5b0503263..000000000 --- a/man/ContextWorkhorse.Rd +++ /dev/null @@ -1,46 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/ContextWorkhorse.R -\name{ContextWorkhorse} -\alias{ContextWorkhorse} -\title{Batch Tuning Context} -\description{ -A \link{CallbackResample} accesses and modifies data during the optimization via the \code{ContextBatchTuning}. -See the section on active bindings for a list of modifiable objects. -See \code{\link[=callback_batch_tuning]{callback_batch_tuning()}} for a list of stages that access \code{ContextBatchTuning}. -} -\section{Super class}{ -\code{\link[mlr3misc:Context]{mlr3misc::Context}} -> \code{ContextResample} -} -\section{Methods}{ -\subsection{Public methods}{ -\itemize{ -\item \href{#method-ContextResample-clone}{\code{ContextWorkhorse$clone()}} -} -} -\if{html}{\out{ -
Inherited methods - -
-}} -\if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-ContextResample-clone}{}}} -\subsection{Method \code{clone()}}{ -The objects of this class are cloneable with this method. -\subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ContextWorkhorse$clone(deep = FALSE)}\if{html}{\out{
}} -} - -\subsection{Arguments}{ -\if{html}{\out{
}} -\describe{ -\item{\code{deep}}{Whether to make a deep clone.} -} -\if{html}{\out{
}} -} -} -} diff --git a/man/assert_evaluation_callback.Rd b/man/assert_evaluation_callback.Rd new file mode 100644 index 000000000..5f54d4fd5 --- /dev/null +++ b/man/assert_evaluation_callback.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/CallbackEvaluation.R +\name{assert_evaluation_callback} +\alias{assert_evaluation_callback} +\alias{assert_evaluation_callbacks} +\title{Assertions for Callbacks} +\usage{ +assert_evaluation_callback(callback, null_ok = FALSE) + +assert_evaluation_callbacks(callbacks) +} +\arguments{ +\item{callback}{(\link{CallbackEvaluation}).} + +\item{null_ok}{(\code{logical(1)})\cr +If \code{TRUE}, \code{NULL} is allowed.} + +\item{callbacks}{(list of \link{CallbackEvaluation}).} +} +\value{ +[CallbackEvaluation | List of \link{CallbackEvaluation}s. +} +\description{ +Assertions for \link{CallbackEvaluation} class. +} diff --git a/man/assert_workhorse_callback.Rd b/man/assert_workhorse_callback.Rd deleted file mode 100644 index a5add6c9b..000000000 --- a/man/assert_workhorse_callback.Rd +++ /dev/null @@ -1,25 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/CallbackWorkhorse.R -\name{assert_workhorse_callback} -\alias{assert_workhorse_callback} -\alias{assert_workhorse_callbacks} -\title{Assertions for Callbacks} -\usage{ -assert_workhorse_callback(callback, null_ok = FALSE) - -assert_workhorse_callbacks(callbacks) -} -\arguments{ -\item{callback}{(\link{CallbackWorkhorse}).} - -\item{null_ok}{(\code{logical(1)})\cr -If \code{TRUE}, \code{NULL} is allowed.} - -\item{callbacks}{(list of \link{CallbackWorkhorse}).} -} -\value{ -[CallbackWorkhorse | List of \link{CallbackWorkhorse}s. -} -\description{ -Assertions for \link{CallbackWorkhorse} class. -} diff --git a/man/callback_workhorse.Rd b/man/callback_evaluation.Rd similarity index 56% rename from man/callback_workhorse.Rd rename to man/callback_evaluation.Rd index e258d0ab3..04db4d220 100644 --- a/man/callback_workhorse.Rd +++ b/man/callback_evaluation.Rd @@ -1,16 +1,17 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/CallbackWorkhorse.R -\name{callback_workhorse} -\alias{callback_workhorse} +% Please edit documentation in R/CallbackEvaluation.R +\name{callback_evaluation} +\alias{callback_evaluation} \title{Create Workhorse Callback} \usage{ -callback_workhorse( +callback_evaluation( id, label = NA_character_, man = NA_character_, - on_workhorse_before_train = NULL, - on_workhorse_before_predict = NULL, - on_workhorse_before_result = NULL + on_evaluation_begin = NULL, + on_evaluation_before_train = NULL, + on_evaluation_before_predict = NULL, + on_evaluation_end = NULL ) } \arguments{ @@ -25,17 +26,18 @@ String in the format \verb{[pkg]::[topic]} pointing to a manual page for this ob The referenced help package can be opened via method \verb{$help()}.} } \description{ -Function to create a \link{CallbackWorkhorse}. +Function to create a \link{CallbackEvaluation}. -\if{html}{\out{
}}\preformatted{Start Workhorse - - on_workhorse_before_train - - on_workhorse_before_predict - - on_workhorse_before_result -End Tuning +\if{html}{\out{
}}\preformatted{Start Evaluation on Worker + - on_evaluation_begin + - on_evaluation_before_train + - on_evaluation_before_predict + - on_evaluation_end +End Evaluation on Worker }\if{html}{\out{
}} } \details{ When implementing a callback, each function must have two arguments named \code{callback} and \code{context}. A callback can write data to the state (\verb{$state}), e.g. settings that affect the callback itself. -Workhorse callbacks access \link{ContextWorkhorse}. +Evaluation callbacks access \link{ContextEvaluation}. } diff --git a/tests/testthat/test_CallbackEvaluation.R b/tests/testthat/test_CallbackEvaluation.R new file mode 100644 index 000000000..dccd455de --- /dev/null +++ b/tests/testthat/test_CallbackEvaluation.R @@ -0,0 +1,116 @@ +test_that("on_evaluation_begin works", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + + callback = callback_evaluation("test", + + on_evaluation_begin = function(callback, context) { + expect_task(context$task) + expect_learner(context$learner) + expect_resampling(context$resampling) + expect_null(context$param_values) + expect_null(context$sets) + expect_null(context$test_set) + expect_null(context$predict_sets) + expect_null(context$pdatas) + } + ) + + resample(task, learner, resampling, callbacks = callback) + +}) + +test_that("on_evaluation_before_train works", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + + callback = callback_evaluation("test", + + on_evaluation_before_train = function(callback, context) { + expect_task(context$task) + expect_learner(context$learner) + expect_resampling(context$resampling) + expect_null(context$param_values) + expect_list(context$sets, len = 2) + expect_equal(names(context$sets), c("train", "test")) + expect_integer(context$sets$train) + expect_integer(context$sets$test) + expect_null(context$test_set) + expect_null(context$predict_sets) + expect_null(context$pdatas) + } + ) + + resample(task, learner, resampling, callbacks = callback) + +}) + +test_that("on_evaluation_before_predict works", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + + callback = callback_evaluation("test", + + on_evaluation_before_predict = function(callback, context) { + expect_task(context$task) + expect_learner(context$learner) + expect_resampling(context$resampling) + expect_null(context$param_values) + expect_list(context$sets, len = 2) + expect_equal(names(context$sets), c("train", "test")) + expect_integer(context$sets$train) + expect_integer(context$sets$test) + expect_class(context$learner$model, "rpart") + expect_null(context$test_set) + expect_equal(context$predict_sets, "test") + expect_null(context$pdatas) + } + ) + + resample(task, learner, resampling, callbacks = callback) +}) + +test_that("on_evaluation_end works", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + + callback = callback_evaluation("test", + + on_evaluation_end = function(callback, context) { + expect_task(context$task) + expect_learner(context$learner) + expect_resampling(context$resampling) + expect_null(context$param_values) + expect_list(context$sets, len = 2) + expect_equal(names(context$sets), c("train", "test")) + expect_integer(context$sets$train) + expect_integer(context$sets$test) + expect_class(context$learner$model, "rpart") + expect_null(context$test_set) + expect_equal(context$predict_sets, "test") + expect_class(context$pdatas$test, "PredictionData") + } + ) + + resample(task, learner, resampling, callbacks = callback) +}) + +test_that("writing to learner$state works", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + + callback = callback_evaluation("test", + + on_evaluation_end = function(callback, context) { + context$learner$state$test = 1 + } + ) + + rr = resample(task, learner, resampling, callbacks = callback) + expect_equal(rr$learners[[1]]$state$test, 1) +}) From 465d7c91b6118832d6a6e99427a681238a34cffe Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 27 Nov 2024 12:17:39 +0100 Subject: [PATCH 06/54] ... --- R/mlr_callbacks.R | 38 +++++++++++++++++++++++++++++ R/zzz.R | 4 +++ tests/testthat/test_mlr_callbacks.R | 13 ++++++++++ 3 files changed, 55 insertions(+) create mode 100644 R/mlr_callbacks.R create mode 100644 tests/testthat/test_mlr_callbacks.R diff --git a/R/mlr_callbacks.R b/R/mlr_callbacks.R new file mode 100644 index 000000000..c9ae5f504 --- /dev/null +++ b/R/mlr_callbacks.R @@ -0,0 +1,38 @@ +#' @title Score Measures Callback +#' +#' @include CallbackEvaluation.R +#' @name mlr3.score_measures +#' +#' @description +#' This `CallbackEvaluation` scores measures directly on the worker. +#' This way measures that require a model can be scores without saving the model. +#' +#' @examples +#' clbk("mlr3.score_measures", measures = msr("classif.ce")) +#' +#' task = tsk("pima") +#' learner = lrn("classif.rpart") +#' resampling = rsmp("cv", folds = 3) +#' +#' callback = clbk("mlr3.score_measures", measures = msr("selected_features")) +#' +#' rr = resample(task, learner, resampling = resampling, callbacks = callback) +#' +#' rr$learners[[1]]$state$selected_features +NULL + +load_callback_score_measures = function() { + callback_evaluation("mlr3.score_measures", + label = "Score Measures Callback", + man = "mlr3::mlr3.score_measures", + + on_evaluation_end = function(callback, context) { + measures = as_measures(callback$state$measures) + + # Score measures on the test set + pred = as_prediction(context$pdatas$test) + res = pred$score(measures, context$task, context$learner) + context$learner$state$selected_features = res + } + ) +} diff --git a/R/zzz.R b/R/zzz.R index b34c219ba..d671c6746 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -74,6 +74,10 @@ dummy_import = function() { # nocov start backports::import(pkgname) + # callbacks + x = utils::getFromNamespace("mlr_callbacks", ns = "mlr3misc") + x$add("mlr3.score_measures", load_callback_score_measures) + # setup logger lg = lgr::get_logger(pkgname) assign("lg", lg, envir = parent.env(environment())) diff --git a/tests/testthat/test_mlr_callbacks.R b/tests/testthat/test_mlr_callbacks.R new file mode 100644 index 000000000..c2669188b --- /dev/null +++ b/tests/testthat/test_mlr_callbacks.R @@ -0,0 +1,13 @@ +test_that("score_measure works", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + + callback = clbk("mlr3.score_measures", measures = msr("selected_features")) + + rr = resample(task, learner, resampling = resampling, callbacks = callback) + + walk(rr$learners, function(learner) { + expect_number(learner$state$selected_features) + }) +}) From a2ad28fb7b1f73b3d41906448f71433821b54a13 Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 27 Nov 2024 12:23:29 +0100 Subject: [PATCH 07/54] ... --- .gitignore | 4 +++- DESCRIPTION | 1 + man/mlr3.score_measures.Rd | 22 ++++++++++++++++++++++ 3 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 man/mlr3.score_measures.Rd diff --git a/.gitignore b/.gitignore index d48cbc176..093a40e42 100644 --- a/.gitignore +++ b/.gitignore @@ -24,7 +24,8 @@ .LSOverride # Icon must end with two \r -Icon +Icon + # Thumbnails ._* @@ -180,3 +181,4 @@ revdep/ # misc Meta/ +Rplots.pdf diff --git a/DESCRIPTION b/DESCRIPTION index 59cf400ef..ed4ce7203 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -191,6 +191,7 @@ Collate: 'helper_print.R' 'install_pkgs.R' 'marshal.R' + 'mlr_callbacks.R' 'mlr_sugar.R' 'mlr_test_helpers.R' 'partition.R' diff --git a/man/mlr3.score_measures.Rd b/man/mlr3.score_measures.Rd new file mode 100644 index 000000000..4cf9fa49f --- /dev/null +++ b/man/mlr3.score_measures.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/mlr_callbacks.R +\name{mlr3.score_measures} +\alias{mlr3.score_measures} +\title{Score Measures Callback} +\description{ +This \code{CallbackEvaluation} scores measures directly on the worker. +This way measures that require a model can be scores without saving the model. +} +\examples{ +clbk("mlr3.score_measures", measures = msr("classif.ce")) + +task = tsk("pima") +learner = lrn("classif.rpart") +resampling = rsmp("cv", folds = 3) + +callback = clbk("mlr3.score_measures", measures = msr("selected_features")) + +rr = resample(task, learner, resampling = resampling, callbacks = callback) + +rr$learners[[1]]$state$selected_features +} From 01e6a4b6d73ff6000d55ed5c2c4ffb47ebaaa0dc Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 27 Nov 2024 14:03:34 +0100 Subject: [PATCH 08/54] ... --- R/ContextEvaluation.R | 5 + R/ResampleResult.R | 10 +- R/ResultData.R | 18 +++- R/resample.R | 3 +- R/worker.R | 7 +- man/ContextEvaluation.Rd | 4 + man/ResampleResult.Rd | 3 + man/ResultData.Rd | 22 +++++ pkgdown/_pkgdown.yml | 4 + tests/testthat/test_CallbackEvaluation.R | 120 ++++++++++++++--------- tests/testthat/test_mlr_callbacks.R | 2 + 11 files changed, 141 insertions(+), 57 deletions(-) diff --git a/R/ContextEvaluation.R b/R/ContextEvaluation.R index 0aa7cc73e..9c366aa5f 100644 --- a/R/ContextEvaluation.R +++ b/R/ContextEvaluation.R @@ -50,6 +50,11 @@ ContextEvaluation = R6Class("ContextEvaluation", #' The data is available on stage `on_evaluation_end`. pdatas = NULL, + #' @field data_extra (list())\cr + #' Data saved in the [ResampleResult] or [BenchmarkResult]. + #' Use this field to save results. + data_extra = NULL, + #' @description #' Creates a new instance of this [R6][R6::R6Class] class. #' diff --git a/R/ResampleResult.R b/R/ResampleResult.R index 83b79ada3..c31343ca7 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -335,6 +335,12 @@ ResampleResult = R6Class("ResampleResult", private$.data$learners(private$.view)$learner }, + #' @field data_extra (list())\cr + #' Additional data stored in the [ResampleResult]. + data_extra = function() { + private$.data$data_extra(private$.view) + }, + #' @field warnings ([data.table::data.table()])\cr #' A table with all warning messages. #' Column names are `"iteration"` and `"msg"`. @@ -370,10 +376,10 @@ ResampleResult = R6Class("ResampleResult", ) #' @export -as.data.table.ResampleResult = function(x, ..., predict_sets = "test") { # nolint +as.data.table.ResampleResult = function(x, ..., predict_sets = "test", data_extra = FALSE) { # nolint private = get_private(x) tab = private$.data$as_data_table(view = private$.view, predict_sets = predict_sets) - tab[, c("task", "learner", "resampling", "iteration", "prediction"), with = FALSE] + tab[, c("task", "learner", "resampling", "iteration", "prediction", if (data_extra) "data_extra"), with = FALSE] } # #' @export diff --git a/R/ResultData.R b/R/ResultData.R index a2c74feb0..c6df7245d 100644 --- a/R/ResultData.R +++ b/R/ResultData.R @@ -18,12 +18,12 @@ #' print(ResultData$new()$data) ResultData = R6Class("ResultData", public = list( + #' @field data (`list()`)\cr #' List of [data.table::data.table()], arranged in a star schema. #' Do not operate directly on this list. data = NULL, - #' @description #' Creates a new instance of this [R6][R6::R6Class] class. #' An alternative construction method is provided by [as_result_data()]. @@ -40,12 +40,12 @@ ResultData = R6Class("ResultData", self$data = star_init() } else { assert_names(names(data), - permutation.of = c("task", "learner", "learner_state", "resampling", "iteration", "param_values", "prediction", "uhash", "learner_hash")) + permutation.of = c("task", "learner", "learner_state", "resampling", "iteration", "param_values", "prediction", "uhash", "learner_hash", "data_extra")) if (nrow(data) == 0L) { self$data = star_init() } else { - setcolorder(data, c("uhash", "iteration", "learner_state", "prediction", "task", "learner", "resampling", "param_values", "learner_hash")) + setcolorder(data, c("uhash", "iteration", "learner_state", "prediction", "data_extra", "task", "learner", "resampling", "param_values", "learner_hash")) uhashes = data.table(uhash = unique(data$uhash)) setkeyv(data, c("uhash", "iteration")) @@ -189,6 +189,15 @@ ResultData = R6Class("ResultData", do.call(c, self$predictions(view = view, predict_sets = predict_sets)) }, + #' @description + #' Returns additional data stored. + #' + #' @return `list()`. + data_extra = function(view = NULL) { + .__ii__ = private$get_view_index(view) + self$data$fact[.__ii__, "data_extra", with = FALSE][[1L]] + }, + #' @description #' Combines multiple [ResultData] objects, modifying `self` in-place. #' @@ -315,7 +324,7 @@ ResultData = R6Class("ResultData", } cns = c("uhash", "task", "task_hash", "learner", "learner_hash", "learner_param_vals", "resampling", - "resampling_hash", "iteration", "prediction") + "resampling_hash", "iteration", "prediction", "data_extra") merge(self$data$uhashes, tab[, cns, with = FALSE], by = "uhash", sort = FALSE) }, @@ -375,6 +384,7 @@ star_init = function() { iteration = integer(), learner_state = list(), prediction = list(), + data_extra = list(), learner_hash = character(), task_hash = character(), diff --git a/R/resample.R b/R/resample.R index fd241d936..4cbcc2d64 100644 --- a/R/resample.R +++ b/R/resample.R @@ -140,7 +140,8 @@ resample = function( prediction = map(res, "prediction"), uhash = UUIDgenerate(), param_values = map(res, "param_values"), - learner_hash = map_chr(res, "learner_hash") + learner_hash = map_chr(res, "learner_hash"), + data_extra = map(res, "data_extra") ) result_data = ResultData$new(data, store_backends = store_backends) diff --git a/R/worker.R b/R/worker.R index 5ae271412..aced1724b 100644 --- a/R/worker.R +++ b/R/worker.R @@ -370,7 +370,12 @@ workhorse = function( learner_state = set_class(learner$state, c("learner_state", "list")) - list(learner_state = learner_state, prediction = pdatas, param_values = learner$param_set$values, learner_hash = learner_hash) + list( + learner_state = learner_state, + prediction = pdatas, + param_values = learner$param_set$values, + learner_hash = learner_hash, + data_extra = ctx$data_extra) } # creates the tasks and row ids for the selected predict sets diff --git a/man/ContextEvaluation.Rd b/man/ContextEvaluation.Rd index 6f50e39b2..78bb4fac8 100644 --- a/man/ContextEvaluation.Rd +++ b/man/ContextEvaluation.Rd @@ -45,6 +45,10 @@ The sets are available on stage \code{on_evaluation_before_predict}.} \item{\code{pdatas}}{(List of \link{PredictionData})\cr The prediction data. The data is available on stage \code{on_evaluation_end}.} + +\item{\code{data_extra}}{(list())\cr +Data saved in the \link{ResampleResult} or \link{BenchmarkResult}. +Use this field to save results.} } \if{html}{\out{
}} } diff --git a/man/ResampleResult.Rd b/man/ResampleResult.Rd index 8ea3c7d18..39b584d6f 100644 --- a/man/ResampleResult.Rd +++ b/man/ResampleResult.Rd @@ -82,6 +82,9 @@ Instantiated \link{Resampling} object which stores the splits into training and \item{\code{learners}}{(list of \link{Learner})\cr List of trained learners, sorted by resampling iteration.} +\item{\code{data_extra}}{(list())\cr +Additional data stored in the \link{ResampleResult}.} + \item{\code{warnings}}{(\code{\link[data.table:data.table]{data.table::data.table()}})\cr A table with all warning messages. Column names are \code{"iteration"} and \code{"msg"}. diff --git a/man/ResultData.Rd b/man/ResultData.Rd index 1c75c8998..2bb02bf3e 100644 --- a/man/ResultData.Rd +++ b/man/ResultData.Rd @@ -48,6 +48,7 @@ Returns \code{NULL} if the \link{ResultData} is empty.} \item \href{#method-ResultData-resamplings}{\code{ResultData$resamplings()}} \item \href{#method-ResultData-predictions}{\code{ResultData$predictions()}} \item \href{#method-ResultData-prediction}{\code{ResultData$prediction()}} +\item \href{#method-ResultData-data_extra}{\code{ResultData$data_extra()}} \item \href{#method-ResultData-combine}{\code{ResultData$combine()}} \item \href{#method-ResultData-sweep}{\code{ResultData$sweep()}} \item \href{#method-ResultData-marshal}{\code{ResultData$marshal()}} @@ -299,6 +300,27 @@ Default is \code{"test"}.} } } \if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ResultData-data_extra}{}}} +\subsection{Method \code{data_extra()}}{ +Returns additional data stored. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{ResultData$data_extra(view = NULL)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{view}}{\code{character(1)}\cr +Single \code{uhash} to restrict the results to.} +} +\if{html}{\out{
}} +} +\subsection{Returns}{ +\code{list()}. +} +} +\if{html}{\out{
}} \if{html}{\out{}} \if{latex}{\out{\hypertarget{method-ResultData-combine}{}}} \subsection{Method \code{combine()}}{ diff --git a/pkgdown/_pkgdown.yml b/pkgdown/_pkgdown.yml index 57c90b87b..dc9efd23e 100644 --- a/pkgdown/_pkgdown.yml +++ b/pkgdown/_pkgdown.yml @@ -111,6 +111,10 @@ reference: - mlr_sugar - mlr_reflections - set_threads + - title: Callbacks + contents: + - CallbackEvaluation + - ContextEvaluation - title: Internal Objects and Functions contents: - marshaling diff --git a/tests/testthat/test_CallbackEvaluation.R b/tests/testthat/test_CallbackEvaluation.R index dccd455de..2cec38903 100644 --- a/tests/testthat/test_CallbackEvaluation.R +++ b/tests/testthat/test_CallbackEvaluation.R @@ -6,19 +6,19 @@ test_that("on_evaluation_begin works", { callback = callback_evaluation("test", on_evaluation_begin = function(callback, context) { - expect_task(context$task) - expect_learner(context$learner) - expect_resampling(context$resampling) - expect_null(context$param_values) - expect_null(context$sets) - expect_null(context$test_set) - expect_null(context$predict_sets) - expect_null(context$pdatas) + # expect_* does not work + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + assert_null(context$param_values) + assert_null(context$sets) + assert_null(context$test_set) + assert_null(context$predict_sets) + assert_null(context$pdatas) } ) - resample(task, learner, resampling, callbacks = callback) - + expect_resample_result(resample(task, learner, resampling, callbacks = callback)) }) test_that("on_evaluation_before_train works", { @@ -29,21 +29,21 @@ test_that("on_evaluation_before_train works", { callback = callback_evaluation("test", on_evaluation_before_train = function(callback, context) { - expect_task(context$task) - expect_learner(context$learner) - expect_resampling(context$resampling) - expect_null(context$param_values) - expect_list(context$sets, len = 2) - expect_equal(names(context$sets), c("train", "test")) - expect_integer(context$sets$train) - expect_integer(context$sets$test) - expect_null(context$test_set) - expect_null(context$predict_sets) - expect_null(context$pdatas) + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + assert_null(context$param_values) + assert_list(context$sets, len = 2) + assert_names(names(context$sets), identical.to = c("train", "test")) + assert_integer(context$sets$train) + assert_integer(context$sets$test) + assert_null(context$test_set) + assert_null(context$predict_sets) + assert_null(context$pdatas) } ) - resample(task, learner, resampling, callbacks = callback) + expect_resample_result(resample(task, learner, resampling, callbacks = callback)) }) @@ -55,22 +55,22 @@ test_that("on_evaluation_before_predict works", { callback = callback_evaluation("test", on_evaluation_before_predict = function(callback, context) { - expect_task(context$task) - expect_learner(context$learner) - expect_resampling(context$resampling) - expect_null(context$param_values) - expect_list(context$sets, len = 2) - expect_equal(names(context$sets), c("train", "test")) - expect_integer(context$sets$train) - expect_integer(context$sets$test) - expect_class(context$learner$model, "rpart") - expect_null(context$test_set) - expect_equal(context$predict_sets, "test") - expect_null(context$pdatas) + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + assert_null(context$param_values) + assert_list(context$sets, len = 2) + assert_names(names(context$sets), identical.to = c("train", "test")) + assert_integer(context$sets$train) + assert_integer(context$sets$test) + assert_class(context$learner$model, "rpart") + assert_null(context$test_set) + assert_true(context$predict_sets == "test") + assert_null(context$pdatas) } ) - resample(task, learner, resampling, callbacks = callback) + expect_resample_result(resample(task, learner, resampling, callbacks = callback)) }) test_that("on_evaluation_end works", { @@ -81,22 +81,22 @@ test_that("on_evaluation_end works", { callback = callback_evaluation("test", on_evaluation_end = function(callback, context) { - expect_task(context$task) - expect_learner(context$learner) - expect_resampling(context$resampling) - expect_null(context$param_values) - expect_list(context$sets, len = 2) - expect_equal(names(context$sets), c("train", "test")) - expect_integer(context$sets$train) - expect_integer(context$sets$test) - expect_class(context$learner$model, "rpart") - expect_null(context$test_set) - expect_equal(context$predict_sets, "test") - expect_class(context$pdatas$test, "PredictionData") + assert_task(context$task) + assert_learner(context$learner) + assert_resampling(context$resampling) + assert_null(context$param_values) + assert_list(context$sets, len = 2) + assert_names(names(context$sets), identical.to = c("train", "test")) + assert_integer(context$sets$train) + assert_integer(context$sets$test) + assert_class(context$learner$model, "rpart") + assert_null(context$test_set) + assert_true(context$predict_sets == "test") + assert_class(context$pdatas$test, "PredictionData") } ) - resample(task, learner, resampling, callbacks = callback) + expect_resample_result(resample(task, learner, resampling, callbacks = callback)) }) test_that("writing to learner$state works", { @@ -112,5 +112,27 @@ test_that("writing to learner$state works", { ) rr = resample(task, learner, resampling, callbacks = callback) - expect_equal(rr$learners[[1]]$state$test, 1) + + walk(rr$learners, function(learner) { + expect_equal(learner$state$test, 1) + }) +}) + +test_that("writing to data_extra works", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + + callback = callback_evaluation("test", + + on_evaluation_end = function(callback, context) { + context$data_extra$test = 1 + } + ) + + rr = resample(task, learner, resampling, callbacks = callback) + + walk(rr$data_extra, function(x) { + expect_equal(x$test, 1) + }) }) diff --git a/tests/testthat/test_mlr_callbacks.R b/tests/testthat/test_mlr_callbacks.R index c2669188b..7beaa5de7 100644 --- a/tests/testthat/test_mlr_callbacks.R +++ b/tests/testthat/test_mlr_callbacks.R @@ -10,4 +10,6 @@ test_that("score_measure works", { walk(rr$learners, function(learner) { expect_number(learner$state$selected_features) }) + + expect_names(names(as.data.table(rr, data_extra = TRUE)), must.include = "data_extra") }) From ae85cc07840cf615cc22742419bdd7d25bee01ed Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 27 Nov 2024 17:30:07 +0100 Subject: [PATCH 09/54] ... --- R/benchmark.R | 9 ++++++--- R/resample.R | 4 ++-- man-roxygen/param_callbacks.R | 3 +++ man/benchmark.Rd | 7 ++++++- man/resample.Rd | 4 ++++ 5 files changed, 21 insertions(+), 6 deletions(-) create mode 100644 man-roxygen/param_callbacks.R diff --git a/R/benchmark.R b/R/benchmark.R index 5f26c38a6..07e862d65 100644 --- a/R/benchmark.R +++ b/R/benchmark.R @@ -18,6 +18,7 @@ #' @template param_allow_hotstart #' @template param_clone #' @template param_unmarshal +#' @template param_callbacks #' #' @return [BenchmarkResult]. #' @@ -81,7 +82,7 @@ #' ## Get the training set of the 2nd iteration of the featureless learner on penguins #' rr = bmr$aggregate()[learner_id == "classif.featureless"]$resample_result[[1]] #' rr$resampling$train_set(2) -benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) { +benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE, callbacks = NULL) { assert_subset(clone, c("task", "learner", "resampling")) assert_data_frame(design, min.rows = 1L) assert_names(names(design), must.include = c("task", "learner", "resampling")) @@ -96,6 +97,7 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps } assert_flag(store_models) assert_flag(store_backends) + callbacks = assert_callbacks(as_callbacks(callbacks)) # check for multiple task types task_types = unique(map_chr(design$task, "task_type")) @@ -187,14 +189,15 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps res = future_map(n, workhorse, task = grid$task, learner = grid$learner, resampling = grid$resampling, iteration = grid$iteration, param_values = grid$param_values, mode = grid$mode, - MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal) + MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, callbacks = callbacks) ) grid = insert_named(grid, list( learner_state = map(res, "learner_state"), prediction = map(res, "prediction"), param_values = map(res, "param_values"), - learner_hash = map_chr(res, "learner_hash") + learner_hash = map_chr(res, "learner_hash"), + data_extra = map(res, "data_extra") )) lg$info("Finished benchmark") diff --git a/R/resample.R b/R/resample.R index 4cbcc2d64..da368373a 100644 --- a/R/resample.R +++ b/R/resample.R @@ -15,6 +15,7 @@ #' @template param_allow_hotstart #' @template param_clone #' @template param_unmarshal +#' @template param_callbacks #' @return [ResampleResult]. #' #' @template section_predict_sets @@ -67,8 +68,6 @@ resample = function( unmarshal = TRUE, callbacks = NULL ) { - callbacks = assert_callbacks(as_callbacks(callbacks)) - assert_subset(clone, c("task", "learner", "resampling")) task = assert_task(as_task(task, clone = "task" %in% clone)) learner = assert_learner(as_learner(learner, clone = "learner" %in% clone, discard_state = TRUE)) @@ -78,6 +77,7 @@ resample = function( # this does not check the internal validation task as it might not be set yet assert_learnable(task, learner) assert_flag(unmarshal) + callbacks = assert_callbacks(as_callbacks(callbacks)) set_encapsulation(list(learner), encapsulate) if (!resampling$is_instantiated) { diff --git a/man-roxygen/param_callbacks.R b/man-roxygen/param_callbacks.R new file mode 100644 index 000000000..cdb286953 --- /dev/null +++ b/man-roxygen/param_callbacks.R @@ -0,0 +1,3 @@ +#' @param callbacks (List of [mlr3misc::Callback])\cr +#' Callbacks to be executed during the resampling process. +#' See [CallbackEvaluation] and [ContextEvaluation] for details. diff --git a/man/benchmark.Rd b/man/benchmark.Rd index 9cfc995f7..9f53cecfd 100644 --- a/man/benchmark.Rd +++ b/man/benchmark.Rd @@ -11,7 +11,8 @@ benchmark( encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), - unmarshal = TRUE + unmarshal = TRUE, + callbacks = NULL ) } \arguments{ @@ -63,6 +64,10 @@ Per default, all input objects are cloned.} Whether to unmarshal learners that were marshaled during the execution. If \code{TRUE} all models are stored in unmarshaled form. If \code{FALSE}, all learners (that need marshaling) are stored in marshaled form.} + +\item{callbacks}{(List of \link[mlr3misc:Callback]{mlr3misc::Callback})\cr +Callbacks to be executed during the resampling process. +See \link{CallbackEvaluation} and \link{ContextEvaluation} for details.} } \value{ \link{BenchmarkResult}. diff --git a/man/resample.Rd b/man/resample.Rd index 00378067e..41340d299 100644 --- a/man/resample.Rd +++ b/man/resample.Rd @@ -65,6 +65,10 @@ Per default, all input objects are cloned.} Whether to unmarshal learners that were marshaled during the execution. If \code{TRUE} all models are stored in unmarshaled form. If \code{FALSE}, all learners (that need marshaling) are stored in marshaled form.} + +\item{callbacks}{(List of \link[mlr3misc:Callback]{mlr3misc::Callback})\cr +Callbacks to be executed during the resampling process. +See \link{CallbackEvaluation} and \link{ContextEvaluation} for details.} } \value{ \link{ResampleResult}. From 8dd2ba5eb725a26e55e74531c8babe1694843d0f Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 28 Nov 2024 14:24:40 +0100 Subject: [PATCH 10/54] ... --- NAMESPACE | 1 + R/as_result_data.R | 7 ++++--- R/resample.R | 8 ++++++++ man/mlr_assertions.Rd | 9 +++++++++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 20e3e3f8c..aa85fe9fa 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -215,6 +215,7 @@ export(assert_resample_result) export(assert_resampling) export(assert_resamplings) export(assert_row_ids) +export(assert_scorable) export(assert_task) export(assert_tasks) export(assert_validate) diff --git a/R/as_result_data.R b/R/as_result_data.R index e264a7d02..c307cc84c 100644 --- a/R/as_result_data.R +++ b/R/as_result_data.R @@ -39,18 +39,18 @@ #' #' rdata = as_result_data(task, learners, resampling, iterations, predictions) #' ResampleResult$new(rdata) -as_result_data = function(task, learners, resampling, iterations, predictions, learner_states = NULL, store_backends = TRUE) { +as_result_data = function(task, learners, resampling, iterations, predictions, learner_states = NULL, data_extra = NULL, store_backends = TRUE) { assert_task(task) assert_learners(learners, task = task) assert_resampling(resampling, instantiated = TRUE) assert_integer(iterations, any.missing = FALSE, lower = 1L, upper = resampling$iters, unique = TRUE) assert_list(predictions, types = "list") assert_list(learner_states, null.ok = TRUE) + assert_list(data_extra, null.ok = TRUE) predictions = map(predictions, function(x) map(x, as_prediction_data)) N = length(iterations) - if (length(learners) != N) { stopf("Number of learners (%i) must match the number of resampling iterations (%i)", length(learners), N) } @@ -78,6 +78,7 @@ as_result_data = function(task, learners, resampling, iterations, predictions, l resampling = list(resampling), iteration = iterations, prediction = predictions, - uhash = UUIDgenerate() + uhash = UUIDgenerate(), + data_extra = data_extra ), store_backends = store_backends) } diff --git a/R/resample.R b/R/resample.R index da368373a..574f7a621 100644 --- a/R/resample.R +++ b/R/resample.R @@ -131,6 +131,8 @@ resample = function( MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, callbacks = callbacks) ) + lg$debug("Resampling finished") + data = data.table( task = list(task), learner = grid$learner, @@ -144,8 +146,12 @@ resample = function( data_extra = map(res, "data_extra") ) + lg$debug("Prepare result data") + result_data = ResultData$new(data, store_backends = store_backends) + lg$debug("Result data written") + # the worker already ensures that models are sent back in marshaled form if unmarshal = FALSE, so we don't have # to do anything in this case. This allows us to minimize the amount of marshaling in those situtions where # the model is available in both states on the worker @@ -153,5 +159,7 @@ resample = function( result_data$unmarshal() } + lg$debug("Prepare resample result") + ResampleResult$new(result_data) } diff --git a/man/mlr_assertions.Rd b/man/mlr_assertions.Rd index 5f7c380d1..21a497c43 100644 --- a/man/mlr_assertions.Rd +++ b/man/mlr_assertions.Rd @@ -10,6 +10,7 @@ \alias{assert_learnable} \alias{assert_predictable} \alias{assert_measure} +\alias{assert_scorable} \alias{assert_measures} \alias{assert_resampling} \alias{assert_resamplings} @@ -67,6 +68,14 @@ assert_measure( .var.name = vname(measure) ) +assert_scorable( + measure, + task, + learner, + prediction = NULL, + .var.name = vname(measure) +) + assert_measures( measures, task = NULL, From 4b2cf71f6c6007105c1dfb28ee697dace2e1f5c6 Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 28 Nov 2024 14:32:50 +0100 Subject: [PATCH 11/54] ... --- R/resample.R | 3 +++ 1 file changed, 3 insertions(+) diff --git a/R/resample.R b/R/resample.R index 574f7a621..9afeec0ac 100644 --- a/R/resample.R +++ b/R/resample.R @@ -68,6 +68,9 @@ resample = function( unmarshal = TRUE, callbacks = NULL ) { + + lg$debug("Start resampling") + assert_subset(clone, c("task", "learner", "resampling")) task = assert_task(as_task(task, clone = "task" %in% clone)) learner = assert_learner(as_learner(learner, clone = "learner" %in% clone, discard_state = TRUE)) From ee83cbbd639ab1e014c7f1567982b9c4c76f03b7 Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 28 Nov 2024 14:51:11 +0100 Subject: [PATCH 12/54] ... --- R/resample.R | 3 ++- R/worker.R | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/R/resample.R b/R/resample.R index 9afeec0ac..30b4cd6ba 100644 --- a/R/resample.R +++ b/R/resample.R @@ -94,6 +94,7 @@ resample = function( } else { NULL } + browser() lgr_threshold = map_int(mlr_reflections$loggers, "threshold") grid = if (allow_hotstart) { @@ -108,7 +109,7 @@ resample = function( } if (is.null(learner$hotstart_stack) || is.null(start_learner)) { # no hotstart learners stored or no adaptable model found - lg$debug("Resampling with hotstarting not possible. No start learner found.") + #lg$debug("Resampling with hotstarting not possible. No start learner found.") mode = "train" } else { # hotstart learner found diff --git a/R/worker.R b/R/worker.R index aced1724b..04fbb7ad6 100644 --- a/R/worker.R +++ b/R/worker.R @@ -298,11 +298,11 @@ workhorse = function( } } # restore logger thresholds - for (package in names(lgr_threshold)) { - logger = lgr::get_logger(package) - threshold = lgr_threshold[package] - logger$set_threshold(threshold) - } + # for (package in names(lgr_threshold)) { + # logger = lgr::get_logger(package) + # threshold = lgr_threshold[package] + # logger$set_threshold(threshold) + # } lg$info("%s learner '%s' on task '%s' (iter %i/%i)", if (mode == "train") "Applying" else "Hotstarting", learner$id, task$id, iteration, resampling$iters) From c0424e279090a6b963d8ac9b6069cfb22b656928 Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 28 Nov 2024 14:53:43 +0100 Subject: [PATCH 13/54] ... --- R/resample.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/resample.R b/R/resample.R index 30b4cd6ba..a2fb53a47 100644 --- a/R/resample.R +++ b/R/resample.R @@ -94,7 +94,7 @@ resample = function( } else { NULL } - browser() + lgr_threshold = map_int(mlr_reflections$loggers, "threshold") grid = if (allow_hotstart) { From 0789de082f943a81a6b9535b9a41b0169b69d15e Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 28 Nov 2024 17:44:35 +0100 Subject: [PATCH 14/54] refactor: remove objekt logging --- R/worker.R | 36 +++++++++++++----------------------- 1 file changed, 13 insertions(+), 23 deletions(-) diff --git a/R/worker.R b/R/worker.R index e52daa23f..76a6ecea0 100644 --- a/R/worker.R +++ b/R/worker.R @@ -34,8 +34,7 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL # subset to train set w/o cloning if (!is.null(train_row_ids)) { - lg$debug("Subsetting task '%s' to %i rows", - task$id, length(train_row_ids), task = task$clone(), row_ids = train_row_ids) + lg$debug("Subsetting task '%s' to %i rows", task$id, length(train_row_ids)) task_private = get_private(task) prev_use = task_private$.row_roles$use @@ -64,7 +63,7 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL if (mode == "train") learner$state = list() lg$debug("Calling %s method of Learner '%s' on task '%s' with %i observations", - mode, learner$id, task$id, task$nrow, learner = learner$clone()) + mode, learner$id, task$id, task$nrow) # call train_wrapper with encapsulation result = encapsulate(learner$encapsulation["train"], @@ -101,26 +100,23 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL } if (is.null(result$result)) { - lg$info("Learner '%s' on task '%s' failed to %s a model", - learner$id, task$id, mode, learner = learner$clone(), messages = result$log$msg) + lg$info("Learner '%s' on task '%s' failed to %s a model", learner$id, task$id, mode) } else { - lg$debug("Learner '%s' on task '%s' succeeded to %s a model", - learner$id, task$id, mode, learner = learner$clone(), result = result$result, messages = result$log$msg) + lg$debug("Learner '%s' on task '%s' succeeded to %s a model", learner$id, task$id, mode) } # fit fallback learner fb = learner$fallback if (!is.null(fb)) { lg$info("Calling train method of fallback '%s' on task '%s' with %i observations", - fb$id, task$id, task$nrow, learner = fb$clone()) + fb$id, task$id, task$nrow) fb = assert_learner(as_learner(fb)) require_namespaces(fb$packages) fb$train(task) learner$state$fallback_state = fb$state - lg$debug("Fitted fallback learner '%s'", - fb$id, learner = fb$clone()) + lg$debug("Fitted fallback learner '%s'", fb$id) } @@ -164,8 +160,7 @@ learner_predict = function(learner, task, row_ids = NULL) { # subset to test set w/o cloning if (!is.null(row_ids)) { - lg$debug("Subsetting task '%s' to %i rows", - task$id, length(row_ids), task = task$clone(), row_ids = row_ids) + lg$debug("Subsetting task '%s' to %i rows", task$id, length(row_ids)) task_private = get_private(task) prev_use = task_private$.row_roles$use @@ -179,21 +174,19 @@ learner_predict = function(learner, task, row_ids = NULL) { if (task$nrow == 0L) { # return an empty prediction object, #421 - lg$debug("No observations in task, returning empty prediction data", task = task) + lg$debug("No observations in task, returning empty prediction data") learner$state$log = append_log(learner$state$log, "predict", "output", "No data to predict on, create empty prediction") return(create_empty_prediction_data(task, learner)) } if (is.null(learner$state$model)) { - lg$debug("Learner '%s' has no model stored", - learner$id, learner = learner$clone()) + lg$debug("Learner '%s' has no model stored", learner$id) pdata = NULL learner$state$predict_time = NA_real_ } else { # call predict with encapsulation - lg$debug("Calling predict method of Learner '%s' on task '%s' with %i observations", - learner$id, task$id, task$nrow, learner = learner$clone()) + lg$debug("Calling predict method of Learner '%s' on task '%s' with %i observations", learner$id, task$id, task$nrow) if (isTRUE(all.equal(learner$encapsulation[["predict"]], "callr"))) { learner$model = marshal_model(learner$model, inplace = TRUE) @@ -212,8 +205,7 @@ learner_predict = function(learner, task, row_ids = NULL) { learner$state$log = append_log(learner$state$log, "predict", result$log$class, result$log$msg) learner$state$predict_time = sum(learner$state$predict_time, result$elapsed) - lg$debug("Learner '%s' returned an object of class '%s'", - learner$id, class(pdata)[1L], learner = learner$clone(), prediction_data = pdata, messages = result$log$msg) + lg$debug("Learner '%s' returned an object of class '%s'", learner$id, class(pdata)[1L]) } @@ -228,16 +220,14 @@ learner_predict = function(learner, task, row_ids = NULL) { if (is.null(pdata)) { - lg$debug("Creating new Prediction using fallback '%s'", - fb$id, learner = fb$clone()) + lg$debug("Creating new Prediction using fallback '%s'", fb$id) learner$state$log = append_log(learner$state$log, "predict", "output", "Using fallback learner for predictions") pdata = predict_fb(task$row_ids) } else { miss_ids = is_missing_prediction_data(pdata) - lg$debug("Imputing %i/%i predictions using fallback '%s'", - length(miss_ids), length(pdata$row_ids), fb$id, learner = fb$clone()) + lg$debug("Imputing %i/%i predictions using fallback '%s'", length(miss_ids), length(pdata$row_ids), fb$id) if (length(miss_ids)) { learner$state$log = append_log(learner$state$log, "predict", "output", "Using fallback learner to impute predictions") From f6fc7dcfae8df5de22c5878a36a233aad851e71c Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 12:26:07 +0100 Subject: [PATCH 15/54] ... --- R/resample.R | 8 -------- 1 file changed, 8 deletions(-) diff --git a/R/resample.R b/R/resample.R index a2fb53a47..4e1105298 100644 --- a/R/resample.R +++ b/R/resample.R @@ -135,8 +135,6 @@ resample = function( MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, callbacks = callbacks) ) - lg$debug("Resampling finished") - data = data.table( task = list(task), learner = grid$learner, @@ -150,12 +148,8 @@ resample = function( data_extra = map(res, "data_extra") ) - lg$debug("Prepare result data") - result_data = ResultData$new(data, store_backends = store_backends) - lg$debug("Result data written") - # the worker already ensures that models are sent back in marshaled form if unmarshal = FALSE, so we don't have # to do anything in this case. This allows us to minimize the amount of marshaling in those situtions where # the model is available in both states on the worker @@ -163,7 +157,5 @@ resample = function( result_data$unmarshal() } - lg$debug("Prepare resample result") - ResampleResult$new(result_data) } From ad46a8ca69eef20e99877d3bb30fdb8c28340170 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 13:47:33 +0100 Subject: [PATCH 16/54] ... --- R/ResultData.R | 37 +++++++++++++++++++---------- R/as_result_data.R | 15 +++++++++++- R/benchmark.R | 7 +++--- R/resample.R | 8 ++++--- inst/testthat/helper_expectations.R | 2 +- man/ResultData.Rd | 9 ++++--- man/as_result_data.Rd | 1 + 7 files changed, 56 insertions(+), 23 deletions(-) diff --git a/R/ResultData.R b/R/ResultData.R index c6df7245d..6cc25def6 100644 --- a/R/ResultData.R +++ b/R/ResultData.R @@ -20,8 +20,8 @@ ResultData = R6Class("ResultData", public = list( #' @field data (`list()`)\cr - #' List of [data.table::data.table()], arranged in a star schema. - #' Do not operate directly on this list. + #' List of [data.table::data.table()], arranged in a star schema. + #' Do not operate directly on this list. data = NULL, #' @description @@ -29,23 +29,25 @@ ResultData = R6Class("ResultData", #' An alternative construction method is provided by [as_result_data()]. #' #' @param data ([data.table::data.table()]) | `NULL`)\cr - #' Do not initialize this object yourself, use [as_result_data()] instead. + #' Do not initialize this object yourself, use [as_result_data()] instead. + #' @param data_extra (`list()`)\cr + #' Additional data to store. + #' This can be used to store additional information for each iteration. + #' #' @param store_backends (`logical(1)`)\cr - #' If set to `FALSE`, the backends of the [Task]s provided in `data` are - #' removed. - initialize = function(data = NULL, store_backends = TRUE) { + #' If set to `FALSE`, the backends of the [Task]s provided in `data` are removed. + initialize = function(data = NULL, data_extra = NULL, store_backends = TRUE) { assert_flag(store_backends) if (is.null(data)) { self$data = star_init() } else { - assert_names(names(data), - permutation.of = c("task", "learner", "learner_state", "resampling", "iteration", "param_values", "prediction", "uhash", "learner_hash", "data_extra")) + assert_names(names(data), permutation.of = c("task", "learner", "learner_state", "resampling", "iteration", "param_values", "prediction", "uhash", "learner_hash")) if (nrow(data) == 0L) { self$data = star_init() } else { - setcolorder(data, c("uhash", "iteration", "learner_state", "prediction", "data_extra", "task", "learner", "resampling", "param_values", "learner_hash")) + setcolorder(data, c("uhash", "iteration", "learner_state", "prediction", "task", "learner", "resampling", "param_values", "learner_hash")) uhashes = data.table(uhash = unique(data$uhash)) setkeyv(data, c("uhash", "iteration")) @@ -68,6 +70,12 @@ ResultData = R6Class("ResultData", set(data, j = "resampling", value = NULL) set(data, j = "param_values", value = NULL) + # add extra data to fact table + if (!is.null(data_extra)) { + assert_list(data_extra, len = nrow(data)) + set(data, j = "data_extra", value = data_extra) + } + if (!store_backends) { set(tasks, j = "task", value = lapply(tasks$task, task_rm_backend)) } @@ -194,6 +202,9 @@ ResultData = R6Class("ResultData", #' #' @return `list()`. data_extra = function(view = NULL) { + if ("data_extra" %nin% names(self$data$fact)) { + return(NULL) + } .__ii__ = private$get_view_index(view) self$data$fact[.__ii__, "data_extra", with = FALSE][[1L]] }, @@ -324,7 +335,7 @@ ResultData = R6Class("ResultData", } cns = c("uhash", "task", "task_hash", "learner", "learner_hash", "learner_param_vals", "resampling", - "resampling_hash", "iteration", "prediction", "data_extra") + "resampling_hash", "iteration", "prediction", if ("data_extra" %in% names(self$data$fact)) "data_extra") merge(self$data$uhashes, tab[, cns, with = FALSE], by = "uhash", sort = FALSE) }, @@ -378,13 +389,13 @@ ResultData = R6Class("ResultData", ####################################################################################################################### ### constructor ####################################################################################################################### -star_init = function() { +star_init = function(data_extra = FALSE) { fact = data.table( uhash = character(), iteration = integer(), learner_state = list(), prediction = list(), - data_extra = list(), + learner_hash = character(), task_hash = character(), @@ -394,6 +405,8 @@ star_init = function() { key = c("uhash", "iteration") ) + if (data_extra) fact[, "data_extra" := list()] + uhashes = data.table( uhash = character() ) diff --git a/R/as_result_data.R b/R/as_result_data.R index c307cc84c..ed29f5251 100644 --- a/R/as_result_data.R +++ b/R/as_result_data.R @@ -39,7 +39,16 @@ #' #' rdata = as_result_data(task, learners, resampling, iterations, predictions) #' ResampleResult$new(rdata) -as_result_data = function(task, learners, resampling, iterations, predictions, learner_states = NULL, data_extra = NULL, store_backends = TRUE) { +as_result_data = function( + task, + learners, + resampling, + iterations, + predictions, + learner_states = NULL, + data_extra = NULL, + store_backends = TRUE + ) { assert_task(task) assert_learners(learners, task = task) assert_resampling(resampling, instantiated = TRUE) @@ -69,6 +78,10 @@ as_result_data = function(task, learners, resampling, iterations, predictions, l stopf("Resampling '%s' has not been trained on task '%s', hashes do not match", resampling$id, task$id) } + if (!is.null(data_extra) && length(data_extra) != N) { + stopf("Number of data_extra (%i) must match the number of resampling iterations (%i)", length(data_extra), N) + } + ResultData$new(data.table( task = list(task), learner = learners, diff --git a/R/benchmark.R b/R/benchmark.R index 07e862d65..e091c6772 100644 --- a/R/benchmark.R +++ b/R/benchmark.R @@ -196,15 +196,16 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps learner_state = map(res, "learner_state"), prediction = map(res, "prediction"), param_values = map(res, "param_values"), - learner_hash = map_chr(res, "learner_hash"), - data_extra = map(res, "data_extra") + learner_hash = map_chr(res, "learner_hash") )) lg$info("Finished benchmark") set(grid, j = "mode", value = NULL) - result_data = ResultData$new(grid, store_backends = store_backends) + data_extra = if (length(callbacks)) map(res, "data_extra") + + result_data = ResultData$new(grid, data_extra, store_backends = store_backends) if (unmarshal && store_models) { result_data$unmarshal() diff --git a/R/resample.R b/R/resample.R index 4e1105298..2f1fde3c4 100644 --- a/R/resample.R +++ b/R/resample.R @@ -144,11 +144,13 @@ resample = function( prediction = map(res, "prediction"), uhash = UUIDgenerate(), param_values = map(res, "param_values"), - learner_hash = map_chr(res, "learner_hash"), - data_extra = map(res, "data_extra") + learner_hash = map_chr(res, "learner_hash") ) - result_data = ResultData$new(data, store_backends = store_backends) + # save the extra data only if a callback could have generated some + data_extra = if (length(callbacks)) map(res, "data_extra") + + result_data = ResultData$new(data, data_extra, store_backends = store_backends) # the worker already ensures that models are sent back in marshaled form if unmarshal = FALSE, so we don't have # to do anything in this case. This allows us to minimize the amount of marshaling in those situtions where diff --git a/inst/testthat/helper_expectations.R b/inst/testthat/helper_expectations.R index c7656defb..01cdd622f 100644 --- a/inst/testthat/helper_expectations.R +++ b/inst/testthat/helper_expectations.R @@ -689,7 +689,7 @@ expect_resultdata = function(rdata, consistency = TRUE) { checkmate::expect_class(rdata, "ResultData") data = rdata$data - proto = mlr3:::star_init() + proto = mlr3:::star_init(data_extra = "data_extra" %in% names(data$fact)) checkmate::expect_set_equal(names(data), names(proto)) for (nn in names(proto)) { diff --git a/man/ResultData.Rd b/man/ResultData.Rd index 2bb02bf3e..355b194e1 100644 --- a/man/ResultData.Rd +++ b/man/ResultData.Rd @@ -66,7 +66,7 @@ Returns \code{NULL} if the \link{ResultData} is empty.} Creates a new instance of this \link[R6:R6Class]{R6} class. An alternative construction method is provided by \code{\link[=as_result_data]{as_result_data()}}. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ResultData$new(data = NULL, store_backends = TRUE)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{ResultData$new(data = NULL, data_extra = NULL, store_backends = TRUE)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -75,9 +75,12 @@ An alternative construction method is provided by \code{\link[=as_result_data]{a \item{\code{data}}{(\code{\link[data.table:data.table]{data.table::data.table()}}) | \code{NULL})\cr Do not initialize this object yourself, use \code{\link[=as_result_data]{as_result_data()}} instead.} +\item{\code{data_extra}}{(\code{list()})\cr +Additional data to store. +This can be used to store additional information for each iteration.} + \item{\code{store_backends}}{(\code{logical(1)})\cr -If set to \code{FALSE}, the backends of the \link{Task}s provided in \code{data} are -removed.} +If set to \code{FALSE}, the backends of the \link{Task}s provided in \code{data} are removed.} } \if{html}{\out{}} } diff --git a/man/as_result_data.Rd b/man/as_result_data.Rd index e3b106c35..45a017f10 100644 --- a/man/as_result_data.Rd +++ b/man/as_result_data.Rd @@ -11,6 +11,7 @@ as_result_data( iterations, predictions, learner_states = NULL, + data_extra = NULL, store_backends = TRUE ) } From 911b3edb94a82748106182b6d23027b89e718d49 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 13:50:20 +0100 Subject: [PATCH 17/54] ... --- R/resample.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/resample.R b/R/resample.R index 2f1fde3c4..cfe2205d5 100644 --- a/R/resample.R +++ b/R/resample.R @@ -109,7 +109,7 @@ resample = function( } if (is.null(learner$hotstart_stack) || is.null(start_learner)) { # no hotstart learners stored or no adaptable model found - #lg$debug("Resampling with hotstarting not possible. No start learner found.") + lg$debug("Resampling with hotstarting not possible. No start learner found.") mode = "train" } else { # hotstart learner found From fbd3b01a745b626843aa9841a1393a603e08a69e Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 13:51:12 +0100 Subject: [PATCH 18/54] ... --- R/worker.R | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/R/worker.R b/R/worker.R index 5dc7942eb..c620a09c4 100644 --- a/R/worker.R +++ b/R/worker.R @@ -287,12 +287,13 @@ workhorse = function( }, add = TRUE) } } + # restore logger thresholds - # for (package in names(lgr_threshold)) { - # logger = lgr::get_logger(package) - # threshold = lgr_threshold[package] - # logger$set_threshold(threshold) - # } + for (package in names(lgr_threshold)) { + logger = lgr::get_logger(package) + threshold = lgr_threshold[package] + logger$set_threshold(threshold) + } lg$info("%s learner '%s' on task '%s' (iter %i/%i)", if (mode == "train") "Applying" else "Hotstarting", learner$id, task$id, iteration, resampling$iters) From 53b766d0fff1aa74efaaab982a50bf27cd3081c2 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 13:54:04 +0100 Subject: [PATCH 19/54] ... --- NAMESPACE | 5 +++-- R/CallbackEvaluation.R | 13 +++++++++++++ R/as_result_data.R | 2 ++ R/reexports.R | 9 +++++++++ man/as_result_data.Rd | 3 +++ man/callback_evaluation.Rd | 16 ++++++++++++++++ man/reexports.Rd | 5 +++++ 7 files changed, 51 insertions(+), 2 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index aa85fe9fa..99bc5a893 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -224,6 +224,8 @@ export(benchmark) export(benchmark_grid) export(callback_evaluation) export(check_prediction_data) +export(clbk) +export(clbks) export(col_info) export(convert_task) export(create_empty_prediction_data) @@ -241,6 +243,7 @@ export(learner_unmarshal) export(lrn) export(lrns) export(marshal_model) +export(mlr_callbacks) export(mlr_learners) export(mlr_measures) export(mlr_reflections) @@ -269,8 +272,6 @@ import(palmerpenguins) import(paradox) importFrom(R6,R6Class) importFrom(R6,is.R6) -importFrom(data.table,as.data.table) -importFrom(data.table,data.table) importFrom(future,nbrOfWorkers) importFrom(future,plan) importFrom(graphics,plot) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R index 3b8834096..8e3d9eeff 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackEvaluation.R @@ -57,6 +57,19 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' String in the format `[pkg]::[topic]` pointing to a manual page for this object. #' The referenced help package can be opened via method `$help()`. #' +#' @param on_evaluation_begin (`function()`)\cr +#' Stage called at the beginning of an evaluation. +#' Called in `workhorse()` (internal). +#' @param on_evaluation_before_train (`function()`)\cr +#' Stage called before training the learner. +#' Called in `workhorse()` (internal). +#' @param on_evaluation_before_predict (`function()`)\cr +#' Stage called before predicting. +#' Called in `workhorse()` (internal). +#' @param on_evaluation_end (`function()`)\cr +#' Stage called at the end of an evaluation. +#' Called in `workhorse()` (internal). +#' #' @export callback_evaluation = function( id, diff --git a/R/as_result_data.R b/R/as_result_data.R index ed29f5251..14913e392 100644 --- a/R/as_result_data.R +++ b/R/as_result_data.R @@ -15,6 +15,8 @@ #' @param predictions (list of list of [Prediction]s). #' @param learner_states (`list()`)\cr #' Learner states. If not provided, the states of `learners` are automatically extracted. +#' @param data_extra (`list()`)\cr +#' Additional data for each iteration. #' @param store_backends (`logical(1)`)\cr #' If set to `FALSE`, the backends of the [Task]s provided in `data` are #' removed. diff --git a/R/reexports.R b/R/reexports.R index 7fc1b7d49..b80171a9b 100644 --- a/R/reexports.R +++ b/R/reexports.R @@ -3,3 +3,12 @@ data.table::as.data.table #' @export data.table::data.table + +#' @export +mlr3misc::mlr_callbacks + +#' @export +mlr3misc::clbk + +#' @export +mlr3misc::clbks diff --git a/man/as_result_data.Rd b/man/as_result_data.Rd index 45a017f10..06c1d60bf 100644 --- a/man/as_result_data.Rd +++ b/man/as_result_data.Rd @@ -29,6 +29,9 @@ as_result_data( \item{learner_states}{(\code{list()})\cr Learner states. If not provided, the states of \code{learners} are automatically extracted.} +\item{data_extra}{(\code{list()})\cr +Additional data for each iteration.} + \item{store_backends}{(\code{logical(1)})\cr If set to \code{FALSE}, the backends of the \link{Task}s provided in \code{data} are removed.} diff --git a/man/callback_evaluation.Rd b/man/callback_evaluation.Rd index 04db4d220..c828b2e4e 100644 --- a/man/callback_evaluation.Rd +++ b/man/callback_evaluation.Rd @@ -24,6 +24,22 @@ Label for the new instance.} \item{man}{(\code{character(1)})\cr String in the format \verb{[pkg]::[topic]} pointing to a manual page for this object. The referenced help package can be opened via method \verb{$help()}.} + +\item{on_evaluation_begin}{(\verb{function()})\cr +Stage called at the beginning of an evaluation. +Called in \code{workhorse()} (internal).} + +\item{on_evaluation_before_train}{(\verb{function()})\cr +Stage called before training the learner. +Called in \code{workhorse()} (internal).} + +\item{on_evaluation_before_predict}{(\verb{function()})\cr +Stage called before predicting. +Called in \code{workhorse()} (internal).} + +\item{on_evaluation_end}{(\verb{function()})\cr +Stage called at the end of an evaluation. +Called in \code{workhorse()} (internal).} } \description{ Function to create a \link{CallbackEvaluation}. diff --git a/man/reexports.Rd b/man/reexports.Rd index 27cc9198f..d23ec3886 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -5,6 +5,9 @@ \alias{reexports} \alias{as.data.table} \alias{data.table} +\alias{mlr_callbacks} +\alias{clbk} +\alias{clbks} \title{Objects exported from other packages} \keyword{internal} \description{ @@ -13,5 +16,7 @@ below to see their documentation. \describe{ \item{data.table}{\code{\link[data.table]{as.data.table}}, \code{\link[data.table]{data.table}}} + + \item{mlr3misc}{\code{\link[mlr3misc]{clbk}}, \code{\link[mlr3misc:clbk]{clbks}}, \code{\link[mlr3misc]{mlr_callbacks}}} }} From 81b880f7efc6b75c768bf487495bc857c1d02188 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 14:00:08 +0100 Subject: [PATCH 20/54] ... --- NAMESPACE | 5 +++++ R/worker.R | 36 ++++++++++++++++++++++-------------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 99bc5a893..c536e1726 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -272,9 +272,14 @@ import(palmerpenguins) import(paradox) importFrom(R6,R6Class) importFrom(R6,is.R6) +importFrom(data.table,as.data.table) +importFrom(data.table,data.table) importFrom(future,nbrOfWorkers) importFrom(future,plan) importFrom(graphics,plot) +importFrom(mlr3misc,clbk) +importFrom(mlr3misc,clbks) +importFrom(mlr3misc,mlr_callbacks) importFrom(parallelly,availableCores) importFrom(stats,contr.treatment) importFrom(stats,model.frame) diff --git a/R/worker.R b/R/worker.R index c620a09c4..6363f7315 100644 --- a/R/worker.R +++ b/R/worker.R @@ -1,4 +1,4 @@ -learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NULL, mode = "train", callback, context) { +learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NULL, mode = "train") { # This wrapper calls learner$.train, and additionally performs some basic # checks that the training was successful. # Exceptions here are possibly encapsulated, so that they get captured @@ -34,7 +34,8 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL # subset to train set w/o cloning if (!is.null(train_row_ids)) { - lg$debug("Subsetting task '%s' to %i rows", task$id, length(train_row_ids)) + lg$debug("Subsetting task '%s' to %i rows", + task$id, length(train_row_ids), task = task$clone(), row_ids = train_row_ids) task_private = get_private(task) prev_use = task_private$.row_roles$use @@ -63,7 +64,7 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL if (mode == "train") learner$state = list() lg$debug("Calling %s method of Learner '%s' on task '%s' with %i observations", - mode, learner$id, task$id, task$nrow) + mode, learner$id, task$id, task$nrow, learner = learner$clone()) # call train_wrapper with encapsulation result = encapsulate(learner$encapsulation["train"], @@ -100,23 +101,26 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL } if (is.null(result$result)) { - lg$info("Learner '%s' on task '%s' failed to %s a model", learner$id, task$id, mode) + lg$info("Learner '%s' on task '%s' failed to %s a model", + learner$id, task$id, mode, learner = learner$clone(), messages = result$log$msg) } else { - lg$debug("Learner '%s' on task '%s' succeeded to %s a model", learner$id, task$id, mode) + lg$debug("Learner '%s' on task '%s' succeeded to %s a model", + learner$id, task$id, mode, learner = learner$clone(), result = result$result, messages = result$log$msg) } # fit fallback learner fb = learner$fallback if (!is.null(fb)) { lg$info("Calling train method of fallback '%s' on task '%s' with %i observations", - fb$id, task$id, task$nrow) + fb$id, task$id, task$nrow, learner = fb$clone()) fb = assert_learner(as_learner(fb)) require_namespaces(fb$packages) fb$train(task) learner$state$fallback_state = fb$state - lg$debug("Fitted fallback learner '%s'", fb$id) + lg$debug("Fitted fallback learner '%s'", + fb$id, learner = fb$clone()) } @@ -160,7 +164,8 @@ learner_predict = function(learner, task, row_ids = NULL) { # subset to test set w/o cloning if (!is.null(row_ids)) { - lg$debug("Subsetting task '%s' to %i rows", task$id, length(row_ids)) + lg$debug("Subsetting task '%s' to %i rows", + task$id, length(row_ids), task = task$clone(), row_ids = row_ids) task_private = get_private(task) prev_use = task_private$.row_roles$use @@ -174,19 +179,20 @@ learner_predict = function(learner, task, row_ids = NULL) { if (task$nrow == 0L) { # return an empty prediction object, #421 - lg$debug("No observations in task, returning empty prediction data") + lg$debug("No observations in task, returning empty prediction data", task = task) learner$state$log = append_log(learner$state$log, "predict", "output", "No data to predict on, create empty prediction") return(create_empty_prediction_data(task, learner)) } if (is.null(learner$state$model)) { - lg$debug("Learner '%s' has no model stored", learner$id) - + lg$debug("Learner '%s' has no model stored", + learner$id, learner = learner$clone()) pdata = NULL learner$state$predict_time = NA_real_ } else { # call predict with encapsulation - lg$debug("Calling predict method of Learner '%s' on task '%s' with %i observations", learner$id, task$id, task$nrow) + lg$debug("Calling predict method of Learner '%s' on task '%s' with %i observations", + learner$id, task$id, task$nrow, learner = learner$clone()) if (isTRUE(all.equal(learner$encapsulation[["predict"]], "callr"))) { learner$model = marshal_model(learner$model, inplace = TRUE) @@ -205,7 +211,8 @@ learner_predict = function(learner, task, row_ids = NULL) { learner$state$log = append_log(learner$state$log, "predict", result$log$class, result$log$msg) learner$state$predict_time = sum(learner$state$predict_time, result$elapsed) - lg$debug("Learner '%s' returned an object of class '%s'", learner$id, class(pdata)[1L]) + lg$debug("Learner '%s' returned an object of class '%s'", + learner$id, class(pdata)[1L], learner = learner$clone(), prediction_data = pdata, messages = result$log$msg) } @@ -227,7 +234,8 @@ learner_predict = function(learner, task, row_ids = NULL) { } else { miss_ids = is_missing_prediction_data(pdata) - lg$debug("Imputing %i/%i predictions using fallback '%s'", length(miss_ids), length(pdata$row_ids), fb$id) + lg$debug("Imputing %i/%i predictions using fallback '%s'", + length(miss_ids), length(pdata$row_ids), fb$id, learner = fb$clone()) if (length(miss_ids)) { learner$state$log = append_log(learner$state$log, "predict", "output", "Using fallback learner to impute predictions") From 2ad3d54a260b4496ae856e4b5d1977d7b95d2cb4 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 14:01:29 +0100 Subject: [PATCH 21/54] ... --- R/worker.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/worker.R b/R/worker.R index 6363f7315..a810cf9ab 100644 --- a/R/worker.R +++ b/R/worker.R @@ -64,7 +64,7 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL if (mode == "train") learner$state = list() lg$debug("Calling %s method of Learner '%s' on task '%s' with %i observations", - mode, learner$id, task$id, task$nrow, learner = learner$clone()) + mode, learner$id, task$id, task$nrow, learner = learner$clone()) # call train_wrapper with encapsulation result = encapsulate(learner$encapsulation["train"], @@ -104,7 +104,7 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL lg$info("Learner '%s' on task '%s' failed to %s a model", learner$id, task$id, mode, learner = learner$clone(), messages = result$log$msg) } else { - lg$debug("Learner '%s' on task '%s' succeeded to %s a model", + lg$debug("Learner '%s' on task '%s' succeeded to %s a model", learner$id, task$id, mode, learner = learner$clone(), result = result$result, messages = result$log$msg) } From 017d511a921f8c7dc03b8a4e245d776efbebb496 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 14:01:58 +0100 Subject: [PATCH 22/54] ... --- R/worker.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/worker.R b/R/worker.R index a810cf9ab..e5be181c1 100644 --- a/R/worker.R +++ b/R/worker.R @@ -187,6 +187,7 @@ learner_predict = function(learner, task, row_ids = NULL) { if (is.null(learner$state$model)) { lg$debug("Learner '%s' has no model stored", learner$id, learner = learner$clone()) + pdata = NULL learner$state$predict_time = NA_real_ } else { From 73996e7f21a543906a1f8e7e2de137c1d0169ebb Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 14:02:28 +0100 Subject: [PATCH 23/54] ... --- R/worker.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/worker.R b/R/worker.R index e5be181c1..bf3ecb51e 100644 --- a/R/worker.R +++ b/R/worker.R @@ -228,7 +228,8 @@ learner_predict = function(learner, task, row_ids = NULL) { if (is.null(pdata)) { - lg$debug("Creating new Prediction using fallback '%s'", fb$id) + lg$debug("Creating new Prediction using fallback '%s'", + fb$id, learner = fb$clone()) learner$state$log = append_log(learner$state$log, "predict", "output", "Using fallback learner for predictions") pdata = predict_fb(task$row_ids) From 50c13b9aa4d51ec4795bc5b5c936612bf7701f14 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 14:03:50 +0100 Subject: [PATCH 24/54] ... --- inst/test.r | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 inst/test.r diff --git a/inst/test.r b/inst/test.r deleted file mode 100644 index a4571198b..000000000 --- a/inst/test.r +++ /dev/null @@ -1,13 +0,0 @@ -# Callback -callback = callback_resample( - id = "test", - on_resample_before_result_data = function(callback, context) { - print("on_resample_before_result_data") - } -) - -learner = lrn("classif.rpart") -task = tsk("iris") -resampling = rsmp("cv", folds = 3) - -resample(task, learner, resampling, callbacks = callback) From cf92e718b2fb027e1fdefb6b7239db979fa69561 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 29 Nov 2024 14:07:53 +0100 Subject: [PATCH 25/54] pkgdown --- pkgdown/_pkgdown.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkgdown/_pkgdown.yml b/pkgdown/_pkgdown.yml index dc9efd23e..57e76ebe5 100644 --- a/pkgdown/_pkgdown.yml +++ b/pkgdown/_pkgdown.yml @@ -115,6 +115,9 @@ reference: contents: - CallbackEvaluation - ContextEvaluation + - callback_evaluation + - assert_evaluation_callback + - mlr3.score_measures - title: Internal Objects and Functions contents: - marshaling From 8f7f308396e6eadda41c4e2497d921ef72882827 Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 3 Dec 2024 17:20:17 +0100 Subject: [PATCH 26/54] add iteration to context --- R/ContextEvaluation.R | 7 ++++++- R/worker.R | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/R/ContextEvaluation.R b/R/ContextEvaluation.R index 9c366aa5f..d12a1e76f 100644 --- a/R/ContextEvaluation.R +++ b/R/ContextEvaluation.R @@ -30,6 +30,10 @@ ContextEvaluation = R6Class("ContextEvaluation", #' Is usually only set while tuning. param_values = NULL, + #' @field iteration (`integer()`)\cr + #' The current iteration. + iteration = NULL, + #' @field sets (`list()`)\cr #' The train and test set. #' The sets are available on stage `on_evaluation_before_train``. @@ -66,11 +70,12 @@ ContextEvaluation = R6Class("ContextEvaluation", #' The resampling strategy to be used. #' @param param_values (`list()`)\cr #' The parameter values to be used. - initialize = function(task, learner, resampling, param_values) { + initialize = function(task, learner, resampling, param_values, iteration) { # no assertions to avoid overhead self$task = task self$learner = learner self$resampling = resampling + self$iteration = iteration super$initialize(id = "evaluate", label = "Evaluation") } diff --git a/R/worker.R b/R/worker.R index bf3ecb51e..957e223e2 100644 --- a/R/worker.R +++ b/R/worker.R @@ -265,7 +265,7 @@ workhorse = function( unmarshal = TRUE, callbacks = NULL ) { - ctx = ContextEvaluation$new(task, learner, resampling, param_values) + ctx = ContextEvaluation$new(task, learner, resampling, param_values, iteration) call_back("on_evaluation_begin", callbacks, ctx) From 1e00462d817ef2b433f54a25b19f0853723a4b87 Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 4 Dec 2024 10:34:52 +0100 Subject: [PATCH 27/54] ... --- R/CallbackEvaluation.R | 15 +++++++++++---- R/ContextEvaluation.R | 26 +++----------------------- R/worker.R | 16 ++++++++-------- man/ContextEvaluation.Rd | 8 +++++++- man/callback_evaluation.Rd | 11 +++++++++-- 5 files changed, 38 insertions(+), 38 deletions(-) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R index 8e3d9eeff..1debd01b5 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackEvaluation.R @@ -34,15 +34,22 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' #' @description #' Function to create a [CallbackEvaluation]. +#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. +#' +#' Evaluation callbacks are called at different stages of the resampling process. +#' The stages are prefixed with `on_*`. #' #' ``` -#' Start Evaluation on Worker +#' Start Resampling Iteration on Worker #' - on_evaluation_begin #' - on_evaluation_before_train #' - on_evaluation_before_predict #' - on_evaluation_end -#' End Evaluation on Worker +#' End Resampling Iteration on Worker #' ``` +#' +#' See also the section on parameters for more information on the stages. +#' A evaluation callback works with [ContextEvaluation]. # #' @details #' When implementing a callback, each function must have two arguments named `callback` and `context`. @@ -84,7 +91,7 @@ callback_evaluation = function( on_evaluation_begin, on_evaluation_before_train, on_evaluation_before_predict, - on_evaluation_end ), + on_evaluation_end), c( "on_evaluation_begin", "on_evaluation_before_train", @@ -92,7 +99,7 @@ callback_evaluation = function( "on_evaluation_end" )), is.null) - walk(stages, function(stage) assert_function(stage, args = c("callback", "context"))) + stages = map(stages, function(stage) crate(assert_function(stage, args = c("callback", "context")))) callback = CallbackEvaluation$new(id, label, man) iwalk(stages, function(stage, name) callback[[name]] = stage) callback diff --git a/R/ContextEvaluation.R b/R/ContextEvaluation.R index d12a1e76f..242dd29b3 100644 --- a/R/ContextEvaluation.R +++ b/R/ContextEvaluation.R @@ -25,30 +25,10 @@ ContextEvaluation = R6Class("ContextEvaluation", #' The resampling is unchanged during the evaluation. resampling = NULL, - #' @field param_values `list()`\cr - #' The parameter values to be used. - #' Is usually only set while tuning. - param_values = NULL, - #' @field iteration (`integer()`)\cr #' The current iteration. iteration = NULL, - #' @field sets (`list()`)\cr - #' The train and test set. - #' The sets are available on stage `on_evaluation_before_train``. - sets = NULL, - - #' @field test_set (`integer()`)\cr - #' Validation test set. - #' The set is only available when using internal validation. - test_set = NULL, - - #' @field predict_sets (`list()`)\cr - #' The prediction sets stored in `learner$predict_sets`. - #' The sets are available on stage `on_evaluation_before_predict`. - predict_sets = NULL, - #' @field pdatas (List of [PredictionData])\cr #' The prediction data. #' The data is available on stage `on_evaluation_end`. @@ -68,9 +48,9 @@ ContextEvaluation = R6Class("ContextEvaluation", #' The learner to be evaluated. #' @param resampling ([Resampling])\cr #' The resampling strategy to be used. - #' @param param_values (`list()`)\cr - #' The parameter values to be used. - initialize = function(task, learner, resampling, param_values, iteration) { + #' @param iteration (`integer()`)\cr + #' The current iteration. + initialize = function(task, learner, resampling, iteration) { # no assertions to avoid overhead self$task = task self$learner = learner diff --git a/R/worker.R b/R/worker.R index 957e223e2..0941786b5 100644 --- a/R/worker.R +++ b/R/worker.R @@ -265,7 +265,7 @@ workhorse = function( unmarshal = TRUE, callbacks = NULL ) { - ctx = ContextEvaluation$new(task, learner, resampling, param_values, iteration) + ctx = ContextEvaluation$new(task, learner, resampling, iteration) call_back("on_evaluation_begin", callbacks, ctx) @@ -308,7 +308,7 @@ workhorse = function( lg$info("%s learner '%s' on task '%s' (iter %i/%i)", if (mode == "train") "Applying" else "Hotstarting", learner$id, task$id, iteration, resampling$iters) - ctx$sets = list( + sets = list( train = resampling$train_set(iteration), test = resampling$test_set(iteration) ) @@ -323,11 +323,11 @@ workhorse = function( validate = get0("validate", learner) - ctx$test_set = if (identical(validate, "test")) ctx$sets$test + ctx$test_set = if (identical(validate, "test")) sets$test call_back("on_evaluation_before_train", callbacks, ctx) - train_result = learner_train(learner, task, ctx$sets[["train"]], ctx$test_set, mode = mode) + train_result = learner_train(learner, task, sets[["train"]], ctx$test_set, mode = mode) ctx$learner = learner = train_result$learner # process the model so it can be used for prediction (e.g. marshal for callr prediction), but also @@ -338,10 +338,10 @@ workhorse = function( ) # predict for each set - ctx$predict_sets = learner$predict_sets + predict_sets = learner$predict_sets # creates the tasks and row_ids for all selected predict sets - pred_data = prediction_tasks_and_sets(task, train_result, validate, ctx$sets, ctx$predict_sets) + pred_data = prediction_tasks_and_sets(task, train_result, validate, sets, predict_sets) call_back("on_evaluation_before_predict", callbacks, ctx) @@ -349,9 +349,9 @@ workhorse = function( lg$debug("Creating Prediction for predict set '%s'", set) learner_predict(learner, task, row_ids) - }, set = ctx$predict_sets, row_ids = pred_data$sets, task = pred_data$tasks) + }, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks) - if (!length(ctx$predict_sets)) { + if (!length(predict_sets)) { learner$state$predict_time = 0L } ctx$pdatas = discard(pdatas, is.null) diff --git a/man/ContextEvaluation.Rd b/man/ContextEvaluation.Rd index 78bb4fac8..b0e95953a 100644 --- a/man/ContextEvaluation.Rd +++ b/man/ContextEvaluation.Rd @@ -30,6 +30,9 @@ The resampling is unchanged during the evaluation.} The parameter values to be used. Is usually only set while tuning.} +\item{\code{iteration}}{(\code{integer()})\cr +The current iteration.} + \item{\code{sets}}{(\code{list()})\cr The train and test set. The sets are available on stage `on_evaluation_before_train``.} @@ -73,7 +76,7 @@ Use this field to save results.} \subsection{Method \code{new()}}{ Creates a new instance of this \link[R6:R6Class]{R6} class. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ContextEvaluation$new(task, learner, resampling, param_values)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{ContextEvaluation$new(task, learner, resampling, param_values, iteration)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -90,6 +93,9 @@ The resampling strategy to be used.} \item{\code{param_values}}{(\code{list()})\cr The parameter values to be used.} + +\item{\code{iteration}}{(\code{integer()})\cr +The current iteration.} } \if{html}{\out{}} } diff --git a/man/callback_evaluation.Rd b/man/callback_evaluation.Rd index c828b2e4e..4a74900bc 100644 --- a/man/callback_evaluation.Rd +++ b/man/callback_evaluation.Rd @@ -43,14 +43,21 @@ Called in \code{workhorse()} (internal).} } \description{ Function to create a \link{CallbackEvaluation}. +Predefined callbacks are stored in the \link[mlr3misc:Dictionary]{dictionary} \link{mlr_callbacks} and can be retrieved with \code{\link[=clbk]{clbk()}}. -\if{html}{\out{
}}\preformatted{Start Evaluation on Worker +Evaluation callbacks are called at different stages of the resampling process. +The stages are prefixed with \verb{on_*}. + +\if{html}{\out{
}}\preformatted{Start Resampling Iteration on Worker - on_evaluation_begin - on_evaluation_before_train - on_evaluation_before_predict - on_evaluation_end -End Evaluation on Worker +End Resampling Iteration on Worker }\if{html}{\out{
}} + +See also the section on parameters for more information on the stages. +A evaluation callback works with \link{ContextEvaluation}. } \details{ When implementing a callback, each function must have two arguments named \code{callback} and \code{context}. From 1559518f914dc8b82303835ded5963b9ce8dc3fa Mon Sep 17 00:00:00 2001 From: Marc Becker <33069354+be-marc@users.noreply.github.com> Date: Wed, 4 Dec 2024 10:35:10 +0100 Subject: [PATCH 28/54] Update R/as_result_data.R Co-authored-by: Sebastian Fischer --- R/as_result_data.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/as_result_data.R b/R/as_result_data.R index 14913e392..bb7f9b648 100644 --- a/R/as_result_data.R +++ b/R/as_result_data.R @@ -81,7 +81,7 @@ as_result_data = function( } if (!is.null(data_extra) && length(data_extra) != N) { - stopf("Number of data_extra (%i) must match the number of resampling iterations (%i)", length(data_extra), N) + stopf("Length of data_extra (%i) must match the number of resampling iterations (%i)", length(data_extra), N) } ResultData$new(data.table( From d046835f95fbbd4f6b4696a24d26b7c4ea9396d6 Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 4 Dec 2024 10:51:46 +0100 Subject: [PATCH 29/54] tests --- tests/testthat/test_CallbackEvaluation.R | 30 +++--------------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/tests/testthat/test_CallbackEvaluation.R b/tests/testthat/test_CallbackEvaluation.R index 2cec38903..e9e2f6beb 100644 --- a/tests/testthat/test_CallbackEvaluation.R +++ b/tests/testthat/test_CallbackEvaluation.R @@ -10,10 +10,7 @@ test_that("on_evaluation_begin works", { assert_task(context$task) assert_learner(context$learner) assert_resampling(context$resampling) - assert_null(context$param_values) - assert_null(context$sets) - assert_null(context$test_set) - assert_null(context$predict_sets) + assert_number(context$iteration) assert_null(context$pdatas) } ) @@ -32,13 +29,7 @@ test_that("on_evaluation_before_train works", { assert_task(context$task) assert_learner(context$learner) assert_resampling(context$resampling) - assert_null(context$param_values) - assert_list(context$sets, len = 2) - assert_names(names(context$sets), identical.to = c("train", "test")) - assert_integer(context$sets$train) - assert_integer(context$sets$test) - assert_null(context$test_set) - assert_null(context$predict_sets) + assert_number(context$iteration) assert_null(context$pdatas) } ) @@ -58,14 +49,6 @@ test_that("on_evaluation_before_predict works", { assert_task(context$task) assert_learner(context$learner) assert_resampling(context$resampling) - assert_null(context$param_values) - assert_list(context$sets, len = 2) - assert_names(names(context$sets), identical.to = c("train", "test")) - assert_integer(context$sets$train) - assert_integer(context$sets$test) - assert_class(context$learner$model, "rpart") - assert_null(context$test_set) - assert_true(context$predict_sets == "test") assert_null(context$pdatas) } ) @@ -84,14 +67,7 @@ test_that("on_evaluation_end works", { assert_task(context$task) assert_learner(context$learner) assert_resampling(context$resampling) - assert_null(context$param_values) - assert_list(context$sets, len = 2) - assert_names(names(context$sets), identical.to = c("train", "test")) - assert_integer(context$sets$train) - assert_integer(context$sets$test) - assert_class(context$learner$model, "rpart") - assert_null(context$test_set) - assert_true(context$predict_sets == "test") + assert_number(context$iteration) assert_class(context$pdatas$test, "PredictionData") } ) From 5c91d7579373e0ae01669d667a198eb7e6604916 Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 4 Dec 2024 10:57:45 +0100 Subject: [PATCH 30/54] ... --- R/worker.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/worker.R b/R/worker.R index 0941786b5..948be5c7b 100644 --- a/R/worker.R +++ b/R/worker.R @@ -323,11 +323,11 @@ workhorse = function( validate = get0("validate", learner) - ctx$test_set = if (identical(validate, "test")) sets$test + test_set = if (identical(validate, "test")) sets$test call_back("on_evaluation_before_train", callbacks, ctx) - train_result = learner_train(learner, task, sets[["train"]], ctx$test_set, mode = mode) + train_result = learner_train(learner, task, sets[["train"]], test_set, mode = mode) ctx$learner = learner = train_result$learner # process the model so it can be used for prediction (e.g. marshal for callr prediction), but also From 9d14bb0a508595c0c8941322814acb0d22b1e5e0 Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 10 Dec 2024 10:08:04 +0100 Subject: [PATCH 31/54] ... --- R/CallbackEvaluation.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R index 1debd01b5..732fa07f5 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackEvaluation.R @@ -125,6 +125,7 @@ assert_evaluation_callback = function(callback, null_ok = FALSE) { #' @export #' @param callbacks (list of [CallbackEvaluation]). #' @rdname assert_evaluation_callback -assert_evaluation_callbacks = function(callbacks) { +assert_evaluation_callbacks = function(callbacks, null_ok = FALSE) { + if (null_ok && is.null(callbacks)) return(invisible(NULL)) invisible(lapply(callbacks, assert_evaluation_callback)) } From 15f9e76ebc6f8e0d091404a37dd34bd7167b2c90 Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 10 Dec 2024 11:47:19 +0100 Subject: [PATCH 32/54] ... --- R/BenchmarkResult.R | 14 +++++- R/ResampleResult.R | 4 +- R/benchmark.R | 2 +- R/resample.R | 3 +- tests/testthat/test_CallbackEvaluation.R | 63 ++++++++++++++++++++++-- 5 files changed, 76 insertions(+), 10 deletions(-) diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index 9ee35c3ce..18cfd70a9 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -487,6 +487,14 @@ BenchmarkResult = R6Class("BenchmarkResult", setcolorder(tab, c("learner_hash", "learner_id", "learner"))[] }, + + #' @field data_extra (list())\cr + #' Additional data stored in the [ResampleResult]. + data_extra = function() { + private$.data$data_extra(private$.view) + }, + + #' @field resamplings ([data.table::data.table()])\cr #' Table of included [Resampling]s with three columns: #' @@ -545,10 +553,12 @@ BenchmarkResult = R6Class("BenchmarkResult", ) #' @export -as.data.table.BenchmarkResult = function(x, ..., hashes = FALSE, predict_sets = "test", task_characteristics = FALSE) { # nolint +as.data.table.BenchmarkResult = function(x, ..., hashes = FALSE, predict_sets = "test", task_characteristics = FALSE, data_extra = FALSE) { # nolint assert_flag(task_characteristics) tab = get_private(x)$.data$as_data_table(view = NULL, predict_sets = predict_sets) - tab = tab[, c("uhash", "task", "learner", "resampling", "iteration", "prediction"), with = FALSE] + cns = c("uhash", "task", "learner", "resampling", "iteration", "prediction") + if (data_extra && "data_extra" %in% names(tab)) cns = c(cns, "data_extra") + tab = tab[, cns, with = FALSE] if (task_characteristics) { set(tab, j = "characteristics", value = map(tab$task, "characteristics")) diff --git a/R/ResampleResult.R b/R/ResampleResult.R index c31343ca7..79420c743 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -379,7 +379,9 @@ ResampleResult = R6Class("ResampleResult", as.data.table.ResampleResult = function(x, ..., predict_sets = "test", data_extra = FALSE) { # nolint private = get_private(x) tab = private$.data$as_data_table(view = private$.view, predict_sets = predict_sets) - tab[, c("task", "learner", "resampling", "iteration", "prediction", if (data_extra) "data_extra"), with = FALSE] + cns = c("task", "learner", "resampling", "iteration", "prediction") + if (data_extra && "data_extra" %in% names(tab)) cns = c(cns, "data_extra") + tab[, cns, with = FALSE] } # #' @export diff --git a/R/benchmark.R b/R/benchmark.R index e091c6772..48d65bccd 100644 --- a/R/benchmark.R +++ b/R/benchmark.R @@ -203,7 +203,7 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps set(grid, j = "mode", value = NULL) - data_extra = if (length(callbacks)) map(res, "data_extra") + data_extra = if (length(callbacks) && any(map_lgl(res, function(x) !is.null(x$data_extra)))) map(res, "data_extra") result_data = ResultData$new(grid, data_extra, store_backends = store_backends) diff --git a/R/resample.R b/R/resample.R index cfe2205d5..5e7614592 100644 --- a/R/resample.R +++ b/R/resample.R @@ -147,8 +147,7 @@ resample = function( learner_hash = map_chr(res, "learner_hash") ) - # save the extra data only if a callback could have generated some - data_extra = if (length(callbacks)) map(res, "data_extra") + data_extra = if (length(callbacks) && any(map_lgl(res, function(x) !is.null(x$data_extra)))) map(res, "data_extra") result_data = ResultData$new(data, data_extra, store_backends = store_backends) diff --git a/tests/testthat/test_CallbackEvaluation.R b/tests/testthat/test_CallbackEvaluation.R index e9e2f6beb..995e370b5 100644 --- a/tests/testthat/test_CallbackEvaluation.R +++ b/tests/testthat/test_CallbackEvaluation.R @@ -81,17 +81,24 @@ test_that("writing to learner$state works", { resampling = rsmp("cv", folds = 3) callback = callback_evaluation("test", - on_evaluation_end = function(callback, context) { context$learner$state$test = 1 } ) + # resample result rr = resample(task, learner, resampling, callbacks = callback) - walk(rr$learners, function(learner) { expect_equal(learner$state$test, 1) }) + + # benchmark result + design = benchmark_grid(task, learner, resampling) + bmr = benchmark(design, callbacks = callback) + + walk(bmr$score()$learner, function(learner) { + expect_equal(learner$state$test, 1) + }) }) test_that("writing to data_extra works", { @@ -100,15 +107,63 @@ test_that("writing to data_extra works", { resampling = rsmp("cv", folds = 3) callback = callback_evaluation("test", - on_evaluation_end = function(callback, context) { context$data_extra$test = 1 } ) + # resample result rr = resample(task, learner, resampling, callbacks = callback) - walk(rr$data_extra, function(x) { expect_equal(x$test, 1) }) + + # resample result data.table + tab = as.data.table(rr, data_extra = TRUE) + expect_data_table(tab) + expect_names(names(tab), must.include = "data_extra") + + # benchmark result + design = benchmark_grid(task, learner, resampling) + bmr = benchmark(design, callbacks = callback) + expect_list(bmr$data_extra, len = 3) + + # benchmark data.table + tab = as.data.table(bmr, data_extra = TRUE) + expect_names(names(tab), must.include = "data_extra") + expect_list(tab$data_extra) + walk(tab$data_extra, function(x) { + expect_equal(x$test, 1) + }) +}) + +test_that("data_extra is null", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + + callback = callback_evaluation("test", + on_evaluation_end = function(callback, context) { + context$learner$state$test = 1 + } + ) + + # resample result + rr = resample(task, learner, resampling, callbacks = callback) + expect_null(rr$data_extra) + + # resample result data.table + tab = as.data.table(bmr, data_extra = TRUE) + expect_data_table(tab) + expect_names(names(tab), disjunct.from = "data_extra") + + # benchmark result + design = benchmark_grid(task, learner, resampling) + bmr = benchmark(design, callbacks = callback) + expect_null(bmr$data_extra) + + # benchmark data.table + tab = as.data.table(bmr, data_extra = TRUE) + expect_data_table(tab) + expect_names(names(tab), disjunct.from = "data_extra") }) From 50a26dfd36d0d17773e67416377e1c3dc1ba3fd7 Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 10 Dec 2024 11:47:26 +0100 Subject: [PATCH 33/54] ... --- tests/testthat/test_CallbackEvaluation.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test_CallbackEvaluation.R b/tests/testthat/test_CallbackEvaluation.R index 995e370b5..3a455f63d 100644 --- a/tests/testthat/test_CallbackEvaluation.R +++ b/tests/testthat/test_CallbackEvaluation.R @@ -91,14 +91,15 @@ test_that("writing to learner$state works", { walk(rr$learners, function(learner) { expect_equal(learner$state$test, 1) }) + expect_null(rr$data_extra) # benchmark result design = benchmark_grid(task, learner, resampling) bmr = benchmark(design, callbacks = callback) - walk(bmr$score()$learner, function(learner) { expect_equal(learner$state$test, 1) }) + expect_null(bmr$data_extra) }) test_that("writing to data_extra works", { From 65df92d8f58fdb24ef5090ff65f43e80c789bd40 Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 10 Dec 2024 11:57:22 +0100 Subject: [PATCH 34/54] ... --- R/BenchmarkResult.R | 8 -------- tests/testthat/test_CallbackEvaluation.R | 13 +++---------- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index 18cfd70a9..9d63b4b28 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -487,14 +487,6 @@ BenchmarkResult = R6Class("BenchmarkResult", setcolorder(tab, c("learner_hash", "learner_id", "learner"))[] }, - - #' @field data_extra (list())\cr - #' Additional data stored in the [ResampleResult]. - data_extra = function() { - private$.data$data_extra(private$.view) - }, - - #' @field resamplings ([data.table::data.table()])\cr #' Table of included [Resampling]s with three columns: #' diff --git a/tests/testthat/test_CallbackEvaluation.R b/tests/testthat/test_CallbackEvaluation.R index 3a455f63d..1cd0bc85e 100644 --- a/tests/testthat/test_CallbackEvaluation.R +++ b/tests/testthat/test_CallbackEvaluation.R @@ -99,7 +99,6 @@ test_that("writing to learner$state works", { walk(bmr$score()$learner, function(learner) { expect_equal(learner$state$test, 1) }) - expect_null(bmr$data_extra) }) test_that("writing to data_extra works", { @@ -124,12 +123,9 @@ test_that("writing to data_extra works", { expect_data_table(tab) expect_names(names(tab), must.include = "data_extra") - # benchmark result + # benchmark data.table design = benchmark_grid(task, learner, resampling) bmr = benchmark(design, callbacks = callback) - expect_list(bmr$data_extra, len = 3) - - # benchmark data.table tab = as.data.table(bmr, data_extra = TRUE) expect_names(names(tab), must.include = "data_extra") expect_list(tab$data_extra) @@ -154,16 +150,13 @@ test_that("data_extra is null", { expect_null(rr$data_extra) # resample result data.table - tab = as.data.table(bmr, data_extra = TRUE) + tab = as.data.table(rr, data_extra = TRUE) expect_data_table(tab) expect_names(names(tab), disjunct.from = "data_extra") - # benchmark result + # benchmark data.table design = benchmark_grid(task, learner, resampling) bmr = benchmark(design, callbacks = callback) - expect_null(bmr$data_extra) - - # benchmark data.table tab = as.data.table(bmr, data_extra = TRUE) expect_data_table(tab) expect_names(names(tab), disjunct.from = "data_extra") From 7335ded6081808ec98663876abb49c8e6cd41c6b Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 10 Dec 2024 11:58:24 +0100 Subject: [PATCH 35/54] ... --- R/BenchmarkResult.R | 2 +- R/ResampleResult.R | 2 +- man/BenchmarkResult.Rd | 2 +- man/ContextEvaluation.Rd | 21 +-------------------- man/ResampleResult.Rd | 2 +- man/assert_evaluation_callback.Rd | 2 +- 6 files changed, 6 insertions(+), 25 deletions(-) diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index 9d63b4b28..3ef7a7713 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -19,7 +19,7 @@ #' @template param_measures #' #' @section S3 Methods: -#' * `as.data.table(rr, ..., reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", task_characteristics = FALSE)`\cr +#' * `as.data.table(rr, ..., reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", task_characteristics = FALSE, data_extra = FALSE)`\cr #' [BenchmarkResult] -> [data.table::data.table()]\cr #' Returns a tabular view of the internal data. #' * `c(...)`\cr diff --git a/R/ResampleResult.R b/R/ResampleResult.R index 79420c743..770a0baa1 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -13,7 +13,7 @@ #' @template param_measures #' #' @section S3 Methods: -#' * `as.data.table(rr, reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test")`\cr +#' * `as.data.table(rr, reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", data_extra = FALSE)`\cr #' [ResampleResult] -> [data.table::data.table()]\cr #' Returns a tabular view of the internal data. #' * `c(...)`\cr diff --git a/man/BenchmarkResult.Rd b/man/BenchmarkResult.Rd index 0e1f18d9a..32413532b 100644 --- a/man/BenchmarkResult.Rd +++ b/man/BenchmarkResult.Rd @@ -20,7 +20,7 @@ Do not modify any extracted object without cloning it first. \section{S3 Methods}{ \itemize{ -\item \code{as.data.table(rr, ..., reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", task_characteristics = FALSE)}\cr +\item \code{as.data.table(rr, ..., reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", task_characteristics = FALSE, data_extra = FALSE)}\cr \link{BenchmarkResult} -> \code{\link[data.table:data.table]{data.table::data.table()}}\cr Returns a tabular view of the internal data. \item \code{c(...)}\cr diff --git a/man/ContextEvaluation.Rd b/man/ContextEvaluation.Rd index b0e95953a..fe1ee1c94 100644 --- a/man/ContextEvaluation.Rd +++ b/man/ContextEvaluation.Rd @@ -26,25 +26,9 @@ The learner contains the models after stage \code{on_evaluation_before_train}.} The resampling strategy to be used. The resampling is unchanged during the evaluation.} -\item{\code{param_values}}{\code{list()}\cr -The parameter values to be used. -Is usually only set while tuning.} - \item{\code{iteration}}{(\code{integer()})\cr The current iteration.} -\item{\code{sets}}{(\code{list()})\cr -The train and test set. -The sets are available on stage `on_evaluation_before_train``.} - -\item{\code{test_set}}{(\code{integer()})\cr -Validation test set. -The set is only available when using internal validation.} - -\item{\code{predict_sets}}{(\code{list()})\cr -The prediction sets stored in \code{learner$predict_sets}. -The sets are available on stage \code{on_evaluation_before_predict}.} - \item{\code{pdatas}}{(List of \link{PredictionData})\cr The prediction data. The data is available on stage \code{on_evaluation_end}.} @@ -76,7 +60,7 @@ Use this field to save results.} \subsection{Method \code{new()}}{ Creates a new instance of this \link[R6:R6Class]{R6} class. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ContextEvaluation$new(task, learner, resampling, param_values, iteration)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{ContextEvaluation$new(task, learner, resampling, iteration)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -91,9 +75,6 @@ The learner to be evaluated.} \item{\code{resampling}}{(\link{Resampling})\cr The resampling strategy to be used.} -\item{\code{param_values}}{(\code{list()})\cr -The parameter values to be used.} - \item{\code{iteration}}{(\code{integer()})\cr The current iteration.} } diff --git a/man/ResampleResult.Rd b/man/ResampleResult.Rd index 39b584d6f..d7061a6bb 100644 --- a/man/ResampleResult.Rd +++ b/man/ResampleResult.Rd @@ -14,7 +14,7 @@ Do not modify any object without cloning it first. \section{S3 Methods}{ \itemize{ -\item \code{as.data.table(rr, reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test")}\cr +\item \code{as.data.table(rr, reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", data_extra = FALSE)}\cr \link{ResampleResult} -> \code{\link[data.table:data.table]{data.table::data.table()}}\cr Returns a tabular view of the internal data. \item \code{c(...)}\cr diff --git a/man/assert_evaluation_callback.Rd b/man/assert_evaluation_callback.Rd index 5f54d4fd5..31ffa4b43 100644 --- a/man/assert_evaluation_callback.Rd +++ b/man/assert_evaluation_callback.Rd @@ -7,7 +7,7 @@ \usage{ assert_evaluation_callback(callback, null_ok = FALSE) -assert_evaluation_callbacks(callbacks) +assert_evaluation_callbacks(callbacks, null_ok = FALSE) } \arguments{ \item{callback}{(\link{CallbackEvaluation}).} From 876edaf970580d70e7b31488893ab0ddae31146e Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 10 Dec 2024 12:18:28 +0100 Subject: [PATCH 36/54] ... --- R/mlr_callbacks.R | 5 +++-- tests/testthat/test_mlr_callbacks.R | 16 +++++++++++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/R/mlr_callbacks.R b/R/mlr_callbacks.R index c9ae5f504..cc90c1cf2 100644 --- a/R/mlr_callbacks.R +++ b/R/mlr_callbacks.R @@ -31,8 +31,9 @@ load_callback_score_measures = function() { # Score measures on the test set pred = as_prediction(context$pdatas$test) - res = pred$score(measures, context$task, context$learner) - context$learner$state$selected_features = res + context$data_extra = insert_named(context$data_extra, list( + score_measures = pred$score(measures, context$task, context$learner) + )) } ) } diff --git a/tests/testthat/test_mlr_callbacks.R b/tests/testthat/test_mlr_callbacks.R index 7beaa5de7..47681a451 100644 --- a/tests/testthat/test_mlr_callbacks.R +++ b/tests/testthat/test_mlr_callbacks.R @@ -7,9 +7,19 @@ test_that("score_measure works", { rr = resample(task, learner, resampling = resampling, callbacks = callback) - walk(rr$learners, function(learner) { - expect_number(learner$state$selected_features) + expect_list(rr$data_extra) + walk(rr$data_extra, function(data) { + expect_names(names(data), must.include = "score_measures") + expect_names(names(data[["score_measures"]]), must.include = "selected_features") }) - expect_names(names(as.data.table(rr, data_extra = TRUE)), must.include = "data_extra") + callback = clbk("mlr3.score_measures", measures = msrs(c("classif.ce", "selected_features"))) + + rr = resample(task, learner, resampling = resampling, callbacks = callback) + + expect_list(rr$data_extra) + walk(rr$data_extra, function(data) { + expect_names(names(data), must.include = "score_measures") + expect_names(names(data[["score_measures"]]), must.include = c("classif.ce", "selected_features")) + }) }) From e4ddf18b6a43e1236c67cdaa9bd906019400ab74 Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 10 Dec 2024 12:19:30 +0100 Subject: [PATCH 37/54] ... --- R/mlr_callbacks.R | 8 ++++---- man/mlr3.score_measures.Rd | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/R/mlr_callbacks.R b/R/mlr_callbacks.R index cc90c1cf2..cad2bc0fa 100644 --- a/R/mlr_callbacks.R +++ b/R/mlr_callbacks.R @@ -4,8 +4,8 @@ #' @name mlr3.score_measures #' #' @description -#' This `CallbackEvaluation` scores measures directly on the worker. -#' This way measures that require a model can be scores without saving the model. +#' This [CallbackEvaluation] scores measures directly on the worker. +#' This way measures that require a model can be scored without saving the model. #' #' @examples #' clbk("mlr3.score_measures", measures = msr("classif.ce")) @@ -16,9 +16,9 @@ #' #' callback = clbk("mlr3.score_measures", measures = msr("selected_features")) #' -#' rr = resample(task, learner, resampling = resampling, callbacks = callback) +#' rr = resample(task, learner, resampling = resampling, store_models = FALSE, callbacks = callback) #' -#' rr$learners[[1]]$state$selected_features +#' rr$data_extra NULL load_callback_score_measures = function() { diff --git a/man/mlr3.score_measures.Rd b/man/mlr3.score_measures.Rd index 4cf9fa49f..a40e1686f 100644 --- a/man/mlr3.score_measures.Rd +++ b/man/mlr3.score_measures.Rd @@ -4,8 +4,8 @@ \alias{mlr3.score_measures} \title{Score Measures Callback} \description{ -This \code{CallbackEvaluation} scores measures directly on the worker. -This way measures that require a model can be scores without saving the model. +This \link{CallbackEvaluation} scores measures directly on the worker. +This way measures that require a model can be scored without saving the model. } \examples{ clbk("mlr3.score_measures", measures = msr("classif.ce")) @@ -16,7 +16,7 @@ resampling = rsmp("cv", folds = 3) callback = clbk("mlr3.score_measures", measures = msr("selected_features")) -rr = resample(task, learner, resampling = resampling, callbacks = callback) +rr = resample(task, learner, resampling = resampling, store_models = FALSE, callbacks = callback) -rr$learners[[1]]$state$selected_features +rr$data_extra } From 3599be05d0b053fc476fa8b68b4b53f5813b737a Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 10 Dec 2024 14:04:57 +0100 Subject: [PATCH 38/54] ... --- R/CallbackEvaluation.R | 44 ++++++++++++++++++++++++++++++-------- man/CallbackEvaluation.Rd | 7 ++++-- man/callback_evaluation.Rd | 29 +++++++++++++++++++++++-- 3 files changed, 67 insertions(+), 13 deletions(-) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R index 732fa07f5..9f064dd6f 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackEvaluation.R @@ -1,7 +1,10 @@ -#' @title Create Evaluation Callback +#' @title Evaluation Callback #' #' @description -#' Callbacks allow to customize the behavior of `resample()` and `benchmark()` in mlr3. +#' Specialized [mlr3misc::Callback] to customize the behavior of [resample()] and [benchmark()] in mlr3. +#' The [callback_evaluation()] function is used to create instances of this class. +#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. +#' For more information on callbacks, see the [callback_evaluation()] documentation. #' #' @export CallbackEvaluation= R6Class("CallbackEvaluation", @@ -30,13 +33,14 @@ CallbackEvaluation= R6Class("CallbackEvaluation", ) ) -#' @title Create Workhorse Callback +#' @title Create Evaluation Callback #' #' @description #' Function to create a [CallbackEvaluation]. #' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. #' #' Evaluation callbacks are called at different stages of the resampling process. +#' Each stage is called once per resampling iteration. #' The stages are prefixed with `on_*`. #' #' ``` @@ -49,21 +53,22 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' ``` #' #' See also the section on parameters for more information on the stages. -#' A evaluation callback works with [ContextEvaluation]. +#' An evaluation callback works with [ContextEvaluation]. # #' @details #' When implementing a callback, each function must have two arguments named `callback` and `context`. #' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself. #' Evaluation callbacks access [ContextEvaluation]. +#' Data can be stored in the [ResampleResult] and [BenchmarkResult] objects via `context$data_extra`. +#' Alternatively results can be stored in the learner state via `context$learner$state`. #' #' @param id (`character(1)`)\cr -#' Identifier for the new instance. +#' Identifier for the new instance. #' @param label (`character(1)`)\cr -#' Label for the new instance. +#' Label for the new instance. #' @param man (`character(1)`)\cr -#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. -#' The referenced help package can be opened via method `$help()`. -#' +#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. +#' The referenced help package can be opened via method `$help()`. #' @param on_evaluation_begin (`function()`)\cr #' Stage called at the beginning of an evaluation. #' Called in `workhorse()` (internal). @@ -78,6 +83,27 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' Called in `workhorse()` (internal). #' #' @export +#' @examples +#' callback = callback_evaluation("selected_features", +#' label = "Selected Features", +#' +#' on_evaluation_end = function(callback, context) { +#' pred = as_prediction(context$pdatas$test) +#' selected_features = pred$score( +#' measure = msr("selected_features"), +#' learner = context$learner, +#' task = context$task) +#' context$learner$state$selected_features = selected_features +#' } +#' ) +#' +#' task = tsk("pima") +#' learner = lrn("classif.rpart") +#' resampling = rsmp("cv", folds = 3) +#' +#' rr = resample(task, learner, resampling, callbacks = callback) +#' +#' rr$learners[[1]]$state$selected_features callback_evaluation = function( id, label = NA_character_, diff --git a/man/CallbackEvaluation.Rd b/man/CallbackEvaluation.Rd index 67dfd7124..bd075c53c 100644 --- a/man/CallbackEvaluation.Rd +++ b/man/CallbackEvaluation.Rd @@ -2,9 +2,12 @@ % Please edit documentation in R/CallbackEvaluation.R \name{CallbackEvaluation} \alias{CallbackEvaluation} -\title{Create Evaluation Callback} +\title{Evaluation Callback} \description{ -Callbacks allow to customize the behavior of \code{resample()} and \code{benchmark()} in mlr3. +Specialized \link[mlr3misc:Callback]{mlr3misc::Callback} to customize the behavior of \code{\link[=resample]{resample()}} and \code{\link[=benchmark]{benchmark()}} in mlr3. +The \code{\link[=callback_evaluation]{callback_evaluation()}} function is used to create instances of this class. +Predefined callbacks are stored in the \link[mlr3misc:Dictionary]{dictionary} \link{mlr_callbacks} and can be retrieved with \code{\link[=clbk]{clbk()}}. +For more information on callbacks, see the \code{\link[=callback_evaluation]{callback_evaluation()}} documentation. } \section{Super class}{ \code{\link[mlr3misc:Callback]{mlr3misc::Callback}} -> \code{CallbackEvaluation} diff --git a/man/callback_evaluation.Rd b/man/callback_evaluation.Rd index 4a74900bc..bf76a2944 100644 --- a/man/callback_evaluation.Rd +++ b/man/callback_evaluation.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/CallbackEvaluation.R \name{callback_evaluation} \alias{callback_evaluation} -\title{Create Workhorse Callback} +\title{Create Evaluation Callback} \usage{ callback_evaluation( id, @@ -46,6 +46,7 @@ Function to create a \link{CallbackEvaluation}. Predefined callbacks are stored in the \link[mlr3misc:Dictionary]{dictionary} \link{mlr_callbacks} and can be retrieved with \code{\link[=clbk]{clbk()}}. Evaluation callbacks are called at different stages of the resampling process. +Each stage is called once per resampling iteration. The stages are prefixed with \verb{on_*}. \if{html}{\out{
}}\preformatted{Start Resampling Iteration on Worker @@ -57,10 +58,34 @@ End Resampling Iteration on Worker }\if{html}{\out{
}} See also the section on parameters for more information on the stages. -A evaluation callback works with \link{ContextEvaluation}. +An evaluation callback works with \link{ContextEvaluation}. } \details{ When implementing a callback, each function must have two arguments named \code{callback} and \code{context}. A callback can write data to the state (\verb{$state}), e.g. settings that affect the callback itself. Evaluation callbacks access \link{ContextEvaluation}. +Data can be stored in the \link{ResampleResult} and \link{BenchmarkResult} objects via \code{context$data_extra}. +Alternatively results can be stored in the learner state via \code{context$learner$state}. +} +\examples{ +callback = callback_evaluation("selected_features", + label = "Selected Features", + + on_evaluation_end = function(callback, context) { + pred = as_prediction(context$pdatas$test) + selected_features = pred$score( + measure = msr("selected_features"), + learner = context$learner, + task = context$task) + context$learner$state$selected_features = selected_features + } +) + +task = tsk("pima") +learner = lrn("classif.rpart") +resampling = rsmp("cv", folds = 3) + +rr = resample(task, learner, resampling, callbacks = callback) + +rr$learners[[1]]$state$selected_features } From fe0ecf622d502f8a0c0891f43b58761c4ec829a6 Mon Sep 17 00:00:00 2001 From: be-marc Date: Tue, 10 Dec 2024 14:06:47 +0100 Subject: [PATCH 39/54] ... --- R/ResampleResult.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/ResampleResult.R b/R/ResampleResult.R index 770a0baa1..d6bd48f41 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -337,7 +337,8 @@ ResampleResult = R6Class("ResampleResult", #' @field data_extra (list())\cr #' Additional data stored in the [ResampleResult]. - data_extra = function() { + data_extra = function(rhs) { + assert_ro_binding(rhs) private$.data$data_extra(private$.view) }, From 6d36e2235caf3d36d7d2458de3be7bad61bcc51e Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 11 Dec 2024 09:05:46 +0100 Subject: [PATCH 40/54] ... --- R/CallbackEvaluation.R | 8 +++++++- man/callback_evaluation.Rd | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R index 9f064dd6f..b3924bc39 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackEvaluation.R @@ -42,18 +42,24 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' Evaluation callbacks are called at different stages of the resampling process. #' Each stage is called once per resampling iteration. #' The stages are prefixed with `on_*`. +#' The text in brackets indicates what happens between the stages and which accesses to the [ContextEvaluation] `ctx` are typical for the stage. #' #' ``` #' Start Resampling Iteration on Worker #' - on_evaluation_begin +#' (Split `ctx$task` into training and test set with `ctx$resampling` and `ctx$iteration`) #' - on_evaluation_before_train +#' (Train the learner `ctx$learner` on training data) #' - on_evaluation_before_predict +#' (Predict on predict sets and store prediction data `ctx$pdatas`) #' - on_evaluation_end +#' (Erase model `ctx$learner$model` if requested and return results) #' End Resampling Iteration on Worker #' ``` #' +#' The callback can store data in `ctx$learner$state` or `ctx$data_extra`. +#' The data in `ctx$data_extra` is stored in the [ResampleResult] or [BenchmarkResult]. #' See also the section on parameters for more information on the stages. -#' An evaluation callback works with [ContextEvaluation]. # #' @details #' When implementing a callback, each function must have two arguments named `callback` and `context`. diff --git a/man/callback_evaluation.Rd b/man/callback_evaluation.Rd index bf76a2944..8163f1055 100644 --- a/man/callback_evaluation.Rd +++ b/man/callback_evaluation.Rd @@ -48,17 +48,23 @@ Predefined callbacks are stored in the \link[mlr3misc:Dictionary]{dictionary} \l Evaluation callbacks are called at different stages of the resampling process. Each stage is called once per resampling iteration. The stages are prefixed with \verb{on_*}. +The text in brackets indicates what happens between the stages and which accesses to the \link{ContextEvaluation} \code{ctx} are typical for the stage. \if{html}{\out{
}}\preformatted{Start Resampling Iteration on Worker - on_evaluation_begin + (Split `ctx$task` into training and test set with `ctx$resampling` and `ctx$iteration`) - on_evaluation_before_train + (Train the learner `ctx$learner` on training data) - on_evaluation_before_predict + (Predict on predict sets and store prediction data `ctx$pdatas`) - on_evaluation_end + (Erase model `ctx$learner$model` if requested and return results) End Resampling Iteration on Worker }\if{html}{\out{
}} +The callback can store data in \code{ctx$learner$state} or \code{ctx$data_extra}. +The data in \code{ctx$data_extra} is stored in the \link{ResampleResult} or \link{BenchmarkResult}. See also the section on parameters for more information on the stages. -An evaluation callback works with \link{ContextEvaluation}. } \details{ When implementing a callback, each function must have two arguments named \code{callback} and \code{context}. From 3597ca4681dbbaa0538b1cf18df29b7d8eb1c63f Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 11 Dec 2024 10:20:58 +0100 Subject: [PATCH 41/54] ... --- R/CallbackEvaluation.R | 11 +++++++---- R/ContextEvaluation.R | 23 ++++++++++++++++++----- R/ResultData.R | 2 +- man/ContextEvaluation.Rd | 10 ++++++++-- man/callback_evaluation.Rd | 13 +++++++++---- tests/testthat/test_CallbackEvaluation.R | 22 ++++++++++++++++++++++ 6 files changed, 65 insertions(+), 16 deletions(-) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R index b3924bc39..f3416d1ed 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackEvaluation.R @@ -42,7 +42,7 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' Evaluation callbacks are called at different stages of the resampling process. #' Each stage is called once per resampling iteration. #' The stages are prefixed with `on_*`. -#' The text in brackets indicates what happens between the stages and which accesses to the [ContextEvaluation] `ctx` are typical for the stage. +#' The text in brackets indicates what happens between the stages and which accesses to the [ContextEvaluation] (`ctx`) are typical for the stage. #' #' ``` #' Start Resampling Iteration on Worker @@ -64,9 +64,12 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' @details #' When implementing a callback, each function must have two arguments named `callback` and `context`. #' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself. -#' Evaluation callbacks access [ContextEvaluation]. -#' Data can be stored in the [ResampleResult] and [BenchmarkResult] objects via `context$data_extra`. -#' Alternatively results can be stored in the learner state via `context$learner$state`. +#' +#' @section Parallelization: +#' Be careful when modifying `ctx$learner`, `ctx$task`, or `ctx$resampling` because callbacks can behave differently when parallelizing the resampling process. +#' When running the resampling process sequentially, the modifications are carried over to the next iteration. +#' When parallelizing the resampling process, modifying the [ContextEvaluation] will not be synchronized between workers. +#' This also applies to the `$state` of the callback. #' #' @param id (`character(1)`)\cr #' Identifier for the new instance. diff --git a/R/ContextEvaluation.R b/R/ContextEvaluation.R index 242dd29b3..ab9293aef 100644 --- a/R/ContextEvaluation.R +++ b/R/ContextEvaluation.R @@ -34,11 +34,6 @@ ContextEvaluation = R6Class("ContextEvaluation", #' The data is available on stage `on_evaluation_end`. pdatas = NULL, - #' @field data_extra (list())\cr - #' Data saved in the [ResampleResult] or [BenchmarkResult]. - #' Use this field to save results. - data_extra = NULL, - #' @description #' Creates a new instance of this [R6][R6::R6Class] class. #' @@ -59,5 +54,23 @@ ContextEvaluation = R6Class("ContextEvaluation", super$initialize(id = "evaluate", label = "Evaluation") } + ), + + active = list( + + #' @field data_extra (list())\cr + #' Data saved in the [ResampleResult] or [BenchmarkResult]. + #' Use this field to save results. + #' Must be a `list()`. + data_extra = function(rhs) { + if (missing(rhs)) { + return(private$.data_extra) + } + private$.data_extra = assert_list(rhs) + } + ), + + private = list( + .data_extra = NULL ) ) diff --git a/R/ResultData.R b/R/ResultData.R index 6cc25def6..21c3681e5 100644 --- a/R/ResultData.R +++ b/R/ResultData.R @@ -73,7 +73,7 @@ ResultData = R6Class("ResultData", # add extra data to fact table if (!is.null(data_extra)) { assert_list(data_extra, len = nrow(data)) - set(data, j = "data_extra", value = data_extra) + set(data, j = "data_extra", value = list(data_extra)) } if (!store_backends) { diff --git a/man/ContextEvaluation.Rd b/man/ContextEvaluation.Rd index fe1ee1c94..1091599e2 100644 --- a/man/ContextEvaluation.Rd +++ b/man/ContextEvaluation.Rd @@ -32,10 +32,16 @@ The current iteration.} \item{\code{pdatas}}{(List of \link{PredictionData})\cr The prediction data. The data is available on stage \code{on_evaluation_end}.} - +} +\if{html}{\out{
}} +} +\section{Active bindings}{ +\if{html}{\out{
}} +\describe{ \item{\code{data_extra}}{(list())\cr Data saved in the \link{ResampleResult} or \link{BenchmarkResult}. -Use this field to save results.} +Use this field to save results. +Must be a \code{list()}.} } \if{html}{\out{
}} } diff --git a/man/callback_evaluation.Rd b/man/callback_evaluation.Rd index 8163f1055..c7959ce00 100644 --- a/man/callback_evaluation.Rd +++ b/man/callback_evaluation.Rd @@ -48,7 +48,7 @@ Predefined callbacks are stored in the \link[mlr3misc:Dictionary]{dictionary} \l Evaluation callbacks are called at different stages of the resampling process. Each stage is called once per resampling iteration. The stages are prefixed with \verb{on_*}. -The text in brackets indicates what happens between the stages and which accesses to the \link{ContextEvaluation} \code{ctx} are typical for the stage. +The text in brackets indicates what happens between the stages and which accesses to the \link{ContextEvaluation} (\code{ctx}) are typical for the stage. \if{html}{\out{
}}\preformatted{Start Resampling Iteration on Worker - on_evaluation_begin @@ -69,10 +69,15 @@ See also the section on parameters for more information on the stages. \details{ When implementing a callback, each function must have two arguments named \code{callback} and \code{context}. A callback can write data to the state (\verb{$state}), e.g. settings that affect the callback itself. -Evaluation callbacks access \link{ContextEvaluation}. -Data can be stored in the \link{ResampleResult} and \link{BenchmarkResult} objects via \code{context$data_extra}. -Alternatively results can be stored in the learner state via \code{context$learner$state}. } +\section{Parallelization}{ + +Be careful when modifying \code{ctx$learner}, \code{ctx$task}, or \code{ctx$resampling} because callbacks can behave differently when parallelizing the resampling process. +When running the resampling process sequentially, the modifications are carried over to the next iteration. +When parallelizing the resampling process, modifying the \link{ContextEvaluation} will not be synchronized between workers. +This also applies to the \verb{$state} of the callback. +} + \examples{ callback = callback_evaluation("selected_features", label = "Selected Features", diff --git a/tests/testthat/test_CallbackEvaluation.R b/tests/testthat/test_CallbackEvaluation.R index 1cd0bc85e..13ad6eb60 100644 --- a/tests/testthat/test_CallbackEvaluation.R +++ b/tests/testthat/test_CallbackEvaluation.R @@ -134,6 +134,27 @@ test_that("writing to data_extra works", { }) }) +test_that("data_extra is a list column", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("holdout") + + callback = callback_evaluation("test", + on_evaluation_end = function(callback, context) { + context$data_extra$test = 1 + } + ) + + rr = resample(task, learner, resampling, callbacks = callback) + expect_list(as.data.table(rr, data_extra = TRUE)$data_extra, len = 1) + expect_list(as.data.table(rr, data_extra = TRUE)$data_extra[[1]], len = 1) + + resampling = rsmp("cv", folds = 3) + rr = resample(task, learner, resampling, callbacks = callback) + expect_list(as.data.table(rr, data_extra = TRUE)$data_extra, len = 3) + expect_list(as.data.table(rr, data_extra = TRUE)$data_extra[[1]], len = 1) +}) + test_that("data_extra is null", { task = tsk("pima") learner = lrn("classif.rpart") @@ -161,3 +182,4 @@ test_that("data_extra is null", { expect_data_table(tab) expect_names(names(tab), disjunct.from = "data_extra") }) + From d751dac2794d91229426acd0f9408ca9f04c63db Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 11 Dec 2024 10:22:32 +0100 Subject: [PATCH 42/54] ... --- R/CallbackEvaluation.R | 24 ++++++++++++------------ R/ContextEvaluation.R | 8 ++++---- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R index f3416d1ed..10ecfdb92 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackEvaluation.R @@ -72,24 +72,24 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' This also applies to the `$state` of the callback. #' #' @param id (`character(1)`)\cr -#' Identifier for the new instance. +#' Identifier for the new instance. #' @param label (`character(1)`)\cr -#' Label for the new instance. +#' Label for the new instance. #' @param man (`character(1)`)\cr -#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. -#' The referenced help package can be opened via method `$help()`. +#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. +#' The referenced help package can be opened via method `$help()`. #' @param on_evaluation_begin (`function()`)\cr -#' Stage called at the beginning of an evaluation. -#' Called in `workhorse()` (internal). +#' Stage called at the beginning of an evaluation. +#' Called in `workhorse()` (internal). #' @param on_evaluation_before_train (`function()`)\cr -#' Stage called before training the learner. -#' Called in `workhorse()` (internal). +#' Stage called before training the learner. +#' Called in `workhorse()` (internal). #' @param on_evaluation_before_predict (`function()`)\cr -#' Stage called before predicting. -#' Called in `workhorse()` (internal). +#' Stage called before predicting. +#' Called in `workhorse()` (internal). #' @param on_evaluation_end (`function()`)\cr -#' Stage called at the end of an evaluation. -#' Called in `workhorse()` (internal). +#' Stage called at the end of an evaluation. +#' Called in `workhorse()` (internal). #' #' @export #' @examples diff --git a/R/ContextEvaluation.R b/R/ContextEvaluation.R index ab9293aef..a0baf9862 100644 --- a/R/ContextEvaluation.R +++ b/R/ContextEvaluation.R @@ -38,13 +38,13 @@ ContextEvaluation = R6Class("ContextEvaluation", #' Creates a new instance of this [R6][R6::R6Class] class. #' #' @param task ([Task])\cr - #' The task to be evaluated. + #' The task to be evaluated. #' @param learner ([Learner])\cr - #' The learner to be evaluated. + #' The learner to be evaluated. #' @param resampling ([Resampling])\cr - #' The resampling strategy to be used. + #' The resampling strategy to be used. #' @param iteration (`integer()`)\cr - #' The current iteration. + #' The current iteration. initialize = function(task, learner, resampling, iteration) { # no assertions to avoid overhead self$task = task From ab2cf9bb5fc5b0e73693ab27a3e1e46f3b74c5c4 Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 11 Dec 2024 10:23:57 +0100 Subject: [PATCH 43/54] ... --- R/CallbackEvaluation.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R index 10ecfdb92..cd17e15cc 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackEvaluation.R @@ -152,8 +152,7 @@ callback_evaluation = function( #' @return [CallbackEvaluation | List of [CallbackEvaluation]s. #' @export assert_evaluation_callback = function(callback, null_ok = FALSE) { - if (null_ok && is.null(callback)) return(invisible(NULL)) - assert_class(callback, "CallbackEvaluation") + assert_class(callback, "CallbackEvaluation", null.ok = null_ok) invisible(callback) } From 7fa85e04a49ada90b6694ef95b77eafb02839312 Mon Sep 17 00:00:00 2001 From: be-marc Date: Wed, 11 Dec 2024 10:33:01 +0100 Subject: [PATCH 44/54] ... --- R/CallbackEvaluation.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R index cd17e15cc..e4b6d5e85 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackEvaluation.R @@ -160,6 +160,7 @@ assert_evaluation_callback = function(callback, null_ok = FALSE) { #' @param callbacks (list of [CallbackEvaluation]). #' @rdname assert_evaluation_callback assert_evaluation_callbacks = function(callbacks, null_ok = FALSE) { + assert_list(callbacks, null.ok = null_ok) if (null_ok && is.null(callbacks)) return(invisible(NULL)) invisible(lapply(callbacks, assert_evaluation_callback)) } From 9be497fa609e247461de684ca6117bcbd82b5ed4 Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 19 Dec 2024 15:26:54 +0100 Subject: [PATCH 45/54] ... --- R/worker.R | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/R/worker.R b/R/worker.R index 948be5c7b..bdd6dd5c1 100644 --- a/R/worker.R +++ b/R/worker.R @@ -270,9 +270,9 @@ workhorse = function( call_back("on_evaluation_begin", callbacks, ctx) if (!is.null(pb)) { - pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) + pb(sprintf("%s|%s|i:%i", ctx$task$id, ctx$learner$id, ctx$iteration)) } - if ("internal_valid" %in% learner$predict_sets && is.null(task$internal_valid_task) && is.null(get0("validate", learner))) { + if ("internal_valid" %in% ctx$learner$predict_sets && is.null(ctx$task$internal_valid_task) && is.null(get0("validate", ctx$learner))) { stopf("Cannot set the predict_type field of learner '%s' to 'internal_valid' if there is no internal validation task configured", learner$id) } @@ -306,75 +306,75 @@ workhorse = function( } lg$info("%s learner '%s' on task '%s' (iter %i/%i)", - if (mode == "train") "Applying" else "Hotstarting", learner$id, task$id, iteration, resampling$iters) + if (mode == "train") "Applying" else "Hotstarting", ctx$learner$id, ctx$task$id, ctx$iteration, ctx$resampling$iters) sets = list( - train = resampling$train_set(iteration), - test = resampling$test_set(iteration) + train = ctx$resampling$train_set(ctx$iteration), + test = ctx$resampling$test_set(ctx$iteration) ) # train model - ctx$learner = learner = learner$clone() + ctx$learner = ctx$learner$clone() if (length(param_values)) { - learner$param_set$values = list() - learner$param_set$set_values(.values = param_values) + ctx$learner$param_set$values = list() + ctx$learner$param_set$set_values(.values = param_values) } - learner_hash = learner$hash + learner_hash = ctx$learner$hash - validate = get0("validate", learner) + validate = get0("validate", ctx$learner) test_set = if (identical(validate, "test")) sets$test call_back("on_evaluation_before_train", callbacks, ctx) - train_result = learner_train(learner, task, sets[["train"]], test_set, mode = mode) - ctx$learner = learner = train_result$learner + train_result = learner_train(ctx$learner, ctx$task, sets[["train"]], test_set, mode = mode) + ctx$learner = train_result$learner # process the model so it can be used for prediction (e.g. marshal for callr prediction), but also # keep a copy of the model in current form in case this is the format that we want to send back to the main process # and not the format that we need for prediction model_copy_or_null = process_model_before_predict( - learner = learner, store_models = store_models, is_sequential = is_sequential, unmarshal = unmarshal + learner = ctx$learner, store_models = store_models, is_sequential = is_sequential, unmarshal = unmarshal ) # predict for each set - predict_sets = learner$predict_sets + predict_sets = ctx$learner$predict_sets # creates the tasks and row_ids for all selected predict sets - pred_data = prediction_tasks_and_sets(task, train_result, validate, sets, predict_sets) + pred_data = prediction_tasks_and_sets(ctx$task, train_result, validate, sets, predict_sets) call_back("on_evaluation_before_predict", callbacks, ctx) pdatas = Map(function(set, row_ids, task) { lg$debug("Creating Prediction for predict set '%s'", set) - learner_predict(learner, task, row_ids) + learner_predict(ctx$learner, task, row_ids) }, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks) if (!length(predict_sets)) { - learner$state$predict_time = 0L + ctx$learner$state$predict_time = 0L } ctx$pdatas = discard(pdatas, is.null) # set the model slot after prediction so it can be sent back to the main process process_model_after_predict( - learner = learner, store_models = store_models, is_sequential = is_sequential, model_copy = model_copy_or_null, + learner = ctx$learner, store_models = store_models, is_sequential = is_sequential, model_copy = model_copy_or_null, unmarshal = unmarshal ) call_back("on_evaluation_end", callbacks, ctx) if (!store_models) { - lg$debug("Erasing stored model for learner '%s'", learner$id) - learner$state$model = NULL + lg$debug("Erasing stored model for learner '%s'", ctx$learner$id) + ctx$learner$state$model = NULL } - learner_state = set_class(learner$state, c("learner_state", "list")) + learner_state = set_class(ctx$learner$state, c("learner_state", "list")) list( learner_state = learner_state, - prediction = pdatas, - param_values = learner$param_set$values, + prediction = ctx$pdatas, + param_values = ctx$learner$param_set$values, learner_hash = learner_hash, data_extra = ctx$data_extra) } From f392d045ddf795e6719ebf5e62ebe897a42e1087 Mon Sep 17 00:00:00 2001 From: be-marc Date: Thu, 19 Dec 2024 16:04:00 +0100 Subject: [PATCH 46/54] ... --- R/CallbackEvaluation.R | 7 +------ man/callback_evaluation.Rd | 7 +------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/R/CallbackEvaluation.R b/R/CallbackEvaluation.R index e4b6d5e85..1cc925886 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackEvaluation.R @@ -97,12 +97,7 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' label = "Selected Features", #' #' on_evaluation_end = function(callback, context) { -#' pred = as_prediction(context$pdatas$test) -#' selected_features = pred$score( -#' measure = msr("selected_features"), -#' learner = context$learner, -#' task = context$task) -#' context$learner$state$selected_features = selected_features +#' context$learner$state$selected_features = context$learner$selected_features() #' } #' ) #' diff --git a/man/callback_evaluation.Rd b/man/callback_evaluation.Rd index c7959ce00..1f18d6ac6 100644 --- a/man/callback_evaluation.Rd +++ b/man/callback_evaluation.Rd @@ -83,12 +83,7 @@ callback = callback_evaluation("selected_features", label = "Selected Features", on_evaluation_end = function(callback, context) { - pred = as_prediction(context$pdatas$test) - selected_features = pred$score( - measure = msr("selected_features"), - learner = context$learner, - task = context$task) - context$learner$state$selected_features = selected_features + context$learner$state$selected_features = context$learner$selected_features() } ) From b1117cc612d5a6297a521046bf0d2ac46cf8bbb8 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 11:25:19 +0100 Subject: [PATCH 47/54] ... --- DESCRIPTION | 2 +- NAMESPACE | 10 +- ...allbackEvaluation.R => CallbackResample.R} | 105 +++++++++--------- R/ContextEvaluation.R | 89 +++++++++------ R/mlr_callbacks.R | 84 ++++++++++---- R/worker.R | 11 +- R/zzz.R | 5 +- man-roxygen/param_callbacks.R | 2 +- ...lbackEvaluation.Rd => CallbackResample.Rd} | 32 +++--- ...ontextEvaluation.Rd => ContextResample.Rd} | 50 ++++----- man/assert_evaluation_callback.Rd | 25 ----- man/assert_resample_callback.Rd | 25 +++++ man/benchmark.Rd | 2 +- ...ack_evaluation.Rd => callback_resample.Rd} | 51 ++++----- man/mlr3.holdout_set.Rd | 28 +++++ man/mlr3.model_extractor.Rd | 32 ++++++ man/mlr3.score_measures.Rd | 22 ---- man/resample.Rd | 2 +- pkgdown/_pkgdown.yml | 8 +- ...ckEvaluation.R => test_CallbackResample.R} | 40 +++---- tests/testthat/test_ContextEvaluation.R | 17 +++ 21 files changed, 378 insertions(+), 264 deletions(-) rename R/{CallbackEvaluation.R => CallbackResample.R} (56%) rename man/{CallbackEvaluation.Rd => CallbackResample.Rd} (71%) rename man/{ContextEvaluation.Rd => ContextResample.Rd} (62%) delete mode 100644 man/assert_evaluation_callback.Rd create mode 100644 man/assert_resample_callback.Rd rename man/{callback_evaluation.Rd => callback_resample.Rd} (64%) create mode 100644 man/mlr3.holdout_set.Rd create mode 100644 man/mlr3.model_extractor.Rd delete mode 100644 man/mlr3.score_measures.Rd rename tests/testthat/{test_CallbackEvaluation.R => test_CallbackResample.R} (81%) create mode 100644 tests/testthat/test_ContextEvaluation.R diff --git a/DESCRIPTION b/DESCRIPTION index 3a6ca41d7..b44d54702 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -79,7 +79,7 @@ RoxygenNote: 7.3.2 Collate: 'mlr_reflections.R' 'BenchmarkResult.R' - 'CallbackEvaluation.R' + 'CallbackResample.R' 'ContextEvaluation.R' 'warn_deprecated.R' 'DataBackend.R' diff --git a/NAMESPACE b/NAMESPACE index c536e1726..00f462a51 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -120,8 +120,8 @@ S3method(unmarshal_model,classif.debug_model_marshaled) S3method(unmarshal_model,default) S3method(unmarshal_model,learner_state_marshaled) export(BenchmarkResult) -export(CallbackEvaluation) -export(ContextEvaluation) +export(CallbackResample) +export(ContextResample) export(DataBackend) export(DataBackendDataTable) export(DataBackendMatrix) @@ -202,8 +202,6 @@ export(as_tasks) export(as_tasks_unsupervised) export(assert_backend) export(assert_benchmark_result) -export(assert_evaluation_callback) -export(assert_evaluation_callbacks) export(assert_learnable) export(assert_learner) export(assert_learners) @@ -211,6 +209,8 @@ export(assert_measure) export(assert_measures) export(assert_predictable) export(assert_prediction) +export(assert_resample_callback) +export(assert_resample_callbacks) export(assert_resample_result) export(assert_resampling) export(assert_resamplings) @@ -222,7 +222,7 @@ export(assert_validate) export(auto_convert) export(benchmark) export(benchmark_grid) -export(callback_evaluation) +export(callback_resample) export(check_prediction_data) export(clbk) export(clbks) diff --git a/R/CallbackEvaluation.R b/R/CallbackResample.R similarity index 56% rename from R/CallbackEvaluation.R rename to R/CallbackResample.R index 1cc925886..a7376cbac 100644 --- a/R/CallbackEvaluation.R +++ b/R/CallbackResample.R @@ -2,57 +2,57 @@ #' #' @description #' Specialized [mlr3misc::Callback] to customize the behavior of [resample()] and [benchmark()] in mlr3. -#' The [callback_evaluation()] function is used to create instances of this class. +#' The [callback_resample()] function is used to create instances of this class. #' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. -#' For more information on callbacks, see the [callback_evaluation()] documentation. +#' For more information on callbacks, see the [callback_resample()] documentation. #' #' @export -CallbackEvaluation= R6Class("CallbackEvaluation", +CallbackResample = R6Class("CallbackResample", inherit = Callback, public = list( - #' @field on_evaluation_begin (`function()`)\cr - #' Stage called at the beginning of an evaluation. + #' @field on_resample_begin (`function()`)\cr + #' Stage called at the beginning of the resampling iteration. #' Called in `workhorse()` (internal). - on_evaluation_begin = NULL, + on_resample_begin = NULL, - #' @field on_evaluation_before_train (`function()`)\cr + #' @field on_resample_before_train (`function()`)\cr #' Stage called before training the learner. #' Called in `workhorse()` (internal). - on_evaluation_before_train = NULL, + on_resample_before_train = NULL, - #' @field on_evaluation_before_predict (`function()`)\cr + #' @field on_resample_before_predict (`function()`)\cr #' Stage called before predicting. #' Called in `workhorse()` (internal). - on_evaluation_before_predict = NULL, + on_resample_before_predict = NULL, - #' @field on_evaluation_end (`function()`)\cr - #' Stage called at the end of an evaluation. + #' @field on_resample_end (`function()`)\cr + #' Stage called at the end of the resample iteration. #' Called in `workhorse()` (internal). - on_evaluation_end = NULL + on_resample_end = NULL ) ) #' @title Create Evaluation Callback #' #' @description -#' Function to create a [CallbackEvaluation]. +#' Function to create a [CallbackResample]. #' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. #' #' Evaluation callbacks are called at different stages of the resampling process. #' Each stage is called once per resampling iteration. -#' The stages are prefixed with `on_*`. -#' The text in brackets indicates what happens between the stages and which accesses to the [ContextEvaluation] (`ctx`) are typical for the stage. +#' The stages are prefixed with `on_resample_*`. +#' The text in brackets indicates what happens between the stages and which accesses to the [ContextResample] (`ctx`) are typical for the stage. #' #' ``` #' Start Resampling Iteration on Worker -#' - on_evaluation_begin +#' - on_resample_begin #' (Split `ctx$task` into training and test set with `ctx$resampling` and `ctx$iteration`) -#' - on_evaluation_before_train +#' - on_resample_before_train #' (Train the learner `ctx$learner` on training data) -#' - on_evaluation_before_predict +#' - on_resample_before_predict #' (Predict on predict sets and store prediction data `ctx$pdatas`) -#' - on_evaluation_end +#' - on_resample_end #' (Erase model `ctx$learner$model` if requested and return results) #' End Resampling Iteration on Worker #' ``` @@ -64,12 +64,7 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' @details #' When implementing a callback, each function must have two arguments named `callback` and `context`. #' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself. -#' -#' @section Parallelization: -#' Be careful when modifying `ctx$learner`, `ctx$task`, or `ctx$resampling` because callbacks can behave differently when parallelizing the resampling process. -#' When running the resampling process sequentially, the modifications are carried over to the next iteration. -#' When parallelizing the resampling process, modifying the [ContextEvaluation] will not be synchronized between workers. -#' This also applies to the `$state` of the callback. +#' We highly discourage changing the task, learner and resampling objects via the callback. #' #' @param id (`character(1)`)\cr #' Identifier for the new instance. @@ -78,25 +73,25 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' @param man (`character(1)`)\cr #' String in the format `[pkg]::[topic]` pointing to a manual page for this object. #' The referenced help package can be opened via method `$help()`. -#' @param on_evaluation_begin (`function()`)\cr +#' @param on_resample_begin (`function()`)\cr #' Stage called at the beginning of an evaluation. #' Called in `workhorse()` (internal). -#' @param on_evaluation_before_train (`function()`)\cr +#' @param on_resample_before_train (`function()`)\cr #' Stage called before training the learner. #' Called in `workhorse()` (internal). -#' @param on_evaluation_before_predict (`function()`)\cr +#' @param on_resample_before_predict (`function()`)\cr #' Stage called before predicting. #' Called in `workhorse()` (internal). -#' @param on_evaluation_end (`function()`)\cr +#' @param on_resample_end (`function()`)\cr #' Stage called at the end of an evaluation. #' Called in `workhorse()` (internal). #' #' @export #' @examples -#' callback = callback_evaluation("selected_features", +#' callback = callback_resample("selected_features", #' label = "Selected Features", #' -#' on_evaluation_end = function(callback, context) { +#' on_resample_end = function(callback, context) { #' context$learner$state$selected_features = context$learner$selected_features() #' } #' ) @@ -108,29 +103,29 @@ CallbackEvaluation= R6Class("CallbackEvaluation", #' rr = resample(task, learner, resampling, callbacks = callback) #' #' rr$learners[[1]]$state$selected_features -callback_evaluation = function( +callback_resample = function( id, label = NA_character_, man = NA_character_, - on_evaluation_begin = NULL, - on_evaluation_before_train = NULL, - on_evaluation_before_predict = NULL, - on_evaluation_end = NULL + on_resample_begin = NULL, + on_resample_before_train = NULL, + on_resample_before_predict = NULL, + on_resample_end = NULL ) { stages = discard(set_names(list( - on_evaluation_begin, - on_evaluation_before_train, - on_evaluation_before_predict, - on_evaluation_end), + on_resample_begin, + on_resample_before_train, + on_resample_before_predict, + on_resample_end), c( - "on_evaluation_begin", - "on_evaluation_before_train", - "on_evaluation_before_predict", - "on_evaluation_end" + "on_resample_begin", + "on_resample_before_train", + "on_resample_before_predict", + "on_resample_end" )), is.null) stages = map(stages, function(stage) crate(assert_function(stage, args = c("callback", "context")))) - callback = CallbackEvaluation$new(id, label, man) + callback = CallbackResample$new(id, label, man) iwalk(stages, function(stage, name) callback[[name]] = stage) callback } @@ -138,24 +133,24 @@ callback_evaluation = function( #' @title Assertions for Callbacks #' #' @description -#' Assertions for [CallbackEvaluation] class. +#' Assertions for [CallbackResample] class. #' -#' @param callback ([CallbackEvaluation]). +#' @param callback ([CallbackResample]). #' @param null_ok (`logical(1)`)\cr #' If `TRUE`, `NULL` is allowed. #' -#' @return [CallbackEvaluation | List of [CallbackEvaluation]s. +#' @return [CallbackResample | List of [CallbackResample]s. #' @export -assert_evaluation_callback = function(callback, null_ok = FALSE) { - assert_class(callback, "CallbackEvaluation", null.ok = null_ok) +assert_resample_callback = function(callback, null_ok = FALSE) { + assert_class(callback, "CallbackResample", null.ok = null_ok) invisible(callback) } #' @export -#' @param callbacks (list of [CallbackEvaluation]). -#' @rdname assert_evaluation_callback -assert_evaluation_callbacks = function(callbacks, null_ok = FALSE) { +#' @param callbacks (list of [CallbackResample]). +#' @rdname assert_resample_callback +assert_resample_callbacks = function(callbacks, null_ok = FALSE) { assert_list(callbacks, null.ok = null_ok) if (null_ok && is.null(callbacks)) return(invisible(NULL)) - invisible(lapply(callbacks, assert_evaluation_callback)) + invisible(lapply(callbacks, assert_resample_callback)) } diff --git a/R/ContextEvaluation.R b/R/ContextEvaluation.R index a0baf9862..0028b6700 100644 --- a/R/ContextEvaluation.R +++ b/R/ContextEvaluation.R @@ -1,39 +1,15 @@ #' @title Evaluation Context #' #' @description -#' A [CallbackEvaluation] accesses and modifies data during [resample()] and [benchmark()] via the `ContextEvaluation`. +#' A [CallbackResample] accesses and modifies data during [resample()] and [benchmark()] via the `ContextResample`. #' See the section on fields for a list of modifiable objects. -#' See [callback_evaluation()] for a list of stages that access `ContextEvaluation`. +#' See [callback_resample()] for a list of stages that access `ContextResample`. #' #' @export -ContextEvaluation = R6Class("ContextEvaluation", +ContextResample = R6Class("ContextResample", inherit = Context, public = list( - #' @field task ([Task])\cr - #' The task to be evaluated. - #' The task is unchanged during the evaluation. - task = NULL, - - #' @field learner ([Learner])\cr - #' The learner to be evaluated. - #' The learner contains the models after stage `on_evaluation_before_train`. - learner = NULL, - - #' @field resampling [Resampling]\cr - #' The resampling strategy to be used. - #' The resampling is unchanged during the evaluation. - resampling = NULL, - - #' @field iteration (`integer()`)\cr - #' The current iteration. - iteration = NULL, - - #' @field pdatas (List of [PredictionData])\cr - #' The prediction data. - #' The data is available on stage `on_evaluation_end`. - pdatas = NULL, - #' @description #' Creates a new instance of this [R6][R6::R6Class] class. #' @@ -47,10 +23,10 @@ ContextEvaluation = R6Class("ContextEvaluation", #' The current iteration. initialize = function(task, learner, resampling, iteration) { # no assertions to avoid overhead - self$task = task - self$learner = learner - self$resampling = resampling - self$iteration = iteration + private$.task = task + private$.learner = learner + private$.resampling = resampling + private$.iteration = iteration super$initialize(id = "evaluate", label = "Evaluation") } @@ -58,6 +34,52 @@ ContextEvaluation = R6Class("ContextEvaluation", active = list( + #' @field task ([Task])\cr + #' The task to be evaluated. + #' The task is unchanged during the evaluation. + #' The task is read-only. + task = function(rhs) { + assert_ro_binding(rhs) + private$.task + }, + + #' @field learner ([Learner])\cr + #' The learner to be evaluated. + #' The learner contains the models after stage `on_resample_before_train`. + learner = function(rhs) { + if (missing(rhs)) { + return(private$.learner) + } + private$.learner = assert_learner(rhs) + }, + + #' @field resampling [Resampling]\cr + #' The resampling strategy to be used. + #' The resampling is unchanged during the evaluation. + #' The resampling is read-only. + resampling = function(rhs) { + assert_ro_binding(rhs) + private$.resampling + }, + + #' @field iteration (`integer()`)\cr + #' The current iteration. + #' The iteration is read-only. + iteration = function(rhs) { + assert_ro_binding(rhs) + private$.iteration + }, + + #' @field pdatas (List of [PredictionData])\cr + #' The prediction data. + #' The data is available on stage `on_resample_end`. + pdatas = function(rhs) { + if (missing(rhs)) { + return(private$.pdatas) + } + private$.pdatas = assert_list(rhs, "PredictionData") + }, + #' @field data_extra (list())\cr #' Data saved in the [ResampleResult] or [BenchmarkResult]. #' Use this field to save results. @@ -71,6 +93,11 @@ ContextEvaluation = R6Class("ContextEvaluation", ), private = list( + .task = NULL, + .learner = NULL, + .resampling = NULL, + .iteration = NULL, + .pdatas = NULL, .data_extra = NULL ) ) diff --git a/R/mlr_callbacks.R b/R/mlr_callbacks.R index cad2bc0fa..912ffba81 100644 --- a/R/mlr_callbacks.R +++ b/R/mlr_callbacks.R @@ -1,39 +1,85 @@ -#' @title Score Measures Callback +#' @title Model Extractor Callback #' -#' @include CallbackEvaluation.R -#' @name mlr3.score_measures +#' @include CallbackResample.R +#' @name mlr3.model_extractor #' #' @description -#' This [CallbackEvaluation] scores measures directly on the worker. -#' This way measures that require a model can be scored without saving the model. +#' This [CallbackResample] extracts information from the model after training with a user-defined function. +#' This way information can be extracted from the model without saving the model (`store_models = FALSE`). +#' The `fun` must be a function that takes a learner as input and returns the extracted information as named list (see example). +#' The callback is very helpful to call `$selected_features()`, `$importance()`, `$oob_error()` on the learner. #' -#' @examples -#' clbk("mlr3.score_measures", measures = msr("classif.ce")) +#' @param fun (`function(learner)`)\cr +#' Function to extract information from the learner. +#' The function must have the argument `learner`. +#' The function must return a named list. #' +#' @examples #' task = tsk("pima") #' learner = lrn("classif.rpart") #' resampling = rsmp("cv", folds = 3) #' -#' callback = clbk("mlr3.score_measures", measures = msr("selected_features")) +#' # define function to extract selected features +#' selected_features = function(learner) list(selected_features = learner$selected_features()) +#' +#' # create callback +#' callback = clbk("mlr3.model_extractor", fun = selected_features) #' #' rr = resample(task, learner, resampling = resampling, store_models = FALSE, callbacks = callback) #' #' rr$data_extra NULL -load_callback_score_measures = function() { - callback_evaluation("mlr3.score_measures", - label = "Score Measures Callback", - man = "mlr3::mlr3.score_measures", +load_callback_model_extractor = function() { + callback_resample("mlr3.model_extractor", + label = "Model Extractor Callback", + man = "mlr3::mlr3.model_extractor", + + on_resample_end = function(callback, context) { + assert_function(callback$state$fun, args = "learner") + context$data_extra = invoke(callback$state$fun, learner = context$learner) + } + ) +} + +#' @title Callback Holdout Task +#' +#' @include CallbackResample.R +#' @name mlr3.holdout_set +#' +#' @description +#' This [CallbackResample] predicts on an additional holdout task after training. +#' +#' @param task ([Task])\cr +#' The holdout task. +#' +#' @examples +#' task = tsk("pima") +#' task_holdout = task$clone() +#' learner = lrn("classif.rpart") +#' resampling = rsmp("cv", folds = 3) +#' splits = partition(task, 0.7) +#' +#' task$filter(splits$train) +#' task_holdout$filter(splits$test) +#' +#' callback = clbk("mlr3.holdout_task", task = task_holdout) +#' +#' rr = resample(task, learner, resampling = resampling, callbacks = callback) +#' +#' rr$data_extra +NULL + +load_callback_holdout_task = function() { + callback_resample("mlr3.holdout_task", + label = "Callback Holdout Task", + man = "mlr3::mlr3.holdout_task", - on_evaluation_end = function(callback, context) { - measures = as_measures(callback$state$measures) + on_resample_before_predict = function(callback, context) { + assert_task(callback$state$task) - # Score measures on the test set - pred = as_prediction(context$pdatas$test) - context$data_extra = insert_named(context$data_extra, list( - score_measures = pred$score(measures, context$task, context$learner) - )) + pred = context$learner$predict(callback$state$task) + context$data_extra = list(prediction_holdout = pred) } ) } diff --git a/R/worker.R b/R/worker.R index bdd6dd5c1..bfb985c06 100644 --- a/R/worker.R +++ b/R/worker.R @@ -265,9 +265,9 @@ workhorse = function( unmarshal = TRUE, callbacks = NULL ) { - ctx = ContextEvaluation$new(task, learner, resampling, iteration) + ctx = ContextResample$new(task, learner, resampling, iteration) - call_back("on_evaluation_begin", callbacks, ctx) + call_back("on_resample_begin", callbacks, ctx) if (!is.null(pb)) { pb(sprintf("%s|%s|i:%i", ctx$task$id, ctx$learner$id, ctx$iteration)) @@ -325,7 +325,7 @@ workhorse = function( test_set = if (identical(validate, "test")) sets$test - call_back("on_evaluation_before_train", callbacks, ctx) + call_back("on_resample_before_train", callbacks, ctx) train_result = learner_train(ctx$learner, ctx$task, sets[["train"]], test_set, mode = mode) ctx$learner = train_result$learner @@ -342,8 +342,7 @@ workhorse = function( # creates the tasks and row_ids for all selected predict sets pred_data = prediction_tasks_and_sets(ctx$task, train_result, validate, sets, predict_sets) - - call_back("on_evaluation_before_predict", callbacks, ctx) + call_back("on_resample_before_predict", callbacks, ctx) pdatas = Map(function(set, row_ids, task) { lg$debug("Creating Prediction for predict set '%s'", set) @@ -362,7 +361,7 @@ workhorse = function( unmarshal = unmarshal ) - call_back("on_evaluation_end", callbacks, ctx) + call_back("on_resample_end", callbacks, ctx) if (!store_models) { lg$debug("Erasing stored model for learner '%s'", ctx$learner$id) diff --git a/R/zzz.R b/R/zzz.R index d671c6746..b48ff6eef 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -74,9 +74,10 @@ dummy_import = function() { # nocov start backports::import(pkgname) - # callbacks + # callbacks x = utils::getFromNamespace("mlr_callbacks", ns = "mlr3misc") - x$add("mlr3.score_measures", load_callback_score_measures) + x$add("mlr3.model_extractor", load_callback_model_extractor) + x$add("mlr3.holdout_task", load_callback_holdout_task) # setup logger lg = lgr::get_logger(pkgname) diff --git a/man-roxygen/param_callbacks.R b/man-roxygen/param_callbacks.R index cdb286953..85bb82891 100644 --- a/man-roxygen/param_callbacks.R +++ b/man-roxygen/param_callbacks.R @@ -1,3 +1,3 @@ #' @param callbacks (List of [mlr3misc::Callback])\cr #' Callbacks to be executed during the resampling process. -#' See [CallbackEvaluation] and [ContextEvaluation] for details. +#' See [CallbackResample] and [ContextResample] for details. diff --git a/man/CallbackEvaluation.Rd b/man/CallbackResample.Rd similarity index 71% rename from man/CallbackEvaluation.Rd rename to man/CallbackResample.Rd index bd075c53c..845fb8035 100644 --- a/man/CallbackEvaluation.Rd +++ b/man/CallbackResample.Rd @@ -1,34 +1,34 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/CallbackEvaluation.R -\name{CallbackEvaluation} -\alias{CallbackEvaluation} +% Please edit documentation in R/CallbackResample.R +\name{CallbackResample} +\alias{CallbackResample} \title{Evaluation Callback} \description{ Specialized \link[mlr3misc:Callback]{mlr3misc::Callback} to customize the behavior of \code{\link[=resample]{resample()}} and \code{\link[=benchmark]{benchmark()}} in mlr3. -The \code{\link[=callback_evaluation]{callback_evaluation()}} function is used to create instances of this class. +The \code{\link[=callback_resample]{callback_resample()}} function is used to create instances of this class. Predefined callbacks are stored in the \link[mlr3misc:Dictionary]{dictionary} \link{mlr_callbacks} and can be retrieved with \code{\link[=clbk]{clbk()}}. -For more information on callbacks, see the \code{\link[=callback_evaluation]{callback_evaluation()}} documentation. +For more information on callbacks, see the \code{\link[=callback_resample]{callback_resample()}} documentation. } \section{Super class}{ -\code{\link[mlr3misc:Callback]{mlr3misc::Callback}} -> \code{CallbackEvaluation} +\code{\link[mlr3misc:Callback]{mlr3misc::Callback}} -> \code{CallbackResample} } \section{Public fields}{ \if{html}{\out{
}} \describe{ -\item{\code{on_evaluation_begin}}{(\verb{function()})\cr -Stage called at the beginning of an evaluation. +\item{\code{on_resample_begin}}{(\verb{function()})\cr +Stage called at the beginning of the resampling iteration. Called in \code{workhorse()} (internal).} -\item{\code{on_evaluation_before_train}}{(\verb{function()})\cr +\item{\code{on_resample_before_train}}{(\verb{function()})\cr Stage called before training the learner. Called in \code{workhorse()} (internal).} -\item{\code{on_evaluation_before_predict}}{(\verb{function()})\cr +\item{\code{on_resample_before_predict}}{(\verb{function()})\cr Stage called before predicting. Called in \code{workhorse()} (internal).} -\item{\code{on_evaluation_end}}{(\verb{function()})\cr -Stage called at the end of an evaluation. +\item{\code{on_resample_end}}{(\verb{function()})\cr +Stage called at the end of the resample iteration. Called in \code{workhorse()} (internal).} } \if{html}{\out{
}} @@ -36,7 +36,7 @@ Called in \code{workhorse()} (internal).} \section{Methods}{ \subsection{Public methods}{ \itemize{ -\item \href{#method-CallbackEvaluation-clone}{\code{CallbackEvaluation$clone()}} +\item \href{#method-CallbackResample-clone}{\code{CallbackResample$clone()}} } } \if{html}{\out{ @@ -51,12 +51,12 @@ Called in \code{workhorse()} (internal).} }} \if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-CallbackEvaluation-clone}{}}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-CallbackResample-clone}{}}} \subsection{Method \code{clone()}}{ The objects of this class are cloneable with this method. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{CallbackEvaluation$clone(deep = FALSE)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{CallbackResample$clone(deep = FALSE)}\if{html}{\out{
}} } \subsection{Arguments}{ diff --git a/man/ContextEvaluation.Rd b/man/ContextResample.Rd similarity index 62% rename from man/ContextEvaluation.Rd rename to man/ContextResample.Rd index 1091599e2..e679fc68f 100644 --- a/man/ContextEvaluation.Rd +++ b/man/ContextResample.Rd @@ -1,43 +1,41 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/ContextEvaluation.R -\name{ContextEvaluation} -\alias{ContextEvaluation} +\name{ContextResample} +\alias{ContextResample} \title{Evaluation Context} \description{ -A \link{CallbackEvaluation} accesses and modifies data during \code{\link[=resample]{resample()}} and \code{\link[=benchmark]{benchmark()}} via the \code{ContextEvaluation}. +A \link{CallbackResample} accesses and modifies data during \code{\link[=resample]{resample()}} and \code{\link[=benchmark]{benchmark()}} via the \code{ContextResample}. See the section on fields for a list of modifiable objects. -See \code{\link[=callback_evaluation]{callback_evaluation()}} for a list of stages that access \code{ContextEvaluation}. +See \code{\link[=callback_resample]{callback_resample()}} for a list of stages that access \code{ContextResample}. } \section{Super class}{ -\code{\link[mlr3misc:Context]{mlr3misc::Context}} -> \code{ContextEvaluation} +\code{\link[mlr3misc:Context]{mlr3misc::Context}} -> \code{ContextResample} } -\section{Public fields}{ -\if{html}{\out{
}} +\section{Active bindings}{ +\if{html}{\out{
}} \describe{ \item{\code{task}}{(\link{Task})\cr The task to be evaluated. -The task is unchanged during the evaluation.} +The task is unchanged during the evaluation. +The task is read-only.} \item{\code{learner}}{(\link{Learner})\cr The learner to be evaluated. -The learner contains the models after stage \code{on_evaluation_before_train}.} +The learner contains the models after stage \code{on_resample_before_train}.} \item{\code{resampling}}{\link{Resampling}\cr The resampling strategy to be used. -The resampling is unchanged during the evaluation.} +The resampling is unchanged during the evaluation. +The resampling is read-only.} \item{\code{iteration}}{(\code{integer()})\cr -The current iteration.} +The current iteration. +The iteration is read-only.} \item{\code{pdatas}}{(List of \link{PredictionData})\cr The prediction data. -The data is available on stage \code{on_evaluation_end}.} -} -\if{html}{\out{
}} -} -\section{Active bindings}{ -\if{html}{\out{
}} -\describe{ +The data is available on stage \code{on_resample_end}.} + \item{\code{data_extra}}{(list())\cr Data saved in the \link{ResampleResult} or \link{BenchmarkResult}. Use this field to save results. @@ -48,8 +46,8 @@ Must be a \code{list()}.} \section{Methods}{ \subsection{Public methods}{ \itemize{ -\item \href{#method-ContextEvaluation-new}{\code{ContextEvaluation$new()}} -\item \href{#method-ContextEvaluation-clone}{\code{ContextEvaluation$clone()}} +\item \href{#method-ContextResample-new}{\code{ContextResample$new()}} +\item \href{#method-ContextResample-clone}{\code{ContextResample$clone()}} } } \if{html}{\out{ @@ -61,12 +59,12 @@ Must be a \code{list()}.} }} \if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-ContextEvaluation-new}{}}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ContextResample-new}{}}} \subsection{Method \code{new()}}{ Creates a new instance of this \link[R6:R6Class]{R6} class. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ContextEvaluation$new(task, learner, resampling, iteration)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{ContextResample$new(task, learner, resampling, iteration)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -88,12 +86,12 @@ The current iteration.} } } \if{html}{\out{
}} -\if{html}{\out{}} -\if{latex}{\out{\hypertarget{method-ContextEvaluation-clone}{}}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-ContextResample-clone}{}}} \subsection{Method \code{clone()}}{ The objects of this class are cloneable with this method. \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{ContextEvaluation$clone(deep = FALSE)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{ContextResample$clone(deep = FALSE)}\if{html}{\out{
}} } \subsection{Arguments}{ diff --git a/man/assert_evaluation_callback.Rd b/man/assert_evaluation_callback.Rd deleted file mode 100644 index 31ffa4b43..000000000 --- a/man/assert_evaluation_callback.Rd +++ /dev/null @@ -1,25 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/CallbackEvaluation.R -\name{assert_evaluation_callback} -\alias{assert_evaluation_callback} -\alias{assert_evaluation_callbacks} -\title{Assertions for Callbacks} -\usage{ -assert_evaluation_callback(callback, null_ok = FALSE) - -assert_evaluation_callbacks(callbacks, null_ok = FALSE) -} -\arguments{ -\item{callback}{(\link{CallbackEvaluation}).} - -\item{null_ok}{(\code{logical(1)})\cr -If \code{TRUE}, \code{NULL} is allowed.} - -\item{callbacks}{(list of \link{CallbackEvaluation}).} -} -\value{ -[CallbackEvaluation | List of \link{CallbackEvaluation}s. -} -\description{ -Assertions for \link{CallbackEvaluation} class. -} diff --git a/man/assert_resample_callback.Rd b/man/assert_resample_callback.Rd new file mode 100644 index 000000000..89589bc98 --- /dev/null +++ b/man/assert_resample_callback.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/CallbackResample.R +\name{assert_resample_callback} +\alias{assert_resample_callback} +\alias{assert_resample_callbacks} +\title{Assertions for Callbacks} +\usage{ +assert_resample_callback(callback, null_ok = FALSE) + +assert_resample_callbacks(callbacks, null_ok = FALSE) +} +\arguments{ +\item{callback}{(\link{CallbackResample}).} + +\item{null_ok}{(\code{logical(1)})\cr +If \code{TRUE}, \code{NULL} is allowed.} + +\item{callbacks}{(list of \link{CallbackResample}).} +} +\value{ +[CallbackResample | List of \link{CallbackResample}s. +} +\description{ +Assertions for \link{CallbackResample} class. +} diff --git a/man/benchmark.Rd b/man/benchmark.Rd index 9f53cecfd..23e9063c7 100644 --- a/man/benchmark.Rd +++ b/man/benchmark.Rd @@ -67,7 +67,7 @@ If \code{FALSE}, all learners (that need marshaling) are stored in marshaled for \item{callbacks}{(List of \link[mlr3misc:Callback]{mlr3misc::Callback})\cr Callbacks to be executed during the resampling process. -See \link{CallbackEvaluation} and \link{ContextEvaluation} for details.} +See \link{CallbackResample} and \link{ContextResample} for details.} } \value{ \link{BenchmarkResult}. diff --git a/man/callback_evaluation.Rd b/man/callback_resample.Rd similarity index 64% rename from man/callback_evaluation.Rd rename to man/callback_resample.Rd index 1f18d6ac6..16a55fdf2 100644 --- a/man/callback_evaluation.Rd +++ b/man/callback_resample.Rd @@ -1,17 +1,17 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/CallbackEvaluation.R -\name{callback_evaluation} -\alias{callback_evaluation} +% Please edit documentation in R/CallbackResample.R +\name{callback_resample} +\alias{callback_resample} \title{Create Evaluation Callback} \usage{ -callback_evaluation( +callback_resample( id, label = NA_character_, man = NA_character_, - on_evaluation_begin = NULL, - on_evaluation_before_train = NULL, - on_evaluation_before_predict = NULL, - on_evaluation_end = NULL + on_resample_begin = NULL, + on_resample_before_train = NULL, + on_resample_before_predict = NULL, + on_resample_end = NULL ) } \arguments{ @@ -25,39 +25,39 @@ Label for the new instance.} String in the format \verb{[pkg]::[topic]} pointing to a manual page for this object. The referenced help package can be opened via method \verb{$help()}.} -\item{on_evaluation_begin}{(\verb{function()})\cr +\item{on_resample_begin}{(\verb{function()})\cr Stage called at the beginning of an evaluation. Called in \code{workhorse()} (internal).} -\item{on_evaluation_before_train}{(\verb{function()})\cr +\item{on_resample_before_train}{(\verb{function()})\cr Stage called before training the learner. Called in \code{workhorse()} (internal).} -\item{on_evaluation_before_predict}{(\verb{function()})\cr +\item{on_resample_before_predict}{(\verb{function()})\cr Stage called before predicting. Called in \code{workhorse()} (internal).} -\item{on_evaluation_end}{(\verb{function()})\cr +\item{on_resample_end}{(\verb{function()})\cr Stage called at the end of an evaluation. Called in \code{workhorse()} (internal).} } \description{ -Function to create a \link{CallbackEvaluation}. +Function to create a \link{CallbackResample}. Predefined callbacks are stored in the \link[mlr3misc:Dictionary]{dictionary} \link{mlr_callbacks} and can be retrieved with \code{\link[=clbk]{clbk()}}. Evaluation callbacks are called at different stages of the resampling process. Each stage is called once per resampling iteration. -The stages are prefixed with \verb{on_*}. -The text in brackets indicates what happens between the stages and which accesses to the \link{ContextEvaluation} (\code{ctx}) are typical for the stage. +The stages are prefixed with \verb{on_resample_*}. +The text in brackets indicates what happens between the stages and which accesses to the \link{ContextResample} (\code{ctx}) are typical for the stage. \if{html}{\out{
}}\preformatted{Start Resampling Iteration on Worker - - on_evaluation_begin + - on_resample_begin (Split `ctx$task` into training and test set with `ctx$resampling` and `ctx$iteration`) - - on_evaluation_before_train + - on_resample_before_train (Train the learner `ctx$learner` on training data) - - on_evaluation_before_predict + - on_resample_before_predict (Predict on predict sets and store prediction data `ctx$pdatas`) - - on_evaluation_end + - on_resample_end (Erase model `ctx$learner$model` if requested and return results) End Resampling Iteration on Worker }\if{html}{\out{
}} @@ -69,20 +69,13 @@ See also the section on parameters for more information on the stages. \details{ When implementing a callback, each function must have two arguments named \code{callback} and \code{context}. A callback can write data to the state (\verb{$state}), e.g. settings that affect the callback itself. +We highly discourage changing the task, learner and resampling objects via the callback. } -\section{Parallelization}{ - -Be careful when modifying \code{ctx$learner}, \code{ctx$task}, or \code{ctx$resampling} because callbacks can behave differently when parallelizing the resampling process. -When running the resampling process sequentially, the modifications are carried over to the next iteration. -When parallelizing the resampling process, modifying the \link{ContextEvaluation} will not be synchronized between workers. -This also applies to the \verb{$state} of the callback. -} - \examples{ -callback = callback_evaluation("selected_features", +callback = callback_resample("selected_features", label = "Selected Features", - on_evaluation_end = function(callback, context) { + on_resample_end = function(callback, context) { context$learner$state$selected_features = context$learner$selected_features() } ) diff --git a/man/mlr3.holdout_set.Rd b/man/mlr3.holdout_set.Rd new file mode 100644 index 000000000..8a61993b5 --- /dev/null +++ b/man/mlr3.holdout_set.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/mlr_callbacks.R +\name{mlr3.holdout_set} +\alias{mlr3.holdout_set} +\title{Callback Holdout Task} +\arguments{ +\item{task}{(\link{Task})\cr +The holdout task.} +} +\description{ +This \link{CallbackResample} predicts on an additional holdout task after training. +} +\examples{ +task = tsk("pima") +task_holdout = task$clone() +learner = lrn("classif.rpart") +resampling = rsmp("cv", folds = 3) +splits = partition(task, 0.7) + +task$filter(splits$train) +task_holdout$filter(splits$test) + +callback = clbk("mlr3.holdout_task", task = task_holdout) + +rr = resample(task, learner, resampling = resampling, callbacks = callback) + +rr$data_extra +} diff --git a/man/mlr3.model_extractor.Rd b/man/mlr3.model_extractor.Rd new file mode 100644 index 000000000..23806fdc1 --- /dev/null +++ b/man/mlr3.model_extractor.Rd @@ -0,0 +1,32 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/mlr_callbacks.R +\name{mlr3.model_extractor} +\alias{mlr3.model_extractor} +\title{Model Extractor Callback} +\arguments{ +\item{fun}{(\verb{function(learner)})\cr +Function to extract information from the learner. +The function must have the argument \code{learner}. +The function must return a named list.} +} +\description{ +This \link{CallbackResample} extracts information from the model after training with a user-defined function. +This way information can be extracted from the model without saving the model (\code{store_models = FALSE}). +The \code{fun} must be a function that takes a learner as input and returns the extracted information as named list (see example). +The callback is very helpful to call \verb{$selected_features()}, \verb{$importance()}, \verb{$oob_error()} on the learner. +} +\examples{ +task = tsk("pima") +learner = lrn("classif.rpart") +resampling = rsmp("cv", folds = 3) + +# define function to extract selected features +selected_features = function(learner) list(selected_features = learner$selected_features()) + +# create callback +callback = clbk("mlr3.model_extractor", fun = selected_features) + +rr = resample(task, learner, resampling = resampling, store_models = FALSE, callbacks = callback) + +rr$data_extra +} diff --git a/man/mlr3.score_measures.Rd b/man/mlr3.score_measures.Rd deleted file mode 100644 index a40e1686f..000000000 --- a/man/mlr3.score_measures.Rd +++ /dev/null @@ -1,22 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/mlr_callbacks.R -\name{mlr3.score_measures} -\alias{mlr3.score_measures} -\title{Score Measures Callback} -\description{ -This \link{CallbackEvaluation} scores measures directly on the worker. -This way measures that require a model can be scored without saving the model. -} -\examples{ -clbk("mlr3.score_measures", measures = msr("classif.ce")) - -task = tsk("pima") -learner = lrn("classif.rpart") -resampling = rsmp("cv", folds = 3) - -callback = clbk("mlr3.score_measures", measures = msr("selected_features")) - -rr = resample(task, learner, resampling = resampling, store_models = FALSE, callbacks = callback) - -rr$data_extra -} diff --git a/man/resample.Rd b/man/resample.Rd index 41340d299..d5e10af0d 100644 --- a/man/resample.Rd +++ b/man/resample.Rd @@ -68,7 +68,7 @@ If \code{FALSE}, all learners (that need marshaling) are stored in marshaled for \item{callbacks}{(List of \link[mlr3misc:Callback]{mlr3misc::Callback})\cr Callbacks to be executed during the resampling process. -See \link{CallbackEvaluation} and \link{ContextEvaluation} for details.} +See \link{CallbackResample} and \link{ContextResample} for details.} } \value{ \link{ResampleResult}. diff --git a/pkgdown/_pkgdown.yml b/pkgdown/_pkgdown.yml index 57e76ebe5..b086f8365 100644 --- a/pkgdown/_pkgdown.yml +++ b/pkgdown/_pkgdown.yml @@ -113,10 +113,10 @@ reference: - set_threads - title: Callbacks contents: - - CallbackEvaluation - - ContextEvaluation - - callback_evaluation - - assert_evaluation_callback + - CallbackResample + - ContextResample + - callback_resample + - assert_resample_callback - mlr3.score_measures - title: Internal Objects and Functions contents: diff --git a/tests/testthat/test_CallbackEvaluation.R b/tests/testthat/test_CallbackResample.R similarity index 81% rename from tests/testthat/test_CallbackEvaluation.R rename to tests/testthat/test_CallbackResample.R index 13ad6eb60..248a7a3b1 100644 --- a/tests/testthat/test_CallbackEvaluation.R +++ b/tests/testthat/test_CallbackResample.R @@ -1,11 +1,11 @@ -test_that("on_evaluation_begin works", { +test_that("on_resample_begin works", { task = tsk("pima") learner = lrn("classif.rpart") resampling = rsmp("cv", folds = 3) - callback = callback_evaluation("test", + callback = callback_resample("test", - on_evaluation_begin = function(callback, context) { + on_resample_begin = function(callback, context) { # expect_* does not work assert_task(context$task) assert_learner(context$learner) @@ -18,14 +18,14 @@ test_that("on_evaluation_begin works", { expect_resample_result(resample(task, learner, resampling, callbacks = callback)) }) -test_that("on_evaluation_before_train works", { +test_that("on_resample_before_train works", { task = tsk("pima") learner = lrn("classif.rpart") resampling = rsmp("cv", folds = 3) - callback = callback_evaluation("test", + callback = callback_resample("test", - on_evaluation_before_train = function(callback, context) { + on_resample_before_train = function(callback, context) { assert_task(context$task) assert_learner(context$learner) assert_resampling(context$resampling) @@ -38,14 +38,14 @@ test_that("on_evaluation_before_train works", { }) -test_that("on_evaluation_before_predict works", { +test_that("on_resample_before_predict works", { task = tsk("pima") learner = lrn("classif.rpart") resampling = rsmp("cv", folds = 3) - callback = callback_evaluation("test", + callback = callback_resample("test", - on_evaluation_before_predict = function(callback, context) { + on_resample_before_predict = function(callback, context) { assert_task(context$task) assert_learner(context$learner) assert_resampling(context$resampling) @@ -56,14 +56,14 @@ test_that("on_evaluation_before_predict works", { expect_resample_result(resample(task, learner, resampling, callbacks = callback)) }) -test_that("on_evaluation_end works", { +test_that("on_resample_end works", { task = tsk("pima") learner = lrn("classif.rpart") resampling = rsmp("cv", folds = 3) - callback = callback_evaluation("test", + callback = callback_resample("test", - on_evaluation_end = function(callback, context) { + on_resample_end = function(callback, context) { assert_task(context$task) assert_learner(context$learner) assert_resampling(context$resampling) @@ -80,8 +80,8 @@ test_that("writing to learner$state works", { learner = lrn("classif.rpart") resampling = rsmp("cv", folds = 3) - callback = callback_evaluation("test", - on_evaluation_end = function(callback, context) { + callback = callback_resample("test", + on_resample_end = function(callback, context) { context$learner$state$test = 1 } ) @@ -106,8 +106,8 @@ test_that("writing to data_extra works", { learner = lrn("classif.rpart") resampling = rsmp("cv", folds = 3) - callback = callback_evaluation("test", - on_evaluation_end = function(callback, context) { + callback = callback_resample("test", + on_resample_end = function(callback, context) { context$data_extra$test = 1 } ) @@ -139,8 +139,8 @@ test_that("data_extra is a list column", { learner = lrn("classif.rpart") resampling = rsmp("holdout") - callback = callback_evaluation("test", - on_evaluation_end = function(callback, context) { + callback = callback_resample("test", + on_resample_end = function(callback, context) { context$data_extra$test = 1 } ) @@ -160,8 +160,8 @@ test_that("data_extra is null", { learner = lrn("classif.rpart") resampling = rsmp("cv", folds = 3) - callback = callback_evaluation("test", - on_evaluation_end = function(callback, context) { + callback = callback_resample("test", + on_resample_end = function(callback, context) { context$learner$state$test = 1 } ) diff --git a/tests/testthat/test_ContextEvaluation.R b/tests/testthat/test_ContextEvaluation.R new file mode 100644 index 000000000..4166cbf8d --- /dev/null +++ b/tests/testthat/test_ContextEvaluation.R @@ -0,0 +1,17 @@ +test_that("ContextResample works", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + iteration = 1 + + ctx = ContextResample$new(task, learner, resampling, iteration) + + expect_task(ctx$task) + expect_learner(ctx$learner) + expect_resampling(ctx$resampling) + expect_equal(ctx$iteration, iteration) + + expect_error({ctx$task = tsk("spam")}, "read-only") + expect_error({ctx$resampling = rsmp("cv", folds = 5)}, "read-only") + expect_error({ctx$iteration = 2}, "read-only") +}) From de8d9ad1e337ba637d990bd9daac54c27a4ff1bf Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 11:31:32 +0100 Subject: [PATCH 48/54] ... --- R/BenchmarkResult.R | 6 +++--- R/ResampleResult.R | 6 +++--- man/BenchmarkResult.Rd | 2 +- man/ResampleResult.Rd | 2 +- tests/testthat/test_CallbackResample.R | 16 ++++++++-------- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index 3ef7a7713..e60d3e204 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -19,7 +19,7 @@ #' @template param_measures #' #' @section S3 Methods: -#' * `as.data.table(rr, ..., reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", task_characteristics = FALSE, data_extra = FALSE)`\cr +#' * `as.data.table(rr, ..., reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", task_characteristics = FALSE)`\cr #' [BenchmarkResult] -> [data.table::data.table()]\cr #' Returns a tabular view of the internal data. #' * `c(...)`\cr @@ -545,11 +545,11 @@ BenchmarkResult = R6Class("BenchmarkResult", ) #' @export -as.data.table.BenchmarkResult = function(x, ..., hashes = FALSE, predict_sets = "test", task_characteristics = FALSE, data_extra = FALSE) { # nolint +as.data.table.BenchmarkResult = function(x, ..., hashes = FALSE, predict_sets = "test", task_characteristics = FALSE) { # nolint assert_flag(task_characteristics) tab = get_private(x)$.data$as_data_table(view = NULL, predict_sets = predict_sets) cns = c("uhash", "task", "learner", "resampling", "iteration", "prediction") - if (data_extra && "data_extra" %in% names(tab)) cns = c(cns, "data_extra") + if ("data_extra" %in% names(tab)) cns = c(cns, "data_extra") tab = tab[, cns, with = FALSE] if (task_characteristics) { diff --git a/R/ResampleResult.R b/R/ResampleResult.R index d6bd48f41..92a2a16d3 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -13,7 +13,7 @@ #' @template param_measures #' #' @section S3 Methods: -#' * `as.data.table(rr, reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", data_extra = FALSE)`\cr +#' * `as.data.table(rr, reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test")`\cr #' [ResampleResult] -> [data.table::data.table()]\cr #' Returns a tabular view of the internal data. #' * `c(...)`\cr @@ -377,11 +377,11 @@ ResampleResult = R6Class("ResampleResult", ) #' @export -as.data.table.ResampleResult = function(x, ..., predict_sets = "test", data_extra = FALSE) { # nolint +as.data.table.ResampleResult = function(x, ..., predict_sets = "test") { # nolint private = get_private(x) tab = private$.data$as_data_table(view = private$.view, predict_sets = predict_sets) cns = c("task", "learner", "resampling", "iteration", "prediction") - if (data_extra && "data_extra" %in% names(tab)) cns = c(cns, "data_extra") + if ("data_extra" %in% names(tab)) cns = c(cns, "data_extra") tab[, cns, with = FALSE] } diff --git a/man/BenchmarkResult.Rd b/man/BenchmarkResult.Rd index 32413532b..0e1f18d9a 100644 --- a/man/BenchmarkResult.Rd +++ b/man/BenchmarkResult.Rd @@ -20,7 +20,7 @@ Do not modify any extracted object without cloning it first. \section{S3 Methods}{ \itemize{ -\item \code{as.data.table(rr, ..., reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", task_characteristics = FALSE, data_extra = FALSE)}\cr +\item \code{as.data.table(rr, ..., reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", task_characteristics = FALSE)}\cr \link{BenchmarkResult} -> \code{\link[data.table:data.table]{data.table::data.table()}}\cr Returns a tabular view of the internal data. \item \code{c(...)}\cr diff --git a/man/ResampleResult.Rd b/man/ResampleResult.Rd index d7061a6bb..39b584d6f 100644 --- a/man/ResampleResult.Rd +++ b/man/ResampleResult.Rd @@ -14,7 +14,7 @@ Do not modify any object without cloning it first. \section{S3 Methods}{ \itemize{ -\item \code{as.data.table(rr, reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test", data_extra = FALSE)}\cr +\item \code{as.data.table(rr, reassemble_learners = TRUE, convert_predictions = TRUE, predict_sets = "test")}\cr \link{ResampleResult} -> \code{\link[data.table:data.table]{data.table::data.table()}}\cr Returns a tabular view of the internal data. \item \code{c(...)}\cr diff --git a/tests/testthat/test_CallbackResample.R b/tests/testthat/test_CallbackResample.R index 248a7a3b1..29f7da144 100644 --- a/tests/testthat/test_CallbackResample.R +++ b/tests/testthat/test_CallbackResample.R @@ -119,14 +119,14 @@ test_that("writing to data_extra works", { }) # resample result data.table - tab = as.data.table(rr, data_extra = TRUE) + tab = as.data.table(rr) expect_data_table(tab) expect_names(names(tab), must.include = "data_extra") # benchmark data.table design = benchmark_grid(task, learner, resampling) bmr = benchmark(design, callbacks = callback) - tab = as.data.table(bmr, data_extra = TRUE) + tab = as.data.table(bmr) expect_names(names(tab), must.include = "data_extra") expect_list(tab$data_extra) walk(tab$data_extra, function(x) { @@ -146,13 +146,13 @@ test_that("data_extra is a list column", { ) rr = resample(task, learner, resampling, callbacks = callback) - expect_list(as.data.table(rr, data_extra = TRUE)$data_extra, len = 1) - expect_list(as.data.table(rr, data_extra = TRUE)$data_extra[[1]], len = 1) + expect_list(as.data.table(rr)$data_extra, len = 1) + expect_list(as.data.table(rr)$data_extra[[1]], len = 1) resampling = rsmp("cv", folds = 3) rr = resample(task, learner, resampling, callbacks = callback) - expect_list(as.data.table(rr, data_extra = TRUE)$data_extra, len = 3) - expect_list(as.data.table(rr, data_extra = TRUE)$data_extra[[1]], len = 1) + expect_list(as.data.table(rr)$data_extra, len = 3) + expect_list(as.data.table(rr)$data_extra[[1]], len = 1) }) test_that("data_extra is null", { @@ -171,14 +171,14 @@ test_that("data_extra is null", { expect_null(rr$data_extra) # resample result data.table - tab = as.data.table(rr, data_extra = TRUE) + tab = as.data.table(rr) expect_data_table(tab) expect_names(names(tab), disjunct.from = "data_extra") # benchmark data.table design = benchmark_grid(task, learner, resampling) bmr = benchmark(design, callbacks = callback) - tab = as.data.table(bmr, data_extra = TRUE) + tab = as.data.table(bmr) expect_data_table(tab) expect_names(names(tab), disjunct.from = "data_extra") }) From e6ffb33c927d1c88c7823b325077fbeef2661fad Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 11:36:23 +0100 Subject: [PATCH 49/54] ... --- R/CallbackResample.R | 13 ++++++------- man/CallbackResample.Rd | 2 +- man/callback_resample.Rd | 11 +++++------ pkgdown/_pkgdown.yml | 4 +++- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/R/CallbackResample.R b/R/CallbackResample.R index a7376cbac..0f8def9b9 100644 --- a/R/CallbackResample.R +++ b/R/CallbackResample.R @@ -1,4 +1,4 @@ -#' @title Evaluation Callback +#' @title Resample Callback #' #' @description #' Specialized [mlr3misc::Callback] to customize the behavior of [resample()] and [benchmark()] in mlr3. @@ -88,18 +88,17 @@ CallbackResample = R6Class("CallbackResample", #' #' @export #' @examples -#' callback = callback_resample("selected_features", -#' label = "Selected Features", +#' task = tsk("pima") +#' learner = lrn("classif.rpart") +#' resampling = rsmp("cv", folds = 3) #' +#' # save selected features callback +#' callback = callback_resample("selected_features", #' on_resample_end = function(callback, context) { #' context$learner$state$selected_features = context$learner$selected_features() #' } #' ) #' -#' task = tsk("pima") -#' learner = lrn("classif.rpart") -#' resampling = rsmp("cv", folds = 3) -#' #' rr = resample(task, learner, resampling, callbacks = callback) #' #' rr$learners[[1]]$state$selected_features diff --git a/man/CallbackResample.Rd b/man/CallbackResample.Rd index 845fb8035..567bfe5db 100644 --- a/man/CallbackResample.Rd +++ b/man/CallbackResample.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/CallbackResample.R \name{CallbackResample} \alias{CallbackResample} -\title{Evaluation Callback} +\title{Resample Callback} \description{ Specialized \link[mlr3misc:Callback]{mlr3misc::Callback} to customize the behavior of \code{\link[=resample]{resample()}} and \code{\link[=benchmark]{benchmark()}} in mlr3. The \code{\link[=callback_resample]{callback_resample()}} function is used to create instances of this class. diff --git a/man/callback_resample.Rd b/man/callback_resample.Rd index 16a55fdf2..1819a6a72 100644 --- a/man/callback_resample.Rd +++ b/man/callback_resample.Rd @@ -72,18 +72,17 @@ A callback can write data to the state (\verb{$state}), e.g. settings that affec We highly discourage changing the task, learner and resampling objects via the callback. } \examples{ -callback = callback_resample("selected_features", - label = "Selected Features", +task = tsk("pima") +learner = lrn("classif.rpart") +resampling = rsmp("cv", folds = 3) +# save selected features callback +callback = callback_resample("selected_features", on_resample_end = function(callback, context) { context$learner$state$selected_features = context$learner$selected_features() } ) -task = tsk("pima") -learner = lrn("classif.rpart") -resampling = rsmp("cv", folds = 3) - rr = resample(task, learner, resampling, callbacks = callback) rr$learners[[1]]$state$selected_features diff --git a/pkgdown/_pkgdown.yml b/pkgdown/_pkgdown.yml index b086f8365..ebc36febf 100644 --- a/pkgdown/_pkgdown.yml +++ b/pkgdown/_pkgdown.yml @@ -117,7 +117,9 @@ reference: - ContextResample - callback_resample - assert_resample_callback - - mlr3.score_measures + - assert_resample_callbacks + - mlr3.model_extractor + - mlr3.holdout_task - title: Internal Objects and Functions contents: - marshaling From 288da828509fa079172228bd944bd8f6248836ec Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 11:42:00 +0100 Subject: [PATCH 50/54] ... --- tests/testthat/test_mlr_callbacks.R | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/testthat/test_mlr_callbacks.R b/tests/testthat/test_mlr_callbacks.R index 47681a451..b51711fb2 100644 --- a/tests/testthat/test_mlr_callbacks.R +++ b/tests/testthat/test_mlr_callbacks.R @@ -1,25 +1,37 @@ -test_that("score_measure works", { +test_that("model extractor works", { task = tsk("pima") learner = lrn("classif.rpart") resampling = rsmp("cv", folds = 3) - callback = clbk("mlr3.score_measures", measures = msr("selected_features")) + selected_features = function(learner) list(selected_features = learner$selected_features()) + callback = clbk("mlr3.model_extractor", fun = selected_features) rr = resample(task, learner, resampling = resampling, callbacks = callback) expect_list(rr$data_extra) walk(rr$data_extra, function(data) { - expect_names(names(data), must.include = "score_measures") - expect_names(names(data[["score_measures"]]), must.include = "selected_features") + expect_names(names(data), must.include = "selected_features") + expect_subset(data[["selected_features"]], task$feature_names) }) +}) + +test_that("holdout task works", { + task = tsk("pima") + task_holdout = task$clone() + learner = lrn("classif.rpart") + resampling = rsmp("cv", folds = 3) + splits = partition(task, 0.7) + + task$filter(splits$train) + task_holdout$filter(splits$test) - callback = clbk("mlr3.score_measures", measures = msrs(c("classif.ce", "selected_features"))) + callback = clbk("mlr3.holdout_task", task = task_holdout) rr = resample(task, learner, resampling = resampling, callbacks = callback) expect_list(rr$data_extra) walk(rr$data_extra, function(data) { - expect_names(names(data), must.include = "score_measures") - expect_names(names(data[["score_measures"]]), must.include = c("classif.ce", "selected_features")) + expect_names(names(data), must.include = "prediction_holdout") + expect_prediction(data[["prediction_holdout"]]) }) }) From aa959983d632b74b838c6cc08aac1fc4c759a795 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 11:43:10 +0100 Subject: [PATCH 51/54] ... --- R/mlr_callbacks.R | 2 +- man/{mlr3.holdout_set.Rd => mlr3.holdout_task.Rd} | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) rename man/{mlr3.holdout_set.Rd => mlr3.holdout_task.Rd} (92%) diff --git a/R/mlr_callbacks.R b/R/mlr_callbacks.R index 912ffba81..9dd015e06 100644 --- a/R/mlr_callbacks.R +++ b/R/mlr_callbacks.R @@ -45,7 +45,7 @@ load_callback_model_extractor = function() { #' @title Callback Holdout Task #' #' @include CallbackResample.R -#' @name mlr3.holdout_set +#' @name mlr3.holdout_task #' #' @description #' This [CallbackResample] predicts on an additional holdout task after training. diff --git a/man/mlr3.holdout_set.Rd b/man/mlr3.holdout_task.Rd similarity index 92% rename from man/mlr3.holdout_set.Rd rename to man/mlr3.holdout_task.Rd index 8a61993b5..067ee04e8 100644 --- a/man/mlr3.holdout_set.Rd +++ b/man/mlr3.holdout_task.Rd @@ -1,7 +1,7 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/mlr_callbacks.R -\name{mlr3.holdout_set} -\alias{mlr3.holdout_set} +\name{mlr3.holdout_task} +\alias{mlr3.holdout_task} \title{Callback Holdout Task} \arguments{ \item{task}{(\link{Task})\cr From 92b9ddde4a596e3f6aae5341f151d9e609cf774b Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 12:26:23 +0100 Subject: [PATCH 52/54] ... --- DESCRIPTION | 2 +- R/{ContextEvaluation.R => ContextResample.R} | 2 +- man/ContextResample.Rd | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) rename R/{ContextEvaluation.R => ContextResample.R} (99%) diff --git a/DESCRIPTION b/DESCRIPTION index b44d54702..9823e439f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -80,7 +80,7 @@ Collate: 'mlr_reflections.R' 'BenchmarkResult.R' 'CallbackResample.R' - 'ContextEvaluation.R' + 'ContextResample.R' 'warn_deprecated.R' 'DataBackend.R' 'DataBackendCbind.R' diff --git a/R/ContextEvaluation.R b/R/ContextResample.R similarity index 99% rename from R/ContextEvaluation.R rename to R/ContextResample.R index 0028b6700..45f48f931 100644 --- a/R/ContextEvaluation.R +++ b/R/ContextResample.R @@ -1,4 +1,4 @@ -#' @title Evaluation Context +#' @title Resample Context #' #' @description #' A [CallbackResample] accesses and modifies data during [resample()] and [benchmark()] via the `ContextResample`. diff --git a/man/ContextResample.Rd b/man/ContextResample.Rd index e679fc68f..d0831b177 100644 --- a/man/ContextResample.Rd +++ b/man/ContextResample.Rd @@ -1,8 +1,8 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/ContextEvaluation.R +% Please edit documentation in R/ContextResample.R \name{ContextResample} \alias{ContextResample} -\title{Evaluation Context} +\title{Resample Context} \description{ A \link{CallbackResample} accesses and modifies data during \code{\link[=resample]{resample()}} and \code{\link[=benchmark]{benchmark()}} via the \code{ContextResample}. See the section on fields for a list of modifiable objects. From 41028cd346976c748332a19b4d076151ed76efc6 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 12:35:43 +0100 Subject: [PATCH 53/54] ... --- R/ResultData.R | 9 ++++----- R/worker.R | 45 +++++++++++++++++++++++---------------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/R/ResultData.R b/R/ResultData.R index 21c3681e5..5d1322e3b 100644 --- a/R/ResultData.R +++ b/R/ResultData.R @@ -29,13 +29,12 @@ ResultData = R6Class("ResultData", #' An alternative construction method is provided by [as_result_data()]. #' #' @param data ([data.table::data.table()]) | `NULL`)\cr - #' Do not initialize this object yourself, use [as_result_data()] instead. + #' Do not initialize this object yourself, use [as_result_data()] instead. #' @param data_extra (`list()`)\cr - #' Additional data to store. - #' This can be used to store additional information for each iteration. - #' + #' Additional data to store. + #' This can be used to store additional information for each iteration. #' @param store_backends (`logical(1)`)\cr - #' If set to `FALSE`, the backends of the [Task]s provided in `data` are removed. + #' If set to `FALSE`, the backends of the [Task]s provided in `data` are removed. initialize = function(data = NULL, data_extra = NULL, store_backends = TRUE) { assert_flag(store_backends) diff --git a/R/worker.R b/R/worker.R index bfb985c06..bef53b538 100644 --- a/R/worker.R +++ b/R/worker.R @@ -270,9 +270,9 @@ workhorse = function( call_back("on_resample_begin", callbacks, ctx) if (!is.null(pb)) { - pb(sprintf("%s|%s|i:%i", ctx$task$id, ctx$learner$id, ctx$iteration)) + pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) } - if ("internal_valid" %in% ctx$learner$predict_sets && is.null(ctx$task$internal_valid_task) && is.null(get0("validate", ctx$learner))) { + if ("internal_valid" %in% learner$predict_sets && is.null(task$internal_valid_task) && is.null(get0("validate", learner))) { stopf("Cannot set the predict_type field of learner '%s' to 'internal_valid' if there is no internal validation task configured", learner$id) } @@ -306,74 +306,75 @@ workhorse = function( } lg$info("%s learner '%s' on task '%s' (iter %i/%i)", - if (mode == "train") "Applying" else "Hotstarting", ctx$learner$id, ctx$task$id, ctx$iteration, ctx$resampling$iters) + if (mode == "train") "Applying" else "Hotstarting", learner$id, task$id, iteration, resampling$iters) sets = list( - train = ctx$resampling$train_set(ctx$iteration), - test = ctx$resampling$test_set(ctx$iteration) + train = resampling$train_set(iteration), + test = resampling$test_set(iteration) ) # train model - ctx$learner = ctx$learner$clone() + # use `learner` reference instead of `ctx$learner` to avoid going through the active binding + ctx$learner = learner = ctx$learner$clone() if (length(param_values)) { - ctx$learner$param_set$values = list() - ctx$learner$param_set$set_values(.values = param_values) + learner$param_set$values = list() + learner$param_set$set_values(.values = param_values) } - learner_hash = ctx$learner$hash + learner_hash = learner$hash - validate = get0("validate", ctx$learner) + validate = get0("validate", learner) test_set = if (identical(validate, "test")) sets$test call_back("on_resample_before_train", callbacks, ctx) - train_result = learner_train(ctx$learner, ctx$task, sets[["train"]], test_set, mode = mode) - ctx$learner = train_result$learner + train_result = learner_train(learner, task, sets[["train"]], test_set, mode = mode) + ctx$learner = learner = train_result$learner # process the model so it can be used for prediction (e.g. marshal for callr prediction), but also # keep a copy of the model in current form in case this is the format that we want to send back to the main process # and not the format that we need for prediction model_copy_or_null = process_model_before_predict( - learner = ctx$learner, store_models = store_models, is_sequential = is_sequential, unmarshal = unmarshal + learner = learner, store_models = store_models, is_sequential = is_sequential, unmarshal = unmarshal ) # predict for each set - predict_sets = ctx$learner$predict_sets + predict_sets = learner$predict_sets # creates the tasks and row_ids for all selected predict sets - pred_data = prediction_tasks_and_sets(ctx$task, train_result, validate, sets, predict_sets) + pred_data = prediction_tasks_and_sets(task, train_result, validate, sets, predict_sets) call_back("on_resample_before_predict", callbacks, ctx) pdatas = Map(function(set, row_ids, task) { lg$debug("Creating Prediction for predict set '%s'", set) - learner_predict(ctx$learner, task, row_ids) + learner_predict(learner, task, row_ids) }, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks) if (!length(predict_sets)) { - ctx$learner$state$predict_time = 0L + learner$state$predict_time = 0L } ctx$pdatas = discard(pdatas, is.null) # set the model slot after prediction so it can be sent back to the main process process_model_after_predict( - learner = ctx$learner, store_models = store_models, is_sequential = is_sequential, model_copy = model_copy_or_null, + learner = learner, store_models = store_models, is_sequential = is_sequential, model_copy = model_copy_or_null, unmarshal = unmarshal ) call_back("on_resample_end", callbacks, ctx) if (!store_models) { - lg$debug("Erasing stored model for learner '%s'", ctx$learner$id) - ctx$learner$state$model = NULL + lg$debug("Erasing stored model for learner '%s'", learner$id) + learner$state$model = NULL } - learner_state = set_class(ctx$learner$state, c("learner_state", "list")) + learner_state = set_class(learner$state, c("learner_state", "list")) list( learner_state = learner_state, prediction = ctx$pdatas, - param_values = ctx$learner$param_set$values, + param_values = learner$param_set$values, learner_hash = learner_hash, data_extra = ctx$data_extra) } From adc95f84e50ba279060962625fa053e1d67e7406 Mon Sep 17 00:00:00 2001 From: be-marc Date: Fri, 20 Dec 2024 12:46:50 +0100 Subject: [PATCH 54/54] ... --- R/mlr_callbacks.R | 2 -- tests/testthat/test_CallbackResample.R | 26 ++++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/R/mlr_callbacks.R b/R/mlr_callbacks.R index 9dd015e06..4fa05a40c 100644 --- a/R/mlr_callbacks.R +++ b/R/mlr_callbacks.R @@ -76,8 +76,6 @@ load_callback_holdout_task = function() { man = "mlr3::mlr3.holdout_task", on_resample_before_predict = function(callback, context) { - assert_task(callback$state$task) - pred = context$learner$predict(callback$state$task) context$data_extra = list(prediction_holdout = pred) } diff --git a/tests/testthat/test_CallbackResample.R b/tests/testthat/test_CallbackResample.R index 29f7da144..21c2caa82 100644 --- a/tests/testthat/test_CallbackResample.R +++ b/tests/testthat/test_CallbackResample.R @@ -183,3 +183,29 @@ test_that("data_extra is null", { expect_names(names(tab), disjunct.from = "data_extra") }) +test_that("learner cloning in workhorse is passed to context", { + task = tsk("pima") + learner = lrn("classif.rpart") + resampling = rsmp("holdout") + + callback = callback_resample("test", + on_resample_begin = function(callback, context) { + callback$state$address_1 = data.table::address(context$learner) + }, + + on_resample_before_train = function(callback, context) { + callback$state$address_2 = data.table::address(context$learner) + }, + + on_resample_end = function(callback, context) { + context$data_extra = list( + address_1 = callback$state$address_1, + address_2 = callback$state$address_2 + ) + } + ) + + rr = resample(task, learner, resampling, callbacks = callback) + + expect_true(rr$data_extra[[1]]$address_1 != rr$data_extra[[1]]$address_2) +})