diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index 69d894afde..67268e9dca 100755 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -69,7 +69,7 @@ class ChromaClient(VectorDBBase): return self.client.delete_collection(name=collection_name) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10 ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. try: @@ -78,6 +78,7 @@ class ChromaClient(VectorDBBase): result = collection.query( query_embeddings=vectors, n_results=limit, + where=filter, ) # chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1 diff --git a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py index 6de0d859f8..46cd1ad920 100644 --- a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py +++ b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py @@ -153,7 +153,7 @@ class ElasticsearchClient(VectorDBBase): # Status: works def search( - self, collection_name: str, vectors: list[list[float]], limit: int + self, collection_name: str, vectors: list[list[float]], filter: Optional[dict] = None, limit: int = 10 ) -> Optional[SearchResult]: query = { "size": limit, diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 23e4bbd03e..69ca3988d9 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -179,7 +179,7 @@ class MilvusClient(VectorDBBase): ) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10 ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. collection_name = collection_name.replace("-", "_") diff --git a/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py index 203a36141e..5dfad33de5 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py @@ -157,7 +157,7 @@ class MilvusClient(VectorDBBase): collection.insert(entities) def search( - self, collection_name: str, vectors: List[List[float]], limit: int + self, collection_name: str, vectors: List[List[float]], filter: Optional[Dict] = None, limit: int = 10 ) -> Optional[SearchResult]: if not vectors: return None diff --git a/backend/open_webui/retrieval/vector/dbs/opengauss.py b/backend/open_webui/retrieval/vector/dbs/opengauss.py index 056a8a61cc..7d4f9ea092 100644 --- a/backend/open_webui/retrieval/vector/dbs/opengauss.py +++ b/backend/open_webui/retrieval/vector/dbs/opengauss.py @@ -233,7 +233,8 @@ class OpenGaussClient(VectorDBBase): self, collection_name: str, vectors: List[List[float]], - limit: Optional[int] = None, + filter: Optional[Dict[str, Any]] = None, + limit: int = 10, ) -> Optional[SearchResult]: try: if not vectors: diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 2e946710e2..e6bd99b976 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -113,7 +113,7 @@ class OpenSearchClient(VectorDBBase): self.client.indices.delete(index=self._get_index_name(collection_name)) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10 ) -> Optional[SearchResult]: try: if not self.has_collection(collection_name): diff --git a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py index 3f5c3463f0..db7b869439 100644 --- a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py +++ b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py @@ -521,7 +521,7 @@ class Oracle23aiClient(VectorDBBase): raise def search( - self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10 ) -> Optional[SearchResult]: """ Search for similar vectors in the database. diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index b9b1dba07b..6ae09cf05f 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -427,7 +427,8 @@ class PgvectorClient(VectorDBBase): self, collection_name: str, vectors: List[List[float]], - limit: Optional[int] = None, + filter: Optional[Dict[str, Any]] = None, + limit: int = 10, ) -> Optional[SearchResult]: try: if not vectors: @@ -475,9 +476,40 @@ class PgvectorClient(VectorDBBase): ) # Build the lateral subquery for each query vector + where_clauses = [DocumentChunk.collection_name == collection_name] + + # Apply metadata filter if provided + if filter: + for key, value in filter.items(): + if isinstance(value, dict) and "$in" in value: + # Handle $in operator: {"field": {"$in": [values]}} + in_values = value["$in"] + if PGVECTOR_PGCRYPTO: + where_clauses.append( + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + )[key].astext.in_([str(v) for v in in_values]) + ) + else: + where_clauses.append( + DocumentChunk.vmetadata[key].astext.in_([str(v) for v in in_values]) + ) + else: + # Handle simple equality: {"field": "value"} + if PGVECTOR_PGCRYPTO: + where_clauses.append( + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + )[key].astext == str(value) + ) + else: + where_clauses.append( + DocumentChunk.vmetadata[key].astext == str(value) + ) + subq = ( select(*result_fields) - .where(DocumentChunk.collection_name == collection_name) + .where(*where_clauses) .order_by( (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) ) diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index 94d09dabf5..22b9cb98b1 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -391,7 +391,7 @@ class PineconeClient(VectorDBBase): ) def search( - self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10 ) -> Optional[SearchResult]: """Search for similar vectors in a collection.""" if not vectors or not vectors[0]: diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index ce7095bea2..efa33a681d 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -145,7 +145,7 @@ class QdrantClient(VectorDBBase): ) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, collection_name: str, vectors: list[list[float | int]], filter: Optional[dict] = None, limit: int = 10 ) -> Optional[SearchResult]: # Search for the nearest neighbor items based on the vectors and return 'limit' number of results. if limit is None: diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py index fdc8f9d897..70ec6d6068 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py @@ -254,7 +254,7 @@ class QdrantClient(VectorDBBase): ) def search( - self, collection_name: str, vectors: List[List[float | int]], limit: int + self, collection_name: str, vectors: List[List[float | int]], filter: Optional[Dict] = None, limit: int = 10 ) -> Optional[SearchResult]: """ Search for the nearest neighbor items based on the vectors with tenant isolation. diff --git a/backend/open_webui/retrieval/vector/dbs/s3vector.py b/backend/open_webui/retrieval/vector/dbs/s3vector.py index 95fc5d3f24..d6784f9f26 100644 --- a/backend/open_webui/retrieval/vector/dbs/s3vector.py +++ b/backend/open_webui/retrieval/vector/dbs/s3vector.py @@ -295,7 +295,7 @@ class S3VectorClient(VectorDBBase): raise def search( - self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10 ) -> Optional[SearchResult]: """ Search for similar vectors in a collection using multiple query vectors. diff --git a/backend/open_webui/retrieval/vector/dbs/weaviate.py b/backend/open_webui/retrieval/vector/dbs/weaviate.py index 6bb8a1ecb4..680a6c9730 100644 --- a/backend/open_webui/retrieval/vector/dbs/weaviate.py +++ b/backend/open_webui/retrieval/vector/dbs/weaviate.py @@ -159,7 +159,7 @@ class WeaviateClient(VectorDBBase): ) def search( - self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + self, collection_name: str, vectors: List[List[Union[float, int]]], filter: Optional[dict] = None, limit: int = 10 ) -> Optional[SearchResult]: sane_collection_name = self._sanitize_collection_name(collection_name) if not self.client.collections.exists(sane_collection_name): diff --git a/backend/open_webui/retrieval/vector/main.py b/backend/open_webui/retrieval/vector/main.py index 53f752f579..a76fec9956 100644 --- a/backend/open_webui/retrieval/vector/main.py +++ b/backend/open_webui/retrieval/vector/main.py @@ -53,7 +53,11 @@ class VectorDBBase(ABC): @abstractmethod def search( - self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + self, + collection_name: str, + vectors: List[List[Union[float, int]]], + filter: Optional[Dict] = None, + limit: int = 10, ) -> Optional[SearchResult]: """Search for similar vectors in a collection.""" pass diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index c671759303..9fc30424ca 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -46,6 +46,54 @@ router = APIRouter() PAGE_ITEM_COUNT = 30 +############################ +# Knowledge Base Embedding +############################ + +KNOWLEDGE_BASES_COLLECTION = "knowledge-bases" + + +async def embed_knowledge_base_metadata( + request: Request, + knowledge_base_id: str, + name: str, + description: str, +) -> bool: + """Generate and store embedding for knowledge base.""" + try: + content = f"{name}\n\n{description}" if description else name + embedding = await request.app.state.EMBEDDING_FUNCTION(content) + VECTOR_DB_CLIENT.upsert( + collection_name=KNOWLEDGE_BASES_COLLECTION, + items=[ + { + "id": knowledge_base_id, + "text": content, + "vector": embedding, + "metadata": { + "knowledge_base_id": knowledge_base_id, + }, + } + ], + ) + return True + except Exception as e: + log.error(f"Failed to embed knowledge base {knowledge_base_id}: {e}") + return False + + +def remove_knowledge_base_metadata_embedding(knowledge_base_id: str) -> bool: + """Remove knowledge base embedding.""" + try: + VECTOR_DB_CLIENT.delete( + collection_name=KNOWLEDGE_BASES_COLLECTION, + ids=[knowledge_base_id], + ) + return True + except Exception as e: + log.debug(f"Failed to remove embedding for {knowledge_base_id}: {e}") + return False + class KnowledgeAccessResponse(KnowledgeUserResponse): write_access: Optional[bool] = False @@ -205,6 +253,13 @@ async def create_new_knowledge( knowledge = Knowledges.insert_new_knowledge(user.id, form_data, db=db) if knowledge: + # Embed knowledge base for semantic search + await embed_knowledge_base_metadata( + request, + knowledge.id, + knowledge.name, + knowledge.description, + ) return knowledge else: raise HTTPException( @@ -281,6 +336,30 @@ async def reindex_knowledge_files( return True +############################ +# ReindexKnowledgeBases +############################ + + +@router.post("/metadata/reindex", response_model=dict) +async def reindex_knowledge_base_metadata_embeddings( + request: Request, + user=Depends(get_admin_user), + db: Session = Depends(get_session), +): + """Batch embed all existing knowledge bases. Admin only.""" + knowledge_bases = Knowledges.get_knowledge_bases(db=db) + log.info(f"Reindexing embeddings for {len(knowledge_bases)} knowledge bases") + + success_count = 0 + for kb in knowledge_bases: + if await embed_knowledge_base_metadata(request, kb.id, kb.name, kb.description): + success_count += 1 + + log.info(f"Embedding reindex complete: {success_count}/{len(knowledge_bases)}") + return {"total": len(knowledge_bases), "success": success_count} + + ############################ # GetKnowledgeById ############################ @@ -369,6 +448,13 @@ async def update_knowledge_by_id( knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data, db=db) if knowledge: + # Re-embed knowledge base for semantic search + await embed_knowledge_base_metadata( + request, + knowledge.id, + knowledge.name, + knowledge.description, + ) return KnowledgeFilesResponse( **knowledge.model_dump(), files=Knowledges.get_file_metadatas_by_id(knowledge.id, db=db), @@ -718,6 +804,10 @@ async def delete_knowledge_by_id( except Exception as e: log.debug(e) pass + + # Remove knowledge base embedding + remove_knowledge_base_metadata_embedding(id) + result = Knowledges.delete_knowledge_by_id(id=id, db=db) return result diff --git a/backend/open_webui/tools/builtin.py b/backend/open_webui/tools/builtin.py index 2f57986184..7c691f2e29 100644 --- a/backend/open_webui/tools/builtin.py +++ b/backend/open_webui/tools/builtin.py @@ -39,6 +39,8 @@ from open_webui.models.groups import Groups log = logging.getLogger(__name__) +MAX_KNOWLEDGE_BASE_SEARCH_ITEMS = 10_000 + # ============================================================================= # TIME UTILITIES # ============================================================================= @@ -1413,7 +1415,7 @@ async def view_knowledge_file( return json.dumps({"error": str(e)}) -async def query_knowledge_bases( +async def query_knowledge_files( query: str, knowledge_ids: Optional[list[str]] = None, count: int = 5, @@ -1422,7 +1424,7 @@ async def query_knowledge_bases( __model_knowledge__: list[dict] = None, ) -> str: """ - Search internal knowledge bases using semantic/vector search. This should be your first + Search knowledge base files using semantic/vector search. This should be your first choice for finding information before searching the web. Searches across collections (KBs), individual files, and notes that the user has access to. @@ -1558,6 +1560,105 @@ async def query_knowledge_bases( chunks = chunks[:count] return json.dumps(chunks, ensure_ascii=False) + except Exception as e: + log.exception(f"query_knowledge_files error: {e}") + return json.dumps({"error": str(e)}) + + +async def query_knowledge_bases( + query: str, + count: int = 5, + __request__: Request = None, + __user__: dict = None, +) -> str: + """ + Search knowledge bases by semantic similarity to query. + Finds KBs whose name/description match the meaning of your query. + Use this to discover relevant knowledge bases before querying their files. + + :param query: Natural language query describing what you're looking for + :param count: Maximum results (default: 5) + :return: JSON with matching KBs (id, name, description, similarity) + """ + if __request__ is None: + return json.dumps({"error": "Request context not available"}) + + if not __user__: + return json.dumps({"error": "User context not available"}) + + try: + import heapq + from open_webui.models.knowledge import Knowledges + from open_webui.routers.knowledge import KNOWLEDGE_BASES_COLLECTION + from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT + + user_id = __user__.get("id") + user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id)] + query_embedding = await __request__.app.state.EMBEDDING_FUNCTION(query) + + # Min-heap of (distance, knowledge_base_id) - only holds top `count` results + top_results_heap = [] + seen_ids = set() + page_offset = 0 + page_size = 100 + + while True: + accessible_knowledge_bases = Knowledges.search_knowledge_bases( + user_id, + filter={"user_id": user_id, "group_ids": user_group_ids}, + skip=page_offset, + limit=page_size, + ) + + if not accessible_knowledge_bases.items: + break + + accessible_ids = [kb.id for kb in accessible_knowledge_bases.items] + + search_results = VECTOR_DB_CLIENT.search( + collection_name=KNOWLEDGE_BASES_COLLECTION, + vectors=[query_embedding], + filter={"knowledge_base_id": {"$in": accessible_ids}}, + limit=count, + ) + + if search_results and search_results.ids and search_results.ids[0]: + result_ids = search_results.ids[0] + result_distances = search_results.distances[0] if search_results.distances else [0] * len(result_ids) + + for knowledge_base_id, distance in zip(result_ids, result_distances): + if knowledge_base_id in seen_ids: + continue + seen_ids.add(knowledge_base_id) + + if len(top_results_heap) < count: + heapq.heappush(top_results_heap, (distance, knowledge_base_id)) + elif distance > top_results_heap[0][0]: + heapq.heapreplace(top_results_heap, (distance, knowledge_base_id)) + + page_offset += page_size + if len(accessible_knowledge_bases.items) < page_size: + break + if page_offset >= MAX_KNOWLEDGE_BASE_SEARCH_ITEMS: + break + + # Sort by distance descending (best first) and fetch KB details + sorted_results = sorted(top_results_heap, key=lambda x: x[0], reverse=True) + + matching_knowledge_bases = [] + for distance, knowledge_base_id in sorted_results: + knowledge_base = Knowledges.get_knowledge_by_id(knowledge_base_id) + if knowledge_base: + matching_knowledge_bases.append({ + "id": knowledge_base.id, + "name": knowledge_base.name, + "description": knowledge_base.description or "", + "similarity": round(distance, 4), + }) + + return json.dumps(matching_knowledge_bases, ensure_ascii=False) + except Exception as e: log.exception(f"query_knowledge_bases error: {e}") return json.dumps({"error": str(e)}) + diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 35a0e42b5c..fe2d7e5dc1 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -157,7 +157,7 @@ def get_citation_source_from_tool_result( - document: list of document contents - metadata: list of metadata objects with source, file_id, name fields - Returns a list of sources (usually one, but query_knowledge_bases may return multiple). + Returns a list of sources (usually one, but query_knowledge_files may return multiple). """ try: if tool_name == "search_web": @@ -217,7 +217,7 @@ def get_citation_source_from_tool_result( } ] - elif tool_name == "query_knowledge_bases": + elif tool_name == "query_knowledge_files": chunks = json.loads(tool_result) # Group chunks by source for better citation display @@ -3343,7 +3343,7 @@ async def process_chat_response( in [ "search_web", "view_knowledge_file", - "query_knowledge_bases", + "query_knowledge_files", ] and tool_result ): diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 700b4c6765..6cb6c4b856 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -68,9 +68,10 @@ from open_webui.tools.builtin import ( write_note, list_knowledge_bases, search_knowledge_bases, - search_knowledge_files, - view_knowledge_file, query_knowledge_bases, + search_knowledge_files, + query_knowledge_files, + view_knowledge_file, ) import copy @@ -406,21 +407,22 @@ def get_builtin_tools( builtin_functions.extend([get_current_timestamp, calculate_timestamp]) # Knowledge base tools - conditional injection based on model knowledge - # If model has attached knowledge (any type), only provide query_knowledge_bases + # If model has attached knowledge (any type), only provide query_knowledge_files # Otherwise, provide all KB browsing tools model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", []) if model_knowledge: # Model has attached knowledge - only allow semantic search within it - builtin_functions.append(query_knowledge_bases) + builtin_functions.append(query_knowledge_files) else: # No model knowledge - allow full KB browsing builtin_functions.extend( [ list_knowledge_bases, search_knowledge_bases, - search_knowledge_files, - view_knowledge_file, query_knowledge_bases, + search_knowledge_files, + query_knowledge_files, + view_knowledge_file, ] )