diff --git a/vision/dcgan_mnist/dcgan_mnist.jl b/vision/dcgan_mnist/dcgan_mnist.jl index 248ecbd94..2548615fa 100644 --- a/vision/dcgan_mnist/dcgan_mnist.jl +++ b/vision/dcgan_mnist/dcgan_mnist.jl @@ -8,6 +8,7 @@ using Statistics using Parameters: @with_kw using Printf using Random +using Zygote: @nograd @with_kw struct HyperParams batch_size::Int = 128 @@ -38,29 +39,35 @@ end generator_loss(fake_output) = mean(logitbinarycrossentropy.(fake_output, 1f0)) -function train_discriminator!(gen, dscr, x, opt_dscr, hparams) - noise = randn!(similar(x, (hparams.latent_dim, hparams.batch_size))) - fake_input = gen(noise) +function train_discriminator!(dscr, fake, x, opt_dscr, hparams) ps = Flux.params(dscr) # Taking gradient loss, back = Flux.pullback(ps) do - discriminator_loss(dscr(x), dscr(fake_input)) + discriminator_loss(dscr(x), dscr(fake)) end grad = back(1f0) update!(opt_dscr, ps, grad) return loss end -function train_generator!(gen, dscr, x, opt_gen, hparams) +@nograd train_discriminator! + +function train_gen_dscr!(gen, dscr, x, opt_gen, opt_dscr, hparams) noise = randn!(similar(x, (hparams.latent_dim, hparams.batch_size))) - ps = Flux.params(gen) + ps_gen = Flux.params(gen) # Taking gradient - loss, back = Flux.pullback(ps) do - generator_loss(dscr(gen(noise))) + + loss_dscr = nothing + loss_gen, back_gen = Flux.pullback(ps_gen) do + fake = gen(noise) + loss_dscr = train_discriminator!(dscr, fake, x, opt_dscr, hparams) + generator_loss(dscr(fake)) end - grad = back(1f0) - update!(opt_gen, ps, grad) - return loss + + grad_gen = back_gen(1f0) + update!(opt_gen, ps_gen, grad_gen) + + return loss_gen, loss_dscr end function train(; kws...) @@ -109,8 +116,7 @@ function train(; kws...) @info "Epoch $ep" for x in data # Update discriminator and generator - loss_dscr = train_discriminator!(gen, dscr, x, opt_dscr, hparams) - loss_gen = train_generator!(gen, dscr, x, opt_gen, hparams) + loss_gen, loss_dscr = train_gen_dscr!(gen, dscr, x, opt_gen, opt_dscr, hparams) if train_steps % hparams.verbose_freq == 0 @info("Train step $(train_steps), Discriminator loss = $(loss_dscr), Generator loss = $(loss_gen)")