Skip to content

Commit

Permalink
Merge pull request #61 from studio-recoding/feat/category
Browse files Browse the repository at this point in the history
feat: 카테고리 분류 추가
  • Loading branch information
uommou authored May 25, 2024
2 parents cdb0c5a + 62163fb commit c933d4c
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 45 deletions.
31 changes: 31 additions & 0 deletions app/database/connect_rds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
import pymysql
from pymysql.cursors import DictCursor
from dotenv import load_dotenv

# .env 파일에서 환경 변수 로드
load_dotenv()

def get_rds_connection():
return pymysql.connect(
host=os.getenv('RDS_HOST'),
user=os.getenv('RDS_USER'),
password=os.getenv('RDS_PASSWORD'),
database=os.getenv('RDS_DATABASE'),
cursorclass=DictCursor
)

def fetch_category_classification_data(member_id):
connection = get_rds_connection()
try:
with connection.cursor() as cursor:
sql = """
SELECT c.*
FROM category c
WHERE c.member_id = %s
"""
cursor.execute(sql, (member_id,))
result = cursor.fetchall()
return result
finally:
connection.close()
57 changes: 31 additions & 26 deletions app/prompt/openai_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,37 +86,42 @@ class Template:
YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Current time is {current_time}. Respond to the user considering the current time.
User input: {question}
"""

case2_template = """
{persona}
{chat_type}
The user's input contains information about a new event they want to add to their schedule. You have two tasks to perform:
{persona}
{chat_type}
The user's input contains information about a new event they want to add to their schedule. You have two tasks to perform:
1. Respond kindly to the user's input. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT.
2. Organize the event the user wants to add into a json format for saving in a database. The returned json will have keys for info, location, person, and date.
- info: Summarizes what the user wants to do. This value must always be present.
- location: If the user's event information includes a place, save that place as the value.
- person: If th e user's event mentions a person they want to include, save that person as the value.
- date: If the user's event information includes a specific date and time, save that date and time in datetime format. Dates should be organized based on the current time at the user's location. Current time is {current_time}.
Separate the outputs for tasks 1 and 2 with a special token <separate>.
1. Respond kindly to the user's input. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT.
2. Organize the event the user wants to add into a json format for saving in a database. The returned json will have keys for info, location, person, start_time, end_time, and category.
- info: Summarizes what the user wants to do. This value must always be present.
- location: If the user's event information includes a place, save that place as the value.
- person: If the user's event mentions a person they want to include, save that person as the value.
- start_time: If the user's event information includes a specific date and time, save that date and time in datetime format. Dates should be organized based on the current time at the user's location. Current time is {current_time}.
- end_time: If the user's event information includes an end time, save that date and time in datetime format.
- category: Choose the most appropriate category for the event from the following list: {categories}.
Separate the outputs for tasks 1 and 2 with a special token <separate>.
Example for one-shot learning:
Example for one-shot learning:
User input: I have a meeting with Dr. Smith at her office on March 3rd at 10am.
User input: I have a meeting with Dr. Smith at her office on March 3rd from 10am to 11am.
Response to user:
Shall I add your meeting with Dr. Smith at her office on March 3rd at 10am to your schedule?
<separate>
{{
"info": "meeting with Dr. Smith",
"location": "Dr. Smith's office",
"person": "Dr. Smith",
"date": "2023-03-03T10:00:00"
}}
User input: {question}
Response to user:
"""
Response to user:
Shall I add your meeting with Dr. Smith at her office on March 3rd from 10am to 11am to your schedule?
<separate>
{{
"info": "meeting with Dr. Smith",
"location": "Dr. Smith's office",
"person": "Dr. Smith",
"start_time": "2023-03-03T10:00:00",
"end_time": "2023-03-03T11:00:00",
"category": "Work"
}}
User input: {question}
Response to user:
"""

case3_template = """
{persona}
Expand Down
62 changes: 43 additions & 19 deletions app/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langchain_core.prompts import PromptTemplate
from datetime import datetime

from app.database.connect_rds import fetch_category_classification_data
from app.dto.openai_dto import PromptRequest, ChatResponse, ChatCaseResponse
from app.prompt import openai_prompt, persona_prompt

Expand Down Expand Up @@ -102,25 +103,48 @@ async def get_langchain_normal(data: PromptRequest, chat_type_prompt): # case 1
# case 2 : 일정 생성
#@router.post("/case/make_schedule") # 테스트용 엔드포인트
async def get_langchain_schedule(data: PromptRequest, chat_type_prompt):
print("running case 2")
# description: use langchain
config_normal = config['NESS_NORMAL']

chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
model_name=config_normal['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)
question = data.prompt
persona = data.persona
user_persona_prompt = persona_prompt.Template.from_persona(persona)
case2_template = openai_prompt.Template.case2_template

prompt = PromptTemplate.from_template(case2_template)
current_time = datetime.now()
response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt))
print(response)
return response
try:
print("running case 2")
member_id = data.member_id
categories = fetch_category_classification_data(member_id)
# 카테고리 데이터를 텍스트로 변환 (JSON 형식으로 변환)
categories_text = ", ".join([category['name'] for category in categories])
print(categories_text)

# description: use langchain
config_normal = config['NESS_NORMAL']

chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0)
max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수
model_name=config_normal['MODEL_NAME'], # 모델명
openai_api_key=OPENAI_API_KEY # API 키
)

question = data.prompt
persona = data.persona
user_persona_prompt = persona_prompt.Template.from_persona(persona)
case2_template = openai_prompt.Template.case2_template

prompt = PromptTemplate.from_template(case2_template)
current_time = datetime.now()

# OpenAI 프롬프트에 데이터 통합
response = chat_model.predict(
prompt.format(
persona=user_persona_prompt,
output_language="Korean",
question=question,
current_time=current_time,
chat_type=chat_type_prompt,
categories=categories_text # 카테고리 데이터를 프롬프트에 포함
)
)

print(response)
return response

except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

# case 3 : rag
#@router.post("/case/rag") # 테스트용 엔드포인트
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ python-dotenv==1.0.0
starlette==0.35.1
pydantic==2.5.3
sentence-transformers==2.5.1
pymysql

0 comments on commit c933d4c

Please sign in to comment.