import os import json import numpy as np import faiss from typing import List, Dict, Tuple, Optional from config import settings from model_manager import model_router from tqdm import tqdm from exceptions import VectorStoreError from concurrent.futures import ThreadPoolExecutor import threading class VectorStore: def __init__(self, index_file: str = settings.VECTOR_INDEX_FILE, metadata_file: str = settings.METADATA_FILE, content_file: str = "content.json"): self.index_file = index_file self.metadata_file = metadata_file self.content_file = content_file self.dimension = 1536 # text-embedding-ada-002 dimension self.index = None self.metadata = [] self.content = [] # Initialize client based on provider self.client = self._initialize_client() # Load existing index and metadata if they exist self.load() def _initialize_client(self): """Initialize API client - 使用模型路由器,不需要单独初始化客户端""" return None def _initialize_index(self): """初始化向量索引""" # 使用IVF索引提高检索效率 nlist = 100 # 聚类中心数量 quantizer = faiss.IndexFlatIP(self.dimension) self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist, faiss.METRIC_INNER_PRODUCT) # 训练索引(如果有数据) if len(self.content) > 0: # 提取现有向量用于训练 embeddings = self._get_embeddings(self.content[:min(1000, len(self.content))]) embeddings_array = np.array(embeddings, dtype=np.float32) self.index.train(embeddings_array) def _get_embeddings(self, texts: List[str]) -> List[List[float]]: """Get embeddings for texts using model router""" try: embeddings = model_router.get_embeddings(settings.EMBEDDING_MODEL, texts) return embeddings except Exception as e: raise ValueError(f"获取嵌入向量失败: {str(e)}") def add_documents(self, segments: List[Dict]): """Add document segments to the vector store""" if not segments: return # Extract texts from segments texts = [segment["content"] for segment in segments] contents = texts[:] # Keep a copy of original content # Get embeddings embeddings = self._get_embeddings(texts) # Normalize embeddings (L2) embeddings = [np.array(emb) / np.linalg.norm(emb) for emb in embeddings] embeddings_array = np.array(embeddings, dtype=np.float32) # Initialize index if needed if self.index is None: self._initialize_index() # Check if index exists and is trained if self.index is not None: # Train index if not trained if not self.index.is_trained: self.index.train(embeddings_array) # Add to index self.index.add(embeddings_array) # Add metadata and content for segment in segments: self.metadata.append(segment["metadata"]) self.content.extend(contents) # Save after adding self.save() def search(self, query: str, k: int = settings.TOP_K) -> List[Tuple[Dict, float]]: """Search for similar segments""" if self.index is None or len(self.metadata) == 0: return [] # Get query embedding query_embedding = self._get_embeddings([query])[0] # Normalize query embedding query_embedding = np.array(query_embedding) / np.linalg.norm(query_embedding) query_embedding = np.array([query_embedding], dtype=np.float32) # Search distances, indices = self.index.search(query_embedding, k) # Prepare results results = [] for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): if idx < len(self.metadata): # Check bounds result = { "metadata": self.metadata[idx], "content": self.content[idx] if idx < len(self.content) else "", "score": float(1 - distance) # Convert distance to similarity score } results.append((result, float(1 - distance))) return results def get_all_segments(self) -> List[Dict]: """Get all segments with their content""" segments = [] for i, metadata in enumerate(self.metadata): segment = { "content": self.content[i] if i < len(self.content) else "", "metadata": metadata } segments.append(segment) return segments def save(self): """Save index, metadata, and content to disk""" if self.index is not None: faiss.write_index(self.index, self.index_file) with open(self.metadata_file, "w", encoding="utf-8") as f: json.dump(self.metadata, f, ensure_ascii=False, indent=2) with open(self.content_file, "w", encoding="utf-8") as f: json.dump(self.content, f, ensure_ascii=False, indent=2) def load(self): """Load index, metadata, and content from disk""" if os.path.exists(self.index_file): self.index = faiss.read_index(self.index_file) if os.path.exists(self.metadata_file): with open(self.metadata_file, "r", encoding="utf-8") as f: self.metadata = json.load(f) if os.path.exists(self.content_file): with open(self.content_file, "r", encoding="utf-8") as f: self.content = json.load(f)