166 lines
6.4 KiB
Python
166 lines
6.4 KiB
Python
import os
|
||
import json
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import Dict, Any, List, Optional
|
||
import csv
|
||
|
||
|
||
class GenerationLogger:
|
||
def __init__(self, log_folder: str = "logs"):
|
||
self.log_folder = log_folder
|
||
if not os.path.exists(log_folder):
|
||
os.makedirs(log_folder)
|
||
|
||
# 初始化日志文件
|
||
self.log_file = os.path.join(log_folder, "generation_log.json")
|
||
self.csv_file = os.path.join(log_folder, "generation_log.csv")
|
||
|
||
# 设置Python标准日志
|
||
self.logger = logging.getLogger('generation_logger')
|
||
self.logger.setLevel(logging.INFO)
|
||
|
||
# 创建文件处理器
|
||
if not self.logger.handlers:
|
||
handler = logging.FileHandler(os.path.join(log_folder, 'generation.log'), encoding='utf-8')
|
||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
handler.setFormatter(formatter)
|
||
self.logger.addHandler(handler)
|
||
|
||
# 如果CSV文件不存在,创建并写入表头
|
||
if not os.path.exists(self.csv_file):
|
||
with open(self.csv_file, 'w', newline='', encoding='utf-8') as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow([
|
||
'timestamp', 'query', 'generated_text', 'style',
|
||
'total_score', 'relevance', 'accuracy', 'completeness', 'fluency',
|
||
'hallucination_keywords', 'hallucination_entities'
|
||
])
|
||
|
||
def log_generation(self, query: str, generated_text: str, style: str,
|
||
source_segments: List[Any], score: Optional[Dict[str, Any]] = None,
|
||
hallucination_warnings: Optional[Dict[str, List[Any]]] = None):
|
||
"""
|
||
Log a generation event
|
||
|
||
Args:
|
||
query: The query used for generation
|
||
generated_text: The generated text
|
||
style: The style used for generation
|
||
source_segments: The source segments used
|
||
score: Optional score information
|
||
hallucination_warnings: Optional hallucination warnings
|
||
"""
|
||
# 记录到标准日志
|
||
self.logger.info(f"Generation request: {query[:50]}...")
|
||
|
||
# 创建日志条目
|
||
log_entry = {
|
||
"timestamp": datetime.now().isoformat(),
|
||
"query": query,
|
||
"generated_text": generated_text,
|
||
"style": style,
|
||
"source_segments_count": len(source_segments),
|
||
"score": score,
|
||
"hallucination_warnings": hallucination_warnings
|
||
}
|
||
|
||
# 写入JSON日志文件
|
||
self._write_json_log(log_entry)
|
||
|
||
# 写入CSV日志文件
|
||
self._write_csv_log(log_entry)
|
||
|
||
def _write_json_log(self, log_entry: Dict[str, Any]):
|
||
"""Write log entry to JSON file"""
|
||
# 读取现有日志
|
||
logs = []
|
||
if os.path.exists(self.log_file):
|
||
with open(self.log_file, 'r', encoding='utf-8') as f:
|
||
try:
|
||
logs = json.load(f)
|
||
except json.JSONDecodeError:
|
||
logs = []
|
||
|
||
# 添加新日志条目
|
||
logs.append(log_entry)
|
||
|
||
# 写入文件
|
||
with open(self.log_file, 'w', encoding='utf-8') as f:
|
||
json.dump(logs, f, ensure_ascii=False, indent=2)
|
||
|
||
def _write_csv_log(self, log_entry: Dict[str, Any]):
|
||
"""Write log entry to CSV file"""
|
||
with open(self.csv_file, 'a', newline='', encoding='utf-8') as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow([
|
||
log_entry["timestamp"],
|
||
log_entry["query"],
|
||
log_entry["generated_text"],
|
||
log_entry["style"],
|
||
log_entry["score"].get("total_score", "") if log_entry["score"] else "",
|
||
log_entry["score"].get("dimensions", {}).get("relevance", "") if log_entry["score"] else "",
|
||
log_entry["score"].get("dimensions", {}).get("accuracy", "") if log_entry["score"] else "",
|
||
log_entry["score"].get("dimensions", {}).get("completeness", "") if log_entry["score"] else "",
|
||
log_entry["score"].get("dimensions", {}).get("fluency", "") if log_entry["score"] else "",
|
||
", ".join(log_entry["hallucination_warnings"].get("keywords", [])) if log_entry["hallucination_warnings"] else "",
|
||
", ".join(log_entry["hallucination_warnings"].get("entities", [])) if log_entry["hallucination_warnings"] else ""
|
||
])
|
||
|
||
def get_logs(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||
"""
|
||
Get recent logs
|
||
|
||
Args:
|
||
limit: Maximum number of logs to return
|
||
|
||
Returns:
|
||
List of log entries
|
||
"""
|
||
if not os.path.exists(self.log_file):
|
||
return []
|
||
|
||
with open(self.log_file, 'r', encoding='utf-8') as f:
|
||
try:
|
||
logs = json.load(f)
|
||
# 返回最近的记录
|
||
return logs[-limit:]
|
||
except json.JSONDecodeError:
|
||
return []
|
||
|
||
def clear_logs(self):
|
||
"""Clear all logs"""
|
||
if os.path.exists(self.log_file):
|
||
os.remove(self.log_file)
|
||
if os.path.exists(self.csv_file):
|
||
os.remove(self.csv_file)
|
||
|
||
# 清空标准日志
|
||
for handler in self.logger.handlers:
|
||
if isinstance(handler, logging.FileHandler):
|
||
handler.stream.truncate(0)
|
||
|
||
# 重新创建日志文件夹(如果不存在)
|
||
if not os.path.exists(self.log_folder):
|
||
os.makedirs(self.log_folder)
|
||
|
||
# 重新创建JSON日志文件
|
||
with open(self.log_file, 'w', encoding='utf-8') as f:
|
||
json.dump([], f)
|
||
|
||
# 重新创建CSV文件和表头
|
||
with open(self.csv_file, 'w', newline='', encoding='utf-8') as f:
|
||
writer = csv.writer(f)
|
||
writer.writerow([
|
||
'timestamp', 'query', 'generated_text', 'style',
|
||
'total_score', 'relevance', 'accuracy', 'completeness', 'fluency',
|
||
'hallucination_keywords', 'hallucination_entities'
|
||
])
|
||
|
||
def log_info(self, info_msg: str):
|
||
"""记录信息日志"""
|
||
self.logger.info(info_msg)
|
||
|
||
# Create global logger instance
|
||
logger = GenerationLogger()
|