Skip to content

Commit

Permalink
Fixes and improvements to the answer page
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Oct 3, 2024
1 parent 91ad599 commit 5240834
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 163 deletions.
61 changes: 41 additions & 20 deletions adrenaline/api/routes/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import logging
import os
from typing import Any, Dict, List

from fastapi import APIRouter, Body, Depends, HTTPException
from motor.motor_asyncio import AsyncIOMotorDatabase

from api.pages.data import CoTStep, Query
from api.pages.data import Query
from api.patients.cot import generate_cot_answer, generate_cot_steps
from api.patients.db import get_database
from api.patients.rag import EmbeddingManager, MilvusManager, retrieve_relevant_notes
Expand Down Expand Up @@ -50,8 +49,23 @@ async def generate_cot_steps_endpoint(
query: Query = Body(...), # noqa: B008
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> Dict[str, List[Dict[str, str]]]:
"""Generate a chain of thought steps for a given query."""
):
"""Generate COT steps.
Parameters
----------
query : Query
The query to generate steps for.
db : AsyncIOMotorDatabase
The database to use.
current_user : User
The current user.
Returns
-------
List[Step]
The generated steps.
"""
try:
mode = "patient" if query.patient_id else "general"
context = ""
Expand Down Expand Up @@ -90,8 +104,6 @@ async def generate_cot_steps_endpoint(

return {"cot_steps": [step.dict() for step in steps]}

except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve)) from ve
except Exception as e:
raise HTTPException(
status_code=500, detail=f"An error occurred: {str(e)}"
Expand All @@ -103,8 +115,23 @@ async def generate_cot_answer_endpoint(
query: Query = Body(...), # noqa: B008
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> Dict[str, Any]:
"""Generate a chain of thought answer for a given query."""
):
"""Generate a COT answer.
Parameters
----------
query : Query
The query to generate an answer for.
db : AsyncIOMotorDatabase
The database to use.
current_user : User
The current user.
Returns
-------
Dict[str, str]
The generated answer and reasoning.
"""
try:
mode = "patient" if query.patient_id else "general"
context = ""
Expand All @@ -120,14 +147,14 @@ async def generate_cot_answer_endpoint(
)
context = "\n".join([note.text for note in relevant_notes])

cot_steps = (
query.steps
if query.steps
else await generate_cot_steps(query.query, mode, context)
)
if not query.steps:
raise HTTPException(
status_code=400, detail="Steps are required to generate an answer"
)

answer, reasoning = await generate_cot_answer(
user_query=query.query,
steps=cot_steps,
steps=query.steps,
mode=mode,
context=context,
)
Expand All @@ -151,15 +178,9 @@ async def generate_cot_answer_endpoint(
return {
"answer": answer,
"reasoning": reasoning,
"steps": [step.model_dump() for step in cot_steps]
if isinstance(cot_steps[0], CoTStep)
else cot_steps,
}

except ValueError as ve:
raise HTTPException(status_code=400, detail=str(ve)) from ve
except Exception as e:
logger.exception("An error occurred in generate_cot_answer_endpoint")
raise HTTPException(
status_code=500, detail=f"An error occurred: {str(e)}"
) from e
Loading

0 comments on commit 5240834

Please sign in to comment.