Skip to content

Commit

Permalink
Avoid stack overflows with non-standard float types; closes JuliaMath#76
Browse files Browse the repository at this point in the history
  • Loading branch information
moble authored and ChrisRackauckas committed Jan 22, 2025
1 parent 6bf9c1b commit 9f6ac27
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 17 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ OpenLibm_jll = "05823500-19ac-5b8b-9628-191a04bc5112"
julia = "1.6"

[extras]
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["DoubleFloats", "Test"]
45 changes: 31 additions & 14 deletions src/NaNMath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,55 @@ module NaNMath
using OpenLibm_jll
const libm = OpenLibm_jll.libopenlibm


for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
:lgamma, :log1p)
@eval begin
($f)(x::Float64) = ccall(($(string(f)),libm), Float64, (Float64,), x)
($f)(x::Float32) = ccall(($(string(f,"f")),libm), Float32, (Float32,), x)
($f)(x::Real) = ($f)(float(x))
if $f !== :lgamma
($f)(x) = (Base.$f)(x)
($f)(x::Float16) = Float16(($f)(Float32(x)))
function ($f)(x::Real)
xf = float(x)
x === xf && throw(MethodError($f, (x,)))
return ($f)(xf)
end
end
end

for f in (:sqrt,)
@eval ($f)(x) = (Base.$f)(x)
end

for f in (:max, :min)
@eval ($f)(x, y) = (Base.$f)(x, y)
end
sin(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.sin(x) : T(NaN)
cos(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.cos(x) : T(NaN)
tan(x::T) where {T<:AbstractFloat} = isfinite(x) ? Base.tan(x) : T(NaN)
asin(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.asin(x)
acos(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.acos(x)
acosh(x::T) where {T<:AbstractFloat} = x < 1 ? T(NaN) : Base.acosh(x)
atanh(x::T) where {T<:AbstractFloat} = abs(x) > 1 ? T(NaN) : Base.atanh(x)
log(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log(x)
log2(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log2(x)
log10(x::T) where {T<:AbstractFloat} = x < 0 ? T(NaN) : Base.log10(x)
# lgamma does not have a Base version; the MethodError above will suffice
log1p(x::T) where {T<:AbstractFloat} = x < -1 ? T(NaN) : Base.log1p(x)

# Would be more efficient to remove the domain check in Base.sqrt(),
# but this doesn't seem easy to do.
sqrt(x::T) where {T<:AbstractFloat} = x < 0.0 ? T(NaN) : Base.sqrt(x)
sqrt(x::Real) = sqrt(float(x))
function sqrt(x::Real)
xf = float(x)
x === xf && throw(MethodError(sqrt, (x,)))
return sqrt(xf)
end

# Don't override built-in ^ operator
pow(x::Float64, y::Float64) = ccall((:pow,libm), Float64, (Float64,Float64), x, y)
pow(x::Float32, y::Float32) = ccall((:powf,libm), Float32, (Float32,Float32), x, y)
pow(x::Float16, y::Float16) = Float16(pow(Float32(x), Float32(y)))
# We `promote` first before converting to floating pointing numbers to ensure that
# e.g. `pow(::Float32, ::Int)` ends up calling `pow(::Float32, ::Float32)`
pow(x::Real, y::Real) = pow(promote(x, y)...)
pow(x::T, y::T) where {T<:Real} = pow(float(x), float(y))
pow(x::Number, y::Number) = pow(promote(x, y)...)
yf = float(y)
xf = float(x)
function pow(x::T, y::T) where {T<:Number}
x === xf && y === yf && throw(MethodError(pow, (x,y)))
return pow(xf, yf)
end
pow(x, y) = ^(x, y)

# The following combinations are safe, so we can fall back to ^
Expand Down
67 changes: 65 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,57 @@
using NaNMath
using Test
using DoubleFloats


# https://github.com/JuliaMath/NaNMath.jl/issues/76
@test_throws MethodError NaNMath.pow(1.0, 1.0+im)


for T in (Float64, Float32, Float16, BigFloat)
for f in (:sin, :cos, :tan, :asin, :acos, :acosh, :atanh, :log, :log2, :log10,
:log1p) # Note: do :lgamma separately because it can't handle BigFloat
@eval begin
@test NaNMath.$f($T(2//3)) isa $T
@test NaNMath.$f($T(3//2)) isa $T
@test NaNMath.$f($T(-2//3)) isa $T
@test NaNMath.$f($T(-3//2)) isa $T
@test NaNMath.$f($T(Inf)) isa $T
@test NaNMath.$f($T(-Inf)) isa $T
end
end
end
for T in (Float64, Float32, Float16)
@test NaNMath.lgamma(T(2//3)) isa T
@test NaNMath.lgamma(T(3//2)) isa T
@test NaNMath.lgamma(T(-2//3)) isa T
@test NaNMath.lgamma(T(-3//2)) isa T
@test NaNMath.lgamma(T(Inf)) isa T
@test NaNMath.lgamma(T(-Inf)) isa T
end
@test_throws MethodError NaNMath.lgamma(BigFloat(2//3))

@test isnan(NaNMath.log(-10))
@test isnan(NaNMath.log(-10f0))
@test isnan(NaNMath.log(Float16(-10)))
@test isnan(NaNMath.log1p(-100))
@test isnan(NaNMath.log1p(-100f0))
@test isnan(NaNMath.log1p(Float16(-100)))
@test isnan(NaNMath.pow(-1.5,2.3))
@test isnan(NaNMath.pow(-1.5f0,2.3f0))
@test isnan(NaNMath.pow(-1.5,2.3f0))
@test isnan(NaNMath.pow(-1.5f0,2.3))
@test isnan(NaNMath.pow(Float16(-1.5),Float16(2.3)))
@test isnan(NaNMath.pow(Float16(-1.5),2.3))
@test isnan(NaNMath.pow(-1.5,Float16(2.3)))
@test isnan(NaNMath.pow(Float16(-1.5),2.3f0))
@test isnan(NaNMath.pow(-1.5f0,Float16(2.3)))
@test isnan(NaNMath.pow(-1.5f0,BigFloat(2.3)))
@test isnan(NaNMath.pow(BigFloat(-1.5),BigFloat(2.3)))
@test isnan(NaNMath.pow(BigFloat(-1.5),2.3f0))
@test isnan(NaNMath.pow(-1.5f0,Double64(2.3)))
@test isnan(NaNMath.pow(Double64(-1.5),Double64(2.3)))
@test isnan(NaNMath.pow(Double64(-1.5),2.3f0))
@test NaNMath.pow(-1,2) isa Float64
@test NaNMath.pow(-1.5f0,2) isa Float32
@test NaNMath.pow(-1.5f0,2//1) isa Float32
@test NaNMath.pow(-1.5f0,2.3f0) isa Float32
Expand All @@ -15,16 +60,34 @@ using Test
@test NaNMath.pow(-1.5,2//1) isa Float64
@test NaNMath.pow(-1.5,2.3f0) isa Float64
@test NaNMath.pow(-1.5,2.3) isa Float64
@test NaNMath.pow(Float16(-1.5),2.3) isa Float64
@test NaNMath.pow(Float16(-1.5),Float16(2.3)) isa Float16
@test NaNMath.pow(-1.5,Float16(2.3)) isa Float64
@test NaNMath.pow(Float16(-1.5),2.3f0) isa Float32
@test NaNMath.pow(-1.5f0,Float16(2.3)) isa Float32
@test NaNMath.pow(-1.5f0,BigFloat(2.3)) isa BigFloat
@test NaNMath.pow(BigFloat(-1.5),BigFloat(2.3)) isa BigFloat
@test NaNMath.pow(BigFloat(-1.5),2.3f0) isa BigFloat
@test NaNMath.pow(-1.5f0,Double64(2.3)) isa Double64
@test NaNMath.pow(Double64(-1.5),Double64(2.3)) isa Double64
@test NaNMath.pow(Double64(-1.5),2.3f0) isa Double64
@test NaNMath.sqrt(-5) isa Float64
@test NaNMath.pow(-1,2) === 1
@test NaNMath.pow(2,2) === 4
@test NaNMath.pow(1.0, 1.0+im) === 1.0 + 0.0im
@test NaNMath.pow(1.0+im, 1) === 1.0 + 1.0im
@test NaNMath.pow(1.0+im, 1.0) === 1.0 + 1.0im
@test isnan(NaNMath.sqrt(-5))
@test NaNMath.sqrt(5) == Base.sqrt(5)
@test NaNMath.sqrt(-5f0) isa Float32
@test NaNMath.sqrt(5f0) == Base.sqrt(5f0)
@test NaNMath.sqrt(Float16(-5)) isa Float16
@test NaNMath.sqrt(Float16(5)) == Base.sqrt(Float16(5))
@test NaNMath.sqrt(BigFloat(-5)) isa BigFloat
@test NaNMath.sqrt(BigFloat(5)) == Base.sqrt(BigFloat(5))
@test isnan(NaNMath.sqrt(-3.2f0)) && NaNMath.sqrt(-3.2f0) isa Float32
@test isnan(NaNMath.sqrt(-BigFloat(7.0))) && NaNMath.sqrt(-BigFloat(7.0)) isa BigFloat
@test isnan(NaNMath.sqrt(-7)) && NaNMath.sqrt(-7) isa Float64
@test isnan(NaNMath.sqrt(-BigFloat(7.0))) && NaNMath.sqrt(-BigFloat(7.0)) isa BigFloat
@test isnan(NaNMath.sqrt(-7)) && NaNMath.sqrt(-7) isa Float64
@inferred NaNMath.sqrt(5)
@inferred NaNMath.sqrt(5.0)
@inferred NaNMath.sqrt(5.0f0)
Expand Down

0 comments on commit 9f6ac27

Please sign in to comment.