第一次提交
This commit is contained in:
8
app/utils/__init__.py
Normal file
8
app/utils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .crypto import AESCipher, generate_hash, verify_hash
|
||||
from .machine_code import MachineCodeGenerator
|
||||
from .validators import LicenseValidator, format_license_key
|
||||
|
||||
__all__ = [
|
||||
'AESCipher', 'generate_hash', 'verify_hash',
|
||||
'MachineCodeGenerator', 'LicenseValidator', 'format_license_key'
|
||||
]
|
||||
201
app/utils/alipay.py
Normal file
201
app/utils/alipay.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
支付宝支付工具类
|
||||
"""
|
||||
|
||||
from alipay import AliPay
|
||||
from flask import current_app
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class AlipayHelper:
|
||||
"""支付宝支付助手类"""
|
||||
|
||||
def __init__(self, app=None):
|
||||
"""
|
||||
初始化支付宝客户端
|
||||
:param app: Flask应用实例
|
||||
"""
|
||||
self.alipay = None
|
||||
if app:
|
||||
self.init_app(app)
|
||||
|
||||
def init_app(self, app):
|
||||
"""
|
||||
初始化支付宝配置
|
||||
:param app: Flask应用实例
|
||||
"""
|
||||
config = app.config
|
||||
|
||||
# 检查必要的配置
|
||||
required_configs = [
|
||||
'ALIPAY_APP_ID',
|
||||
'ALIPAY_PRIVATE_KEY',
|
||||
'ALIPAY_PUBLIC_KEY',
|
||||
'ALIPAY_ALIPAY_PUBLIC_KEY'
|
||||
]
|
||||
|
||||
missing_configs = [cfg for cfg in required_configs if not config.get(cfg)]
|
||||
if missing_configs:
|
||||
raise ValueError(f"缺少必要的支付宝配置: {', '.join(missing_configs)}")
|
||||
|
||||
# 初始化支付宝客户端
|
||||
self.alipay = AliPay(
|
||||
appid=config['ALIPAY_APP_ID'],
|
||||
app_private_key=config['ALIPAY_PRIVATE_KEY'],
|
||||
alipay_public_key=config['ALIPAY_ALIPAY_PUBLIC_KEY'],
|
||||
sign_type=config.get('ALIPAY_SIGN_TYPE', 'RSA2'),
|
||||
charset=config.get('ALIPAY_CHARSET', 'utf-8'),
|
||||
gateway=config.get('ALIPAY_GATEWAY', 'https://openapi.alipay.com/gateway.do')
|
||||
)
|
||||
|
||||
def create_payment_url(self, order_number, amount, subject, notify_url, return_url):
|
||||
"""
|
||||
创建支付宝支付链接
|
||||
:param order_number: 订单号
|
||||
:param amount: 支付金额
|
||||
:param subject: 支付主题
|
||||
:param notify_url: 异步通知URL
|
||||
:param return_url: 同步返回URL
|
||||
:return: 支付链接
|
||||
"""
|
||||
if not self.alipay:
|
||||
raise ValueError("支付宝客户端未初始化")
|
||||
|
||||
# 构建订单参数
|
||||
order_params = {
|
||||
'out_trade_no': order_number, # 商户订单号
|
||||
'total_amount': str(amount), # 订单总金额
|
||||
'subject': subject, # 订单标题
|
||||
'body': f'订单{order_number}的支付', # 订单描述
|
||||
'product_code': 'FAST_INSTANT_TRADE_PAY', # 产品代码
|
||||
}
|
||||
|
||||
# 生成支付链接
|
||||
payment_url = self.alipay.trade_page_pay(
|
||||
total_amount=str(amount),
|
||||
subject=subject,
|
||||
out_trade_no=order_number,
|
||||
return_url=return_url,
|
||||
notify_url=notify_url
|
||||
)
|
||||
|
||||
# 构建完整的支付URL
|
||||
gateway = self.alipay.gateway
|
||||
full_payment_url = f"{gateway}?{payment_url}"
|
||||
|
||||
return full_payment_url
|
||||
|
||||
def create_wap_payment_url(self, order_number, amount, subject, notify_url, return_url):
|
||||
"""
|
||||
创建手机网站支付链接
|
||||
:param order_number: 订单号
|
||||
:param amount: 支付金额
|
||||
:param subject: 支付主题
|
||||
:param notify_url: 异步通知URL
|
||||
:param return_url: 同步返回URL
|
||||
:return: 支付链接
|
||||
"""
|
||||
if not self.alipay:
|
||||
raise ValueError("支付宝客户端未初始化")
|
||||
|
||||
# 生成支付链接
|
||||
payment_url = self.alipay.trade_wap_pay(
|
||||
total_amount=str(amount),
|
||||
subject=subject,
|
||||
out_trade_no=order_number,
|
||||
return_url=return_url,
|
||||
notify_url=notify_url
|
||||
)
|
||||
|
||||
# 构建完整的支付URL
|
||||
gateway = self.alipay.gateway
|
||||
full_payment_url = f"{gateway}?{payment_url}"
|
||||
|
||||
return full_payment_url
|
||||
|
||||
def query_order_status(self, order_number):
|
||||
"""
|
||||
查询订单支付状态
|
||||
:param order_number: 订单号
|
||||
:return: 订单状态信息
|
||||
"""
|
||||
if not self.alipay:
|
||||
raise ValueError("支付宝客户端未初始化")
|
||||
|
||||
try:
|
||||
result = self.alipay.alipay_trade_query(
|
||||
out_trade_no=order_number
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"查询支付宝订单状态失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def verify_notification(self, data, signature):
|
||||
"""
|
||||
验证支付宝异步通知签名
|
||||
:param data: 通知数据
|
||||
:param signature: 签名
|
||||
:return: 验证结果
|
||||
"""
|
||||
if not self.alipay:
|
||||
raise ValueError("支付宝客户端未初始化")
|
||||
|
||||
try:
|
||||
# 验证签名
|
||||
return self.alipay.verify(data, signature)
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"验证支付宝通知签名失败: {str(e)}")
|
||||
return False
|
||||
|
||||
def verify_trade_status(self, notify_data):
|
||||
"""
|
||||
验证交易状态
|
||||
:param notify_data: 通知数据
|
||||
:return: 验证结果和交易状态
|
||||
"""
|
||||
if not self.alipay:
|
||||
raise ValueError("支付宝客户端未初始化")
|
||||
|
||||
try:
|
||||
# 验证签名
|
||||
if not self.alipay.verify(notify_data, notify_data.get('sign')):
|
||||
return False, None
|
||||
|
||||
# 检查交易状态
|
||||
trade_status = notify_data.get('trade_status')
|
||||
if trade_status in ['TRADE_SUCCESS', 'TRADE_FINISHED']:
|
||||
return True, trade_status
|
||||
else:
|
||||
return False, trade_status
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"验证支付宝交易状态失败: {str(e)}")
|
||||
return False, None
|
||||
|
||||
def get_trade_state(self, order_number):
|
||||
"""
|
||||
获取交易状态
|
||||
:param order_number: 订单号
|
||||
:return: 交易状态
|
||||
"""
|
||||
result = self.query_order_status(order_number)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
# 根据支付宝返回的状态码判断
|
||||
code = result.get('code')
|
||||
if code == '10000': # 接口调用成功
|
||||
trade_status = result.get('trade_status')
|
||||
if trade_status == 'TRADE_SUCCESS':
|
||||
return 'SUCCESS'
|
||||
elif trade_status == 'TRADE_FINISHED':
|
||||
return 'FINISHED'
|
||||
elif trade_status == 'TRADE_CLOSED':
|
||||
return 'CLOSED'
|
||||
elif trade_status == 'WAIT_BUYER_PAY':
|
||||
return 'WAIT_PAY'
|
||||
elif code == '40004': # 业务处理失败
|
||||
return 'NOT_EXIST'
|
||||
|
||||
return None
|
||||
178
app/utils/api_response.py
Normal file
178
app/utils/api_response.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
统一的API响应工具
|
||||
提供标准化的API响应格式
|
||||
"""
|
||||
|
||||
from flask import jsonify, Response
|
||||
from typing import Any, Optional, Dict
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
|
||||
class APIResponse:
|
||||
"""API响应类"""
|
||||
|
||||
@staticmethod
|
||||
def success(data: Any = None, message: str = "操作成功", code: int = 200) -> Response:
|
||||
"""
|
||||
成功响应
|
||||
|
||||
Args:
|
||||
data: 响应数据
|
||||
message: 响应消息
|
||||
code: HTTP状态码
|
||||
|
||||
Returns:
|
||||
Response: Flask响应对象
|
||||
"""
|
||||
response_data = {
|
||||
'success': True,
|
||||
'message': message
|
||||
}
|
||||
if data is not None:
|
||||
response_data['data'] = data
|
||||
return jsonify(response_data), code
|
||||
|
||||
@staticmethod
|
||||
def error(message: str = "操作失败", code: int = 400, data: Any = None) -> Response:
|
||||
"""
|
||||
错误响应
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
code: HTTP状态码
|
||||
data: 额外数据
|
||||
|
||||
Returns:
|
||||
Response: Flask响应对象
|
||||
"""
|
||||
response_data = {
|
||||
'success': False,
|
||||
'message': message
|
||||
}
|
||||
if data is not None:
|
||||
response_data['data'] = data
|
||||
return jsonify(response_data), code
|
||||
|
||||
@staticmethod
|
||||
def validation_error(errors: Dict[str, str], message: str = "参数验证失败") -> Response:
|
||||
"""
|
||||
参数验证错误响应
|
||||
|
||||
Args:
|
||||
errors: 错误详情字典
|
||||
message: 错误消息
|
||||
|
||||
Returns:
|
||||
Response: Flask响应对象
|
||||
"""
|
||||
return APIResponse.error(
|
||||
message=message,
|
||||
code=400,
|
||||
data={'errors': errors}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unauthorized(message: str = "未授权,请先登录") -> Response:
|
||||
"""
|
||||
未授权响应
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
|
||||
Returns:
|
||||
Response: Flask响应对象
|
||||
"""
|
||||
return APIResponse.error(message=message, code=401)
|
||||
|
||||
@staticmethod
|
||||
def forbidden(message: str = "权限不足") -> Response:
|
||||
"""
|
||||
禁止访问响应
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
|
||||
Returns:
|
||||
Response: Flask响应对象
|
||||
"""
|
||||
return APIResponse.error(message=message, code=403)
|
||||
|
||||
@staticmethod
|
||||
def not_found(message: str = "资源不存在") -> Response:
|
||||
"""
|
||||
资源不存在响应
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
|
||||
Returns:
|
||||
Response: Flask响应对象
|
||||
"""
|
||||
return APIResponse.error(message=message, code=404)
|
||||
|
||||
@staticmethod
|
||||
def server_error(message: str = "服务器内部错误") -> Response:
|
||||
"""
|
||||
服务器错误响应
|
||||
|
||||
Args:
|
||||
message: 错误消息
|
||||
|
||||
Returns:
|
||||
Response: Flask响应对象
|
||||
"""
|
||||
return APIResponse.error(message=message, code=500)
|
||||
|
||||
@staticmethod
|
||||
def paginated(items: list, total: int, page: int, per_page: int,
|
||||
message: str = "查询成功") -> Response:
|
||||
"""
|
||||
分页响应
|
||||
|
||||
Args:
|
||||
items: 数据列表
|
||||
total: 总数
|
||||
page: 当前页
|
||||
per_page: 每页数量
|
||||
message: 响应消息
|
||||
|
||||
Returns:
|
||||
Response: Flask响应对象
|
||||
"""
|
||||
import math
|
||||
total_pages = math.ceil(total / per_page) if total > 0 else 0
|
||||
|
||||
return APIResponse.success(
|
||||
data={
|
||||
'items': items,
|
||||
'pagination': {
|
||||
'current_page': page,
|
||||
'per_page': per_page,
|
||||
'total': total,
|
||||
'total_pages': total_pages,
|
||||
'has_prev': page > 1,
|
||||
'has_next': page < total_pages
|
||||
}
|
||||
},
|
||||
message=message
|
||||
)
|
||||
|
||||
|
||||
def success_response(data: Any = None, message: str = "操作成功") -> Response:
|
||||
"""成功响应的便捷函数"""
|
||||
return APIResponse.success(data=data, message=message)
|
||||
|
||||
|
||||
def error_response(message: str = "操作失败", code: int = 400) -> Response:
|
||||
"""错误响应的便捷函数"""
|
||||
return APIResponse.error(message=message, code=code)
|
||||
|
||||
|
||||
def not_found_response(message: str = "资源不存在") -> Response:
|
||||
"""资源不存在响应的便捷函数"""
|
||||
return APIResponse.not_found(message=message)
|
||||
|
||||
|
||||
def validation_error_response(errors: Dict[str, str]) -> Response:
|
||||
"""参数验证错误响应的便捷函数"""
|
||||
return APIResponse.validation_error(errors=errors)
|
||||
1001
app/utils/auth_validator.py
Normal file
1001
app/utils/auth_validator.py
Normal file
File diff suppressed because it is too large
Load Diff
181
app/utils/background_tasks.py
Normal file
181
app/utils/background_tasks.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
后台任务模块
|
||||
用于执行定时任务,如更新过期卡密状态
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from app import db
|
||||
from app.models import License
|
||||
from flask import current_app
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def update_expired_licenses():
|
||||
"""
|
||||
更新过期卡密状态
|
||||
将所有已激活但已过期的卡密状态更新为2(已过期)
|
||||
"""
|
||||
try:
|
||||
logger.info("开始检查过期卡密...")
|
||||
|
||||
# 查找所有已激活但已过期的卡密
|
||||
# 条件:status=1(已激活)且expire_time < 当前时间
|
||||
expired_licenses = License.query.filter(
|
||||
License.status == 1,
|
||||
License.expire_time.isnot(None),
|
||||
License.expire_time < datetime.utcnow()
|
||||
).all()
|
||||
|
||||
if not expired_licenses:
|
||||
logger.info("没有发现过期卡密")
|
||||
return {
|
||||
'success': True,
|
||||
'message': '没有发现过期卡密',
|
||||
'updated_count': 0
|
||||
}
|
||||
|
||||
# 更新过期卡密状态
|
||||
updated_count = 0
|
||||
for license_obj in expired_licenses:
|
||||
old_status = license_obj.status
|
||||
license_obj.status = 2 # 更新为已过期
|
||||
updated_count += 1
|
||||
logger.info(
|
||||
f"更新卡密 {license_obj.license_key} 状态: {old_status} -> {license_obj.status}"
|
||||
)
|
||||
|
||||
# 提交事务
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"成功更新 {updated_count} 个过期卡密状态")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': f'成功更新 {updated_count} 个过期卡密状态',
|
||||
'updated_count': updated_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 回滚事务
|
||||
db.session.rollback()
|
||||
logger.error(f"更新过期卡密失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'更新过期卡密失败: {str(e)}',
|
||||
'updated_count': 0
|
||||
}
|
||||
|
||||
|
||||
def check_licenses_batch():
|
||||
"""
|
||||
批量检查所有卡密状态
|
||||
包括:
|
||||
1. 已激活但过期的 -> 更新为已过期
|
||||
2. 已过期但状态未更新的 -> 更新状态
|
||||
"""
|
||||
try:
|
||||
logger.info("开始批量检查卡密状态...")
|
||||
|
||||
# 检查1:已激活但过期的卡密
|
||||
active_but_expired = License.query.filter(
|
||||
License.status == 1, # 已激活
|
||||
License.expire_time.isnot(None),
|
||||
License.expire_time < datetime.utcnow()
|
||||
).count()
|
||||
|
||||
# 检查2:已过期但状态正确的卡密
|
||||
expired_and_marked = License.query.filter(
|
||||
License.status == 2, # 已过期
|
||||
License.expire_time.isnot(None),
|
||||
License.expire_time < datetime.utcnow()
|
||||
).count()
|
||||
|
||||
# 检查3:已激活且未过期的卡密
|
||||
active_and_valid = License.query.filter(
|
||||
License.status == 1, # 已激活
|
||||
db.or_(
|
||||
License.expire_time.is_(None), # 永久卡
|
||||
License.expire_time >= datetime.utcnow() # 未过期
|
||||
)
|
||||
).count()
|
||||
|
||||
# 检查4:未激活的卡密
|
||||
inactive = License.query.filter(
|
||||
License.status == 0 # 未激活
|
||||
).count()
|
||||
|
||||
# 检查5:已禁用的卡密
|
||||
disabled = License.query.filter(
|
||||
License.status == 3 # 已禁用
|
||||
).count()
|
||||
|
||||
logger.info(
|
||||
f"卡密状态统计:\n"
|
||||
f" 已激活但过期: {active_but_expired}\n"
|
||||
f" 已过期且已标记: {expired_and_marked}\n"
|
||||
f" 已激活且有效: {active_and_valid}\n"
|
||||
f" 未激活: {inactive}\n"
|
||||
f" 已禁用: {disabled}"
|
||||
)
|
||||
|
||||
# 执行更新
|
||||
update_result = update_expired_licenses()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': '批量检查完成',
|
||||
'statistics': {
|
||||
'active_but_expired': active_but_expired,
|
||||
'expired_and_marked': expired_and_marked,
|
||||
'active_and_valid': active_and_valid,
|
||||
'inactive': inactive,
|
||||
'disabled': disabled
|
||||
},
|
||||
'update_result': update_result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量检查卡密失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'批量检查卡密失败: {str(e)}'
|
||||
}
|
||||
|
||||
|
||||
def cleanup_old_license_logs():
|
||||
"""
|
||||
清理旧的卡密验证日志
|
||||
保留最近30天的记录
|
||||
"""
|
||||
try:
|
||||
from app.models import AuditLog
|
||||
|
||||
# 保留最近30天的记录
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=30)
|
||||
|
||||
# 删除过期的审计日志
|
||||
deleted_count = AuditLog.query.filter(
|
||||
AuditLog.create_time < cutoff_date
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
logger.info(f"卡密日志清理完成,删除了 {deleted_count} 条记录")
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': f'日志清理完成,删除了 {deleted_count} 条记录',
|
||||
'deleted_count': deleted_count,
|
||||
'cutoff_date': cutoff_date
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logger.error(f"清理日志失败: {str(e)}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'清理日志失败: {str(e)}'
|
||||
}
|
||||
23
app/utils/cors_middleware.py
Normal file
23
app/utils/cors_middleware.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from flask import request, jsonify, make_response
|
||||
from functools import wraps
|
||||
|
||||
def add_cors_headers(response):
|
||||
"""为响应添加CORS头部"""
|
||||
# 允许特定源访问
|
||||
response.headers['Access-Control-Allow-Origin'] = '*'
|
||||
response.headers['Access-Control-Allow-Methods'] = 'GET, POST, PUT, DELETE, OPTIONS'
|
||||
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, X-Requested-With'
|
||||
response.headers['Access-Control-Max-Age'] = '86400' # 24小时
|
||||
return response
|
||||
|
||||
def cors_after(response):
|
||||
"""在每个请求后添加CORS头部"""
|
||||
return add_cors_headers(response)
|
||||
|
||||
def handle_preflight():
|
||||
"""处理预检请求"""
|
||||
if request.method == "OPTIONS":
|
||||
response = make_response()
|
||||
response = add_cors_headers(response)
|
||||
return response
|
||||
return None
|
||||
11
app/utils/crypto.py
Normal file
11
app/utils/crypto.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# 为了兼容性,导入简化的加密工具
|
||||
from .simple_crypto import (
|
||||
SimpleCrypto as AESCipher,
|
||||
generate_hash,
|
||||
verify_hash,
|
||||
generate_token,
|
||||
generate_uuid,
|
||||
encrypt_password,
|
||||
verify_password,
|
||||
generate_signature
|
||||
)
|
||||
185
app/utils/file_security.py
Normal file
185
app/utils/file_security.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
文件上传安全检查工具
|
||||
防止恶意文件上传和攻击
|
||||
"""
|
||||
import os
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Optional
|
||||
import hashlib
|
||||
import uuid
|
||||
from flask import current_app
|
||||
|
||||
|
||||
ALLOWED_EXTENSIONS = {
|
||||
# 图片
|
||||
'png', 'jpg', 'jpeg', 'gif', 'bmp', 'webp',
|
||||
# 文档
|
||||
'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx', 'txt', 'csv',
|
||||
# 压缩包
|
||||
'zip', 'rar', '7z', 'tar', 'gz',
|
||||
# 其他
|
||||
'json', 'xml'
|
||||
}
|
||||
|
||||
BLOCKED_EXTENSIONS = {
|
||||
# 可执行文件
|
||||
'exe', 'bat', 'cmd', 'com', 'pif', 'scr', 'vbs', 'js', 'jar',
|
||||
# 脚本
|
||||
'sh', 'py', 'php', 'asp', 'aspx', 'jsp',
|
||||
# 系统文件
|
||||
'sys', 'dll', 'so', 'dylib',
|
||||
# 其他危险文件
|
||||
'html', 'htm', 'php', 'asp', 'aspx', 'jsp'
|
||||
}
|
||||
|
||||
|
||||
def get_file_extension(filename: str) -> str:
|
||||
"""获取文件扩展名(小写)"""
|
||||
return Path(filename).suffix.lower().lstrip('.')
|
||||
|
||||
|
||||
def check_file_extension(filename: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
检查文件扩展名
|
||||
:param filename: 文件名
|
||||
:return: (是否允许, 错误消息)
|
||||
"""
|
||||
ext = get_file_extension(filename)
|
||||
|
||||
# 检查是否在阻止列表中
|
||||
if ext in BLOCKED_EXTENSIONS:
|
||||
return False, f"不允许上传.{ext}类型的文件"
|
||||
|
||||
# 检查是否在允许列表中
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
return False, f"不支持的文件类型: .{ext}"
|
||||
|
||||
return True, "文件类型允许"
|
||||
|
||||
|
||||
def check_file_signature(file_path: str, allowed_mimetypes: set) -> Tuple[bool, str]:
|
||||
"""
|
||||
检查文件签名(魔数)
|
||||
:param file_path: 文件路径
|
||||
:param allowed_mimetypes: 允许的MIME类型集合
|
||||
:return: (是否通过, 错误消息)
|
||||
"""
|
||||
try:
|
||||
# 读取文件头部(通常前20字节足够识别文件类型)
|
||||
with open(file_path, 'rb') as f:
|
||||
header = f.read(20)
|
||||
|
||||
# 根据文件头部判断文件类型
|
||||
# 这里只是简单示例,实际应该根据具体文件类型实现
|
||||
if header.startswith(b'\x89PNG\r\n\x1a\n'):
|
||||
mimetype = 'image/png'
|
||||
elif header.startswith(b'\xff\xd8\xff'):
|
||||
mimetype = 'image/jpeg'
|
||||
elif header.startswith(b'GIF87a') or header.startswith(b'GIF89a'):
|
||||
mimetype = 'image/gif'
|
||||
elif header.startswith(b'%PDF'):
|
||||
mimetype = 'application/pdf'
|
||||
elif header.startswith(b'PK'):
|
||||
mimetype = 'application/zip'
|
||||
else:
|
||||
mimetype = mimetypes.guess_type(file_path)[0] or 'application/octet-stream'
|
||||
|
||||
# 检查MIME类型是否允许
|
||||
if mimetype not in allowed_mimetypes:
|
||||
return False, f"文件签名验证失败: 检测到{mimetype},但不在允许列表中"
|
||||
|
||||
return True, "文件签名验证通过"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"文件签名检查失败: {str(e)}"
|
||||
|
||||
|
||||
def generate_safe_filename(original_filename: str) -> str:
|
||||
"""
|
||||
生成安全的文件名
|
||||
:param original_filename: 原始文件名
|
||||
:return: 安全的文件名
|
||||
"""
|
||||
# 获取文件扩展名
|
||||
ext = get_file_extension(original_filename)
|
||||
|
||||
# 生成唯一文件名(使用UUID)
|
||||
unique_name = str(uuid.uuid4())
|
||||
|
||||
# 组合新的文件名
|
||||
safe_filename = f"{unique_name}.{ext}" if ext else unique_name
|
||||
|
||||
return safe_filename
|
||||
|
||||
|
||||
def check_file_size(file_path: str, max_size: int = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
检查文件大小
|
||||
:param file_path: 文件路径
|
||||
:param max_size: 最大大小(字节),None表示使用配置中的值
|
||||
:return: (是否通过, 错误消息)
|
||||
"""
|
||||
if max_size is None:
|
||||
max_size = current_app.config.get('MAX_CONTENT_LENGTH', 50 * 1024 * 1024) # 默认50MB
|
||||
|
||||
try:
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > max_size:
|
||||
size_mb = file_size / (1024 * 1024)
|
||||
max_size_mb = max_size / (1024 * 1024)
|
||||
return False, f"文件大小超出限制: {size_mb:.2f}MB > {max_size_mb:.2f}MB"
|
||||
|
||||
return True, f"文件大小检查通过: {file_size / (1024 * 1024):.2f}MB"
|
||||
|
||||
except Exception as e:
|
||||
return False, f"文件大小检查失败: {str(e)}"
|
||||
|
||||
|
||||
def secure_file_upload(file_storage, upload_folder: str, allowed_mimetypes: set) -> Tuple[bool, str, str]:
|
||||
"""
|
||||
安全的文件上传
|
||||
:param file_storage: Flask的FileStorage对象
|
||||
:param upload_folder: 上传目录
|
||||
:param allowed_mimetypes: 允许的MIME类型集合
|
||||
:return: (是否成功, 消息, 文件路径)
|
||||
"""
|
||||
try:
|
||||
# 检查文件名
|
||||
if not file_storage.filename:
|
||||
return False, "文件名不能为空", ""
|
||||
|
||||
# 检查文件扩展名
|
||||
allowed, msg = check_file_extension(file_storage.filename)
|
||||
if not allowed:
|
||||
return False, msg, ""
|
||||
|
||||
# 生成安全的文件名
|
||||
safe_filename = generate_safe_filename(file_storage.filename)
|
||||
file_path = os.path.join(upload_folder, safe_filename)
|
||||
|
||||
# 确保上传目录存在
|
||||
os.makedirs(upload_folder, exist_ok=True)
|
||||
|
||||
# 保存文件
|
||||
file_storage.save(file_path)
|
||||
|
||||
# 检查文件大小
|
||||
allowed, msg = check_file_size(file_path)
|
||||
if not allowed:
|
||||
# 删除已保存的文件
|
||||
os.remove(file_path)
|
||||
return False, msg, ""
|
||||
|
||||
# 检查文件签名
|
||||
allowed, msg = check_file_signature(file_path, allowed_mimetypes)
|
||||
if not allowed:
|
||||
# 删除已保存的文件
|
||||
os.remove(file_path)
|
||||
return False, msg, ""
|
||||
|
||||
return True, "文件上传成功", file_path
|
||||
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"文件上传失败: {str(e)}")
|
||||
return False, f"文件上传失败: {str(e)}", ""
|
||||
121
app/utils/license_generator.py
Normal file
121
app/utils/license_generator.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
许可证生成工具类
|
||||
"""
|
||||
|
||||
from app import db
|
||||
from app.models import License, Product, Package
|
||||
from datetime import datetime, timedelta
|
||||
import secrets
|
||||
import string
|
||||
|
||||
|
||||
class LicenseGenerator:
|
||||
"""许可证生成器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化"""
|
||||
pass
|
||||
|
||||
def generate_license_key(self, length=32, prefix=''):
|
||||
"""
|
||||
生成卡密(格式:XXXX-XXXX-XXXX-XXXX)
|
||||
:param length: 卡密长度
|
||||
:param prefix: 卡密前缀
|
||||
:return: 生成的卡密
|
||||
"""
|
||||
# 生成随机字符串
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
# 生成32位字符(4组,每组8位)
|
||||
random_chars = ''.join(secrets.choice(chars) for _ in range(32 - len(prefix)))
|
||||
|
||||
# 格式化为XXXX-XXXX-XXXX-XXXX格式
|
||||
formatted_key = '-'.join([
|
||||
random_chars[i:i+8] for i in range(0, len(random_chars), 8)
|
||||
])
|
||||
|
||||
# 组合前缀和格式化后的密钥
|
||||
license_key = prefix + formatted_key if prefix else formatted_key
|
||||
|
||||
# 确保唯一性
|
||||
while License.query.filter_by(license_key=license_key).first():
|
||||
random_chars = ''.join(secrets.choice(chars) for _ in range(32 - len(prefix)))
|
||||
formatted_key = '-'.join([
|
||||
random_chars[i:i+8] for i in range(0, len(random_chars), 8)
|
||||
])
|
||||
license_key = prefix + formatted_key if prefix else formatted_key
|
||||
|
||||
return license_key
|
||||
|
||||
def generate_license(self, product_id, package_id, contact_person, phone, quantity=1, license_type=1):
|
||||
"""
|
||||
生成许可证
|
||||
:param product_id: 产品ID
|
||||
:param package_id: 套餐ID
|
||||
:param contact_person: 联系人
|
||||
:param phone: 手机号
|
||||
:param quantity: 数量
|
||||
:param license_type: 许可证类型 (0=试用, 1=正式)
|
||||
:return: 生成的许可证密钥
|
||||
"""
|
||||
# 查询产品和套餐
|
||||
product = Product.query.filter_by(product_id=product_id).first()
|
||||
if not product:
|
||||
raise ValueError(f"产品不存在: {product_id}")
|
||||
|
||||
package = Package.query.filter_by(package_id=package_id).first()
|
||||
if not package:
|
||||
raise ValueError(f"套餐不存在: {package_id}")
|
||||
|
||||
# 计算有效期
|
||||
if package.duration == -1: # 永久卡
|
||||
expire_time = None
|
||||
else:
|
||||
expire_time = datetime.utcnow() + timedelta(days=package.duration)
|
||||
|
||||
# 生成许可证密钥
|
||||
license_key = self.generate_license_key()
|
||||
|
||||
# 创建许可证记录
|
||||
license_obj = License(
|
||||
license_key=license_key,
|
||||
product_id=product_id,
|
||||
package_id=package_id,
|
||||
contact_person=contact_person,
|
||||
phone=phone,
|
||||
license_type=license_type,
|
||||
expire_time=expire_time,
|
||||
max_devices=package.max_devices,
|
||||
status=1 # 启用
|
||||
)
|
||||
|
||||
# 保存到数据库
|
||||
db.session.add(license_obj)
|
||||
db.session.commit()
|
||||
|
||||
return license_key
|
||||
|
||||
def generate_batch(self, product_id, package_id, contact_person, phone, count=1, license_type=1):
|
||||
"""
|
||||
批量生成许可证
|
||||
:param product_id: 产品ID
|
||||
:param package_id: 套餐ID
|
||||
:param contact_person: 联系人
|
||||
:param phone: 手机号
|
||||
:param count: 数量
|
||||
:param license_type: 许可证类型
|
||||
:return: 生成的许可证密钥列表
|
||||
"""
|
||||
license_keys = []
|
||||
for _ in range(count):
|
||||
license_key = self.generate_license(
|
||||
product_id=product_id,
|
||||
package_id=package_id,
|
||||
contact_person=contact_person,
|
||||
phone=phone,
|
||||
quantity=1,
|
||||
license_type=license_type
|
||||
)
|
||||
license_keys.append(license_key)
|
||||
|
||||
return license_keys
|
||||
51
app/utils/logger.py
Normal file
51
app/utils/logger.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
from app.models import AuditLog
|
||||
from functools import wraps
|
||||
import json
|
||||
|
||||
def log_operation(action, target_type, target_id=None, details=None):
|
||||
"""记录操作日志的工具函数"""
|
||||
try:
|
||||
# 获取当前用户信息
|
||||
admin_id = getattr(current_user, 'admin_id', None) if hasattr(current_user, 'is_authenticated') and current_user.is_authenticated else None
|
||||
|
||||
# 获取客户端IP
|
||||
ip_address = request.headers.get('X-Forwarded-For', request.remote_addr)
|
||||
|
||||
# 获取用户代理
|
||||
user_agent = request.headers.get('User-Agent', '')
|
||||
|
||||
# 记录审计日志
|
||||
AuditLog.log_action(
|
||||
admin_id=admin_id,
|
||||
action=action,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
details=details,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent
|
||||
)
|
||||
except Exception as e:
|
||||
if hasattr(current_app, 'logger'):
|
||||
current_app.logger.error(f"记录操作日志失败: {str(e)}")
|
||||
|
||||
def log_operations(action, target_type):
|
||||
"""记录操作日志的装饰器"""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def decorated_function(*args, **kwargs):
|
||||
try:
|
||||
# 执行原函数
|
||||
result = f(*args, **kwargs)
|
||||
|
||||
# 记录成功日志
|
||||
log_operation(action, target_type)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# 记录错误日志
|
||||
log_operation(f"{action}_ERROR", target_type, details={'error': str(e)})
|
||||
raise e
|
||||
return decorated_function
|
||||
return decorator
|
||||
296
app/utils/machine_code.py
Normal file
296
app/utils/machine_code.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import hashlib
|
||||
import platform
|
||||
import subprocess
|
||||
import uuid
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
class MachineCodeGenerator:
|
||||
"""机器码生成器"""
|
||||
|
||||
def __init__(self):
|
||||
self._machine_code_cache = None
|
||||
|
||||
def generate(self) -> str:
|
||||
"""
|
||||
生成机器码
|
||||
:return: 32位哈希字符串
|
||||
"""
|
||||
if self._machine_code_cache:
|
||||
return self._machine_code_cache
|
||||
|
||||
# 收集硬件信息
|
||||
hw_info = []
|
||||
|
||||
try:
|
||||
# 1. 获取主板序列号
|
||||
board_serial = self._get_board_serial()
|
||||
if board_serial:
|
||||
hw_info.append(board_serial)
|
||||
|
||||
# 2. 获取CPU ID
|
||||
cpu_id = self._get_cpu_id()
|
||||
if cpu_id:
|
||||
hw_info.append(cpu_id)
|
||||
|
||||
# 3. 获取硬盘序列号
|
||||
disk_serial = self._get_disk_serial()
|
||||
if disk_serial:
|
||||
hw_info.append(disk_serial)
|
||||
|
||||
# 4. 获取系统UUID
|
||||
system_uuid = self._get_system_uuid()
|
||||
if system_uuid:
|
||||
hw_info.append(system_uuid)
|
||||
|
||||
# 5. 获取MAC地址
|
||||
mac_address = self._get_mac_address()
|
||||
if mac_address:
|
||||
hw_info.append(mac_address)
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取硬件信息时出错: {e}")
|
||||
|
||||
# 如果没有获取到任何硬件信息,使用备用方案
|
||||
if not hw_info:
|
||||
hw_info = [str(uuid.uuid4()), platform.node()]
|
||||
|
||||
# 组合所有硬件信息并生成哈希
|
||||
combined_info = '|'.join(hw_info)
|
||||
hash_obj = hashlib.sha256(combined_info.encode('utf-8'))
|
||||
machine_code = hash_obj.hexdigest()[:32].upper()
|
||||
|
||||
self._machine_code_cache = machine_code
|
||||
return machine_code
|
||||
|
||||
def _get_board_serial(self) -> Optional[str]:
|
||||
"""获取主板序列号"""
|
||||
try:
|
||||
system = platform.system().lower()
|
||||
|
||||
if system == 'windows':
|
||||
# Windows系统
|
||||
result = subprocess.run(
|
||||
'wmic baseboard get serialnumber',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
lines = result.stdout.strip().split('\n')
|
||||
for line in lines:
|
||||
if line.strip() and 'SerialNumber' not in line:
|
||||
return line.strip()
|
||||
|
||||
elif system == 'linux':
|
||||
# Linux系统
|
||||
for path in ['/sys/class/dmi/id/board_serial', '/sys/class/dmi/id/product_uuid']:
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read().strip()
|
||||
if content:
|
||||
return content
|
||||
|
||||
elif system == 'darwin':
|
||||
# macOS系统
|
||||
result = subprocess.run(
|
||||
'system_profiler SPHardwareDataType | grep "Serial Number"',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
match = re.search(r'Serial Number:\s*(.+)', result.stdout)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _get_cpu_id(self) -> Optional[str]:
|
||||
"""获取CPU ID"""
|
||||
try:
|
||||
system = platform.system().lower()
|
||||
|
||||
if system == 'windows':
|
||||
# Windows系统
|
||||
result = subprocess.run(
|
||||
'wmic cpu get processorid',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
lines = result.stdout.strip().split('\n')
|
||||
for line in lines:
|
||||
if line.strip() and 'ProcessorId' not in line:
|
||||
return line.strip()
|
||||
|
||||
elif system == 'linux':
|
||||
# Linux系统
|
||||
result = subprocess.run(
|
||||
'cat /proc/cpuinfo | grep -i "processor id" | head -1',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
match = re.search(r':\s*(.+)', result.stdout)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
elif system == 'darwin':
|
||||
# macOS系统
|
||||
result = subprocess.run(
|
||||
'system_profiler SPHardwareDataType | grep "Processor Name"',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
match = re.search(r'Processor Name:\s*(.+)', result.stdout)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _get_disk_serial(self) -> Optional[str]:
|
||||
"""获取硬盘序列号"""
|
||||
try:
|
||||
system = platform.system().lower()
|
||||
|
||||
if system == 'windows':
|
||||
# Windows系统
|
||||
result = subprocess.run(
|
||||
'wmic diskdrive get serialnumber',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
lines = result.stdout.strip().split('\n')
|
||||
for line in lines:
|
||||
if line.strip() and 'SerialNumber' not in line:
|
||||
return line.strip()
|
||||
|
||||
elif system == 'linux':
|
||||
# Linux系统
|
||||
result = subprocess.run(
|
||||
'lsblk -d -o serial | head -2 | tail -1',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return result.stdout.strip()
|
||||
|
||||
elif system == 'darwin':
|
||||
# macOS系统
|
||||
result = subprocess.run(
|
||||
'diskutil info / | grep "Serial Number"',
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10
|
||||
)
|
||||
if result.returncode == 0:
|
||||
match = re.search(r'Serial Number:\s*(.+)', result.stdout)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _get_system_uuid(self) -> Optional[str]:
|
||||
"""获取系统UUID"""
|
||||
try:
|
||||
# 尝试使用Python的uuid模块获取
|
||||
system_uuid = str(uuid.getnode())
|
||||
if system_uuid and system_uuid != '0':
|
||||
return system_uuid
|
||||
|
||||
# 备用方案:读取系统UUID文件
|
||||
system = platform.system().lower()
|
||||
|
||||
if system == 'linux':
|
||||
for path in ['/etc/machine-id', '/var/lib/dbus/machine-id']:
|
||||
if os.path.exists(path):
|
||||
with open(path, 'r') as f:
|
||||
content = f.read().strip()
|
||||
if content:
|
||||
return content
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _get_mac_address(self) -> Optional[str]:
|
||||
"""获取MAC地址"""
|
||||
try:
|
||||
# 获取第一个非回环网络接口的MAC地址
|
||||
import uuid
|
||||
mac = uuid.getnode()
|
||||
if mac != 0:
|
||||
# 格式化MAC地址
|
||||
mac_str = ':'.join([f'{(mac >> 8*i) & 0xff:02x}' for i in range(6)])
|
||||
return mac_str.upper()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def save_to_cache(self, file_path: str = '.machine_code') -> bool:
|
||||
"""
|
||||
保存机器码到缓存文件
|
||||
:param file_path: 缓存文件路径
|
||||
:return: 是否保存成功
|
||||
"""
|
||||
try:
|
||||
machine_code = self.generate()
|
||||
with open(file_path, 'w') as f:
|
||||
f.write(machine_code)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def load_from_cache(self, file_path: str = '.machine_code') -> Optional[str]:
|
||||
"""
|
||||
从缓存文件加载机器码
|
||||
:param file_path: 缓存文件路径
|
||||
:return: 机器码,如果加载失败返回None
|
||||
"""
|
||||
try:
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
machine_code = f.read().strip()
|
||||
if machine_code and len(machine_code) == 32:
|
||||
self._machine_code_cache = machine_code
|
||||
return machine_code
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def generate_machine_code() -> str:
|
||||
"""
|
||||
生成机器码的便捷函数
|
||||
:return: 32位机器码字符串
|
||||
"""
|
||||
generator = MachineCodeGenerator()
|
||||
return generator.generate()
|
||||
222
app/utils/scheduler.py
Normal file
222
app/utils/scheduler.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
定时任务调度器
|
||||
管理所有后台定时任务的启动、停止和配置
|
||||
"""
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from flask import Flask
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局调度器实例
|
||||
scheduler = None
|
||||
|
||||
|
||||
def init_scheduler(app: Flask):
|
||||
"""
|
||||
初始化定时任务调度器
|
||||
|
||||
Args:
|
||||
app: Flask应用实例
|
||||
"""
|
||||
global scheduler
|
||||
|
||||
if scheduler is not None:
|
||||
logger.warning("调度器已初始化,跳过重复初始化")
|
||||
return scheduler
|
||||
|
||||
# 创建后台调度器
|
||||
scheduler = BackgroundScheduler(
|
||||
timezone='UTC',
|
||||
# 如果使用数据库存储作业,可以配置:
|
||||
# jobstores={'default': SQLAlchemyJobStore(url=app.config['SQLALCHEMY_DATABASE_URI'])},
|
||||
job_defaults={
|
||||
'coalesce': False, # 当错过多个执行时机时,不合并执行
|
||||
'max_instances': 1, # 同一作业的最大并发实例数
|
||||
}
|
||||
)
|
||||
|
||||
# 添加定时任务
|
||||
|
||||
# 1. 每小时检查一次过期卡密
|
||||
scheduler.add_job(
|
||||
func=check_and_update_expired_licenses,
|
||||
trigger=IntervalTrigger(hours=1),
|
||||
id='update_expired_licenses_hourly',
|
||||
name='每小时更新过期卡密状态',
|
||||
replace_existing=True,
|
||||
max_instances=1
|
||||
)
|
||||
|
||||
# 2. 每天凌晨2点执行全面检查
|
||||
scheduler.add_job(
|
||||
func=daily_license_health_check,
|
||||
trigger=CronTrigger(hour=2, minute=0),
|
||||
id='daily_license_health_check',
|
||||
name='每日卡密健康检查',
|
||||
replace_existing=True,
|
||||
max_instances=1
|
||||
)
|
||||
|
||||
# 3. 每周清理一次旧日志
|
||||
scheduler.add_job(
|
||||
func=weekly_cleanup_logs,
|
||||
trigger=CronTrigger(day_of_week='sun', hour=3, minute=0),
|
||||
id='weekly_cleanup_logs',
|
||||
name='每周清理日志',
|
||||
replace_existing=True,
|
||||
max_instances=1
|
||||
)
|
||||
|
||||
logger.info("定时任务调度器初始化完成")
|
||||
logger.info("已添加以下定时任务:")
|
||||
logger.info(" 1. 每小时更新过期卡密状态")
|
||||
logger.info(" 2. 每天凌晨2点卡密健康检查")
|
||||
logger.info(" 3. 每周日凌晨3点清理日志")
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
def start_scheduler():
|
||||
"""启动调度器"""
|
||||
global scheduler
|
||||
|
||||
if scheduler is None:
|
||||
raise RuntimeError("调度器未初始化,请先调用 init_scheduler()")
|
||||
|
||||
if scheduler.running:
|
||||
logger.warning("调度器已在运行中")
|
||||
return
|
||||
|
||||
scheduler.start()
|
||||
logger.info("定时任务调度器已启动")
|
||||
|
||||
|
||||
def stop_scheduler():
|
||||
"""停止调度器"""
|
||||
global scheduler
|
||||
|
||||
if scheduler is None:
|
||||
logger.warning("调度器未初始化")
|
||||
return
|
||||
|
||||
if not scheduler.running:
|
||||
logger.warning("调度器未运行")
|
||||
return
|
||||
|
||||
scheduler.shutdown()
|
||||
logger.info("定时任务调度器已停止")
|
||||
|
||||
|
||||
def check_and_update_expired_licenses():
|
||||
"""
|
||||
检查并更新过期卡密状态
|
||||
这是定时任务的包装函数,导入在函数内部以避免循环导入
|
||||
"""
|
||||
try:
|
||||
from .background_tasks import update_expired_licenses
|
||||
result = update_expired_licenses()
|
||||
|
||||
if result['success']:
|
||||
logger.info(f"过期卡密检查完成: {result['message']}")
|
||||
else:
|
||||
logger.error(f"过期卡密检查失败: {result['message']}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"执行过期卡密检查时发生错误: {str(e)}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'执行过期卡密检查时发生错误: {str(e)}'
|
||||
}
|
||||
|
||||
|
||||
def daily_license_health_check():
|
||||
"""
|
||||
每日卡密健康检查
|
||||
执行全面的卡密状态检查和统计
|
||||
"""
|
||||
try:
|
||||
from .background_tasks import check_licenses_batch
|
||||
result = check_licenses_batch()
|
||||
|
||||
if result['success']:
|
||||
stats = result.get('statistics', {})
|
||||
logger.info(
|
||||
f"每日卡密健康检查完成:\n"
|
||||
f" 已激活但过期: {stats.get('active_but_expired', 0)}\n"
|
||||
f" 已过期且已标记: {stats.get('expired_and_marked', 0)}\n"
|
||||
f" 已激活且有效: {stats.get('active_and_valid', 0)}\n"
|
||||
f" 未激活: {stats.get('inactive', 0)}\n"
|
||||
f" 已禁用: {stats.get('disabled', 0)}"
|
||||
)
|
||||
else:
|
||||
logger.error(f"每日卡密健康检查失败: {result['message']}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"执行每日卡密健康检查时发生错误: {str(e)}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'执行每日卡密健康检查时发生错误: {str(e)}'
|
||||
}
|
||||
|
||||
|
||||
def weekly_cleanup_logs():
|
||||
"""
|
||||
每周清理日志
|
||||
清理过期的审计日志和验证记录
|
||||
"""
|
||||
try:
|
||||
from .background_tasks import cleanup_old_license_logs
|
||||
result = cleanup_old_license_logs()
|
||||
|
||||
if result['success']:
|
||||
logger.info(f"每周日志清理完成: {result['message']}")
|
||||
else:
|
||||
logger.error(f"每周日志清理失败: {result['message']}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"执行每周日志清理时发生错误: {str(e)}", exc_info=True)
|
||||
return {
|
||||
'success': False,
|
||||
'message': f'执行每周日志清理时发生错误: {str(e)}'
|
||||
}
|
||||
|
||||
|
||||
def get_scheduler():
|
||||
"""获取全局调度器实例"""
|
||||
return scheduler
|
||||
|
||||
|
||||
def get_job_status():
|
||||
"""
|
||||
获取所有定时任务的状态
|
||||
|
||||
Returns:
|
||||
dict: 包含所有任务状态的字典
|
||||
"""
|
||||
if scheduler is None:
|
||||
return {
|
||||
'running': False,
|
||||
'jobs': []
|
||||
}
|
||||
|
||||
jobs_info = []
|
||||
for job in scheduler.get_jobs():
|
||||
jobs_info.append({
|
||||
'id': job.id,
|
||||
'name': job.name,
|
||||
'next_run_time': job.next_run_time.isoformat() if job.next_run_time else None,
|
||||
'trigger': str(job.trigger)
|
||||
})
|
||||
|
||||
return {
|
||||
'running': scheduler.running,
|
||||
'jobs': jobs_info
|
||||
}
|
||||
153
app/utils/simple_crypto.py
Normal file
153
app/utils/simple_crypto.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import hashlib
|
||||
import base64
|
||||
import secrets
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
from flask import current_app
|
||||
import os
|
||||
|
||||
class SimpleCrypto:
|
||||
"""简化的加密工具类,使用 cryptography 包"""
|
||||
|
||||
def __init__(self, key=None):
|
||||
"""
|
||||
初始化加密器
|
||||
:param key: 加密密钥,必须从应用配置中获取
|
||||
"""
|
||||
if key:
|
||||
self.key = key.encode() if isinstance(key, str) else key
|
||||
else:
|
||||
# 从应用配置获取密钥,生产环境必须设置
|
||||
key_str = current_app.config.get('AUTH_SECRET_KEY')
|
||||
if not key_str:
|
||||
raise ValueError("AUTH_SECRET_KEY未设置!生产环境必须设置AUTH_SECRET_KEY环境变量!")
|
||||
self.key = key_str.encode('utf-8')
|
||||
|
||||
# 使用固定盐值(从密钥派生),确保同一密钥加密的数据可以解密
|
||||
# 使用SHA256哈希密钥,取前16字节作为盐值
|
||||
salt = hashlib.sha256(b'kamixitong_salt_v1:' + self.key).digest()[:16]
|
||||
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=32,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
)
|
||||
self.fernet_key = base64.urlsafe_b64encode(kdf.derive(self.key))
|
||||
self.cipher = Fernet(self.fernet_key)
|
||||
|
||||
def encrypt(self, data):
|
||||
"""
|
||||
加密数据
|
||||
:param data: 要加密的数据(字符串)
|
||||
:return: base64编码的加密结果
|
||||
"""
|
||||
try:
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
|
||||
encrypted_data = self.cipher.encrypt(data)
|
||||
result = base64.b64encode(encrypted_data)
|
||||
return result.decode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"加密失败: {str(e)}")
|
||||
|
||||
def decrypt(self, encrypted_data):
|
||||
"""
|
||||
解密数据
|
||||
:param encrypted_data: base64编码的加密数据
|
||||
:return: 解密后的原始数据
|
||||
"""
|
||||
try:
|
||||
if isinstance(encrypted_data, str):
|
||||
encrypted_data = encrypted_data.encode('utf-8')
|
||||
|
||||
data = base64.b64decode(encrypted_data)
|
||||
decrypted_data = self.cipher.decrypt(data)
|
||||
return decrypted_data.decode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"解密失败: {str(e)}")
|
||||
|
||||
def generate_hash(data, salt=None):
|
||||
"""
|
||||
生成哈希值
|
||||
:param data: 要哈希的数据
|
||||
:param salt: 盐值,如果不提供则随机生成
|
||||
:return: (哈希值, 盐值) 元组
|
||||
"""
|
||||
if salt is None:
|
||||
salt = secrets.token_hex(16)
|
||||
|
||||
# 组合数据和盐值
|
||||
combined = f"{data}{salt}".encode('utf-8')
|
||||
|
||||
# 生成SHA256哈希
|
||||
hash_obj = hashlib.sha256(combined)
|
||||
hash_value = hash_obj.hexdigest()
|
||||
|
||||
return hash_value, salt
|
||||
|
||||
def generate_signature(data, secret_key):
|
||||
"""
|
||||
生成签名
|
||||
:param data: 要签名的数据
|
||||
:param secret_key: 密钥
|
||||
:return: 签名
|
||||
"""
|
||||
combined = f"{data}{secret_key}".encode('utf-8')
|
||||
hash_obj = hashlib.sha256(combined)
|
||||
return hash_obj.hexdigest()
|
||||
|
||||
def verify_hash(data, hash_value, salt):
|
||||
"""
|
||||
验证哈希值
|
||||
:param data: 原始数据
|
||||
:param hash_value: 要验证的哈希值
|
||||
:param salt: 盐值
|
||||
:return: 验证结果
|
||||
"""
|
||||
computed_hash, _ = generate_hash(data, salt)
|
||||
return computed_hash == hash_value
|
||||
|
||||
def generate_token(length=32):
|
||||
"""
|
||||
生成随机令牌
|
||||
:param length: 令牌长度
|
||||
:return: 随机令牌字符串
|
||||
"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
def generate_uuid():
|
||||
"""
|
||||
生成UUID
|
||||
:return: UUID字符串
|
||||
"""
|
||||
import uuid
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def encrypt_password(password, salt=None):
|
||||
"""
|
||||
加密密码
|
||||
:param password: 原始密码
|
||||
:param salt: 盐值
|
||||
:return: (加密后的密码, 盐值)
|
||||
"""
|
||||
return generate_hash(password, salt)
|
||||
|
||||
def verify_password(password, hashed_password, salt):
|
||||
"""
|
||||
验证密码
|
||||
:param password: 原始密码
|
||||
:param hashed_password: 加密后的密码
|
||||
:param salt: 盐值
|
||||
:return: 验证结果
|
||||
"""
|
||||
return verify_hash(password, hashed_password, salt)
|
||||
|
||||
# 为了兼容性,保留原来的函数名
|
||||
def AESCipher(key=None):
|
||||
"""兼容性包装器"""
|
||||
return SimpleCrypto(key)
|
||||
110
app/utils/transaction.py
Normal file
110
app/utils/transaction.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
数据库事务管理工具
|
||||
提供统一的事务处理和异常处理机制
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from flask import current_app
|
||||
from app import db
|
||||
import traceback
|
||||
|
||||
|
||||
class TransactionError(Exception):
|
||||
"""事务错误基类"""
|
||||
pass
|
||||
|
||||
|
||||
class TransactionRollbackError(TransactionError):
|
||||
"""事务回滚错误"""
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def transaction(auto_commit=True, auto_rollback=True):
|
||||
"""
|
||||
事务上下文管理器
|
||||
|
||||
Args:
|
||||
auto_commit: 是否在成功时自动提交
|
||||
auto_rollback: 是否在失败时自动回滚
|
||||
|
||||
Usage:
|
||||
with transaction() as session:
|
||||
# 执行数据库操作
|
||||
pass
|
||||
"""
|
||||
try:
|
||||
yield db.session
|
||||
if auto_commit:
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
if auto_rollback:
|
||||
db.session.rollback()
|
||||
|
||||
# 记录错误日志
|
||||
error_msg = f"事务执行失败: {str(e)}"
|
||||
current_app.logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
|
||||
# 抛出事务错误
|
||||
raise TransactionRollbackError(error_msg) from e
|
||||
|
||||
|
||||
def safe_commit():
|
||||
"""
|
||||
安全提交事务
|
||||
|
||||
Returns:
|
||||
tuple: (success: bool, error: str or None)
|
||||
"""
|
||||
try:
|
||||
db.session.commit()
|
||||
return True, None
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
error_msg = f"事务提交失败: {str(e)}"
|
||||
current_app.logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
return False, error_msg
|
||||
|
||||
|
||||
def safe_rollback():
|
||||
"""
|
||||
安全回滚事务
|
||||
|
||||
Returns:
|
||||
bool: 是否回滚成功
|
||||
"""
|
||||
try:
|
||||
db.session.rollback()
|
||||
return True
|
||||
except Exception as e:
|
||||
current_app.logger.error(f"事务回滚失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return False
|
||||
|
||||
|
||||
def execute_in_transaction(func):
|
||||
"""
|
||||
装饰器:在事务中执行函数
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
|
||||
Returns:
|
||||
函数执行结果或TransactionError
|
||||
|
||||
Usage:
|
||||
@execute_in_transaction
|
||||
def create_license():
|
||||
# 创建卡密逻辑
|
||||
pass
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
db.session.commit()
|
||||
return result
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
error_msg = f"事务执行失败: {str(e)}"
|
||||
current_app.logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
||||
raise TransactionError(error_msg) from e
|
||||
return wrapper
|
||||
400
app/utils/validators.py
Normal file
400
app/utils/validators.py
Normal file
@@ -0,0 +1,400 @@
|
||||
import re
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Tuple, Optional, Dict, Any, List
|
||||
|
||||
class LicenseValidator:
|
||||
"""许可证验证器"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
"""
|
||||
初始化验证器
|
||||
:param config: 配置字典
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.max_failed_attempts = self.config.get('MAX_FAILED_ATTEMPTS', 5)
|
||||
self.lockout_minutes = self.config.get('LOCKOUT_MINUTES', 10)
|
||||
|
||||
def validate_license_key(self, license_key: str) -> bool:
|
||||
"""
|
||||
验证卡密格式(支持XXXX-XXXX-XXXX-XXXX格式)
|
||||
:param license_key: 卡密字符串
|
||||
:return: 是否有效
|
||||
"""
|
||||
if not license_key:
|
||||
return False
|
||||
|
||||
# 去除空格和制表符,并转为大写
|
||||
license_key = license_key.strip().replace(' ', '').replace('\t', '').upper()
|
||||
|
||||
# 检查是否为XXXX-XXXX-XXXX-XXXX格式
|
||||
if '-' in license_key:
|
||||
parts = license_key.split('-')
|
||||
# 应该有4部分,每部分8个字符
|
||||
if len(parts) == 4 and all(len(part) == 8 for part in parts):
|
||||
# 检查所有字符是否为大写字母或数字
|
||||
combined = ''.join(parts)
|
||||
if len(combined) == 32:
|
||||
pattern = r'^[A-Z0-9]+$'
|
||||
import re
|
||||
return bool(re.match(pattern, combined))
|
||||
return False
|
||||
else:
|
||||
# 兼容旧格式:检查长度(16-32位)
|
||||
if len(license_key) < 16 or len(license_key) > 32:
|
||||
return False
|
||||
|
||||
# 检查字符(只允许大写字母和数字)
|
||||
pattern = r'^[A-Z0-9_]+$'
|
||||
import re
|
||||
return bool(re.match(pattern, license_key))
|
||||
|
||||
def format_license_key(self, license_key: str) -> str:
|
||||
"""
|
||||
格式化卡密为XXXX-XXXX-XXXX-XXXX格式
|
||||
:param license_key: 原始卡密
|
||||
:return: 格式化后的卡密
|
||||
"""
|
||||
if not license_key:
|
||||
return ''
|
||||
|
||||
# 去除空格、制表符和连字符,并转为大写
|
||||
clean_key = license_key.strip().replace(' ', '').replace('\t', '').replace('-', '').upper()
|
||||
|
||||
# 如果长度不足32位,右补0
|
||||
if len(clean_key) < 32:
|
||||
clean_key = clean_key.ljust(32, '0')
|
||||
# 如果长度超过32位,截取前32位
|
||||
elif len(clean_key) > 32:
|
||||
clean_key = clean_key[:32]
|
||||
|
||||
# 格式化为XXXX-XXXX-XXXX-XXXX格式
|
||||
formatted_key = '-'.join([
|
||||
clean_key[i:i+8] for i in range(0, len(clean_key), 8)
|
||||
])
|
||||
|
||||
return formatted_key
|
||||
|
||||
def check_failed_attempts(self, failed_attempts: int, last_attempt_time: datetime) -> Tuple[bool, int]:
|
||||
"""
|
||||
检查失败尝试次数和时间
|
||||
:param failed_attempts: 失败次数
|
||||
:param last_attempt_time: 最后尝试时间
|
||||
:return: (是否允许尝试, 剩余锁定时间(秒))
|
||||
"""
|
||||
if failed_attempts < self.max_failed_attempts:
|
||||
return True, 0
|
||||
|
||||
# 检查锁定时间是否已过
|
||||
lock_time = timedelta(minutes=self.lockout_minutes)
|
||||
time_passed = datetime.utcnow() - last_attempt_time
|
||||
|
||||
if time_passed >= lock_time:
|
||||
return True, 0
|
||||
|
||||
remaining_seconds = int((lock_time - time_passed).total_seconds())
|
||||
return False, remaining_seconds
|
||||
|
||||
def validate_software_version(self, version: str) -> bool:
|
||||
"""
|
||||
验证软件版本格式
|
||||
:param version: 版本字符串
|
||||
:return: 是否有效
|
||||
"""
|
||||
if not version:
|
||||
return False
|
||||
|
||||
# 语义化版本格式:主版本号.次版本号.修订号
|
||||
pattern = r'^\d+\.\d+\.\d+$'
|
||||
return bool(re.match(pattern, version))
|
||||
|
||||
def compare_versions(self, version1: str, version2: str) -> int:
|
||||
"""
|
||||
比较版本号
|
||||
:param version1: 版本1
|
||||
:param version2: 版本2
|
||||
:return: -1(version1<version2), 0(version1==version2), 1(version1>version2)
|
||||
"""
|
||||
try:
|
||||
v1_parts = [int(x) for x in version1.split('.')]
|
||||
v2_parts = [int(x) for x in version2.split('.')]
|
||||
|
||||
# 补齐版本号长度
|
||||
max_len = max(len(v1_parts), len(v2_parts))
|
||||
v1_parts.extend([0] * (max_len - len(v1_parts)))
|
||||
v2_parts.extend([0] * (max_len - len(v2_parts)))
|
||||
|
||||
for v1, v2 in zip(v1_parts, v2_parts):
|
||||
if v1 < v2:
|
||||
return -1
|
||||
elif v1 > v2:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
except (ValueError, AttributeError):
|
||||
return -1
|
||||
|
||||
def validate_machine_code(self, machine_code: str) -> bool:
|
||||
"""
|
||||
验证机器码格式
|
||||
:param machine_code: 机器码字符串
|
||||
:return: 是否有效
|
||||
"""
|
||||
if not machine_code:
|
||||
return False
|
||||
|
||||
# 机器码应该是32位大写字母和数字的组合
|
||||
if len(machine_code) != 32:
|
||||
return False
|
||||
|
||||
pattern = r'^[A-F0-9]+$'
|
||||
return bool(re.match(pattern, machine_code))
|
||||
|
||||
def create_verification_hash(self, data: Dict[str, Any], secret_key: str) -> str:
|
||||
"""
|
||||
创建验证哈希
|
||||
:param data: 要验证的数据字典
|
||||
:param secret_key: 密钥
|
||||
:return: 哈希值
|
||||
"""
|
||||
# 按键排序确保一致性
|
||||
sorted_data = sorted(data.items())
|
||||
combined = '&'.join([f"{k}={v}" for k, v in sorted_data])
|
||||
combined += f"&key={secret_key}"
|
||||
|
||||
hash_obj = hashlib.sha256(combined.encode('utf-8'))
|
||||
return hash_obj.hexdigest()
|
||||
|
||||
def verify_hash(self, data: Dict[str, Any], hash_value: str, secret_key: str) -> bool:
|
||||
"""
|
||||
验证哈希值
|
||||
:param data: 原始数据字典
|
||||
:param hash_value: 要验证的哈希值
|
||||
:param secret_key: 密钥
|
||||
:return: 验证结果
|
||||
"""
|
||||
computed_hash = self.create_verification_hash(data, secret_key)
|
||||
return computed_hash == hash_value
|
||||
|
||||
def is_url_safe(self, url: str) -> bool:
|
||||
"""
|
||||
检查URL是否安全
|
||||
:param url: URL字符串
|
||||
:return: 是否安全
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
# 基本URL格式检查
|
||||
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
|
||||
if not re.match(pattern, url):
|
||||
return False
|
||||
|
||||
# 检查协议
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def sanitize_input(self, input_str: str) -> str:
|
||||
"""
|
||||
清理输入字符串
|
||||
:param input_str: 输入字符串
|
||||
:return: 清理后的字符串
|
||||
"""
|
||||
if not input_str:
|
||||
return ''
|
||||
|
||||
# 移除特殊字符
|
||||
dangerous_chars = ['<', '>', '"', "'", '&', '\x00']
|
||||
for char in dangerous_chars:
|
||||
input_str = input_str.replace(char, '')
|
||||
|
||||
# 限制长度
|
||||
return input_str[:1000]
|
||||
|
||||
def format_license_key(license_key: str) -> str:
|
||||
"""
|
||||
格式化卡密的便捷函数
|
||||
:param license_key: 原始卡密
|
||||
:return: 格式化后的卡密
|
||||
"""
|
||||
validator = LicenseValidator()
|
||||
return validator.format_license_key(license_key)
|
||||
|
||||
def validate_license_key(license_key: str) -> bool:
|
||||
"""
|
||||
验证卡密格式的便捷函数
|
||||
:param license_key: 卡密字符串
|
||||
:return: 是否有效
|
||||
"""
|
||||
validator = LicenseValidator()
|
||||
return validator.validate_license_key(license_key)
|
||||
|
||||
|
||||
# ==================== 通用验证工具 ====================
|
||||
|
||||
class ValidationError(Exception):
|
||||
"""验证错误"""
|
||||
pass
|
||||
|
||||
|
||||
class Validator:
|
||||
"""通用验证器类,提供链式验证"""
|
||||
|
||||
def __init__(self, value: Any, field_name: str = "字段"):
|
||||
self.value = value
|
||||
self.field_name = field_name
|
||||
self.errors = []
|
||||
|
||||
def required(self) -> 'Validator':
|
||||
"""验证必填"""
|
||||
if self.value is None or (isinstance(self.value, str) and not self.value.strip()):
|
||||
self.errors.append(f"{self.field_name}不能为空")
|
||||
return self
|
||||
|
||||
def min_length(self, min_len: int) -> 'Validator':
|
||||
"""验证最小长度"""
|
||||
if self.value and len(self.value) < min_len:
|
||||
self.errors.append(f"{self.field_name}长度不能少于{min_len}个字符")
|
||||
return self
|
||||
|
||||
def max_length(self, max_len: int) -> 'Validator':
|
||||
"""验证最大长度"""
|
||||
if self.value and len(self.value) > max_len:
|
||||
self.errors.append(f"{self.field_name}长度不能超过{max_len}个字符")
|
||||
return self
|
||||
|
||||
def email(self) -> 'Validator':
|
||||
"""验证邮箱"""
|
||||
if self.value and not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', self.value):
|
||||
self.errors.append(f"{self.field_name}格式不正确")
|
||||
return self
|
||||
|
||||
def phone(self) -> 'Validator':
|
||||
"""验证手机号"""
|
||||
if self.value and not re.match(r'^1[3-9]\d{9}$', str(self.value)):
|
||||
self.errors.append(f"{self.field_name}格式不正确")
|
||||
return self
|
||||
|
||||
def range(self, min_val: int, max_val: int) -> 'Validator':
|
||||
"""验证范围"""
|
||||
if self.value is not None and (self.value < min_val or self.value > max_val):
|
||||
self.errors.append(f"{self.field_name}必须在{min_val}-{max_val}之间")
|
||||
return self
|
||||
|
||||
def choice(self, choices: List[Any]) -> 'Validator':
|
||||
"""验证选项"""
|
||||
if self.value not in choices:
|
||||
self.errors.append(f"{self.field_name}必须是以下之一: {', '.join(map(str, choices))}")
|
||||
return self
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""验证是否通过"""
|
||||
return len(self.errors) == 0
|
||||
|
||||
def get_errors(self) -> List[str]:
|
||||
"""获取错误列表"""
|
||||
return self.errors
|
||||
|
||||
def raise_if_invalid(self) -> None:
|
||||
"""如果验证失败则抛出异常"""
|
||||
if not self.is_valid():
|
||||
raise ValidationError('; '.join(self.errors))
|
||||
|
||||
|
||||
def validate_timestamp(timestamp: int, max_seconds: int = 300) -> bool:
|
||||
"""
|
||||
验证时间戳有效性
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
max_seconds: 最大允许的时间差(秒)
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
try:
|
||||
request_time = datetime.fromtimestamp(timestamp)
|
||||
current_time = datetime.utcnow()
|
||||
time_diff = abs((current_time - request_time).total_seconds())
|
||||
return time_diff <= max_seconds
|
||||
except (ValueError, TypeError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def validate_product_id(product_id: str) -> bool:
|
||||
"""
|
||||
验证产品ID格式
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
pattern = r'^PROD_[A-F0-9]{8}$|^[A-Za-z0-9_]{1,32}$'
|
||||
return re.match(pattern, product_id) is not None
|
||||
|
||||
|
||||
def sanitize_string(value: str, max_length: int = 255) -> str:
|
||||
"""
|
||||
清理字符串(移除危险字符)
|
||||
|
||||
Args:
|
||||
value: 原始字符串
|
||||
max_length: 最大长度
|
||||
|
||||
Returns:
|
||||
str: 清理后的字符串
|
||||
"""
|
||||
if not value:
|
||||
return ''
|
||||
# 移除潜在的XSS攻击字符
|
||||
value = value.strip()
|
||||
# 截断到指定长度
|
||||
if len(value) > max_length:
|
||||
value = value[:max_length]
|
||||
return value
|
||||
|
||||
|
||||
def validate_filename(filename: str, allowed_extensions: Optional[List[str]] = None) -> None:
|
||||
"""
|
||||
验证文件名
|
||||
|
||||
Args:
|
||||
filename: 文件名
|
||||
allowed_extensions: 允许的扩展名列表
|
||||
|
||||
Raises:
|
||||
ValidationError: 如果文件名无效
|
||||
"""
|
||||
if not filename:
|
||||
raise ValidationError("文件名不能为空")
|
||||
|
||||
# 防止路径遍历攻击
|
||||
if '..' in filename or '/' in filename or '\\' in filename:
|
||||
raise ValidationError("文件名包含非法字符")
|
||||
|
||||
# 验证扩展名
|
||||
if allowed_extensions:
|
||||
ext = filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
|
||||
if ext not in allowed_extensions:
|
||||
raise ValidationError(f"文件扩展名必须是以下之一: {', '.join(allowed_extensions)}")
|
||||
|
||||
|
||||
def validate_pagination(page: int = 1, per_page: int = 20, max_per_page: int = 100) -> tuple:
|
||||
"""
|
||||
验证分页参数
|
||||
|
||||
Args:
|
||||
page: 页码
|
||||
per_page: 每页数量
|
||||
max_per_page: 最大每页数量
|
||||
|
||||
Returns:
|
||||
tuple: (page, per_page) 修正后的值
|
||||
"""
|
||||
page = max(1, page)
|
||||
per_page = min(max(1, per_page), max_per_page)
|
||||
return page, per_page
|
||||
Reference in New Issue
Block a user