diff --git a/community_tasks/arabic_evals.py b/community_tasks/arabic_evals.py index 86ab69e2..a68abbe6 100644 --- a/community_tasks/arabic_evals.py +++ b/community_tasks/arabic_evals.py @@ -28,8 +28,11 @@ """ import random import re +from typing import Any, Dict, List, Optional, Union -from lighteval.metrics.metrics import Metrics +from lighteval.metrics.llm_as_judge import JudgeLM +from lighteval.metrics.metrics import Metric, MetricCategory, Metrics +from lighteval.metrics.utils.metric_utils import MetricUseCase from lighteval.tasks.default_prompts import LETTER_INDICES from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc @@ -832,6 +835,215 @@ def __init__( ] +class JudgeMetricWrapper(Metric): + """Wrapper class for LLM-based judge metric implementation.""" + + def __init__(self, judge: JudgeLM): + """ + Initializes the judge metric wrapper. + + Args: + judge (JudgeLM): The LLM judge instance to use for evaluation. + """ + self.judge = judge + self.metric_name = "llm_as_judge" + self.category = MetricCategory.LLM_AS_JUDGE + self.corpus_level_fn = self.aggregate_scores + self.sample_level_fn = self._sample_level_fn + self.higher_is_better = True # Fixed tuple syntax + self.use_case = MetricUseCase.NONE + + def compute(self, responses: list[str], formatted_docs: list[Doc], **kwargs) -> dict[str, float]: + """ + Computes evaluation scores using the judge's evaluate_answer method. + + Args: + responses (list[str]): The predicted answers + formatted_docs (list[Doc]): Documents containing questions and gold answers + + Returns: + dict[str, float]: Dictionary containing evaluation scores + """ + results = [] + for i, doc in enumerate(formatted_docs): + question = doc.query + gold = doc.choices[doc.gold_index] if doc.gold_index is not None else None + answer = responses[i][0].result[0] + + score, _, _ = self.judge.evaluate_answer(question=question, answer=answer, options=None, gold=gold) + results.append({self.metric_name: score}) + + return results + + def aggregate_scores(self, scores: list[dict]) -> float: + return sum(scores) / len(scores) if scores else 0.0 + + def _sample_level_fn(self): + return None + + +def parse_candidates(candidates: Union[List[str], str]) -> List[str]: + """ + Parses and validates candidate answers from either list or string format. + + Args: + candidates: Either a list of candidate answers or a newline-separated string + + Returns: + List[str]: List of validated candidate answers + + Raises: + ValueError: If candidates cannot be parsed or are empty + """ + try: + if isinstance(candidates, list): + parsed_candidates = [str(c).strip() for c in candidates if c] + else: + parsed_candidates = [c.strip() for c in str(candidates).split("\n") if c.strip()] + + if not parsed_candidates: + raise ValueError("No valid candidates found after parsing") + + return parsed_candidates + except Exception as e: + raise ValueError(f"Failed to parse candidates: {str(e)}") + + +def qa_prompt_arabic(line: Dict[str, Any], task_name: str = None) -> Doc: + """ + Formats the prompt for Arabic question answering with candidates. + + Args: + line: Dictionary containing question and candidate information + task_name: Optional name for the task + + Returns: + Doc: Formatted document for evaluation + + Raises: + ValueError: If required fields are missing or invalid + """ + try: + # Validates and extracts the question + if not isinstance(line.get("question"), str): + raise ValueError("Question must be a string") + question = line["question"] + + # Processes candidate answers + candidates = parse_candidates(line["candidates"]) + + # Validates gold answer + if "gold_answer" not in line: + raise ValueError("Gold answer is required") + gold_answer = str(line["gold_answer"]) + + # Constructs the prompt + instruction = "بناءً على السياقات المقترحة التالية، اجب عن السؤال التالي" + query = f"{instruction}\n\nالسؤال:\n{question}\n\nالسياقات المقترحة:\n{', '.join(candidates)}\n" + + return Doc( + task_name=task_name or "alrage", + query=query, + instruction=instruction, + choices=[gold_answer], # Gold answer is used as the only valid choice + gold_index=0, # Index of the correct answer in choices + ) + except Exception as e: + raise ValueError(f"Failed to create QA prompt: {str(e)}") + + +def judge_template(question: str, answer: str, gold: str, options: Optional[List[str]] = None) -> List[Dict[str, str]]: + """ + Template for the Arabic judge prompt. + + System prompt translation: + You are a neutral expert evaluator. Your tasks are: + 1. Evaluate the answer's accuracy compared to the correct answer + 2. Verify that the answer is supported by the provided context + 3. Evaluate the quality and comprehensiveness of the answer + Rate the answer on a scale from 0 to 10. + + Args: + question: The question being evaluated + answer: The provided answer + gold: The correct answer + options: Optional list of answer choices + + Returns: + List[Dict[str, str]]: Formatted messages for the judge + """ + messages = [ + { + "role": "system", + "content": """أنت مقيّم محايد خبير باللغة العربية. يجب عليك: +1. تقييم دقة الإجابة مقارنة بالإجابة الصحيحة +2. التحقق من أن الإجابة مدعومة بالسياق المقدم +3. تقييم جودة وشمولية الإجابة + +مهم جداً: يجب أن يكون ردك رقماً فقط من 0 إلى 10. لا تضف أي نص أو تفسير.""", + }, + { + "role": "user", + "content": f"""السؤال: {question} + +الإجابة المقدمة: {answer} + +الإجابة الصحيحة: {gold} + +أعط تقييماً من 0 إلى 10: +0-2: إجابة خاطئة تماماً +3-4: إجابة جزئية مع أخطاء +5-6: إجابة متوسطة +7-8: إجابة جيدة +9-10: إجابة ممتازة + +اكتب رقماً فقط من 0 إلى 10 بدون أي نص إضافي:""", + }, + ] + return messages + + +def process_judge_response(response) -> float: + """Process the judge's response to extract the score""" + # If response is a list, extract the content from the user role + if isinstance(response, list): + response_content = " ".join(item["content"] for item in response if item["role"] == "user") + else: + response_content = response # If it's not a list, use it directly + + try: + # Extract the score from the response content + score = float(next(num for num in response_content.split() if num.replace(".", "", 1).isdigit())) + return min(max(score / 10.0, 0.0), 1.0) + except (StopIteration, ValueError): + return 0.0 + + +judge = JudgeLM( + model="Qwen/Qwen2.5-72B-Instruct", + templates=judge_template, + process_judge_response=process_judge_response, + judge_backend="vllm", +) + +wrapped_judge = JudgeMetricWrapper(judge) + +# Task configuration +alrage_qa_task = LightevalTaskConfig( + name="alrage_qa", + prompt_function=qa_prompt_arabic, + suite=["community"], + hf_repo="OALL/ALRAGE", + hf_subset=None, + hf_avail_splits=["train"], + evaluation_splits=["train"], + metric=[wrapped_judge], + trust_dataset=True, + generation_size=200, + stop_sequence=[], + version=0, +) + TASKS_TABLE = ( ARABIC_MMLU_TASKS + ARABIC_MMLU_HT_TASKS @@ -852,4 +1064,5 @@ def __init__( + [hellaswag_okapi_ar_task] + [toxigen_ar_task] + [sciq_ar_task] + + [alrage_qa_task] ) diff --git a/examples/tasks/OALL_v2_tasks.txt b/examples/tasks/OALL_v2_tasks.txt index fc1b4f7e..176b662d 100644 --- a/examples/tasks/OALL_v2_tasks.txt +++ b/examples/tasks/OALL_v2_tasks.txt @@ -115,3 +115,4 @@ community|arabic_mmlu_ht:sociology|0|0 community|arabic_mmlu_ht:us_foreign_policy|0|0 community|arabic_mmlu_ht:virology|0|0 community|arabic_mmlu_ht:world_religions|0|0 +community|alrage_qa|0|0