From 57797c2c8cab60b22444d6db3ac86e8a221856f4 Mon Sep 17 00:00:00 2001 From: Wu Clan Date: Mon, 19 Feb 2024 21:58:32 +0800 Subject: [PATCH] Update the response status code in exception handlers --- .../app/common/exception/exception_handler.py | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/backend/app/common/exception/exception_handler.py b/backend/app/common/exception/exception_handler.py index 2c2f5345..99f086c8 100644 --- a/backend/app/common/exception/exception_handler.py +++ b/backend/app/common/exception/exception_handler.py @@ -1,11 +1,13 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- +from asgiref.sync import sync_to_async from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError from pydantic import ValidationError from pydantic.errors import PydanticUserError from starlette.exceptions import HTTPException from starlette.middleware.cors import CORSMiddleware +from uvicorn.protocols.http.h11_impl import STATUS_PHRASES from backend.app.common.exception.errors import BaseExceptionMixin from backend.app.common.log import log @@ -19,6 +21,28 @@ from backend.app.utils.serializers import MsgSpecJSONResponse +@sync_to_async +def _get_exception_code(status_code: int): + """ + 获取返回状态码, OpenAPI, Uvicorn... 可用状态码基于 RFC 定义, 详细代码见下方链接 + + `python 状态码标准支持 `__ + + `IANA 状态码注册表 `__ + + :param status_code: + :return: + """ + try: + STATUS_PHRASES[status_code] + except Exception: # noqa: ignore + code = StandardResponseCode.HTTP_400 + else: + code = status_code + return code + + async def _validation_exception_handler(request: Request, e: RequestValidationError | ValidationError): """ 数据验证异常处理 @@ -83,7 +107,7 @@ async def http_exception_handler(request: Request, exc: HTTPException): content = res.model_dump() request.state.__request_http_exception__ = content # 用于在中间件中获取异常信息 return MsgSpecJSONResponse( - status_code=StandardResponseCode.HTTP_400, + status_code=await _get_exception_code(exc.status_code), content=content, headers=exc.headers, ) @@ -162,7 +186,7 @@ async def all_exception_handler(request: Request, exc: Exception): """ if isinstance(exc, BaseExceptionMixin): return MsgSpecJSONResponse( - status_code=StandardResponseCode.HTTP_400, + status_code=await _get_exception_code(exc.code), content={ 'code': exc.code, 'msg': str(exc.msg), @@ -177,7 +201,7 @@ async def all_exception_handler(request: Request, exc: Exception): log.error(traceback.format_exc()) if settings.ENVIRONMENT == 'dev': content = { - 'code': 500, + 'code': StandardResponseCode.HTTP_500, 'msg': str(exc), 'data': None, }