Methods for statistical inference on the generalization error.
pak::pkg_install("mlr-org/mlr3inferr")
The main purpose of the package is to allow to obtain confidence
intervals for the generalization error for a number of resampling
methods. Below, we evaluate a decision tree on the sonar task using a
holdout resampling and obtain a confidence interval for the
generalization error. This is achieved using the msr("ci.holdout")
measure, to which we pass another mlr3::Measure
that determines the
loss function.
library(mlr3inferr)
rr = resample(tsk("sonar"), lrn("classif.rpart"), rsmp("holdout"))
# 0.05 is also the default
ci = msr("ci.holdout", "classif.acc", alpha = 0.05)
rr$aggregate(ci)
#> classif.acc classif.acc.lower classif.acc.upper
#> 0.7391304 0.6347628 0.8434981
It is also possible to select the default inference method for a certain
Resampling
method using msr("ci")
ci_default = msr("ci", "classif.acc")
rr$aggregate(ci_default)
#> classif.acc classif.acc.lower classif.acc.upper
#> 0.7391304 0.6347628 0.8434981
With mlr3viz
, it is also possible to
visualize multiple confidence intervals. Below, we compare a random
forest with a decision tree and a featureless learner:
library(mlr3learners)
library(mlr3viz)
bmr = benchmark(benchmark_grid(
tsks(c("sonar", "german_credit")),
lrns(c("classif.rpart", "classif.ranger", "classif.featureless")),
rsmp("subsampling")
))
autoplot(bmr, "ci", msr("ci", "classif.ce"))
Note that:
- Some methods require pointwise loss functions, i.e. have an
$obs_loss
field. - Not for every resampling method exists an inference method.
- There are combinations of datasets and learners, where inference methods can fail.
- Additional Resampling Methods
- Confidence Intervals for the Generalization Error for some resampling methods
content = as.data.table(mlr3::mlr_measures, objects = TRUE)[startsWith(get("key"), "ci."),]
content$resamplings = map(content$object, function(x) paste0(gsub("Resampling", "", x$resamplings), collapse = ", "))
content[["only pointwise loss"]] = map_chr(content$object, function(object) {
if (get_private(object)$.requires_obs_loss) "yes" else "false"
})
content = content[, c("key", "label", "resamplings", "only pointwise loss")]
knitr::kable(content, format = "markdown", col.names = tools::toTitleCase(names(content)))
Key | Label | Resamplings | Only Pointwise Loss |
---|---|---|---|
ci.con_z | Conservative-Z CI | PairedSubsampling | false |
ci.cor_t | Corrected-T CI | Subsampling | false |
ci.holdout | Holdout CI | Holdout | yes |
ci.wald_cv | Naive CV CI | CV, LOO | yes |
ci.ncv | Nested CV CI | NestedCV | yes |
mlr3inferr is a free and open source software project that encourages participation and feedback. If you have any issues, questions, suggestions or feedback, please do not hesitate to open an “issue” about it on the GitHub page!
In case of problems / bugs, it is often helpful if you provide a “minimum working example” that showcases the behaviour (but don’t worry about this if the bug is obvious).
Please understand that the resources of the project are limited: response may sometimes be delayed by a few days, and some feature suggestions may be rejected if they are deemed too tangential to the vision behind the project.