From 024654cb37b77ee8016b23927508a8896c209c00 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Fri, 1 Mar 2024 14:56:15 +0000 Subject: [PATCH] wip on rolling management --- src/lighteval/models/base_model.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/lighteval/models/base_model.py b/src/lighteval/models/base_model.py index 8ebe90de4..7685cc1b5 100644 --- a/src/lighteval/models/base_model.py +++ b/src/lighteval/models/base_model.py @@ -514,6 +514,7 @@ def loglikelihood_rolling( requests, override_bs=override_bs, return_bool_score=False, + rolling=True, ) return results @@ -522,6 +523,7 @@ def _loglikelihood_tokens( requests: list[LoglikelihoodRequest], override_bs: int = -1, return_bool_score: bool = True, + rolling: bool = False, ) -> list[LoglikelihoodReturn]: dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS) starting_batch_size = STARTING_BATCH_SIZE @@ -530,9 +532,12 @@ def _loglikelihood_tokens( for split_start, split_end in tqdm(dataset.splits_start_end_iterator()): context_enc = dataset[0].tokenized_context continuation_enc = dataset[0].tokenized_continuation - max_context_continuation_size_allowed = len( - (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] - ) + if rolling: # we take all the sequence in rolling mode + max_context_continuation_size_allowed = len(context_enc + continuation_enc) + else: # in normal mode, we left cut the context if needed + max_context_continuation_size_allowed = len( + (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] + ) batch_size = self._get_batch_size( override_bs=override_bs,