Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some minor fixes in batch_observation and utils #1513

Merged
merged 4 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from avalanche.models.utils import avalanche_model_adaptation
from avalanche.training.templates.strategy_mixin_protocol import SGDStrategyProtocol
from avalanche.models.dynamic_optimizers import reset_optimizer, update_optimizer
from avalanche.training.utils import at_task_boundary
from avalanche.training.utils import _at_task_boundary


class BatchObservation(SGDStrategyProtocol):
Expand Down Expand Up @@ -73,10 +73,7 @@ def check_model_and_optimizer(self, reset_optimizer_state=False, **kwargs):
if self.optimized_param_id is None:
self.make_optimizer(reset_optimizer_state=True, **kwargs)

if at_task_boundary(self.experience):
self.model = self.model_adaptation()
self.make_optimizer(reset_optimizer_state=reset_optimizer_state, **kwargs)
else:
if _at_task_boundary(self.experience, before=True):
self.model = self.model_adaptation()
self.make_optimizer(reset_optimizer_state=reset_optimizer_state, **kwargs)

Expand Down
23 changes: 14 additions & 9 deletions avalanche/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@

"""
from collections import defaultdict
from typing import Dict, NamedTuple, List, Optional, Tuple, Callable, Union
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Module, Linear
from torch.utils.data import Dataset, DataLoader
from torch.nn import Linear, Module
from torch.utils.data import DataLoader, Dataset

from avalanche.models.batch_renorm import BatchRenorm2D
from avalanche.benchmarks import OnlineCLExperience
from avalanche.models.batch_renorm import BatchRenorm2D


def at_task_boundary(training_experience) -> bool:
def _at_task_boundary(training_experience, before=True) -> bool:
"""
Given a training experience,
returns true if the experience is at the task boundary
Expand All @@ -41,11 +41,17 @@ def at_task_boundary(training_experience) -> bool:

- If the experience is not an online experience, returns True

:param before: If used in before_training_exp,
set to True, otherwise set
to False

"""

if isinstance(training_experience, OnlineCLExperience):
if training_experience.access_task_boundaries:
if training_experience.is_first_subexp:
if before and training_experience.is_first_subexp:
return True
elif (not before) and training_experience.is_last_subexp:
return True
else:
return True
Expand Down Expand Up @@ -222,7 +228,7 @@ def replace_bn_with_brn(
):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if type(target_attr) == torch.nn.BatchNorm2d:
if isinstance(target_attr, torch.nn.BatchNorm2d):
# print('replaced: ', name, attr_str)
setattr(
m,
Expand Down Expand Up @@ -253,7 +259,7 @@ def change_brn_pars(
):
for attr_str in dir(m):
target_attr = getattr(m, attr_str)
if type(target_attr) == BatchRenorm2D:
if isinstance(target_attr, BatchRenorm2D):
target_attr.momentum = torch.tensor((momentum), requires_grad=False)
target_attr.r_max = torch.tensor(r_max, requires_grad=False)
target_attr.d_max = torch.tensor(d_max, requires_grad=False)
Expand Down Expand Up @@ -481,5 +487,4 @@ def __str__(self):
"examples_per_class",
"ParamData",
"cycle",
"at_task_boundary",
]