baoxiang/backend/app/routers/websocket.py
2025-12-17 13:19:55 +08:00

342 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
WebSocket路由 - 实时通信
"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
from typing import Dict, List
import json
import asyncio
from datetime import datetime
from ..core.security import verify_token
from ..utils.redis import get_pool_cache
from ..services.game_service import GameService
from ..services.countdown_service import countdown_manager
from ..core.database import SessionLocal
router = APIRouter(tags=["websocket"])
# 连接管理器
class ConnectionManager:
def __init__(self):
# 主播ID -> 连接列表
self.streamer_connections: Dict[int, List[WebSocket]] = {}
# 用户ID -> 连接列表
self.user_connections: Dict[int, List[WebSocket]] = {}
async def connect(self, websocket: WebSocket, streamer_id: int = None, user_id: int = None):
await websocket.accept()
if streamer_id:
if streamer_id not in self.streamer_connections:
self.streamer_connections[streamer_id] = []
self.streamer_connections[streamer_id].append(websocket)
if user_id:
if user_id not in self.user_connections:
self.user_connections[user_id] = []
self.user_connections[user_id].append(websocket)
def disconnect(self, websocket: WebSocket, streamer_id: int = None, user_id: int = None):
if streamer_id and streamer_id in self.streamer_connections:
if websocket in self.streamer_connections[streamer_id]:
self.streamer_connections[streamer_id].remove(websocket)
if user_id and user_id in self.user_connections:
if websocket in self.user_connections[user_id]:
self.user_connections[user_id].remove(websocket)
async def send_to_streamer(self, streamer_id: int, message: dict):
if streamer_id in self.streamer_connections:
for connection in self.streamer_connections[streamer_id]:
try:
await connection.send_text(json.dumps(message, ensure_ascii=False))
except:
pass
async def send_to_user(self, user_id: int, message: dict):
if user_id in self.user_connections:
for connection in self.user_connections[user_id]:
try:
await connection.send_text(json.dumps(message, ensure_ascii=False))
except:
pass
async def broadcast_to_all(self, message: dict):
"""广播给所有连接"""
all_connections = []
for connections in self.streamer_connections.values():
all_connections.extend(connections)
for connections in self.user_connections.values():
all_connections.extend(connections)
for connection in all_connections:
try:
await connection.send_text(json.dumps(message, ensure_ascii=False))
except:
pass
manager = ConnectionManager()
@router.websocket("/ws/streamer/{streamer_id}")
async def websocket_endpoint_for_streamer(
websocket: WebSocket,
streamer_id: int,
token: str = Query(...)
):
# 验证令牌
user_id = verify_token(token)
if not user_id:
await websocket.close(code=4001)
return
await manager.connect(websocket, streamer_id=streamer_id)
print(f"Streamer {streamer_id} WebSocket connected")
try:
# 启动奖池广播任务频率降低到5秒
broadcast_task = asyncio.create_task(
broadcast_pool_updates(streamer_id)
)
print(f"Started pool broadcast task for streamer {streamer_id}")
while True:
# 保持连接
data = await websocket.receive_text()
# 可以处理来自客户端的消息
except WebSocketDisconnect:
manager.disconnect(websocket, streamer_id=streamer_id)
if 'broadcast_task' in locals():
broadcast_task.cancel()
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket, streamer_id=streamer_id)
if 'broadcast_task' in locals():
broadcast_task.cancel()
@router.websocket("/ws/user/{user_id}")
async def websocket_endpoint_for_user(
websocket: WebSocket,
user_id: int,
token: str = Query(...)
):
# 验证令牌
token_user_id = verify_token(token)
if not token_user_id or int(token_user_id) != user_id:
await websocket.close(code=4001)
return
await manager.connect(websocket, user_id=user_id)
try:
while True:
data = await websocket.receive_text()
# 处理用户消息
except WebSocketDisconnect:
manager.disconnect(websocket, user_id=user_id)
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket, user_id=user_id)
# 添加通用WebSocket端点来处理查询参数
@router.websocket("/socket.io/")
async def websocket_endpoint(
websocket: WebSocket,
role: str = Query(...),
id: int = Query(...),
token: str = Query(...)
):
# 验证令牌
user_id = verify_token(token)
if not user_id:
await websocket.close(code=4001)
return
# 根据角色连接到相应的管理器
if role == "streamer":
streamer_id = int(id)
await manager.connect(websocket, streamer_id=streamer_id)
print(f"Streamer {streamer_id} WebSocket (socket.io) connected")
try:
# 启动奖池广播任务频率降低到5秒
broadcast_task = asyncio.create_task(
broadcast_pool_updates(streamer_id)
)
print(f"Started pool broadcast task (socket.io) for streamer {streamer_id}")
while True:
data = await websocket.receive_text()
# 可以处理来自客户端的消息
except WebSocketDisconnect:
manager.disconnect(websocket, streamer_id=streamer_id)
if 'broadcast_task' in locals():
broadcast_task.cancel()
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket, streamer_id=streamer_id)
if 'broadcast_task' in locals():
broadcast_task.cancel()
elif role == "user":
user_id = int(id)
await manager.connect(websocket, user_id=user_id)
try:
while True:
data = await websocket.receive_text()
# 处理用户消息
except WebSocketDisconnect:
manager.disconnect(websocket, user_id=user_id)
except Exception as e:
print(f"WebSocket error: {e}")
manager.disconnect(websocket, user_id=user_id)
else:
await websocket.close(code=4000)
async def broadcast_pool_updates(streamer_id: int):
"""定期广播奖池更新(不包含倒计时)"""
print(f"Broadcast pool updates task started for streamer {streamer_id}")
while True:
try:
# 获取主播的所有宝箱
db = SessionLocal()
chests = GameService.get_active_chests(db, streamer_id)
for chest in chests:
# 获取最新奖池数据
pool_data = get_pool_cache(chest.id)
# 计算赔率 - 修改为只对获胜方抽水
total = pool_data["pool_a"] + pool_data["pool_b"]
if total > 0:
# 赔率计算改为:(自己奖池 + 对方奖池 * 0.9) / 自己奖池
odds_a = round((pool_data["pool_a"] + pool_data["pool_b"] * 0.9) / pool_data["pool_a"], 2) if pool_data["pool_a"] > 0 else 0
odds_b = round((pool_data["pool_b"] + pool_data["pool_a"] * 0.9) / pool_data["pool_b"], 2) if pool_data["pool_b"] > 0 else 0
else:
odds_a = odds_b = 0
# 发送奖池更新消息
pool_message = {
"type": "pool_update",
"chest_id": chest.id,
"pool_a": pool_data["pool_a"],
"pool_b": pool_data["pool_b"],
"odds_a": odds_a,
"odds_b": odds_b,
"total_bets": chest.total_bets
}
await manager.send_to_streamer(streamer_id, pool_message)
db.close()
# 每5秒广播一次奖池更新大幅降低频率节省资源
await asyncio.sleep(5)
except Exception as e:
print(f"Broadcast error: {e}")
await asyncio.sleep(1)
# ==================== 倒计时相关函数(已废弃) ====================
# 说明前端已实现本地倒计时后端通过SchedulerService自动封盘
# 以下函数不再使用,保留以供参考
# async def start_chests_countdown(streamer_id: int):
# """启动主播所有活跃宝箱的倒计时(已废弃)"""
# print(f"Starting countdown for streamer {streamer_id}")
# try:
# db = SessionLocal()
# chests = GameService.get_active_chests(db, streamer_id)
#
# print(f"Found {len(chests)} chests for streamer {streamer_id}")
#
# for chest in chests:
# print(f"Checking chest {chest.id}, status: {chest.status}")
# if chest.status == 0: # BETTING
# print(f"Starting countdown for chest {chest.id}")
# # 启动倒计时
# await countdown_manager.start_chest_countdown(
# chest,
# on_update=lambda cid, time_rem: send_countdown_update(cid, time_rem, streamer_id),
# on_expire=lambda cid: handle_countdown_expire(cid)
# )
#
# db.close()
# except Exception as e:
# print(f"Error starting countdown for streamer {streamer_id}: {e}")
# import traceback
# traceback.print_exc()
# async def stop_all_chests_countdown(streamer_id: int):
# """停止主播所有宝箱的倒计时(已废弃)"""
# try:
# db = SessionLocal()
# chests = GameService.get_active_chests(db, streamer_id)
#
# for chest in chests:
# await countdown_manager.stop_chest_countdown(chest.id)
#
# db.close()
# except Exception as e:
# print(f"Error stopping countdown for streamer {streamer_id}: {e}")
# async def send_countdown_update(chest_id: int, time_remaining: int, streamer_id: int):
# """发送倒计时更新消息(已废弃)"""
# print(f"Sending countdown update: chest {chest_id}, time: {time_remaining}")
# countdown_message = {
# "type": "countdown_update",
# "chest_id": chest_id,
# "time_remaining": time_remaining
# }
#
# # 向主播和所有用户广播倒计时更新
# await manager.send_to_streamer(streamer_id, countdown_message)
# await manager.broadcast_to_all(countdown_message)
# async def handle_countdown_expire(chest_id: int):
# """处理倒计时结束(已废弃)"""
# print(f"Chest {chest_id} countdown expired, locking...")
#
# try:
# # 更新数据库中的宝箱状态
# db = SessionLocal()
# updated_chest = GameService.lock_expired_chest(db, chest_id)
#
# if updated_chest:
# # 通知所有客户端宝箱状态变更
# status_message = {
# "type": "chest_status",
# "chest_id": chest_id,
# "status": updated_chest.status
# }
# await manager.broadcast_to_all(status_message)
#
# db.close()
# except Exception as e:
# print(f"Error locking expired chest {chest_id}: {e}")
# try:
# db.close()
# except:
# pass
async def notify_user_balance_update(user_id: int, new_balance: int):
"""通知用户余额更新"""
message = {
"type": "balance_update",
"balance": new_balance
}
await manager.send_to_user(user_id, message)