hyf-backend/th_agenter/services/auth.py

135 lines
4.9 KiB
Python
Raw Normal View History

2026-01-21 13:45:39 +08:00
"""Authentication service."""
from loguru import logger
from typing import Optional
from datetime import datetime, timedelta, timezone
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from sqlalchemy import select
import bcrypt
import jwt
from ..core.config import settings
from ..db.database import get_session
from ..models.user import User
security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security),
session: Session = Depends(get_session)
) -> User:
"""Get current authenticated user (for Depends)."""
from ..core.context import UserContext
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
payload = AuthService.verify_token(token)
if payload is None:
session.desc = f"ERROR: 令牌验证失败 - 令牌: {token[:50]}..."
raise credentials_exception
username: str = payload.get("sub")
if username is None:
session.desc = "ERROR: 令牌中没有用户名"
raise credentials_exception
stmt = select(User).where(User.username == username)
user = (await session.execute(stmt)).scalar_one_or_none()
if user is None:
session.desc = f"ERROR: 数据库中未找到用户 {username}"
raise credentials_exception
UserContext.set_current_user(user, canLog=True)
return user
def get_current_active_user(current_user: User = Depends(get_current_user)) -> User:
"""Get current active user (for Depends)."""
if not current_user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Inactive user"
)
return current_user
class AuthService:
"""Authentication service."""
get_current_user = get_current_user
get_current_active_user = get_current_active_user
@staticmethod
async def authenticate_user_by_email(session: Session, email: str, password: str) -> Optional[User]:
"""Authenticate user with email and password."""
session.desc = f"根据邮箱 {email} 验证用户密码"
stmt = select(User).where(User.email == email)
user = (await session.execute(stmt)).scalar_one_or_none()
if not user:
return None
if not AuthService.verify_password(password, user.hashed_password):
return None
return user
@staticmethod
async def authenticate_user(session: Session, username: str, password: str) -> Optional[User]:
"""Authenticate user with username and password."""
session.desc = f"根据用户名 {username} 验证用户密码"
stmt = select(User).where(User.username == username)
user = (await session.execute(stmt)).scalar_one_or_none()
if not user:
return None
if not AuthService.verify_password(password, user.hashed_password):
return None
return user
@staticmethod
async def create_access_token(session: Session, data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""创建 JWT 访问 token"""
session.desc = f"创建 JWT 访问 token - 数据: {data}"
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.security.access_token_expire_minutes)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
to_encode,
settings.security.secret_key,
algorithm=settings.security.algorithm
)
return encoded_jwt
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
# 直接使用bcrypt库进行密码验证
return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8'))
@staticmethod
def get_password_hash(password: str) -> str:
"""Generate password hash."""
# 直接使用bcrypt库进行哈希
salt = bcrypt.gensalt()
hashed_bytes = bcrypt.hashpw(password.encode('utf-8'), salt)
hashed_password = hashed_bytes.decode('utf-8')
return hashed_password
@staticmethod
def verify_token(token: str) -> Optional[dict]:
"""Verify JWT token."""
try:
payload = jwt.decode(
token,
settings.security.secret_key,
algorithms=[settings.security.algorithm]
)
return payload
except jwt.PyJWTError as e:
logger.error(f"Token verification failed: {e}")
logger.error(f"Token: {token[:50]}...")
return None