nodebookls/generator.py
2025-10-29 13:56:24 +08:00

232 lines
8.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List, Dict, Optional, Tuple
from config import settings
from model_manager import model_router
from style_templates import style_manager
from exceptions import GenerationError
class TextGenerator:
def __init__(self):
self.client = self._initialize_client()
def _initialize_client(self):
"""Initialize API client - 使用模型路由器,不需要单独初始化客户端"""
return None
def generate_text(self, context: str, style: str = "通用文案",
min_length: int = 50, max_length: int = 200,
history: Optional[List[Dict]] = None) -> str:
"""
Generate text based on context and style
Args:
context: Retrieved context information
style: Writing style (e.g., 小红书种草风, 官方通告, 知乎科普)
min_length: Minimum text length
max_length: Maximum text length
history: Conversation history
Returns:
Generated text
"""
# 检查是否有可用的模型提供商
if not model_router.providers:
raise GenerationError("未配置任何模型提供商")
try:
# Get style template
template_info = style_manager.get_template(style)
prompt_template = template_info["template"]
temperature = template_info["temperature"]
# Build prompt
prompt = prompt_template.format(
context=context,
min_length=min_length,
max_length=max_length
)
# Prepare messages
messages = []
# Add history if provided
if history:
# Only keep last 2 rounds of conversation
recent_history = history[-4:] # 2 rounds = 4 messages (user/assistant pairs)
messages.extend(recent_history)
# Add current prompt
messages.append({"role": "user", "content": prompt})
# Calculate max tokens based on max length
max_tokens = max_length * settings.MAX_TOKENS_FACTOR
# Generate text using model router
content = model_router.generate_text(
model=settings.GENERATION_MODEL,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return content
except Exception as e:
raise GenerationError(f"文本生成失败: {str(e)}")
def score_generation(self, generated_text: str, context: str, query: str) -> Dict:
"""
Score the quality of generated text using GPT-4
Args:
generated_text: The generated text to score
context: The source context
query: The original query
Returns:
Dictionary containing score and feedback
"""
# 检查是否有可用的模型提供商
if not model_router.providers:
return {
"score": 0,
"feedback": "未配置任何模型提供商"
}
try:
# Create scoring prompt
prompt = f"""请对以下生成的文本进行评分,评估其与原始查询和上下文的一致性:
原始查询: {query}
上下文信息: {context}
生成文本: {generated_text}
请从以下维度进行评分满分100分
1. 相关性30分生成内容与查询的相关程度
2. 准确性30分生成内容与上下文信息的一致程度
3. 完整性20分是否充分回答了查询
4. 流畅性20分语言表达是否自然流畅
请提供:
- 总分0-100
- 各维度得分
- 简要反馈意见
- 改进建议
请以以下JSON格式返回结果
{{
"total_score": 85,
"dimensions": {{
"relevance": 25,
"accuracy": 28,
"completeness": 18,
"fluency": 14
}},
"feedback": "生成内容与查询相关,但可以更详细...",
"suggestions": "建议增加更多具体示例..."
}}
"""
# Generate score using model router with GPT-4
content = model_router.generate_text(
model="gpt-4",
messages=[
{"role": "system", "content": "你是一个专业的文本质量评估专家。"},
{"role": "user", "content": prompt}
],
temperature=0.1,
max_tokens=500
)
if content:
import json
try:
# Try to parse as JSON
score_data = json.loads(content)
return score_data
except json.JSONDecodeError:
# If JSON parsing fails, return as feedback
return {
"score": 0,
"feedback": content
}
else:
return {
"score": 0,
"feedback": "评分生成失败"
}
except Exception as e:
return {
"score": 0,
"feedback": f"评分生成错误: {str(e)}"
}
def detect_hallucination_keywords(self, text: str) -> List[str]:
"""
Detect hallucination keywords in generated text
Args:
text: Generated text
Returns:
List of detected hallucination keywords
"""
hallucination_keywords = [
"据悉", "据报道", "研究表明", "据专家称", "有消息称",
"据了解", "据分析", "据预测", "据估计", "据透露",
"可能", "也许", "大概", "似乎", "看起来",
"普遍认为", "大多数人认为", "通常情况下"
]
detected = []
for keyword in hallucination_keywords:
if keyword in text:
detected.append(keyword)
return detected
def detect_hallucination_entities(self, text: str, context: str) -> List[str]:
"""
Detect hallucinated entities that don't appear in the context
Args:
text: Generated text
context: Source context
Returns:
List of potentially hallucinated entities
"""
# 这是一个简化的实现,实际应用中可以使用更复杂的实体识别技术
# 提取文本中的实体(这里简化为提取名词短语)
import re
# 提取生成文本中的可能实体(简单实现)
generated_entities = re.findall(r'[A-Za-z0-9\u4e00-\u9fff]{2,}', text)
# 提取上下文中的实体
context_entities = re.findall(r'[A-Za-z0-9\u4e00-\u9fff]{2,}', context)
# 找出生成文本中但上下文中没有的实体
hallucinated_entities = []
for entity in generated_entities:
if entity not in context_entities and len(entity) > 2:
# 过滤掉一些常见的通用词
common_words = {"可以", "能够", "通过", "进行", "提供", "支持", "包括", "以及", "或者", "但是", "然而", "因此", "所以"}
if entity not in common_words:
hallucinated_entities.append(entity)
return hallucinated_entities
def comprehensive_hallucination_check(self, text: str, context: str) -> Tuple[List[str], List[str]]:
"""
Comprehensive hallucination detection
Args:
text: Generated text
context: Source context
Returns:
Tuple of (hallucination_keywords, hallucinated_entities)
"""
keywords = self.detect_hallucination_keywords(text)
entities = self.detect_hallucination_entities(text, context)
return keywords, entities