import numpy as np from typing import List, Tuple, Dict from rank_bm25 import BM25Okapi from vector_store import VectorStore from config import settings from exceptions import RetrievalError from cache_manager import cache_manager import hashlib class HybridRetriever: def __init__(self, vector_store: VectorStore): self.vector_store = vector_store self.bm25 = None self.segments = [] self.segment_contents = [] def prepare_bm25(self, segments: List[Dict]): """ Prepare BM25 index with segments Args: segments: List of document segments """ self.segments = segments # Extract content for BM25 self.segment_contents = [segment["content"] for segment in segments] # Tokenize content for BM25 tokenized_corpus = [doc.split() for doc in self.segment_contents] # Create BM25 index self.bm25 = BM25Okapi(tokenized_corpus) def reciprocal_rank_fusion(self, bm25_results: List[Tuple[Dict, float]], vector_results: List[Tuple[Dict, float]], k: int = 60) -> List[Tuple[Dict, float]]: """ 使用Reciprocal Rank Fusion (RRF)融合BM25和向量检索结果 Args: bm25_results: BM25检索结果 vector_results: 向量检索结果 k: RRF参数 Returns: 融合后的结果列表 """ # 创建段落ID到结果的映射 bm25_map = {result[0]["metadata"]["segment_id"]: (result[0], result[1], i+1) for i, result in enumerate(bm25_results)} vector_map = {result[0]["metadata"]["segment_id"]: (result[0], result[1], i+1) for i, result in enumerate(vector_results)} # 获取所有唯一的段落ID all_segment_ids = set(bm25_map.keys()) | set(vector_map.keys()) # 计算RRF得分 fused_scores = {} for segment_id in all_segment_ids: bm25_rank = bm25_map[segment_id][2] if segment_id in bm25_map else len(bm25_results) + 1 vector_rank = vector_map[segment_id][2] if segment_id in vector_map else len(vector_results) + 1 # RRF公式 rrf_score = 1.0 / (bm25_rank + k) + 1.0 / (vector_rank + k) fused_scores[segment_id] = rrf_score # 排序并返回结果 sorted_segments = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True) # 构建结果 results = [] for segment_id, score in sorted_segments: if segment_id in bm25_map: result = bm25_map[segment_id][0] else: result = vector_map[segment_id][0] results.append((result, score)) return results def hybrid_search(self, query: str, k: int = settings.TOP_K, bm25_weight: float = 0.5) -> List[Tuple[Dict, float]]: """ Perform hybrid search using BM25 and vector search Args: query: Search query k: Number of results to return bm25_weight: Weight for BM25 scores (0.0 to 1.0) Returns: List of (segment, score) tuples sorted by combined score """ # 生成缓存键 cache_key = f"hybrid_search:{hashlib.md5(f'{query}:{k}:{bm25_weight}'.encode()).hexdigest()}" # 尝试从缓存获取结果 cached_result = cache_manager.get(cache_key) if cached_result is not None: return cached_result if self.bm25 is None: # Fallback to vector search only result = self.vector_store.search(query, k) # 缓存结果 cache_manager.put(cache_key, result) return result # Perform BM25 search tokenized_query = query.split() bm25_scores = self.bm25.get_scores(tokenized_query) # Format BM25 results bm25_results = [] for i, segment in enumerate(self.segments): result = { "metadata": segment["metadata"], "content": segment["content"], "score": float(bm25_scores[i]), "bm25_score": float(bm25_scores[i]), "vector_score": 0.0 } bm25_results.append((result, float(bm25_scores[i]))) # Perform vector search vector_results = self.vector_store.search(query, k * 2) # Get more results for re-ranking # Use RRF for fusion if both results are available if bm25_results and vector_results: fused_results = self.reciprocal_rank_fusion(bm25_results, vector_results, k=60) result = fused_results[:k] elif bm25_results: # Only BM25 results bm25_results.sort(key=lambda x: x[1], reverse=True) result = bm25_results[:k] elif vector_results: # Only vector results result = vector_results[:k] else: # No results result = [] # 缓存结果 cache_manager.put(cache_key, result) return result def search_with_context(self, query: str, k: int = settings.TOP_K, include_context: bool = True) -> List[Tuple[Dict, float]]: """ Search and include context (previous and next segments) Args: query: Search query k: Number of results to return include_context: Whether to include context segments Returns: List of (segment, score) tuples """ # Perform hybrid search results = self.hybrid_search(query, k) if not include_context: return results # Add context segments enriched_results = [] for result, score in results: # Add previous and next segments if they exist file_name = result["metadata"]["file_name"] paragraph_id = result["metadata"]["paragraph_id"] # Find previous segment prev_segment = None for segment in self.segments: if (segment["metadata"]["file_name"] == file_name and segment["metadata"]["paragraph_id"] == paragraph_id - 1): prev_segment = segment break # Find next segment next_segment = None for segment in self.segments: if (segment["metadata"]["file_name"] == file_name and segment["metadata"]["paragraph_id"] == paragraph_id + 1): next_segment = segment break # Add context to result enriched_result = result.copy() enriched_result["context"] = { "previous": prev_segment, "next": next_segment } enriched_results.append((enriched_result, score)) return enriched_results