hyf-backend/th_agenter/core/context.py

142 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
HTTP请求上下文管理获取当前登录用户信息及Token信息
"""
from contextvars import ContextVar
from typing import Optional
import threading
from ..models.user import User
from loguru import logger
# Context variable to store current user
current_user_context: ContextVar[Optional[dict]] = ContextVar('current_user', default=None)
# Thread-local storage as backup
_thread_local = threading.local()
class UserContext:
"""User context manager for accessing current user globally."""
@staticmethod
def set_current_user(user: User, canLog: bool = False) -> None:
"""Set current user in context."""
if canLog:
logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})")
# Store user information as a dictionary instead of the SQLAlchemy model
user_dict = {
'id': user.id,
'username': user.username,
'email': user.email,
'full_name': user.full_name,
'is_active': user.is_active
}
# Set in ContextVar
current_user_context.set(user_dict)
# Also set in thread-local as backup
_thread_local.current_user = user_dict
# Verify it was set
verify_user = current_user_context.get()
if canLog:
logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}")
@staticmethod
def set_current_user_with_token(user: User, canLog: bool = False):
"""Set current user in context and return token for cleanup."""
if canLog:
logger.info(f"[UserContext] - 设置用户上下文 {user.username} (ID: {user.id})")
# Store user information as a dictionary instead of the SQLAlchemy model
user_dict = {
'id': user.id,
'username': user.username,
'email': user.email,
'full_name': user.full_name,
'is_active': user.is_active
}
# Set in ContextVar and get token
token = current_user_context.set(user_dict)
# Also set in thread-local as backup
_thread_local.current_user = user_dict
# Verify it was set
verify_user = current_user_context.get()
if canLog:
logger.info(f"[UserContext] - 验证 - ContextVar 用户: {verify_user.get('username') if verify_user else None}")
return token
@staticmethod
def reset_current_user_token(token):
"""Reset current user context using token."""
logger.info("[UserContext] - Resetting user context using token")
# Reset ContextVar using token
current_user_context.reset(token)
# Clear thread-local as well
if hasattr(_thread_local, 'current_user'):
delattr(_thread_local, 'current_user')
@staticmethod
def get_current_user() -> Optional[dict]:
"""Get current user from context."""
# Try ContextVar first
user = current_user_context.get()
if user:
# logger.info(f"[UserContext] - 取得当前用户为 ContextVar 用户: {user.get('username') if user else None}")
return user
# Fallback to thread-local
user = getattr(_thread_local, 'current_user', None)
if user:
# logger.info(f"[UserContext] - 取得当前用户为线程本地用户: {user.get('username') if user else None}")
return user
logger.error("[UserContext] - 上下文未找到当前用户 (neither ContextVar nor thread-local)")
return None
@staticmethod
def get_current_user_id() -> Optional[int]:
"""Get current user ID from context."""
try:
user = UserContext.get_current_user()
return user.get('id') if user else None
except Exception as e:
logger.error(f"[UserContext] - Error getting current user ID: {e}")
return None
@staticmethod
def clear_current_user(canLog: bool = False) -> None:
"""Clear current user from context."""
if canLog:
logger.info("[UserContext] - 清除当前用户上下文")
current_user_context.set(None)
if hasattr(_thread_local, 'current_user'):
delattr(_thread_local, 'current_user')
@staticmethod
def require_current_user() -> dict:
"""Get current user from context, raise exception if not found."""
# Use the same logic as get_current_user to check both ContextVar and thread-local
user = UserContext.get_current_user()
if user is None:
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No authenticated user in context"
)
return user
@staticmethod
def require_current_user_id() -> int:
"""Get current user ID from context, raise exception if not found."""
user = UserContext.require_current_user()
return user.get('id')