Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improve JSON parsing with fallback support #164

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 71 additions & 26 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,61 +278,106 @@ def get_response_from_llm(
return content, new_msg_history


def extract_json_between_markers(llm_output):
# Regular expression pattern to find JSON content between ```json and ```
json_pattern = r"```json(.*?)```"
matches = re.findall(json_pattern, llm_output, re.DOTALL)
def extract_json_between_markers(llm_output, strict_mode=True):
"""Extract JSON from LLM output with fallback parsing.

if not matches:
# Fallback: Try to find any JSON-like content in the output
json_pattern = r"\{.*?\}"
matches = re.findall(json_pattern, llm_output, re.DOTALL)
Args:
llm_output (str): Raw LLM output
strict_mode (bool): If True, require JSON markers (default: True)

for json_string in matches:
json_string = json_string.strip()
Returns:
dict: Parsed JSON object

Raises:
ValueError: If no valid JSON could be parsed or if strict_mode requirements not met
"""
def try_parse_json(text):
"""Helper to try parsing JSON with cleaning."""
try:
parsed_json = json.loads(json_string)
return parsed_json
return json.loads(text.strip())
except json.JSONDecodeError:
# Attempt to fix common JSON issues
# Try cleaning and parsing again
clean_text = re.sub(r"[\x00-\x1F\x7F]", "", text)
try:
# Remove invalid control characters
json_string_clean = re.sub(r"[\x00-\x1F\x7F]", "", json_string)
parsed_json = json.loads(json_string_clean)
return parsed_json
return json.loads(clean_text.strip())
except json.JSONDecodeError:
continue # Try next match
return None

# Try marked JSON first
pattern = r"```json\n(.*?)```"
matches = re.findall(pattern, llm_output, re.DOTALL)

# If we found marked JSON, try to parse it
if matches:
for match in matches:
result = try_parse_json(match)
if result is not None:
return result
# If we found markers but couldn't parse any content
if strict_mode:
raise ValueError("Found JSON markers but content was invalid")

# If we're in strict mode and got here, no valid marked JSON was found
if strict_mode:
raise ValueError("No JSON markers found in output")

# In non-strict mode, try finding any JSON-like content
pattern = r"\{.*?\}"
matches = re.findall(pattern, llm_output, re.DOTALL)
for match in matches:
result = try_parse_json(match)
if result is not None:
return result

# Last resort: try parsing the entire input
result = try_parse_json(llm_output)
if result is not None:
return result

return None # No valid JSON found
raise ValueError("Failed to parse any JSON content")


def create_client(model):
if model.startswith("claude-"):
print(f"Using Anthropic API with model {model}.")
return anthropic.Anthropic(), model
client = anthropic.Anthropic()
client.strict_json = True
return client, model
elif model.startswith("bedrock") and "claude" in model:
client_model = model.split("/")[-1]
print(f"Using Amazon Bedrock with model {client_model}.")
return anthropic.AnthropicBedrock(), client_model
client = anthropic.AnthropicBedrock()
client.strict_json = True
return client, client_model
elif model.startswith("vertex_ai") and "claude" in model:
client_model = model.split("/")[-1]
print(f"Using Vertex AI with model {client_model}.")
return anthropic.AnthropicVertex(), client_model
client = anthropic.AnthropicVertex()
client.strict_json = True
return client, client_model
elif 'gpt' in model:
print(f"Using OpenAI API with model {model}.")
return openai.OpenAI(), model
client = openai.OpenAI()
client.strict_json = True
return client, model
elif model in ["o1-preview-2024-09-12", "o1-mini-2024-09-12"]:
print(f"Using OpenAI API with model {model}.")
client = openai.OpenAI()
client.strict_json = True
return client, model
elif model == "deepseek-coder-v2-0724":
return openai.OpenAI(), model
elif model in ["deepseek-chat", "deepseek-reasoner"]:
print(f"Using OpenAI API with {model}.")
return openai.OpenAI(
client = openai.OpenAI(
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url="https://api.deepseek.com"
), model
)
client.strict_json = True
return client, model
elif model == "llama3.1-405b":
print(f"Using OpenAI API with {model}.")
return openai.OpenAI(
client = openai.OpenAI(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1"
), "meta-llama/llama-3.1-405b-instruct"
Expand Down
39 changes: 39 additions & 0 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
from ai_scientist.llm import extract_json_between_markers

def test_extract_json_unmarked():
"""Test parsing unmarked but valid JSON."""
valid_json = '{"key": "value"}'
result = extract_json_between_markers(valid_json, strict_mode=False)
assert result == {"key": "value"}

def test_extract_json_marked():
"""Test parsing JSON with markers."""
marked_json = '```json\n{"key": "value"}\n```'
result = extract_json_between_markers(marked_json)
assert result == {"key": "value"}

def test_extract_json_fallback():
"""Test fallback behavior with invalid marked JSON."""
invalid_marked = '```json\nInvalid JSON\n```'
valid_unmarked = '{"key": "value"}'
combined = f"{invalid_marked}\n{valid_unmarked}"
result = extract_json_between_markers(combined, strict_mode=False)
assert result == {"key": "value"}

def test_extract_json_strict_mode_error():
"""Test strict mode error handling."""
valid_json = '{"key": "value"}'
with pytest.raises(ValueError, match="No JSON markers found in output"):
extract_json_between_markers(valid_json, strict_mode=True)

def test_extract_json_invalid_content():
"""Test handling of invalid JSON content."""
invalid_json = '{"key": value}' # Missing quotes around value
with pytest.raises(ValueError, match="Failed to parse any JSON content"):
extract_json_between_markers(invalid_json, strict_mode=False)

def test_extract_json_no_content():
"""Test handling of empty or non-JSON content."""
with pytest.raises(ValueError, match="Failed to parse any JSON content"):
extract_json_between_markers("No JSON here", strict_mode=False)