Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes in Dynamic Modules #1600

Merged
merged 9 commits into from
Feb 28, 2024
86 changes: 72 additions & 14 deletions avalanche/models/dynamic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,41 @@
to allow architectural modifications (multi-head classifiers, progressive
networks, ...).
"""
from typing import List, Optional

import torch
from torch.nn import Module
from typing import Optional

from avalanche.benchmarks.utils.flat_data import ConstantSequence
from avalanche.benchmarks.scenarios import CLExperience
from avalanche.benchmarks.utils.flat_data import ConstantSequence


def avalanche_model_adaptation(
module: Module,
experience: CLExperience,
_visited=None,
_initial_call: bool = True,
AlbinSou marked this conversation as resolved.
Show resolved Hide resolved
):
if _visited is None:
_visited = []

if module in _visited:
return

_visited.append(module)
AlbinSou marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(module, DynamicModule):
if (not _initial_call) and (not module._auto_adapt):
# Some modules don't want to be auto-adapted
return
else:
module.adaptation(experience)

# Iterate over children
for name, submodule in module.named_children():
avalanche_model_adaptation(
submodule, experience, _visited=_visited, _initial_call=False
)


class DynamicModule(Module):
Expand All @@ -29,6 +58,22 @@ class DynamicModule(Module):
`model_adaptation`, which adapts the model given the current experience.
"""

def __init__(self, auto_adapt=True):
"""
:param auto_adapt: If True, will be adapted in the recursive adaptation loop
else, will be adapted by a module in charge
(i.e IncrementalClassifier inside MultiHeadClassifier)
"""
super().__init__()
self._auto_adapt = auto_adapt

def adapt(self, experience):
"""
Calls self.adaptation recursively accross
the hierarchy of module children
"""
avalanche_model_adaptation(self, experience)

def adaptation(self, experience: CLExperience):
"""Adapt the module (freeze units, add units...) using the current
data. Optimizers must be updated after the model adaptation.
Expand All @@ -43,6 +88,10 @@ def adaptation(self, experience: CLExperience):
require the model's adaptation, such as the discovery of new
classes or tasks.

.. warning::
This function only adapts the current module, to recursively adapt all
submodules use self.adapt() instead

:param experience: the current experience.
:return:
"""
Expand Down Expand Up @@ -97,8 +146,8 @@ class MultiTaskModule(DynamicModule):
the output is computed in parallel for each task.
"""

def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.max_class_label = 0
self.known_train_tasks_labels = set()
""" Set of task labels encountered up to now. """
Expand All @@ -122,10 +171,7 @@ def adaptation(self, experience: CLExperience):
"""
curr_classes = experience.classes_in_this_experience
self.max_class_label = max(self.max_class_label, max(curr_classes) + 1)
if self.training:
self.train_adaptation(experience)
else:
self.eval_adaptation(experience)
super().adaptation(experience)

def eval_adaptation(self, experience: CLExperience):
pass
Expand Down Expand Up @@ -207,6 +253,7 @@ def __init__(
initial_out_features=2,
masking=True,
mask_value=-1000,
**kwargs,
):
"""
:param in_features: number of input features.
Expand All @@ -215,7 +262,7 @@ def __init__(
:param masking: whether unused units should be masked (default=True).
:param mask_value: the value used for masked units (default=-1000).
"""
super().__init__()
super().__init__(**kwargs)
self.masking = masking
self.mask_value = mask_value

Expand All @@ -224,7 +271,7 @@ def __init__(
self.register_buffer("active_units", au_init)

@torch.no_grad()
def adaptation(self, experience: CLExperience):
def train_adaptation(self, experience: CLExperience):
"""If `dataset` contains unseen classes the classifier is expanded.

:param experience: data from the current experience.
Expand Down Expand Up @@ -256,6 +303,9 @@ def adaptation(self, experience: CLExperience):
self.classifier.weight[:old_nclasses] = old_w
self.classifier.bias[:old_nclasses] = old_b

def eval_adaptation(self, experience):
self.train_adaptation(experience)

def forward(self, x, **kwargs):
"""compute the output given the input `x`. This module does not use
the task label.
Expand Down Expand Up @@ -321,7 +371,10 @@ def __init__(
# masking in IncrementalClassifier is unaware of task labels
# so we do masking here instead.
first_head = IncrementalClassifier(
self.in_features, self.starting_out_features, masking=False
self.in_features,
self.starting_out_features,
masking=False,
auto_adapt=False,
)
self.classifiers["0"] = first_head
self.max_class_label = max(self.max_class_label, initial_out_features)
Expand All @@ -345,13 +398,12 @@ def task_masks(self):
res[tid] = getattr(self, f"active_units_T{tid}").to(torch.bool)
return res

def adaptation(self, experience: CLExperience):
def train_adaptation(self, experience: CLExperience):
"""If `dataset` contains new tasks, a new head is initialized.

:param experience: data from the current experience.
:return:
"""
super().adaptation(experience)
device = self._adaptation_device
curr_classes = experience.classes_in_this_experience
task_labels = experience.task_labels
Expand All @@ -364,7 +416,10 @@ def adaptation(self, experience: CLExperience):
# head adaptation
if tid not in self.classifiers: # create new head
new_head = IncrementalClassifier(
self.in_features, self.starting_out_features, masking=False
self.in_features,
self.starting_out_features,
masking=False,
auto_adapt=False,
).to(device)
self.classifiers[tid] = new_head

Expand Down Expand Up @@ -404,6 +459,9 @@ def adaptation(self, experience: CLExperience):
if self.training:
self._buffers[au_name][curr_classes] = 1

def eval_adaptation(self, experience):
self.train_adaptation(experience)

def forward_single_task(self, x, task_label):
"""compute the output given the input `x`. This module uses the task
label to activate the correct head.
Expand Down
18 changes: 6 additions & 12 deletions avalanche/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel

from avalanche._annotations import deprecated
from avalanche.benchmarks.scenarios import CLExperience
from avalanche.benchmarks.utils import _make_taskaware_classification_dataset
from avalanche.models.dynamic_modules import DynamicModule, MultiTaskModule
from avalanche.models.dynamic_modules import (
DynamicModule,
MultiTaskModule,
avalanche_model_adaptation,
)


def is_multi_task_module(model: nn.Module) -> bool:
Expand All @@ -22,17 +27,6 @@ def avalanche_forward(model, x, task_labels):
return model(x)


def avalanche_model_adaptation(model: nn.Module, experience: CLExperience):
if isinstance(model, DistributedDataParallel):
raise RuntimeError(
"The model is wrapped in DistributedDataParallel. "
"Please unwrap it before calling this method."
)
for module in model.modules():
if isinstance(module, DynamicModule):
module.adaptation(experience)


class FeatureExtractorBackbone(nn.Module):
"""
This PyTorch module allows us to extract features from a backbone network
Expand Down
15 changes: 13 additions & 2 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F

from avalanche.logging import TextLogger
from avalanche.models import (
Expand Down Expand Up @@ -416,8 +417,8 @@ def test_multihead_head_creation(self):
w.data_ptr() for group in optimizer.param_groups for w in group["params"]
]

assert w_ptr not in opt_params_ptrs # head0 has been updated
assert b_ptr not in opt_params_ptrs # head0 has been updated
# assert w_ptr not in opt_params_ptrs # head0 has NOT been updated
# assert b_ptr not in opt_params_ptrs # head0 has NOT been updated
AlbinSou marked this conversation as resolved.
Show resolved Hide resolved
assert w_ptr_t0 in opt_params_ptrs
assert b_ptr_t0 in opt_params_ptrs
assert w_ptr_new in opt_params_ptrs
Expand Down Expand Up @@ -455,13 +456,23 @@ def test_multihead_head_selection(self):
for x, y, t in DataLoader(benchmark.train_stream[0].dataset):
y_mh = model(x, t)
y_t = model_t0(x)

# We need to pad y_t to dim with zeros
# because y_mh will have max dim of all heads
y_t = F.pad(y_t, (0, y_mh.size(1) - y_t.size(1)))

AlbinSou marked this conversation as resolved.
Show resolved Hide resolved
assert ((y_mh - y_t) ** 2).sum() < 1.0e-7
break

# check head task4
for x, y, t in DataLoader(benchmark.train_stream[4].dataset):
y_mh = model(x, t)
y_t = model_t4(x)

# We need to pad y_t to dim with zeros
# because y_mh will have max dim of all heads
y_t = F.pad(y_t, (0, y_mh.size(1) - y_t.size(1)))

assert ((y_mh - y_t) ** 2).sum() < 1.0e-7
break

Expand Down
Loading