Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add findmin, findmax, argmin, and argmax #53

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name = "NaNMath"
uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
repo = "https://github.com/mlubin/NaNMath.jl.git"
authors = ["Miles Lubin"]
version = "0.3.7"
version = "0.3.8"

[deps]

Expand Down
160 changes: 160 additions & 0 deletions src/NaNMath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,164 @@ for f in (:min, :max)
@eval ($f)(a, b, c, xs...) = Base.afoldl($f, ($f)(($f)(a, b), c), xs...)
end

"""
NaNMath.findmin([f,] domain) -> (f(x), index)

##### Args:
* `f`: a function applied to the values in `domain`
* `domain`: A non-empty iterable of floating point numbers or `Missing`.

##### Returns:
* Returns a pair of a value in the codomain (outputs of `f`, defaulting to `identity`) and
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
the index of the corresponding value in the `domain` (inputs to `f`) such that `f(x)` is
minimized. If there are multiple minimal points, then the first one will be returned. `NaN`s
are treated as greater than all other values.

##### Examples:
```julia
julia> NaNMath.findmin([1., 1., 2., 2., NaN])
(1.0, 1)

julia> NaNMath.findmin(-, [1., 1., 2., 2., NaN])
(-2.0, 3)
```
"""
function findmin end
findmin(f, x) = _findminmax(Base.isgreater, f, x)
findmin(x) = findmin(identity, x)

"""
NaNMath.findmax([f,] domain) -> (f(x), index)

##### Args:
* `f`: a function applied to the values in `domain`
* `domain`: A non-empty iterable of floating point numbers or `Missing`.
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

##### Returns:
* Returns a pair of a value in the codomain (outputs of `f`, defaulting to `identity`) and
the index of the corresponding value in the `domain` (inputs to `f`) such that `f(x)` is
maximized. If there are multiple minimal points, then the first one will be returned. `NaN`s
are treated as less than all other values.

##### Examples:
```julia
julia> NaNMath.findmax([1., 1., 2., 2., NaN])
(2.0, 3)

julia> NaNMath.findmax(-, [1., 1., 2., 2., NaN])
(-1.0, 1)
```
"""
function findmax end
findmax(f, x) = _findminmax(Base.isless, f, x)
findmax(x) = findmax(identity, x)

function _findminmax_op(cmp)
return (x1_i1, x2_i2) -> begin
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
x1 = first(x1_i1)
x1 === missing && return x1_i1
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
x2 = first(x2_i2)
x2 === missing && return x2_i2
return ifelse((x1 isa Number && isnan(x2)) || !cmp(x1, x2), x1_i1, x2_i2)
end
end

function _findminmax(cmp, f, x)
return mapfoldl(_findminmax_op(cmp), pairs(x)) do (k, xk)
return f(xk), k
end
end

"""
NaNMath.argmin(f, domain) -> x

##### Args:
* `f`: A function applied to the values of `domain`
* `domain`: A non-empty iterable of floating point numbers or `Missing`.
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

##### Returns:
* Returns a value `x` in the domain of `f` for which `f(x)` is minimised. If there are
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
multiple minimal values for `f(x)`, then the first one will be found. `NaN`s are treated as
less than all other values.

##### Examples:
```julia
julia> NaNMath.argmin(abs, [1., -1., -2., 2., NaN])
1.0

julia> NaNMath.argmin(identity, [7, 1, 1, NaN])
1.0
```

NaNMath.argmin(itr) -> key

##### Args:
* `itr`: A non-empty iterable of floating point numbers or `Missing`.

##### Returns:
* Returns the index or key of the minimal element in `itr`. If there are multiple
minimal elements, then the first one will be returned

##### Examples:
```julia
julia> NaNMath.argmin([7, 1, 1, NaN])
2

julia> NaNMath.argmin([1.0 2; 3 NaN])
CartesianIndex(1, 1)

julia> NaNMath.argmin(Dict("x" => 1.0, "y" => -1, "z" => NaN))
"y"
```
"""
function argmin end
argmin(x) = findmin(identity, x)[2]
argmin(f, x) = mapfoldl(x -> (f(x), x), _findminmax_op(Base.isgreater), x)[2]

"""
NaNMath.argmax(f, domain) -> x

##### Args:
* `f`: A function applied to the values of `domain`
* `domain`: A non-empty iterable of floating point numbers or `Missing`.
sethaxen marked this conversation as resolved.
Show resolved Hide resolved

##### Returns:
* Returns a value `x` in the domain of `f` for which `f(x)` is maximised. If there are
multiple minimal values for `f(x)`, then the first one will be found. `NaN`s are treated as
greater than all other values.

##### Examples:
```julia
julia> NaNMath.argmax(abs, [1., -1., -2., NaN])
2.0

julia> NaNMath.argmax(identity, [7, 1, 1, NaN])
7.0
```

NaNMath.argmax(itr) -> key

##### Args:
* `itr`: A non-empty iterable of floating point numbers or `Missing`.

##### Returns:
* Returns the index or key of the maximal element in `itr`. If there are multiple
maximal elements, then the first one will be returned

##### Examples:
```julia
julia> NaNMath.argmax([7, 1, 1, NaN])
1

julia> NaNMath.argmax([1.0 2; 3 NaN])
CartesianIndex(2, 1)

julia> NaNMath.argmax(Dict("x" => 1.0, "y" => -1, "z" => NaN))
"x"
```
"""
function argmax end
argmax(x) = findmax(identity, x)[2]
argmax(f, x) = mapfoldl(x -> (f(x), x), _findminmax_op(Base.isless), x)[2]

end
100 changes: 100 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,103 @@ using Test
@test isnan(NaNMath.max(NaN, NaN))
@test isnan(NaNMath.max(NaN))
@test NaNMath.max(NaN, NaN, 0.0, 1.0) == 1.0

@testset "findmin/findmax" begin
xvals = [
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
[1., 2., 3., 3., 1.],
[missing, missing],
[missing, 1.0],
[1.0, missing],
(1., 2, 3., 3, 1),
(x=1, y=3, z=-4, w=-2),
Dict(:a => 1.0, :b => 1.0, :d => 3.0, :c => 2.0),
]
@testset for x in xvals
@test NaNMath.findmin(x) === findmin(x)
@test NaNMath.findmax(x) === findmax(x)
@test NaNMath.findmin(identity, x) === findmin(identity, x)
@test NaNMath.findmax(identity, x) === findmax(identity, x)
@test NaNMath.findmin(sin, x) === findmin(sin, x)
@test NaNMath.findmax(sin, x) === findmax(sin, x)
end
x = [7, 7, NaN, 1, 1, NaN]
@test NaNMath.findmin(x) === (1.0, 4)
@test NaNMath.findmax(x) === (7.0, 1)
@test NaNMath.findmin(identity, x) === (1.0, 4)
@test NaNMath.findmax(identity, x) === (7.0, 1)
@test NaNMath.findmin(-, x) === (-7.0, 1)
@test NaNMath.findmax(-, x) === (-1.0, 4)

x = [NaN, NaN]
@test NaNMath.findmin(x) === (NaN, 1)
@test NaNMath.findmax(x) === (NaN, 1)
@test NaNMath.findmin(identity, x) === (NaN, 1)
@test NaNMath.findmax(identity, x) === (NaN, 1)
@test NaNMath.findmin(sin, x) === (NaN, 1)
@test NaNMath.findmax(sin, x) === (NaN, 1)

x = [3, missing, NaN, -1]
@test NaNMath.findmin(x) === (missing, 2)
@test NaNMath.findmax(x) === (missing, 2)
@test NaNMath.findmin(identity, x) === (missing, 2)
@test NaNMath.findmax(identity, x) === (missing, 2)
@test NaNMath.findmin(sin, x) === (missing, 2)
@test NaNMath.findmax(sin, x) === (missing, 2)

x = Dict(:x => 3, :w => 2, :y => -1.0, :z => NaN)
@test NaNMath.findmin(x) === (-1.0, :y)
@test NaNMath.findmax(x) === (3, :x)
@test NaNMath.findmin(identity, x) === (-1.0, :y)
@test NaNMath.findmax(identity, x) === (3, :x)
@test NaNMath.findmin(-, x) === (-3, :x)
@test NaNMath.findmax(-, x) === (1.0, :y)
end

@testset "argmin/argmax" begin
xvals = [
[1., 2., 3., 3., 1.],
[missing, missing],
[missing, 1.0],
[1.0, missing],
(1., 2, 3., 3, 1),
(x=1, y=3, z=-4, w=-2),
Dict(:a => 1.0, :b => 1.0, :d => 3.0, :c => 2.0),
]
@testset for x in xvals
@test NaNMath.argmin(x) === argmin(x)
@test NaNMath.argmax(x) === argmax(x)
@test NaNMath.argmin(identity, x) === argmin(identity, x)
@test NaNMath.argmax(identity, x) === argmax(identity, x)
x isa Dict || @test NaNMath.argmin(sin, x) === argmin(sin, x)
x isa Dict || @test NaNMath.argmax(sin, x) === argmax(sin, x)
end
x = [7, 7, NaN, 1, 1, NaN]
@test NaNMath.argmin(x) === 4
@test NaNMath.argmax(x) === 1
@test NaNMath.argmin(identity, x) === 1.0
@test NaNMath.argmax(identity, x) === 7.0
@test NaNMath.argmin(-, x) === 7.0
@test NaNMath.argmax(-, x) === 1.0

x = [NaN, NaN]
@test NaNMath.argmin(x) === 1
@test NaNMath.argmax(x) === 1
@test NaNMath.argmin(identity, x) === NaN
@test NaNMath.argmax(identity, x) === NaN
@test NaNMath.argmin(-, x) === NaN
@test NaNMath.argmax(-, x) === NaN

x = [3, missing, NaN, -1]
@test NaNMath.argmin(x) === 2
@test NaNMath.argmax(x) === 2
@test NaNMath.argmin(identity, x) === missing
@test NaNMath.argmax(identity, x) === missing
@test NaNMath.argmin(-, x) === missing
@test NaNMath.argmax(-, x) === missing

x = Dict(:x => 3, :w => 2, :z => -1.0, :y => NaN)
@test NaNMath.argmin(x) === :z
@test NaNMath.argmax(x) === :x
@test NaNMath.argmin(identity, x) === argmin(identity, x)
@test NaNMath.argmax(identity, x) === argmax(identity, x)
end