Skip to content

Commit

Permalink
remove multitask tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbinSou committed Dec 11, 2023
1 parent 59ad9c1 commit 235bc26
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 88 deletions.
73 changes: 38 additions & 35 deletions avalanche/training/plugins/feature_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,52 @@
import tqdm
from torch import Tensor, nn

from avalanche.models.utils import avalanche_forward
from avalanche.training.plugins import SupervisedPlugin
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.utils import _at_task_boundary, cycle
from avalanche.models.utils import avalanche_forward


class FeatureExtractorModel(nn.Module):
"""
Feature extractor that additionnaly stores the features
"""

def __init__(self, feature_extractor, train_classifier):
super().__init__()
self.feature_extractor = feature_extractor
self.train_classifier = train_classifier
self.features = None

def forward(self, x):
self.features = self.feature_extractor(x)
x = self.train_classifier(self.features)
return x


class FeatureDataset(torch.utils.data.Dataset):
"""
Wrapper around features tensor dataset
Required for compatibility with storage policy
"""

def __init__(self, data, targets):
self.data = data
self.targets = targets

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index], self.targets[index]


class FeatureDistillationPlugin(SupervisedPlugin):
def __init__(self, alpha=1):
"""
Adds a Distillation loss term on the features of the model,
trying to maximize the cosine similarity between current and old features
:param alpha: distillation hyperparameter. It can be either a float
number or a list containing alpha for each experience.
"""
Expand Down Expand Up @@ -51,40 +88,6 @@ def after_training_exp(self, strategy, **kwargs):
self.prev_model = copy.deepcopy(strategy.model)


class FeatureExtractorModel(nn.Module):
"""
Feature extractor that additionnaly stores the features
"""

def __init__(self, feature_extractor, train_classifier):
super().__init__()
self.feature_extractor = feature_extractor
self.train_classifier = train_classifier
self.features = None

def forward(self, x):
self.features = self.feature_extractor(x)
x = self.train_classifier(self.features)
return x


class FeatureDataset(torch.utils.data.Dataset):
"""
Wrapper around features tensor dataset
Required for compatibility with storage policy
"""

def __init__(self, data, targets):
self.data = data
self.targets = targets

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index], self.targets[index]


class FeatureReplayPlugin(SupervisedPlugin):
"""
Store some features and use them for replay
Expand Down
53 changes: 0 additions & 53 deletions tests/training/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,32 +1088,6 @@ def test_feature_distillation(self):
)
run_strategy(benchmark, strategy)

# Multitask

#model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True)

## Modify model to make it compatible
#last_fc_name = "classifier"
#old_layer = getattr(model, last_fc_name)
#setattr(model, last_fc_name, torch.nn.Identity())
#model = FeatureExtractorModel(model, old_layer)

#feature_distillation = FeatureDistillationPlugin(alpha=10)

#plugins = [feature_distillation]

#strategy = Naive(
# model,
# optimizer,
# criterion,
# device=self.device,
# train_mb_size=10,
# eval_mb_size=50,
# train_epochs=2,
# plugins=plugins,
#)
#run_strategy(benchmark, strategy)

def test_feature_replay(self):
# SIT scenario
model, optimizer, criterion, benchmark = self.init_scenario(multi_task=False)
Expand All @@ -1140,33 +1114,6 @@ def test_feature_replay(self):
)
run_strategy(benchmark, strategy)

# Multitask

#model, optimizer, criterion, benchmark = self.init_scenario(multi_task=True)

## Modify model to make it compatible
#last_fc_name = "classifier"
#old_layer = getattr(model, last_fc_name)
#setattr(model, last_fc_name, torch.nn.Identity())
#model = FeatureExtractorModel(model, old_layer)

#feature_replay = FeatureReplayPlugin(mem_size=100)

#plugins = [feature_replay]

#strategy = Naive(
# model,
# optimizer,
# criterion,
# device=self.device,
# train_mb_size=10,
# eval_mb_size=50,
# train_epochs=2,
# plugins=plugins,
#)
#run_strategy(benchmark, strategy)


def load_benchmark(
self,
use_task_labels=False,
Expand Down

0 comments on commit 235bc26

Please sign in to comment.