-
-
Notifications
You must be signed in to change notification settings - Fork 295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generative Replay with weighted loss for replayed data #1596
Conversation
increasing_replay_size: bool = False, | ||
is_weighted_replay: bool = False, | ||
weight_replay_loss_factor: float = 1.0, | ||
weight_replay_loss: float = 0.0001, | ||
): | ||
""" | ||
Init. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add the documentation for the arguments?
def before_backward(self, strategy: Template, *args, **kwargs) -> Any: | ||
super().before_backward(strategy, *args, **kwargs) | ||
""" | ||
Generate replay data and calculate the loss on the replay data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring must be before super() call
replay_output = torch.zeros(replay_data.shape[0]) | ||
|
||
# make copy of mbatch | ||
mbatch = deepcopy(strategy.mbatch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you need a deepcopy here?
@@ -0,0 +1,97 @@ | |||
################################################################################ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please modify one of the existing examples instead of creating a new one
Thank you for your contribution. I left some comments. It's not clear to me why you multiply FYI, you can fix the lint issues with by running |
Thank you for reviewing, I am trying to make it better.
According to this section of API Docu: |
Pull Request Test Coverage Report for Build 8317913660Details
💛 - Coveralls |
Thanks @gogamid , this looks good! |
First attempt to add support for Generative Replay with weighted loss for replayed data.
Added three arguments: