Skip to content

Commit

Permalink
Rolling management (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier authored Mar 8, 2024
1 parent 988959c commit bca2b1d
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ def loglikelihood_rolling(
requests,
override_bs=override_bs,
return_bool_score=False,
rolling=True,
)
return results

Expand All @@ -568,6 +569,7 @@ def _loglikelihood_tokens(
requests: list[LoglikelihoodRequest],
override_bs: int = -1,
return_bool_score: bool = True,
rolling: bool = False,
) -> list[LoglikelihoodReturn]:
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
starting_batch_size = STARTING_BATCH_SIZE
Expand All @@ -576,9 +578,12 @@ def _loglikelihood_tokens(
for split_start, split_end in tqdm(dataset.splits_start_end_iterator()):
context_enc = dataset[0].tokenized_context
continuation_enc = dataset[0].tokenized_continuation
max_context_continuation_size_allowed = len(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
)
if rolling: # we take all the sequence in rolling mode
max_context_continuation_size_allowed = len(context_enc + continuation_enc)
else: # in normal mode, we left cut the context if needed
max_context_continuation_size_allowed = len(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
)

batch_size = self._get_batch_size(
override_bs=override_bs,
Expand Down

0 comments on commit bca2b1d

Please sign in to comment.