Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add callbacks for resample and benchmark #1214

Open
wants to merge 58 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
5b460c2
...
be-marc Nov 20, 2024
af87ce7
...
be-marc Nov 20, 2024
6a97bf0
...
be-marc Nov 21, 2024
c6ae6bc
...
be-marc Nov 21, 2024
211b9cf
...
be-marc Nov 26, 2024
465d7c9
...
be-marc Nov 27, 2024
a2ad28f
...
be-marc Nov 27, 2024
01e6a4b
...
be-marc Nov 27, 2024
ae85cc0
...
be-marc Nov 27, 2024
a610614
Merge branch 'main' into callback
be-marc Nov 27, 2024
8dd2ba5
...
be-marc Nov 28, 2024
4b2cf71
...
be-marc Nov 28, 2024
ee83cbb
...
be-marc Nov 28, 2024
c0424e2
...
be-marc Nov 28, 2024
0789de0
refactor: remove objekt logging
be-marc Nov 28, 2024
bf85752
Merge branch 'reduce_logger' into callback
be-marc Nov 28, 2024
f6fc7dc
...
be-marc Nov 29, 2024
ad46a8c
...
be-marc Nov 29, 2024
911b3ed
...
be-marc Nov 29, 2024
fbd3b01
...
be-marc Nov 29, 2024
53b766d
...
be-marc Nov 29, 2024
81b880f
...
be-marc Nov 29, 2024
2ad3d54
...
be-marc Nov 29, 2024
017d511
...
be-marc Nov 29, 2024
73996e7
...
be-marc Nov 29, 2024
50c13b9
...
be-marc Nov 29, 2024
cf92e71
pkgdown
be-marc Nov 29, 2024
8f7f308
add iteration to context
be-marc Dec 3, 2024
1e00462
...
be-marc Dec 4, 2024
1559518
Update R/as_result_data.R
be-marc Dec 4, 2024
d046835
tests
be-marc Dec 4, 2024
5c91d75
...
be-marc Dec 4, 2024
9d14bb0
...
be-marc Dec 10, 2024
15f9e76
...
be-marc Dec 10, 2024
50a26df
...
be-marc Dec 10, 2024
65df92d
...
be-marc Dec 10, 2024
7335ded
...
be-marc Dec 10, 2024
4c2e7f2
Merge branch 'main' into callback
be-marc Dec 10, 2024
876edaf
...
be-marc Dec 10, 2024
e4ddf18
...
be-marc Dec 10, 2024
3599be0
...
be-marc Dec 10, 2024
fe0ecf6
...
be-marc Dec 10, 2024
6d36e22
...
be-marc Dec 11, 2024
3597ca4
...
be-marc Dec 11, 2024
d751dac
...
be-marc Dec 11, 2024
ab2cf9b
...
be-marc Dec 11, 2024
7fa85e0
...
be-marc Dec 11, 2024
9be497f
...
be-marc Dec 19, 2024
2b3dc85
Merge branch 'main' into callback
be-marc Dec 19, 2024
f392d04
...
be-marc Dec 19, 2024
b1117cc
...
be-marc Dec 20, 2024
de8d9ad
...
be-marc Dec 20, 2024
e6ffb33
...
be-marc Dec 20, 2024
288da82
...
be-marc Dec 20, 2024
aa95998
...
be-marc Dec 20, 2024
92b9ddd
...
be-marc Dec 20, 2024
41028cd
...
be-marc Dec 20, 2024
adc95f8
...
be-marc Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ RoxygenNote: 7.3.2
Collate:
'mlr_reflections.R'
'BenchmarkResult.R'
'CallbackResample.R'
'ContextResample.R'
'warn_deprecated.R'
'DataBackend.R'
'DataBackendCbind.R'
Expand Down Expand Up @@ -189,6 +191,7 @@ Collate:
'helper_print.R'
'install_pkgs.R'
'marshal.R'
'mlr_callbacks.R'
'mlr_sugar.R'
'mlr_test_helpers.R'
'partition.R'
Expand Down
11 changes: 11 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ S3method(unmarshal_model,classif.debug_model_marshaled)
S3method(unmarshal_model,default)
S3method(unmarshal_model,learner_state_marshaled)
export(BenchmarkResult)
export(CallbackResample)
export(ContextResample)
export(DataBackend)
export(DataBackendDataTable)
export(DataBackendMatrix)
Expand Down Expand Up @@ -207,6 +209,8 @@ export(assert_measure)
export(assert_measures)
export(assert_predictable)
export(assert_prediction)
export(assert_resample_callback)
export(assert_resample_callbacks)
export(assert_resample_result)
export(assert_resampling)
export(assert_resamplings)
Expand All @@ -218,7 +222,10 @@ export(assert_validate)
export(auto_convert)
export(benchmark)
export(benchmark_grid)
export(callback_resample)
export(check_prediction_data)
export(clbk)
export(clbks)
export(col_info)
export(convert_task)
export(create_empty_prediction_data)
Expand All @@ -236,6 +243,7 @@ export(learner_unmarshal)
export(lrn)
export(lrns)
export(marshal_model)
export(mlr_callbacks)
export(mlr_learners)
export(mlr_measures)
export(mlr_reflections)
Expand Down Expand Up @@ -269,6 +277,9 @@ importFrom(data.table,data.table)
importFrom(future,nbrOfWorkers)
importFrom(future,plan)
importFrom(graphics,plot)
importFrom(mlr3misc,clbk)
importFrom(mlr3misc,clbks)
importFrom(mlr3misc,mlr_callbacks)
importFrom(parallelly,availableCores)
importFrom(stats,contr.treatment)
importFrom(stats,model.frame)
Expand Down
4 changes: 3 additions & 1 deletion R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,9 @@ BenchmarkResult = R6Class("BenchmarkResult",
as.data.table.BenchmarkResult = function(x, ..., hashes = FALSE, predict_sets = "test", task_characteristics = FALSE) { # nolint
assert_flag(task_characteristics)
tab = get_private(x)$.data$as_data_table(view = NULL, predict_sets = predict_sets)
tab = tab[, c("uhash", "task", "learner", "resampling", "iteration", "prediction"), with = FALSE]
cns = c("uhash", "task", "learner", "resampling", "iteration", "prediction")
if ("data_extra" %in% names(tab)) cns = c(cns, "data_extra")
tab = tab[, cns, with = FALSE]

if (task_characteristics) {
set(tab, j = "characteristics", value = map(tab$task, "characteristics"))
Expand Down
155 changes: 155 additions & 0 deletions R/CallbackResample.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#' @title Resample Callback
#'
#' @description
#' Specialized [mlr3misc::Callback] to customize the behavior of [resample()] and [benchmark()] in mlr3.
#' The [callback_resample()] function is used to create instances of this class.
#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()].
#' For more information on callbacks, see the [callback_resample()] documentation.
#'
#' @export
CallbackResample = R6Class("CallbackResample",
inherit = Callback,
public = list(

#' @field on_resample_begin (`function()`)\cr
#' Stage called at the beginning of the resampling iteration.
#' Called in `workhorse()` (internal).
on_resample_begin = NULL,

#' @field on_resample_before_train (`function()`)\cr
#' Stage called before training the learner.
#' Called in `workhorse()` (internal).
on_resample_before_train = NULL,

#' @field on_resample_before_predict (`function()`)\cr
#' Stage called before predicting.
#' Called in `workhorse()` (internal).
on_resample_before_predict = NULL,

#' @field on_resample_end (`function()`)\cr
#' Stage called at the end of the resample iteration.
#' Called in `workhorse()` (internal).
on_resample_end = NULL
)
)

#' @title Create Evaluation Callback
#'
#' @description
#' Function to create a [CallbackResample].
#' Predefined callbacks are stored in the [dictionary][mlr3misc::Dictionary] [mlr_callbacks] and can be retrieved with [clbk()].
#'
#' Evaluation callbacks are called at different stages of the resampling process.
#' Each stage is called once per resampling iteration.
#' The stages are prefixed with `on_resample_*`.
#' The text in brackets indicates what happens between the stages and which accesses to the [ContextResample] (`ctx`) are typical for the stage.
#'
#' ```
#' Start Resampling Iteration on Worker
#' - on_resample_begin
#' (Split `ctx$task` into training and test set with `ctx$resampling` and `ctx$iteration`)
#' - on_resample_before_train
#' (Train the learner `ctx$learner` on training data)
#' - on_resample_before_predict
#' (Predict on predict sets and store prediction data `ctx$pdatas`)
#' - on_resample_end
#' (Erase model `ctx$learner$model` if requested and return results)
#' End Resampling Iteration on Worker
#' ```
#'
#' The callback can store data in `ctx$learner$state` or `ctx$data_extra`.
#' The data in `ctx$data_extra` is stored in the [ResampleResult] or [BenchmarkResult].
#' See also the section on parameters for more information on the stages.
#
#' @details
#' When implementing a callback, each function must have two arguments named `callback` and `context`.
#' A callback can write data to the state (`$state`), e.g. settings that affect the callback itself.
#' We highly discourage changing the task, learner and resampling objects via the callback.
#'
#' @param id (`character(1)`)\cr
#' Identifier for the new instance.
#' @param label (`character(1)`)\cr
#' Label for the new instance.
#' @param man (`character(1)`)\cr
#' String in the format `[pkg]::[topic]` pointing to a manual page for this object.
#' The referenced help package can be opened via method `$help()`.
#' @param on_resample_begin (`function()`)\cr
#' Stage called at the beginning of an evaluation.
#' Called in `workhorse()` (internal).
#' @param on_resample_before_train (`function()`)\cr
#' Stage called before training the learner.
#' Called in `workhorse()` (internal).
#' @param on_resample_before_predict (`function()`)\cr
#' Stage called before predicting.
#' Called in `workhorse()` (internal).
#' @param on_resample_end (`function()`)\cr
#' Stage called at the end of an evaluation.
#' Called in `workhorse()` (internal).
#'
#' @export
#' @examples
#' task = tsk("pima")
#' learner = lrn("classif.rpart")
#' resampling = rsmp("cv", folds = 3)
#'
#' # save selected features callback
#' callback = callback_resample("selected_features",
#' on_resample_end = function(callback, context) {
#' context$learner$state$selected_features = context$learner$selected_features()
#' }
#' )
#'
#' rr = resample(task, learner, resampling, callbacks = callback)
#'
#' rr$learners[[1]]$state$selected_features
callback_resample = function(
id,
label = NA_character_,
man = NA_character_,
on_resample_begin = NULL,
on_resample_before_train = NULL,
on_resample_before_predict = NULL,
on_resample_end = NULL
) {
stages = discard(set_names(list(
on_resample_begin,
on_resample_before_train,
on_resample_before_predict,
on_resample_end),
c(
"on_resample_begin",
"on_resample_before_train",
"on_resample_before_predict",
"on_resample_end"
)), is.null)

stages = map(stages, function(stage) crate(assert_function(stage, args = c("callback", "context"))))
callback = CallbackResample$new(id, label, man)
iwalk(stages, function(stage, name) callback[[name]] = stage)
callback
}

#' @title Assertions for Callbacks
#'
#' @description
#' Assertions for [CallbackResample] class.
#'
#' @param callback ([CallbackResample]).
#' @param null_ok (`logical(1)`)\cr
#' If `TRUE`, `NULL` is allowed.
#'
#' @return [CallbackResample | List of [CallbackResample]s.
#' @export
assert_resample_callback = function(callback, null_ok = FALSE) {
assert_class(callback, "CallbackResample", null.ok = null_ok)
invisible(callback)
}

#' @export
#' @param callbacks (list of [CallbackResample]).
#' @rdname assert_resample_callback
assert_resample_callbacks = function(callbacks, null_ok = FALSE) {
assert_list(callbacks, null.ok = null_ok)
if (null_ok && is.null(callbacks)) return(invisible(NULL))
invisible(lapply(callbacks, assert_resample_callback))
}
103 changes: 103 additions & 0 deletions R/ContextResample.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#' @title Resample Context
#'
#' @description
#' A [CallbackResample] accesses and modifies data during [resample()] and [benchmark()] via the `ContextResample`.
#' See the section on fields for a list of modifiable objects.
#' See [callback_resample()] for a list of stages that access `ContextResample`.
#'
#' @export
ContextResample = R6Class("ContextResample",
inherit = Context,
public = list(

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
#'
#' @param task ([Task])\cr
#' The task to be evaluated.
#' @param learner ([Learner])\cr
#' The learner to be evaluated.
#' @param resampling ([Resampling])\cr
#' The resampling strategy to be used.
#' @param iteration (`integer()`)\cr
#' The current iteration.
initialize = function(task, learner, resampling, iteration) {
# no assertions to avoid overhead
private$.task = task
private$.learner = learner
private$.resampling = resampling
private$.iteration = iteration

super$initialize(id = "evaluate", label = "Evaluation")
}
),

active = list(

#' @field task ([Task])\cr
#' The task to be evaluated.
#' The task is unchanged during the evaluation.
#' The task is read-only.
task = function(rhs) {
assert_ro_binding(rhs)
private$.task
},

#' @field learner ([Learner])\cr
#' The learner to be evaluated.
#' The learner contains the models after stage `on_resample_before_train`.
learner = function(rhs) {
if (missing(rhs)) {
return(private$.learner)
}
private$.learner = assert_learner(rhs)
},

#' @field resampling [Resampling]\cr
#' The resampling strategy to be used.
#' The resampling is unchanged during the evaluation.
#' The resampling is read-only.
resampling = function(rhs) {
assert_ro_binding(rhs)
private$.resampling
},

#' @field iteration (`integer()`)\cr
#' The current iteration.
#' The iteration is read-only.
iteration = function(rhs) {
assert_ro_binding(rhs)
private$.iteration
},

#' @field pdatas (List of [PredictionData])\cr
#' The prediction data.
#' The data is available on stage `on_resample_end`.
pdatas = function(rhs) {
if (missing(rhs)) {
return(private$.pdatas)
}
private$.pdatas = assert_list(rhs, "PredictionData")
},

#' @field data_extra (list())\cr
#' Data saved in the [ResampleResult] or [BenchmarkResult].
#' Use this field to save results.
#' Must be a `list()`.
data_extra = function(rhs) {
if (missing(rhs)) {
return(private$.data_extra)
}
private$.data_extra = assert_list(rhs)
}
),

private = list(
.task = NULL,
.learner = NULL,
.resampling = NULL,
.iteration = NULL,
.pdatas = NULL,
.data_extra = NULL
)
)
11 changes: 10 additions & 1 deletion R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,13 @@ ResampleResult = R6Class("ResampleResult",
private$.data$learners(private$.view)$learner
},

#' @field data_extra (list())\cr
#' Additional data stored in the [ResampleResult].
data_extra = function(rhs) {
assert_ro_binding(rhs)
private$.data$data_extra(private$.view)
},

#' @field warnings ([data.table::data.table()])\cr
#' A table with all warning messages.
#' Column names are `"iteration"` and `"msg"`.
Expand Down Expand Up @@ -373,7 +380,9 @@ ResampleResult = R6Class("ResampleResult",
as.data.table.ResampleResult = function(x, ..., predict_sets = "test") { # nolint
private = get_private(x)
tab = private$.data$as_data_table(view = private$.view, predict_sets = predict_sets)
tab[, c("task", "learner", "resampling", "iteration", "prediction"), with = FALSE]
cns = c("task", "learner", "resampling", "iteration", "prediction")
if ("data_extra" %in% names(tab)) cns = c(cns, "data_extra")
tab[, cns, with = FALSE]
}

# #' @export
Expand Down
Loading
Loading