Skip to content

Commit

Permalink
Merge pull request #460 from yyhhyyyyyy/code-formatting
Browse files Browse the repository at this point in the history
Code Formatting
  • Loading branch information
yyhhyyyyyy authored Jul 25, 2024
2 parents bbd4e94 + 5c2db3a commit 84ae8e5
Show file tree
Hide file tree
Showing 20 changed files with 413 additions and 217 deletions.
9 changes: 7 additions & 2 deletions app/asgi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Application implementation - ASGI."""

import os

from fastapi import FastAPI, Request
Expand All @@ -24,7 +25,9 @@ def exception_handler(request: Request, e: HttpException):
def validation_exception_handler(request: Request, e: RequestValidationError):
return JSONResponse(
status_code=400,
content=utils.get_response(status=400, data=e.errors(), message='field required'),
content=utils.get_response(
status=400, data=e.errors(), message="field required"
),
)


Expand Down Expand Up @@ -61,7 +64,9 @@ def get_application() -> FastAPI:
)

task_dir = utils.task_dir()
app.mount("/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name="")
app.mount(
"/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name=""
)

public_dir = utils.public_dir()
app.mount("/", StaticFiles(directory=public_dir, html=True), name="")
Expand Down
15 changes: 10 additions & 5 deletions app/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
def __init_logger():
# _log_file = utils.storage_dir("logs/server.log")
_lvl = config.log_level
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
)

def format_record(record):
# 获取日志记录中的文件全路径
Expand All @@ -21,10 +23,13 @@ def format_record(record):
record["file"].path = f"./{relative_path}"
# 返回修改后的格式字符串
# 您可以根据需要调整这里的格式
_format = '<green>{time:%Y-%m-%d %H:%M:%S}</> | ' + \
'<level>{level}</> | ' + \
'"{file.path}:{line}":<blue> {function}</> ' + \
'- <level>{message}</>' + "\n"
_format = (
"<green>{time:%Y-%m-%d %H:%M:%S}</> | "
+ "<level>{level}</> | "
+ '"{file.path}:{line}":<blue> {function}</> '
+ "- <level>{message}</>"
+ "\n"
)
return _format

logger.remove()
Expand Down
8 changes: 5 additions & 3 deletions app/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def load_config():
_config_ = toml.load(config_file)
except Exception as e:
logger.warning(f"load config failed: {str(e)}, try to load as utf-8-sig")
with open(config_file, mode="r", encoding='utf-8-sig') as fp:
with open(config_file, mode="r", encoding="utf-8-sig") as fp:
_cfg_content = fp.read()
_config_ = toml.loads(_cfg_content)
return _config_
Expand All @@ -52,8 +52,10 @@ def save_config():
listen_host = _cfg.get("listen_host", "0.0.0.0")
listen_port = _cfg.get("listen_port", 8080)
project_name = _cfg.get("project_name", "MoneyPrinterTurbo")
project_description = _cfg.get("project_description",
"<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>")
project_description = _cfg.get(
"project_description",
"<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>",
)
project_version = _cfg.get("project_version", "1.1.9")
reload_debug = False

Expand Down
12 changes: 8 additions & 4 deletions app/controllers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@


def get_task_id(request: Request):
task_id = request.headers.get('x-task-id')
task_id = request.headers.get("x-task-id")
if not task_id:
task_id = uuid4()
return str(task_id)


def get_api_key(request: Request):
api_key = request.headers.get('x-api-key')
api_key = request.headers.get("x-api-key")
return api_key


Expand All @@ -23,5 +23,9 @@ def verify_token(request: Request):
if token != config.app.get("api_key", ""):
request_id = get_task_id(request)
request_url = request.url
user_agent = request.headers.get('user-agent')
raise HttpException(task_id=request_id, status_code=401, message=f"invalid token: {request_url}, {user_agent}")
user_agent = request.headers.get("user-agent")
raise HttpException(
task_id=request_id,
status_code=401,
message=f"invalid token: {request_url}, {user_agent}",
)
19 changes: 13 additions & 6 deletions app/controllers/manager/base_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@ def add_task(self, func: Callable, *args: Any, **kwargs: Any):
print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}")
self.execute_task(func, *args, **kwargs)
else:
print(f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}")
print(
f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}"
)
self.enqueue({"func": func, "args": args, "kwargs": kwargs})

def execute_task(self, func: Callable, *args: Any, **kwargs: Any):
thread = threading.Thread(target=self.run_task, args=(func, *args), kwargs=kwargs)
thread = threading.Thread(
target=self.run_task, args=(func, *args), kwargs=kwargs
)
thread.start()

def run_task(self, func: Callable, *args: Any, **kwargs: Any):
Expand All @@ -35,11 +39,14 @@ def run_task(self, func: Callable, *args: Any, **kwargs: Any):

def check_queue(self):
with self.lock:
if self.current_tasks < self.max_concurrent_tasks and not self.is_queue_empty():
if (
self.current_tasks < self.max_concurrent_tasks
and not self.is_queue_empty()
):
task_info = self.dequeue()
func = task_info['func']
args = task_info.get('args', ())
kwargs = task_info.get('kwargs', {})
func = task_info["func"]
args = task_info.get("args", ())
kwargs = task_info.get("kwargs", {})
self.execute_task(func, *args, **kwargs)

def task_done(self):
Expand Down
24 changes: 16 additions & 8 deletions app/controllers/manager/redis_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from app.services import task as tm

FUNC_MAP = {
'start': tm.start,
"start": tm.start,
# 'start_test': tm.start_test
}

Expand All @@ -24,22 +24,30 @@ def create_queue(self):
def enqueue(self, task: Dict):
task_with_serializable_params = task.copy()

if 'params' in task['kwargs'] and isinstance(task['kwargs']['params'], VideoParams):
task_with_serializable_params['kwargs']['params'] = task['kwargs']['params'].dict()
if "params" in task["kwargs"] and isinstance(
task["kwargs"]["params"], VideoParams
):
task_with_serializable_params["kwargs"]["params"] = task["kwargs"][
"params"
].dict()

# 将函数对象转换为其名称
task_with_serializable_params['func'] = task['func'].__name__
task_with_serializable_params["func"] = task["func"].__name__
self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params))

def dequeue(self):
task_json = self.redis_client.lpop(self.queue)
if task_json:
task_info = json.loads(task_json)
# 将函数名称转换回函数对象
task_info['func'] = FUNC_MAP[task_info['func']]

if 'params' in task_info['kwargs'] and isinstance(task_info['kwargs']['params'], dict):
task_info['kwargs']['params'] = VideoParams(**task_info['kwargs']['params'])
task_info["func"] = FUNC_MAP[task_info["func"]]

if "params" in task_info["kwargs"] and isinstance(
task_info["kwargs"]["params"], dict
):
task_info["kwargs"]["params"] = VideoParams(
**task_info["kwargs"]["params"]
)

return task_info
return None
Expand Down
7 changes: 6 additions & 1 deletion app/controllers/ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
router = APIRouter()


@router.get("/ping", tags=["Health Check"], description="检查服务可用性", response_description="pong")
@router.get(
"/ping",
tags=["Health Check"],
description="检查服务可用性",
response_description="pong",
)
def ping(request: Request) -> str:
return "pong"
4 changes: 2 additions & 2 deletions app/controllers/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

def new_router(dependencies=None):
router = APIRouter()
router.tags = ['V1']
router.prefix = '/api/v1'
router.tags = ["V1"]
router.prefix = "/api/v1"
# 将认证依赖项应用于所有路由
if dependencies:
router.dependencies = dependencies
Expand Down
43 changes: 28 additions & 15 deletions app/controllers/v1/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from fastapi import Request
from app.controllers.v1.base import new_router
from app.models.schema import VideoScriptResponse, VideoScriptRequest, VideoTermsResponse, VideoTermsRequest
from app.models.schema import (
VideoScriptResponse,
VideoScriptRequest,
VideoTermsResponse,
VideoTermsRequest,
)
from app.services import llm
from app.utils import utils

Expand All @@ -9,23 +14,31 @@
router = new_router()


@router.post("/scripts", response_model=VideoScriptResponse, summary="Create a script for the video")
@router.post(
"/scripts",
response_model=VideoScriptResponse,
summary="Create a script for the video",
)
def generate_video_script(request: Request, body: VideoScriptRequest):
video_script = llm.generate_script(video_subject=body.video_subject,
language=body.video_language,
paragraph_number=body.paragraph_number)
response = {
"video_script": video_script
}
video_script = llm.generate_script(
video_subject=body.video_subject,
language=body.video_language,
paragraph_number=body.paragraph_number,
)
response = {"video_script": video_script}
return utils.get_response(200, response)


@router.post("/terms", response_model=VideoTermsResponse, summary="Generate video terms based on the video script")
@router.post(
"/terms",
response_model=VideoTermsResponse,
summary="Generate video terms based on the video script",
)
def generate_video_terms(request: Request, body: VideoTermsRequest):
video_terms = llm.generate_terms(video_subject=body.video_subject,
video_script=body.video_script,
amount=body.amount)
response = {
"video_terms": video_terms
}
video_terms = llm.generate_terms(
video_subject=body.video_subject,
video_script=body.video_script,
amount=body.amount,
)
response = {"video_terms": video_terms}
return utils.get_response(200, response)
22 changes: 18 additions & 4 deletions app/models/const.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
PUNCTUATIONS = [
"?", ",", ".", "、", ";", ":", "!", "…",
"?", ",", "。", "、", ";", ":", "!", "...",
"?",
",",
".",
"、",
";",
":",
"!",
"…",
"?",
",",
"。",
"、",
";",
":",
"!",
"...",
]

TASK_STATE_FAILED = -1
TASK_STATE_COMPLETE = 1
TASK_STATE_PROCESSING = 4

FILE_TYPE_VIDEOS = ['mp4', 'mov', 'mkv', 'webm']
FILE_TYPE_IMAGES = ['jpg', 'jpeg', 'png', 'bmp']
FILE_TYPE_VIDEOS = ["mp4", "mov", "mkv", "webm"]
FILE_TYPE_IMAGES = ["jpg", "jpeg", "png", "bmp"]
8 changes: 5 additions & 3 deletions app/models/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@


class HttpException(Exception):
def __init__(self, task_id: str, status_code: int, message: str = '', data: Any = None):
def __init__(
self, task_id: str, status_code: int, message: str = "", data: Any = None
):
self.message = message
self.status_code = status_code
self.data = data
# 获取异常堆栈信息
tb_str = traceback.format_exc().strip()
if not tb_str or tb_str == "NoneType: None":
msg = f'HttpException: {status_code}, {task_id}, {message}'
msg = f"HttpException: {status_code}, {task_id}, {message}"
else:
msg = f'HttpException: {status_code}, {task_id}, {message}\n{tb_str}'
msg = f"HttpException: {status_code}, {task_id}, {message}\n{tb_str}"

if status_code == 400:
logger.warning(msg)
Expand Down
Loading

0 comments on commit 84ae8e5

Please sign in to comment.