mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-25 17:15:16 +02:00
278 lines
10 KiB
Python
278 lines
10 KiB
Python
"""
|
|
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Optional, Tuple, List, Dict, Any
|
|
|
|
from open_webui.config import (
|
|
MILVUS_URI,
|
|
MILVUS_TOKEN,
|
|
MILVUS_DB,
|
|
MILVUS_COLLECTION_PREFIX,
|
|
MILVUS_INDEX_TYPE,
|
|
MILVUS_METRIC_TYPE,
|
|
MILVUS_HNSW_M,
|
|
MILVUS_HNSW_EFCONSTRUCTION,
|
|
MILVUS_IVF_FLAT_NLIST,
|
|
)
|
|
from open_webui.retrieval.vector.main import (
|
|
GetResult,
|
|
SearchResult,
|
|
VectorDBBase,
|
|
VectorItem,
|
|
)
|
|
from pymilvus import (
|
|
connections,
|
|
utility,
|
|
Collection,
|
|
CollectionSchema,
|
|
FieldSchema,
|
|
DataType,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
RESOURCE_ID_FIELD = 'resource_id'
|
|
|
|
|
|
class MilvusClient(VectorDBBase):
|
|
def __init__(self):
|
|
# Milvus collection names can only contain numbers, letters, and underscores.
|
|
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace('-', '_')
|
|
connections.connect(
|
|
alias='default',
|
|
uri=MILVUS_URI,
|
|
token=MILVUS_TOKEN,
|
|
db_name=MILVUS_DB,
|
|
)
|
|
|
|
# Main collection types for multi-tenancy
|
|
self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories'
|
|
self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge'
|
|
self.FILE_COLLECTION = f'{self.collection_prefix}_files'
|
|
self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web_search'
|
|
self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash_based'
|
|
self.shared_collections = [
|
|
self.MEMORY_COLLECTION,
|
|
self.KNOWLEDGE_COLLECTION,
|
|
self.FILE_COLLECTION,
|
|
self.WEB_SEARCH_COLLECTION,
|
|
self.HASH_BASED_COLLECTION,
|
|
]
|
|
|
|
def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
|
|
"""
|
|
Maps the traditional collection name to multi-tenant collection and resource ID.
|
|
|
|
WARNING: This mapping relies on current Open WebUI naming conventions for
|
|
collection names. If Open WebUI changes how it generates collection names
|
|
(e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash
|
|
formats), this mapping will break and route data to incorrect collections.
|
|
POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
|
|
DATA MAPPING INSIDE THE DATABASE.
|
|
"""
|
|
resource_id = collection_name
|
|
|
|
if collection_name.startswith('user-memory-'):
|
|
return self.MEMORY_COLLECTION, resource_id
|
|
elif collection_name.startswith('file-'):
|
|
return self.FILE_COLLECTION, resource_id
|
|
elif collection_name.startswith('web-search-'):
|
|
return self.WEB_SEARCH_COLLECTION, resource_id
|
|
elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name):
|
|
return self.HASH_BASED_COLLECTION, resource_id
|
|
else:
|
|
return self.KNOWLEDGE_COLLECTION, resource_id
|
|
|
|
def _create_shared_collection(self, mt_collection_name: str, dimension: int):
|
|
fields = [
|
|
FieldSchema(
|
|
name='id',
|
|
dtype=DataType.VARCHAR,
|
|
is_primary=True,
|
|
auto_id=False,
|
|
max_length=36,
|
|
),
|
|
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
|
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
|
|
FieldSchema(name='metadata', dtype=DataType.JSON),
|
|
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
|
|
]
|
|
schema = CollectionSchema(fields, 'Shared collection for multi-tenancy')
|
|
collection = Collection(mt_collection_name, schema)
|
|
|
|
index_params = {
|
|
'metric_type': MILVUS_METRIC_TYPE,
|
|
'index_type': MILVUS_INDEX_TYPE,
|
|
'params': {},
|
|
}
|
|
if MILVUS_INDEX_TYPE == 'HNSW':
|
|
index_params['params'] = {
|
|
'M': MILVUS_HNSW_M,
|
|
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
|
|
}
|
|
elif MILVUS_INDEX_TYPE == 'IVF_FLAT':
|
|
index_params['params'] = {'nlist': MILVUS_IVF_FLAT_NLIST}
|
|
|
|
collection.create_index('vector', index_params)
|
|
collection.create_index(RESOURCE_ID_FIELD)
|
|
log.info(f'Created shared collection: {mt_collection_name}')
|
|
return collection
|
|
|
|
def _ensure_collection(self, mt_collection_name: str, dimension: int):
|
|
if not utility.has_collection(mt_collection_name):
|
|
self._create_shared_collection(mt_collection_name, dimension)
|
|
|
|
def has_collection(self, collection_name: str) -> bool:
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
|
if not utility.has_collection(mt_collection):
|
|
return False
|
|
|
|
collection = Collection(mt_collection)
|
|
collection.load()
|
|
res = collection.query(expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1)
|
|
return len(res) > 0
|
|
|
|
def upsert(self, collection_name: str, items: List[VectorItem]):
|
|
if not items:
|
|
return
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
|
dimension = len(items[0]['vector'])
|
|
self._ensure_collection(mt_collection, dimension)
|
|
collection = Collection(mt_collection)
|
|
|
|
entities = [
|
|
{
|
|
'id': item['id'],
|
|
'vector': item['vector'],
|
|
'text': item['text'],
|
|
'metadata': item['metadata'],
|
|
RESOURCE_ID_FIELD: resource_id,
|
|
}
|
|
for item in items
|
|
]
|
|
collection.insert(entities)
|
|
|
|
def search(
|
|
self,
|
|
collection_name: str,
|
|
vectors: List[List[float]],
|
|
filter: Optional[Dict] = None,
|
|
limit: int = 10,
|
|
) -> Optional[SearchResult]:
|
|
if not vectors:
|
|
return None
|
|
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
|
if not utility.has_collection(mt_collection):
|
|
return None
|
|
|
|
collection = Collection(mt_collection)
|
|
collection.load()
|
|
|
|
search_params = {'metric_type': MILVUS_METRIC_TYPE, 'params': {}}
|
|
results = collection.search(
|
|
data=vectors,
|
|
anns_field='vector',
|
|
param=search_params,
|
|
limit=limit,
|
|
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
|
|
output_fields=['id', 'text', 'metadata'],
|
|
)
|
|
|
|
ids, documents, metadatas, distances = [], [], [], []
|
|
for hits in results:
|
|
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
|
|
for hit in hits:
|
|
batch_ids.append(hit.entity.get('id'))
|
|
batch_docs.append(hit.entity.get('text'))
|
|
batch_metadatas.append(hit.entity.get('metadata'))
|
|
batch_dists.append(hit.distance)
|
|
ids.append(batch_ids)
|
|
documents.append(batch_docs)
|
|
metadatas.append(batch_metadatas)
|
|
distances.append(batch_dists)
|
|
|
|
return SearchResult(ids=ids, documents=documents, metadatas=metadatas, distances=distances)
|
|
|
|
def delete(
|
|
self,
|
|
collection_name: str,
|
|
ids: Optional[List[str]] = None,
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
):
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
|
if not utility.has_collection(mt_collection):
|
|
return
|
|
|
|
collection = Collection(mt_collection)
|
|
|
|
# Build expression
|
|
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
|
|
if ids:
|
|
# Milvus expects a string list for 'in' operator
|
|
id_list_str = ', '.join([f"'{id_val}'" for id_val in ids])
|
|
expr.append(f'id in [{id_list_str}]')
|
|
|
|
if filter:
|
|
for key, value in filter.items():
|
|
expr.append(f"metadata['{key}'] == '{value}'")
|
|
|
|
collection.delete(' and '.join(expr))
|
|
|
|
def reset(self):
|
|
for collection_name in self.shared_collections:
|
|
if utility.has_collection(collection_name):
|
|
utility.drop_collection(collection_name)
|
|
|
|
def delete_collection(self, collection_name: str):
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
|
if not utility.has_collection(mt_collection):
|
|
return
|
|
|
|
collection = Collection(mt_collection)
|
|
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
|
|
|
|
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
|
if not utility.has_collection(mt_collection):
|
|
return None
|
|
|
|
collection = Collection(mt_collection)
|
|
collection.load()
|
|
|
|
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
|
|
if filter:
|
|
for key, value in filter.items():
|
|
if isinstance(value, str):
|
|
expr.append(f"metadata['{key}'] == '{value}'")
|
|
else:
|
|
expr.append(f"metadata['{key}'] == {value}")
|
|
|
|
iterator = collection.query_iterator(
|
|
expr=' and '.join(expr),
|
|
output_fields=['id', 'text', 'metadata'],
|
|
limit=limit if limit else -1,
|
|
)
|
|
|
|
all_results = []
|
|
while True:
|
|
batch = iterator.next()
|
|
if not batch:
|
|
iterator.close()
|
|
break
|
|
all_results.extend(batch)
|
|
|
|
ids = [res['id'] for res in all_results]
|
|
documents = [res['text'] for res in all_results]
|
|
metadatas = [res['metadata'] for res in all_results]
|
|
|
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
|
|
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
|
return self.query(collection_name, filter={}, limit=None)
|
|
|
|
def insert(self, collection_name: str, items: List[VectorItem]):
|
|
return self.upsert(collection_name, items)
|