Skip to content

Commit

Permalink
Merge pull request #63 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.6.2 release
  • Loading branch information
ablaom authored Jun 3, 2024
2 parents 4de6d32 + 99a5dd2 commit 387c7c2
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 127 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJIteration"
uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
authors = ["Anthony D. Blaom <[email protected]>"]
version = "0.6.1"
version = "0.6.2"

[deps]
IterationControl = "b3c1a2ee-3fec-4384-bf48-272ea71de57c"
Expand All @@ -11,7 +11,7 @@ Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

[compat]
IterationControl = "0.5"
MLJBase = "1"
MLJBase = "1.4"
julia = "1.6"

[extras]
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

| Linux | Coverage | Documentation |
| :-----------: | :------: | :-------:|
| [![Build status](https://github.com/JuliaAI/MLJIteration.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/MLJIteration.jl/actions)| [![codecov.io](http://codecov.io/github/JuliaAI/MLJIteration.jl/coverage.svg?branch=master)](http://codecov.io/github/JuliaAI/MLJIteration.jl?branch=master) | [![docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/)|
| [![Build status](https://github.com/JuliaAI/MLJIteration.jl/workflows/CI/badge.svg)](https://github.com/JuliaAI/MLJIteration.jl/actions)| [![codecov.io](http://codecov.io/github/JuliaAI/MLJIteration.jl/coverage.svg?branch=master)](http://codecov.io/github/JuliaAI/MLJIteration.jl?branch=master) | [![docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://JuliaAI.github.io/MLJ.jl/dev/controlling_iterative_models/)|


A package for wrapping iterative models provided by the
[MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) machine
[MLJ](https://JuliaAI.github.io/MLJ.jl/dev/) machine
learning framework in a control strategy.

Builds on the generic iteration control tool
Expand All @@ -16,7 +16,7 @@ Builds on the generic iteration control tool
## Installation

Included as part of
[MLJ installation](https://alan-turing-institute.github.io/MLJ.jl/dev/#Installation-1).
[MLJ installation](https://JuliaAI.github.io/MLJ.jl/dev/#Installation-1).

Alternatively, for a "minimal" installation:

Expand Down Expand Up @@ -54,7 +54,7 @@ mach = machine(iterated_model, X, y) |> fit!;
## Documentation

See the [Controlling Iterative
Models](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/)
Models](https://JuliaAI.github.io/MLJ.jl/dev/controlling_iterative_models/)
section of the [MLJ
manual](https://alan-turing-institute.github.io/MLJ.jl/dev/).
manual](https://JuliaAI.github.io/MLJ.jl/dev/).

6 changes: 2 additions & 4 deletions src/MLJIteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ const CONTROLS = vcat(IterationControl.CONTROLS,
:WithModelDo,
:CycleLearningRate,
:Save])

const CONTROLS_LIST = join(map(c->"$c()", CONTROLS), ", ", " and ")
const TRAINING_CONTROLS = [:Step, ]

# export all control types:
for control in CONTROLS
eval(:(export $control))
end

const CONTROLS_DEFAULT = [Step(1),
const DEFAULT_CONTROLS = [Step(1),
Patience(5),
GL(),
TimeLimit(0.03), # about 2 mins
Expand All @@ -42,6 +42,4 @@ include("traits.jl")
include("ic_model.jl")
include("core.jl")



end # module
251 changes: 141 additions & 110 deletions src/constructors.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const IterationResamplingTypes =
Union{Holdout,Nothing,MLJBase.TrainTestPairs}
Union{Holdout,InSample,Nothing,MLJBase.TrainTestPairs}


## TYPES AND CONSTRUCTOR
Expand Down Expand Up @@ -72,96 +72,119 @@ err_bad_iteration_parameter(p) =
ArgumentError("Model to be iterated does not have :($p) as an iteration parameter. ")

"""
IteratedModel(model=nothing,
controls=$CONTROLS_DEFAULT,
retrain=false,
resampling=Holdout(),
measure=nothing,
weights=nothing,
class_weights=nothing,
operation=predict,
verbosity=1,
check_measure=true,
iteration_parameter=nothing,
cache=true)
Wrap the specified `model <: Supervised` in the specified iteration
`controls`. Training a machine bound to the wrapper iterates a
corresonding machine bound to `model`. Here `model` should support
iteration.
To list all controls, do `MLJIteration.CONTROLS`. Controls are
summarized at
[https://alan-turing-institute.github.io/MLJ.jl/dev/getting_started/](https://alan-turing-institute.github.io/MLJ.jl/dev/controlling_iterative_models/)
but query individual doc-strings for details and advanced options. For
creating your own controls, refer to the documentation just cited.
To make out-of-sample losses available to the controls, the machine
bound to `model` is only trained on part of the data, as iteration
proceeds. See details on training below. Specify `retrain=true`
to ensure the model is retrained on *all* available data, using the
same number of iterations, once controlled iteration has stopped.
Specify `resampling=nothing` if all data is to be used for controlled
iteration, with each out-of-sample loss replaced by the most recent
training loss, assuming this is made available by the model
(`supports_training_losses(model) == true`). Otherwise, `resampling`
must have type `Holdout` (eg, `Holdout(fraction_train=0.8, rng=123)`).
Assuming `retrain=true` or `resampling=nothing`,
`iterated_model` behaves exactly like the original `model` but with
the iteration parameter automatically selected. If
`retrain=false` (default) and `resampling` is not `nothing`, then
`iterated_model` behaves like the original model trained on a subset
of the provided data.
Controlled iteration can be continued with new `fit!` calls (warm
restart) by mutating a control, or by mutating the iteration parameter
of `model`, which is otherwise ignored.
### Training
Given an instance `iterated_model` of `IteratedModel`, calling
`fit!(mach)` on a machine `mach = machine(iterated_model, data...)`
performs the following actions:
- Assuming `resampling !== nothing`, the `data` is split into *train* and
*test* sets, according to the specified `resampling` strategy, which
must have type `Holdout`.
- A clone of the wrapped model, `iterated_model.model`, is bound to
the train data in an internal machine, `train_mach`. If `resampling
=== nothing`, all data is used instead. This machine is the object
to which controls are applied. For example, `Callback(fitted_params
|> print)` will print the value of `fitted_params(train_mach)`.
IteratedModel(model;
controls=MLJIteration.DEFAULT_CONTROLS,
resampling=Holdout(),
measure=nothing,
retrain=false,
advanced_options...,
)
Wrap the specified supervised `model` in the specified iteration `controls`. Here `model`
should support iteration, which is true if (`iteration_parameter(model)` is different from
`nothing`.
Available controls: $CONTROLS_LIST.
!!! important
To make out-of-sample losses available to the controls, the wrapped `model` is only
trained on part of the data, as iteration proceeds. The user may want to force
retraining on all data after controlled iteration has finished by specifying
`retrain=true`. See also "Training", and the `retrain` option, under "Extended help"
below.
# Extended help
# Options
- `controls=$DEFAULT_CONTROLS`: Controls are summarized at
[https://JuliaAI.github.io/MLJ.jl/dev/getting_started/](https://JuliaAI.github.io/MLJ.jl/dev/controlling_iterative_models/)
but query individual doc-strings for details and advanced options. For creating your own
controls, refer to the documentation just cited.
- `resampling=Holdout(fraction_train=0.7)`: The default resampling holds back 30% of data
for computing an out-of-sample estimate of performance (the "loss") for loss-based
controls such as `WithLossDo`. Specify `resampling=nothing` if all data is to be used
for controlled iteration, with each out-of-sample loss replaced by the most recent
training loss, assuming this is made available by the model
(`supports_training_losses(model) == true`). If the model does not report a training
loss, you can use `resampling=InSample()` instead. Otherwise, `resampling` must have
type `Holdout` or be a vector with one element of the form `(train_indices,
test_indices)`.
- `measure=nothing`: StatisticalMeasures.jl compatible measure for estimating model
performance (the "loss", but the orientation is immaterial - i.e., this could be a
score). Inferred by default. Ignored if `resampling=nothing`.
- `retrain=false`: If `retrain=true` or `resampling=nothing`, `iterated_model` behaves
exactly like the original `model` but with the iteration parameter automatically
selected ("learned"). That is, the model is retrained on *all* available data, using the
same number of iterations, once controlled iteration has stopped. This is typically
desired if wrapping the iterated model further, or when inserting in a pipeline or other
composite model. If `retrain=false` (default) and `resampling isa Holdout`, then
`iterated_model` behaves like the original model trained on a subset of the provided
data.
- `weights=nothing`: per-observation weights to be passed to `measure` where supported; if
unspecified, these are understood to be uniform.
- `class_weights=nothing`: class-weights to be passed to `measure` where supported; if
unspecified, these are understood to be uniform.
- `operation=nothing`: Operation, such as `predict` or `predict_mode`, for computing
target values, or proxy target values, for consumption by `measure`; automatically
inferred by default.
- `check_measure=true`: Specify `false` to override checks on `measure` for compatibility
with the training data.
- `iteration_parameter=nothing`: A symbol, such as `:epochs`, naming the iteration
parameter of `model`; inferred by default. Note that the actual value of the iteration
parameter in the supplied `model` is ignored; only the value of an internal clone is
mutated during training the wrapped model.
- `cache=true`: Whether or not model-specific representations of data are cached in
between iteration parameter increments; specify `cache=false` to prioritize memory over
speed.
# Training
Training an instance `iterated_model` of `IteratedModel` on some `data` (by binding to a
machine and calling `fit!`, for example) performs the following actions:
- Assuming `resampling !== nothing`, the `data` is split into *train* and *test* sets,
according to the specified `resampling` strategy.
- A clone of the wrapped model, `model` is bound to the train data in an internal machine,
`train_mach`. If `resampling === nothing`, all data is used instead. This machine is the
object to which controls are applied. For example, `Callback(fitted_params |> print)`
will print the value of `fitted_params(train_mach)`.
- The iteration parameter of the clone is set to `0`.
- The specified `controls` are repeatedly applied to `train_mach` in
sequence, until one of the controls triggers a stop. Loss-based
controls (eg, `Patience()`, `GL()`, `Threshold(0.001)`) use an
out-of-sample loss, obtained by applying `measure` to predictions
and the test target values. (Specifically, these predictions are
those returned by `operation(train_mach)`.) If `resampling ===
nothing` then the most recent training loss is used instead. Some
controls require *both* out-of-sample and training losses (eg,
`PQ()`).
- The specified `controls` are repeatedly applied to `train_mach` in sequence, until one
of the controls triggers a stop. Loss-based controls (eg, `Patience()`, `GL()`,
`Threshold(0.001)`) use an out-of-sample loss, obtained by applying `measure` to
predictions and the test target values. (Specifically, these predictions are those
returned by `operation(train_mach)`.) If `resampling === nothing` then the most recent
training loss is used instead. Some controls require *both* out-of-sample and training
losses (eg, `PQ()`).
- Once a stop has been triggered, a clone of `model` is bound to all
`data` in a machine called `mach_production` below, unless
`retrain == false` or `resampling === nothing`, in which case
`mach_production` coincides with `train_mach`.
- Once a stop has been triggered, a clone of `model` is bound to all `data` in a machine
called `mach_production` below, unless `retrain == false` (true by default) or
`resampling === nothing`, in which case `mach_production` coincides with `train_mach`.
### Prediction
# Prediction
Calling `predict(mach, Xnew)` returns `predict(mach_production,
Xnew)`. Similar similar statements hold for `predict_mean`,
`predict_mode`, `predict_median`.
Calling `predict(mach, Xnew)` in the example above returns `predict(mach_production,
Xnew)`. Similar similar statements hold for `predict_mean`, `predict_mode`,
`predict_median`.
### Controls
# Controls that mutate parameters
A control is permitted to mutate the fields (hyper-parameters) of
`train_mach.model` (the clone of `model`). For example, to mutate a
Expand All @@ -174,11 +197,25 @@ in that parameter, this will trigger retraining of `train_mach` from
scratch, with a different training outcome, which is not recommended.
### Warm restarts
# Warm restarts
If `iterated_model` is mutated and `fit!(mach)` is called again, then
a warm restart is attempted if the only parameters to change are
`model` or `controls` or both.
In the following example, the second `fit!` call will not restart training of the internal
`train_mach`, assuming `model` supports warm restarts:
```julia
iterated_model = IteratedModel(
model,
controls = [Step(1), NumberLimit(100)],
)
mach = machine(iterated_model, X, y)
fit!(mach) # train for 100 iterations
iterated_model.controls = [Step(1), NumberLimit(50)],
fit!(mach) # train for an *extra* 50 iterations
```
More generally, if `iterated_model` is mutated and `fit!(mach)` is called again, then a
warm restart is attempted if the only parameters to change are `model` or `controls` or
both.
Specifically, `train_mach.model` is mutated to match the current value
of `iterated_model.model` and the iteration parameter of the latter is
Expand All @@ -188,14 +225,14 @@ repeated application of the (updated) controls begin anew.
"""
function IteratedModel(args...;
model=nothing,
control=CONTROLS_DEFAULT,
control=DEFAULT_CONTROLS,
controls=control,
resampling=MLJBase.Holdout(),
measures=nothing,
measure=measures,
weights=nothing,
class_weights=nothing,
operation=predict,
operation=nothing,
retrain=false,
check_measure=true,
iteration_parameter=nothing,
Expand All @@ -211,30 +248,24 @@ function IteratedModel(args...;
atom = model
end

options = (
atom,
controls,
resampling,
measure,
weights,
class_weights,
operation,
retrain,
check_measure,
iteration_parameter,
cache,
)

if atom isa Deterministic
iterated_model = DeterministicIteratedModel(atom,
controls,
resampling,
measure,
weights,
class_weights,
operation,
retrain,
check_measure,
iteration_parameter,
cache)
iterated_model = DeterministicIteratedModel(options...)
elseif atom isa Probabilistic
iterated_model = ProbabilisticIteratedModel(atom,
controls,
resampling,
measure,
weights,
class_weights,
operation,
retrain,
check_measure,
iteration_parameter,
cache)
iterated_model = ProbabilisticIteratedModel(options...)
else
throw(ERR_NOT_SUPERVISED)
end
Expand Down
6 changes: 2 additions & 4 deletions src/traits.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
MLJBase.is_wrapper(::Type{<:EitherIteratedModel}) = true
MLJBase.caches_data_by_default(::Type{<:EitherIteratedModel}) = false
MLJBase.load_path(::Type{<:DeterministicIteratedModel}) =
"MLJIteration.DeterministicIteratedModel"
MLJBase.load_path(::Type{<:ProbabilisticIteratedModel}) =
"MLJIteration.ProbabilisticIteratedModel"
MLJBase.load_path(::Type{<:EitherIteratedModel}) = "MLJIteration.IteratedModel"
MLJBase.constructor(::Type{<:EitherIteratedModel}) = IteratedModel
MLJBase.package_name(::Type{<:EitherIteratedModel}) = "MLJIteration"
MLJBase.package_uuid(::Type{<:EitherIteratedModel}) =
"614be32b-d00c-4edb-bd02-1eb411ab5e55"
Expand Down
Loading

0 comments on commit 387c7c2

Please sign in to comment.