Skip to content

Commit

Permalink
refactor: use more set_values() (#1239)
Browse files Browse the repository at this point in the history
* refactor: use more set_values()

* docs: adjust text to reference set_values for changing param set
  • Loading branch information
m-muecke authored Jan 6, 2025
1 parent 47a3a51 commit 9c95317
Show file tree
Hide file tree
Showing 19 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@
#' All information about hyperparameters is stored in the slot `param_set` which is a [paradox::ParamSet].
#' The printer gives an overview about the ids of available hyperparameters, their storage type, lower and upper bounds,
#' possible levels (for factors), default values and assigned values.
#' To set hyperparameters, assign a named list to the subslot `values`:
#' To set hyperparameters, call the `set_values()` method on the `param_set`:
#' ```
#' lrn = lrn("classif.rpart")
#' lrn$param_set$values = list(minsplit = 3, cp = 0.01)
#' lrn$param_set$set_values(minsplit = 3, cp = 0.01)
#' ```
#' Note that this operation replaces all previously set hyperparameter values.
#' If you only intend to change one specific hyperparameter value and leave the others as-is, you can use the helper function [mlr3misc::insert_named()]:
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
#' @export
#' @examples
#' learner = lrn("classif.debug")
#' learner$param_set$values = list(message_train = 1, save_tasks = TRUE)
#' learner$param_set$set_values(message_train = 1, save_tasks = TRUE)
#'
#' # this should signal a message
#' task = tsk("penguins")
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassifFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ LearnerClassifFeatureless = R6Class("LearnerClassifFeatureless", inherit = Learn
ps = ps(
method = p_fct(c("mode", "sample", "weighted.sample"), default = "mode", tags = "predict")
)
ps$values = list(method = "mode")
ps$set_values(method = "mode")
super$initialize(
id = "classif.featureless",
feature_types = mlr_reflections$task_feature_types,
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerClassifRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ LearnerClassifRpart = R6Class("LearnerClassifRpart", inherit = LearnerClassif,
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
)
ps$values = list(xval = 0L)
ps$set_values(xval = 0L)

super$initialize(
id = "classif.rpart",
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerRegrFeatureless.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ LearnerRegrFeatureless = R6Class("LearnerRegrFeatureless", inherit = LearnerRegr
ps = ps(
robust = p_lgl(default = TRUE, tags = "train")
)
ps$values = list(robust = FALSE)
ps$set_values(robust = FALSE)

super$initialize(
id = "regr.featureless",
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerRegrRpart.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ LearnerRegrRpart = R6Class("LearnerRegrRpart", inherit = LearnerRegr,
usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"),
xval = p_int(0L, default = 10L, tags = "train")
)
ps$values = list(xval = 0L)
ps$set_values(xval = 0L)

super$initialize(
id = "regr.rpart",
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureClassifCosts.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ MeasureClassifCosts = R6Class("MeasureClassifCosts",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(normalize = p_lgl(tags = "required"))
param_set$values = list(normalize = TRUE)
param_set$set_values(normalize = TRUE)

super$initialize(
id = "classif.costs",
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ MeasureDebugClassif = R6Class("MeasureDebugClassif",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(na_ratio = p_dbl(0, 1, tags = "required"))
param_set$values = list(na_ratio = 0)
param_set$set_values(na_ratio = 0)
super$initialize(
id = "debug_classif",
param_set = param_set,
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSelectedFeatures.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ MeasureSelectedFeatures = R6Class("MeasureSelectedFeatures",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(normalize = p_lgl(tags = "required"))
param_set$values = list(normalize = FALSE)
param_set$set_values(normalize = FALSE)

super$initialize(
id = "selected_features",
Expand Down
2 changes: 1 addition & 1 deletion R/Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
#' r$param_set$values
#'
#' # Do only 3 repeats on 10% of the data
#' r$param_set$values = list(ratio = 0.1, repeats = 3)
#' r$param_set$set_values(ratio = 0.1, repeats = 3)
#' r$param_set$values
#'
#' # Instantiate on penguins task
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingBootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ ResamplingBootstrap = R6Class("ResamplingBootstrap", inherit = Resampling,
ratio = p_dbl(0, upper = 1, tags = "required"),
repeats = p_int(1L, tags = "required")
)
ps$values = list(ratio = 1, repeats = 30L)
ps$set_values(ratio = 1, repeats = 30L)

super$initialize(id = "bootstrap", param_set = ps, duplicated_ids = TRUE,
label = "Bootstrap", man = "mlr3::mlr_resamplings_bootstrap")
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ResamplingCV = R6Class("ResamplingCV", inherit = Resampling,
ps = ps(
folds = p_int(2L, tags = "required")
)
ps$values = list(folds = 10L)
ps$set_values(folds = 10L)

super$initialize(id = "cv", param_set = ps,
label = "Cross-Validation", man = "mlr3::mlr_resamplings_cv")
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ ResamplingHoldout = R6Class("ResamplingHoldout", inherit = Resampling,
ps = ps(
ratio = p_dbl(0, 1, tags = "required")
)
ps$values = list(ratio = 2 / 3)
ps$set_values(ratio = 2 / 3)

super$initialize(id = "holdout", param_set = ps,
label = "Holdout", man = "mlr3::mlr_resamplings_holdout")
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingRepeatedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling,
folds = p_int(2L, tags = "required"),
repeats = p_int(1L)
)
ps$values = list(repeats = 10L, folds = 10L)
ps$set_values(repeats = 10L, folds = 10L)
super$initialize(id = "repeated_cv", param_set = ps,
label = "Repeated Cross-Validation", man = "mlr3::mlr_resamplings_repeated_cv")
},
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingSubsampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ ResamplingSubsampling = R6Class("ResamplingSubsampling", inherit = Resampling,
ratio = p_dbl(0, 1, tags = "required"),
repeats = p_int(1, tags = "required")
)
ps$values = list(repeats = 30L, ratio = 2 / 3)
ps$set_values(repeats = 30L, ratio = 2 / 3)

super$initialize(id = "subsampling", param_set = ps,
label = "Subsampling", man = "mlr3::mlr_resamplings_subsampling")
Expand Down
2 changes: 1 addition & 1 deletion R/TaskGeneratorMoons.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ TaskGeneratorMoons = R6Class("TaskGeneratorMoons",
ps = ps(
sigma = p_dbl(0, tags = "required")
)
ps$values = list(sigma = 1)
ps$set_values(sigma = 1)

super$initialize(id = "moons", task_type = "classif", param_set = ps,
label = "Moons Classification", man = "mlr3::mlr_task_generators_moons")
Expand Down
4 changes: 2 additions & 2 deletions man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/Resampling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_learners_classif.debug.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 9c95317

Please sign in to comment.