nodebookls/config_manager.py
2025-10-29 13:56:24 +08:00

159 lines
6.3 KiB
Python

import json
import os
from typing import Dict, Any
from exceptions import ConfigurationError
import time
class ConfigManager:
def __init__(self, config_file: str = "user_config.json"):
self.config_file = config_file
self.last_modified = 0
self.default_config = {
# 系统配置
"max_files": 20,
"top_k": 8,
"batch_size": 50,
# 生成配置
"default_style": "通用文案",
"default_min_length": 50,
"default_max_length": 200,
"temperature": 0.25,
# 检索配置
"default_search_type": "hybrid",
"bm25_weight": 0.5,
"include_context": True,
# 导出配置
"default_export_format": "markdown",
# 日志配置
"enable_logging": True,
# API配置
"embedding_model": "text-embedding-ada-002",
"generation_model": "gpt-3.5-turbo",
"embedding_provider": "openai",
"generation_provider": "openai",
"openai_api_base": "https://api.openai.com/v1",
"anthropic_api_base": "https://api.anthropic.com/v1",
"qwen_api_base": "https://dashscope.aliyuncs.com/api/v1",
# 第三方模型提供商配置
"openrouter_api_key": "",
"openrouter_api_base": "https://openrouter.ai/api/v1",
"siliconflow_api_key": "",
"siliconflow_api_base": "https://api.siliconflow.cn/v1"
}
self.config = self.load_config()
def validate_config(self, config: Dict[str, Any]) -> bool:
"""验证配置项的有效性"""
# 验证数值范围
if "max_files" in config and not (1 <= config["max_files"] <= 100):
raise ConfigurationError("max_files must be between 1 and 100")
if "top_k" in config and not (1 <= config["top_k"] <= 50):
raise ConfigurationError("top_k must be between 1 and 50")
if "batch_size" in config and not (1 <= config["batch_size"] <= 100):
raise ConfigurationError("batch_size must be between 1 and 100")
if "temperature" in config and not (0.0 <= config["temperature"] <= 1.0):
raise ConfigurationError("temperature must be between 0.0 and 1.0")
if "bm25_weight" in config and not (0.0 <= config["bm25_weight"] <= 1.0):
raise ConfigurationError("bm25_weight must be between 0.0 and 1.0")
# 验证枚举值
valid_search_types = ["hybrid", "vector"]
if "default_search_type" in config and config["default_search_type"] not in valid_search_types:
raise ConfigurationError(f"default_search_type must be one of {valid_search_types}")
valid_export_formats = ["markdown", "docx"]
if "default_export_format" in config and config["default_export_format"] not in valid_export_formats:
raise ConfigurationError(f"default_export_format must be one of {valid_export_formats}")
# 验证提供商
valid_providers = ["openai", "anthropic", "qwen", "openrouter", "siliconflow"]
if "embedding_provider" in config and config["embedding_provider"] not in valid_providers:
raise ConfigurationError(f"embedding_provider must be one of {valid_providers}")
if "generation_provider" in config and config["generation_provider"] not in valid_providers:
raise ConfigurationError(f"generation_provider must be one of {valid_providers}")
return True
def check_config_update(self):
"""检查配置文件是否更新"""
if os.path.exists(self.config_file):
modified_time = os.path.getmtime(self.config_file)
if modified_time > self.last_modified:
self.last_modified = modified_time
self.config = self.load_config()
return True
return False
def load_config(self) -> Dict[str, Any]:
"""加载用户配置"""
if os.path.exists(self.config_file):
try:
with open(self.config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
# 确保所有默认配置项都存在
for key, value in self.default_config.items():
if key not in config:
config[key] = value
return config
except (json.JSONDecodeError, IOError):
# 如果配置文件损坏,返回默认配置
return self.default_config.copy()
else:
# 如果配置文件不存在,返回默认配置
return self.default_config.copy()
def save_config(self, config: Dict[str, Any]) -> bool:
"""保存用户配置"""
try:
# 验证配置
self.validate_config(config)
# 合并配置,确保不会丢失任何配置项
merged_config = self.default_config.copy()
merged_config.update(config)
with open(self.config_file, 'w', encoding='utf-8') as f:
json.dump(merged_config, f, ensure_ascii=False, indent=2)
self.config = merged_config
return True
except (IOError, ConfigurationError) as e:
print(f"保存配置失败: {e}")
return False
def get_config(self, key: str, default=None):
"""获取配置项"""
# 检查配置文件是否更新
self.check_config_update()
return self.config.get(key, default)
def set_config(self, key: str, value: Any) -> bool:
"""设置配置项"""
self.config[key] = value
return self.save_config(self.config)
def get_all_config(self) -> Dict[str, Any]:
"""获取所有配置"""
# 检查配置文件是否更新
self.check_config_update()
return self.config.copy()
def reset_to_default(self) -> bool:
"""重置为默认配置"""
self.config = self.default_config.copy()
return self.save_config(self.config)
# 创建全局配置管理器实例
config_manager = ConfigManager()