Skip to content

Commit

Permalink
FIX tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Oct 13, 2023
1 parent ab2dfb9 commit 2827d75
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion avalanche/benchmarks/scenarios/dataset_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def split_validation_class_balanced(
# shuffle exp_indices
exp_indices_t = torch.as_tensor(exp_indices)[torch.randperm(len(exp_indices))]
# shuffle the targets as well
exp_targets = targets_as_tensor[exp_indices]
exp_targets = targets_as_tensor[exp_indices_t]

train_exp_indices: list[int] = []
valid_exp_indices: list[int] = []
Expand Down
1 change: 1 addition & 0 deletions tests/benchmarks/scenarios/test_dataset_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def test_split_dataset_class_balanced(self):
for cid in exp.classes_in_this_experience:
train_cnt = (torch.as_tensor(train_d.targets) == cid).sum()
valid_cnt = (torch.as_tensor(valid_d.targets) == cid).sum()
# print(train_cnt, valid_cnt)
assert abs(train_cnt - valid_cnt) <= 1

ratio = 0.123
Expand Down
2 changes: 1 addition & 1 deletion tests/benchmarks/utils/test_avalanche_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_subset_subset_merge(self):
self.assertTrue(torch.equal(x_curr, x_true))

t_curr = torch.tensor(
[curr_dataset.task_labels[idx] for idx in range(d_sz)]
[curr_dataset.targets_task_labels[idx] for idx in range(d_sz)]
)
t_true = torch.stack([dadata[idx] for idx in true_indices], dim=0)
self.assertTrue(torch.equal(t_curr, t_true))
Expand Down
4 changes: 2 additions & 2 deletions tests/training/test_online_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_naive(self):
benchmark_streams = benchmark.streams.values()

# With task boundaries
model, optimizer, criterion, my_nc_benchmark = self.init_sit()
model, optimizer, criterion, _ = self.init_sit()
strategy = OnlineNaive(
model,
optimizer,
Expand Down Expand Up @@ -113,7 +113,7 @@ def run_strategy_no_boundaries(self, benchmark, cl_strategy):
cl_strategy.evaluator.loggers = [TextLogger(sys.stdout)]
results = []

cl_strategy.train(benchmark.train_stream, num_workers=0)
cl_strategy.train(benchmark.train_online_stream, num_workers=0)
print("Training completed")

assert cl_strategy.clock.train_exp_counter > 0
Expand Down

0 comments on commit 2827d75

Please sign in to comment.