134 lines
4.6 KiB
Python
134 lines
4.6 KiB
Python
import os
|
|
import json
|
|
from typing import List, Dict, Optional
|
|
from vector_store import VectorStore
|
|
|
|
|
|
class KnowledgeBase:
|
|
"""单个知识库类"""
|
|
|
|
def __init__(self, name: str, base_path: str = "knowledge_bases"):
|
|
self.name = name
|
|
self.base_path = os.path.join(base_path, name)
|
|
self.config_file = os.path.join(self.base_path, "config.json")
|
|
|
|
# 确保目录存在
|
|
os.makedirs(self.base_path, exist_ok=True)
|
|
|
|
# 初始化向量存储
|
|
index_file = os.path.join(self.base_path, "vector_index.faiss")
|
|
metadata_file = os.path.join(self.base_path, "metadata.json")
|
|
content_file = os.path.join(self.base_path, "content.json")
|
|
|
|
self.vector_store = VectorStore(index_file, metadata_file, content_file)
|
|
self.config = self._load_config()
|
|
|
|
def _load_config(self) -> Dict:
|
|
"""加载知识库配置"""
|
|
if os.path.exists(self.config_file):
|
|
with open(self.config_file, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
else:
|
|
# 默认配置
|
|
config = {
|
|
"name": self.name,
|
|
"description": "",
|
|
"created_at": "",
|
|
"updated_at": "",
|
|
"document_count": 0
|
|
}
|
|
self._save_config(config)
|
|
return config
|
|
|
|
def _save_config(self, config: Dict):
|
|
"""保存知识库配置"""
|
|
with open(self.config_file, 'w', encoding='utf-8') as f:
|
|
json.dump(config, f, ensure_ascii=False, indent=2)
|
|
|
|
def add_documents(self, segments: List[Dict]):
|
|
"""向知识库添加文档"""
|
|
self.vector_store.add_documents(segments)
|
|
|
|
# 更新配置
|
|
self.config["document_count"] += len(segments)
|
|
self._save_config(self.config)
|
|
|
|
def search(self, query: str, k: int = 8):
|
|
"""在知识库中搜索"""
|
|
return self.vector_store.search(query, k)
|
|
|
|
def get_all_segments(self) -> List[Dict]:
|
|
"""获取所有段落"""
|
|
return self.vector_store.get_all_segments()
|
|
|
|
def delete(self):
|
|
"""删除知识库"""
|
|
import shutil
|
|
if os.path.exists(self.base_path):
|
|
shutil.rmtree(self.base_path)
|
|
|
|
|
|
class KnowledgeBaseManager:
|
|
"""知识库管理器"""
|
|
|
|
def __init__(self, base_path: str = "knowledge_bases"):
|
|
self.base_path = base_path
|
|
os.makedirs(self.base_path, exist_ok=True)
|
|
self.knowledge_bases = {}
|
|
self._load_knowledge_bases()
|
|
|
|
def _load_knowledge_bases(self):
|
|
"""加载所有知识库"""
|
|
if os.path.exists(self.base_path):
|
|
for item in os.listdir(self.base_path):
|
|
item_path = os.path.join(self.base_path, item)
|
|
if os.path.isdir(item_path):
|
|
try:
|
|
kb = KnowledgeBase(item, self.base_path)
|
|
self.knowledge_bases[item] = kb
|
|
except Exception as e:
|
|
print(f"加载知识库 {item} 失败: {e}")
|
|
|
|
def create_knowledge_base(self, name: str, description: str = "") -> KnowledgeBase:
|
|
"""创建新的知识库"""
|
|
if name in self.knowledge_bases:
|
|
raise ValueError(f"知识库 {name} 已存在")
|
|
|
|
kb = KnowledgeBase(name, self.base_path)
|
|
kb.config["description"] = description
|
|
kb._save_config(kb.config)
|
|
|
|
self.knowledge_bases[name] = kb
|
|
return kb
|
|
|
|
def get_knowledge_base(self, name: str) -> Optional[KnowledgeBase]:
|
|
"""获取知识库"""
|
|
return self.knowledge_bases.get(name)
|
|
|
|
def list_knowledge_bases(self) -> List[Dict]:
|
|
"""列出所有知识库"""
|
|
result = []
|
|
for name, kb in self.knowledge_bases.items():
|
|
result.append({
|
|
"name": name,
|
|
"description": kb.config.get("description", ""),
|
|
"document_count": kb.config.get("document_count", 0)
|
|
})
|
|
return result
|
|
|
|
def delete_knowledge_base(self, name: str):
|
|
"""删除知识库"""
|
|
if name in self.knowledge_bases:
|
|
self.knowledge_bases[name].delete()
|
|
del self.knowledge_bases[name]
|
|
|
|
def search_all(self, query: str, k: int = 8) -> Dict[str, List]:
|
|
"""在所有知识库中搜索"""
|
|
results = {}
|
|
for name, kb in self.knowledge_bases.items():
|
|
results[name] = kb.search(query, k)
|
|
return results
|
|
|
|
|
|
# 全局知识库管理器实例
|
|
kb_manager = KnowledgeBaseManager() |