232 lines
8.0 KiB
Python
232 lines
8.0 KiB
Python
from typing import List, Dict, Optional, Tuple
|
||
from config import settings
|
||
from model_manager import model_router
|
||
from style_templates import style_manager
|
||
from exceptions import GenerationError
|
||
|
||
class TextGenerator:
|
||
def __init__(self):
|
||
self.client = self._initialize_client()
|
||
|
||
def _initialize_client(self):
|
||
"""Initialize API client - 使用模型路由器,不需要单独初始化客户端"""
|
||
return None
|
||
|
||
def generate_text(self, context: str, style: str = "通用文案",
|
||
min_length: int = 50, max_length: int = 200,
|
||
history: Optional[List[Dict]] = None) -> str:
|
||
"""
|
||
Generate text based on context and style
|
||
|
||
Args:
|
||
context: Retrieved context information
|
||
style: Writing style (e.g., 小红书种草风, 官方通告, 知乎科普)
|
||
min_length: Minimum text length
|
||
max_length: Maximum text length
|
||
history: Conversation history
|
||
|
||
Returns:
|
||
Generated text
|
||
"""
|
||
# 检查是否有可用的模型提供商
|
||
if not model_router.providers:
|
||
raise GenerationError("未配置任何模型提供商")
|
||
|
||
try:
|
||
# Get style template
|
||
template_info = style_manager.get_template(style)
|
||
prompt_template = template_info["template"]
|
||
temperature = template_info["temperature"]
|
||
|
||
# Build prompt
|
||
prompt = prompt_template.format(
|
||
context=context,
|
||
min_length=min_length,
|
||
max_length=max_length
|
||
)
|
||
|
||
# Prepare messages
|
||
messages = []
|
||
|
||
# Add history if provided
|
||
if history:
|
||
# Only keep last 2 rounds of conversation
|
||
recent_history = history[-4:] # 2 rounds = 4 messages (user/assistant pairs)
|
||
messages.extend(recent_history)
|
||
|
||
# Add current prompt
|
||
messages.append({"role": "user", "content": prompt})
|
||
|
||
# Calculate max tokens based on max length
|
||
max_tokens = max_length * settings.MAX_TOKENS_FACTOR
|
||
|
||
# Generate text using model router
|
||
content = model_router.generate_text(
|
||
model=settings.GENERATION_MODEL,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens
|
||
)
|
||
return content
|
||
except Exception as e:
|
||
raise GenerationError(f"文本生成失败: {str(e)}")
|
||
|
||
def score_generation(self, generated_text: str, context: str, query: str) -> Dict:
|
||
"""
|
||
Score the quality of generated text using GPT-4
|
||
|
||
Args:
|
||
generated_text: The generated text to score
|
||
context: The source context
|
||
query: The original query
|
||
|
||
Returns:
|
||
Dictionary containing score and feedback
|
||
"""
|
||
# 检查是否有可用的模型提供商
|
||
if not model_router.providers:
|
||
return {
|
||
"score": 0,
|
||
"feedback": "未配置任何模型提供商"
|
||
}
|
||
|
||
try:
|
||
# Create scoring prompt
|
||
prompt = f"""请对以下生成的文本进行评分,评估其与原始查询和上下文的一致性:
|
||
|
||
原始查询: {query}
|
||
|
||
上下文信息: {context}
|
||
|
||
生成文本: {generated_text}
|
||
|
||
请从以下维度进行评分(满分100分):
|
||
1. 相关性(30分):生成内容与查询的相关程度
|
||
2. 准确性(30分):生成内容与上下文信息的一致程度
|
||
3. 完整性(20分):是否充分回答了查询
|
||
4. 流畅性(20分):语言表达是否自然流畅
|
||
|
||
请提供:
|
||
- 总分(0-100)
|
||
- 各维度得分
|
||
- 简要反馈意见
|
||
- 改进建议
|
||
|
||
请以以下JSON格式返回结果:
|
||
{{
|
||
"total_score": 85,
|
||
"dimensions": {{
|
||
"relevance": 25,
|
||
"accuracy": 28,
|
||
"completeness": 18,
|
||
"fluency": 14
|
||
}},
|
||
"feedback": "生成内容与查询相关,但可以更详细...",
|
||
"suggestions": "建议增加更多具体示例..."
|
||
}}
|
||
"""
|
||
|
||
# Generate score using model router with GPT-4
|
||
content = model_router.generate_text(
|
||
model="gpt-4",
|
||
messages=[
|
||
{"role": "system", "content": "你是一个专业的文本质量评估专家。"},
|
||
{"role": "user", "content": prompt}
|
||
],
|
||
temperature=0.1,
|
||
max_tokens=500
|
||
)
|
||
if content:
|
||
import json
|
||
try:
|
||
# Try to parse as JSON
|
||
score_data = json.loads(content)
|
||
return score_data
|
||
except json.JSONDecodeError:
|
||
# If JSON parsing fails, return as feedback
|
||
return {
|
||
"score": 0,
|
||
"feedback": content
|
||
}
|
||
else:
|
||
return {
|
||
"score": 0,
|
||
"feedback": "评分生成失败"
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
"score": 0,
|
||
"feedback": f"评分生成错误: {str(e)}"
|
||
}
|
||
|
||
def detect_hallucination_keywords(self, text: str) -> List[str]:
|
||
"""
|
||
Detect hallucination keywords in generated text
|
||
|
||
Args:
|
||
text: Generated text
|
||
|
||
Returns:
|
||
List of detected hallucination keywords
|
||
"""
|
||
hallucination_keywords = [
|
||
"据悉", "据报道", "研究表明", "据专家称", "有消息称",
|
||
"据了解", "据分析", "据预测", "据估计", "据透露",
|
||
"可能", "也许", "大概", "似乎", "看起来",
|
||
"普遍认为", "大多数人认为", "通常情况下"
|
||
]
|
||
|
||
detected = []
|
||
for keyword in hallucination_keywords:
|
||
if keyword in text:
|
||
detected.append(keyword)
|
||
|
||
return detected
|
||
|
||
def detect_hallucination_entities(self, text: str, context: str) -> List[str]:
|
||
"""
|
||
Detect hallucinated entities that don't appear in the context
|
||
|
||
Args:
|
||
text: Generated text
|
||
context: Source context
|
||
|
||
Returns:
|
||
List of potentially hallucinated entities
|
||
"""
|
||
# 这是一个简化的实现,实际应用中可以使用更复杂的实体识别技术
|
||
# 提取文本中的实体(这里简化为提取名词短语)
|
||
import re
|
||
|
||
# 提取生成文本中的可能实体(简单实现)
|
||
generated_entities = re.findall(r'[A-Za-z0-9\u4e00-\u9fff]{2,}', text)
|
||
|
||
# 提取上下文中的实体
|
||
context_entities = re.findall(r'[A-Za-z0-9\u4e00-\u9fff]{2,}', context)
|
||
|
||
# 找出生成文本中但上下文中没有的实体
|
||
hallucinated_entities = []
|
||
for entity in generated_entities:
|
||
if entity not in context_entities and len(entity) > 2:
|
||
# 过滤掉一些常见的通用词
|
||
common_words = {"可以", "能够", "通过", "进行", "提供", "支持", "包括", "以及", "或者", "但是", "然而", "因此", "所以"}
|
||
if entity not in common_words:
|
||
hallucinated_entities.append(entity)
|
||
|
||
return hallucinated_entities
|
||
|
||
def comprehensive_hallucination_check(self, text: str, context: str) -> Tuple[List[str], List[str]]:
|
||
"""
|
||
Comprehensive hallucination detection
|
||
|
||
Args:
|
||
text: Generated text
|
||
context: Source context
|
||
|
||
Returns:
|
||
Tuple of (hallucination_keywords, hallucinated_entities)
|
||
"""
|
||
keywords = self.detect_hallucination_keywords(text)
|
||
entities = self.detect_hallucination_entities(text, context)
|
||
return keywords, entities |