Skip to content

Commit

Permalink
Merge pull request #49 from weaviate/framework-abstraction
Browse files Browse the repository at this point in the history
Abstract Function Calling as a framework
  • Loading branch information
CShorten authored Jan 11, 2025
2 parents 453ec86 + de7ab03 commit b9546d3
Show file tree
Hide file tree
Showing 35 changed files with 729 additions and 10 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ jobs:

- name: Install dependencies
run: pip install -r requirements.txt

- name: Install dependencies
run: |
pip install -r requirements.txt
pip install pytest requests
# Step 4: Run tests
- name: Run pytest
run: pytest
run: |
PYTHONPATH=$PYTHONPATH:$(pwd) pytest
File renamed without changes.
File renamed without changes.
183 changes: 177 additions & 6 deletions notebooks/pydantic-ai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -14,7 +14,7 @@
"| Span Processor: SimpleSpanProcessor\n",
"| Collector Endpoint: https://app.phoenix.arize.com/v1/traces\n",
"| Transport: HTTP\n",
"| Transport Headers: {'api_key': '****'}\n",
"| Transport Headers: {'api_key': '****', 'authorization': '****'}\n",
"| \n",
"| Using a default SpanProcessor. `add_span_processor` will overwrite this default.\n",
"| \n",
Expand All @@ -40,7 +40,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -51,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -83,7 +83,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -97,7 +97,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"greeting=\"Hello, Alice! It's great to see you!\"\n"
"greeting=\"Hello, Alice! Hope you're having a wonderful day!\"\n"
]
}
],
Expand Down Expand Up @@ -160,6 +160,177 @@
"{\"messages\": [{\"role\": \"system\", \"content\": \"You are a personalized greeter AI. Return a short greeting for the user.\"}, {\"role\": \"system\", \"content\": \"The user's name is 'Alice'.\"}, {\"role\": \"user\", \"content\": \"Hi, can you greet me?\"}], \"model\": \"gpt-4o\", \"n\": 1, \"parallel_tool_calls\": true, \"stream\": false, \"tool_choice\": \"required\", \"tools\": [{\"type\": \"function\", \"function\": {\"name\": \"final_result\", \"description\": \"Structured output from the AI.\", \"parameters\": {\"properties\": {\"greeting\": {\"description\": \"A short greeting to the user\", \"title\": \"Greeting\", \"type\": \"string\"}}, \"required\": [\"greeting\"], \"title\": \"GreetingResult\", \"type\": \"object\"}}}]}\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Step 1] Analyzing the user's problem to find relevant database queries...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Failed to export batch code: 204, reason: \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Analysis result (queries):\n",
"{\n",
" \"queries\": [\n",
" \"SELECT name, rating, distance FROM restaurants WHERE location = 'current location' ORDER BY rating DESC, distance ASC;\"\n",
" ]\n",
"}\n",
"\n",
"[Step 2] Formatting the queries into Pydantic API requests...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Failed to export batch code: 204, reason: \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Formatting result (API requests):\n",
"{\n",
" \"requests\": [\n",
" {\n",
" \"query_text\": \"SELECT name, rating, distance FROM restaurants WHERE location = 'current location' ORDER BY rating DESC, distance ASC;\",\n",
" \"endpoint\": \"/execute-query\"\n",
" }\n",
" ]\n",
"}\n"
]
}
],
"source": [
"import nest_asyncio\n",
"nest_asyncio.apply()\n",
"\n",
"import asyncio\n",
"import json\n",
"from dataclasses import dataclass\n",
"from typing import List\n",
"\n",
"from pydantic import BaseModel, Field\n",
"from pydantic_ai import Agent, RunContext\n",
"\n",
"\n",
"@dataclass\n",
"class ProblemContext:\n",
" user_name: str\n",
"\n",
"\n",
"class DatabaseQueryAnalysis(BaseModel):\n",
" queries: List[str] = Field(default_factory=list)\n",
"\n",
"\n",
"analysis_agent = Agent(\n",
" model=\"openai:gpt-4o\",\n",
" deps_type=ProblemContext,\n",
" result_type=DatabaseQueryAnalysis,\n",
" system_prompt=(\n",
" \"You are an AI that, given a user's problem, identifies what database queries \"\n",
" \"would be needed to retrieve information that solves the user's problem.\"\n",
" ),\n",
")\n",
"\n",
"\n",
"@analysis_agent.system_prompt\n",
"async def analysis_agent_system_prompt(ctx: RunContext[ProblemContext]) -> str:\n",
" return (\n",
" f\"The user's name is {ctx.deps.user_name!r}. \"\n",
" \"Analyze the user's input and suggest relevant database queries.\"\n",
" )\n",
"\n",
"\n",
"class QueryAPIRequest(BaseModel):\n",
" query_text: str = Field(description=\"The raw query to be executed.\")\n",
" endpoint: str = Field(\n",
" default=\"/execute-query\",\n",
" description=\"The endpoint where the query should be sent.\"\n",
" )\n",
"\n",
"\n",
"class QueryAPIRequests(BaseModel):\n",
" requests: List[QueryAPIRequest] = Field(default_factory=list)\n",
"\n",
"\n",
"formatting_agent = Agent(\n",
" model=\"openai:gpt-4o\",\n",
" deps_type=ProblemContext,\n",
" result_type=QueryAPIRequests,\n",
" system_prompt=(\n",
" \"You are an AI that formats database queries into the provided Pydantic BaseModels for API requests.\"\n",
" ),\n",
")\n",
"\n",
"\n",
"@formatting_agent.system_prompt\n",
"async def formatting_agent_prompt(ctx: RunContext[ProblemContext]) -> str:\n",
" return (\n",
" f\"The user's name is {ctx.deps.user_name!r}. \"\n",
" \"Convert the list of database queries into `QueryAPIRequest` objects, \"\n",
" \"wrapped in the `QueryAPIRequests` model.\"\n",
" )\n",
"\n",
"\n",
"async def main():\n",
" deps = ProblemContext(user_name=\"Alice\")\n",
"\n",
" user_problem = (\n",
" \"I need to find the top-rated restaurants near me. \"\n",
" \"Show me some options sorted by rating and distance.\"\n",
" )\n",
"\n",
" print(\"[Step 1] Analyzing the user's problem to find relevant database queries...\")\n",
" analysis_result = await analysis_agent.run(user_problem, deps=deps)\n",
"\n",
" # Convert to dict and JSON-serialize with indentation\n",
" print(\"Analysis result (queries):\")\n",
" print(json.dumps(analysis_result.data.model_dump(), indent=2))\n",
"\n",
" queries_to_format = analysis_result.data.queries\n",
" second_input = (\n",
" \"Here are the queries to format:\\n\" + \"\\n\".join(queries_to_format)\n",
" )\n",
"\n",
" print(\"\\n[Step 2] Formatting the queries into Pydantic API requests...\")\n",
" formatting_result = await formatting_agent.run(second_input, deps=deps)\n",
"\n",
" # Again, convert to dict and JSON-serialize with indentation\n",
" print(\"Formatting result (API requests):\")\n",
" print(json.dumps(formatting_result.data.model_dump(), indent=2))\n",
"\n",
"\n",
"# If you're in an environment that supports top-level await (e.g., Jupyter):\n",
"await main()\n",
"\n",
"# If you're in a standard Python script:\n",
"# if __name__ == \"__main__\":\n",
"# asyncio.run(main())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pytest==8.3.4
pytest==8.3.4
requests>=2.31.0
Empty file added src/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
3 changes: 3 additions & 0 deletions src/lm/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
# together models are accessed through the openai SDK with a different base URL
LMModelProvider = Literal["ollama", "openai", "anthropic", "cohere", "together"]

# need to add models to this...!
# fixes test_lm.py

class LMService():
def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/lm/pydantic_agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Tie together the nodes and executors here
1 change: 1 addition & 0 deletions src/lm/pydantic_agent/executors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Weaviate Query Executor
1 change: 1 addition & 0 deletions src/lm/pydantic_agent/nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Nodes for prompts
Loading

0 comments on commit b9546d3

Please sign in to comment.