-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimising encoder twice during CURL? #20
Comments
Hi @wassname |
Hi, It seems that the encoder in actor never be updated either by loss or soft update (EMA), except that in the initialisation.
Only the encoder in critic/critic_target is updated by the critic_loss and the cpc. Is there any insight for not updating the encoder in the actor? |
Hi @IpadLi, I wondered on this a while back and mailed @MishaLaskin about it.
@MishaLaskin 's reply:
|
Hi @tejassp2002 Thanks a lot. |
Hi, can we integrate the update_critic function and update_cpc function by adding the critic_loss and cpc_loss together? self.cpc_optimizer = torch.optim.Adam([self.CURL.W], lr=encoder_lr) |
The work of SAC+AE (https://arxiv.org/pdf/1910.01741.pdf) suggests to use the gradient from critic only (no actor) to update the encoder. Since this repo is based on the implementation of SAC+AE (as said in readme), I think CURL just follows it. |
Hi, thanks for posting the reply from the author! Update -- Sorry, the tie_weight function actually make the actor encoder and critic encoder share the same weights. |
Hello! Does it mean the weights of actor encoder are still same with the critic encoder after the critic encoder is updated? |
yes,actor and critic indeed share the same encoder. |
Thanks for sharing your code, it's great to be able to go through the implementation.
Maybe I'm misunderstanding this, but it seem that if you intend
self.cpc_optimizer
to only optimise W, thenshould be
or
The code I'm referring to is here and the torch docs for parameter are here. And I'm comparing it to section 4.7 of your paper.
As it stands it seems that encoder is optimised twice, once in
encoder_optimizer
and again incpc_optimizer
.Or am I missing something?
The text was updated successfully, but these errors were encountered: