342 lines
12 KiB
Python
342 lines
12 KiB
Python
"""
|
||
WebSocket路由 - 实时通信
|
||
"""
|
||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query
|
||
from typing import Dict, List
|
||
import json
|
||
import asyncio
|
||
from datetime import datetime
|
||
from ..core.security import verify_token
|
||
from ..utils.redis import get_pool_cache
|
||
from ..services.game_service import GameService
|
||
from ..services.countdown_service import countdown_manager
|
||
from ..core.database import SessionLocal
|
||
|
||
router = APIRouter(tags=["websocket"])
|
||
|
||
# 连接管理器
|
||
class ConnectionManager:
|
||
def __init__(self):
|
||
# 主播ID -> 连接列表
|
||
self.streamer_connections: Dict[int, List[WebSocket]] = {}
|
||
# 用户ID -> 连接列表
|
||
self.user_connections: Dict[int, List[WebSocket]] = {}
|
||
|
||
async def connect(self, websocket: WebSocket, streamer_id: int = None, user_id: int = None):
|
||
await websocket.accept()
|
||
|
||
if streamer_id:
|
||
if streamer_id not in self.streamer_connections:
|
||
self.streamer_connections[streamer_id] = []
|
||
self.streamer_connections[streamer_id].append(websocket)
|
||
|
||
if user_id:
|
||
if user_id not in self.user_connections:
|
||
self.user_connections[user_id] = []
|
||
self.user_connections[user_id].append(websocket)
|
||
|
||
def disconnect(self, websocket: WebSocket, streamer_id: int = None, user_id: int = None):
|
||
if streamer_id and streamer_id in self.streamer_connections:
|
||
if websocket in self.streamer_connections[streamer_id]:
|
||
self.streamer_connections[streamer_id].remove(websocket)
|
||
if user_id and user_id in self.user_connections:
|
||
if websocket in self.user_connections[user_id]:
|
||
self.user_connections[user_id].remove(websocket)
|
||
|
||
async def send_to_streamer(self, streamer_id: int, message: dict):
|
||
if streamer_id in self.streamer_connections:
|
||
for connection in self.streamer_connections[streamer_id]:
|
||
try:
|
||
await connection.send_text(json.dumps(message, ensure_ascii=False))
|
||
except:
|
||
pass
|
||
|
||
async def send_to_user(self, user_id: int, message: dict):
|
||
if user_id in self.user_connections:
|
||
for connection in self.user_connections[user_id]:
|
||
try:
|
||
await connection.send_text(json.dumps(message, ensure_ascii=False))
|
||
except:
|
||
pass
|
||
|
||
async def broadcast_to_all(self, message: dict):
|
||
"""广播给所有连接"""
|
||
all_connections = []
|
||
for connections in self.streamer_connections.values():
|
||
all_connections.extend(connections)
|
||
for connections in self.user_connections.values():
|
||
all_connections.extend(connections)
|
||
|
||
for connection in all_connections:
|
||
try:
|
||
await connection.send_text(json.dumps(message, ensure_ascii=False))
|
||
except:
|
||
pass
|
||
|
||
|
||
manager = ConnectionManager()
|
||
|
||
|
||
@router.websocket("/ws/streamer/{streamer_id}")
|
||
async def websocket_endpoint_for_streamer(
|
||
websocket: WebSocket,
|
||
streamer_id: int,
|
||
token: str = Query(...)
|
||
):
|
||
# 验证令牌
|
||
user_id = verify_token(token)
|
||
if not user_id:
|
||
await websocket.close(code=4001)
|
||
return
|
||
|
||
await manager.connect(websocket, streamer_id=streamer_id)
|
||
|
||
print(f"Streamer {streamer_id} WebSocket connected")
|
||
|
||
try:
|
||
# 启动奖池广播任务(频率降低到5秒)
|
||
broadcast_task = asyncio.create_task(
|
||
broadcast_pool_updates(streamer_id)
|
||
)
|
||
print(f"Started pool broadcast task for streamer {streamer_id}")
|
||
|
||
while True:
|
||
# 保持连接
|
||
data = await websocket.receive_text()
|
||
# 可以处理来自客户端的消息
|
||
|
||
except WebSocketDisconnect:
|
||
manager.disconnect(websocket, streamer_id=streamer_id)
|
||
if 'broadcast_task' in locals():
|
||
broadcast_task.cancel()
|
||
except Exception as e:
|
||
print(f"WebSocket error: {e}")
|
||
manager.disconnect(websocket, streamer_id=streamer_id)
|
||
if 'broadcast_task' in locals():
|
||
broadcast_task.cancel()
|
||
|
||
|
||
@router.websocket("/ws/user/{user_id}")
|
||
async def websocket_endpoint_for_user(
|
||
websocket: WebSocket,
|
||
user_id: int,
|
||
token: str = Query(...)
|
||
):
|
||
# 验证令牌
|
||
token_user_id = verify_token(token)
|
||
if not token_user_id or int(token_user_id) != user_id:
|
||
await websocket.close(code=4001)
|
||
return
|
||
|
||
await manager.connect(websocket, user_id=user_id)
|
||
|
||
try:
|
||
while True:
|
||
data = await websocket.receive_text()
|
||
# 处理用户消息
|
||
|
||
except WebSocketDisconnect:
|
||
manager.disconnect(websocket, user_id=user_id)
|
||
except Exception as e:
|
||
print(f"WebSocket error: {e}")
|
||
manager.disconnect(websocket, user_id=user_id)
|
||
|
||
|
||
# 添加通用WebSocket端点来处理查询参数
|
||
@router.websocket("/socket.io/")
|
||
async def websocket_endpoint(
|
||
websocket: WebSocket,
|
||
role: str = Query(...),
|
||
id: int = Query(...),
|
||
token: str = Query(...)
|
||
):
|
||
# 验证令牌
|
||
user_id = verify_token(token)
|
||
if not user_id:
|
||
await websocket.close(code=4001)
|
||
return
|
||
|
||
# 根据角色连接到相应的管理器
|
||
if role == "streamer":
|
||
streamer_id = int(id)
|
||
await manager.connect(websocket, streamer_id=streamer_id)
|
||
|
||
print(f"Streamer {streamer_id} WebSocket (socket.io) connected")
|
||
|
||
try:
|
||
# 启动奖池广播任务(频率降低到5秒)
|
||
broadcast_task = asyncio.create_task(
|
||
broadcast_pool_updates(streamer_id)
|
||
)
|
||
print(f"Started pool broadcast task (socket.io) for streamer {streamer_id}")
|
||
|
||
while True:
|
||
data = await websocket.receive_text()
|
||
# 可以处理来自客户端的消息
|
||
|
||
except WebSocketDisconnect:
|
||
manager.disconnect(websocket, streamer_id=streamer_id)
|
||
if 'broadcast_task' in locals():
|
||
broadcast_task.cancel()
|
||
except Exception as e:
|
||
print(f"WebSocket error: {e}")
|
||
manager.disconnect(websocket, streamer_id=streamer_id)
|
||
if 'broadcast_task' in locals():
|
||
broadcast_task.cancel()
|
||
|
||
elif role == "user":
|
||
user_id = int(id)
|
||
await manager.connect(websocket, user_id=user_id)
|
||
|
||
try:
|
||
while True:
|
||
data = await websocket.receive_text()
|
||
# 处理用户消息
|
||
|
||
except WebSocketDisconnect:
|
||
manager.disconnect(websocket, user_id=user_id)
|
||
except Exception as e:
|
||
print(f"WebSocket error: {e}")
|
||
manager.disconnect(websocket, user_id=user_id)
|
||
else:
|
||
await websocket.close(code=4000)
|
||
|
||
|
||
async def broadcast_pool_updates(streamer_id: int):
|
||
"""定期广播奖池更新(不包含倒计时)"""
|
||
print(f"Broadcast pool updates task started for streamer {streamer_id}")
|
||
while True:
|
||
try:
|
||
# 获取主播的所有宝箱
|
||
db = SessionLocal()
|
||
chests = GameService.get_active_chests(db, streamer_id)
|
||
|
||
for chest in chests:
|
||
# 获取最新奖池数据
|
||
pool_data = get_pool_cache(chest.id)
|
||
|
||
# 计算赔率 - 修改为只对获胜方抽水
|
||
total = pool_data["pool_a"] + pool_data["pool_b"]
|
||
if total > 0:
|
||
# 赔率计算改为:(自己奖池 + 对方奖池 * 0.9) / 自己奖池
|
||
odds_a = round((pool_data["pool_a"] + pool_data["pool_b"] * 0.9) / pool_data["pool_a"], 2) if pool_data["pool_a"] > 0 else 0
|
||
odds_b = round((pool_data["pool_b"] + pool_data["pool_a"] * 0.9) / pool_data["pool_b"], 2) if pool_data["pool_b"] > 0 else 0
|
||
else:
|
||
odds_a = odds_b = 0
|
||
|
||
# 发送奖池更新消息
|
||
pool_message = {
|
||
"type": "pool_update",
|
||
"chest_id": chest.id,
|
||
"pool_a": pool_data["pool_a"],
|
||
"pool_b": pool_data["pool_b"],
|
||
"odds_a": odds_a,
|
||
"odds_b": odds_b,
|
||
"total_bets": chest.total_bets
|
||
}
|
||
|
||
await manager.send_to_streamer(streamer_id, pool_message)
|
||
|
||
db.close()
|
||
|
||
# 每5秒广播一次奖池更新(大幅降低频率,节省资源)
|
||
await asyncio.sleep(5)
|
||
|
||
except Exception as e:
|
||
print(f"Broadcast error: {e}")
|
||
await asyncio.sleep(1)
|
||
|
||
|
||
# ==================== 倒计时相关函数(已废弃) ====================
|
||
# 说明:前端已实现本地倒计时,后端通过SchedulerService自动封盘
|
||
# 以下函数不再使用,保留以供参考
|
||
|
||
# async def start_chests_countdown(streamer_id: int):
|
||
# """启动主播所有活跃宝箱的倒计时(已废弃)"""
|
||
# print(f"Starting countdown for streamer {streamer_id}")
|
||
# try:
|
||
# db = SessionLocal()
|
||
# chests = GameService.get_active_chests(db, streamer_id)
|
||
#
|
||
# print(f"Found {len(chests)} chests for streamer {streamer_id}")
|
||
#
|
||
# for chest in chests:
|
||
# print(f"Checking chest {chest.id}, status: {chest.status}")
|
||
# if chest.status == 0: # BETTING
|
||
# print(f"Starting countdown for chest {chest.id}")
|
||
# # 启动倒计时
|
||
# await countdown_manager.start_chest_countdown(
|
||
# chest,
|
||
# on_update=lambda cid, time_rem: send_countdown_update(cid, time_rem, streamer_id),
|
||
# on_expire=lambda cid: handle_countdown_expire(cid)
|
||
# )
|
||
#
|
||
# db.close()
|
||
# except Exception as e:
|
||
# print(f"Error starting countdown for streamer {streamer_id}: {e}")
|
||
# import traceback
|
||
# traceback.print_exc()
|
||
|
||
|
||
# async def stop_all_chests_countdown(streamer_id: int):
|
||
# """停止主播所有宝箱的倒计时(已废弃)"""
|
||
# try:
|
||
# db = SessionLocal()
|
||
# chests = GameService.get_active_chests(db, streamer_id)
|
||
#
|
||
# for chest in chests:
|
||
# await countdown_manager.stop_chest_countdown(chest.id)
|
||
#
|
||
# db.close()
|
||
# except Exception as e:
|
||
# print(f"Error stopping countdown for streamer {streamer_id}: {e}")
|
||
|
||
|
||
# async def send_countdown_update(chest_id: int, time_remaining: int, streamer_id: int):
|
||
# """发送倒计时更新消息(已废弃)"""
|
||
# print(f"Sending countdown update: chest {chest_id}, time: {time_remaining}")
|
||
# countdown_message = {
|
||
# "type": "countdown_update",
|
||
# "chest_id": chest_id,
|
||
# "time_remaining": time_remaining
|
||
# }
|
||
#
|
||
# # 向主播和所有用户广播倒计时更新
|
||
# await manager.send_to_streamer(streamer_id, countdown_message)
|
||
# await manager.broadcast_to_all(countdown_message)
|
||
|
||
|
||
# async def handle_countdown_expire(chest_id: int):
|
||
# """处理倒计时结束(已废弃)"""
|
||
# print(f"Chest {chest_id} countdown expired, locking...")
|
||
#
|
||
# try:
|
||
# # 更新数据库中的宝箱状态
|
||
# db = SessionLocal()
|
||
# updated_chest = GameService.lock_expired_chest(db, chest_id)
|
||
#
|
||
# if updated_chest:
|
||
# # 通知所有客户端宝箱状态变更
|
||
# status_message = {
|
||
# "type": "chest_status",
|
||
# "chest_id": chest_id,
|
||
# "status": updated_chest.status
|
||
# }
|
||
# await manager.broadcast_to_all(status_message)
|
||
#
|
||
# db.close()
|
||
# except Exception as e:
|
||
# print(f"Error locking expired chest {chest_id}: {e}")
|
||
# try:
|
||
# db.close()
|
||
# except:
|
||
# pass
|
||
|
||
|
||
async def notify_user_balance_update(user_id: int, new_balance: int):
|
||
"""通知用户余额更新"""
|
||
message = {
|
||
"type": "balance_update",
|
||
"balance": new_balance
|
||
}
|
||
await manager.send_to_user(user_id, message) |