Skip to content

Commit

Permalink
Merge pull request #160 from AayushSabharwal/as/rat-stuff
Browse files Browse the repository at this point in the history
refactor: add DiffEqArray constructor
  • Loading branch information
ChrisRackauckas authored Apr 30, 2024
2 parents 52d8c3f + 2a8477b commit 652504a
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- Core
version:
- '1'
- '1.6'
- '1.10'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ ChainRulesCore = "1"
ForwardDiff = "0.10.3"
MacroTools = "0.5"
PreallocationTools = "0.4"
RecursiveArrayTools = "2,3"
RecursiveArrayTools = "3"
StaticArrays = "1.0"
julia = "1.6"
julia = "1.10"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
1 change: 1 addition & 0 deletions src/LabelledArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import RecursiveArrayTools, PreallocationTools, ForwardDiff
include("slarray.jl")
include("larray.jl")
include("chainrules.jl")
include("diffeqarray.jl")

# Common
@generated function __getindex(x::Union{LArray, SLArray}, ::Val{s}) where {s}
Expand Down
7 changes: 7 additions & 0 deletions src/diffeqarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
for LArrayType in [LArray, SLArray]
@eval function RecursiveArrayTools.DiffEqArray(vec::AbstractVector{<:$LArrayType},
ts::AbstractVector,
p = nothing)
RecursiveArrayTools.DiffEqArray(vec, ts, p; variables = collect(symbols(vec[1])))
end
end
6 changes: 6 additions & 0 deletions test/recursivearraytools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using RecursiveArrayTools, LabelledArrays, Test

ABC = @SLVector (:a, :b, :c);
A = ABC(1, 2, 3);
B = RecursiveArrayTools.DiffEqArray([A, A], [0.0, 2.0]);
@test getindex(B, :a) == [1, 1]
32 changes: 21 additions & 11 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,27 @@ using StaticArrays
using InteractiveUtils
using ChainRulesTestUtils

@time begin
@time @testset "SLArrays" begin
include("slarrays.jl")
end
@time @testset "LArrays" begin
include("larrays.jl")
end
@time @testset "DiffEq" begin
include("diffeq.jl")
const GROUP = get(ENV, "GROUP", "All")

if GROUP == "All"
@time begin
@time @testset "SLArrays" begin
include("slarrays.jl")
end
@time @testset "LArrays" begin
include("larrays.jl")
end
@time @testset "DiffEq" begin
include("diffeq.jl")
end
@time @testset "ChainRules" begin
include("chainrules.jl")
end
end
@time @testset "ChainRules" begin
include("chainrules.jl")
end

if GROUP == "All" || GROUP == "RecursiveArrayTools"
@time @testset "RecursiveArrayTools" begin
include("recursivearraytools.jl")
end
end

0 comments on commit 652504a

Please sign in to comment.