diff --git a/kura/cli/cli.py b/kura/cli/cli.py index e8015b9..83c2ac5 100644 --- a/kura/cli/cli.py +++ b/kura/cli/cli.py @@ -2,14 +2,20 @@ import uvicorn from kura.cli.server import api from rich import print +import os app = typer.Typer() @app.command() -def start_app(): +def start_app( + dir: str = typer.Option( + "./checkpoints", + help="Directory to use for checkpoints, relative to the current directory", + ), +): """Start the FastAPI server""" - + os.environ["KURA_CHECKPOINT_DIR"] = dir uvicorn.run(api, host="0.0.0.0", port=8000) print( "\n[bold green]🚀 Access website at[/bold green] [bold blue][http://localhost:8000](http://localhost:8000)[/bold blue]\n" diff --git a/kura/cli/server.py b/kura/cli/server.py index ca88b9d..d606d2b 100644 --- a/kura/cli/server.py +++ b/kura/cli/server.py @@ -11,6 +11,7 @@ generate_new_chats_per_week_data, ) import json +import os api = FastAPI() @@ -57,16 +58,21 @@ async def analyse_conversations(conversation_data: ConversationData): for conversation in conversation_data.data ] - # Load clusters from checkpoint file if it exists clusters_file = ( - Path(__file__).parent.parent.parent - / "checkpoints/dimensionality_checkpoints.json" + Path(os.path.abspath(os.environ["KURA_CHECKPOINT_DIR"])) + / "dimensionality_checkpoints.json" ) clusters = [] + print(clusters_file) + + # Load clusters from checkpoint file if it exists + if not clusters_file.exists(): - kura = Kura() - kura.conversations = conversations + kura = Kura( + checkpoint_dir=Path(os.path.abspath(os.environ["KURA_CHECKPOINT_DIR"])), + conversations=conversations[:100], + ) await kura.cluster_conversations() with open(clusters_file) as f: diff --git a/kura/kura.py b/kura/kura.py index b7f1e87..639af8b 100644 --- a/kura/kura.py +++ b/kura/kura.py @@ -25,10 +25,17 @@ def __init__( meta_cluster_model: BaseMetaClusterModel = MetaClusterModel(), dimensionality_reduction: BaseDimensionalityReduction = HDBUMAP(), max_clusters: int = 10, - checkpoint_dir: str = "checkpoints", + checkpoint_dir: str = "./checkpoints", cluster_checkpoint_name: str = "clusters.json", meta_cluster_checkpoint_name: str = "meta_clusters.json", ): + # Override checkpoint dirs so that they're the same for the models + summarisation_model.checkpoint_dir = checkpoint_dir + cluster_model.checkpoint_dir = checkpoint_dir + meta_cluster_model.checkpoint_dir = checkpoint_dir + dimensionality_reduction.checkpoint_dir = checkpoint_dir + + self.embedding_model = embedding_model self.embedding_model = embedding_model self.summarisation_model = summarisation_model self.conversations = conversations diff --git a/main.py b/main.py deleted file mode 100644 index 84faa82..0000000 --- a/main.py +++ /dev/null @@ -1,123 +0,0 @@ -from helpers.conversation import load_conversations, summarise_conversation -from helpers.clusters import cluster_summaries, reduce_clusters, embed_summaries -from asyncio import run, Semaphore -import instructor -import google.generativeai as genai -from tqdm.asyncio import tqdm_asyncio as asyncio -from rich import print -import json -import os -from openai import AsyncOpenAI - -from helpers.types import Cluster - - -async def generate_clusters( - SUMMARIES_PER_CLUSTER=20, - CHILD_CLUSTERS_PER_CLUSTER=10, - MAX_FINAL_CLUSTERS=10, - start_step=None, -): - conversations = load_conversations("conversations.json") - - # Step 1: Generate or load summaries - if start_step == "summarize" or not os.path.exists("checkpoints/summaries.json"): - client = instructor.from_gemini( - genai.GenerativeModel("gemini-1.5-flash-latest"), - use_async=True, - ) - sem = Semaphore(50) - - summaries = await asyncio.gather( - *[ - summarise_conversation(client, sem, conversation) - for conversation in conversations - ] - ) - - # Save summaries - with open("checkpoints/summaries.json", "w") as f: - for summary in summaries: - f.write(json.dumps(summary) + "\n") - print("Generated and saved summaries") - else: - with open("checkpoints/summaries.json", "r") as f: - summaries = [json.loads(line) for line in f] - print("Loaded existing summaries") - - # Step 2: Generate or load embeddings - if start_step in ["summarize", "embed"] or not os.path.exists( - "checkpoints/embedded_summaries.json" - ): - oai_client = AsyncOpenAI() - sem = Semaphore(10) - await embed_summaries(oai_client, sem, summaries) - print("Generated and saved embeddings") - - # Step 3: Generate or load base clusters - if start_step in ["summarize", "embed", "cluster"] or not os.path.exists( - "checkpoints/base_clusters.json" - ): - client = instructor.from_gemini( - genai.GenerativeModel("gemini-1.5-flash-latest"), - use_async=True, - ) - sem = Semaphore(10) - - base_clusters = await cluster_summaries( - client, sem, summaries, SUMMARIES_PER_CLUSTER - ) - - # Save base clusters - with open("checkpoints/base_clusters.json", "w") as f: - for cluster in base_clusters: - f.write(cluster.model_dump_json() + "\n") - print(f"Generated and saved {len(base_clusters)} base clusters") - else: - with open("checkpoints/base_clusters.json", "r") as f: - base_clusters = [Cluster.model_validate_json(line) for line in f] - print("Loaded existing base clusters") - - # Create checkpoints directory if it doesn't exist - os.makedirs("checkpoints", exist_ok=True) - - clusters = base_clusters - root_clusters = base_clusters - iteration = 0 - - while len(root_clusters) >= MAX_FINAL_CLUSTERS: - print(f"\nIteration {iteration}") - print(f"Starting with {len(root_clusters)} root clusters") - - # Get both updated original clusters and new parent clusters - reduced_clusters = await reduce_clusters( - client, sem, root_clusters, CHILD_CLUSTERS_PER_CLUSTER - ) - - # Recalculate root clusters for next iteration - root_clusters = [c for c in reduced_clusters if not c.parent_id] - - # Remove the outdated versions of clusters that were just reduced - old_cluster_ids = {rc.id for rc in root_clusters} - clusters = [c for c in clusters if c.id not in old_cluster_ids] - - # Add both the updated original clusters and new parent clusters - clusters.extend(reduced_clusters) - - print(f"Reduced to {len(root_clusters)} root clusters") - print(f"Total clusters after reduction: {len(clusters)}") - - iteration += 1 - - print(f"\nFinal reduction complete:") - print(f"Root clusters: {len(root_clusters)}") - print(f"Total clusters (including children): {len(clusters)}") - - # Save final clusters - with open("clusters.json", "w") as f: - for cluster in clusters: - f.write(cluster.model_dump_json() + "\n") - - -if __name__ == "__main__": - run(generate_clusters())