Skip to content

Commit

Permalink
Merge pull request #4 from ivanleomk/fix-path
Browse files Browse the repository at this point in the history
fix: correct issue with file path
  • Loading branch information
ivanleomk authored Jan 14, 2025
2 parents af13a9a + b2d540d commit f9d8622
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 131 deletions.
10 changes: 8 additions & 2 deletions kura/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
import uvicorn
from kura.cli.server import api
from rich import print
import os

app = typer.Typer()


@app.command()
def start_app():
def start_app(
dir: str = typer.Option(
"./checkpoints",
help="Directory to use for checkpoints, relative to the current directory",
),
):
"""Start the FastAPI server"""

os.environ["KURA_CHECKPOINT_DIR"] = dir
uvicorn.run(api, host="0.0.0.0", port=8000)
print(
"\n[bold green]🚀 Access website at[/bold green] [bold blue][http://localhost:8000](http://localhost:8000)[/bold blue]\n"
Expand Down
16 changes: 11 additions & 5 deletions kura/cli/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
generate_new_chats_per_week_data,
)
import json
import os

api = FastAPI()

Expand Down Expand Up @@ -57,16 +58,21 @@ async def analyse_conversations(conversation_data: ConversationData):
for conversation in conversation_data.data
]

# Load clusters from checkpoint file if it exists
clusters_file = (
Path(__file__).parent.parent.parent
/ "checkpoints/dimensionality_checkpoints.json"
Path(os.path.abspath(os.environ["KURA_CHECKPOINT_DIR"]))
/ "dimensionality_checkpoints.json"
)
clusters = []

print(clusters_file)

# Load clusters from checkpoint file if it exists

if not clusters_file.exists():
kura = Kura()
kura.conversations = conversations
kura = Kura(
checkpoint_dir=Path(os.path.abspath(os.environ["KURA_CHECKPOINT_DIR"])),
conversations=conversations[:100],
)
await kura.cluster_conversations()

with open(clusters_file) as f:
Expand Down
9 changes: 8 additions & 1 deletion kura/kura.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ def __init__(
meta_cluster_model: BaseMetaClusterModel = MetaClusterModel(),
dimensionality_reduction: BaseDimensionalityReduction = HDBUMAP(),
max_clusters: int = 10,
checkpoint_dir: str = "checkpoints",
checkpoint_dir: str = "./checkpoints",
cluster_checkpoint_name: str = "clusters.json",
meta_cluster_checkpoint_name: str = "meta_clusters.json",
):
# Override checkpoint dirs so that they're the same for the models
summarisation_model.checkpoint_dir = checkpoint_dir
cluster_model.checkpoint_dir = checkpoint_dir
meta_cluster_model.checkpoint_dir = checkpoint_dir
dimensionality_reduction.checkpoint_dir = checkpoint_dir

self.embedding_model = embedding_model
self.embedding_model = embedding_model
self.summarisation_model = summarisation_model
self.conversations = conversations
Expand Down
123 changes: 0 additions & 123 deletions main.py

This file was deleted.

0 comments on commit f9d8622

Please sign in to comment.