nodebookls/vector_store.py
2025-10-29 13:56:24 +08:00

155 lines
5.8 KiB
Python

import os
import json
import numpy as np
import faiss
from typing import List, Dict, Tuple, Optional
from config import settings
from model_manager import model_router
from tqdm import tqdm
from exceptions import VectorStoreError
from concurrent.futures import ThreadPoolExecutor
import threading
class VectorStore:
def __init__(self, index_file: str = settings.VECTOR_INDEX_FILE,
metadata_file: str = settings.METADATA_FILE,
content_file: str = "content.json"):
self.index_file = index_file
self.metadata_file = metadata_file
self.content_file = content_file
self.dimension = 1536 # text-embedding-ada-002 dimension
self.index = None
self.metadata = []
self.content = []
# Initialize client based on provider
self.client = self._initialize_client()
# Load existing index and metadata if they exist
self.load()
def _initialize_client(self):
"""Initialize API client - 使用模型路由器,不需要单独初始化客户端"""
return None
def _initialize_index(self):
"""初始化向量索引"""
# 使用IVF索引提高检索效率
nlist = 100 # 聚类中心数量
quantizer = faiss.IndexFlatIP(self.dimension)
self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist, faiss.METRIC_INNER_PRODUCT)
# 训练索引(如果有数据)
if len(self.content) > 0:
# 提取现有向量用于训练
embeddings = self._get_embeddings(self.content[:min(1000, len(self.content))])
embeddings_array = np.array(embeddings, dtype=np.float32)
self.index.train(embeddings_array)
def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Get embeddings for texts using model router"""
try:
embeddings = model_router.get_embeddings(settings.EMBEDDING_MODEL, texts)
return embeddings
except Exception as e:
raise ValueError(f"获取嵌入向量失败: {str(e)}")
def add_documents(self, segments: List[Dict]):
"""Add document segments to the vector store"""
if not segments:
return
# Extract texts from segments
texts = [segment["content"] for segment in segments]
contents = texts[:] # Keep a copy of original content
# Get embeddings
embeddings = self._get_embeddings(texts)
# Normalize embeddings (L2)
embeddings = [np.array(emb) / np.linalg.norm(emb) for emb in embeddings]
embeddings_array = np.array(embeddings, dtype=np.float32)
# Initialize index if needed
if self.index is None:
self._initialize_index()
# Check if index exists and is trained
if self.index is not None:
# Train index if not trained
if not self.index.is_trained:
self.index.train(embeddings_array)
# Add to index
self.index.add(embeddings_array)
# Add metadata and content
for segment in segments:
self.metadata.append(segment["metadata"])
self.content.extend(contents)
# Save after adding
self.save()
def search(self, query: str, k: int = settings.TOP_K) -> List[Tuple[Dict, float]]:
"""Search for similar segments"""
if self.index is None or len(self.metadata) == 0:
return []
# Get query embedding
query_embedding = self._get_embeddings([query])[0]
# Normalize query embedding
query_embedding = np.array(query_embedding) / np.linalg.norm(query_embedding)
query_embedding = np.array([query_embedding], dtype=np.float32)
# Search
distances, indices = self.index.search(query_embedding, k)
# Prepare results
results = []
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
if idx < len(self.metadata): # Check bounds
result = {
"metadata": self.metadata[idx],
"content": self.content[idx] if idx < len(self.content) else "",
"score": float(1 - distance) # Convert distance to similarity score
}
results.append((result, float(1 - distance)))
return results
def get_all_segments(self) -> List[Dict]:
"""Get all segments with their content"""
segments = []
for i, metadata in enumerate(self.metadata):
segment = {
"content": self.content[i] if i < len(self.content) else "",
"metadata": metadata
}
segments.append(segment)
return segments
def save(self):
"""Save index, metadata, and content to disk"""
if self.index is not None:
faiss.write_index(self.index, self.index_file)
with open(self.metadata_file, "w", encoding="utf-8") as f:
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
with open(self.content_file, "w", encoding="utf-8") as f:
json.dump(self.content, f, ensure_ascii=False, indent=2)
def load(self):
"""Load index, metadata, and content from disk"""
if os.path.exists(self.index_file):
self.index = faiss.read_index(self.index_file)
if os.path.exists(self.metadata_file):
with open(self.metadata_file, "r", encoding="utf-8") as f:
self.metadata = json.load(f)
if os.path.exists(self.content_file):
with open(self.content_file, "r", encoding="utf-8") as f:
self.content = json.load(f)