155 lines
5.4 KiB
Python
155 lines
5.4 KiB
Python
"""Database connection and session management."""
|
||
|
||
import uuid, re
|
||
from loguru import logger
|
||
import traceback
|
||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||
from sqlalchemy.orm import sessionmaker
|
||
from typing import Optional
|
||
|
||
from utils.general import gradient_text
|
||
|
||
from ..core.config import get_settings
|
||
from .base import Base
|
||
from utils.util_exceptions import DatabaseError
|
||
|
||
# Custom Session class with desc property and unique ID
|
||
class DrSession(AsyncSession):
|
||
"""Custom Session class with desc property and unique ID."""
|
||
|
||
def __init__(self, **kwargs):
|
||
"""Initialize DrSession with unique ID."""
|
||
super().__init__(**kwargs)
|
||
self.title = ""
|
||
self.descs = []
|
||
# 确保info属性存在
|
||
if not hasattr(self, 'info'):
|
||
self.info = {}
|
||
self.info['session_id'] = str(uuid.uuid4()).split('-')[0]
|
||
self.stepIndex = 0
|
||
|
||
@property
|
||
def title(self) -> Optional[str]:
|
||
"""Get work brief from session info."""
|
||
return self.info.get('title')
|
||
|
||
@title.setter
|
||
def title(self, value: str) -> None:
|
||
"""Set work brief in session info."""
|
||
if('title' not in self.info or self.info['title'].strip() == ""):
|
||
self.info['title'] = value # 确保title属性存在
|
||
else:
|
||
self.info['title'] = value + " >>> " + self.info['title']
|
||
|
||
@property
|
||
def desc(self) -> Optional[str]:
|
||
"""Get work brief from session info."""
|
||
return self.info.get('desc')
|
||
|
||
@desc.setter
|
||
def desc(self, value: str) -> None:
|
||
"""Set work brief in session info."""
|
||
self.stepIndex += 1
|
||
# 统一在这里打印更详细的 session 日志,方便排查问题
|
||
try:
|
||
# level 取 -3,可以拿到触发 desc 设置的上层业务代码位置
|
||
pos = self.parse_source_pos(-3)
|
||
except Exception:
|
||
pos = "unknown"
|
||
|
||
logger.info(f"{self.log_prefix()} STEP[{self.stepIndex}] {value} >>> @ {pos}")
|
||
|
||
def log_prefix(self) -> str:
|
||
"""Get log prefix with session ID and desc."""
|
||
return f"〖Session{self.info['session_id']}〗"
|
||
|
||
def parse_source_pos(self, level: int):
|
||
pos = (traceback.format_stack())[level].strip().split('\n')[0]
|
||
match = re.search(r"File \"(.+?)\", line (\d+), in (\w+)", pos);
|
||
if match:
|
||
file = match.group(1).replace("F:\\DrGraph_Python\\FastAPI\\", "")
|
||
pos = f"{file}:{match.group(2)} in {match.group(3)}"
|
||
return pos
|
||
|
||
def log_info(self, msg: str, level: int = -2):
|
||
"""Log info message with session ID."""
|
||
pos = self.parse_source_pos(level)
|
||
logger.info(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||
|
||
def log_success(self, msg: str, level: int = -2):
|
||
"""Log success message with session ID."""
|
||
pos = self.parse_source_pos(level)
|
||
logger.success(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||
|
||
def log_warning(self, msg: str, level: int = -2):
|
||
"""Log warning message with session ID."""
|
||
pos = self.parse_source_pos(level)
|
||
logger.warning(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||
|
||
def log_error(self, msg: str, level: int = -2):
|
||
"""Log error message with session ID."""
|
||
pos = self.parse_source_pos(level)
|
||
logger.error(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||
|
||
def log_exception(self, msg: str, level: int = -2):
|
||
"""Log exception message with session ID."""
|
||
pos = self.parse_source_pos(level)
|
||
logger.exception(f"{self.log_prefix()} {msg} >>> @ {pos}")
|
||
|
||
engine_async = create_async_engine(
|
||
get_settings().database.url,
|
||
echo=False, # get_settings().database.echo,
|
||
future=True,
|
||
pool_size=get_settings().database.pool_size,
|
||
max_overflow=get_settings().database.max_overflow,
|
||
pool_pre_ping=True,
|
||
pool_recycle=3600,
|
||
)
|
||
from fastapi import HTTPException, Request
|
||
|
||
AsyncSessionFactory = sessionmaker(
|
||
bind=engine_async,
|
||
class_=DrSession,
|
||
expire_on_commit=False,
|
||
autoflush=True
|
||
)
|
||
|
||
async def get_session(request: Request = None):
|
||
url = "无request"
|
||
if request:
|
||
url = f"{request.method} {request.url.path}"# .split("://")[-1]
|
||
# session = AsyncSessionFactory()
|
||
print(url)
|
||
# 取得request的来源IP
|
||
if request:
|
||
client_host = request.client.host
|
||
else:
|
||
client_host = "无request"
|
||
|
||
# 使用 AsyncSessionFactory 创建会话,确保 async/greenlet 配置正确
|
||
#(包括 expire_on_commit=False,避免在属性访问时触发隐式 IO,导致 MissingGreenlet / pk_1 参数异常)
|
||
session: DrSession = AsyncSessionFactory()
|
||
|
||
session.title = f"{url} - {client_host}"
|
||
|
||
# 设置request属性
|
||
if request:
|
||
session.request = request
|
||
|
||
try:
|
||
yield session
|
||
|
||
except Exception as e:
|
||
errMsg = f"数据库 session 异常 >>> {e}"
|
||
# 先打带堆栈的异常日志
|
||
session.log_exception(errMsg)
|
||
# 再通过 desc 打一条结构化的 info 日志(含步骤、调用位置)
|
||
session.desc = f"EXCEPTION: {errMsg}"
|
||
await session.rollback()
|
||
# 重新抛出原始异常,不转换为 HTTPException
|
||
raise e # HTTPException(status_code=e.status_code, detail=errMsg) # main.py中将捕获本异常
|
||
finally:
|
||
# session.desc = f"数据库 session 关闭"
|
||
session.desc = ""
|
||
await session.close()
|