232 lines
8.0 KiB
Python
232 lines
8.0 KiB
Python
|
|
from typing import List, Dict, Optional, Tuple
|
|||
|
|
from config import settings
|
|||
|
|
from model_manager import model_router
|
|||
|
|
from style_templates import style_manager
|
|||
|
|
from exceptions import GenerationError
|
|||
|
|
|
|||
|
|
class TextGenerator:
|
|||
|
|
def __init__(self):
|
|||
|
|
self.client = self._initialize_client()
|
|||
|
|
|
|||
|
|
def _initialize_client(self):
|
|||
|
|
"""Initialize API client - 使用模型路由器,不需要单独初始化客户端"""
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def generate_text(self, context: str, style: str = "通用文案",
|
|||
|
|
min_length: int = 50, max_length: int = 200,
|
|||
|
|
history: Optional[List[Dict]] = None) -> str:
|
|||
|
|
"""
|
|||
|
|
Generate text based on context and style
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
context: Retrieved context information
|
|||
|
|
style: Writing style (e.g., 小红书种草风, 官方通告, 知乎科普)
|
|||
|
|
min_length: Minimum text length
|
|||
|
|
max_length: Maximum text length
|
|||
|
|
history: Conversation history
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Generated text
|
|||
|
|
"""
|
|||
|
|
# 检查是否有可用的模型提供商
|
|||
|
|
if not model_router.providers:
|
|||
|
|
raise GenerationError("未配置任何模型提供商")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Get style template
|
|||
|
|
template_info = style_manager.get_template(style)
|
|||
|
|
prompt_template = template_info["template"]
|
|||
|
|
temperature = template_info["temperature"]
|
|||
|
|
|
|||
|
|
# Build prompt
|
|||
|
|
prompt = prompt_template.format(
|
|||
|
|
context=context,
|
|||
|
|
min_length=min_length,
|
|||
|
|
max_length=max_length
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Prepare messages
|
|||
|
|
messages = []
|
|||
|
|
|
|||
|
|
# Add history if provided
|
|||
|
|
if history:
|
|||
|
|
# Only keep last 2 rounds of conversation
|
|||
|
|
recent_history = history[-4:] # 2 rounds = 4 messages (user/assistant pairs)
|
|||
|
|
messages.extend(recent_history)
|
|||
|
|
|
|||
|
|
# Add current prompt
|
|||
|
|
messages.append({"role": "user", "content": prompt})
|
|||
|
|
|
|||
|
|
# Calculate max tokens based on max length
|
|||
|
|
max_tokens = max_length * settings.MAX_TOKENS_FACTOR
|
|||
|
|
|
|||
|
|
# Generate text using model router
|
|||
|
|
content = model_router.generate_text(
|
|||
|
|
model=settings.GENERATION_MODEL,
|
|||
|
|
messages=messages,
|
|||
|
|
temperature=temperature,
|
|||
|
|
max_tokens=max_tokens
|
|||
|
|
)
|
|||
|
|
return content
|
|||
|
|
except Exception as e:
|
|||
|
|
raise GenerationError(f"文本生成失败: {str(e)}")
|
|||
|
|
|
|||
|
|
def score_generation(self, generated_text: str, context: str, query: str) -> Dict:
|
|||
|
|
"""
|
|||
|
|
Score the quality of generated text using GPT-4
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
generated_text: The generated text to score
|
|||
|
|
context: The source context
|
|||
|
|
query: The original query
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Dictionary containing score and feedback
|
|||
|
|
"""
|
|||
|
|
# 检查是否有可用的模型提供商
|
|||
|
|
if not model_router.providers:
|
|||
|
|
return {
|
|||
|
|
"score": 0,
|
|||
|
|
"feedback": "未配置任何模型提供商"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Create scoring prompt
|
|||
|
|
prompt = f"""请对以下生成的文本进行评分,评估其与原始查询和上下文的一致性:
|
|||
|
|
|
|||
|
|
原始查询: {query}
|
|||
|
|
|
|||
|
|
上下文信息: {context}
|
|||
|
|
|
|||
|
|
生成文本: {generated_text}
|
|||
|
|
|
|||
|
|
请从以下维度进行评分(满分100分):
|
|||
|
|
1. 相关性(30分):生成内容与查询的相关程度
|
|||
|
|
2. 准确性(30分):生成内容与上下文信息的一致程度
|
|||
|
|
3. 完整性(20分):是否充分回答了查询
|
|||
|
|
4. 流畅性(20分):语言表达是否自然流畅
|
|||
|
|
|
|||
|
|
请提供:
|
|||
|
|
- 总分(0-100)
|
|||
|
|
- 各维度得分
|
|||
|
|
- 简要反馈意见
|
|||
|
|
- 改进建议
|
|||
|
|
|
|||
|
|
请以以下JSON格式返回结果:
|
|||
|
|
{{
|
|||
|
|
"total_score": 85,
|
|||
|
|
"dimensions": {{
|
|||
|
|
"relevance": 25,
|
|||
|
|
"accuracy": 28,
|
|||
|
|
"completeness": 18,
|
|||
|
|
"fluency": 14
|
|||
|
|
}},
|
|||
|
|
"feedback": "生成内容与查询相关,但可以更详细...",
|
|||
|
|
"suggestions": "建议增加更多具体示例..."
|
|||
|
|
}}
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# Generate score using model router with GPT-4
|
|||
|
|
content = model_router.generate_text(
|
|||
|
|
model="gpt-4",
|
|||
|
|
messages=[
|
|||
|
|
{"role": "system", "content": "你是一个专业的文本质量评估专家。"},
|
|||
|
|
{"role": "user", "content": prompt}
|
|||
|
|
],
|
|||
|
|
temperature=0.1,
|
|||
|
|
max_tokens=500
|
|||
|
|
)
|
|||
|
|
if content:
|
|||
|
|
import json
|
|||
|
|
try:
|
|||
|
|
# Try to parse as JSON
|
|||
|
|
score_data = json.loads(content)
|
|||
|
|
return score_data
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
# If JSON parsing fails, return as feedback
|
|||
|
|
return {
|
|||
|
|
"score": 0,
|
|||
|
|
"feedback": content
|
|||
|
|
}
|
|||
|
|
else:
|
|||
|
|
return {
|
|||
|
|
"score": 0,
|
|||
|
|
"feedback": "评分生成失败"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return {
|
|||
|
|
"score": 0,
|
|||
|
|
"feedback": f"评分生成错误: {str(e)}"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def detect_hallucination_keywords(self, text: str) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
Detect hallucination keywords in generated text
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: Generated text
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
List of detected hallucination keywords
|
|||
|
|
"""
|
|||
|
|
hallucination_keywords = [
|
|||
|
|
"据悉", "据报道", "研究表明", "据专家称", "有消息称",
|
|||
|
|
"据了解", "据分析", "据预测", "据估计", "据透露",
|
|||
|
|
"可能", "也许", "大概", "似乎", "看起来",
|
|||
|
|
"普遍认为", "大多数人认为", "通常情况下"
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
detected = []
|
|||
|
|
for keyword in hallucination_keywords:
|
|||
|
|
if keyword in text:
|
|||
|
|
detected.append(keyword)
|
|||
|
|
|
|||
|
|
return detected
|
|||
|
|
|
|||
|
|
def detect_hallucination_entities(self, text: str, context: str) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
Detect hallucinated entities that don't appear in the context
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: Generated text
|
|||
|
|
context: Source context
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
List of potentially hallucinated entities
|
|||
|
|
"""
|
|||
|
|
# 这是一个简化的实现,实际应用中可以使用更复杂的实体识别技术
|
|||
|
|
# 提取文本中的实体(这里简化为提取名词短语)
|
|||
|
|
import re
|
|||
|
|
|
|||
|
|
# 提取生成文本中的可能实体(简单实现)
|
|||
|
|
generated_entities = re.findall(r'[A-Za-z0-9\u4e00-\u9fff]{2,}', text)
|
|||
|
|
|
|||
|
|
# 提取上下文中的实体
|
|||
|
|
context_entities = re.findall(r'[A-Za-z0-9\u4e00-\u9fff]{2,}', context)
|
|||
|
|
|
|||
|
|
# 找出生成文本中但上下文中没有的实体
|
|||
|
|
hallucinated_entities = []
|
|||
|
|
for entity in generated_entities:
|
|||
|
|
if entity not in context_entities and len(entity) > 2:
|
|||
|
|
# 过滤掉一些常见的通用词
|
|||
|
|
common_words = {"可以", "能够", "通过", "进行", "提供", "支持", "包括", "以及", "或者", "但是", "然而", "因此", "所以"}
|
|||
|
|
if entity not in common_words:
|
|||
|
|
hallucinated_entities.append(entity)
|
|||
|
|
|
|||
|
|
return hallucinated_entities
|
|||
|
|
|
|||
|
|
def comprehensive_hallucination_check(self, text: str, context: str) -> Tuple[List[str], List[str]]:
|
|||
|
|
"""
|
|||
|
|
Comprehensive hallucination detection
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
text: Generated text
|
|||
|
|
context: Source context
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Tuple of (hallucination_keywords, hallucinated_entities)
|
|||
|
|
"""
|
|||
|
|
keywords = self.detect_hallucination_keywords(text)
|
|||
|
|
entities = self.detect_hallucination_entities(text, context)
|
|||
|
|
return keywords, entities
|