199 lines
7.1 KiB
Python
199 lines
7.1 KiB
Python
import numpy as np
|
|
from typing import List, Tuple, Dict
|
|
from rank_bm25 import BM25Okapi
|
|
from vector_store import VectorStore
|
|
from config import settings
|
|
from exceptions import RetrievalError
|
|
from cache_manager import cache_manager
|
|
import hashlib
|
|
|
|
class HybridRetriever:
|
|
def __init__(self, vector_store: VectorStore):
|
|
self.vector_store = vector_store
|
|
self.bm25 = None
|
|
self.segments = []
|
|
self.segment_contents = []
|
|
|
|
def prepare_bm25(self, segments: List[Dict]):
|
|
"""
|
|
Prepare BM25 index with segments
|
|
|
|
Args:
|
|
segments: List of document segments
|
|
"""
|
|
self.segments = segments
|
|
|
|
# Extract content for BM25
|
|
self.segment_contents = [segment["content"] for segment in segments]
|
|
|
|
# Tokenize content for BM25
|
|
tokenized_corpus = [doc.split() for doc in self.segment_contents]
|
|
|
|
# Create BM25 index
|
|
self.bm25 = BM25Okapi(tokenized_corpus)
|
|
|
|
def reciprocal_rank_fusion(self, bm25_results: List[Tuple[Dict, float]],
|
|
vector_results: List[Tuple[Dict, float]],
|
|
k: int = 60) -> List[Tuple[Dict, float]]:
|
|
"""
|
|
使用Reciprocal Rank Fusion (RRF)融合BM25和向量检索结果
|
|
|
|
Args:
|
|
bm25_results: BM25检索结果
|
|
vector_results: 向量检索结果
|
|
k: RRF参数
|
|
|
|
Returns:
|
|
融合后的结果列表
|
|
"""
|
|
# 创建段落ID到结果的映射
|
|
bm25_map = {result[0]["metadata"]["segment_id"]: (result[0], result[1], i+1)
|
|
for i, result in enumerate(bm25_results)}
|
|
vector_map = {result[0]["metadata"]["segment_id"]: (result[0], result[1], i+1)
|
|
for i, result in enumerate(vector_results)}
|
|
|
|
# 获取所有唯一的段落ID
|
|
all_segment_ids = set(bm25_map.keys()) | set(vector_map.keys())
|
|
|
|
# 计算RRF得分
|
|
fused_scores = {}
|
|
for segment_id in all_segment_ids:
|
|
bm25_rank = bm25_map[segment_id][2] if segment_id in bm25_map else len(bm25_results) + 1
|
|
vector_rank = vector_map[segment_id][2] if segment_id in vector_map else len(vector_results) + 1
|
|
|
|
# RRF公式
|
|
rrf_score = 1.0 / (bm25_rank + k) + 1.0 / (vector_rank + k)
|
|
fused_scores[segment_id] = rrf_score
|
|
|
|
# 排序并返回结果
|
|
sorted_segments = sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
|
|
|
|
# 构建结果
|
|
results = []
|
|
for segment_id, score in sorted_segments:
|
|
if segment_id in bm25_map:
|
|
result = bm25_map[segment_id][0]
|
|
else:
|
|
result = vector_map[segment_id][0]
|
|
results.append((result, score))
|
|
|
|
return results
|
|
|
|
def hybrid_search(self, query: str, k: int = settings.TOP_K,
|
|
bm25_weight: float = 0.5) -> List[Tuple[Dict, float]]:
|
|
"""
|
|
Perform hybrid search using BM25 and vector search
|
|
|
|
Args:
|
|
query: Search query
|
|
k: Number of results to return
|
|
bm25_weight: Weight for BM25 scores (0.0 to 1.0)
|
|
|
|
Returns:
|
|
List of (segment, score) tuples sorted by combined score
|
|
"""
|
|
# 生成缓存键
|
|
cache_key = f"hybrid_search:{hashlib.md5(f'{query}:{k}:{bm25_weight}'.encode()).hexdigest()}"
|
|
|
|
# 尝试从缓存获取结果
|
|
cached_result = cache_manager.get(cache_key)
|
|
if cached_result is not None:
|
|
return cached_result
|
|
|
|
if self.bm25 is None:
|
|
# Fallback to vector search only
|
|
result = self.vector_store.search(query, k)
|
|
# 缓存结果
|
|
cache_manager.put(cache_key, result)
|
|
return result
|
|
|
|
# Perform BM25 search
|
|
tokenized_query = query.split()
|
|
bm25_scores = self.bm25.get_scores(tokenized_query)
|
|
|
|
# Format BM25 results
|
|
bm25_results = []
|
|
for i, segment in enumerate(self.segments):
|
|
result = {
|
|
"metadata": segment["metadata"],
|
|
"content": segment["content"],
|
|
"score": float(bm25_scores[i]),
|
|
"bm25_score": float(bm25_scores[i]),
|
|
"vector_score": 0.0
|
|
}
|
|
bm25_results.append((result, float(bm25_scores[i])))
|
|
|
|
# Perform vector search
|
|
vector_results = self.vector_store.search(query, k * 2) # Get more results for re-ranking
|
|
|
|
# Use RRF for fusion if both results are available
|
|
if bm25_results and vector_results:
|
|
fused_results = self.reciprocal_rank_fusion(bm25_results, vector_results, k=60)
|
|
result = fused_results[:k]
|
|
elif bm25_results:
|
|
# Only BM25 results
|
|
bm25_results.sort(key=lambda x: x[1], reverse=True)
|
|
result = bm25_results[:k]
|
|
elif vector_results:
|
|
# Only vector results
|
|
result = vector_results[:k]
|
|
else:
|
|
# No results
|
|
result = []
|
|
|
|
# 缓存结果
|
|
cache_manager.put(cache_key, result)
|
|
return result
|
|
|
|
def search_with_context(self, query: str, k: int = settings.TOP_K,
|
|
include_context: bool = True) -> List[Tuple[Dict, float]]:
|
|
"""
|
|
Search and include context (previous and next segments)
|
|
|
|
Args:
|
|
query: Search query
|
|
k: Number of results to return
|
|
include_context: Whether to include context segments
|
|
|
|
Returns:
|
|
List of (segment, score) tuples
|
|
"""
|
|
# Perform hybrid search
|
|
results = self.hybrid_search(query, k)
|
|
|
|
if not include_context:
|
|
return results
|
|
|
|
# Add context segments
|
|
enriched_results = []
|
|
for result, score in results:
|
|
# Add previous and next segments if they exist
|
|
file_name = result["metadata"]["file_name"]
|
|
paragraph_id = result["metadata"]["paragraph_id"]
|
|
|
|
# Find previous segment
|
|
prev_segment = None
|
|
for segment in self.segments:
|
|
if (segment["metadata"]["file_name"] == file_name and
|
|
segment["metadata"]["paragraph_id"] == paragraph_id - 1):
|
|
prev_segment = segment
|
|
break
|
|
|
|
# Find next segment
|
|
next_segment = None
|
|
for segment in self.segments:
|
|
if (segment["metadata"]["file_name"] == file_name and
|
|
segment["metadata"]["paragraph_id"] == paragraph_id + 1):
|
|
next_segment = segment
|
|
break
|
|
|
|
# Add context to result
|
|
enriched_result = result.copy()
|
|
enriched_result["context"] = {
|
|
"previous": prev_segment,
|
|
"next": next_segment
|
|
}
|
|
|
|
enriched_results.append((enriched_result, score))
|
|
|
|
return enriched_results |