diff --git a/.env.example b/.env.example deleted file mode 100644 index 9c621033f..000000000 --- a/.env.example +++ /dev/null @@ -1,4 +0,0 @@ -OPENAI_MODEL_ENGINE = 'gpt-3.5-turbo' -SYSTEM_MESSAGE = 'You are a helpful assistant.' -LINE_CHANNEL_SECRET = -LINE_CHANNEL_ACCESS_TOKEN = \ No newline at end of file diff --git a/.gcloudignore b/.gcloudignore new file mode 100644 index 000000000..5e4d69bba --- /dev/null +++ b/.gcloudignore @@ -0,0 +1,9 @@ +.git/ + +!config/config.yml +!config/ssl/ca-cert.crt +!config/ssl/client-cert.pem +!config/ssl/client-key.pem +!config/ssl/server-ca.pem +!config/ssl/ssl-cert.crt +!config/ssl/ssl-key.key \ No newline at end of file diff --git a/.gitignore b/.gitignore index bd456a4c3..cdae211cf 100644 --- a/.gitignore +++ b/.gitignore @@ -74,6 +74,12 @@ web_modules/ .env .env.test +# config yml file +config/config.yml + +# ssl setting files +config/ssl/* + # parcel-bundler cache (https://parceljs.org/) .cache .parcel-cache @@ -92,12 +98,18 @@ dist # https://nextjs.org/blog/next-9-1#public-directory-support # public +*.pyc + # vuepress build output .vuepress/dist # Serverless directories .serverless/ +# Testing notebook +.ipynb_checkpoints/ +*.ipynb + # FuseBox cache .fusebox/ @@ -117,4 +129,3 @@ dist .yarn/install-state.gz .pnp.* - diff --git a/Dockerfile b/Dockerfile index f19e12e1f..26ca37d08 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,12 @@ -FROM python:3.9-alpine +FROM python:3.12 +ENV PYTHONUNBUFFERED True COPY ./ /ChatGPT-Line-Bot WORKDIR /ChatGPT-Line-Bot - +RUN find . -name "*.ipynb" -exec rm {} + +RUN apt-get install libpq-dev RUN pip3 install -r requirements.txt -CMD ["python3", "main.py"] \ No newline at end of file +# CMD ["python3", "main.py"] +CMD nohup gunicorn -w 4 main:app --access-logfile /var/log/gunicorn_access.txt --error-logfile /var/log/gunicorn_error.txt -b :8080 --timeout 120 \ No newline at end of file diff --git a/config/config.yml.example b/config/config.yml.example new file mode 100644 index 000000000..9aa90ac1f --- /dev/null +++ b/config/config.yml.example @@ -0,0 +1,25 @@ +--- +line: + channel_access_token: + channel_secret: + +openai: + api_key: + assistant_id: + +db: + host: + port: 5432 + db_name: + user: + password: + sslmode: + sslrootcert: + sslcert: + sslkey: + +logfile: ./logs/linebot.log + +commands: + help: "這裡是台南市議會聊天機器人,目前已經輸入了台南市議會第四屆公開議事錄中的會議逐字稿,請輸入您的問題,以便我檢索逐字稿內容來回應您。若您希望重設聊天內容,請輸入「/reset」以重置聊天。" +... diff --git a/main.py b/main.py index 52680e7fc..2baa32dbd 100644 --- a/main.py +++ b/main.py @@ -1,40 +1,49 @@ -from dotenv import load_dotenv -from flask import Flask, request, abort +from flask import Flask, request, abort,render_template, jsonify from linebot import ( - LineBotApi, WebhookHandler + LineBotApi ) -from linebot.exceptions import ( +from linebot.v3 import ( + WebhookHandler +) +from linebot.v3.exceptions import ( InvalidSignatureError ) -from linebot.models import ( - MessageEvent, TextMessage, TextSendMessage, ImageSendMessage, AudioMessage + +from linebot.v3.messaging import ( + Configuration, + ApiClient, + MessagingApi, + MessagingApiBlob, + ReplyMessageRequest, + TextMessage +) + +from linebot.v3.webhooks import ( + MessageEvent, + TextMessageContent, + AudioMessageContent ) + + import os import uuid +import time from src.models import OpenAIModel -from src.memory import Memory +from src.config import load_config from src.logger import logger -from src.storage import Storage, FileStorage, MongoStorage -from src.utils import get_role_and_content -from src.service.youtube import Youtube, YoutubeTranscriptReader -from src.service.website import Website, WebsiteReader -from src.mongodb import mongodb - -load_dotenv('.env') +from src.db import Database +from src.utils import get_response_data, get_content_and_reference, replace_file_name, check_token_valid, get_file_dict, detect_none_references app = Flask(__name__) -line_bot_api = LineBotApi(os.getenv('LINE_CHANNEL_ACCESS_TOKEN')) -handler = WebhookHandler(os.getenv('LINE_CHANNEL_SECRET')) -storage = None -youtube = Youtube(step=4) -website = Website() - - -memory = Memory(system_message=os.getenv('SYSTEM_MESSAGE'), memory_message_count=2) -model_management = {} -api_keys = {} - +config = load_config() +configuration = Configuration(access_token=config['line']['channel_access_token']) +handler = WebhookHandler(config['line']['channel_secret']) +openai_api_key = config['openai']['api_key'] +openai_assistant_id = config['openai']['assistant_id'] +database = Database(config['db']) +model = OpenAIModel(api_key=openai_api_key, assistant_id=openai_assistant_id) +file_dict = get_file_dict(model) @app.route("/callback", methods=['POST']) def callback(): @@ -48,131 +57,122 @@ def callback(): abort(400) return 'OK' - -@handler.add(MessageEvent, message=TextMessage) -def handle_text_message(event): - user_id = event.source.user_id - text = event.message.text.strip() +def handle_assistant_message(user_id, text): logger.info(f'{user_id}: {text}') - + database.check_connect() + logger.debug('database check done') + global file_dict try: - if text.startswith('/註冊'): - api_key = text[3:].strip() - model = OpenAIModel(api_key=api_key) - is_successful, _, _ = model.check_token_valid() + if text.startswith('/reset'): + thread_id = database.query_thread(user_id) + if thread_id: + model.delete_thread(thread_id) + database.delete_thread(user_id) + msg = TextMessage(text='Reset The Chatbot.') + else: + msg = TextMessage(text='Nothing to reset.') + elif text.startswith('/'): + command = input_string[1:].split()[0] + if command in config['commands']: + msg = TextMessage(text=config['commands'][command] + "\n\n") + else: + msg = TextMessage(text="Command not found.") + else: + thread_id = database.query_thread(user_id) + if thread_id: + is_successful, response, error_message = model.retrieve_thread(thread_id) + if not is_successful: + database.delete_thread(user_id) + thread_id = None + if not thread_id: + is_successful, response, error_message = model.create_thread() + if not is_successful: + raise Exception(error_message) + else: + thread_id = response['id'] + logger.debug('thread_id: ' + thread_id) + database.save_thread(user_id, thread_id) + is_successful, response, error_message = model.create_thread_message(thread_id, text) if not is_successful: - raise ValueError('Invalid API token') - model_management[user_id] = model - storage.save({ - user_id: api_key - }) - msg = TextSendMessage(text='Token 有效,註冊成功') - - elif text.startswith('/指令說明'): - msg = TextSendMessage(text="指令:\n/註冊 + API Token\n👉 API Token 請先到 https://platform.openai.com/ 註冊登入後取得\n\n/系統訊息 + Prompt\n👉 Prompt 可以命令機器人扮演某個角色,例如:請你扮演擅長做總結的人\n\n/清除\n👉 當前每一次都會紀錄最後兩筆歷史紀錄,這個指令能夠清除歷史訊息\n\n/圖像 + Prompt\n👉 會調用 DALL∙E 2 Model,以文字生成圖像\n\n語音輸入\n👉 會調用 Whisper 模型,先將語音轉換成文字,再調用 ChatGPT 以文字回覆\n\n其他文字輸入\n👉 調用 ChatGPT 以文字回覆") - - elif text.startswith('/系統訊息'): - memory.change_system_message(user_id, text[5:].strip()) - msg = TextSendMessage(text='輸入成功') - - elif text.startswith('/清除'): - memory.remove(user_id) - msg = TextSendMessage(text='歷史訊息清除成功') - - elif text.startswith('/圖像'): - prompt = text[3:].strip() - memory.append(user_id, 'user', prompt) - is_successful, response, error_message = model_management[user_id].image_generations(prompt) + raise Exception(error_message) + is_successful, response, error_message = model.create_thread_run(thread_id) if not is_successful: raise Exception(error_message) - url = response['data'][0]['url'] - msg = ImageSendMessage( - original_content_url=url, - preview_image_url=url - ) - memory.append(user_id, 'assistant', url) - - else: - user_model = model_management[user_id] - memory.append(user_id, 'user', text) - url = website.get_url_from_text(text) - if url: - if youtube.retrieve_video_id(text): - is_successful, chunks, error_message = youtube.get_transcript_chunks(youtube.retrieve_video_id(text)) - if not is_successful: - raise Exception(error_message) - youtube_transcript_reader = YoutubeTranscriptReader(user_model, os.getenv('OPENAI_MODEL_ENGINE')) - is_successful, response, error_message = youtube_transcript_reader.summarize(chunks) - if not is_successful: - raise Exception(error_message) - role, response = get_role_and_content(response) - msg = TextSendMessage(text=response) + while response['status'] != 'completed': + run_id = response['id'] + if response['status'] == 'queued': + time.sleep(10) else: - chunks = website.get_content_from_url(url) - if len(chunks) == 0: - raise Exception('無法撈取此網站文字') - website_reader = WebsiteReader(user_model, os.getenv('OPENAI_MODEL_ENGINE')) - is_successful, response, error_message = website_reader.summarize(chunks) - if not is_successful: - raise Exception(error_message) - role, response = get_role_and_content(response) - msg = TextSendMessage(text=response) - else: - is_successful, response, error_message = user_model.chat_completions(memory.get(user_id), os.getenv('OPENAI_MODEL_ENGINE')) + time.sleep(3) + is_successful, response, error_message = model.retrieve_thread_run(thread_id, run_id) + logger.debug(run_id + ': ' + response['status']) if not is_successful: raise Exception(error_message) - role, response = get_role_and_content(response) - msg = TextSendMessage(text=response) - memory.append(user_id, role, response) - except ValueError: - msg = TextSendMessage(text='Token 無效,請重新註冊,格式為 /註冊 sk-xxxxx') - except KeyError: - msg = TextSendMessage(text='請先註冊 Token,格式為 /註冊 sk-xxxxx') + logger.debug(response) + is_successful, response, error_message = model.list_thread_messages(thread_id) + if not is_successful: + raise Exception(error_message) + logger.debug(response) + response_message = get_content_and_reference(response, file_dict) + if detect_none_references(response_message): + file_dict = get_file_dict() + response_message = get_content_and_reference(response, file_dict) + logger.debug(response_message) + msg = TextMessage(text=response_message) except Exception as e: - memory.remove(user_id) if str(e).startswith('Incorrect API key provided'): - msg = TextSendMessage(text='OpenAI API Token 有誤,請重新註冊。') + msg = TextMessage(text='OpenAI API Token 有誤,請重新註冊。') elif str(e).startswith('That model is currently overloaded with other requests.'): - msg = TextSendMessage(text='已超過負荷,請稍後再試') + msg = TextMessage(text='已超過負荷,請稍後再試') else: - msg = TextSendMessage(text=str(e)) - line_bot_api.reply_message(event.reply_token, msg) + msg = TextMessage(text='發生錯誤:' + str(e)) + return msg -@handler.add(MessageEvent, message=AudioMessage) -def handle_audio_message(event): +@handler.add(MessageEvent, message=TextMessageContent) +def handle_text_message(event): user_id = event.source.user_id - audio_content = line_bot_api.get_message_content(event.message.id) - input_audio_path = f'{str(uuid.uuid4())}.m4a' - with open(input_audio_path, 'wb') as fd: - for chunk in audio_content.iter_content(): - fd.write(chunk) + text = event.message.text.strip() + msg = handle_assistant_message(user_id, text) + with ApiClient(configuration) as api_client: + line_bot_api = MessagingApi(api_client) + line_bot_api.reply_message_with_http_info( + ReplyMessageRequest( + reply_token=event.reply_token, + messages=[msg] + ) + ) - try: - if not model_management.get(user_id): - raise ValueError('Invalid API token') - else: - is_successful, response, error_message = model_management[user_id].audio_transcriptions(input_audio_path, 'whisper-1') - if not is_successful: - raise Exception(error_message) - memory.append(user_id, 'user', response['text']) - is_successful, response, error_message = model_management[user_id].chat_completions(memory.get(user_id), 'gpt-3.5-turbo') +@handler.add(MessageEvent, message=AudioMessageContent) +def handle_audio_message(event): + user_id = event.source.user_id + with ApiClient(configuration) as api_client: + line_bot_blob_api = MessagingApiBlob(api_client) + audio_content = line_bot_blob_api.get_message_content(message_id=event.message.id) + input_audio_path = f'{str(uuid.uuid4())}.m4a' + with open(input_audio_path, 'wb') as fd: + fd.write(audio_content) + try: + is_successful, response, error_message = model.audio_transcriptions(input_audio_path, 'whisper-1') if not is_successful: raise Exception(error_message) - role, response = get_role_and_content(response) - memory.append(user_id, role, response) - msg = TextSendMessage(text=response) - except ValueError: - msg = TextSendMessage(text='請先註冊你的 API Token,格式為 /註冊 [API TOKEN]') - except KeyError: - msg = TextSendMessage(text='請先註冊 Token,格式為 /註冊 sk-xxxxx') - except Exception as e: - memory.remove(user_id) - if str(e).startswith('Incorrect API key provided'): - msg = TextSendMessage(text='OpenAI API Token 有誤,請重新註冊。') - else: - msg = TextSendMessage(text=str(e)) - os.remove(input_audio_path) + text = response['text'] + msg = handle_assistant_message(user_id, text) + except Exception as e: + if str(e).startswith('Incorrect API key provided'): + msg = TextMessage(text='OpenAI API Token 有誤,請重新註冊。') + elif str(e).startswith('That model is currently overloaded with other requests.'): + msg = TextMessage(text='已超過負荷,請稍後再試') + else: + msg = TextMessage(text='發生錯誤:' + str(e)) + os.remove(input_audio_path) + line_bot_api = MessagingApi(api_client) + line_bot_api.reply_message_with_http_info( + ReplyMessageRequest( + reply_token=event.reply_token, + messages=[msg] + ) + ) line_bot_api.reply_message(event.reply_token, msg) @@ -180,17 +180,21 @@ def handle_audio_message(event): def home(): return 'Hello World' +@app.route('/chat') +def index(): + return render_template('chat.html') + +@app.route('/ask', methods=['POST']) +def ask(): + user_message = request.json['message'] + response_message = ask_api(user_message) + return jsonify({'message': response_message}) + +def ask_api(message): + text_content = handle_assistant_message('test_user', message) + return text_content.text + if __name__ == "__main__": - if os.getenv('USE_MONGO'): - mongodb.connect_to_database() - storage = Storage(MongoStorage(mongodb.db)) - else: - storage = Storage(FileStorage('db.json')) - try: - data = storage.load() - for user_id in data.keys(): - model_management[user_id] = OpenAIModel(api_key=data[user_id]) - except FileNotFoundError: - pass - app.run(host='0.0.0.0', port=8080) + model.check_token_valid() + app.run(host='0.0.0.0', port=8080, debug=True) diff --git a/requirements.txt b/requirements.txt index b8f2bf572..1b907194c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ -line-bot-sdk==2.4.1 -python-dotenv==0.21.1 -Flask==2.2.2 -opencc-python-reimplemented==0.1.4 -beautifulsoup4==4.11.2 -youtube-transcript-api==0.5.0 -pymongo==4.3.3 \ No newline at end of file +line-bot-sdk +Flask +Werkzeug +opencc-python-reimplemented +beautifulsoup4 +youtube-transcript-api +pyyaml +psycopg2 +gunicorn \ No newline at end of file diff --git a/src/config.py b/src/config.py new file mode 100644 index 000000000..af09f9f5a --- /dev/null +++ b/src/config.py @@ -0,0 +1,18 @@ +import yaml +import os + +def load_config(file_path="config/config.yml"): + """Load YAML configuration file.""" + try: + with open(file_path, 'r') as file: + config = yaml.safe_load(file) + return config + except FileNotFoundError: + print("Configuration file not found.") + except yaml.YAMLError as exc: + print("Error in configuration file:", exc) + +# 例如使用 +if __name__ == "__main__": + config = load_config() + print(config) \ No newline at end of file diff --git a/src/db.py b/src/db.py new file mode 100644 index 000000000..8197a0fb1 --- /dev/null +++ b/src/db.py @@ -0,0 +1,102 @@ +import os +import psycopg2 +import datetime + +class Database: + + def __init__(self, config): + self.config = config + self.conn = None + self.cursor = None + self.last_connected_time = 0 + self.timeout = 30 * 60 # 30分鐘超時 + + def connect_to_database(self): + host = self.config['host'] + port = self.config['port'] + db_name = self.config['db_name'] + user = self.config['user'] + password = self.config['password'] + sslmode = self.config['sslmode'] #verify-full or #verify-ca + sslrootcert = self.config['sslrootcert'] #server-ca.pem, # Server CA 證書路徑 + sslcert = self.config['sslcert'] #client-cert.pem, # Client 證書路徑 + sslkey = self.config['sslkey'] #client-key.pem # Client 私鑰路徑 + + # 建立連接 + if sslmode in ['verify-full', 'verify-ca']: + self.conn = psycopg2.connect( + dbname=db_name, + user=user, + password=password, + host=host, + port=port, + sslmode=sslmode, + sslrootcert=sslrootcert, + sslcert=sslcert, + sslkey=sslkey + ) + else: + self.conn = psycopg2.connect( + dbname=db_name, + user=user, + password=password, + host=host, + port=port + ) + self.conn.autocommit = True + self.cursor = self.conn.cursor() + self.cursor.execute("SELECT 1") # 用來測試連接是否成功 + assert self.cursor.fetchone()[0] == 1 + self.last_connected_time = datetime.datetime.now() + + def check_connect(self): + if self.conn: + if (datetime.datetime.now() - self.last_connected_time).total_seconds() > self.timeout: + self.close_connection() + self.connect_to_database() + else: + self.connect_to_database() + + def close_connection(self): + if self.conn: + self.cursor.close() + self.conn.close() + + def query_thread(self, user_id): + self.cursor.execute( + """ + SELECT thread_id + FROM user_thread_table + WHERE user_id = %s; + """, + (user_id,) + ) + result = self.cursor.fetchone() + self.last_connected_time = datetime.datetime.now() + return result[0] if result else None + + def save_thread(self, user_id, thread_id): + now = datetime.datetime.utcnow() + self.cursor.execute( + """ + INSERT INTO user_thread_table (user_id, thread_id, created_at) + VALUES (%s, %s, %s) + ON CONFLICT (user_id) DO UPDATE SET + thread_id = EXCLUDED.thread_id, + created_at = EXCLUDED.created_at; + """, + (user_id, thread_id, now) + ) + self.conn.commit() + self.last_connected_time = datetime.datetime.now() + + def delete_thread(self, user_id): + self.cursor.execute( + """ + DELETE FROM user_thread_table + WHERE user_id = %s; + """, + (user_id,) + ) + self.conn.commit() + self.last_connected_time = datetime.datetime.now() diff --git a/src/logger.py b/src/logger.py index 254539a59..ed675f9d2 100644 --- a/src/logger.py +++ b/src/logger.py @@ -1,7 +1,7 @@ import os import logging import logging.handlers - +from src.config import load_config class CustomFormatter(logging.Formatter): __LEVEL_COLORS = [ @@ -42,7 +42,7 @@ class LoggerFactory: @staticmethod def create_logger(formatter, handlers): logger = logging.getLogger('chatgpt_logger') - logger.setLevel(logging.INFO) + logger.setLevel(logging.DEBUG) for handler in handlers: handler.setLevel(logging.DEBUG) handler.setFormatter(formatter) @@ -51,7 +51,10 @@ def create_logger(formatter, handlers): class FileHandler(logging.FileHandler): - def __init__(self, log_file): + def __init__(self, log_file=None): + if not log_file: + config = load_config() + log_file = config['logfile'] os.makedirs(os.path.dirname(log_file), exist_ok=True) super().__init__(log_file) @@ -59,8 +62,7 @@ def __init__(self, log_file): class ConsoleHandler(logging.StreamHandler): pass - formatter = CustomFormatter() -file_handler = FileHandler('./logs') +file_handler = FileHandler() console_handler = ConsoleHandler() logger = LoggerFactory.create_logger(formatter, [file_handler, console_handler]) diff --git a/src/models.py b/src/models.py index 87acc3d67..a0ecf764b 100644 --- a/src/models.py +++ b/src/models.py @@ -6,6 +6,39 @@ class ModelInterface: def check_token_valid(self) -> bool: pass + def list_files(self) -> str: + pass + + def retrieve_assistant(self) -> str: + pass + + def retrieve_vector_store(self, vector_store_id: str) -> str: + pass + + def retrieve_vector_store_files(self, vector_store_id: str) -> str: + pass + + def create_thread(self) -> str: + pass + + def retrieve_thread(self) -> str: + pass + + def delete_thread(self, thread_id: str) -> bool: + pass + + def create_thread_message(self, thread_id: str, messages: List[Dict]) -> bool: + pass + + def list_thread_messages(self, thread_id: str) -> str: + pass + + def create_thread_run(self, thread_id: str) -> str: + pass + + def retrieve_thread_run(self, thread_id: str, run_id:str) -> str: + pass + def chat_completions(self, messages: List[Dict], model_engine: str) -> str: pass @@ -17,20 +50,26 @@ def image_generations(self, prompt: str) -> str: class OpenAIModel(ModelInterface): - def __init__(self, api_key: str): + def __init__(self, api_key: str, assistant_id: str): self.api_key = api_key + self.assistant_id = assistant_id self.base_url = 'https://api.openai.com/v1' - def _request(self, method, endpoint, body=None, files=None): + def _request(self, method, endpoint, body=None, files=None, assistant=False): self.headers = { 'Authorization': f'Bearer {self.api_key}' } try: if method == 'GET': + if assistant: + self.headers['Content-Type'] = 'application/json' + self.headers['OpenAI-Beta'] = 'assistants=v2' r = requests.get(f'{self.base_url}{endpoint}', headers=self.headers) elif method == 'POST': if body: self.headers['Content-Type'] = 'application/json' + if assistant: + self.headers['OpenAI-Beta'] = 'assistants=v2' r = requests.post(f'{self.base_url}{endpoint}', headers=self.headers, json=body, files=files) r = r.json() if r.get('error'): @@ -39,9 +78,61 @@ def _request(self, method, endpoint, body=None, files=None): return False, None, 'OpenAI API 系統不穩定,請稍後再試' return True, r, None - def check_token_valid(self): + def check_token_valid(self) -> bool: return self._request('GET', '/models') + def list_files(self) -> str: + endpoint = '/files' + return self._request('GET', endpoint, assistant=True) + + def retrieve_assistant(self) -> str: + endpoint = '/assistants/' + self.assistant_id + return self._request('GET', endpoint, assistant=True) + + def retrieve_vector_store(self, vector_store_id) -> str: + endpoint = '/vector_stores/' + vector_store_id + return self._request('GET', endpoint, assistant=True) + + def list_vector_store_files(self, vector_store_id) -> str: + endpoint = '/vector_stores/' + vector_store_id + '/files' + return self._request('GET', endpoint, assistant=True) + + def create_thread(self) -> str: + endpoint = '/threads' + return self._request('POST', endpoint, assistant=True) + + def retrieve_thread(self, thread_id) -> str: + endpoint = '/threads/' + thread_id + return self._request('GET', endpoint, assistant=True) + + def delete_thread(self, thread_id) -> bool: + endpoint = '/threads/' + thread_id + return self._request('DELETE', endpoint, assistant=True) + + def create_thread_message(self, thread_id, messages) -> bool: + endpoint = '/threads/' + thread_id + '/messages' + json_body = { + 'role': 'user', + 'content': messages + } + return self._request('POST', endpoint, body=json_body, assistant=True) + + def create_thread_run(self, thread_id) -> str: + endpoint = '/threads/' + thread_id + '/runs' + json_body = { + 'assistant_id': self.assistant_id, + 'temperature': 0 + } + return self._request('POST', endpoint, body=json_body, assistant=True) + + def retrieve_thread_run(self, thread_id, run_id) -> str: + endpoint = '/threads/' + thread_id + '/runs/' + run_id + return self._request('GET', endpoint, assistant=True) + + def list_thread_messages(self, thread_id) -> str: + endpoint = '/threads/' + thread_id + '/messages' + return self._request('GET', endpoint, assistant=True) + def chat_completions(self, messages, model_engine) -> str: json_body = { 'model': model_engine, diff --git a/src/mongodb.py b/src/mongodb.py deleted file mode 100644 index de9e79400..000000000 --- a/src/mongodb.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - -from pymongo import MongoClient - - -class MongoDB(): - """ - Environment Variables: - MONGODB__PATH - MONGODB__DBNAME - """ - client: None - db: None - - def connect_to_database(self, mongo_path=None, db_name=None): - mongo_path = mongo_path or os.getenv('MONGODB__PATH') - db_name = db_name or os.getenv('MONGODB__DBNAME') - self.client = MongoClient(mongo_path) - assert self.client.config.command('ping')['ok'] == 1.0 - self.db = self.client[db_name] - - -mongodb = MongoDB() diff --git a/src/storage.py b/src/storage.py deleted file mode 100644 index 2739efd03..000000000 --- a/src/storage.py +++ /dev/null @@ -1,54 +0,0 @@ -import json -import datetime - - -class FileStorage: - def __init__(self, file_name): - self.fine_name = file_name - self.history = {} - - def save(self, data): - self.history.update(data) - with open(self.fine_name, 'w', newline='') as f: - json.dump(self.history, f) - - def load(self): - with open(self.fine_name, newline='') as jsonfile: - data = json.load(jsonfile) - self.history = data - return self.history - - -class MongoStorage: - def __init__(self, db): - self.db = db - - def save(self, data): - user_id, api_key = list(data.items())[0] - self.db['api_key'].update_one({ - 'user_id': user_id - }, { - '$set': { - 'user_id': user_id, - 'api_key': api_key, - 'created_at': datetime.datetime.utcnow() - } - }, upsert=True) - - def load(self): - data = list(self.db['api_key'].find()) - res = {} - for i in range(len(data)): - res[data[i]['user_id']] = data[i]['api_key'] - return res - - -class Storage: - def __init__(self, storage): - self.storage = storage - - def save(self, data): - self.storage.save(data) - - def load(self): - return self.storage.load() diff --git a/src/utils.py b/src/utils.py index b4aa89d2f..e49ad4ef8 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,11 +1,62 @@ import opencc +import re s2t_converter = opencc.OpenCC('s2t') t2s_converter = opencc.OpenCC('t2s') +def get_response_data(response) -> dict: + for item in response['data']: + if item['role'] == 'assistant' and item['content'] and item['content'][0]['type'] == 'text': + return item + return None -def get_role_and_content(response: str): - role = response['choices'][0]['message']['role'] - content = response['choices'][0]['message']['content'].strip() - content = s2t_converter.convert(content) - return role, content +def get_content_and_reference(response, file_dict) -> str: + data = get_response_data(response) + if not data: + return '' + text = data['content'][0]['text']['value'] + annotations = data['content'][0]['text']['annotations'] + text = s2t_converter.convert(text) + # 替換註釋文本 + ref_mapping = {} + for i, annotation in enumerate(annotations, 1): + original_text = annotation['text'] + file_id = annotation['file_citation']['file_id'] + replacement_text = f"[{i}]" + text = text.replace(original_text, replacement_text) + ref_mapping[replacement_text] = f"{replacement_text}: {file_dict.get(file_id)}" + + # 添加文件識別碼引用 + reference_text = '\n'.join(ref_mapping.values()) + final_text = f"{text}\n\n{reference_text}" + + return final_text + +def replace_file_name(content, file_dict) -> str: + sorted_file_dict = sorted(file_dict.items(), key=lambda x: -len(x[0])) + + # 對每個鍵進行替換 + for key, value in sorted_file_dict: + # 使用 re.escape 避免鍵中可能包含的任何正則表達式特殊字符影響匹配 + text = re.sub(re.escape(key), value, text) + return text + +def check_token_valid(model) -> bool: + model = OpenAIModel(api_key=openai_api_key, assistant_id=openai_assistant_id) + is_successful, _, _ = model.check_token_valid() + if not is_successful: + raise ValueError('Invalid API token') + return is_successful + +def get_file_dict(model) -> dict: + is_successful, response, error_message = model.list_files() + if not is_successful: + raise Exception(error_message) + file_dict = { file['id']: file['filename'].replace('.txt', '').replace('.json', '') for file in response['data'] } + return file_dict + +def detect_none_references(text): + if re.search(r'\[\d+\]: None', text): + return True + else: + return False diff --git a/templates/chat.html b/templates/chat.html new file mode 100644 index 000000000..e50969f2f --- /dev/null +++ b/templates/chat.html @@ -0,0 +1,96 @@ + + + + + + 台南市議會聊天機器人 + + + +

台南市議會聊天機器人測試聊天介面

+
+

+
+ + + + + +