diff --git a/main.py b/main.py index 1d7a07f75..52680e7fc 100644 --- a/main.py +++ b/main.py @@ -15,17 +15,18 @@ from src.models import OpenAIModel from src.memory import Memory from src.logger import logger -from src.storage import Storage +from src.storage import Storage, FileStorage, MongoStorage from src.utils import get_role_and_content from src.service.youtube import Youtube, YoutubeTranscriptReader from src.service.website import Website, WebsiteReader +from src.mongodb import mongodb load_dotenv('.env') app = Flask(__name__) line_bot_api = LineBotApi(os.getenv('LINE_CHANNEL_ACCESS_TOKEN')) handler = WebhookHandler(os.getenv('LINE_CHANNEL_SECRET')) -storage = Storage('db.json') +storage = None youtube = Youtube(step=4) website = Website() @@ -62,8 +63,9 @@ def handle_text_message(event): if not is_successful: raise ValueError('Invalid API token') model_management[user_id] = model - api_keys[user_id] = api_key - storage.save(api_keys) + storage.save({ + user_id: api_key + }) msg = TextSendMessage(text='Token 有效,註冊成功') elif text.startswith('/指令說明'): @@ -180,6 +182,11 @@ def home(): if __name__ == "__main__": + if os.getenv('USE_MONGO'): + mongodb.connect_to_database() + storage = Storage(MongoStorage(mongodb.db)) + else: + storage = Storage(FileStorage('db.json')) try: data = storage.load() for user_id in data.keys(): diff --git a/src/mongodb.py b/src/mongodb.py new file mode 100644 index 000000000..de9e79400 --- /dev/null +++ b/src/mongodb.py @@ -0,0 +1,23 @@ +import os + +from pymongo import MongoClient + + +class MongoDB(): + """ + Environment Variables: + MONGODB__PATH + MONGODB__DBNAME + """ + client: None + db: None + + def connect_to_database(self, mongo_path=None, db_name=None): + mongo_path = mongo_path or os.getenv('MONGODB__PATH') + db_name = db_name or os.getenv('MONGODB__DBNAME') + self.client = MongoClient(mongo_path) + assert self.client.config.command('ping')['ok'] == 1.0 + self.db = self.client[db_name] + + +mongodb = MongoDB() diff --git a/src/storage.py b/src/storage.py index d4ac6b3e8..2739efd03 100644 --- a/src/storage.py +++ b/src/storage.py @@ -1,15 +1,54 @@ import json +import datetime -class Storage(): +class FileStorage: def __init__(self, file_name): self.fine_name = file_name + self.history = {} def save(self, data): + self.history.update(data) with open(self.fine_name, 'w', newline='') as f: - json.dump(data, f) + json.dump(self.history, f) def load(self): with open(self.fine_name, newline='') as jsonfile: data = json.load(jsonfile) - return data + self.history = data + return self.history + + +class MongoStorage: + def __init__(self, db): + self.db = db + + def save(self, data): + user_id, api_key = list(data.items())[0] + self.db['api_key'].update_one({ + 'user_id': user_id + }, { + '$set': { + 'user_id': user_id, + 'api_key': api_key, + 'created_at': datetime.datetime.utcnow() + } + }, upsert=True) + + def load(self): + data = list(self.db['api_key'].find()) + res = {} + for i in range(len(data)): + res[data[i]['user_id']] = data[i]['api_key'] + return res + + +class Storage: + def __init__(self, storage): + self.storage = storage + + def save(self, data): + self.storage.save(data) + + def load(self): + return self.storage.load()