Skip to content

Commit

Permalink
Optimize remote db (#41)
Browse files Browse the repository at this point in the history
* optimize HierarchyAnnotator for remote db

* make db.get and db.update batch calls
* add document to the graph for local use

* optimize ChunkerAnnotator for remote db

* optimize DiffAnnotator for remote db

* optimize CallGraphAnnotator for remote db

* remove db from ContextBuilder entirely

* optimize SummarizerAnnotator for remote db

* optimize database.query for remote db

* minor version bump

* fixes from testing

* fix chunk parent resolution

* fix chunker issues

* switch DEFAULT_COMPLETIONS_MODEL to gpt-4o

* done use chroma.upsert, it duplicates embeddings

* format fixes
  • Loading branch information
granawkins authored May 17, 2024
1 parent 8ce7dca commit 781e8fb
Show file tree
Hide file tree
Showing 22 changed files with 410 additions and 451 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ packages=["ragdaemon"]

[project]
name = "ragdaemon"
version = "0.4.7"
version = "0.5.0"
description = "Generate and render a call graph for a Python project."
readme = "README.md"
dependencies = [
Expand Down
2 changes: 1 addition & 1 deletion ragdaemon/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.7"
__version__ = "0.5.0"
27 changes: 15 additions & 12 deletions ragdaemon/annotators/call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from spice.models import TextModel

from ragdaemon.annotators.base_annotator import Annotator
from ragdaemon.database import Database
from ragdaemon.database import Database, remove_update_db_duplicates
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.errors import RagdaemonError
from ragdaemon.utils import (
Expand Down Expand Up @@ -155,13 +155,11 @@ async def get_file_call_data(
node: str,
data: dict,
graph: KnowledgeGraph,
db: Database,
retries: int = 1,
):
"""Generate and save call data for a file node to graph and db"""
"""Generate and save call data for a file node to graph"""
calls = {}
record = db.get(data["checksum"])
document = record["documents"][0]
document = data["document"]

# Insert line numbers
lines = document.split("\n")
Expand All @@ -184,10 +182,6 @@ async def get_file_call_data(
else "Skipping."
)

# Save to db and graph
metadatas = record["metadatas"][0]
metadatas[self.call_field_id] = json.dumps(calls)
db.update(data["checksum"], metadatas=metadatas)
data[self.call_field_id] = calls

async def annotate(
Expand All @@ -212,17 +206,27 @@ async def annotate(
files_with_calls.append((node, data))
# Generate/add call data for nodes that don't have it
tasks = []
files_just_updated = set()
for node, data in files_with_calls:
if refresh or data.get(self.call_field_id, None) is None:
checksum = data.get("checksum")
if checksum is None:
raise RagdaemonError(f"Node {node} has no checksum.")
tasks.append(self.get_file_call_data(node, data, graph, db))
tasks.append(self.get_file_call_data(node, data, graph))
files_just_updated.add(node)
if len(tasks) > 0:
if self.verbose:
await tqdm.gather(*tasks, desc="Generating call graph")
else:
await asyncio.gather(*tasks)
update_db = {"ids": [], "metadatas": []}
for node in files_just_updated:
data = graph.nodes[node]
update_db["ids"].append(data["checksum"])
metadatas = {self.call_field_id: json.dumps(data[self.call_field_id])}
update_db["metadatas"].append(metadatas)
update_db = remove_update_db_duplicates(**update_db)
db.update(**update_db)

# Add call edges to graph. Each call should have only ONE source; if there are
# chunks, the source is the matching chunk, otherwise it's the file.
Expand All @@ -244,8 +248,7 @@ async def annotate(
checksum = data.get("checksum")
if checksum is None:
raise RagdaemonError(f"File node {file} is missing checksum field.")
record = db.get(checksum)
document = record["documents"][0]
document = data["document"]
for i in range(1, len(document.split("\n")) + 1):
line_index[i] = file
else:
Expand Down
192 changes: 92 additions & 100 deletions ragdaemon/annotators/chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@

import asyncio
import json
from copy import deepcopy
from pathlib import Path
from typing import Any, Coroutine, Optional
from typing import Any, Optional

from tqdm.asyncio import tqdm

from ragdaemon.annotators.base_annotator import Annotator
from ragdaemon.database import Database, remove_add_to_db_duplicates
from ragdaemon.database import (
Database,
remove_add_to_db_duplicates,
remove_update_db_duplicates,
)
from ragdaemon.errors import RagdaemonError
from ragdaemon.graph import KnowledgeGraph
from ragdaemon.utils import DEFAULT_CODE_EXTENSIONS, get_document, hash_str, truncate
Expand Down Expand Up @@ -64,92 +69,18 @@ async def chunk_document(self, document: str) -> list[dict[str, Any]]:
"""Return a list of {id, ref} chunks for the given document."""
raise NotImplementedError()

async def get_file_chunk_data(self, node, data, db):
async def get_file_chunk_data(self, node, data):
"""Generate and save chunk data for a file node to graph and db"""
record = db.get(data["checksum"])
document = record["documents"][0]
document = data["document"]
try:
chunks = await self.chunk_document(document)
except RagdaemonError:
if self.verbose:
print(f"Error chunking {node}; skipping.")
chunks = []
# Save to db and graph
metadatas = record["metadatas"][0]
metadatas[self.chunk_field_id] = json.dumps(chunks)
db.update(data["checksum"], metadatas=metadatas)
chunks = sorted(chunks, key=lambda x: len(x["id"]))
data[self.chunk_field_id] = chunks

def add_file_chunks_to_graph(
self,
file: str,
data: dict,
graph: KnowledgeGraph,
db: Database,
refresh: bool = False,
) -> dict[str, list[Any]]:
"""Load chunks from file data into db/graph"""

# Grab and validate chunks for given file
chunks = data.get(self.chunk_field_id)
if chunks is None:
raise RagdaemonError(f"Node {file} missing {self.chunk_field_id}")
if isinstance(chunks, str):
chunks = json.loads(chunks)
data[self.chunk_field_id] = chunks

add_to_db = {"ids": [], "documents": [], "metadatas": []}
if len(chunks) == 0:
return add_to_db
base_id = f"{file}:BASE"
if not any(chunk["id"] == base_id for chunk in chunks):
raise RagdaemonError(f"Node {file} missing base chunk")
edges_to_add = {(file, base_id)}
for chunk in chunks:
# Locate or create record for chunk
id, ref = chunk["id"], chunk["ref"]
document = get_document(ref, Path(graph.graph["cwd"]))
checksum = hash_str(document)
records = db.get(checksum)["metadatas"]
if not refresh and len(records) > 0:
record = records[0]
else:
record = {
"id": id,
"type": "chunk",
"ref": chunk["ref"],
"checksum": checksum,
"active": False,
}
document, truncate_ratio = truncate(document, db.embedding_model)
if truncate_ratio > 0 and self.verbose:
print(f"Truncated {id} by {truncate_ratio:.2%}")
add_to_db["ids"].append(checksum)
add_to_db["documents"].append(document)
add_to_db["metadatas"].append(record)

# Add chunk to graph and connect hierarchy edges
graph.add_node(record["id"], **record)

def _link_to_base_chunk(_id):
"""Recursively create links from _id to base chunk."""
path_str, chunk_str = _id.split(":")
chunk_list = chunk_str.split(".")
_parent = (
f"{path_str}:{'.'.join(chunk_list[:-1])}"
if len(chunk_list) > 1
else base_id
)
edges_to_add.add((_parent, _id))
if _parent != base_id:
_link_to_base_chunk(_parent)

if id != base_id:
_link_to_base_chunk(id)
for source, target in edges_to_add:
graph.add_edge(source, target, type="hierarchy")
return add_to_db

async def annotate(
self, graph: KnowledgeGraph, db: Database, refresh: bool = False
) -> KnowledgeGraph:
Expand All @@ -174,37 +105,98 @@ async def annotate(
files_just_chunked = set()
for node, data in files_with_chunks:
if refresh or data.get(self.chunk_field_id, None) is None:
tasks.append(self.get_file_chunk_data(node, data, db))
tasks.append(self.get_file_chunk_data(node, data))
files_just_chunked.add(node)
elif isinstance(data[self.chunk_field_id], str):
data[self.chunk_field_id] = json.loads(data[self.chunk_field_id])
if len(tasks) > 0:
if self.verbose:
await tqdm.gather(*tasks, desc="Chunking files...")
else:
await asyncio.gather(*tasks)
update_db = {"ids": [], "metadatas": []}
for node in files_just_chunked:
data = graph.nodes[node]
update_db["ids"].append(data["checksum"])
metadatas = {self.chunk_field_id: json.dumps(data[self.chunk_field_id])}
update_db["metadatas"].append(metadatas)
update_db = remove_update_db_duplicates(**update_db)
db.update(**update_db)

# Process chunks
add_to_db = {"ids": [], "documents": [], "metadatas": []}
remove_from_db = set()
# 1. Add all chunks to graph
all_chunk_ids = set()
for file, data in files_with_chunks:
try:
refresh = refresh or file in files_just_chunked
_add_to_db = self.add_file_chunks_to_graph(
file, data, graph, db, refresh
)
for field, values in _add_to_db.items():
add_to_db[field].extend(values)
except RagdaemonError as e:
# If there's a problem with the chunks, remove the file from the db.
# This, along with 'files_just_chunked', prevents invalid database
# records perpetuating.
if self.verbose:
print(f"Error adding chunks for {file}:\n{e}. Removing db record.")
remove_from_db.add(data["checksum"])
if len(data[self.chunk_field_id]) == 0:
continue
if len(remove_from_db) > 0:
db.delete(list(remove_from_db))
raise RagdaemonError(f"Chunking error, try again.")
# Sort such that "parents" are added before "children"
base_id = f"{file}:BASE"
chunks = [c for c in data[self.chunk_field_id] if c["id"] != base_id]
chunks.sort(key=lambda x: len(x["id"]))
base_chunk = [c for c in data[self.chunk_field_id] if c["id"] == base_id]
if len(base_chunk) != 1:
raise RagdaemonError(f"Node {file} missing base chunk")
chunks = base_chunk + chunks
# Load chunks into graph
for chunk in chunks:
id, ref = chunk["id"], chunk["ref"]
document = get_document(ref, Path(graph.graph["cwd"]))
chunk_data = {
"id": id,
"ref": ref,
"type": "chunk",
"document": document,
"checksum": hash_str(document),
"active": False,
}
graph.add_node(id, **chunk_data)
all_chunk_ids.add(id)
# Locate the parent and add hierarchy edge
chunk_str = id.split(":")[1]
if chunk_str == "BASE":
parent = file
elif "." not in chunk_str:
parent = base_id
else:
parts = chunk_str.split(".")
while True:
parent = f"{file}:{'.'.join(parts[:-1])}"
if parent in graph:
break
parent_str = parent.split(":")[1]
if "." not in parent_str:
# If we can't find a parent, use the base node.
if self.verbose:
print(f"No parent node found for {id}")
parent = base_id
break
# If intermediate parents are missing, skip them
parts = parent_str.split(".")
graph.add_edge(parent, id, type="hierarchy")

# 2. Get metadata for all chunks from db
all_chunk_checksums = [
graph.nodes[chunk]["checksum"] for chunk in all_chunk_ids
]
response = db.get(ids=all_chunk_checksums, include=["metadatas"])
db_data = {data["id"]: data for data in response["metadatas"]}
add_to_db = {"ids": [], "documents": [], "metadatas": []}
for chunk in all_chunk_ids:
if chunk in db_data:
# 3. Add db metadata for nodes that have it
graph.nodes[chunk].update(db_data[chunk])
else:
# 4. Add to db nodes that don't
data = deepcopy(graph.nodes[chunk])
document = data.pop("document")
document, truncate_ratio = truncate(document, db.embedding_model)
if truncate_ratio > 0 and self.verbose:
print(f"Truncated {chunk} by {truncate_ratio:.2%}")
add_to_db["ids"].append(data["checksum"])
add_to_db["documents"].append(document)
add_to_db["metadatas"].append(data)
if len(add_to_db["ids"]) > 0:
add_to_db = remove_add_to_db_duplicates(**add_to_db)
db.upsert(**add_to_db)
db.add(**add_to_db)

return graph
8 changes: 5 additions & 3 deletions ragdaemon/annotators/chunker_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ class ChunkerLLM(Chunker):
def __init__(
self,
*args,
batch_size: int = 800,
model: Optional[TextModel | str] = DEFAULT_COMPLETION_MODEL,
**kwargs,
):
super().__init__(*args, **kwargs)
self.batch_size = batch_size
self.model = model

async def get_llm_response(
Expand Down Expand Up @@ -88,7 +90,7 @@ async def get_llm_response(
return chunks

async def chunk_document(
self, document: str, batch_size: int = 1000, retries: int = 1
self, document: str, retries: int = 1
) -> list[dict[str, Any]]:
"""Parse file_lines into a list of {id, ref} chunks."""
lines = document.split("\n")
Expand All @@ -100,9 +102,9 @@ async def chunk_document(

# Get raw llm output: {id, start_line, end_line}
chunks = list[dict[str, Any]]()
n_batches = (len(file_lines) + batch_size - 1) // batch_size
n_batches = (len(file_lines) + self.batch_size - 1) // self.batch_size
for i in range(n_batches):
batch_lines = file_lines[i * batch_size : (i + 1) * batch_size]
batch_lines = file_lines[i * self.batch_size : (i + 1) * self.batch_size]
last_chunk = chunks.pop() if chunks else None
for j in range(retries + 1, 0, -1):
try:
Expand Down
Loading

0 comments on commit 781e8fb

Please sign in to comment.