Skip to content

Commit

Permalink
[feat] tag generation rough sketch
Browse files Browse the repository at this point in the history
  • Loading branch information
uommou committed Apr 7, 2024
1 parent 5f93d90 commit a2bab8b
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
22 changes: 21 additions & 1 deletion app/database/chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import datetime
from dotenv import load_dotenv
from app.dto.db_dto import AddScheduleDTO, RecommendationMainRequestDTO
from app.dto.db_dto import AddScheduleDTO, RecommendationMainRequestDTO, ReportTagsRequestDTO

load_dotenv()
CHROMA_DB_IP_ADDRESS = os.getenv("CHROMA_DB_IP_ADDRESS")
Expand Down Expand Up @@ -76,5 +76,25 @@ async def db_daily_schedule(user_data: RecommendationMainRequestDTO):
)
return results['documents']

# 태그 생성용 스케쥴 반환 - 카테고리에 따라
async def db_monthly_tag_schedule(user_data: ReportTagsRequestDTO):
member = user_data.member_id
schedule_datetime_start = user_data.schedule_datetime_start
schedule_datetime_end = user_data.schedule_datetime_end
schedule_date = schedule_datetime_start.split("T")[0]
persona = user_data.user_persona or "hard working"
results = schedules.query(
query_texts=[persona],
n_results=5,
where={"$and":
[
{"member": {"$eq": int(member)}},
{"date": {"$eq": schedule_date}}
]
}
# where_document={"$contains":"search_string"} # optional filter
)
return results['documents']

def get_chroma_client():
return chroma_client
7 changes: 7 additions & 0 deletions app/dto/db_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,10 @@ class ReportMemoryEmojiRequestDTO(BaseModel):
user_persona: str
schedule_datetime_start: str
schedule_datetime_end: str

class ReportTagsRequestDTO(BaseModel):
member_id: int
user_persona: str
schedule_datetime_start: str
schedule_datetime_end: str

29 changes: 28 additions & 1 deletion app/routers/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from langchain_community.chat_models import ChatOpenAI
from langchain_core.prompts import PromptTemplate

from app.dto.db_dto import ReportMemoryEmojiRequestDTO
from app.dto.db_dto import ReportMemoryEmojiRequestDTO, ReportTagsRequestDTO
from app.dto.openai_dto import ChatResponse
from app.prompt import report_prompt
import app.database.chroma_db as vectordb
Expand Down Expand Up @@ -53,3 +53,30 @@ async def get_memory_emoji(user_data: ReportMemoryEmojiRequestDTO) -> ChatRespon
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

@router.post("/tags", status_code=status.HTTP_200_OK)
async def get_tags(user_data: ReportTagsRequestDTO) -> ChatResponse:
try:
# 모델
config_tags = config['NESS_RECOMMENDATION']

chat_model = ChatOpenAI(temperature=config_tags['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_tags['MAX_TOKENS'], # 최대 토큰수
model_name=config_tags['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)

# vectordb에서 유저의 정보를 가져온다.
schedule = await vectordb.db_monthly_tag_schedule(user_data)

print(schedule)

# 템플릿
memory_emoji_template = report_prompt.Template.memory_emoji_template

prompt = PromptTemplate.from_template(memory_emoji_template)
result = chat_model.predict(prompt.format(output_language="Korean", schedule=schedule))
print(result)
return ChatResponse(ness=result)

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

0 comments on commit a2bab8b

Please sign in to comment.