diff --git a/.gitignore b/.gitignore index 78756acf1..9e6791bdc 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,3 @@ *.jl.*.cov *.jl.mem docs/build -Manifest.toml diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 000000000..2c8d09ff1 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,224 @@ +# This file is machine-generated - editing it directly is not advised + +[[AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "8ed9de2f1b1a9b1dee48582ad477c6e67b83eb2c" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "1.0.0" + +[[Artifacts]] +deps = ["Pkg"] +git-tree-sha1 = "c30985d8821e0cd73870b17b0ed0ce6dc44cb744" +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.3.0" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[Buffers]] +deps = ["MacroTools", "ZygoteRules"] +git-tree-sha1 = "ae8092bb50cb5e8fcbcbadace2703eeb2362cc2e" +repo-rev = "master" +repo-url = "https://github.com/sethaxen/Buffers.jl" +uuid = "f06dc053-dc01-43d6-9774-6ac660c29876" +version = "0.1.0" + +[[ChainRules]] +deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Reexport", "Requires", "Statistics"] +git-tree-sha1 = "31b28f5123afa5e5ca0c885e4051896032754578" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "0.7.45" + +[[ChainRulesCore]] +deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"] +git-tree-sha1 = "15081c431bb25848ad9b0d172a65794f3a3e197a" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "0.9.24" + +[[CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[Compat]] +deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] +git-tree-sha1 = "919c7f3151e79ff196add81d7f4e45d91bbf420b" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "3.25.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "8e695f735fca77e9708e795eda62afdb869cbb70" +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.3.4+0" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.3" + +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.0.2" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "8bd8e47ff5d34b20f0aa9641988eb660590008bc" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.11.0" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "8de2519a83c6c1c2442c2f481dd9a8364855daf4" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.14" + +[[IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.4.2" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[JLLWrappers]] +git-tree-sha1 = "04b49c556240b62d5a799e94c63d5fc14d3c07cd" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.1.4" + +[[LibGit2]] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.6" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[MuladdMacro]] +git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" +uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" +version = "0.2.2" + +[[NaNMath]] +git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.5" + +[[OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.3+4" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Reexport]] +git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.0.0" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "cfbac6c1ed70c002ec6361e7fd334f02820d6419" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.1.2" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[SharedArrays]] +deps = ["Distributed", "Mmap", "Random", "Serialization"] +uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["ChainRulesCore", "OpenSpecFun_jll"] +git-tree-sha1 = "75394dbe2bd346beeed750fb02baa6445487b862" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "1.2.1" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "9da72ed50e94dbff92036da395275ed114e04d49" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.0.1" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.1" diff --git a/Project.toml b/Project.toml index 461bf3710..0cf0be7e8 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.6.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" +Buffers = "f06dc053-dc01-43d6-9774-6ac660c29876" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/src/Zygote.jl b/src/Zygote.jl index a7b42cf2a..c3db61825 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -5,6 +5,7 @@ using LinearAlgebra: copytri!, AbstractTriangular import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty, literal_getfield +using Buffers: Buffer, bufferfrom using ChainRules: ChainRules, rrule, unthunk using IRTools @@ -15,7 +16,6 @@ import Distributed: pmap, CachingPool, workers export Params, gradient, pullback, pushforward, @code_adjoint include("tools/idset.jl") -include("tools/buffer.jl") include("tools/builtins.jl") include("forward/Forward.jl") @@ -32,7 +32,6 @@ include("lib/lib.jl") include("lib/number.jl") include("lib/base.jl") include("lib/array.jl") -include("lib/buffer.jl") include("lib/broadcast.jl") include("lib/forward.jl") include("lib/utils.jl") diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl deleted file mode 100644 index 4d332c16d..000000000 --- a/src/lib/buffer.jl +++ /dev/null @@ -1,52 +0,0 @@ -grad_mut(b::Buffer) = fill!(similar(b.data, Any), nothing) -grad_mut(b::Buffer{T}) where T<:Number = fill!(similar(b.data, float(T)), 0) - -@nograd Buffer - -@adjoint function getindex(b::Buffer, i...) - b[i...], function (Δ) - grad = grad_mut(__context__, b) - grad[i...] = accum(grad[i...], Δ) - return - end -end - -@adjoint! function setindex!(b::Buffer, v, i...) - setindex!(b, v, i...), function (_) - grad = grad_mut(__context__, b) - v̄ = grad[i...] - zero = eltype(grad) <: Number ? 0 : nothing - if i isa NTuple{N,Integer} where N - grad[i...] = zero - else - grad[i...] .= zero - end - (nothing, v̄, map(_->nothing, i)...) - end -end - -@adjoint! function copyto!(b::Buffer, xs) - copyto!(b, xs), function (_) - grad = grad_mut(__context__, b) - x̄s = copy(grad) - grad .= eltype(grad) <: Number ? 0 : nothing - return (nothing, x̄s) - end -end - -@adjoint! function push!(b::Buffer, x) - push!(b, x), function (y) - grad = grad_mut(__context__, b) - return (nothing, pop!(grad)) - end -end - -_pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x::AbstractArray) = - _pullback(cx, copyto!, b, x) - -@adjoint function copy(b::Buffer) - copy(b), function (b̄) - grad_mut(__context__, b)[:] = b̄ - return - end -end diff --git a/src/tools/buffer.jl b/src/tools/buffer.jl deleted file mode 100644 index 9409a74bc..000000000 --- a/src/tools/buffer.jl +++ /dev/null @@ -1,86 +0,0 @@ -""" - Buffer(xs, ...) - -`Buffer` is an array-like type which is mutable when taking gradients. You can -construct a `Buffer` with the same syntax as `similar` (e.g. `Buffer(xs, 5)`) -and then use normal indexing. Finally, use `copy` to get back a normal array. - -For example: - -```julia -julia> function vstack(xs) - buf = Buffer(xs, length(xs), 5) - for i = 1:5 - buf[:, i] = xs - end - return copy(buf) - end -vstack (generic function with 1 method) - -julia> vstack([1, 2, 3]) -3×5 Array{Int64,2}: - 1 1 1 1 1 - 2 2 2 2 2 - 3 3 3 3 3 - -julia> gradient(x -> sum(vstack(x)), [1, 2, 3]) -([5.0, 5.0, 5.0],) -``` - -`Buffer` is not an `AbstractArray` and can't be used for linear algebra -operations like matrix multiplication. This prevents it from being captured by -pullbacks. - -`copy` is a semantic copy, but does not allocate memory. Instead the `Buffer` -is made immutable after copying. -""" -mutable struct Buffer{T,A<:AbstractArray{T}} - data::A - freeze::Bool -end - -Buffer(xs::AbstractArray, args...) = - Buffer(similar(xs, args...), false) - -bufferfrom(xs::AbstractArray) = Buffer(xs, false) - -Base.getindex(b::Buffer, i...) = b.data[i...] - -function Base.setindex!(b::Buffer, v, i...) - b.freeze && error("Buffer is frozen") - b.data[i...] = v -end - -function Base.copyto!(b::Buffer, data) - b.freeze && error("Buffer is frozen") - copyto!(b.data, data) -end - -function Base.push!(b::Buffer, data) - b.freeze && error("Buffer is frozen") - push!(b.data, data) -end - -function Base.copy(b::Buffer) - b.freeze = true - return b.data -end - -function Base.deleteat!(b::Buffer, i) - b.freeze && error("Buffer is frozen") - deleteat!(b.data, i) - return b -end - -@forward Buffer.data Base.eltype, Base.length, Base.ndims, Base.size, Base.axes, - Base.eachindex, Base.stride, Base.strides, Base.findfirst, - Base.keys - -Base.IteratorSize(::Type{<:Buffer{<:Any, A}}) where {A} = Base.IteratorSize(A) - -# Buffer iteration mirrors iteration for AbstractArray -function Base.iterate(b::Buffer, state=(eachindex(b),)) - y = iterate(state...) - y === nothing && return nothing - b[y[1]], (state[1], tail(y)...) -end