159 lines
6.3 KiB
Python
159 lines
6.3 KiB
Python
import json
|
|
import os
|
|
from typing import Dict, Any
|
|
from exceptions import ConfigurationError
|
|
import time
|
|
|
|
|
|
class ConfigManager:
|
|
def __init__(self, config_file: str = "user_config.json"):
|
|
self.config_file = config_file
|
|
self.last_modified = 0
|
|
self.default_config = {
|
|
# 系统配置
|
|
"max_files": 20,
|
|
"top_k": 8,
|
|
"batch_size": 50,
|
|
|
|
# 生成配置
|
|
"default_style": "通用文案",
|
|
"default_min_length": 50,
|
|
"default_max_length": 200,
|
|
"temperature": 0.25,
|
|
|
|
# 检索配置
|
|
"default_search_type": "hybrid",
|
|
"bm25_weight": 0.5,
|
|
"include_context": True,
|
|
|
|
# 导出配置
|
|
"default_export_format": "markdown",
|
|
|
|
# 日志配置
|
|
"enable_logging": True,
|
|
|
|
# API配置
|
|
"embedding_model": "text-embedding-ada-002",
|
|
"generation_model": "gpt-3.5-turbo",
|
|
"embedding_provider": "openai",
|
|
"generation_provider": "openai",
|
|
"openai_api_base": "https://api.openai.com/v1",
|
|
"anthropic_api_base": "https://api.anthropic.com/v1",
|
|
"qwen_api_base": "https://dashscope.aliyuncs.com/api/v1",
|
|
|
|
# 第三方模型提供商配置
|
|
"openrouter_api_key": "",
|
|
"openrouter_api_base": "https://openrouter.ai/api/v1",
|
|
"siliconflow_api_key": "",
|
|
"siliconflow_api_base": "https://api.siliconflow.cn/v1"
|
|
}
|
|
self.config = self.load_config()
|
|
|
|
def validate_config(self, config: Dict[str, Any]) -> bool:
|
|
"""验证配置项的有效性"""
|
|
# 验证数值范围
|
|
if "max_files" in config and not (1 <= config["max_files"] <= 100):
|
|
raise ConfigurationError("max_files must be between 1 and 100")
|
|
|
|
if "top_k" in config and not (1 <= config["top_k"] <= 50):
|
|
raise ConfigurationError("top_k must be between 1 and 50")
|
|
|
|
if "batch_size" in config and not (1 <= config["batch_size"] <= 100):
|
|
raise ConfigurationError("batch_size must be between 1 and 100")
|
|
|
|
if "temperature" in config and not (0.0 <= config["temperature"] <= 1.0):
|
|
raise ConfigurationError("temperature must be between 0.0 and 1.0")
|
|
|
|
if "bm25_weight" in config and not (0.0 <= config["bm25_weight"] <= 1.0):
|
|
raise ConfigurationError("bm25_weight must be between 0.0 and 1.0")
|
|
|
|
# 验证枚举值
|
|
valid_search_types = ["hybrid", "vector"]
|
|
if "default_search_type" in config and config["default_search_type"] not in valid_search_types:
|
|
raise ConfigurationError(f"default_search_type must be one of {valid_search_types}")
|
|
|
|
valid_export_formats = ["markdown", "docx"]
|
|
if "default_export_format" in config and config["default_export_format"] not in valid_export_formats:
|
|
raise ConfigurationError(f"default_export_format must be one of {valid_export_formats}")
|
|
|
|
# 验证提供商
|
|
valid_providers = ["openai", "anthropic", "qwen", "openrouter", "siliconflow"]
|
|
if "embedding_provider" in config and config["embedding_provider"] not in valid_providers:
|
|
raise ConfigurationError(f"embedding_provider must be one of {valid_providers}")
|
|
if "generation_provider" in config and config["generation_provider"] not in valid_providers:
|
|
raise ConfigurationError(f"generation_provider must be one of {valid_providers}")
|
|
|
|
return True
|
|
|
|
def check_config_update(self):
|
|
"""检查配置文件是否更新"""
|
|
if os.path.exists(self.config_file):
|
|
modified_time = os.path.getmtime(self.config_file)
|
|
if modified_time > self.last_modified:
|
|
self.last_modified = modified_time
|
|
self.config = self.load_config()
|
|
return True
|
|
return False
|
|
|
|
def load_config(self) -> Dict[str, Any]:
|
|
"""加载用户配置"""
|
|
if os.path.exists(self.config_file):
|
|
try:
|
|
with open(self.config_file, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
# 确保所有默认配置项都存在
|
|
for key, value in self.default_config.items():
|
|
if key not in config:
|
|
config[key] = value
|
|
return config
|
|
except (json.JSONDecodeError, IOError):
|
|
# 如果配置文件损坏,返回默认配置
|
|
return self.default_config.copy()
|
|
else:
|
|
# 如果配置文件不存在,返回默认配置
|
|
return self.default_config.copy()
|
|
|
|
def save_config(self, config: Dict[str, Any]) -> bool:
|
|
"""保存用户配置"""
|
|
try:
|
|
# 验证配置
|
|
self.validate_config(config)
|
|
|
|
# 合并配置,确保不会丢失任何配置项
|
|
merged_config = self.default_config.copy()
|
|
merged_config.update(config)
|
|
|
|
with open(self.config_file, 'w', encoding='utf-8') as f:
|
|
json.dump(merged_config, f, ensure_ascii=False, indent=2)
|
|
|
|
self.config = merged_config
|
|
return True
|
|
except (IOError, ConfigurationError) as e:
|
|
print(f"保存配置失败: {e}")
|
|
return False
|
|
|
|
def get_config(self, key: str, default=None):
|
|
"""获取配置项"""
|
|
# 检查配置文件是否更新
|
|
self.check_config_update()
|
|
return self.config.get(key, default)
|
|
|
|
def set_config(self, key: str, value: Any) -> bool:
|
|
"""设置配置项"""
|
|
self.config[key] = value
|
|
return self.save_config(self.config)
|
|
|
|
def get_all_config(self) -> Dict[str, Any]:
|
|
"""获取所有配置"""
|
|
# 检查配置文件是否更新
|
|
self.check_config_update()
|
|
return self.config.copy()
|
|
|
|
def reset_to_default(self) -> bool:
|
|
"""重置为默认配置"""
|
|
self.config = self.default_config.copy()
|
|
return self.save_config(self.config)
|
|
|
|
|
|
# 创建全局配置管理器实例
|
|
config_manager = ConfigManager() |