diff --git a/avalanche/benchmarks/scenarios/online.py b/avalanche/benchmarks/scenarios/online.py index cb164e576..928f43b9f 100644 --- a/avalanche/benchmarks/scenarios/online.py +++ b/avalanche/benchmarks/scenarios/online.py @@ -183,6 +183,7 @@ def __iter__(self) -> Generator[OnlineCLExperience, None, None]: exp_idx = 0 while init_idx < len(exp_indices): final_idx = init_idx + self.experience_size # Exclusive + if final_idx > len(exp_indices): if self.drop_last: break @@ -190,6 +191,10 @@ def __iter__(self) -> Generator[OnlineCLExperience, None, None]: final_idx = len(exp_indices) is_last = True + # check is_last when drop_last=True + if self.drop_last and (final_idx + self.experience_size > len(exp_indices)): + is_last = True + sub_exp_subset = exp_dataset.subset(exp_indices[init_idx:final_idx]) exp = OnlineCLExperience( dataset=sub_exp_subset, diff --git a/tests/benchmarks/scenarios/test_online_scenario.py b/tests/benchmarks/scenarios/test_online_scenario.py index 9ce0d7e62..4b3322b06 100644 --- a/tests/benchmarks/scenarios/test_online_scenario.py +++ b/tests/benchmarks/scenarios/test_online_scenario.py @@ -22,7 +22,7 @@ def test_ocl_scenario_experience(self): for s in ocl_benchmark.streams.values(): print(s.name) - def test_split_online_stream(self): + def test_split_online_stream_drop_last(self): num_exp, num_classes = 5, 10 d1, d2 = dummy_classification_datasets(n_classes=num_classes) bm = class_incremental_benchmark( @@ -31,9 +31,44 @@ def test_split_online_stream(self): online_train_stream = split_online_stream( bm.train_stream, experience_size=10, drop_last=True ) + cnt_is_first = 0 + cnt_is_last = 0 for exp in online_train_stream: + if exp.is_first_subexp: + cnt_is_first += 1 + if exp.is_last_subexp: + cnt_is_last += 1 assert len(exp.dataset) == 10 + + assert exp.is_last_subexp # final exp should have is_last_subexp == True + assert cnt_is_last == len(bm.train_stream) + assert cnt_is_first == len(bm.train_stream) + def test_split_online_stream_not_drop(self): + num_exp, num_classes = 5, 10 + d1, d2 = dummy_classification_datasets(n_classes=num_classes) + bm = class_incremental_benchmark( + {"train": d1, "test": d2}, num_experiences=num_exp + ) + online_train_stream = split_online_stream( + bm.train_stream, experience_size=10, drop_last=False + ) + cnt_is_first = 0 + cnt_is_last = 0 + for exp in online_train_stream: + if exp.is_first_subexp: + cnt_is_first += 1 + if exp.is_last_subexp: + cnt_is_last += 1 + + if not exp.is_last_subexp: + assert len(exp.dataset) == 10 + else: + assert len(exp.dataset) <= 10 + + assert exp.is_last_subexp # final exp should have is_last_subexp == True + assert cnt_is_last == len(bm.train_stream) + assert cnt_is_first == len(bm.train_stream) if __name__ == "__main__": unittest.main()