nodebookls/generator.py

232 lines
8.0 KiB
Python
Raw Normal View History

2025-10-29 13:56:24 +08:00
from typing import List, Dict, Optional, Tuple
from config import settings
from model_manager import model_router
from style_templates import style_manager
from exceptions import GenerationError
class TextGenerator:
def __init__(self):
self.client = self._initialize_client()
def _initialize_client(self):
"""Initialize API client - 使用模型路由器,不需要单独初始化客户端"""
return None
def generate_text(self, context: str, style: str = "通用文案",
min_length: int = 50, max_length: int = 200,
history: Optional[List[Dict]] = None) -> str:
"""
Generate text based on context and style
Args:
context: Retrieved context information
style: Writing style (e.g., 小红书种草风, 官方通告, 知乎科普)
min_length: Minimum text length
max_length: Maximum text length
history: Conversation history
Returns:
Generated text
"""
# 检查是否有可用的模型提供商
if not model_router.providers:
raise GenerationError("未配置任何模型提供商")
try:
# Get style template
template_info = style_manager.get_template(style)
prompt_template = template_info["template"]
temperature = template_info["temperature"]
# Build prompt
prompt = prompt_template.format(
context=context,
min_length=min_length,
max_length=max_length
)
# Prepare messages
messages = []
# Add history if provided
if history:
# Only keep last 2 rounds of conversation
recent_history = history[-4:] # 2 rounds = 4 messages (user/assistant pairs)
messages.extend(recent_history)
# Add current prompt
messages.append({"role": "user", "content": prompt})
# Calculate max tokens based on max length
max_tokens = max_length * settings.MAX_TOKENS_FACTOR
# Generate text using model router
content = model_router.generate_text(
model=settings.GENERATION_MODEL,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return content
except Exception as e:
raise GenerationError(f"文本生成失败: {str(e)}")
def score_generation(self, generated_text: str, context: str, query: str) -> Dict:
"""
Score the quality of generated text using GPT-4
Args:
generated_text: The generated text to score
context: The source context
query: The original query
Returns:
Dictionary containing score and feedback
"""
# 检查是否有可用的模型提供商
if not model_router.providers:
return {
"score": 0,
"feedback": "未配置任何模型提供商"
}
try:
# Create scoring prompt
prompt = f"""请对以下生成的文本进行评分,评估其与原始查询和上下文的一致性:
原始查询: {query}
上下文信息: {context}
生成文本: {generated_text}
请从以下维度进行评分满分100分
1. 相关性30生成内容与查询的相关程度
2. 准确性30生成内容与上下文信息的一致程度
3. 完整性20是否充分回答了查询
4. 流畅性20语言表达是否自然流畅
请提供
- 总分0-100
- 各维度得分
- 简要反馈意见
- 改进建议
请以以下JSON格式返回结果
{{
"total_score": 85,
"dimensions": {{
"relevance": 25,
"accuracy": 28,
"completeness": 18,
"fluency": 14
}},
"feedback": "生成内容与查询相关,但可以更详细...",
"suggestions": "建议增加更多具体示例..."
}}
"""
# Generate score using model router with GPT-4
content = model_router.generate_text(
model="gpt-4",
messages=[
{"role": "system", "content": "你是一个专业的文本质量评估专家。"},
{"role": "user", "content": prompt}
],
temperature=0.1,
max_tokens=500
)
if content:
import json
try:
# Try to parse as JSON
score_data = json.loads(content)
return score_data
except json.JSONDecodeError:
# If JSON parsing fails, return as feedback
return {
"score": 0,
"feedback": content
}
else:
return {
"score": 0,
"feedback": "评分生成失败"
}
except Exception as e:
return {
"score": 0,
"feedback": f"评分生成错误: {str(e)}"
}
def detect_hallucination_keywords(self, text: str) -> List[str]:
"""
Detect hallucination keywords in generated text
Args:
text: Generated text
Returns:
List of detected hallucination keywords
"""
hallucination_keywords = [
"据悉", "据报道", "研究表明", "据专家称", "有消息称",
"据了解", "据分析", "据预测", "据估计", "据透露",
"可能", "也许", "大概", "似乎", "看起来",
"普遍认为", "大多数人认为", "通常情况下"
]
detected = []
for keyword in hallucination_keywords:
if keyword in text:
detected.append(keyword)
return detected
def detect_hallucination_entities(self, text: str, context: str) -> List[str]:
"""
Detect hallucinated entities that don't appear in the context
Args:
text: Generated text
context: Source context
Returns:
List of potentially hallucinated entities
"""
# 这是一个简化的实现,实际应用中可以使用更复杂的实体识别技术
# 提取文本中的实体(这里简化为提取名词短语)
import re
# 提取生成文本中的可能实体(简单实现)
generated_entities = re.findall(r'[A-Za-z0-9\u4e00-\u9fff]{2,}', text)
# 提取上下文中的实体
context_entities = re.findall(r'[A-Za-z0-9\u4e00-\u9fff]{2,}', context)
# 找出生成文本中但上下文中没有的实体
hallucinated_entities = []
for entity in generated_entities:
if entity not in context_entities and len(entity) > 2:
# 过滤掉一些常见的通用词
common_words = {"可以", "能够", "通过", "进行", "提供", "支持", "包括", "以及", "或者", "但是", "然而", "因此", "所以"}
if entity not in common_words:
hallucinated_entities.append(entity)
return hallucinated_entities
def comprehensive_hallucination_check(self, text: str, context: str) -> Tuple[List[str], List[str]]:
"""
Comprehensive hallucination detection
Args:
text: Generated text
context: Source context
Returns:
Tuple of (hallucination_keywords, hallucinated_entities)
"""
keywords = self.detect_hallucination_keywords(text)
entities = self.detect_hallucination_entities(text, context)
return keywords, entities