Skip to content

Commit

Permalink
Merge branch 'main' into dt_threads
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Dec 13, 2023
2 parents 0a49890 + 318748b commit e0c6c3f
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 54 deletions.
3 changes: 2 additions & 1 deletion .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ linters: with_defaults(
object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
cyclocomp_linter = NULL, # do not check function complexity
commented_code_linter = NULL, # allow code in comments
line_length_linter = line_length_linter(180)
line_length_linter = line_length_linter(180),
indentation_linter(indent = 2L, hanging_indent_style = "never")
)

2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.2.3.9000
Collate:
'mlr_reflections.R'
'BenchmarkResult.R'
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlr3 (development version)

* Optimize runtime of `resample()` and `benchmark()` by reducing the number of hashing operations.

# mlr3 0.17.0

* Learners cannot be added to the `HotstartStack` anymore when the model is missing.
Expand Down
1 change: 1 addition & 0 deletions R/DataBackend.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ DataBackend = R6Class("DataBackend", cloneable = FALSE,
}
private$.hash = assert_string(rhs)
},

#' @template field_col_hashes
col_hashes = function() {
cn = setdiff(self$colnames, self$primary_key)
Expand Down
25 changes: 20 additions & 5 deletions R/Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@
#' prop.table(table(task$truth(r$train_set(1)))) # roughly same proportion
Resampling = R6Class("Resampling",
public = list(
#' @template field_id
id = NULL,

#' @template field_label
label = NULL,

Expand Down Expand Up @@ -126,7 +123,7 @@ Resampling = R6Class("Resampling",
#'
#' Note that this object is typically constructed via a derived classes, e.g. [ResamplingCV] or [ResamplingHoldout].
initialize = function(id, param_set = ps(), duplicated_ids = FALSE, label = NA_character_, man = NA_character_) {
self$id = assert_string(id, min.chars = 1L)
private$.id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$param_set = assert_param_set(param_set)
self$duplicated_ids = assert_flag(duplicated_ids)
Expand Down Expand Up @@ -186,6 +183,7 @@ Resampling = R6Class("Resampling",
instance = private$.combine(lapply(strata$row_id, private$.sample, task = task))
}

private$.hash = NULL
self$instance = instance
self$task_hash = task$hash
self$task_nrow = task$nrow
Expand Down Expand Up @@ -214,6 +212,16 @@ Resampling = R6Class("Resampling",
),

active = list(
#' @template field_id
id = function(rhs) {
if (missing(rhs)) {
return(private$.id)
}

private$.hash = NULL
private$.id = assert_string(rhs, min.chars = 1L)
},

#' @field is_instantiated (`logical(1)`)\cr
#' Is `TRUE` if the resampling has been instantiated.
is_instantiated = function(rhs) {
Expand All @@ -227,11 +235,18 @@ Resampling = R6Class("Resampling",
if (!self$is_instantiated) {
return(NA_character_)
}
calculate_hash(list(class(self), self$id, self$param_set$values, self$instance))

if (is.null(private$.hash)) {
private$.hash = calculate_hash(list(class(self), self$id, self$param_set$values, self$instance))
}

private$.hash
}
),

private = list(
.id = NULL,
.hash = NULL,
.groups = NULL,

.get_set = function(getter, i) {
Expand Down
34 changes: 17 additions & 17 deletions R/ResultData.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,33 +45,33 @@ ResultData = R6Class("ResultData",
if (nrow(data) == 0L) {
self$data = star_init()
} else {
fact = data[, c("uhash", "iteration", "learner_state", "prediction", "task", "learner", "resampling", "param_values", "learner_hash"),
with = FALSE]
set(fact, j = "task_hash", value = hashes(fact$task))
set(fact, j = "learner_phash", value = phashes(fact$learner))
set(fact, j = "resampling_hash", value = hashes(fact$resampling))

uhashes = data.table(uhash = unique(fact$uhash))
tasks = fact[, list(task = .SD$task[1L]),
setcolorder(data, c("uhash", "iteration", "learner_state", "prediction", "task", "learner", "resampling", "param_values", "learner_hash"))
uhashes = data.table(uhash = unique(data$uhash))
setkeyv(data, c("uhash", "iteration"))

data[, task_hash := task[[1]]$hash, by = "uhash"]
data[, learner_phash := learner[[1]]$phash, by = "uhash"]
data[, resampling_hash := resampling[[1]]$hash, by = "uhash"]

tasks = data[, list(task = .SD$task[1L]),
keyby = "task_hash"]
learners = fact[, list(learner = list(.SD$learner[[1L]]$reset())),
learners = data[, list(learner = list(.SD$learner[[1L]]$reset())),
keyby = "learner_phash"]
resamplings = fact[, list(resampling = .SD$resampling[1L]),
resamplings = data[, list(resampling = .SD$resampling[1L]),
keyby = "resampling_hash"]
learner_components = fact[, list(learner_param_vals = list(.SD$param_values[[1]])),
learner_components = data[, list(learner_param_vals = list(.SD$param_values[[1]])),
keyby = "learner_hash"]

set(fact, j = "task", value = NULL)
set(fact, j = "learner", value = NULL)
set(fact, j = "resampling", value = NULL)
set(fact, j = "param_values", value = NULL)
setkeyv(fact, c("uhash", "iteration"))
set(data, j = "task", value = NULL)
set(data, j = "learner", value = NULL)
set(data, j = "resampling", value = NULL)
set(data, j = "param_values", value = NULL)

if (!store_backends) {
set(tasks, j = "task", value = lapply(tasks$task, task_rm_backend))
}

self$data = list(fact = fact, uhashes = uhashes, tasks = tasks, learners = learners,
self$data = list(fact = data, uhashes = uhashes, tasks = tasks, learners = learners,
resamplings = resamplings, learner_components = learner_components)
}
}
Expand Down
77 changes: 57 additions & 20 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
#' Instead, the methods first create a new [DataBackendDataTable] from the provided new data, and then
#' merge both backends into an abstract [DataBackend] which merges the results on-demand.
#' * `rename()` wraps the [DataBackend] of the Task in an additional [DataBackend] which deals with the renaming. Also updates `$col_roles` and `$col_info`.
#' * `set_levels()` updates the field `col_info()`.
#' * `set_levels()` and `droplevels()` `update the field `col_info()`.
#'
#' @template seealso_task
#' @concept Task
Expand All @@ -73,9 +73,6 @@
#' head(task)
Task = R6Class("Task",
public = list(
#' @template field_id
id = NULL,

#' @template field_label
label = NA_character_,

Expand Down Expand Up @@ -114,7 +111,7 @@ Task = R6Class("Task",
#'
#' Note that this object is typically constructed via a derived classes, e.g. [TaskClassif] or [TaskRegr].
initialize = function(id, task_type, backend, label = NA_character_, extra_args = list()) {
self$id = assert_string(id, min.chars = 1L)
private$.id = assert_string(id, min.chars = 1L)
self$label = assert_string(label, na.ok = TRUE)
self$task_type = assert_choice(task_type, mlr_reflections$task_types$type)
if (!inherits(backend, "DataBackend")) {
Expand Down Expand Up @@ -175,7 +172,7 @@ Task = R6Class("Task",
catf("%s (%i x %i)%s", format(self), self$nrow, self$ncol,
if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label))

roles = self$col_roles
roles = private$.col_roles
roles = roles[lengths(roles) > 0L]

# print additional columns as specified in reflections
Expand Down Expand Up @@ -204,7 +201,7 @@ Task = R6Class("Task",
catn(str_indent(sprintf("* %s:", str), roles[[role]]))
})

nrows = list(test = length(self$row_roles$test), holdout = length(self$row_roles$holdout))
nrows = list(test = length(private$.row_roles$test), holdout = length(private$.row_roles$holdout))
if (nrows$test || nrows$holdout) {
str = paste(c(
if(nrows$test) sprintf("%i (test)", nrows$test),
Expand Down Expand Up @@ -237,8 +234,8 @@ Task = R6Class("Task",
assert_choice(data_format, self$data_formats)
assert_flag(ordered)

row_roles = self$row_roles
col_roles = self$col_roles
row_roles = private$.row_roles
col_roles = private$.col_roles

if (is.null(rows)) {
rows = row_roles$use
Expand Down Expand Up @@ -368,6 +365,7 @@ Task = R6Class("Task",
filter = function(rows) {
assert_has_backend(self)
rows = assert_row_ids(rows)
private$.hash = NULL
private$.row_roles$use = intersect(private$.row_roles$use, rows)
invisible(self)
},
Expand All @@ -387,6 +385,8 @@ Task = R6Class("Task",
assert_has_backend(self)
assert_character(cols)
assert_subset(cols, private$.col_roles$feature)
private$.hash = NULL
private$.col_hashes = NULL
private$.col_roles$feature = intersect(private$.col_roles$feature, cols)
invisible(self)
},
Expand Down Expand Up @@ -452,7 +452,7 @@ Task = R6Class("Task",

# columns with these roles must be present in data
mandatory_roles = c("target", "feature", "weight", "group", "stratum", "order")
mandatory_cols = unlist(self$col_roles[mandatory_roles], use.names = FALSE)
mandatory_cols = unlist(private$.col_roles[mandatory_roles], use.names = FALSE)
missing_cols = setdiff(mandatory_cols, data$colnames)
if (length(missing_cols)) {
stopf("Cannot rbind data to task '%s', missing the following mandatory columns: %s", self$id, str_collapse(missing_cols))
Expand Down Expand Up @@ -484,9 +484,10 @@ Task = R6Class("Task",
tab[, c("type_y", "levels_y") := list(NULL, NULL)]

# everything looks good, modify task
private$.hash = NULL
self$backend = DataBackendRbind$new(self$backend, data)
self$col_info = tab[]
self$row_roles$use = c(self$row_roles$use, data$rownames)
private$.row_roles$use = c(private$.row_roles$use, data$rownames)

invisible(self)
},
Expand Down Expand Up @@ -535,7 +536,10 @@ Task = R6Class("Task",
setkeyv(self$col_info, "id")

# add new features
self$col_roles$feature = union(self$col_roles$feature, setdiff(data$colnames, c(pk, self$col_roles$target)))
private$.hash = NULL
private$.col_hashes = NULL
col_roles = private$.col_roles
private$.col_roles$feature = union(col_roles$feature, setdiff(data$colnames, c(pk, col_roles$target)))

# update backend
self$backend = DataBackendCbind$new(self$backend, data)
Expand All @@ -562,9 +566,11 @@ Task = R6Class("Task",
#' the object in its previous state.
rename = function(old, new) {
assert_has_backend(self)
private$.hash = NULL
private$.col_hashes = NULL
self$backend = DataBackendRename$new(self$backend, old, new)
setkeyv(self$col_info[old, ("id") := new, on = "id"], "id")
self$col_roles = map(self$col_roles, map_values, old = old, new = new)
private$.col_roles = map(private$.col_roles, map_values, old = old, new = new)
invisible(self)
},

Expand Down Expand Up @@ -593,7 +599,10 @@ Task = R6Class("Task",
set_row_roles = function(rows, roles = NULL, add_to = NULL, remove_from = NULL) {
assert_has_backend(self)
assert_subset(rows, self$backend$rownames)

private$.hash = NULL
private$.row_roles = task_set_roles(private$.row_roles, rows, roles, add_to, remove_from)

invisible(self)
},

Expand Down Expand Up @@ -622,8 +631,12 @@ Task = R6Class("Task",
set_col_roles = function(cols, roles = NULL, add_to = NULL, remove_from = NULL) {
assert_has_backend(self)
assert_subset(cols, self$col_info$id)

private$.hash = NULL
private$.col_hashes = NULL
new_roles = task_set_roles(private$.col_roles, cols, roles, add_to, remove_from)
private$.col_roles = task_check_col_roles(self, new_roles)

invisible(self)
},

Expand All @@ -646,6 +659,8 @@ Task = R6Class("Task",

tab = enframe(lapply(levels, unname), name = "id", value = "levels")
tab$fix_factor_levels = TRUE

private$.hash = NULL
self$col_info = ujoin(self$col_info, tab, key = "id")

invisible(self)
Expand All @@ -670,6 +685,7 @@ Task = R6Class("Task",
tab = tab[lengths(levels) > lengths(new_levels)]
tab[, c("levels", "fix_factor_levels") := list(Map(intersect, levels, new_levels), TRUE)]

private$.hash = NULL
self$col_info = ujoin(self$col_info, remove_named(tab, "new_levels"), key = "id")

invisible(self)
Expand Down Expand Up @@ -706,12 +722,27 @@ Task = R6Class("Task",
),

active = list(
#' @template field_id
id = function(rhs) {
if (missing(rhs)) {
return(private$.id)
}

private$.hash = NULL
private$.id = assert_string(rhs, min.chars = 1L)
},


#' @template field_hash
hash = function(rhs) {
private$.hash %??% calculate_hash(
class(self), self$id, self$backend$hash, self$col_info,
remove_named(private$.row_roles, "test"), private$.col_roles, private$.properties
)
if (is.null(private$.hash)) {
private$.hash = calculate_hash(
class(self), self$id, self$backend$hash, self$col_info,
remove_named(private$.row_roles, "test"), private$.col_roles, private$.properties
)
}

private$.hash
},

#' @field row_ids (`integer()`)\cr
Expand All @@ -728,7 +759,7 @@ Task = R6Class("Task",
#' * `"row_name"` (`character()`).
row_names = function(rhs) {
assert_ro_binding(rhs)
nn = self$col_roles$name
nn = private$.col_roles$name
if (length(nn) == 0L) {
return(NULL)
}
Expand Down Expand Up @@ -804,6 +835,7 @@ Task = R6Class("Task",
assert_names(names(rhs), "unique", permutation.of = mlr_reflections$task_row_roles, .var.name = "names of row_roles")
rhs = map(rhs, assert_row_ids, .var.name = "elements of row_roles")

private$.hash = NULL
private$.row_roles = rhs
},

Expand Down Expand Up @@ -835,6 +867,8 @@ Task = R6Class("Task",
assert_names(names(rhs), "unique", must.include = mlr_reflections$task_col_roles[[self$task_type]], .var.name = "names of col_roles")
assert_subset(unlist(rhs, use.names = FALSE), setdiff(self$col_info$id, self$backend$primary_key), .var.name = "elements of col_roles")

private$.hash = NULL
private$.col_hashes = NULL
private$.col_roles = task_check_col_roles(self, rhs)
},

Expand Down Expand Up @@ -982,11 +1016,15 @@ Task = R6Class("Task",

#' @template field_col_hashes
col_hashes = function() {
private$.col_hashes %??% self$backend$col_hashes[setdiff(unlist(self$col_roles), self$backend$primary_key)]
if (is.null(private$.col_hashes)) {
private$.col_hashes = self$backend$col_hashes[setdiff(unlist(private$.col_roles), self$backend$primary_key)]
}
private$.col_hashes
}
),

private = list(
.id = NULL,
.properties = NULL,
.col_roles = NULL,
.row_roles = NULL,
Expand All @@ -995,7 +1033,6 @@ Task = R6Class("Task",

deep_clone = function(name, value) {
# NB: DataBackends are never copied!
# TODO: check if we can assume col_info to be read-only
if (name == "col_info") copy(value) else value
}
)
Expand Down
Loading

0 comments on commit e0c6c3f

Please sign in to comment.