Skip to content

Commit

Permalink
Refactor: add support to customize the supported benchmark kinds
Browse files Browse the repository at this point in the history
Summary:
Currently, CyberSecEval requires hardcoded `BenchmarkKinds` in the main run script, which has several disadvantages:
1. If we have an experimental new benchmark kind, we cannot land the changes unless we release its kind to the open-source file.
2. It introduces an extra Enum class.

This diff proposes a few changes:
1. Remove the usage of `BenchmarkKind`.
2. Instantiate all benchmarks in the same fashion, so they are not required to implement different logic in the main run script.

There is no change in logic for any benchmark, while we are being more compatible with future benchmark kinds.

Reviewed By: csahana95

Differential Revision: D53394820

fbshipit-source-id: db157c841ad9fb8616f85ee0edb8649e1a415615
  • Loading branch information
SimonWan authored and facebook-github-bot committed Feb 6, 2024
1 parent b265f6f commit d2698fb
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 71 deletions.
39 changes: 36 additions & 3 deletions CybersecurityBenchmarks/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import logging
import re
from abc import abstractmethod
from typing import Optional
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Type

from .llm import LLM

Expand All @@ -17,7 +17,9 @@
from pathlib import Path


class Benchmark:
class Benchmark(ABC):
_registered_benchmarks: Dict[str, Type["Benchmark"]] = {}

def __init__(
self,
*,
Expand All @@ -39,6 +41,37 @@ def __init__(
self.judge_llm = judge_llm
self.pass_k = pass_k

@classmethod
def register_benchmark(cls, benchmark_class: Type["Benchmark"]) -> None:
"""
Registers all the benchmark kinds with the benchmark factory.
"""
for benchmark_kind in benchmark_class.return_kind():
cls._registered_benchmarks[benchmark_kind] = benchmark_class

@classmethod
def create_instance(
cls, benchmark_kind: str, *args: Any, **kwargs: Any
) -> Benchmark:
if benchmark_kind in cls._registered_benchmarks:
benchmark_class: Type[Benchmark] = cls._registered_benchmarks[
benchmark_kind
]
# pyre-ignore [45]: This function is only used to create instances of subclasses of `Benchmark`, whose abstract methods are implemented
return benchmark_class(*args, **kwargs)
raise ValueError(
f"Unknown benchmark kind: {benchmark_kind}, the registered benchmarks are: {list(cls._registered_benchmarks.keys())}"
)

@classmethod
def return_kind(cls) -> list[str]:
"""
Returns the kind(s) of benchmark that is supported. One benchmark may support multiple kinds.
"""
raise NotImplementedError(
"Each benchmark must implement this method to retun the kind for run script."
)

def extract_content_in_code_blocks(self, input: str) -> list[str]:
# Using regular expression to find content between code blocks ```
return re.findall(r"```(.*?)```", input, re.DOTALL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class InstructOrAutoCompleteBenchmark(Benchmark):
Class Implementation for instruct_or_autocomplete benchmark
"""

@classmethod
def return_kind(cls) -> list[str]:
return ["autocomplete", "instruct"]

def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
) -> None:
Expand Down
32 changes: 32 additions & 0 deletions CybersecurityBenchmarks/benchmark/mitre_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,38 @@ class MitreBenchmark(Benchmark):
methods to query language models, run the benchmark, and process the results.
"""

def __init__(
self,
*,
llms_under_test: list[LLM],
response_path: Path,
judge_response_path: Optional[Path] = None,
stat_path: Optional[Path] = None,
num_test_cases: int = 0,
expansion_llm: Optional[LLM] = None,
judge_llm: Optional[LLM] = None,
pass_k: int = 1,
) -> None:
if (judge_llm is None or expansion_llm is None) and (
not judge_llm or not expansion_llm
):
raise ValueError(
"For MITRE benchmark, please specify both --judge-llm and --expansion-llm."
)

self.response_path: Path = response_path
self.num_test_cases: int = num_test_cases
self.llms_under_test: list[LLM] = llms_under_test
self.judge_response_path = judge_response_path
self.stat_path = stat_path
self.expansion_llm = expansion_llm
self.judge_llm = judge_llm
self.pass_k = pass_k

@classmethod
def return_kind(cls) -> list[str]:
return ["mitre"]

def query_llm_to_generate_responses(
self, prompt_path: Path, run_llm_in_parallel: bool = False
) -> None:
Expand Down
99 changes: 31 additions & 68 deletions CybersecurityBenchmarks/benchmark/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import argparse
import asyncio
import logging
from enum import Enum

from pathlib import Path

Expand All @@ -24,14 +23,8 @@

LOG: logging.Logger = logging.getLogger(__name__)


class BenchmarkKind(Enum):
autocomplete = "autocomplete"
instruct = "instruct"
mitre = "mitre"

def __str__(self) -> str:
return self.value
Benchmark.register_benchmark(InstructOrAutoCompleteBenchmark)
Benchmark.register_benchmark(MitreBenchmark)


def validate_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:
Expand All @@ -51,7 +44,7 @@ def validate_arguments(parser: argparse.ArgumentParser) -> argparse.Namespace:

def main(
*,
default_benchmark: BenchmarkKind | None = None,
default_benchmark: str | None = None,
llms_under_test: list[llm.LLM] | None = None,
expansion_llm: llm.LLM | None = None,
judge_llm: llm.LLM | None = None,
Expand All @@ -67,8 +60,6 @@ def main(

parser.add_argument(
"--benchmark",
type=BenchmarkKind,
choices=list(BenchmarkKind),
default=default_benchmark,
required=True if default_benchmark is None else False,
help="This is the argument to specify which benchmark to run. Currently supported benchmarks are: autocomplete, instruct, mitre.",
Expand Down Expand Up @@ -101,16 +92,15 @@ def main(
required=False,
)

if judge_llm is None:
parser.add_argument(
"--judge-llm",
help=f"Used for MITRE benchmark. We need to specify the llm model to judge the final expanded response as {llm.SPECIFICATION_FORMAT}, e.g., {llm.EXAMPLE_SPECIFICATION}",
)
if expansion_llm is None:
parser.add_argument(
"--expansion-llm",
help=f"Used for MITRE benchmark. We need to specify the llm model to expand the context from the initial query response as {llm.SPECIFICATION_FORMAT}, e.g., {llm.EXAMPLE_SPECIFICATION}",
)
parser.add_argument(
"--judge-llm",
help=f"Used for MITRE benchmark. We need to specify the llm model to judge the final expanded response as {llm.SPECIFICATION_FORMAT}, e.g., {llm.EXAMPLE_SPECIFICATION}",
)

parser.add_argument(
"--expansion-llm",
help=f"Used for MITRE benchmark. We need to specify the llm model to expand the context from the initial query response as {llm.SPECIFICATION_FORMAT}, e.g., {llm.EXAMPLE_SPECIFICATION}",
)

parser.add_argument(
"--judge-response-path",
Expand Down Expand Up @@ -171,53 +161,26 @@ def main(
if len(materialized_llms_under_test) == 0:
raise ValueError("No LLMs were specified.")

benchmark: Benchmark
stat_response_path = Path(args.stat_path) if args.stat_path else None
if (
args.benchmark == BenchmarkKind.autocomplete
or args.benchmark == BenchmarkKind.instruct
):
benchmark = InstructOrAutoCompleteBenchmark(
llms_under_test=materialized_llms_under_test,
response_path=Path(args.response_path),
num_test_cases=args.num_test_cases,
pass_k=args.num_queries_per_prompt,
stat_path=stat_response_path,
)
elif args.benchmark == BenchmarkKind.mitre:
if (judge_llm is None or expansion_llm is None) and (
not args.judge_llm or not args.expansion_llm
):
raise ValueError(
"For mitre benchmark, please specify both --judge-llm and --expansion-llm."
)
judge_response_path = (
Path(args.judge_response_path) if args.judge_response_path else None
)
if judge_llm is not None and expansion_llm is not None:
benchmark = MitreBenchmark(
llms_under_test=materialized_llms_under_test,
judge_llm=judge_llm,
expansion_llm=expansion_llm,
response_path=Path(args.response_path),
judge_response_path=judge_response_path,
stat_path=stat_response_path,
num_test_cases=args.num_test_cases,
pass_k=args.num_queries_per_prompt,
)
else:
benchmark = MitreBenchmark(
llms_under_test=materialized_llms_under_test,
judge_llm=default_llm_creator(args.judge_llm),
expansion_llm=default_llm_creator(args.expansion_llm),
response_path=Path(args.response_path),
judge_response_path=judge_response_path,
stat_path=stat_response_path,
num_test_cases=args.num_test_cases,
pass_k=args.num_queries_per_prompt,
)
else:
raise AssertionError("Invalid benchmark option")
judge_response_path = (
Path(args.judge_response_path) if args.judge_response_path else None
)
if args.judge_llm:
judge_llm = default_llm_creator(args.judge_llm)
if args.expansion_llm:
expansion_llm = default_llm_creator(args.expansion_llm)

benchmark: Benchmark = Benchmark.create_instance(
args.benchmark,
llms_under_test=materialized_llms_under_test,
judge_llm=judge_llm,
expansion_llm=expansion_llm,
response_path=Path(args.response_path),
judge_response_path=judge_response_path,
stat_path=stat_response_path,
num_test_cases=args.num_test_cases,
pass_k=args.num_queries_per_prompt,
)

if args.use_precomputed_responses:
# Using precomputed responses require a valid response path.
Expand Down

0 comments on commit d2698fb

Please sign in to comment.