Skip to content

Commit

Permalink
Merge pull request #34 from studio-recoding/feat/cicd
Browse files Browse the repository at this point in the history
[fix] 의존성 수정
  • Loading branch information
uommou authored Mar 24, 2024
2 parents b80a56f + c2215ac commit d45cd8d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 24 deletions.
14 changes: 7 additions & 7 deletions app/database/chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from chromadb.utils import embedding_functions
from chromadb.config import Settings

# FAST API
from fastapi import Depends

# ETC
import os
import datetime
Expand All @@ -11,12 +14,6 @@

load_dotenv()
CHROMA_DB_IP_ADDRESS = os.getenv("CHROMA_DB_IP_ADDRESS")
# # 로컬에 ChromaDB가 저장된 경우
# DIR = os.path.dirname(os.path.abspath(__file__))
# DB_PATH = os.path.join(DIR, 'data')
# # 로컬 디스크에서 ChromaDB 실행: allow_reset으로 데이터베이스 초기화, anonymized_telemetry를 False로 설정하여 텔레메트리 수집 비활성화
# chroma_client = chromadb.PersistentClient(path=DB_PATH,
# settings=Settings(allow_reset=True, anonymized_telemetry=False))

# description: 원격 EC2 인스턴스에서 ChromaDB에 연결
chroma_client = chromadb.HttpClient(host=CHROMA_DB_IP_ADDRESS, port=8000)
Expand Down Expand Up @@ -55,4 +52,7 @@ async def add_db_data(schedule_data: AddScheduleDTO):
ids=[str(schedule_data.schedule_id)],
metadatas=[{"datetime": schedule_data.schedule_datetime, "member": schedule_data.member_id, "category": schedule_data.category}]
)
return True
return True

def get_chroma_client():
return chroma_client
3 changes: 3 additions & 0 deletions app/dto/openai_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@

class PromptRequest(BaseModel):
prompt: str

class ChatResponse(BaseModel):
ness: str
2 changes: 1 addition & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dotenv import load_dotenv

# BACKEND
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, Depends

# ETC
import os
Expand Down
20 changes: 10 additions & 10 deletions app/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import os

from dotenv import load_dotenv
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, status
from langchain_community.chat_models import ChatOpenAI
from langchain_core.prompts import PromptTemplate

from app.dto import openai_dto
from app.dto.openai_dto import PromptRequest, ChatResponse
from app.prompt import openai_prompt

import app.database.chroma_db as vectordb
Expand All @@ -26,8 +26,8 @@
config = configparser.ConfigParser()
config.read(CONFIG_FILE_PATH)

@router.post("/case")
async def get_langchain_case(data: openai_dto.PromptRequest):
@router.post("/case", status_code=status.HTTP_200_OK, response_model=ChatResponse)
async def get_langchain_case(data: PromptRequest) -> ChatResponse:
# description: use langchain

config_normal = config['NESS_NORMAL']
Expand Down Expand Up @@ -64,7 +64,7 @@ async def get_langchain_case(data: openai_dto.PromptRequest):

# case 1 : normal
#@router.post("/case/normal") # 테스트용 엔드포인트
async def get_langchain_normal(data: openai_dto.PromptRequest): # case 1 : normal
async def get_langchain_normal(data: PromptRequest) -> ChatResponse: # case 1 : normal
print("running case 1")
# description: use langchain
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
Expand All @@ -82,11 +82,11 @@ async def get_langchain_normal(data: openai_dto.PromptRequest): # case 1 : norma
prompt = PromptTemplate.from_template(my_template)
response = chat_model.predict(prompt.format(output_language="Korean", question=question))
print(response)
return response
return ChatResponse(ness=response)

# case 2 : 일정 생성
#@router.post("/case/make_schedule") # 테스트용 엔드포인트
async def get_langchain_schedule(data: openai_dto.PromptRequest):
async def get_langchain_schedule(data: PromptRequest) -> ChatResponse:
print("running case 2")
# description: use langchain
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
Expand All @@ -100,11 +100,11 @@ async def get_langchain_schedule(data: openai_dto.PromptRequest):
prompt = PromptTemplate.from_template(case2_template)
response = chat_model.predict(prompt.format(output_language="Korean", question=question))
print(response)
return response
return ChatResponse(ness=response)

# case 3 : rag
#@router.post("/case/rag") # 테스트용 엔드포인트
async def get_langchain_rag(data: openai_dto.PromptRequest):
async def get_langchain_rag(data: PromptRequest) -> ChatResponse:
print("running case 3")
# description: use langchain
chat_model = ChatOpenAI(temperature=0, # 창의성 (0.0 ~ 2.0)
Expand All @@ -126,4 +126,4 @@ async def get_langchain_rag(data: openai_dto.PromptRequest):
# 비동기 함수라면, 예: response = await chat_model.predict(...) 형태로 수정해야 합니다.
response = chat_model.predict(prompt.format(output_language="Korean", question=question, schedule=schedule))
print(response)
return response
return ChatResponse(ness=response)
12 changes: 6 additions & 6 deletions app/routers/chromadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import os

from dotenv import load_dotenv
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Depends

from app.dto.db_dto import AddScheduleDTO
from app.database.chroma_db import add_db_data, get_chroma_client

router = APIRouter(
prefix="/chromadb",
Expand All @@ -22,11 +23,10 @@


@router.post("/add_schedule")
async def add_schedule_endpoint(schedule_data: AddScheduleDTO, vectordb=None):
async def add_schedule_endpoint(schedule_data: AddScheduleDTO, chroma_client=Depends(get_chroma_client)):
try:
# vectordb.add_db_data 함수를 비동기적으로 호출합니다.
await vectordb.add_db_data(schedule_data)
# 직접 `add_db_data` 함수를 비동기적으로 호출합니다.
await add_db_data(schedule_data)
return {"message": "Schedule added successfully"}
except Exception as e:
# 에러 처리: 에러가 발생하면 HTTP 500 응답을 반환합니다.
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))

0 comments on commit d45cd8d

Please sign in to comment.