nodebookls/hybrid_retriever.py
2025-10-29 13:56:24 +08:00

199 lines
7.1 KiB
Python

import numpy as np
from typing import List, Tuple, Dict
from rank_bm25 import BM25Okapi
from vector_store import VectorStore
from config import settings
from exceptions import RetrievalError
from cache_manager import cache_manager
import hashlib
class HybridRetriever:
def __init__(self, vector_store: VectorStore):
self.vector_store = vector_store
self.bm25 = None
self.segments = []
self.segment_contents = []
def prepare_bm25(self, segments: List[Dict]):
"""
Prepare BM25 index with segments
Args:
segments: List of document segments
"""
self.segments = segments
# Extract content for BM25
self.segment_contents = [segment["content"] for segment in segments]
# Tokenize content for BM25
tokenized_corpus = [doc.split() for doc in self.segment_contents]
# Create BM25 index
self.bm25 = BM25Okapi(tokenized_corpus)
def reciprocal_rank_fusion(self, bm25_results: List[Tuple[Dict, float]],
vector_results: List[Tuple[Dict, float]],
k: int = 60) -> List[Tuple[Dict, float]]:
"""
使用Reciprocal Rank Fusion (RRF)融合BM25和向量检索结果
Args:
bm25_results: BM25检索结果
vector_results: 向量检索结果
k: RRF参数
Returns:
融合后的结果列表
"""
# 创建段落ID到结果的映射
bm25_map = {result[0]["metadata"]["segment_id"]: (result[0], result[1], i+1)
for i, result in enumerate(bm25_results)}
vector_map = {result[0]["metadata"]["segment_id"]: (result[0], result[1], i+1)
for i, result in enumerate(vector_results)}
# 获取所有唯一的段落ID
all_segment_ids = set(bm25_map.keys()) | set(vector_map.keys())
# 计算RRF得分
fused_scores = {}
for segment_id in all_segment_ids:
bm25_rank = bm25_map[segment_id][2] if segment_id in bm25_map else len(bm25_results) + 1
vector_rank = vector_map[segment_id][2] if segment_id in vector_map else len(vector_results) + 1
# RRF公式
rrf_score = 1.0 / (bm25_rank + k) + 1.0 / (vector_rank + k)
fused_scores[segment_id] = rrf_score
# 排序并返回结果
sorted_segments = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
# 构建结果
results = []
for segment_id, score in sorted_segments:
if segment_id in bm25_map:
result = bm25_map[segment_id][0]
else:
result = vector_map[segment_id][0]
results.append((result, score))
return results
def hybrid_search(self, query: str, k: int = settings.TOP_K,
bm25_weight: float = 0.5) -> List[Tuple[Dict, float]]:
"""
Perform hybrid search using BM25 and vector search
Args:
query: Search query
k: Number of results to return
bm25_weight: Weight for BM25 scores (0.0 to 1.0)
Returns:
List of (segment, score) tuples sorted by combined score
"""
# 生成缓存键
cache_key = f"hybrid_search:{hashlib.md5(f'{query}:{k}:{bm25_weight}'.encode()).hexdigest()}"
# 尝试从缓存获取结果
cached_result = cache_manager.get(cache_key)
if cached_result is not None:
return cached_result
if self.bm25 is None:
# Fallback to vector search only
result = self.vector_store.search(query, k)
# 缓存结果
cache_manager.put(cache_key, result)
return result
# Perform BM25 search
tokenized_query = query.split()
bm25_scores = self.bm25.get_scores(tokenized_query)
# Format BM25 results
bm25_results = []
for i, segment in enumerate(self.segments):
result = {
"metadata": segment["metadata"],
"content": segment["content"],
"score": float(bm25_scores[i]),
"bm25_score": float(bm25_scores[i]),
"vector_score": 0.0
}
bm25_results.append((result, float(bm25_scores[i])))
# Perform vector search
vector_results = self.vector_store.search(query, k * 2) # Get more results for re-ranking
# Use RRF for fusion if both results are available
if bm25_results and vector_results:
fused_results = self.reciprocal_rank_fusion(bm25_results, vector_results, k=60)
result = fused_results[:k]
elif bm25_results:
# Only BM25 results
bm25_results.sort(key=lambda x: x[1], reverse=True)
result = bm25_results[:k]
elif vector_results:
# Only vector results
result = vector_results[:k]
else:
# No results
result = []
# 缓存结果
cache_manager.put(cache_key, result)
return result
def search_with_context(self, query: str, k: int = settings.TOP_K,
include_context: bool = True) -> List[Tuple[Dict, float]]:
"""
Search and include context (previous and next segments)
Args:
query: Search query
k: Number of results to return
include_context: Whether to include context segments
Returns:
List of (segment, score) tuples
"""
# Perform hybrid search
results = self.hybrid_search(query, k)
if not include_context:
return results
# Add context segments
enriched_results = []
for result, score in results:
# Add previous and next segments if they exist
file_name = result["metadata"]["file_name"]
paragraph_id = result["metadata"]["paragraph_id"]
# Find previous segment
prev_segment = None
for segment in self.segments:
if (segment["metadata"]["file_name"] == file_name and
segment["metadata"]["paragraph_id"] == paragraph_id - 1):
prev_segment = segment
break
# Find next segment
next_segment = None
for segment in self.segments:
if (segment["metadata"]["file_name"] == file_name and
segment["metadata"]["paragraph_id"] == paragraph_id + 1):
next_segment = segment
break
# Add context to result
enriched_result = result.copy()
enriched_result["context"] = {
"previous": prev_segment,
"next": next_segment
}
enriched_results.append((enriched_result, score))
return enriched_results