Skip to content

Commit

Permalink
add test for is_last in OCL
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Oct 13, 2023
1 parent 2051c09 commit d65bf03
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
5 changes: 5 additions & 0 deletions avalanche/benchmarks/scenarios/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,18 @@ def __iter__(self) -> Generator[OnlineCLExperience, None, None]:
exp_idx = 0
while init_idx < len(exp_indices):
final_idx = init_idx + self.experience_size # Exclusive

if final_idx > len(exp_indices):
if self.drop_last:
break

final_idx = len(exp_indices)
is_last = True

# check is_last when drop_last=True
if self.drop_last and (final_idx + self.experience_size > len(exp_indices)):
is_last = True

sub_exp_subset = exp_dataset.subset(exp_indices[init_idx:final_idx])
exp = OnlineCLExperience(
dataset=sub_exp_subset,
Expand Down
37 changes: 36 additions & 1 deletion tests/benchmarks/scenarios/test_online_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_ocl_scenario_experience(self):
for s in ocl_benchmark.streams.values():
print(s.name)

def test_split_online_stream(self):
def test_split_online_stream_drop_last(self):
num_exp, num_classes = 5, 10
d1, d2 = dummy_classification_datasets(n_classes=num_classes)
bm = class_incremental_benchmark(
Expand All @@ -31,9 +31,44 @@ def test_split_online_stream(self):
online_train_stream = split_online_stream(
bm.train_stream, experience_size=10, drop_last=True
)
cnt_is_first = 0
cnt_is_last = 0
for exp in online_train_stream:
if exp.is_first_subexp:
cnt_is_first += 1
if exp.is_last_subexp:
cnt_is_last += 1
assert len(exp.dataset) == 10

assert exp.is_last_subexp # final exp should have is_last_subexp == True
assert cnt_is_last == len(bm.train_stream)
assert cnt_is_first == len(bm.train_stream)

def test_split_online_stream_not_drop(self):
num_exp, num_classes = 5, 10
d1, d2 = dummy_classification_datasets(n_classes=num_classes)
bm = class_incremental_benchmark(
{"train": d1, "test": d2}, num_experiences=num_exp
)
online_train_stream = split_online_stream(
bm.train_stream, experience_size=10, drop_last=False
)
cnt_is_first = 0
cnt_is_last = 0
for exp in online_train_stream:
if exp.is_first_subexp:
cnt_is_first += 1
if exp.is_last_subexp:
cnt_is_last += 1

if not exp.is_last_subexp:
assert len(exp.dataset) == 10
else:
assert len(exp.dataset) <= 10

assert exp.is_last_subexp # final exp should have is_last_subexp == True
assert cnt_is_last == len(bm.train_stream)
assert cnt_is_first == len(bm.train_stream)

if __name__ == "__main__":
unittest.main()

0 comments on commit d65bf03

Please sign in to comment.