From 77dea16bf203b464619639eb0919b6d9459e9baf Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 25 Jan 2022 14:18:58 +1300 Subject: [PATCH 1/4] fix iteration_parameter logic to fix #40 --- src/constructors.jl | 33 +++++++++++++++++++++++++-------- src/core.jl | 6 +----- src/ic_model.jl | 2 +- test/constructors.jl | 21 ++++++++++++++++----- 4 files changed, 43 insertions(+), 19 deletions(-) diff --git a/src/constructors.jl b/src/constructors.jl index ff2ef8b..6ccf9e0 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -1,8 +1,3 @@ -const ERR_MISSING_TRAINING_CONTROL = - ArgumentError("At least one control must be a training control "* - "(have type `$TrainingControl`) or be a "* - "custom control that calls IterationControl.train!. ") - const IterationResamplingTypes = Union{Holdout,Nothing,MLJBase.TrainTestPairs} @@ -37,6 +32,11 @@ mutable struct ProbabilisticIteratedModel{M<:Probabilistic} <: MLJBase.Probabili cache::Bool end +const ERR_MISSING_TRAINING_CONTROL = + ArgumentError("At least one control must be a training control "* + "(have type `$TrainingControl`) or be a "* + "custom control that calls IterationControl.train!. ") + const EitherIteratedModel{M} = Union{DeterministicIteratedModel{M},ProbabilisticIteratedModel{M}} @@ -54,6 +54,9 @@ const ERR_NEED_PARAMETER = "must be a `Symbol` or, in the case of a nested parameter, "* "an `Expr` (as in `booster.nrounds`). ") +err_bad_iteration_parameter(p) = + ArgumentError("Model to be iterated does not have :($p) as an iteration parameter. ") + """ IteratedModel(model=nothing, controls=$CONTROLS_DEFAULT, @@ -220,6 +223,8 @@ function IteratedModel(; model=nothing, end + + function MLJBase.clean!(iterated_model::EitherIteratedModel) message = "" if iterated_model.measure === nothing && @@ -232,9 +237,21 @@ function MLJBase.clean!(iterated_model::EitherIteratedModel) "Setting measure=$(iterated_model.measure). " end end - iterated_model.iteration_parameter === nothing && - iteration_parameter(iterated_model.model) === nothing && - throw(ERR_NEED_PARAMETER) + if iterated_model.iteration_parameter === nothing + iterated_model.iteration_parameter = iteration_parameter(iterated_model.model) + if iterated_model.iteration_parameter === nothing + throw(ERR_NEED_PARAMETER) + else + message *= "No iteration parameter specified. "* + "Setting iteration_parameter=:($(iterated_model.iteration_parameter)). " + end + end + try + MLJBase.recursive_getproperty(iterated_model.model, + iterated_model.iteration_parameter) + catch + throw(err_bad_iteration_parameter(iterated_model.iteration_parameter)) + end resampling = iterated_model.resampling diff --git a/src/core.jl b/src/core.jl index 3888672..bd376bb 100644 --- a/src/core.jl +++ b/src/core.jl @@ -51,11 +51,7 @@ end function MLJBase.fit(iterated_model::EitherIteratedModel, verbosity, data...) model = deepcopy(iterated_model.model) - - # get name of iteration parameter: - _iter = MLJBase.iteration_parameter(model) - iteration_param = _iter === nothing ? - iterated_model.iteration_parameter : _iter + iteration_param = iterated_model.iteration_parameter # instantiate `train_mach`: mach = if iterated_model.resampling === nothing diff --git a/src/ic_model.jl b/src/ic_model.jl index c870553..e69cbce 100644 --- a/src/ic_model.jl +++ b/src/ic_model.jl @@ -52,7 +52,7 @@ end # overloading `expose`- for `resampling === nothing`: IterationControl.expose(ic_model::ICModel) = ic_model.mach -# overloading `expose`- for `resampling isa Holdout` or +# overloading `expose`- for `resampling isa Holdout` or # other resampling strategy: IterationControl.expose(ic_model::ICModel{<:Machine{<:Resampler}}) = MLJBase.fitted_params(ic_model.mach).machine diff --git a/test/constructors.jl b/test/constructors.jl index 0f25975..76e9967 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -7,6 +7,7 @@ using Test struct Foo <: MLJBase.Unsupervised end struct Bar <: MLJBase.Deterministic end +struct FooBar <: MLJBase.Deterministic end @testset "constructors" begin model = DummyIterativeModel() @@ -15,13 +16,17 @@ struct Bar <: MLJBase.Deterministic end @test_throws MLJIteration.ERR_NOT_SUPERVISED IteratedModel(model=Int) @test_throws MLJIteration.ERR_NEED_MEASURE IteratedModel(model=Bar()) @test_throws MLJIteration.ERR_NEED_PARAMETER IteratedModel(model=Bar(), - measure=rms) - iterated_model = @test_logs((:info, r"No measure"), + measure=rms) + iterated_model = @test_logs((:info, "No measure specified. Setting "* + "measure=RootMeanSquaredError(). No "* + "iteration parameter specified. "* + "Setting iteration_parameter=:(n). "), IteratedModel(model=model)) @test iterated_model.measure == RootMeanSquaredError() - @test_logs IteratedModel(model=model, measure=mae) + @test iterated_model.iteration_parameter == :n + @test_logs IteratedModel(model=model, measure=mae, iteration_parameter=:n) - @test_logs IteratedModel(model=model, resampling=nothing) + @test_logs IteratedModel(model=model, resampling=nothing, iteration_parameter=:n) @test_logs((:info, r"`resampling` must be"), IteratedModel(model=model, @@ -34,12 +39,18 @@ struct Bar <: MLJBase.Deterministic end measure=rms)) @test_logs IteratedModel(model=model, resampling=[([1, 2], [3, 4]),], - measure=rms) + measure=rms, + iteration_parameter=:n) @test_throws(MLJIteration.ERR_MISSING_TRAINING_CONTROL, IteratedModel(model=model, resampling=nothing, controls=[Patience(), InvalidValue()])) + + @test_throws(MLJIteration.err_bad_iteration_parameter(:goo), + IteratedModel(model=model, + measure=mae, + iteration_parameter=:goo)) end end From 7293d71d69811cdd2d4728782853572c7a455cd8 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 25 Jan 2022 14:35:57 +1300 Subject: [PATCH 2/4] address #41 --- src/constructors.jl | 30 +++++++++++++++++++++--------- test/constructors.jl | 4 +++- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/constructors.jl b/src/constructors.jl index 6ccf9e0..6c7638a 100644 --- a/src/constructors.jl +++ b/src/constructors.jl @@ -37,11 +37,10 @@ const ERR_MISSING_TRAINING_CONTROL = "(have type `$TrainingControl`) or be a "* "custom control that calls IterationControl.train!. ") +const ERR_TOO_MANY_ARGUMENTS = + ArgumentError("At most one non-keyword argument allowed. ") const EitherIteratedModel{M} = Union{DeterministicIteratedModel{M},ProbabilisticIteratedModel{M}} - -const ERR_NO_MODEL = - ArgumentError("You need to specify model=... ") const ERR_NOT_SUPERVISED = ArgumentError("Only `Deterministic` and `Probabilistic` "* "model types supported.") @@ -53,6 +52,10 @@ const ERR_NEED_PARAMETER = "parameter. Please specify `iteration_parameter=...`. This "* "must be a `Symbol` or, in the case of a nested parameter, "* "an `Expr` (as in `booster.nrounds`). ") +const ERR_MODEL_UNSPECIFIED = ArgumentError( +"Expecting atomic model as argument, or as keyword argument `model=...`, "* + "but neither detected. ") + err_bad_iteration_parameter(p) = ArgumentError("Model to be iterated does not have :($p) as an iteration parameter. ") @@ -172,7 +175,8 @@ updated to the last value used in the preceding `fit!(mach)` call. Then repeated application of the (updated) controls begin anew. """ -function IteratedModel(; model=nothing, +function IteratedModel(args...; + model=nothing, control=CONTROLS_DEFAULT, controls=control, resampling=MLJBase.Holdout(), @@ -186,10 +190,18 @@ function IteratedModel(; model=nothing, iteration_parameter=nothing, cache=true) - model == nothing && throw(ERR_NO_MODEL) + length(args) < 2 || throw(ArgumentError("At most one non-keyword argument allowed. ")) + if length(args) === 1 + atom = first(args) + model === nothing || + @warn "Using `model=$atom`. Ignoring specification `model=$model`. " + else + model === nothing && throw(ERR_MODEL_UNSPECIFIED) + atom = model + end - if model isa Deterministic - iterated_model = DeterministicIteratedModel(model, + if atom isa Deterministic + iterated_model = DeterministicIteratedModel(atom, controls, resampling, measure, @@ -200,8 +212,8 @@ function IteratedModel(; model=nothing, check_measure, iteration_parameter, cache) - elseif model isa Probabilistic - iterated_model = ProbabilisticIteratedModel(model, + elseif atom isa Probabilistic + iterated_model = ProbabilisticIteratedModel(atom, controls, resampling, measure, diff --git a/test/constructors.jl b/test/constructors.jl index 76e9967..d6e2a73 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -11,7 +11,8 @@ struct FooBar <: MLJBase.Deterministic end @testset "constructors" begin model = DummyIterativeModel() - @test_throws MLJIteration.ERR_NO_MODEL IteratedModel() + @test_throws MLJIteration.ERR_TOO_MANY_ARGUMENTS IteratedModel(1, 2) + @test_throws MLJIteration.ERR_MODEL_UNSPECIFIED IteratedModel() @test_throws MLJIteration.ERR_NOT_SUPERVISED IteratedModel(model=Foo()) @test_throws MLJIteration.ERR_NOT_SUPERVISED IteratedModel(model=Int) @test_throws MLJIteration.ERR_NEED_MEASURE IteratedModel(model=Bar()) @@ -25,6 +26,7 @@ struct FooBar <: MLJBase.Deterministic end @test iterated_model.measure == RootMeanSquaredError() @test iterated_model.iteration_parameter == :n @test_logs IteratedModel(model=model, measure=mae, iteration_parameter=:n) + @test_logs IteratedModel(model, measure=mae, iteration_parameter=:n) @test_logs IteratedModel(model=model, resampling=nothing, iteration_parameter=:n) From 3ac47eaa84ef41100d6d45f84ae13c452ebf945b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 25 Jan 2022 15:03:24 +1300 Subject: [PATCH 3/4] bump julia compat (and ci) to julia = "1.6" --- .github/workflows/ci.yml | 2 +- Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 91fb6f7..9b5d87f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: version: - - '1.0' + - '1.6' - '1' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index dfebd97..bd27837 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] IterationControl = "0.5" MLJBase = "0.18.8, 0.19" -julia = "1" +julia = "1.6" [extras] MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" From 8915445cad27ccaa71bc9064d4496043a38a797b Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Fri, 28 Jan 2022 08:19:22 +1300 Subject: [PATCH 4/4] bump 0.4.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index bd27837..6562309 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJIteration" uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" authors = ["Anthony D. Blaom "] -version = "0.4.1" +version = "0.4.2" [deps] IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"