mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-26 01:25:34 +02:00
668 lines
27 KiB
Python
668 lines
27 KiB
Python
from typing import Optional, List, Dict, Any, Tuple
|
|
import logging
|
|
import json
|
|
from sqlalchemy import (
|
|
func,
|
|
literal,
|
|
cast,
|
|
column,
|
|
create_engine,
|
|
Column,
|
|
Integer,
|
|
MetaData,
|
|
LargeBinary,
|
|
select,
|
|
text,
|
|
Text,
|
|
Table,
|
|
values,
|
|
)
|
|
from sqlalchemy.sql import true
|
|
from sqlalchemy.pool import NullPool, QueuePool
|
|
|
|
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
|
from sqlalchemy.dialects.postgresql import JSONB, array
|
|
from pgvector.sqlalchemy import Vector, HALFVEC
|
|
from sqlalchemy.ext.mutable import MutableDict
|
|
from sqlalchemy.exc import NoSuchTableError
|
|
|
|
|
|
from open_webui.retrieval.vector.utils import process_metadata
|
|
from open_webui.retrieval.vector.main import (
|
|
VectorDBBase,
|
|
VectorItem,
|
|
SearchResult,
|
|
GetResult,
|
|
)
|
|
from open_webui.config import (
|
|
PGVECTOR_DB_URL,
|
|
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
|
|
PGVECTOR_CREATE_EXTENSION,
|
|
PGVECTOR_PGCRYPTO,
|
|
PGVECTOR_PGCRYPTO_KEY,
|
|
PGVECTOR_POOL_SIZE,
|
|
PGVECTOR_POOL_MAX_OVERFLOW,
|
|
PGVECTOR_POOL_TIMEOUT,
|
|
PGVECTOR_POOL_RECYCLE,
|
|
PGVECTOR_INDEX_METHOD,
|
|
PGVECTOR_HNSW_M,
|
|
PGVECTOR_HNSW_EF_CONSTRUCTION,
|
|
PGVECTOR_IVFFLAT_LISTS,
|
|
PGVECTOR_USE_HALFVEC,
|
|
)
|
|
|
|
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
|
USE_HALFVEC = PGVECTOR_USE_HALFVEC
|
|
|
|
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
|
|
VECTOR_OPCLASS = 'halfvec_cosine_ops' if USE_HALFVEC else 'vector_cosine_ops'
|
|
Base = declarative_base()
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def pgcrypto_encrypt(val, key):
|
|
return func.pgp_sym_encrypt(val, literal(key))
|
|
|
|
|
|
def pgcrypto_decrypt(col, key, outtype='text'):
|
|
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
|
|
|
|
|
|
class DocumentChunk(Base):
|
|
__tablename__ = 'document_chunk'
|
|
|
|
id = Column(Text, primary_key=True)
|
|
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
|
|
collection_name = Column(Text, nullable=False)
|
|
|
|
if PGVECTOR_PGCRYPTO:
|
|
text = Column(LargeBinary, nullable=True)
|
|
vmetadata = Column(LargeBinary, nullable=True)
|
|
else:
|
|
text = Column(Text, nullable=True)
|
|
vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True)
|
|
|
|
|
|
class PgvectorClient(VectorDBBase):
|
|
def __init__(self) -> None:
|
|
# if no pgvector uri, use the existing database connection
|
|
if not PGVECTOR_DB_URL:
|
|
from open_webui.internal.db import ScopedSession
|
|
|
|
self.session = ScopedSession
|
|
else:
|
|
if isinstance(PGVECTOR_POOL_SIZE, int):
|
|
if PGVECTOR_POOL_SIZE > 0:
|
|
engine = create_engine(
|
|
PGVECTOR_DB_URL,
|
|
pool_size=PGVECTOR_POOL_SIZE,
|
|
max_overflow=PGVECTOR_POOL_MAX_OVERFLOW,
|
|
pool_timeout=PGVECTOR_POOL_TIMEOUT,
|
|
pool_recycle=PGVECTOR_POOL_RECYCLE,
|
|
pool_pre_ping=True,
|
|
poolclass=QueuePool,
|
|
)
|
|
else:
|
|
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool)
|
|
else:
|
|
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
|
|
self.session = scoped_session(SessionLocal)
|
|
|
|
try:
|
|
# Ensure the pgvector extension is available
|
|
# Use a conditional check to avoid permission issues on Azure PostgreSQL
|
|
if PGVECTOR_CREATE_EXTENSION:
|
|
self.session.execute(
|
|
text("""
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
|
|
CREATE EXTENSION IF NOT EXISTS vector;
|
|
END IF;
|
|
END $$;
|
|
""")
|
|
)
|
|
|
|
if PGVECTOR_PGCRYPTO:
|
|
# Ensure the pgcrypto extension is available for encryption
|
|
# Use a conditional check to avoid permission issues on Azure PostgreSQL
|
|
self.session.execute(
|
|
text("""
|
|
DO $$
|
|
BEGIN
|
|
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
|
|
CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
|
END IF;
|
|
END $$;
|
|
""")
|
|
)
|
|
|
|
if not PGVECTOR_PGCRYPTO_KEY:
|
|
raise ValueError('PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled.')
|
|
|
|
# Check vector length consistency
|
|
self.check_vector_length()
|
|
|
|
# Create the tables if they do not exist
|
|
# Base.metadata.create_all requires a bind (engine or connection)
|
|
# Get the connection from the session
|
|
connection = self.session.connection()
|
|
Base.metadata.create_all(bind=connection)
|
|
|
|
index_method, index_options = self._vector_index_configuration()
|
|
self._ensure_vector_index(index_method, index_options)
|
|
|
|
self.session.execute(
|
|
text(
|
|
'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);'
|
|
)
|
|
)
|
|
self.session.commit()
|
|
log.info('Initialization complete.')
|
|
except Exception as e:
|
|
self.session.rollback()
|
|
log.exception(f'Error during initialization: {e}')
|
|
raise
|
|
|
|
@staticmethod
|
|
def _extract_index_method(index_def: Optional[str]) -> Optional[str]:
|
|
if not index_def:
|
|
return None
|
|
try:
|
|
after_using = index_def.lower().split('using ', 1)[1]
|
|
return after_using.split()[0]
|
|
except (IndexError, AttributeError):
|
|
return None
|
|
|
|
def _vector_index_configuration(self) -> Tuple[str, str]:
|
|
if PGVECTOR_INDEX_METHOD:
|
|
index_method = PGVECTOR_INDEX_METHOD
|
|
log.info(
|
|
"Using vector index method '%s' from PGVECTOR_INDEX_METHOD.",
|
|
index_method,
|
|
)
|
|
elif USE_HALFVEC:
|
|
index_method = 'hnsw'
|
|
log.info(
|
|
'VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.',
|
|
VECTOR_LENGTH,
|
|
)
|
|
else:
|
|
index_method = 'ivfflat'
|
|
|
|
if index_method == 'hnsw':
|
|
index_options = f'WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})'
|
|
else:
|
|
index_options = f'WITH (lists = {PGVECTOR_IVFFLAT_LISTS})'
|
|
|
|
return index_method, index_options
|
|
|
|
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
|
|
index_name = 'idx_document_chunk_vector'
|
|
existing_index_def = self.session.execute(
|
|
text("""
|
|
SELECT indexdef
|
|
FROM pg_indexes
|
|
WHERE schemaname = current_schema()
|
|
AND tablename = 'document_chunk'
|
|
AND indexname = :index_name
|
|
"""),
|
|
{'index_name': index_name},
|
|
).scalar()
|
|
|
|
existing_method = self._extract_index_method(existing_index_def)
|
|
if existing_method and existing_method != index_method:
|
|
raise RuntimeError(
|
|
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
|
|
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
|
|
'Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) '
|
|
'and recreate it with the new method before restarting Open WebUI.'
|
|
)
|
|
|
|
if not existing_index_def:
|
|
index_sql = (
|
|
f'CREATE INDEX IF NOT EXISTS {index_name} '
|
|
f'ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})'
|
|
)
|
|
if index_options:
|
|
index_sql = f'{index_sql} {index_options}'
|
|
self.session.execute(text(index_sql))
|
|
log.info(
|
|
"Ensured vector index '%s' using %s%s.",
|
|
index_name,
|
|
index_method,
|
|
f' {index_options}' if index_options else '',
|
|
)
|
|
|
|
def check_vector_length(self) -> None:
|
|
"""
|
|
Check if the VECTOR_LENGTH matches the existing vector column dimension in the database.
|
|
Raises an exception if there is a mismatch.
|
|
"""
|
|
metadata = MetaData()
|
|
try:
|
|
# Attempt to reflect the 'document_chunk' table
|
|
document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind)
|
|
except NoSuchTableError:
|
|
# Table does not exist; no action needed
|
|
return
|
|
|
|
# Proceed to check the vector column
|
|
if 'vector' in document_chunk_table.columns:
|
|
vector_column = document_chunk_table.columns['vector']
|
|
vector_type = vector_column.type
|
|
expected_type = HALFVEC if USE_HALFVEC else Vector
|
|
|
|
if not isinstance(vector_type, expected_type):
|
|
raise Exception(
|
|
"The 'vector' column type does not match the expected type "
|
|
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
|
|
)
|
|
|
|
db_vector_length = getattr(vector_type, 'dim', None)
|
|
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
|
|
raise Exception(
|
|
f'VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. '
|
|
'Cannot change vector size after initialization without migrating the data.'
|
|
)
|
|
else:
|
|
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
|
|
|
|
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
|
# Adjust vector to have length VECTOR_LENGTH
|
|
current_length = len(vector)
|
|
if current_length < VECTOR_LENGTH:
|
|
# Pad the vector with zeros
|
|
vector += [0.0] * (VECTOR_LENGTH - current_length)
|
|
elif current_length > VECTOR_LENGTH:
|
|
# Truncate the vector to VECTOR_LENGTH
|
|
vector = vector[:VECTOR_LENGTH]
|
|
return vector
|
|
|
|
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
|
try:
|
|
if PGVECTOR_PGCRYPTO:
|
|
for item in items:
|
|
vector = self.adjust_vector_length(item['vector'])
|
|
# Use raw SQL for BYTEA/pgcrypto
|
|
# Ensure metadata is converted to its JSON text representation
|
|
json_metadata = json.dumps(item['metadata'])
|
|
self.session.execute(
|
|
text("""
|
|
INSERT INTO document_chunk
|
|
(id, vector, collection_name, text, vmetadata)
|
|
VALUES (
|
|
:id, :vector, :collection_name,
|
|
pgp_sym_encrypt(:text, :key),
|
|
pgp_sym_encrypt(:metadata_text, :key)
|
|
)
|
|
ON CONFLICT (id) DO NOTHING
|
|
"""),
|
|
{
|
|
'id': item['id'],
|
|
'vector': vector,
|
|
'collection_name': collection_name,
|
|
'text': item['text'],
|
|
'metadata_text': json_metadata,
|
|
'key': PGVECTOR_PGCRYPTO_KEY,
|
|
},
|
|
)
|
|
self.session.commit()
|
|
log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'")
|
|
|
|
else:
|
|
new_items = []
|
|
for item in items:
|
|
vector = self.adjust_vector_length(item['vector'])
|
|
new_chunk = DocumentChunk(
|
|
id=item['id'],
|
|
vector=vector,
|
|
collection_name=collection_name,
|
|
text=item['text'],
|
|
vmetadata=process_metadata(item['metadata']),
|
|
)
|
|
new_items.append(new_chunk)
|
|
self.session.bulk_save_objects(new_items)
|
|
self.session.commit()
|
|
log.info(f"Inserted {len(new_items)} items into collection '{collection_name}'.")
|
|
except Exception as e:
|
|
self.session.rollback()
|
|
log.exception(f'Error during insert: {e}')
|
|
raise
|
|
|
|
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
|
try:
|
|
if PGVECTOR_PGCRYPTO:
|
|
for item in items:
|
|
vector = self.adjust_vector_length(item['vector'])
|
|
json_metadata = json.dumps(item['metadata'])
|
|
self.session.execute(
|
|
text("""
|
|
INSERT INTO document_chunk
|
|
(id, vector, collection_name, text, vmetadata)
|
|
VALUES (
|
|
:id, :vector, :collection_name,
|
|
pgp_sym_encrypt(:text, :key),
|
|
pgp_sym_encrypt(:metadata_text, :key)
|
|
)
|
|
ON CONFLICT (id) DO UPDATE SET
|
|
vector = EXCLUDED.vector,
|
|
collection_name = EXCLUDED.collection_name,
|
|
text = EXCLUDED.text,
|
|
vmetadata = EXCLUDED.vmetadata
|
|
"""),
|
|
{
|
|
'id': item['id'],
|
|
'vector': vector,
|
|
'collection_name': collection_name,
|
|
'text': item['text'],
|
|
'metadata_text': json_metadata,
|
|
'key': PGVECTOR_PGCRYPTO_KEY,
|
|
},
|
|
)
|
|
self.session.commit()
|
|
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
|
|
else:
|
|
for item in items:
|
|
vector = self.adjust_vector_length(item['vector'])
|
|
existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first()
|
|
if existing:
|
|
existing.vector = vector
|
|
existing.text = item['text']
|
|
existing.vmetadata = process_metadata(item['metadata'])
|
|
existing.collection_name = collection_name # Update collection_name if necessary
|
|
else:
|
|
new_chunk = DocumentChunk(
|
|
id=item['id'],
|
|
vector=vector,
|
|
collection_name=collection_name,
|
|
text=item['text'],
|
|
vmetadata=process_metadata(item['metadata']),
|
|
)
|
|
self.session.add(new_chunk)
|
|
self.session.commit()
|
|
log.info(f"Upserted {len(items)} items into collection '{collection_name}'.")
|
|
except Exception as e:
|
|
self.session.rollback()
|
|
log.exception(f'Error during upsert: {e}')
|
|
raise
|
|
|
|
def search(
|
|
self,
|
|
collection_name: str,
|
|
vectors: List[List[float]],
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
limit: int = 10,
|
|
) -> Optional[SearchResult]:
|
|
try:
|
|
if not vectors:
|
|
return None
|
|
|
|
# Adjust query vectors to VECTOR_LENGTH
|
|
vectors = [self.adjust_vector_length(vector) for vector in vectors]
|
|
num_queries = len(vectors)
|
|
|
|
def vector_expr(vector):
|
|
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
|
|
|
# Create the values for query vectors
|
|
qid_col = column('qid', Integer)
|
|
q_vector_col = column('q_vector', VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
|
query_vectors = (
|
|
values(qid_col, q_vector_col)
|
|
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
|
|
.alias('query_vectors')
|
|
)
|
|
|
|
result_fields = [
|
|
DocumentChunk.id,
|
|
]
|
|
if PGVECTOR_PGCRYPTO:
|
|
result_fields.append(pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'))
|
|
result_fields.append(
|
|
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata')
|
|
)
|
|
else:
|
|
result_fields.append(DocumentChunk.text)
|
|
result_fields.append(DocumentChunk.vmetadata)
|
|
result_fields.append((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance'))
|
|
|
|
# Build the lateral subquery for each query vector
|
|
where_clauses = [DocumentChunk.collection_name == collection_name]
|
|
|
|
# Apply metadata filter if provided
|
|
if filter:
|
|
for key, value in filter.items():
|
|
if isinstance(value, dict) and '$in' in value:
|
|
# Handle $in operator: {"field": {"$in": [values]}}
|
|
in_values = value['$in']
|
|
if PGVECTOR_PGCRYPTO:
|
|
where_clauses.append(
|
|
pgcrypto_decrypt(
|
|
DocumentChunk.vmetadata,
|
|
PGVECTOR_PGCRYPTO_KEY,
|
|
JSONB,
|
|
)[key].astext.in_([str(v) for v in in_values])
|
|
)
|
|
else:
|
|
where_clauses.append(DocumentChunk.vmetadata[key].astext.in_([str(v) for v in in_values]))
|
|
else:
|
|
# Handle simple equality: {"field": "value"}
|
|
if PGVECTOR_PGCRYPTO:
|
|
where_clauses.append(
|
|
pgcrypto_decrypt(
|
|
DocumentChunk.vmetadata,
|
|
PGVECTOR_PGCRYPTO_KEY,
|
|
JSONB,
|
|
)[key].astext
|
|
== str(value)
|
|
)
|
|
else:
|
|
where_clauses.append(DocumentChunk.vmetadata[key].astext == str(value))
|
|
|
|
subq = (
|
|
select(*result_fields)
|
|
.where(*where_clauses)
|
|
.order_by((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)))
|
|
)
|
|
if limit is not None:
|
|
subq = subq.limit(limit)
|
|
subq = subq.lateral('result')
|
|
|
|
# Build the main query by joining query_vectors and the lateral subquery
|
|
stmt = (
|
|
select(
|
|
query_vectors.c.qid,
|
|
subq.c.id,
|
|
subq.c.text,
|
|
subq.c.vmetadata,
|
|
subq.c.distance,
|
|
)
|
|
.select_from(query_vectors)
|
|
.join(subq, true())
|
|
.order_by(query_vectors.c.qid, subq.c.distance)
|
|
)
|
|
|
|
result_proxy = self.session.execute(stmt)
|
|
results = result_proxy.all()
|
|
|
|
ids = [[] for _ in range(num_queries)]
|
|
distances = [[] for _ in range(num_queries)]
|
|
documents = [[] for _ in range(num_queries)]
|
|
metadatas = [[] for _ in range(num_queries)]
|
|
|
|
if not results:
|
|
return SearchResult(
|
|
ids=ids,
|
|
distances=distances,
|
|
documents=documents,
|
|
metadatas=metadatas,
|
|
)
|
|
|
|
for row in results:
|
|
qid = int(row.qid)
|
|
ids[qid].append(row.id)
|
|
# normalize and re-orders pgvec distance from [2, 0] to [0, 1] score range
|
|
# https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
|
|
distances[qid].append((2.0 - row.distance) / 2.0)
|
|
documents[qid].append(row.text)
|
|
metadatas[qid].append(row.vmetadata)
|
|
|
|
self.session.rollback() # read-only transaction
|
|
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
|
|
except Exception as e:
|
|
self.session.rollback()
|
|
log.exception(f'Error during search: {e}')
|
|
return None
|
|
|
|
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
|
try:
|
|
if PGVECTOR_PGCRYPTO:
|
|
# Build where clause for vmetadata filter
|
|
where_clauses = [DocumentChunk.collection_name == collection_name]
|
|
for key, value in filter.items():
|
|
# decrypt then check key: JSON filter after decryption
|
|
where_clauses.append(
|
|
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext
|
|
== str(value)
|
|
)
|
|
stmt = select(
|
|
DocumentChunk.id,
|
|
pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'),
|
|
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'),
|
|
).where(*where_clauses)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
results = self.session.execute(stmt).all()
|
|
else:
|
|
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
|
|
|
for key, value in filter.items():
|
|
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
|
|
|
if limit is not None:
|
|
query = query.limit(limit)
|
|
|
|
results = query.all()
|
|
|
|
if not results:
|
|
return None
|
|
|
|
ids = [[result.id for result in results]]
|
|
documents = [[result.text for result in results]]
|
|
metadatas = [[result.vmetadata for result in results]]
|
|
|
|
self.session.rollback() # read-only transaction
|
|
return GetResult(
|
|
ids=ids,
|
|
documents=documents,
|
|
metadatas=metadatas,
|
|
)
|
|
except Exception as e:
|
|
self.session.rollback()
|
|
log.exception(f'Error during query: {e}')
|
|
return None
|
|
|
|
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
|
|
try:
|
|
if PGVECTOR_PGCRYPTO:
|
|
stmt = select(
|
|
DocumentChunk.id,
|
|
pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'),
|
|
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'),
|
|
).where(DocumentChunk.collection_name == collection_name)
|
|
if limit is not None:
|
|
stmt = stmt.limit(limit)
|
|
results = self.session.execute(stmt).all()
|
|
ids = [[row.id for row in results]]
|
|
documents = [[row.text for row in results]]
|
|
metadatas = [[row.vmetadata for row in results]]
|
|
else:
|
|
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
|
if limit is not None:
|
|
query = query.limit(limit)
|
|
|
|
results = query.all()
|
|
|
|
if not results:
|
|
return None
|
|
|
|
ids = [[result.id for result in results]]
|
|
documents = [[result.text for result in results]]
|
|
metadatas = [[result.vmetadata for result in results]]
|
|
|
|
self.session.rollback() # read-only transaction
|
|
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
|
except Exception as e:
|
|
self.session.rollback()
|
|
log.exception(f'Error during get: {e}')
|
|
return None
|
|
|
|
def delete(
|
|
self,
|
|
collection_name: str,
|
|
ids: Optional[List[str]] = None,
|
|
filter: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
try:
|
|
if PGVECTOR_PGCRYPTO:
|
|
wheres = [DocumentChunk.collection_name == collection_name]
|
|
if ids:
|
|
wheres.append(DocumentChunk.id.in_(ids))
|
|
if filter:
|
|
for key, value in filter.items():
|
|
wheres.append(
|
|
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext
|
|
== str(value)
|
|
)
|
|
stmt = DocumentChunk.__table__.delete().where(*wheres)
|
|
result = self.session.execute(stmt)
|
|
deleted = result.rowcount
|
|
else:
|
|
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
|
if ids:
|
|
query = query.filter(DocumentChunk.id.in_(ids))
|
|
if filter:
|
|
for key, value in filter.items():
|
|
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
|
deleted = query.delete(synchronize_session=False)
|
|
self.session.commit()
|
|
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
|
except Exception as e:
|
|
self.session.rollback()
|
|
log.exception(f'Error during delete: {e}')
|
|
raise
|
|
|
|
def reset(self) -> None:
|
|
try:
|
|
deleted = self.session.query(DocumentChunk).delete()
|
|
self.session.commit()
|
|
log.info(f"Reset complete. Deleted {deleted} items from 'document_chunk' table.")
|
|
except Exception as e:
|
|
self.session.rollback()
|
|
log.exception(f'Error during reset: {e}')
|
|
raise
|
|
|
|
def close(self) -> None:
|
|
pass
|
|
|
|
def has_collection(self, collection_name: str) -> bool:
|
|
try:
|
|
exists = (
|
|
self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first()
|
|
is not None
|
|
)
|
|
self.session.rollback() # read-only transaction
|
|
return exists
|
|
except Exception as e:
|
|
self.session.rollback()
|
|
log.exception(f'Error checking collection existence: {e}')
|
|
return False
|
|
|
|
def delete_collection(self, collection_name: str) -> None:
|
|
self.delete(collection_name)
|
|
log.info(f"Collection '{collection_name}' deleted.")
|