Skip to content

Commit

Permalink
[feat] report tags complete: added db metadatas, changed db query logic
Browse files Browse the repository at this point in the history
  • Loading branch information
uommou committed Apr 7, 2024
1 parent a2bab8b commit 68d36ef
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 12 deletions.
25 changes: 21 additions & 4 deletions app/database/chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@ async def search_db_query(query):
# 스프링 백엔드로부터 chroma DB에 저장할 데이터를 받아 DB에 추가한다.
async def add_db_data(schedule_data: AddScheduleDTO):
schedule_date = schedule_data.schedule_datetime_start.split("T")[0]
year = int(schedule_date.split("-")[0])
month = int(schedule_date.split("-")[1])
date = int(schedule_date.split("-")[2])
schedules.add(
documents=[schedule_data.data],
ids=[str(schedule_data.schedule_id)],
metadatas=[{"date": schedule_date, "datetime_start": schedule_data.schedule_datetime_start, "datetime_end": schedule_data.schedule_datetime_end, "member": schedule_data.member_id, "category": schedule_data.category, "location": schedule_data.location, "person": schedule_data.person}]
metadatas=[{"year": year, "month": month, "date": date, "datetime_start": schedule_data.schedule_datetime_start, "datetime_end": schedule_data.schedule_datetime_end, "member": schedule_data.member_id, "category": schedule_data.category, "location": schedule_data.location, "person": schedule_data.person}]
)
return True

Expand All @@ -62,14 +65,19 @@ async def db_daily_schedule(user_data: RecommendationMainRequestDTO):
schedule_datetime_start = user_data.schedule_datetime_start
schedule_datetime_end = user_data.schedule_datetime_end
schedule_date = schedule_datetime_start.split("T")[0]
year = int(schedule_date.split("-")[0])
month = int(schedule_date.split("-")[1])
date = int(schedule_date.split("-")[2])
persona = user_data.user_persona or "hard working"
results = schedules.query(
query_texts=[persona],
n_results=5,
where={"$and":
[
{"member": {"$eq": int(member)}},
{"date": {"$eq": schedule_date}}
{"year": {"$eq": year}},
{"month": {"$eq": month}},
{"date": {"$eq": date}}
]
}
# where_document={"$contains":"search_string"} # optional filter
Expand All @@ -82,14 +90,23 @@ async def db_monthly_tag_schedule(user_data: ReportTagsRequestDTO):
schedule_datetime_start = user_data.schedule_datetime_start
schedule_datetime_end = user_data.schedule_datetime_end
schedule_date = schedule_datetime_start.split("T")[0]
year = int(schedule_date.split("-")[0])
month = int(schedule_date.split("-")[1])
date = int(schedule_date.split("-")[2])
persona = user_data.user_persona or "hard working"
results = schedules.query(
query_texts=[persona],
n_results=5,
n_results=15,
where={"$and":
[
{"member": {"$eq": int(member)}},
{"date": {"$eq": schedule_date}}
{"year": {"$eq": year}},
{"$or":
[{"$and":
[{"month": {"$eq": month-1}}, {"date": {"$gte": 10}}]},
{"$and":
[{"month": {"$eq": month}}, {"date": {"$lt": 10}}]}
]}
]
}
# where_document={"$contains":"search_string"} # optional filter
Expand Down
5 changes: 4 additions & 1 deletion app/dto/openai_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ class ChatResponse(BaseModel):

class ChatCaseResponse(BaseModel):
ness: str
case: int
case: int

class TagsResponse(BaseModel):
tags: list
7 changes: 6 additions & 1 deletion app/prompt/openai_config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,9 @@ MODEL_NAME = gpt-4
[NESS_RECOMMENDATION]
TEMPERATURE = 0
MAX_TOKENS = 2048
MODEL_NAME = gpt-3.5-turbo-1106
MODEL_NAME = gpt-3.5-turbo-1106

[NESS_TAGS]
TEMPERATURE = 0
MAX_TOKENS = 2048
MODEL_NAME = gpt-4
18 changes: 18 additions & 0 deletions app/prompt/report_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,22 @@ class Template:
User schedule: {schedule}
AI Recommendation:
"""

report_tags_template = """
You are an AI assistant tasked with analyzing a user's schedule over the span of a month. From this detailed schedule, you will distill three keywords that best encapsulate the user's activities, interests, or achievements throughout the month. These keywords should not only reflect the user's endeavors but also convey a sense of accomplishment and enjoyment. Your output should be engaging, showcasing your wit and unique perspective. Here are the rules for your analysis:
YOU MUST USE {output_language} TO RESPOND TO THE INPUT.
YOU MUST PROVIDE THREE KEYWORDS in your response. Each keyword should be a single word or a concise phrase.
The keywords must capture the essence of the user's monthly activities, highlighting aspects that are both rewarding and enjoyable.
Your selections should be creative and personalized, aiming to reflect the user's unique experiences over the month.
Example:
User's monthly schedule: [Attended a programming bootcamp, Completed a marathon, Read three novels, Volunteered at the local food bank, Started a blog about sustainability]
AI Recommendation: "공부 매니아, 환경 지킴이, 자기 계발 홀릭"
User's monthly schedule: [Took photography classes, Explored three new hiking trails, Organized a neighborhood clean-up, Experimented with vegan recipes]
AI Recommendation: "모험가, 미식가, 도파민 중독자"
User's monthly schedule: {schedule}
AI Recommendation:
"""
15 changes: 9 additions & 6 deletions app/routers/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain_core.prompts import PromptTemplate

from app.dto.db_dto import ReportMemoryEmojiRequestDTO, ReportTagsRequestDTO
from app.dto.openai_dto import ChatResponse
from app.dto.openai_dto import ChatResponse, TagsResponse
from app.prompt import report_prompt
import app.database.chroma_db as vectordb

Expand Down Expand Up @@ -54,10 +54,10 @@ async def get_memory_emoji(user_data: ReportMemoryEmojiRequestDTO) -> ChatRespon
raise HTTPException(status_code=500, detail=str(e))

@router.post("/tags", status_code=status.HTTP_200_OK)
async def get_tags(user_data: ReportTagsRequestDTO) -> ChatResponse:
async def get_tags(user_data: ReportTagsRequestDTO) -> TagsResponse:
try:
# 모델
config_tags = config['NESS_RECOMMENDATION']
config_tags = config['NESS_TAGS']

chat_model = ChatOpenAI(temperature=config_tags['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_tags['MAX_TOKENS'], # 최대 토큰수
Expand All @@ -71,12 +71,15 @@ async def get_tags(user_data: ReportTagsRequestDTO) -> ChatResponse:
print(schedule)

# 템플릿
memory_emoji_template = report_prompt.Template.memory_emoji_template
report_tags_template = report_prompt.Template.report_tags_template

prompt = PromptTemplate.from_template(memory_emoji_template)
prompt = PromptTemplate.from_template(report_tags_template)
result = chat_model.predict(prompt.format(output_language="Korean", schedule=schedule))
print(result)
return ChatResponse(ness=result)
tags = result.split("\"")[1].split(",")
tags = [tag.strip() for tag in tags]
print(tags)
return TagsResponse(tags=tags)

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

0 comments on commit 68d36ef

Please sign in to comment.