Skip to content

Commit

Permalink
Update to create cohort search page
Browse files Browse the repository at this point in the history
  • Loading branch information
amrit110 committed Oct 21, 2024
1 parent a3cbbda commit f566ef2
Show file tree
Hide file tree
Showing 9 changed files with 523 additions and 395 deletions.
39 changes: 39 additions & 0 deletions adrenaline/api/patients/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,45 @@
from pydantic import BaseModel, Field, field_validator


class CohortSearchQuery(BaseModel):
"""Query for cohort search.
Attributes
----------
query: str
The search query.
top_k: int
The number of top results to return.
"""

query: str
top_k: int = 100


class CohortSearchResult(BaseModel):
"""Result for cohort search.
Attributes
----------
patient_id: int
The patient ID.
note_type: str
The type of the note.
note_text: str
The text of the note.
timestamp: int
The timestamp of the note.
similarity_score: float
The similarity score.
"""

patient_id: int
note_type: str
note_text: str
timestamp: int
similarity_score: float


class QAPair(BaseModel):
"""
Represents a question-answer pair.
Expand Down
182 changes: 151 additions & 31 deletions adrenaline/api/patients/rag.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""RAG for patients."""
"""RAG for patients and cohort search."""

import asyncio
import logging
from datetime import datetime
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple

import httpx
from pymilvus import Collection, connections, utility


COLLECTION_NAME = "patient_notes"
MILVUS_HOST = "localhost"
MILVUS_PORT = 19530
EMBEDDING_SERVICE_URL = "http://localhost:8004/embeddings"
NER_SERVICE_URL = "http://clinical-ner-service-dev:8000/extract_entities"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -19,12 +22,11 @@ class EmbeddingManager:
"""Manager for embedding service."""

def __init__(self, embedding_service_url: str):
"""Initialize the embedding manager."""
self.embedding_service_url = embedding_service_url
self.client = httpx.AsyncClient(timeout=60.0)

async def get_embedding(self, text: str) -> List[float]:
"""Get the embedding for a given text."""
"""Get the embedding for a text."""
response = await self.client.post(
self.embedding_service_url,
json={"texts": [text], "instruction": "Represent the query for retrieval:"},
Expand All @@ -33,12 +35,34 @@ async def get_embedding(self, text: str) -> List[float]:
return response.json()["embeddings"][0]

async def close(self):
"""Close the embedding manager."""
"""Close the client."""
await self.client.aclose()


class NERManager:
"""Manager for NER service."""

def __init__(self, ner_service_url: str):
"""Initialize the NER manager."""
self.ner_service_url = ner_service_url
self.client = httpx.AsyncClient(timeout=300.0)

async def extract_entities(self, text: str) -> Dict[str, Any]:
"""Extract entities from a text."""
response = await self.client.post(
self.ner_service_url,
json={"text": text},
)
response.raise_for_status()
return response.json()

async def close(self):
"""Close the client."""
await self.client.aclose()


class MilvusManager:
"""Manager for Milvus."""
"""Manager for Milvus operations."""

def __init__(self, host: str, port: int):
"""Initialize the Milvus manager."""
Expand All @@ -61,11 +85,6 @@ def get_collection(self) -> Collection:
self.collection = Collection(self.collection_name)
return self.collection

def load_collection(self):
"""Load the collection."""
collection = self.get_collection()
collection.load()

async def ensure_collection_loaded(self):
"""Ensure the collection is loaded."""
collection = self.get_collection()
Expand All @@ -74,18 +93,18 @@ async def ensure_collection_loaded(self):
async def search(
self,
query_vector: List[float],
patient_id: int,
top_k: int,
time_range: Dict[str, int] = None,
patient_id: int = None,
top_k: int = 5,
) -> List[Dict[str, Any]]:
"""Search for the nearest neighbors."""
"""Retrieve the relevant notes directly from Milvus."""
await self.ensure_collection_loaded()
collection = self.get_collection()
search_params = {"metric_type": "IP", "params": {"nprobe": 16, "ef": 64}}
search_params = {
"metric_type": "IP",
"params": {"nprobe": 16, "ef": 64},
}

expr = f"patient_id == {patient_id}"
if time_range:
expr += f" && timestamp >= {time_range['start']} && timestamp <= {time_range['end']}"
expr = f"patient_id == {patient_id}" if patient_id else None

results = collection.search(
data=[query_vector],
Expand Down Expand Up @@ -116,30 +135,131 @@ async def search(
for hit in results[0]
]

# Sort by distance in descending order (higher IP score means more similar)
filtered_results.sort(key=lambda x: x["distance"], reverse=True)

return filtered_results

async def cohort_search(
self, query_vector: List[float], top_k: int = 2
) -> List[Tuple[int, Dict[str, Any]]]:
"""Retrieve the cohort search results from Milvus."""
search_results = await self.search(query_vector, top_k=top_k)

# Group results by patient_id and keep only the top result for each patient
patient_results = {}
for result in search_results:
patient_id = result["patient_id"]
if (
patient_id not in patient_results
or result["distance"] > patient_results[patient_id]["distance"]
):
patient_results[patient_id] = result

cohort_results = list(patient_results.items())
cohort_results.sort(key=lambda x: x[1]["distance"], reverse=True)
return cohort_results[:top_k]


class RAGManager:
"""Manager for RAG operations."""

def __init__(
self,
embedding_manager: EmbeddingManager,
milvus_manager: MilvusManager,
ner_manager: NERManager,
):
"""Initialize the RAG manager."""
self.embedding_manager = embedding_manager
self.milvus_manager = milvus_manager
self.ner_manager = ner_manager

async def retrieve_relevant_notes(
self,
user_query: str,
patient_id: int,
top_k: int = 5,
) -> List[Dict[str, Any]]:
"""Retrieve the relevant notes directly from Milvus."""
query_embedding = await self.embedding_manager.get_embedding(user_query)
search_results = await self.milvus_manager.search(
query_embedding, patient_id, top_k
)

# Extract entities from the query
query_entities = await self.ner_manager.extract_entities(user_query)

# Extract entities from the retrieved notes and filter based on matched entities
filtered_results = []
for result in search_results:
note_entities = await self.ner_manager.extract_entities(result["note_text"])
matching_entities = set(query_entities.keys()) & set(note_entities.keys())
if matching_entities:
result["matching_entities"] = list(matching_entities)
filtered_results.append(result)

filtered_results.sort(
key=lambda x: len(x.get("matching_entities", [])), reverse=True
)

logger.info(
f"Retrieved {len(filtered_results)} relevant notes for patient {patient_id}"
)
for i, result in enumerate(filtered_results):
logger.info(
f"Result {i+1}: Distance = {result['distance']}, Matching Entities = {result.get('matching_entities', [])}"
)

return filtered_results[:top_k]

async def cohort_search(
self, user_query: str, top_k: int = 2
) -> List[Tuple[int, Dict[str, Any]]]:
"""Retrieve the cohort search results from Milvus."""
query_embedding = await self.embedding_manager.get_embedding(user_query)
cohort_results = await self.milvus_manager.cohort_search(query_embedding, top_k)

# Extract entities from the query
query_entities = await self.ner_manager.extract_entities(user_query)

# Filter and sort results based on matching entities
filtered_results = []
for patient_id, note_details in cohort_results:
note_entities = await self.ner_manager.extract_entities(
note_details["note_text"]
)
matching_entities = set(query_entities.keys()) & set(note_entities.keys())
if matching_entities:
note_details["matching_entities"] = list(matching_entities)
filtered_results.append((patient_id, note_details))

filtered_results.sort(
key=lambda x: len(x[1].get("matching_entities", [])), reverse=True
)

logger.info(
f"Found {len(filtered_results)} patients matching the query: '{user_query}'"
)
for _, (patient_id, note_details) in enumerate(filtered_results[:5]):
logger.info(
f"Patient {patient_id}: Distance = {note_details['distance']}, "
f"Note Type = {note_details['note_type']}, "
f"Matching Entities = {note_details.get('matching_entities', [])}"
)

return filtered_results[:top_k]


async def retrieve_relevant_notes(
user_query: str,
embedding_manager: EmbeddingManager,
milvus_manager: MilvusManager,
patient_id: int,
top_k: int = 5,
time_range: Dict[str, int] = None,
) -> List[Dict[str, Any]]:
"""Retrieve the relevant notes directly from Milvus."""
query_embedding = await embedding_manager.get_embedding(user_query)
search_results = await milvus_manager.search(
query_embedding, patient_id, top_k, time_range
)
for result in search_results:
print(result["note_type"])
search_results = await milvus_manager.search(query_embedding, patient_id, top_k)
logger.info(f"Retrieved {len(search_results)} relevant notes")
for i, result in enumerate(search_results):
logger.info(
f"Result {i+1}: Distance = {result['distance']}, Timestamp = {datetime.fromtimestamp(result['timestamp'])}"
)
logger.info(f"Result {i+1}: Distance = {result['distance']}")
return search_results
62 changes: 57 additions & 5 deletions adrenaline/api/routes/answer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
"""Routes for generating answers."""
"""Routes for generating answers and performing cohort searches."""

import logging
import os
from datetime import datetime
from typing import Dict
from typing import Dict, List

from fastapi import APIRouter, Body, Depends, HTTPException
from motor.motor_asyncio import AsyncIOMotorDatabase

from api.pages.data import Query
from api.patients.answer import generate_answer
from api.patients.data import CohortSearchQuery, CohortSearchResult
from api.patients.db import get_database
from api.patients.rag import EmbeddingManager, MilvusManager, retrieve_relevant_notes
from api.patients.rag import (
EmbeddingManager,
MilvusManager,
NERManager,
RAGManager,
retrieve_relevant_notes,
)
from api.users.auth import get_current_active_user
from api.users.data import User

Expand All @@ -36,13 +43,18 @@
EMBEDDING_SERVICE_URL = (
f"http://{EMBEDDING_SERVICE_HOST}:{EMBEDDING_SERVICE_PORT}/embeddings"
)
NER_SERVICE_PORT = os.getenv("NER_SERVICE_PORT", "8000")
NER_SERVICE_URL = f"http://clinical-ner-service-dev:{NER_SERVICE_PORT}/extract_entities"

EMBEDDING_MANAGER = EmbeddingManager(EMBEDDING_SERVICE_URL)
MILVUS_MANAGER = MilvusManager(MILVUS_HOST, MILVUS_PORT)
MILVUS_MANAGER.connect()
NER_MANAGER = NERManager(NER_SERVICE_URL)
RAG_MANAGER = RAGManager(EMBEDDING_MANAGER, MILVUS_MANAGER, NER_MANAGER)


@router.post("/generate_cot_answer")
async def generate_cot_answer_endpoint(
@router.post("/generate_answer")
async def generate_answer_endpoint(
query: Query = Body(...), # noqa: B008
db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
Expand Down Expand Up @@ -126,3 +138,43 @@ async def generate_cot_answer_endpoint(
raise HTTPException(
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
) from e


@router.post("/cohort_search")
async def cohort_search_endpoint(
query: CohortSearchQuery = Body(...), # noqa: B008
current_user: User = Depends(get_current_active_user), # noqa: B008
) -> List[CohortSearchResult]:
"""Perform a cohort search across all patients."""
try:
logger.info(f"Received cohort search query: {query.query}")

if not query.query:
raise ValueError("Query string is empty")

cohort_results = await RAG_MANAGER.cohort_search(query.query, query.top_k)
logger.info(f"Found {len(cohort_results)} patients matching the query")

return [
CohortSearchResult(
patient_id=patient_id,
note_type=note_details["note_type"],
note_text=note_details["note_text"][
:500
], # Limit to first 500 characters
timestamp=note_details["timestamp"],
similarity_score=note_details["distance"],
)
for patient_id, note_details in cohort_results
]

except ValueError as ve:
logger.error(f"Validation error: {str(ve)}")
raise HTTPException(status_code=400, detail=str(ve)) from ve
except Exception as e:
logger.error(
f"Unexpected error in cohort_search_endpoint: {str(e)}", exc_info=True
)
raise HTTPException(
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
) from e
Loading

0 comments on commit f566ef2

Please sign in to comment.