Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Reactant.jl #28

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft

Support for Reactant.jl #28

wants to merge 9 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Nov 26, 2024

This doesn't work yet, but wants to work like this:

julia> using Flux, Fluxperimental, Reactant, Enzyme
julia> img = rand32(28, 28, 1, 128);
julia> loss(m, x) = sum(abs2, m(x));
julia> mlp = Chain(Flux.flatten, Dense(28^2 => 32, tanh), Dense(32 => 10));

julia> mlp(img)[1:3]  # plain Julia
3-element Vector{Float32}:
  0.22694848
 -0.72605485
  0.57976365

julia> re_mlp = Reactor(mlp);  # uses Reactant

julia> re_mlp(img)[1:3]
┌ Info: compiling...summary(xr) = "28×28×1×128 ConcreteRArray{Float32, 4}"
3-element ConcreteRArray{Float32, 1}:
  0.22694828
 -0.72605544
  0.57976323

julia> re_mlp  # after forward but not yet gradient
Reactor(
  Chain(
    Flux.flatten,
    Dense(784 => 32, tanh),             # 25_120 parameters
    Dense(32 => 10),                    # 330 parameters
  ),
  # compiled for 28×28×1×128 ConcreteRArray{Float32, 4}
  # norm(∇) ≈ 0.0f0
  # ∇compiled for nothing
)         # Total: 4 trainable arrays, 25_450 parameters,
          # plus 4 non-trainable, 25_450 parameters, summarysize 689 bytes.

julia> Flux.gradient(loss, mlp, img)[1].layers[2].bias[1:3]  # uses Zygote
3-element Vector{Float32}:
   90.490005
 -208.77806
   28.711397

julia> Flux.gradient(loss, re_mlp, Const(img))[1].layers[2].bias[1:3]  # uses Reactant
[ Info: compiling gradient(loss, ::Reactor)
# assorted errors...

@jumerckx
Copy link

jumerckx commented Dec 8, 2024

I had a look an the issue stems from Reactant trying to trace the Active type passed to Enzyme.autodiff. Reactant is unable to create MLIR equivalents of the Active argument.

details

The actual error is caused by the statements in the compiled thunk:
usbuf_1 = (traced_getfield(getindex(args, 3), $(Expr(:quote, 1)))).data where, in this case, args[3] == Active.

I'll leave it to @wsmoses to say whether this should work in the future. But by `@compile`-ing a function that only takes the arguments that actually need to end up in the MLIR function, compilation works. i.e.:
_autodiff(f, dup, inputs, seed) = Enzyme.autodiff(Reverse, Const(_fun!), seed, Const(f), dup, inputs...)

# inside Flux.gradient:
fun = @compile _autodiff(f, dup, xrs, seed) 
fun(f, dup, xrs, seed)

I haven't verified correctness of the gradients, though.

@wsmoses
Copy link

wsmoses commented Dec 8, 2024

so all the reactant rarrays are duplicated. rnumber's should be active [in reverse mode]. Just like Vector{Float32} vs Float32. We haven't done many rnumber tests with AD yet so that could be it? But I'd assume flux just takes rarray's?

Either way can you file an issue with the case that doesn' work and we can fix it!

@wsmoses
Copy link

wsmoses commented Dec 8, 2024

in this case GH CI complains with:

ERROR: LoadError: SystemError: opening file "/home/runner/work/Fluxperimental.jl/Fluxperimental.jl/src/reactant.jl": No such file or directory

so it will be a bit hard to help debug that as is :P

@mcabbott
Copy link
Member Author

mcabbott commented Dec 8, 2024

Sorry the first commit left out one file! I've now pushed an updated version, without any changes to the code touching Reactant. Strangely, it gives errors on the first call, but works on the second -- maybe I didn't try that last time.

else
@info "compiling gradient($f, ::Fluxactor)..."
@info "compiling gradient($f, ::Reactor, ::Const...)"
# fun = @compile Enzyme.autodiff(Reverse, f, Active, dup, xrs...) # this gives ERROR: "Unhandled type Type" above
fun = @compile Enzyme.autodiff(Reverse, Const(_fun!), seed, Const(f), dup, xrs...) # this gives ERROR: type TypeVar has no field data
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a bit of an aside, but presently you can’t directly compile autodiff, autodiff needs to be inside of another function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. This I changed in 0665a55 and get different errors :)

(If @compile requires exactly one function call with Symbol arguments, ideally the macro would check that & reject more complicated syntax immediately?)

fwd_input
fwd_count::Int
gradient::M
grad_compiled
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Longer term it’s actually better to compile the entire update step rather than just the gradient

Copy link
Member Author

@mcabbott mcabbott Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I hope that in real use this will be most of the time -- gradients are expensive, and this should capture the entire fwd+back. Optimisers.jl is one or two fused in-place broadcasts per parameter array.

The function train! does the whole update step, many of them, so perhaps that's the next step.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the reason why compiling the entire updating is nice is because we can get rid of intermediate allocations for the gradient/etc and just generate kernels/etc that update the model weights in place

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, that's another step. Then you want the allocation of shadows to also happen within @compile, rather than being stored in external Duplicated / Reactor structs.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah ideally, because they can get optimized away on the inside of compile

@wsmoses
Copy link

wsmoses commented Dec 8, 2024

Looks like an issue on this side?

Fluxperimental.jl: Error During Test at /home/runner/work/Fluxperimental.jl/Fluxperimental.jl/test/runtests.jl:4
Got exception outside of a @test
LoadError: UndefVarError: @compact not defined
Stacktrace:

@mcabbott
Copy link
Member Author

mcabbott commented Dec 8, 2024

There are no tests yet, just trying things at the REPL, as shown in docstrings. E.g. here is the present state, with error.

(The error is because I commented out half the imports to circumvent FluxML/Flux.jl#2545 for now.)

@mcabbott
Copy link
Member Author

Now with tests. CI shows different errors on 1.10 and 1.11.

@wsmoses
Copy link

wsmoses commented Dec 16, 2024

Can you try on Reactant#main and open issues for anything pending?

things are quite in ….. flux

@mcabbott
Copy link
Member Author

mcabbott commented Dec 16, 2024

Well that gives me different errors!

julia> @testset "simple train!" begin
           X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32)
           Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1)
           # data = Flux.DataLoader((X, Y); batchsize=16, shuffle=true)
           data = Flux.DataLoader((X .+ 0f0, Y .+ 0f0); batchsize=16, shuffle=true)  # this avoids some erros from conversion

           model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2)) |> Reactor
           state = Flux.setup(Adam(0.1, (0.7, 0.95)), model)  # Note that I'm doing this after |> Reactor, ideally before would work too?

           Flux.train!(model, data, state; epochs=100) do m, x, y
               Flux.logitcrossentropy(m(x), y)
           end

           @test all((softmax(model(X)) .> 0.5) .== Y)
       end
[ Info: compiling
simple train!: Error During Test at REPL[6]:1
  Got exception outside of a @test
  UndefVarError: `traced_getfield` not defined in `Reactant.TracedUtils`
  Suggestion: check for spelling errors or missing imports.
  Stacktrace:
    [1] push_val!(ad_inputs::Vector{Reactant.MLIR.IR.Value}, x::Chain{Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}}}, path::Tuple{Int64, Int64, Int64})
      @ Reactant.TracedUtils ~/.julia/packages/Reactant/BtZAf/src/TracedUtils.jl:326
    [2] push_acts!(ad_inputs::Vector{Reactant.MLIR.IR.Value}, x::Duplicated{Chain{Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}}}}, path::Tuple{Int64, Int64, Int64}, reverse::Bool)
      @ Reactant ~/.julia/packages/Reactant/BtZAf/src/Interpreter.jl:154
    [3] overload_autodiff(::ReverseMode{false, false, FFIABI, false, false}, ::Const{typeof(FluxReactantExt._applyloss!)}, ::Type{Const{Nothing}}, ::Duplicated{Reactant.TracedRArray{Float32, 1}}, ::Const{var"#5#6"}, ::Duplicated{Chain{Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}}}}, ::Const{Reactant.TracedRArray{Float32, 2}}, ::Const{Reactant.TracedRArray{Float32, 2}})
      @ Reactant ~/.julia/packages/Reactant/BtZAf/src/Interpreter.jl:255
    [4] autodiff(::ReverseMode{false, false, FFIABI, false, false}, ::Const{typeof(FluxReactantExt._applyloss!)}, ::Type{Const{Nothing}}, ::Duplicated{Reactant.TracedRArray{Float32, 1}}, ::Const{var"#5#6"}, ::Duplicated{Chain{Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}}}}, ::Const{Reactant.TracedRArray{Float32, 2}}, ::Const{Reactant.TracedRArray{Float32, 2}})
      @ Reactant ~/.julia/packages/Reactant/BtZAf/src/Interpreter.jl:492
    [5] autodiff
      @ ~/.julia/packages/Enzyme/haqjK/src/Enzyme.jl:544 [inlined]
    [6] _step!
      @ ~/.julia/dev/Fluxperimental/ext/FluxReactantExt.jl:354 [inlined]
    [7] _step!(loss::var"#5#6", seed::Duplicated{Reactant.TracedRArray{Float32, 1}}, model::Chain{Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}}}, d_splat::Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}}, opt_state::@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}})
      @ Reactant ./<missing>:0
    [8] GenericMemory
      @ ./boot.jl:516 [inlined]
    [9] IdDict
      @ ./iddict.jl:31 [inlined]
   [10] IdDict
      @ ./iddict.jl:49 [inlined]
   [11] make_zero (repeats 2 times)
      @ ~/.julia/packages/EnzymeCore/15Zff/src/EnzymeCore.jl:529 [inlined]
   [12] _step!
      @ ~/.julia/dev/Fluxperimental/ext/FluxReactantExt.jl:353 [inlined]
   [13] call_with_reactant(::typeof(FluxReactantExt._step!), ::var"#5#6", ::Duplicated{Reactant.TracedRArray{Float32, 1}}, ::Chain{Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}}}, ::Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}}, ::@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}})
      @ Reactant ~/.julia/packages/Reactant/BtZAf/src/utils.jl:0
   [14] (::Reactant.TracedUtils.var"#8#18"{Bool, Bool, typeof(FluxReactantExt._step!), Tuple{var"#5#6", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{var"#5#6", Duplicated{Reactant.TracedRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}})()
      @ Reactant.TracedUtils ~/.julia/packages/Reactant/BtZAf/src/TracedUtils.jl:210
   [15] block!(f::Reactant.TracedUtils.var"#8#18"{Bool, Bool, typeof(FluxReactantExt._step!), Tuple{var"#5#6", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{var"#5#6", Duplicated{Reactant.TracedRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}}}, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}}, blk::Reactant.MLIR.IR.Block)
      @ Reactant.MLIR.IR ~/.julia/packages/Reactant/BtZAf/src/mlir/IR/Block.jl:201
   [16] make_mlir_fn(f::Function, args::Tuple{var"#5#6", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
      @ Reactant.TracedUtils ~/.julia/packages/Reactant/BtZAf/src/TracedUtils.jl:197
   [17] make_mlir_fn
      @ ~/.julia/packages/Reactant/BtZAf/src/TracedUtils.jl:117 [inlined]
   [18] #10
      @ ~/.julia/packages/Reactant/BtZAf/src/Compiler.jl:295 [inlined]
   [19] block!(f::Reactant.Compiler.var"#10#15"{typeof(FluxReactantExt._step!), Tuple{var"#5#6", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}}, blk::Reactant.MLIR.IR.Block)
      @ Reactant.MLIR.IR ~/.julia/packages/Reactant/BtZAf/src/mlir/IR/Block.jl:201
   [20] #9
      @ ~/.julia/packages/Reactant/BtZAf/src/Compiler.jl:294 [inlined]
   [21] mmodule!(f::Reactant.Compiler.var"#9#14"{Reactant.MLIR.IR.Module, typeof(FluxReactantExt._step!), Tuple{var"#5#6", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}}, blk::Reactant.MLIR.IR.Module)
      @ Reactant.MLIR.IR ~/.julia/packages/Reactant/BtZAf/src/mlir/IR/Module.jl:92
   [22] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{var"#5#6", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}; optimize::Bool)
      @ Reactant.Compiler ~/.julia/packages/Reactant/BtZAf/src/Compiler.jl:291
   [23] compile_mlir!
      @ ~/.julia/packages/Reactant/BtZAf/src/Compiler.jl:290 [inlined]
   [24] (::Reactant.Compiler.var"#34#36"{Bool, typeof(FluxReactantExt._step!), Tuple{var"#5#6", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}})()
      @ Reactant.Compiler ~/.julia/packages/Reactant/BtZAf/src/Compiler.jl:698
   [25] context!(f::Reactant.Compiler.var"#34#36"{Bool, typeof(FluxReactantExt._step!), Tuple{var"#5#6", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}}, ctx::Reactant.MLIR.IR.Context)
      @ Reactant.MLIR.IR ~/.julia/packages/Reactant/BtZAf/src/mlir/IR/Context.jl:76
   [26] compile_xla(f::Function, args::Tuple{var"#5#6", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}; client::Nothing, optimize::Bool)
      @ Reactant.Compiler ~/.julia/packages/Reactant/BtZAf/src/Compiler.jl:695
   [27] compile_xla
      @ ~/.julia/packages/Reactant/BtZAf/src/Compiler.jl:690 [inlined]
   [28] compile(f::Function, args::Tuple{var"#5#6", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}; client::Nothing, optimize::Bool, sync::Bool)
      @ Reactant.Compiler ~/.julia/packages/Reactant/BtZAf/src/Compiler.jl:722
   [29] macro expansion
      @ ~/.julia/packages/Reactant/BtZAf/src/Compiler.jl:475 [inlined]
   [30] macro expansion
      @ ~/.julia/dev/Fluxperimental/ext/FluxReactantExt.jl:329 [inlined]
   [31] macro expansion
      @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
   [32] train!(loss::Function, m::Reactor{Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}}, data::MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random.TaskLocalRNG, Val{nothing}}, opt_state::@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}; epochs::Int64)
      @ FluxReactantExt ~/.julia/dev/Fluxperimental/ext/FluxReactantExt.jl:324
   [33] macro expansion
      @ REPL[6]:10 [inlined]
   [34] macro expansion
      @ /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/Test/src/Test.jl:1700 [inlined]
   [35] top-level scope
      @ REPL[6]:2
   [36] eval
      @ ./boot.jl:430 [inlined]
   [37] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
      @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:226
   [38] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
      @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:323
   [39] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
      @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:308
   [40] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
      @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:464
   [41] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:450
   [42] (::Base.var"#1138#1140"{Bool, Symbol, Bool})(REPL::Module)
      @ Base ./client.jl:446
   [43] #invokelatest#2
      @ ./essentials.jl:1054 [inlined]
   [44] invokelatest
      @ ./essentials.jl:1051 [inlined]
   [45] run_main_repl(interactive::Bool, quiet::Bool, banner::Symbol, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:430
   [46] repl_main
      @ ./client.jl:567 [inlined]
   [47] _start()
      @ Base ./client.jl:541
Test Summary: | Error  Total   Time
simple train! |     1      1  42.7s
ERROR: Some tests did not pass: 0 passed, 0 failed, 1 errored, 0 broken.

(jl_rGG3mV) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_rGG3mV/Project.toml`
  [587475ba] Flux v0.16.1 `~/.julia/dev/Flux`
  [3102ee7a] Fluxperimental v0.2.3 `~/.julia/dev/Fluxperimental`
  [3c362404] Reactant v0.2.10 `https://github.com/EnzymeAD/Reactant.jl.git#main`
  [a3311ec8] ReactantCore v0.1.3 `https://github.com/EnzymeAD/Reactant.jl.git:lib/ReactantCore#main`

julia> versioninfo()
Julia Version 1.11.0
Commit 501a4f25c2b (2024-10-07 11:40 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 11 × Apple M3 Pro
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, apple-m3)
Threads: 4 default, 0 interactive, 2 GC (on 5 virtual cores)

@wsmoses
Copy link

wsmoses commented Dec 16, 2024

try with EnzymeAD/Reactant.jl#385 by chance?

@mcabbott
Copy link
Member Author

With that PR merged, now it's back to saying "type Array has no field data".

julia> @testset "simple train!" begin
           X = repeat(hcat(digits.(0:3, base=2, pad=2)...), 1, 32)
           Y = Flux.onehotbatch(xor.(eachrow(X)...), 0:1)
           # data = Flux.DataLoader((X, Y); batchsize=16, shuffle=true)
           data = Flux.DataLoader((X .+ 0f0, Y .+ 0f0); batchsize=16, shuffle=true)  # this avoids some erros from conversion

           model = Chain(Dense(2 => 3, sigmoid), BatchNorm(3), Dense(3 => 2)) |> Reactor
           state = Flux.setup(Adam(0.1, (0.7, 0.95)), model)  # Note that I'm doing this after |> Reactor, ideally before would work too?

           Flux.train!(model, data, state; epochs=100) do m, x, y
               Flux.logitcrossentropy(m(x), y)
           end

           @test all((softmax(model(X)) .> 0.5) .== Y)
       end

[ Info: compiling
simple train!: Error During Test at REPL[5]:1
  Got exception outside of a @test
  type Array has no field data
  Stacktrace:
    [1] getproperty
      @ ./Base.jl:49 [inlined]
    [2] macro expansion
      @ ~/.julia/packages/Reactant/i0Ypg/src/Compiler.jl:771 [inlined]
    [3] (::Reactant.Compiler.Thunk{Symbol("##_step!_reactant#475241")})(::var"#5#6", ::Duplicated{ConcreteRArray{Float32, 1}}, ::Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, ::Tuple{Matrix{Float32}, Matrix{Float32}}, ::@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}})
      @ Reactant.Compiler ~/.julia/packages/Reactant/i0Ypg/src/Compiler.jl:794
    [4] macro expansion
      @ ~/.julia/dev/Fluxperimental/ext/FluxReactantExt.jl:332 [inlined]
    [5] macro expansion
      @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
    [6] train!(loss::Function, m::Reactor{Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}}, data::MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random.TaskLocalRNG, Val{nothing}}, opt_state::@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}; epochs::Int64)
      @ FluxReactantExt ~/.julia/dev/Fluxperimental/ext/FluxReactantExt.jl:324
    [7] macro expansion
      @ REPL[5]:10 [inlined]
    [8] macro expansion
      @ /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/Test/src/Test.jl:1700 [inlined]
    [9] top-level scope
      @ REPL[5]:2
   [10] eval
      @ ./boot.jl:430 [inlined]
   [11] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
      @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:226
   [12] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
      @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:323
   [13] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
      @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:308
   [14] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
      @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:464
   [15] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL /Applications/Julia-1.11.app/Contents/Resources/julia/share/julia/stdlib/v1.11/REPL/src/REPL.jl:450
   [16] (::Base.var"#1138#1140"{Bool, Symbol, Bool})(REPL::Module)
      @ Base ./client.jl:446
   [17] #invokelatest#2
      @ ./essentials.jl:1054 [inlined]
   [18] invokelatest
      @ ./essentials.jl:1051 [inlined]
   [19] run_main_repl(interactive::Bool, quiet::Bool, banner::Symbol, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:430
   [20] repl_main
      @ ./client.jl:567 [inlined]
   [21] _start()
      @ Base ./client.jl:541
Test Summary: | Error  Total   Time
simple train! |     1      1  54.8s
ERROR: Some tests did not pass: 0 passed, 0 failed, 1 errored, 0 broken.

(jl_IPkmEO) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_IPkmEO/Project.toml`
  [587475ba] Flux v0.16.1 `~/.julia/dev/Flux`
  [3102ee7a] Fluxperimental v0.2.3 `~/.julia/dev/Fluxperimental`
  [3c362404] Reactant v0.2.10 `https://github.com/EnzymeAD/Reactant.jl.git#main`
  [a3311ec8] ReactantCore v0.1.3 `https://github.com/EnzymeAD/Reactant.jl.git:lib/ReactantCore#main`

julia> versioninfo()
Julia Version 1.11.0
Commit 501a4f25c2b (2024-10-07 11:40 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 11 × Apple M3 Pro
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, apple-m3)
Threads: 4 default, 0 interactive, 2 GC (on 5 virtual cores)

@wsmoses
Copy link

wsmoses commented Dec 16, 2024

hm can you file an issue with a MWE? I presume this should be easy to fix once we have a nice reproducer

@wsmoses
Copy link

wsmoses commented Jan 2, 2025

@mcabbott what is the status here with the latest reactant?

@wsmoses
Copy link

wsmoses commented Jan 5, 2025

@mcabbott @CarloLucibello one bug is due to FluxML/Optimisers.jl#206 which should be fixed by that PR.

The second one is a bug in this PR. However, I've made Reactant throw a better error message for it in EnzymeAD/Reactant.jl#474

[ Info: compiling gradient(loss1, ::Reactor, ::Const...)
simple gradient: Error During Test at /Users/wmoses/git/Fluxperimental.jl/test/reactant.jl:26
  Got exception outside of a @test
  
  The Reactant-compiled function `Reactant.Compiler.Thunk{FluxReactantExt.var"#_autodiff#3"{var"#loss1#127"}, Symbol("##_autodiff_reactant#99162"), Tuple{Duplicated{ConcreteRArray{Float32, 1}}, Duplicated{Chain{Tuple{typeof(Flux.flatten), Dense{typeof(tanh), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}}, Const{ConcreteRArray{Float32, 4}}}, false}` exists, but no method is defined for this combination of argument types.
  You passed in arguments with types (ReverseMode{false, false, FFIABI, false, false}, var"#loss1#127", Type{Active}, Duplicated{Chain{Tuple{typeof(Flux.flatten), Dense{typeof(tanh), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}}, Const{ConcreteRArray{Float32, 4}})
  However the method you are calling was compiled for arguments with types (Duplicated{ConcreteRArray{Float32, 1}}, Duplicated{Chain{Tuple{typeof(Flux.flatten), Dense{typeof(tanh), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}}, Const{ConcreteRArray{Float32, 4}})
  Stacktrace:
    [1] macro expansion
      @ ~/git/Reactant.jl/src/Compiler.jl:943 [inlined]
    [2] (::Reactant.Compiler.Thunk{FluxReactantExt.var"#_autodiff#3"{var"#loss1#127"}, Symbol("##_autodiff_reactant#99162"), Tuple{Duplicated{ConcreteRArray{Float32, 1}}, Duplicated{Chain{Tuple{typeof(Flux.flatten), Dense{typeof(tanh), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}}, Const{ConcreteRArray{Float32, 4}}}, false})(::ReverseMode{false, false, FFIABI, false, false}, ::var"#loss1#127", ::Type{Active}, ::Duplicated{Chain{Tuple{typeof(Flux.flatten), Dense{typeof(tanh), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}}, ::Const{ConcreteRArray{Float32, 4}})
      @ Reactant.Compiler ~/git/Reactant.jl/src/Compiler.jl:939
    [3] gradient(f::Function, m::Reactor{Chain{Tuple{typeof(Flux.flatten), Dense{typeof(tanh), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}}, xs::Const{Array{Float32, 4}})
      @ FluxReactantExt ~/git/Fluxperimental.jl/ext/FluxReactantExt.jl:209
    [4] macro expansion
      @ ~/git/Fluxperimental.jl/test/reactant.jl:41 [inlined]
    [5] macro expansion
      @ ~/.julia/juliaup/julia-1.10.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
    [6] macro expansion
      @ ~/git/Fluxperimental.jl/test/reactant.jl:27 [inlined]
    [7] macro expansion
      @ ~/.julia/juliaup/julia-1.10.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
    [8] top-level scope
      @ ~/git/Fluxperimental.jl/test/reactant.jl:5
    [9] include(fname::String)
      @ Base.MainInclude ./client.jl:494
   [10] macro expansion
      @ ~/git/Fluxperimental.jl/test/runtests.jl:17 [inlined]
   [11] macro expansion
      @ ~/.julia/juliaup/julia-1.10.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [12] top-level scope
      @ ~/git/Fluxperimental.jl/test/runtests.jl:5
   [13] include(fname::String)
      @ Base.MainInclude ./client.jl:494
   [14] top-level scope
      @ none:6
   [15] eval
      @ ./boot.jl:385 [inlined]
   [16] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:296
   [17] _start()
      @ Base ./client.jl:557

@wsmoses
Copy link

wsmoses commented Jan 5, 2025

Okay now it's presumably hitting the optimizers inner type issue (FluxML/Optimisers.jl#205):

[ Info: compiling gradient(loss1, ::Reactor, ::Const...)
simple gradient: Test Failed at /Users/wmoses/git/Fluxperimental.jl/test/reactant.jl:44
  Expression: Array(g4)  g1
   Evaluated: Float32[-0.48188388, 0.5550047, 0.15954198, -0.62123954, 2.332253, -0.45462236, -3.659017, -1.5994651, 1.4772613, 1.2025286    -0.34134686, -2.2225528, -0.38309634, -0.05193321, 1.1065944, -1.7070123, 0.8250545, -1.0351651, 1.2146482, -0.48555374]  Float32[-0.48188365, 0.5550045, 0.15954155, -0.6212394, 2.332254, -0.45462242, -3.659018, -1.5994655, 1.4772608, 1.2025281    -0.34134674, -2.2225525, -0.38309598, -0.05108498, 1.1065947, -1.7070134, 0.8250548, -1.0351648, 1.2146478, -0.48555297]

Stacktrace:
 [1] macro expansion
   @ ~/.julia/juliaup/julia-1.10.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:672 [inlined]
 [2] macro expansion
   @ ~/git/Fluxperimental.jl/test/reactant.jl:44 [inlined]
 [3] macro expansion
   @ ~/.julia/juliaup/julia-1.10.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
 [4] macro expansion
   @ ~/git/Fluxperimental.jl/test/reactant.jl:27 [inlined]
 [5] macro expansion
   @ ~/.julia/juliaup/julia-1.10.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
 [6] top-level scope
   @ ~/git/Fluxperimental.jl/test/reactant.jl:5
[ Info: compiling
simple train!: Error During Test at /Users/wmoses/git/Fluxperimental.jl/test/reactant.jl:152
  Got exception outside of a @test
  MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
  
  Closest candidates are:
    (::Type{T})(::T) where T<:Number
     @ Core boot.jl:792
    Float32(::BigFloat, ::Base.MPFR.MPFRRoundingMode)
     @ Base mpfr.jl:390
    Float32(::BigFloat)
     @ Base mpfr.jl:390
    ...
  
  Stacktrace:
    [1] macro expansion
      @ ~/git/Reactant.jl/src/utils.jl:0 [inlined]
    [2] call_with_reactant(::Type{Float32}, ::Reactant.TracedRNumber{Float32})
      @ Reactant ~/git/Reactant.jl/src/utils.jl:583
    [3] convert
      @ ./number.jl:7 [inlined]
    [4] convert(none::Type{Float32}, none::Reactant.TracedRNumber{Float32})
      @ Reactant ./<missing>:0
    [5] convert
      @ ./number.jl:7 [inlined]
    [6] call_with_reactant(::typeof(convert), ::Type{Float32}, ::Reactant.TracedRNumber{Float32})
      @ Reactant ~/git/Reactant.jl/src/utils.jl:0
    [7] cvt1
      @ ./essentials.jl:468 [inlined]
    [8] ntuple
      @ ./ntuple.jl:49 [inlined]
    [9] convert
      @ ./essentials.jl:470 [inlined]
   [10] convert(none::Type{Tuple{Float32, Float32}}, none::Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}})
      @ Reactant ./<missing>:0
   [11] convert
      @ ./essentials.jl:460 [inlined]
   [12] call_with_reactant(::typeof(convert), ::Type{Tuple{Float32, Float32}}, ::Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}})
      @ Reactant ~/git/Reactant.jl/src/utils.jl:0
   [13] cvt1
      @ ./essentials.jl:468 [inlined]
   [14] ntuple
      @ ./ntuple.jl:50 [inlined]
   [15] convert
      @ ./essentials.jl:470 [inlined]
   [16] convert(none::Type{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, none::Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}})
      @ Reactant ./<missing>:0
   [17] convert
      @ ./essentials.jl:460 [inlined]
   [18] call_with_reactant(::typeof(convert), ::Type{Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, ::Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}})
      @ Reactant ~/git/Reactant.jl/src/utils.jl:0
   [19] setproperty!
      @ ./Base.jl:40 [inlined]
   [20] setproperty!(none::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, none::Symbol, none::Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}})
      @ Reactant ./<missing>:0
   [21] setproperty!
      @ ./Base.jl:39 [inlined]
   [22] call_with_reactant(::typeof(setproperty!), ::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, ::Symbol, ::Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Reactant.TracedRNumber{Float32}, Reactant.TracedRNumber{Float32}}})
      @ Reactant ~/git/Reactant.jl/src/utils.jl:0
   [23] #_update!#10
      @ ~/git/Optimisers.jl/src/interface.jl:96 [inlined]
   [24] var"#_update!#10"(none::IdDict{Optimisers.Leaf, Any}, none::IdDict{Any, Any}, none::typeof(Optimisers._update!), none::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, none::Reactant.TracedRArray{Float32, 2})
      @ Reactant ./<missing>:0
   [25] #_update!#10
      @ ~/git/Optimisers.jl/src/interface.jl:93 [inlined]
   [26] call_with_reactant(::Optimisers.var"##_update!#10", ::IdDict{Optimisers.Leaf, Any}, ::IdDict{Any, Any}, ::typeof(Optimisers._update!), ::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, ::Reactant.TracedRArray{Float32, 2})
      @ Reactant ~/git/Reactant.jl/src/utils.jl:0
   [27] _update!
      @ ~/git/Optimisers.jl/src/interface.jl:92 [inlined]
   [28] #8
      @ ~/git/Optimisers.jl/src/interface.jl:85 [inlined]
   [29] map
      @ ./tuple.jl:322 [inlined]
   [30] map
      @ ./namedtuple.jl:265 [inlined]
   [31] mapvalue
      @ ~/git/Optimisers.jl/src/utils.jl:2 [inlined]
   [32] #_update!#7
      @ ~/git/Optimisers.jl/src/interface.jl:85 [inlined]
   [33] var"#_update!#7"(none::IdDict{Optimisers.Leaf, Any}, none::IdDict{Any, Any}, none::typeof(Optimisers._update!), none::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, none::Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}})
      @ Reactant ./<missing>:0
   [34] #_update!#7
      @ ~/git/Optimisers.jl/src/interface.jl:82 [inlined]
   [35] call_with_reactant(::Optimisers.var"##_update!#7", ::IdDict{Optimisers.Leaf, Any}, ::IdDict{Any, Any}, ::typeof(Optimisers._update!), ::@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, ::Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}})
      @ Reactant ~/git/Reactant.jl/src/utils.jl:0
   [36] _update!
      @ ~/git/Optimisers.jl/src/interface.jl:81 [inlined]
   [37] #8
      @ ~/git/Optimisers.jl/src/interface.jl:85 [inlined]
   [38] map
      @ ./tuple.jl:322 [inlined]
   [39] mapvalue
      @ ~/git/Optimisers.jl/src/utils.jl:2 [inlined]
   [40] #_update!#7
      @ ~/git/Optimisers.jl/src/interface.jl:85 [inlined]
   [41] var"#_update!#7"(none::IdDict{Optimisers.Leaf, Any}, none::IdDict{Any, Any}, none::typeof(Optimisers._update!), none::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}, none::Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}})
      @ Reactant ./<missing>:0
   [42] #_update!#7
      @ ~/git/Optimisers.jl/src/interface.jl:82 [inlined]
   [43] call_with_reactant(::Optimisers.var"##_update!#7", ::IdDict{Optimisers.Leaf, Any}, ::IdDict{Any, Any}, ::typeof(Optimisers._update!), ::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}, ::Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}})
      @ Reactant ~/git/Reactant.jl/src/utils.jl:0
   [44] _update!
      @ ~/git/Optimisers.jl/src/interface.jl:81 [inlined]
   [45] #8
      @ ~/git/Optimisers.jl/src/interface.jl:85 [inlined]
   [46] map
      @ ./tuple.jl:318 [inlined]
   [47] map
      @ ./namedtuple.jl:265 [inlined]
   [48] mapvalue
      @ ~/git/Optimisers.jl/src/utils.jl:2 [inlined]
   [49] #_update!#7
      @ ~/git/Optimisers.jl/src/interface.jl:85 [inlined]
   [50] _update!
      @ ~/git/Optimisers.jl/src/interface.jl:81 [inlined]
   [51] update!
      @ ~/git/Optimisers.jl/src/interface.jl:77 [inlined]
   [52] _step!
      @ ~/git/Fluxperimental.jl/ext/FluxReactantExt.jl:355 [inlined]
   [53] _step!(none::var"#126#128", none::Duplicated{Reactant.TracedRArray{Float32, 1}}, none::Chain{Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}}}, none::Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}}, none::@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}})
      @ Reactant ./<missing>:0
   [54] Array
      @ ./boot.jl:477 [inlined]
   [55] IdDict
      @ ./iddict.jl:30 [inlined]
   [56] IdDict
      @ ./iddict.jl:48 [inlined]
   [57] make_zero (repeats 2 times)
      @ ~/.julia/packages/EnzymeCore/15Zff/src/EnzymeCore.jl:529 [inlined]
   [58] _step!
      @ ~/git/Fluxperimental.jl/ext/FluxReactantExt.jl:353 [inlined]
   [59] call_with_reactant(::typeof(FluxReactantExt._step!), ::var"#126#128", ::Duplicated{Reactant.TracedRArray{Float32, 1}}, ::Chain{Tuple{Dense{typeof(σ), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}, BatchNorm{typeof(identity), Reactant.TracedRArray{Float32, 1}, Float32, Reactant.TracedRArray{Float32, 1}}, Dense{typeof(identity), Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 1}}}}, ::Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}}, ::@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 2}, Reactant.TracedRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{Reactant.TracedRArray{Float32, 1}, Reactant.TracedRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}})
      @ Reactant ~/git/Reactant.jl/src/utils.jl:0
   [60] make_mlir_fn(f::Function, args::Tuple{var"#126#128", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool, do_transpose::Bool)
      @ Reactant.TracedUtils ~/git/Reactant.jl/src/TracedUtils.jl:184
   [61] make_mlir_fn
      @ ~/git/Reactant.jl/src/TracedUtils.jl:86 [inlined]
   [62] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{var"#126#128", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}; optimize::Bool, no_nan::Bool)
      @ Reactant.Compiler ~/git/Reactant.jl/src/Compiler.jl:348
   [63] compile_mlir!
      @ ~/git/Reactant.jl/src/Compiler.jl:339 [inlined]
   [64] compile_xla(f::Function, args::Tuple{var"#126#128", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}; client::Nothing, optimize::Bool, no_nan::Bool)
      @ Reactant.Compiler ~/git/Reactant.jl/src/Compiler.jl:844
   [65] compile_xla
      @ ~/git/Reactant.jl/src/Compiler.jl:835 [inlined]
   [66] compile(f::Function, args::Tuple{var"#126#128", Duplicated{ConcreteRArray{Float32, 1}}, Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}}, @NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}}; client::Nothing, optimize::Bool, sync::Bool, no_nan::Bool)
      @ Reactant.Compiler ~/git/Reactant.jl/src/Compiler.jl:870
   [67] macro expansion
      @ ~/git/Reactant.jl/src/Compiler.jl:580 [inlined]
   [68] macro expansion
      @ ~/git/Fluxperimental.jl/ext/FluxReactantExt.jl:329 [inlined]
   [69] macro expansion
      @ ~/.julia/packages/ProgressLogging/6KXlp/src/ProgressLogging.jl:328 [inlined]
   [70] train!(loss::Function, m::Reactor{Chain{Tuple{Dense{typeof(σ), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}, BatchNorm{typeof(identity), ConcreteRArray{Float32, 1}, Float32, ConcreteRArray{Float32, 1}}, Dense{typeof(identity), ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 1}}}}}, data::MLUtils.DataLoader{Tuple{Matrix{Float32}, Matrix{Float32}}, Random._GLOBAL_RNG, Val{nothing}}, opt_state::@NamedTuple{layers::Tuple{@NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}, @NamedTuple{λ::Tuple{}, β::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, γ::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, μ::Tuple{}, σ²::Tuple{}, ϵ::Tuple{}, momentum::Tuple{}, affine::Tuple{}, track_stats::Tuple{}, active::Tuple{}, chs::Tuple{}}, @NamedTuple{weight::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 2}, ConcreteRArray{Float32, 2}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam, Tuple{ConcreteRArray{Float32, 1}, ConcreteRArray{Float32, 1}, Tuple{Float32, Float32}}}, σ::Tuple{}}}}; epochs::Int64)
      @ FluxReactantExt ~/git/Fluxperimental.jl/ext/FluxReactantExt.jl:324
   [71] macro expansion
      @ ~/git/Fluxperimental.jl/test/reactant.jl:161 [inlined]
   [72] macro expansion
      @ ~/.julia/juliaup/julia-1.10.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [73] macro expansion
      @ ~/git/Fluxperimental.jl/test/reactant.jl:153 [inlined]
   [74] macro expansion
      @ ~/.julia/juliaup/julia-1.10.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [75] top-level scope
      @ ~/git/Fluxperimental.jl/test/reactant.jl:5
   [76] include(fname::String)
      @ Base.MainInclude ./client.jl:494
   [77] macro expansion
      @ ~/git/Fluxperimental.jl/test/runtests.jl:17 [inlined]
   [78] macro expansion
      @ ~/.julia/juliaup/julia-1.10.7+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
   [79] top-level scope
      @ ~/git/Fluxperimental.jl/test/runtests.jl:5
   [80] include(fname::String)
      @ Base.MainInclude ./client.jl:494
   [81] top-level scope
      @ none:6
   [82] eval
      @ ./boot.jl:385 [inlined]
   [83] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:296
Test Summary:                           | Pass  Fail  Error  Broken  Total     Time
Fluxperimental.jl                       |   74     1      1       4     80  2m43.2s
  Split + Join                          |    2                           2     1.9s
  Applying the Chain!                   |    9                           9     8.4s
  @compact                              |   20                    3     23     6.6s
  Custom naming of @compact with NoShow |    2                           2     0.3s
  NoShow                                |    7                           7     3.1s
  simple case                           |    3                           3     0.1s
  re-defined                            |    4                           4     0.1s
  new defn                              |    5                           5     0.1s
  no-function                           |    4                           4     0.1s
  gradient, withgradient, Moonduo       |    7                    1      8    40.7s
  Reactant + Flux                       |   11     1      1             13    57.3s
    simple forwards                     |    6                           6    13.7s
    simple gradient                     |    5     1                     6    27.4s
    simple train!                       |                 1              1    16.1s
ERROR: LoadError: Some tests did not pass: 74 passed, 1 failed, 1 errored, 4 broken.
in expression starting at /Users/wmoses/git/Fluxperimental.jl/test/runtests.jl:4
ERROR: Package Fluxperimental errored during testing

@wsmoses
Copy link

wsmoses commented Jan 5, 2025

My patch is

diff --git a/ext/FluxReactantExt.jl b/ext/FluxReactantExt.jl
index 0c1b801..34b043b 100644
--- a/ext/FluxReactantExt.jl
+++ b/ext/FluxReactantExt.jl
@@ -187,7 +187,7 @@ function Flux.gradient(f::Function, m::Reactor, xs::Const...)
     # _seed = Ref(0f0), Ref(1f0)  # MethodError: no method matching Float32(::Reactant.TracedRNumber{Float32})
     _seed = ([0f0], [1f0]) |> Reactant.to_rarray
     seed = Duplicated(_seed...)
-    function _autodiff(seed, dup, xrs...)
+    function _autodiff(seed, f, dup, xrs...)
         Enzyme.make_zero!(Ref(dup.dval))
         Enzyme.autodiff(Reverse, Const(_fun!), seed, Const(f), dup, xrs...)  # suggestion from @jumerckx to pass simpler arguments to the function seen by  @compile
     end
@@ -197,16 +197,16 @@ function Flux.gradient(f::Function, m::Reactor, xs::Const...)
     elseif input == m.grad_input
         # m.grad_compiled(Reverse, f, Active, dup, xrs...)
         # m.grad_compiled(Reverse, Const(_fun!), seed, Const(f), dup, xrs...)
-        m.grad_compiled(seed, dup, xrs...)
+        m.grad_compiled(seed, f, dup, xrs...)
         m.grad_count += 1
     else
         @info "compiling gradient($f, ::Reactor, ::Const...)"
         # fun = @compile Enzyme.autodiff(Reverse, f, Active, dup, xrs...)  # this gives ERROR: "Unhandled type Type" above
         # fun = @compile Enzyme.autodiff(Reverse, Const(_fun!), seed, Const(f), dup, xrs...)  # this gives ERROR: type TypeVar has no field data
-        fun = @compile _autodiff(seed, dup, xrs...)  # ERROR: BoundsError: attempt to access ReverseMode{false, false, FFIABI, false, false} at index [1]
+        fun = @compile _autodiff(seed, f, dup, xrs...)  # ERROR: BoundsError: attempt to access ReverseMode{false, false, FFIABI, false, false} at index [1]
         m.grad_compiled = fun
         m.grad_input = _input_summary(f, xrs...)
-        fun(Reverse, f, Active, dup, xrs...)
+        fun(seed, f, dup, xrs...)
         m.grad_count += 1
     end
     map(_grad_or_nothing, (dup, xrs...))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants