diff --git a/DESCRIPTION b/DESCRIPTION index 16060dae1..9823e439f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -79,6 +79,8 @@ RoxygenNote: 7.3.2 Collate: 'mlr_reflections.R' 'BenchmarkResult.R' + 'CallbackResample.R' + 'ContextResample.R' 'warn_deprecated.R' 'DataBackend.R' 'DataBackendCbind.R' @@ -189,6 +191,7 @@ Collate: 'helper_print.R' 'install_pkgs.R' 'marshal.R' + 'mlr_callbacks.R' 'mlr_sugar.R' 'mlr_test_helpers.R' 'partition.R' diff --git a/NAMESPACE b/NAMESPACE index 8946cdd5e..00f462a51 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -120,6 +120,8 @@ S3method(unmarshal_model,classif.debug_model_marshaled) S3method(unmarshal_model,default) S3method(unmarshal_model,learner_state_marshaled) export(BenchmarkResult) +export(CallbackResample) +export(ContextResample) export(DataBackend) export(DataBackendDataTable) export(DataBackendMatrix) @@ -207,6 +209,8 @@ export(assert_measure) export(assert_measures) export(assert_predictable) export(assert_prediction) +export(assert_resample_callback) +export(assert_resample_callbacks) export(assert_resample_result) export(assert_resampling) export(assert_resamplings) @@ -218,7 +222,10 @@ export(assert_validate) export(auto_convert) export(benchmark) export(benchmark_grid) +export(callback_resample) export(check_prediction_data) +export(clbk) +export(clbks) export(col_info) export(convert_task) export(create_empty_prediction_data) @@ -236,6 +243,7 @@ export(learner_unmarshal) export(lrn) export(lrns) export(marshal_model) +export(mlr_callbacks) export(mlr_learners) export(mlr_measures) export(mlr_reflections) @@ -269,6 +277,9 @@ importFrom(data.table,data.table) importFrom(future,nbrOfWorkers) importFrom(future,plan) importFrom(graphics,plot) +importFrom(mlr3misc,clbk) +importFrom(mlr3misc,clbks) +importFrom(mlr3misc,mlr_callbacks) importFrom(parallelly,availableCores) importFrom(stats,contr.treatment) importFrom(stats,model.frame) diff --git a/R/BenchmarkResult.R b/R/BenchmarkResult.R index 9ee35c3ce..e60d3e204 100644 --- a/R/BenchmarkResult.R +++ b/R/BenchmarkResult.R @@ -548,7 +548,9 @@ BenchmarkResult = R6Class("BenchmarkResult", as.data.table.BenchmarkResult = function(x, ..., hashes = FALSE, predict_sets = "test", task_characteristics = FALSE) { # nolint assert_flag(task_characteristics) tab = get_private(x)$.data$as_data_table(view = NULL, predict_sets = predict_sets) - tab = tab[, c("uhash", "task", "learner", "resampling", "iteration", "prediction"), with = FALSE] + cns = c("uhash", "task", "learner", "resampling", "iteration", "prediction") + if ("data_extra" %in% names(tab)) cns = c(cns, "data_extra") + tab = tab[, cns, with = FALSE] if (task_characteristics) { set(tab, j = "characteristics", value = map(tab$task, "characteristics")) diff --git a/R/CallbackResample.R b/R/CallbackResample.R new file mode 100644 index 000000000..0f8def9b9 --- /dev/null +++ b/R/CallbackResample.R @@ -0,0 +1,155 @@ +#' @title Resample Callback +#' +#' @description +#' Specialized [mlr3misc::Callback] to customize the behavior of [resample()] and [benchmark()] in mlr3. +#' The [callback_resample()] function is used to create instances of this class. +#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. +#' For more information on callbacks, see the [callback_resample()] documentation. +#' +#' @export +CallbackResample = R6Class("CallbackResample", + inherit = Callback, + public = list( + + #' @field on_resample_begin (`function()`)\cr + #' Stage called at the beginning of the resampling iteration. + #' Called in `workhorse()` (internal). + on_resample_begin = NULL, + + #' @field on_resample_before_train (`function()`)\cr + #' Stage called before training the learner. + #' Called in `workhorse()` (internal). + on_resample_before_train = NULL, + + #' @field on_resample_before_predict (`function()`)\cr + #' Stage called before predicting. + #' Called in `workhorse()` (internal). + on_resample_before_predict = NULL, + + #' @field on_resample_end (`function()`)\cr + #' Stage called at the end of the resample iteration. + #' Called in `workhorse()` (internal). + on_resample_end = NULL + ) +) + +#' @title Create Evaluation Callback +#' +#' @description +#' Function to create a [CallbackResample]. +#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()]. +#' +#' Evaluation callbacks are called at different stages of the resampling process. +#' Each stage is called once per resampling iteration. +#' The stages are prefixed with `on_resample_*`. +#' The text in brackets indicates what happens between the stages and which accesses to the [ContextResample] (`ctx`) are typical for the stage. +#' +#' ``` +#' Start Resampling Iteration on Worker +#' - on_resample_begin +#' (Split `ctx$task` into training and test set with `ctx$resampling` and `ctx$iteration`) +#' - on_resample_before_train +#' (Train the learner `ctx$learner` on training data) +#' - on_resample_before_predict +#' (Predict on predict sets and store prediction data `ctx$pdatas`) +#' - on_resample_end +#' (Erase model `ctx$learner$model` if requested and return results) +#' End Resampling Iteration on Worker +#' ``` +#' +#' The callback can store data in `ctx$learner$state` or `ctx$data_extra`. +#' The data in `ctx$data_extra` is stored in the [ResampleResult] or [BenchmarkResult]. +#' See also the section on parameters for more information on the stages. +# +#' @details +#' When implementing a callback, each function must have two arguments named `callback` and `context`. +#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself. +#' We highly discourage changing the task, learner and resampling objects via the callback. +#' +#' @param id (`character(1)`)\cr +#' Identifier for the new instance. +#' @param label (`character(1)`)\cr +#' Label for the new instance. +#' @param man (`character(1)`)\cr +#' String in the format `[pkg]::[topic]` pointing to a manual page for this object. +#' The referenced help package can be opened via method `$help()`. +#' @param on_resample_begin (`function()`)\cr +#' Stage called at the beginning of an evaluation. +#' Called in `workhorse()` (internal). +#' @param on_resample_before_train (`function()`)\cr +#' Stage called before training the learner. +#' Called in `workhorse()` (internal). +#' @param on_resample_before_predict (`function()`)\cr +#' Stage called before predicting. +#' Called in `workhorse()` (internal). +#' @param on_resample_end (`function()`)\cr +#' Stage called at the end of an evaluation. +#' Called in `workhorse()` (internal). +#' +#' @export +#' @examples +#' task = tsk("pima") +#' learner = lrn("classif.rpart") +#' resampling = rsmp("cv", folds = 3) +#' +#' # save selected features callback +#' callback = callback_resample("selected_features", +#' on_resample_end = function(callback, context) { +#' context$learner$state$selected_features = context$learner$selected_features() +#' } +#' ) +#' +#' rr = resample(task, learner, resampling, callbacks = callback) +#' +#' rr$learners[[1]]$state$selected_features +callback_resample = function( + id, + label = NA_character_, + man = NA_character_, + on_resample_begin = NULL, + on_resample_before_train = NULL, + on_resample_before_predict = NULL, + on_resample_end = NULL + ) { + stages = discard(set_names(list( + on_resample_begin, + on_resample_before_train, + on_resample_before_predict, + on_resample_end), + c( + "on_resample_begin", + "on_resample_before_train", + "on_resample_before_predict", + "on_resample_end" + )), is.null) + + stages = map(stages, function(stage) crate(assert_function(stage, args = c("callback", "context")))) + callback = CallbackResample$new(id, label, man) + iwalk(stages, function(stage, name) callback[[name]] = stage) + callback +} + +#' @title Assertions for Callbacks +#' +#' @description +#' Assertions for [CallbackResample] class. +#' +#' @param callback ([CallbackResample]). +#' @param null_ok (`logical(1)`)\cr +#' If `TRUE`, `NULL` is allowed. +#' +#' @return [CallbackResample | List of [CallbackResample]s. +#' @export +assert_resample_callback = function(callback, null_ok = FALSE) { + assert_class(callback, "CallbackResample", null.ok = null_ok) + invisible(callback) +} + +#' @export +#' @param callbacks (list of [CallbackResample]). +#' @rdname assert_resample_callback +assert_resample_callbacks = function(callbacks, null_ok = FALSE) { + assert_list(callbacks, null.ok = null_ok) + if (null_ok && is.null(callbacks)) return(invisible(NULL)) + invisible(lapply(callbacks, assert_resample_callback)) +} diff --git a/R/ContextResample.R b/R/ContextResample.R new file mode 100644 index 000000000..45f48f931 --- /dev/null +++ b/R/ContextResample.R @@ -0,0 +1,103 @@ +#' @title Resample Context +#' +#' @description +#' A [CallbackResample] accesses and modifies data during [resample()] and [benchmark()] via the `ContextResample`. +#' See the section on fields for a list of modifiable objects. +#' See [callback_resample()] for a list of stages that access `ContextResample`. +#' +#' @export +ContextResample = R6Class("ContextResample", + inherit = Context, + public = list( + + #' @description + #' Creates a new instance of this [R6][R6::R6Class] class. + #' + #' @param task ([Task])\cr + #' The task to be evaluated. + #' @param learner ([Learner])\cr + #' The learner to be evaluated. + #' @param resampling ([Resampling])\cr + #' The resampling strategy to be used. + #' @param iteration (`integer()`)\cr + #' The current iteration. + initialize = function(task, learner, resampling, iteration) { + # no assertions to avoid overhead + private$.task = task + private$.learner = learner + private$.resampling = resampling + private$.iteration = iteration + + super$initialize(id = "evaluate", label = "Evaluation") + } + ), + + active = list( + + #' @field task ([Task])\cr + #' The task to be evaluated. + #' The task is unchanged during the evaluation. + #' The task is read-only. + task = function(rhs) { + assert_ro_binding(rhs) + private$.task + }, + + #' @field learner ([Learner])\cr + #' The learner to be evaluated. + #' The learner contains the models after stage `on_resample_before_train`. + learner = function(rhs) { + if (missing(rhs)) { + return(private$.learner) + } + private$.learner = assert_learner(rhs) + }, + + #' @field resampling [Resampling]\cr + #' The resampling strategy to be used. + #' The resampling is unchanged during the evaluation. + #' The resampling is read-only. + resampling = function(rhs) { + assert_ro_binding(rhs) + private$.resampling + }, + + #' @field iteration (`integer()`)\cr + #' The current iteration. + #' The iteration is read-only. + iteration = function(rhs) { + assert_ro_binding(rhs) + private$.iteration + }, + + #' @field pdatas (List of [PredictionData])\cr + #' The prediction data. + #' The data is available on stage `on_resample_end`. + pdatas = function(rhs) { + if (missing(rhs)) { + return(private$.pdatas) + } + private$.pdatas = assert_list(rhs, "PredictionData") + }, + + #' @field data_extra (list())\cr + #' Data saved in the [ResampleResult] or [BenchmarkResult]. + #' Use this field to save results. + #' Must be a `list()`. + data_extra = function(rhs) { + if (missing(rhs)) { + return(private$.data_extra) + } + private$.data_extra = assert_list(rhs) + } + ), + + private = list( + .task = NULL, + .learner = NULL, + .resampling = NULL, + .iteration = NULL, + .pdatas = NULL, + .data_extra = NULL + ) +) diff --git a/R/ResampleResult.R b/R/ResampleResult.R index 83b79ada3..92a2a16d3 100644 --- a/R/ResampleResult.R +++ b/R/ResampleResult.R @@ -335,6 +335,13 @@ ResampleResult = R6Class("ResampleResult", private$.data$learners(private$.view)$learner }, + #' @field data_extra (list())\cr + #' Additional data stored in the [ResampleResult]. + data_extra = function(rhs) { + assert_ro_binding(rhs) + private$.data$data_extra(private$.view) + }, + #' @field warnings ([data.table::data.table()])\cr #' A table with all warning messages. #' Column names are `"iteration"` and `"msg"`. @@ -373,7 +380,9 @@ ResampleResult = R6Class("ResampleResult", as.data.table.ResampleResult = function(x, ..., predict_sets = "test") { # nolint private = get_private(x) tab = private$.data$as_data_table(view = private$.view, predict_sets = predict_sets) - tab[, c("task", "learner", "resampling", "iteration", "prediction"), with = FALSE] + cns = c("task", "learner", "resampling", "iteration", "prediction") + if ("data_extra" %in% names(tab)) cns = c(cns, "data_extra") + tab[, cns, with = FALSE] } # #' @export diff --git a/R/ResultData.R b/R/ResultData.R index a2c74feb0..5d1322e3b 100644 --- a/R/ResultData.R +++ b/R/ResultData.R @@ -18,29 +18,30 @@ #' print(ResultData$new()$data) ResultData = R6Class("ResultData", public = list( + #' @field data (`list()`)\cr - #' List of [data.table::data.table()], arranged in a star schema. - #' Do not operate directly on this list. + #' List of [data.table::data.table()], arranged in a star schema. + #' Do not operate directly on this list. data = NULL, - #' @description #' Creates a new instance of this [R6][R6::R6Class] class. #' An alternative construction method is provided by [as_result_data()]. #' #' @param data ([data.table::data.table()]) | `NULL`)\cr #' Do not initialize this object yourself, use [as_result_data()] instead. + #' @param data_extra (`list()`)\cr + #' Additional data to store. + #' This can be used to store additional information for each iteration. #' @param store_backends (`logical(1)`)\cr - #' If set to `FALSE`, the backends of the [Task]s provided in `data` are - #' removed. - initialize = function(data = NULL, store_backends = TRUE) { + #' If set to `FALSE`, the backends of the [Task]s provided in `data` are removed. + initialize = function(data = NULL, data_extra = NULL, store_backends = TRUE) { assert_flag(store_backends) if (is.null(data)) { self$data = star_init() } else { - assert_names(names(data), - permutation.of = c("task", "learner", "learner_state", "resampling", "iteration", "param_values", "prediction", "uhash", "learner_hash")) + assert_names(names(data), permutation.of = c("task", "learner", "learner_state", "resampling", "iteration", "param_values", "prediction", "uhash", "learner_hash")) if (nrow(data) == 0L) { self$data = star_init() @@ -68,6 +69,12 @@ ResultData = R6Class("ResultData", set(data, j = "resampling", value = NULL) set(data, j = "param_values", value = NULL) + # add extra data to fact table + if (!is.null(data_extra)) { + assert_list(data_extra, len = nrow(data)) + set(data, j = "data_extra", value = list(data_extra)) + } + if (!store_backends) { set(tasks, j = "task", value = lapply(tasks$task, task_rm_backend)) } @@ -189,6 +196,18 @@ ResultData = R6Class("ResultData", do.call(c, self$predictions(view = view, predict_sets = predict_sets)) }, + #' @description + #' Returns additional data stored. + #' + #' @return `list()`. + data_extra = function(view = NULL) { + if ("data_extra" %nin% names(self$data$fact)) { + return(NULL) + } + .__ii__ = private$get_view_index(view) + self$data$fact[.__ii__, "data_extra", with = FALSE][[1L]] + }, + #' @description #' Combines multiple [ResultData] objects, modifying `self` in-place. #' @@ -315,7 +334,7 @@ ResultData = R6Class("ResultData", } cns = c("uhash", "task", "task_hash", "learner", "learner_hash", "learner_param_vals", "resampling", - "resampling_hash", "iteration", "prediction") + "resampling_hash", "iteration", "prediction", if ("data_extra" %in% names(self$data$fact)) "data_extra") merge(self$data$uhashes, tab[, cns, with = FALSE], by = "uhash", sort = FALSE) }, @@ -369,13 +388,14 @@ ResultData = R6Class("ResultData", ####################################################################################################################### ### constructor ####################################################################################################################### -star_init = function() { +star_init = function(data_extra = FALSE) { fact = data.table( uhash = character(), iteration = integer(), learner_state = list(), prediction = list(), + learner_hash = character(), task_hash = character(), learner_phash = character(), @@ -384,6 +404,8 @@ star_init = function() { key = c("uhash", "iteration") ) + if (data_extra) fact[, "data_extra" := list()] + uhashes = data.table( uhash = character() ) diff --git a/R/as_result_data.R b/R/as_result_data.R index e264a7d02..bb7f9b648 100644 --- a/R/as_result_data.R +++ b/R/as_result_data.R @@ -15,6 +15,8 @@ #' @param predictions (list of list of [Prediction]s). #' @param learner_states (`list()`)\cr #' Learner states. If not provided, the states of `learners` are automatically extracted. +#' @param data_extra (`list()`)\cr +#' Additional data for each iteration. #' @param store_backends (`logical(1)`)\cr #' If set to `FALSE`, the backends of the [Task]s provided in `data` are #' removed. @@ -39,18 +41,27 @@ #' #' rdata = as_result_data(task, learners, resampling, iterations, predictions) #' ResampleResult$new(rdata) -as_result_data = function(task, learners, resampling, iterations, predictions, learner_states = NULL, store_backends = TRUE) { +as_result_data = function( + task, + learners, + resampling, + iterations, + predictions, + learner_states = NULL, + data_extra = NULL, + store_backends = TRUE + ) { assert_task(task) assert_learners(learners, task = task) assert_resampling(resampling, instantiated = TRUE) assert_integer(iterations, any.missing = FALSE, lower = 1L, upper = resampling$iters, unique = TRUE) assert_list(predictions, types = "list") assert_list(learner_states, null.ok = TRUE) + assert_list(data_extra, null.ok = TRUE) predictions = map(predictions, function(x) map(x, as_prediction_data)) N = length(iterations) - if (length(learners) != N) { stopf("Number of learners (%i) must match the number of resampling iterations (%i)", length(learners), N) } @@ -69,6 +80,10 @@ as_result_data = function(task, learners, resampling, iterations, predictions, l stopf("Resampling '%s' has not been trained on task '%s', hashes do not match", resampling$id, task$id) } + if (!is.null(data_extra) && length(data_extra) != N) { + stopf("Length of data_extra (%i) must match the number of resampling iterations (%i)", length(data_extra), N) + } + ResultData$new(data.table( task = list(task), learner = learners, @@ -78,6 +93,7 @@ as_result_data = function(task, learners, resampling, iterations, predictions, l resampling = list(resampling), iteration = iterations, prediction = predictions, - uhash = UUIDgenerate() + uhash = UUIDgenerate(), + data_extra = data_extra ), store_backends = store_backends) } diff --git a/R/benchmark.R b/R/benchmark.R index 5f26c38a6..48d65bccd 100644 --- a/R/benchmark.R +++ b/R/benchmark.R @@ -18,6 +18,7 @@ #' @template param_allow_hotstart #' @template param_clone #' @template param_unmarshal +#' @template param_callbacks #' #' @return [BenchmarkResult]. #' @@ -81,7 +82,7 @@ #' ## Get the training set of the 2nd iteration of the featureless learner on penguins #' rr = bmr$aggregate()[learner_id == "classif.featureless"]$resample_result[[1]] #' rr$resampling$train_set(2) -benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) { +benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE, callbacks = NULL) { assert_subset(clone, c("task", "learner", "resampling")) assert_data_frame(design, min.rows = 1L) assert_names(names(design), must.include = c("task", "learner", "resampling")) @@ -96,6 +97,7 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps } assert_flag(store_models) assert_flag(store_backends) + callbacks = assert_callbacks(as_callbacks(callbacks)) # check for multiple task types task_types = unique(map_chr(design$task, "task_type")) @@ -187,7 +189,7 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps res = future_map(n, workhorse, task = grid$task, learner = grid$learner, resampling = grid$resampling, iteration = grid$iteration, param_values = grid$param_values, mode = grid$mode, - MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal) + MoreArgs = list(store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, callbacks = callbacks) ) grid = insert_named(grid, list( @@ -201,7 +203,9 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps set(grid, j = "mode", value = NULL) - result_data = ResultData$new(grid, store_backends = store_backends) + data_extra = if (length(callbacks) && any(map_lgl(res, function(x) !is.null(x$data_extra)))) map(res, "data_extra") + + result_data = ResultData$new(grid, data_extra, store_backends = store_backends) if (unmarshal && store_models) { result_data$unmarshal() diff --git a/R/mlr_callbacks.R b/R/mlr_callbacks.R new file mode 100644 index 000000000..4fa05a40c --- /dev/null +++ b/R/mlr_callbacks.R @@ -0,0 +1,83 @@ +#' @title Model Extractor Callback +#' +#' @include CallbackResample.R +#' @name mlr3.model_extractor +#' +#' @description +#' This [CallbackResample] extracts information from the model after training with a user-defined function. +#' This way information can be extracted from the model without saving the model (`store_models = FALSE`). +#' The `fun` must be a function that takes a learner as input and returns the extracted information as named list (see example). +#' The callback is very helpful to call `$selected_features()`, `$importance()`, `$oob_error()` on the learner. +#' +#' @param fun (`function(learner)`)\cr +#' Function to extract information from the learner. +#' The function must have the argument `learner`. +#' The function must return a named list. +#' +#' @examples +#' task = tsk("pima") +#' learner = lrn("classif.rpart") +#' resampling = rsmp("cv", folds = 3) +#' +#' # define function to extract selected features +#' selected_features = function(learner) list(selected_features = learner$selected_features()) +#' +#' # create callback +#' callback = clbk("mlr3.model_extractor", fun = selected_features) +#' +#' rr = resample(task, learner, resampling = resampling, store_models = FALSE, callbacks = callback) +#' +#' rr$data_extra +NULL + +load_callback_model_extractor = function() { + callback_resample("mlr3.model_extractor", + label = "Model Extractor Callback", + man = "mlr3::mlr3.model_extractor", + + on_resample_end = function(callback, context) { + assert_function(callback$state$fun, args = "learner") + context$data_extra = invoke(callback$state$fun, learner = context$learner) + } + ) +} + +#' @title Callback Holdout Task +#' +#' @include CallbackResample.R +#' @name mlr3.holdout_task +#' +#' @description +#' This [CallbackResample] predicts on an additional holdout task after training. +#' +#' @param task ([Task])\cr +#' The holdout task. +#' +#' @examples +#' task = tsk("pima") +#' task_holdout = task$clone() +#' learner = lrn("classif.rpart") +#' resampling = rsmp("cv", folds = 3) +#' splits = partition(task, 0.7) +#' +#' task$filter(splits$train) +#' task_holdout$filter(splits$test) +#' +#' callback = clbk("mlr3.holdout_task", task = task_holdout) +#' +#' rr = resample(task, learner, resampling = resampling, callbacks = callback) +#' +#' rr$data_extra +NULL + +load_callback_holdout_task = function() { + callback_resample("mlr3.holdout_task", + label = "Callback Holdout Task", + man = "mlr3::mlr3.holdout_task", + + on_resample_before_predict = function(callback, context) { + pred = context$learner$predict(callback$state$task) + context$data_extra = list(prediction_holdout = pred) + } + ) +} diff --git a/R/reexports.R b/R/reexports.R index 7fc1b7d49..b80171a9b 100644 --- a/R/reexports.R +++ b/R/reexports.R @@ -3,3 +3,12 @@ data.table::as.data.table #' @export data.table::data.table + +#' @export +mlr3misc::mlr_callbacks + +#' @export +mlr3misc::clbk + +#' @export +mlr3misc::clbks diff --git a/R/resample.R b/R/resample.R index cc1bb88f2..5e7614592 100644 --- a/R/resample.R +++ b/R/resample.R @@ -15,6 +15,7 @@ #' @template param_allow_hotstart #' @template param_clone #' @template param_unmarshal +#' @template param_callbacks #' @return [ResampleResult]. #' #' @template section_predict_sets @@ -55,7 +56,21 @@ #' bmr1 = as_benchmark_result(rr) #' bmr2 = as_benchmark_result(rr_featureless) #' print(bmr1$combine(bmr2)) -resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) { +resample = function( + task, + learner, + resampling, + store_models = FALSE, + store_backends = TRUE, + encapsulate = NA_character_, + allow_hotstart = FALSE, + clone = c("task", "learner", "resampling"), + unmarshal = TRUE, + callbacks = NULL + ) { + + lg$debug("Start resampling") + assert_subset(clone, c("task", "learner", "resampling")) task = assert_task(as_task(task, clone = "task" %in% clone)) learner = assert_learner(as_learner(learner, clone = "learner" %in% clone, discard_state = TRUE)) @@ -65,6 +80,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe # this does not check the internal validation task as it might not be set yet assert_learnable(task, learner) assert_flag(unmarshal) + callbacks = assert_callbacks(as_callbacks(callbacks)) set_encapsulation(list(learner), encapsulate) if (!resampling$is_instantiated) { @@ -78,6 +94,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe } else { NULL } + lgr_threshold = map_int(mlr_reflections$loggers, "threshold") grid = if (allow_hotstart) { @@ -115,7 +132,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe } res = future_map(n, workhorse, iteration = seq_len(n), learner = grid$learner, mode = grid$mode, - MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal) + MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, callbacks = callbacks) ) data = data.table( @@ -130,7 +147,9 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe learner_hash = map_chr(res, "learner_hash") ) - result_data = ResultData$new(data, store_backends = store_backends) + data_extra = if (length(callbacks) && any(map_lgl(res, function(x) !is.null(x$data_extra)))) map(res, "data_extra") + + result_data = ResultData$new(data, data_extra, store_backends = store_backends) # the worker already ensures that models are sent back in marshaled form if unmarshal = FALSE, so we don't have # to do anything in this case. This allows us to minimize the amount of marshaling in those situtions where diff --git a/R/worker.R b/R/worker.R index e52daa23f..bef53b538 100644 --- a/R/worker.R +++ b/R/worker.R @@ -251,7 +251,24 @@ learner_predict = function(learner, task, row_ids = NULL) { } -workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE, unmarshal = TRUE) { +workhorse = function( + iteration, + task, + learner, + resampling, + param_values = NULL, + lgr_threshold, + store_models = FALSE, + pb = NULL, + mode = "train", + is_sequential = TRUE, + unmarshal = TRUE, + callbacks = NULL + ) { + ctx = ContextResample$new(task, learner, resampling, iteration) + + call_back("on_resample_begin", callbacks, ctx) + if (!is.null(pb)) { pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) } @@ -280,6 +297,7 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, }, add = TRUE) } } + # restore logger thresholds for (package in names(lgr_threshold)) { logger = lgr::get_logger(package) @@ -296,7 +314,8 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, ) # train model - learner = learner$clone() + # use `learner` reference instead of `ctx$learner` to avoid going through the active binding + ctx$learner = learner = ctx$learner$clone() if (length(param_values)) { learner$param_set$values = list() learner$param_set$set_values(.values = param_values) @@ -306,8 +325,11 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, validate = get0("validate", learner) test_set = if (identical(validate, "test")) sets$test + + call_back("on_resample_before_train", callbacks, ctx) + train_result = learner_train(learner, task, sets[["train"]], test_set, mode = mode) - learner = train_result$learner + ctx$learner = learner = train_result$learner # process the model so it can be used for prediction (e.g. marshal for callr prediction), but also # keep a copy of the model in current form in case this is the format that we want to send back to the main process @@ -321,16 +343,18 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, # creates the tasks and row_ids for all selected predict sets pred_data = prediction_tasks_and_sets(task, train_result, validate, sets, predict_sets) + call_back("on_resample_before_predict", callbacks, ctx) pdatas = Map(function(set, row_ids, task) { lg$debug("Creating Prediction for predict set '%s'", set) + learner_predict(learner, task, row_ids) }, set = predict_sets, row_ids = pred_data$sets, task = pred_data$tasks) if (!length(predict_sets)) { learner$state$predict_time = 0L } - pdatas = discard(pdatas, is.null) + ctx$pdatas = discard(pdatas, is.null) # set the model slot after prediction so it can be sent back to the main process process_model_after_predict( @@ -338,9 +362,21 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, unmarshal = unmarshal ) + call_back("on_resample_end", callbacks, ctx) + + if (!store_models) { + lg$debug("Erasing stored model for learner '%s'", learner$id) + learner$state$model = NULL + } + learner_state = set_class(learner$state, c("learner_state", "list")) - list(learner_state = learner_state, prediction = pdatas, param_values = learner$param_set$values, learner_hash = learner_hash) + list( + learner_state = learner_state, + prediction = ctx$pdatas, + param_values = learner$param_set$values, + learner_hash = learner_hash, + data_extra = ctx$data_extra) } # creates the tasks and row ids for the selected predict sets @@ -414,13 +450,10 @@ process_model_before_predict = function(learner, store_models, is_sequential, un } process_model_after_predict = function(learner, store_models, is_sequential, unmarshal, model_copy) { - if (!store_models) { - lg$debug("Erasing stored model for learner '%s'", learner$id) - learner$state$model = NULL - } else if (!is.null(model_copy)) { + if (store_models && !is.null(model_copy)) { # we created a copy of the model to avoid additional marshaling cycles learner$model = model_copy - } else if (!is_sequential || !unmarshal) { + } else if (store_models && !is_sequential || !unmarshal) { # no copy was created, here we make sure that we return the model the way the user wants it learner$model = marshal_model(learner$model, inplace = TRUE) } diff --git a/R/zzz.R b/R/zzz.R index b34c219ba..b48ff6eef 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -74,6 +74,11 @@ dummy_import = function() { # nocov start backports::import(pkgname) + # callbacks + x = utils::getFromNamespace("mlr_callbacks", ns = "mlr3misc") + x$add("mlr3.model_extractor", load_callback_model_extractor) + x$add("mlr3.holdout_task", load_callback_holdout_task) + # setup logger lg = lgr::get_logger(pkgname) assign("lg", lg, envir = parent.env(environment())) diff --git a/inst/testthat/helper_expectations.R b/inst/testthat/helper_expectations.R index d6c2ca1b0..55ac9835c 100644 --- a/inst/testthat/helper_expectations.R +++ b/inst/testthat/helper_expectations.R @@ -689,7 +689,7 @@ expect_resultdata = function(rdata, consistency = TRUE) { checkmate::expect_class(rdata, "ResultData") data = rdata$data - proto = mlr3:::star_init() + proto = mlr3:::star_init(data_extra = "data_extra" %in% names(data$fact)) checkmate::expect_set_equal(names(data), names(proto)) for (nn in names(proto)) { diff --git a/man-roxygen/param_callbacks.R b/man-roxygen/param_callbacks.R new file mode 100644 index 000000000..85bb82891 --- /dev/null +++ b/man-roxygen/param_callbacks.R @@ -0,0 +1,3 @@ +#' @param callbacks (List of [mlr3misc::Callback])\cr +#' Callbacks to be executed during the resampling process. +#' See [CallbackResample] and [ContextResample] for details. diff --git a/man/CallbackResample.Rd b/man/CallbackResample.Rd new file mode 100644 index 000000000..567bfe5db --- /dev/null +++ b/man/CallbackResample.Rd @@ -0,0 +1,70 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/CallbackResample.R +\name{CallbackResample} +\alias{CallbackResample} +\title{Resample Callback} +\description{ +Specialized \link[mlr3misc:Callback]{mlr3misc::Callback} to customize the behavior of \code{\link[=resample]{resample()}} and \code{\link[=benchmark]{benchmark()}} in mlr3. +The \code{\link[=callback_resample]{callback_resample()}} function is used to create instances of this class. +Predefined callbacks are stored in the \link[mlr3misc:Dictionary]{dictionary} \link{mlr_callbacks} and can be retrieved with \code{\link[=clbk]{clbk()}}. +For more information on callbacks, see the \code{\link[=callback_resample]{callback_resample()}} documentation. +} +\section{Super class}{ +\code{\link[mlr3misc:Callback]{mlr3misc::Callback}} -> \code{CallbackResample} +} +\section{Public fields}{ +\if{html}{\out{