Skip to content

Commit

Permalink
Add trace ID to exception handlers (#411)
Browse files Browse the repository at this point in the history
  • Loading branch information
wu-clan authored Sep 9, 2024
1 parent b13410c commit 1f95a77
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 35 deletions.
2 changes: 1 addition & 1 deletion backend/common/exception/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
全局业务异常类
业务代码执行异常时,可以使用 raise xxxError 触发内部错误,它尽可能实现带有后台任务的异常,但它不适用于**自定义响应状态码**
如果要求使用**自定义响应状态码**,则可以通过 return await response_base.fail(res=CustomResponseCode.xxx) 直接返回
如果要求使用**自定义响应状态码**,则可以通过 return response_base.fail(res=CustomResponseCode.xxx) 直接返回
""" # noqa: E501

from typing import Any
Expand Down
71 changes: 43 additions & 28 deletions backend/common/exception/exception_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from backend.core.conf import settings
from backend.utils.serializers import MsgSpecJSONResponse
from backend.utils.trace_id import get_request_trace_id


def _get_exception_code(status_code: int):
Expand Down Expand Up @@ -70,13 +71,14 @@ async def _validation_exception_handler(request: Request, e: RequestValidationEr
error_input = error.get('input')
field = str(error.get('loc')[-1])
error_msg = error.get('msg')
message = f'{error_msg}{field},输入:{error_input}' if settings.ENVIRONMENT == 'dev' else error_msg
message = f'{field} {error_msg},输入:{error_input}' if settings.ENVIRONMENT == 'dev' else error_msg
msg = f'请求参数非法: {message}'
data = {'errors': errors} if settings.ENVIRONMENT == 'dev' else None
content = {
'code': StandardResponseCode.HTTP_422,
'msg': msg,
'data': data,
'trace_id': get_request_trace_id(request),
}
request.state.__request_validation_exception__ = content # 用于在中间件中获取异常信息
return MsgSpecJSONResponse(status_code=422, content=content)
Expand Down Expand Up @@ -104,7 +106,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
request.state.__request_http_exception__ = content # 用于在中间件中获取异常信息
return MsgSpecJSONResponse(
status_code=_get_exception_code(exc.status_code),
content=content,
content=content.update(trace_id=get_request_trace_id(request)),
headers=exc.headers,
)

Expand Down Expand Up @@ -145,6 +147,7 @@ async def pydantic_user_error_handler(request: Request, exc: PydanticUserError):
'code': StandardResponseCode.HTTP_500,
'msg': CUSTOM_USAGE_ERROR_MESSAGES.get(exc.code),
'data': None,
'trace_id': get_request_trace_id(request),
},
)

Expand All @@ -168,7 +171,27 @@ async def assertion_error_handler(request: Request, exc: AssertionError):
content = res.model_dump()
return MsgSpecJSONResponse(
status_code=StandardResponseCode.HTTP_500,
content=content,
content=content.update(trace_id=get_request_trace_id(request)),
)

@app.exception_handler(BaseExceptionMixin)
async def custom_exception_handler(request: Request, exc: BaseExceptionMixin):
"""
全局异常处理
:param request:
:param exc:
:return:
"""
return MsgSpecJSONResponse(
status_code=_get_exception_code(exc.code),
content={
'code': exc.code,
'msg': str(exc.msg),
'data': exc.data if exc.data else None,
'trace_id': get_request_trace_id(request),
},
background=exc.background,
)

@app.exception_handler(Exception)
Expand All @@ -180,31 +203,23 @@ async def all_exception_handler(request: Request, exc: Exception):
:param exc:
:return:
"""
if isinstance(exc, BaseExceptionMixin):
return MsgSpecJSONResponse(
status_code=_get_exception_code(exc.code),
content={
'code': exc.code,
'msg': str(exc.msg),
'data': exc.data if exc.data else None,
},
background=exc.background,
)
else:
import traceback
import traceback

log.error(f'未知异常: {exc}')
log.error(traceback.format_exc())
if settings.ENVIRONMENT == 'dev':
content = {
'code': StandardResponseCode.HTTP_500,
'msg': str(exc),
'data': None,
}
else:
res = response_base.fail(res=CustomResponseCode.HTTP_500)
content = res.model_dump()
return MsgSpecJSONResponse(status_code=StandardResponseCode.HTTP_500, content=content)
log.error(f'未知异常: {exc}')
log.error(traceback.format_exc())
if settings.ENVIRONMENT == 'dev':
content = {
'code': StandardResponseCode.HTTP_500,
'msg': str(exc),
'data': None,
}
else:
res = response_base.fail(res=CustomResponseCode.HTTP_500)
content = res.model_dump()
return MsgSpecJSONResponse(
status_code=StandardResponseCode.HTTP_500,
content=content.update(trace_id=get_request_trace_id(request)),
)

if settings.MIDDLEWARE_CORS:

Expand Down Expand Up @@ -238,7 +253,7 @@ async def cors_custom_code_500_exception_handler(request, exc):
content = res.model_dump()
response = MsgSpecJSONResponse(
status_code=exc.code if isinstance(exc, BaseExceptionMixin) else StandardResponseCode.HTTP_500,
content=content,
content=content.update(trace_id=get_request_trace_id(request)),
background=exc.background if isinstance(exc, BaseExceptionMixin) else None,
)
origin = request.headers.get('origin')
Expand Down
8 changes: 4 additions & 4 deletions backend/middleware/opera_log_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from backend.utils.encrypt import AESCipher, ItsDCipher, Md5Cipher
from backend.utils.request_parse import parse_ip_info, parse_user_agent_info
from backend.utils.timezone import timezone
from backend.utils.trace_id import get_request_trace_id


class OperaLogMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -62,7 +63,7 @@ async def dispatch(self, request: Request, call_next) -> Response:

# 日志创建
opera_log_in = CreateOperaLogParam(
trace_id=request.headers.get(settings.TRACE_ID_REQUEST_HEADER_KEY) or '-',
trace_id=get_request_trace_id(request),
username=username,
method=method,
title=summary,
Expand Down Expand Up @@ -100,9 +101,9 @@ async def execute_request(self, request: Request, call_next) -> RequestCallNext:
response = None
try:
response = await call_next(request)
code, msg = self.validation_exception_handler(request, code, msg)
except Exception as e:
log.exception(e)
code, msg = await self.request_exception_handler(request, code, msg)
# code 处理包含 SQLAlchemy 和 Pydantic
code = getattr(e, 'code', None) or code
msg = getattr(e, 'msg', None) or msg
Expand All @@ -112,8 +113,7 @@ async def execute_request(self, request: Request, call_next) -> RequestCallNext:
return RequestCallNext(code=str(code), msg=msg, status=status, err=err, response=response)

@staticmethod
@sync_to_async
def request_exception_handler(request: Request, code: int, msg: str) -> tuple[str, str]:
def validation_exception_handler(request: Request, code: int, msg: str) -> tuple[str, str]:
"""请求异常处理器"""
try:
http_exception = request.state.__request_http_exception__
Expand Down
3 changes: 1 addition & 2 deletions backend/utils/request_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from backend.database.db_redis import redis_client


@sync_to_async
def get_request_ip(request: Request) -> str:
"""获取请求的 ip 地址"""
real = request.headers.get('X-Real-IP')
Expand Down Expand Up @@ -78,7 +77,7 @@ def get_location_offline(ip: str) -> dict | None:

async def parse_ip_info(request: Request) -> IpInfo:
country, region, city = None, None, None
ip = await get_request_ip(request)
ip = get_request_ip(request)
location = await redis_client.get(f'{settings.IP_LOCATION_REDIS_PREFIX}:{ip}')
if location:
country, region, city = location.split(' ')
Expand Down
9 changes: 9 additions & 0 deletions backend/utils/trace_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from fastapi import Request

from backend.core.conf import settings


def get_request_trace_id(request: Request) -> str:
return request.headers.get(settings.TRACE_ID_REQUEST_HEADER_KEY) or '-'

0 comments on commit 1f95a77

Please sign in to comment.