Kamixitong/app/middleware/rate_limit.py
2025-12-12 11:35:14 +08:00

170 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
请求频率限制中间件
防止API被滥用和攻击
"""
from flask import request, jsonify, current_app
from functools import wraps
import time
import hashlib
from typing import Dict, Optional
import redis
import os
class RateLimiter:
"""频率限制器"""
def __init__(self):
self.redis_client = None
self._init_redis()
def _init_redis(self):
"""初始化Redis连接"""
try:
redis_url = os.environ.get('REDIS_URL')
if redis_url:
self.redis_client = redis.from_url(redis_url)
except Exception as e:
current_app.logger.warning(f"Redis连接失败将使用内存存储: {str(e)}")
self.redis_client = None
def _get_client_ip(self) -> str:
"""获取客户端IP地址"""
# 优先使用X-Forwarded-For头部如果存在
if request.headers.get('X-Forwarded-For'):
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
return request.remote_addr or '0.0.0.0'
def _make_key(self, identifier: str, window: str) -> str:
"""生成Redis键"""
return f"rate_limit:{identifier}:{window}"
def check_rate_limit(
self,
identifier: Optional[str] = None,
limit: int = 100,
window: int = 3600
) -> Tuple[bool, Dict]:
"""
检查频率限制
:param identifier: 标识符IP或用户ID
:param limit: 限制次数
:param window: 时间窗口(秒)
:return: (是否允许, 限制信息)
"""
if not identifier:
identifier = self._get_client_ip()
current_time = int(time.time())
window_size = window
# 计算当前窗口
current_window = current_time // window_size
if self.redis_client:
# 使用Redis存储
key = self._make_key(identifier, current_window)
pipe = self.redis_client.pipeline()
pipe.incr(key)
pipe.expire(key, window_size)
results = pipe.execute()
current_requests = results[0]
else:
# 使用内存存储(仅用于开发环境)
key = self._make_key(identifier, current_window)
if not hasattr(current_app, '_rate_limits'):
current_app._rate_limits = {}
if key not in current_app._rate_limits:
current_app._rate_limits[key] = {
'count': 1,
'expire': current_time + window_size
}
else:
# 检查是否过期
if current_app._rate_limits[key]['expire'] < current_time:
current_app._rate_limits[key] = {
'count': 1,
'expire': current_time + window_size
}
else:
current_app._rate_limits[key]['count'] += 1
current_requests = current_app._rate_limits[key]['count']
# 检查是否超过限制
remaining = max(0, limit - current_requests)
reset_time = (current_window + 1) * window_size
allowed = current_requests <= limit
return allowed, {
'limit': limit,
'remaining': remaining,
'reset': reset_time,
'current': current_requests
}
# 全局实例
rate_limiter = RateLimiter()
def rate_limit(limit: int = 100, window: int = 3600, key_func=None):
"""
频率限制装饰器
:param limit: 限制次数
:param window: 时间窗口(秒)
:param key_func: 自定义键生成函数
"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
# 生成标识符
identifier = None
if key_func:
identifier = key_func()
# 检查频率限制
allowed, info = rate_limiter.check_rate_limit(identifier, limit, window)
if not allowed:
# 添加限制信息到响应头
response = jsonify({
'success': False,
'message': '请求过于频繁,请稍后再试',
'rate_limit': info
})
response.status_code = 429
response.headers['X-RateLimit-Limit'] = str(info['limit'])
response.headers['X-RateLimit-Remaining'] = str(info['remaining'])
response.headers['X-RateLimit-Reset'] = str(info['reset'])
return response
# 如果允许,执行函数
result = f(*args, **kwargs)
# 添加限制信息到响应头
if hasattr(result, 'headers'):
result.headers['X-RateLimit-Limit'] = str(info['limit'])
result.headers['X-RateLimit-Remaining'] = str(info['remaining'])
result.headers['X-RateLimit-Reset'] = str(info['reset'])
return result
return decorated_function
return decorator
def ip_key() -> str:
"""基于IP的键"""
return rate_limiter._get_client_ip()
def user_key() -> str:
"""基于用户ID的键"""
from flask_login import current_user
if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated:
return f"user:{current_user.get_id()}"
return ip_key()