From 03c35de32467e0ad6a82395af3a4b65cc54ac86e Mon Sep 17 00:00:00 2001 From: Luca Peric Date: Fri, 13 Sep 2024 10:00:30 +0000 Subject: [PATCH] Bugfixing completion stats break with new reasoning tokens release --- evals/cli/oaieval.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index a8927dda4c..6bc4304a89 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -5,7 +5,7 @@ import logging import shlex import sys -from typing import Any, Mapping, Optional, Union, cast +from typing import Any, List, Mapping, Optional, TypedDict, Union, cast import evals import evals.api @@ -263,6 +263,33 @@ def build_recorder( ) +CompletionTokenDetails = TypedDict("CompletionTokenDetails", {"reasoning_tokens": Optional[int]}) +UsageEvent = TypedDict("UsageEvent", { + "completion_token_details": Optional[CompletionTokenDetails], + "total_tokens": Optional[int], + "prompt_tokens": Optional[int], + "completion_tokens": Optional[int], +}) + + +def get_val(val: Union[Optional[int], CompletionTokenDetails]) -> int: + if val is None: + return 0 + if isinstance(val, dict): + return val["reasoning_tokens"] if val["reasoning_tokens"] is not None else 0 + return val + + +def sum_usage(usage_events: List[UsageEvent]) -> Mapping[str, int]: + total_usages_by_key = {} + for key in usage_events[0]: + total_usages_by_key[key] = sum( + get_val(u[key]) + for u in usage_events + ) + return total_usages_by_key + + def add_token_usage_to_result(result: dict[str, Any], recorder: RecorderBase) -> None: """ Add token usage from logged sampling events to the result dictionary from the recorder. @@ -275,10 +302,7 @@ def add_token_usage_to_result(result: dict[str, Any], recorder: RecorderBase) -> logger.info(f"Found {len(usage_events)}/{len(sampling_events)} sampling events with usage data") if usage_events: # Sum up the usage of all samples (assumes the usage is the same for all samples) - total_usage = { - key: sum(u[key] if u[key] is not None else 0 for u in usage_events) - for key in usage_events[0] - } + total_usage = sum_usage(usage_events) total_usage_str = "\n".join(f"{key}: {value:,}" for key, value in total_usage.items()) logger.info(f"Token usage from {len(usage_events)} sampling events:\n{total_usage_str}") for key, value in total_usage.items():