From 244572f44a07af9c0f984292ed09b0517aae48ec Mon Sep 17 00:00:00 2001 From: Michel Lang Date: Wed, 13 Dec 2023 13:10:32 +0100 Subject: [PATCH] Cache hashes (#977) --- .lintr | 3 +- DESCRIPTION | 2 +- NAMESPACE | 2 + R/DataBackend.R | 1 + R/Resampling.R | 25 +++++++++--- R/Task.R | 77 ++++++++++++++++++++++++++--------- man/Resampling.Rd | 8 ++-- man/Task.Rd | 10 ++--- tests/testthat/test_Measure.R | 2 +- 9 files changed, 93 insertions(+), 37 deletions(-) diff --git a/.lintr b/.lintr index 9c43293db..4656a8ddc 100644 --- a/.lintr +++ b/.lintr @@ -5,6 +5,7 @@ linters: with_defaults( object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names cyclocomp_linter = NULL, # do not check function complexity commented_code_linter = NULL, # allow code in comments - line_length_linter = line_length_linter(180) + line_length_linter = line_length_linter(180), + indentation_linter(indent = 2L, hanging_indent_style = "never") ) diff --git a/DESCRIPTION b/DESCRIPTION index 1f1248aa7..6a209ec99 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -73,7 +73,7 @@ Config/testthat/edition: 3 Config/testthat/parallel: false NeedsCompilation: no Roxygen: list(markdown = TRUE, r6 = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.2.3.9000 Collate: 'mlr_reflections.R' 'BenchmarkResult.R' diff --git a/NAMESPACE b/NAMESPACE index 42587fb0d..1e4db46ba 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -234,6 +234,8 @@ import(palmerpenguins) import(paradox) importFrom(R6,R6Class) importFrom(R6,is.R6) +importFrom(data.table,as.data.table) +importFrom(data.table,data.table) importFrom(future,nbrOfWorkers) importFrom(future,plan) importFrom(graphics,plot) diff --git a/R/DataBackend.R b/R/DataBackend.R index 948a12fb0..c61c28110 100644 --- a/R/DataBackend.R +++ b/R/DataBackend.R @@ -98,6 +98,7 @@ DataBackend = R6Class("DataBackend", cloneable = FALSE, } private$.hash = assert_string(rhs) }, + #' @template field_col_hashes col_hashes = function() { cn = setdiff(self$colnames, self$primary_key) diff --git a/R/Resampling.R b/R/Resampling.R index 8e43d7119..e03d4c371 100644 --- a/R/Resampling.R +++ b/R/Resampling.R @@ -82,9 +82,6 @@ #' prop.table(table(task$truth(r$train_set(1)))) # roughly same proportion Resampling = R6Class("Resampling", public = list( - #' @template field_id - id = NULL, - #' @template field_label label = NULL, @@ -126,7 +123,7 @@ Resampling = R6Class("Resampling", #' #' Note that this object is typically constructed via a derived classes, e.g. [ResamplingCV] or [ResamplingHoldout]. initialize = function(id, param_set = ps(), duplicated_ids = FALSE, label = NA_character_, man = NA_character_) { - self$id = assert_string(id, min.chars = 1L) + private$.id = assert_string(id, min.chars = 1L) self$label = assert_string(label, na.ok = TRUE) self$param_set = assert_param_set(param_set) self$duplicated_ids = assert_flag(duplicated_ids) @@ -186,6 +183,7 @@ Resampling = R6Class("Resampling", instance = private$.combine(lapply(strata$row_id, private$.sample, task = task)) } + private$.hash = NULL self$instance = instance self$task_hash = task$hash self$task_nrow = task$nrow @@ -214,6 +212,16 @@ Resampling = R6Class("Resampling", ), active = list( + #' @template field_id + id = function(rhs) { + if (missing(rhs)) { + return(private$.id) + } + + private$.hash = NULL + private$.id = assert_string(rhs, min.chars = 1L) + }, + #' @field is_instantiated (`logical(1)`)\cr #' Is `TRUE` if the resampling has been instantiated. is_instantiated = function(rhs) { @@ -227,11 +235,18 @@ Resampling = R6Class("Resampling", if (!self$is_instantiated) { return(NA_character_) } - calculate_hash(list(class(self), self$id, self$param_set$values, self$instance)) + + if (is.null(private$.hash)) { + private$.hash = calculate_hash(list(class(self), self$id, self$param_set$values, self$instance)) + } + + private$.hash } ), private = list( + .id = NULL, + .hash = NULL, .groups = NULL, .get_set = function(getter, i) { diff --git a/R/Task.R b/R/Task.R index dd8f0a09c..cc2910607 100644 --- a/R/Task.R +++ b/R/Task.R @@ -48,7 +48,7 @@ #' Instead, the methods first create a new [DataBackendDataTable] from the provided new data, and then #' merge both backends into an abstract [DataBackend] which merges the results on-demand. #' * `rename()` wraps the [DataBackend] of the Task in an additional [DataBackend] which deals with the renaming. Also updates `$col_roles` and `$col_info`. -#' * `set_levels()` updates the field `col_info()`. +#' * `set_levels()` and `droplevels()` `update the field `col_info()`. #' #' @template seealso_task #' @concept Task @@ -73,9 +73,6 @@ #' head(task) Task = R6Class("Task", public = list( - #' @template field_id - id = NULL, - #' @template field_label label = NA_character_, @@ -114,7 +111,7 @@ Task = R6Class("Task", #' #' Note that this object is typically constructed via a derived classes, e.g. [TaskClassif] or [TaskRegr]. initialize = function(id, task_type, backend, label = NA_character_, extra_args = list()) { - self$id = assert_string(id, min.chars = 1L) + private$.id = assert_string(id, min.chars = 1L) self$label = assert_string(label, na.ok = TRUE) self$task_type = assert_choice(task_type, mlr_reflections$task_types$type) if (!inherits(backend, "DataBackend")) { @@ -175,7 +172,7 @@ Task = R6Class("Task", catf("%s (%i x %i)%s", format(self), self$nrow, self$ncol, if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label)) - roles = self$col_roles + roles = private$.col_roles roles = roles[lengths(roles) > 0L] # print additional columns as specified in reflections @@ -204,7 +201,7 @@ Task = R6Class("Task", catn(str_indent(sprintf("* %s:", str), roles[[role]])) }) - nrows = list(test = length(self$row_roles$test), holdout = length(self$row_roles$holdout)) + nrows = list(test = length(private$.row_roles$test), holdout = length(private$.row_roles$holdout)) if (nrows$test || nrows$holdout) { str = paste(c( if(nrows$test) sprintf("%i (test)", nrows$test), @@ -237,8 +234,8 @@ Task = R6Class("Task", assert_choice(data_format, self$data_formats) assert_flag(ordered) - row_roles = self$row_roles - col_roles = self$col_roles + row_roles = private$.row_roles + col_roles = private$.col_roles if (is.null(rows)) { rows = row_roles$use @@ -368,6 +365,7 @@ Task = R6Class("Task", filter = function(rows) { assert_has_backend(self) rows = assert_row_ids(rows) + private$.hash = NULL private$.row_roles$use = intersect(private$.row_roles$use, rows) invisible(self) }, @@ -387,6 +385,8 @@ Task = R6Class("Task", assert_has_backend(self) assert_character(cols) assert_subset(cols, private$.col_roles$feature) + private$.hash = NULL + private$.col_hashes = NULL private$.col_roles$feature = intersect(private$.col_roles$feature, cols) invisible(self) }, @@ -452,7 +452,7 @@ Task = R6Class("Task", # columns with these roles must be present in data mandatory_roles = c("target", "feature", "weight", "group", "stratum", "order") - mandatory_cols = unlist(self$col_roles[mandatory_roles], use.names = FALSE) + mandatory_cols = unlist(private$.col_roles[mandatory_roles], use.names = FALSE) missing_cols = setdiff(mandatory_cols, data$colnames) if (length(missing_cols)) { stopf("Cannot rbind data to task '%s', missing the following mandatory columns: %s", self$id, str_collapse(missing_cols)) @@ -484,9 +484,10 @@ Task = R6Class("Task", tab[, c("type_y", "levels_y") := list(NULL, NULL)] # everything looks good, modify task + private$.hash = NULL self$backend = DataBackendRbind$new(self$backend, data) self$col_info = tab[] - self$row_roles$use = c(self$row_roles$use, data$rownames) + private$.row_roles$use = c(private$.row_roles$use, data$rownames) invisible(self) }, @@ -535,7 +536,10 @@ Task = R6Class("Task", setkeyv(self$col_info, "id") # add new features - self$col_roles$feature = union(self$col_roles$feature, setdiff(data$colnames, c(pk, self$col_roles$target))) + private$.hash = NULL + private$.col_hashes = NULL + col_roles = private$.col_roles + private$.col_roles$feature = union(col_roles$feature, setdiff(data$colnames, c(pk, col_roles$target))) # update backend self$backend = DataBackendCbind$new(self$backend, data) @@ -562,9 +566,11 @@ Task = R6Class("Task", #' the object in its previous state. rename = function(old, new) { assert_has_backend(self) + private$.hash = NULL + private$.col_hashes = NULL self$backend = DataBackendRename$new(self$backend, old, new) setkeyv(self$col_info[old, ("id") := new, on = "id"], "id") - self$col_roles = map(self$col_roles, map_values, old = old, new = new) + private$.col_roles = map(private$.col_roles, map_values, old = old, new = new) invisible(self) }, @@ -593,7 +599,10 @@ Task = R6Class("Task", set_row_roles = function(rows, roles = NULL, add_to = NULL, remove_from = NULL) { assert_has_backend(self) assert_subset(rows, self$backend$rownames) + + private$.hash = NULL private$.row_roles = task_set_roles(private$.row_roles, rows, roles, add_to, remove_from) + invisible(self) }, @@ -622,8 +631,12 @@ Task = R6Class("Task", set_col_roles = function(cols, roles = NULL, add_to = NULL, remove_from = NULL) { assert_has_backend(self) assert_subset(cols, self$col_info$id) + + private$.hash = NULL + private$.col_hashes = NULL new_roles = task_set_roles(private$.col_roles, cols, roles, add_to, remove_from) private$.col_roles = task_check_col_roles(self, new_roles) + invisible(self) }, @@ -646,6 +659,8 @@ Task = R6Class("Task", tab = enframe(lapply(levels, unname), name = "id", value = "levels") tab$fix_factor_levels = TRUE + + private$.hash = NULL self$col_info = ujoin(self$col_info, tab, key = "id") invisible(self) @@ -670,6 +685,7 @@ Task = R6Class("Task", tab = tab[lengths(levels) > lengths(new_levels)] tab[, c("levels", "fix_factor_levels") := list(Map(intersect, levels, new_levels), TRUE)] + private$.hash = NULL self$col_info = ujoin(self$col_info, remove_named(tab, "new_levels"), key = "id") invisible(self) @@ -706,12 +722,27 @@ Task = R6Class("Task", ), active = list( + #' @template field_id + id = function(rhs) { + if (missing(rhs)) { + return(private$.id) + } + + private$.hash = NULL + private$.id = assert_string(rhs, min.chars = 1L) + }, + + #' @template field_hash hash = function(rhs) { - private$.hash %??% calculate_hash( - class(self), self$id, self$backend$hash, self$col_info, - remove_named(private$.row_roles, "test"), private$.col_roles, private$.properties - ) + if (is.null(private$.hash)) { + private$.hash = calculate_hash( + class(self), self$id, self$backend$hash, self$col_info, + remove_named(private$.row_roles, "test"), private$.col_roles, private$.properties + ) + } + + private$.hash }, #' @field row_ids (`integer()`)\cr @@ -728,7 +759,7 @@ Task = R6Class("Task", #' * `"row_name"` (`character()`). row_names = function(rhs) { assert_ro_binding(rhs) - nn = self$col_roles$name + nn = private$.col_roles$name if (length(nn) == 0L) { return(NULL) } @@ -804,6 +835,7 @@ Task = R6Class("Task", assert_names(names(rhs), "unique", permutation.of = mlr_reflections$task_row_roles, .var.name = "names of row_roles") rhs = map(rhs, assert_row_ids, .var.name = "elements of row_roles") + private$.hash = NULL private$.row_roles = rhs }, @@ -835,6 +867,8 @@ Task = R6Class("Task", assert_names(names(rhs), "unique", must.include = mlr_reflections$task_col_roles[[self$task_type]], .var.name = "names of col_roles") assert_subset(unlist(rhs, use.names = FALSE), setdiff(self$col_info$id, self$backend$primary_key), .var.name = "elements of col_roles") + private$.hash = NULL + private$.col_hashes = NULL private$.col_roles = task_check_col_roles(self, rhs) }, @@ -982,11 +1016,15 @@ Task = R6Class("Task", #' @template field_col_hashes col_hashes = function() { - private$.col_hashes %??% self$backend$col_hashes[setdiff(unlist(self$col_roles), self$backend$primary_key)] + if (is.null(private$.col_hashes)) { + private$.col_hashes = self$backend$col_hashes[setdiff(unlist(private$.col_roles), self$backend$primary_key)] + } + private$.col_hashes } ), private = list( + .id = NULL, .properties = NULL, .col_roles = NULL, .row_roles = NULL, @@ -995,7 +1033,6 @@ Task = R6Class("Task", deep_clone = function(name, value) { # NB: DataBackends are never copied! - # TODO: check if we can assume col_info to be read-only if (name == "col_info") copy(value) else value } ) diff --git a/man/Resampling.Rd b/man/Resampling.Rd index 134ab77b2..d48ab4aeb 100644 --- a/man/Resampling.Rd +++ b/man/Resampling.Rd @@ -105,10 +105,6 @@ Other Resampling: \section{Public fields}{ \if{html}{\out{
}} \describe{ -\item{\code{id}}{(\code{character(1)})\cr -Identifier of the object. -Used in tables, plot and text output.} - \item{\code{label}}{(\code{character(1)})\cr Label for this object. Can be used in tables, plot and text output instead of the ID.} @@ -144,6 +140,10 @@ Defaults to \code{NA}, but can be set by child classes.} \section{Active bindings}{ \if{html}{\out{
}} \describe{ +\item{\code{id}}{(\code{character(1)})\cr +Identifier of the object. +Used in tables, plot and text output.} + \item{\code{is_instantiated}}{(\code{logical(1)})\cr Is \code{TRUE} if the resampling has been instantiated.} diff --git a/man/Task.Rd b/man/Task.Rd index 187eaefe5..da405c0c2 100644 --- a/man/Task.Rd +++ b/man/Task.Rd @@ -46,7 +46,7 @@ This provides a different "view" on the data without altering the data itself. Instead, the methods first create a new \link{DataBackendDataTable} from the provided new data, and then merge both backends into an abstract \link{DataBackend} which merges the results on-demand. \item \code{rename()} wraps the \link{DataBackend} of the Task in an additional \link{DataBackend} which deals with the renaming. Also updates \verb{$col_roles} and \verb{$col_info}. -\item \code{set_levels()} updates the field \code{col_info()}. +\item \code{set_levels()} and \code{droplevels()} \verb{update the field }col_info()`. } } @@ -108,10 +108,6 @@ Other Task: \section{Public fields}{ \if{html}{\out{
}} \describe{ -\item{\code{id}}{(\code{character(1)})\cr -Identifier of the object. -Used in tables, plot and text output.} - \item{\code{label}}{(\code{character(1)})\cr Label for this object. Can be used in tables, plot and text output instead of the ID.} @@ -153,6 +149,10 @@ Package version of \code{mlr3} used to create the task.} \section{Active bindings}{ \if{html}{\out{
}} \describe{ +\item{\code{id}}{(\code{character(1)})\cr +Identifier of the object. +Used in tables, plot and text output.} + \item{\code{hash}}{(\code{character(1)})\cr Hash (unique identifier) for this object.} diff --git a/tests/testthat/test_Measure.R b/tests/testthat/test_Measure.R index 38bd54134..312f10711 100644 --- a/tests/testthat/test_Measure.R +++ b/tests/testthat/test_Measure.R @@ -128,5 +128,5 @@ test_that("time_train is > 0", { skip_on_cran() rr = resample(tsk("iris"), lrn("classif.debug"), rsmp("holdout")) res = rr$score(msr("time_train")) - expect_gt(res$time_train, 0) + expect_gte(res$time_train, 0) })