make bm25_weight a regular parameter of query_doc.. / get_sources_from_files functions

This commit is contained in:
Jan Kessler
2025-05-20 11:21:14 +02:00
parent b5ddaf6417
commit 308d8ac04a
3 changed files with 19 additions and 4 deletions

View File

@@ -29,7 +29,6 @@ from open_webui.config import (
RAG_EMBEDDING_QUERY_PREFIX,
RAG_EMBEDDING_CONTENT_PREFIX,
RAG_EMBEDDING_PREFIX_FIELD_NAME,
RAG_BM25_WEIGHT,
)
log = logging.getLogger(__name__)
@@ -117,6 +116,7 @@ def query_doc_with_hybrid_search(
reranking_function,
k_reranker: int,
r: float,
bm25_weight: float,
) -> dict:
try:
log.debug(f"query_doc_with_hybrid_search:doc {collection_name}")
@@ -132,18 +132,18 @@ def query_doc_with_hybrid_search(
top_k=k,
)
if RAG_BM25_WEIGHT <= 0:
if bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_search_retriever], weights=[1.]
)
elif RAG_BM25_WEIGHT >= 1:
elif bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever], weights=[1.]
)
else:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever],
weights=[RAG_BM25_WEIGHT, 1. - RAG_BM25_WEIGHT]
weights=[bm25_weight, 1. - bm25_weight]
)
compressor = RerankCompressor(
@@ -325,6 +325,7 @@ def query_collection_with_hybrid_search(
reranking_function,
k_reranker: int,
r: float,
bm25_weight: float,
) -> dict:
results = []
error = False
@@ -358,6 +359,7 @@ def query_collection_with_hybrid_search(
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
bm25_weight=bm25_weight,
)
return result, None
except Exception as e:
@@ -445,6 +447,7 @@ def get_sources_from_files(
reranking_function,
k_reranker,
r,
bm25_weight,
hybrid_search,
full_context=False,
):
@@ -562,6 +565,7 @@ def get_sources_from_files(
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
bm25_weight=bm25_weight,
)
except Exception as e:
log.debug(