diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7a0a1f0..101d9c9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: version: - - '1.6' + - '1.0' - '1' - 'nightly' os: diff --git a/Project.toml b/Project.toml index 0831a68..2f10e18 100644 --- a/Project.toml +++ b/Project.toml @@ -1,18 +1,23 @@ name = "DiffResults" uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" +version = "1.0.3" [deps] -StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[weakdeps] +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[extensions] +StaticArraysExt = "StaticArrays" [compat] -StaticArrays = "1.5.8" -StaticArraysCore = "1.4.0" -julia = "1.6" +StaticArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 1.0" +julia = "1" [extras] StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "StaticArrays"] +test = ["StaticArrays", "Test"] diff --git a/ext/StaticArraysExt.jl b/ext/StaticArraysExt.jl new file mode 100644 index 0000000..c96bd45 --- /dev/null +++ b/ext/StaticArraysExt.jl @@ -0,0 +1,25 @@ +module StaticArraysExt + +using DiffResults, StaticArrays + +import DiffResults: DiffResult, ImmutableDiffResult, GradientResult, JacobianResult, HessianResult, derivative! +using DiffResults: value, tuple_setindex + +DiffResult(value::Number, derivs::Tuple{Vararg{StaticArray}}) = ImmutableDiffResult(value, derivs) +DiffResult(value::StaticArray, derivs::Tuple{Vararg{StaticArray}}) = ImmutableDiffResult(value, derivs) + +GradientResult(x::StaticArray) = DiffResult(first(x), x) + +JacobianResult(x::StaticArray) = DiffResult(x, zeros(StaticArrays.similar_type(typeof(x), Size(length(x),length(x))))) +JacobianResult(y::StaticArray, x::StaticArray) = DiffResult(y, zeros(StaticArrays.similar_type(typeof(x), Size(length(y),length(x))))) + +HessianResult(x::StaticArray) = DiffResult(first(x), x, zeros(StaticArrays.similar_type(typeof(x), Size(length(x),length(x))))) + +function derivative!(r::ImmutableDiffResult, x::Union{Number,StaticArray}, ::Type{Val{i}} = Val{1}) where {i} + return ImmutableDiffResult(value(r), tuple_setindex(r.derivs, x, Val{i})) +end +function derivative!(f, r::ImmutableDiffResult, x::StaticArray, ::Type{Val{i}} = Val{1}) where {i} + return derivative!(r, map(f, x), Val{i}) +end + +end diff --git a/src/DiffResults.jl b/src/DiffResults.jl index c440d50..8f24396 100644 --- a/src/DiffResults.jl +++ b/src/DiffResults.jl @@ -1,7 +1,5 @@ module DiffResults -using StaticArraysCore: StaticArray, similar_type, Size - ######### # Types # ######### @@ -45,8 +43,6 @@ Note that `derivs` can be provide in splatted form, i.e. `DiffResult(value, deri DiffResult DiffResult(value::Number, derivs::Tuple{Vararg{Number}}) = ImmutableDiffResult(value, derivs) -DiffResult(value::Number, derivs::Tuple{Vararg{StaticArray}}) = ImmutableDiffResult(value, derivs) -DiffResult(value::StaticArray, derivs::Tuple{Vararg{StaticArray}}) = ImmutableDiffResult(value, derivs) DiffResult(value::Number, derivs::Tuple{Vararg{AbstractArray}}) = MutableDiffResult(value, derivs) DiffResult(value::AbstractArray, derivs::Tuple{Vararg{AbstractArray}}) = MutableDiffResult(value, derivs) DiffResult(value::Union{Number,AbstractArray}, derivs::Union{Number,AbstractArray}...) = DiffResult(value, derivs) @@ -62,7 +58,6 @@ shape information. If you want to allocate storage yourself, use the `DiffResult constructor instead. """ GradientResult(x::AbstractArray) = DiffResult(first(x), similar(x)) -GradientResult(x::StaticArray) = DiffResult(first(x), x) """ JacobianResult(x::AbstractArray) @@ -76,7 +71,6 @@ shape information. If you want to allocate storage yourself, use the `DiffResult constructor instead. """ JacobianResult(x::AbstractArray) = DiffResult(similar(x), similar(x, length(x), length(x))) -JacobianResult(x::StaticArray) = DiffResult(x, zeros(similar_type(typeof(x), Size(length(x),length(x))))) """ JacobianResult(y::AbstractArray, x::AbstractArray) @@ -89,7 +83,6 @@ Like the single argument version, `y` and `x` are only used for type and shape information and are not stored in the returned `DiffResult`. """ JacobianResult(y::AbstractArray, x::AbstractArray) = DiffResult(similar(y), similar(y, length(y), length(x))) -JacobianResult(y::StaticArray, x::StaticArray) = DiffResult(y, zeros(similar_type(typeof(x), Size(length(y),length(x))))) """ HessianResult(x::AbstractArray) @@ -102,7 +95,6 @@ shape information. If you want to allocate storage yourself, use the `DiffResult constructor instead. """ HessianResult(x::AbstractArray) = DiffResult(first(x), zeros(eltype(x), size(x)), similar(x, length(x), length(x))) -HessianResult(x::StaticArray) = DiffResult(first(x), x, zeros(similar_type(typeof(x), Size(length(x),length(x))))) ############# # Interface # @@ -200,10 +192,6 @@ function derivative!(r::MutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Va return r end -function derivative!(r::ImmutableDiffResult, x::Union{Number,StaticArray}, ::Type{Val{i}} = Val{1}) where {i} - return ImmutableDiffResult(value(r), tuple_setindex(r.derivs, x, Val{i})) -end - function derivative!(r::ImmutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i} T = tuple_eltype(r.derivs, Val{i}) return ImmutableDiffResult(value(r), tuple_setindex(r.derivs, T(x), Val{i})) @@ -229,10 +217,6 @@ function derivative!(f, r::ImmutableDiffResult, x::Number, ::Type{Val{i}} = Val{ return derivative!(r, f(x), Val{i}) end -function derivative!(f, r::ImmutableDiffResult, x::StaticArray, ::Type{Val{i}} = Val{1}) where {i} - return derivative!(r, map(f, x), Val{i}) -end - function derivative!(f, r::ImmutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i} T = tuple_eltype(r.derivs, Val{i}) return derivative!(r, map(f, T(x)), Val{i}) @@ -333,4 +317,8 @@ Base.show(io::IO, r::ImmutableDiffResult) = print(io, "ImmutableDiffResult($(r.v Base.show(io::IO, r::MutableDiffResult) = print(io, "MutableDiffResult($(r.value), $(r.derivs))") +if !isdefined(Base, :get_extension) + include("../ext/StaticArraysExt.jl") +end + end # module