155 lines
5.8 KiB
Python
155 lines
5.8 KiB
Python
import os
|
|
import json
|
|
import numpy as np
|
|
import faiss
|
|
from typing import List, Dict, Tuple, Optional
|
|
from config import settings
|
|
from model_manager import model_router
|
|
from tqdm import tqdm
|
|
from exceptions import VectorStoreError
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import threading
|
|
|
|
class VectorStore:
|
|
def __init__(self, index_file: str = settings.VECTOR_INDEX_FILE,
|
|
metadata_file: str = settings.METADATA_FILE,
|
|
content_file: str = "content.json"):
|
|
self.index_file = index_file
|
|
self.metadata_file = metadata_file
|
|
self.content_file = content_file
|
|
self.dimension = 1536 # text-embedding-ada-002 dimension
|
|
self.index = None
|
|
self.metadata = []
|
|
self.content = []
|
|
|
|
# Initialize client based on provider
|
|
self.client = self._initialize_client()
|
|
|
|
# Load existing index and metadata if they exist
|
|
self.load()
|
|
|
|
def _initialize_client(self):
|
|
"""Initialize API client - 使用模型路由器,不需要单独初始化客户端"""
|
|
return None
|
|
|
|
def _initialize_index(self):
|
|
"""初始化向量索引"""
|
|
# 使用IVF索引提高检索效率
|
|
nlist = 100 # 聚类中心数量
|
|
quantizer = faiss.IndexFlatIP(self.dimension)
|
|
self.index = faiss.IndexIVFFlat(quantizer, self.dimension, nlist, faiss.METRIC_INNER_PRODUCT)
|
|
|
|
# 训练索引(如果有数据)
|
|
if len(self.content) > 0:
|
|
# 提取现有向量用于训练
|
|
embeddings = self._get_embeddings(self.content[:min(1000, len(self.content))])
|
|
embeddings_array = np.array(embeddings, dtype=np.float32)
|
|
self.index.train(embeddings_array)
|
|
|
|
def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
|
"""Get embeddings for texts using model router"""
|
|
try:
|
|
embeddings = model_router.get_embeddings(settings.EMBEDDING_MODEL, texts)
|
|
return embeddings
|
|
except Exception as e:
|
|
raise ValueError(f"获取嵌入向量失败: {str(e)}")
|
|
|
|
def add_documents(self, segments: List[Dict]):
|
|
"""Add document segments to the vector store"""
|
|
if not segments:
|
|
return
|
|
|
|
# Extract texts from segments
|
|
texts = [segment["content"] for segment in segments]
|
|
contents = texts[:] # Keep a copy of original content
|
|
|
|
# Get embeddings
|
|
embeddings = self._get_embeddings(texts)
|
|
|
|
# Normalize embeddings (L2)
|
|
embeddings = [np.array(emb) / np.linalg.norm(emb) for emb in embeddings]
|
|
embeddings_array = np.array(embeddings, dtype=np.float32)
|
|
|
|
# Initialize index if needed
|
|
if self.index is None:
|
|
self._initialize_index()
|
|
|
|
# Check if index exists and is trained
|
|
if self.index is not None:
|
|
# Train index if not trained
|
|
if not self.index.is_trained:
|
|
self.index.train(embeddings_array)
|
|
|
|
# Add to index
|
|
self.index.add(embeddings_array)
|
|
|
|
# Add metadata and content
|
|
for segment in segments:
|
|
self.metadata.append(segment["metadata"])
|
|
self.content.extend(contents)
|
|
|
|
# Save after adding
|
|
self.save()
|
|
|
|
def search(self, query: str, k: int = settings.TOP_K) -> List[Tuple[Dict, float]]:
|
|
"""Search for similar segments"""
|
|
if self.index is None or len(self.metadata) == 0:
|
|
return []
|
|
|
|
# Get query embedding
|
|
query_embedding = self._get_embeddings([query])[0]
|
|
|
|
# Normalize query embedding
|
|
query_embedding = np.array(query_embedding) / np.linalg.norm(query_embedding)
|
|
query_embedding = np.array([query_embedding], dtype=np.float32)
|
|
|
|
# Search
|
|
distances, indices = self.index.search(query_embedding, k)
|
|
|
|
# Prepare results
|
|
results = []
|
|
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
|
|
if idx < len(self.metadata): # Check bounds
|
|
result = {
|
|
"metadata": self.metadata[idx],
|
|
"content": self.content[idx] if idx < len(self.content) else "",
|
|
"score": float(1 - distance) # Convert distance to similarity score
|
|
}
|
|
results.append((result, float(1 - distance)))
|
|
|
|
return results
|
|
|
|
def get_all_segments(self) -> List[Dict]:
|
|
"""Get all segments with their content"""
|
|
segments = []
|
|
for i, metadata in enumerate(self.metadata):
|
|
segment = {
|
|
"content": self.content[i] if i < len(self.content) else "",
|
|
"metadata": metadata
|
|
}
|
|
segments.append(segment)
|
|
return segments
|
|
|
|
def save(self):
|
|
"""Save index, metadata, and content to disk"""
|
|
if self.index is not None:
|
|
faiss.write_index(self.index, self.index_file)
|
|
|
|
with open(self.metadata_file, "w", encoding="utf-8") as f:
|
|
json.dump(self.metadata, f, ensure_ascii=False, indent=2)
|
|
|
|
with open(self.content_file, "w", encoding="utf-8") as f:
|
|
json.dump(self.content, f, ensure_ascii=False, indent=2)
|
|
|
|
def load(self):
|
|
"""Load index, metadata, and content from disk"""
|
|
if os.path.exists(self.index_file):
|
|
self.index = faiss.read_index(self.index_file)
|
|
|
|
if os.path.exists(self.metadata_file):
|
|
with open(self.metadata_file, "r", encoding="utf-8") as f:
|
|
self.metadata = json.load(f)
|
|
|
|
if os.path.exists(self.content_file):
|
|
with open(self.content_file, "r", encoding="utf-8") as f:
|
|
self.content = json.load(f) |