""" 请求频率限制中间件 防止API被滥用和攻击 """ from flask import request, jsonify, current_app from functools import wraps import time import hashlib from typing import Dict, Optional import redis import os class RateLimiter: """频率限制器""" def __init__(self): self.redis_client = None self._init_redis() def _init_redis(self): """初始化Redis连接""" try: redis_url = os.environ.get('REDIS_URL') if redis_url: self.redis_client = redis.from_url(redis_url) except Exception as e: current_app.logger.warning(f"Redis连接失败,将使用内存存储: {str(e)}") self.redis_client = None def _get_client_ip(self) -> str: """获取客户端IP地址""" # 优先使用X-Forwarded-For头部(如果存在) if request.headers.get('X-Forwarded-For'): return request.headers.get('X-Forwarded-For').split(',')[0].strip() return request.remote_addr or '0.0.0.0' def _make_key(self, identifier: str, window: str) -> str: """生成Redis键""" return f"rate_limit:{identifier}:{window}" def check_rate_limit( self, identifier: Optional[str] = None, limit: int = 100, window: int = 3600 ) -> Tuple[bool, Dict]: """ 检查频率限制 :param identifier: 标识符(IP或用户ID) :param limit: 限制次数 :param window: 时间窗口(秒) :return: (是否允许, 限制信息) """ if not identifier: identifier = self._get_client_ip() current_time = int(time.time()) window_size = window # 计算当前窗口 current_window = current_time // window_size if self.redis_client: # 使用Redis存储 key = self._make_key(identifier, current_window) pipe = self.redis_client.pipeline() pipe.incr(key) pipe.expire(key, window_size) results = pipe.execute() current_requests = results[0] else: # 使用内存存储(仅用于开发环境) key = self._make_key(identifier, current_window) if not hasattr(current_app, '_rate_limits'): current_app._rate_limits = {} if key not in current_app._rate_limits: current_app._rate_limits[key] = { 'count': 1, 'expire': current_time + window_size } else: # 检查是否过期 if current_app._rate_limits[key]['expire'] < current_time: current_app._rate_limits[key] = { 'count': 1, 'expire': current_time + window_size } else: current_app._rate_limits[key]['count'] += 1 current_requests = current_app._rate_limits[key]['count'] # 检查是否超过限制 remaining = max(0, limit - current_requests) reset_time = (current_window + 1) * window_size allowed = current_requests <= limit return allowed, { 'limit': limit, 'remaining': remaining, 'reset': reset_time, 'current': current_requests } # 全局实例 rate_limiter = RateLimiter() def rate_limit(limit: int = 100, window: int = 3600, key_func=None): """ 频率限制装饰器 :param limit: 限制次数 :param window: 时间窗口(秒) :param key_func: 自定义键生成函数 """ def decorator(f): @wraps(f) def decorated_function(*args, **kwargs): # 生成标识符 identifier = None if key_func: identifier = key_func() # 检查频率限制 allowed, info = rate_limiter.check_rate_limit(identifier, limit, window) if not allowed: # 添加限制信息到响应头 response = jsonify({ 'success': False, 'message': '请求过于频繁,请稍后再试', 'rate_limit': info }) response.status_code = 429 response.headers['X-RateLimit-Limit'] = str(info['limit']) response.headers['X-RateLimit-Remaining'] = str(info['remaining']) response.headers['X-RateLimit-Reset'] = str(info['reset']) return response # 如果允许,执行函数 result = f(*args, **kwargs) # 添加限制信息到响应头 if hasattr(result, 'headers'): result.headers['X-RateLimit-Limit'] = str(info['limit']) result.headers['X-RateLimit-Remaining'] = str(info['remaining']) result.headers['X-RateLimit-Reset'] = str(info['reset']) return result return decorated_function return decorator def ip_key() -> str: """基于IP的键""" return rate_limiter._get_client_ip() def user_key() -> str: """基于用户ID的键""" from flask_login import current_user if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated: return f"user:{current_user.get_id()}" return ip_key()