Skip to content

Commit

Permalink
Fix celery service functions error (#462)
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-clan authored Nov 16, 2024
1 parent 5e60d9c commit a15693c
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 41 deletions.
4 changes: 2 additions & 2 deletions backend/app/task/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@

你可以通过 `CELERY_BROKER` 控制消息代理选择,它支持 redis 和 rabbitmq

对于本地调试,我们建议使用 redis
对于本地调试,建议使用 redis

对于线上环境,我们强制使用 rabbitmq
对于线上环境,强制使用 rabbitmq
41 changes: 24 additions & 17 deletions backend/app/task/api/v1/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,38 +14,45 @@
router = APIRouter()


@router.get('', summary='获取所有可执行任务模块', dependencies=[DependsJwtAuth])
@router.get('', summary='获取可执行任务', dependencies=[DependsJwtAuth])
async def get_all_tasks() -> ResponseModel:
tasks = task_service.get_list()
tasks = await task_service.get_list()
return response_base.success(data=tasks)


@router.get('/running', summary='获取正在执行的任务', dependencies=[DependsJwtAuth])
async def get_current_task() -> ResponseModel:
task = task_service.get()
return response_base.success(data=task)


@router.get('/{tid}/status', summary='获取任务状态', dependencies=[DependsJwtAuth])
async def get_task_status(tid: Annotated[str, Path(description='任务ID')]) -> ResponseModel:
status = task_service.get_status(tid)
@router.get(
'/{tid}',
summary='获取任务详情',
deprecated=True,
description='此接口被视为作废,建议使用 flower 查看任务详情',
dependencies=[DependsJwtAuth],
)
async def get_task_detail(tid: Annotated[str, Path(description='任务ID')]) -> ResponseModel:
status = task_service.get_detail(tid=tid)
return response_base.success(data=status)


@router.get('/{tid}', summary='获取任务结果', dependencies=[DependsJwtAuth])
async def get_task_result(tid: Annotated[str, Path(description='任务ID')]) -> ResponseModel:
task = task_service.get_result(tid)
return response_base.success(data=task)
@router.post(
'/{tid}',
summary='撤销任务',
dependencies=[
Depends(RequestPermission('sys:task:revoke')),
DependsRBAC,
],
)
async def revoke_task(tid: Annotated[str, Path(description='任务ID')]) -> ResponseModel:
task_service.revoke(tid=tid)
return response_base.success()


@router.post(
'/{name}',
'',
summary='执行任务',
dependencies=[
Depends(RequestPermission('sys:task:run')),
DependsRBAC,
],
)
async def run_task(obj: RunParam) -> ResponseModel:
task = task_service.run(name=obj.name, args=obj.args, kwargs=obj.kwargs)
task = task_service.run(obj=obj)
return response_base.success(data=task)
2 changes: 1 addition & 1 deletion backend/app/task/celery_task/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@

@celery_app.task(name='task_demo_async')
async def task_demo_async() -> str:
await sleep(10)
await sleep(20)
return 'test async'
2 changes: 1 addition & 1 deletion backend/app/task/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TaskSettings(BaseSettings):

# Celery
CELERY_BROKER: Literal['rabbitmq', 'redis'] = 'redis'
CELERY_BACKEND_REDIS_PREFIX: str = 'fba:celery_'
CELERY_BACKEND_REDIS_PREFIX: str = 'fba:celery:'
CELERY_BACKEND_REDIS_TIMEOUT: int = 5
CELERY_TASK_PACKAGES: list[str] = [
'app.task.celery_task',
Expand Down
47 changes: 27 additions & 20 deletions backend/app/task/service/task_service.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,53 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from celery.exceptions import NotRegistered
from celery.result import AsyncResult
from starlette.concurrency import run_in_threadpool

from backend.app.task.celery import celery_app
from backend.app.task.schema.task import RunParam
from backend.common.dataclasses import TaskResult
from backend.common.exception.errors import NotFoundError


class TaskService:
@staticmethod
def get_list():
filtered_tasks = []
tasks = celery_app.tasks
for key, value in tasks.items():
if not key.startswith('celery.'):
filtered_tasks.append({key, value})
return filtered_tasks

@staticmethod
def get():
return celery_app.current_worker_task
async def get_list():
registered_tasks = await run_in_threadpool(celery_app.control.inspect().registered)
tasks = list(registered_tasks.values())[0]
return tasks

@staticmethod
def get_status(uid: str):
def get_detail(*, tid: str):
try:
task_result = AsyncResult(id=uid, app=celery_app)
result = AsyncResult(id=tid, app=celery_app)
except NotRegistered:
raise NotFoundError(msg='任务不存在')
return task_result.status
return TaskResult(
result=result.result,
traceback=result.traceback,
status=result.state,
name=result.name,
args=result.args,
kwargs=result.kwargs,
worker=result.worker,
retries=result.retries,
queue=result.queue,
)

@staticmethod
def get_result(uid: str):
def revoke(*, tid: str):
try:
task_result = AsyncResult(id=uid, app=celery_app)
result = AsyncResult(id=tid, app=celery_app)
except NotRegistered:
raise NotFoundError(msg='任务不存在')
return task_result.result
result.revoke(terminate=True)

@staticmethod
def run(*, name: str, args: list | None = None, kwargs: dict | None = None):
task = celery_app.send_task(name=name, args=args, kwargs=kwargs)
return task
def run(*, obj: RunParam):
task: AsyncResult = celery_app.send_task(name=obj.name, args=obj.args, kwargs=obj.kwargs)
return task.task_id


task_service: TaskService = TaskService()
13 changes: 13 additions & 0 deletions backend/common/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,16 @@ class AccessToken:
class RefreshToken:
refresh_token: str
refresh_token_expire_time: datetime


@dataclasses.dataclass
class TaskResult:
result: str
traceback: str
status: str
name: str
args: list | None
kwargs: dict | None
worker: str
retries: int | None
queue: str | None

0 comments on commit a15693c

Please sign in to comment.