""" WebSocket路由 - 实时通信 """ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query from typing import Dict, List import json import asyncio from datetime import datetime from ..core.security import verify_token from ..utils.redis import get_pool_cache from ..services.game_service import GameService from ..services.countdown_service import countdown_manager from ..core.database import SessionLocal router = APIRouter(tags=["websocket"]) # 连接管理器 class ConnectionManager: def __init__(self): # 主播ID -> 连接列表 self.streamer_connections: Dict[int, List[WebSocket]] = {} # 用户ID -> 连接列表 self.user_connections: Dict[int, List[WebSocket]] = {} async def connect(self, websocket: WebSocket, streamer_id: int = None, user_id: int = None): await websocket.accept() if streamer_id: if streamer_id not in self.streamer_connections: self.streamer_connections[streamer_id] = [] self.streamer_connections[streamer_id].append(websocket) if user_id: if user_id not in self.user_connections: self.user_connections[user_id] = [] self.user_connections[user_id].append(websocket) def disconnect(self, websocket: WebSocket, streamer_id: int = None, user_id: int = None): if streamer_id and streamer_id in self.streamer_connections: if websocket in self.streamer_connections[streamer_id]: self.streamer_connections[streamer_id].remove(websocket) if user_id and user_id in self.user_connections: if websocket in self.user_connections[user_id]: self.user_connections[user_id].remove(websocket) async def send_to_streamer(self, streamer_id: int, message: dict): if streamer_id in self.streamer_connections: for connection in self.streamer_connections[streamer_id]: try: await connection.send_text(json.dumps(message, ensure_ascii=False)) except: pass async def send_to_user(self, user_id: int, message: dict): if user_id in self.user_connections: for connection in self.user_connections[user_id]: try: await connection.send_text(json.dumps(message, ensure_ascii=False)) except: pass async def broadcast_to_all(self, message: dict): """广播给所有连接""" all_connections = [] for connections in self.streamer_connections.values(): all_connections.extend(connections) for connections in self.user_connections.values(): all_connections.extend(connections) for connection in all_connections: try: await connection.send_text(json.dumps(message, ensure_ascii=False)) except: pass manager = ConnectionManager() @router.websocket("/ws/streamer/{streamer_id}") async def websocket_endpoint_for_streamer( websocket: WebSocket, streamer_id: int, token: str = Query(...) ): # 验证令牌 user_id = verify_token(token) if not user_id: await websocket.close(code=4001) return await manager.connect(websocket, streamer_id=streamer_id) print(f"Streamer {streamer_id} WebSocket connected") try: # 启动奖池广播任务(频率降低到5秒) broadcast_task = asyncio.create_task( broadcast_pool_updates(streamer_id) ) print(f"Started pool broadcast task for streamer {streamer_id}") while True: # 保持连接 data = await websocket.receive_text() # 可以处理来自客户端的消息 except WebSocketDisconnect: manager.disconnect(websocket, streamer_id=streamer_id) if 'broadcast_task' in locals(): broadcast_task.cancel() except Exception as e: print(f"WebSocket error: {e}") manager.disconnect(websocket, streamer_id=streamer_id) if 'broadcast_task' in locals(): broadcast_task.cancel() @router.websocket("/ws/user/{user_id}") async def websocket_endpoint_for_user( websocket: WebSocket, user_id: int, token: str = Query(...) ): # 验证令牌 token_user_id = verify_token(token) if not token_user_id or int(token_user_id) != user_id: await websocket.close(code=4001) return await manager.connect(websocket, user_id=user_id) try: while True: data = await websocket.receive_text() # 处理用户消息 except WebSocketDisconnect: manager.disconnect(websocket, user_id=user_id) except Exception as e: print(f"WebSocket error: {e}") manager.disconnect(websocket, user_id=user_id) # 添加通用WebSocket端点来处理查询参数 @router.websocket("/socket.io/") async def websocket_endpoint( websocket: WebSocket, role: str = Query(...), id: int = Query(...), token: str = Query(...) ): # 验证令牌 user_id = verify_token(token) if not user_id: await websocket.close(code=4001) return # 根据角色连接到相应的管理器 if role == "streamer": streamer_id = int(id) await manager.connect(websocket, streamer_id=streamer_id) print(f"Streamer {streamer_id} WebSocket (socket.io) connected") try: # 启动奖池广播任务(频率降低到5秒) broadcast_task = asyncio.create_task( broadcast_pool_updates(streamer_id) ) print(f"Started pool broadcast task (socket.io) for streamer {streamer_id}") while True: data = await websocket.receive_text() # 可以处理来自客户端的消息 except WebSocketDisconnect: manager.disconnect(websocket, streamer_id=streamer_id) if 'broadcast_task' in locals(): broadcast_task.cancel() except Exception as e: print(f"WebSocket error: {e}") manager.disconnect(websocket, streamer_id=streamer_id) if 'broadcast_task' in locals(): broadcast_task.cancel() elif role == "user": user_id = int(id) await manager.connect(websocket, user_id=user_id) try: while True: data = await websocket.receive_text() # 处理用户消息 except WebSocketDisconnect: manager.disconnect(websocket, user_id=user_id) except Exception as e: print(f"WebSocket error: {e}") manager.disconnect(websocket, user_id=user_id) else: await websocket.close(code=4000) async def broadcast_pool_updates(streamer_id: int): """定期广播奖池更新(不包含倒计时)""" print(f"Broadcast pool updates task started for streamer {streamer_id}") while True: try: # 获取主播的所有宝箱 db = SessionLocal() chests = GameService.get_active_chests(db, streamer_id) for chest in chests: # 获取最新奖池数据 pool_data = get_pool_cache(chest.id) # 计算赔率 total = pool_data["pool_a"] + pool_data["pool_b"] if total > 0: distributable = total * 0.9 odds_a = round(distributable / pool_data["pool_a"], 2) if pool_data["pool_a"] > 0 else 0 odds_b = round(distributable / pool_data["pool_b"], 2) if pool_data["pool_b"] > 0 else 0 else: odds_a = odds_b = 0 # 发送奖池更新消息 pool_message = { "type": "pool_update", "chest_id": chest.id, "pool_a": pool_data["pool_a"], "pool_b": pool_data["pool_b"], "odds_a": odds_a, "odds_b": odds_b, "total_bets": chest.total_bets } await manager.send_to_streamer(streamer_id, pool_message) db.close() # 每5秒广播一次奖池更新(大幅降低频率,节省资源) await asyncio.sleep(5) except Exception as e: print(f"Broadcast error: {e}") await asyncio.sleep(1) # ==================== 倒计时相关函数(已废弃) ==================== # 说明:前端已实现本地倒计时,后端通过SchedulerService自动封盘 # 以下函数不再使用,保留以供参考 # async def start_chests_countdown(streamer_id: int): # """启动主播所有活跃宝箱的倒计时(已废弃)""" # print(f"Starting countdown for streamer {streamer_id}") # try: # db = SessionLocal() # chests = GameService.get_active_chests(db, streamer_id) # # print(f"Found {len(chests)} chests for streamer {streamer_id}") # # for chest in chests: # print(f"Checking chest {chest.id}, status: {chest.status}") # if chest.status == 0: # BETTING # print(f"Starting countdown for chest {chest.id}") # # 启动倒计时 # await countdown_manager.start_chest_countdown( # chest, # on_update=lambda cid, time_rem: send_countdown_update(cid, time_rem, streamer_id), # on_expire=lambda cid: handle_countdown_expire(cid) # ) # # db.close() # except Exception as e: # print(f"Error starting countdown for streamer {streamer_id}: {e}") # import traceback # traceback.print_exc() # async def stop_all_chests_countdown(streamer_id: int): # """停止主播所有宝箱的倒计时(已废弃)""" # try: # db = SessionLocal() # chests = GameService.get_active_chests(db, streamer_id) # # for chest in chests: # await countdown_manager.stop_chest_countdown(chest.id) # # db.close() # except Exception as e: # print(f"Error stopping countdown for streamer {streamer_id}: {e}") # async def send_countdown_update(chest_id: int, time_remaining: int, streamer_id: int): # """发送倒计时更新消息(已废弃)""" # print(f"Sending countdown update: chest {chest_id}, time: {time_remaining}") # countdown_message = { # "type": "countdown_update", # "chest_id": chest_id, # "time_remaining": time_remaining # } # # # 向主播和所有用户广播倒计时更新 # await manager.send_to_streamer(streamer_id, countdown_message) # await manager.broadcast_to_all(countdown_message) # async def handle_countdown_expire(chest_id: int): # """处理倒计时结束(已废弃)""" # print(f"Chest {chest_id} countdown expired, locking...") # # try: # # 更新数据库中的宝箱状态 # db = SessionLocal() # updated_chest = GameService.lock_expired_chest(db, chest_id) # # if updated_chest: # # 通知所有客户端宝箱状态变更 # status_message = { # "type": "chest_status", # "chest_id": chest_id, # "status": updated_chest.status # } # await manager.broadcast_to_all(status_message) # # db.close() # except Exception as e: # print(f"Error locking expired chest {chest_id}: {e}") # try: # db.close() # except: # pass async def notify_user_balance_update(user_id: int, new_balance: int): """通知用户余额更新""" message = { "type": "balance_update", "balance": new_balance } await manager.send_to_user(user_id, message)