diff --git a/R/resample.R b/R/resample.R index cc1bb88f2..ea75b228e 100644 --- a/R/resample.R +++ b/R/resample.R @@ -15,6 +15,7 @@ #' @template param_allow_hotstart #' @template param_clone #' @template param_unmarshal +#' @template param_extractor #' @return [ResampleResult]. #' #' @template section_predict_sets @@ -55,7 +56,18 @@ #' bmr1 = as_benchmark_result(rr) #' bmr2 = as_benchmark_result(rr_featureless) #' print(bmr1$combine(bmr2)) -resample = function(task, learner, resampling, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) { +resample = function( + task, + learner, + resampling, + store_models = FALSE, + store_backends = TRUE, + encapsulate = NA_character_, + allow_hotstart = FALSE, + clone = c("task", "learner", "resampling"), + unmarshal = TRUE, + extractor = NULL + ) { assert_subset(clone, c("task", "learner", "resampling")) task = assert_task(as_task(task, clone = "task" %in% clone)) learner = assert_learner(as_learner(learner, clone = "learner" %in% clone, discard_state = TRUE)) @@ -115,7 +127,7 @@ resample = function(task, learner, resampling, store_models = FALSE, store_backe } res = future_map(n, workhorse, iteration = seq_len(n), learner = grid$learner, mode = grid$mode, - MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal) + MoreArgs = list(task = task, resampling = resampling, store_models = store_models, lgr_threshold = lgr_threshold, pb = pb, unmarshal = unmarshal, extractor = extractor) ) data = data.table( diff --git a/R/worker.R b/R/worker.R index e52daa23f..0cf16610b 100644 --- a/R/worker.R +++ b/R/worker.R @@ -251,7 +251,20 @@ learner_predict = function(learner, task, row_ids = NULL) { } -workhorse = function(iteration, task, learner, resampling, param_values = NULL, lgr_threshold, store_models = FALSE, pb = NULL, mode = "train", is_sequential = TRUE, unmarshal = TRUE) { +workhorse = function( + iteration, + task, + learner, + resampling, + param_values = NULL, + lgr_threshold, + store_models = FALSE, + pb = NULL, + mode = "train", + is_sequential = TRUE, + unmarshal = TRUE, + extractor = NULL + ) { if (!is.null(pb)) { pb(sprintf("%s|%s|i:%i", task$id, learner$id, iteration)) } @@ -332,6 +345,10 @@ workhorse = function(iteration, task, learner, resampling, param_values = NULL, } pdatas = discard(pdatas, is.null) + if (!is.null(extractor)) { + learner$state = insert_named(learner$state, extractor(learner$model)) + } + # set the model slot after prediction so it can be sent back to the main process process_model_after_predict( learner = learner, store_models = store_models, is_sequential = is_sequential, model_copy = model_copy_or_null, diff --git a/man-roxygen/param_extractor.R b/man-roxygen/param_extractor.R new file mode 100644 index 000000000..846298bc1 --- /dev/null +++ b/man-roxygen/param_extractor.R @@ -0,0 +1,3 @@ +#' @param extractor (`function()`)\cr +#' Function to extract information from the learner model on the worker. +#' The function takes `model` as input and must return a named list. diff --git a/man/Task.Rd b/man/Task.Rd index f45a41ddc..4e925c1ad 100644 --- a/man/Task.Rd +++ b/man/Task.Rd @@ -147,9 +147,6 @@ Required for \code{\link[=convert_task]{convert_task()}}.} \item{\code{mlr3_version}}{(\code{package_version})\cr Package version of \code{mlr3} used to create the task.} - -\item{\code{characteristics}}{(\code{list()})\cr -List of characteristics of the task, e.g. \code{list(n = 5, p = 7)}.} } \if{html}{\out{}} } @@ -303,6 +300,9 @@ Alternatively, you can provide a \code{\link[=data.frame]{data.frame()}} with th Hash (unique identifier) for all columns except the \code{primary_key}: A \code{character} vector, named by the columns that each element refers to.\cr Columns of different \code{\link{Task}}s or \code{\link{DataBackend}}s that have agreeing \code{col_hashes} always represent the same data, given that the same \code{row}s are selected. The reverse is not necessarily true: There can be columns with the same content that have different \code{col_hashes}.} + +\item{\code{characteristics}}{(\code{list()})\cr +List of characteristics of the task, e.g. \code{list(n = 5, p = 7)}.} } \if{html}{\out{}} } diff --git a/man/resample.Rd b/man/resample.Rd index b972108ef..e1c895822 100644 --- a/man/resample.Rd +++ b/man/resample.Rd @@ -13,7 +13,8 @@ resample( encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), - unmarshal = TRUE + unmarshal = TRUE, + extractor = NULL ) } \arguments{ @@ -64,6 +65,10 @@ Per default, all input objects are cloned.} Whether to unmarshal learners that were marshaled during the execution. If \code{TRUE} all models are stored in unmarshaled form. If \code{FALSE}, all learners (that need marshaling) are stored in marshaled form.} + +\item{extractor}{(\verb{function()})\cr +Function to extract information from the learner model on the worker. +The function takes \code{model} as input and must return a named list.} } \value{ \link{ResampleResult}.