import json import os from typing import Dict, Any from exceptions import ConfigurationError import time class ConfigManager: def __init__(self, config_file: str = "user_config.json"): self.config_file = config_file self.last_modified = 0 self.default_config = { # 系统配置 "max_files": 20, "top_k": 8, "batch_size": 50, # 生成配置 "default_style": "通用文案", "default_min_length": 50, "default_max_length": 200, "temperature": 0.25, # 检索配置 "default_search_type": "hybrid", "bm25_weight": 0.5, "include_context": True, # 导出配置 "default_export_format": "markdown", # 日志配置 "enable_logging": True, # API配置 "embedding_model": "text-embedding-ada-002", "generation_model": "gpt-3.5-turbo", "embedding_provider": "openai", "generation_provider": "openai", "openai_api_base": "https://api.openai.com/v1", "anthropic_api_base": "https://api.anthropic.com/v1", "qwen_api_base": "https://dashscope.aliyuncs.com/api/v1", # 第三方模型提供商配置 "openrouter_api_key": "", "openrouter_api_base": "https://openrouter.ai/api/v1", "siliconflow_api_key": "", "siliconflow_api_base": "https://api.siliconflow.cn/v1" } self.config = self.load_config() def validate_config(self, config: Dict[str, Any]) -> bool: """验证配置项的有效性""" # 验证数值范围 if "max_files" in config and not (1 <= config["max_files"] <= 100): raise ConfigurationError("max_files must be between 1 and 100") if "top_k" in config and not (1 <= config["top_k"] <= 50): raise ConfigurationError("top_k must be between 1 and 50") if "batch_size" in config and not (1 <= config["batch_size"] <= 100): raise ConfigurationError("batch_size must be between 1 and 100") if "temperature" in config and not (0.0 <= config["temperature"] <= 1.0): raise ConfigurationError("temperature must be between 0.0 and 1.0") if "bm25_weight" in config and not (0.0 <= config["bm25_weight"] <= 1.0): raise ConfigurationError("bm25_weight must be between 0.0 and 1.0") # 验证枚举值 valid_search_types = ["hybrid", "vector"] if "default_search_type" in config and config["default_search_type"] not in valid_search_types: raise ConfigurationError(f"default_search_type must be one of {valid_search_types}") valid_export_formats = ["markdown", "docx"] if "default_export_format" in config and config["default_export_format"] not in valid_export_formats: raise ConfigurationError(f"default_export_format must be one of {valid_export_formats}") # 验证提供商 valid_providers = ["openai", "anthropic", "qwen", "openrouter", "siliconflow"] if "embedding_provider" in config and config["embedding_provider"] not in valid_providers: raise ConfigurationError(f"embedding_provider must be one of {valid_providers}") if "generation_provider" in config and config["generation_provider"] not in valid_providers: raise ConfigurationError(f"generation_provider must be one of {valid_providers}") return True def check_config_update(self): """检查配置文件是否更新""" if os.path.exists(self.config_file): modified_time = os.path.getmtime(self.config_file) if modified_time > self.last_modified: self.last_modified = modified_time self.config = self.load_config() return True return False def load_config(self) -> Dict[str, Any]: """加载用户配置""" if os.path.exists(self.config_file): try: with open(self.config_file, 'r', encoding='utf-8') as f: config = json.load(f) # 确保所有默认配置项都存在 for key, value in self.default_config.items(): if key not in config: config[key] = value return config except (json.JSONDecodeError, IOError): # 如果配置文件损坏,返回默认配置 return self.default_config.copy() else: # 如果配置文件不存在,返回默认配置 return self.default_config.copy() def save_config(self, config: Dict[str, Any]) -> bool: """保存用户配置""" try: # 验证配置 self.validate_config(config) # 合并配置,确保不会丢失任何配置项 merged_config = self.default_config.copy() merged_config.update(config) with open(self.config_file, 'w', encoding='utf-8') as f: json.dump(merged_config, f, ensure_ascii=False, indent=2) self.config = merged_config return True except (IOError, ConfigurationError) as e: print(f"保存配置失败: {e}") return False def get_config(self, key: str, default=None): """获取配置项""" # 检查配置文件是否更新 self.check_config_update() return self.config.get(key, default) def set_config(self, key: str, value: Any) -> bool: """设置配置项""" self.config[key] = value return self.save_config(self.config) def get_all_config(self) -> Dict[str, Any]: """获取所有配置""" # 检查配置文件是否更新 self.check_config_update() return self.config.copy() def reset_to_default(self) -> bool: """重置为默认配置""" self.config = self.default_config.copy() return self.save_config(self.config) # 创建全局配置管理器实例 config_manager = ConfigManager()