第一次提交

This commit is contained in:
2026-03-25 15:24:22 +08:00
commit 0f8ac68d4d
156 changed files with 42365 additions and 0 deletions

8
app/utils/__init__.py Normal file
View File

@@ -0,0 +1,8 @@
from .crypto import AESCipher, generate_hash, verify_hash
from .machine_code import MachineCodeGenerator
from .validators import LicenseValidator, format_license_key
__all__ = [
'AESCipher', 'generate_hash', 'verify_hash',
'MachineCodeGenerator', 'LicenseValidator', 'format_license_key'
]

201
app/utils/alipay.py Normal file
View File

@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-
"""
支付宝支付工具类
"""
from alipay import AliPay
from flask import current_app
from datetime import datetime
class AlipayHelper:
"""支付宝支付助手类"""
def __init__(self, app=None):
"""
初始化支付宝客户端
:param app: Flask应用实例
"""
self.alipay = None
if app:
self.init_app(app)
def init_app(self, app):
"""
初始化支付宝配置
:param app: Flask应用实例
"""
config = app.config
# 检查必要的配置
required_configs = [
'ALIPAY_APP_ID',
'ALIPAY_PRIVATE_KEY',
'ALIPAY_PUBLIC_KEY',
'ALIPAY_ALIPAY_PUBLIC_KEY'
]
missing_configs = [cfg for cfg in required_configs if not config.get(cfg)]
if missing_configs:
raise ValueError(f"缺少必要的支付宝配置: {', '.join(missing_configs)}")
# 初始化支付宝客户端
self.alipay = AliPay(
appid=config['ALIPAY_APP_ID'],
app_private_key=config['ALIPAY_PRIVATE_KEY'],
alipay_public_key=config['ALIPAY_ALIPAY_PUBLIC_KEY'],
sign_type=config.get('ALIPAY_SIGN_TYPE', 'RSA2'),
charset=config.get('ALIPAY_CHARSET', 'utf-8'),
gateway=config.get('ALIPAY_GATEWAY', 'https://openapi.alipay.com/gateway.do')
)
def create_payment_url(self, order_number, amount, subject, notify_url, return_url):
"""
创建支付宝支付链接
:param order_number: 订单号
:param amount: 支付金额
:param subject: 支付主题
:param notify_url: 异步通知URL
:param return_url: 同步返回URL
:return: 支付链接
"""
if not self.alipay:
raise ValueError("支付宝客户端未初始化")
# 构建订单参数
order_params = {
'out_trade_no': order_number, # 商户订单号
'total_amount': str(amount), # 订单总金额
'subject': subject, # 订单标题
'body': f'订单{order_number}的支付', # 订单描述
'product_code': 'FAST_INSTANT_TRADE_PAY', # 产品代码
}
# 生成支付链接
payment_url = self.alipay.trade_page_pay(
total_amount=str(amount),
subject=subject,
out_trade_no=order_number,
return_url=return_url,
notify_url=notify_url
)
# 构建完整的支付URL
gateway = self.alipay.gateway
full_payment_url = f"{gateway}?{payment_url}"
return full_payment_url
def create_wap_payment_url(self, order_number, amount, subject, notify_url, return_url):
"""
创建手机网站支付链接
:param order_number: 订单号
:param amount: 支付金额
:param subject: 支付主题
:param notify_url: 异步通知URL
:param return_url: 同步返回URL
:return: 支付链接
"""
if not self.alipay:
raise ValueError("支付宝客户端未初始化")
# 生成支付链接
payment_url = self.alipay.trade_wap_pay(
total_amount=str(amount),
subject=subject,
out_trade_no=order_number,
return_url=return_url,
notify_url=notify_url
)
# 构建完整的支付URL
gateway = self.alipay.gateway
full_payment_url = f"{gateway}?{payment_url}"
return full_payment_url
def query_order_status(self, order_number):
"""
查询订单支付状态
:param order_number: 订单号
:return: 订单状态信息
"""
if not self.alipay:
raise ValueError("支付宝客户端未初始化")
try:
result = self.alipay.alipay_trade_query(
out_trade_no=order_number
)
return result
except Exception as e:
current_app.logger.error(f"查询支付宝订单状态失败: {str(e)}")
return None
def verify_notification(self, data, signature):
"""
验证支付宝异步通知签名
:param data: 通知数据
:param signature: 签名
:return: 验证结果
"""
if not self.alipay:
raise ValueError("支付宝客户端未初始化")
try:
# 验证签名
return self.alipay.verify(data, signature)
except Exception as e:
current_app.logger.error(f"验证支付宝通知签名失败: {str(e)}")
return False
def verify_trade_status(self, notify_data):
"""
验证交易状态
:param notify_data: 通知数据
:return: 验证结果和交易状态
"""
if not self.alipay:
raise ValueError("支付宝客户端未初始化")
try:
# 验证签名
if not self.alipay.verify(notify_data, notify_data.get('sign')):
return False, None
# 检查交易状态
trade_status = notify_data.get('trade_status')
if trade_status in ['TRADE_SUCCESS', 'TRADE_FINISHED']:
return True, trade_status
else:
return False, trade_status
except Exception as e:
current_app.logger.error(f"验证支付宝交易状态失败: {str(e)}")
return False, None
def get_trade_state(self, order_number):
"""
获取交易状态
:param order_number: 订单号
:return: 交易状态
"""
result = self.query_order_status(order_number)
if not result:
return None
# 根据支付宝返回的状态码判断
code = result.get('code')
if code == '10000': # 接口调用成功
trade_status = result.get('trade_status')
if trade_status == 'TRADE_SUCCESS':
return 'SUCCESS'
elif trade_status == 'TRADE_FINISHED':
return 'FINISHED'
elif trade_status == 'TRADE_CLOSED':
return 'CLOSED'
elif trade_status == 'WAIT_BUYER_PAY':
return 'WAIT_PAY'
elif code == '40004': # 业务处理失败
return 'NOT_EXIST'
return None

178
app/utils/api_response.py Normal file
View File

@@ -0,0 +1,178 @@
"""
统一的API响应工具
提供标准化的API响应格式
"""
from flask import jsonify, Response
from typing import Any, Optional, Dict
from werkzeug.http import HTTP_STATUS_CODES
class APIResponse:
"""API响应类"""
@staticmethod
def success(data: Any = None, message: str = "操作成功", code: int = 200) -> Response:
"""
成功响应
Args:
data: 响应数据
message: 响应消息
code: HTTP状态码
Returns:
Response: Flask响应对象
"""
response_data = {
'success': True,
'message': message
}
if data is not None:
response_data['data'] = data
return jsonify(response_data), code
@staticmethod
def error(message: str = "操作失败", code: int = 400, data: Any = None) -> Response:
"""
错误响应
Args:
message: 错误消息
code: HTTP状态码
data: 额外数据
Returns:
Response: Flask响应对象
"""
response_data = {
'success': False,
'message': message
}
if data is not None:
response_data['data'] = data
return jsonify(response_data), code
@staticmethod
def validation_error(errors: Dict[str, str], message: str = "参数验证失败") -> Response:
"""
参数验证错误响应
Args:
errors: 错误详情字典
message: 错误消息
Returns:
Response: Flask响应对象
"""
return APIResponse.error(
message=message,
code=400,
data={'errors': errors}
)
@staticmethod
def unauthorized(message: str = "未授权,请先登录") -> Response:
"""
未授权响应
Args:
message: 错误消息
Returns:
Response: Flask响应对象
"""
return APIResponse.error(message=message, code=401)
@staticmethod
def forbidden(message: str = "权限不足") -> Response:
"""
禁止访问响应
Args:
message: 错误消息
Returns:
Response: Flask响应对象
"""
return APIResponse.error(message=message, code=403)
@staticmethod
def not_found(message: str = "资源不存在") -> Response:
"""
资源不存在响应
Args:
message: 错误消息
Returns:
Response: Flask响应对象
"""
return APIResponse.error(message=message, code=404)
@staticmethod
def server_error(message: str = "服务器内部错误") -> Response:
"""
服务器错误响应
Args:
message: 错误消息
Returns:
Response: Flask响应对象
"""
return APIResponse.error(message=message, code=500)
@staticmethod
def paginated(items: list, total: int, page: int, per_page: int,
message: str = "查询成功") -> Response:
"""
分页响应
Args:
items: 数据列表
total: 总数
page: 当前页
per_page: 每页数量
message: 响应消息
Returns:
Response: Flask响应对象
"""
import math
total_pages = math.ceil(total / per_page) if total > 0 else 0
return APIResponse.success(
data={
'items': items,
'pagination': {
'current_page': page,
'per_page': per_page,
'total': total,
'total_pages': total_pages,
'has_prev': page > 1,
'has_next': page < total_pages
}
},
message=message
)
def success_response(data: Any = None, message: str = "操作成功") -> Response:
"""成功响应的便捷函数"""
return APIResponse.success(data=data, message=message)
def error_response(message: str = "操作失败", code: int = 400) -> Response:
"""错误响应的便捷函数"""
return APIResponse.error(message=message, code=code)
def not_found_response(message: str = "资源不存在") -> Response:
"""资源不存在响应的便捷函数"""
return APIResponse.not_found(message=message)
def validation_error_response(errors: Dict[str, str]) -> Response:
"""参数验证错误响应的便捷函数"""
return APIResponse.validation_error(errors=errors)

1001
app/utils/auth_validator.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
"""
后台任务模块
用于执行定时任务,如更新过期卡密状态
"""
from datetime import datetime, timedelta
from app import db
from app.models import License
from flask import current_app
import logging
logger = logging.getLogger(__name__)
def update_expired_licenses():
"""
更新过期卡密状态
将所有已激活但已过期的卡密状态更新为2已过期
"""
try:
logger.info("开始检查过期卡密...")
# 查找所有已激活但已过期的卡密
# 条件status=1已激活且expire_time < 当前时间
expired_licenses = License.query.filter(
License.status == 1,
License.expire_time.isnot(None),
License.expire_time < datetime.utcnow()
).all()
if not expired_licenses:
logger.info("没有发现过期卡密")
return {
'success': True,
'message': '没有发现过期卡密',
'updated_count': 0
}
# 更新过期卡密状态
updated_count = 0
for license_obj in expired_licenses:
old_status = license_obj.status
license_obj.status = 2 # 更新为已过期
updated_count += 1
logger.info(
f"更新卡密 {license_obj.license_key} 状态: {old_status} -> {license_obj.status}"
)
# 提交事务
db.session.commit()
logger.info(f"成功更新 {updated_count} 个过期卡密状态")
return {
'success': True,
'message': f'成功更新 {updated_count} 个过期卡密状态',
'updated_count': updated_count
}
except Exception as e:
# 回滚事务
db.session.rollback()
logger.error(f"更新过期卡密失败: {str(e)}", exc_info=True)
return {
'success': False,
'message': f'更新过期卡密失败: {str(e)}',
'updated_count': 0
}
def check_licenses_batch():
"""
批量检查所有卡密状态
包括:
1. 已激活但过期的 -> 更新为已过期
2. 已过期但状态未更新的 -> 更新状态
"""
try:
logger.info("开始批量检查卡密状态...")
# 检查1已激活但过期的卡密
active_but_expired = License.query.filter(
License.status == 1, # 已激活
License.expire_time.isnot(None),
License.expire_time < datetime.utcnow()
).count()
# 检查2已过期但状态正确的卡密
expired_and_marked = License.query.filter(
License.status == 2, # 已过期
License.expire_time.isnot(None),
License.expire_time < datetime.utcnow()
).count()
# 检查3已激活且未过期的卡密
active_and_valid = License.query.filter(
License.status == 1, # 已激活
db.or_(
License.expire_time.is_(None), # 永久卡
License.expire_time >= datetime.utcnow() # 未过期
)
).count()
# 检查4未激活的卡密
inactive = License.query.filter(
License.status == 0 # 未激活
).count()
# 检查5已禁用的卡密
disabled = License.query.filter(
License.status == 3 # 已禁用
).count()
logger.info(
f"卡密状态统计:\n"
f" 已激活但过期: {active_but_expired}\n"
f" 已过期且已标记: {expired_and_marked}\n"
f" 已激活且有效: {active_and_valid}\n"
f" 未激活: {inactive}\n"
f" 已禁用: {disabled}"
)
# 执行更新
update_result = update_expired_licenses()
return {
'success': True,
'message': '批量检查完成',
'statistics': {
'active_but_expired': active_but_expired,
'expired_and_marked': expired_and_marked,
'active_and_valid': active_and_valid,
'inactive': inactive,
'disabled': disabled
},
'update_result': update_result
}
except Exception as e:
logger.error(f"批量检查卡密失败: {str(e)}", exc_info=True)
return {
'success': False,
'message': f'批量检查卡密失败: {str(e)}'
}
def cleanup_old_license_logs():
"""
清理旧的卡密验证日志
保留最近30天的记录
"""
try:
from app.models import AuditLog
# 保留最近30天的记录
cutoff_date = datetime.utcnow() - timedelta(days=30)
# 删除过期的审计日志
deleted_count = AuditLog.query.filter(
AuditLog.create_time < cutoff_date
).delete(synchronize_session=False)
db.session.commit()
logger.info(f"卡密日志清理完成,删除了 {deleted_count} 条记录")
return {
'success': True,
'message': f'日志清理完成,删除了 {deleted_count} 条记录',
'deleted_count': deleted_count,
'cutoff_date': cutoff_date
}
except Exception as e:
db.session.rollback()
logger.error(f"清理日志失败: {str(e)}", exc_info=True)
return {
'success': False,
'message': f'清理日志失败: {str(e)}'
}

View File

@@ -0,0 +1,23 @@
from flask import request, jsonify, make_response
from functools import wraps
def add_cors_headers(response):
"""为响应添加CORS头部"""
# 允许特定源访问
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, X-Requested-With'
response.headers['Access-Control-Max-Age'] = '86400' # 24小时
return response
def cors_after(response):
"""在每个请求后添加CORS头部"""
return add_cors_headers(response)
def handle_preflight():
"""处理预检请求"""
if request.method == "OPTIONS":
response = make_response()
response = add_cors_headers(response)
return response
return None

11
app/utils/crypto.py Normal file
View File

@@ -0,0 +1,11 @@
# 为了兼容性,导入简化的加密工具
from .simple_crypto import (
SimpleCrypto as AESCipher,
generate_hash,
verify_hash,
generate_token,
generate_uuid,
encrypt_password,
verify_password,
generate_signature
)

185
app/utils/file_security.py Normal file
View File

@@ -0,0 +1,185 @@
"""
文件上传安全检查工具
防止恶意文件上传和攻击
"""
import os
import mimetypes
from pathlib import Path
from typing import Tuple, Optional
import hashlib
import uuid
from flask import current_app
ALLOWED_EXTENSIONS = {
# 图片
'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp',
# 文档
'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx', 'txt', 'csv',
# 压缩包
'zip', 'rar', '7z', 'tar', 'gz',
# 其他
'json', 'xml'
}
BLOCKED_EXTENSIONS = {
# 可执行文件
'exe', 'bat', 'cmd', 'com', 'pif', 'scr', 'vbs', 'js', 'jar',
# 脚本
'sh', 'py', 'php', 'asp', 'aspx', 'jsp',
# 系统文件
'sys', 'dll', 'so', 'dylib',
# 其他危险文件
'html', 'htm', 'php', 'asp', 'aspx', 'jsp'
}
def get_file_extension(filename: str) -> str:
"""获取文件扩展名(小写)"""
return Path(filename).suffix.lower().lstrip('.')
def check_file_extension(filename: str) -> Tuple[bool, str]:
"""
检查文件扩展名
:param filename: 文件名
:return: (是否允许, 错误消息)
"""
ext = get_file_extension(filename)
# 检查是否在阻止列表中
if ext in BLOCKED_EXTENSIONS:
return False, f"不允许上传.{ext}类型的文件"
# 检查是否在允许列表中
if ext not in ALLOWED_EXTENSIONS:
return False, f"不支持的文件类型: .{ext}"
return True, "文件类型允许"
def check_file_signature(file_path: str, allowed_mimetypes: set) -> Tuple[bool, str]:
"""
检查文件签名(魔数)
:param file_path: 文件路径
:param allowed_mimetypes: 允许的MIME类型集合
:return: (是否通过, 错误消息)
"""
try:
# 读取文件头部通常前20字节足够识别文件类型
with open(file_path, 'rb') as f:
header = f.read(20)
# 根据文件头部判断文件类型
# 这里只是简单示例,实际应该根据具体文件类型实现
if header.startswith(b'\x89PNG\r\n\x1a\n'):
mimetype = 'image/png'
elif header.startswith(b'\xff\xd8\xff'):
mimetype = 'image/jpeg'
elif header.startswith(b'GIF87a') or header.startswith(b'GIF89a'):
mimetype = 'image/gif'
elif header.startswith(b'%PDF'):
mimetype = 'application/pdf'
elif header.startswith(b'PK'):
mimetype = 'application/zip'
else:
mimetype = mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
# 检查MIME类型是否允许
if mimetype not in allowed_mimetypes:
return False, f"文件签名验证失败: 检测到{mimetype},但不在允许列表中"
return True, "文件签名验证通过"
except Exception as e:
return False, f"文件签名检查失败: {str(e)}"
def generate_safe_filename(original_filename: str) -> str:
"""
生成安全的文件名
:param original_filename: 原始文件名
:return: 安全的文件名
"""
# 获取文件扩展名
ext = get_file_extension(original_filename)
# 生成唯一文件名使用UUID
unique_name = str(uuid.uuid4())
# 组合新的文件名
safe_filename = f"{unique_name}.{ext}" if ext else unique_name
return safe_filename
def check_file_size(file_path: str, max_size: int = None) -> Tuple[bool, str]:
"""
检查文件大小
:param file_path: 文件路径
:param max_size: 最大大小字节None表示使用配置中的值
:return: (是否通过, 错误消息)
"""
if max_size is None:
max_size = current_app.config.get('MAX_CONTENT_LENGTH', 50 * 1024 * 1024) # 默认50MB
try:
file_size = os.path.getsize(file_path)
if file_size > max_size:
size_mb = file_size / (1024 * 1024)
max_size_mb = max_size / (1024 * 1024)
return False, f"文件大小超出限制: {size_mb:.2f}MB > {max_size_mb:.2f}MB"
return True, f"文件大小检查通过: {file_size / (1024 * 1024):.2f}MB"
except Exception as e:
return False, f"文件大小检查失败: {str(e)}"
def secure_file_upload(file_storage, upload_folder: str, allowed_mimetypes: set) -> Tuple[bool, str, str]:
"""
安全的文件上传
:param file_storage: Flask的FileStorage对象
:param upload_folder: 上传目录
:param allowed_mimetypes: 允许的MIME类型集合
:return: (是否成功, 消息, 文件路径)
"""
try:
# 检查文件名
if not file_storage.filename:
return False, "文件名不能为空", ""
# 检查文件扩展名
allowed, msg = check_file_extension(file_storage.filename)
if not allowed:
return False, msg, ""
# 生成安全的文件名
safe_filename = generate_safe_filename(file_storage.filename)
file_path = os.path.join(upload_folder, safe_filename)
# 确保上传目录存在
os.makedirs(upload_folder, exist_ok=True)
# 保存文件
file_storage.save(file_path)
# 检查文件大小
allowed, msg = check_file_size(file_path)
if not allowed:
# 删除已保存的文件
os.remove(file_path)
return False, msg, ""
# 检查文件签名
allowed, msg = check_file_signature(file_path, allowed_mimetypes)
if not allowed:
# 删除已保存的文件
os.remove(file_path)
return False, msg, ""
return True, "文件上传成功", file_path
except Exception as e:
current_app.logger.error(f"文件上传失败: {str(e)}")
return False, f"文件上传失败: {str(e)}", ""

View File

@@ -0,0 +1,121 @@
# -*- coding: utf-8 -*-
"""
许可证生成工具类
"""
from app import db
from app.models import License, Product, Package
from datetime import datetime, timedelta
import secrets
import string
class LicenseGenerator:
"""许可证生成器"""
def __init__(self):
"""初始化"""
pass
def generate_license_key(self, length=32, prefix=''):
"""
生成卡密格式XXXX-XXXX-XXXX-XXXX
:param length: 卡密长度
:param prefix: 卡密前缀
:return: 生成的卡密
"""
# 生成随机字符串
chars = string.ascii_uppercase + string.digits
# 生成32位字符4组每组8位
random_chars = ''.join(secrets.choice(chars) for _ in range(32 - len(prefix)))
# 格式化为XXXX-XXXX-XXXX-XXXX格式
formatted_key = '-'.join([
random_chars[i:i+8] for i in range(0, len(random_chars), 8)
])
# 组合前缀和格式化后的密钥
license_key = prefix + formatted_key if prefix else formatted_key
# 确保唯一性
while License.query.filter_by(license_key=license_key).first():
random_chars = ''.join(secrets.choice(chars) for _ in range(32 - len(prefix)))
formatted_key = '-'.join([
random_chars[i:i+8] for i in range(0, len(random_chars), 8)
])
license_key = prefix + formatted_key if prefix else formatted_key
return license_key
def generate_license(self, product_id, package_id, contact_person, phone, quantity=1, license_type=1):
"""
生成许可证
:param product_id: 产品ID
:param package_id: 套餐ID
:param contact_person: 联系人
:param phone: 手机号
:param quantity: 数量
:param license_type: 许可证类型 (0=试用, 1=正式)
:return: 生成的许可证密钥
"""
# 查询产品和套餐
product = Product.query.filter_by(product_id=product_id).first()
if not product:
raise ValueError(f"产品不存在: {product_id}")
package = Package.query.filter_by(package_id=package_id).first()
if not package:
raise ValueError(f"套餐不存在: {package_id}")
# 计算有效期
if package.duration == -1: # 永久卡
expire_time = None
else:
expire_time = datetime.utcnow() + timedelta(days=package.duration)
# 生成许可证密钥
license_key = self.generate_license_key()
# 创建许可证记录
license_obj = License(
license_key=license_key,
product_id=product_id,
package_id=package_id,
contact_person=contact_person,
phone=phone,
license_type=license_type,
expire_time=expire_time,
max_devices=package.max_devices,
status=1 # 启用
)
# 保存到数据库
db.session.add(license_obj)
db.session.commit()
return license_key
def generate_batch(self, product_id, package_id, contact_person, phone, count=1, license_type=1):
"""
批量生成许可证
:param product_id: 产品ID
:param package_id: 套餐ID
:param contact_person: 联系人
:param phone: 手机号
:param count: 数量
:param license_type: 许可证类型
:return: 生成的许可证密钥列表
"""
license_keys = []
for _ in range(count):
license_key = self.generate_license(
product_id=product_id,
package_id=package_id,
contact_person=contact_person,
phone=phone,
quantity=1,
license_type=license_type
)
license_keys.append(license_key)
return license_keys

51
app/utils/logger.py Normal file
View File

@@ -0,0 +1,51 @@
from flask import request, current_app
from flask_login import current_user
from app.models import AuditLog
from functools import wraps
import json
def log_operation(action, target_type, target_id=None, details=None):
"""记录操作日志的工具函数"""
try:
# 获取当前用户信息
admin_id = getattr(current_user, 'admin_id', None) if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated else None
# 获取客户端IP
ip_address = request.headers.get('X-Forwarded-For', request.remote_addr)
# 获取用户代理
user_agent = request.headers.get('User-Agent', '')
# 记录审计日志
AuditLog.log_action(
admin_id=admin_id,
action=action,
target_type=target_type,
target_id=target_id,
details=details,
ip_address=ip_address,
user_agent=user_agent
)
except Exception as e:
if hasattr(current_app, 'logger'):
current_app.logger.error(f"记录操作日志失败: {str(e)}")
def log_operations(action, target_type):
"""记录操作日志的装饰器"""
def decorator(f):
@wraps(f)
def decorated_function(*args, **kwargs):
try:
# 执行原函数
result = f(*args, **kwargs)
# 记录成功日志
log_operation(action, target_type)
return result
except Exception as e:
# 记录错误日志
log_operation(f"{action}_ERROR", target_type, details={'error': str(e)})
raise e
return decorated_function
return decorator

296
app/utils/machine_code.py Normal file
View File

@@ -0,0 +1,296 @@
import hashlib
import platform
import subprocess
import uuid
import os
import re
from typing import Optional
class MachineCodeGenerator:
"""机器码生成器"""
def __init__(self):
self._machine_code_cache = None
def generate(self) -> str:
"""
生成机器码
:return: 32位哈希字符串
"""
if self._machine_code_cache:
return self._machine_code_cache
# 收集硬件信息
hw_info = []
try:
# 1. 获取主板序列号
board_serial = self._get_board_serial()
if board_serial:
hw_info.append(board_serial)
# 2. 获取CPU ID
cpu_id = self._get_cpu_id()
if cpu_id:
hw_info.append(cpu_id)
# 3. 获取硬盘序列号
disk_serial = self._get_disk_serial()
if disk_serial:
hw_info.append(disk_serial)
# 4. 获取系统UUID
system_uuid = self._get_system_uuid()
if system_uuid:
hw_info.append(system_uuid)
# 5. 获取MAC地址
mac_address = self._get_mac_address()
if mac_address:
hw_info.append(mac_address)
except Exception as e:
print(f"获取硬件信息时出错: {e}")
# 如果没有获取到任何硬件信息,使用备用方案
if not hw_info:
hw_info = [str(uuid.uuid4()), platform.node()]
# 组合所有硬件信息并生成哈希
combined_info = '|'.join(hw_info)
hash_obj = hashlib.sha256(combined_info.encode('utf-8'))
machine_code = hash_obj.hexdigest()[:32].upper()
self._machine_code_cache = machine_code
return machine_code
def _get_board_serial(self) -> Optional[str]:
"""获取主板序列号"""
try:
system = platform.system().lower()
if system == 'windows':
# Windows系统
result = subprocess.run(
'wmic baseboard get serialnumber',
shell=True,
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
for line in lines:
if line.strip() and 'SerialNumber' not in line:
return line.strip()
elif system == 'linux':
# Linux系统
for path in ['/sys/class/dmi/id/board_serial', '/sys/class/dmi/id/product_uuid']:
if os.path.exists(path):
with open(path, 'r') as f:
content = f.read().strip()
if content:
return content
elif system == 'darwin':
# macOS系统
result = subprocess.run(
'system_profiler SPHardwareDataType | grep "Serial Number"',
shell=True,
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
match = re.search(r'Serial Number:\s*(.+)', result.stdout)
if match:
return match.group(1).strip()
except Exception:
pass
return None
def _get_cpu_id(self) -> Optional[str]:
"""获取CPU ID"""
try:
system = platform.system().lower()
if system == 'windows':
# Windows系统
result = subprocess.run(
'wmic cpu get processorid',
shell=True,
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
for line in lines:
if line.strip() and 'ProcessorId' not in line:
return line.strip()
elif system == 'linux':
# Linux系统
result = subprocess.run(
'cat /proc/cpuinfo | grep -i "processor id" | head -1',
shell=True,
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
match = re.search(r':\s*(.+)', result.stdout)
if match:
return match.group(1).strip()
elif system == 'darwin':
# macOS系统
result = subprocess.run(
'system_profiler SPHardwareDataType | grep "Processor Name"',
shell=True,
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
match = re.search(r'Processor Name:\s*(.+)', result.stdout)
if match:
return match.group(1).strip()
except Exception:
pass
return None
def _get_disk_serial(self) -> Optional[str]:
"""获取硬盘序列号"""
try:
system = platform.system().lower()
if system == 'windows':
# Windows系统
result = subprocess.run(
'wmic diskdrive get serialnumber',
shell=True,
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
lines = result.stdout.strip().split('\n')
for line in lines:
if line.strip() and 'SerialNumber' not in line:
return line.strip()
elif system == 'linux':
# Linux系统
result = subprocess.run(
'lsblk -d -o serial | head -2 | tail -1',
shell=True,
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0 and result.stdout.strip():
return result.stdout.strip()
elif system == 'darwin':
# macOS系统
result = subprocess.run(
'diskutil info / | grep "Serial Number"',
shell=True,
capture_output=True,
text=True,
timeout=10
)
if result.returncode == 0:
match = re.search(r'Serial Number:\s*(.+)', result.stdout)
if match:
return match.group(1).strip()
except Exception:
pass
return None
def _get_system_uuid(self) -> Optional[str]:
"""获取系统UUID"""
try:
# 尝试使用Python的uuid模块获取
system_uuid = str(uuid.getnode())
if system_uuid and system_uuid != '0':
return system_uuid
# 备用方案读取系统UUID文件
system = platform.system().lower()
if system == 'linux':
for path in ['/etc/machine-id', '/var/lib/dbus/machine-id']:
if os.path.exists(path):
with open(path, 'r') as f:
content = f.read().strip()
if content:
return content
except Exception:
pass
return None
def _get_mac_address(self) -> Optional[str]:
"""获取MAC地址"""
try:
# 获取第一个非回环网络接口的MAC地址
import uuid
mac = uuid.getnode()
if mac != 0:
# 格式化MAC地址
mac_str = ':'.join([f'{(mac >> 8*i) & 0xff:02x}' for i in range(6)])
return mac_str.upper()
except Exception:
pass
return None
def save_to_cache(self, file_path: str = '.machine_code') -> bool:
"""
保存机器码到缓存文件
:param file_path: 缓存文件路径
:return: 是否保存成功
"""
try:
machine_code = self.generate()
with open(file_path, 'w') as f:
f.write(machine_code)
return True
except Exception:
return False
def load_from_cache(self, file_path: str = '.machine_code') -> Optional[str]:
"""
从缓存文件加载机器码
:param file_path: 缓存文件路径
:return: 机器码如果加载失败返回None
"""
try:
if os.path.exists(file_path):
with open(file_path, 'r') as f:
machine_code = f.read().strip()
if machine_code and len(machine_code) == 32:
self._machine_code_cache = machine_code
return machine_code
except Exception:
pass
return None
def generate_machine_code() -> str:
"""
生成机器码的便捷函数
:return: 32位机器码字符串
"""
generator = MachineCodeGenerator()
return generator.generate()

222
app/utils/scheduler.py Normal file
View File

@@ -0,0 +1,222 @@
# -*- coding: utf-8 -*-
"""
定时任务调度器
管理所有后台定时任务的启动、停止和配置
"""
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from flask import Flask
import logging
logger = logging.getLogger(__name__)
# 全局调度器实例
scheduler = None
def init_scheduler(app: Flask):
"""
初始化定时任务调度器
Args:
app: Flask应用实例
"""
global scheduler
if scheduler is not None:
logger.warning("调度器已初始化,跳过重复初始化")
return scheduler
# 创建后台调度器
scheduler = BackgroundScheduler(
timezone='UTC',
# 如果使用数据库存储作业,可以配置:
# jobstores={'default': SQLAlchemyJobStore(url=app.config['SQLALCHEMY_DATABASE_URI'])},
job_defaults={
'coalesce': False, # 当错过多个执行时机时,不合并执行
'max_instances': 1, # 同一作业的最大并发实例数
}
)
# 添加定时任务
# 1. 每小时检查一次过期卡密
scheduler.add_job(
func=check_and_update_expired_licenses,
trigger=IntervalTrigger(hours=1),
id='update_expired_licenses_hourly',
name='每小时更新过期卡密状态',
replace_existing=True,
max_instances=1
)
# 2. 每天凌晨2点执行全面检查
scheduler.add_job(
func=daily_license_health_check,
trigger=CronTrigger(hour=2, minute=0),
id='daily_license_health_check',
name='每日卡密健康检查',
replace_existing=True,
max_instances=1
)
# 3. 每周清理一次旧日志
scheduler.add_job(
func=weekly_cleanup_logs,
trigger=CronTrigger(day_of_week='sun', hour=3, minute=0),
id='weekly_cleanup_logs',
name='每周清理日志',
replace_existing=True,
max_instances=1
)
logger.info("定时任务调度器初始化完成")
logger.info("已添加以下定时任务:")
logger.info(" 1. 每小时更新过期卡密状态")
logger.info(" 2. 每天凌晨2点卡密健康检查")
logger.info(" 3. 每周日凌晨3点清理日志")
return scheduler
def start_scheduler():
"""启动调度器"""
global scheduler
if scheduler is None:
raise RuntimeError("调度器未初始化,请先调用 init_scheduler()")
if scheduler.running:
logger.warning("调度器已在运行中")
return
scheduler.start()
logger.info("定时任务调度器已启动")
def stop_scheduler():
"""停止调度器"""
global scheduler
if scheduler is None:
logger.warning("调度器未初始化")
return
if not scheduler.running:
logger.warning("调度器未运行")
return
scheduler.shutdown()
logger.info("定时任务调度器已停止")
def check_and_update_expired_licenses():
"""
检查并更新过期卡密状态
这是定时任务的包装函数,导入在函数内部以避免循环导入
"""
try:
from .background_tasks import update_expired_licenses
result = update_expired_licenses()
if result['success']:
logger.info(f"过期卡密检查完成: {result['message']}")
else:
logger.error(f"过期卡密检查失败: {result['message']}")
return result
except Exception as e:
logger.error(f"执行过期卡密检查时发生错误: {str(e)}", exc_info=True)
return {
'success': False,
'message': f'执行过期卡密检查时发生错误: {str(e)}'
}
def daily_license_health_check():
"""
每日卡密健康检查
执行全面的卡密状态检查和统计
"""
try:
from .background_tasks import check_licenses_batch
result = check_licenses_batch()
if result['success']:
stats = result.get('statistics', {})
logger.info(
f"每日卡密健康检查完成:\n"
f" 已激活但过期: {stats.get('active_but_expired', 0)}\n"
f" 已过期且已标记: {stats.get('expired_and_marked', 0)}\n"
f" 已激活且有效: {stats.get('active_and_valid', 0)}\n"
f" 未激活: {stats.get('inactive', 0)}\n"
f" 已禁用: {stats.get('disabled', 0)}"
)
else:
logger.error(f"每日卡密健康检查失败: {result['message']}")
return result
except Exception as e:
logger.error(f"执行每日卡密健康检查时发生错误: {str(e)}", exc_info=True)
return {
'success': False,
'message': f'执行每日卡密健康检查时发生错误: {str(e)}'
}
def weekly_cleanup_logs():
"""
每周清理日志
清理过期的审计日志和验证记录
"""
try:
from .background_tasks import cleanup_old_license_logs
result = cleanup_old_license_logs()
if result['success']:
logger.info(f"每周日志清理完成: {result['message']}")
else:
logger.error(f"每周日志清理失败: {result['message']}")
return result
except Exception as e:
logger.error(f"执行每周日志清理时发生错误: {str(e)}", exc_info=True)
return {
'success': False,
'message': f'执行每周日志清理时发生错误: {str(e)}'
}
def get_scheduler():
"""获取全局调度器实例"""
return scheduler
def get_job_status():
"""
获取所有定时任务的状态
Returns:
dict: 包含所有任务状态的字典
"""
if scheduler is None:
return {
'running': False,
'jobs': []
}
jobs_info = []
for job in scheduler.get_jobs():
jobs_info.append({
'id': job.id,
'name': job.name,
'next_run_time': job.next_run_time.isoformat() if job.next_run_time else None,
'trigger': str(job.trigger)
})
return {
'running': scheduler.running,
'jobs': jobs_info
}

153
app/utils/simple_crypto.py Normal file
View File

@@ -0,0 +1,153 @@
import hashlib
import base64
import secrets
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from flask import current_app
import os
class SimpleCrypto:
"""简化的加密工具类,使用 cryptography 包"""
def __init__(self, key=None):
"""
初始化加密器
:param key: 加密密钥,必须从应用配置中获取
"""
if key:
self.key = key.encode() if isinstance(key, str) else key
else:
# 从应用配置获取密钥,生产环境必须设置
key_str = current_app.config.get('AUTH_SECRET_KEY')
if not key_str:
raise ValueError("AUTH_SECRET_KEY未设置生产环境必须设置AUTH_SECRET_KEY环境变量")
self.key = key_str.encode('utf-8')
# 使用固定盐值(从密钥派生),确保同一密钥加密的数据可以解密
# 使用SHA256哈希密钥取前16字节作为盐值
salt = hashlib.sha256(b'kamixitong_salt_v1:' + self.key).digest()[:16]
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
self.fernet_key = base64.urlsafe_b64encode(kdf.derive(self.key))
self.cipher = Fernet(self.fernet_key)
def encrypt(self, data):
"""
加密数据
:param data: 要加密的数据(字符串)
:return: base64编码的加密结果
"""
try:
if isinstance(data, str):
data = data.encode('utf-8')
encrypted_data = self.cipher.encrypt(data)
result = base64.b64encode(encrypted_data)
return result.decode('utf-8')
except Exception as e:
raise ValueError(f"加密失败: {str(e)}")
def decrypt(self, encrypted_data):
"""
解密数据
:param encrypted_data: base64编码的加密数据
:return: 解密后的原始数据
"""
try:
if isinstance(encrypted_data, str):
encrypted_data = encrypted_data.encode('utf-8')
data = base64.b64decode(encrypted_data)
decrypted_data = self.cipher.decrypt(data)
return decrypted_data.decode('utf-8')
except Exception as e:
raise ValueError(f"解密失败: {str(e)}")
def generate_hash(data, salt=None):
"""
生成哈希值
:param data: 要哈希的数据
:param salt: 盐值,如果不提供则随机生成
:return: (哈希值, 盐值) 元组
"""
if salt is None:
salt = secrets.token_hex(16)
# 组合数据和盐值
combined = f"{data}{salt}".encode('utf-8')
# 生成SHA256哈希
hash_obj = hashlib.sha256(combined)
hash_value = hash_obj.hexdigest()
return hash_value, salt
def generate_signature(data, secret_key):
"""
生成签名
:param data: 要签名的数据
:param secret_key: 密钥
:return: 签名
"""
combined = f"{data}{secret_key}".encode('utf-8')
hash_obj = hashlib.sha256(combined)
return hash_obj.hexdigest()
def verify_hash(data, hash_value, salt):
"""
验证哈希值
:param data: 原始数据
:param hash_value: 要验证的哈希值
:param salt: 盐值
:return: 验证结果
"""
computed_hash, _ = generate_hash(data, salt)
return computed_hash == hash_value
def generate_token(length=32):
"""
生成随机令牌
:param length: 令牌长度
:return: 随机令牌字符串
"""
return secrets.token_urlsafe(length)
def generate_uuid():
"""
生成UUID
:return: UUID字符串
"""
import uuid
return str(uuid.uuid4())
def encrypt_password(password, salt=None):
"""
加密密码
:param password: 原始密码
:param salt: 盐值
:return: (加密后的密码, 盐值)
"""
return generate_hash(password, salt)
def verify_password(password, hashed_password, salt):
"""
验证密码
:param password: 原始密码
:param hashed_password: 加密后的密码
:param salt: 盐值
:return: 验证结果
"""
return verify_hash(password, hashed_password, salt)
# 为了兼容性,保留原来的函数名
def AESCipher(key=None):
"""兼容性包装器"""
return SimpleCrypto(key)

110
app/utils/transaction.py Normal file
View File

@@ -0,0 +1,110 @@
"""
数据库事务管理工具
提供统一的事务处理和异常处理机制
"""
from contextlib import contextmanager
from flask import current_app
from app import db
import traceback
class TransactionError(Exception):
"""事务错误基类"""
pass
class TransactionRollbackError(TransactionError):
"""事务回滚错误"""
pass
@contextmanager
def transaction(auto_commit=True, auto_rollback=True):
"""
事务上下文管理器
Args:
auto_commit: 是否在成功时自动提交
auto_rollback: 是否在失败时自动回滚
Usage:
with transaction() as session:
# 执行数据库操作
pass
"""
try:
yield db.session
if auto_commit:
db.session.commit()
except Exception as e:
if auto_rollback:
db.session.rollback()
# 记录错误日志
error_msg = f"事务执行失败: {str(e)}"
current_app.logger.error(f"{error_msg}\n{traceback.format_exc()}")
# 抛出事务错误
raise TransactionRollbackError(error_msg) from e
def safe_commit():
"""
安全提交事务
Returns:
tuple: (success: bool, error: str or None)
"""
try:
db.session.commit()
return True, None
except Exception as e:
db.session.rollback()
error_msg = f"事务提交失败: {str(e)}"
current_app.logger.error(f"{error_msg}\n{traceback.format_exc()}")
return False, error_msg
def safe_rollback():
"""
安全回滚事务
Returns:
bool: 是否回滚成功
"""
try:
db.session.rollback()
return True
except Exception as e:
current_app.logger.error(f"事务回滚失败: {str(e)}\n{traceback.format_exc()}")
return False
def execute_in_transaction(func):
"""
装饰器:在事务中执行函数
Args:
func: 要执行的函数
Returns:
函数执行结果或TransactionError
Usage:
@execute_in_transaction
def create_license():
# 创建卡密逻辑
pass
"""
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
db.session.commit()
return result
except Exception as e:
db.session.rollback()
error_msg = f"事务执行失败: {str(e)}"
current_app.logger.error(f"{error_msg}\n{traceback.format_exc()}")
raise TransactionError(error_msg) from e
return wrapper

400
app/utils/validators.py Normal file
View File

@@ -0,0 +1,400 @@
import re
import hashlib
from datetime import datetime, timedelta
from typing import Tuple, Optional, Dict, Any, List
class LicenseValidator:
"""许可证验证器"""
def __init__(self, config=None):
"""
初始化验证器
:param config: 配置字典
"""
self.config = config or {}
self.max_failed_attempts = self.config.get('MAX_FAILED_ATTEMPTS', 5)
self.lockout_minutes = self.config.get('LOCKOUT_MINUTES', 10)
def validate_license_key(self, license_key: str) -> bool:
"""
验证卡密格式支持XXXX-XXXX-XXXX-XXXX格式
:param license_key: 卡密字符串
:return: 是否有效
"""
if not license_key:
return False
# 去除空格和制表符,并转为大写
license_key = license_key.strip().replace(' ', '').replace('\t', '').upper()
# 检查是否为XXXX-XXXX-XXXX-XXXX格式
if '-' in license_key:
parts = license_key.split('-')
# 应该有4部分每部分8个字符
if len(parts) == 4 and all(len(part) == 8 for part in parts):
# 检查所有字符是否为大写字母或数字
combined = ''.join(parts)
if len(combined) == 32:
pattern = r'^[A-Z0-9]+$'
import re
return bool(re.match(pattern, combined))
return False
else:
# 兼容旧格式检查长度16-32位
if len(license_key) < 16 or len(license_key) > 32:
return False
# 检查字符(只允许大写字母和数字)
pattern = r'^[A-Z0-9_]+$'
import re
return bool(re.match(pattern, license_key))
def format_license_key(self, license_key: str) -> str:
"""
格式化卡密为XXXX-XXXX-XXXX-XXXX格式
:param license_key: 原始卡密
:return: 格式化后的卡密
"""
if not license_key:
return ''
# 去除空格、制表符和连字符,并转为大写
clean_key = license_key.strip().replace(' ', '').replace('\t', '').replace('-', '').upper()
# 如果长度不足32位右补0
if len(clean_key) < 32:
clean_key = clean_key.ljust(32, '0')
# 如果长度超过32位截取前32位
elif len(clean_key) > 32:
clean_key = clean_key[:32]
# 格式化为XXXX-XXXX-XXXX-XXXX格式
formatted_key = '-'.join([
clean_key[i:i+8] for i in range(0, len(clean_key), 8)
])
return formatted_key
def check_failed_attempts(self, failed_attempts: int, last_attempt_time: datetime) -> Tuple[bool, int]:
"""
检查失败尝试次数和时间
:param failed_attempts: 失败次数
:param last_attempt_time: 最后尝试时间
:return: (是否允许尝试, 剩余锁定时间(秒))
"""
if failed_attempts < self.max_failed_attempts:
return True, 0
# 检查锁定时间是否已过
lock_time = timedelta(minutes=self.lockout_minutes)
time_passed = datetime.utcnow() - last_attempt_time
if time_passed >= lock_time:
return True, 0
remaining_seconds = int((lock_time - time_passed).total_seconds())
return False, remaining_seconds
def validate_software_version(self, version: str) -> bool:
"""
验证软件版本格式
:param version: 版本字符串
:return: 是否有效
"""
if not version:
return False
# 语义化版本格式:主版本号.次版本号.修订号
pattern = r'^\d+\.\d+\.\d+$'
return bool(re.match(pattern, version))
def compare_versions(self, version1: str, version2: str) -> int:
"""
比较版本号
:param version1: 版本1
:param version2: 版本2
:return: -1(version1<version2), 0(version1==version2), 1(version1>version2)
"""
try:
v1_parts = [int(x) for x in version1.split('.')]
v2_parts = [int(x) for x in version2.split('.')]
# 补齐版本号长度
max_len = max(len(v1_parts), len(v2_parts))
v1_parts.extend([0] * (max_len - len(v1_parts)))
v2_parts.extend([0] * (max_len - len(v2_parts)))
for v1, v2 in zip(v1_parts, v2_parts):
if v1 < v2:
return -1
elif v1 > v2:
return 1
return 0
except (ValueError, AttributeError):
return -1
def validate_machine_code(self, machine_code: str) -> bool:
"""
验证机器码格式
:param machine_code: 机器码字符串
:return: 是否有效
"""
if not machine_code:
return False
# 机器码应该是32位大写字母和数字的组合
if len(machine_code) != 32:
return False
pattern = r'^[A-F0-9]+$'
return bool(re.match(pattern, machine_code))
def create_verification_hash(self, data: Dict[str, Any], secret_key: str) -> str:
"""
创建验证哈希
:param data: 要验证的数据字典
:param secret_key: 密钥
:return: 哈希值
"""
# 按键排序确保一致性
sorted_data = sorted(data.items())
combined = '&'.join([f"{k}={v}" for k, v in sorted_data])
combined += f"&key={secret_key}"
hash_obj = hashlib.sha256(combined.encode('utf-8'))
return hash_obj.hexdigest()
def verify_hash(self, data: Dict[str, Any], hash_value: str, secret_key: str) -> bool:
"""
验证哈希值
:param data: 原始数据字典
:param hash_value: 要验证的哈希值
:param secret_key: 密钥
:return: 验证结果
"""
computed_hash = self.create_verification_hash(data, secret_key)
return computed_hash == hash_value
def is_url_safe(self, url: str) -> bool:
"""
检查URL是否安全
:param url: URL字符串
:return: 是否安全
"""
if not url:
return False
# 基本URL格式检查
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
if not re.match(pattern, url):
return False
# 检查协议
if not url.startswith(('http://', 'https://')):
return False
return True
def sanitize_input(self, input_str: str) -> str:
"""
清理输入字符串
:param input_str: 输入字符串
:return: 清理后的字符串
"""
if not input_str:
return ''
# 移除特殊字符
dangerous_chars = ['<', '>', '"', "'", '&', '\x00']
for char in dangerous_chars:
input_str = input_str.replace(char, '')
# 限制长度
return input_str[:1000]
def format_license_key(license_key: str) -> str:
"""
格式化卡密的便捷函数
:param license_key: 原始卡密
:return: 格式化后的卡密
"""
validator = LicenseValidator()
return validator.format_license_key(license_key)
def validate_license_key(license_key: str) -> bool:
"""
验证卡密格式的便捷函数
:param license_key: 卡密字符串
:return: 是否有效
"""
validator = LicenseValidator()
return validator.validate_license_key(license_key)
# ==================== 通用验证工具 ====================
class ValidationError(Exception):
"""验证错误"""
pass
class Validator:
"""通用验证器类,提供链式验证"""
def __init__(self, value: Any, field_name: str = "字段"):
self.value = value
self.field_name = field_name
self.errors = []
def required(self) -> 'Validator':
"""验证必填"""
if self.value is None or (isinstance(self.value, str) and not self.value.strip()):
self.errors.append(f"{self.field_name}不能为空")
return self
def min_length(self, min_len: int) -> 'Validator':
"""验证最小长度"""
if self.value and len(self.value) < min_len:
self.errors.append(f"{self.field_name}长度不能少于{min_len}个字符")
return self
def max_length(self, max_len: int) -> 'Validator':
"""验证最大长度"""
if self.value and len(self.value) > max_len:
self.errors.append(f"{self.field_name}长度不能超过{max_len}个字符")
return self
def email(self) -> 'Validator':
"""验证邮箱"""
if self.value and not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', self.value):
self.errors.append(f"{self.field_name}格式不正确")
return self
def phone(self) -> 'Validator':
"""验证手机号"""
if self.value and not re.match(r'^1[3-9]\d{9}$', str(self.value)):
self.errors.append(f"{self.field_name}格式不正确")
return self
def range(self, min_val: int, max_val: int) -> 'Validator':
"""验证范围"""
if self.value is not None and (self.value < min_val or self.value > max_val):
self.errors.append(f"{self.field_name}必须在{min_val}-{max_val}之间")
return self
def choice(self, choices: List[Any]) -> 'Validator':
"""验证选项"""
if self.value not in choices:
self.errors.append(f"{self.field_name}必须是以下之一: {', '.join(map(str, choices))}")
return self
def is_valid(self) -> bool:
"""验证是否通过"""
return len(self.errors) == 0
def get_errors(self) -> List[str]:
"""获取错误列表"""
return self.errors
def raise_if_invalid(self) -> None:
"""如果验证失败则抛出异常"""
if not self.is_valid():
raise ValidationError('; '.join(self.errors))
def validate_timestamp(timestamp: int, max_seconds: int = 300) -> bool:
"""
验证时间戳有效性
Args:
timestamp: 时间戳
max_seconds: 最大允许的时间差(秒)
Returns:
bool: 是否有效
"""
try:
request_time = datetime.fromtimestamp(timestamp)
current_time = datetime.utcnow()
time_diff = abs((current_time - request_time).total_seconds())
return time_diff <= max_seconds
except (ValueError, TypeError, OSError):
return False
def validate_product_id(product_id: str) -> bool:
"""
验证产品ID格式
Args:
product_id: 产品ID
Returns:
bool: 是否有效
"""
pattern = r'^PROD_[A-F0-9]{8}$|^[A-Za-z0-9_]{1,32}$'
return re.match(pattern, product_id) is not None
def sanitize_string(value: str, max_length: int = 255) -> str:
"""
清理字符串(移除危险字符)
Args:
value: 原始字符串
max_length: 最大长度
Returns:
str: 清理后的字符串
"""
if not value:
return ''
# 移除潜在的XSS攻击字符
value = value.strip()
# 截断到指定长度
if len(value) > max_length:
value = value[:max_length]
return value
def validate_filename(filename: str, allowed_extensions: Optional[List[str]] = None) -> None:
"""
验证文件名
Args:
filename: 文件名
allowed_extensions: 允许的扩展名列表
Raises:
ValidationError: 如果文件名无效
"""
if not filename:
raise ValidationError("文件名不能为空")
# 防止路径遍历攻击
if '..' in filename or '/' in filename or '\\' in filename:
raise ValidationError("文件名包含非法字符")
# 验证扩展名
if allowed_extensions:
ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
if ext not in allowed_extensions:
raise ValidationError(f"文件扩展名必须是以下之一: {', '.join(allowed_extensions)}")
def validate_pagination(page: int = 1, per_page: int = 20, max_per_page: int = 100) -> tuple:
"""
验证分页参数
Args:
page: 页码
per_page: 每页数量
max_per_page: 最大每页数量
Returns:
tuple: (page, per_page) 修正后的值
"""
page = max(1, page)
per_page = min(max(1, per_page), max_per_page)
return page, per_page