diff --git a/adrenaline/api/patients/data.py b/adrenaline/api/patients/data.py index 39c0505..24b8698 100644 --- a/adrenaline/api/patients/data.py +++ b/adrenaline/api/patients/data.py @@ -6,6 +6,45 @@ from pydantic import BaseModel, Field, field_validator +class CohortSearchQuery(BaseModel): + """Query for cohort search. + + Attributes + ---------- + query: str + The search query. + top_k: int + The number of top results to return. + """ + + query: str + top_k: int = 100 + + +class CohortSearchResult(BaseModel): + """Result for cohort search. + + Attributes + ---------- + patient_id: int + The patient ID. + note_type: str + The type of the note. + note_text: str + The text of the note. + timestamp: int + The timestamp of the note. + similarity_score: float + The similarity score. + """ + + patient_id: int + note_type: str + note_text: str + timestamp: int + similarity_score: float + + class QAPair(BaseModel): """ Represents a question-answer pair. diff --git a/adrenaline/api/patients/rag.py b/adrenaline/api/patients/rag.py index 3e0e1df..153b397 100644 --- a/adrenaline/api/patients/rag.py +++ b/adrenaline/api/patients/rag.py @@ -1,15 +1,18 @@ -"""RAG for patients.""" +"""RAG for patients and cohort search.""" import asyncio import logging -from datetime import datetime -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple import httpx from pymilvus import Collection, connections, utility COLLECTION_NAME = "patient_notes" +MILVUS_HOST = "localhost" +MILVUS_PORT = 19530 +EMBEDDING_SERVICE_URL = "http://localhost:8004/embeddings" +NER_SERVICE_URL = "http://clinical-ner-service-dev:8000/extract_entities" logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -19,12 +22,11 @@ class EmbeddingManager: """Manager for embedding service.""" def __init__(self, embedding_service_url: str): - """Initialize the embedding manager.""" self.embedding_service_url = embedding_service_url self.client = httpx.AsyncClient(timeout=60.0) async def get_embedding(self, text: str) -> List[float]: - """Get the embedding for a given text.""" + """Get the embedding for a text.""" response = await self.client.post( self.embedding_service_url, json={"texts": [text], "instruction": "Represent the query for retrieval:"}, @@ -33,12 +35,34 @@ async def get_embedding(self, text: str) -> List[float]: return response.json()["embeddings"][0] async def close(self): - """Close the embedding manager.""" + """Close the client.""" + await self.client.aclose() + + +class NERManager: + """Manager for NER service.""" + + def __init__(self, ner_service_url: str): + """Initialize the NER manager.""" + self.ner_service_url = ner_service_url + self.client = httpx.AsyncClient(timeout=300.0) + + async def extract_entities(self, text: str) -> Dict[str, Any]: + """Extract entities from a text.""" + response = await self.client.post( + self.ner_service_url, + json={"text": text}, + ) + response.raise_for_status() + return response.json() + + async def close(self): + """Close the client.""" await self.client.aclose() class MilvusManager: - """Manager for Milvus.""" + """Manager for Milvus operations.""" def __init__(self, host: str, port: int): """Initialize the Milvus manager.""" @@ -61,11 +85,6 @@ def get_collection(self) -> Collection: self.collection = Collection(self.collection_name) return self.collection - def load_collection(self): - """Load the collection.""" - collection = self.get_collection() - collection.load() - async def ensure_collection_loaded(self): """Ensure the collection is loaded.""" collection = self.get_collection() @@ -74,18 +93,18 @@ async def ensure_collection_loaded(self): async def search( self, query_vector: List[float], - patient_id: int, - top_k: int, - time_range: Dict[str, int] = None, + patient_id: int = None, + top_k: int = 5, ) -> List[Dict[str, Any]]: - """Search for the nearest neighbors.""" + """Retrieve the relevant notes directly from Milvus.""" await self.ensure_collection_loaded() collection = self.get_collection() - search_params = {"metric_type": "IP", "params": {"nprobe": 16, "ef": 64}} + search_params = { + "metric_type": "IP", + "params": {"nprobe": 16, "ef": 64}, + } - expr = f"patient_id == {patient_id}" - if time_range: - expr += f" && timestamp >= {time_range['start']} && timestamp <= {time_range['end']}" + expr = f"patient_id == {patient_id}" if patient_id else None results = collection.search( data=[query_vector], @@ -116,11 +135,119 @@ async def search( for hit in results[0] ] - # Sort by distance in descending order (higher IP score means more similar) filtered_results.sort(key=lambda x: x["distance"], reverse=True) - return filtered_results + async def cohort_search( + self, query_vector: List[float], top_k: int = 2 + ) -> List[Tuple[int, Dict[str, Any]]]: + """Retrieve the cohort search results from Milvus.""" + search_results = await self.search(query_vector, top_k=top_k) + + # Group results by patient_id and keep only the top result for each patient + patient_results = {} + for result in search_results: + patient_id = result["patient_id"] + if ( + patient_id not in patient_results + or result["distance"] > patient_results[patient_id]["distance"] + ): + patient_results[patient_id] = result + + cohort_results = list(patient_results.items()) + cohort_results.sort(key=lambda x: x[1]["distance"], reverse=True) + return cohort_results[:top_k] + + +class RAGManager: + """Manager for RAG operations.""" + + def __init__( + self, + embedding_manager: EmbeddingManager, + milvus_manager: MilvusManager, + ner_manager: NERManager, + ): + """Initialize the RAG manager.""" + self.embedding_manager = embedding_manager + self.milvus_manager = milvus_manager + self.ner_manager = ner_manager + + async def retrieve_relevant_notes( + self, + user_query: str, + patient_id: int, + top_k: int = 5, + ) -> List[Dict[str, Any]]: + """Retrieve the relevant notes directly from Milvus.""" + query_embedding = await self.embedding_manager.get_embedding(user_query) + search_results = await self.milvus_manager.search( + query_embedding, patient_id, top_k + ) + + # Extract entities from the query + query_entities = await self.ner_manager.extract_entities(user_query) + + # Extract entities from the retrieved notes and filter based on matched entities + filtered_results = [] + for result in search_results: + note_entities = await self.ner_manager.extract_entities(result["note_text"]) + matching_entities = set(query_entities.keys()) & set(note_entities.keys()) + if matching_entities: + result["matching_entities"] = list(matching_entities) + filtered_results.append(result) + + filtered_results.sort( + key=lambda x: len(x.get("matching_entities", [])), reverse=True + ) + + logger.info( + f"Retrieved {len(filtered_results)} relevant notes for patient {patient_id}" + ) + for i, result in enumerate(filtered_results): + logger.info( + f"Result {i+1}: Distance = {result['distance']}, Matching Entities = {result.get('matching_entities', [])}" + ) + + return filtered_results[:top_k] + + async def cohort_search( + self, user_query: str, top_k: int = 2 + ) -> List[Tuple[int, Dict[str, Any]]]: + """Retrieve the cohort search results from Milvus.""" + query_embedding = await self.embedding_manager.get_embedding(user_query) + cohort_results = await self.milvus_manager.cohort_search(query_embedding, top_k) + + # Extract entities from the query + query_entities = await self.ner_manager.extract_entities(user_query) + + # Filter and sort results based on matching entities + filtered_results = [] + for patient_id, note_details in cohort_results: + note_entities = await self.ner_manager.extract_entities( + note_details["note_text"] + ) + matching_entities = set(query_entities.keys()) & set(note_entities.keys()) + if matching_entities: + note_details["matching_entities"] = list(matching_entities) + filtered_results.append((patient_id, note_details)) + + filtered_results.sort( + key=lambda x: len(x[1].get("matching_entities", [])), reverse=True + ) + + logger.info( + f"Found {len(filtered_results)} patients matching the query: '{user_query}'" + ) + for _, (patient_id, note_details) in enumerate(filtered_results[:5]): + logger.info( + f"Patient {patient_id}: Distance = {note_details['distance']}, " + f"Note Type = {note_details['note_type']}, " + f"Matching Entities = {note_details.get('matching_entities', [])}" + ) + + return filtered_results[:top_k] + async def retrieve_relevant_notes( user_query: str, @@ -128,18 +255,11 @@ async def retrieve_relevant_notes( milvus_manager: MilvusManager, patient_id: int, top_k: int = 5, - time_range: Dict[str, int] = None, ) -> List[Dict[str, Any]]: """Retrieve the relevant notes directly from Milvus.""" query_embedding = await embedding_manager.get_embedding(user_query) - search_results = await milvus_manager.search( - query_embedding, patient_id, top_k, time_range - ) - for result in search_results: - print(result["note_type"]) + search_results = await milvus_manager.search(query_embedding, patient_id, top_k) logger.info(f"Retrieved {len(search_results)} relevant notes") for i, result in enumerate(search_results): - logger.info( - f"Result {i+1}: Distance = {result['distance']}, Timestamp = {datetime.fromtimestamp(result['timestamp'])}" - ) + logger.info(f"Result {i+1}: Distance = {result['distance']}") return search_results diff --git a/adrenaline/api/routes/answer.py b/adrenaline/api/routes/answer.py index b9e6693..8b2f123 100644 --- a/adrenaline/api/routes/answer.py +++ b/adrenaline/api/routes/answer.py @@ -1,17 +1,24 @@ -"""Routes for generating answers.""" +"""Routes for generating answers and performing cohort searches.""" import logging import os from datetime import datetime -from typing import Dict +from typing import Dict, List from fastapi import APIRouter, Body, Depends, HTTPException from motor.motor_asyncio import AsyncIOMotorDatabase from api.pages.data import Query from api.patients.answer import generate_answer +from api.patients.data import CohortSearchQuery, CohortSearchResult from api.patients.db import get_database -from api.patients.rag import EmbeddingManager, MilvusManager, retrieve_relevant_notes +from api.patients.rag import ( + EmbeddingManager, + MilvusManager, + NERManager, + RAGManager, + retrieve_relevant_notes, +) from api.users.auth import get_current_active_user from api.users.data import User @@ -36,13 +43,18 @@ EMBEDDING_SERVICE_URL = ( f"http://{EMBEDDING_SERVICE_HOST}:{EMBEDDING_SERVICE_PORT}/embeddings" ) +NER_SERVICE_PORT = os.getenv("NER_SERVICE_PORT", "8000") +NER_SERVICE_URL = f"http://clinical-ner-service-dev:{NER_SERVICE_PORT}/extract_entities" + EMBEDDING_MANAGER = EmbeddingManager(EMBEDDING_SERVICE_URL) MILVUS_MANAGER = MilvusManager(MILVUS_HOST, MILVUS_PORT) MILVUS_MANAGER.connect() +NER_MANAGER = NERManager(NER_SERVICE_URL) +RAG_MANAGER = RAGManager(EMBEDDING_MANAGER, MILVUS_MANAGER, NER_MANAGER) -@router.post("/generate_cot_answer") -async def generate_cot_answer_endpoint( +@router.post("/generate_answer") +async def generate_answer_endpoint( query: Query = Body(...), # noqa: B008 db: AsyncIOMotorDatabase = Depends(get_database), # noqa: B008 current_user: User = Depends(get_current_active_user), # noqa: B008 @@ -126,3 +138,43 @@ async def generate_cot_answer_endpoint( raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) from e + + +@router.post("/cohort_search") +async def cohort_search_endpoint( + query: CohortSearchQuery = Body(...), # noqa: B008 + current_user: User = Depends(get_current_active_user), # noqa: B008 +) -> List[CohortSearchResult]: + """Perform a cohort search across all patients.""" + try: + logger.info(f"Received cohort search query: {query.query}") + + if not query.query: + raise ValueError("Query string is empty") + + cohort_results = await RAG_MANAGER.cohort_search(query.query, query.top_k) + logger.info(f"Found {len(cohort_results)} patients matching the query") + + return [ + CohortSearchResult( + patient_id=patient_id, + note_type=note_details["note_type"], + note_text=note_details["note_text"][ + :500 + ], # Limit to first 500 characters + timestamp=note_details["timestamp"], + similarity_score=note_details["distance"], + ) + for patient_id, note_details in cohort_results + ] + + except ValueError as ve: + logger.error(f"Validation error: {str(ve)}") + raise HTTPException(status_code=400, detail=str(ve)) from ve + except Exception as e: + logger.error( + f"Unexpected error in cohort_search_endpoint: {str(e)}", exc_info=True + ) + raise HTTPException( + status_code=500, detail=f"An unexpected error occurred: {str(e)}" + ) from e diff --git a/ui/src/app/answer/[id]/page.tsx b/ui/src/app/answer/[id]/page.tsx deleted file mode 100644 index 73886e9..0000000 --- a/ui/src/app/answer/[id]/page.tsx +++ /dev/null @@ -1,320 +0,0 @@ -'use client' - -import React, { useEffect, useState, useCallback, useMemo } from 'react' -import { useParams, useSearchParams } from 'next/navigation' -import { - Box, Flex, VStack, useColorModeValue, Container, Card, CardBody, - useToast, Skeleton, Text, Heading, Progress -} from '@chakra-ui/react' -import { motion, AnimatePresence } from 'framer-motion' -import Sidebar from '../../components/sidebar' -import { withAuth } from '../../components/with-auth' -import SearchBox from '../../components/search-box' -import AnswerCard from '../../components/answer-card' - -const MotionBox = motion(Box) - -interface Query { - query: string; - patient_id?: number; -} - -interface Answer { - answer: string; - reasoning: string; -} - -interface QueryAnswer { - query: Query; - answer?: Answer; - is_first: boolean; -} - -interface PageData { - id: string; - user_id: string; - query_answers: QueryAnswer[]; - created_at: string; - updated_at: string; -} - -interface SearchState { - isSearching: boolean; - answer: string | null; - reasoning: string | null; -} - -const AnswerPage: React.FC = () => { - const [pageData, setPageData] = useState(null) - const [isLoading, setIsLoading] = useState(true) - const [searchState, setSearchState] = useState({ - isSearching: false, - answer: null, - reasoning: null, - }) - const params = useParams() - const searchParams = useSearchParams() - const id = params?.id as string - const isNewQuery = searchParams?.get('new') === 'true' - const initialQuery = searchParams?.get('query') - const toast = useToast() - - const bgColor = useColorModeValue('gray.50', 'gray.900') - const cardBgColor = useColorModeValue('white', 'gray.800') - - const fetchPageData = useCallback(async (): Promise => { - setIsLoading(true) - try { - const token = localStorage.getItem('token') - if (!token) throw new Error('No token found') - - const response = await fetch(`/api/pages/${id}`, { - headers: { 'Authorization': `Bearer ${token}` }, - }) - - if (!response.ok) { - const errorData = await response.json() - throw new Error(`Failed to fetch page data: ${errorData.message}`) - } - - const data: PageData = await response.json() - setPageData(data) - return data - } catch (error) { - console.error('Error loading page data:', error) - toast({ - title: "Error", - description: error instanceof Error ? error.message : "An error occurred while loading page data", - status: "error", - duration: 3000, - isClosable: true, - }) - return null - } finally { - setIsLoading(false) - } - }, [id, toast]) - - const generateAnswer = useCallback(async (query: string, pageId: string, patientId?: number): Promise => { - try { - const token = localStorage.getItem('token') - if (!token) throw new Error('No token found') - - const answerResponse = await fetch('/api/generate_cot_answer', { - method: 'POST', - headers: { - 'Authorization': `Bearer ${token}`, - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ query, page_id: pageId, patient_id: patientId }), - signal: AbortSignal.timeout(180000) // 3 minutes timeout - }) - - if (!answerResponse.ok) { - const errorData = await answerResponse.json() - throw new Error(`Failed to generate answer: ${errorData.message || answerResponse.statusText}`) - } - - return await answerResponse.json() - } catch (error) { - console.error('Error generating answer:', error) - toast({ - title: "Error", - description: error instanceof Error ? error.message : "An error occurred while generating answer", - status: "error", - duration: 30000, - isClosable: true, - }) - return null - } - }, [toast]) - - const handleSearch = useCallback(async (query: string) => { - if (!query.trim()) { - toast({ - title: "Error", - description: "Please enter a query", - status: "error", - duration: 30000, - isClosable: true, - }) - return - } - - setSearchState(prev => ({ ...prev, isSearching: true, answer: null, reasoning: null })) - - try { - const answer = await generateAnswer(query, id) - - setSearchState(prev => ({ - ...prev, - answer: answer?.answer || null, - reasoning: answer?.reasoning || null, - isSearching: false - })) - - if (answer) { - setPageData(prevData => { - if (!prevData) return null - const updatedQueryAnswers = [ - ...prevData.query_answers, - { - query: { query }, - answer, - is_first: false - } - ] - return { ...prevData, query_answers: updatedQueryAnswers } - }) - } - } catch (error) { - console.error('Error:', error) - toast({ - title: "Error", - description: error instanceof Error ? error.message : "An error occurred", - status: "error", - duration: 30000, - isClosable: true, - }) - setSearchState(prev => ({ ...prev, isSearching: false })) - } - }, [generateAnswer, id, toast]) - - useEffect(() => { - const initializePage = async () => { - try { - const data = await fetchPageData(); - if (data) { - const firstQueryAnswer = data.query_answers[0]; - if (firstQueryAnswer && isNewQuery && initialQuery && !firstQueryAnswer.answer) { - setSearchState(prev => ({ ...prev, isSearching: true })); - const answer = await generateAnswer(initialQuery, id, firstQueryAnswer.query.patient_id); - setSearchState(prev => ({ - ...prev, - answer: answer?.answer || null, - reasoning: answer?.reasoning || null, - isSearching: false - })); - if (answer) { - setPageData(prevData => { - if (!prevData) return data; - const updatedQueryAnswers = prevData.query_answers.map((qa, index) => - index === 0 ? { ...qa, answer } : qa - ); - return { ...prevData, query_answers: updatedQueryAnswers }; - }); - } - } else if (firstQueryAnswer && firstQueryAnswer.answer) { - setSearchState(prev => ({ - ...prev, - answer: firstQueryAnswer.answer?.answer || null, - reasoning: firstQueryAnswer.answer?.reasoning || null, - isSearching: false - })); - } - } - } catch (error) { - console.error('Error initializing page:', error); - toast({ - title: "Error", - description: "Failed to initialize page. Please try refreshing.", - status: "error", - duration: 5000, - isClosable: true, - }); - setSearchState(prev => ({ ...prev, isSearching: false })); - } - }; - - initializePage(); - }, [fetchPageData, generateAnswer, isNewQuery, initialQuery, id, toast]); - - const firstQueryAnswer = useMemo(() => pageData?.query_answers[0], [pageData]); - const { isSearching, answer, reasoning } = searchState; - - return ( - - - - - - - {isLoading ? ( - - ) : firstQueryAnswer ? ( - - - Query - {firstQueryAnswer.query.query} - - - ) : ( - - - No page data found - - - )} - - - - {isSearching && ( - - - - - Generating Answer - - div': { - transitionDuration: '1.5s', - }, - }} - /> - - Analyzing query and formulating response... - - - - - )} - - - - {answer && ( - - - - )} - - - - - - - - - - ) -} - -export default withAuth(AnswerPage) diff --git a/ui/src/app/cohort/[id]/page.tsx b/ui/src/app/cohort/[id]/page.tsx new file mode 100644 index 0000000..86ac21c --- /dev/null +++ b/ui/src/app/cohort/[id]/page.tsx @@ -0,0 +1,178 @@ +'use client' + +import React, { useEffect, useState, useCallback } from 'react' +import { useParams, useSearchParams, useRouter } from 'next/navigation' +import { + Box, Flex, VStack, useColorModeValue, Container, Card, CardBody, + useToast, Text, Heading, SimpleGrid, Skeleton +} from '@chakra-ui/react' +import { motion, AnimatePresence } from 'framer-motion' +import Sidebar from '../../components/sidebar' +import { withAuth } from '../../components/with-auth' +import SearchBox from '../../components/search-box' +import PatientCohortCard from '../../components/patient-cohort-card' + +const MotionBox = motion(Box) + +interface CohortSearchResult { + patient_id: number; + note_text: string; + similarity_score: number; +} + +interface GroupedCohortResult { + patient_id: number; + total_notes: number; + notes_summary: CohortSearchResult[]; +} + +const CohortPage: React.FC = () => { + const [searchResults, setSearchResults] = useState([]) + const [isLoading, setIsLoading] = useState(false) + const params = useParams() + const searchParams = useSearchParams() + const router = useRouter() + const id = params?.id as string + const initialQuery = searchParams?.get('query') + const toast = useToast() + + const bgColor = useColorModeValue('gray.50', 'gray.900') + const cardBgColor = useColorModeValue('white', 'gray.800') + + const performCohortSearch = useCallback(async (query: string): Promise => { + setIsLoading(true) + try { + const token = localStorage.getItem('token') + if (!token) throw new Error('No token found') + + const response = await fetch('/api/cohort_search', { + method: 'POST', + headers: { + 'Authorization': `Bearer ${token}`, + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ query, top_k: 2 }), + }) + + if (!response.ok) { + const errorData = await response.json() + throw new Error(`Failed to perform cohort search: ${errorData.message}`) + } + + const results: CohortSearchResult[] = await response.json() + + const groupedResults = results.reduce((acc, result) => { + const existingPatient = acc.find(p => p.patient_id === result.patient_id) + if (existingPatient) { + existingPatient.total_notes += 1 + existingPatient.notes_summary.push(result) + } else { + acc.push({ + patient_id: result.patient_id, + total_notes: 1, + notes_summary: [result] + }) + } + return acc + }, [] as GroupedCohortResult[]) + + setSearchResults(groupedResults) + } catch (error) { + console.error('Error performing cohort search:', error) + toast({ + title: "Error", + description: error instanceof Error ? error.message : "An error occurred while performing cohort search", + status: "error", + duration: 5000, + isClosable: true, + }) + } finally { + setIsLoading(false) + } + }, [toast]) + + const handleSearch = useCallback(async (query: string) => { + if (!query.trim()) { + toast({ + title: "Error", + description: "Please enter a query", + status: "error", + duration: 3000, + isClosable: true, + }) + return + } + + await performCohortSearch(query) + }, [performCohortSearch, toast]) + + useEffect(() => { + if (initialQuery) { + performCohortSearch(initialQuery) + } + }, [initialQuery, performCohortSearch]) + + const handlePatientClick = useCallback((patientId: number) => { + router.push(`/patient/${patientId}`) + }, [router]) + + return ( + + + + + + + + + Cohort Search + {initialQuery || "Enter a query to search across all patients"} + + + + + + + {isLoading ? ( + + {[...Array(6)].map((_, index) => ( + + ))} + + ) : ( + + {searchResults.map((result) => ( + + ))} + + )} + + + + + + + + + + + ) +} + +export default withAuth(CohortPage) diff --git a/ui/src/app/components/entity-viz.tsx b/ui/src/app/components/entity-viz.tsx index 2044c9c..5346b7e 100644 --- a/ui/src/app/components/entity-viz.tsx +++ b/ui/src/app/components/entity-viz.tsx @@ -47,7 +47,7 @@ const EntityVisualization: React.FC = ({ text, entitie const hash = entityTypes.join('').split('').reduce((acc, char) => char.charCodeAt(0) + acc, 0); const colorIndex = hash % baseColors.length; - const shade = (hash % 3 + 1) * 100; // This will give us shades 100, 200, or 300 + const shade = (hash % 3 + 1) * 100; return `${baseColors[colorIndex]}.${shade}`; }, []); @@ -105,6 +105,12 @@ const EntityVisualization: React.FC = ({ text, entitie {entity.icd10.map(icd => `${icd.chapter}: ${icd.name}`).join(', ')} )} + {Object.entries(entity.meta_anns).map(([key, metaAnn]) => ( + + {metaAnn.name}: + {metaAnn.value} (Confidence: {metaAnn.confidence.toFixed(2)}) + + ))} diff --git a/ui/src/app/components/patient-cohort-card.tsx b/ui/src/app/components/patient-cohort-card.tsx new file mode 100644 index 0000000..e49f24f --- /dev/null +++ b/ui/src/app/components/patient-cohort-card.tsx @@ -0,0 +1,69 @@ +// ui/src/app/components/PatientCohortCard.tsx +import React from 'react'; +import { + Box, Card, CardBody, Heading, Badge, Text, useColorModeValue, + Accordion, AccordionItem, AccordionButton, AccordionPanel, AccordionIcon, + VStack, StackDivider +} from '@chakra-ui/react'; + +interface Note { + note_text: string; + similarity_score: number; +} + +interface PatientCohortCardProps { + patientId: number; + totalNotes: number; + notes: Note[]; + onCardClick: (patientId: number) => void; +} + +const PatientCohortCard: React.FC = ({ patientId, totalNotes, notes, onCardClick }) => { + const cardBgColor = useColorModeValue('white', 'gray.800'); + const noteBgColor = useColorModeValue('gray.50', 'gray.700'); + + return ( + onCardClick(patientId)} + _hover={{ transform: 'scale(1.02)', transition: 'transform 0.2s' }} + > + + + Patient ID: {patientId} + + Total Notes: {totalNotes} + + + e.stopPropagation()}> + + View Matching Notes + + + + + } + spacing={4} + align="stretch" + > + {notes.map((note, index) => ( + + {note.note_text} + + Similarity Score: {note.similarity_score.toFixed(4)} + + + ))} + + + + + + + ); +}; + +export default PatientCohortCard; diff --git a/ui/src/app/components/search-box.tsx b/ui/src/app/components/search-box.tsx index f2c7494..990a786 100644 --- a/ui/src/app/components/search-box.tsx +++ b/ui/src/app/components/search-box.tsx @@ -7,11 +7,17 @@ interface SearchBoxProps { onSearch: (query: string, isPatientMode: boolean) => void; isLoading: boolean; isPatientPage?: boolean; + isCohortPage?: boolean; } -const SearchBox: React.FC = ({ onSearch, isLoading, isPatientPage = false }) => { +const SearchBox: React.FC = ({ + onSearch, + isLoading, + isPatientPage = false, + isCohortPage = false +}) => { const [query, setQuery] = useState(''); - const [isPatientMode, setIsPatientMode] = useState(isPatientPage); + const [isPatientMode, setIsPatientMode] = useState(false); const handleQueryChange = useCallback((e: React.ChangeEvent) => { setQuery(e.target.value); @@ -36,7 +42,13 @@ const SearchBox: React.FC = ({ onSearch, isLoading, isPatientPag const bgColor = useColorModeValue('white', 'gray.800'); const buttonBgColor = useColorModeValue('#1f5280', '#3a7ab3'); - const patientTextColor = isPatientMode ? buttonBgColor : 'inherit'; + const modeTextColor = useColorModeValue('gray.600', 'gray.400'); + + const getPlaceholder = () => { + if (isPatientPage) return "Ask a question about this patient..."; + if (isCohortPage) return "Enter a query to search across all patients..."; + return isPatientMode ? "Enter patient ID" : "Ask a question about the cohort..."; + }; return ( @@ -52,7 +64,7 @@ const SearchBox: React.FC = ({ onSearch, isLoading, isPatientPag value={query} onChange={handleQueryChange} onKeyPress={handleKeyPress} - placeholder={isPatientPage ? "Ask a question about this patient..." : (isPatientMode ? "Enter patient ID" : "Ask a question...")} + placeholder={getPlaceholder()} minRows={3} maxRows={10} style={{ @@ -78,17 +90,16 @@ const SearchBox: React.FC = ({ onSearch, isLoading, isPatientPag align="center" flexDirection={{ base: 'column', sm: 'row' }} > - {!isPatientPage && ( + {!isPatientPage && !isCohortPage && ( + + {isPatientMode ? "Patient" : "Cohort"} + - - Patient - )}