Skip to content

Commit

Permalink
Merge pull request #38 from studio-recoding/feat/recommendation
Browse files Browse the repository at this point in the history
[feat] 메인페이지 추천 코드 작성 완료
  • Loading branch information
uommou authored Mar 31, 2024
2 parents 69396f9 + f292ce1 commit 55f776a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 8 deletions.
23 changes: 20 additions & 3 deletions app/database/chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import datetime
from dotenv import load_dotenv
from app.dto.db_dto import AddScheduleDTO
from app.dto.db_dto import AddScheduleDTO, RecommendationMainRequestDTO

load_dotenv()
CHROMA_DB_IP_ADDRESS = os.getenv("CHROMA_DB_IP_ADDRESS")
Expand All @@ -34,13 +34,13 @@
def check_db_heartbeat():
chroma_client.heartbeat()

# description: DB에서 검색하는 함수
# description: DB에서 검색하는 함수 - chat case 3에 사용
async def search_db_query(query):
# 컬렉션 생성
# 컬렉션에 쿼리 전송
result = schedules.query(
query_texts=query,
n_results=2 # 결과에서 한 가지 문서만 반환하면 한강공원이, 두 가지 문서 반환하면 AI가 뜸->유사도가 이상하게 검사되는 것 같음
n_results=5 # 결과에서 한 가지 문서만 반환하면 한강공원이, 두 가지 문서 반환하면 AI가 뜸->유사도가 이상하게 검사되는 것 같음
)
return result

Expand All @@ -54,5 +54,22 @@ async def add_db_data(schedule_data: AddScheduleDTO):
)
return True

# 메인페이지 한 줄 추천 기능에 사용하는 함수
# 유저의 id, 해당 날짜로 필터링
async def db_recommendation_main(user_data: RecommendationMainRequestDTO):
results = schedules.query(
user_persona=["hard working"],
n_results=5,
where={"$and" :
[
{"member": {"$eq": int(user_data.member_id)}},
{"datetime_start": {
"$gte": user_data.schedule_datetime_start, # greater than or equal
"$lt": user_data.schedule_datetime_end # less than
}}
]}
# where_document={"$contains":"search_string"} # optional filter
)

def get_chroma_client():
return chroma_client
6 changes: 6 additions & 0 deletions app/dto/db_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,9 @@ class AddScheduleDTO(BaseModel):
location: str
person: str

class RecommendationMainRequestDTO(BaseModel):
member_id: int
user_persona: str
schedule_datetime_start: str
schedule_datetime_end: str

20 changes: 19 additions & 1 deletion app/prompt/openai_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
class Template:
recommendation_template = """한 줄 추천 기능 템플릿"""
recommendation_template = """
You are an AI assistant designed to recommend daily activities based on a user's schedule. You will receive a day's worth of the user's schedule information. Your task is to understand that schedule and, based on it, recommend an activity for the user to perform that day. There are a few rules you must follow in your recommendations:
1. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT.
2. Ensure your recommendation is encouraging, so the user doesn't feel compelled.
3. The recommendation must be concise, limited to one sentence without any additional commentary.
Example:
User schedule: 8:00 AM - 9:00 AM: Gym
9:30 AM - 12:00 PM: Work meetings
12:00 PM - 1:00 PM: Lunch break
1:00 PM - 5:00 PM: Work on project
5:30 PM - 7:00 PM: Free time
7:00 PM - 9:00 PM: Dinner with family
9:30 PM: Free time
AI Recommendation: "Since you've got some free time before dinner, how about taking a short walk in the park to relax and clear your mind?"
User schedule: {schedule}
AI Recommendation:
"""
# case 분류 잘 안됨 - 수정 필요
case_classify_template = """
Task: User Chat Classification
Expand Down
15 changes: 11 additions & 4 deletions app/routers/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import os

from dotenv import load_dotenv
from fastapi import APIRouter
from fastapi import APIRouter, Depends, status
from langchain_community.chat_models import ChatOpenAI
from langchain_core.prompts import PromptTemplate

from app.dto.db_dto import RecommendationMainRequestDTO
from app.dto.openai_dto import ChatResponse
from app.prompt import openai_prompt
import app.database.chroma_db as vectordb

router = APIRouter(
prefix="/recommendation",
Expand All @@ -22,8 +25,8 @@
config = configparser.ConfigParser()
config.read(CONFIG_FILE_PATH)

@router.get("/main")
async def get_recommendation():
@router.get("/main", status_code=status.HTTP_200_OK)
async def get_recommendation(user_data: RecommendationMainRequestDTO) -> ChatResponse:

# 모델
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
Expand All @@ -32,9 +35,13 @@ async def get_recommendation():
openai_api_key=OPENAI_API_KEY # API 키
)

# vectordb에서 유저의 정보를 가져온다.
schedule = await vectordb.db_recommendation_main(user_data)

# 템플릿
recommendation_template = openai_prompt.Template.recommendation_template

prompt = PromptTemplate.from_template(recommendation_template)
return chat_model.predict(prompt.format())
result = await chat_model.predict(prompt.format(output_language="Korean", schedule=schedule))
print(result)
return ChatResponse(ness=result)

0 comments on commit 55f776a

Please sign in to comment.