Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generative Replay with weighted loss for replayed data #1596

Merged
merged 7 commits into from
May 31, 2024

Conversation

gogamid
Copy link
Contributor

@gogamid gogamid commented Feb 13, 2024

First attempt to add support for Generative Replay with weighted loss for replayed data.
Added three arguments:

is_weighted_replay=True,
weight_replay_loss_factor=2.0,
weight_replay_loss=0.001,

increasing_replay_size: bool = False,
is_weighted_replay: bool = False,
weight_replay_loss_factor: float = 1.0,
weight_replay_loss: float = 0.0001,
):
"""
Init.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add the documentation for the arguments?

def before_backward(self, strategy: Template, *args, **kwargs) -> Any:
super().before_backward(strategy, *args, **kwargs)
"""
Generate replay data and calculate the loss on the replay data.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring must be before super() call

replay_output = torch.zeros(replay_data.shape[0])

# make copy of mbatch
mbatch = deepcopy(strategy.mbatch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you need a deepcopy here?

@@ -0,0 +1,97 @@
################################################################################
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please modify one of the existing examples instead of creating a new one

@AntonioCarta
Copy link
Collaborator

Thank you for your contribution. I left some comments.

It's not clear to me why you multiply weight_replay_loss_factor after each iteration. Is this really what you want?

FYI, you can fix the lint issues with by running black ..

@gogamid
Copy link
Contributor Author

gogamid commented Mar 17, 2024

Thank you for reviewing, I am trying to make it better.

It's not clear to me why you multiply weight_replay_loss_factor after each iteration. Is this really what you want?

According to this section of API Docu:
"Another way to implempent the algorithm is by weighting the loss function and give more importance to the replayed data as the number of experiences increases."
Therefore, I increase the weight of loss by multiplying it with a factor to give more importance to the replay data as the number of experiences increases. Is there better suggestion that you can make?

@gogamid gogamid closed this Mar 17, 2024
@gogamid gogamid reopened this Mar 17, 2024
@coveralls
Copy link

coveralls commented Mar 17, 2024

Pull Request Test Coverage Report for Build 8317913660

Details

  • 3 of 29 (10.34%) changed or added relevant lines in 1 file are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage decreased (-0.05%) to 51.761%

Changes Missing Coverage Covered Lines Changed/Added Lines %
avalanche/training/plugins/generative_replay.py 3 29 10.34%
Files with Coverage Reduction New Missed Lines %
avalanche/training/plugins/generative_replay.py 1 20.73%
Totals Coverage Status
Change from base Build 8098020118: -0.05%
Covered Lines: 14757
Relevant Lines: 28510

💛 - Coveralls

@AndreaCossu
Copy link
Collaborator

Thanks @gogamid , this looks good!

@AndreaCossu AndreaCossu merged commit bbd0778 into ContinualAI:master May 31, 2024
19 of 21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants