""" 改进的config.py - 消除全局变量,使用ConfigManager """ import configparser import getpass import logging import os import sys from pathlib import Path from logging.handlers import RotatingFileHandler from typing import Optional try: from dotenv import load_dotenv load_dotenv() except ImportError: pass from config_manager import config_manager CONFIG_FILE = "config.ini" def _get_env_value(key: str, default: str = "") -> str: """从环境变量获取配置值""" value = os.getenv(key, default) return value or default DEFAULT_CONFIG = { "General": { "chrome_user_dir": f"C:\\Users\\{getpass.getuser()}\\AppData\\Local\\Google\\Chrome\\User Data", "articles_path": "articles", "images_path": "picture", "title_file": "文章链接.xlsx", "max_threads": "3", "min_article_length": "100", "enable_plagiarism_detection": "false" }, "Coze": { "workflow_id": "", "access_token": "", "is_async": "false", "input_data_template": "{\"article\": \"{article_text}\", \"link\":\"{link}\", \"weijin\":\"{weijin}\"}", "last_used_template": "", "last_used_template_type": "文章" }, "Database": { "host": _get_env_value("DB_HOST", ""), "user": _get_env_value("DB_USER", ""), "password": _get_env_value("DB_PASSWORD", ""), "database": _get_env_value("DB_NAME", "toutiao") }, "Dify": { "api_key": _get_env_value("DIFY_API_KEY", ""), "user_id": _get_env_value("DIFY_USER_ID", "toutiao"), "url": _get_env_value("DIFY_URL", "http://27.106.125.150/v1/workflows/run"), "input_data_template": "{\"old_article\": \"{article_text}\"}" }, "Baidu": { "api_key": _get_env_value("BAIDU_API_KEY", ""), "secret_key": _get_env_value("BAIDU_SECRET_KEY", ""), "enable_detection": _get_env_value("BAIDU_ENABLE_DETECTION", "false"), "save_violation_articles": "true" }, "ImageModify": { "crop_percent": "0.02", "min_rotation": "0.3", "max_rotation": "3.0", "min_brightness": "0.8", "max_brightness": "1.2", "watermark_text": "Qin Quan Shan Chu", "watermark_opacity": "128", "overlay_opacity": "30" }, "Keywords": { "banned_words": "珠海,落马,股票,股市,股民,爆炸,火灾,死亡,抢劫,诈骗,习大大,习近平,政府,官员,扫黑,警察,落网,嫌疑人,通报,暴力执法,执法,暴力,气象,天气,暴雨,大雨" } } def create_default_config() -> bool: """创建默认配置文件""" config = configparser.ConfigParser() config.read_dict(DEFAULT_CONFIG) directories = ["articles", "picture", "data", "logs", "archive", "backups"] for directory in directories: Path(directory).mkdir(exist_ok=True) with open(CONFIG_FILE, 'w', encoding='utf-8') as f: config.write(f) print(f"已创建默认配置文件: {CONFIG_FILE}", file=sys.stderr) print("请编辑 config.ini 文件配置您的参数", file=sys.stderr) return True def load_config() -> configparser.ConfigParser: """加载配置文件""" config = configparser.ConfigParser() if not os.path.exists(CONFIG_FILE): for section, options in DEFAULT_CONFIG.items(): config[section] = options with open(CONFIG_FILE, 'w', encoding='utf-8') as f: config.write(f) else: config.read(CONFIG_FILE, encoding='utf-8') for section, options in DEFAULT_CONFIG.items(): if not config.has_section(section): config[section] = {} for option, value in options.items(): if not config.has_option(section, option): config[section][option] = value with open(CONFIG_FILE, 'w', encoding='utf-8') as f: config.write(f) return config def save_config(config: configparser.ConfigParser) -> None: """保存配置文件""" with open(CONFIG_FILE, 'w', encoding='utf-8') as f: config.write(f) CONFIG = load_config() # 向后兼容的全局变量 ARTICLES_BASE_PATH = CONFIG['General']['articles_path'] IMGS_BASE_PATH = CONFIG['General']['images_path'] TITLE_BASE_PATH = CONFIG['General']['title_file'] MAX_THREADS = int(CONFIG['General']['max_threads']) MIN_ARTICLE_LENGTH = int(CONFIG['General']['min_article_length']) ENABLE_PLAGIARISM_DETECTION = CONFIG['General'].get('enable_plagiarism_detection', 'false').lower() == 'true' # 创建必要的目录 directories = ["articles", "picture", "data", "logs", "archive", "backups"] for directory in directories: dir_path = Path(directory) if not dir_path.exists(): dir_path.mkdir(parents=True, exist_ok=True) try: dir_path.chmod(0o777) except (AttributeError, PermissionError): pass # 日志配置 - 使用RotatingFileHandler实现日志轮转 LOG_DIR = Path("logs") LOG_DIR.mkdir(exist_ok=True) LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s" DATE_FORMAT = "%Y-%m-%d %H:%M:%S" log_file = LOG_DIR / "article_replace.log" file_handler = RotatingFileHandler( log_file, maxBytes=10 * 1024 * 1024, backupCount=5, encoding='utf-8' ) file_handler.setLevel(logging.INFO) file_handler.setFormatter(logging.Formatter(LOG_FORMAT, DATE_FORMAT)) console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) console_handler.setFormatter(logging.Formatter(LOG_FORMAT, DATE_FORMAT)) logging.basicConfig( level=logging.INFO, format=LOG_FORMAT, datefmt=DATE_FORMAT, handlers=[file_handler, console_handler] ) logger = logging.getLogger(__name__) LOG_FILE = str(log_file) BACKUP_DIR = Path("backups") BACKUP_DIR.mkdir(exist_ok=True) MAX_CONFIG_BACKUPS = 10 def backup_config() -> Optional[str]: """备份配置文件到backups目录""" import shutil from datetime import datetime config_path = Path(CONFIG_FILE) if not config_path.exists(): return None timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_path = BACKUP_DIR / f"config_{timestamp}.ini" try: shutil.copy2(config_path, backup_path) logger.info(f"配置文件已备份到: {backup_path}") import glob as glob_module backup_files = sorted( glob_module.glob(str(BACKUP_DIR / "config_*.ini")), key=lambda x: Path(x).stat().st_mtime, reverse=True ) for old_file in backup_files[MAX_CONFIG_BACKUPS:]: try: Path(old_file).unlink() logger.debug(f"已删除旧备份文件: {old_file}") except Exception as e: logger.warning(f"删除旧备份文件失败: {old_file}, {e}") return str(backup_path) except Exception as e: logger.error(f"配置文件备份失败: {e}") return None def restore_config(backup_filename: str) -> bool: """从备份恢复配置文件""" import shutil backup_path = BACKUP_DIR / backup_filename if not backup_path.exists(): logger.error(f"备份文件不存在: {backup_path}") return False try: backup_config() shutil.copy2(backup_path, Path(CONFIG_FILE)) logger.info(f"配置文件已从 {backup_filename} 恢复") return True except Exception as e: logger.error(f"配置文件恢复失败: {e}") return False DB_BACKUP_DIR = Path("backups/database") DB_BACKUP_DIR.mkdir(parents=True, exist_ok=True) def backup_database() -> bool: """备份数据库(如果配置了数据库)""" import subprocess db_host = CONFIG.get('Database', 'host') db_user = CONFIG.get('Database', 'user') db_password = CONFIG.get('Database', 'password') db_name = CONFIG.get('Database', 'database') if not all([db_host, db_user, db_password, db_name]): logger.info("数据库配置不完整,跳过数据库备份") return False from datetime import datetime timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_file = DB_BACKUP_DIR / f"{db_name}_{timestamp}.sql" try: cmd = [ "mysqldump", f"-h{db_host}", f"-u{db_user}", f"-p{db_password}", db_name ] with open(backup_file, 'w', encoding='utf-8') as f: subprocess.run(cmd, stdout=f, check=True) logger.info(f"数据库已备份到: {backup_file}") import glob as glob_module backup_files = sorted( glob_module.glob(str(DB_BACKUP_DIR / f"{db_name}_*.sql")), key=lambda x: Path(x).stat().st_mtime, reverse=True ) for old_file in backup_files[MAX_CONFIG_BACKUPS:]: try: Path(old_file).unlink() except Exception as e: logger.warning(f"删除旧备份文件失败: {old_file}, {e}") return True except Exception as e: logger.error(f"数据库备份失败: {e}") return False def backup_data() -> Optional[str]: """备份数据目录(文章和图片)""" import shutil from datetime import datetime data_dirs = [CONFIG['General']['articles_path'], CONFIG['General']['images_path']] timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_name = f"data_backup_{timestamp}" backup_path = BACKUP_DIR / backup_name try: backup_path.mkdir(parents=True, exist_ok=True) for data_dir in data_dirs: if os.path.exists(data_dir): dest = backup_path / data_dir shutil.copytree(data_dir, dest, dirs_exist_ok=True) logger.info(f"数据已备份到: {backup_path}") return str(backup_path) except Exception as e: logger.error(f"数据备份失败: {e}") return None # 便捷的配置访问函数(向后兼容) def get_articles_path() -> str: """获取文章保存路径""" return CONFIG['General']['articles_path'] def get_images_path() -> str: """获取图片保存路径""" return CONFIG['General']['images_path'] def get_max_threads() -> int: """获取最大线程数""" return int(CONFIG['General']['max_threads']) def get_min_article_length() -> int: """获取最小文章字数""" return int(CONFIG['General'].get('min_article_length', '100')) def is_plagiarism_detection_enabled() -> bool: """检查是否启用原创度检测""" return CONFIG['General'].get('enable_plagiarism_detection', 'false').lower() == 'true' def get_coze_workflow_id() -> str: """获取Coze工作流ID""" return CONFIG['Coze']['workflow_id'] def get_coze_access_token() -> str: """获取Coze访问令牌""" return CONFIG['Coze']['access_token'] def is_coze_async() -> bool: """检查Coze是否异步调用""" return CONFIG['Coze'].get('is_async', 'false').lower() == 'true' def get_crop_percent() -> float: """获取图片裁剪比例""" return float(CONFIG['ImageModify']['crop_percent']) def get_rotation_range() -> tuple: """获取图片旋转范围""" return ( float(CONFIG['ImageModify']['min_rotation']), float(CONFIG['ImageModify']['max_rotation']) ) def get_brightness_range() -> tuple: """获取图片亮度范围""" return ( float(CONFIG['ImageModify']['min_brightness']), float(CONFIG['ImageModify']['max_brightness']) ) def get_watermark_text() -> str: """获取水印文字""" return CONFIG['ImageModify']['watermark_text'] def get_banned_words() -> str: """获取违禁词列表""" return CONFIG['Keywords']['banned_words']