395 lines
12 KiB
Python
395 lines
12 KiB
Python
"""
|
||
用户服务
|
||
"""
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import and_, or_, desc
|
||
from typing import List, Optional, Tuple
|
||
from ..models.user import User, Transaction, UserRole, UserStatus
|
||
from ..schemas.user import UserCreate, UserUpdate, ChangePasswordRequest
|
||
from ..core.security import get_password_hash, verify_password
|
||
from ..core.config import settings
|
||
|
||
# 获取游戏配置
|
||
game_settings = settings.game
|
||
from ..utils.redis import redis_client
|
||
from datetime import datetime
|
||
from ..services.system_service import SystemService
|
||
from ..models.system import ConfigType
|
||
|
||
|
||
class UserService:
|
||
"""用户服务"""
|
||
|
||
@staticmethod
|
||
def create_user(db: Session, user_data: UserCreate) -> User:
|
||
"""
|
||
创建用户
|
||
"""
|
||
# 检查用户名和邮箱是否已存在
|
||
existing_user = db.query(User).filter(
|
||
(User.username == user_data.username) | (User.email == user_data.email)
|
||
).first()
|
||
if existing_user:
|
||
raise ValueError("用户名或邮箱已存在")
|
||
|
||
# 检查密码长度(虽然前端已检查,但后端也要验证)
|
||
if len(user_data.password) > 72:
|
||
raise ValueError("密码长度不能超过72个字符")
|
||
|
||
# 创建用户
|
||
hashed_password = get_password_hash(user_data.password)
|
||
db_user = User(
|
||
username=user_data.username,
|
||
email=user_data.email,
|
||
hashed_password=hashed_password,
|
||
)
|
||
db.add(db_user)
|
||
db.commit() # 先提交用户,获取ID
|
||
db.refresh(db_user) # 刷新对象,获取自增ID
|
||
|
||
# 记录注册奖励交易
|
||
db.add(Transaction(
|
||
user_id=db_user.id,
|
||
type="注册奖励",
|
||
amount=game_settings.NEW_USER_REWARD,
|
||
balance_after=db_user.balance,
|
||
description="新用户注册奖励"
|
||
))
|
||
db.commit()
|
||
|
||
return db_user
|
||
|
||
@staticmethod
|
||
def authenticate_user(db: Session, username: str, password: str) -> User:
|
||
"""
|
||
验证用户
|
||
"""
|
||
user = db.query(User).filter(User.username == username).first()
|
||
if not user or not verify_password(password, user.hashed_password):
|
||
return None
|
||
return user
|
||
|
||
@staticmethod
|
||
def get_user_by_id(db: Session, user_id: int) -> User:
|
||
"""
|
||
根据ID获取用户
|
||
"""
|
||
return db.query(User).filter(User.id == user_id).first()
|
||
|
||
@staticmethod
|
||
def get_user_by_username(db: Session, username: str) -> User:
|
||
"""
|
||
根据用户名获取用户
|
||
"""
|
||
return db.query(User).filter(User.username == username).first()
|
||
|
||
@staticmethod
|
||
def update_user(db: Session, user: User, user_data: UserUpdate) -> User:
|
||
"""
|
||
更新用户信息
|
||
"""
|
||
update_data = user_data.dict(exclude_unset=True)
|
||
|
||
# 如果更新邮箱,检查邮箱是否已存在
|
||
if "email" in update_data and update_data["email"] != user.email:
|
||
existing = db.query(User).filter(
|
||
User.email == update_data["email"],
|
||
User.id != user.id
|
||
).first()
|
||
if existing:
|
||
raise ValueError("邮箱已存在")
|
||
|
||
for field, value in update_data.items():
|
||
setattr(user, field, value)
|
||
|
||
db.commit()
|
||
db.refresh(user)
|
||
return user
|
||
|
||
@staticmethod
|
||
def get_user_list(
|
||
db: Session,
|
||
skip: int = 0,
|
||
limit: int = 20,
|
||
search: Optional[str] = None,
|
||
role: Optional[str] = None,
|
||
status: Optional[str] = None,
|
||
sort_by: str = "created_at",
|
||
order: str = "desc"
|
||
) -> List[User]:
|
||
"""
|
||
获取用户列表(管理员)
|
||
"""
|
||
query = db.query(User)
|
||
|
||
# 搜索过滤
|
||
if search:
|
||
query = query.filter(
|
||
or_(
|
||
User.username.like(f"%{search}%"),
|
||
User.email.like(f"%{search}%"),
|
||
User.nickname.like(f"%{search}%")
|
||
)
|
||
)
|
||
|
||
# 角色过滤
|
||
if role:
|
||
query = query.filter(User.role == role)
|
||
|
||
# 状态过滤
|
||
if status:
|
||
query = query.filter(User.status == status)
|
||
|
||
# 排序
|
||
sort_field = getattr(User, sort_by, User.created_at)
|
||
if order.lower() == "desc":
|
||
query = query.order_by(desc(sort_field))
|
||
else:
|
||
query = query.order_by(sort_field)
|
||
|
||
return query.offset(skip).limit(limit).all()
|
||
|
||
@staticmethod
|
||
def adjust_balance_with_version(
|
||
db: Session,
|
||
user: User,
|
||
amount: int,
|
||
description: str,
|
||
admin_user: User,
|
||
expected_version: Optional[int] = None
|
||
) -> bool:
|
||
"""
|
||
调整用户余额(使用乐观锁)
|
||
"""
|
||
# 记录变更前的余额
|
||
old_balance = user.balance
|
||
|
||
# 如果指定了版本号,进行版本检查
|
||
if expected_version is not None and user.version != expected_version:
|
||
raise ValueError("用户数据已更新,请刷新后重试")
|
||
|
||
# 更新余额和版本
|
||
user.balance += amount
|
||
user.version += 1
|
||
|
||
# 通知用户余额更新
|
||
import asyncio
|
||
try:
|
||
# 尝试获取当前事件循环
|
||
loop = asyncio.get_running_loop()
|
||
# 局部导入避免循环导入
|
||
from ..routers.websocket import notify_user_balance_update
|
||
loop.create_task(notify_user_balance_update(user.id, user.balance))
|
||
except RuntimeError:
|
||
# 如果没有运行中的事件循环,则在新线程中处理
|
||
import threading
|
||
def run_async():
|
||
async def _run():
|
||
# 局部导入避免循环导入
|
||
from ..routers.websocket import notify_user_balance_update
|
||
await notify_user_balance_update(user.id, user.balance)
|
||
asyncio.run(_run())
|
||
thread = threading.Thread(target=run_async, daemon=True)
|
||
thread.start()
|
||
|
||
# 记录交易
|
||
transaction = Transaction(
|
||
user_id=user.id,
|
||
type="管理员调整",
|
||
amount=amount,
|
||
balance_after=user.balance,
|
||
description=f"{description} (操作人: {admin_user.username})"
|
||
)
|
||
|
||
db.add(transaction)
|
||
db.add(user)
|
||
db.commit()
|
||
|
||
# 余额变更后,触发余额监控服务
|
||
from ..services.balance_monitor_service import BalanceMonitorService
|
||
BalanceMonitorService.on_balance_changed(db, user, old_balance, user.balance)
|
||
|
||
return True
|
||
|
||
@staticmethod
|
||
def create_transaction(
|
||
db: Session,
|
||
user_id: int,
|
||
transaction_type: str,
|
||
amount: int,
|
||
balance_after: int,
|
||
description: str,
|
||
related_id: Optional[int] = None
|
||
) -> Transaction:
|
||
"""
|
||
创建交易记录
|
||
"""
|
||
transaction = Transaction(
|
||
user_id=user_id,
|
||
type=transaction_type,
|
||
amount=amount,
|
||
balance_after=balance_after,
|
||
description=description,
|
||
related_id=related_id
|
||
)
|
||
db.add(transaction)
|
||
db.commit()
|
||
db.refresh(transaction)
|
||
return transaction
|
||
|
||
@staticmethod
|
||
def get_user_transactions(
|
||
db: Session,
|
||
user_id: int,
|
||
skip: int = 0,
|
||
limit: int = 50
|
||
) -> List[Transaction]:
|
||
"""
|
||
获取用户交易记录
|
||
"""
|
||
return db.query(Transaction).filter(
|
||
Transaction.user_id == user_id
|
||
).order_by(desc(Transaction.created_at)).offset(skip).limit(limit).all()
|
||
|
||
@staticmethod
|
||
def get_user_transactions_paginated(
|
||
db: Session,
|
||
user_id: int,
|
||
skip: int = 0,
|
||
limit: int = 20
|
||
) -> List[Transaction]:
|
||
"""
|
||
分页获取用户交易记录(用于管理员界面)
|
||
"""
|
||
return db.query(Transaction).filter(
|
||
Transaction.user_id == user_id
|
||
).order_by(desc(Transaction.created_at)).offset(skip).limit(limit).all()
|
||
@staticmethod
|
||
def claim_daily_allowance(db: Session, user: User) -> bool:
|
||
"""
|
||
领取每日低保
|
||
"""
|
||
today = datetime.now().date()
|
||
allowance_key = f"daily_allowance:{user.id}:{today.isoformat()}"
|
||
|
||
# 检查今日是否已领取
|
||
if redis_client.exists(allowance_key):
|
||
return False
|
||
|
||
# 从数据库获取低保金额
|
||
daily_allowance = UserService.get_daily_allowance(db)
|
||
|
||
# 发放低保
|
||
user.balance += daily_allowance
|
||
user.version += 1
|
||
db.commit()
|
||
|
||
# 通知用户余额更新
|
||
import asyncio
|
||
try:
|
||
# 尝试获取当前事件循环
|
||
loop = asyncio.get_running_loop()
|
||
# 局部导入避免循环导入
|
||
from ..routers.websocket import notify_user_balance_update
|
||
loop.create_task(notify_user_balance_update(user.id, user.balance))
|
||
except RuntimeError:
|
||
# 如果没有运行中的事件循环,则在新线程中处理
|
||
import threading
|
||
def run_async():
|
||
async def _run():
|
||
# 局部导入避免循环导入
|
||
from ..routers.websocket import notify_user_balance_update
|
||
await notify_user_balance_update(user.id, user.balance)
|
||
asyncio.run(_run())
|
||
thread = threading.Thread(target=run_async, daemon=True)
|
||
thread.start()
|
||
|
||
# 记录交易
|
||
UserService.create_transaction(
|
||
db=db,
|
||
user_id=user.id,
|
||
transaction_type="低保",
|
||
amount=daily_allowance,
|
||
balance_after=user.balance,
|
||
description="每日低保"
|
||
)
|
||
|
||
# 设置领取标记(24小时过期)
|
||
redis_client.setex(allowance_key, 86400, "claimed")
|
||
|
||
return True
|
||
|
||
@staticmethod
|
||
def get_rich_ranking(db: Session, limit: int = 10) -> List[User]:
|
||
"""
|
||
获取富豪榜
|
||
"""
|
||
return db.query(User).filter(
|
||
User.is_active == True
|
||
).order_by(desc(User.balance)).limit(limit).all()
|
||
|
||
@staticmethod
|
||
def get_next_allowance_time(db: Session, user: User) -> dict:
|
||
"""
|
||
获取下次低保领取时间
|
||
"""
|
||
from datetime import datetime, timedelta
|
||
today = datetime.now().date()
|
||
allowance_key = f"daily_allowance:{user.id}:{today.isoformat()}"
|
||
|
||
# 从数据库获取低保金额
|
||
daily_allowance = UserService.get_daily_allowance(db)
|
||
|
||
# 检查今日是否已领取
|
||
if redis_client.exists(allowance_key):
|
||
# 如果今天已领取,下次领取时间为明天
|
||
tomorrow = today + timedelta(days=1)
|
||
next_claim_time = datetime.combine(tomorrow, datetime.min.time())
|
||
return {
|
||
"can_claim": False,
|
||
"next_claim_time": next_claim_time.isoformat(),
|
||
"daily_allowance": daily_allowance
|
||
}
|
||
else:
|
||
# 如果今天未领取,可以立即领取
|
||
return {
|
||
"can_claim": True,
|
||
"next_claim_time": datetime.now().isoformat(),
|
||
"daily_allowance": daily_allowance
|
||
}
|
||
|
||
@staticmethod
|
||
def get_daily_allowance(db: Session) -> int:
|
||
"""
|
||
从数据库获取每日低保金额
|
||
"""
|
||
config = SystemService.get_config(db, "GAME_DAILY_ALLOWANCE")
|
||
if config:
|
||
typed_value = SystemService.get_typed_value(config)
|
||
return int(typed_value)
|
||
# 如果数据库中没有配置,返回默认值
|
||
return game_settings.DAILY_ALLOWANCE
|
||
|
||
@staticmethod
|
||
def get_balance_zero_reward(db: Session) -> int:
|
||
"""
|
||
从数据库获取余额清零自动发放金额
|
||
"""
|
||
config = SystemService.get_config(db, "BALANCE_ZERO_REWARD_AMOUNT")
|
||
if config:
|
||
typed_value = SystemService.get_typed_value(config)
|
||
return int(typed_value)
|
||
# 如果数据库中没有配置,返回默认值10000分
|
||
return 10000
|
||
|
||
@staticmethod
|
||
def get_allowance_reset_time(db: Session) -> str:
|
||
"""
|
||
从数据库获取低保每日刷新时间
|
||
"""
|
||
config = SystemService.get_config(db, "ALLOWANCE_RESET_TIME")
|
||
if config:
|
||
return config.config_value
|
||
# 如果数据库中没有配置,返回默认值
|
||
return "00:00"
|