This commit is contained in:
Timothy Jaeryang Baek
2026-03-17 17:58:01 -05:00
parent fcf7208352
commit de3317e26b
220 changed files with 17200 additions and 22836 deletions

View File

@@ -40,78 +40,76 @@ class DatalabMarkerLoader:
self.output_format = output_format
def _get_mime_type(self, filename: str) -> str:
ext = filename.rsplit(".", 1)[-1].lower()
ext = filename.rsplit('.', 1)[-1].lower()
mime_map = {
"pdf": "application/pdf",
"xls": "application/vnd.ms-excel",
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"ods": "application/vnd.oasis.opendocument.spreadsheet",
"doc": "application/msword",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"odt": "application/vnd.oasis.opendocument.text",
"ppt": "application/vnd.ms-powerpoint",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"odp": "application/vnd.oasis.opendocument.presentation",
"html": "text/html",
"epub": "application/epub+zip",
"png": "image/png",
"jpeg": "image/jpeg",
"jpg": "image/jpeg",
"webp": "image/webp",
"gif": "image/gif",
"tiff": "image/tiff",
'pdf': 'application/pdf',
'xls': 'application/vnd.ms-excel',
'xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'ods': 'application/vnd.oasis.opendocument.spreadsheet',
'doc': 'application/msword',
'docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'odt': 'application/vnd.oasis.opendocument.text',
'ppt': 'application/vnd.ms-powerpoint',
'pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'odp': 'application/vnd.oasis.opendocument.presentation',
'html': 'text/html',
'epub': 'application/epub+zip',
'png': 'image/png',
'jpeg': 'image/jpeg',
'jpg': 'image/jpeg',
'webp': 'image/webp',
'gif': 'image/gif',
'tiff': 'image/tiff',
}
return mime_map.get(ext, "application/octet-stream")
return mime_map.get(ext, 'application/octet-stream')
def check_marker_request_status(self, request_id: str) -> dict:
url = f"{self.api_base_url}/{request_id}"
headers = {"X-Api-Key": self.api_key}
url = f'{self.api_base_url}/{request_id}'
headers = {'X-Api-Key': self.api_key}
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
result = response.json()
log.info(f"Marker API status check for request {request_id}: {result}")
log.info(f'Marker API status check for request {request_id}: {result}')
return result
except requests.HTTPError as e:
log.error(f"Error checking Marker request status: {e}")
log.error(f'Error checking Marker request status: {e}')
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to check Marker request: {e}",
detail=f'Failed to check Marker request: {e}',
)
except ValueError as e:
log.error(f"Invalid JSON checking Marker request: {e}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
)
log.error(f'Invalid JSON checking Marker request: {e}')
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON: {e}')
def load(self) -> List[Document]:
filename = os.path.basename(self.file_path)
mime_type = self._get_mime_type(filename)
headers = {"X-Api-Key": self.api_key}
headers = {'X-Api-Key': self.api_key}
form_data = {
"use_llm": str(self.use_llm).lower(),
"skip_cache": str(self.skip_cache).lower(),
"force_ocr": str(self.force_ocr).lower(),
"paginate": str(self.paginate).lower(),
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
"disable_image_extraction": str(self.disable_image_extraction).lower(),
"format_lines": str(self.format_lines).lower(),
"output_format": self.output_format,
'use_llm': str(self.use_llm).lower(),
'skip_cache': str(self.skip_cache).lower(),
'force_ocr': str(self.force_ocr).lower(),
'paginate': str(self.paginate).lower(),
'strip_existing_ocr': str(self.strip_existing_ocr).lower(),
'disable_image_extraction': str(self.disable_image_extraction).lower(),
'format_lines': str(self.format_lines).lower(),
'output_format': self.output_format,
}
if self.additional_config and self.additional_config.strip():
form_data["additional_config"] = self.additional_config
form_data['additional_config'] = self.additional_config
log.info(
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
)
try:
with open(self.file_path, "rb") as f:
files = {"file": (filename, f, mime_type)}
with open(self.file_path, 'rb') as f:
files = {'file': (filename, f, mime_type)}
response = requests.post(
f"{self.api_base_url}",
f'{self.api_base_url}',
data=form_data,
files=files,
headers=headers,
@@ -119,29 +117,25 @@ class DatalabMarkerLoader:
response.raise_for_status()
result = response.json()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {e}",
detail=f'Datalab Marker request failed: {e}',
)
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
)
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON response: {e}')
except Exception as e:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
if not result.get("success"):
if not result.get('success'):
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
detail=f'Datalab Marker request failed: {result.get("error", "Unknown error")}',
)
check_url = result.get("request_check_url")
request_id = result.get("request_id")
check_url = result.get('request_check_url')
request_id = result.get('request_id')
# Check if this is a direct response (self-hosted) or polling response (DataLab)
if check_url:
@@ -154,54 +148,45 @@ class DatalabMarkerLoader:
poll_result = poll_response.json()
except (requests.HTTPError, ValueError) as e:
raw_body = poll_response.text
log.error(f"Polling error: {e}, response body: {raw_body}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
)
log.error(f'Polling error: {e}, response body: {raw_body}')
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Polling failed: {e}')
status_val = poll_result.get("status")
success_val = poll_result.get("success")
status_val = poll_result.get('status')
success_val = poll_result.get('success')
if status_val == "complete":
if status_val == 'complete':
summary = {
k: poll_result.get(k)
for k in (
"status",
"output_format",
"success",
"error",
"page_count",
"total_cost",
'status',
'output_format',
'success',
'error',
'page_count',
'total_cost',
)
}
log.info(
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
)
log.info(f'Marker processing completed successfully: {json.dumps(summary, indent=2)}')
break
if status_val == "failed" or success_val is False:
log.error(
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
)
error_msg = (
poll_result.get("error")
or "Marker returned failure without error message"
)
if status_val == 'failed' or success_val is False:
log.error(f'Marker poll failed full response: {json.dumps(poll_result, indent=2)}')
error_msg = poll_result.get('error') or 'Marker returned failure without error message'
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Marker processing failed: {error_msg}",
detail=f'Marker processing failed: {error_msg}',
)
else:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="Marker processing timed out",
detail='Marker processing timed out',
)
if not poll_result.get("success", False):
error_msg = poll_result.get("error") or "Unknown processing error"
if not poll_result.get('success', False):
error_msg = poll_result.get('error') or 'Unknown processing error'
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Final processing failed: {error_msg}",
detail=f'Final processing failed: {error_msg}',
)
# DataLab format - content in format-specific fields
@@ -210,69 +195,65 @@ class DatalabMarkerLoader:
final_result = poll_result
else:
# Self-hosted direct response - content in "output" field
if "output" in result:
log.info("Self-hosted Marker returned direct response without polling")
raw_content = result.get("output")
if 'output' in result:
log.info('Self-hosted Marker returned direct response without polling')
raw_content = result.get('output')
final_result = result
else:
available_fields = (
list(result.keys())
if isinstance(result, dict)
else "non-dict response"
)
available_fields = list(result.keys()) if isinstance(result, dict) else 'non-dict response'
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
)
if self.output_format.lower() == "json":
if self.output_format.lower() == 'json':
full_text = json.dumps(raw_content, indent=2)
elif self.output_format.lower() in {"markdown", "html"}:
elif self.output_format.lower() in {'markdown', 'html'}:
full_text = str(raw_content).strip()
else:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported output format: {self.output_format}",
detail=f'Unsupported output format: {self.output_format}',
)
if not full_text:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Marker returned empty content",
detail='Marker returned empty content',
)
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
marker_output_dir = os.path.join('/app/backend/data/uploads', 'marker_output')
os.makedirs(marker_output_dir, exist_ok=True)
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
file_ext = file_ext_map.get(self.output_format.lower(), "txt")
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
file_ext_map = {'markdown': 'md', 'json': 'json', 'html': 'html'}
file_ext = file_ext_map.get(self.output_format.lower(), 'txt')
output_filename = f'{os.path.splitext(filename)[0]}.{file_ext}'
output_path = os.path.join(marker_output_dir, output_filename)
try:
with open(output_path, "w", encoding="utf-8") as f:
with open(output_path, 'w', encoding='utf-8') as f:
f.write(full_text)
log.info(f"Saved Marker output to: {output_path}")
log.info(f'Saved Marker output to: {output_path}')
except Exception as e:
log.warning(f"Failed to write marker output to disk: {e}")
log.warning(f'Failed to write marker output to disk: {e}')
metadata = {
"source": filename,
"output_format": final_result.get("output_format", self.output_format),
"page_count": final_result.get("page_count", 0),
"processed_with_llm": self.use_llm,
"request_id": request_id or "",
'source': filename,
'output_format': final_result.get('output_format', self.output_format),
'page_count': final_result.get('page_count', 0),
'processed_with_llm': self.use_llm,
'request_id': request_id or '',
}
images = final_result.get("images", {})
images = final_result.get('images', {})
if images:
metadata["image_count"] = len(images)
metadata["images"] = json.dumps(list(images.keys()))
metadata['image_count'] = len(images)
metadata['images'] = json.dumps(list(images.keys()))
for k, v in metadata.items():
if isinstance(v, (dict, list)):
metadata[k] = json.dumps(v)
elif v is None:
metadata[k] = ""
metadata[k] = ''
return [Document(page_content=full_text, metadata=metadata)]

View File

@@ -29,18 +29,18 @@ class ExternalDocumentLoader(BaseLoader):
self.user = user
def load(self) -> List[Document]:
with open(self.file_path, "rb") as f:
with open(self.file_path, 'rb') as f:
data = f.read()
headers = {}
if self.mime_type is not None:
headers["Content-Type"] = self.mime_type
headers['Content-Type'] = self.mime_type
if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}"
headers['Authorization'] = f'Bearer {self.api_key}'
try:
headers["X-Filename"] = quote(os.path.basename(self.file_path))
headers['X-Filename'] = quote(os.path.basename(self.file_path))
except Exception:
pass
@@ -48,24 +48,23 @@ class ExternalDocumentLoader(BaseLoader):
headers = include_user_info_headers(headers, self.user)
url = self.url
if url.endswith("/"):
if url.endswith('/'):
url = url[:-1]
try:
response = requests.put(f"{url}/process", data=data, headers=headers)
response = requests.put(f'{url}/process', data=data, headers=headers)
except Exception as e:
log.error(f"Error connecting to endpoint: {e}")
raise Exception(f"Error connecting to endpoint: {e}")
log.error(f'Error connecting to endpoint: {e}')
raise Exception(f'Error connecting to endpoint: {e}')
if response.ok:
response_data = response.json()
if response_data:
if isinstance(response_data, dict):
return [
Document(
page_content=response_data.get("page_content"),
metadata=response_data.get("metadata"),
page_content=response_data.get('page_content'),
metadata=response_data.get('metadata'),
)
]
elif isinstance(response_data, list):
@@ -73,17 +72,15 @@ class ExternalDocumentLoader(BaseLoader):
for document in response_data:
documents.append(
Document(
page_content=document.get("page_content"),
metadata=document.get("metadata"),
page_content=document.get('page_content'),
metadata=document.get('metadata'),
)
)
return documents
else:
raise Exception("Error loading document: Unable to parse content")
raise Exception('Error loading document: Unable to parse content')
else:
raise Exception("Error loading document: No content returned")
raise Exception('Error loading document: No content returned')
else:
raise Exception(
f"Error loading document: {response.status_code} {response.text}"
)
raise Exception(f'Error loading document: {response.status_code} {response.text}')

View File

@@ -30,22 +30,22 @@ class ExternalWebLoader(BaseLoader):
response = requests.post(
self.external_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
"Authorization": f"Bearer {self.external_api_key}",
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) External Web Loader',
'Authorization': f'Bearer {self.external_api_key}',
},
json={
"urls": urls,
'urls': urls,
},
)
response.raise_for_status()
results = response.json()
for result in results:
yield Document(
page_content=result.get("page_content", ""),
metadata=result.get("metadata", {}),
page_content=result.get('page_content', ''),
metadata=result.get('metadata', {}),
)
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {urls}: {e}")
log.error(f'Error extracting content from batch {urls}: {e}')
else:
raise e

View File

@@ -30,59 +30,59 @@ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
known_source_ext = [
"go",
"py",
"java",
"sh",
"bat",
"ps1",
"cmd",
"js",
"ts",
"css",
"cpp",
"hpp",
"h",
"c",
"cs",
"sql",
"log",
"ini",
"pl",
"pm",
"r",
"dart",
"dockerfile",
"env",
"php",
"hs",
"hsc",
"lua",
"nginxconf",
"conf",
"m",
"mm",
"plsql",
"perl",
"rb",
"rs",
"db2",
"scala",
"bash",
"swift",
"vue",
"svelte",
"ex",
"exs",
"erl",
"tsx",
"jsx",
"hs",
"lhs",
"json",
"yaml",
"yml",
"toml",
'go',
'py',
'java',
'sh',
'bat',
'ps1',
'cmd',
'js',
'ts',
'css',
'cpp',
'hpp',
'h',
'c',
'cs',
'sql',
'log',
'ini',
'pl',
'pm',
'r',
'dart',
'dockerfile',
'env',
'php',
'hs',
'hsc',
'lua',
'nginxconf',
'conf',
'm',
'mm',
'plsql',
'perl',
'rb',
'rs',
'db2',
'scala',
'bash',
'swift',
'vue',
'svelte',
'ex',
'exs',
'erl',
'tsx',
'jsx',
'hs',
'lhs',
'json',
'yaml',
'yml',
'toml',
]
@@ -99,11 +99,11 @@ class ExcelLoader:
xls = pd.ExcelFile(self.file_path)
for sheet_name in xls.sheet_names:
df = pd.read_excel(xls, sheet_name=sheet_name)
text_parts.append(f"Sheet: {sheet_name}\n{df.to_string(index=False)}")
text_parts.append(f'Sheet: {sheet_name}\n{df.to_string(index=False)}')
return [
Document(
page_content="\n\n".join(text_parts),
metadata={"source": self.file_path},
page_content='\n\n'.join(text_parts),
metadata={'source': self.file_path},
)
]
@@ -125,11 +125,11 @@ class PptxLoader:
if shape.has_text_frame:
slide_texts.append(shape.text_frame.text)
if slide_texts:
text_parts.append(f"Slide {i}:\n" + "\n".join(slide_texts))
text_parts.append(f'Slide {i}:\n' + '\n'.join(slide_texts))
return [
Document(
page_content="\n\n".join(text_parts),
metadata={"source": self.file_path},
page_content='\n\n'.join(text_parts),
metadata={'source': self.file_path},
)
]
@@ -143,41 +143,41 @@ class TikaLoader:
self.extract_images = extract_images
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
with open(self.file_path, 'rb') as f:
data = f.read()
if self.mime_type is not None:
headers = {"Content-Type": self.mime_type}
headers = {'Content-Type': self.mime_type}
else:
headers = {}
if self.extract_images == True:
headers["X-Tika-PDFextractInlineImages"] = "true"
headers['X-Tika-PDFextractInlineImages'] = 'true'
endpoint = self.url
if not endpoint.endswith("/"):
endpoint += "/"
endpoint += "tika/text"
if not endpoint.endswith('/'):
endpoint += '/'
endpoint += 'tika/text'
r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY)
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
text = raw_metadata.get('X-TIKA:content', '<No text content found>').strip()
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
if 'Content-Type' in raw_metadata:
headers['Content-Type'] = raw_metadata['Content-Type']
log.debug("Tika extracted text: %s", text)
log.debug('Tika extracted text: %s', text)
return [Document(page_content=text, metadata=headers)]
else:
raise Exception(f"Error calling Tika: {r.reason}")
raise Exception(f'Error calling Tika: {r.reason}')
class DoclingLoader:
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
self.url = url.rstrip("/")
self.url = url.rstrip('/')
self.api_key = api_key
self.file_path = file_path
self.mime_type = mime_type
@@ -185,199 +185,183 @@ class DoclingLoader:
self.params = params or {}
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
with open(self.file_path, 'rb') as f:
headers = {}
if self.api_key:
headers["X-Api-Key"] = f"{self.api_key}"
headers['X-Api-Key'] = f'{self.api_key}'
r = requests.post(
f"{self.url}/v1/convert/file",
f'{self.url}/v1/convert/file',
files={
"files": (
'files': (
self.file_path,
f,
self.mime_type or "application/octet-stream",
self.mime_type or 'application/octet-stream',
)
},
data={
"image_export_mode": "placeholder",
'image_export_mode': 'placeholder',
**self.params,
},
headers=headers,
)
if r.ok:
result = r.json()
document_data = result.get("document", {})
text = document_data.get("md_content", "<No text content found>")
document_data = result.get('document', {})
text = document_data.get('md_content', '<No text content found>')
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
metadata = {'Content-Type': self.mime_type} if self.mime_type else {}
log.debug("Docling extracted text: %s", text)
log.debug('Docling extracted text: %s', text)
return [Document(page_content=text, metadata=metadata)]
else:
error_msg = f"Error calling Docling API: {r.reason}"
error_msg = f'Error calling Docling API: {r.reason}'
if r.text:
try:
error_data = r.json()
if "detail" in error_data:
error_msg += f" - {error_data['detail']}"
if 'detail' in error_data:
error_msg += f' - {error_data["detail"]}'
except Exception:
error_msg += f" - {r.text}"
raise Exception(f"Error calling Docling: {error_msg}")
error_msg += f' - {r.text}'
raise Exception(f'Error calling Docling: {error_msg}')
class Loader:
def __init__(self, engine: str = "", **kwargs):
def __init__(self, engine: str = '', **kwargs):
self.engine = engine
self.user = kwargs.get("user", None)
self.user = kwargs.get('user', None)
self.kwargs = kwargs
def load(
self, filename: str, file_content_type: str, file_path: str
) -> list[Document]:
def load(self, filename: str, file_content_type: str, file_path: str) -> list[Document]:
loader = self._get_loader(filename, file_content_type, file_path)
docs = loader.load()
return [
Document(
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
)
for doc in docs
]
return [Document(page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata) for doc in docs]
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
return file_ext in known_source_ext or (
file_content_type
and file_content_type.find("text/") >= 0
and file_content_type.find('text/') >= 0
# Avoid text/html files being detected as text
and not file_content_type.find("html") >= 0
and not file_content_type.find('html') >= 0
)
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
file_ext = filename.split('.')[-1].lower()
if (
self.engine == "external"
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
self.engine == 'external'
and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL')
and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY')
):
loader = ExternalDocumentLoader(
file_path=file_path,
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
url=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL'),
api_key=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY'),
mime_type=file_content_type,
user=self.user,
)
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
elif self.engine == 'tika' and self.kwargs.get('TIKA_SERVER_URL'):
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TikaLoader(
url=self.kwargs.get("TIKA_SERVER_URL"),
url=self.kwargs.get('TIKA_SERVER_URL'),
file_path=file_path,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
)
elif (
self.engine == "datalab_marker"
and self.kwargs.get("DATALAB_MARKER_API_KEY")
self.engine == 'datalab_marker'
and self.kwargs.get('DATALAB_MARKER_API_KEY')
and file_ext
in [
"pdf",
"xls",
"xlsx",
"ods",
"doc",
"docx",
"odt",
"ppt",
"pptx",
"odp",
"html",
"epub",
"png",
"jpeg",
"jpg",
"webp",
"gif",
"tiff",
'pdf',
'xls',
'xlsx',
'ods',
'doc',
'docx',
'odt',
'ppt',
'pptx',
'odp',
'html',
'epub',
'png',
'jpeg',
'jpg',
'webp',
'gif',
'tiff',
]
):
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "")
if not api_base_url or api_base_url.strip() == "":
api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
api_base_url = self.kwargs.get('DATALAB_MARKER_API_BASE_URL', '')
if not api_base_url or api_base_url.strip() == '':
api_base_url = 'https://www.datalab.to/api/v1/marker' # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
loader = DatalabMarkerLoader(
file_path=file_path,
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
api_key=self.kwargs['DATALAB_MARKER_API_KEY'],
api_base_url=api_base_url,
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"),
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
strip_existing_ocr=self.kwargs.get(
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
),
disable_image_extraction=self.kwargs.get(
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
),
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
output_format=self.kwargs.get(
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
),
additional_config=self.kwargs.get('DATALAB_MARKER_ADDITIONAL_CONFIG'),
use_llm=self.kwargs.get('DATALAB_MARKER_USE_LLM', False),
skip_cache=self.kwargs.get('DATALAB_MARKER_SKIP_CACHE', False),
force_ocr=self.kwargs.get('DATALAB_MARKER_FORCE_OCR', False),
paginate=self.kwargs.get('DATALAB_MARKER_PAGINATE', False),
strip_existing_ocr=self.kwargs.get('DATALAB_MARKER_STRIP_EXISTING_OCR', False),
disable_image_extraction=self.kwargs.get('DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION', False),
format_lines=self.kwargs.get('DATALAB_MARKER_FORMAT_LINES', False),
output_format=self.kwargs.get('DATALAB_MARKER_OUTPUT_FORMAT', 'markdown'),
)
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
elif self.engine == 'docling' and self.kwargs.get('DOCLING_SERVER_URL'):
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
# Build params for DoclingLoader
params = self.kwargs.get("DOCLING_PARAMS", {})
params = self.kwargs.get('DOCLING_PARAMS', {})
if not isinstance(params, dict):
try:
params = json.loads(params)
except json.JSONDecodeError:
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
log.error('Invalid DOCLING_PARAMS format, expected JSON object')
params = {}
loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"),
api_key=self.kwargs.get("DOCLING_API_KEY", None),
url=self.kwargs.get('DOCLING_SERVER_URL'),
api_key=self.kwargs.get('DOCLING_API_KEY', None),
file_path=file_path,
mime_type=file_content_type,
params=params,
)
elif (
self.engine == "document_intelligence"
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
self.engine == 'document_intelligence'
and self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT') != ''
and (
file_ext in ["pdf", "docx", "ppt", "pptx"]
file_ext in ['pdf', 'docx', 'ppt', 'pptx']
or file_content_type
in [
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
]
)
):
if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "":
if self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY') != '':
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
api_key=self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY'),
api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
)
else:
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
azure_credential=DefaultAzureCredential(),
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
)
elif self.engine == "mineru" and file_ext in [
"pdf"
]: # MinerU currently only supports PDF
mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300)
elif self.engine == 'mineru' and file_ext in ['pdf']: # MinerU currently only supports PDF
mineru_timeout = self.kwargs.get('MINERU_API_TIMEOUT', 300)
if mineru_timeout:
try:
mineru_timeout = int(mineru_timeout)
@@ -386,111 +370,115 @@ class Loader:
loader = MinerULoader(
file_path=file_path,
api_mode=self.kwargs.get("MINERU_API_MODE", "local"),
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"),
api_key=self.kwargs.get("MINERU_API_KEY", ""),
params=self.kwargs.get("MINERU_PARAMS", {}),
api_mode=self.kwargs.get('MINERU_API_MODE', 'local'),
api_url=self.kwargs.get('MINERU_API_URL', 'http://localhost:8000'),
api_key=self.kwargs.get('MINERU_API_KEY', ''),
params=self.kwargs.get('MINERU_PARAMS', {}),
timeout=mineru_timeout,
)
elif (
self.engine == "mistral_ocr"
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
and file_ext
in ["pdf"] # Mistral OCR currently only supports PDF and images
self.engine == 'mistral_ocr'
and self.kwargs.get('MISTRAL_OCR_API_KEY') != ''
and file_ext in ['pdf'] # Mistral OCR currently only supports PDF and images
):
loader = MistralLoader(
base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"),
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"),
base_url=self.kwargs.get('MISTRAL_OCR_API_BASE_URL'),
api_key=self.kwargs.get('MISTRAL_OCR_API_KEY'),
file_path=file_path,
)
else:
if file_ext == "pdf":
if file_ext == 'pdf':
loader = PyPDFLoader(
file_path,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
mode=self.kwargs.get("PDF_LOADER_MODE", "page"),
extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
mode=self.kwargs.get('PDF_LOADER_MODE', 'page'),
)
elif file_ext == "csv":
elif file_ext == 'csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_ext == "rst":
elif file_ext == 'rst':
try:
from langchain_community.document_loaders import UnstructuredRSTLoader
loader = UnstructuredRSTLoader(file_path, mode="elements")
loader = UnstructuredRSTLoader(file_path, mode='elements')
except ImportError:
log.warning(
"The 'unstructured' package is not installed. "
"Falling back to plain text loading for .rst file. "
"Install it with: pip install unstructured"
'Falling back to plain text loading for .rst file. '
'Install it with: pip install unstructured'
)
loader = TextLoader(file_path, autodetect_encoding=True)
elif file_ext == "xml":
elif file_ext == 'xml':
try:
from langchain_community.document_loaders import UnstructuredXMLLoader
loader = UnstructuredXMLLoader(file_path)
except ImportError:
log.warning(
"The 'unstructured' package is not installed. "
"Falling back to plain text loading for .xml file. "
"Install it with: pip install unstructured"
'Falling back to plain text loading for .xml file. '
'Install it with: pip install unstructured'
)
loader = TextLoader(file_path, autodetect_encoding=True)
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md":
elif file_ext in ['htm', 'html']:
loader = BSHTMLLoader(file_path, open_encoding='unicode_escape')
elif file_ext == 'md':
loader = TextLoader(file_path, autodetect_encoding=True)
elif file_content_type == "application/epub+zip":
elif file_content_type == 'application/epub+zip':
try:
from langchain_community.document_loaders import UnstructuredEPubLoader
loader = UnstructuredEPubLoader(file_path)
except ImportError:
raise ValueError(
"Processing .epub files requires the 'unstructured' package. "
"Install it with: pip install unstructured"
'Install it with: pip install unstructured'
)
elif (
file_content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file_ext == "docx"
file_content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
or file_ext == 'docx'
):
loader = Docx2txtLoader(file_path)
elif file_content_type in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]:
'application/vnd.ms-excel',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
] or file_ext in ['xls', 'xlsx']:
try:
from langchain_community.document_loaders import UnstructuredExcelLoader
loader = UnstructuredExcelLoader(file_path)
except ImportError:
log.warning(
"The 'unstructured' package is not installed. "
"Falling back to pandas for Excel file loading. "
"Install unstructured for better results: pip install unstructured"
'Falling back to pandas for Excel file loading. '
'Install unstructured for better results: pip install unstructured'
)
loader = ExcelLoader(file_path)
elif file_content_type in [
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
] or file_ext in ["ppt", "pptx"]:
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
] or file_ext in ['ppt', 'pptx']:
try:
from langchain_community.document_loaders import UnstructuredPowerPointLoader
loader = UnstructuredPowerPointLoader(file_path)
except ImportError:
log.warning(
"The 'unstructured' package is not installed. "
"Falling back to python-pptx for PowerPoint file loading. "
"Install unstructured for better results: pip install unstructured"
'Falling back to python-pptx for PowerPoint file loading. '
'Install unstructured for better results: pip install unstructured'
)
loader = PptxLoader(file_path)
elif file_ext == "msg":
elif file_ext == 'msg':
loader = OutlookMessageLoader(file_path)
elif file_ext == "odt":
elif file_ext == 'odt':
try:
from langchain_community.document_loaders import UnstructuredODTLoader
loader = UnstructuredODTLoader(file_path)
except ImportError:
raise ValueError(
"Processing .odt files requires the 'unstructured' package. "
"Install it with: pip install unstructured"
'Install it with: pip install unstructured'
)
elif self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
@@ -498,4 +486,3 @@ class Loader:
loader = TextLoader(file_path, autodetect_encoding=True)
return loader

View File

@@ -22,37 +22,35 @@ class MinerULoader:
def __init__(
self,
file_path: str,
api_mode: str = "local",
api_url: str = "http://localhost:8000",
api_key: str = "",
api_mode: str = 'local',
api_url: str = 'http://localhost:8000',
api_key: str = '',
params: dict = None,
timeout: Optional[int] = 300,
):
self.file_path = file_path
self.api_mode = api_mode.lower()
self.api_url = api_url.rstrip("/")
self.api_url = api_url.rstrip('/')
self.api_key = api_key
self.timeout = timeout
# Parse params dict with defaults
self.params = params or {}
self.enable_ocr = params.get("enable_ocr", False)
self.enable_formula = params.get("enable_formula", True)
self.enable_table = params.get("enable_table", True)
self.language = params.get("language", "en")
self.model_version = params.get("model_version", "pipeline")
self.enable_ocr = params.get('enable_ocr', False)
self.enable_formula = params.get('enable_formula', True)
self.enable_table = params.get('enable_table', True)
self.language = params.get('language', 'en')
self.model_version = params.get('model_version', 'pipeline')
self.page_ranges = self.params.pop("page_ranges", "")
self.page_ranges = self.params.pop('page_ranges', '')
# Validate API mode
if self.api_mode not in ["local", "cloud"]:
raise ValueError(
f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'"
)
if self.api_mode not in ['local', 'cloud']:
raise ValueError(f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'")
# Validate Cloud API requirements
if self.api_mode == "cloud" and not self.api_key:
raise ValueError("API key is required for Cloud API mode")
if self.api_mode == 'cloud' and not self.api_key:
raise ValueError('API key is required for Cloud API mode')
def load(self) -> List[Document]:
"""
@@ -60,12 +58,12 @@ class MinerULoader:
Routes to Cloud or Local API based on api_mode.
"""
try:
if self.api_mode == "cloud":
if self.api_mode == 'cloud':
return self._load_cloud_api()
else:
return self._load_local_api()
except Exception as e:
log.error(f"Error loading document with MinerU: {e}")
log.error(f'Error loading document with MinerU: {e}')
raise
def _load_local_api(self) -> List[Document]:
@@ -73,14 +71,14 @@ class MinerULoader:
Load document using Local API (synchronous).
Posts file to /file_parse endpoint and gets immediate response.
"""
log.info(f"Using MinerU Local API at {self.api_url}")
log.info(f'Using MinerU Local API at {self.api_url}')
filename = os.path.basename(self.file_path)
# Build form data for Local API
form_data = {
**self.params,
"return_md": "true",
'return_md': 'true',
}
# Page ranges (Local API uses start_page_id and end_page_id)
@@ -89,18 +87,18 @@ class MinerULoader:
# Full page range parsing would require parsing the string
log.warning(
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
"Consider using start_page_id/end_page_id parameters if needed."
'Consider using start_page_id/end_page_id parameters if needed.'
)
try:
with open(self.file_path, "rb") as f:
files = {"files": (filename, f, "application/octet-stream")}
with open(self.file_path, 'rb') as f:
files = {'files': (filename, f, 'application/octet-stream')}
log.info(f"Sending file to MinerU Local API: {filename}")
log.debug(f"Local API parameters: {form_data}")
log.info(f'Sending file to MinerU Local API: {filename}')
log.debug(f'Local API parameters: {form_data}')
response = requests.post(
f"{self.api_url}/file_parse",
f'{self.api_url}/file_parse',
data=form_data,
files=files,
timeout=self.timeout,
@@ -108,27 +106,25 @@ class MinerULoader:
response.raise_for_status()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
except requests.Timeout:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="MinerU Local API request timed out",
detail='MinerU Local API request timed out',
)
except requests.HTTPError as e:
error_detail = f"MinerU Local API request failed: {e}"
error_detail = f'MinerU Local API request failed: {e}'
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data}"
error_detail += f' - {error_data}'
except Exception:
error_detail += f" - {e.response.text}"
error_detail += f' - {e.response.text}'
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error calling MinerU Local API: {str(e)}",
detail=f'Error calling MinerU Local API: {str(e)}',
)
# Parse response
@@ -137,41 +133,41 @@ class MinerULoader:
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response from MinerU Local API: {e}",
detail=f'Invalid JSON response from MinerU Local API: {e}',
)
# Extract markdown content from response
if "results" not in result:
if 'results' not in result:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail="MinerU Local API response missing 'results' field",
)
results = result["results"]
results = result['results']
if not results:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="MinerU returned empty results",
detail='MinerU returned empty results',
)
# Get the first (and typically only) result
file_result = list(results.values())[0]
markdown_content = file_result.get("md_content", "")
markdown_content = file_result.get('md_content', '')
if not markdown_content:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="MinerU returned empty markdown content",
detail='MinerU returned empty markdown content',
)
log.info(f"Successfully parsed document with MinerU Local API: {filename}")
log.info(f'Successfully parsed document with MinerU Local API: {filename}')
# Create metadata
metadata = {
"source": filename,
"api_mode": "local",
"backend": result.get("backend", "unknown"),
"version": result.get("version", "unknown"),
'source': filename,
'api_mode': 'local',
'backend': result.get('backend', 'unknown'),
'version': result.get('version', 'unknown'),
}
return [Document(page_content=markdown_content, metadata=metadata)]
@@ -181,7 +177,7 @@ class MinerULoader:
Load document using Cloud API (asynchronous).
Uses batch upload endpoint to avoid need for public file URLs.
"""
log.info(f"Using MinerU Cloud API at {self.api_url}")
log.info(f'Using MinerU Cloud API at {self.api_url}')
filename = os.path.basename(self.file_path)
@@ -195,17 +191,15 @@ class MinerULoader:
result = self._poll_batch_status(batch_id, filename)
# Step 4: Download and extract markdown from ZIP
markdown_content = self._download_and_extract_zip(
result["full_zip_url"], filename
)
markdown_content = self._download_and_extract_zip(result['full_zip_url'], filename)
log.info(f"Successfully parsed document with MinerU Cloud API: {filename}")
log.info(f'Successfully parsed document with MinerU Cloud API: {filename}')
# Create metadata
metadata = {
"source": filename,
"api_mode": "cloud",
"batch_id": batch_id,
'source': filename,
'api_mode': 'cloud',
'batch_id': batch_id,
}
return [Document(page_content=markdown_content, metadata=metadata)]
@@ -216,49 +210,49 @@ class MinerULoader:
Returns (batch_id, upload_url).
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
}
# Build request body
request_body = {
**self.params,
"files": [
'files': [
{
"name": filename,
"is_ocr": self.enable_ocr,
'name': filename,
'is_ocr': self.enable_ocr,
}
],
}
# Add page ranges if specified
if self.page_ranges:
request_body["files"][0]["page_ranges"] = self.page_ranges
request_body['files'][0]['page_ranges'] = self.page_ranges
log.info(f"Requesting upload URL for: {filename}")
log.debug(f"Cloud API request body: {request_body}")
log.info(f'Requesting upload URL for: {filename}')
log.debug(f'Cloud API request body: {request_body}')
try:
response = requests.post(
f"{self.api_url}/file-urls/batch",
f'{self.api_url}/file-urls/batch',
headers=headers,
json=request_body,
timeout=30,
)
response.raise_for_status()
except requests.HTTPError as e:
error_detail = f"Failed to request upload URL: {e}"
error_detail = f'Failed to request upload URL: {e}'
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data.get('msg', error_data)}"
error_detail += f' - {error_data.get("msg", error_data)}'
except Exception:
error_detail += f" - {e.response.text}"
error_detail += f' - {e.response.text}'
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error requesting upload URL: {str(e)}",
detail=f'Error requesting upload URL: {str(e)}',
)
try:
@@ -266,28 +260,28 @@ class MinerULoader:
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response: {e}",
detail=f'Invalid JSON response: {e}',
)
# Check for API error response
if result.get("code") != 0:
if result.get('code') != 0:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
)
data = result.get("data", {})
batch_id = data.get("batch_id")
file_urls = data.get("file_urls", [])
data = result.get('data', {})
batch_id = data.get('batch_id')
file_urls = data.get('file_urls', [])
if not batch_id or not file_urls:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail="MinerU Cloud API response missing batch_id or file_urls",
detail='MinerU Cloud API response missing batch_id or file_urls',
)
upload_url = file_urls[0]
log.info(f"Received upload URL for batch: {batch_id}")
log.info(f'Received upload URL for batch: {batch_id}')
return batch_id, upload_url
@@ -295,10 +289,10 @@ class MinerULoader:
"""
Upload file to presigned URL (no authentication needed).
"""
log.info(f"Uploading file to presigned URL")
log.info(f'Uploading file to presigned URL')
try:
with open(self.file_path, "rb") as f:
with open(self.file_path, 'rb') as f:
response = requests.put(
upload_url,
data=f,
@@ -306,26 +300,24 @@ class MinerULoader:
)
response.raise_for_status()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
except requests.Timeout:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="File upload to presigned URL timed out",
detail='File upload to presigned URL timed out',
)
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Failed to upload file to presigned URL: {e}",
detail=f'Failed to upload file to presigned URL: {e}',
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error uploading file: {str(e)}",
detail=f'Error uploading file: {str(e)}',
)
log.info("File uploaded successfully")
log.info('File uploaded successfully')
def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
"""
@@ -333,35 +325,35 @@ class MinerULoader:
Returns the result dict for the file.
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
'Authorization': f'Bearer {self.api_key}',
}
max_iterations = 300 # 10 minutes max (2 seconds per iteration)
poll_interval = 2 # seconds
log.info(f"Polling batch status: {batch_id}")
log.info(f'Polling batch status: {batch_id}')
for iteration in range(max_iterations):
try:
response = requests.get(
f"{self.api_url}/extract-results/batch/{batch_id}",
f'{self.api_url}/extract-results/batch/{batch_id}',
headers=headers,
timeout=30,
)
response.raise_for_status()
except requests.HTTPError as e:
error_detail = f"Failed to poll batch status: {e}"
error_detail = f'Failed to poll batch status: {e}'
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data.get('msg', error_data)}"
error_detail += f' - {error_data.get("msg", error_data)}'
except Exception:
error_detail += f" - {e.response.text}"
error_detail += f' - {e.response.text}'
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error polling batch status: {str(e)}",
detail=f'Error polling batch status: {str(e)}',
)
try:
@@ -369,58 +361,56 @@ class MinerULoader:
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response while polling: {e}",
detail=f'Invalid JSON response while polling: {e}',
)
# Check for API error response
if result.get("code") != 0:
if result.get('code') != 0:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
)
data = result.get("data", {})
extract_result = data.get("extract_result", [])
data = result.get('data', {})
extract_result = data.get('extract_result', [])
# Find our file in the batch results
file_result = None
for item in extract_result:
if item.get("file_name") == filename:
if item.get('file_name') == filename:
file_result = item
break
if not file_result:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"File {filename} not found in batch results",
detail=f'File {filename} not found in batch results',
)
state = file_result.get("state")
state = file_result.get('state')
if state == "done":
log.info(f"Processing complete for {filename}")
if state == 'done':
log.info(f'Processing complete for {filename}')
return file_result
elif state == "failed":
error_msg = file_result.get("err_msg", "Unknown error")
elif state == 'failed':
error_msg = file_result.get('err_msg', 'Unknown error')
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU processing failed: {error_msg}",
detail=f'MinerU processing failed: {error_msg}',
)
elif state in ["waiting-file", "pending", "running", "converting"]:
elif state in ['waiting-file', 'pending', 'running', 'converting']:
# Still processing
if iteration % 10 == 0: # Log every 20 seconds
log.info(
f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})"
)
log.info(f'Processing status: {state} (iteration {iteration + 1}/{max_iterations})')
time.sleep(poll_interval)
else:
log.warning(f"Unknown state: {state}")
log.warning(f'Unknown state: {state}')
time.sleep(poll_interval)
# Timeout
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="MinerU processing timed out after 10 minutes",
detail='MinerU processing timed out after 10 minutes',
)
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
@@ -428,7 +418,7 @@ class MinerULoader:
Download ZIP file from CDN and extract markdown content.
Returns the markdown content as a string.
"""
log.info(f"Downloading results from: {zip_url}")
log.info(f'Downloading results from: {zip_url}')
try:
response = requests.get(zip_url, timeout=60)
@@ -436,23 +426,23 @@ class MinerULoader:
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Failed to download results ZIP: {e}",
detail=f'Failed to download results ZIP: {e}',
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error downloading results: {str(e)}",
detail=f'Error downloading results: {str(e)}',
)
# Save ZIP to temporary file and extract
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_zip:
tmp_zip.write(response.content)
tmp_zip_path = tmp_zip.name
with tempfile.TemporaryDirectory() as tmp_dir:
# Extract ZIP
with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref:
with zipfile.ZipFile(tmp_zip_path, 'r') as zip_ref:
zip_ref.extractall(tmp_dir)
# Find markdown file - search recursively for any .md file
@@ -466,33 +456,27 @@ class MinerULoader:
full_path = os.path.join(root, file)
all_files.append(full_path)
# Look for any .md file
if file.endswith(".md"):
if file.endswith('.md'):
found_md_path = full_path
log.info(f"Found markdown file at: {full_path}")
log.info(f'Found markdown file at: {full_path}')
try:
with open(full_path, "r", encoding="utf-8") as f:
with open(full_path, 'r', encoding='utf-8') as f:
markdown_content = f.read()
if (
markdown_content
): # Use the first non-empty markdown file
if markdown_content: # Use the first non-empty markdown file
break
except Exception as e:
log.warning(f"Failed to read {full_path}: {e}")
log.warning(f'Failed to read {full_path}: {e}')
if markdown_content:
break
if markdown_content is None:
log.error(f"Available files in ZIP: {all_files}")
log.error(f'Available files in ZIP: {all_files}')
# Try to provide more helpful error message
md_files = [f for f in all_files if f.endswith(".md")]
md_files = [f for f in all_files if f.endswith('.md')]
if md_files:
error_msg = (
f"Found .md files but couldn't read them: {md_files}"
)
error_msg = f"Found .md files but couldn't read them: {md_files}"
else:
error_msg = (
f"No .md files found in ZIP. Available files: {all_files}"
)
error_msg = f'No .md files found in ZIP. Available files: {all_files}'
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=error_msg,
@@ -504,21 +488,19 @@ class MinerULoader:
except zipfile.BadZipFile as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid ZIP file received: {e}",
detail=f'Invalid ZIP file received: {e}',
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error extracting ZIP: {str(e)}",
detail=f'Error extracting ZIP: {str(e)}',
)
if not markdown_content:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Extracted markdown content is empty",
detail='Extracted markdown content is empty',
)
log.info(
f"Successfully extracted markdown content ({len(markdown_content)} characters)"
)
log.info(f'Successfully extracted markdown content ({len(markdown_content)} characters)')
return markdown_content

View File

@@ -49,13 +49,11 @@ class MistralLoader:
enable_debug_logging: Enable detailed debug logs.
"""
if not api_key:
raise ValueError("API key cannot be empty.")
raise ValueError('API key cannot be empty.')
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found at {file_path}")
raise FileNotFoundError(f'File not found at {file_path}')
self.base_url = (
base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1"
)
self.base_url = base_url.rstrip('/') if base_url else 'https://api.mistral.ai/v1'
self.api_key = api_key
self.file_path = file_path
self.timeout = timeout
@@ -65,18 +63,10 @@ class MistralLoader:
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
# This prevents long-running OCR operations from affecting quick operations
# and improves user experience by failing fast on operations that should be quick
self.upload_timeout = min(
timeout, 120
) # Cap upload at 2 minutes - prevents hanging on large files
self.url_timeout = (
30 # URL requests should be fast - fail quickly if API is slow
)
self.ocr_timeout = (
timeout # OCR can take the full timeout - this is the heavy operation
)
self.cleanup_timeout = (
30 # Cleanup should be quick - don't hang on file deletion
)
self.upload_timeout = min(timeout, 120) # Cap upload at 2 minutes - prevents hanging on large files
self.url_timeout = 30 # URL requests should be fast - fail quickly if API is slow
self.ocr_timeout = timeout # OCR can take the full timeout - this is the heavy operation
self.cleanup_timeout = 30 # Cleanup should be quick - don't hang on file deletion
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
@@ -85,8 +75,8 @@ class MistralLoader:
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
'Authorization': f'Bearer {self.api_key}',
'User-Agent': 'OpenWebUI-MistralLoader/2.0', # Helps API provider track usage
}
def _debug_log(self, message: str, *args) -> None:
@@ -108,43 +98,39 @@ class MistralLoader:
return {} # Return empty dict if no content
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
log.error(f'HTTP error occurred: {http_err} - Response: {response.text}')
raise
except requests.exceptions.RequestException as req_err:
log.error(f"Request exception occurred: {req_err}")
log.error(f'Request exception occurred: {req_err}')
raise
except ValueError as json_err: # Includes JSONDecodeError
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
log.error(f'JSON decode error: {json_err} - Response: {response.text}')
raise # Re-raise after logging
async def _handle_response_async(
self, response: aiohttp.ClientResponse
) -> Dict[str, Any]:
async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]:
"""Async version of response handling with better error info."""
try:
response.raise_for_status()
# Check content type
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
content_type = response.headers.get('content-type', '')
if 'application/json' not in content_type:
if response.status == 204:
return {}
text = await response.text()
raise ValueError(
f"Unexpected content type: {content_type}, body: {text[:200]}..."
)
raise ValueError(f'Unexpected content type: {content_type}, body: {text[:200]}...')
return await response.json()
except aiohttp.ClientResponseError as e:
error_text = await response.text() if response else "No response"
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
error_text = await response.text() if response else 'No response'
log.error(f'HTTP {e.status}: {e.message} - Response: {error_text[:500]}')
raise
except aiohttp.ClientError as e:
log.error(f"Client error: {e}")
log.error(f'Client error: {e}')
raise
except Exception as e:
log.error(f"Unexpected error processing response: {e}")
log.error(f'Unexpected error processing response: {e}')
raise
def _is_retryable_error(self, error: Exception) -> bool:
@@ -172,13 +158,11 @@ class MistralLoader:
return True # Timeouts might resolve on retry
if isinstance(error, requests.exceptions.HTTPError):
# Only retry on server errors (5xx) or rate limits (429)
if hasattr(error, "response") and error.response is not None:
if hasattr(error, 'response') and error.response is not None:
status_code = error.response.status_code
return status_code >= 500 or status_code == 429
return False
if isinstance(
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
):
if isinstance(error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)):
return True # Async network/timeout errors are retryable
if isinstance(error, aiohttp.ClientResponseError):
return error.status >= 500 or error.status == 429
@@ -204,8 +188,7 @@ class MistralLoader:
# Prevents overwhelming the server while ensuring reasonable retry delays
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
f"Retrying in {wait_time}s..."
f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
)
time.sleep(wait_time)
@@ -226,8 +209,7 @@ class MistralLoader:
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
f"Retrying in {wait_time}s..."
f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
)
await asyncio.sleep(wait_time) # Non-blocking wait
@@ -240,15 +222,15 @@ class MistralLoader:
Although streaming is not enabled for this endpoint, the file is opened
in a context manager to minimize memory usage duration.
"""
log.info("Uploading file to Mistral API")
url = f"{self.base_url}/files"
log.info('Uploading file to Mistral API')
url = f'{self.base_url}/files'
def upload_request():
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
# This ensures the file is closed immediately after reading, reducing memory usage
with open(self.file_path, "rb") as f:
files = {"file": (self.file_name, f, "application/pdf")}
data = {"purpose": "ocr"}
with open(self.file_path, 'rb') as f:
files = {'file': (self.file_name, f, 'application/pdf')}
data = {'purpose': 'ocr'}
# NOTE: stream=False is required for this endpoint
# The Mistral API doesn't support chunked uploads for this endpoint
@@ -265,42 +247,38 @@ class MistralLoader:
try:
response_data = self._retry_request_sync(upload_request)
file_id = response_data.get("id")
file_id = response_data.get('id')
if not file_id:
raise ValueError("File ID not found in upload response.")
log.info(f"File uploaded successfully. File ID: {file_id}")
raise ValueError('File ID not found in upload response.')
log.info(f'File uploaded successfully. File ID: {file_id}')
return file_id
except Exception as e:
log.error(f"Failed to upload file: {e}")
log.error(f'Failed to upload file: {e}')
raise
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
"""Async file upload with streaming for better memory efficiency."""
url = f"{self.base_url}/files"
url = f'{self.base_url}/files'
async def upload_request():
# Create multipart writer for streaming upload
writer = aiohttp.MultipartWriter("form-data")
writer = aiohttp.MultipartWriter('form-data')
# Add purpose field
purpose_part = writer.append("ocr")
purpose_part.set_content_disposition("form-data", name="purpose")
purpose_part = writer.append('ocr')
purpose_part.set_content_disposition('form-data', name='purpose')
# Add file part with streaming
file_part = writer.append_payload(
aiohttp.streams.FilePayload(
self.file_path,
filename=self.file_name,
content_type="application/pdf",
content_type='application/pdf',
)
)
file_part.set_content_disposition(
"form-data", name="file", filename=self.file_name
)
file_part.set_content_disposition('form-data', name='file', filename=self.file_name)
self._debug_log(
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
)
self._debug_log(f'Uploading file: {self.file_name} ({self.file_size:,} bytes)')
async with session.post(
url,
@@ -312,48 +290,44 @@ class MistralLoader:
response_data = await self._retry_request_async(upload_request)
file_id = response_data.get("id")
file_id = response_data.get('id')
if not file_id:
raise ValueError("File ID not found in upload response.")
raise ValueError('File ID not found in upload response.')
log.info(f"File uploaded successfully. File ID: {file_id}")
log.info(f'File uploaded successfully. File ID: {file_id}')
return file_id
def _get_signed_url(self, file_id: str) -> str:
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
log.info(f"Getting signed URL for file ID: {file_id}")
url = f"{self.base_url}/files/{file_id}/url"
params = {"expiry": 1}
signed_url_headers = {**self.headers, "Accept": "application/json"}
log.info(f'Getting signed URL for file ID: {file_id}')
url = f'{self.base_url}/files/{file_id}/url'
params = {'expiry': 1}
signed_url_headers = {**self.headers, 'Accept': 'application/json'}
def url_request():
response = requests.get(
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
)
response = requests.get(url, headers=signed_url_headers, params=params, timeout=self.url_timeout)
return self._handle_response(response)
try:
response_data = self._retry_request_sync(url_request)
signed_url = response_data.get("url")
signed_url = response_data.get('url')
if not signed_url:
raise ValueError("Signed URL not found in response.")
log.info("Signed URL received.")
raise ValueError('Signed URL not found in response.')
log.info('Signed URL received.')
return signed_url
except Exception as e:
log.error(f"Failed to get signed URL: {e}")
log.error(f'Failed to get signed URL: {e}')
raise
async def _get_signed_url_async(
self, session: aiohttp.ClientSession, file_id: str
) -> str:
async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str:
"""Async signed URL retrieval."""
url = f"{self.base_url}/files/{file_id}/url"
params = {"expiry": 1}
url = f'{self.base_url}/files/{file_id}/url'
params = {'expiry': 1}
headers = {**self.headers, "Accept": "application/json"}
headers = {**self.headers, 'Accept': 'application/json'}
async def url_request():
self._debug_log(f"Getting signed URL for file ID: {file_id}")
self._debug_log(f'Getting signed URL for file ID: {file_id}')
async with session.get(
url,
headers=headers,
@@ -364,69 +338,65 @@ class MistralLoader:
response_data = await self._retry_request_async(url_request)
signed_url = response_data.get("url")
signed_url = response_data.get('url')
if not signed_url:
raise ValueError("Signed URL not found in response.")
raise ValueError('Signed URL not found in response.')
self._debug_log("Signed URL received successfully")
self._debug_log('Signed URL received successfully')
return signed_url
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
log.info("Processing OCR via Mistral API")
url = f"{self.base_url}/ocr"
log.info('Processing OCR via Mistral API')
url = f'{self.base_url}/ocr'
ocr_headers = {
**self.headers,
"Content-Type": "application/json",
"Accept": "application/json",
'Content-Type': 'application/json',
'Accept': 'application/json',
}
payload = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url,
'model': 'mistral-ocr-latest',
'document': {
'type': 'document_url',
'document_url': signed_url,
},
"include_image_base64": False,
'include_image_base64': False,
}
def ocr_request():
response = requests.post(
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
)
response = requests.post(url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout)
return self._handle_response(response)
try:
ocr_response = self._retry_request_sync(ocr_request)
log.info("OCR processing done.")
self._debug_log("OCR response: %s", ocr_response)
log.info('OCR processing done.')
self._debug_log('OCR response: %s', ocr_response)
return ocr_response
except Exception as e:
log.error(f"Failed during OCR processing: {e}")
log.error(f'Failed during OCR processing: {e}')
raise
async def _process_ocr_async(
self, session: aiohttp.ClientSession, signed_url: str
) -> Dict[str, Any]:
async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]:
"""Async OCR processing with timing metrics."""
url = f"{self.base_url}/ocr"
url = f'{self.base_url}/ocr'
headers = {
**self.headers,
"Content-Type": "application/json",
"Accept": "application/json",
'Content-Type': 'application/json',
'Accept': 'application/json',
}
payload = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url,
'model': 'mistral-ocr-latest',
'document': {
'type': 'document_url',
'document_url': signed_url,
},
"include_image_base64": False,
'include_image_base64': False,
}
async def ocr_request():
log.info("Starting OCR processing via Mistral API")
log.info('Starting OCR processing via Mistral API')
start_time = time.time()
async with session.post(
@@ -438,7 +408,7 @@ class MistralLoader:
ocr_response = await self._handle_response_async(response)
processing_time = time.time() - start_time
log.info(f"OCR processing completed in {processing_time:.2f}s")
log.info(f'OCR processing completed in {processing_time:.2f}s')
return ocr_response
@@ -446,42 +416,36 @@ class MistralLoader:
def _delete_file(self, file_id: str) -> None:
"""Deletes the file from Mistral storage (sync version)."""
log.info(f"Deleting uploaded file ID: {file_id}")
url = f"{self.base_url}/files/{file_id}"
log.info(f'Deleting uploaded file ID: {file_id}')
url = f'{self.base_url}/files/{file_id}'
try:
response = requests.delete(
url, headers=self.headers, timeout=self.cleanup_timeout
)
response = requests.delete(url, headers=self.headers, timeout=self.cleanup_timeout)
delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}")
log.info(f'File deleted successfully: {delete_response}')
except Exception as e:
# Log error but don't necessarily halt execution if deletion fails
log.error(f"Failed to delete file ID {file_id}: {e}")
log.error(f'Failed to delete file ID {file_id}: {e}')
async def _delete_file_async(
self, session: aiohttp.ClientSession, file_id: str
) -> None:
async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None:
"""Async file deletion with error tolerance."""
try:
async def delete_request():
self._debug_log(f"Deleting file ID: {file_id}")
self._debug_log(f'Deleting file ID: {file_id}')
async with session.delete(
url=f"{self.base_url}/files/{file_id}",
url=f'{self.base_url}/files/{file_id}',
headers=self.headers,
timeout=aiohttp.ClientTimeout(
total=self.cleanup_timeout
), # Shorter timeout for cleanup
timeout=aiohttp.ClientTimeout(total=self.cleanup_timeout), # Shorter timeout for cleanup
) as response:
return await self._handle_response_async(response)
await self._retry_request_async(delete_request)
self._debug_log(f"File {file_id} deleted successfully")
self._debug_log(f'File {file_id} deleted successfully')
except Exception as e:
# Don't fail the entire process if cleanup fails
log.warning(f"Failed to delete file ID {file_id}: {e}")
log.warning(f'Failed to delete file ID {file_id}: {e}')
@asynccontextmanager
async def _get_session(self):
@@ -506,7 +470,7 @@ class MistralLoader:
async with aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
headers={'User-Agent': 'OpenWebUI-MistralLoader/2.0'},
raise_for_status=False, # We handle status codes manually
trust_env=True,
) as session:
@@ -514,13 +478,13 @@ class MistralLoader:
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
pages_data = ocr_response.get("pages")
pages_data = ocr_response.get('pages')
if not pages_data:
log.warning("No pages found in OCR response.")
log.warning('No pages found in OCR response.')
return [
Document(
page_content="No text content found",
metadata={"error": "no_pages", "file_name": self.file_name},
page_content='No text content found',
metadata={'error': 'no_pages', 'file_name': self.file_name},
)
]
@@ -530,8 +494,8 @@ class MistralLoader:
# Process pages in a memory-efficient way
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
page_content = page_data.get('markdown')
page_index = page_data.get('index') # API uses 0-based index
if page_content is None or page_index is None:
skipped_pages += 1
@@ -548,7 +512,7 @@ class MistralLoader:
if not cleaned_content:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
self._debug_log(f'Skipping empty page {page_index}')
continue
# Create document with optimized metadata
@@ -556,34 +520,30 @@ class MistralLoader:
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index + 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
"content_length": len(cleaned_content),
'page': page_index, # 0-based index from API
'page_label': page_index + 1, # 1-based label for convenience
'total_pages': total_pages,
'file_name': self.file_name,
'file_size': self.file_size,
'processing_engine': 'mistral-ocr',
'content_length': len(cleaned_content),
},
)
)
if skipped_pages > 0:
log.info(
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
)
log.info(f'Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages')
if not documents:
# Case where pages existed but none had valid markdown/index
log.warning(
"OCR response contained pages, but none had valid content/index."
)
log.warning('OCR response contained pages, but none had valid content/index.')
return [
Document(
page_content="No valid text content found in document",
page_content='No valid text content found in document',
metadata={
"error": "no_valid_pages",
"total_pages": total_pages,
"file_name": self.file_name,
'error': 'no_valid_pages',
'total_pages': total_pages,
'file_name': self.file_name,
},
)
]
@@ -615,24 +575,20 @@ class MistralLoader:
documents = self._process_results(ocr_response)
total_time = time.time() - start_time
log.info(
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
log.info(f'Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
return documents
except Exception as e:
total_time = time.time() - start_time
log.error(
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
)
log.error(f'An error occurred during the loading process after {total_time:.2f}s: {e}')
# Return an error document on failure
return [
Document(
page_content=f"Error during processing: {e}",
page_content=f'Error during processing: {e}',
metadata={
"error": "processing_failed",
"file_name": self.file_name,
'error': 'processing_failed',
'file_name': self.file_name,
},
)
]
@@ -643,9 +599,7 @@ class MistralLoader:
self._delete_file(file_id)
except Exception as del_e:
# Log deletion error, but don't overwrite original error if one occurred
log.error(
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
)
log.error(f'Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}')
async def load_async(self) -> List[Document]:
"""
@@ -672,21 +626,19 @@ class MistralLoader:
documents = self._process_results(ocr_response)
total_time = time.time() - start_time
log.info(
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
log.info(f'Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
return documents
except Exception as e:
total_time = time.time() - start_time
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
log.error(f'Async OCR workflow failed after {total_time:.2f}s: {e}')
return [
Document(
page_content=f"Error during OCR processing: {e}",
page_content=f'Error during OCR processing: {e}',
metadata={
"error": "processing_failed",
"file_name": self.file_name,
'error': 'processing_failed',
'file_name': self.file_name,
},
)
]
@@ -697,11 +649,11 @@ class MistralLoader:
async with self._get_session() as session:
await self._delete_file_async(session, file_id)
except Exception as cleanup_error:
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
log.error(f'Cleanup failed for file ID {file_id}: {cleanup_error}')
@staticmethod
async def load_multiple_async(
loaders: List["MistralLoader"],
loaders: List['MistralLoader'],
max_concurrent: int = 5, # Limit concurrent requests
) -> List[List[Document]]:
"""
@@ -717,15 +669,13 @@ class MistralLoader:
if not loaders:
return []
log.info(
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
)
log.info(f'Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent')
start_time = time.time()
# Use semaphore to control concurrency
semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
async def process_with_semaphore(loader: 'MistralLoader') -> List[Document]:
async with semaphore:
return await loader.load_async()
@@ -737,14 +687,14 @@ class MistralLoader:
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
log.error(f"File {i} failed: {result}")
log.error(f'File {i} failed: {result}')
processed_results.append(
[
Document(
page_content=f"Error processing file: {result}",
page_content=f'Error processing file: {result}',
metadata={
"error": "batch_processing_failed",
"file_index": i,
'error': 'batch_processing_failed',
'file_index': i,
},
)
]
@@ -755,15 +705,13 @@ class MistralLoader:
# MONITORING: Log comprehensive batch processing statistics
total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results)
success_count = sum(
1 for result in results if not isinstance(result, Exception)
)
success_count = sum(1 for result in results if not isinstance(result, Exception))
failure_count = len(results) - success_count
log.info(
f"Batch processing completed in {total_time:.2f}s: "
f"{success_count} files succeeded, {failure_count} files failed, "
f"produced {total_docs} total documents"
f'Batch processing completed in {total_time:.2f}s: '
f'{success_count} files succeeded, {failure_count} files failed, '
f'produced {total_docs} total documents'
)
return processed_results

View File

@@ -25,7 +25,7 @@ class TavilyLoader(BaseLoader):
self,
urls: Union[str, List[str]],
api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic",
extract_depth: Literal['basic', 'advanced'] = 'basic',
continue_on_failure: bool = True,
) -> None:
"""Initialize Tavily Extract client.
@@ -42,13 +42,13 @@ class TavilyLoader(BaseLoader):
continue_on_failure: Whether to continue if extraction of a URL fails.
"""
if not urls:
raise ValueError("At least one URL must be provided.")
raise ValueError('At least one URL must be provided.')
self.api_key = api_key
self.urls = urls if isinstance(urls, list) else [urls]
self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure
self.api_url = "https://api.tavily.com/extract"
self.api_url = 'https://api.tavily.com/extract'
def lazy_load(self) -> Iterator[Document]:
"""Extract and yield documents from the URLs using Tavily Extract API."""
@@ -57,35 +57,35 @@ class TavilyLoader(BaseLoader):
batch_urls = self.urls[i : i + batch_size]
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
}
# Use string for single URL, array for multiple URLs
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
payload = {'urls': urls_param, 'extract_depth': self.extract_depth}
# Make the API call
response = requests.post(self.api_url, headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()
# Process successful results
for result in response_data.get("results", []):
url = result.get("url", "")
content = result.get("raw_content", "")
for result in response_data.get('results', []):
url = result.get('url', '')
content = result.get('raw_content', '')
if not content:
log.warning(f"No content extracted from {url}")
log.warning(f'No content extracted from {url}')
continue
# Add URLs as metadata
metadata = {"source": url}
metadata = {'source': url}
yield Document(
page_content=content,
metadata=metadata,
)
for failed in response_data.get("failed_results", []):
url = failed.get("url", "")
error = failed.get("error", "Unknown error")
log.error(f"Failed to extract content from {url}: {error}")
for failed in response_data.get('failed_results', []):
url = failed.get('url', '')
error = failed.get('error', 'Unknown error')
log.error(f'Failed to extract content from {url}: {error}')
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {batch_urls}: {e}")
log.error(f'Error extracting content from batch {batch_urls}: {e}')
else:
raise e

View File

@@ -7,14 +7,14 @@ from langchain_core.documents import Document
log = logging.getLogger(__name__)
ALLOWED_SCHEMES = {"http", "https"}
ALLOWED_SCHEMES = {'http', 'https'}
ALLOWED_NETLOCS = {
"youtu.be",
"m.youtube.com",
"youtube.com",
"www.youtube.com",
"www.youtube-nocookie.com",
"vid.plus",
'youtu.be',
'm.youtube.com',
'youtube.com',
'www.youtube.com',
'www.youtube-nocookie.com',
'vid.plus',
}
@@ -30,17 +30,17 @@ def _parse_video_id(url: str) -> Optional[str]:
path = parsed_url.path
if path.endswith("/watch"):
if path.endswith('/watch'):
query = parsed_url.query
parsed_query = parse_qs(query)
if "v" in parsed_query:
ids = parsed_query["v"]
if 'v' in parsed_query:
ids = parsed_query['v']
video_id = ids if isinstance(ids, str) else ids[0]
else:
return None
else:
path = parsed_url.path.lstrip("/")
video_id = path.split("/")[-1]
path = parsed_url.path.lstrip('/')
video_id = path.split('/')[-1]
if len(video_id) != 11: # Video IDs are 11 characters long
return None
@@ -54,13 +54,13 @@ class YoutubeLoader:
def __init__(
self,
video_id: str,
language: Union[str, Sequence[str]] = "en",
language: Union[str, Sequence[str]] = 'en',
proxy_url: Optional[str] = None,
):
"""Initialize with YouTube video ID."""
_video_id = _parse_video_id(video_id)
self.video_id = _video_id if _video_id is not None else video_id
self._metadata = {"source": video_id}
self._metadata = {'source': video_id}
self.proxy_url = proxy_url
# Ensure language is a list
@@ -70,8 +70,8 @@ class YoutubeLoader:
self.language = list(language)
# Add English as fallback if not already in the list
if "en" not in self.language:
self.language.append("en")
if 'en' not in self.language:
self.language.append('en')
def load(self) -> List[Document]:
"""Load YouTube transcripts into `Document` objects."""
@@ -85,14 +85,12 @@ class YoutubeLoader:
except ImportError:
raise ImportError(
'Could not import "youtube_transcript_api" Python package. '
"Please install it with `pip install youtube-transcript-api`."
'Please install it with `pip install youtube-transcript-api`.'
)
if self.proxy_url:
youtube_proxies = GenericProxyConfig(
http_url=self.proxy_url, https_url=self.proxy_url
)
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
youtube_proxies = GenericProxyConfig(http_url=self.proxy_url, https_url=self.proxy_url)
log.debug(f'Using proxy URL: {self.proxy_url[:14]}...')
else:
youtube_proxies = None
@@ -100,7 +98,7 @@ class YoutubeLoader:
try:
transcript_list = transcript_api.list(self.video_id)
except Exception as e:
log.exception("Loading YouTube transcript failed")
log.exception('Loading YouTube transcript failed')
return []
# Try each language in order of priority
@@ -110,14 +108,10 @@ class YoutubeLoader:
if transcript.is_generated:
log.debug(f"Found generated transcript for language '{lang}'")
try:
transcript = transcript_list.find_manually_created_transcript(
[lang]
)
transcript = transcript_list.find_manually_created_transcript([lang])
log.debug(f"Found manual transcript for language '{lang}'")
except NoTranscriptFound:
log.debug(
f"No manual transcript found for language '{lang}', using generated"
)
log.debug(f"No manual transcript found for language '{lang}', using generated")
pass
log.debug(f"Found transcript for language '{lang}'")
@@ -131,12 +125,10 @@ class YoutubeLoader:
log.debug(f"Empty transcript for language '{lang}'")
continue
transcript_text = " ".join(
transcript_text = ' '.join(
map(
lambda transcript_piece: (
transcript_piece.text.strip(" ")
if hasattr(transcript_piece, "text")
else ""
transcript_piece.text.strip(' ') if hasattr(transcript_piece, 'text') else ''
),
transcript_pieces,
)
@@ -150,9 +142,9 @@ class YoutubeLoader:
raise e
# If we get here, all languages failed
languages_tried = ", ".join(self.language)
languages_tried = ', '.join(self.language)
log.warning(
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
f'No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed.'
)
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))

View File

@@ -13,19 +13,17 @@ log = logging.getLogger(__name__)
class ColBERT(BaseReranker):
def __init__(self, name, **kwargs) -> None:
log.info("ColBERT: Loading model", name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
log.info('ColBERT: Loading model', name)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
DOCKER = kwargs.get("env") == "docker"
DOCKER = kwargs.get('env') == 'docker'
if DOCKER:
# This is a workaround for the issue with the docker container
# where the torch extension is not loaded properly
# and the following error is thrown:
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
lock_file = (
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
)
lock_file = '/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock'
if os.path.exists(lock_file):
os.remove(lock_file)
@@ -36,23 +34,16 @@ class ColBERT(BaseReranker):
pass
def calculate_similarity_scores(self, query_embeddings, document_embeddings):
query_embeddings = query_embeddings.to(self.device)
document_embeddings = document_embeddings.to(self.device)
# Validate dimensions to ensure compatibility
if query_embeddings.dim() != 3:
raise ValueError(
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
)
raise ValueError(f'Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}.')
if document_embeddings.dim() != 3:
raise ValueError(
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
)
raise ValueError(f'Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}.')
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
raise ValueError(
"There should be either one query or queries equal to the number of documents."
)
raise ValueError('There should be either one query or queries equal to the number of documents.')
# Transpose the query embeddings to align for matrix multiplication
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
@@ -69,7 +60,6 @@ class ColBERT(BaseReranker):
return normalized_scores.detach().cpu().numpy().astype(np.float32)
def predict(self, sentences):
query = sentences[0][0]
docs = [i[1] for i in sentences]
@@ -80,8 +70,6 @@ class ColBERT(BaseReranker):
embedded_query = embedded_queries[0]
# Calculate retrieval scores for the query against all documents
scores = self.calculate_similarity_scores(
embedded_query.unsqueeze(0), embedded_docs
)
scores = self.calculate_similarity_scores(embedded_query.unsqueeze(0), embedded_docs)
return scores

View File

@@ -15,8 +15,8 @@ class ExternalReranker(BaseReranker):
def __init__(
self,
api_key: str,
url: str = "http://localhost:8080/v1/rerank",
model: str = "reranker",
url: str = 'http://localhost:8080/v1/rerank',
model: str = 'reranker',
timeout: Optional[int] = None,
):
self.api_key = api_key
@@ -24,33 +24,31 @@ class ExternalReranker(BaseReranker):
self.model = model
self.timeout = timeout
def predict(
self, sentences: List[Tuple[str, str]], user=None
) -> Optional[List[float]]:
def predict(self, sentences: List[Tuple[str, str]], user=None) -> Optional[List[float]]:
query = sentences[0][0]
docs = [i[1] for i in sentences]
payload = {
"model": self.model,
"query": query,
"documents": docs,
"top_n": len(docs),
'model': self.model,
'query': query,
'documents': docs,
'top_n': len(docs),
}
try:
log.info(f"ExternalReranker:predict:model {self.model}")
log.info(f"ExternalReranker:predict:query {query}")
log.info(f'ExternalReranker:predict:model {self.model}')
log.info(f'ExternalReranker:predict:query {query}')
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
f"{self.url}",
f'{self.url}',
headers=headers,
json=payload,
timeout=self.timeout,
@@ -60,13 +58,13 @@ class ExternalReranker(BaseReranker):
r.raise_for_status()
data = r.json()
if "results" in data:
sorted_results = sorted(data["results"], key=lambda x: x["index"])
return [result["relevance_score"] for result in sorted_results]
if 'results' in data:
sorted_results = sorted(data['results'], key=lambda x: x['index'])
return [result['relevance_score'] for result in sorted_results]
else:
log.error("No results found in external reranking response")
log.error('No results found in external reranking response')
return None
except Exception as e:
log.exception(f"Error in external reranking: {e}")
log.exception(f'Error in external reranking: {e}')
return None

File diff suppressed because it is too large Load Diff

View File

@@ -31,17 +31,15 @@ log = logging.getLogger(__name__)
class ChromaClient(VectorDBBase):
def __init__(self):
settings_dict = {
"allow_reset": True,
"anonymized_telemetry": False,
'allow_reset': True,
'anonymized_telemetry': False,
}
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
settings_dict['chroma_client_auth_provider'] = CHROMA_CLIENT_AUTH_PROVIDER
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
settings_dict["chroma_client_auth_credentials"] = (
CHROMA_CLIENT_AUTH_CREDENTIALS
)
settings_dict['chroma_client_auth_credentials'] = CHROMA_CLIENT_AUTH_CREDENTIALS
if CHROMA_HTTP_HOST != "":
if CHROMA_HTTP_HOST != '':
self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
@@ -87,25 +85,23 @@ class ChromaClient(VectorDBBase):
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
# https://docs.trychroma.com/docs/collections/configure cosine equation
distances: list = result["distances"][0]
distances: list = result['distances'][0]
distances = [2 - dist for dist in distances]
distances = [[dist / 2 for dist in distances]]
return SearchResult(
**{
"ids": result["ids"],
"distances": distances,
"documents": result["documents"],
"metadatas": result["metadatas"],
'ids': result['ids'],
'distances': distances,
'documents': result['documents'],
'metadatas': result['metadatas'],
}
)
return None
except Exception as e:
return None
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
# Query the items from the collection based on the filter.
try:
collection = self.client.get_collection(name=collection_name)
@@ -117,9 +113,9 @@ class ChromaClient(VectorDBBase):
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
'ids': [result['ids']],
'documents': [result['documents']],
'metadatas': [result['metadatas']],
}
)
return None
@@ -133,23 +129,21 @@ class ChromaClient(VectorDBBase):
result = collection.get()
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
'ids': [result['ids']],
'documents': [result['documents']],
'metadatas': [result['metadatas']],
}
)
return None
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [process_metadata(item["metadata"]) for item in items]
ids = [item['id'] for item in items]
documents = [item['text'] for item in items]
embeddings = [item['vector'] for item in items]
metadatas = [process_metadata(item['metadata']) for item in items]
for batch in create_batches(
api=self.client,
@@ -162,18 +156,14 @@ class ChromaClient(VectorDBBase):
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [process_metadata(item["metadata"]) for item in items]
ids = [item['id'] for item in items]
documents = [item['text'] for item in items]
embeddings = [item['vector'] for item in items]
metadatas = [process_metadata(item['metadata']) for item in items]
collection.upsert(
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
)
collection.upsert(ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas)
def delete(
self,
@@ -191,9 +181,7 @@ class ChromaClient(VectorDBBase):
collection.delete(where=filter)
except Exception as e:
# If collection doesn't exist, that's fine - nothing to delete
log.debug(
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
)
log.debug(f'Attempted to delete from non-existent collection {collection_name}. Ignoring.')
pass
def reset(self):

View File

@@ -51,7 +51,7 @@ class ElasticsearchClient(VectorDBBase):
# Status: works
def _get_index_name(self, dimension: int) -> str:
return f"{self.index_prefix}_d{str(dimension)}"
return f'{self.index_prefix}_d{str(dimension)}'
# Status: works
def _scan_result_to_get_result(self, result) -> GetResult:
@@ -62,24 +62,24 @@ class ElasticsearchClient(VectorDBBase):
metadatas = []
for hit in result:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
ids.append(hit['_id'])
documents.append(hit['_source'].get('text'))
metadatas.append(hit['_source'].get('metadata'))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
# Status: works
def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
if not result['hits']['hits']:
return None
ids = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
for hit in result['hits']['hits']:
ids.append(hit['_id'])
documents.append(hit['_source'].get('text'))
metadatas.append(hit['_source'].get('metadata'))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
@@ -90,11 +90,11 @@ class ElasticsearchClient(VectorDBBase):
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
for hit in result['hits']['hits']:
ids.append(hit['_id'])
distances.append(hit['_score'])
documents.append(hit['_source'].get('text'))
metadatas.append(hit['_source'].get('metadata'))
return SearchResult(
ids=[ids],
@@ -106,26 +106,26 @@ class ElasticsearchClient(VectorDBBase):
# Status: works
def _create_index(self, dimension: int):
body = {
"mappings": {
"dynamic_templates": [
'mappings': {
'dynamic_templates': [
{
"strings": {
"match_mapping_type": "string",
"mapping": {"type": "keyword"},
'strings': {
'match_mapping_type': 'string',
'mapping': {'type': 'keyword'},
}
}
],
"properties": {
"collection": {"type": "keyword"},
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": True,
"similarity": "cosine",
'properties': {
'collection': {'type': 'keyword'},
'id': {'type': 'keyword'},
'vector': {
'type': 'dense_vector',
'dims': dimension, # Adjust based on your vector dimensions
'index': True,
'similarity': 'cosine',
},
"text": {"type": "text"},
"metadata": {"type": "object"},
'text': {'type': 'text'},
'metadata': {'type': 'object'},
},
}
}
@@ -139,21 +139,19 @@ class ElasticsearchClient(VectorDBBase):
# Status: works
def has_collection(self, collection_name) -> bool:
query_body = {"query": {"bool": {"filter": []}}}
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
query_body = {'query': {'bool': {'filter': []}}}
query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
try:
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
result = self.client.count(index=f'{self.index_prefix}*', body=query_body)
return result.body["count"] > 0
return result.body['count'] > 0
except Exception as e:
return None
def delete_collection(self, collection_name: str):
query = {"query": {"term": {"collection": collection_name}}}
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
query = {'query': {'term': {'collection': collection_name}}}
self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
# Status: works
def search(
@@ -164,51 +162,41 @@ class ElasticsearchClient(VectorDBBase):
limit: int = 10,
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {
"bool": {"filter": [{"term": {"collection": collection_name}}]}
},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
'size': limit,
'_source': ['text', 'metadata'],
'query': {
'script_score': {
'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
'script': {
'source': "cosineSimilarity(params.vector, 'vector') + 1.0",
'params': {'vector': vectors[0]}, # Assuming single query vector
},
}
},
}
result = self.client.search(
index=self._get_index_name(len(vectors[0])), body=query
)
result = self.client.search(index=self._get_index_name(len(vectors[0])), body=query)
return self._result_to_search_result(result)
# Status: only tested halfwat
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
if not self.has_collection(collection_name):
return None
query_body = {
"query": {"bool": {"filter": []}},
"_source": ["text", "metadata"],
'query': {'bool': {'filter': []}},
'_source': ['text', 'metadata'],
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
query_body['query']['bool']['filter'].append({'term': {field: value}})
query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
size = limit if limit else 10
try:
result = self.client.search(
index=f"{self.index_prefix}*",
index=f'{self.index_prefix}*',
body=query_body,
size=size,
)
@@ -220,9 +208,7 @@ class ElasticsearchClient(VectorDBBase):
# Status: works
def _has_index(self, dimension: int):
return self.client.indices.exists(
index=self._get_index_name(dimension=dimension)
)
return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
def get_or_create_index(self, dimension: int):
if not self._has_index(dimension=dimension):
@@ -232,28 +218,28 @@ class ElasticsearchClient(VectorDBBase):
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
"_source": ["text", "metadata"],
'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
'_source': ['text', 'metadata'],
}
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
results = list(scan(self.client, index=f'{self.index_prefix}*', query=query))
return self._scan_result_to_get_result(results)
# Status: works
def insert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(dimension=len(items[0]["vector"]))
if not self._has_index(dimension=len(items[0]['vector'])):
self._create_index(dimension=len(items[0]['vector']))
for batch in self._create_batches(items):
actions = [
{
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
"_id": item["id"],
"_source": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": process_metadata(item["metadata"]),
'_index': self._get_index_name(dimension=len(items[0]['vector'])),
'_id': item['id'],
'_source': {
'collection': collection_name,
'vector': item['vector'],
'text': item['text'],
'metadata': process_metadata(item['metadata']),
},
}
for item in batch
@@ -262,21 +248,21 @@ class ElasticsearchClient(VectorDBBase):
# Upsert documents using the update API with doc_as_upsert=True.
def upsert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(dimension=len(items[0]["vector"]))
if not self._has_index(dimension=len(items[0]['vector'])):
self._create_index(dimension=len(items[0]['vector']))
for batch in self._create_batches(items):
actions = [
{
"_op_type": "update",
"_index": self._get_index_name(dimension=len(item["vector"])),
"_id": item["id"],
"doc": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": process_metadata(item["metadata"]),
'_op_type': 'update',
'_index': self._get_index_name(dimension=len(item['vector'])),
'_id': item['id'],
'doc': {
'collection': collection_name,
'vector': item['vector'],
'text': item['text'],
'metadata': process_metadata(item['metadata']),
},
"doc_as_upsert": True,
'doc_as_upsert': True,
}
for item in batch
]
@@ -289,22 +275,17 @@ class ElasticsearchClient(VectorDBBase):
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
}
query = {'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}}}
# logic based on chromaDB
if ids:
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
query['query']['bool']['filter'].append({'terms': {'_id': ids}})
elif filter:
for field, value in filter.items():
query["query"]["bool"]["filter"].append(
{"term": {f"metadata.{field}": value}}
)
query['query']['bool']['filter'].append({'term': {f'metadata.{field}': value}})
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}*")
indices = self.client.indices.get(index=f'{self.index_prefix}*')
for index in indices:
self.client.indices.delete(index=index)

View File

@@ -47,8 +47,8 @@ def _embedding_to_f32_bytes(vec: List[float]) -> bytes:
byte sequence. We use array('f') to avoid a numpy dependency and byteswap on
big-endian platforms for portability.
"""
a = array.array("f", [float(x) for x in vec]) # float32
if sys.byteorder != "little":
a = array.array('f', [float(x) for x in vec]) # float32
if sys.byteorder != 'little':
a.byteswap()
return a.tobytes()
@@ -68,7 +68,7 @@ def _safe_json(v: Any) -> Dict[str, Any]:
return v
if isinstance(v, (bytes, bytearray)):
try:
v = v.decode("utf-8")
v = v.decode('utf-8')
except Exception:
return {}
if isinstance(v, str):
@@ -105,16 +105,16 @@ class MariaDBVectorClient(VectorDBBase):
"""
self.db_url = (db_url or MARIADB_VECTOR_DB_URL).strip()
self.vector_length = int(vector_length)
self.distance_strategy = (distance_strategy or "cosine").strip().lower()
self.distance_strategy = (distance_strategy or 'cosine').strip().lower()
self.index_m = int(index_m)
if self.distance_strategy not in {"cosine", "euclidean"}:
if self.distance_strategy not in {'cosine', 'euclidean'}:
raise ValueError("distance_strategy must be 'cosine' or 'euclidean'")
if not self.db_url.lower().startswith("mariadb+mariadbconnector://"):
if not self.db_url.lower().startswith('mariadb+mariadbconnector://'):
raise ValueError(
"MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) "
"to ensure qmark paramstyle and correct VECTOR binding."
'MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) '
'to ensure qmark paramstyle and correct VECTOR binding.'
)
if isinstance(MARIADB_VECTOR_POOL_SIZE, int):
@@ -129,9 +129,7 @@ class MariaDBVectorClient(VectorDBBase):
poolclass=QueuePool,
)
else:
self.engine = create_engine(
self.db_url, pool_pre_ping=True, poolclass=NullPool
)
self.engine = create_engine(self.db_url, pool_pre_ping=True, poolclass=NullPool)
else:
self.engine = create_engine(self.db_url, pool_pre_ping=True)
self._init_schema()
@@ -185,7 +183,7 @@ class MariaDBVectorClient(VectorDBBase):
conn.commit()
except Exception as e:
conn.rollback()
log.exception(f"Error during database initialization: {e}")
log.exception(f'Error during database initialization: {e}')
raise
def _check_vector_length(self) -> None:
@@ -197,19 +195,19 @@ class MariaDBVectorClient(VectorDBBase):
"""
with self._connect() as conn:
with conn.cursor() as cur:
cur.execute("SHOW CREATE TABLE document_chunk")
cur.execute('SHOW CREATE TABLE document_chunk')
row = cur.fetchone()
if not row or len(row) < 2:
return
ddl = row[1]
m = re.search(r"vector\\((\\d+)\\)", ddl, flags=re.IGNORECASE)
m = re.search(r'vector\\((\\d+)\\)', ddl, flags=re.IGNORECASE)
if not m:
return
existing = int(m.group(1))
if existing != int(self.vector_length):
raise Exception(
f"VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. "
"Cannot change vector size after initialization without migrating the data."
f'VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. '
'Cannot change vector size after initialization without migrating the data.'
)
def adjust_vector_length(self, vector: List[float]) -> List[float]:
@@ -227,11 +225,7 @@ class MariaDBVectorClient(VectorDBBase):
"""
Return the MariaDB Vector distance function name for the configured strategy.
"""
return (
"vec_distance_cosine"
if self.distance_strategy == "cosine"
else "vec_distance_euclidean"
)
return 'vec_distance_cosine' if self.distance_strategy == 'cosine' else 'vec_distance_euclidean'
def _score_from_dist(self, dist: float) -> float:
"""
@@ -240,7 +234,7 @@ class MariaDBVectorClient(VectorDBBase):
- cosine: score ~= 1 - cosine_distance, clamped to [0, 1]
- euclidean: score = 1 / (1 + dist)
"""
if self.distance_strategy == "cosine":
if self.distance_strategy == 'cosine':
score = 1.0 - dist
if score < 0.0:
score = 0.0
@@ -260,48 +254,48 @@ class MariaDBVectorClient(VectorDBBase):
- {"$or": [ ... ]}
"""
if not expr or not isinstance(expr, dict):
return "", []
return '', []
if "$and" in expr:
if '$and' in expr:
parts: List[str] = []
params: List[Any] = []
for e in expr.get("$and") or []:
for e in expr.get('$and') or []:
s, p = self._build_filter_sql_qmark(e)
if s:
parts.append(s)
params.extend(p)
return ("(" + " AND ".join(parts) + ")") if parts else "", params
return ('(' + ' AND '.join(parts) + ')') if parts else '', params
if "$or" in expr:
if '$or' in expr:
parts: List[str] = []
params: List[Any] = []
for e in expr.get("$or") or []:
for e in expr.get('$or') or []:
s, p = self._build_filter_sql_qmark(e)
if s:
parts.append(s)
params.extend(p)
return ("(" + " OR ".join(parts) + ")") if parts else "", params
return ('(' + ' OR '.join(parts) + ')') if parts else '', params
clauses: List[str] = []
params: List[Any] = []
for key, value in expr.items():
if key.startswith("$"):
if key.startswith('$'):
continue
json_expr = f"JSON_UNQUOTE(JSON_EXTRACT(vmetadata, '$.{key}'))"
if isinstance(value, dict) and "$in" in value:
vals = [str(v) for v in (value.get("$in") or [])]
if isinstance(value, dict) and '$in' in value:
vals = [str(v) for v in (value.get('$in') or [])]
if not vals:
clauses.append("0=1")
clauses.append('0=1')
continue
ors = []
for v in vals:
ors.append(f"{json_expr} = ?")
ors.append(f'{json_expr} = ?')
params.append(v)
clauses.append("(" + " OR ".join(ors) + ")")
clauses.append('(' + ' OR '.join(ors) + ')')
else:
clauses.append(f"{json_expr} = ?")
clauses.append(f'{json_expr} = ?')
params.append(str(value))
return ("(" + " AND ".join(clauses) + ")") if clauses else "", params
return ('(' + ' AND '.join(clauses) + ')') if clauses else '', params
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
@@ -322,15 +316,15 @@ class MariaDBVectorClient(VectorDBBase):
"""
params: List[Tuple[Any, ...]] = []
for item in items:
v = self.adjust_vector_length(item["vector"])
v = self.adjust_vector_length(item['vector'])
emb = _embedding_to_f32_bytes(v)
meta = process_metadata(item.get("metadata") or {})
meta = process_metadata(item.get('metadata') or {})
params.append(
(
item["id"],
item['id'],
emb,
collection_name,
item.get("text"),
item.get('text'),
json.dumps(meta),
)
)
@@ -338,7 +332,7 @@ class MariaDBVectorClient(VectorDBBase):
conn.commit()
except Exception as e:
conn.rollback()
log.exception(f"Error during insert: {e}")
log.exception(f'Error during insert: {e}')
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -365,15 +359,15 @@ class MariaDBVectorClient(VectorDBBase):
"""
params: List[Tuple[Any, ...]] = []
for item in items:
v = self.adjust_vector_length(item["vector"])
v = self.adjust_vector_length(item['vector'])
emb = _embedding_to_f32_bytes(v)
meta = process_metadata(item.get("metadata") or {})
meta = process_metadata(item.get('metadata') or {})
params.append(
(
item["id"],
item['id'],
emb,
collection_name,
item.get("text"),
item.get('text'),
json.dumps(meta),
)
)
@@ -381,7 +375,7 @@ class MariaDBVectorClient(VectorDBBase):
conn.commit()
except Exception as e:
conn.rollback()
log.exception(f"Error during upsert: {e}")
log.exception(f'Error during upsert: {e}')
raise
def search(
@@ -415,10 +409,10 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn:
with conn.cursor() as cur:
fsql, fparams = self._build_filter_sql_qmark(filter or {})
where = "collection_name = ?"
where = 'collection_name = ?'
base_params: List[Any] = [collection_name]
if fsql:
where = where + " AND " + fsql
where = where + ' AND ' + fsql
base_params.extend(fparams)
sql = f"""
@@ -460,26 +454,24 @@ class MariaDBVectorClient(VectorDBBase):
metadatas=metadatas,
)
except Exception as e:
log.exception(f"[MARIADB_VECTOR] search() failed: {e}")
log.exception(f'[MARIADB_VECTOR] search() failed: {e}')
return None
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
"""
Retrieve documents by metadata filter (non-vector query).
"""
with self._connect() as conn:
with conn.cursor() as cur:
fsql, fparams = self._build_filter_sql_qmark(filter or {})
where = "collection_name = ?"
where = 'collection_name = ?'
params: List[Any] = [collection_name]
if fsql:
where = where + " AND " + fsql
where = where + ' AND ' + fsql
params.extend(fparams)
sql = f"SELECT id, text, vmetadata FROM document_chunk WHERE {where}"
sql = f'SELECT id, text, vmetadata FROM document_chunk WHERE {where}'
if limit is not None:
sql += " LIMIT ?"
sql += ' LIMIT ?'
params.append(int(limit))
cur.execute(sql, params)
rows = cur.fetchall()
@@ -490,18 +482,16 @@ class MariaDBVectorClient(VectorDBBase):
metadatas = [[_safe_json(r[2]) for r in rows]]
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
def get(
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
"""
Retrieve documents in a collection without filtering (optionally limited).
"""
with self._connect() as conn:
with conn.cursor() as cur:
sql = "SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?"
sql = 'SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?'
params: List[Any] = [collection_name]
if limit is not None:
sql += " LIMIT ?"
sql += ' LIMIT ?'
params.append(int(limit))
cur.execute(sql, params)
rows = cur.fetchall()
@@ -526,12 +516,12 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn:
with conn.cursor() as cur:
try:
where = ["collection_name = ?"]
where = ['collection_name = ?']
params: List[Any] = [collection_name]
if ids:
ph = ", ".join(["?"] * len(ids))
where.append(f"id IN ({ph})")
ph = ', '.join(['?'] * len(ids))
where.append(f'id IN ({ph})')
params.extend(ids)
if filter:
@@ -540,12 +530,12 @@ class MariaDBVectorClient(VectorDBBase):
where.append(fsql)
params.extend(fparams)
sql = "DELETE FROM document_chunk WHERE " + " AND ".join(where)
sql = 'DELETE FROM document_chunk WHERE ' + ' AND '.join(where)
cur.execute(sql, params)
conn.commit()
except Exception as e:
conn.rollback()
log.exception(f"Error during delete: {e}")
log.exception(f'Error during delete: {e}')
raise
def reset(self) -> None:
@@ -555,11 +545,11 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn:
with conn.cursor() as cur:
try:
cur.execute("TRUNCATE TABLE document_chunk")
cur.execute('TRUNCATE TABLE document_chunk')
conn.commit()
except Exception as e:
conn.rollback()
log.exception(f"Error during reset: {e}")
log.exception(f'Error during reset: {e}')
raise
def has_collection(self, collection_name: str) -> bool:
@@ -570,7 +560,7 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1",
'SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1',
(collection_name,),
)
return cur.fetchone() is not None
@@ -590,4 +580,4 @@ class MariaDBVectorClient(VectorDBBase):
try:
self.engine.dispose()
except Exception as e:
log.exception(f"Error during dispose the underlying SQLAlchemy engine: {e}")
log.exception(f'Error during dispose the underlying SQLAlchemy engine: {e}')

View File

@@ -35,7 +35,7 @@ log = logging.getLogger(__name__)
class MilvusClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open_webui"
self.collection_prefix = 'open_webui'
if MILVUS_TOKEN is None:
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
else:
@@ -50,17 +50,17 @@ class MilvusClient(VectorDBBase):
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text"))
_metadatas.append(item.get("metadata"))
_ids.append(item.get('id'))
_documents.append(item.get('data', {}).get('text'))
_metadatas.append(item.get('metadata'))
ids.append(_ids)
documents.append(_documents)
metadatas.append(_metadatas)
return GetResult(
**{
"ids": ids,
"documents": documents,
"metadatas": metadatas,
'ids': ids,
'documents': documents,
'metadatas': metadatas,
}
)
@@ -75,23 +75,23 @@ class MilvusClient(VectorDBBase):
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_ids.append(item.get('id'))
# normalize milvus score from [-1, 1] to [0, 1] range
# https://milvus.io/docs/de/metric.md
_dist = (item.get("distance") + 1.0) / 2.0
_dist = (item.get('distance') + 1.0) / 2.0
_distances.append(_dist)
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata"))
_documents.append(item.get('entity', {}).get('data', {}).get('text'))
_metadatas.append(item.get('entity', {}).get('metadata'))
ids.append(_ids)
distances.append(_distances)
documents.append(_documents)
metadatas.append(_metadatas)
return SearchResult(
**{
"ids": ids,
"distances": distances,
"documents": documents,
"metadatas": metadatas,
'ids': ids,
'distances': distances,
'documents': documents,
'metadatas': metadatas,
}
)
@@ -101,21 +101,19 @@ class MilvusClient(VectorDBBase):
enable_dynamic_field=True,
)
schema.add_field(
field_name="id",
field_name='id',
datatype=DataType.VARCHAR,
is_primary=True,
max_length=65535,
)
schema.add_field(
field_name="vector",
field_name='vector',
datatype=DataType.FLOAT_VECTOR,
dim=dimension,
description="vector",
)
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
schema.add_field(
field_name="metadata", datatype=DataType.JSON, description="metadata"
description='vector',
)
schema.add_field(field_name='data', datatype=DataType.JSON, description='data')
schema.add_field(field_name='metadata', datatype=DataType.JSON, description='metadata')
index_params = self.client.prepare_index_params()
@@ -123,44 +121,44 @@ class MilvusClient(VectorDBBase):
index_type = MILVUS_INDEX_TYPE.upper()
metric_type = MILVUS_METRIC_TYPE.upper()
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
log.info(f'Using Milvus index type: {index_type}, metric type: {metric_type}')
index_creation_params = {}
if index_type == "HNSW":
if index_type == 'HNSW':
index_creation_params = {
"M": MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
'M': MILVUS_HNSW_M,
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
}
log.info(f"HNSW params: {index_creation_params}")
elif index_type == "IVF_FLAT":
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
log.info(f"IVF_FLAT params: {index_creation_params}")
elif index_type == "DISKANN":
log.info(f'HNSW params: {index_creation_params}')
elif index_type == 'IVF_FLAT':
index_creation_params = {'nlist': MILVUS_IVF_FLAT_NLIST}
log.info(f'IVF_FLAT params: {index_creation_params}')
elif index_type == 'DISKANN':
index_creation_params = {
"max_degree": MILVUS_DISKANN_MAX_DEGREE,
"search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE,
'max_degree': MILVUS_DISKANN_MAX_DEGREE,
'search_list_size': MILVUS_DISKANN_SEARCH_LIST_SIZE,
}
log.info(f"DISKANN params: {index_creation_params}")
elif index_type in ["FLAT", "AUTOINDEX"]:
log.info(f"Using {index_type} index with no specific build-time params.")
log.info(f'DISKANN params: {index_creation_params}')
elif index_type in ['FLAT', 'AUTOINDEX']:
log.info(f'Using {index_type} index with no specific build-time params.')
else:
log.warning(
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. "
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
f'Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. '
f'Milvus will use its default for the collection if this type is not directly supported for index creation.'
)
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
index_params.add_index(
field_name="vector",
field_name='vector',
index_type=index_type,
metric_type=metric_type,
params=index_creation_params,
)
self.client.create_collection(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
schema=schema,
index_params=index_params,
)
@@ -170,17 +168,13 @@ class MilvusClient(VectorDBBase):
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
collection_name = collection_name.replace("-", "_")
return self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
collection_name = collection_name.replace('-', '_')
return self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name.
collection_name = collection_name.replace("-", "_")
return self.client.drop_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
collection_name = collection_name.replace('-', '_')
return self.client.drop_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
def search(
self,
@@ -190,15 +184,15 @@ class MilvusClient(VectorDBBase):
limit: int = 10,
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_")
collection_name = collection_name.replace('-', '_')
# For some index types like IVF_FLAT, search params like nprobe can be set.
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
# For simplicity, not adding configurable search_params here, but could be extended.
result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
data=vectors,
limit=limit,
output_fields=["data", "metadata"],
output_fields=['data', 'metadata'],
# search_params=search_params # Potentially add later if needed
)
return self._result_to_search_result(result)
@@ -206,11 +200,9 @@ class MilvusClient(VectorDBBase):
def query(self, collection_name: str, filter: dict, limit: int = -1):
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
collection_name = collection_name.replace("-", "_")
collection_name = collection_name.replace('-', '_')
if not self.has_collection(collection_name):
log.warning(
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
log.warning(f'Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
return None
filter_expressions = []
@@ -220,9 +212,9 @@ class MilvusClient(VectorDBBase):
else:
filter_expressions.append(f'metadata["{key}"] == {value}')
filter_string = " && ".join(filter_expressions)
filter_string = ' && '.join(filter_expressions)
collection = Collection(f"{self.collection_prefix}_{collection_name}")
collection = Collection(f'{self.collection_prefix}_{collection_name}')
collection.load()
try:
@@ -233,9 +225,9 @@ class MilvusClient(VectorDBBase):
iterator = collection.query_iterator(
expr=filter_string,
output_fields=[
"id",
"data",
"metadata",
'id',
'data',
'metadata',
],
limit=limit if limit > 0 else -1,
)
@@ -248,7 +240,7 @@ class MilvusClient(VectorDBBase):
break
all_results.extend(batch)
log.debug(f"Total results from query: {len(all_results)}")
log.debug(f'Total results from query: {len(all_results)}')
return self._result_to_get_result([all_results] if all_results else [[]])
except Exception as e:
@@ -259,7 +251,7 @@ class MilvusClient(VectorDBBase):
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. This can be very resource-intensive for large collections.
collection_name = collection_name.replace("-", "_")
collection_name = collection_name.replace('-', '_')
log.warning(
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
)
@@ -269,35 +261,25 @@ class MilvusClient(VectorDBBase):
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_")
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
)
collection_name = collection_name.replace('-', '_')
if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.')
if not items:
log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
f'Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.'
)
raise ValueError(
"Cannot create Milvus collection without items to determine vector dimension."
)
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
raise ValueError('Cannot create Milvus collection without items to determine vector dimension.')
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
log.info(
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
log.info(f'Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": process_metadata(item["metadata"]),
'id': item['id'],
'vector': item['vector'],
'data': {'text': item['text']},
'metadata': process_metadata(item['metadata']),
}
for item in items
],
@@ -305,35 +287,27 @@ class MilvusClient(VectorDBBase):
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_")
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
)
collection_name = collection_name.replace('-', '_')
if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.')
if not items:
log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
f'Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.'
)
raise ValueError(
"Cannot create Milvus collection for upsert without items to determine vector dimension."
'Cannot create Milvus collection for upsert without items to determine vector dimension.'
)
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
log.info(
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
log.info(f'Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": process_metadata(item["metadata"]),
'id': item['id'],
'vector': item['vector'],
'data': {'text': item['text']},
'metadata': process_metadata(item['metadata']),
}
for item in items
],
@@ -346,46 +320,35 @@ class MilvusClient(VectorDBBase):
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids or filter.
collection_name = collection_name.replace("-", "_")
collection_name = collection_name.replace('-', '_')
if not self.has_collection(collection_name):
log.warning(
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
log.warning(f'Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
return None
if ids:
log.info(
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
)
log.info(f'Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}')
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
ids=ids,
)
elif filter:
filter_string = " && ".join(
[
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
filter_string = ' && '.join([f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items()])
log.info(
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
f'Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}'
)
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
filter=filter_string,
)
else:
log.warning(
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
f'Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.'
)
return None
def reset(self):
# Resets the database. This will delete all collections and item entries that match the prefix.
log.warning(
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
)
log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
collection_names = self.client.list_collections()
deleted_collections = []
for collection_name_full in collection_names:
@@ -393,7 +356,7 @@ class MilvusClient(VectorDBBase):
try:
self.client.drop_collection(collection_name=collection_name_full)
deleted_collections.append(collection_name_full)
log.info(f"Deleted collection: {collection_name_full}")
log.info(f'Deleted collection: {collection_name_full}')
except Exception as e:
log.error(f"Error deleting collection {collection_name_full}: {e}")
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")
log.error(f'Error deleting collection {collection_name_full}: {e}')
log.info(f'Milvus reset complete. Deleted collections: {deleted_collections}')

View File

@@ -33,26 +33,26 @@ from pymilvus import (
log = logging.getLogger(__name__)
RESOURCE_ID_FIELD = "resource_id"
RESOURCE_ID_FIELD = 'resource_id'
class MilvusClient(VectorDBBase):
def __init__(self):
# Milvus collection names can only contain numbers, letters, and underscores.
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace('-', '_')
connections.connect(
alias="default",
alias='default',
uri=MILVUS_URI,
token=MILVUS_TOKEN,
db_name=MILVUS_DB,
)
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories'
self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge'
self.FILE_COLLECTION = f'{self.collection_prefix}_files'
self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web_search'
self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash_based'
self.shared_collections = [
self.MEMORY_COLLECTION,
self.KNOWLEDGE_COLLECTION,
@@ -74,15 +74,13 @@ class MilvusClient(VectorDBBase):
"""
resource_id = collection_name
if collection_name.startswith("user-memory-"):
if collection_name.startswith('user-memory-'):
return self.MEMORY_COLLECTION, resource_id
elif collection_name.startswith("file-"):
elif collection_name.startswith('file-'):
return self.FILE_COLLECTION, resource_id
elif collection_name.startswith("web-search-"):
elif collection_name.startswith('web-search-'):
return self.WEB_SEARCH_COLLECTION, resource_id
elif len(collection_name) == 63 and all(
c in "0123456789abcdef" for c in collection_name
):
elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name):
return self.HASH_BASED_COLLECTION, resource_id
else:
return self.KNOWLEDGE_COLLECTION, resource_id
@@ -90,36 +88,36 @@ class MilvusClient(VectorDBBase):
def _create_shared_collection(self, mt_collection_name: str, dimension: int):
fields = [
FieldSchema(
name="id",
name='id',
dtype=DataType.VARCHAR,
is_primary=True,
auto_id=False,
max_length=36,
),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="metadata", dtype=DataType.JSON),
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name='metadata', dtype=DataType.JSON),
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
]
schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
schema = CollectionSchema(fields, 'Shared collection for multi-tenancy')
collection = Collection(mt_collection_name, schema)
index_params = {
"metric_type": MILVUS_METRIC_TYPE,
"index_type": MILVUS_INDEX_TYPE,
"params": {},
'metric_type': MILVUS_METRIC_TYPE,
'index_type': MILVUS_INDEX_TYPE,
'params': {},
}
if MILVUS_INDEX_TYPE == "HNSW":
index_params["params"] = {
"M": MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
if MILVUS_INDEX_TYPE == 'HNSW':
index_params['params'] = {
'M': MILVUS_HNSW_M,
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
}
elif MILVUS_INDEX_TYPE == "IVF_FLAT":
index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
elif MILVUS_INDEX_TYPE == 'IVF_FLAT':
index_params['params'] = {'nlist': MILVUS_IVF_FLAT_NLIST}
collection.create_index("vector", index_params)
collection.create_index('vector', index_params)
collection.create_index(RESOURCE_ID_FIELD)
log.info(f"Created shared collection: {mt_collection_name}")
log.info(f'Created shared collection: {mt_collection_name}')
return collection
def _ensure_collection(self, mt_collection_name: str, dimension: int):
@@ -127,9 +125,7 @@ class MilvusClient(VectorDBBase):
self._create_shared_collection(mt_collection_name, dimension)
def has_collection(self, collection_name: str) -> bool:
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
if not utility.has_collection(mt_collection):
return False
@@ -141,19 +137,17 @@ class MilvusClient(VectorDBBase):
def upsert(self, collection_name: str, items: List[VectorItem]):
if not items:
return
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
dimension = len(items[0]["vector"])
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
dimension = len(items[0]['vector'])
self._ensure_collection(mt_collection, dimension)
collection = Collection(mt_collection)
entities = [
{
"id": item["id"],
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
'id': item['id'],
'vector': item['vector'],
'text': item['text'],
'metadata': item['metadata'],
RESOURCE_ID_FIELD: resource_id,
}
for item in items
@@ -170,41 +164,37 @@ class MilvusClient(VectorDBBase):
if not vectors:
return None
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
if not utility.has_collection(mt_collection):
return None
collection = Collection(mt_collection)
collection.load()
search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
search_params = {'metric_type': MILVUS_METRIC_TYPE, 'params': {}}
results = collection.search(
data=vectors,
anns_field="vector",
anns_field='vector',
param=search_params,
limit=limit,
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
output_fields=["id", "text", "metadata"],
output_fields=['id', 'text', 'metadata'],
)
ids, documents, metadatas, distances = [], [], [], []
for hits in results:
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
for hit in hits:
batch_ids.append(hit.entity.get("id"))
batch_docs.append(hit.entity.get("text"))
batch_metadatas.append(hit.entity.get("metadata"))
batch_ids.append(hit.entity.get('id'))
batch_docs.append(hit.entity.get('text'))
batch_metadatas.append(hit.entity.get('metadata'))
batch_dists.append(hit.distance)
ids.append(batch_ids)
documents.append(batch_docs)
metadatas.append(batch_metadatas)
distances.append(batch_dists)
return SearchResult(
ids=ids, documents=documents, metadatas=metadatas, distances=distances
)
return SearchResult(ids=ids, documents=documents, metadatas=metadatas, distances=distances)
def delete(
self,
@@ -212,9 +202,7 @@ class MilvusClient(VectorDBBase):
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
):
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
if not utility.has_collection(mt_collection):
return
@@ -224,14 +212,14 @@ class MilvusClient(VectorDBBase):
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
if ids:
# Milvus expects a string list for 'in' operator
id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
expr.append(f"id in [{id_list_str}]")
id_list_str = ', '.join([f"'{id_val}'" for id_val in ids])
expr.append(f'id in [{id_list_str}]')
if filter:
for key, value in filter.items():
expr.append(f"metadata['{key}'] == '{value}'")
collection.delete(" and ".join(expr))
collection.delete(' and '.join(expr))
def reset(self):
for collection_name in self.shared_collections:
@@ -239,21 +227,15 @@ class MilvusClient(VectorDBBase):
utility.drop_collection(collection_name)
def delete_collection(self, collection_name: str):
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
if not utility.has_collection(mt_collection):
return
collection = Collection(mt_collection)
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
if not utility.has_collection(mt_collection):
return None
@@ -269,8 +251,8 @@ class MilvusClient(VectorDBBase):
expr.append(f"metadata['{key}'] == {value}")
iterator = collection.query_iterator(
expr=" and ".join(expr),
output_fields=["id", "text", "metadata"],
expr=' and '.join(expr),
output_fields=['id', 'text', 'metadata'],
limit=limit if limit else -1,
)
@@ -282,9 +264,9 @@ class MilvusClient(VectorDBBase):
break
all_results.extend(batch)
ids = [res["id"] for res in all_results]
documents = [res["text"] for res in all_results]
metadatas = [res["metadata"] for res in all_results]
ids = [res['id'] for res in all_results]
documents = [res['text'] for res in all_results]
metadatas = [res['metadata'] for res in all_results]
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])

View File

@@ -36,17 +36,15 @@ from sqlalchemy.dialects import registry
class OpenGaussDialect(PGDialect_psycopg2):
name = "opengauss"
name = 'opengauss'
def _get_server_version_info(self, connection):
try:
version = connection.exec_driver_sql("SELECT version()").scalar()
version = connection.exec_driver_sql('SELECT version()').scalar()
if not version:
return (9, 0, 0)
match = re.search(
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", version, re.IGNORECASE
)
match = re.search(r'openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?', version, re.IGNORECASE)
if match:
return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
@@ -56,7 +54,7 @@ class OpenGaussDialect(PGDialect_psycopg2):
# Register dialect
registry.register("opengauss", __name__, "OpenGaussDialect")
registry.register('opengauss', __name__, 'OpenGaussDialect')
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
@@ -80,11 +78,11 @@ VECTOR_LENGTH = OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH
Base = declarative_base()
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
log.setLevel(SRC_LOG_LEVELS['RAG'])
class DocumentChunk(Base):
__tablename__ = "document_chunk"
__tablename__ = 'document_chunk'
id = Column(Text, primary_key=True)
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
@@ -100,26 +98,24 @@ class OpenGaussClient(VectorDBBase):
self.session = ScopedSession
else:
engine_kwargs = {"pool_pre_ping": True, "dialect": OpenGaussDialect()}
engine_kwargs = {'pool_pre_ping': True, 'dialect': OpenGaussDialect()}
if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0:
engine_kwargs.update(
{
"pool_size": OPENGAUSS_POOL_SIZE,
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
"pool_timeout": OPENGAUSS_POOL_TIMEOUT,
"pool_recycle": OPENGAUSS_POOL_RECYCLE,
"poolclass": QueuePool,
'pool_size': OPENGAUSS_POOL_SIZE,
'max_overflow': OPENGAUSS_POOL_MAX_OVERFLOW,
'pool_timeout': OPENGAUSS_POOL_TIMEOUT,
'pool_recycle': OPENGAUSS_POOL_RECYCLE,
'poolclass': QueuePool,
}
)
else:
engine_kwargs["poolclass"] = NullPool
engine_kwargs['poolclass'] = NullPool
engine = create_engine(OPENGAUSS_DB_URL, **engine_kwargs)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
self.session = scoped_session(SessionLocal)
try:
@@ -128,47 +124,42 @@ class OpenGaussClient(VectorDBBase):
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
'CREATE INDEX IF NOT EXISTS idx_document_chunk_vector '
'ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);'
)
)
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
"ON document_chunk (collection_name);"
'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);'
)
)
self.session.commit()
log.info("OpenGauss vector database initialization completed.")
log.info('OpenGauss vector database initialization completed.')
except Exception as e:
self.session.rollback()
log.exception(f"OpenGauss Initialization failed.: {e}")
log.exception(f'OpenGauss Initialization failed.: {e}')
raise
def check_vector_length(self) -> None:
metadata = MetaData()
try:
document_chunk_table = Table(
"document_chunk", metadata, autoload_with=self.session.bind
)
document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind)
except NoSuchTableError:
return
if "vector" in document_chunk_table.columns:
vector_column = document_chunk_table.columns["vector"]
if 'vector' in document_chunk_table.columns:
vector_column = document_chunk_table.columns['vector']
vector_type = vector_column.type
if isinstance(vector_type, Vector):
db_vector_length = vector_type.dim
if db_vector_length != VECTOR_LENGTH:
raise Exception(
f"Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database."
f'Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database.'
)
else:
raise Exception("The 'vector' column type is not Vector.")
else:
raise Exception(
"The 'vector' column does not exist in the 'document_chunk' table."
)
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
def adjust_vector_length(self, vector: List[float]) -> List[float]:
current_length = len(vector)
@@ -182,55 +173,47 @@ class OpenGaussClient(VectorDBBase):
try:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
vector = self.adjust_vector_length(item['vector'])
new_chunk = DocumentChunk(
id=item["id"],
id=item['id'],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
text=item['text'],
vmetadata=process_metadata(item['metadata']),
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserting {len(new_items)} items into collection '{collection_name}'."
)
log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Failed to insert data: {e}")
log.exception(f'Failed to insert data: {e}')
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
vector = self.adjust_vector_length(item['vector'])
existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first()
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = process_metadata(item["metadata"])
existing.text = item['text']
existing.vmetadata = process_metadata(item['metadata'])
existing.collection_name = collection_name
else:
new_chunk = DocumentChunk(
id=item["id"],
id=item['id'],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
text=item['text'],
vmetadata=process_metadata(item['metadata']),
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Inserting/updating {len(items)} items in collection '{collection_name}'."
)
log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Failed to insert or update data.: {e}")
log.exception(f'Failed to insert or update data.: {e}')
raise
def search(
@@ -250,35 +233,29 @@ class OpenGaussClient(VectorDBBase):
def vector_expr(vector):
return cast(array(vector), Vector(VECTOR_LENGTH))
qid_col = column("qid", Integer)
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
qid_col = column('qid', Integer)
q_vector_col = column('q_vector', Vector(VECTOR_LENGTH))
query_vectors = (
values(qid_col, q_vector_col)
.data(
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
)
.alias("query_vectors")
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
.alias('query_vectors')
)
result_fields = [
DocumentChunk.id,
DocumentChunk.text,
DocumentChunk.vmetadata,
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
),
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance'),
]
subq = (
select(*result_fields)
.where(DocumentChunk.collection_name == collection_name)
.order_by(
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
)
.order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
)
if limit is not None:
subq = subq.limit(limit)
subq = subq.lateral("result")
subq = subq.lateral('result')
stmt = (
select(
@@ -309,21 +286,15 @@ class OpenGaussClient(VectorDBBase):
metadatas[qid].append(row.vmetadata)
self.session.rollback()
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Vector search failed: {e}")
log.exception(f'Vector search failed: {e}')
return None
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
for key, value in filter.items():
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
@@ -344,16 +315,12 @@ class OpenGaussClient(VectorDBBase):
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Conditional query failed: {e}")
log.exception(f'Conditional query failed: {e}')
return None
def get(
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
if limit is not None:
query = query.limit(limit)
@@ -370,7 +337,7 @@ class OpenGaussClient(VectorDBBase):
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Failed to retrieve data: {e}")
log.exception(f'Failed to retrieve data: {e}')
return None
def delete(
@@ -380,32 +347,28 @@ class OpenGaussClient(VectorDBBase):
filter: Optional[Dict[str, Any]] = None,
) -> None:
try:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
deleted = query.delete(synchronize_session=False)
self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'")
except Exception as e:
self.session.rollback()
log.exception(f"Failed to delete data: {e}")
log.exception(f'Failed to delete data: {e}')
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
log.info(f"Reset completed. Deleted {deleted} items")
log.info(f'Reset completed. Deleted {deleted} items')
except Exception as e:
self.session.rollback()
log.exception(f"Reset failed: {e}")
log.exception(f'Reset failed: {e}')
raise
def close(self) -> None:
@@ -414,16 +377,14 @@ class OpenGaussClient(VectorDBBase):
def has_collection(self, collection_name: str) -> bool:
try:
exists = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.collection_name == collection_name)
.first()
self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first()
is not None
)
self.session.rollback()
return exists
except Exception as e:
self.session.rollback()
log.exception(f"Failed to check collection existence: {e}")
log.exception(f'Failed to check collection existence: {e}')
return False
def delete_collection(self, collection_name: str) -> None:

View File

@@ -24,7 +24,7 @@ from open_webui.config import (
class OpenSearchClient(VectorDBBase):
def __init__(self):
self.index_prefix = "open_webui"
self.index_prefix = 'open_webui'
self.client = OpenSearch(
hosts=[OPENSEARCH_URI],
use_ssl=OPENSEARCH_SSL,
@@ -33,25 +33,25 @@ class OpenSearchClient(VectorDBBase):
)
def _get_index_name(self, collection_name: str) -> str:
return f"{self.index_prefix}_{collection_name}"
return f'{self.index_prefix}_{collection_name}'
def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
if not result['hits']['hits']:
return None
ids = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
for hit in result['hits']['hits']:
ids.append(hit['_id'])
documents.append(hit['_source'].get('text'))
metadatas.append(hit['_source'].get('metadata'))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
def _result_to_search_result(self, result) -> SearchResult:
if not result["hits"]["hits"]:
if not result['hits']['hits']:
return None
ids = []
@@ -59,11 +59,11 @@ class OpenSearchClient(VectorDBBase):
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
for hit in result['hits']['hits']:
ids.append(hit['_id'])
distances.append(hit['_score'])
documents.append(hit['_source'].get('text'))
metadatas.append(hit['_source'].get('metadata'))
return SearchResult(
ids=[ids],
@@ -74,33 +74,31 @@ class OpenSearchClient(VectorDBBase):
def _create_index(self, collection_name: str, dimension: int):
body = {
"settings": {"index": {"knn": True}},
"mappings": {
"properties": {
"id": {"type": "keyword"},
"vector": {
"type": "knn_vector",
"dimension": dimension, # Adjust based on your vector dimensions
"index": True,
"similarity": "faiss",
"method": {
"name": "hnsw",
"space_type": "innerproduct", # Use inner product to approximate cosine similarity
"engine": "faiss",
"parameters": {
"ef_construction": 128,
"m": 16,
'settings': {'index': {'knn': True}},
'mappings': {
'properties': {
'id': {'type': 'keyword'},
'vector': {
'type': 'knn_vector',
'dimension': dimension, # Adjust based on your vector dimensions
'index': True,
'similarity': 'faiss',
'method': {
'name': 'hnsw',
'space_type': 'innerproduct', # Use inner product to approximate cosine similarity
'engine': 'faiss',
'parameters': {
'ef_construction': 128,
'm': 16,
},
},
},
"text": {"type": "text"},
"metadata": {"type": "object"},
'text': {'type': 'text'},
'metadata': {'type': 'object'},
}
},
}
self.client.indices.create(
index=self._get_index_name(collection_name), body=body
)
self.client.indices.create(index=self._get_index_name(collection_name), body=body)
def _create_batches(self, items: list[VectorItem], batch_size=100):
for i in range(0, len(items), batch_size):
@@ -128,46 +126,40 @@ class OpenSearchClient(VectorDBBase):
return None
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
"params": {
"field": "vector",
"query_value": vectors[0],
'size': limit,
'_source': ['text', 'metadata'],
'query': {
'script_score': {
'query': {'match_all': {}},
'script': {
'source': '(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0',
'params': {
'field': 'vector',
'query_value': vectors[0],
}, # Assuming single query vector
},
}
},
}
result = self.client.search(
index=self._get_index_name(collection_name), body=query
)
result = self.client.search(index=self._get_index_name(collection_name), body=query)
return self._result_to_search_result(result)
except Exception as e:
return None
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
if not self.has_collection(collection_name):
return None
query_body = {
"query": {"bool": {"filter": []}},
"_source": ["text", "metadata"],
'query': {'bool': {'filter': []}},
'_source': ['text', 'metadata'],
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
{"term": {"metadata." + str(field) + ".keyword": value}}
)
query_body['query']['bool']['filter'].append({'term': {'metadata.' + str(field) + '.keyword': value}})
size = limit if limit else 10000
@@ -188,28 +180,24 @@ class OpenSearchClient(VectorDBBase):
self._create_index(collection_name, dimension)
def get(self, collection_name: str) -> Optional[GetResult]:
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
query = {'query': {'match_all': {}}, '_source': ['text', 'metadata']}
result = self.client.search(
index=self._get_index_name(collection_name), body=query
)
result = self.client.search(index=self._get_index_name(collection_name), body=query)
return self._result_to_get_result(result)
def insert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
self._create_index_if_not_exists(collection_name=collection_name, dimension=len(items[0]['vector']))
for batch in self._create_batches(items):
actions = [
{
"_op_type": "index",
"_index": self._get_index_name(collection_name),
"_id": item["id"],
"_source": {
"vector": item["vector"],
"text": item["text"],
"metadata": process_metadata(item["metadata"]),
'_op_type': 'index',
'_index': self._get_index_name(collection_name),
'_id': item['id'],
'_source': {
'vector': item['vector'],
'text': item['text'],
'metadata': process_metadata(item['metadata']),
},
}
for item in batch
@@ -218,22 +206,20 @@ class OpenSearchClient(VectorDBBase):
self.client.indices.refresh(index=self._get_index_name(collection_name))
def upsert(self, collection_name: str, items: list[VectorItem]):
self._create_index_if_not_exists(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
self._create_index_if_not_exists(collection_name=collection_name, dimension=len(items[0]['vector']))
for batch in self._create_batches(items):
actions = [
{
"_op_type": "update",
"_index": self._get_index_name(collection_name),
"_id": item["id"],
"doc": {
"vector": item["vector"],
"text": item["text"],
"metadata": process_metadata(item["metadata"]),
'_op_type': 'update',
'_index': self._get_index_name(collection_name),
'_id': item['id'],
'doc': {
'vector': item['vector'],
'text': item['text'],
'metadata': process_metadata(item['metadata']),
},
"doc_as_upsert": True,
'doc_as_upsert': True,
}
for item in batch
]
@@ -249,27 +235,23 @@ class OpenSearchClient(VectorDBBase):
if ids:
actions = [
{
"_op_type": "delete",
"_index": self._get_index_name(collection_name),
"_id": id,
'_op_type': 'delete',
'_index': self._get_index_name(collection_name),
'_id': id,
}
for id in ids
]
bulk(self.client, actions)
elif filter:
query_body = {
"query": {"bool": {"filter": []}},
'query': {'bool': {'filter': []}},
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append(
{"term": {"metadata." + str(field) + ".keyword": value}}
)
self.client.delete_by_query(
index=self._get_index_name(collection_name), body=query_body
)
query_body['query']['bool']['filter'].append({'term': {'metadata.' + str(field) + '.keyword': value}})
self.client.delete_by_query(index=self._get_index_name(collection_name), body=query_body)
self.client.indices.refresh(index=self._get_index_name(collection_name))
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
indices = self.client.indices.get(index=f'{self.index_prefix}_*')
for index in indices:
self.client.indices.delete(index=index)

View File

@@ -94,15 +94,15 @@ class Oracle23aiClient(VectorDBBase):
self._create_dbcs_pool()
dsn = ORACLE_DB_DSN
log.info(f"Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]")
log.info(f'Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]')
with self.get_connection() as connection:
log.info(f"Connection version: {connection.version}")
log.info(f'Connection version: {connection.version}')
self._initialize_database(connection)
log.info("Oracle Vector Search initialization complete.")
log.info('Oracle Vector Search initialization complete.')
except Exception as e:
log.exception(f"Error during Oracle Vector Search initialization: {e}")
log.exception(f'Error during Oracle Vector Search initialization: {e}')
raise
def _create_adb_pool(self) -> None:
@@ -122,7 +122,7 @@ class Oracle23aiClient(VectorDBBase):
wallet_location=ORACLE_WALLET_DIR,
wallet_password=ORACLE_WALLET_PASSWORD,
)
log.info("Created ADB connection pool with wallet authentication.")
log.info('Created ADB connection pool with wallet authentication.')
def _create_dbcs_pool(self) -> None:
"""
@@ -138,7 +138,7 @@ class Oracle23aiClient(VectorDBBase):
max=ORACLE_DB_POOL_MAX,
increment=ORACLE_DB_POOL_INCREMENT,
)
log.info("Created DB connection pool with basic authentication.")
log.info('Created DB connection pool with basic authentication.')
def get_connection(self):
"""
@@ -155,13 +155,11 @@ class Oracle23aiClient(VectorDBBase):
return connection
except oracledb.DatabaseError as e:
(error_obj,) = e.args
log.exception(
f"Connection attempt {attempt + 1} failed: {error_obj.message}"
)
log.exception(f'Connection attempt {attempt + 1} failed: {error_obj.message}')
if attempt < max_retries - 1:
wait_time = 2**attempt
log.info(f"Retrying in {wait_time} seconds...")
log.info(f'Retrying in {wait_time} seconds...')
time.sleep(wait_time)
else:
raise
@@ -177,30 +175,30 @@ class Oracle23aiClient(VectorDBBase):
def _monitor():
while True:
try:
log.info("[HealthCheck] Running periodic DB health check...")
log.info('[HealthCheck] Running periodic DB health check...')
self.ensure_connection()
log.info("[HealthCheck] Connection is healthy.")
log.info('[HealthCheck] Connection is healthy.')
except Exception as e:
log.exception(f"[HealthCheck] Connection health check failed: {e}")
log.exception(f'[HealthCheck] Connection health check failed: {e}')
time.sleep(interval_seconds)
thread = threading.Thread(target=_monitor, daemon=True)
thread.start()
log.info(f"Started DB health monitor every {interval_seconds} seconds.")
log.info(f'Started DB health monitor every {interval_seconds} seconds.')
def _reconnect_pool(self):
"""
Attempt to reinitialize the connection pool if it's been closed or broken.
"""
try:
log.info("Attempting to reinitialize the Oracle connection pool...")
log.info('Attempting to reinitialize the Oracle connection pool...')
# Close existing pool if it exists
if self.pool:
try:
self.pool.close()
except Exception as close_error:
log.warning(f"Error closing existing pool: {close_error}")
log.warning(f'Error closing existing pool: {close_error}')
# Re-create the appropriate connection pool based on DB type
if ORACLE_DB_USE_WALLET:
@@ -208,9 +206,9 @@ class Oracle23aiClient(VectorDBBase):
else: # DBCS
self._create_dbcs_pool()
log.info("Connection pool reinitialized.")
log.info('Connection pool reinitialized.')
except Exception as e:
log.exception(f"Failed to reinitialize the connection pool: {e}")
log.exception(f'Failed to reinitialize the connection pool: {e}')
raise
def ensure_connection(self):
@@ -220,11 +218,9 @@ class Oracle23aiClient(VectorDBBase):
try:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("SELECT 1 FROM dual")
cursor.execute('SELECT 1 FROM dual')
except Exception as e:
log.exception(
f"Connection check failed: {e}, attempting to reconnect pool..."
)
log.exception(f'Connection check failed: {e}, attempting to reconnect pool...')
self._reconnect_pool()
def _output_type_handler(self, cursor, metadata):
@@ -239,9 +235,7 @@ class Oracle23aiClient(VectorDBBase):
A variable with appropriate conversion for vector types
"""
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
return cursor.var(
metadata.type_code, arraysize=cursor.arraysize, outconverter=list
)
return cursor.var(metadata.type_code, arraysize=cursor.arraysize, outconverter=list)
def _initialize_database(self, connection) -> None:
"""
@@ -257,7 +251,7 @@ class Oracle23aiClient(VectorDBBase):
"""
with connection.cursor() as cursor:
try:
log.info("Creating Table document_chunk")
log.info('Creating Table document_chunk')
cursor.execute(
"""
BEGIN
@@ -279,7 +273,7 @@ class Oracle23aiClient(VectorDBBase):
"""
)
log.info("Creating Index document_chunk_collection_name_idx")
log.info('Creating Index document_chunk_collection_name_idx')
cursor.execute(
"""
BEGIN
@@ -296,7 +290,7 @@ class Oracle23aiClient(VectorDBBase):
"""
)
log.info("Creating VECTOR INDEX document_chunk_vector_ivf_idx")
log.info('Creating VECTOR INDEX document_chunk_vector_ivf_idx')
cursor.execute(
"""
BEGIN
@@ -318,11 +312,11 @@ class Oracle23aiClient(VectorDBBase):
)
connection.commit()
log.info("Database initialization completed successfully.")
log.info('Database initialization completed successfully.')
except Exception as e:
connection.rollback()
log.exception(f"Error during database initialization: {e}")
log.exception(f'Error during database initialization: {e}')
raise
def check_vector_length(self) -> None:
@@ -344,7 +338,7 @@ class Oracle23aiClient(VectorDBBase):
Returns:
bytes: The vector in Oracle BLOB format
"""
return array.array("f", vector)
return array.array('f', vector)
def adjust_vector_length(self, vector: List[float]) -> List[float]:
"""
@@ -373,7 +367,7 @@ class Oracle23aiClient(VectorDBBase):
"""
if isinstance(obj, Decimal):
return float(obj)
raise TypeError(f"{obj} is not JSON serializable")
raise TypeError(f'{obj} is not JSON serializable')
def _metadata_to_json(self, metadata: Dict) -> str:
"""
@@ -385,7 +379,7 @@ class Oracle23aiClient(VectorDBBase):
Returns:
str: JSON representation of metadata
"""
return json.dumps(metadata, default=self._decimal_handler) if metadata else "{}"
return json.dumps(metadata, default=self._decimal_handler) if metadata else '{}'
def _json_to_metadata(self, json_str: str) -> Dict:
"""
@@ -424,8 +418,8 @@ class Oracle23aiClient(VectorDBBase):
try:
with connection.cursor() as cursor:
for item in items:
vector_blob = self._vector_to_blob(item["vector"])
metadata_json = self._metadata_to_json(item["metadata"])
vector_blob = self._vector_to_blob(item['vector'])
metadata_json = self._metadata_to_json(item['metadata'])
cursor.execute(
"""
@@ -434,22 +428,20 @@ class Oracle23aiClient(VectorDBBase):
VALUES (:id, :collection_name, :text, :metadata, :vector)
""",
{
"id": item["id"],
"collection_name": collection_name,
"text": item["text"],
"metadata": metadata_json,
"vector": vector_blob,
'id': item['id'],
'collection_name': collection_name,
'text': item['text'],
'metadata': metadata_json,
'vector': vector_blob,
},
)
connection.commit()
log.info(
f"Successfully inserted {len(items)} items into collection '{collection_name}'."
)
log.info(f"Successfully inserted {len(items)} items into collection '{collection_name}'.")
except Exception as e:
connection.rollback()
log.exception(f"Error during insert: {e}")
log.exception(f'Error during insert: {e}')
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -480,8 +472,8 @@ class Oracle23aiClient(VectorDBBase):
try:
with connection.cursor() as cursor:
for item in items:
vector_blob = self._vector_to_blob(item["vector"])
metadata_json = self._metadata_to_json(item["metadata"])
vector_blob = self._vector_to_blob(item['vector'])
metadata_json = self._metadata_to_json(item['metadata'])
cursor.execute(
"""
@@ -499,27 +491,25 @@ class Oracle23aiClient(VectorDBBase):
VALUES (:ins_id, :ins_collection_name, :ins_text, :ins_metadata, :ins_vector)
""",
{
"merge_id": item["id"],
"upd_collection_name": collection_name,
"upd_text": item["text"],
"upd_metadata": metadata_json,
"upd_vector": vector_blob,
"ins_id": item["id"],
"ins_collection_name": collection_name,
"ins_text": item["text"],
"ins_metadata": metadata_json,
"ins_vector": vector_blob,
'merge_id': item['id'],
'upd_collection_name': collection_name,
'upd_text': item['text'],
'upd_metadata': metadata_json,
'upd_vector': vector_blob,
'ins_id': item['id'],
'ins_collection_name': collection_name,
'ins_text': item['text'],
'ins_metadata': metadata_json,
'ins_vector': vector_blob,
},
)
connection.commit()
log.info(
f"Successfully upserted {len(items)} items into collection '{collection_name}'."
)
log.info(f"Successfully upserted {len(items)} items into collection '{collection_name}'.")
except Exception as e:
connection.rollback()
log.exception(f"Error during upsert: {e}")
log.exception(f'Error during upsert: {e}')
raise
def search(
@@ -551,13 +541,11 @@ class Oracle23aiClient(VectorDBBase):
... for i, (id, dist) in enumerate(zip(results.ids[0], results.distances[0])):
... log.info(f"Match {i+1}: id={id}, distance={dist}")
"""
log.info(
f"Searching items from collection '{collection_name}' with limit {limit}."
)
log.info(f"Searching items from collection '{collection_name}' with limit {limit}.")
try:
if not vectors:
log.warning("No vectors provided for search.")
log.warning('No vectors provided for search.')
return None
num_queries = len(vectors)
@@ -583,9 +571,9 @@ class Oracle23aiClient(VectorDBBase):
FETCH APPROX FIRST :limit ROWS ONLY
""",
{
"query_vector": vector_blob,
"collection_name": collection_name,
"limit": limit,
'query_vector': vector_blob,
'collection_name': collection_name,
'limit': limit,
},
)
@@ -593,35 +581,21 @@ class Oracle23aiClient(VectorDBBase):
for row in results:
ids[qid].append(row[0])
documents[qid].append(
row[1].read()
if isinstance(row[1], oracledb.LOB)
else str(row[1])
)
documents[qid].append(row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]))
# 🔧 FIXED: Parse JSON metadata properly
metadata_str = (
row[2].read()
if isinstance(row[2], oracledb.LOB)
else row[2]
)
metadata_str = row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
metadatas[qid].append(self._json_to_metadata(metadata_str))
distances[qid].append(float(row[3]))
log.info(
f"Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results."
)
log.info(f'Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results.')
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
except Exception as e:
log.exception(f"Error during search: {e}")
log.exception(f'Error during search: {e}')
return None
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
"""
Query items based on metadata filters.
@@ -653,15 +627,15 @@ class Oracle23aiClient(VectorDBBase):
WHERE collection_name = :collection_name
"""
params = {"collection_name": collection_name}
params = {'collection_name': collection_name}
for i, (key, value) in enumerate(filter.items()):
param_name = f"value_{i}"
param_name = f'value_{i}'
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
params[param_name] = str(value)
query += " FETCH FIRST :limit ROWS ONLY"
params["limit"] = limit
query += ' FETCH FIRST :limit ROWS ONLY'
params['limit'] = limit
with self.get_connection() as connection:
with connection.cursor() as cursor:
@@ -669,32 +643,25 @@ class Oracle23aiClient(VectorDBBase):
results = cursor.fetchall()
if not results:
log.info("No results found for query.")
log.info('No results found for query.')
return None
ids = [[row[0] for row in results]]
documents = [
[
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
for row in results
]
]
documents = [[row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]) for row in results]]
# 🔧 FIXED: Parse JSON metadata properly
metadatas = [
[
self._json_to_metadata(
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
)
self._json_to_metadata(row[2].read() if isinstance(row[2], oracledb.LOB) else row[2])
for row in results
]
]
log.info(f"Query completed. Found {len(results)} results.")
log.info(f'Query completed. Found {len(results)} results.')
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
log.exception(f"Error during query: {e}")
log.exception(f'Error during query: {e}')
return None
def get(self, collection_name: str) -> Optional[GetResult]:
@@ -729,28 +696,21 @@ class Oracle23aiClient(VectorDBBase):
WHERE collection_name = :collection_name
FETCH FIRST :limit ROWS ONLY
""",
{"collection_name": collection_name, "limit": limit},
{'collection_name': collection_name, 'limit': limit},
)
results = cursor.fetchall()
if not results:
log.info("No results found.")
log.info('No results found.')
return None
ids = [[row[0] for row in results]]
documents = [
[
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
for row in results
]
]
documents = [[row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]) for row in results]]
# 🔧 FIXED: Parse JSON metadata properly
metadatas = [
[
self._json_to_metadata(
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
)
self._json_to_metadata(row[2].read() if isinstance(row[2], oracledb.LOB) else row[2])
for row in results
]
]
@@ -758,7 +718,7 @@ class Oracle23aiClient(VectorDBBase):
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
log.exception(f"Error during get: {e}")
log.exception(f'Error during get: {e}')
return None
def delete(
@@ -790,21 +750,19 @@ class Oracle23aiClient(VectorDBBase):
log.info(f"Deleting items from collection '{collection_name}'.")
try:
query = (
"DELETE FROM document_chunk WHERE collection_name = :collection_name"
)
params = {"collection_name": collection_name}
query = 'DELETE FROM document_chunk WHERE collection_name = :collection_name'
params = {'collection_name': collection_name}
if ids:
# 🔧 FIXED: Use proper parameterized query to prevent SQL injection
placeholders = ",".join([f":id_{i}" for i in range(len(ids))])
query += f" AND id IN ({placeholders})"
placeholders = ','.join([f':id_{i}' for i in range(len(ids))])
query += f' AND id IN ({placeholders})'
for i, id_val in enumerate(ids):
params[f"id_{i}"] = id_val
params[f'id_{i}'] = id_val
if filter:
for i, (key, value) in enumerate(filter.items()):
param_name = f"value_{i}"
param_name = f'value_{i}'
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
params[param_name] = str(value)
@@ -817,7 +775,7 @@ class Oracle23aiClient(VectorDBBase):
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
log.exception(f"Error during delete: {e}")
log.exception(f'Error during delete: {e}')
raise
def reset(self) -> None:
@@ -833,21 +791,19 @@ class Oracle23aiClient(VectorDBBase):
>>> client = Oracle23aiClient()
>>> client.reset() # Warning: Removes all data!
"""
log.info("Resetting database - deleting all items.")
log.info('Resetting database - deleting all items.')
try:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute("DELETE FROM document_chunk")
cursor.execute('DELETE FROM document_chunk')
deleted = cursor.rowcount
connection.commit()
log.info(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
log.info(f"Reset complete. Deleted {deleted} items from 'document_chunk' table.")
except Exception as e:
log.exception(f"Error during reset: {e}")
log.exception(f'Error during reset: {e}')
raise
def close(self) -> None:
@@ -862,11 +818,11 @@ class Oracle23aiClient(VectorDBBase):
>>> client.close()
"""
try:
if hasattr(self, "pool") and self.pool:
if hasattr(self, 'pool') and self.pool:
self.pool.close()
log.info("Oracle Vector Search connection pool closed.")
log.info('Oracle Vector Search connection pool closed.')
except Exception as e:
log.exception(f"Error closing connection pool: {e}")
log.exception(f'Error closing connection pool: {e}')
def has_collection(self, collection_name: str) -> bool:
"""
@@ -895,7 +851,7 @@ class Oracle23aiClient(VectorDBBase):
WHERE collection_name = :collection_name
FETCH FIRST 1 ROWS ONLY
""",
{"collection_name": collection_name},
{'collection_name': collection_name},
)
count = cursor.fetchone()[0]
@@ -903,7 +859,7 @@ class Oracle23aiClient(VectorDBBase):
return count > 0
except Exception as e:
log.exception(f"Error checking collection existence: {e}")
log.exception(f'Error checking collection existence: {e}')
return False
def delete_collection(self, collection_name: str) -> None:
@@ -929,15 +885,13 @@ class Oracle23aiClient(VectorDBBase):
DELETE FROM document_chunk
WHERE collection_name = :collection_name
""",
{"collection_name": collection_name},
{'collection_name': collection_name},
)
deleted = cursor.rowcount
connection.commit()
log.info(
f"Collection '{collection_name}' deleted. Removed {deleted} items."
)
log.info(f"Collection '{collection_name}' deleted. Removed {deleted} items.")
except Exception as e:
log.exception(f"Error deleting collection '{collection_name}': {e}")

View File

@@ -55,7 +55,7 @@ VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
USE_HALFVEC = PGVECTOR_USE_HALFVEC
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
VECTOR_OPCLASS = 'halfvec_cosine_ops' if USE_HALFVEC else 'vector_cosine_ops'
Base = declarative_base()
log = logging.getLogger(__name__)
@@ -65,12 +65,12 @@ def pgcrypto_encrypt(val, key):
return func.pgp_sym_encrypt(val, literal(key))
def pgcrypto_decrypt(col, key, outtype="text"):
def pgcrypto_decrypt(col, key, outtype='text'):
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
class DocumentChunk(Base):
__tablename__ = "document_chunk"
__tablename__ = 'document_chunk'
id = Column(Text, primary_key=True)
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
@@ -86,7 +86,6 @@ class DocumentChunk(Base):
class PgvectorClient(VectorDBBase):
def __init__(self) -> None:
# if no pgvector uri, use the existing database connection
if not PGVECTOR_DB_URL:
from open_webui.internal.db import ScopedSession
@@ -105,46 +104,44 @@ class PgvectorClient(VectorDBBase):
poolclass=QueuePool,
)
else:
engine = create_engine(
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
)
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool)
else:
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
self.session = scoped_session(SessionLocal)
try:
# Ensure the pgvector extension is available
# Use a conditional check to avoid permission issues on Azure PostgreSQL
if PGVECTOR_CREATE_EXTENSION:
self.session.execute(text("""
self.session.execute(
text("""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
CREATE EXTENSION IF NOT EXISTS vector;
END IF;
END $$;
"""))
""")
)
if PGVECTOR_PGCRYPTO:
# Ensure the pgcrypto extension is available for encryption
# Use a conditional check to avoid permission issues on Azure PostgreSQL
self.session.execute(text("""
self.session.execute(
text("""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
CREATE EXTENSION IF NOT EXISTS pgcrypto;
END IF;
END $$;
"""))
""")
)
if not PGVECTOR_PGCRYPTO_KEY:
raise ValueError(
"PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
)
raise ValueError('PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled.')
# Check vector length consistency
self.check_vector_length()
@@ -160,15 +157,14 @@ class PgvectorClient(VectorDBBase):
self.session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
"ON document_chunk (collection_name);"
'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);'
)
)
self.session.commit()
log.info("Initialization complete.")
log.info('Initialization complete.')
except Exception as e:
self.session.rollback()
log.exception(f"Error during initialization: {e}")
log.exception(f'Error during initialization: {e}')
raise
@staticmethod
@@ -176,7 +172,7 @@ class PgvectorClient(VectorDBBase):
if not index_def:
return None
try:
after_using = index_def.lower().split("using ", 1)[1]
after_using = index_def.lower().split('using ', 1)[1]
return after_using.split()[0]
except (IndexError, AttributeError):
return None
@@ -189,23 +185,23 @@ class PgvectorClient(VectorDBBase):
index_method,
)
elif USE_HALFVEC:
index_method = "hnsw"
index_method = 'hnsw'
log.info(
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
'VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.',
VECTOR_LENGTH,
)
else:
index_method = "ivfflat"
index_method = 'ivfflat'
if index_method == "hnsw":
index_options = f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
if index_method == 'hnsw':
index_options = f'WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})'
else:
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
index_options = f'WITH (lists = {PGVECTOR_IVFFLAT_LISTS})'
return index_method, index_options
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
index_name = "idx_document_chunk_vector"
index_name = 'idx_document_chunk_vector'
existing_index_def = self.session.execute(
text("""
SELECT indexdef
@@ -214,7 +210,7 @@ class PgvectorClient(VectorDBBase):
AND tablename = 'document_chunk'
AND indexname = :index_name
"""),
{"index_name": index_name},
{'index_name': index_name},
).scalar()
existing_method = self._extract_index_method(existing_index_def)
@@ -222,23 +218,23 @@ class PgvectorClient(VectorDBBase):
raise RuntimeError(
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
"and recreate it with the new method before restarting Open WebUI."
'Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) '
'and recreate it with the new method before restarting Open WebUI.'
)
if not existing_index_def:
index_sql = (
f"CREATE INDEX IF NOT EXISTS {index_name} "
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
f'CREATE INDEX IF NOT EXISTS {index_name} '
f'ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})'
)
if index_options:
index_sql = f"{index_sql} {index_options}"
index_sql = f'{index_sql} {index_options}'
self.session.execute(text(index_sql))
log.info(
"Ensured vector index '%s' using %s%s.",
index_name,
index_method,
f" {index_options}" if index_options else "",
f' {index_options}' if index_options else '',
)
def check_vector_length(self) -> None:
@@ -249,16 +245,14 @@ class PgvectorClient(VectorDBBase):
metadata = MetaData()
try:
# Attempt to reflect the 'document_chunk' table
document_chunk_table = Table(
"document_chunk", metadata, autoload_with=self.session.bind
)
document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind)
except NoSuchTableError:
# Table does not exist; no action needed
return
# Proceed to check the vector column
if "vector" in document_chunk_table.columns:
vector_column = document_chunk_table.columns["vector"]
if 'vector' in document_chunk_table.columns:
vector_column = document_chunk_table.columns['vector']
vector_type = vector_column.type
expected_type = HALFVEC if USE_HALFVEC else Vector
@@ -268,16 +262,14 @@ class PgvectorClient(VectorDBBase):
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
)
db_vector_length = getattr(vector_type, "dim", None)
db_vector_length = getattr(vector_type, 'dim', None)
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
raise Exception(
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
"Cannot change vector size after initialization without migrating the data."
f'VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. '
'Cannot change vector size after initialization without migrating the data.'
)
else:
raise Exception(
"The 'vector' column does not exist in the 'document_chunk' table."
)
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
def adjust_vector_length(self, vector: List[float]) -> List[float]:
# Adjust vector to have length VECTOR_LENGTH
@@ -294,10 +286,10 @@ class PgvectorClient(VectorDBBase):
try:
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
vector = self.adjust_vector_length(item['vector'])
# Use raw SQL for BYTEA/pgcrypto
# Ensure metadata is converted to its JSON text representation
json_metadata = json.dumps(item["metadata"])
json_metadata = json.dumps(item['metadata'])
self.session.execute(
text("""
INSERT INTO document_chunk
@@ -310,12 +302,12 @@ class PgvectorClient(VectorDBBase):
ON CONFLICT (id) DO NOTHING
"""),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata_text": json_metadata,
"key": PGVECTOR_PGCRYPTO_KEY,
'id': item['id'],
'vector': vector,
'collection_name': collection_name,
'text': item['text'],
'metadata_text': json_metadata,
'key': PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
@@ -324,31 +316,29 @@ class PgvectorClient(VectorDBBase):
else:
new_items = []
for item in items:
vector = self.adjust_vector_length(item["vector"])
vector = self.adjust_vector_length(item['vector'])
new_chunk = DocumentChunk(
id=item["id"],
id=item['id'],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
text=item['text'],
vmetadata=process_metadata(item['metadata']),
)
new_items.append(new_chunk)
self.session.bulk_save_objects(new_items)
self.session.commit()
log.info(
f"Inserted {len(new_items)} items into collection '{collection_name}'."
)
log.info(f"Inserted {len(new_items)} items into collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Error during insert: {e}")
log.exception(f'Error during insert: {e}')
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
try:
if PGVECTOR_PGCRYPTO:
for item in items:
vector = self.adjust_vector_length(item["vector"])
json_metadata = json.dumps(item["metadata"])
vector = self.adjust_vector_length(item['vector'])
json_metadata = json.dumps(item['metadata'])
self.session.execute(
text("""
INSERT INTO document_chunk
@@ -365,47 +355,39 @@ class PgvectorClient(VectorDBBase):
vmetadata = EXCLUDED.vmetadata
"""),
{
"id": item["id"],
"vector": vector,
"collection_name": collection_name,
"text": item["text"],
"metadata_text": json_metadata,
"key": PGVECTOR_PGCRYPTO_KEY,
'id': item['id'],
'vector': vector,
'collection_name': collection_name,
'text': item['text'],
'metadata_text': json_metadata,
'key': PGVECTOR_PGCRYPTO_KEY,
},
)
self.session.commit()
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
else:
for item in items:
vector = self.adjust_vector_length(item["vector"])
existing = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.id == item["id"])
.first()
)
vector = self.adjust_vector_length(item['vector'])
existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first()
if existing:
existing.vector = vector
existing.text = item["text"]
existing.vmetadata = process_metadata(item["metadata"])
existing.collection_name = (
collection_name # Update collection_name if necessary
)
existing.text = item['text']
existing.vmetadata = process_metadata(item['metadata'])
existing.collection_name = collection_name # Update collection_name if necessary
else:
new_chunk = DocumentChunk(
id=item["id"],
id=item['id'],
vector=vector,
collection_name=collection_name,
text=item["text"],
vmetadata=process_metadata(item["metadata"]),
text=item['text'],
vmetadata=process_metadata(item['metadata']),
)
self.session.add(new_chunk)
self.session.commit()
log.info(
f"Upserted {len(items)} items into collection '{collection_name}'."
)
log.info(f"Upserted {len(items)} items into collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Error during upsert: {e}")
log.exception(f'Error during upsert: {e}')
raise
def search(
@@ -427,38 +409,26 @@ class PgvectorClient(VectorDBBase):
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
# Create the values for query vectors
qid_col = column("qid", Integer)
q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
qid_col = column('qid', Integer)
q_vector_col = column('q_vector', VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
query_vectors = (
values(qid_col, q_vector_col)
.data(
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
)
.alias("query_vectors")
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
.alias('query_vectors')
)
result_fields = [
DocumentChunk.id,
]
if PGVECTOR_PGCRYPTO:
result_fields.append(pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'))
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text")
)
result_fields.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata")
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata')
)
else:
result_fields.append(DocumentChunk.text)
result_fields.append(DocumentChunk.vmetadata)
result_fields.append(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
"distance"
)
)
result_fields.append((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance'))
# Build the lateral subquery for each query vector
where_clauses = [DocumentChunk.collection_name == collection_name]
@@ -466,9 +436,9 @@ class PgvectorClient(VectorDBBase):
# Apply metadata filter if provided
if filter:
for key, value in filter.items():
if isinstance(value, dict) and "$in" in value:
if isinstance(value, dict) and '$in' in value:
# Handle $in operator: {"field": {"$in": [values]}}
in_values = value["$in"]
in_values = value['$in']
if PGVECTOR_PGCRYPTO:
where_clauses.append(
pgcrypto_decrypt(
@@ -478,11 +448,7 @@ class PgvectorClient(VectorDBBase):
)[key].astext.in_([str(v) for v in in_values])
)
else:
where_clauses.append(
DocumentChunk.vmetadata[key].astext.in_(
[str(v) for v in in_values]
)
)
where_clauses.append(DocumentChunk.vmetadata[key].astext.in_([str(v) for v in in_values]))
else:
# Handle simple equality: {"field": "value"}
if PGVECTOR_PGCRYPTO:
@@ -495,20 +461,16 @@ class PgvectorClient(VectorDBBase):
== str(value)
)
else:
where_clauses.append(
DocumentChunk.vmetadata[key].astext == str(value)
)
where_clauses.append(DocumentChunk.vmetadata[key].astext == str(value))
subq = (
select(*result_fields)
.where(*where_clauses)
.order_by(
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
)
.order_by((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)))
)
if limit is not None:
subq = subq.limit(limit)
subq = subq.lateral("result")
subq = subq.lateral('result')
# Build the main query by joining query_vectors and the lateral subquery
stmt = (
@@ -550,17 +512,13 @@ class PgvectorClient(VectorDBBase):
metadatas[qid].append(row.vmetadata)
self.session.rollback() # read-only transaction
return SearchResult(
ids=ids, distances=distances, documents=documents, metadatas=metadatas
)
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Error during search: {e}")
log.exception(f'Error during search: {e}')
return None
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
try:
if PGVECTOR_PGCRYPTO:
# Build where clause for vmetadata filter
@@ -568,32 +526,22 @@ class PgvectorClient(VectorDBBase):
for key, value in filter.items():
# decrypt then check key: JSON filter after decryption
where_clauses.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext
== str(value)
)
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'),
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'),
).where(*where_clauses)
if limit is not None:
stmt = stmt.limit(limit)
results = self.session.execute(stmt).all()
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
if limit is not None:
query = query.limit(limit)
@@ -615,22 +563,16 @@ class PgvectorClient(VectorDBBase):
)
except Exception as e:
self.session.rollback()
log.exception(f"Error during query: {e}")
log.exception(f'Error during query: {e}')
return None
def get(
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
try:
if PGVECTOR_PGCRYPTO:
stmt = select(
DocumentChunk.id,
pgcrypto_decrypt(
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
).label("text"),
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
).label("vmetadata"),
pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'),
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'),
).where(DocumentChunk.collection_name == collection_name)
if limit is not None:
stmt = stmt.limit(limit)
@@ -639,10 +581,7 @@ class PgvectorClient(VectorDBBase):
documents = [[row.text for row in results]]
metadatas = [[row.vmetadata for row in results]]
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
if limit is not None:
query = query.limit(limit)
@@ -659,7 +598,7 @@ class PgvectorClient(VectorDBBase):
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
except Exception as e:
self.session.rollback()
log.exception(f"Error during get: {e}")
log.exception(f'Error during get: {e}')
return None
def delete(
@@ -676,43 +615,35 @@ class PgvectorClient(VectorDBBase):
if filter:
for key, value in filter.items():
wheres.append(
pgcrypto_decrypt(
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
)[key].astext
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext
== str(value)
)
stmt = DocumentChunk.__table__.delete().where(*wheres)
result = self.session.execute(stmt)
deleted = result.rowcount
else:
query = self.session.query(DocumentChunk).filter(
DocumentChunk.collection_name == collection_name
)
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
if ids:
query = query.filter(DocumentChunk.id.in_(ids))
if filter:
for key, value in filter.items():
query = query.filter(
DocumentChunk.vmetadata[key].astext == str(value)
)
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
deleted = query.delete(synchronize_session=False)
self.session.commit()
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
except Exception as e:
self.session.rollback()
log.exception(f"Error during delete: {e}")
log.exception(f'Error during delete: {e}')
raise
def reset(self) -> None:
try:
deleted = self.session.query(DocumentChunk).delete()
self.session.commit()
log.info(
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
)
log.info(f"Reset complete. Deleted {deleted} items from 'document_chunk' table.")
except Exception as e:
self.session.rollback()
log.exception(f"Error during reset: {e}")
log.exception(f'Error during reset: {e}')
raise
def close(self) -> None:
@@ -721,16 +652,14 @@ class PgvectorClient(VectorDBBase):
def has_collection(self, collection_name: str) -> bool:
try:
exists = (
self.session.query(DocumentChunk)
.filter(DocumentChunk.collection_name == collection_name)
.first()
self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first()
is not None
)
self.session.rollback() # read-only transaction
return exists
except Exception as e:
self.session.rollback()
log.exception(f"Error checking collection existence: {e}")
log.exception(f'Error checking collection existence: {e}')
return False
def delete_collection(self, collection_name: str) -> None:

View File

@@ -45,7 +45,7 @@ log = logging.getLogger(__name__)
class PineconeClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open-webui"
self.collection_prefix = 'open-webui'
# Validate required configuration
self._validate_config()
@@ -67,7 +67,7 @@ class PineconeClient(VectorDBBase):
timeout=30, # Reasonable timeout for operations
)
self.using_grpc = True
log.info("Using Pinecone gRPC client for optimal performance")
log.info('Using Pinecone gRPC client for optimal performance')
else:
# Fallback to HTTP client with enhanced connection pooling
self.client = Pinecone(
@@ -76,7 +76,7 @@ class PineconeClient(VectorDBBase):
timeout=30, # Reasonable timeout for operations
)
self.using_grpc = False
log.info("Using Pinecone HTTP client (gRPC not available)")
log.info('Using Pinecone HTTP client (gRPC not available)')
# Persistent executor for batch operations
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
@@ -88,20 +88,18 @@ class PineconeClient(VectorDBBase):
"""Validate that all required configuration variables are set."""
missing_vars = []
if not PINECONE_API_KEY:
missing_vars.append("PINECONE_API_KEY")
missing_vars.append('PINECONE_API_KEY')
if not PINECONE_ENVIRONMENT:
missing_vars.append("PINECONE_ENVIRONMENT")
missing_vars.append('PINECONE_ENVIRONMENT')
if not PINECONE_INDEX_NAME:
missing_vars.append("PINECONE_INDEX_NAME")
missing_vars.append('PINECONE_INDEX_NAME')
if not PINECONE_DIMENSION:
missing_vars.append("PINECONE_DIMENSION")
missing_vars.append('PINECONE_DIMENSION')
if not PINECONE_CLOUD:
missing_vars.append("PINECONE_CLOUD")
missing_vars.append('PINECONE_CLOUD')
if missing_vars:
raise ValueError(
f"Required configuration missing: {', '.join(missing_vars)}"
)
raise ValueError(f'Required configuration missing: {", ".join(missing_vars)}')
def _initialize_index(self) -> None:
"""Initialize the Pinecone index."""
@@ -126,8 +124,8 @@ class PineconeClient(VectorDBBase):
)
except Exception as e:
log.error(f"Failed to initialize Pinecone index: {e}")
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
log.error(f'Failed to initialize Pinecone index: {e}')
raise RuntimeError(f'Failed to initialize Pinecone index: {e}')
def _retry_pinecone_operation(self, operation_func, max_retries=3):
"""Retry Pinecone operations with exponential backoff for rate limits and network issues."""
@@ -140,18 +138,18 @@ class PineconeClient(VectorDBBase):
is_retryable = any(
keyword in error_str
for keyword in [
"rate limit",
"quota",
"timeout",
"network",
"connection",
"unavailable",
"internal error",
"429",
"500",
"502",
"503",
"504",
'rate limit',
'quota',
'timeout',
'network',
'connection',
'unavailable',
'internal error',
'429',
'500',
'502',
'503',
'504',
]
)
@@ -162,45 +160,42 @@ class PineconeClient(VectorDBBase):
# Exponential backoff with jitter
delay = (2**attempt) + random.uniform(0, 1)
log.warning(
f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
f"retrying in {delay:.2f}s: {e}"
f'Pinecone operation failed (attempt {attempt + 1}/{max_retries}), retrying in {delay:.2f}s: {e}'
)
time.sleep(delay)
def _create_points(
self, items: List[VectorItem], collection_name_with_prefix: str
) -> List[Dict[str, Any]]:
def _create_points(self, items: List[VectorItem], collection_name_with_prefix: str) -> List[Dict[str, Any]]:
"""Convert VectorItem objects to Pinecone point format."""
points = []
for item in items:
# Start with any existing metadata or an empty dict
metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
metadata = item.get('metadata', {}).copy() if item.get('metadata') else {}
# Add text to metadata if available
if "text" in item:
metadata["text"] = item["text"]
if 'text' in item:
metadata['text'] = item['text']
# Always add collection_name to metadata for filtering
metadata["collection_name"] = collection_name_with_prefix
metadata['collection_name'] = collection_name_with_prefix
point = {
"id": item["id"],
"values": item["vector"],
"metadata": process_metadata(metadata),
'id': item['id'],
'values': item['vector'],
'metadata': process_metadata(metadata),
}
points.append(point)
return points
def _get_collection_name_with_prefix(self, collection_name: str) -> str:
"""Get the collection name with prefix."""
return f"{self.collection_prefix}_{collection_name}"
return f'{self.collection_prefix}_{collection_name}'
def _normalize_distance(self, score: float) -> float:
"""Normalize distance score based on the metric used."""
if self.metric.lower() == "cosine":
if self.metric.lower() == 'cosine':
# Cosine similarity ranges from -1 to 1, normalize to 0 to 1
return (score + 1.0) / 2.0
elif self.metric.lower() in ["euclidean", "dotproduct"]:
elif self.metric.lower() in ['euclidean', 'dotproduct']:
# These are already suitable for ranking (smaller is better for Euclidean)
return score
else:
@@ -214,68 +209,56 @@ class PineconeClient(VectorDBBase):
metadatas = []
for match in matches:
metadata = getattr(match, "metadata", {}) or {}
ids.append(match.id if hasattr(match, "id") else match["id"])
documents.append(metadata.get("text", ""))
metadata = getattr(match, 'metadata', {}) or {}
ids.append(match.id if hasattr(match, 'id') else match['id'])
documents.append(metadata.get('text', ''))
metadatas.append(metadata)
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
'ids': [ids],
'documents': [documents],
'metadatas': [metadatas],
}
)
def has_collection(self, collection_name: str) -> bool:
"""Check if a collection exists by searching for at least one item."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
try:
# Search for at least 1 item with this collection name in metadata
response = self.index.query(
vector=[0.0] * self.dimension, # dummy vector
top_k=1,
filter={"collection_name": collection_name_with_prefix},
filter={'collection_name': collection_name_with_prefix},
include_metadata=False,
)
matches = getattr(response, "matches", []) or []
matches = getattr(response, 'matches', []) or []
return len(matches) > 0
except Exception as e:
log.exception(
f"Error checking collection '{collection_name_with_prefix}': {e}"
)
log.exception(f"Error checking collection '{collection_name_with_prefix}': {e}")
return False
def delete_collection(self, collection_name: str) -> None:
"""Delete a collection by removing all vectors with the collection name in metadata."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
try:
self.index.delete(filter={"collection_name": collection_name_with_prefix})
log.info(
f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
)
self.index.delete(filter={'collection_name': collection_name_with_prefix})
log.info(f"Collection '{collection_name_with_prefix}' deleted (all vectors removed).")
except Exception as e:
log.warning(
f"Failed to delete collection '{collection_name_with_prefix}': {e}"
)
log.warning(f"Failed to delete collection '{collection_name_with_prefix}': {e}")
raise
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Insert vectors into a collection."""
if not items:
log.warning("No items to insert")
log.warning('No items to insert')
return
start_time = time.time()
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
points = self._create_points(items, collection_name_with_prefix)
# Parallelize batch inserts for performance
@@ -288,26 +271,23 @@ class PineconeClient(VectorDBBase):
try:
future.result()
except Exception as e:
log.error(f"Error inserting batch: {e}")
log.error(f'Error inserting batch: {e}')
raise
elapsed = time.time() - start_time
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
log.debug(f'Insert of {len(points)} vectors took {elapsed:.2f} seconds')
log.info(
f"Successfully inserted {len(points)} vectors in parallel batches "
f"into '{collection_name_with_prefix}'"
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
)
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
"""Upsert (insert or update) vectors into a collection."""
if not items:
log.warning("No items to upsert")
log.warning('No items to upsert')
return
start_time = time.time()
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
points = self._create_points(items, collection_name_with_prefix)
# Parallelize batch upserts for performance
@@ -320,78 +300,53 @@ class PineconeClient(VectorDBBase):
try:
future.result()
except Exception as e:
log.error(f"Error upserting batch: {e}")
log.error(f'Error upserting batch: {e}')
raise
elapsed = time.time() - start_time
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
log.debug(f'Upsert of {len(points)} vectors took {elapsed:.2f} seconds')
log.info(
f"Successfully upserted {len(points)} vectors in parallel batches "
f"into '{collection_name_with_prefix}'"
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
)
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
"""Async version of insert using asyncio and run_in_executor for improved performance."""
if not items:
log.warning("No items to insert")
log.warning('No items to insert')
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
points = self._create_points(items, collection_name_with_prefix)
# Create batches
batches = [
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
]
batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)]
loop = asyncio.get_event_loop()
tasks = [
loop.run_in_executor(
None, functools.partial(self.index.upsert, vectors=batch)
)
for batch in batches
]
tasks = [loop.run_in_executor(None, functools.partial(self.index.upsert, vectors=batch)) for batch in batches]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
log.error(f"Error in async insert batch: {result}")
log.error(f'Error in async insert batch: {result}')
raise result
log.info(
f"Successfully async inserted {len(points)} vectors in batches "
f"into '{collection_name_with_prefix}'"
)
log.info(f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'")
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
"""Async version of upsert using asyncio and run_in_executor for improved performance."""
if not items:
log.warning("No items to upsert")
log.warning('No items to upsert')
return
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
points = self._create_points(items, collection_name_with_prefix)
# Create batches
batches = [
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
]
batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)]
loop = asyncio.get_event_loop()
tasks = [
loop.run_in_executor(
None, functools.partial(self.index.upsert, vectors=batch)
)
for batch in batches
]
tasks = [loop.run_in_executor(None, functools.partial(self.index.upsert, vectors=batch)) for batch in batches]
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
log.error(f"Error in async upsert batch: {result}")
log.error(f'Error in async upsert batch: {result}')
raise result
log.info(
f"Successfully async upserted {len(points)} vectors in batches "
f"into '{collection_name_with_prefix}'"
)
log.info(f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'")
def search(
self,
@@ -402,12 +357,10 @@ class PineconeClient(VectorDBBase):
) -> Optional[SearchResult]:
"""Search for similar vectors in a collection."""
if not vectors or not vectors[0]:
log.warning("No vectors provided for search")
log.warning('No vectors provided for search')
return None
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
if limit is None or limit <= 0:
limit = NO_LIMIT
@@ -421,10 +374,10 @@ class PineconeClient(VectorDBBase):
vector=query_vector,
top_k=limit,
include_metadata=True,
filter={"collection_name": collection_name_with_prefix},
filter={'collection_name': collection_name_with_prefix},
)
matches = getattr(query_response, "matches", []) or []
matches = getattr(query_response, 'matches', []) or []
if not matches:
# Return empty result if no matches
return SearchResult(
@@ -438,12 +391,7 @@ class PineconeClient(VectorDBBase):
get_result = self._result_to_get_result(matches)
# Calculate normalized distances based on metric
distances = [
[
self._normalize_distance(getattr(match, "score", 0.0))
for match in matches
]
]
distances = [[self._normalize_distance(getattr(match, 'score', 0.0)) for match in matches]]
return SearchResult(
ids=get_result.ids,
@@ -455,13 +403,9 @@ class PineconeClient(VectorDBBase):
log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
return None
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
"""Query vectors by metadata filter."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
if limit is None or limit <= 0:
limit = NO_LIMIT
@@ -471,7 +415,7 @@ class PineconeClient(VectorDBBase):
zero_vector = [0.0] * self.dimension
# Combine user filter with collection_name
pinecone_filter = {"collection_name": collection_name_with_prefix}
pinecone_filter = {'collection_name': collection_name_with_prefix}
if filter:
pinecone_filter.update(filter)
@@ -483,7 +427,7 @@ class PineconeClient(VectorDBBase):
include_metadata=True,
)
matches = getattr(query_response, "matches", []) or []
matches = getattr(query_response, 'matches', []) or []
return self._result_to_get_result(matches)
except Exception as e:
@@ -492,9 +436,7 @@ class PineconeClient(VectorDBBase):
def get(self, collection_name: str) -> Optional[GetResult]:
"""Get all vectors in a collection."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
try:
# Use a zero vector for fetching all entries
@@ -505,10 +447,10 @@ class PineconeClient(VectorDBBase):
vector=zero_vector,
top_k=NO_LIMIT,
include_metadata=True,
filter={"collection_name": collection_name_with_prefix},
filter={'collection_name': collection_name_with_prefix},
)
matches = getattr(query_response, "matches", []) or []
matches = getattr(query_response, 'matches', []) or []
return self._result_to_get_result(matches)
except Exception as e:
@@ -522,9 +464,7 @@ class PineconeClient(VectorDBBase):
filter: Optional[Dict] = None,
) -> None:
"""Delete vectors by IDs or filter."""
collection_name_with_prefix = self._get_collection_name_with_prefix(
collection_name
)
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
try:
if ids:
@@ -534,28 +474,20 @@ class PineconeClient(VectorDBBase):
# Note: When deleting by ID, we can't filter by collection_name
# This is a limitation of Pinecone - be careful with ID uniqueness
self.index.delete(ids=batch_ids)
log.debug(
f"Deleted batch of {len(batch_ids)} vectors by ID "
f"from '{collection_name_with_prefix}'"
)
log.info(
f"Successfully deleted {len(ids)} vectors by ID "
f"from '{collection_name_with_prefix}'"
)
log.debug(f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'")
log.info(f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'")
elif filter:
# Combine user filter with collection_name
pinecone_filter = {"collection_name": collection_name_with_prefix}
pinecone_filter = {'collection_name': collection_name_with_prefix}
if filter:
pinecone_filter.update(filter)
# Delete by metadata filter
self.index.delete(filter=pinecone_filter)
log.info(
f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
)
log.info(f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'")
else:
log.warning("No ids or filter provided for delete operation")
log.warning('No ids or filter provided for delete operation')
except Exception as e:
log.error(f"Error deleting from collection '{collection_name}': {e}")
@@ -565,9 +497,9 @@ class PineconeClient(VectorDBBase):
"""Reset the database by deleting all collections."""
try:
self.index.delete(delete_all=True)
log.info("All vectors successfully deleted from the index.")
log.info('All vectors successfully deleted from the index.')
except Exception as e:
log.error(f"Failed to reset Pinecone index: {e}")
log.error(f'Failed to reset Pinecone index: {e}')
raise
def close(self):
@@ -576,7 +508,7 @@ class PineconeClient(VectorDBBase):
# The new Pinecone client doesn't need explicit closing
pass
except Exception as e:
log.warning(f"Failed to clean up Pinecone resources: {e}")
log.warning(f'Failed to clean up Pinecone resources: {e}')
self._executor.shutdown(wait=True)
def __enter__(self):

View File

@@ -76,19 +76,19 @@ class QdrantClient(VectorDBBase):
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
documents.append(payload['text'])
metadatas.append(payload['metadata'])
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
'ids': [ids],
'documents': [documents],
'metadatas': [metadatas],
}
)
def _create_collection(self, collection_name: str, dimension: int):
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
collection_name_with_prefix = f'{self.collection_prefix}_{collection_name}'
self.client.create_collection(
collection_name=collection_name_with_prefix,
vectors_config=models.VectorParams(
@@ -104,7 +104,7 @@ class QdrantClient(VectorDBBase):
# Create payload indexes for efficient filtering
self.client.create_payload_index(
collection_name=collection_name_with_prefix,
field_name="metadata.hash",
field_name='metadata.hash',
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=False,
@@ -113,40 +113,34 @@ class QdrantClient(VectorDBBase):
)
self.client.create_payload_index(
collection_name=collection_name_with_prefix,
field_name="metadata.file_id",
field_name='metadata.file_id',
field_schema=models.KeywordIndexParams(
type=models.KeywordIndexType.KEYWORD,
is_tenant=False,
on_disk=self.QDRANT_ON_DISK,
),
)
log.info(f"collection {collection_name_with_prefix} successfully created!")
log.info(f'collection {collection_name_with_prefix} successfully created!')
def _create_collection_if_not_exists(self, collection_name, dimension):
if not self.has_collection(collection_name=collection_name):
self._create_collection(
collection_name=collection_name, dimension=dimension
)
self._create_collection(collection_name=collection_name, dimension=dimension)
def _create_points(self, items: list[VectorItem]):
return [
PointStruct(
id=item["id"],
vector=item["vector"],
payload={"text": item["text"], "metadata": item["metadata"]},
id=item['id'],
vector=item['vector'],
payload={'text': item['text'], 'metadata': item['metadata']},
)
for item in items
]
def has_collection(self, collection_name: str) -> bool:
return self.client.collection_exists(
f"{self.collection_prefix}_{collection_name}"
)
return self.client.collection_exists(f'{self.collection_prefix}_{collection_name}')
def delete_collection(self, collection_name: str):
return self.client.delete_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
return self.client.delete_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
def search(
self,
@@ -160,7 +154,7 @@ class QdrantClient(VectorDBBase):
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
query_response = self.client.query_points(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
query=vectors[0],
limit=limit,
)
@@ -184,13 +178,11 @@ class QdrantClient(VectorDBBase):
field_conditions = []
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
models.FieldCondition(key=f'metadata.{key}', match=models.MatchValue(value=value))
)
points = self.client.scroll(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
scroll_filter=models.Filter(should=field_conditions),
limit=limit,
)
@@ -202,22 +194,22 @@ class QdrantClient(VectorDBBase):
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
points = self.client.scroll(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
)
return self._result_to_get_result(points[0])
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
self._create_collection_if_not_exists(collection_name, len(items[0]['vector']))
points = self._create_points(items)
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
self.client.upload_points(f'{self.collection_prefix}_{collection_name}', points)
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
self._create_collection_if_not_exists(collection_name, len(items[0]['vector']))
points = self._create_points(items)
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
return self.client.upsert(f'{self.collection_prefix}_{collection_name}', points)
def delete(
self,
@@ -230,26 +222,28 @@ class QdrantClient(VectorDBBase):
if ids:
for id_value in ids:
field_conditions.append(
models.FieldCondition(
key="metadata.id",
match=models.MatchValue(value=id_value),
(
field_conditions.append(
models.FieldCondition(
key='metadata.id',
match=models.MatchValue(value=id_value),
),
),
),
)
elif filter:
for key, value in filter.items():
field_conditions.append(
models.FieldCondition(
key=f"metadata.{key}",
match=models.MatchValue(value=value),
(
field_conditions.append(
models.FieldCondition(
key=f'metadata.{key}',
match=models.MatchValue(value=value),
),
),
),
)
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
points_selector=models.FilterSelector(
filter=models.Filter(must=field_conditions)
),
collection_name=f'{self.collection_prefix}_{collection_name}',
points_selector=models.FilterSelector(filter=models.Filter(must=field_conditions)),
)
def reset(self):

View File

@@ -29,22 +29,18 @@ from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
NO_LIMIT = 999999999
TENANT_ID_FIELD = "tenant_id"
TENANT_ID_FIELD = 'tenant_id'
DEFAULT_DIMENSION = 384
log = logging.getLogger(__name__)
def _tenant_filter(tenant_id: str) -> models.FieldCondition:
return models.FieldCondition(
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
)
return models.FieldCondition(key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id))
def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
return models.FieldCondition(
key=f"metadata.{key}", match=models.MatchValue(value=value)
)
return models.FieldCondition(key=f'metadata.{key}', match=models.MatchValue(value=value))
class QdrantClient(VectorDBBase):
@@ -59,9 +55,7 @@ class QdrantClient(VectorDBBase):
self.QDRANT_HNSW_M = QDRANT_HNSW_M
if not self.QDRANT_URI:
raise ValueError(
"QDRANT_URI is not set. Please configure it in the environment variables."
)
raise ValueError('QDRANT_URI is not set. Please configure it in the environment variables.')
# Unified handling for either scheme
parsed = urlparse(self.QDRANT_URI)
@@ -86,19 +80,19 @@ class QdrantClient(VectorDBBase):
)
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories'
self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge'
self.FILE_COLLECTION = f'{self.collection_prefix}_files'
self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web-search'
self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash-based'
def _result_to_get_result(self, points) -> GetResult:
ids, documents, metadatas = [], [], []
for point in points:
payload = point.payload
ids.append(point.id)
documents.append(payload["text"])
metadatas.append(payload["metadata"])
documents.append(payload['text'])
metadatas.append(payload['metadata'])
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
@@ -118,29 +112,25 @@ class QdrantClient(VectorDBBase):
# Check for user memory collections
tenant_id = collection_name
if collection_name.startswith("user-memory-"):
if collection_name.startswith('user-memory-'):
return self.MEMORY_COLLECTION, tenant_id
# Check for file collections
elif collection_name.startswith("file-"):
elif collection_name.startswith('file-'):
return self.FILE_COLLECTION, tenant_id
# Check for web search collections
elif collection_name.startswith("web-search-"):
elif collection_name.startswith('web-search-'):
return self.WEB_SEARCH_COLLECTION, tenant_id
# Handle hash-based collections (YouTube and web URLs)
elif len(collection_name) == 63 and all(
c in "0123456789abcdef" for c in collection_name
):
elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name):
return self.HASH_BASED_COLLECTION, tenant_id
else:
return self.KNOWLEDGE_COLLECTION, tenant_id
def _create_multi_tenant_collection(
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
):
def _create_multi_tenant_collection(self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION):
"""
Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
"""
@@ -158,9 +148,7 @@ class QdrantClient(VectorDBBase):
m=0,
),
)
log.info(
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
)
log.info(f'Multi-tenant collection {mt_collection_name} created with dimension {dimension}!')
self.client.create_payload_index(
collection_name=mt_collection_name,
@@ -172,7 +160,7 @@ class QdrantClient(VectorDBBase):
),
)
for field in ("metadata.hash", "metadata.file_id"):
for field in ('metadata.hash', 'metadata.file_id'):
self.client.create_payload_index(
collection_name=mt_collection_name,
field_name=field,
@@ -182,28 +170,24 @@ class QdrantClient(VectorDBBase):
),
)
def _create_points(
self, items: List[VectorItem], tenant_id: str
) -> List[PointStruct]:
def _create_points(self, items: List[VectorItem], tenant_id: str) -> List[PointStruct]:
"""
Create point structs from vector items with tenant ID.
"""
return [
PointStruct(
id=item["id"],
vector=item["vector"],
id=item['id'],
vector=item['vector'],
payload={
"text": item["text"],
"metadata": item["metadata"],
'text': item['text'],
'metadata': item['metadata'],
TENANT_ID_FIELD: tenant_id,
},
)
for item in items
]
def _ensure_collection(
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
):
def _ensure_collection(self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION):
"""
Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
"""
@@ -246,15 +230,13 @@ class QdrantClient(VectorDBBase):
must_conditions = [_tenant_filter(tenant_id)]
should_conditions = []
if ids:
should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
should_conditions = [_metadata_filter('id', id_value) for id_value in ids]
elif filter:
must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
return self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=must_conditions, should=should_conditions)
),
points_selector=models.FilterSelector(filter=models.Filter(must=must_conditions, should=should_conditions)),
)
def search(
@@ -289,9 +271,7 @@ class QdrantClient(VectorDBBase):
distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
)
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
):
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None):
"""
Query points with filters and tenant isolation.
"""
@@ -338,7 +318,7 @@ class QdrantClient(VectorDBBase):
if not self.client or not items:
return None
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
dimension = len(items[0]["vector"])
dimension = len(items[0]['vector'])
self._ensure_collection(mt_collection, dimension)
points = self._create_points(items, tenant_id)
self.client.upload_points(mt_collection, points)
@@ -372,7 +352,5 @@ class QdrantClient(VectorDBBase):
return None
self.client.delete(
collection_name=mt_collection,
points_selector=models.FilterSelector(
filter=models.Filter(must=[_tenant_filter(tenant_id)])
),
points_selector=models.FilterSelector(filter=models.Filter(must=[_tenant_filter(tenant_id)])),
)

View File

@@ -28,18 +28,16 @@ class S3VectorClient(VectorDBBase):
# Simple validation - log warnings instead of raising exceptions
if not self.bucket_name:
log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
log.warning('S3_VECTOR_BUCKET_NAME not set - S3Vector will not work')
if not self.region:
log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
log.warning('S3_VECTOR_REGION not set - S3Vector will not work')
if self.bucket_name and self.region:
try:
self.client = boto3.client("s3vectors", region_name=self.region)
log.info(
f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
)
self.client = boto3.client('s3vectors', region_name=self.region)
log.info(f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'")
except Exception as e:
log.error(f"Failed to initialize S3Vector client: {e}")
log.error(f'Failed to initialize S3Vector client: {e}')
self.client = None
else:
self.client = None
@@ -48,8 +46,8 @@ class S3VectorClient(VectorDBBase):
self,
index_name: str,
dimension: int,
data_type: str = "float32",
distance_metric: str = "cosine",
data_type: str = 'float32',
distance_metric: str = 'cosine',
) -> None:
"""
Create a new index in the S3 vector bucket for the given collection if it does not exist.
@@ -66,21 +64,17 @@ class S3VectorClient(VectorDBBase):
dimension=dimension,
distanceMetric=distance_metric,
metadataConfiguration={
"nonFilterableMetadataKeys": [
"text",
'nonFilterableMetadataKeys': [
'text',
]
},
)
log.info(
f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
)
log.info(f'Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})')
except Exception as e:
log.error(f"Error creating S3 index '{index_name}': {e}")
raise
def _filter_metadata(
self, metadata: Dict[str, Any], item_id: str
) -> Dict[str, Any]:
def _filter_metadata(self, metadata: Dict[str, Any], item_id: str) -> Dict[str, Any]:
"""
Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
"""
@@ -89,16 +83,16 @@ class S3VectorClient(VectorDBBase):
# Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
important_keys = [
"text", # The actual document content
"file_id", # File ID
"source", # Document source file
"title", # Document title
"page", # Page number
"total_pages", # Total pages in document
"embedding_config", # Embedding configuration
"created_by", # User who created it
"name", # Document name
"hash", # Content hash
'text', # The actual document content
'file_id', # File ID
'source', # Document source file
'title', # Document title
'page', # Page number
'total_pages', # Total pages in document
'embedding_config', # Embedding configuration
'created_by', # User who created it
'name', # Document name
'hash', # Content hash
]
filtered_metadata = {}
@@ -117,9 +111,7 @@ class S3VectorClient(VectorDBBase):
if len(filtered_metadata) >= 10:
break
log.warning(
f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
)
log.warning(f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys")
return filtered_metadata
def has_collection(self, collection_name: str) -> bool:
@@ -128,9 +120,7 @@ class S3VectorClient(VectorDBBase):
This avoids pagination issues with list_indexes() and is significantly faster.
"""
try:
self.client.get_index(
vectorBucketName=self.bucket_name, indexName=collection_name
)
self.client.get_index(vectorBucketName=self.bucket_name, indexName=collection_name)
return True
except Exception as e:
log.error(f"Error checking if index '{collection_name}' exists: {e}")
@@ -142,16 +132,12 @@ class S3VectorClient(VectorDBBase):
"""
if not self.has_collection(collection_name):
log.warning(
f"Collection '{collection_name}' does not exist, nothing to delete"
)
log.warning(f"Collection '{collection_name}' does not exist, nothing to delete")
return
try:
log.info(f"Deleting collection '{collection_name}'")
self.client.delete_index(
vectorBucketName=self.bucket_name, indexName=collection_name
)
self.client.delete_index(vectorBucketName=self.bucket_name, indexName=collection_name)
log.info(f"Successfully deleted collection '{collection_name}'")
except Exception as e:
log.error(f"Error deleting collection '{collection_name}': {e}")
@@ -162,10 +148,10 @@ class S3VectorClient(VectorDBBase):
Insert vector items into the S3 Vector index. Create index if it does not exist.
"""
if not items:
log.warning("No items to insert")
log.warning('No items to insert')
return
dimension = len(items[0]["vector"])
dimension = len(items[0]['vector'])
try:
if not self.has_collection(collection_name):
@@ -173,36 +159,36 @@ class S3VectorClient(VectorDBBase):
self._create_index(
index_name=collection_name,
dimension=dimension,
data_type="float32",
distance_metric="cosine",
data_type='float32',
distance_metric='cosine',
)
# Prepare vectors for insertion
vectors = []
for item in items:
# Ensure vector data is in the correct format for S3 Vector API
vector_data = item["vector"]
vector_data = item['vector']
if isinstance(vector_data, list):
# Convert list to float32 values as required by S3 Vector API
vector_data = [float(x) for x in vector_data]
# Prepare metadata, ensuring the text field is preserved
metadata = item.get("metadata", {}).copy()
metadata = item.get('metadata', {}).copy()
# Add the text field to metadata so it's available for retrieval
metadata["text"] = item["text"]
metadata['text'] = item['text']
# Convert metadata to string format for consistency
metadata = process_metadata(metadata)
# Filter metadata to comply with S3 Vector API limit of 10 keys
metadata = self._filter_metadata(metadata, item["id"])
metadata = self._filter_metadata(metadata, item['id'])
vectors.append(
{
"key": item["id"],
"data": {"float32": vector_data},
"metadata": metadata,
'key': item['id'],
'data': {'float32': vector_data},
'metadata': metadata,
}
)
@@ -215,15 +201,11 @@ class S3VectorClient(VectorDBBase):
indexName=collection_name,
vectors=batch,
)
log.info(
f"Inserted batch {i//batch_size + 1}: {len(batch)} vectors into index '{collection_name}'."
)
log.info(f"Inserted batch {i // batch_size + 1}: {len(batch)} vectors into index '{collection_name}'.")
log.info(
f"Completed insertion of {len(vectors)} vectors into index '{collection_name}'."
)
log.info(f"Completed insertion of {len(vectors)} vectors into index '{collection_name}'.")
except Exception as e:
log.error(f"Error inserting vectors: {e}")
log.error(f'Error inserting vectors: {e}')
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -231,49 +213,47 @@ class S3VectorClient(VectorDBBase):
Insert or update vector items in the S3 Vector index. Create index if it does not exist.
"""
if not items:
log.warning("No items to upsert")
log.warning('No items to upsert')
return
dimension = len(items[0]["vector"])
log.info(f"Upsert dimension: {dimension}")
dimension = len(items[0]['vector'])
log.info(f'Upsert dimension: {dimension}')
try:
if not self.has_collection(collection_name):
log.info(
f"Index '{collection_name}' does not exist. Creating index for upsert."
)
log.info(f"Index '{collection_name}' does not exist. Creating index for upsert.")
self._create_index(
index_name=collection_name,
dimension=dimension,
data_type="float32",
distance_metric="cosine",
data_type='float32',
distance_metric='cosine',
)
# Prepare vectors for upsert
vectors = []
for item in items:
# Ensure vector data is in the correct format for S3 Vector API
vector_data = item["vector"]
vector_data = item['vector']
if isinstance(vector_data, list):
# Convert list to float32 values as required by S3 Vector API
vector_data = [float(x) for x in vector_data]
# Prepare metadata, ensuring the text field is preserved
metadata = item.get("metadata", {}).copy()
metadata = item.get('metadata', {}).copy()
# Add the text field to metadata so it's available for retrieval
metadata["text"] = item["text"]
metadata['text'] = item['text']
# Convert metadata to string format for consistency
metadata = process_metadata(metadata)
# Filter metadata to comply with S3 Vector API limit of 10 keys
metadata = self._filter_metadata(metadata, item["id"])
metadata = self._filter_metadata(metadata, item['id'])
vectors.append(
{
"key": item["id"],
"data": {"float32": vector_data},
"metadata": metadata,
'key': item['id'],
'data': {'float32': vector_data},
'metadata': metadata,
}
)
@@ -283,12 +263,10 @@ class S3VectorClient(VectorDBBase):
batch = vectors[i : i + batch_size]
if i == 0: # Log sample info for first batch only
log.info(
f"Upserting batch 1: {len(batch)} vectors. First vector sample: key={batch[0]['key']}, data_type={type(batch[0]['data']['float32'])}, data_len={len(batch[0]['data']['float32'])}"
f'Upserting batch 1: {len(batch)} vectors. First vector sample: key={batch[0]["key"]}, data_type={type(batch[0]["data"]["float32"])}, data_len={len(batch[0]["data"]["float32"])}'
)
else:
log.info(
f"Upserting batch {i//batch_size + 1}: {len(batch)} vectors."
)
log.info(f'Upserting batch {i // batch_size + 1}: {len(batch)} vectors.')
self.client.put_vectors(
vectorBucketName=self.bucket_name,
@@ -296,11 +274,9 @@ class S3VectorClient(VectorDBBase):
vectors=batch,
)
log.info(
f"Completed upsert of {len(vectors)} vectors into index '{collection_name}'."
)
log.info(f"Completed upsert of {len(vectors)} vectors into index '{collection_name}'.")
except Exception as e:
log.error(f"Error upserting vectors: {e}")
log.error(f'Error upserting vectors: {e}')
raise
def search(
@@ -319,13 +295,11 @@ class S3VectorClient(VectorDBBase):
return None
if not vectors:
log.warning("No query vectors provided")
log.warning('No query vectors provided')
return None
try:
log.info(
f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
)
log.info(f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}")
# Initialize result lists
all_ids = []
@@ -335,10 +309,10 @@ class S3VectorClient(VectorDBBase):
# Process each query vector
for i, query_vector in enumerate(vectors):
log.debug(f"Processing query vector {i+1}/{len(vectors)}")
log.debug(f'Processing query vector {i + 1}/{len(vectors)}')
# Prepare the query vector in S3 Vector format
query_vector_dict = {"float32": [float(x) for x in query_vector]}
query_vector_dict = {'float32': [float(x) for x in query_vector]}
# Call S3 Vector query API
response = self.client.query_vectors(
@@ -356,24 +330,22 @@ class S3VectorClient(VectorDBBase):
query_metadatas = []
query_distances = []
result_vectors = response.get("vectors", [])
result_vectors = response.get('vectors', [])
for vector in result_vectors:
vector_id = vector.get("key")
vector_metadata = vector.get("metadata", {})
vector_distance = vector.get("distance", 0.0)
vector_id = vector.get('key')
vector_metadata = vector.get('metadata', {})
vector_distance = vector.get('distance', 0.0)
# Extract document text from metadata
document_text = ""
document_text = ''
if isinstance(vector_metadata, dict):
# Get the text field first (highest priority)
document_text = vector_metadata.get("text")
document_text = vector_metadata.get('text')
if not document_text:
# Fallback to other possible text fields
document_text = (
vector_metadata.get("content")
or vector_metadata.get("document")
or vector_id
vector_metadata.get('content') or vector_metadata.get('document') or vector_id
)
else:
document_text = vector_id
@@ -389,7 +361,7 @@ class S3VectorClient(VectorDBBase):
all_metadatas.append(query_metadatas)
all_distances.append(query_distances)
log.info(f"Search completed. Found results for {len(all_ids)} queries")
log.info(f'Search completed. Found results for {len(all_ids)} queries')
# Return SearchResult format
return SearchResult(
@@ -402,24 +374,20 @@ class S3VectorClient(VectorDBBase):
except Exception as e:
log.error(f"Error searching collection '{collection_name}': {str(e)}")
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
if hasattr(e, 'response') and 'Error' in e.response:
error_code = e.response['Error']['Code']
if error_code == 'NotFoundException':
log.warning(f"Collection '{collection_name}' not found")
return None
elif error_code == "ValidationException":
log.error(f"Invalid query vector dimensions or parameters")
elif error_code == 'ValidationException':
log.error(f'Invalid query vector dimensions or parameters')
return None
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
elif error_code == 'AccessDeniedException':
log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
return None
raise
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
"""
Query vectors from a collection using metadata filter.
"""
@@ -429,7 +397,7 @@ class S3VectorClient(VectorDBBase):
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
if not filter:
log.warning("No filter provided, returning all vectors")
log.warning('No filter provided, returning all vectors')
return self.get(collection_name)
try:
@@ -443,17 +411,13 @@ class S3VectorClient(VectorDBBase):
all_vectors_result = self.get(collection_name)
if not all_vectors_result or not all_vectors_result.ids:
log.warning("No vectors found in collection")
log.warning('No vectors found in collection')
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
# Extract the lists from the result
all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
all_documents = (
all_vectors_result.documents[0] if all_vectors_result.documents else []
)
all_metadatas = (
all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
)
all_documents = all_vectors_result.documents[0] if all_vectors_result.documents else []
all_metadatas = all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
# Apply client-side filtering
filtered_ids = []
@@ -472,9 +436,7 @@ class S3VectorClient(VectorDBBase):
if limit and len(filtered_ids) >= limit:
break
log.info(
f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
)
log.info(f'Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total')
# Return GetResult format
if filtered_ids:
@@ -489,15 +451,13 @@ class S3VectorClient(VectorDBBase):
except Exception as e:
log.error(f"Error querying collection '{collection_name}': {str(e)}")
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
if hasattr(e, 'response') and 'Error' in e.response:
error_code = e.response['Error']['Code']
if error_code == 'NotFoundException':
log.warning(f"Collection '{collection_name}' not found")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
elif error_code == 'AccessDeniedException':
log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
raise
@@ -524,47 +484,43 @@ class S3VectorClient(VectorDBBase):
while True:
# Prepare request parameters
request_params = {
"vectorBucketName": self.bucket_name,
"indexName": collection_name,
"returnData": False, # Don't include vector data (not needed for get)
"returnMetadata": True, # Include metadata
"maxResults": 500, # Use reasonable page size
'vectorBucketName': self.bucket_name,
'indexName': collection_name,
'returnData': False, # Don't include vector data (not needed for get)
'returnMetadata': True, # Include metadata
'maxResults': 500, # Use reasonable page size
}
if next_token:
request_params["nextToken"] = next_token
request_params['nextToken'] = next_token
# Call S3 Vector API
response = self.client.list_vectors(**request_params)
# Process vectors in this page
vectors = response.get("vectors", [])
vectors = response.get('vectors', [])
for vector in vectors:
vector_id = vector.get("key")
vector_data = vector.get("data", {})
vector_metadata = vector.get("metadata", {})
vector_id = vector.get('key')
vector_data = vector.get('data', {})
vector_metadata = vector.get('metadata', {})
# Extract the actual vector array
vector_array = vector_data.get("float32", [])
vector_array = vector_data.get('float32', [])
# For documents, we try to extract text from metadata or use the vector ID
document_text = ""
document_text = ''
if isinstance(vector_metadata, dict):
# Get the text field first (highest priority)
document_text = vector_metadata.get("text")
document_text = vector_metadata.get('text')
if not document_text:
# Fallback to other possible text fields
document_text = (
vector_metadata.get("content")
or vector_metadata.get("document")
or vector_id
vector_metadata.get('content') or vector_metadata.get('document') or vector_id
)
# Log the actual content for debugging
log.debug(
f"Document text preview (first 200 chars): {str(document_text)[:200]}"
)
log.debug(f'Document text preview (first 200 chars): {str(document_text)[:200]}')
else:
document_text = vector_id
@@ -573,37 +529,29 @@ class S3VectorClient(VectorDBBase):
all_metadatas.append(vector_metadata)
# Check if there are more pages
next_token = response.get("nextToken")
next_token = response.get('nextToken')
if not next_token:
break
log.info(
f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
)
log.info(f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'")
# Return in GetResult format
# The Open WebUI GetResult expects lists of lists, so we wrap each list
if all_ids:
return GetResult(
ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
)
return GetResult(ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas])
else:
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
except Exception as e:
log.error(
f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
)
log.error(f"Error retrieving vectors from collection '{collection_name}': {str(e)}")
# Handle specific AWS exceptions
if hasattr(e, "response") and "Error" in e.response:
error_code = e.response["Error"]["Code"]
if error_code == "NotFoundException":
if hasattr(e, 'response') and 'Error' in e.response:
error_code = e.response['Error']['Code']
if error_code == 'NotFoundException':
log.warning(f"Collection '{collection_name}' not found")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
elif error_code == "AccessDeniedException":
log.error(
f"Access denied for collection '{collection_name}'. Check permissions."
)
elif error_code == 'AccessDeniedException':
log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
raise
@@ -618,20 +566,16 @@ class S3VectorClient(VectorDBBase):
"""
if not self.has_collection(collection_name):
log.warning(
f"Collection '{collection_name}' does not exist, nothing to delete"
)
log.warning(f"Collection '{collection_name}' does not exist, nothing to delete")
return
# Check if this is a knowledge collection (not file-specific)
is_knowledge_collection = not collection_name.startswith("file-")
is_knowledge_collection = not collection_name.startswith('file-')
try:
if ids:
# Delete by specific vector IDs/keys
log.info(
f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
)
log.info(f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'")
self.client.delete_vectors(
vectorBucketName=self.bucket_name,
indexName=collection_name,
@@ -641,15 +585,13 @@ class S3VectorClient(VectorDBBase):
elif filter:
# Handle filter-based deletion
log.info(
f"Deleting vectors by filter from collection '{collection_name}': {filter}"
)
log.info(f"Deleting vectors by filter from collection '{collection_name}': {filter}")
# If this is a knowledge collection and we have a file_id filter,
# also clean up the corresponding file-specific collection
if is_knowledge_collection and "file_id" in filter:
file_id = filter["file_id"]
file_collection_name = f"file-{file_id}"
if is_knowledge_collection and 'file_id' in filter:
file_id = filter['file_id']
file_collection_name = f'file-{file_id}'
if self.has_collection(file_collection_name):
log.info(
f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
@@ -661,9 +603,7 @@ class S3VectorClient(VectorDBBase):
query_result = self.query(collection_name, filter)
if query_result and query_result.ids and query_result.ids[0]:
matching_ids = query_result.ids[0]
log.info(
f"Found {len(matching_ids)} vectors matching filter, deleting them"
)
log.info(f'Found {len(matching_ids)} vectors matching filter, deleting them')
# Delete the matching vectors by ID
self.client.delete_vectors(
@@ -671,17 +611,13 @@ class S3VectorClient(VectorDBBase):
indexName=collection_name,
keys=matching_ids,
)
log.info(
f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
)
log.info(f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter")
else:
log.warning("No vectors found matching the filter criteria")
log.warning('No vectors found matching the filter criteria')
else:
log.warning("No IDs or filter provided for deletion")
log.warning('No IDs or filter provided for deletion')
except Exception as e:
log.error(
f"Error deleting vectors from collection '{collection_name}': {e}"
)
log.error(f"Error deleting vectors from collection '{collection_name}': {e}")
raise
def reset(self) -> None:
@@ -690,36 +626,32 @@ class S3VectorClient(VectorDBBase):
"""
try:
log.warning(
"Reset called - this will delete all vector indexes in the S3 bucket"
)
log.warning('Reset called - this will delete all vector indexes in the S3 bucket')
# List all indexes
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
indexes = response.get("indexes", [])
indexes = response.get('indexes', [])
if not indexes:
log.warning("No indexes found to delete")
log.warning('No indexes found to delete')
return
# Delete all indexes
deleted_count = 0
for index in indexes:
index_name = index.get("indexName")
index_name = index.get('indexName')
if index_name:
try:
self.client.delete_index(
vectorBucketName=self.bucket_name, indexName=index_name
)
self.client.delete_index(vectorBucketName=self.bucket_name, indexName=index_name)
deleted_count += 1
log.info(f"Deleted index: {index_name}")
log.info(f'Deleted index: {index_name}')
except Exception as e:
log.error(f"Error deleting index '{index_name}': {e}")
log.info(f"Reset completed: deleted {deleted_count} indexes")
log.info(f'Reset completed: deleted {deleted_count} indexes')
except Exception as e:
log.error(f"Error during reset: {e}")
log.error(f'Error during reset: {e}')
raise
def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
@@ -732,15 +664,15 @@ class S3VectorClient(VectorDBBase):
# Check each filter condition
for key, expected_value in filter.items():
# Handle special operators
if key.startswith("$"):
if key == "$and":
if key.startswith('$'):
if key == '$and':
# All conditions must match
if not isinstance(expected_value, list):
continue
for condition in expected_value:
if not self._matches_filter(metadata, condition):
return False
elif key == "$or":
elif key == '$or':
# At least one condition must match
if not isinstance(expected_value, list):
continue
@@ -760,22 +692,19 @@ class S3VectorClient(VectorDBBase):
if isinstance(expected_value, dict):
# Handle comparison operators
for op, op_value in expected_value.items():
if op == "$eq":
if op == '$eq':
if actual_value != op_value:
return False
elif op == "$ne":
elif op == '$ne':
if actual_value == op_value:
return False
elif op == "$in":
if (
not isinstance(op_value, list)
or actual_value not in op_value
):
elif op == '$in':
if not isinstance(op_value, list) or actual_value not in op_value:
return False
elif op == "$nin":
elif op == '$nin':
if isinstance(op_value, list) and actual_value in op_value:
return False
elif op == "$exists":
elif op == '$exists':
if bool(op_value) != (key in metadata):
return False
# Add more operators as needed

View File

@@ -60,47 +60,43 @@ class WeaviateClient(VectorDBBase):
try:
# Build connection parameters
connection_params = {
"http_host": WEAVIATE_HTTP_HOST,
"http_port": WEAVIATE_HTTP_PORT,
"http_secure": WEAVIATE_HTTP_SECURE,
"grpc_host": WEAVIATE_GRPC_HOST,
"grpc_port": WEAVIATE_GRPC_PORT,
"grpc_secure": WEAVIATE_GRPC_SECURE,
"skip_init_checks": WEAVIATE_SKIP_INIT_CHECKS,
'http_host': WEAVIATE_HTTP_HOST,
'http_port': WEAVIATE_HTTP_PORT,
'http_secure': WEAVIATE_HTTP_SECURE,
'grpc_host': WEAVIATE_GRPC_HOST,
'grpc_port': WEAVIATE_GRPC_PORT,
'grpc_secure': WEAVIATE_GRPC_SECURE,
'skip_init_checks': WEAVIATE_SKIP_INIT_CHECKS,
}
# Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty
if WEAVIATE_API_KEY:
connection_params["auth_credentials"] = (
weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
)
connection_params['auth_credentials'] = weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
self.client = weaviate.connect_to_custom(**connection_params)
self.client.connect()
except Exception as e:
raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e
raise ConnectionError(f'Failed to connect to Weaviate: {e}') from e
def _sanitize_collection_name(self, collection_name: str) -> str:
"""Sanitize collection name to be a valid Weaviate class name."""
if not isinstance(collection_name, str) or not collection_name.strip():
raise ValueError("Collection name must be a non-empty string")
raise ValueError('Collection name must be a non-empty string')
# Requirements for a valid Weaviate class name:
# The collection name must begin with a capital letter.
# The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed.
# Replace hyphens with underscores and keep only alphanumeric characters
name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_"))
name = name.strip("_")
name = re.sub(r'[^a-zA-Z0-9_]', '', collection_name.replace('-', '_'))
name = name.strip('_')
if not name:
raise ValueError(
"Could not sanitize collection name to be a valid Weaviate class name"
)
raise ValueError('Could not sanitize collection name to be a valid Weaviate class name')
# Ensure it starts with a letter and is capitalized
if not name[0].isalpha():
name = "C" + name
name = 'C' + name
return name[0].upper() + name[1:]
@@ -118,9 +114,7 @@ class WeaviateClient(VectorDBBase):
name=collection_name,
vector_config=weaviate.classes.config.Configure.Vectors.self_provided(),
properties=[
weaviate.classes.config.Property(
name="text", data_type=weaviate.classes.config.DataType.TEXT
),
weaviate.classes.config.Property(name='text', data_type=weaviate.classes.config.DataType.TEXT),
],
)
@@ -133,19 +127,15 @@ class WeaviateClient(VectorDBBase):
with collection.batch.fixed_size(batch_size=100) as batch:
for item in items:
item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"])
item_uuid = str(uuid.uuid4()) if not item['id'] else str(item['id'])
properties = {"text": item["text"]}
if item["metadata"]:
clean_metadata = _convert_uuids_to_strings(
process_metadata(item["metadata"])
)
clean_metadata.pop("text", None)
properties = {'text': item['text']}
if item['metadata']:
clean_metadata = _convert_uuids_to_strings(process_metadata(item['metadata']))
clean_metadata.pop('text', None)
properties.update(clean_metadata)
batch.add_object(
properties=properties, uuid=item_uuid, vector=item["vector"]
)
batch.add_object(properties=properties, uuid=item_uuid, vector=item['vector'])
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
sane_collection_name = self._sanitize_collection_name(collection_name)
@@ -156,19 +146,15 @@ class WeaviateClient(VectorDBBase):
with collection.batch.fixed_size(batch_size=100) as batch:
for item in items:
item_uuid = str(item["id"]) if item["id"] else None
item_uuid = str(item['id']) if item['id'] else None
properties = {"text": item["text"]}
if item["metadata"]:
clean_metadata = _convert_uuids_to_strings(
process_metadata(item["metadata"])
)
clean_metadata.pop("text", None)
properties = {'text': item['text']}
if item['metadata']:
clean_metadata = _convert_uuids_to_strings(process_metadata(item['metadata']))
clean_metadata.pop('text', None)
properties.update(clean_metadata)
batch.add_object(
properties=properties, uuid=item_uuid, vector=item["vector"]
)
batch.add_object(properties=properties, uuid=item_uuid, vector=item['vector'])
def search(
self,
@@ -205,16 +191,12 @@ class WeaviateClient(VectorDBBase):
for obj in response.objects:
properties = dict(obj.properties) if obj.properties else {}
documents.append(properties.pop("text", ""))
documents.append(properties.pop('text', ''))
metadatas.append(_convert_uuids_to_strings(properties))
# Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1
raw_distances = [
(
obj.metadata.distance
if obj.metadata and obj.metadata.distance
else 2.0
)
(obj.metadata.distance if obj.metadata and obj.metadata.distance else 2.0)
for obj in response.objects
]
distances = [(2 - dist) / 2 for dist in raw_distances]
@@ -231,16 +213,14 @@ class WeaviateClient(VectorDBBase):
return SearchResult(
**{
"ids": result_ids,
"documents": result_documents,
"metadatas": result_metadatas,
"distances": result_distances,
'ids': result_ids,
'documents': result_documents,
'metadatas': result_metadatas,
'distances': result_distances,
}
)
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
sane_collection_name = self._sanitize_collection_name(collection_name)
if not self.client.collections.exists(sane_collection_name):
return None
@@ -250,21 +230,15 @@ class WeaviateClient(VectorDBBase):
weaviate_filter = None
if filter:
for key, value in filter.items():
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(
value
)
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value)
weaviate_filter = (
prop_filter
if weaviate_filter is None
else weaviate.classes.query.Filter.all_of(
[weaviate_filter, prop_filter]
)
else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter])
)
try:
response = collection.query.fetch_objects(
filters=weaviate_filter, limit=limit
)
response = collection.query.fetch_objects(filters=weaviate_filter, limit=limit)
ids = [str(obj.uuid) for obj in response.objects]
documents = []
@@ -272,14 +246,14 @@ class WeaviateClient(VectorDBBase):
for obj in response.objects:
properties = dict(obj.properties) if obj.properties else {}
documents.append(properties.pop("text", ""))
documents.append(properties.pop('text', ''))
metadatas.append(_convert_uuids_to_strings(properties))
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
'ids': [ids],
'documents': [documents],
'metadatas': [metadatas],
}
)
except Exception:
@@ -297,7 +271,7 @@ class WeaviateClient(VectorDBBase):
for item in collection.iterator():
ids.append(str(item.uuid))
properties = dict(item.properties) if item.properties else {}
documents.append(properties.pop("text", ""))
documents.append(properties.pop('text', ''))
metadatas.append(_convert_uuids_to_strings(properties))
if not ids:
@@ -305,9 +279,9 @@ class WeaviateClient(VectorDBBase):
return GetResult(
**{
"ids": [ids],
"documents": [documents],
"metadatas": [metadatas],
'ids': [ids],
'documents': [documents],
'metadatas': [metadatas],
}
)
except Exception:
@@ -332,15 +306,11 @@ class WeaviateClient(VectorDBBase):
elif filter:
weaviate_filter = None
for key, value in filter.items():
prop_filter = weaviate.classes.query.Filter.by_property(
name=key
).equal(value)
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value)
weaviate_filter = (
prop_filter
if weaviate_filter is None
else weaviate.classes.query.Filter.all_of(
[weaviate_filter, prop_filter]
)
else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter])
)
if weaviate_filter:

View File

@@ -8,7 +8,6 @@ from open_webui.config import (
class Vector:
@staticmethod
def get_vector(vector_type: str) -> VectorDBBase:
"""
@@ -82,7 +81,7 @@ class Vector:
return WeaviateClient()
case _:
raise ValueError(f"Unsupported vector type: {vector_type}")
raise ValueError(f'Unsupported vector type: {vector_type}')
VECTOR_DB_CLIENT = Vector.get_vector(VECTOR_DB)

View File

@@ -63,9 +63,7 @@ class VectorDBBase(ABC):
pass
@abstractmethod
def query(
self, collection_name: str, filter: Dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
"""Query vectors from a collection using metadata filter."""
pass

View File

@@ -2,15 +2,15 @@ from enum import StrEnum
class VectorType(StrEnum):
MILVUS = "milvus"
MARIADB_VECTOR = "mariadb-vector"
QDRANT = "qdrant"
CHROMA = "chroma"
PINECONE = "pinecone"
ELASTICSEARCH = "elasticsearch"
OPENSEARCH = "opensearch"
PGVECTOR = "pgvector"
ORACLE23AI = "oracle23ai"
S3VECTOR = "s3vector"
WEAVIATE = "weaviate"
OPENGAUSS = "opengauss"
MILVUS = 'milvus'
MARIADB_VECTOR = 'mariadb-vector'
QDRANT = 'qdrant'
CHROMA = 'chroma'
PINECONE = 'pinecone'
ELASTICSEARCH = 'elasticsearch'
OPENSEARCH = 'opensearch'
PGVECTOR = 'pgvector'
ORACLE23AI = 'oracle23ai'
S3VECTOR = 's3vector'
WEAVIATE = 'weaviate'
OPENGAUSS = 'opengauss'

View File

@@ -1,13 +1,11 @@
from datetime import datetime
KEYS_TO_EXCLUDE = ["content", "pages", "tables", "paragraphs", "sections", "figures"]
KEYS_TO_EXCLUDE = ['content', 'pages', 'tables', 'paragraphs', 'sections', 'figures']
def filter_metadata(metadata: dict[str, any]) -> dict[str, any]:
# Removes large/redundant fields from metadata dict.
metadata = {
key: value for key, value in metadata.items() if key not in KEYS_TO_EXCLUDE
}
metadata = {key: value for key, value in metadata.items() if key not in KEYS_TO_EXCLUDE}
return metadata

View File

@@ -40,20 +40,17 @@ def search_azure(
from azure.search.documents import SearchClient
except ImportError:
log.error(
"azure-search-documents package is not installed. "
"Install it with: pip install azure-search-documents"
'azure-search-documents package is not installed. Install it with: pip install azure-search-documents'
)
raise ImportError(
"azure-search-documents is required for Azure AI Search. "
"Install it with: pip install azure-search-documents"
'azure-search-documents is required for Azure AI Search. '
'Install it with: pip install azure-search-documents'
)
try:
# Create search client with API key authentication
credential = AzureKeyCredential(api_key)
search_client = SearchClient(
endpoint=endpoint, index_name=index_name, credential=credential
)
search_client = SearchClient(endpoint=endpoint, index_name=index_name, credential=credential)
# Perform the search
results = search_client.search(search_text=query, top=count)
@@ -68,42 +65,42 @@ def search_azure(
# Try to find URL field (common names)
link = (
result_dict.get("url")
or result_dict.get("link")
or result_dict.get("uri")
or result_dict.get("metadata_storage_path")
or ""
result_dict.get('url')
or result_dict.get('link')
or result_dict.get('uri')
or result_dict.get('metadata_storage_path')
or ''
)
# Try to find title field (common names)
title = (
result_dict.get("title")
or result_dict.get("name")
or result_dict.get("metadata_title")
or result_dict.get("metadata_storage_name")
result_dict.get('title')
or result_dict.get('name')
or result_dict.get('metadata_title')
or result_dict.get('metadata_storage_name')
or None
)
# Try to find content/snippet field (common names)
snippet = (
result_dict.get("content")
or result_dict.get("snippet")
or result_dict.get("description")
or result_dict.get("summary")
or result_dict.get("text")
result_dict.get('content')
or result_dict.get('snippet')
or result_dict.get('description')
or result_dict.get('summary')
or result_dict.get('text')
or None
)
# Truncate snippet if too long
if snippet and len(snippet) > 500:
snippet = snippet[:497] + "..."
snippet = snippet[:497] + '...'
if link: # Only add if we found a valid link
search_results.append(
{
"link": link,
"title": title,
"snippet": snippet,
'link': link,
'title': title,
'snippet': snippet,
}
)
@@ -114,13 +111,13 @@ def search_azure(
# Convert to SearchResult objects
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
link=result['link'],
title=result.get('title'),
snippet=result.get('snippet'),
)
for result in search_results
]
except Exception as ex:
log.error(f"Azure AI Search error: {ex}")
log.error(f'Azure AI Search error: {ex}')
raise ex

View File

@@ -21,48 +21,44 @@ def search_bing(
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
mkt = locale
params = {"q": query, "mkt": mkt, "count": count}
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
params = {'q': query, 'mkt': mkt, 'count': count}
headers = {'Ocp-Apim-Subscription-Key': subscription_key}
try:
response = requests.get(endpoint, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("webPages", {}).get("value", [])
results = json_response.get('webPages', {}).get('value', [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"],
title=result.get("name"),
snippet=result.get("snippet"),
link=result['url'],
title=result.get('name'),
snippet=result.get('snippet'),
)
for result in results
]
except Exception as ex:
log.error(f"Error: {ex}")
log.error(f'Error: {ex}')
raise ex
def main():
parser = argparse.ArgumentParser(description="Search Bing from the command line.")
parser = argparse.ArgumentParser(description='Search Bing from the command line.')
parser.add_argument(
"query",
'query',
type=str,
default="Top 10 international news today",
help="The search query.",
default='Top 10 international news today',
help='The search query.',
)
parser.add_argument('--count', type=int, default=10, help='Number of search results to return.')
parser.add_argument('--filter', nargs='*', help='List of filters to apply to the search results.')
parser.add_argument(
"--count", type=int, default=10, help="Number of search results to return."
)
parser.add_argument(
"--filter", nargs="*", help="List of filters to apply to the search results."
)
parser.add_argument(
"--locale",
'--locale',
type=str,
default="en-US",
help="The locale to use for the search, maps to market in api",
default='en-US',
help='The locale to use for the search, maps to market in api',
)
args = parser.parse_args()

View File

@@ -10,43 +10,38 @@ log = logging.getLogger(__name__)
def _parse_response(response):
results = []
if "data" in response:
data = response["data"]
if "webPages" in data:
webPages = data["webPages"]
if "value" in webPages:
if 'data' in response:
data = response['data']
if 'webPages' in data:
webPages = data['webPages']
if 'value' in webPages:
results = [
{
"id": item.get("id", ""),
"name": item.get("name", ""),
"url": item.get("url", ""),
"snippet": item.get("snippet", ""),
"summary": item.get("summary", ""),
"siteName": item.get("siteName", ""),
"siteIcon": item.get("siteIcon", ""),
"datePublished": item.get("datePublished", "")
or item.get("dateLastCrawled", ""),
'id': item.get('id', ''),
'name': item.get('name', ''),
'url': item.get('url', ''),
'snippet': item.get('snippet', ''),
'summary': item.get('summary', ''),
'siteName': item.get('siteName', ''),
'siteIcon': item.get('siteIcon', ''),
'datePublished': item.get('datePublished', '') or item.get('dateLastCrawled', ''),
}
for item in webPages["value"]
for item in webPages['value']
]
return results
def search_bocha(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
def search_bocha(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]:
"""Search using Bocha's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Bocha Search API key
query (str): The query to search for
"""
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
url = 'https://api.bochaai.com/v1/web-search?utm_source=ollama'
headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
payload = json.dumps(
{"query": query, "summary": True, "freshness": "noLimit", "count": count}
)
payload = json.dumps({'query': query, 'summary': True, 'freshness': 'noLimit', 'count': count})
response = requests.post(url, headers=headers, data=payload, timeout=5)
response.raise_for_status()
@@ -56,8 +51,6 @@ def search_bocha(
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("name"), snippet=result.get("summary")
)
SearchResult(link=result['url'], title=result.get('name'), snippet=result.get('summary'))
for result in results[:count]
]

View File

@@ -8,44 +8,42 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results
log = logging.getLogger(__name__)
def search_brave(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
def search_brave(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Brave Search API key
query (str): The query to search for
"""
url = "https://api.search.brave.com/res/v1/web/search"
url = 'https://api.search.brave.com/res/v1/web/search'
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": api_key,
'Accept': 'application/json',
'Accept-Encoding': 'gzip',
'X-Subscription-Token': api_key,
}
params = {"q": query, "count": count}
params = {'q': query, 'count': count}
response = requests.get(url, headers=headers, params=params)
# Handle 429 rate limiting - Brave free tier allows 1 request/second
# If rate limited, wait 1 second and retry once before failing
if response.status_code == 429:
log.info("Brave Search API rate limited (429), retrying after 1 second...")
log.info('Brave Search API rate limited (429), retrying after 1 second...')
time.sleep(1)
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("web", {}).get("results", [])
results = json_response.get('web', {}).get('results', [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"],
title=result.get("title"),
snippet=result.get("description"),
link=result['url'],
title=result.get('title'),
snippet=result.get('description'),
)
for result in results[:count]
]

View File

@@ -13,7 +13,7 @@ def search_duckduckgo(
count: int,
filter_list: Optional[list[str]] = None,
concurrent_requests: Optional[int] = None,
backend: Optional[str] = "auto",
backend: Optional[str] = 'auto',
) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
@@ -33,20 +33,18 @@ def search_duckduckgo(
# Use the ddgs.text() method to perform the search
try:
search_results = ddgs.text(
query, safesearch="moderate", max_results=count, backend=backend
)
search_results = ddgs.text(query, safesearch='moderate', max_results=count, backend=backend)
except RatelimitException as e:
log.error(f"RatelimitException: {e}")
log.error(f'RatelimitException: {e}')
if filter_list:
search_results = get_filtered_results(search_results, filter_list)
# Return the list of search results
return [
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
link=result['href'],
title=result.get('title'),
snippet=result.get('body'),
)
for result in search_results
]

View File

@@ -7,7 +7,7 @@ from open_webui.retrieval.web.main import SearchResult
log = logging.getLogger(__name__)
EXA_API_BASE = "https://api.exa.ai"
EXA_API_BASE = 'https://api.exa.ai'
@dataclass
@@ -31,36 +31,34 @@ def search_exa(
count (int): Number of results to return
filter_list (Optional[list[str]]): List of domains to filter results by
"""
log.info(f"Searching with Exa for query: {query}")
log.info(f'Searching with Exa for query: {query}')
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
payload = {
"query": query,
"numResults": count or 5,
"includeDomains": filter_list,
"contents": {"text": True, "highlights": True},
"type": "auto", # Use the auto search type (keyword or neural)
'query': query,
'numResults': count or 5,
'includeDomains': filter_list,
'contents': {'text': True, 'highlights': True},
'type': 'auto', # Use the auto search type (keyword or neural)
}
try:
response = requests.post(
f"{EXA_API_BASE}/search", headers=headers, json=payload
)
response = requests.post(f'{EXA_API_BASE}/search', headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = []
for result in data["results"]:
for result in data['results']:
results.append(
ExaResult(
url=result["url"],
title=result["title"],
text=result["text"],
url=result['url'],
title=result['title'],
text=result['text'],
)
)
log.info(f"Found {len(results)} results")
log.info(f'Found {len(results)} results')
return [
SearchResult(
link=result.url,
@@ -70,5 +68,5 @@ def search_exa(
for result in results
]
except Exception as e:
log.error(f"Error searching Exa: {e}")
log.error(f'Error searching Exa: {e}')
return []

View File

@@ -24,12 +24,12 @@ def search_external(
) -> List[SearchResult]:
try:
headers = {
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Bearer {external_api_key}",
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot',
'Authorization': f'Bearer {external_api_key}',
}
headers = include_user_info_headers(headers, user)
chat_id = getattr(request.state, "chat_id", None)
chat_id = getattr(request.state, 'chat_id', None)
if chat_id:
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = str(chat_id)
@@ -37,8 +37,8 @@ def search_external(
external_url,
headers=headers,
json={
"query": query,
"count": count,
'query': query,
'count': count,
},
)
response.raise_for_status()
@@ -47,14 +47,14 @@ def search_external(
results = get_filtered_results(results, filter_list)
results = [
SearchResult(
link=result.get("link"),
title=result.get("title"),
snippet=result.get("snippet"),
link=result.get('link'),
title=result.get('title'),
snippet=result.get('snippet'),
)
for result in results[:count]
]
log.info(f"External search results: {results}")
log.info(f'External search results: {results}')
return results
except Exception as e:
log.error(f"Error in External search: {e}")
log.error(f'Error in External search: {e}')
return []

View File

@@ -17,9 +17,7 @@ def search_firecrawl(
from firecrawl import FirecrawlApp
firecrawl = FirecrawlApp(api_key=firecrawl_api_key, api_url=firecrawl_url)
response = firecrawl.search(
query=query, limit=count, ignore_invalid_urls=True, timeout=count * 3
)
response = firecrawl.search(query=query, limit=count, ignore_invalid_urls=True, timeout=count * 3)
results = response.web
if filter_list:
results = get_filtered_results(results, filter_list)
@@ -31,8 +29,8 @@ def search_firecrawl(
)
for result in results[:count]
]
log.info(f"External search results: {results}")
log.info(f'External search results: {results}')
return results
except Exception as e:
log.error(f"Error in External search: {e}")
log.error(f'Error in External search: {e}')
return []

View File

@@ -28,11 +28,11 @@ def search_google_pse(
Returns:
list[SearchResult]: A list of SearchResult objects.
"""
url = "https://www.googleapis.com/customsearch/v1"
url = 'https://www.googleapis.com/customsearch/v1'
headers = {"Content-Type": "application/json"}
headers = {'Content-Type': 'application/json'}
if referer:
headers["Referer"] = referer
headers['Referer'] = referer
all_results = []
start_index = 1 # Google PSE start parameter is 1-based
@@ -40,21 +40,19 @@ def search_google_pse(
while count > 0:
num_results_this_page = min(count, 10) # Google PSE max results per page is 10
params = {
"cx": search_engine_id,
"q": query,
"key": api_key,
"num": num_results_this_page,
"start": start_index,
'cx': search_engine_id,
'q': query,
'key': api_key,
'num': num_results_this_page,
'start': start_index,
}
response = requests.request("GET", url, headers=headers, params=params)
response = requests.request('GET', url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("items", [])
results = json_response.get('items', [])
if results: # check if results are returned. If not, no more pages to fetch.
all_results.extend(results)
count -= len(
results
) # Decrement count by the number of results fetched in this page.
count -= len(results) # Decrement count by the number of results fetched in this page.
start_index += 10 # Increment start index for the next page
else:
break # No more results from Google PSE, break the loop
@@ -64,9 +62,9 @@ def search_google_pse(
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
link=result['link'],
title=result.get('title'),
snippet=result.get('snippet'),
)
for result in all_results
]

View File

@@ -7,9 +7,7 @@ from yarl import URL
log = logging.getLogger(__name__)
def search_jina(
api_key: str, query: str, count: int, base_url: str = ""
) -> list[SearchResult]:
def search_jina(api_key: str, query: str, count: int, base_url: str = '') -> list[SearchResult]:
"""
Search using Jina's Search API and return the results as a list of SearchResult objects.
Args:
@@ -21,16 +19,16 @@ def search_jina(
Returns:
list[SearchResult]: A list of search results
"""
jina_search_endpoint = base_url if base_url else "https://s.jina.ai/"
jina_search_endpoint = base_url if base_url else 'https://s.jina.ai/'
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
"Authorization": api_key,
"X-Retain-Images": "none",
'Accept': 'application/json',
'Content-Type': 'application/json',
'Authorization': api_key,
'X-Retain-Images': 'none',
}
payload = {"q": query, "count": count if count <= 10 else 10}
payload = {'q': query, 'count': count if count <= 10 else 10}
url = str(URL(jina_search_endpoint))
response = requests.post(url, headers=headers, json=payload)
@@ -38,12 +36,12 @@ def search_jina(
data = response.json()
results = []
for result in data["data"]:
for result in data['data']:
results.append(
SearchResult(
link=result["url"],
title=result.get("title"),
snippet=result.get("content"),
link=result['url'],
title=result.get('title'),
snippet=result.get('content'),
)
)

View File

@@ -7,9 +7,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results
log = logging.getLogger(__name__)
def search_kagi(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
def search_kagi(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]:
"""Search using Kagi's Search API and return the results as a list of SearchResult objects.
The Search API will inherit the settings in your account, including results personalization and snippet length.
@@ -19,23 +17,21 @@ def search_kagi(
query (str): The query to search for
count (int): The number of results to return
"""
url = "https://kagi.com/api/v0/search"
url = 'https://kagi.com/api/v0/search'
headers = {
"Authorization": f"Bot {api_key}",
'Authorization': f'Bot {api_key}',
}
params = {"q": query, "limit": count}
params = {'q': query, 'limit': count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
search_results = json_response.get("data", [])
search_results = json_response.get('data', [])
results = [
SearchResult(
link=result["url"], title=result["title"], snippet=result.get("snippet")
)
SearchResult(link=result['url'], title=result['title'], snippet=result.get('snippet'))
for result in search_results
if result["t"] == 0
if result['t'] == 0
]
print(results)

View File

@@ -16,7 +16,7 @@ def get_filtered_results(results, filter_list):
filtered_results = []
for result in results:
url = result.get("url") or result.get("link", "") or result.get("href", "")
url = result.get('url') or result.get('link', '') or result.get('href', '')
if not validators.url(url):
continue

View File

@@ -7,32 +7,27 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results
log = logging.getLogger(__name__)
def search_mojeek(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
def search_mojeek(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]:
"""Search using Mojeek's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Mojeek Search API key
query (str): The query to search for
"""
url = "https://api.mojeek.com/search"
url = 'https://api.mojeek.com/search'
headers = {
"Accept": "application/json",
'Accept': 'application/json',
}
params = {"q": query, "api_key": api_key, "fmt": "json", "t": count}
params = {'q': query, 'api_key': api_key, 'fmt': 'json', 't': count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("response", {}).get("results", [])
results = json_response.get('response', {}).get('results', [])
print(results)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("desc")
)
for result in results
SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('desc')) for result in results
]

View File

@@ -23,30 +23,30 @@ def search_ollama_cloud(
count (int): Number of results to return
filter_list (Optional[list[str]]): List of domains to filter results by
"""
log.info(f"Searching with Ollama for query: {query}")
log.info(f'Searching with Ollama for query: {query}')
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {"query": query, "max_results": count}
headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
payload = {'query': query, 'max_results': count}
try:
response = requests.post(f"{url}/api/web_search", headers=headers, json=payload)
response = requests.post(f'{url}/api/web_search', headers=headers, json=payload)
response.raise_for_status()
data = response.json()
results = data.get("results", [])
log.info(f"Found {len(results)} results")
results = data.get('results', [])
log.info(f'Found {len(results)} results')
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result.get("url", ""),
title=result.get("title", ""),
snippet=result.get("content", ""),
link=result.get('url', ''),
title=result.get('title', ''),
snippet=result.get('content', ''),
)
for result in results
]
except Exception as e:
log.error(f"Error searching Ollama: {e}")
log.error(f'Error searching Ollama: {e}')
return []

View File

@@ -5,13 +5,13 @@ import requests
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
MODELS = Literal[
"sonar",
"sonar-pro",
"sonar-reasoning",
"sonar-reasoning-pro",
"sonar-deep-research",
'sonar',
'sonar-pro',
'sonar-reasoning',
'sonar-reasoning-pro',
'sonar-deep-research',
]
SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"]
SEARCH_CONTEXT_USAGE_LEVELS = Literal['low', 'medium', 'high']
log = logging.getLogger(__name__)
@@ -22,8 +22,8 @@ def search_perplexity(
query: str,
count: int,
filter_list: Optional[list[str]] = None,
model: MODELS = "sonar",
search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = "medium",
model: MODELS = 'sonar',
search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = 'medium',
) -> list[SearchResult]:
"""Search using Perplexity API and return the results as a list of SearchResult objects.
@@ -38,66 +38,63 @@ def search_perplexity(
"""
# Handle PersistentConfig object
if hasattr(api_key, "__str__"):
if hasattr(api_key, '__str__'):
api_key = str(api_key)
try:
url = "https://api.perplexity.ai/chat/completions"
url = 'https://api.perplexity.ai/chat/completions'
# Create payload for the API call
payload = {
"model": model,
"messages": [
'model': model,
'messages': [
{
"role": "system",
"content": "You are a search assistant. Provide factual information with citations.",
'role': 'system',
'content': 'You are a search assistant. Provide factual information with citations.',
},
{"role": "user", "content": query},
{'role': 'user', 'content': query},
],
"temperature": 0.2, # Lower temperature for more factual responses
"stream": False,
"web_search_options": {
"search_context_usage": search_context_usage,
'temperature': 0.2, # Lower temperature for more factual responses
'stream': False,
'web_search_options': {
'search_context_usage': search_context_usage,
},
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
'Authorization': f'Bearer {api_key}',
'Content-Type': 'application/json',
}
# Make the API request
response = requests.request("POST", url, json=payload, headers=headers)
response = requests.request('POST', url, json=payload, headers=headers)
# Parse the JSON response
json_response = response.json()
# Extract citations from the response
citations = json_response.get("citations", [])
citations = json_response.get('citations', [])
# Create search results from citations
results = []
for i, citation in enumerate(citations[:count]):
# Extract content from the response to use as snippet
content = ""
if "choices" in json_response and json_response["choices"]:
content = ''
if 'choices' in json_response and json_response['choices']:
if i == 0:
content = json_response["choices"][0]["message"]["content"]
content = json_response['choices'][0]['message']['content']
result = {"link": citation, "title": f"Source {i+1}", "snippet": content}
result = {'link': citation, 'title': f'Source {i + 1}', 'snippet': content}
results.append(result)
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"], title=result["title"], snippet=result["snippet"]
)
SearchResult(link=result['link'], title=result['title'], snippet=result['snippet'])
for result in results[:count]
]
except Exception as e:
log.error(f"Error searching with Perplexity API: {e}")
log.error(f'Error searching with Perplexity API: {e}')
return []

View File

@@ -13,7 +13,7 @@ def search_perplexity_search(
query: str,
count: int,
filter_list: Optional[list[str]] = None,
api_url: str = "https://api.perplexity.ai/search",
api_url: str = 'https://api.perplexity.ai/search',
user=None,
) -> list[SearchResult]:
"""Search using Perplexity API and return the results as a list of SearchResult objects.
@@ -29,10 +29,10 @@ def search_perplexity_search(
"""
# Handle PersistentConfig object
if hasattr(api_key, "__str__"):
if hasattr(api_key, '__str__'):
api_key = str(api_key)
if hasattr(api_url, "__str__"):
if hasattr(api_url, '__str__'):
api_url = str(api_url)
try:
@@ -40,13 +40,13 @@ def search_perplexity_search(
# Create payload for the API call
payload = {
"query": query,
"max_results": count,
'query': query,
'max_results': count,
}
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
'Authorization': f'Bearer {api_key}',
'Content-Type': 'application/json',
}
# Forward user info headers if user is provided
@@ -54,20 +54,17 @@ def search_perplexity_search(
headers = include_user_info_headers(headers, user)
# Make the API request
response = requests.request("POST", url, json=payload, headers=headers)
response = requests.request('POST', url, json=payload, headers=headers)
# Parse the JSON response
json_response = response.json()
# Extract citations from the response
results = json_response.get("results", [])
results = json_response.get('results', [])
return [
SearchResult(
link=result["url"], title=result["title"], snippet=result["snippet"]
)
for result in results
SearchResult(link=result['url'], title=result['title'], snippet=result['snippet']) for result in results
]
except Exception as e:
log.error(f"Error searching with Perplexity Search API: {e}")
log.error(f'Error searching with Perplexity Search API: {e}')
return []

View File

@@ -21,28 +21,26 @@ def search_searchapi(
api_key (str): A searchapi.io API key
query (str): The query to search for
"""
url = "https://www.searchapi.io/api/v1/search"
url = 'https://www.searchapi.io/api/v1/search'
engine = engine or "google"
engine = engine or 'google'
payload = {"engine": engine, "q": query, "api_key": api_key}
payload = {'engine': engine, 'q': query, 'api_key': api_key}
url = f"{url}?{urlencode(payload)}"
response = requests.request("GET", url)
url = f'{url}?{urlencode(payload)}'
response = requests.request('GET', url)
json_response = response.json()
log.info(f"results from searchapi search: {json_response}")
log.info(f'results from searchapi search: {json_response}')
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
results = sorted(json_response.get('organic_results', []), key=lambda x: x.get('position', 0))
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
link=result['link'],
title=result.get('title'),
snippet=result.get('snippet'),
)
for result in results[:count]
]

View File

@@ -38,38 +38,38 @@ def search_searxng(
"""
# Default values for optional parameters are provided as empty strings or None when not specified.
language = kwargs.get("language", "all")
safesearch = kwargs.get("safesearch", "1")
time_range = kwargs.get("time_range", "")
categories = "".join(kwargs.get("categories", []))
language = kwargs.get('language', 'all')
safesearch = kwargs.get('safesearch', '1')
time_range = kwargs.get('time_range', '')
categories = ''.join(kwargs.get('categories', []))
params = {
"q": query,
"format": "json",
"pageno": 1,
"safesearch": safesearch,
"language": language,
"time_range": time_range,
"categories": categories,
"theme": "simple",
"image_proxy": 0,
'q': query,
'format': 'json',
'pageno': 1,
'safesearch': safesearch,
'language': language,
'time_range': time_range,
'categories': categories,
'theme': 'simple',
'image_proxy': 0,
}
# Legacy query format
if "<query>" in query_url:
if '<query>' in query_url:
# Strip all query parameters from the URL
query_url = query_url.split("?")[0]
query_url = query_url.split('?')[0]
log.debug(f"searching {query_url}")
log.debug(f'searching {query_url}')
response = requests.get(
query_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Accept": "text/html",
"Accept-Encoding": "gzip, deflate",
"Accept-Language": "en-US,en;q=0.5",
"Connection": "keep-alive",
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot',
'Accept': 'text/html',
'Accept-Encoding': 'gzip, deflate',
'Accept-Language': 'en-US,en;q=0.5',
'Connection': 'keep-alive',
},
params=params,
)
@@ -77,13 +77,11 @@ def search_searxng(
response.raise_for_status() # Raise an exception for HTTP errors.
json_response = response.json()
results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
results = json_response.get('results', [])
sorted_results = sorted(results, key=lambda x: x.get('score', 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content")
)
SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('content'))
for result in sorted_results[:count]
]

View File

@@ -21,28 +21,26 @@ def search_serpapi(
api_key (str): A serpapi.com API key
query (str): The query to search for
"""
url = "https://serpapi.com/search"
url = 'https://serpapi.com/search'
engine = engine or "google"
engine = engine or 'google'
payload = {"engine": engine, "q": query, "api_key": api_key}
payload = {'engine': engine, 'q': query, 'api_key': api_key}
url = f"{url}?{urlencode(payload)}"
response = requests.request("GET", url)
url = f'{url}?{urlencode(payload)}'
response = requests.request('GET', url)
json_response = response.json()
log.info(f"results from serpapi search: {json_response}")
log.info(f'results from serpapi search: {json_response}')
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
results = sorted(json_response.get('organic_results', []), key=lambda x: x.get('position', 0))
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
link=result['link'],
title=result.get('title'),
snippet=result.get('snippet'),
)
for result in results[:count]
]

View File

@@ -8,34 +8,30 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results
log = logging.getLogger(__name__)
def search_serper(
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
) -> list[SearchResult]:
def search_serper(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
api_key (str): A serper.dev API key
query (str): The query to search for
"""
url = "https://google.serper.dev/search"
url = 'https://google.serper.dev/search'
payload = json.dumps({"q": query})
headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
payload = json.dumps({'q': query})
headers = {'X-API-KEY': api_key, 'Content-Type': 'application/json'}
response = requests.request("POST", url, headers=headers, data=payload)
response = requests.request('POST', url, headers=headers, data=payload)
response.raise_for_status()
json_response = response.json()
results = sorted(
json_response.get("organic", []), key=lambda x: x.get("position", 0)
)
results = sorted(json_response.get('organic', []), key=lambda x: x.get('position', 0))
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
link=result['link'],
title=result.get('title'),
snippet=result.get('description'),
)
for result in results[:count]
]

View File

@@ -12,10 +12,10 @@ def search_serply(
api_key: str,
query: str,
count: int,
hl: str = "us",
hl: str = 'us',
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
device_type: str = 'desktop',
proxy_location: str = 'US',
filter_list: Optional[list[str]] = None,
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
@@ -26,42 +26,40 @@ def search_serply(
hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
limit (int): The maximum number of results to return [10-100, defaults to 10]
"""
log.info("Searching with Serply")
log.info('Searching with Serply')
url = "https://api.serply.io/v1/search/"
url = 'https://api.serply.io/v1/search/'
query_payload = {
"q": query,
"language": "en",
"num": limit,
"gl": proxy_location.upper(),
"hl": hl.lower(),
'q': query,
'language': 'en',
'num': limit,
'gl': proxy_location.upper(),
'hl': hl.lower(),
}
url = f"{url}{urlencode(query_payload)}"
url = f'{url}{urlencode(query_payload)}'
headers = {
"X-API-KEY": api_key,
"X-User-Agent": device_type,
"User-Agent": "open-webui",
"X-Proxy-Location": proxy_location,
'X-API-KEY': api_key,
'X-User-Agent': device_type,
'User-Agent': 'open-webui',
'X-Proxy-Location': proxy_location,
}
response = requests.request("GET", url, headers=headers)
response = requests.request('GET', url, headers=headers)
response.raise_for_status()
json_response = response.json()
log.info(f"results from serply search: {json_response}")
log.info(f'results from serply search: {json_response}')
results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
)
results = sorted(json_response.get('results', []), key=lambda x: x.get('realPosition', 0))
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
link=result['link'],
title=result.get('title'),
snippet=result.get('description'),
)
for result in results[:count]
]

View File

@@ -21,26 +21,22 @@ def search_serpstack(
query (str): The query to search for
https_enabled (bool): Whether to use HTTPS or HTTP for the API request
"""
url = f"{'https' if https_enabled else 'http'}://api.serpstack.com/search"
url = f'{"https" if https_enabled else "http"}://api.serpstack.com/search'
headers = {"Content-Type": "application/json"}
headers = {'Content-Type': 'application/json'}
params = {
"access_key": api_key,
"query": query,
'access_key': api_key,
'query': query,
}
response = requests.request("POST", url, headers=headers, params=params)
response = requests.request('POST', url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = sorted(
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
)
results = sorted(json_response.get('organic_results', []), key=lambda x: x.get('position', 0))
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
)
SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('snippet'))
for result in results[:count]
]

View File

@@ -26,33 +26,26 @@ def search_sougou(
try:
cred = credential.Credential(sougou_api_sid, sougou_api_sk)
http_profile = HttpProfile()
http_profile.endpoint = "tms.tencentcloudapi.com"
http_profile.endpoint = 'tms.tencentcloudapi.com'
client_profile = ClientProfile()
client_profile.http_profile = http_profile
params = json.dumps({"Query": query, "Cnt": 20})
common_client = CommonClient(
"tms", "2020-12-29", cred, "", profile=client_profile
)
params = json.dumps({'Query': query, 'Cnt': 20})
common_client = CommonClient('tms', '2020-12-29', cred, '', profile=client_profile)
results = [
json.loads(page)
for page in common_client.call_json("SearchPro", json.loads(params))[
"Response"
]["Pages"]
json.loads(page) for page in common_client.call_json('SearchPro', json.loads(params))['Response']['Pages']
]
sorted_results = sorted(
results, key=lambda x: x.get("scour", 0.0), reverse=True
)
sorted_results = sorted(results, key=lambda x: x.get('scour', 0.0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result.get("url"),
title=result.get("title"),
snippet=result.get("passage"),
link=result.get('url'),
title=result.get('title'),
snippet=result.get('passage'),
)
for result in sorted_results[:count]
]
except TencentCloudSDKException as err:
log.error(f"Error in Sougou search: {err}")
log.error(f'Error in Sougou search: {err}')
return []

View File

@@ -24,26 +24,26 @@ def search_tavily(
Returns:
list[SearchResult]: A list of search results
"""
url = "https://api.tavily.com/search"
url = 'https://api.tavily.com/search'
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}',
}
data = {"query": query, "max_results": count}
data = {'query': query, 'max_results': count}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
json_response = response.json()
results = json_response.get("results", [])
results = json_response.get('results', [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"],
title=result.get("title", ""),
snippet=result.get("content"),
link=result['url'],
title=result.get('title', ''),
snippet=result.get('content'),
)
for result in results
]

View File

@@ -67,16 +67,14 @@ def validate_url(url: Union[str, Sequence[str]]):
parsed_url = urllib.parse.urlparse(url)
# Protocol validation - only allow http/https
if parsed_url.scheme not in ["http", "https"]:
log.warning(
f"Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}"
)
if parsed_url.scheme not in ['http', 'https']:
log.warning(f'Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}')
raise ValueError(ERROR_MESSAGES.INVALID_URL)
# Blocklist check using unified filtering logic
if WEB_FETCH_FILTER_LIST:
if not is_string_allowed(url, WEB_FETCH_FILTER_LIST):
log.warning(f"URL blocked by filter list: {url}")
log.warning(f'URL blocked by filter list: {url}')
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH:
@@ -106,29 +104,29 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
if validate_url(u):
valid_urls.append(u)
except Exception as e:
log.debug(f"Invalid URL {u}: {str(e)}")
log.debug(f'Invalid URL {u}: {str(e)}')
continue
return valid_urls
def extract_metadata(soup, url):
metadata = {"source": url}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get("content", "No description found.")
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
metadata = {'source': url}
if title := soup.find('title'):
metadata['title'] = title.get_text()
if description := soup.find('meta', attrs={'name': 'description'}):
metadata['description'] = description.get('content', 'No description found.')
if html := soup.find('html'):
metadata['language'] = html.get('lang', 'No language found.')
return metadata
def verify_ssl_cert(url: str) -> bool:
"""Verify SSL certificate for the given URL."""
if not url.startswith("https://"):
if not url.startswith('https://'):
return True
try:
hostname = url.split("://")[-1].split("/")[0]
hostname = url.split('://')[-1].split('/')[0]
context = ssl.create_default_context(cafile=certifi.where())
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
s.connect((hostname, 443))
@@ -136,7 +134,7 @@ def verify_ssl_cert(url: str) -> bool:
except ssl.SSLError:
return False
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
log.warning(f'SSL verification failed for {url}: {str(e)}')
return False
@@ -168,14 +166,14 @@ class URLProcessingMixin:
async def _safe_process_url(self, url: str) -> bool:
"""Perform safety checks before processing a URL."""
if self.verify_ssl and not await self._verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
raise ValueError(f'SSL certificate verification failed for {url}')
await self._wait_for_rate_limit()
return True
def _safe_process_url_sync(self, url: str) -> bool:
"""Synchronous version of safety checks."""
if self.verify_ssl and not verify_ssl_cert(url):
raise ValueError(f"SSL certificate verification failed for {url}")
raise ValueError(f'SSL certificate verification failed for {url}')
self._sync_wait_for_rate_limit()
return True
@@ -191,7 +189,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
api_key: Optional[str] = None,
api_url: Optional[str] = None,
timeout: Optional[int] = None,
mode: Literal["crawl", "scrape", "map"] = "scrape",
mode: Literal['crawl', 'scrape', 'map'] = 'scrape',
proxy: Optional[Dict[str, str]] = None,
params: Optional[Dict] = None,
):
@@ -216,15 +214,15 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
params: The parameters to pass to the Firecrawl API.
For more details, visit: https://docs.firecrawl.dev/sdks/python#batch-scrape
"""
proxy_server = proxy.get("server") if proxy else None
proxy_server = proxy.get('server') if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
proxy['server'] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
proxy = {'server': env_proxy_server}
self.web_paths = web_paths
self.verify_ssl = verify_ssl
self.requests_per_second = requests_per_second
@@ -240,7 +238,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
def lazy_load(self) -> Iterator[Document]:
"""Load documents using FireCrawl batch_scrape."""
log.debug(
"Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s",
'Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s',
len(self.web_paths),
self.mode,
self.params,
@@ -251,7 +249,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
firecrawl = FirecrawlApp(api_key=self.api_key, api_url=self.api_url)
result = firecrawl.batch_scrape(
self.web_paths,
formats=["markdown"],
formats=['markdown'],
skip_tls_verification=not self.verify_ssl,
ignore_invalid_urls=True,
remove_base64_images=True,
@@ -260,28 +258,26 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
**self.params,
)
if result.status != "completed":
raise RuntimeError(
f"FireCrawl batch scrape did not complete successfully. result: {result}"
)
if result.status != 'completed':
raise RuntimeError(f'FireCrawl batch scrape did not complete successfully. result: {result}')
for data in result.data:
metadata = data.metadata or {}
yield Document(
page_content=data.markdown or "",
metadata={"source": metadata.url or metadata.source_url or ""},
page_content=data.markdown or '',
metadata={'source': metadata.url or metadata.source_url or ''},
)
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error extracting content from URLs: {e}")
log.exception(f'Error extracting content from URLs: {e}')
else:
raise e
async def alazy_load(self):
"""Async version of lazy_load."""
log.debug(
"Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s",
'Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s',
len(self.web_paths),
self.mode,
self.params,
@@ -292,7 +288,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
firecrawl = FirecrawlApp(api_key=self.api_key, api_url=self.api_url)
result = firecrawl.batch_scrape(
self.web_paths,
formats=["markdown"],
formats=['markdown'],
skip_tls_verification=not self.verify_ssl,
ignore_invalid_urls=True,
remove_base64_images=True,
@@ -301,21 +297,19 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
**self.params,
)
if result.status != "completed":
raise RuntimeError(
f"FireCrawl batch scrape did not complete successfully. result: {result}"
)
if result.status != 'completed':
raise RuntimeError(f'FireCrawl batch scrape did not complete successfully. result: {result}')
for data in result.data:
metadata = data.metadata or {}
yield Document(
page_content=data.markdown or "",
metadata={"source": metadata.url or metadata.source_url or ""},
page_content=data.markdown or '',
metadata={'source': metadata.url or metadata.source_url or ''},
)
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error extracting content from URLs: {e}")
log.exception(f'Error extracting content from URLs: {e}')
else:
raise e
@@ -325,7 +319,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
self,
web_paths: Union[str, List[str]],
api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic",
extract_depth: Literal['basic', 'advanced'] = 'basic',
continue_on_failure: bool = True,
requests_per_second: Optional[float] = None,
verify_ssl: bool = True,
@@ -345,15 +339,15 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
proxy: Optional proxy configuration.
"""
# Initialize proxy configuration if using environment variables
proxy_server = proxy.get("server") if proxy else None
proxy_server = proxy.get('server') if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
proxy['server'] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
proxy = {'server': env_proxy_server}
# Store parameters for creating TavilyLoader instances
self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
@@ -376,14 +370,14 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
self._safe_process_url_sync(url)
valid_urls.append(url)
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
log.warning(f'SSL verification failed for {url}: {str(e)}')
if not self.continue_on_failure:
raise e
if not valid_urls:
if self.continue_on_failure:
log.warning("No valid URLs to process after SSL verification")
log.warning('No valid URLs to process after SSL verification')
return
raise ValueError("No valid URLs to process after SSL verification")
raise ValueError('No valid URLs to process after SSL verification')
try:
loader = TavilyLoader(
urls=valid_urls,
@@ -394,7 +388,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
yield from loader.lazy_load()
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error extracting content from URLs: {e}")
log.exception(f'Error extracting content from URLs: {e}')
else:
raise e
@@ -406,15 +400,15 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
await self._safe_process_url(url)
valid_urls.append(url)
except Exception as e:
log.warning(f"SSL verification failed for {url}: {str(e)}")
log.warning(f'SSL verification failed for {url}: {str(e)}')
if not self.continue_on_failure:
raise e
if not valid_urls:
if self.continue_on_failure:
log.warning("No valid URLs to process after SSL verification")
log.warning('No valid URLs to process after SSL verification')
return
raise ValueError("No valid URLs to process after SSL verification")
raise ValueError('No valid URLs to process after SSL verification')
try:
loader = TavilyLoader(
@@ -427,7 +421,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
yield document
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading URLs: {e}")
log.exception(f'Error loading URLs: {e}')
else:
raise e
@@ -462,15 +456,15 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing
):
"""Initialize with additional safety parameters and remote browser support."""
proxy_server = proxy.get("server") if proxy else None
proxy_server = proxy.get('server') if proxy else None
if trust_env and not proxy_server:
env_proxies = urllib.request.getproxies()
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
if env_proxy_server:
if proxy:
proxy["server"] = env_proxy_server
proxy['server'] = env_proxy_server
else:
proxy = {"server": env_proxy_server}
proxy = {'server': env_proxy_server}
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
super().__init__(
@@ -504,14 +498,14 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing
page = browser.new_page()
response = page.goto(url, timeout=self.playwright_timeout)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
raise ValueError(f'page.goto() returned None for url {url}')
text = self.evaluator.evaluate(page, browser, response)
metadata = {"source": url}
metadata = {'source': url}
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
log.exception(f'Error loading {url}: {e}')
continue
raise e
browser.close()
@@ -525,9 +519,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing
if self.playwright_ws_url:
browser = await p.chromium.connect(self.playwright_ws_url)
else:
browser = await p.chromium.launch(
headless=self.headless, proxy=self.proxy
)
browser = await p.chromium.launch(headless=self.headless, proxy=self.proxy)
for url in self.urls:
try:
@@ -535,14 +527,14 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing
page = await browser.new_page()
response = await page.goto(url, timeout=self.playwright_timeout)
if response is None:
raise ValueError(f"page.goto() returned None for url {url}")
raise ValueError(f'page.goto() returned None for url {url}')
text = await self.evaluator.evaluate_async(page, browser, response)
metadata = {"source": url}
metadata = {'source': url}
yield Document(page_content=text, metadata=metadata)
except Exception as e:
if self.continue_on_failure:
log.exception(f"Error loading {url}: {e}")
log.exception(f'Error loading {url}: {e}')
continue
raise e
await browser.close()
@@ -560,9 +552,7 @@ class SafeWebBaseLoader(WebBaseLoader):
super().__init__(*args, **kwargs)
self.trust_env = trust_env
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async def _fetch(self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5) -> str:
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
for i in range(retries):
try:
@@ -571,7 +561,7 @@ class SafeWebBaseLoader(WebBaseLoader):
cookies=self.session.cookies.get_dict(),
)
if not self.session.verify:
kwargs["ssl"] = False
kwargs['ssl'] = False
async with session.get(
url,
@@ -585,16 +575,11 @@ class SafeWebBaseLoader(WebBaseLoader):
if i == retries - 1:
raise
else:
log.warning(
f"Error fetching {url} with attempt "
f"{i + 1}/{retries}: {e}. Retrying..."
)
log.warning(f'Error fetching {url} with attempt {i + 1}/{retries}: {e}. Retrying...')
await asyncio.sleep(cooldown * backoff**i)
raise ValueError("retry count exceeded")
raise ValueError('retry count exceeded')
def _unpack_fetch_results(
self, results: Any, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
def _unpack_fetch_results(self, results: Any, urls: List[str], parser: Union[str, None] = None) -> List[Any]:
"""Unpack fetch results into BeautifulSoup objects."""
from bs4 import BeautifulSoup
@@ -602,17 +587,15 @@ class SafeWebBaseLoader(WebBaseLoader):
for i, result in enumerate(results):
url = urls[i]
if parser is None:
if url.endswith(".xml"):
parser = "xml"
if url.endswith('.xml'):
parser = 'xml'
else:
parser = self.default_parser
self._check_parser(parser)
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
return final_results
async def ascrape_all(
self, urls: List[str], parser: Union[str, None] = None
) -> List[Any]:
async def ascrape_all(self, urls: List[str], parser: Union[str, None] = None) -> List[Any]:
"""Async fetch all urls, then return soups for all results."""
results = await self.fetch_all(urls)
return self._unpack_fetch_results(results, urls, parser=parser)
@@ -630,22 +613,20 @@ class SafeWebBaseLoader(WebBaseLoader):
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.exception(f"Error loading {path}: {e}")
log.exception(f'Error loading {path}: {e}')
async def alazy_load(self) -> AsyncIterator[Document]:
"""Async lazy load text from the url(s) in web_path."""
results = await self.ascrape_all(self.web_paths)
for path, soup in zip(self.web_paths, results):
text = soup.get_text(**self.bs_get_text_kwargs)
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
metadata = {'source': path}
if title := soup.find('title'):
metadata['title'] = title.get_text()
if description := soup.find('meta', attrs={'name': 'description'}):
metadata['description'] = description.get('content', 'No description found.')
if html := soup.find('html'):
metadata['language'] = html.get('lang', 'No language found.')
yield Document(page_content=text, metadata=metadata)
async def aload(self) -> list[Document]:
@@ -663,18 +644,18 @@ def get_web_loader(
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
if not safe_urls:
log.warning(f"All provided URLs were blocked or invalid: {urls}")
log.warning(f'All provided URLs were blocked or invalid: {urls}')
raise ValueError(ERROR_MESSAGES.INVALID_URL)
web_loader_args = {
"web_paths": safe_urls,
"verify_ssl": verify_ssl,
"requests_per_second": requests_per_second,
"continue_on_failure": True,
"trust_env": trust_env,
'web_paths': safe_urls,
'verify_ssl': verify_ssl,
'requests_per_second': requests_per_second,
'continue_on_failure': True,
'trust_env': trust_env,
}
if WEB_LOADER_ENGINE.value == "" or WEB_LOADER_ENGINE.value == "safe_web":
if WEB_LOADER_ENGINE.value == '' or WEB_LOADER_ENGINE.value == 'safe_web':
WebLoaderClass = SafeWebBaseLoader
request_kwargs = {}
@@ -685,42 +666,42 @@ def get_web_loader(
timeout_value = None
if timeout_value:
request_kwargs["timeout"] = timeout_value
request_kwargs['timeout'] = timeout_value
if request_kwargs:
web_loader_args["requests_kwargs"] = request_kwargs
web_loader_args['requests_kwargs'] = request_kwargs
if WEB_LOADER_ENGINE.value == "playwright":
if WEB_LOADER_ENGINE.value == 'playwright':
WebLoaderClass = SafePlaywrightURLLoader
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value
web_loader_args['playwright_timeout'] = PLAYWRIGHT_TIMEOUT.value
if PLAYWRIGHT_WS_URL.value:
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URL.value
web_loader_args['playwright_ws_url'] = PLAYWRIGHT_WS_URL.value
if WEB_LOADER_ENGINE.value == "firecrawl":
if WEB_LOADER_ENGINE.value == 'firecrawl':
WebLoaderClass = SafeFireCrawlLoader
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
web_loader_args['api_key'] = FIRECRAWL_API_KEY.value
web_loader_args['api_url'] = FIRECRAWL_API_BASE_URL.value
if FIRECRAWL_TIMEOUT.value:
try:
web_loader_args["timeout"] = int(FIRECRAWL_TIMEOUT.value)
web_loader_args['timeout'] = int(FIRECRAWL_TIMEOUT.value)
except ValueError:
pass
if WEB_LOADER_ENGINE.value == "tavily":
if WEB_LOADER_ENGINE.value == 'tavily':
WebLoaderClass = SafeTavilyLoader
web_loader_args["api_key"] = TAVILY_API_KEY.value
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
web_loader_args['api_key'] = TAVILY_API_KEY.value
web_loader_args['extract_depth'] = TAVILY_EXTRACT_DEPTH.value
if WEB_LOADER_ENGINE.value == "external":
if WEB_LOADER_ENGINE.value == 'external':
WebLoaderClass = ExternalWebLoader
web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value
web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value
web_loader_args['external_url'] = EXTERNAL_WEB_LOADER_URL.value
web_loader_args['external_api_key'] = EXTERNAL_WEB_LOADER_API_KEY.value
if WebLoaderClass:
web_loader = WebLoaderClass(**web_loader_args)
log.debug(
"Using WEB_LOADER_ENGINE %s for %s URLs",
'Using WEB_LOADER_ENGINE %s for %s URLs',
web_loader.__class__.__name__,
len(safe_urls),
)
@@ -728,6 +709,6 @@ def get_web_loader(
return web_loader
else:
raise ValueError(
f"Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. "
f'Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. '
"Please set it to 'safe_web', 'playwright', 'firecrawl', or 'tavily'."
)

View File

@@ -41,29 +41,29 @@ def search_yacy(
yacy_auth = HTTPDigestAuth(username, password)
params = {
"query": query,
"contentdom": "text",
"resource": "global",
"maximumRecords": count,
"nav": "none",
'query': query,
'contentdom': 'text',
'resource': 'global',
'maximumRecords': count,
'nav': 'none',
}
# Check if provided a json API URL
if not query_url.endswith("yacysearch.json"):
if not query_url.endswith('yacysearch.json'):
# Strip all query parameters from the URL
query_url = query_url.rstrip("/") + "/yacysearch.json"
query_url = query_url.rstrip('/') + '/yacysearch.json'
log.debug(f"searching {query_url}")
log.debug(f'searching {query_url}')
response = requests.get(
query_url,
auth=yacy_auth,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Accept": "text/html",
"Accept-Encoding": "gzip, deflate",
"Accept-Language": "en-US,en;q=0.5",
"Connection": "keep-alive",
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot',
'Accept': 'text/html',
'Accept-Encoding': 'gzip, deflate',
'Accept-Language': 'en-US,en;q=0.5',
'Connection': 'keep-alive',
},
params=params,
)
@@ -71,15 +71,15 @@ def search_yacy(
response.raise_for_status() # Raise an exception for HTTP errors.
json_response = response.json()
results = json_response.get("channels", [{}])[0].get("items", [])
sorted_results = sorted(results, key=lambda x: x.get("ranking", 0), reverse=True)
results = json_response.get('channels', [{}])[0].get('items', [])
sorted_results = sorted(results, key=lambda x: x.get('ranking', 0), reverse=True)
if filter_list:
sorted_results = get_filtered_results(sorted_results, filter_list)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
link=result['link'],
title=result.get('title'),
snippet=result.get('description'),
)
for result in sorted_results[:count]
]

View File

@@ -20,14 +20,14 @@ log = logging.getLogger(__name__)
def xml_element_contents_to_string(element: Element) -> str:
buffer = [element.text if element.text else ""]
buffer = [element.text if element.text else '']
for child in element:
buffer.append(xml_element_contents_to_string(child))
buffer.append(element.tail if element.tail else "")
buffer.append(element.tail if element.tail else '')
return "".join(buffer)
return ''.join(buffer)
def search_yandex(
@@ -42,42 +42,38 @@ def search_yandex(
) -> List[SearchResult]:
try:
headers = {
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Authorization": f"Api-Key {yandex_search_api_key}",
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot',
'Authorization': f'Api-Key {yandex_search_api_key}',
}
if user is not None:
headers = include_user_info_headers(headers, user)
chat_id = getattr(request.state, "chat_id", None)
chat_id = getattr(request.state, 'chat_id', None)
if chat_id:
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = str(chat_id)
payload = {} if yandex_search_config == "" else json.loads(yandex_search_config)
payload = {} if yandex_search_config == '' else json.loads(yandex_search_config)
if type(payload.get("query", None)) != dict:
payload["query"] = {}
if type(payload.get('query', None)) != dict:
payload['query'] = {}
if "searchType" not in payload["query"]:
payload["query"]["searchType"] = "SEARCH_TYPE_RU"
if 'searchType' not in payload['query']:
payload['query']['searchType'] = 'SEARCH_TYPE_RU'
payload["query"]["queryText"] = query
payload['query']['queryText'] = query
if type(payload.get("groupSpec", None)) != dict:
payload["groupSpec"] = {}
if type(payload.get('groupSpec', None)) != dict:
payload['groupSpec'] = {}
if "groupMode" not in payload["groupSpec"]:
payload["groupSpec"]["groupMode"] = "GROUP_MODE_DEEP"
if 'groupMode' not in payload['groupSpec']:
payload['groupSpec']['groupMode'] = 'GROUP_MODE_DEEP'
payload["groupSpec"]["groupsOnPage"] = count
payload["groupSpec"]["docsInGroup"] = 1
payload['groupSpec']['groupsOnPage'] = count
payload['groupSpec']['docsInGroup'] = 1
response = requests.post(
(
"https://searchapi.api.cloud.yandex.net/v2/web/search"
if yandex_search_url == ""
else yandex_search_url
),
('https://searchapi.api.cloud.yandex.net/v2/web/search' if yandex_search_url == '' else yandex_search_url),
headers=headers,
json=payload,
)
@@ -85,29 +81,21 @@ def search_yandex(
response.raise_for_status()
response_body = response.json()
if "rawData" not in response_body:
raise Exception(f"No `rawData` in response body: {response_body}")
if 'rawData' not in response_body:
raise Exception(f'No `rawData` in response body: {response_body}')
search_result_body_bytes = base64.decodebytes(
bytes(response_body["rawData"], "utf-8")
)
search_result_body_bytes = base64.decodebytes(bytes(response_body['rawData'], 'utf-8'))
doc_root = ET.parse(io.BytesIO(search_result_body_bytes))
results = []
for group in doc_root.findall("response/results/grouping/group"):
for group in doc_root.findall('response/results/grouping/group'):
results.append(
{
"url": xml_element_contents_to_string(group.find("doc/url")).strip(
"\n"
),
"title": xml_element_contents_to_string(
group.find("doc/title")
).strip("\n"),
"snippet": xml_element_contents_to_string(
group.find("doc/passages/passage")
),
'url': xml_element_contents_to_string(group.find('doc/url')).strip('\n'),
'title': xml_element_contents_to_string(group.find('doc/title')).strip('\n'),
'snippet': xml_element_contents_to_string(group.find('doc/passages/passage')),
}
)
@@ -115,49 +103,47 @@ def search_yandex(
results = [
SearchResult(
link=result.get("url"),
title=result.get("title"),
snippet=result.get("snippet"),
link=result.get('url'),
title=result.get('title'),
snippet=result.get('snippet'),
)
for result in results[:count]
]
log.info(f"Yandex search results: {results}")
log.info(f'Yandex search results: {results}')
return results
except Exception as e:
log.error(f"Error in search: {e}")
log.error(f'Error in search: {e}')
return []
if __name__ == "__main__":
if __name__ == '__main__':
from starlette.datastructures import Headers
from fastapi import FastAPI
result = search_yandex(
Request(
{
"type": "http",
"asgi.version": "3.0",
"asgi.spec_version": "2.0",
"method": "GET",
"path": "/internal",
"query_string": b"",
"headers": Headers({}).raw,
"client": ("127.0.0.1", 12345),
"server": ("127.0.0.1", 80),
"scheme": "http",
"app": FastAPI(),
'type': 'http',
'asgi.version': '3.0',
'asgi.spec_version': '2.0',
'method': 'GET',
'path': '/internal',
'query_string': b'',
'headers': Headers({}).raw,
'client': ('127.0.0.1', 12345),
'server': ('127.0.0.1', 80),
'scheme': 'http',
'app': FastAPI(),
},
None,
),
os.environ.get("YANDEX_WEB_SEARCH_URL", ""),
os.environ.get("YANDEX_WEB_SEARCH_API_KEY", ""),
os.environ.get(
"YANDEX_WEB_SEARCH_CONFIG", '{"query": {"searchType": "SEARCH_TYPE_COM"}}'
),
"TOP movies of the past year",
os.environ.get('YANDEX_WEB_SEARCH_URL', ''),
os.environ.get('YANDEX_WEB_SEARCH_API_KEY', ''),
os.environ.get('YANDEX_WEB_SEARCH_CONFIG', '{"query": {"searchType": "SEARCH_TYPE_COM"}}'),
'TOP movies of the past year',
3,
)

View File

@@ -12,7 +12,7 @@ def search_youcom(
query: str,
count: int,
filter_list: Optional[List[str]] = None,
language: str = "EN",
language: str = 'EN',
) -> List[SearchResult]:
"""Search using You.com's YDC Index API and return the results as a list of SearchResult objects.
@@ -23,30 +23,30 @@ def search_youcom(
filter_list (list[str], optional): Domain filter list
language (str): Language code for search results (default: "EN")
"""
url = "https://ydc-index.io/v1/search"
url = 'https://ydc-index.io/v1/search'
headers = {
"Accept": "application/json",
"X-API-KEY": api_key,
'Accept': 'application/json',
'X-API-KEY': api_key,
}
params = {
"query": query,
"count": count,
"language": language,
'query': query,
'count': count,
'language': language,
}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("results", {}).get("web", [])
results = json_response.get('results', {}).get('web', [])
if filter_list:
results = get_filtered_results(results, filter_list)
return [
SearchResult(
link=result["url"],
title=result.get("title"),
link=result['url'],
title=result.get('title'),
snippet=_build_snippet(result),
)
for result in results[:count]
@@ -62,12 +62,12 @@ def _build_snippet(result: dict) -> str:
"""
parts: list[str] = []
description = result.get("description")
description = result.get('description')
if description:
parts.append(description)
snippets = result.get("snippets")
snippets = result.get('snippets')
if snippets and isinstance(snippets, list):
parts.extend(snippets)
return "\n\n".join(parts)
return '\n\n'.join(parts)