Kamixitong/app/middleware/rate_limit.py

170 lines
5.3 KiB
Python
Raw Normal View History

2025-12-12 11:35:14 +08:00
"""
请求频率限制中间件
防止API被滥用和攻击
"""
from flask import request, jsonify, current_app
from functools import wraps
import time
import hashlib
from typing import Dict, Optional
import redis
import os
class RateLimiter:
"""频率限制器"""
def __init__(self):
self.redis_client = None
self._init_redis()
def _init_redis(self):
"""初始化Redis连接"""
try:
redis_url = os.environ.get('REDIS_URL')
if redis_url:
self.redis_client = redis.from_url(redis_url)
except Exception as e:
current_app.logger.warning(f"Redis连接失败将使用内存存储: {str(e)}")
self.redis_client = None
def _get_client_ip(self) -> str:
"""获取客户端IP地址"""
# 优先使用X-Forwarded-For头部如果存在
if request.headers.get('X-Forwarded-For'):
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
return request.remote_addr or '0.0.0.0'
def _make_key(self, identifier: str, window: str) -> str:
"""生成Redis键"""
return f"rate_limit:{identifier}:{window}"
def check_rate_limit(
self,
identifier: Optional[str] = None,
limit: int = 100,
window: int = 3600
) -> Tuple[bool, Dict]:
"""
检查频率限制
:param identifier: 标识符IP或用户ID
:param limit: 限制次数
:param window: 时间窗口
:return: (是否允许, 限制信息)
"""
if not identifier:
identifier = self._get_client_ip()
current_time = int(time.time())
window_size = window
# 计算当前窗口
current_window = current_time // window_size
if self.redis_client:
# 使用Redis存储
key = self._make_key(identifier, current_window)
pipe = self.redis_client.pipeline()
pipe.incr(key)
pipe.expire(key, window_size)
results = pipe.execute()
current_requests = results[0]
else:
# 使用内存存储(仅用于开发环境)
key = self._make_key(identifier, current_window)
if not hasattr(current_app, '_rate_limits'):
current_app._rate_limits = {}
if key not in current_app._rate_limits:
current_app._rate_limits[key] = {
'count': 1,
'expire': current_time + window_size
}
else:
# 检查是否过期
if current_app._rate_limits[key]['expire'] < current_time:
current_app._rate_limits[key] = {
'count': 1,
'expire': current_time + window_size
}
else:
current_app._rate_limits[key]['count'] += 1
current_requests = current_app._rate_limits[key]['count']
# 检查是否超过限制
remaining = max(0, limit - current_requests)
reset_time = (current_window + 1) * window_size
allowed = current_requests <= limit
return allowed, {
'limit': limit,
'remaining': remaining,
'reset': reset_time,
'current': current_requests
}
# 全局实例
rate_limiter = RateLimiter()
def rate_limit(limit: int = 100, window: int = 3600, key_func=None):
"""
频率限制装饰器
:param limit: 限制次数
:param window: 时间窗口
:param key_func: 自定义键生成函数
"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
# 生成标识符
identifier = None
if key_func:
identifier = key_func()
# 检查频率限制
allowed, info = rate_limiter.check_rate_limit(identifier, limit, window)
if not allowed:
# 添加限制信息到响应头
response = jsonify({
'success': False,
'message': '请求过于频繁,请稍后再试',
'rate_limit': info
})
response.status_code = 429
response.headers['X-RateLimit-Limit'] = str(info['limit'])
response.headers['X-RateLimit-Remaining'] = str(info['remaining'])
response.headers['X-RateLimit-Reset'] = str(info['reset'])
return response
# 如果允许,执行函数
result = f(*args, **kwargs)
# 添加限制信息到响应头
if hasattr(result, 'headers'):
result.headers['X-RateLimit-Limit'] = str(info['limit'])
result.headers['X-RateLimit-Remaining'] = str(info['remaining'])
result.headers['X-RateLimit-Reset'] = str(info['reset'])
return result
return decorated_function
return decorator
def ip_key() -> str:
"""基于IP的键"""
return rate_limiter._get_client_ip()
def user_key() -> str:
"""基于用户ID的键"""
from flask_login import current_user
if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated:
return f"user:{current_user.get_id()}"
return ip_key()