nodebookls/main.py
2025-10-29 13:56:24 +08:00

783 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from typing import List
import os
import tempfile
import shutil
from datetime import datetime
from di_container import container
from config import settings
from tasks import process_files_task, generate_text_task
from config_manager import config_manager
from exceptions import handle_exception
from health_check import health_checker
from performance_monitor import perf_monitor
from model_manager import OpenAIProvider, OpenRouterProvider, SiliconFlowProvider
from knowledge_base_manager import kb_manager
app = FastAPI()
# Initialize components through DI container
file_processor = container.get('file_processor')
vector_store = container.get('vector_store')
hybrid_retriever = container.get('hybrid_retriever', vector_store=vector_store)
text_generator = container.get('text_generator')
exporter = container.get('exporter')
logger = container.get('logger')
# Serve static files
app.mount("/static", StaticFiles(directory="templates"), name="static")
@app.on_event("startup")
async def startup_event():
"""Initialize hybrid retriever with existing segments"""
# Start performance monitoring
perf_monitor.start_monitoring()
# Check if OpenAI API key is configured
if not settings.OPENAI_API_KEY:
print("=" * 60)
print("警告: 未检测到OpenAI API密钥!")
print("要使用完整功能,请:")
print("1. 在项目根目录创建 .env 文件")
print("2. 在 .env 文件中添加: OPENAI_API_KEY=your_actual_api_key")
print("3. 重启应用程序")
print("=" * 60)
segments = vector_store.get_all_segments()
if segments:
hybrid_retriever.prepare_bm25(segments)
# Print performance report
metrics = perf_monitor.stop_monitoring()
print(f"系统启动耗时: {metrics['execution_time']:.4f}s")
print(f"内存使用: {metrics['memory_used'] / 1024 / 1024:.2f}MB")
@app.get("/", response_class=HTMLResponse)
async def root():
with open("templates/index.html", "r", encoding="utf-8") as f:
html_content = f.read()
return HTMLResponse(content=html_content, status_code=200)
@app.get("/settings", response_class=HTMLResponse)
async def settings_page():
with open("templates/settings.html", "r", encoding="utf-8") as f:
html_content = f.read()
return HTMLResponse(content=html_content, status_code=200)
@app.get("/api/config")
async def get_config():
"""获取当前配置"""
return config_manager.get_all_config()
@app.post("/api/config")
async def update_config(config: dict):
"""更新配置"""
if config_manager.save_config(config):
return {"message": "配置更新成功"}
else:
raise HTTPException(status_code=500, detail="配置更新失败")
@app.post("/upload")
async def upload_files(files: List[UploadFile] = File(...)):
"""
Upload and process TXT files
"""
if len(files) > settings.MAX_FILES:
raise HTTPException(status_code=400, detail=f"最多只能上传{settings.MAX_FILES}个文件")
# For immediate response, we'll process files synchronously in this endpoint
# In a full implementation, we would use Celery for async processing
# Save uploaded files temporarily
temp_files = []
try:
for file in files:
# Check file extension
if file.filename is None or not file.filename.endswith(".txt"):
raise HTTPException(status_code=400, detail=f"只支持TXT文件发现文件: {file.filename}")
# Save file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
# Read file content
content = await file.read()
# Write to temp file
temp_file.write(content)
temp_files.append(temp_file.name)
# Process files
processed_segments, empty_count, duplicate_count = file_processor.process_files(temp_files)
# Add to vector store
try:
vector_store.add_documents(processed_segments)
except ValueError as e:
raise HTTPException(status_code=500, detail=f"向量存储错误: {str(e)}")
# Update hybrid retriever
all_segments = vector_store.get_all_segments()
hybrid_retriever.prepare_bm25(all_segments)
return {
"message": "文件上传处理成功",
"processed_segments": len(processed_segments),
"empty_files_filtered": empty_count,
"duplicate_files_filtered": duplicate_count
}
except Exception as e:
# 统一异常处理
error_info = handle_exception(e)
raise HTTPException(status_code=500, detail=error_info["message"])
finally:
# Clean up temporary files
for temp_file in temp_files:
if os.path.exists(temp_file):
os.unlink(temp_file)
@app.post("/upload_async")
async def upload_files_async(files: List[UploadFile] = File(...)):
"""
Upload and process TXT files asynchronously using Celery
"""
if len(files) > settings.MAX_FILES:
raise HTTPException(status_code=400, detail=f"最多只能上传{settings.MAX_FILES}个文件")
# Prepare file contents for async processing
file_contents = {}
for file in files:
# Check file extension
if file.filename is None or not file.filename.endswith(".txt"):
raise HTTPException(status_code=400, detail=f"只支持TXT文件发现文件: {file.filename}")
# Read file content
content = await file.read()
file_contents[file.filename] = content.decode('utf-8') if isinstance(content, bytes) else content
# Start async task
task = process_files_task.delay(file_contents)
return {
"message": "文件上传任务已提交",
"task_id": task.id
}
@app.get("/task_status/{task_id}")
async def get_task_status(task_id: str):
"""
Get the status of an async task
"""
from celery.result import AsyncResult
task_result = AsyncResult(task_id)
if task_result.state == 'PENDING':
# Task is waiting to be processed
response = {
'state': task_result.state,
'status': '任务等待中...'
}
elif task_result.state == 'PROGRESS':
# Task is in progress
response = {
'state': task_result.state,
'status': task_result.info.get('status', ''),
'current': task_result.info.get('current', 0),
'total': task_result.info.get('total', 1),
'percent': task_result.info.get('percent', 0)
}
elif task_result.state == 'SUCCESS':
# Task completed successfully
response = {
'state': task_result.state,
'result': task_result.info
}
else:
# Task failed
response = {
'state': task_result.state,
'error': str(task_result.info)
}
return response
@app.post("/search")
async def search(query: str = Form(...), k: int = Form(settings.TOP_K),
use_hybrid: bool = Form(True), bm25_weight: float = Form(0.5)):
"""
Search for relevant segments
Args:
query: Search query
k: Number of results to return
use_hybrid: Whether to use hybrid search (BM25 + vector)
bm25_weight: Weight for BM25 scores in hybrid search
"""
try:
if use_hybrid:
# Use hybrid search
results = hybrid_retriever.hybrid_search(query, k, bm25_weight)
else:
# Use vector search only
results = vector_store.search(query, k)
# Format results
formatted_results = []
for result, score in results:
formatted_results.append({
"content": result["content"],
"metadata": result["metadata"],
"score": score,
"bm25_score": result.get("bm25_score", None),
"vector_score": result.get("vector_score", None)
})
return {
"query": query,
"results": formatted_results
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"检索失败: {str(e)}")
@app.post("/analyze_file/{file_name}")
async def analyze_file(file_name: str):
"""
Analyze a processed file for keywords, summary and tags
Args:
file_name: Name of the file to analyze
"""
try:
# 构造文件路径来查找缓存的分析结果
# 在实际实现中,我们会在文件处理时缓存分析结果
# 这里我们直接分析文件内容
# 查找文件的段落内容
all_segments = vector_store.get_all_segments()
file_segments = [s for s in all_segments if s["metadata"]["file_name"] == file_name]
if not file_segments:
raise HTTPException(status_code=404, detail=f"未找到文件 {file_name}")
# 合并文件内容
file_content = "\n".join([s["content"] for s in file_segments])
# 分析内容
analyzer = container.get('analyzer')
analysis = analyzer.analyze_content(file_content)
return {
"file_name": file_name,
"analysis": analysis
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"文件分析失败: {str(e)}")
@app.post("/generate")
async def generate_text(
query: str = Form(...),
style: str = Form("通用文案"),
min_length: int = Form(50),
max_length: int = Form(200),
use_hybrid: bool = Form(True),
bm25_weight: float = Form(0.5),
include_context: bool = Form(True),
enable_scoring: bool = Form(False)
):
"""
Generate text based on query
"""
try:
# Search for relevant segments
if use_hybrid:
search_results = hybrid_retriever.search_with_context(query, settings.TOP_K, include_context)
else:
vector_results = vector_store.search(query, settings.TOP_K)
search_results = [(result, score) for result, score in vector_results]
if not search_results:
raise HTTPException(status_code=404, detail="未找到相关文档内容")
# Extract context from search results
context_parts = []
source_segments = []
for result, score in search_results:
# Add main segment
context_parts.append(f"[{result['metadata']['segment_id']}] {result['content']}")
source_segments.append(result)
# Add context if available and requested
if include_context and "context" in result:
context = result["context"]
if context["previous"]:
context_parts.append(f"[{context['previous']['metadata']['segment_id']}] (前文) {context['previous']['content']}")
if context["next"]:
context_parts.append(f"[{context['next']['metadata']['segment_id']}] (后文) {context['next']['content']}")
context = "\n\n".join(context_parts)
# Generate text
generated_text = text_generator.generate_text(
context=context,
style=style,
min_length=min_length,
max_length=max_length
)
# Comprehensive hallucination detection
hallucination_keywords, hallucination_entities = text_generator.comprehensive_hallucination_check(
generated_text, context
)
# Combine all hallucination warnings
hallucination_warnings = {
"keywords": hallucination_keywords,
"entities": hallucination_entities
}
# Prepare response
response_data = {
"query": query,
"generated_text": generated_text,
"hallucination_warnings": hallucination_warnings,
"source_segments": source_segments
}
# Initialize score_result
score_result = None
# Add scoring if enabled
if enable_scoring:
score_result = text_generator.score_generation(generated_text, context, query)
response_data["score"] = score_result
# Log generation
logger.log_generation(
query=query,
generated_text=generated_text,
style=style,
source_segments=source_segments,
score=score_result,
hallucination_warnings=hallucination_warnings
)
return response_data
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"文本生成失败: {str(e)}")
@app.post("/score_generation")
async def score_generation(
generated_text: str = Form(...),
context: str = Form(...),
query: str = Form(...)
):
"""
Score a generated text
Args:
generated_text: The generated text to score
context: The source context
query: The original query
"""
try:
score_result = text_generator.score_generation(generated_text, context, query)
return score_result
except Exception as e:
raise HTTPException(status_code=500, detail=f"评分失败: {str(e)}")
@app.post("/generate_async")
async def generate_text_async(
query: str = Form(...),
style: str = Form("通用文案"),
min_length: int = Form(50),
max_length: int = Form(200)
):
"""
Generate text asynchronously using Celery
"""
# Start async task
task = generate_text_task.delay(query, style, min_length, max_length)
return {
"message": "文本生成任务已提交",
"task_id": task.id
}
@app.post("/export_markdown")
async def export_to_markdown(
query: str = Form(...),
generated_text: str = Form(...),
file_name: str = Form(None)
):
"""
Export generated text to Markdown format
"""
# For now, we'll create empty source segments
# In a full implementation, these would come from the generation process
source_segments = []
# Export to markdown
export_path = exporter.export_to_markdown(query, generated_text, source_segments, file_name)
# Return file download
return FileResponse(export_path, media_type='application/octet-stream', filename=os.path.basename(export_path))
@app.post("/export_docx")
async def export_to_docx(
query: str = Form(...),
generated_text: str = Form(...),
file_name: str = Form(None)
):
"""
Export generated text to DOCX format
"""
# For now, we'll create empty source segments
# In a full implementation, these would come from the generation process
source_segments = []
# Export to DOCX
export_path = exporter.export_to_docx(query, generated_text, source_segments, file_name)
# Return file download
return FileResponse(export_path, media_type='application/vnd.openxmlformats-officedocument.wordprocessingml.document',
filename=os.path.basename(export_path))
@app.get("/logs")
async def get_logs(limit: int = 100):
"""
Get generation logs
"""
logs = logger.get_logs(limit)
return {"logs": logs}
@app.delete("/logs")
async def clear_logs():
"""
Clear all generation logs
"""
logger.clear_logs()
return {"message": "日志已清空"}
@app.get("/health")
async def health_check():
"""
System health check
"""
return health_checker.run_health_check()
@app.post("/test_model_connection")
async def test_model_connection(
provider_type: str = Form(...),
model_name: str = Form(...),
api_key: str = Form(...),
api_base: str = Form(...)
):
"""
Test model connection
"""
try:
# Create a temporary provider for testing
if provider_type == "openai":
provider = OpenAIProvider(provider_type, api_key, api_base, [model_name])
elif provider_type == "openrouter":
provider = OpenRouterProvider(provider_type, api_key, api_base, [model_name])
elif provider_type == "siliconflow":
provider = SiliconFlowProvider(provider_type, api_key, api_base, [model_name])
else:
raise HTTPException(status_code=400, detail=f"不支持的提供商类型: {provider_type}")
# Test with a simple message
if "embedding" in model_name:
# Test embedding model
test_texts = ["测试连接"]
embeddings = provider.get_embeddings(model_name, test_texts)
if embeddings and len(embeddings) > 0:
return {"success": True, "message": "嵌入模型连接成功"}
else:
raise HTTPException(status_code=500, detail="嵌入模型连接失败")
else:
# Test generation model
test_messages = [{"role": "user", "content": "你好"}]
response = provider.generate_text(model_name, test_messages, max_tokens=10)
if response:
return {"success": True, "message": "生成模型连接成功"}
else:
raise HTTPException(status_code=500, detail="生成模型连接失败")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"模型连接测试失败: {str(e)}")
@app.get("/knowledge_bases", response_class=HTMLResponse)
async def knowledge_bases_page():
with open("templates/knowledge_bases.html", "r", encoding="utf-8") as f:
html_content = f.read()
return HTMLResponse(content=html_content, status_code=200)
@app.get("/api/knowledge_bases")
async def list_knowledge_bases():
"""
List all knowledge bases
"""
try:
kbs = kb_manager.list_knowledge_bases()
return {"knowledge_bases": kbs}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取知识库列表失败: {str(e)}")
@app.post("/knowledge_bases")
async def create_knowledge_base(name: str = Form(...), description: str = Form("")):
"""
Create a new knowledge base
"""
try:
kb_manager.create_knowledge_base(name, description)
return {"message": f"知识库 {name} 创建成功"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"创建知识库失败: {str(e)}")
@app.delete("/knowledge_bases/{name}")
async def delete_knowledge_base(name: str):
"""
Delete a knowledge base
"""
try:
kb_manager.delete_knowledge_base(name)
return {"message": f"知识库 {name} 删除成功"}
except Exception as e:
raise HTTPException(status_code=500, detail=f"删除知识库失败: {str(e)}")
@app.post("/knowledge_bases/{name}/upload")
async def upload_to_knowledge_base(name: str, files: List[UploadFile] = File(...)):
"""
Upload files to a specific knowledge base
"""
try:
# 获取知识库
kb = kb_manager.get_knowledge_base(name)
if not kb:
raise HTTPException(status_code=404, detail=f"知识库 {name} 不存在")
# 保存上传的文件
temp_files = []
for file in files:
# 检查文件扩展名
if file.filename is None or not file.filename.endswith(".txt"):
raise HTTPException(status_code=400, detail=f"只支持TXT文件发现文件: {file.filename}")
# 保存文件到临时位置
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
content = await file.read()
temp_file.write(content)
temp_files.append(temp_file.name)
# 处理文件
processed_segments, empty_count, duplicate_count = file_processor.process_files(temp_files)
# 添加到知识库
kb.add_documents(processed_segments)
# 清理临时文件
for temp_file in temp_files:
if os.path.exists(temp_file):
os.unlink(temp_file)
return {
"message": f"文件上传到知识库 {name} 成功",
"processed_segments": len(processed_segments),
"empty_files_filtered": empty_count,
"duplicate_files_filtered": duplicate_count
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"上传文件到知识库失败: {str(e)}")
@app.post("/knowledge_bases/{name}/search")
async def search_knowledge_base(name: str, query: str = Form(...), k: int = Form(8)):
"""
Search in a specific knowledge base
"""
try:
# 获取知识库
kb = kb_manager.get_knowledge_base(name)
if not kb:
raise HTTPException(status_code=404, detail=f"知识库 {name} 不存在")
# 搜索
results = kb.search(query, k)
# 格式化结果
formatted_results = []
for result, score in results:
formatted_results.append({
"content": result["content"],
"metadata": result["metadata"],
"score": score
})
return {
"query": query,
"knowledge_base": name,
"results": formatted_results
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"知识库搜索失败: {str(e)}")
@app.post("/ask")
async def ask_question(
query: str = Form(...),
style: str = Form("通用文案"),
min_length: int = Form(50),
max_length: int = Form(500),
knowledge_base: str = Form(...)
):
"""
Ask a question based on a specific knowledge base
"""
try:
# 获取知识库
kb = kb_manager.get_knowledge_base(knowledge_base)
if not kb:
raise HTTPException(status_code=404, detail=f"知识库 {knowledge_base} 不存在")
# 在知识库中搜索相关内容
search_results = kb.search(query, settings.TOP_K)
# 提取相关内容作为上下文
context_parts = []
related_content = []
for result, score in search_results:
context_parts.append(f"[{result['metadata']['segment_id']}] {result['content']}")
related_content.append({
"content": result["content"],
"metadata": result["metadata"],
"score": score
})
context = "\n\n".join(context_parts)
# 生成回答
generated_text = text_generator.generate_text(
context=context,
style=style,
min_length=min_length,
max_length=max_length
)
# 幻觉检测
hallucination_keywords, hallucination_entities = text_generator.comprehensive_hallucination_check(
generated_text, context
)
# 准备响应数据
response_data = {
"query": query,
"answer": generated_text,
"knowledge_base": knowledge_base,
"related_content": related_content,
"hallucination_warnings": {
"keywords": hallucination_keywords,
"entities": hallucination_entities
}
}
return response_data
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"问答处理失败: {str(e)}")
@app.post("/export_markdown_qa")
async def export_qa_to_markdown(
query: str = Form(...),
generated_text: str = Form(...),
knowledge_base: str = Form(None)
):
"""
Export QA result to Markdown format
"""
try:
# 创建导出内容
export_content = f"# 问答结果\n\n"
export_content += f"**问题**: {query}\n\n"
export_content += f"**知识库**: {knowledge_base or '默认'}\n\n"
export_content += f"**回答**:\n\n{generated_text}\n\n"
export_content += f"---\n\n"
export_content += f"导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"问答结果_{timestamp}.md"
file_path = os.path.join("exports", file_name)
# 确保导出目录存在
os.makedirs("exports", exist_ok=True)
# 写入文件
with open(file_path, "w", encoding="utf-8") as f:
f.write(export_content)
# 返回文件下载
return FileResponse(file_path, media_type='application/octet-stream', filename=file_name)
except Exception as e:
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
@app.post("/export_docx_qa")
async def export_qa_to_docx(
query: str = Form(...),
generated_text: str = Form(...),
knowledge_base: str = Form(None)
):
"""
Export QA result to DOCX format
"""
try:
# 我们将使用纯文本格式替代DOCX避免依赖问题
export_content = f"问答结果\n"
export_content += f"========\n\n"
export_content += f"问题: {query}\n\n"
export_content += f"知识库: {knowledge_base or '默认'}\n\n"
export_content += f"回答:\n{generated_text}\n\n"
export_content += f"---\n\n"
export_content += f"导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_name = f"问答结果_{timestamp}.txt"
file_path = os.path.join("exports", file_name)
# 确保导出目录存在
os.makedirs("exports", exist_ok=True)
# 写入文件
with open(file_path, "w", encoding="utf-8") as f:
f.write(export_content)
# 返回文件下载
return FileResponse(file_path, media_type='application/octet-stream', filename=file_name)
except Exception as e:
raise HTTPException(status_code=500, detail=f"TXT导出失败: {str(e)}")
@app.get("/system-info")
async def system_info():
"""
Get system information
"""
return health_checker.get_system_info()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)