296 lines
12 KiB
Python
296 lines
12 KiB
Python
import threading
|
||
import queue
|
||
import json # 导入 json 模块
|
||
|
||
from ai_studio import call_dify_workflow,call_coze_article_workflow,call_coze_all_article_workflow
|
||
from databases import *
|
||
|
||
from images_edit import download_and_process_images
|
||
from utils import *
|
||
from get_web_content import *
|
||
from config import *
|
||
|
||
|
||
# ==============================主程序===========================
|
||
def process_link(link_info, ai_service, current_template=None,generation_type=None):
|
||
# 确保 link_info 是元组或列表
|
||
if isinstance(link_info, (tuple, list)) and len(link_info) >= 2:
|
||
link, article_type = link_info # 解包链接和类型信息
|
||
else:
|
||
# 如果不是元组或列表,假设它是单独的链接字符串
|
||
link = link_info
|
||
article_type = generation_type or "未分类"
|
||
try:
|
||
if link.startswith("https://www.toutiao.com"):
|
||
title_text, article_text, img_urls = toutiao_w_extract_content(link)
|
||
if title_text == "":
|
||
title_text, article_text, img_urls = toutiao_extract_content(link)
|
||
elif link.startswith("https://mp.weixin.qq.co"):
|
||
title_text, article_text, img_urls = wechat_extract_content(link)
|
||
elif link.startswith("https://www.163.com"):
|
||
title_text, article_text, img_urls = wangyi_extract_content(link)
|
||
else:
|
||
title_text, article_text, img_urls = "", "", []
|
||
|
||
if title_text == "":
|
||
return
|
||
elif len(title_text) > 100:
|
||
return
|
||
|
||
# 检查文章字数是否低于最小阈值
|
||
if len(article_text) < MIN_ARTICLE_LENGTH:
|
||
print(f"文章字数低于最小阈值 {MIN_ARTICLE_LENGTH},跳过处理")
|
||
return
|
||
|
||
# 获取数据库配置
|
||
host = CONFIG['Database']['host']
|
||
user = CONFIG['Database']['user']
|
||
password = CONFIG['Database']['password']
|
||
database = CONFIG['Database']['database']
|
||
|
||
# 判断文章内容是否有违禁词
|
||
check_keywords = check_keywords_in_text(title_text)
|
||
|
||
title = extract_content_until_punctuation(article_text).replace("正文:", "")
|
||
|
||
from datetime import datetime
|
||
# 获取当前时间并格式化
|
||
current_time = datetime.now().strftime("%H:%M:%S")
|
||
# 打印当前时间
|
||
print("当前时间:", current_time)
|
||
|
||
if ai_service == "dify":
|
||
if check_keywords:
|
||
print("文章中有违禁词!")
|
||
check_link_insert(host, user, password, database, link)
|
||
return
|
||
input_data_template_str = CONFIG['Dify'].get('input_data_template', '{"old_article": "{article_text}"}')
|
||
try:
|
||
input_data_template = json.loads(input_data_template_str)
|
||
input_data = {k: v.format(article_text=article_text) for k, v in input_data_template.items()}
|
||
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
||
logger.error(f"处理 Dify input_data 模板时出错: {e}. 使用默认模板.")
|
||
input_data = {"old_article": article_text}
|
||
message_content = call_dify_workflow(input_data)
|
||
|
||
elif ai_service == "coze":
|
||
logger.info("coze正在处理")
|
||
logger.info(f"正在处理的文章类型为:{generation_type}")
|
||
if current_template:
|
||
original_config = {
|
||
'workflow_id': CONFIG['Coze']['workflow_id'],
|
||
'access_token': CONFIG['Coze']['access_token'],
|
||
'is_async': CONFIG['Coze']['is_async']
|
||
}
|
||
|
||
CONFIG['Coze']['workflow_id'] = current_template.get('workflow_id', '')
|
||
CONFIG['Coze']['access_token'] = current_template.get('access_token', '')
|
||
CONFIG['Coze']['is_async'] = current_template.get('is_async', 'true')
|
||
|
||
logger.info(f"应用模板配置: {current_template.get('name')}")
|
||
logger.info(f"Workflow ID: {CONFIG['Coze']['workflow_id']}")
|
||
logger.info(f"Access Token: {'*' * len(CONFIG['Coze']['access_token'])}")
|
||
logger.info(f"Is Async: {CONFIG['Coze']['is_async']}")
|
||
|
||
try:
|
||
input_data_template_str = CONFIG['Coze'].get('input_data_template')
|
||
input_data_template = json.loads(input_data_template_str)
|
||
|
||
if generation_type == "短篇":
|
||
input_data = {"article": article_text}
|
||
print("coze中输入:", input_data)
|
||
message_content = call_coze_article_workflow(input_data)
|
||
elif generation_type == "文章":
|
||
print("原文中标题为:", title_text)
|
||
print("原文中内容为:", article_text)
|
||
input_data = {"title": title_text, "article": article_text}
|
||
print("发送的请求数据为:", input_data)
|
||
try:
|
||
result = call_coze_all_article_workflow(input_data)
|
||
# 检查返回值是否为错误信息
|
||
if isinstance(result, dict) and 'error' in result:
|
||
raise Exception(result['error'])
|
||
title, message_content = result
|
||
except Exception as e:
|
||
logger.error(f"调用 Coze 工作流时出错: {e}")
|
||
raise
|
||
finally:
|
||
if 'original_config' in locals():
|
||
CONFIG['Coze'].update(original_config)
|
||
|
||
# 去除标题首尾的空格
|
||
title_text = title_text.strip()
|
||
|
||
# 创建类型目录
|
||
type_dir = os.path.join(ARTICLES_BASE_PATH, article_type)
|
||
safe_open_directory(type_dir)
|
||
|
||
# 在类型目录下保存文章
|
||
file_name = ""
|
||
if generation_type == '短篇':
|
||
file_name = handle_duplicate_files_advanced(type_dir, title_text.strip())[0]
|
||
elif generation_type == "文章":
|
||
file_name = handle_duplicate_files_advanced(type_dir, title.strip())[0]
|
||
|
||
article_save_path = os.path.join(type_dir, f"{file_name}.txt")
|
||
|
||
if "```" in message_content:
|
||
message_content = message_content.replace("``", "")
|
||
|
||
message_content = title + "\n" + message_content
|
||
|
||
# 判断文章合规度(根据配置决定是否启用)
|
||
enable_detection = CONFIG['Baidu'].get('enable_detection', 'false').lower() == 'true'
|
||
if enable_detection:
|
||
print("正在检测文章合规度")
|
||
if text_detection(message_content) == "合规":
|
||
print("文章合规")
|
||
pass
|
||
else:
|
||
print("文章不合规")
|
||
return
|
||
else:
|
||
print("违规检测已禁用,跳过检测")
|
||
|
||
with open(article_save_path, 'w', encoding='utf-8') as f:
|
||
f.write(message_content)
|
||
logging.info('文本已经保存')
|
||
|
||
if img_urls:
|
||
# 在类型目录下创建图片目录
|
||
type_picture_dir = os.path.join(IMGS_BASE_PATH, article_type)
|
||
safe_open_directory(type_picture_dir)
|
||
# 确保文件名没有多余空格
|
||
download_and_process_images(img_urls, file_name.strip(), type_picture_dir)
|
||
|
||
except Exception as e:
|
||
logging.error(f"处理链接 {link} 时出错: {e}")
|
||
raise
|
||
|
||
|
||
def link_to_text(num_threads=None, ai_service="dify", current_template=None, generation_type=None):
|
||
use_link_path = 'use_link_path.txt'
|
||
|
||
# 读取链接
|
||
links = read_excel(TITLE_BASE_PATH)
|
||
|
||
# 过滤已处理的链接
|
||
filtered_links = []
|
||
host = CONFIG['Database']['host']
|
||
user = CONFIG['Database']['user']
|
||
password = CONFIG['Database']['password']
|
||
database = CONFIG['Database']['database']
|
||
|
||
for link_info in links:
|
||
link = link_info[0].strip() # 获取链接并去除空白字符
|
||
# 如果Excel中有类型,使用Excel中的类型,否则使用传入的generation_type
|
||
article_type = link_info[1].strip() if len(link_info) > 1 and link_info[1].strip() else generation_type
|
||
logging.info(f"总共{len(links)}个链接")
|
||
# if check_link_exists(host, user, password, database, link):
|
||
# logger.info(f"链接已存在: {link}")
|
||
# continue
|
||
# else:
|
||
filtered_links.append((link, article_type)) # 保存链接和类型的元组
|
||
# logger.info(f"链接不存在: {link}")
|
||
# print("链接不存在,存储到过滤器中:", link)
|
||
|
||
if not filtered_links:
|
||
logger.info("没有新链接需要处理")
|
||
return []
|
||
|
||
# 使用多线程处理链接
|
||
results = process_links_with_threads(filtered_links, num_threads, ai_service, current_template,generation_type)
|
||
|
||
# 记录已处理的链接
|
||
with open(use_link_path, 'a+', encoding='utf-8') as f:
|
||
for result in results:
|
||
# 确保 result 是正确的格式
|
||
if isinstance(result, tuple) and len(result) >= 2:
|
||
link_info, success, _ = result
|
||
# 如果 link_info 是元组,提取链接;否则直接使用
|
||
if isinstance(link_info, (tuple, list)):
|
||
link = link_info[0]
|
||
else:
|
||
link = link_info
|
||
if success:
|
||
f.write(str(link) + "\n")
|
||
else:
|
||
logger.warning(f"意外的结果格式: {result}")
|
||
|
||
return results
|
||
|
||
|
||
# 创建一个任务队列和结果队列
|
||
task_queue = queue.Queue()
|
||
result_queue = queue.Queue()
|
||
|
||
|
||
# 工作线程函数
|
||
def worker(ai_service, current_template=None,generation_type=None):
|
||
while True:
|
||
try:
|
||
# 从队列中获取任务
|
||
link_info = task_queue.get()
|
||
if link_info is None: # 结束信号
|
||
break
|
||
|
||
# 解包链接和类型信息
|
||
link, article_type = link_info
|
||
|
||
# 处理链接
|
||
try:
|
||
logger.info(f"开始处理链接:{link}")
|
||
process_link((link, article_type), ai_service, current_template,generation_type)
|
||
result_queue.put(((link, article_type), True, None)) # 成功
|
||
except Exception as e:
|
||
result_queue.put(((link, article_type), False, str(e))) # 失败
|
||
logger.error(f"处理链接 {link} 时出错: {e}")
|
||
|
||
# 标记任务完成
|
||
task_queue.task_done()
|
||
except Exception as e:
|
||
logger.error(f"工作线程出错: {e}")
|
||
|
||
|
||
# 多线程处理链接
|
||
def process_links_with_threads(links, num_threads=None, ai_service="dify", current_template=None,generation_type=None):
|
||
if num_threads is None:
|
||
num_threads = min(MAX_THREADS, len(links))
|
||
else:
|
||
num_threads = min(num_threads, MAX_THREADS, len(links))
|
||
|
||
# 清空任务队列和结果队列
|
||
while not task_queue.empty():
|
||
task_queue.get()
|
||
while not result_queue.empty():
|
||
result_queue.get()
|
||
|
||
# 创建工作线程
|
||
threads = []
|
||
|
||
# 将AI服务选择和模板配置传递给worker函数
|
||
for _ in range(num_threads):
|
||
t = threading.Thread(target=worker, args=(ai_service, current_template,generation_type))
|
||
t.daemon = True
|
||
t.start()
|
||
threads.append(t)
|
||
|
||
# 添加任务到队列
|
||
for link in links:
|
||
task_queue.put(link)
|
||
|
||
# 添加结束信号
|
||
for _ in range(num_threads):
|
||
task_queue.put(None)
|
||
|
||
# 等待所有线程完成
|
||
for t in threads:
|
||
t.join()
|
||
|
||
# 处理结果
|
||
results = []
|
||
while not result_queue.empty():
|
||
results.append(result_queue.get())
|
||
|
||
return results
|