Files
ArticleReplaceBatch/config.py

392 lines
12 KiB
Python
Raw Normal View History

2026-03-25 15:17:18 +08:00
"""
改进的config.py - 消除全局变量使用ConfigManager
"""
import configparser
import getpass
import logging
import os
import sys
from pathlib import Path
from logging.handlers import RotatingFileHandler
from typing import Optional
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
from config_manager import config_manager
CONFIG_FILE = "config.ini"
def _get_env_value(key: str, default: str = "") -> str:
"""从环境变量获取配置值"""
value = os.getenv(key, default)
return value or default
DEFAULT_CONFIG = {
"General": {
"chrome_user_dir": f"C:\\Users\\{getpass.getuser()}\\AppData\\Local\\Google\\Chrome\\User Data",
"articles_path": "articles",
"images_path": "picture",
"title_file": "文章链接.xlsx",
"max_threads": "3",
"min_article_length": "100",
"enable_plagiarism_detection": "false"
},
"Coze": {
"workflow_id": "",
"access_token": "",
"is_async": "false",
"input_data_template": "{\"article\": \"{article_text}\", \"link\":\"{link}\", \"weijin\":\"{weijin}\"}",
"last_used_template": "",
"last_used_template_type": "文章"
},
"Database": {
"host": _get_env_value("DB_HOST", ""),
"user": _get_env_value("DB_USER", ""),
"password": _get_env_value("DB_PASSWORD", ""),
"database": _get_env_value("DB_NAME", "toutiao")
},
"Dify": {
"api_key": _get_env_value("DIFY_API_KEY", ""),
"user_id": _get_env_value("DIFY_USER_ID", "toutiao"),
"url": _get_env_value("DIFY_URL", "http://27.106.125.150/v1/workflows/run"),
"input_data_template": "{\"old_article\": \"{article_text}\"}"
},
"Baidu": {
"api_key": _get_env_value("BAIDU_API_KEY", ""),
"secret_key": _get_env_value("BAIDU_SECRET_KEY", ""),
"enable_detection": _get_env_value("BAIDU_ENABLE_DETECTION", "false"),
"save_violation_articles": "true"
},
"ImageModify": {
"crop_percent": "0.02",
"min_rotation": "0.3",
"max_rotation": "3.0",
"min_brightness": "0.8",
"max_brightness": "1.2",
"watermark_text": "Qin Quan Shan Chu",
"watermark_opacity": "128",
"overlay_opacity": "30"
},
"Keywords": {
"banned_words": "珠海,落马,股票,股市,股民,爆炸,火灾,死亡,抢劫,诈骗,习大大,习近平,政府,官员,扫黑,警察,落网,嫌疑人,通报,暴力执法,执法,暴力,气象,天气,暴雨,大雨"
}
}
def create_default_config() -> bool:
"""创建默认配置文件"""
config = configparser.ConfigParser()
config.read_dict(DEFAULT_CONFIG)
directories = ["articles", "picture", "data", "logs", "archive", "backups"]
for directory in directories:
Path(directory).mkdir(exist_ok=True)
with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
config.write(f)
print(f"已创建默认配置文件: {CONFIG_FILE}", file=sys.stderr)
print("请编辑 config.ini 文件配置您的参数", file=sys.stderr)
return True
def load_config() -> configparser.ConfigParser:
"""加载配置文件"""
config = configparser.ConfigParser()
if not os.path.exists(CONFIG_FILE):
for section, options in DEFAULT_CONFIG.items():
config[section] = options
with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
config.write(f)
else:
config.read(CONFIG_FILE, encoding='utf-8')
for section, options in DEFAULT_CONFIG.items():
if not config.has_section(section):
config[section] = {}
for option, value in options.items():
if not config.has_option(section, option):
config[section][option] = value
with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
config.write(f)
return config
def save_config(config: configparser.ConfigParser) -> None:
"""保存配置文件"""
with open(CONFIG_FILE, 'w', encoding='utf-8') as f:
config.write(f)
CONFIG = load_config()
# 向后兼容的全局变量
ARTICLES_BASE_PATH = CONFIG['General']['articles_path']
IMGS_BASE_PATH = CONFIG['General']['images_path']
TITLE_BASE_PATH = CONFIG['General']['title_file']
MAX_THREADS = int(CONFIG['General']['max_threads'])
MIN_ARTICLE_LENGTH = int(CONFIG['General']['min_article_length'])
ENABLE_PLAGIARISM_DETECTION = CONFIG['General'].get('enable_plagiarism_detection', 'false').lower() == 'true'
# 创建必要的目录
directories = ["articles", "picture", "data", "logs", "archive", "backups"]
for directory in directories:
dir_path = Path(directory)
if not dir_path.exists():
dir_path.mkdir(parents=True, exist_ok=True)
try:
dir_path.chmod(0o777)
except (AttributeError, PermissionError):
pass
# 日志配置 - 使用RotatingFileHandler实现日志轮转
LOG_DIR = Path("logs")
LOG_DIR.mkdir(exist_ok=True)
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s"
DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
log_file = LOG_DIR / "article_replace.log"
file_handler = RotatingFileHandler(
log_file,
maxBytes=10 * 1024 * 1024,
backupCount=5,
encoding='utf-8'
)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter(LOG_FORMAT, DATE_FORMAT))
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(LOG_FORMAT, DATE_FORMAT))
logging.basicConfig(
level=logging.INFO,
format=LOG_FORMAT,
datefmt=DATE_FORMAT,
handlers=[file_handler, console_handler]
)
logger = logging.getLogger(__name__)
LOG_FILE = str(log_file)
BACKUP_DIR = Path("backups")
BACKUP_DIR.mkdir(exist_ok=True)
MAX_CONFIG_BACKUPS = 10
def backup_config() -> Optional[str]:
"""备份配置文件到backups目录"""
import shutil
from datetime import datetime
config_path = Path(CONFIG_FILE)
if not config_path.exists():
return None
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = BACKUP_DIR / f"config_{timestamp}.ini"
try:
shutil.copy2(config_path, backup_path)
logger.info(f"配置文件已备份到: {backup_path}")
import glob as glob_module
backup_files = sorted(
glob_module.glob(str(BACKUP_DIR / "config_*.ini")),
key=lambda x: Path(x).stat().st_mtime,
reverse=True
)
for old_file in backup_files[MAX_CONFIG_BACKUPS:]:
try:
Path(old_file).unlink()
logger.debug(f"已删除旧备份文件: {old_file}")
except Exception as e:
logger.warning(f"删除旧备份文件失败: {old_file}, {e}")
return str(backup_path)
except Exception as e:
logger.error(f"配置文件备份失败: {e}")
return None
def restore_config(backup_filename: str) -> bool:
"""从备份恢复配置文件"""
import shutil
backup_path = BACKUP_DIR / backup_filename
if not backup_path.exists():
logger.error(f"备份文件不存在: {backup_path}")
return False
try:
backup_config()
shutil.copy2(backup_path, Path(CONFIG_FILE))
logger.info(f"配置文件已从 {backup_filename} 恢复")
return True
except Exception as e:
logger.error(f"配置文件恢复失败: {e}")
return False
DB_BACKUP_DIR = Path("backups/database")
DB_BACKUP_DIR.mkdir(parents=True, exist_ok=True)
def backup_database() -> bool:
"""备份数据库(如果配置了数据库)"""
import subprocess
db_host = CONFIG.get('Database', 'host')
db_user = CONFIG.get('Database', 'user')
db_password = CONFIG.get('Database', 'password')
db_name = CONFIG.get('Database', 'database')
if not all([db_host, db_user, db_password, db_name]):
logger.info("数据库配置不完整,跳过数据库备份")
return False
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = DB_BACKUP_DIR / f"{db_name}_{timestamp}.sql"
try:
cmd = [
"mysqldump",
f"-h{db_host}",
f"-u{db_user}",
f"-p{db_password}",
db_name
]
with open(backup_file, 'w', encoding='utf-8') as f:
subprocess.run(cmd, stdout=f, check=True)
logger.info(f"数据库已备份到: {backup_file}")
import glob as glob_module
backup_files = sorted(
glob_module.glob(str(DB_BACKUP_DIR / f"{db_name}_*.sql")),
key=lambda x: Path(x).stat().st_mtime,
reverse=True
)
for old_file in backup_files[MAX_CONFIG_BACKUPS:]:
try:
Path(old_file).unlink()
except Exception as e:
logger.warning(f"删除旧备份文件失败: {old_file}, {e}")
return True
except Exception as e:
logger.error(f"数据库备份失败: {e}")
return False
def backup_data() -> Optional[str]:
"""备份数据目录(文章和图片)"""
import shutil
from datetime import datetime
data_dirs = [CONFIG['General']['articles_path'], CONFIG['General']['images_path']]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_name = f"data_backup_{timestamp}"
backup_path = BACKUP_DIR / backup_name
try:
backup_path.mkdir(parents=True, exist_ok=True)
for data_dir in data_dirs:
if os.path.exists(data_dir):
dest = backup_path / data_dir
shutil.copytree(data_dir, dest, dirs_exist_ok=True)
logger.info(f"数据已备份到: {backup_path}")
return str(backup_path)
except Exception as e:
logger.error(f"数据备份失败: {e}")
return None
# 便捷的配置访问函数(向后兼容)
def get_articles_path() -> str:
"""获取文章保存路径"""
return CONFIG['General']['articles_path']
def get_images_path() -> str:
"""获取图片保存路径"""
return CONFIG['General']['images_path']
def get_max_threads() -> int:
"""获取最大线程数"""
return int(CONFIG['General']['max_threads'])
def get_min_article_length() -> int:
"""获取最小文章字数"""
return int(CONFIG['General'].get('min_article_length', '100'))
def is_plagiarism_detection_enabled() -> bool:
"""检查是否启用原创度检测"""
return CONFIG['General'].get('enable_plagiarism_detection', 'false').lower() == 'true'
def get_coze_workflow_id() -> str:
"""获取Coze工作流ID"""
return CONFIG['Coze']['workflow_id']
def get_coze_access_token() -> str:
"""获取Coze访问令牌"""
return CONFIG['Coze']['access_token']
def is_coze_async() -> bool:
"""检查Coze是否异步调用"""
return CONFIG['Coze'].get('is_async', 'false').lower() == 'true'
def get_crop_percent() -> float:
"""获取图片裁剪比例"""
return float(CONFIG['ImageModify']['crop_percent'])
def get_rotation_range() -> tuple:
"""获取图片旋转范围"""
return (
float(CONFIG['ImageModify']['min_rotation']),
float(CONFIG['ImageModify']['max_rotation'])
)
def get_brightness_range() -> tuple:
"""获取图片亮度范围"""
return (
float(CONFIG['ImageModify']['min_brightness']),
float(CONFIG['ImageModify']['max_brightness'])
)
def get_watermark_text() -> str:
"""获取水印文字"""
return CONFIG['ImageModify']['watermark_text']
def get_banned_words() -> str:
"""获取违禁词列表"""
return CONFIG['Keywords']['banned_words']