Skip to content

Commit

Permalink
add better validation to call_graph
Browse files Browse the repository at this point in the history
  • Loading branch information
granawkins committed May 10, 2024
1 parent 2c5b275 commit 566f6de
Showing 1 changed file with 39 additions and 24 deletions.
63 changes: 39 additions & 24 deletions ragdaemon/annotators/call_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from functools import partial
import json
from pathlib import Path
from typing import Any, Optional
Expand All @@ -19,18 +20,6 @@
)


def is_calls_valid(calls: dict[str, list[dict[str, str | list[int]]]]) -> bool:
"""Expected structure: {path/to/file:class.method: [1, 2, 3]}"""
for target, lines in calls.items():
if not target or not isinstance(target, str):
return False
if not lines or not isinstance(lines, list):
return False
if not all(isinstance(line, int) for line in lines):
return False
return True


class CallGraph(Annotator):
name = "call_graph"
call_field_id = "calls"
Expand Down Expand Up @@ -89,26 +78,52 @@ def is_complete(self, graph: KnowledgeGraph, db: Database) -> bool:
return False
return True

async def get_llm_response(self, document: str, graph: KnowledgeGraph) -> str:
async def get_llm_response(
self, document: str, graph: KnowledgeGraph
) -> dict[str, list[int]]:
if self.spice_client is None:
raise RagdaemonError("Spice client is not initialized.")

messages = SpiceMessages(self.spice_client)
messages.add_system_prompt(name="call_graph.base")
messages.add_user_prompt("call_graph.user", document=document)

def validate(response: str, max_line: int) -> bool:
"""Expected structure: {path/to/file:class.method: [1, 2, 3]}"""
try:
calls = json.loads(response)
except json.JSONDecodeError:
return False
for target, lines in calls.items():
if not target or not isinstance(target, str):
return False
if not lines or not isinstance(lines, list):
return False
if not all(isinstance(line, int) for line in lines):
return False
if any(line > max_line for line in lines):
return False
return True

validator = partial(validate, max_line=len(document.split("\n")))
async with semaphore:
response = await self.spice_client.get_response(
messages=messages,
model=self.model,
response_format={"type": "json_object"},
)
try:
calls = json.loads(response.text)
except json.JSONDecodeError:
raise RagdaemonError("Failed to parse JSON response.")
if not is_calls_valid(calls):
raise RagdaemonError(f"Model returned malformed calls: {calls}")
try:
response = await self.spice_client.get_response(
messages=messages,
model=self.model,
response_format={"type": "json_object"},
validator=validator,
retries=2,
)
except ValueError: # Raised after all retries fail
if self.verbose:
file = document.split("\n")[0]
print(
f"Failed to generate call graph for {file} after 3 tries, Skipping."
)
return {}

calls = json.loads(response.text)

# Resolve library calls
targets = set(calls.keys())
Expand Down

0 comments on commit 566f6de

Please sign in to comment.