Skip to content

Commit

Permalink
fix(learner): column info and type conversion in predict_newdata
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jan 9, 2025
1 parent e21782a commit d753a24
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 19 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# mlr3 (development version)

* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions if the input is a `data.frame` (#685)
* fix: the `$predict_newdata()` method of `Learner` now automatically conducts type conversions (#685)
* BREAKING_CHANGE: Predicting on a `task` with the wrong column information is now an error and not a warning.
* Column names with UTF-8 characters are now allowed by default.
The option `mlr3.allow_utf8_names` is removed.
Expand Down
33 changes: 20 additions & 13 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -370,15 +370,15 @@ Learner = R6Class("Learner",
#' of the training task stored in the learner.
#' If the learner has been fitted via [resample()] or [benchmark()], you need to pass the corresponding task stored
#' in the [ResampleResult] or [BenchmarkResult], respectively.
#' Further, [`auto_convert`] is used for type-conversions to ensure compatability
#' of features between `$train()` and `$predict()`.
#'
#' @param newdata (any object supported by [as_data_backend()])\cr
#' New data to predict on.
#' All data formats convertible by [as_data_backend()] are supported, e.g.
#' `data.frame()` or [DataBackend].
#' If a [DataBackend] is provided as `newdata`, the row ids are preserved,
#' otherwise they are set to to the sequence `1:nrow(newdata)`.
#' If the input is a `data.frame`, [`auto_convert`] is used for type-conversions to ensure compatability
#' of features between `$train()` and `$predict()`.
#'
#' @param task ([Task]).
#'
Expand All @@ -395,31 +395,38 @@ Learner = R6Class("Learner",
task = task_rm_backend(task)
}

if (is.data.frame(newdata)) {
keep_cols = intersect(names(newdata), task$col_info$id)
ci = task$col_info[list(keep_cols), on = "id"]
newdata = do.call(data.table, Map(auto_convert,
value = as.list(newdata)[ci$id],
id = ci$id, type = ci$type, levels = ci$levels))
}

newdata = as_data_backend(newdata)
assert_names(newdata$colnames, must.include = task$feature_names)

# the following columns are automatically set to NA if missing
impute = unlist(task$col_roles[c("target", "name", "order", "stratum", "group", "weight")], use.names = FALSE)
impute = setdiff(impute, newdata$colnames)
if (length(impute)) {
tab1 = if (length(impute)) {
# create list with correct NA types and cbind it to the backend
ci = insert_named(task$col_info[list(impute), c("id", "type", "levels"), on = "id", with = FALSE], list(value = NA))
na_cols = set_names(pmap(ci, function(..., nrow) rep(auto_convert(...), nrow), nrow = newdata$nrow), ci$id)
tab = invoke(data.table, .args = insert_named(na_cols, set_names(list(newdata$rownames), newdata$primary_key)))
invoke(data.table, .args = insert_named(na_cols, set_names(list(newdata$rownames), newdata$primary_key)))
}

# Perform type conversion where necessary
keep_cols = intersect(newdata$colnames, task$col_info$id)
ci = task$col_info[list(keep_cols), ][
get("type") != col_info(newdata)[list(keep_cols), on = "id"]$type]
tab2 = do.call(data.table, Map(auto_convert,
value = as.list(newdata$data(rows = newdata$rownames, cols = ci$id)),
id = ci$id, type = ci$type, levels = ci$levels))

tab = cbind(tab1, tab2)
if (ncol(tab)) {
tab[[newdata$primary_key]] = newdata$rownames
newdata = DataBackendCbind$new(newdata, DataBackendDataTable$new(tab, primary_key = newdata$primary_key))
}

# do some type conversions if necessary
prevci = task$col_info
task$backend = newdata
task$col_info = col_info(task$backend)
task$col_info[, c("label", "fix_factor_levels")] = prevci[list(task$col_info$id), on = "id", c("label", "fix_factor_levels")]
task$col_info$fix_factor_levels[is.na(task$col_info$fix_factor_levels)] = FALSE
task$row_roles$use = task$backend$rownames
self$predict(task)
},
Expand Down
6 changes: 3 additions & 3 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_classif.featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_regr.featureless.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 39 additions & 0 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,42 @@ test_that("predict_newdata auto conversion (#685)", {

expect_equal(p1, p2)
})

test_that("predict_newdata creates column info correctly", {

learner = lrn("classif.debug", save_tasks = TRUE)
task = tsk("iris")
task$col_info$label = letters[1:6]
task$col_info$fix_factor_levels = c(TRUE, TRUE, FALSE, TRUE, FALSE, TRUE)
learner$train(task)

## data.frame is passed without task
p1 = learner$predict_newdata(iris[10:11, ])
expect_equal(learner$model$task_predict$col_info, task$col_info)
expect_equal(p1$row_ids, 1:2)

## backend is passed without task
d = iris[10:11, ]
d$..row_id = 10:11
b = as_data_backend(d, primary_key = "..row_id")
p2 = learner$predict_newdata(b)
expect_equal(p2$row_ids, 10:11)
expect_equal(learner$model$task_predict$col_info, task$col_info)

## data.frame is passed with task
task2 = tsk("iris")
learner$predict_newdata(iris[10:11, ], task2)
expect_equal(learner$model$task_predict$col_info, task2$col_info)

## backend is passed with task
learner$predict_newdata(b, task2)
expect_equal(learner$model$task_predict$col_info, task2$col_info)

## backend with different name for primary key
d2 = iris[10:11, ]
d2$row_id = 10:11
b2 = as_data_backend(d2, primary_key = "row_id")
p3 = learner$predict_newdata(b2, task2)
expect_equal(p3$row_ids, 10:11)
expect_true("row_id" %in% learner$model$task_predict$col_info$id)
})

0 comments on commit d753a24

Please sign in to comment.