baoxiang/backend/app/routers/websocket.py

342 lines
12 KiB
Python
Raw Normal View History

2025-12-16 18:06:50 +08:00
"""
WebSocket路由 - 实时通信
"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
from typing import Dict, List
import json
import asyncio
from datetime import datetime
from ..core.security import verify_token
from ..utils.redis import get_pool_cache
from ..services.game_service import GameService
from ..services.countdown_service import countdown_manager
from ..core.database import SessionLocal
router = APIRouter(tags=["websocket"])
# 连接管理器
class ConnectionManager:
def __init__(self):
# 主播ID -> 连接列表
self.streamer_connections: Dict[int, List[WebSocket]] = {}
# 用户ID -> 连接列表
self.user_connections: Dict[int, List[WebSocket]] = {}
async def connect(self, websocket: WebSocket, streamer_id: int = None, user_id: int = None):
await websocket.accept()
if streamer_id:
if streamer_id not in self.streamer_connections:
self.streamer_connections[streamer_id] = []
self.streamer_connections[streamer_id].append(websocket)
if user_id:
if user_id not in self.user_connections:
self.user_connections[user_id] = []
self.user_connections[user_id].append(websocket)
def disconnect(self, websocket: WebSocket, streamer_id: int = None, user_id: int = None):
if streamer_id and streamer_id in self.streamer_connections:
if websocket in self.streamer_connections[streamer_id]:
self.streamer_connections[streamer_id].remove(websocket)
if user_id and user_id in self.user_connections:
if websocket in self.user_connections[user_id]:
self.user_connections[user_id].remove(websocket)
async def send_to_streamer(self, streamer_id: int, message: dict):
if streamer_id in self.streamer_connections:
for connection in self.streamer_connections[streamer_id]:
try:
await connection.send_text(json.dumps(message, ensure_ascii=False))
except:
pass
async def send_to_user(self, user_id: int, message: dict):
if user_id in self.user_connections:
for connection in self.user_connections[user_id]:
try:
await connection.send_text(json.dumps(message, ensure_ascii=False))
except:
pass
async def broadcast_to_all(self, message: dict):
"""广播给所有连接"""
all_connections = []
for connections in self.streamer_connections.values():
all_connections.extend(connections)
for connections in self.user_connections.values():
all_connections.extend(connections)
for connection in all_connections:
try:
await connection.send_text(json.dumps(message, ensure_ascii=False))
except:
pass
manager = ConnectionManager()
@router.websocket("/ws/streamer/{streamer_id}")
async def websocket_endpoint_for_streamer(
websocket: WebSocket,
streamer_id: int,
token: str = Query(...)
):
# 验证令牌
user_id = verify_token(token)
if not user_id:
await websocket.close(code=4001)
return
await manager.connect(websocket, streamer_id=streamer_id)
print(f"Streamer {streamer_id} WebSocket connected")
try:
# 启动奖池广播任务频率降低到5秒
broadcast_task = asyncio.create_task(
broadcast_pool_updates(streamer_id)
)
print(f"Started pool broadcast task for streamer {streamer_id}")
while True:
# 保持连接
data = await websocket.receive_text()
# 可以处理来自客户端的消息
except WebSocketDisconnect:
manager.disconnect(websocket, streamer_id=streamer_id)
if 'broadcast_task' in locals():
broadcast_task.cancel()
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket, streamer_id=streamer_id)
if 'broadcast_task' in locals():
broadcast_task.cancel()
@router.websocket("/ws/user/{user_id}")
async def websocket_endpoint_for_user(
websocket: WebSocket,
user_id: int,
token: str = Query(...)
):
# 验证令牌
token_user_id = verify_token(token)
if not token_user_id or int(token_user_id) != user_id:
await websocket.close(code=4001)
return
await manager.connect(websocket, user_id=user_id)
try:
while True:
data = await websocket.receive_text()
# 处理用户消息
except WebSocketDisconnect:
manager.disconnect(websocket, user_id=user_id)
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket, user_id=user_id)
# 添加通用WebSocket端点来处理查询参数
@router.websocket("/socket.io/")
async def websocket_endpoint(
websocket: WebSocket,
role: str = Query(...),
id: int = Query(...),
token: str = Query(...)
):
# 验证令牌
user_id = verify_token(token)
if not user_id:
await websocket.close(code=4001)
return
# 根据角色连接到相应的管理器
if role == "streamer":
streamer_id = int(id)
await manager.connect(websocket, streamer_id=streamer_id)
print(f"Streamer {streamer_id} WebSocket (socket.io) connected")
try:
# 启动奖池广播任务频率降低到5秒
broadcast_task = asyncio.create_task(
broadcast_pool_updates(streamer_id)
)
print(f"Started pool broadcast task (socket.io) for streamer {streamer_id}")
while True:
data = await websocket.receive_text()
# 可以处理来自客户端的消息
except WebSocketDisconnect:
manager.disconnect(websocket, streamer_id=streamer_id)
if 'broadcast_task' in locals():
broadcast_task.cancel()
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket, streamer_id=streamer_id)
if 'broadcast_task' in locals():
broadcast_task.cancel()
elif role == "user":
user_id = int(id)
await manager.connect(websocket, user_id=user_id)
try:
while True:
data = await websocket.receive_text()
# 处理用户消息
except WebSocketDisconnect:
manager.disconnect(websocket, user_id=user_id)
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket, user_id=user_id)
else:
await websocket.close(code=4000)
async def broadcast_pool_updates(streamer_id: int):
"""定期广播奖池更新(不包含倒计时)"""
print(f"Broadcast pool updates task started for streamer {streamer_id}")
while True:
try:
# 获取主播的所有宝箱
db = SessionLocal()
chests = GameService.get_active_chests(db, streamer_id)
for chest in chests:
# 获取最新奖池数据
pool_data = get_pool_cache(chest.id)
# 计算赔率
total = pool_data["pool_a"] + pool_data["pool_b"]
if total > 0:
distributable = total * 0.9
odds_a = round(distributable / pool_data["pool_a"], 2) if pool_data["pool_a"] > 0 else 0
odds_b = round(distributable / pool_data["pool_b"], 2) if pool_data["pool_b"] > 0 else 0
else:
odds_a = odds_b = 0
# 发送奖池更新消息
pool_message = {
"type": "pool_update",
"chest_id": chest.id,
"pool_a": pool_data["pool_a"],
"pool_b": pool_data["pool_b"],
"odds_a": odds_a,
"odds_b": odds_b,
"total_bets": chest.total_bets
}
await manager.send_to_streamer(streamer_id, pool_message)
db.close()
# 每5秒广播一次奖池更新大幅降低频率节省资源
await asyncio.sleep(5)
except Exception as e:
print(f"Broadcast error: {e}")
await asyncio.sleep(1)
# ==================== 倒计时相关函数(已废弃) ====================
# 说明前端已实现本地倒计时后端通过SchedulerService自动封盘
# 以下函数不再使用,保留以供参考
# async def start_chests_countdown(streamer_id: int):
# """启动主播所有活跃宝箱的倒计时(已废弃)"""
# print(f"Starting countdown for streamer {streamer_id}")
# try:
# db = SessionLocal()
# chests = GameService.get_active_chests(db, streamer_id)
#
# print(f"Found {len(chests)} chests for streamer {streamer_id}")
#
# for chest in chests:
# print(f"Checking chest {chest.id}, status: {chest.status}")
# if chest.status == 0: # BETTING
# print(f"Starting countdown for chest {chest.id}")
# # 启动倒计时
# await countdown_manager.start_chest_countdown(
# chest,
# on_update=lambda cid, time_rem: send_countdown_update(cid, time_rem, streamer_id),
# on_expire=lambda cid: handle_countdown_expire(cid)
# )
#
# db.close()
# except Exception as e:
# print(f"Error starting countdown for streamer {streamer_id}: {e}")
# import traceback
# traceback.print_exc()
# async def stop_all_chests_countdown(streamer_id: int):
# """停止主播所有宝箱的倒计时(已废弃)"""
# try:
# db = SessionLocal()
# chests = GameService.get_active_chests(db, streamer_id)
#
# for chest in chests:
# await countdown_manager.stop_chest_countdown(chest.id)
#
# db.close()
# except Exception as e:
# print(f"Error stopping countdown for streamer {streamer_id}: {e}")
# async def send_countdown_update(chest_id: int, time_remaining: int, streamer_id: int):
# """发送倒计时更新消息(已废弃)"""
# print(f"Sending countdown update: chest {chest_id}, time: {time_remaining}")
# countdown_message = {
# "type": "countdown_update",
# "chest_id": chest_id,
# "time_remaining": time_remaining
# }
#
# # 向主播和所有用户广播倒计时更新
# await manager.send_to_streamer(streamer_id, countdown_message)
# await manager.broadcast_to_all(countdown_message)
# async def handle_countdown_expire(chest_id: int):
# """处理倒计时结束(已废弃)"""
# print(f"Chest {chest_id} countdown expired, locking...")
#
# try:
# # 更新数据库中的宝箱状态
# db = SessionLocal()
# updated_chest = GameService.lock_expired_chest(db, chest_id)
#
# if updated_chest:
# # 通知所有客户端宝箱状态变更
# status_message = {
# "type": "chest_status",
# "chest_id": chest_id,
# "status": updated_chest.status
# }
# await manager.broadcast_to_all(status_message)
#
# db.close()
# except Exception as e:
# print(f"Error locking expired chest {chest_id}: {e}")
# try:
# db.close()
# except:
# pass
async def notify_user_balance_update(user_id: int, new_balance: int):
"""通知用户余额更新"""
message = {
"type": "balance_update",
"balance": new_balance
}
await manager.send_to_user(user_id, message)