783 lines
26 KiB
Python
783 lines
26 KiB
Python
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||
from fastapi.responses import JSONResponse, HTMLResponse, FileResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from typing import List
|
||
import os
|
||
import tempfile
|
||
import shutil
|
||
from datetime import datetime
|
||
from di_container import container
|
||
from config import settings
|
||
from tasks import process_files_task, generate_text_task
|
||
from config_manager import config_manager
|
||
from exceptions import handle_exception
|
||
from health_check import health_checker
|
||
from performance_monitor import perf_monitor
|
||
from model_manager import OpenAIProvider, OpenRouterProvider, SiliconFlowProvider
|
||
from knowledge_base_manager import kb_manager
|
||
|
||
app = FastAPI()
|
||
|
||
# Initialize components through DI container
|
||
file_processor = container.get('file_processor')
|
||
vector_store = container.get('vector_store')
|
||
hybrid_retriever = container.get('hybrid_retriever', vector_store=vector_store)
|
||
text_generator = container.get('text_generator')
|
||
exporter = container.get('exporter')
|
||
logger = container.get('logger')
|
||
|
||
# Serve static files
|
||
app.mount("/static", StaticFiles(directory="templates"), name="static")
|
||
|
||
@app.on_event("startup")
|
||
async def startup_event():
|
||
"""Initialize hybrid retriever with existing segments"""
|
||
# Start performance monitoring
|
||
perf_monitor.start_monitoring()
|
||
|
||
# Check if OpenAI API key is configured
|
||
if not settings.OPENAI_API_KEY:
|
||
print("=" * 60)
|
||
print("警告: 未检测到OpenAI API密钥!")
|
||
print("要使用完整功能,请:")
|
||
print("1. 在项目根目录创建 .env 文件")
|
||
print("2. 在 .env 文件中添加: OPENAI_API_KEY=your_actual_api_key")
|
||
print("3. 重启应用程序")
|
||
print("=" * 60)
|
||
|
||
segments = vector_store.get_all_segments()
|
||
if segments:
|
||
hybrid_retriever.prepare_bm25(segments)
|
||
|
||
# Print performance report
|
||
metrics = perf_monitor.stop_monitoring()
|
||
print(f"系统启动耗时: {metrics['execution_time']:.4f}s")
|
||
print(f"内存使用: {metrics['memory_used'] / 1024 / 1024:.2f}MB")
|
||
|
||
@app.get("/", response_class=HTMLResponse)
|
||
async def root():
|
||
with open("templates/index.html", "r", encoding="utf-8") as f:
|
||
html_content = f.read()
|
||
return HTMLResponse(content=html_content, status_code=200)
|
||
|
||
@app.get("/settings", response_class=HTMLResponse)
|
||
async def settings_page():
|
||
with open("templates/settings.html", "r", encoding="utf-8") as f:
|
||
html_content = f.read()
|
||
return HTMLResponse(content=html_content, status_code=200)
|
||
|
||
@app.get("/api/config")
|
||
async def get_config():
|
||
"""获取当前配置"""
|
||
return config_manager.get_all_config()
|
||
|
||
@app.post("/api/config")
|
||
async def update_config(config: dict):
|
||
"""更新配置"""
|
||
if config_manager.save_config(config):
|
||
return {"message": "配置更新成功"}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="配置更新失败")
|
||
|
||
@app.post("/upload")
|
||
async def upload_files(files: List[UploadFile] = File(...)):
|
||
"""
|
||
Upload and process TXT files
|
||
"""
|
||
if len(files) > settings.MAX_FILES:
|
||
raise HTTPException(status_code=400, detail=f"最多只能上传{settings.MAX_FILES}个文件")
|
||
|
||
# For immediate response, we'll process files synchronously in this endpoint
|
||
# In a full implementation, we would use Celery for async processing
|
||
|
||
# Save uploaded files temporarily
|
||
temp_files = []
|
||
try:
|
||
for file in files:
|
||
# Check file extension
|
||
if file.filename is None or not file.filename.endswith(".txt"):
|
||
raise HTTPException(status_code=400, detail=f"只支持TXT文件,发现文件: {file.filename}")
|
||
|
||
# Save file temporarily
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
|
||
# Read file content
|
||
content = await file.read()
|
||
# Write to temp file
|
||
temp_file.write(content)
|
||
temp_files.append(temp_file.name)
|
||
|
||
# Process files
|
||
processed_segments, empty_count, duplicate_count = file_processor.process_files(temp_files)
|
||
|
||
# Add to vector store
|
||
try:
|
||
vector_store.add_documents(processed_segments)
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=500, detail=f"向量存储错误: {str(e)}")
|
||
|
||
# Update hybrid retriever
|
||
all_segments = vector_store.get_all_segments()
|
||
hybrid_retriever.prepare_bm25(all_segments)
|
||
|
||
return {
|
||
"message": "文件上传处理成功",
|
||
"processed_segments": len(processed_segments),
|
||
"empty_files_filtered": empty_count,
|
||
"duplicate_files_filtered": duplicate_count
|
||
}
|
||
|
||
except Exception as e:
|
||
# 统一异常处理
|
||
error_info = handle_exception(e)
|
||
raise HTTPException(status_code=500, detail=error_info["message"])
|
||
|
||
finally:
|
||
# Clean up temporary files
|
||
for temp_file in temp_files:
|
||
if os.path.exists(temp_file):
|
||
os.unlink(temp_file)
|
||
|
||
@app.post("/upload_async")
|
||
async def upload_files_async(files: List[UploadFile] = File(...)):
|
||
"""
|
||
Upload and process TXT files asynchronously using Celery
|
||
"""
|
||
if len(files) > settings.MAX_FILES:
|
||
raise HTTPException(status_code=400, detail=f"最多只能上传{settings.MAX_FILES}个文件")
|
||
|
||
# Prepare file contents for async processing
|
||
file_contents = {}
|
||
for file in files:
|
||
# Check file extension
|
||
if file.filename is None or not file.filename.endswith(".txt"):
|
||
raise HTTPException(status_code=400, detail=f"只支持TXT文件,发现文件: {file.filename}")
|
||
|
||
# Read file content
|
||
content = await file.read()
|
||
file_contents[file.filename] = content.decode('utf-8') if isinstance(content, bytes) else content
|
||
|
||
# Start async task
|
||
task = process_files_task.delay(file_contents)
|
||
|
||
return {
|
||
"message": "文件上传任务已提交",
|
||
"task_id": task.id
|
||
}
|
||
|
||
@app.get("/task_status/{task_id}")
|
||
async def get_task_status(task_id: str):
|
||
"""
|
||
Get the status of an async task
|
||
"""
|
||
from celery.result import AsyncResult
|
||
task_result = AsyncResult(task_id)
|
||
|
||
if task_result.state == 'PENDING':
|
||
# Task is waiting to be processed
|
||
response = {
|
||
'state': task_result.state,
|
||
'status': '任务等待中...'
|
||
}
|
||
elif task_result.state == 'PROGRESS':
|
||
# Task is in progress
|
||
response = {
|
||
'state': task_result.state,
|
||
'status': task_result.info.get('status', ''),
|
||
'current': task_result.info.get('current', 0),
|
||
'total': task_result.info.get('total', 1),
|
||
'percent': task_result.info.get('percent', 0)
|
||
}
|
||
elif task_result.state == 'SUCCESS':
|
||
# Task completed successfully
|
||
response = {
|
||
'state': task_result.state,
|
||
'result': task_result.info
|
||
}
|
||
else:
|
||
# Task failed
|
||
response = {
|
||
'state': task_result.state,
|
||
'error': str(task_result.info)
|
||
}
|
||
|
||
return response
|
||
|
||
@app.post("/search")
|
||
async def search(query: str = Form(...), k: int = Form(settings.TOP_K),
|
||
use_hybrid: bool = Form(True), bm25_weight: float = Form(0.5)):
|
||
"""
|
||
Search for relevant segments
|
||
|
||
Args:
|
||
query: Search query
|
||
k: Number of results to return
|
||
use_hybrid: Whether to use hybrid search (BM25 + vector)
|
||
bm25_weight: Weight for BM25 scores in hybrid search
|
||
"""
|
||
try:
|
||
if use_hybrid:
|
||
# Use hybrid search
|
||
results = hybrid_retriever.hybrid_search(query, k, bm25_weight)
|
||
else:
|
||
# Use vector search only
|
||
results = vector_store.search(query, k)
|
||
|
||
# Format results
|
||
formatted_results = []
|
||
for result, score in results:
|
||
formatted_results.append({
|
||
"content": result["content"],
|
||
"metadata": result["metadata"],
|
||
"score": score,
|
||
"bm25_score": result.get("bm25_score", None),
|
||
"vector_score": result.get("vector_score", None)
|
||
})
|
||
|
||
return {
|
||
"query": query,
|
||
"results": formatted_results
|
||
}
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"检索失败: {str(e)}")
|
||
|
||
@app.post("/analyze_file/{file_name}")
|
||
async def analyze_file(file_name: str):
|
||
"""
|
||
Analyze a processed file for keywords, summary and tags
|
||
|
||
Args:
|
||
file_name: Name of the file to analyze
|
||
"""
|
||
try:
|
||
# 构造文件路径来查找缓存的分析结果
|
||
# 在实际实现中,我们会在文件处理时缓存分析结果
|
||
# 这里我们直接分析文件内容
|
||
|
||
# 查找文件的段落内容
|
||
all_segments = vector_store.get_all_segments()
|
||
file_segments = [s for s in all_segments if s["metadata"]["file_name"] == file_name]
|
||
|
||
if not file_segments:
|
||
raise HTTPException(status_code=404, detail=f"未找到文件 {file_name}")
|
||
|
||
# 合并文件内容
|
||
file_content = "\n".join([s["content"] for s in file_segments])
|
||
|
||
# 分析内容
|
||
analyzer = container.get('analyzer')
|
||
analysis = analyzer.analyze_content(file_content)
|
||
|
||
return {
|
||
"file_name": file_name,
|
||
"analysis": analysis
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"文件分析失败: {str(e)}")
|
||
|
||
@app.post("/generate")
|
||
async def generate_text(
|
||
query: str = Form(...),
|
||
style: str = Form("通用文案"),
|
||
min_length: int = Form(50),
|
||
max_length: int = Form(200),
|
||
use_hybrid: bool = Form(True),
|
||
bm25_weight: float = Form(0.5),
|
||
include_context: bool = Form(True),
|
||
enable_scoring: bool = Form(False)
|
||
):
|
||
"""
|
||
Generate text based on query
|
||
"""
|
||
try:
|
||
# Search for relevant segments
|
||
if use_hybrid:
|
||
search_results = hybrid_retriever.search_with_context(query, settings.TOP_K, include_context)
|
||
else:
|
||
vector_results = vector_store.search(query, settings.TOP_K)
|
||
search_results = [(result, score) for result, score in vector_results]
|
||
|
||
if not search_results:
|
||
raise HTTPException(status_code=404, detail="未找到相关文档内容")
|
||
|
||
# Extract context from search results
|
||
context_parts = []
|
||
source_segments = []
|
||
for result, score in search_results:
|
||
# Add main segment
|
||
context_parts.append(f"[{result['metadata']['segment_id']}] {result['content']}")
|
||
source_segments.append(result)
|
||
|
||
# Add context if available and requested
|
||
if include_context and "context" in result:
|
||
context = result["context"]
|
||
if context["previous"]:
|
||
context_parts.append(f"[{context['previous']['metadata']['segment_id']}] (前文) {context['previous']['content']}")
|
||
if context["next"]:
|
||
context_parts.append(f"[{context['next']['metadata']['segment_id']}] (后文) {context['next']['content']}")
|
||
|
||
context = "\n\n".join(context_parts)
|
||
|
||
# Generate text
|
||
generated_text = text_generator.generate_text(
|
||
context=context,
|
||
style=style,
|
||
min_length=min_length,
|
||
max_length=max_length
|
||
)
|
||
|
||
# Comprehensive hallucination detection
|
||
hallucination_keywords, hallucination_entities = text_generator.comprehensive_hallucination_check(
|
||
generated_text, context
|
||
)
|
||
|
||
# Combine all hallucination warnings
|
||
hallucination_warnings = {
|
||
"keywords": hallucination_keywords,
|
||
"entities": hallucination_entities
|
||
}
|
||
|
||
# Prepare response
|
||
response_data = {
|
||
"query": query,
|
||
"generated_text": generated_text,
|
||
"hallucination_warnings": hallucination_warnings,
|
||
"source_segments": source_segments
|
||
}
|
||
|
||
# Initialize score_result
|
||
score_result = None
|
||
|
||
# Add scoring if enabled
|
||
if enable_scoring:
|
||
score_result = text_generator.score_generation(generated_text, context, query)
|
||
response_data["score"] = score_result
|
||
|
||
# Log generation
|
||
logger.log_generation(
|
||
query=query,
|
||
generated_text=generated_text,
|
||
style=style,
|
||
source_segments=source_segments,
|
||
score=score_result,
|
||
hallucination_warnings=hallucination_warnings
|
||
)
|
||
|
||
return response_data
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"文本生成失败: {str(e)}")
|
||
|
||
@app.post("/score_generation")
|
||
async def score_generation(
|
||
generated_text: str = Form(...),
|
||
context: str = Form(...),
|
||
query: str = Form(...)
|
||
):
|
||
"""
|
||
Score a generated text
|
||
|
||
Args:
|
||
generated_text: The generated text to score
|
||
context: The source context
|
||
query: The original query
|
||
"""
|
||
try:
|
||
score_result = text_generator.score_generation(generated_text, context, query)
|
||
return score_result
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"评分失败: {str(e)}")
|
||
|
||
@app.post("/generate_async")
|
||
async def generate_text_async(
|
||
query: str = Form(...),
|
||
style: str = Form("通用文案"),
|
||
min_length: int = Form(50),
|
||
max_length: int = Form(200)
|
||
):
|
||
"""
|
||
Generate text asynchronously using Celery
|
||
"""
|
||
# Start async task
|
||
task = generate_text_task.delay(query, style, min_length, max_length)
|
||
|
||
return {
|
||
"message": "文本生成任务已提交",
|
||
"task_id": task.id
|
||
}
|
||
|
||
@app.post("/export_markdown")
|
||
async def export_to_markdown(
|
||
query: str = Form(...),
|
||
generated_text: str = Form(...),
|
||
file_name: str = Form(None)
|
||
):
|
||
"""
|
||
Export generated text to Markdown format
|
||
"""
|
||
# For now, we'll create empty source segments
|
||
# In a full implementation, these would come from the generation process
|
||
source_segments = []
|
||
|
||
# Export to markdown
|
||
export_path = exporter.export_to_markdown(query, generated_text, source_segments, file_name)
|
||
|
||
# Return file download
|
||
return FileResponse(export_path, media_type='application/octet-stream', filename=os.path.basename(export_path))
|
||
|
||
@app.post("/export_docx")
|
||
async def export_to_docx(
|
||
query: str = Form(...),
|
||
generated_text: str = Form(...),
|
||
file_name: str = Form(None)
|
||
):
|
||
"""
|
||
Export generated text to DOCX format
|
||
"""
|
||
# For now, we'll create empty source segments
|
||
# In a full implementation, these would come from the generation process
|
||
source_segments = []
|
||
|
||
# Export to DOCX
|
||
export_path = exporter.export_to_docx(query, generated_text, source_segments, file_name)
|
||
|
||
# Return file download
|
||
return FileResponse(export_path, media_type='application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||
filename=os.path.basename(export_path))
|
||
|
||
@app.get("/logs")
|
||
async def get_logs(limit: int = 100):
|
||
"""
|
||
Get generation logs
|
||
"""
|
||
logs = logger.get_logs(limit)
|
||
return {"logs": logs}
|
||
|
||
@app.delete("/logs")
|
||
async def clear_logs():
|
||
"""
|
||
Clear all generation logs
|
||
"""
|
||
logger.clear_logs()
|
||
return {"message": "日志已清空"}
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""
|
||
System health check
|
||
"""
|
||
return health_checker.run_health_check()
|
||
|
||
@app.post("/test_model_connection")
|
||
async def test_model_connection(
|
||
provider_type: str = Form(...),
|
||
model_name: str = Form(...),
|
||
api_key: str = Form(...),
|
||
api_base: str = Form(...)
|
||
):
|
||
"""
|
||
Test model connection
|
||
"""
|
||
try:
|
||
# Create a temporary provider for testing
|
||
if provider_type == "openai":
|
||
provider = OpenAIProvider(provider_type, api_key, api_base, [model_name])
|
||
elif provider_type == "openrouter":
|
||
provider = OpenRouterProvider(provider_type, api_key, api_base, [model_name])
|
||
elif provider_type == "siliconflow":
|
||
provider = SiliconFlowProvider(provider_type, api_key, api_base, [model_name])
|
||
else:
|
||
raise HTTPException(status_code=400, detail=f"不支持的提供商类型: {provider_type}")
|
||
|
||
# Test with a simple message
|
||
if "embedding" in model_name:
|
||
# Test embedding model
|
||
test_texts = ["测试连接"]
|
||
embeddings = provider.get_embeddings(model_name, test_texts)
|
||
if embeddings and len(embeddings) > 0:
|
||
return {"success": True, "message": "嵌入模型连接成功"}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="嵌入模型连接失败")
|
||
else:
|
||
# Test generation model
|
||
test_messages = [{"role": "user", "content": "你好"}]
|
||
response = provider.generate_text(model_name, test_messages, max_tokens=10)
|
||
if response:
|
||
return {"success": True, "message": "生成模型连接成功"}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="生成模型连接失败")
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"模型连接测试失败: {str(e)}")
|
||
|
||
@app.get("/knowledge_bases", response_class=HTMLResponse)
|
||
async def knowledge_bases_page():
|
||
with open("templates/knowledge_bases.html", "r", encoding="utf-8") as f:
|
||
html_content = f.read()
|
||
return HTMLResponse(content=html_content, status_code=200)
|
||
|
||
@app.get("/api/knowledge_bases")
|
||
async def list_knowledge_bases():
|
||
"""
|
||
List all knowledge bases
|
||
"""
|
||
try:
|
||
kbs = kb_manager.list_knowledge_bases()
|
||
return {"knowledge_bases": kbs}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取知识库列表失败: {str(e)}")
|
||
|
||
@app.post("/knowledge_bases")
|
||
async def create_knowledge_base(name: str = Form(...), description: str = Form("")):
|
||
"""
|
||
Create a new knowledge base
|
||
"""
|
||
try:
|
||
kb_manager.create_knowledge_base(name, description)
|
||
return {"message": f"知识库 {name} 创建成功"}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"创建知识库失败: {str(e)}")
|
||
|
||
@app.delete("/knowledge_bases/{name}")
|
||
async def delete_knowledge_base(name: str):
|
||
"""
|
||
Delete a knowledge base
|
||
"""
|
||
try:
|
||
kb_manager.delete_knowledge_base(name)
|
||
return {"message": f"知识库 {name} 删除成功"}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"删除知识库失败: {str(e)}")
|
||
|
||
@app.post("/knowledge_bases/{name}/upload")
|
||
async def upload_to_knowledge_base(name: str, files: List[UploadFile] = File(...)):
|
||
"""
|
||
Upload files to a specific knowledge base
|
||
"""
|
||
try:
|
||
# 获取知识库
|
||
kb = kb_manager.get_knowledge_base(name)
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail=f"知识库 {name} 不存在")
|
||
|
||
# 保存上传的文件
|
||
temp_files = []
|
||
for file in files:
|
||
# 检查文件扩展名
|
||
if file.filename is None or not file.filename.endswith(".txt"):
|
||
raise HTTPException(status_code=400, detail=f"只支持TXT文件,发现文件: {file.filename}")
|
||
|
||
# 保存文件到临时位置
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
|
||
content = await file.read()
|
||
temp_file.write(content)
|
||
temp_files.append(temp_file.name)
|
||
|
||
# 处理文件
|
||
processed_segments, empty_count, duplicate_count = file_processor.process_files(temp_files)
|
||
|
||
# 添加到知识库
|
||
kb.add_documents(processed_segments)
|
||
|
||
# 清理临时文件
|
||
for temp_file in temp_files:
|
||
if os.path.exists(temp_file):
|
||
os.unlink(temp_file)
|
||
|
||
return {
|
||
"message": f"文件上传到知识库 {name} 成功",
|
||
"processed_segments": len(processed_segments),
|
||
"empty_files_filtered": empty_count,
|
||
"duplicate_files_filtered": duplicate_count
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"上传文件到知识库失败: {str(e)}")
|
||
|
||
@app.post("/knowledge_bases/{name}/search")
|
||
async def search_knowledge_base(name: str, query: str = Form(...), k: int = Form(8)):
|
||
"""
|
||
Search in a specific knowledge base
|
||
"""
|
||
try:
|
||
# 获取知识库
|
||
kb = kb_manager.get_knowledge_base(name)
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail=f"知识库 {name} 不存在")
|
||
|
||
# 搜索
|
||
results = kb.search(query, k)
|
||
|
||
# 格式化结果
|
||
formatted_results = []
|
||
for result, score in results:
|
||
formatted_results.append({
|
||
"content": result["content"],
|
||
"metadata": result["metadata"],
|
||
"score": score
|
||
})
|
||
|
||
return {
|
||
"query": query,
|
||
"knowledge_base": name,
|
||
"results": formatted_results
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"知识库搜索失败: {str(e)}")
|
||
|
||
@app.post("/ask")
|
||
async def ask_question(
|
||
query: str = Form(...),
|
||
style: str = Form("通用文案"),
|
||
min_length: int = Form(50),
|
||
max_length: int = Form(500),
|
||
knowledge_base: str = Form(...)
|
||
):
|
||
"""
|
||
Ask a question based on a specific knowledge base
|
||
"""
|
||
try:
|
||
# 获取知识库
|
||
kb = kb_manager.get_knowledge_base(knowledge_base)
|
||
if not kb:
|
||
raise HTTPException(status_code=404, detail=f"知识库 {knowledge_base} 不存在")
|
||
|
||
# 在知识库中搜索相关内容
|
||
search_results = kb.search(query, settings.TOP_K)
|
||
|
||
# 提取相关内容作为上下文
|
||
context_parts = []
|
||
related_content = []
|
||
for result, score in search_results:
|
||
context_parts.append(f"[{result['metadata']['segment_id']}] {result['content']}")
|
||
related_content.append({
|
||
"content": result["content"],
|
||
"metadata": result["metadata"],
|
||
"score": score
|
||
})
|
||
|
||
context = "\n\n".join(context_parts)
|
||
|
||
# 生成回答
|
||
generated_text = text_generator.generate_text(
|
||
context=context,
|
||
style=style,
|
||
min_length=min_length,
|
||
max_length=max_length
|
||
)
|
||
|
||
# 幻觉检测
|
||
hallucination_keywords, hallucination_entities = text_generator.comprehensive_hallucination_check(
|
||
generated_text, context
|
||
)
|
||
|
||
# 准备响应数据
|
||
response_data = {
|
||
"query": query,
|
||
"answer": generated_text,
|
||
"knowledge_base": knowledge_base,
|
||
"related_content": related_content,
|
||
"hallucination_warnings": {
|
||
"keywords": hallucination_keywords,
|
||
"entities": hallucination_entities
|
||
}
|
||
}
|
||
|
||
return response_data
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"问答处理失败: {str(e)}")
|
||
|
||
@app.post("/export_markdown_qa")
|
||
async def export_qa_to_markdown(
|
||
query: str = Form(...),
|
||
generated_text: str = Form(...),
|
||
knowledge_base: str = Form(None)
|
||
):
|
||
"""
|
||
Export QA result to Markdown format
|
||
"""
|
||
try:
|
||
# 创建导出内容
|
||
export_content = f"# 问答结果\n\n"
|
||
export_content += f"**问题**: {query}\n\n"
|
||
export_content += f"**知识库**: {knowledge_base or '默认'}\n\n"
|
||
export_content += f"**回答**:\n\n{generated_text}\n\n"
|
||
export_content += f"---\n\n"
|
||
export_content += f"导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||
|
||
# 生成文件名
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
file_name = f"问答结果_{timestamp}.md"
|
||
file_path = os.path.join("exports", file_name)
|
||
|
||
# 确保导出目录存在
|
||
os.makedirs("exports", exist_ok=True)
|
||
|
||
# 写入文件
|
||
with open(file_path, "w", encoding="utf-8") as f:
|
||
f.write(export_content)
|
||
|
||
# 返回文件下载
|
||
return FileResponse(file_path, media_type='application/octet-stream', filename=file_name)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"导出失败: {str(e)}")
|
||
|
||
@app.post("/export_docx_qa")
|
||
async def export_qa_to_docx(
|
||
query: str = Form(...),
|
||
generated_text: str = Form(...),
|
||
knowledge_base: str = Form(None)
|
||
):
|
||
"""
|
||
Export QA result to DOCX format
|
||
"""
|
||
try:
|
||
# 我们将使用纯文本格式替代DOCX,避免依赖问题
|
||
export_content = f"问答结果\n"
|
||
export_content += f"========\n\n"
|
||
export_content += f"问题: {query}\n\n"
|
||
export_content += f"知识库: {knowledge_base or '默认'}\n\n"
|
||
export_content += f"回答:\n{generated_text}\n\n"
|
||
export_content += f"---\n\n"
|
||
export_content += f"导出时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||
|
||
# 生成文件名
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
file_name = f"问答结果_{timestamp}.txt"
|
||
file_path = os.path.join("exports", file_name)
|
||
|
||
# 确保导出目录存在
|
||
os.makedirs("exports", exist_ok=True)
|
||
|
||
# 写入文件
|
||
with open(file_path, "w", encoding="utf-8") as f:
|
||
f.write(export_content)
|
||
|
||
# 返回文件下载
|
||
return FileResponse(file_path, media_type='application/octet-stream', filename=file_name)
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"TXT导出失败: {str(e)}")
|
||
|
||
@app.get("/system-info")
|
||
async def system_info():
|
||
"""
|
||
Get system information
|
||
"""
|
||
return health_checker.get_system_info()
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8001) |