400 lines
12 KiB
Python
400 lines
12 KiB
Python
import re
|
||
import hashlib
|
||
from datetime import datetime, timedelta
|
||
from typing import Tuple, Optional, Dict, Any, List
|
||
|
||
class LicenseValidator:
|
||
"""许可证验证器"""
|
||
|
||
def __init__(self, config=None):
|
||
"""
|
||
初始化验证器
|
||
:param config: 配置字典
|
||
"""
|
||
self.config = config or {}
|
||
self.max_failed_attempts = self.config.get('MAX_FAILED_ATTEMPTS', 5)
|
||
self.lockout_minutes = self.config.get('LOCKOUT_MINUTES', 10)
|
||
|
||
def validate_license_key(self, license_key: str) -> bool:
|
||
"""
|
||
验证卡密格式(支持XXXX-XXXX-XXXX-XXXX格式)
|
||
:param license_key: 卡密字符串
|
||
:return: 是否有效
|
||
"""
|
||
if not license_key:
|
||
return False
|
||
|
||
# 去除空格和制表符,并转为大写
|
||
license_key = license_key.strip().replace(' ', '').replace('\t', '').upper()
|
||
|
||
# 检查是否为XXXX-XXXX-XXXX-XXXX格式
|
||
if '-' in license_key:
|
||
parts = license_key.split('-')
|
||
# 应该有4部分,每部分8个字符
|
||
if len(parts) == 4 and all(len(part) == 8 for part in parts):
|
||
# 检查所有字符是否为大写字母或数字
|
||
combined = ''.join(parts)
|
||
if len(combined) == 32:
|
||
pattern = r'^[A-Z0-9]+$'
|
||
import re
|
||
return bool(re.match(pattern, combined))
|
||
return False
|
||
else:
|
||
# 兼容旧格式:检查长度(16-32位)
|
||
if len(license_key) < 16 or len(license_key) > 32:
|
||
return False
|
||
|
||
# 检查字符(只允许大写字母和数字)
|
||
pattern = r'^[A-Z0-9_]+$'
|
||
import re
|
||
return bool(re.match(pattern, license_key))
|
||
|
||
def format_license_key(self, license_key: str) -> str:
|
||
"""
|
||
格式化卡密为XXXX-XXXX-XXXX-XXXX格式
|
||
:param license_key: 原始卡密
|
||
:return: 格式化后的卡密
|
||
"""
|
||
if not license_key:
|
||
return ''
|
||
|
||
# 去除空格、制表符和连字符,并转为大写
|
||
clean_key = license_key.strip().replace(' ', '').replace('\t', '').replace('-', '').upper()
|
||
|
||
# 如果长度不足32位,右补0
|
||
if len(clean_key) < 32:
|
||
clean_key = clean_key.ljust(32, '0')
|
||
# 如果长度超过32位,截取前32位
|
||
elif len(clean_key) > 32:
|
||
clean_key = clean_key[:32]
|
||
|
||
# 格式化为XXXX-XXXX-XXXX-XXXX格式
|
||
formatted_key = '-'.join([
|
||
clean_key[i:i+8] for i in range(0, len(clean_key), 8)
|
||
])
|
||
|
||
return formatted_key
|
||
|
||
def check_failed_attempts(self, failed_attempts: int, last_attempt_time: datetime) -> Tuple[bool, int]:
|
||
"""
|
||
检查失败尝试次数和时间
|
||
:param failed_attempts: 失败次数
|
||
:param last_attempt_time: 最后尝试时间
|
||
:return: (是否允许尝试, 剩余锁定时间(秒))
|
||
"""
|
||
if failed_attempts < self.max_failed_attempts:
|
||
return True, 0
|
||
|
||
# 检查锁定时间是否已过
|
||
lock_time = timedelta(minutes=self.lockout_minutes)
|
||
time_passed = datetime.utcnow() - last_attempt_time
|
||
|
||
if time_passed >= lock_time:
|
||
return True, 0
|
||
|
||
remaining_seconds = int((lock_time - time_passed).total_seconds())
|
||
return False, remaining_seconds
|
||
|
||
def validate_software_version(self, version: str) -> bool:
|
||
"""
|
||
验证软件版本格式
|
||
:param version: 版本字符串
|
||
:return: 是否有效
|
||
"""
|
||
if not version:
|
||
return False
|
||
|
||
# 语义化版本格式:主版本号.次版本号.修订号
|
||
pattern = r'^\d+\.\d+\.\d+$'
|
||
return bool(re.match(pattern, version))
|
||
|
||
def compare_versions(self, version1: str, version2: str) -> int:
|
||
"""
|
||
比较版本号
|
||
:param version1: 版本1
|
||
:param version2: 版本2
|
||
:return: -1(version1<version2), 0(version1==version2), 1(version1>version2)
|
||
"""
|
||
try:
|
||
v1_parts = [int(x) for x in version1.split('.')]
|
||
v2_parts = [int(x) for x in version2.split('.')]
|
||
|
||
# 补齐版本号长度
|
||
max_len = max(len(v1_parts), len(v2_parts))
|
||
v1_parts.extend([0] * (max_len - len(v1_parts)))
|
||
v2_parts.extend([0] * (max_len - len(v2_parts)))
|
||
|
||
for v1, v2 in zip(v1_parts, v2_parts):
|
||
if v1 < v2:
|
||
return -1
|
||
elif v1 > v2:
|
||
return 1
|
||
|
||
return 0
|
||
except (ValueError, AttributeError):
|
||
return -1
|
||
|
||
def validate_machine_code(self, machine_code: str) -> bool:
|
||
"""
|
||
验证机器码格式
|
||
:param machine_code: 机器码字符串
|
||
:return: 是否有效
|
||
"""
|
||
if not machine_code:
|
||
return False
|
||
|
||
# 机器码应该是32位大写字母和数字的组合
|
||
if len(machine_code) != 32:
|
||
return False
|
||
|
||
pattern = r'^[A-F0-9]+$'
|
||
return bool(re.match(pattern, machine_code))
|
||
|
||
def create_verification_hash(self, data: Dict[str, Any], secret_key: str) -> str:
|
||
"""
|
||
创建验证哈希
|
||
:param data: 要验证的数据字典
|
||
:param secret_key: 密钥
|
||
:return: 哈希值
|
||
"""
|
||
# 按键排序确保一致性
|
||
sorted_data = sorted(data.items())
|
||
combined = '&'.join([f"{k}={v}" for k, v in sorted_data])
|
||
combined += f"&key={secret_key}"
|
||
|
||
hash_obj = hashlib.sha256(combined.encode('utf-8'))
|
||
return hash_obj.hexdigest()
|
||
|
||
def verify_hash(self, data: Dict[str, Any], hash_value: str, secret_key: str) -> bool:
|
||
"""
|
||
验证哈希值
|
||
:param data: 原始数据字典
|
||
:param hash_value: 要验证的哈希值
|
||
:param secret_key: 密钥
|
||
:return: 验证结果
|
||
"""
|
||
computed_hash = self.create_verification_hash(data, secret_key)
|
||
return computed_hash == hash_value
|
||
|
||
def is_url_safe(self, url: str) -> bool:
|
||
"""
|
||
检查URL是否安全
|
||
:param url: URL字符串
|
||
:return: 是否安全
|
||
"""
|
||
if not url:
|
||
return False
|
||
|
||
# 基本URL格式检查
|
||
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
|
||
if not re.match(pattern, url):
|
||
return False
|
||
|
||
# 检查协议
|
||
if not url.startswith(('http://', 'https://')):
|
||
return False
|
||
|
||
return True
|
||
|
||
def sanitize_input(self, input_str: str) -> str:
|
||
"""
|
||
清理输入字符串
|
||
:param input_str: 输入字符串
|
||
:return: 清理后的字符串
|
||
"""
|
||
if not input_str:
|
||
return ''
|
||
|
||
# 移除特殊字符
|
||
dangerous_chars = ['<', '>', '"', "'", '&', '\x00']
|
||
for char in dangerous_chars:
|
||
input_str = input_str.replace(char, '')
|
||
|
||
# 限制长度
|
||
return input_str[:1000]
|
||
|
||
def format_license_key(license_key: str) -> str:
|
||
"""
|
||
格式化卡密的便捷函数
|
||
:param license_key: 原始卡密
|
||
:return: 格式化后的卡密
|
||
"""
|
||
validator = LicenseValidator()
|
||
return validator.format_license_key(license_key)
|
||
|
||
def validate_license_key(license_key: str) -> bool:
|
||
"""
|
||
验证卡密格式的便捷函数
|
||
:param license_key: 卡密字符串
|
||
:return: 是否有效
|
||
"""
|
||
validator = LicenseValidator()
|
||
return validator.validate_license_key(license_key)
|
||
|
||
|
||
# ==================== 通用验证工具 ====================
|
||
|
||
class ValidationError(Exception):
|
||
"""验证错误"""
|
||
pass
|
||
|
||
|
||
class Validator:
|
||
"""通用验证器类,提供链式验证"""
|
||
|
||
def __init__(self, value: Any, field_name: str = "字段"):
|
||
self.value = value
|
||
self.field_name = field_name
|
||
self.errors = []
|
||
|
||
def required(self) -> 'Validator':
|
||
"""验证必填"""
|
||
if self.value is None or (isinstance(self.value, str) and not self.value.strip()):
|
||
self.errors.append(f"{self.field_name}不能为空")
|
||
return self
|
||
|
||
def min_length(self, min_len: int) -> 'Validator':
|
||
"""验证最小长度"""
|
||
if self.value and len(self.value) < min_len:
|
||
self.errors.append(f"{self.field_name}长度不能少于{min_len}个字符")
|
||
return self
|
||
|
||
def max_length(self, max_len: int) -> 'Validator':
|
||
"""验证最大长度"""
|
||
if self.value and len(self.value) > max_len:
|
||
self.errors.append(f"{self.field_name}长度不能超过{max_len}个字符")
|
||
return self
|
||
|
||
def email(self) -> 'Validator':
|
||
"""验证邮箱"""
|
||
if self.value and not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', self.value):
|
||
self.errors.append(f"{self.field_name}格式不正确")
|
||
return self
|
||
|
||
def phone(self) -> 'Validator':
|
||
"""验证手机号"""
|
||
if self.value and not re.match(r'^1[3-9]\d{9}$', str(self.value)):
|
||
self.errors.append(f"{self.field_name}格式不正确")
|
||
return self
|
||
|
||
def range(self, min_val: int, max_val: int) -> 'Validator':
|
||
"""验证范围"""
|
||
if self.value is not None and (self.value < min_val or self.value > max_val):
|
||
self.errors.append(f"{self.field_name}必须在{min_val}-{max_val}之间")
|
||
return self
|
||
|
||
def choice(self, choices: List[Any]) -> 'Validator':
|
||
"""验证选项"""
|
||
if self.value not in choices:
|
||
self.errors.append(f"{self.field_name}必须是以下之一: {', '.join(map(str, choices))}")
|
||
return self
|
||
|
||
def is_valid(self) -> bool:
|
||
"""验证是否通过"""
|
||
return len(self.errors) == 0
|
||
|
||
def get_errors(self) -> List[str]:
|
||
"""获取错误列表"""
|
||
return self.errors
|
||
|
||
def raise_if_invalid(self) -> None:
|
||
"""如果验证失败则抛出异常"""
|
||
if not self.is_valid():
|
||
raise ValidationError('; '.join(self.errors))
|
||
|
||
|
||
def validate_timestamp(timestamp: int, max_seconds: int = 300) -> bool:
|
||
"""
|
||
验证时间戳有效性
|
||
|
||
Args:
|
||
timestamp: 时间戳
|
||
max_seconds: 最大允许的时间差(秒)
|
||
|
||
Returns:
|
||
bool: 是否有效
|
||
"""
|
||
try:
|
||
request_time = datetime.fromtimestamp(timestamp)
|
||
current_time = datetime.utcnow()
|
||
time_diff = abs((current_time - request_time).total_seconds())
|
||
return time_diff <= max_seconds
|
||
except (ValueError, TypeError, OSError):
|
||
return False
|
||
|
||
|
||
def validate_product_id(product_id: str) -> bool:
|
||
"""
|
||
验证产品ID格式
|
||
|
||
Args:
|
||
product_id: 产品ID
|
||
|
||
Returns:
|
||
bool: 是否有效
|
||
"""
|
||
pattern = r'^PROD_[A-F0-9]{8}$|^[A-Za-z0-9_]{1,32}$'
|
||
return re.match(pattern, product_id) is not None
|
||
|
||
|
||
def sanitize_string(value: str, max_length: int = 255) -> str:
|
||
"""
|
||
清理字符串(移除危险字符)
|
||
|
||
Args:
|
||
value: 原始字符串
|
||
max_length: 最大长度
|
||
|
||
Returns:
|
||
str: 清理后的字符串
|
||
"""
|
||
if not value:
|
||
return ''
|
||
# 移除潜在的XSS攻击字符
|
||
value = value.strip()
|
||
# 截断到指定长度
|
||
if len(value) > max_length:
|
||
value = value[:max_length]
|
||
return value
|
||
|
||
|
||
def validate_filename(filename: str, allowed_extensions: Optional[List[str]] = None) -> None:
|
||
"""
|
||
验证文件名
|
||
|
||
Args:
|
||
filename: 文件名
|
||
allowed_extensions: 允许的扩展名列表
|
||
|
||
Raises:
|
||
ValidationError: 如果文件名无效
|
||
"""
|
||
if not filename:
|
||
raise ValidationError("文件名不能为空")
|
||
|
||
# 防止路径遍历攻击
|
||
if '..' in filename or '/' in filename or '\\' in filename:
|
||
raise ValidationError("文件名包含非法字符")
|
||
|
||
# 验证扩展名
|
||
if allowed_extensions:
|
||
ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
|
||
if ext not in allowed_extensions:
|
||
raise ValidationError(f"文件扩展名必须是以下之一: {', '.join(allowed_extensions)}")
|
||
|
||
|
||
def validate_pagination(page: int = 1, per_page: int = 20, max_per_page: int = 100) -> tuple:
|
||
"""
|
||
验证分页参数
|
||
|
||
Args:
|
||
page: 页码
|
||
per_page: 每页数量
|
||
max_per_page: 最大每页数量
|
||
|
||
Returns:
|
||
tuple: (page, per_page) 修正后的值
|
||
"""
|
||
page = max(1, page)
|
||
per_page = min(max(1, per_page), max_per_page)
|
||
return page, per_page |