194 lines
5.6 KiB
Python
194 lines
5.6 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
CSRF问题修复验证脚本
|
||
验证API和Web表单的CSRF配置是否正确
|
||
"""
|
||
import requests
|
||
import sys
|
||
from urllib.parse import urljoin
|
||
|
||
|
||
def test_api_without_csrf():
|
||
"""测试API是否豁免CSRF"""
|
||
print("=" * 60)
|
||
print("🔍 测试API CSRF豁免")
|
||
print("=" * 60)
|
||
|
||
base_url = "http://localhost:5000"
|
||
|
||
# 测试健康检查端点
|
||
print("\n1. 测试健康检查端点...")
|
||
try:
|
||
response = requests.get(f"{base_url}/api/v1/health", timeout=5)
|
||
if response.status_code == 200:
|
||
print(" ✅ 健康检查端点正常")
|
||
print(f" 响应: {response.json()}")
|
||
else:
|
||
print(f" ⚠️ 健康检查返回状态码: {response.status_code}")
|
||
except Exception as e:
|
||
print(f" ❌ 健康检查失败: {str(e)}")
|
||
return False
|
||
|
||
# 测试需要认证的API(应该返回401而不是400)
|
||
print("\n2. 测试需要认证的API...")
|
||
test_endpoints = [
|
||
"/api/v1/licenses",
|
||
"/api/v1/products"
|
||
]
|
||
|
||
for endpoint in test_endpoints:
|
||
try:
|
||
# GET请求(不应该需要CSRF)
|
||
response = requests.get(urljoin(base_url, endpoint), timeout=5)
|
||
|
||
if response.status_code == 400:
|
||
print(f" ❌ {endpoint} 返回400 - 可能是CSRF错误")
|
||
return False
|
||
elif response.status_code == 401:
|
||
print(f" ✅ {endpoint} 返回401 - 未认证(正确)")
|
||
elif response.status_code == 200:
|
||
print(f" ✅ {endpoint} 返回200 - 成功")
|
||
else:
|
||
print(f" ⚠️ {endpoint} 返回状态码: {response.status_code}")
|
||
|
||
except Exception as e:
|
||
print(f" ❌ {endpoint} 请求失败: {str(e)}")
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
def check_csrf_configuration():
|
||
"""检查CSRF配置"""
|
||
print("\n" + "=" * 60)
|
||
print("🔧 检查CSRF配置")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# 检查app/__init__.py中的CSRF配置
|
||
with open('app/__init__.py', 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
if 'csrf.exempt(api_bp)' in content:
|
||
print(" ✅ 已配置API CSRF豁免")
|
||
else:
|
||
print(" ❌ 未找到API CSRF豁免配置")
|
||
return False
|
||
|
||
if 'csrf.init_app(app)' in content:
|
||
print(" ✅ CSRF已初始化")
|
||
else:
|
||
print(" ❌ CSRF未初始化")
|
||
return False
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f" ❌ 检查配置文件失败: {str(e)}")
|
||
return False
|
||
|
||
|
||
def test_web_form_csrf():
|
||
"""测试Web表单CSRF保护"""
|
||
print("\n" + "=" * 60)
|
||
print("🔐 测试Web表单CSRF保护")
|
||
print("=" * 60)
|
||
|
||
base_url = "http://localhost:5000"
|
||
|
||
try:
|
||
# 获取登录页面
|
||
response = requests.get(f"{base_url}/login", timeout=5)
|
||
|
||
if response.status_code == 200:
|
||
print(" ✅ 登录页面可访问")
|
||
|
||
# 检查是否包含CSRF token
|
||
if 'csrf_token' in response.text or 'name="csrf_token"' in response.text:
|
||
print(" ✅ 登录页面包含CSRF token")
|
||
else:
|
||
print(" ⚠️ 登录页面未找到CSRF token")
|
||
else:
|
||
print(f" ⚠️ 登录页面返回状态码: {response.status_code}")
|
||
|
||
except Exception as e:
|
||
print(f" ❌ 测试Web表单失败: {str(e)}")
|
||
return False
|
||
|
||
return True
|
||
|
||
|
||
def generate_report():
|
||
"""生成检查报告"""
|
||
print("\n" + "=" * 60)
|
||
print("📊 CSRF问题修复验证报告")
|
||
print("=" * 60)
|
||
|
||
checks = [
|
||
("API CSRF豁免配置", check_csrf_configuration),
|
||
("API豁免CSRF测试", test_api_without_csrf),
|
||
("Web表单CSRF保护", test_web_form_csrf)
|
||
]
|
||
|
||
results = []
|
||
for name, check_func in checks:
|
||
try:
|
||
result = check_func()
|
||
results.append((name, result))
|
||
except Exception as e:
|
||
print(f"❌ {name}检查失败: {str(e)}")
|
||
results.append((name, False))
|
||
|
||
# 总结
|
||
print("\n" + "=" * 60)
|
||
print("📋 检查结果汇总")
|
||
print("=" * 60)
|
||
|
||
all_passed = True
|
||
for name, passed in results:
|
||
status = "✅ 通过" if passed else "❌ 失败"
|
||
print(f"{name:25s}: {status}")
|
||
if not passed:
|
||
all_passed = False
|
||
|
||
print("\n" + "=" * 60)
|
||
if all_passed:
|
||
print("✅ 所有检查通过!CSRF问题已修复")
|
||
print("\n💡 后续建议:")
|
||
print(" 1. 重启应用以应用修复")
|
||
print(" 2. 测试所有API端点")
|
||
print(" 3. 确认Web表单仍受CSRF保护")
|
||
else:
|
||
print("⚠️ 发现问题,请检查上述详细信息")
|
||
print("=" * 60)
|
||
|
||
return all_passed
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("🔧 KaMiXiTong CSRF问题修复验证工具")
|
||
print("检查API和Web表单的CSRF配置是否正确")
|
||
|
||
# 检查应用是否运行
|
||
try:
|
||
response = requests.get("http://localhost:5000/api/v1/health", timeout=3)
|
||
if response.status_code != 200:
|
||
print("\n⚠️ 应用似乎未运行或无法访问")
|
||
print("请先启动应用: flask run --host=0.0.0.0 --port=5000")
|
||
print("然后重新运行此脚本")
|
||
sys.exit(1)
|
||
except Exception as e:
|
||
print("\n⚠️ 无法连接到应用")
|
||
print(f"错误: {str(e)}")
|
||
print("请先启动应用: flask run --host=0.0.0.0 --port=5000")
|
||
print("然后重新运行此脚本")
|
||
sys.exit(1)
|
||
|
||
passed = generate_report()
|
||
sys.exit(0 if passed else 1)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|