第一次提交
This commit is contained in:
169
app/middleware/rate_limit.py
Normal file
169
app/middleware/rate_limit.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
请求频率限制中间件
|
||||
防止API被滥用和攻击
|
||||
"""
|
||||
from flask import request, jsonify, current_app
|
||||
from functools import wraps
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Dict, Optional, Tuple
|
||||
import redis
|
||||
import os
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""频率限制器"""
|
||||
|
||||
def __init__(self):
|
||||
self.redis_client = None
|
||||
self._init_redis()
|
||||
|
||||
def _init_redis(self):
|
||||
"""初始化Redis连接"""
|
||||
try:
|
||||
redis_url = os.environ.get('REDIS_URL')
|
||||
if redis_url:
|
||||
self.redis_client = redis.from_url(redis_url)
|
||||
except Exception as e:
|
||||
current_app.logger.warning(f"Redis连接失败,将使用内存存储: {str(e)}")
|
||||
self.redis_client = None
|
||||
|
||||
def _get_client_ip(self) -> str:
|
||||
"""获取客户端IP地址"""
|
||||
# 优先使用X-Forwarded-For头部(如果存在)
|
||||
if request.headers.get('X-Forwarded-For'):
|
||||
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
|
||||
return request.remote_addr or '0.0.0.0'
|
||||
|
||||
def _make_key(self, identifier: str, window: int) -> str:
|
||||
"""生成Redis键"""
|
||||
return f"rate_limit:{identifier}:{window}"
|
||||
|
||||
def check_rate_limit(
|
||||
self,
|
||||
identifier: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
window: int = 3600
|
||||
) -> Tuple[bool, Dict]:
|
||||
"""
|
||||
检查频率限制
|
||||
:param identifier: 标识符(IP或用户ID)
|
||||
:param limit: 限制次数
|
||||
:param window: 时间窗口(秒)
|
||||
:return: (是否允许, 限制信息)
|
||||
"""
|
||||
if not identifier:
|
||||
identifier = self._get_client_ip()
|
||||
|
||||
current_time = int(time.time())
|
||||
window_size = window
|
||||
|
||||
# 计算当前窗口
|
||||
current_window = current_time // window_size
|
||||
|
||||
if self.redis_client:
|
||||
# 使用Redis存储
|
||||
key = self._make_key(identifier, current_window)
|
||||
pipe = self.redis_client.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, window_size)
|
||||
results = pipe.execute()
|
||||
|
||||
current_requests = results[0]
|
||||
else:
|
||||
# 使用内存存储(仅用于开发环境)
|
||||
key = self._make_key(identifier, current_window)
|
||||
if not hasattr(current_app, '_rate_limits'):
|
||||
current_app._rate_limits = {}
|
||||
|
||||
if key not in current_app._rate_limits:
|
||||
current_app._rate_limits[key] = {
|
||||
'count': 1,
|
||||
'expire': current_time + window_size
|
||||
}
|
||||
else:
|
||||
# 检查是否过期
|
||||
if current_app._rate_limits[key]['expire'] < current_time:
|
||||
current_app._rate_limits[key] = {
|
||||
'count': 1,
|
||||
'expire': current_time + window_size
|
||||
}
|
||||
else:
|
||||
current_app._rate_limits[key]['count'] += 1
|
||||
|
||||
current_requests = current_app._rate_limits[key]['count']
|
||||
|
||||
# 检查是否超过限制
|
||||
remaining = max(0, limit - current_requests)
|
||||
reset_time = (current_window + 1) * window_size
|
||||
|
||||
allowed = current_requests <= limit
|
||||
|
||||
return allowed, {
|
||||
'limit': limit,
|
||||
'remaining': remaining,
|
||||
'reset': reset_time,
|
||||
'current': current_requests
|
||||
}
|
||||
|
||||
|
||||
# 全局实例
|
||||
rate_limiter = RateLimiter()
|
||||
|
||||
|
||||
def rate_limit(limit: int = 100, window: int = 3600, key_func=None):
|
||||
"""
|
||||
频率限制装饰器
|
||||
:param limit: 限制次数
|
||||
:param window: 时间窗口(秒)
|
||||
:param key_func: 自定义键生成函数
|
||||
"""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
# 生成标识符
|
||||
identifier = None
|
||||
if key_func:
|
||||
identifier = key_func()
|
||||
|
||||
# 检查频率限制
|
||||
allowed, info = rate_limiter.check_rate_limit(identifier, limit, window)
|
||||
|
||||
if not allowed:
|
||||
# 添加限制信息到响应头
|
||||
response = jsonify({
|
||||
'success': False,
|
||||
'message': '请求过于频繁,请稍后再试',
|
||||
'rate_limit': info
|
||||
})
|
||||
response.status_code = 429
|
||||
response.headers['X-RateLimit-Limit'] = str(info['limit'])
|
||||
response.headers['X-RateLimit-Remaining'] = str(info['remaining'])
|
||||
response.headers['X-RateLimit-Reset'] = str(info['reset'])
|
||||
return response
|
||||
|
||||
# 如果允许,执行函数
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
# 添加限制信息到响应头
|
||||
if hasattr(result, 'headers'):
|
||||
result.headers['X-RateLimit-Limit'] = str(info['limit'])
|
||||
result.headers['X-RateLimit-Remaining'] = str(info['remaining'])
|
||||
result.headers['X-RateLimit-Reset'] = str(info['reset'])
|
||||
|
||||
return result
|
||||
return decorated_function
|
||||
return decorator
|
||||
|
||||
|
||||
def ip_key() -> str:
|
||||
"""基于IP的键"""
|
||||
return rate_limiter._get_client_ip()
|
||||
|
||||
|
||||
def user_key() -> str:
|
||||
"""基于用户ID的键"""
|
||||
from flask_login import current_user
|
||||
if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated:
|
||||
return f"user:{current_user.get_id()}"
|
||||
return ip_key()
|
||||
Reference in New Issue
Block a user