diff --git a/Project.toml b/Project.toml index 2a9333bc62..72123e07ac 100644 --- a/Project.toml +++ b/Project.toml @@ -30,9 +30,9 @@ ChainRulesCore = "1.12" Functors = "0.3, 0.4" MLUtils = "0.2, 0.3.1" MacroTools = "0.5" -NNlib = "0.8.9" +NNlib = "0.8.12" NNlibCUDA = "0.2.4" -OneHotArrays = "0.1, 0.2" +OneHotArrays = "0.2.2" Optimisers = "0.2.12" ProgressLogging = "0.1" Reexport = "0.2, 1.0" diff --git a/perf.jl b/perf.jl new file mode 100644 index 0000000000..163d28bc00 --- /dev/null +++ b/perf.jl @@ -0,0 +1,115 @@ +using Flux, BenchmarkTools, Flux.Losses, CUDA +using Flux: OneHotMatrix + +crossv2(ŷ::AbstractMatrix, y::AbstractVector{<:Integer}) = logitcrossentropy(ŷ, Flux.onehotbatch(y, 1:size(ŷ, 1))) +cross_unsafe(ŷ::AbstractMatrix, y::AbstractVector{<:Integer}) = logitcrossentropy(ŷ, OneHotMatrix(y, size(ŷ, 1))) + +function perf(c, n) + labels = rand(1:c, n) + y = Flux.onehotbatch(labels, 1:c) + ŷ = randn(Float32, c, n) + + labelsgpu = labels |> gpu + ygpu = y |> gpu + ŷgpu = ŷ |> gpu + + # println("with ŷ") + # @btime logitcrossentropy($ŷ, $y); + # @btime gradient(ŷ -> logitcrossentropy(ŷ, $y), $ŷ); + + # println("with labels") + # @btime logitcrossentropy($ŷ, $labels); + # @btime gradient(ŷ -> logitcrossentropy(ŷ, $labels), $ŷ); + + # println("crossv2") + # @btime crossv2($ŷ, $labels); + # @btime gradient(ŷ -> crossv2(ŷ, $labels), $ŷ); + + # println("with ŷ - gpu") + # @assert size(ŷgpu) == (c, n) + # @btime CUDA.@sync logitcrossentropy($ŷgpu, $ygpu); + # @btime CUDA.@sync gradient(ŷ -> logitcrossentropy(ŷ, $ygpu), $ŷgpu); + + # println("with labels - gpu") + # @btime CUDA.@sync logitcrossentropy($ŷgpu, $labelsgpu); + # @btime CUDA.@sync gradient(ŷ -> logitcrossentropy(ŷ, $labelsgpu), $ŷgpu); + + # println("crossv2 - gpu") + # @btime CUDA.@sync crossv2($ŷgpu, $labelsgpu); + # @btime CUDA.@sync gradient(ŷ -> crossv2(ŷ, $labelsgpu), $ŷgpu); + + println("cross_unsafe - gpu") + @btime CUDA.@sync cross_unsafe($ŷgpu, $labelsgpu); + @btime CUDA.@sync gradient(ŷ -> cross_unsafe(ŷ, $labelsgpu), $ŷgpu); + + return nothing +end + +perf(10, 128) +# with ŷ +# 14.648 μs (10 allocations: 13.17 KiB) +# 27.381 μs (19 allocations: 35.39 KiB) +# with labels +# 13.716 μs (16 allocations: 9.88 KiB) +# 41.338 μs (119 allocations: 25.22 KiB) +# crossv2 +# 14.838 μs (11 allocations: 13.73 KiB) +# 27.501 μs (20 allocations: 35.95 KiB) +# with ŷ - gpu +# 46.107 μs (163 allocations: 8.52 KiB) +# 109.656 μs (414 allocations: 24.17 KiB) +# with labels - gpu +# 42.620 μs (125 allocations: 6.23 KiB) +# 117.972 μs (375 allocations: 19.61 KiB) +# crossv2 - gpu +# 107.913 μs (284 allocations: 14.45 KiB) +# 177.093 μs (535 allocations: 30.11 KiB) +# cross_unsafe - gpu +# 46.647 μs (163 allocations: 8.52 KiB) +# 110.759 μs (414 allocations: 24.17 KiB) + +perf(100, 128) +# with ŷ +# 121.148 μs (12 allocations: 103.02 KiB) +# 212.059 μs (25 allocations: 304.92 KiB) +# with labels +# 113.914 μs (17 allocations: 54.80 KiB) +# 215.665 μs (122 allocations: 159.98 KiB) +# crossv2 +# 122.620 μs (13 allocations: 103.58 KiB) +# 215.615 μs (26 allocations: 305.48 KiB) +# with ŷ - gpu +# 47.880 μs (163 allocations: 8.52 KiB) +# 110.307 μs (414 allocations: 24.17 KiB) +# with labels - gpu +# 40.567 μs (125 allocations: 6.23 KiB) +# 122.961 μs (375 allocations: 19.61 KiB) +# crossv2 - gpu +# 104.917 μs (284 allocations: 14.45 KiB) +# 171.141 μs (535 allocations: 30.11 KiB) +# cross_unsafe - gpu +# 46.137 μs (163 allocations: 8.52 KiB) +# 109.084 μs (414 allocations: 24.17 KiB) + +perf(100, 1280) +# with ŷ +# 1.378 ms (12 allocations: 1.00 MiB) +# 2.320 ms (25 allocations: 2.97 MiB) +# with labels +# 1.321 ms (18 allocations: 540.97 KiB) +# 2.169 ms (123 allocations: 1.52 MiB) +# crossv2 +# 1.384 ms (13 allocations: 1.01 MiB) +# 2.317 ms (26 allocations: 2.98 MiB) +# with ŷ - gpu +# 60.885 μs (210 allocations: 10.77 KiB) +# 121.919 μs (464 allocations: 26.47 KiB) +# with labels - gpu +# 52.679 μs (174 allocations: 8.52 KiB) +# 128.602 μs (426 allocations: 21.92 KiB) +# crossv2 - gpu +# 137.448 μs (422 allocations: 20.91 KiB) +# 208.361 μs (676 allocations: 36.61 KiB) +# cross_unsafe - gpu +# 58.479 μs (210 allocations: 10.77 KiB) +# 121.839 μs (464 allocations: 26.47 KiB) diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 3d8f6f8149..693494bcb1 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -6,7 +6,7 @@ using Zygote: @adjoint using ChainRulesCore using ..Flux: ofeltype, epseltype using CUDA -using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss +using NNlib: NNlib, logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss import Base.Broadcast: broadcasted export mse, mae, msle, diff --git a/src/losses/functions.jl b/src/losses/functions.jl index ffda2ff99a..45d52215d9 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -183,6 +183,10 @@ as would be the case with the output of a [softmax](@ref Softmax) operation. For numerical stability, it is recommended to use [`logitcrossentropy`](@ref) rather than `softmax` followed by `crossentropy` . +The target array `y` has to be in the same shape as the prediction array `ŷ` +and represent one-hot encoded labels or probabilities. +As an alternative, `y` can also be a vector of integers containing the labels. + Use [`label_smoothing`](@ref) to smooth the true labels as preprocessing before computing the loss. @@ -227,6 +231,21 @@ function crossentropy(ŷ, y; dims = 1, agg = mean, ϵ = epseltype(ŷ)) agg(.-sum(xlogy.(y, ŷ .+ ϵ); dims = dims)) end +function crossentropy(ŷ::AbstractMatrix, y::AbstractVector{<:Integer}; dims = 1, agg = mean, ϵ = epseltype(ŷ)) + n = length(y) + dims ∈ (1, 2) || ArgumentError("Only dims = 1 and 2 are supported at the moment. Pass a one-hot encoded vector y for generic dims.") + n == (dims == 1 ? size(ŷ, 2) : size(ŷ, 1)) || ArgumentError("The length of y should be the same as the batch dimension of ŷ.") + logits = log.(ŷ .+ ϵ) + if dims == 1 + ŷgold = NNlib.gather(logits, y, 1:n) + ŷgold = reshape(ŷgold, 1, :) + else + ŷgold = NNlib.gather(logits, 1:n, y) + ŷgold = reshape(ŷgold, :, 1) + end + return -agg(ŷgold) +end + """ logitcrossentropy(ŷ, y; dims = 1, agg = mean) @@ -234,13 +253,18 @@ Return the cross entropy calculated by agg(-sum(y .* logsoftmax(ŷ; dims); dims)) -This is mathematically equivalent to `crossentropy(softmax(ŷ), y)`, +This is mathematically equivalent to `crossentropy(softmax(ŷ; dims), y; dims)`, but is more numerically stable than using functions [`crossentropy`](@ref) and [softmax](@ref Softmax) separately. +The target array `y` has to be in the same shape as the prediction array `ŷ` +and represent one-hot encoded labels or probabilities. +As an alternative, `y` can also be a vector of integers containing the labels. + See also: [`binarycrossentropy`](@ref), [`logitbinarycrossentropy`](@ref), [`label_smoothing`](@ref). -# Example +# Examples + ```jldoctest julia> y_label = Flux.onehotbatch(collect("abcabaa"), 'a':'c') 3×7 OneHotMatrix(::Vector{UInt32}) with eltype Bool: @@ -263,7 +287,22 @@ julia> Flux.crossentropy(softmax(y_model), y_label) """ function logitcrossentropy(ŷ, y; dims = 1, agg = mean) _check_sizes(ŷ, y) - agg(.-sum(y .* logsoftmax(ŷ; dims = dims); dims = dims)) + agg(.-sum(y .* logsoftmax(ŷ; dims); dims)) +end + +function logitcrossentropy(ŷ::AbstractMatrix, y::AbstractVector{<:Integer}; dims = 1 , agg = mean) + n = length(y) + dims ∈ (1, 2) || ArgumentError("Only dims = 1 and 2 are supported at the moment. Pass a one-hot encoded vector y for generic dims.") + n == (dims == 1 ? size(ŷ, 2) : size(ŷ, 1)) || ArgumentError("The length of y should be the same as the batch dimension of ŷ.") + logits = logsoftmax(ŷ; dims) + if dims == 1 + ŷgold = NNlib.gather(logits, y, 1:n) + ŷgold = reshape(ŷgold, 1, :) + else + ŷgold = NNlib.gather(logits, 1:n, y) + ŷgold = reshape(ŷgold, :, 1) + end + return -agg(ŷgold) end """ @@ -324,6 +363,7 @@ Mathematically equivalent to See also: [`crossentropy`](@ref), [`logitcrossentropy`](@ref). # Examples + ```jldoctest julia> y_bin = Bool[1,0,1]; diff --git a/test/losses.jl b/test/losses.jl index 2ca697a657..11929bf9da 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -59,14 +59,14 @@ y = [123.0,456.0,789.0] end # Now onehot y's -y = onehotbatch([1, 1, 0, 0], 0:1) +y = Flux.onehotbatch([1, 1, 0, 0], 0:1) y_smoothed = label_smoothing(y, 0.1) ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]' v = log(.1 / .9) logŷ = [v 0.0; 0.0 v; 0.0 v; v 0.0]' lossvalue = 1.203972804325936 lossvalue_smoothed = 1.2039728043259348 -yl = onehotbatch([1], 0:1) +yl = Flux.onehotbatch([1], 0:1) sf = 0.1 yls = [sf (1-sf)]' # Effective y after label smoothing ylp = [0.9 0.1]' @@ -75,7 +75,7 @@ logylp = [0.0 v]' # Construct `sim`ilar and `dis`imilar versions of the dataset so we can test effect of smoothing # smoothing should decrease loss on disimilar and increase the loss on similar, compared to # the loss without smoothing -ya = onehotbatch([1, 1, 1, 0, 0], 0:1) +ya = Flux.onehotbatch([1, 1, 1, 0, 0], 0:1) ya_smoothed = label_smoothing(ya, 2sf) y_same = Float32.(ya) y_sim = y_same .* (1-2*sf) .+ sf @@ -92,12 +92,51 @@ y_dis[1,:], y_dis[2,:] = y_dis[2,:], y_dis[1,:] @test iszero(crossentropy(ya, ya, ϵ=0)) @test crossentropy(y_sim, ya) < crossentropy(y_sim, ya_smoothed) @test crossentropy(y_dis, ya) > crossentropy(y_dis, ya_smoothed) + + + labels = rand(1:10, 20) + yc = Flux.onehotbatch(labels, 1:10) + ŷc = softmax(randn(Float32, 10, 20), dims=1) + l1 = crossentropy(ŷc, yc) + l2 = crossentropy(ŷc, labels) + @test l1 ≈ l2 + l1 = crossentropy(ŷc, yc, agg=identity) + l2 = crossentropy(ŷc, labels, agg=identity) + @test size(l1) == size(l2) == (1, 20) + @test l1 ≈ l2 + + labels = rand(1:20, 10) + yd = Flux.onehotbatch(labels, 1:20) + ŷd = softmax(randn(Float32, 10, 20), dims=2) + l1 = crossentropy(ŷd, yd', dims=2, agg=identity) + l2 = crossentropy(ŷd, labels, dims=2, agg=identity) + @test size(l1) == size(l2) == (10, 1) + @test l1 ≈ l2 end @testset "logitcrossentropy" begin @test logitcrossentropy(logŷ, y) ≈ lossvalue @test logitcrossentropy(logylp, yl) ≈ -sum(yl.*logsoftmax(logylp)) @test logitcrossentropy(logylp, label_smoothing(yl, 2sf)) ≈ -sum(yls.*logsoftmax(logylp)) + + labels = rand(1:10, 20) + yc = Flux.onehotbatch(labels, 1:10) + ŷc = randn(Float32, 10, 20) + l1 = logitcrossentropy(ŷc, yc) + l2 = logitcrossentropy(ŷc, labels) + @test l1 ≈ l2 + l1 = logitcrossentropy(ŷc, yc, agg=identity) + l2 = logitcrossentropy(ŷc, labels, agg=identity) + @test size(l1) == size(l2) == (1,20) + @test l1 ≈ l2 + + labels = rand(1:20, 10) + yd = Flux.onehotbatch(labels, 1:20) + ŷd = randn(Float32, 10, 20) + l1 = logitcrossentropy(ŷd, yd', dims=2, agg=identity) + l2 = logitcrossentropy(ŷd, labels, dims=2, agg=identity) + @test size(l1) == size(l2) == (10, 1) + @test l1 ≈ l2 end logŷ, y = randn(3), rand(3)