From 4dedf6ebe9c275ffa16ddc05ec1f34957590a5d0 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Mon, 3 Jun 2024 10:15:19 +0200 Subject: [PATCH] fix hasattr check for task-awareness --- avalanche/training/templates/base_sgd.py | 4 +--- .../training/templates/problem_type/supervised_problem.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/avalanche/training/templates/base_sgd.py b/avalanche/training/templates/base_sgd.py index 8cca4b47e..b51046edf 100644 --- a/avalanche/training/templates/base_sgd.py +++ b/avalanche/training/templates/base_sgd.py @@ -456,9 +456,7 @@ def make_train_dataloader( other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"] # use task-balanced dataloader for task-aware benchmarks - if hasattr(self.experience, "task_label") or hasattr( - self.experience, "task_labels" - ): + if hasattr(self.experience, "task_labels"): self.dataloader = TaskBalancedDataLoader( self.adapted_dataset, oversample_small_groups=True, diff --git a/avalanche/training/templates/problem_type/supervised_problem.py b/avalanche/training/templates/problem_type/supervised_problem.py index 1093abc0e..d7c8b7d54 100644 --- a/avalanche/training/templates/problem_type/supervised_problem.py +++ b/avalanche/training/templates/problem_type/supervised_problem.py @@ -45,9 +45,7 @@ def criterion(self): def forward(self): """Compute the model's output given the current mini-batch.""" # use task-aware forward only for task-aware benchmarks - if hasattr(self.experience, "task_labels") or hasattr( - self.experience, "task_label" - ): + if hasattr(self.experience, "task_labels"): return avalanche_forward(self.model, self.mb_x, self.mb_task_id) else: return self.model(self.mb_x)