170 lines
5.3 KiB
Python
170 lines
5.3 KiB
Python
"""
|
||
请求频率限制中间件
|
||
防止API被滥用和攻击
|
||
"""
|
||
from flask import request, jsonify, current_app
|
||
from functools import wraps
|
||
import time
|
||
import hashlib
|
||
from typing import Dict, Optional
|
||
import redis
|
||
import os
|
||
|
||
|
||
class RateLimiter:
|
||
"""频率限制器"""
|
||
|
||
def __init__(self):
|
||
self.redis_client = None
|
||
self._init_redis()
|
||
|
||
def _init_redis(self):
|
||
"""初始化Redis连接"""
|
||
try:
|
||
redis_url = os.environ.get('REDIS_URL')
|
||
if redis_url:
|
||
self.redis_client = redis.from_url(redis_url)
|
||
except Exception as e:
|
||
current_app.logger.warning(f"Redis连接失败,将使用内存存储: {str(e)}")
|
||
self.redis_client = None
|
||
|
||
def _get_client_ip(self) -> str:
|
||
"""获取客户端IP地址"""
|
||
# 优先使用X-Forwarded-For头部(如果存在)
|
||
if request.headers.get('X-Forwarded-For'):
|
||
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
|
||
return request.remote_addr or '0.0.0.0'
|
||
|
||
def _make_key(self, identifier: str, window: str) -> str:
|
||
"""生成Redis键"""
|
||
return f"rate_limit:{identifier}:{window}"
|
||
|
||
def check_rate_limit(
|
||
self,
|
||
identifier: Optional[str] = None,
|
||
limit: int = 100,
|
||
window: int = 3600
|
||
) -> Tuple[bool, Dict]:
|
||
"""
|
||
检查频率限制
|
||
:param identifier: 标识符(IP或用户ID)
|
||
:param limit: 限制次数
|
||
:param window: 时间窗口(秒)
|
||
:return: (是否允许, 限制信息)
|
||
"""
|
||
if not identifier:
|
||
identifier = self._get_client_ip()
|
||
|
||
current_time = int(time.time())
|
||
window_size = window
|
||
|
||
# 计算当前窗口
|
||
current_window = current_time // window_size
|
||
|
||
if self.redis_client:
|
||
# 使用Redis存储
|
||
key = self._make_key(identifier, current_window)
|
||
pipe = self.redis_client.pipeline()
|
||
pipe.incr(key)
|
||
pipe.expire(key, window_size)
|
||
results = pipe.execute()
|
||
|
||
current_requests = results[0]
|
||
else:
|
||
# 使用内存存储(仅用于开发环境)
|
||
key = self._make_key(identifier, current_window)
|
||
if not hasattr(current_app, '_rate_limits'):
|
||
current_app._rate_limits = {}
|
||
|
||
if key not in current_app._rate_limits:
|
||
current_app._rate_limits[key] = {
|
||
'count': 1,
|
||
'expire': current_time + window_size
|
||
}
|
||
else:
|
||
# 检查是否过期
|
||
if current_app._rate_limits[key]['expire'] < current_time:
|
||
current_app._rate_limits[key] = {
|
||
'count': 1,
|
||
'expire': current_time + window_size
|
||
}
|
||
else:
|
||
current_app._rate_limits[key]['count'] += 1
|
||
|
||
current_requests = current_app._rate_limits[key]['count']
|
||
|
||
# 检查是否超过限制
|
||
remaining = max(0, limit - current_requests)
|
||
reset_time = (current_window + 1) * window_size
|
||
|
||
allowed = current_requests <= limit
|
||
|
||
return allowed, {
|
||
'limit': limit,
|
||
'remaining': remaining,
|
||
'reset': reset_time,
|
||
'current': current_requests
|
||
}
|
||
|
||
|
||
# 全局实例
|
||
rate_limiter = RateLimiter()
|
||
|
||
|
||
def rate_limit(limit: int = 100, window: int = 3600, key_func=None):
|
||
"""
|
||
频率限制装饰器
|
||
:param limit: 限制次数
|
||
:param window: 时间窗口(秒)
|
||
:param key_func: 自定义键生成函数
|
||
"""
|
||
def decorator(f):
|
||
@wraps(f)
|
||
def decorated_function(*args, **kwargs):
|
||
# 生成标识符
|
||
identifier = None
|
||
if key_func:
|
||
identifier = key_func()
|
||
|
||
# 检查频率限制
|
||
allowed, info = rate_limiter.check_rate_limit(identifier, limit, window)
|
||
|
||
if not allowed:
|
||
# 添加限制信息到响应头
|
||
response = jsonify({
|
||
'success': False,
|
||
'message': '请求过于频繁,请稍后再试',
|
||
'rate_limit': info
|
||
})
|
||
response.status_code = 429
|
||
response.headers['X-RateLimit-Limit'] = str(info['limit'])
|
||
response.headers['X-RateLimit-Remaining'] = str(info['remaining'])
|
||
response.headers['X-RateLimit-Reset'] = str(info['reset'])
|
||
return response
|
||
|
||
# 如果允许,执行函数
|
||
result = f(*args, **kwargs)
|
||
|
||
# 添加限制信息到响应头
|
||
if hasattr(result, 'headers'):
|
||
result.headers['X-RateLimit-Limit'] = str(info['limit'])
|
||
result.headers['X-RateLimit-Remaining'] = str(info['remaining'])
|
||
result.headers['X-RateLimit-Reset'] = str(info['reset'])
|
||
|
||
return result
|
||
return decorated_function
|
||
return decorator
|
||
|
||
|
||
def ip_key() -> str:
|
||
"""基于IP的键"""
|
||
return rate_limiter._get_client_ip()
|
||
|
||
|
||
def user_key() -> str:
|
||
"""基于用户ID的键"""
|
||
from flask_login import current_user
|
||
if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated:
|
||
return f"user:{current_user.get_id()}"
|
||
return ip_key()
|