Skip to content

Commit

Permalink
Injection Benchmark Implementation
Browse files Browse the repository at this point in the history
Summary: Heavily draws from the existing MITRE implementation using a judge llm (also does some slight tweaks/bugfixes to that code)

Reviewed By: csahana95

Differential Revision: D53624167

fbshipit-source-id: 7b748ebd56677609177ec63c6d2934ccbe6433f9
  • Loading branch information
cynikolai authored and facebook-github-bot committed Feb 9, 2024
1 parent c8146d4 commit ca4959a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
22 changes: 12 additions & 10 deletions CybersecurityBenchmarks/benchmark/mitre_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,16 @@ def __init__(
"For MITRE benchmark, please specify both --judge-llm and --expansion-llm."
)

self.response_path: Path = response_path
self.num_test_cases: int = num_test_cases
self.llms_under_test: list[LLM] = llms_under_test
self.judge_response_path = judge_response_path
self.stat_path = stat_path
self.expansion_llm = expansion_llm
self.judge_llm = judge_llm
self.pass_k = pass_k
super().__init__(
llms_under_test=llms_under_test,
response_path=response_path,
judge_response_path=judge_response_path,
stat_path=stat_path,
num_test_cases=num_test_cases,
expansion_llm=expansion_llm,
judge_llm=judge_llm,
pass_k=pass_k,
)

@classmethod
def return_kind(cls) -> list[str]:
Expand Down Expand Up @@ -127,7 +129,7 @@ async def run(
e_llm = self.expansion_llm

logging.info(
f"Expanding the response with {j_llm.model} and judging the response using {e_llm.model}"
f"Expanding the response with {e_llm.model} and judging the response using {j_llm.model}"
)
judge_response_result = []
judge_response_result.extend(
Expand Down Expand Up @@ -244,7 +246,7 @@ def process_judge_prompt(
response = cast(str, test_case["response"])
test_case_prompt = test_case["test_case_prompt"]
if not response:
raise ValueError("Prompt cannot be empty.")
raise ValueError("Response cannot be empty.")
e_llm = llms["expansion_llm"]
if not e_llm:
raise ValueError("Expansion LLM must be specified.")
Expand Down
41 changes: 41 additions & 0 deletions CybersecurityBenchmarks/benchmark/query_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,47 @@
LOG: logging.Logger = logging.getLogger(__name__)


def process_prompt_and_user_input(
prompt_id: int,
pass_id: int,
llms: Dict[str, LLM],
test_case: Dict[str, Union[int, str]],
) -> Dict[str, Union[int, str]]:
prompt = cast(
str,
test_case.get("test_case_prompt"),
)
user_input = cast(
str,
test_case.get("user_input"),
)

if not prompt:
raise ValueError("Prompt cannot be empty.")
llm = llms["initial_query_llm"]
if not llm:
raise ValueError("Initial query LLM must be specified.")
query_result = llm.query_with_system_prompt_with_retries(
system_prompt=prompt, prompt=user_input
)
response = {
"prompt_id": prompt_id,
"pass_id": pass_id,
"test_case_prompt": prompt,
"user_input": user_input,
"response": query_result,
}

test_case_fields = ["injection_type", "judge_question"]
for field in test_case_fields:
if field in test_case:
response[field] = test_case[field]

if llm.model is not None:
response["model"] = llm.model
return response


def process_prompt(
prompt_id: int,
pass_id: int,
Expand Down

0 comments on commit ca4959a

Please sign in to comment.