Skip to content

Commit

Permalink
fix hasattr check for task-awareness
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Jun 3, 2024
1 parent 91841ac commit 4dedf6e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
4 changes: 1 addition & 3 deletions avalanche/training/templates/base_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,7 @@ def make_train_dataloader(
other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"]

# use task-balanced dataloader for task-aware benchmarks
if hasattr(self.experience, "task_label") or hasattr(
self.experience, "task_labels"
):
if hasattr(self.experience, "task_labels"):
self.dataloader = TaskBalancedDataLoader(
self.adapted_dataset,
oversample_small_groups=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def criterion(self):
def forward(self):
"""Compute the model's output given the current mini-batch."""
# use task-aware forward only for task-aware benchmarks
if hasattr(self.experience, "task_labels") or hasattr(
self.experience, "task_label"
):
if hasattr(self.experience, "task_labels"):
return avalanche_forward(self.model, self.mb_x, self.mb_task_id)
else:
return self.model(self.mb_x)
Expand Down

0 comments on commit 4dedf6e

Please sign in to comment.