diff --git a/app/database/connect_rds.py b/app/database/connect_rds.py new file mode 100644 index 0000000..c1a9ee4 --- /dev/null +++ b/app/database/connect_rds.py @@ -0,0 +1,31 @@ +import os +import pymysql +from pymysql.cursors import DictCursor +from dotenv import load_dotenv + +# .env 파일에서 환경 변수 로드 +load_dotenv() + +def get_rds_connection(): + return pymysql.connect( + host=os.getenv('RDS_HOST'), + user=os.getenv('RDS_USER'), + password=os.getenv('RDS_PASSWORD'), + database=os.getenv('RDS_DATABASE'), + cursorclass=DictCursor + ) + +def fetch_category_classification_data(member_id): + connection = get_rds_connection() + try: + with connection.cursor() as cursor: + sql = """ + SELECT c.* + FROM category c + WHERE c.member_id = %s + """ + cursor.execute(sql, (member_id,)) + result = cursor.fetchall() + return result + finally: + connection.close() diff --git a/app/prompt/openai_prompt.py b/app/prompt/openai_prompt.py index 9ee23ca..2fff551 100644 --- a/app/prompt/openai_prompt.py +++ b/app/prompt/openai_prompt.py @@ -86,37 +86,42 @@ class Template: YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. Current time is {current_time}. Respond to the user considering the current time. User input: {question} """ + case2_template = """ - {persona} - {chat_type} - The user's input contains information about a new event they want to add to their schedule. You have two tasks to perform: + {persona} + {chat_type} + The user's input contains information about a new event they want to add to their schedule. You have two tasks to perform: - 1. Respond kindly to the user's input. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. - 2. Organize the event the user wants to add into a json format for saving in a database. The returned json will have keys for info, location, person, and date. - - info: Summarizes what the user wants to do. This value must always be present. - - location: If the user's event information includes a place, save that place as the value. - - person: If th e user's event mentions a person they want to include, save that person as the value. - - date: If the user's event information includes a specific date and time, save that date and time in datetime format. Dates should be organized based on the current time at the user's location. Current time is {current_time}. - Separate the outputs for tasks 1 and 2 with a special token . + 1. Respond kindly to the user's input. YOU MUST USE {output_language} TO RESPOND TO THE USER INPUT. + 2. Organize the event the user wants to add into a json format for saving in a database. The returned json will have keys for info, location, person, start_time, end_time, and category. + - info: Summarizes what the user wants to do. This value must always be present. + - location: If the user's event information includes a place, save that place as the value. + - person: If the user's event mentions a person they want to include, save that person as the value. + - start_time: If the user's event information includes a specific date and time, save that date and time in datetime format. Dates should be organized based on the current time at the user's location. Current time is {current_time}. + - end_time: If the user's event information includes an end time, save that date and time in datetime format. + - category: Choose the most appropriate category for the event from the following list: {categories}. + Separate the outputs for tasks 1 and 2 with a special token . - Example for one-shot learning: + Example for one-shot learning: - User input: I have a meeting with Dr. Smith at her office on March 3rd at 10am. + User input: I have a meeting with Dr. Smith at her office on March 3rd from 10am to 11am. - Response to user: - Shall I add your meeting with Dr. Smith at her office on March 3rd at 10am to your schedule? - - {{ - "info": "meeting with Dr. Smith", - "location": "Dr. Smith's office", - "person": "Dr. Smith", - "date": "2023-03-03T10:00:00" - }} - - User input: {question} - - Response to user: - """ + Response to user: + Shall I add your meeting with Dr. Smith at her office on March 3rd from 10am to 11am to your schedule? + + {{ + "info": "meeting with Dr. Smith", + "location": "Dr. Smith's office", + "person": "Dr. Smith", + "start_time": "2023-03-03T10:00:00", + "end_time": "2023-03-03T11:00:00", + "category": "Work" + }} + + User input: {question} + + Response to user: + """ case3_template = """ {persona} diff --git a/app/routers/chat.py b/app/routers/chat.py index 8828d4a..70fb275 100644 --- a/app/routers/chat.py +++ b/app/routers/chat.py @@ -8,6 +8,7 @@ from langchain_core.prompts import PromptTemplate from datetime import datetime +from app.database.connect_rds import fetch_category_classification_data from app.dto.openai_dto import PromptRequest, ChatResponse, ChatCaseResponse from app.prompt import openai_prompt, persona_prompt @@ -102,25 +103,48 @@ async def get_langchain_normal(data: PromptRequest, chat_type_prompt): # case 1 # case 2 : 일정 생성 #@router.post("/case/make_schedule") # 테스트용 엔드포인트 async def get_langchain_schedule(data: PromptRequest, chat_type_prompt): - print("running case 2") - # description: use langchain - config_normal = config['NESS_NORMAL'] - - chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0) - max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수 - model_name=config_normal['MODEL_NAME'], # 모델명 - openai_api_key=OPENAI_API_KEY # API 키 - ) - question = data.prompt - persona = data.persona - user_persona_prompt = persona_prompt.Template.from_persona(persona) - case2_template = openai_prompt.Template.case2_template - - prompt = PromptTemplate.from_template(case2_template) - current_time = datetime.now() - response = chat_model.predict(prompt.format(persona=user_persona_prompt, output_language="Korean", question=question, current_time=current_time, chat_type=chat_type_prompt)) - print(response) - return response + try: + print("running case 2") + member_id = data.member_id + categories = fetch_category_classification_data(member_id) + # 카테고리 데이터를 텍스트로 변환 (JSON 형식으로 변환) + categories_text = ", ".join([category['name'] for category in categories]) + print(categories_text) + + # description: use langchain + config_normal = config['NESS_NORMAL'] + + chat_model = ChatOpenAI(temperature=config_normal['TEMPERATURE'], # 창의성 (0.0 ~ 2.0) + max_tokens=config_normal['MAX_TOKENS'], # 최대 토큰수 + model_name=config_normal['MODEL_NAME'], # 모델명 + openai_api_key=OPENAI_API_KEY # API 키 + ) + + question = data.prompt + persona = data.persona + user_persona_prompt = persona_prompt.Template.from_persona(persona) + case2_template = openai_prompt.Template.case2_template + + prompt = PromptTemplate.from_template(case2_template) + current_time = datetime.now() + + # OpenAI 프롬프트에 데이터 통합 + response = chat_model.predict( + prompt.format( + persona=user_persona_prompt, + output_language="Korean", + question=question, + current_time=current_time, + chat_type=chat_type_prompt, + categories=categories_text # 카테고리 데이터를 프롬프트에 포함 + ) + ) + + print(response) + return response + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) # case 3 : rag #@router.post("/case/rag") # 테스트용 엔드포인트 diff --git a/requirements.txt b/requirements.txt index e6ace11..4f4da86 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ python-dotenv==1.0.0 starlette==0.35.1 pydantic==2.5.3 sentence-transformers==2.5.1 +pymysql \ No newline at end of file