diff --git a/community_tasks/swiss_legal_evals.py b/community_tasks/swiss_legal_evals.py index 9414e1e1..6c1483dd 100644 --- a/community_tasks/swiss_legal_evals.py +++ b/community_tasks/swiss_legal_evals.py @@ -45,6 +45,7 @@ from nltk.translate import meteor_score from packaging import version from sacrebleu import sentence_bleu, sentence_chrf, sentence_ter +from tqdm import tqdm from transformers import AutoModelForSequenceClassification, AutoTokenizer from lighteval.metrics.imports.bert_scorer import BERTScorer @@ -520,7 +521,10 @@ def compute( # Process in batches all_scores = [] - for i in range(0, len(references), self.batch_size): + for i in tqdm( + range(0, len(references), self.batch_size), + desc=f"Processing batches of size {self.batch_size} with {self.metric_name}", + ): batch_refs = references[i : i + self.batch_size] batch_preds = predictions[i : i + self.batch_size] try: @@ -589,15 +593,7 @@ def compute( predictions = [response[0].result for response in responses] sources = [kwargs["formatted_doc"].specific["source"] for kwargs["formatted_doc"] in formatted_docs] - def unpack(x): - if isinstance(x, str): - return x - elif isinstance(x, (list, tuple)): - return unpack(x[0]) - else: - raise ValueError(f"Unknown type {type(x)} of prediction {x}") - - data = [{"src": src, "mt": unpack(pred), "ref": gold} for src, pred, gold in zip(sources, predictions, golds)] + data = [{"src": src, "mt": pred, "ref": gold} for src, pred, gold in zip(sources, predictions, golds)] model_output = self.model.predict( data, batch_size=self.batch_size,