mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-26 01:25:34 +02:00
refac
This commit is contained in:
@@ -40,78 +40,76 @@ class DatalabMarkerLoader:
|
||||
self.output_format = output_format
|
||||
|
||||
def _get_mime_type(self, filename: str) -> str:
|
||||
ext = filename.rsplit(".", 1)[-1].lower()
|
||||
ext = filename.rsplit('.', 1)[-1].lower()
|
||||
mime_map = {
|
||||
"pdf": "application/pdf",
|
||||
"xls": "application/vnd.ms-excel",
|
||||
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"ods": "application/vnd.oasis.opendocument.spreadsheet",
|
||||
"doc": "application/msword",
|
||||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"odt": "application/vnd.oasis.opendocument.text",
|
||||
"ppt": "application/vnd.ms-powerpoint",
|
||||
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"odp": "application/vnd.oasis.opendocument.presentation",
|
||||
"html": "text/html",
|
||||
"epub": "application/epub+zip",
|
||||
"png": "image/png",
|
||||
"jpeg": "image/jpeg",
|
||||
"jpg": "image/jpeg",
|
||||
"webp": "image/webp",
|
||||
"gif": "image/gif",
|
||||
"tiff": "image/tiff",
|
||||
'pdf': 'application/pdf',
|
||||
'xls': 'application/vnd.ms-excel',
|
||||
'xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'ods': 'application/vnd.oasis.opendocument.spreadsheet',
|
||||
'doc': 'application/msword',
|
||||
'docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'odt': 'application/vnd.oasis.opendocument.text',
|
||||
'ppt': 'application/vnd.ms-powerpoint',
|
||||
'pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'odp': 'application/vnd.oasis.opendocument.presentation',
|
||||
'html': 'text/html',
|
||||
'epub': 'application/epub+zip',
|
||||
'png': 'image/png',
|
||||
'jpeg': 'image/jpeg',
|
||||
'jpg': 'image/jpeg',
|
||||
'webp': 'image/webp',
|
||||
'gif': 'image/gif',
|
||||
'tiff': 'image/tiff',
|
||||
}
|
||||
return mime_map.get(ext, "application/octet-stream")
|
||||
return mime_map.get(ext, 'application/octet-stream')
|
||||
|
||||
def check_marker_request_status(self, request_id: str) -> dict:
|
||||
url = f"{self.api_base_url}/{request_id}"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
url = f'{self.api_base_url}/{request_id}'
|
||||
headers = {'X-Api-Key': self.api_key}
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
log.info(f"Marker API status check for request {request_id}: {result}")
|
||||
log.info(f'Marker API status check for request {request_id}: {result}')
|
||||
return result
|
||||
except requests.HTTPError as e:
|
||||
log.error(f"Error checking Marker request status: {e}")
|
||||
log.error(f'Error checking Marker request status: {e}')
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to check Marker request: {e}",
|
||||
detail=f'Failed to check Marker request: {e}',
|
||||
)
|
||||
except ValueError as e:
|
||||
log.error(f"Invalid JSON checking Marker request: {e}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
|
||||
)
|
||||
log.error(f'Invalid JSON checking Marker request: {e}')
|
||||
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON: {e}')
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
filename = os.path.basename(self.file_path)
|
||||
mime_type = self._get_mime_type(filename)
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
headers = {'X-Api-Key': self.api_key}
|
||||
|
||||
form_data = {
|
||||
"use_llm": str(self.use_llm).lower(),
|
||||
"skip_cache": str(self.skip_cache).lower(),
|
||||
"force_ocr": str(self.force_ocr).lower(),
|
||||
"paginate": str(self.paginate).lower(),
|
||||
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
|
||||
"disable_image_extraction": str(self.disable_image_extraction).lower(),
|
||||
"format_lines": str(self.format_lines).lower(),
|
||||
"output_format": self.output_format,
|
||||
'use_llm': str(self.use_llm).lower(),
|
||||
'skip_cache': str(self.skip_cache).lower(),
|
||||
'force_ocr': str(self.force_ocr).lower(),
|
||||
'paginate': str(self.paginate).lower(),
|
||||
'strip_existing_ocr': str(self.strip_existing_ocr).lower(),
|
||||
'disable_image_extraction': str(self.disable_image_extraction).lower(),
|
||||
'format_lines': str(self.format_lines).lower(),
|
||||
'output_format': self.output_format,
|
||||
}
|
||||
|
||||
if self.additional_config and self.additional_config.strip():
|
||||
form_data["additional_config"] = self.additional_config
|
||||
form_data['additional_config'] = self.additional_config
|
||||
|
||||
log.info(
|
||||
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (filename, f, mime_type)}
|
||||
with open(self.file_path, 'rb') as f:
|
||||
files = {'file': (filename, f, mime_type)}
|
||||
response = requests.post(
|
||||
f"{self.api_base_url}",
|
||||
f'{self.api_base_url}',
|
||||
data=form_data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
@@ -119,29 +117,25 @@ class DatalabMarkerLoader:
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Datalab Marker request failed: {e}",
|
||||
detail=f'Datalab Marker request failed: {e}',
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
|
||||
)
|
||||
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON response: {e}')
|
||||
except Exception as e:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
if not result.get("success"):
|
||||
if not result.get('success'):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
|
||||
detail=f'Datalab Marker request failed: {result.get("error", "Unknown error")}',
|
||||
)
|
||||
|
||||
check_url = result.get("request_check_url")
|
||||
request_id = result.get("request_id")
|
||||
check_url = result.get('request_check_url')
|
||||
request_id = result.get('request_id')
|
||||
|
||||
# Check if this is a direct response (self-hosted) or polling response (DataLab)
|
||||
if check_url:
|
||||
@@ -154,54 +148,45 @@ class DatalabMarkerLoader:
|
||||
poll_result = poll_response.json()
|
||||
except (requests.HTTPError, ValueError) as e:
|
||||
raw_body = poll_response.text
|
||||
log.error(f"Polling error: {e}, response body: {raw_body}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
|
||||
)
|
||||
log.error(f'Polling error: {e}, response body: {raw_body}')
|
||||
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Polling failed: {e}')
|
||||
|
||||
status_val = poll_result.get("status")
|
||||
success_val = poll_result.get("success")
|
||||
status_val = poll_result.get('status')
|
||||
success_val = poll_result.get('success')
|
||||
|
||||
if status_val == "complete":
|
||||
if status_val == 'complete':
|
||||
summary = {
|
||||
k: poll_result.get(k)
|
||||
for k in (
|
||||
"status",
|
||||
"output_format",
|
||||
"success",
|
||||
"error",
|
||||
"page_count",
|
||||
"total_cost",
|
||||
'status',
|
||||
'output_format',
|
||||
'success',
|
||||
'error',
|
||||
'page_count',
|
||||
'total_cost',
|
||||
)
|
||||
}
|
||||
log.info(
|
||||
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
|
||||
)
|
||||
log.info(f'Marker processing completed successfully: {json.dumps(summary, indent=2)}')
|
||||
break
|
||||
|
||||
if status_val == "failed" or success_val is False:
|
||||
log.error(
|
||||
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
|
||||
)
|
||||
error_msg = (
|
||||
poll_result.get("error")
|
||||
or "Marker returned failure without error message"
|
||||
)
|
||||
if status_val == 'failed' or success_val is False:
|
||||
log.error(f'Marker poll failed full response: {json.dumps(poll_result, indent=2)}')
|
||||
error_msg = poll_result.get('error') or 'Marker returned failure without error message'
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Marker processing failed: {error_msg}",
|
||||
detail=f'Marker processing failed: {error_msg}',
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="Marker processing timed out",
|
||||
detail='Marker processing timed out',
|
||||
)
|
||||
|
||||
if not poll_result.get("success", False):
|
||||
error_msg = poll_result.get("error") or "Unknown processing error"
|
||||
if not poll_result.get('success', False):
|
||||
error_msg = poll_result.get('error') or 'Unknown processing error'
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Final processing failed: {error_msg}",
|
||||
detail=f'Final processing failed: {error_msg}',
|
||||
)
|
||||
|
||||
# DataLab format - content in format-specific fields
|
||||
@@ -210,69 +195,65 @@ class DatalabMarkerLoader:
|
||||
final_result = poll_result
|
||||
else:
|
||||
# Self-hosted direct response - content in "output" field
|
||||
if "output" in result:
|
||||
log.info("Self-hosted Marker returned direct response without polling")
|
||||
raw_content = result.get("output")
|
||||
if 'output' in result:
|
||||
log.info('Self-hosted Marker returned direct response without polling')
|
||||
raw_content = result.get('output')
|
||||
final_result = result
|
||||
else:
|
||||
available_fields = (
|
||||
list(result.keys())
|
||||
if isinstance(result, dict)
|
||||
else "non-dict response"
|
||||
)
|
||||
available_fields = list(result.keys()) if isinstance(result, dict) else 'non-dict response'
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
|
||||
)
|
||||
|
||||
if self.output_format.lower() == "json":
|
||||
if self.output_format.lower() == 'json':
|
||||
full_text = json.dumps(raw_content, indent=2)
|
||||
elif self.output_format.lower() in {"markdown", "html"}:
|
||||
elif self.output_format.lower() in {'markdown', 'html'}:
|
||||
full_text = str(raw_content).strip()
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported output format: {self.output_format}",
|
||||
detail=f'Unsupported output format: {self.output_format}',
|
||||
)
|
||||
|
||||
if not full_text:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="Marker returned empty content",
|
||||
detail='Marker returned empty content',
|
||||
)
|
||||
|
||||
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
|
||||
marker_output_dir = os.path.join('/app/backend/data/uploads', 'marker_output')
|
||||
os.makedirs(marker_output_dir, exist_ok=True)
|
||||
|
||||
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
|
||||
file_ext = file_ext_map.get(self.output_format.lower(), "txt")
|
||||
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
|
||||
file_ext_map = {'markdown': 'md', 'json': 'json', 'html': 'html'}
|
||||
file_ext = file_ext_map.get(self.output_format.lower(), 'txt')
|
||||
output_filename = f'{os.path.splitext(filename)[0]}.{file_ext}'
|
||||
output_path = os.path.join(marker_output_dir, output_filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(full_text)
|
||||
log.info(f"Saved Marker output to: {output_path}")
|
||||
log.info(f'Saved Marker output to: {output_path}')
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to write marker output to disk: {e}")
|
||||
log.warning(f'Failed to write marker output to disk: {e}')
|
||||
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"output_format": final_result.get("output_format", self.output_format),
|
||||
"page_count": final_result.get("page_count", 0),
|
||||
"processed_with_llm": self.use_llm,
|
||||
"request_id": request_id or "",
|
||||
'source': filename,
|
||||
'output_format': final_result.get('output_format', self.output_format),
|
||||
'page_count': final_result.get('page_count', 0),
|
||||
'processed_with_llm': self.use_llm,
|
||||
'request_id': request_id or '',
|
||||
}
|
||||
|
||||
images = final_result.get("images", {})
|
||||
images = final_result.get('images', {})
|
||||
if images:
|
||||
metadata["image_count"] = len(images)
|
||||
metadata["images"] = json.dumps(list(images.keys()))
|
||||
metadata['image_count'] = len(images)
|
||||
metadata['images'] = json.dumps(list(images.keys()))
|
||||
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, (dict, list)):
|
||||
metadata[k] = json.dumps(v)
|
||||
elif v is None:
|
||||
metadata[k] = ""
|
||||
metadata[k] = ''
|
||||
|
||||
return [Document(page_content=full_text, metadata=metadata)]
|
||||
|
||||
@@ -29,18 +29,18 @@ class ExternalDocumentLoader(BaseLoader):
|
||||
self.user = user
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
with open(self.file_path, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
headers = {}
|
||||
if self.mime_type is not None:
|
||||
headers["Content-Type"] = self.mime_type
|
||||
headers['Content-Type'] = self.mime_type
|
||||
|
||||
if self.api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
headers['Authorization'] = f'Bearer {self.api_key}'
|
||||
|
||||
try:
|
||||
headers["X-Filename"] = quote(os.path.basename(self.file_path))
|
||||
headers['X-Filename'] = quote(os.path.basename(self.file_path))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -48,24 +48,23 @@ class ExternalDocumentLoader(BaseLoader):
|
||||
headers = include_user_info_headers(headers, self.user)
|
||||
|
||||
url = self.url
|
||||
if url.endswith("/"):
|
||||
if url.endswith('/'):
|
||||
url = url[:-1]
|
||||
|
||||
try:
|
||||
response = requests.put(f"{url}/process", data=data, headers=headers)
|
||||
response = requests.put(f'{url}/process', data=data, headers=headers)
|
||||
except Exception as e:
|
||||
log.error(f"Error connecting to endpoint: {e}")
|
||||
raise Exception(f"Error connecting to endpoint: {e}")
|
||||
log.error(f'Error connecting to endpoint: {e}')
|
||||
raise Exception(f'Error connecting to endpoint: {e}')
|
||||
|
||||
if response.ok:
|
||||
|
||||
response_data = response.json()
|
||||
if response_data:
|
||||
if isinstance(response_data, dict):
|
||||
return [
|
||||
Document(
|
||||
page_content=response_data.get("page_content"),
|
||||
metadata=response_data.get("metadata"),
|
||||
page_content=response_data.get('page_content'),
|
||||
metadata=response_data.get('metadata'),
|
||||
)
|
||||
]
|
||||
elif isinstance(response_data, list):
|
||||
@@ -73,17 +72,15 @@ class ExternalDocumentLoader(BaseLoader):
|
||||
for document in response_data:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=document.get("page_content"),
|
||||
metadata=document.get("metadata"),
|
||||
page_content=document.get('page_content'),
|
||||
metadata=document.get('metadata'),
|
||||
)
|
||||
)
|
||||
return documents
|
||||
else:
|
||||
raise Exception("Error loading document: Unable to parse content")
|
||||
raise Exception('Error loading document: Unable to parse content')
|
||||
|
||||
else:
|
||||
raise Exception("Error loading document: No content returned")
|
||||
raise Exception('Error loading document: No content returned')
|
||||
else:
|
||||
raise Exception(
|
||||
f"Error loading document: {response.status_code} {response.text}"
|
||||
)
|
||||
raise Exception(f'Error loading document: {response.status_code} {response.text}')
|
||||
|
||||
@@ -30,22 +30,22 @@ class ExternalWebLoader(BaseLoader):
|
||||
response = requests.post(
|
||||
self.external_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
|
||||
"Authorization": f"Bearer {self.external_api_key}",
|
||||
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) External Web Loader',
|
||||
'Authorization': f'Bearer {self.external_api_key}',
|
||||
},
|
||||
json={
|
||||
"urls": urls,
|
||||
'urls': urls,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
for result in results:
|
||||
yield Document(
|
||||
page_content=result.get("page_content", ""),
|
||||
metadata=result.get("metadata", {}),
|
||||
page_content=result.get('page_content', ''),
|
||||
metadata=result.get('metadata', {}),
|
||||
)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.error(f"Error extracting content from batch {urls}: {e}")
|
||||
log.error(f'Error extracting content from batch {urls}: {e}')
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -30,59 +30,59 @@ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
known_source_ext = [
|
||||
"go",
|
||||
"py",
|
||||
"java",
|
||||
"sh",
|
||||
"bat",
|
||||
"ps1",
|
||||
"cmd",
|
||||
"js",
|
||||
"ts",
|
||||
"css",
|
||||
"cpp",
|
||||
"hpp",
|
||||
"h",
|
||||
"c",
|
||||
"cs",
|
||||
"sql",
|
||||
"log",
|
||||
"ini",
|
||||
"pl",
|
||||
"pm",
|
||||
"r",
|
||||
"dart",
|
||||
"dockerfile",
|
||||
"env",
|
||||
"php",
|
||||
"hs",
|
||||
"hsc",
|
||||
"lua",
|
||||
"nginxconf",
|
||||
"conf",
|
||||
"m",
|
||||
"mm",
|
||||
"plsql",
|
||||
"perl",
|
||||
"rb",
|
||||
"rs",
|
||||
"db2",
|
||||
"scala",
|
||||
"bash",
|
||||
"swift",
|
||||
"vue",
|
||||
"svelte",
|
||||
"ex",
|
||||
"exs",
|
||||
"erl",
|
||||
"tsx",
|
||||
"jsx",
|
||||
"hs",
|
||||
"lhs",
|
||||
"json",
|
||||
"yaml",
|
||||
"yml",
|
||||
"toml",
|
||||
'go',
|
||||
'py',
|
||||
'java',
|
||||
'sh',
|
||||
'bat',
|
||||
'ps1',
|
||||
'cmd',
|
||||
'js',
|
||||
'ts',
|
||||
'css',
|
||||
'cpp',
|
||||
'hpp',
|
||||
'h',
|
||||
'c',
|
||||
'cs',
|
||||
'sql',
|
||||
'log',
|
||||
'ini',
|
||||
'pl',
|
||||
'pm',
|
||||
'r',
|
||||
'dart',
|
||||
'dockerfile',
|
||||
'env',
|
||||
'php',
|
||||
'hs',
|
||||
'hsc',
|
||||
'lua',
|
||||
'nginxconf',
|
||||
'conf',
|
||||
'm',
|
||||
'mm',
|
||||
'plsql',
|
||||
'perl',
|
||||
'rb',
|
||||
'rs',
|
||||
'db2',
|
||||
'scala',
|
||||
'bash',
|
||||
'swift',
|
||||
'vue',
|
||||
'svelte',
|
||||
'ex',
|
||||
'exs',
|
||||
'erl',
|
||||
'tsx',
|
||||
'jsx',
|
||||
'hs',
|
||||
'lhs',
|
||||
'json',
|
||||
'yaml',
|
||||
'yml',
|
||||
'toml',
|
||||
]
|
||||
|
||||
|
||||
@@ -99,11 +99,11 @@ class ExcelLoader:
|
||||
xls = pd.ExcelFile(self.file_path)
|
||||
for sheet_name in xls.sheet_names:
|
||||
df = pd.read_excel(xls, sheet_name=sheet_name)
|
||||
text_parts.append(f"Sheet: {sheet_name}\n{df.to_string(index=False)}")
|
||||
text_parts.append(f'Sheet: {sheet_name}\n{df.to_string(index=False)}')
|
||||
return [
|
||||
Document(
|
||||
page_content="\n\n".join(text_parts),
|
||||
metadata={"source": self.file_path},
|
||||
page_content='\n\n'.join(text_parts),
|
||||
metadata={'source': self.file_path},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -125,11 +125,11 @@ class PptxLoader:
|
||||
if shape.has_text_frame:
|
||||
slide_texts.append(shape.text_frame.text)
|
||||
if slide_texts:
|
||||
text_parts.append(f"Slide {i}:\n" + "\n".join(slide_texts))
|
||||
text_parts.append(f'Slide {i}:\n' + '\n'.join(slide_texts))
|
||||
return [
|
||||
Document(
|
||||
page_content="\n\n".join(text_parts),
|
||||
metadata={"source": self.file_path},
|
||||
page_content='\n\n'.join(text_parts),
|
||||
metadata={'source': self.file_path},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -143,41 +143,41 @@ class TikaLoader:
|
||||
self.extract_images = extract_images
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
with open(self.file_path, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
if self.mime_type is not None:
|
||||
headers = {"Content-Type": self.mime_type}
|
||||
headers = {'Content-Type': self.mime_type}
|
||||
else:
|
||||
headers = {}
|
||||
|
||||
if self.extract_images == True:
|
||||
headers["X-Tika-PDFextractInlineImages"] = "true"
|
||||
headers['X-Tika-PDFextractInlineImages'] = 'true'
|
||||
|
||||
endpoint = self.url
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint += "/"
|
||||
endpoint += "tika/text"
|
||||
if not endpoint.endswith('/'):
|
||||
endpoint += '/'
|
||||
endpoint += 'tika/text'
|
||||
|
||||
r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY)
|
||||
|
||||
if r.ok:
|
||||
raw_metadata = r.json()
|
||||
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
|
||||
text = raw_metadata.get('X-TIKA:content', '<No text content found>').strip()
|
||||
|
||||
if "Content-Type" in raw_metadata:
|
||||
headers["Content-Type"] = raw_metadata["Content-Type"]
|
||||
if 'Content-Type' in raw_metadata:
|
||||
headers['Content-Type'] = raw_metadata['Content-Type']
|
||||
|
||||
log.debug("Tika extracted text: %s", text)
|
||||
log.debug('Tika extracted text: %s', text)
|
||||
|
||||
return [Document(page_content=text, metadata=headers)]
|
||||
else:
|
||||
raise Exception(f"Error calling Tika: {r.reason}")
|
||||
raise Exception(f'Error calling Tika: {r.reason}')
|
||||
|
||||
|
||||
class DoclingLoader:
|
||||
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
|
||||
self.url = url.rstrip("/")
|
||||
self.url = url.rstrip('/')
|
||||
self.api_key = api_key
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
@@ -185,199 +185,183 @@ class DoclingLoader:
|
||||
self.params = params or {}
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
with open(self.file_path, 'rb') as f:
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["X-Api-Key"] = f"{self.api_key}"
|
||||
headers['X-Api-Key'] = f'{self.api_key}'
|
||||
|
||||
r = requests.post(
|
||||
f"{self.url}/v1/convert/file",
|
||||
f'{self.url}/v1/convert/file',
|
||||
files={
|
||||
"files": (
|
||||
'files': (
|
||||
self.file_path,
|
||||
f,
|
||||
self.mime_type or "application/octet-stream",
|
||||
self.mime_type or 'application/octet-stream',
|
||||
)
|
||||
},
|
||||
data={
|
||||
"image_export_mode": "placeholder",
|
||||
'image_export_mode': 'placeholder',
|
||||
**self.params,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
if r.ok:
|
||||
result = r.json()
|
||||
document_data = result.get("document", {})
|
||||
text = document_data.get("md_content", "<No text content found>")
|
||||
document_data = result.get('document', {})
|
||||
text = document_data.get('md_content', '<No text content found>')
|
||||
|
||||
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
|
||||
metadata = {'Content-Type': self.mime_type} if self.mime_type else {}
|
||||
|
||||
log.debug("Docling extracted text: %s", text)
|
||||
log.debug('Docling extracted text: %s', text)
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
else:
|
||||
error_msg = f"Error calling Docling API: {r.reason}"
|
||||
error_msg = f'Error calling Docling API: {r.reason}'
|
||||
if r.text:
|
||||
try:
|
||||
error_data = r.json()
|
||||
if "detail" in error_data:
|
||||
error_msg += f" - {error_data['detail']}"
|
||||
if 'detail' in error_data:
|
||||
error_msg += f' - {error_data["detail"]}'
|
||||
except Exception:
|
||||
error_msg += f" - {r.text}"
|
||||
raise Exception(f"Error calling Docling: {error_msg}")
|
||||
error_msg += f' - {r.text}'
|
||||
raise Exception(f'Error calling Docling: {error_msg}')
|
||||
|
||||
|
||||
class Loader:
|
||||
def __init__(self, engine: str = "", **kwargs):
|
||||
def __init__(self, engine: str = '', **kwargs):
|
||||
self.engine = engine
|
||||
self.user = kwargs.get("user", None)
|
||||
self.user = kwargs.get('user', None)
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load(
|
||||
self, filename: str, file_content_type: str, file_path: str
|
||||
) -> list[Document]:
|
||||
def load(self, filename: str, file_content_type: str, file_path: str) -> list[Document]:
|
||||
loader = self._get_loader(filename, file_content_type, file_path)
|
||||
docs = loader.load()
|
||||
|
||||
return [
|
||||
Document(
|
||||
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
return [Document(page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata) for doc in docs]
|
||||
|
||||
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
|
||||
return file_ext in known_source_ext or (
|
||||
file_content_type
|
||||
and file_content_type.find("text/") >= 0
|
||||
and file_content_type.find('text/') >= 0
|
||||
# Avoid text/html files being detected as text
|
||||
and not file_content_type.find("html") >= 0
|
||||
and not file_content_type.find('html') >= 0
|
||||
)
|
||||
|
||||
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
|
||||
file_ext = filename.split(".")[-1].lower()
|
||||
file_ext = filename.split('.')[-1].lower()
|
||||
|
||||
if (
|
||||
self.engine == "external"
|
||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
|
||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
|
||||
self.engine == 'external'
|
||||
and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL')
|
||||
and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY')
|
||||
):
|
||||
loader = ExternalDocumentLoader(
|
||||
file_path=file_path,
|
||||
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
|
||||
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
|
||||
url=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL'),
|
||||
api_key=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY'),
|
||||
mime_type=file_content_type,
|
||||
user=self.user,
|
||||
)
|
||||
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
||||
elif self.engine == 'tika' and self.kwargs.get('TIKA_SERVER_URL'):
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
loader = TikaLoader(
|
||||
url=self.kwargs.get("TIKA_SERVER_URL"),
|
||||
url=self.kwargs.get('TIKA_SERVER_URL'),
|
||||
file_path=file_path,
|
||||
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
||||
extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
|
||||
)
|
||||
elif (
|
||||
self.engine == "datalab_marker"
|
||||
and self.kwargs.get("DATALAB_MARKER_API_KEY")
|
||||
self.engine == 'datalab_marker'
|
||||
and self.kwargs.get('DATALAB_MARKER_API_KEY')
|
||||
and file_ext
|
||||
in [
|
||||
"pdf",
|
||||
"xls",
|
||||
"xlsx",
|
||||
"ods",
|
||||
"doc",
|
||||
"docx",
|
||||
"odt",
|
||||
"ppt",
|
||||
"pptx",
|
||||
"odp",
|
||||
"html",
|
||||
"epub",
|
||||
"png",
|
||||
"jpeg",
|
||||
"jpg",
|
||||
"webp",
|
||||
"gif",
|
||||
"tiff",
|
||||
'pdf',
|
||||
'xls',
|
||||
'xlsx',
|
||||
'ods',
|
||||
'doc',
|
||||
'docx',
|
||||
'odt',
|
||||
'ppt',
|
||||
'pptx',
|
||||
'odp',
|
||||
'html',
|
||||
'epub',
|
||||
'png',
|
||||
'jpeg',
|
||||
'jpg',
|
||||
'webp',
|
||||
'gif',
|
||||
'tiff',
|
||||
]
|
||||
):
|
||||
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "")
|
||||
if not api_base_url or api_base_url.strip() == "":
|
||||
api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
|
||||
api_base_url = self.kwargs.get('DATALAB_MARKER_API_BASE_URL', '')
|
||||
if not api_base_url or api_base_url.strip() == '':
|
||||
api_base_url = 'https://www.datalab.to/api/v1/marker' # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
|
||||
|
||||
loader = DatalabMarkerLoader(
|
||||
file_path=file_path,
|
||||
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
|
||||
api_key=self.kwargs['DATALAB_MARKER_API_KEY'],
|
||||
api_base_url=api_base_url,
|
||||
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"),
|
||||
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
|
||||
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
|
||||
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
|
||||
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
|
||||
strip_existing_ocr=self.kwargs.get(
|
||||
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
|
||||
),
|
||||
disable_image_extraction=self.kwargs.get(
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
|
||||
),
|
||||
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
|
||||
output_format=self.kwargs.get(
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
|
||||
),
|
||||
additional_config=self.kwargs.get('DATALAB_MARKER_ADDITIONAL_CONFIG'),
|
||||
use_llm=self.kwargs.get('DATALAB_MARKER_USE_LLM', False),
|
||||
skip_cache=self.kwargs.get('DATALAB_MARKER_SKIP_CACHE', False),
|
||||
force_ocr=self.kwargs.get('DATALAB_MARKER_FORCE_OCR', False),
|
||||
paginate=self.kwargs.get('DATALAB_MARKER_PAGINATE', False),
|
||||
strip_existing_ocr=self.kwargs.get('DATALAB_MARKER_STRIP_EXISTING_OCR', False),
|
||||
disable_image_extraction=self.kwargs.get('DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION', False),
|
||||
format_lines=self.kwargs.get('DATALAB_MARKER_FORMAT_LINES', False),
|
||||
output_format=self.kwargs.get('DATALAB_MARKER_OUTPUT_FORMAT', 'markdown'),
|
||||
)
|
||||
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
|
||||
elif self.engine == 'docling' and self.kwargs.get('DOCLING_SERVER_URL'):
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# Build params for DoclingLoader
|
||||
params = self.kwargs.get("DOCLING_PARAMS", {})
|
||||
params = self.kwargs.get('DOCLING_PARAMS', {})
|
||||
if not isinstance(params, dict):
|
||||
try:
|
||||
params = json.loads(params)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
|
||||
log.error('Invalid DOCLING_PARAMS format, expected JSON object')
|
||||
params = {}
|
||||
|
||||
loader = DoclingLoader(
|
||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||
api_key=self.kwargs.get("DOCLING_API_KEY", None),
|
||||
url=self.kwargs.get('DOCLING_SERVER_URL'),
|
||||
api_key=self.kwargs.get('DOCLING_API_KEY', None),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
params=params,
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||
self.engine == 'document_intelligence'
|
||||
and self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT') != ''
|
||||
and (
|
||||
file_ext in ["pdf", "docx", "ppt", "pptx"]
|
||||
file_ext in ['pdf', 'docx', 'ppt', 'pptx']
|
||||
or file_content_type
|
||||
in [
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
]
|
||||
)
|
||||
):
|
||||
if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "":
|
||||
if self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY') != '':
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
||||
api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
|
||||
api_key=self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY'),
|
||||
api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
|
||||
)
|
||||
else:
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
|
||||
azure_credential=DefaultAzureCredential(),
|
||||
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
||||
api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
|
||||
)
|
||||
elif self.engine == "mineru" and file_ext in [
|
||||
"pdf"
|
||||
]: # MinerU currently only supports PDF
|
||||
|
||||
mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300)
|
||||
elif self.engine == 'mineru' and file_ext in ['pdf']: # MinerU currently only supports PDF
|
||||
mineru_timeout = self.kwargs.get('MINERU_API_TIMEOUT', 300)
|
||||
if mineru_timeout:
|
||||
try:
|
||||
mineru_timeout = int(mineru_timeout)
|
||||
@@ -386,111 +370,115 @@ class Loader:
|
||||
|
||||
loader = MinerULoader(
|
||||
file_path=file_path,
|
||||
api_mode=self.kwargs.get("MINERU_API_MODE", "local"),
|
||||
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"),
|
||||
api_key=self.kwargs.get("MINERU_API_KEY", ""),
|
||||
params=self.kwargs.get("MINERU_PARAMS", {}),
|
||||
api_mode=self.kwargs.get('MINERU_API_MODE', 'local'),
|
||||
api_url=self.kwargs.get('MINERU_API_URL', 'http://localhost:8000'),
|
||||
api_key=self.kwargs.get('MINERU_API_KEY', ''),
|
||||
params=self.kwargs.get('MINERU_PARAMS', {}),
|
||||
timeout=mineru_timeout,
|
||||
)
|
||||
elif (
|
||||
self.engine == "mistral_ocr"
|
||||
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
|
||||
and file_ext
|
||||
in ["pdf"] # Mistral OCR currently only supports PDF and images
|
||||
self.engine == 'mistral_ocr'
|
||||
and self.kwargs.get('MISTRAL_OCR_API_KEY') != ''
|
||||
and file_ext in ['pdf'] # Mistral OCR currently only supports PDF and images
|
||||
):
|
||||
loader = MistralLoader(
|
||||
base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"),
|
||||
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"),
|
||||
base_url=self.kwargs.get('MISTRAL_OCR_API_BASE_URL'),
|
||||
api_key=self.kwargs.get('MISTRAL_OCR_API_KEY'),
|
||||
file_path=file_path,
|
||||
)
|
||||
else:
|
||||
if file_ext == "pdf":
|
||||
if file_ext == 'pdf':
|
||||
loader = PyPDFLoader(
|
||||
file_path,
|
||||
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
||||
mode=self.kwargs.get("PDF_LOADER_MODE", "page"),
|
||||
extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
|
||||
mode=self.kwargs.get('PDF_LOADER_MODE', 'page'),
|
||||
)
|
||||
elif file_ext == "csv":
|
||||
elif file_ext == 'csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_ext == "rst":
|
||||
elif file_ext == 'rst':
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredRSTLoader
|
||||
loader = UnstructuredRSTLoader(file_path, mode="elements")
|
||||
|
||||
loader = UnstructuredRSTLoader(file_path, mode='elements')
|
||||
except ImportError:
|
||||
log.warning(
|
||||
"The 'unstructured' package is not installed. "
|
||||
"Falling back to plain text loading for .rst file. "
|
||||
"Install it with: pip install unstructured"
|
||||
'Falling back to plain text loading for .rst file. '
|
||||
'Install it with: pip install unstructured'
|
||||
)
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
elif file_ext == "xml":
|
||||
elif file_ext == 'xml':
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredXMLLoader
|
||||
|
||||
loader = UnstructuredXMLLoader(file_path)
|
||||
except ImportError:
|
||||
log.warning(
|
||||
"The 'unstructured' package is not installed. "
|
||||
"Falling back to plain text loading for .xml file. "
|
||||
"Install it with: pip install unstructured"
|
||||
'Falling back to plain text loading for .xml file. '
|
||||
'Install it with: pip install unstructured'
|
||||
)
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
elif file_ext in ["htm", "html"]:
|
||||
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
|
||||
elif file_ext == "md":
|
||||
elif file_ext in ['htm', 'html']:
|
||||
loader = BSHTMLLoader(file_path, open_encoding='unicode_escape')
|
||||
elif file_ext == 'md':
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
elif file_content_type == "application/epub+zip":
|
||||
elif file_content_type == 'application/epub+zip':
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredEPubLoader
|
||||
|
||||
loader = UnstructuredEPubLoader(file_path)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Processing .epub files requires the 'unstructured' package. "
|
||||
"Install it with: pip install unstructured"
|
||||
'Install it with: pip install unstructured'
|
||||
)
|
||||
elif (
|
||||
file_content_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
or file_ext == "docx"
|
||||
file_content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
||||
or file_ext == 'docx'
|
||||
):
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_content_type in [
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
] or file_ext in ["xls", "xlsx"]:
|
||||
'application/vnd.ms-excel',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
] or file_ext in ['xls', 'xlsx']:
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredExcelLoader
|
||||
|
||||
loader = UnstructuredExcelLoader(file_path)
|
||||
except ImportError:
|
||||
log.warning(
|
||||
"The 'unstructured' package is not installed. "
|
||||
"Falling back to pandas for Excel file loading. "
|
||||
"Install unstructured for better results: pip install unstructured"
|
||||
'Falling back to pandas for Excel file loading. '
|
||||
'Install unstructured for better results: pip install unstructured'
|
||||
)
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_content_type in [
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
] or file_ext in ["ppt", "pptx"]:
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
] or file_ext in ['ppt', 'pptx']:
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredPowerPointLoader
|
||||
|
||||
loader = UnstructuredPowerPointLoader(file_path)
|
||||
except ImportError:
|
||||
log.warning(
|
||||
"The 'unstructured' package is not installed. "
|
||||
"Falling back to python-pptx for PowerPoint file loading. "
|
||||
"Install unstructured for better results: pip install unstructured"
|
||||
'Falling back to python-pptx for PowerPoint file loading. '
|
||||
'Install unstructured for better results: pip install unstructured'
|
||||
)
|
||||
loader = PptxLoader(file_path)
|
||||
elif file_ext == "msg":
|
||||
elif file_ext == 'msg':
|
||||
loader = OutlookMessageLoader(file_path)
|
||||
elif file_ext == "odt":
|
||||
elif file_ext == 'odt':
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredODTLoader
|
||||
|
||||
loader = UnstructuredODTLoader(file_path)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Processing .odt files requires the 'unstructured' package. "
|
||||
"Install it with: pip install unstructured"
|
||||
'Install it with: pip install unstructured'
|
||||
)
|
||||
elif self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
@@ -498,4 +486,3 @@ class Loader:
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
@@ -22,37 +22,35 @@ class MinerULoader:
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_mode: str = "local",
|
||||
api_url: str = "http://localhost:8000",
|
||||
api_key: str = "",
|
||||
api_mode: str = 'local',
|
||||
api_url: str = 'http://localhost:8000',
|
||||
api_key: str = '',
|
||||
params: dict = None,
|
||||
timeout: Optional[int] = 300,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.api_mode = api_mode.lower()
|
||||
self.api_url = api_url.rstrip("/")
|
||||
self.api_url = api_url.rstrip('/')
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
|
||||
# Parse params dict with defaults
|
||||
self.params = params or {}
|
||||
self.enable_ocr = params.get("enable_ocr", False)
|
||||
self.enable_formula = params.get("enable_formula", True)
|
||||
self.enable_table = params.get("enable_table", True)
|
||||
self.language = params.get("language", "en")
|
||||
self.model_version = params.get("model_version", "pipeline")
|
||||
self.enable_ocr = params.get('enable_ocr', False)
|
||||
self.enable_formula = params.get('enable_formula', True)
|
||||
self.enable_table = params.get('enable_table', True)
|
||||
self.language = params.get('language', 'en')
|
||||
self.model_version = params.get('model_version', 'pipeline')
|
||||
|
||||
self.page_ranges = self.params.pop("page_ranges", "")
|
||||
self.page_ranges = self.params.pop('page_ranges', '')
|
||||
|
||||
# Validate API mode
|
||||
if self.api_mode not in ["local", "cloud"]:
|
||||
raise ValueError(
|
||||
f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'"
|
||||
)
|
||||
if self.api_mode not in ['local', 'cloud']:
|
||||
raise ValueError(f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'")
|
||||
|
||||
# Validate Cloud API requirements
|
||||
if self.api_mode == "cloud" and not self.api_key:
|
||||
raise ValueError("API key is required for Cloud API mode")
|
||||
if self.api_mode == 'cloud' and not self.api_key:
|
||||
raise ValueError('API key is required for Cloud API mode')
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
@@ -60,12 +58,12 @@ class MinerULoader:
|
||||
Routes to Cloud or Local API based on api_mode.
|
||||
"""
|
||||
try:
|
||||
if self.api_mode == "cloud":
|
||||
if self.api_mode == 'cloud':
|
||||
return self._load_cloud_api()
|
||||
else:
|
||||
return self._load_local_api()
|
||||
except Exception as e:
|
||||
log.error(f"Error loading document with MinerU: {e}")
|
||||
log.error(f'Error loading document with MinerU: {e}')
|
||||
raise
|
||||
|
||||
def _load_local_api(self) -> List[Document]:
|
||||
@@ -73,14 +71,14 @@ class MinerULoader:
|
||||
Load document using Local API (synchronous).
|
||||
Posts file to /file_parse endpoint and gets immediate response.
|
||||
"""
|
||||
log.info(f"Using MinerU Local API at {self.api_url}")
|
||||
log.info(f'Using MinerU Local API at {self.api_url}')
|
||||
|
||||
filename = os.path.basename(self.file_path)
|
||||
|
||||
# Build form data for Local API
|
||||
form_data = {
|
||||
**self.params,
|
||||
"return_md": "true",
|
||||
'return_md': 'true',
|
||||
}
|
||||
|
||||
# Page ranges (Local API uses start_page_id and end_page_id)
|
||||
@@ -89,18 +87,18 @@ class MinerULoader:
|
||||
# Full page range parsing would require parsing the string
|
||||
log.warning(
|
||||
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
|
||||
"Consider using start_page_id/end_page_id parameters if needed."
|
||||
'Consider using start_page_id/end_page_id parameters if needed.'
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"files": (filename, f, "application/octet-stream")}
|
||||
with open(self.file_path, 'rb') as f:
|
||||
files = {'files': (filename, f, 'application/octet-stream')}
|
||||
|
||||
log.info(f"Sending file to MinerU Local API: {filename}")
|
||||
log.debug(f"Local API parameters: {form_data}")
|
||||
log.info(f'Sending file to MinerU Local API: {filename}')
|
||||
log.debug(f'Local API parameters: {form_data}')
|
||||
|
||||
response = requests.post(
|
||||
f"{self.api_url}/file_parse",
|
||||
f'{self.api_url}/file_parse',
|
||||
data=form_data,
|
||||
files=files,
|
||||
timeout=self.timeout,
|
||||
@@ -108,27 +106,25 @@ class MinerULoader:
|
||||
response.raise_for_status()
|
||||
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
|
||||
except requests.Timeout:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="MinerU Local API request timed out",
|
||||
detail='MinerU Local API request timed out',
|
||||
)
|
||||
except requests.HTTPError as e:
|
||||
error_detail = f"MinerU Local API request failed: {e}"
|
||||
error_detail = f'MinerU Local API request failed: {e}'
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_detail += f" - {error_data}"
|
||||
error_detail += f' - {error_data}'
|
||||
except Exception:
|
||||
error_detail += f" - {e.response.text}"
|
||||
error_detail += f' - {e.response.text}'
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error calling MinerU Local API: {str(e)}",
|
||||
detail=f'Error calling MinerU Local API: {str(e)}',
|
||||
)
|
||||
|
||||
# Parse response
|
||||
@@ -137,41 +133,41 @@ class MinerULoader:
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid JSON response from MinerU Local API: {e}",
|
||||
detail=f'Invalid JSON response from MinerU Local API: {e}',
|
||||
)
|
||||
|
||||
# Extract markdown content from response
|
||||
if "results" not in result:
|
||||
if 'results' not in result:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail="MinerU Local API response missing 'results' field",
|
||||
)
|
||||
|
||||
results = result["results"]
|
||||
results = result['results']
|
||||
if not results:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="MinerU returned empty results",
|
||||
detail='MinerU returned empty results',
|
||||
)
|
||||
|
||||
# Get the first (and typically only) result
|
||||
file_result = list(results.values())[0]
|
||||
markdown_content = file_result.get("md_content", "")
|
||||
markdown_content = file_result.get('md_content', '')
|
||||
|
||||
if not markdown_content:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="MinerU returned empty markdown content",
|
||||
detail='MinerU returned empty markdown content',
|
||||
)
|
||||
|
||||
log.info(f"Successfully parsed document with MinerU Local API: {filename}")
|
||||
log.info(f'Successfully parsed document with MinerU Local API: {filename}')
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"api_mode": "local",
|
||||
"backend": result.get("backend", "unknown"),
|
||||
"version": result.get("version", "unknown"),
|
||||
'source': filename,
|
||||
'api_mode': 'local',
|
||||
'backend': result.get('backend', 'unknown'),
|
||||
'version': result.get('version', 'unknown'),
|
||||
}
|
||||
|
||||
return [Document(page_content=markdown_content, metadata=metadata)]
|
||||
@@ -181,7 +177,7 @@ class MinerULoader:
|
||||
Load document using Cloud API (asynchronous).
|
||||
Uses batch upload endpoint to avoid need for public file URLs.
|
||||
"""
|
||||
log.info(f"Using MinerU Cloud API at {self.api_url}")
|
||||
log.info(f'Using MinerU Cloud API at {self.api_url}')
|
||||
|
||||
filename = os.path.basename(self.file_path)
|
||||
|
||||
@@ -195,17 +191,15 @@ class MinerULoader:
|
||||
result = self._poll_batch_status(batch_id, filename)
|
||||
|
||||
# Step 4: Download and extract markdown from ZIP
|
||||
markdown_content = self._download_and_extract_zip(
|
||||
result["full_zip_url"], filename
|
||||
)
|
||||
markdown_content = self._download_and_extract_zip(result['full_zip_url'], filename)
|
||||
|
||||
log.info(f"Successfully parsed document with MinerU Cloud API: {filename}")
|
||||
log.info(f'Successfully parsed document with MinerU Cloud API: {filename}')
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"api_mode": "cloud",
|
||||
"batch_id": batch_id,
|
||||
'source': filename,
|
||||
'api_mode': 'cloud',
|
||||
'batch_id': batch_id,
|
||||
}
|
||||
|
||||
return [Document(page_content=markdown_content, metadata=metadata)]
|
||||
@@ -216,49 +210,49 @@ class MinerULoader:
|
||||
Returns (batch_id, upload_url).
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
# Build request body
|
||||
request_body = {
|
||||
**self.params,
|
||||
"files": [
|
||||
'files': [
|
||||
{
|
||||
"name": filename,
|
||||
"is_ocr": self.enable_ocr,
|
||||
'name': filename,
|
||||
'is_ocr': self.enable_ocr,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Add page ranges if specified
|
||||
if self.page_ranges:
|
||||
request_body["files"][0]["page_ranges"] = self.page_ranges
|
||||
request_body['files'][0]['page_ranges'] = self.page_ranges
|
||||
|
||||
log.info(f"Requesting upload URL for: {filename}")
|
||||
log.debug(f"Cloud API request body: {request_body}")
|
||||
log.info(f'Requesting upload URL for: {filename}')
|
||||
log.debug(f'Cloud API request body: {request_body}')
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.api_url}/file-urls/batch",
|
||||
f'{self.api_url}/file-urls/batch',
|
||||
headers=headers,
|
||||
json=request_body,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = f"Failed to request upload URL: {e}"
|
||||
error_detail = f'Failed to request upload URL: {e}'
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_detail += f" - {error_data.get('msg', error_data)}"
|
||||
error_detail += f' - {error_data.get("msg", error_data)}'
|
||||
except Exception:
|
||||
error_detail += f" - {e.response.text}"
|
||||
error_detail += f' - {e.response.text}'
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error requesting upload URL: {str(e)}",
|
||||
detail=f'Error requesting upload URL: {str(e)}',
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -266,28 +260,28 @@ class MinerULoader:
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid JSON response: {e}",
|
||||
detail=f'Invalid JSON response: {e}',
|
||||
)
|
||||
|
||||
# Check for API error response
|
||||
if result.get("code") != 0:
|
||||
if result.get('code') != 0:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
|
||||
detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
|
||||
)
|
||||
|
||||
data = result.get("data", {})
|
||||
batch_id = data.get("batch_id")
|
||||
file_urls = data.get("file_urls", [])
|
||||
data = result.get('data', {})
|
||||
batch_id = data.get('batch_id')
|
||||
file_urls = data.get('file_urls', [])
|
||||
|
||||
if not batch_id or not file_urls:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail="MinerU Cloud API response missing batch_id or file_urls",
|
||||
detail='MinerU Cloud API response missing batch_id or file_urls',
|
||||
)
|
||||
|
||||
upload_url = file_urls[0]
|
||||
log.info(f"Received upload URL for batch: {batch_id}")
|
||||
log.info(f'Received upload URL for batch: {batch_id}')
|
||||
|
||||
return batch_id, upload_url
|
||||
|
||||
@@ -295,10 +289,10 @@ class MinerULoader:
|
||||
"""
|
||||
Upload file to presigned URL (no authentication needed).
|
||||
"""
|
||||
log.info(f"Uploading file to presigned URL")
|
||||
log.info(f'Uploading file to presigned URL')
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
with open(self.file_path, 'rb') as f:
|
||||
response = requests.put(
|
||||
upload_url,
|
||||
data=f,
|
||||
@@ -306,26 +300,24 @@ class MinerULoader:
|
||||
)
|
||||
response.raise_for_status()
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
|
||||
except requests.Timeout:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="File upload to presigned URL timed out",
|
||||
detail='File upload to presigned URL timed out',
|
||||
)
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Failed to upload file to presigned URL: {e}",
|
||||
detail=f'Failed to upload file to presigned URL: {e}',
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error uploading file: {str(e)}",
|
||||
detail=f'Error uploading file: {str(e)}',
|
||||
)
|
||||
|
||||
log.info("File uploaded successfully")
|
||||
log.info('File uploaded successfully')
|
||||
|
||||
def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
|
||||
"""
|
||||
@@ -333,35 +325,35 @@ class MinerULoader:
|
||||
Returns the result dict for the file.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
}
|
||||
|
||||
max_iterations = 300 # 10 minutes max (2 seconds per iteration)
|
||||
poll_interval = 2 # seconds
|
||||
|
||||
log.info(f"Polling batch status: {batch_id}")
|
||||
log.info(f'Polling batch status: {batch_id}')
|
||||
|
||||
for iteration in range(max_iterations):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.api_url}/extract-results/batch/{batch_id}",
|
||||
f'{self.api_url}/extract-results/batch/{batch_id}',
|
||||
headers=headers,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = f"Failed to poll batch status: {e}"
|
||||
error_detail = f'Failed to poll batch status: {e}'
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_detail += f" - {error_data.get('msg', error_data)}"
|
||||
error_detail += f' - {error_data.get("msg", error_data)}'
|
||||
except Exception:
|
||||
error_detail += f" - {e.response.text}"
|
||||
error_detail += f' - {e.response.text}'
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error polling batch status: {str(e)}",
|
||||
detail=f'Error polling batch status: {str(e)}',
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -369,58 +361,56 @@ class MinerULoader:
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid JSON response while polling: {e}",
|
||||
detail=f'Invalid JSON response while polling: {e}',
|
||||
)
|
||||
|
||||
# Check for API error response
|
||||
if result.get("code") != 0:
|
||||
if result.get('code') != 0:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
|
||||
detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
|
||||
)
|
||||
|
||||
data = result.get("data", {})
|
||||
extract_result = data.get("extract_result", [])
|
||||
data = result.get('data', {})
|
||||
extract_result = data.get('extract_result', [])
|
||||
|
||||
# Find our file in the batch results
|
||||
file_result = None
|
||||
for item in extract_result:
|
||||
if item.get("file_name") == filename:
|
||||
if item.get('file_name') == filename:
|
||||
file_result = item
|
||||
break
|
||||
|
||||
if not file_result:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"File {filename} not found in batch results",
|
||||
detail=f'File {filename} not found in batch results',
|
||||
)
|
||||
|
||||
state = file_result.get("state")
|
||||
state = file_result.get('state')
|
||||
|
||||
if state == "done":
|
||||
log.info(f"Processing complete for {filename}")
|
||||
if state == 'done':
|
||||
log.info(f'Processing complete for {filename}')
|
||||
return file_result
|
||||
elif state == "failed":
|
||||
error_msg = file_result.get("err_msg", "Unknown error")
|
||||
elif state == 'failed':
|
||||
error_msg = file_result.get('err_msg', 'Unknown error')
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"MinerU processing failed: {error_msg}",
|
||||
detail=f'MinerU processing failed: {error_msg}',
|
||||
)
|
||||
elif state in ["waiting-file", "pending", "running", "converting"]:
|
||||
elif state in ['waiting-file', 'pending', 'running', 'converting']:
|
||||
# Still processing
|
||||
if iteration % 10 == 0: # Log every 20 seconds
|
||||
log.info(
|
||||
f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})"
|
||||
)
|
||||
log.info(f'Processing status: {state} (iteration {iteration + 1}/{max_iterations})')
|
||||
time.sleep(poll_interval)
|
||||
else:
|
||||
log.warning(f"Unknown state: {state}")
|
||||
log.warning(f'Unknown state: {state}')
|
||||
time.sleep(poll_interval)
|
||||
|
||||
# Timeout
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="MinerU processing timed out after 10 minutes",
|
||||
detail='MinerU processing timed out after 10 minutes',
|
||||
)
|
||||
|
||||
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
|
||||
@@ -428,7 +418,7 @@ class MinerULoader:
|
||||
Download ZIP file from CDN and extract markdown content.
|
||||
Returns the markdown content as a string.
|
||||
"""
|
||||
log.info(f"Downloading results from: {zip_url}")
|
||||
log.info(f'Downloading results from: {zip_url}')
|
||||
|
||||
try:
|
||||
response = requests.get(zip_url, timeout=60)
|
||||
@@ -436,23 +426,23 @@ class MinerULoader:
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Failed to download results ZIP: {e}",
|
||||
detail=f'Failed to download results ZIP: {e}',
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error downloading results: {str(e)}",
|
||||
detail=f'Error downloading results: {str(e)}',
|
||||
)
|
||||
|
||||
# Save ZIP to temporary file and extract
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_zip:
|
||||
tmp_zip.write(response.content)
|
||||
tmp_zip_path = tmp_zip.name
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Extract ZIP
|
||||
with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref:
|
||||
with zipfile.ZipFile(tmp_zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(tmp_dir)
|
||||
|
||||
# Find markdown file - search recursively for any .md file
|
||||
@@ -466,33 +456,27 @@ class MinerULoader:
|
||||
full_path = os.path.join(root, file)
|
||||
all_files.append(full_path)
|
||||
# Look for any .md file
|
||||
if file.endswith(".md"):
|
||||
if file.endswith('.md'):
|
||||
found_md_path = full_path
|
||||
log.info(f"Found markdown file at: {full_path}")
|
||||
log.info(f'Found markdown file at: {full_path}')
|
||||
try:
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
with open(full_path, 'r', encoding='utf-8') as f:
|
||||
markdown_content = f.read()
|
||||
if (
|
||||
markdown_content
|
||||
): # Use the first non-empty markdown file
|
||||
if markdown_content: # Use the first non-empty markdown file
|
||||
break
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to read {full_path}: {e}")
|
||||
log.warning(f'Failed to read {full_path}: {e}')
|
||||
if markdown_content:
|
||||
break
|
||||
|
||||
if markdown_content is None:
|
||||
log.error(f"Available files in ZIP: {all_files}")
|
||||
log.error(f'Available files in ZIP: {all_files}')
|
||||
# Try to provide more helpful error message
|
||||
md_files = [f for f in all_files if f.endswith(".md")]
|
||||
md_files = [f for f in all_files if f.endswith('.md')]
|
||||
if md_files:
|
||||
error_msg = (
|
||||
f"Found .md files but couldn't read them: {md_files}"
|
||||
)
|
||||
error_msg = f"Found .md files but couldn't read them: {md_files}"
|
||||
else:
|
||||
error_msg = (
|
||||
f"No .md files found in ZIP. Available files: {all_files}"
|
||||
)
|
||||
error_msg = f'No .md files found in ZIP. Available files: {all_files}'
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=error_msg,
|
||||
@@ -504,21 +488,19 @@ class MinerULoader:
|
||||
except zipfile.BadZipFile as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid ZIP file received: {e}",
|
||||
detail=f'Invalid ZIP file received: {e}',
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error extracting ZIP: {str(e)}",
|
||||
detail=f'Error extracting ZIP: {str(e)}',
|
||||
)
|
||||
|
||||
if not markdown_content:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="Extracted markdown content is empty",
|
||||
detail='Extracted markdown content is empty',
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Successfully extracted markdown content ({len(markdown_content)} characters)"
|
||||
)
|
||||
log.info(f'Successfully extracted markdown content ({len(markdown_content)} characters)')
|
||||
return markdown_content
|
||||
|
||||
@@ -49,13 +49,11 @@ class MistralLoader:
|
||||
enable_debug_logging: Enable detailed debug logs.
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("API key cannot be empty.")
|
||||
raise ValueError('API key cannot be empty.')
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found at {file_path}")
|
||||
raise FileNotFoundError(f'File not found at {file_path}')
|
||||
|
||||
self.base_url = (
|
||||
base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1"
|
||||
)
|
||||
self.base_url = base_url.rstrip('/') if base_url else 'https://api.mistral.ai/v1'
|
||||
self.api_key = api_key
|
||||
self.file_path = file_path
|
||||
self.timeout = timeout
|
||||
@@ -65,18 +63,10 @@ class MistralLoader:
|
||||
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
|
||||
# This prevents long-running OCR operations from affecting quick operations
|
||||
# and improves user experience by failing fast on operations that should be quick
|
||||
self.upload_timeout = min(
|
||||
timeout, 120
|
||||
) # Cap upload at 2 minutes - prevents hanging on large files
|
||||
self.url_timeout = (
|
||||
30 # URL requests should be fast - fail quickly if API is slow
|
||||
)
|
||||
self.ocr_timeout = (
|
||||
timeout # OCR can take the full timeout - this is the heavy operation
|
||||
)
|
||||
self.cleanup_timeout = (
|
||||
30 # Cleanup should be quick - don't hang on file deletion
|
||||
)
|
||||
self.upload_timeout = min(timeout, 120) # Cap upload at 2 minutes - prevents hanging on large files
|
||||
self.url_timeout = 30 # URL requests should be fast - fail quickly if API is slow
|
||||
self.ocr_timeout = timeout # OCR can take the full timeout - this is the heavy operation
|
||||
self.cleanup_timeout = 30 # Cleanup should be quick - don't hang on file deletion
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
|
||||
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
|
||||
@@ -85,8 +75,8 @@ class MistralLoader:
|
||||
|
||||
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'User-Agent': 'OpenWebUI-MistralLoader/2.0', # Helps API provider track usage
|
||||
}
|
||||
|
||||
def _debug_log(self, message: str, *args) -> None:
|
||||
@@ -108,43 +98,39 @@ class MistralLoader:
|
||||
return {} # Return empty dict if no content
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
|
||||
log.error(f'HTTP error occurred: {http_err} - Response: {response.text}')
|
||||
raise
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
log.error(f"Request exception occurred: {req_err}")
|
||||
log.error(f'Request exception occurred: {req_err}')
|
||||
raise
|
||||
except ValueError as json_err: # Includes JSONDecodeError
|
||||
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
|
||||
log.error(f'JSON decode error: {json_err} - Response: {response.text}')
|
||||
raise # Re-raise after logging
|
||||
|
||||
async def _handle_response_async(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> Dict[str, Any]:
|
||||
async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]:
|
||||
"""Async version of response handling with better error info."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
|
||||
# Check content type
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" not in content_type:
|
||||
content_type = response.headers.get('content-type', '')
|
||||
if 'application/json' not in content_type:
|
||||
if response.status == 204:
|
||||
return {}
|
||||
text = await response.text()
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {content_type}, body: {text[:200]}..."
|
||||
)
|
||||
raise ValueError(f'Unexpected content type: {content_type}, body: {text[:200]}...')
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientResponseError as e:
|
||||
error_text = await response.text() if response else "No response"
|
||||
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
|
||||
error_text = await response.text() if response else 'No response'
|
||||
log.error(f'HTTP {e.status}: {e.message} - Response: {error_text[:500]}')
|
||||
raise
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"Client error: {e}")
|
||||
log.error(f'Client error: {e}')
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing response: {e}")
|
||||
log.error(f'Unexpected error processing response: {e}')
|
||||
raise
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
@@ -172,13 +158,11 @@ class MistralLoader:
|
||||
return True # Timeouts might resolve on retry
|
||||
if isinstance(error, requests.exceptions.HTTPError):
|
||||
# Only retry on server errors (5xx) or rate limits (429)
|
||||
if hasattr(error, "response") and error.response is not None:
|
||||
if hasattr(error, 'response') and error.response is not None:
|
||||
status_code = error.response.status_code
|
||||
return status_code >= 500 or status_code == 429
|
||||
return False
|
||||
if isinstance(
|
||||
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
|
||||
):
|
||||
if isinstance(error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)):
|
||||
return True # Async network/timeout errors are retryable
|
||||
if isinstance(error, aiohttp.ClientResponseError):
|
||||
return error.status >= 500 or error.status == 429
|
||||
@@ -204,8 +188,7 @@ class MistralLoader:
|
||||
# Prevents overwhelming the server while ensuring reasonable retry delays
|
||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||
log.warning(
|
||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
||||
f"Retrying in {wait_time}s..."
|
||||
f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
@@ -226,8 +209,7 @@ class MistralLoader:
|
||||
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
|
||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||
log.warning(
|
||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
||||
f"Retrying in {wait_time}s..."
|
||||
f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
|
||||
)
|
||||
await asyncio.sleep(wait_time) # Non-blocking wait
|
||||
|
||||
@@ -240,15 +222,15 @@ class MistralLoader:
|
||||
Although streaming is not enabled for this endpoint, the file is opened
|
||||
in a context manager to minimize memory usage duration.
|
||||
"""
|
||||
log.info("Uploading file to Mistral API")
|
||||
url = f"{self.base_url}/files"
|
||||
log.info('Uploading file to Mistral API')
|
||||
url = f'{self.base_url}/files'
|
||||
|
||||
def upload_request():
|
||||
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
|
||||
# This ensures the file is closed immediately after reading, reducing memory usage
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (self.file_name, f, "application/pdf")}
|
||||
data = {"purpose": "ocr"}
|
||||
with open(self.file_path, 'rb') as f:
|
||||
files = {'file': (self.file_name, f, 'application/pdf')}
|
||||
data = {'purpose': 'ocr'}
|
||||
|
||||
# NOTE: stream=False is required for this endpoint
|
||||
# The Mistral API doesn't support chunked uploads for this endpoint
|
||||
@@ -265,42 +247,38 @@ class MistralLoader:
|
||||
|
||||
try:
|
||||
response_data = self._retry_request_sync(upload_request)
|
||||
file_id = response_data.get("id")
|
||||
file_id = response_data.get('id')
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
||||
raise ValueError('File ID not found in upload response.')
|
||||
log.info(f'File uploaded successfully. File ID: {file_id}')
|
||||
return file_id
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload file: {e}")
|
||||
log.error(f'Failed to upload file: {e}')
|
||||
raise
|
||||
|
||||
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
|
||||
"""Async file upload with streaming for better memory efficiency."""
|
||||
url = f"{self.base_url}/files"
|
||||
url = f'{self.base_url}/files'
|
||||
|
||||
async def upload_request():
|
||||
# Create multipart writer for streaming upload
|
||||
writer = aiohttp.MultipartWriter("form-data")
|
||||
writer = aiohttp.MultipartWriter('form-data')
|
||||
|
||||
# Add purpose field
|
||||
purpose_part = writer.append("ocr")
|
||||
purpose_part.set_content_disposition("form-data", name="purpose")
|
||||
purpose_part = writer.append('ocr')
|
||||
purpose_part.set_content_disposition('form-data', name='purpose')
|
||||
|
||||
# Add file part with streaming
|
||||
file_part = writer.append_payload(
|
||||
aiohttp.streams.FilePayload(
|
||||
self.file_path,
|
||||
filename=self.file_name,
|
||||
content_type="application/pdf",
|
||||
content_type='application/pdf',
|
||||
)
|
||||
)
|
||||
file_part.set_content_disposition(
|
||||
"form-data", name="file", filename=self.file_name
|
||||
)
|
||||
file_part.set_content_disposition('form-data', name='file', filename=self.file_name)
|
||||
|
||||
self._debug_log(
|
||||
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
|
||||
)
|
||||
self._debug_log(f'Uploading file: {self.file_name} ({self.file_size:,} bytes)')
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
@@ -312,48 +290,44 @@ class MistralLoader:
|
||||
|
||||
response_data = await self._retry_request_async(upload_request)
|
||||
|
||||
file_id = response_data.get("id")
|
||||
file_id = response_data.get('id')
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
raise ValueError('File ID not found in upload response.')
|
||||
|
||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
||||
log.info(f'File uploaded successfully. File ID: {file_id}')
|
||||
return file_id
|
||||
|
||||
def _get_signed_url(self, file_id: str) -> str:
|
||||
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
|
||||
log.info(f"Getting signed URL for file ID: {file_id}")
|
||||
url = f"{self.base_url}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
signed_url_headers = {**self.headers, "Accept": "application/json"}
|
||||
log.info(f'Getting signed URL for file ID: {file_id}')
|
||||
url = f'{self.base_url}/files/{file_id}/url'
|
||||
params = {'expiry': 1}
|
||||
signed_url_headers = {**self.headers, 'Accept': 'application/json'}
|
||||
|
||||
def url_request():
|
||||
response = requests.get(
|
||||
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
|
||||
)
|
||||
response = requests.get(url, headers=signed_url_headers, params=params, timeout=self.url_timeout)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response_data = self._retry_request_sync(url_request)
|
||||
signed_url = response_data.get("url")
|
||||
signed_url = response_data.get('url')
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
log.info("Signed URL received.")
|
||||
raise ValueError('Signed URL not found in response.')
|
||||
log.info('Signed URL received.')
|
||||
return signed_url
|
||||
except Exception as e:
|
||||
log.error(f"Failed to get signed URL: {e}")
|
||||
log.error(f'Failed to get signed URL: {e}')
|
||||
raise
|
||||
|
||||
async def _get_signed_url_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> str:
|
||||
async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str:
|
||||
"""Async signed URL retrieval."""
|
||||
url = f"{self.base_url}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
url = f'{self.base_url}/files/{file_id}/url'
|
||||
params = {'expiry': 1}
|
||||
|
||||
headers = {**self.headers, "Accept": "application/json"}
|
||||
headers = {**self.headers, 'Accept': 'application/json'}
|
||||
|
||||
async def url_request():
|
||||
self._debug_log(f"Getting signed URL for file ID: {file_id}")
|
||||
self._debug_log(f'Getting signed URL for file ID: {file_id}')
|
||||
async with session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
@@ -364,69 +338,65 @@ class MistralLoader:
|
||||
|
||||
response_data = await self._retry_request_async(url_request)
|
||||
|
||||
signed_url = response_data.get("url")
|
||||
signed_url = response_data.get('url')
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
raise ValueError('Signed URL not found in response.')
|
||||
|
||||
self._debug_log("Signed URL received successfully")
|
||||
self._debug_log('Signed URL received successfully')
|
||||
return signed_url
|
||||
|
||||
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
|
||||
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
|
||||
log.info("Processing OCR via Mistral API")
|
||||
url = f"{self.base_url}/ocr"
|
||||
log.info('Processing OCR via Mistral API')
|
||||
url = f'{self.base_url}/ocr'
|
||||
ocr_headers = {
|
||||
**self.headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
payload = {
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": signed_url,
|
||||
'model': 'mistral-ocr-latest',
|
||||
'document': {
|
||||
'type': 'document_url',
|
||||
'document_url': signed_url,
|
||||
},
|
||||
"include_image_base64": False,
|
||||
'include_image_base64': False,
|
||||
}
|
||||
|
||||
def ocr_request():
|
||||
response = requests.post(
|
||||
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
|
||||
)
|
||||
response = requests.post(url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
ocr_response = self._retry_request_sync(ocr_request)
|
||||
log.info("OCR processing done.")
|
||||
self._debug_log("OCR response: %s", ocr_response)
|
||||
log.info('OCR processing done.')
|
||||
self._debug_log('OCR response: %s', ocr_response)
|
||||
return ocr_response
|
||||
except Exception as e:
|
||||
log.error(f"Failed during OCR processing: {e}")
|
||||
log.error(f'Failed during OCR processing: {e}')
|
||||
raise
|
||||
|
||||
async def _process_ocr_async(
|
||||
self, session: aiohttp.ClientSession, signed_url: str
|
||||
) -> Dict[str, Any]:
|
||||
async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]:
|
||||
"""Async OCR processing with timing metrics."""
|
||||
url = f"{self.base_url}/ocr"
|
||||
url = f'{self.base_url}/ocr'
|
||||
|
||||
headers = {
|
||||
**self.headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": signed_url,
|
||||
'model': 'mistral-ocr-latest',
|
||||
'document': {
|
||||
'type': 'document_url',
|
||||
'document_url': signed_url,
|
||||
},
|
||||
"include_image_base64": False,
|
||||
'include_image_base64': False,
|
||||
}
|
||||
|
||||
async def ocr_request():
|
||||
log.info("Starting OCR processing via Mistral API")
|
||||
log.info('Starting OCR processing via Mistral API')
|
||||
start_time = time.time()
|
||||
|
||||
async with session.post(
|
||||
@@ -438,7 +408,7 @@ class MistralLoader:
|
||||
ocr_response = await self._handle_response_async(response)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
log.info(f"OCR processing completed in {processing_time:.2f}s")
|
||||
log.info(f'OCR processing completed in {processing_time:.2f}s')
|
||||
|
||||
return ocr_response
|
||||
|
||||
@@ -446,42 +416,36 @@ class MistralLoader:
|
||||
|
||||
def _delete_file(self, file_id: str) -> None:
|
||||
"""Deletes the file from Mistral storage (sync version)."""
|
||||
log.info(f"Deleting uploaded file ID: {file_id}")
|
||||
url = f"{self.base_url}/files/{file_id}"
|
||||
log.info(f'Deleting uploaded file ID: {file_id}')
|
||||
url = f'{self.base_url}/files/{file_id}'
|
||||
|
||||
try:
|
||||
response = requests.delete(
|
||||
url, headers=self.headers, timeout=self.cleanup_timeout
|
||||
)
|
||||
response = requests.delete(url, headers=self.headers, timeout=self.cleanup_timeout)
|
||||
delete_response = self._handle_response(response)
|
||||
log.info(f"File deleted successfully: {delete_response}")
|
||||
log.info(f'File deleted successfully: {delete_response}')
|
||||
except Exception as e:
|
||||
# Log error but don't necessarily halt execution if deletion fails
|
||||
log.error(f"Failed to delete file ID {file_id}: {e}")
|
||||
log.error(f'Failed to delete file ID {file_id}: {e}')
|
||||
|
||||
async def _delete_file_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> None:
|
||||
async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None:
|
||||
"""Async file deletion with error tolerance."""
|
||||
try:
|
||||
|
||||
async def delete_request():
|
||||
self._debug_log(f"Deleting file ID: {file_id}")
|
||||
self._debug_log(f'Deleting file ID: {file_id}')
|
||||
async with session.delete(
|
||||
url=f"{self.base_url}/files/{file_id}",
|
||||
url=f'{self.base_url}/files/{file_id}',
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.cleanup_timeout
|
||||
), # Shorter timeout for cleanup
|
||||
timeout=aiohttp.ClientTimeout(total=self.cleanup_timeout), # Shorter timeout for cleanup
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
await self._retry_request_async(delete_request)
|
||||
self._debug_log(f"File {file_id} deleted successfully")
|
||||
self._debug_log(f'File {file_id} deleted successfully')
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail the entire process if cleanup fails
|
||||
log.warning(f"Failed to delete file ID {file_id}: {e}")
|
||||
log.warning(f'Failed to delete file ID {file_id}: {e}')
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_session(self):
|
||||
@@ -506,7 +470,7 @@ class MistralLoader:
|
||||
async with aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
|
||||
headers={'User-Agent': 'OpenWebUI-MistralLoader/2.0'},
|
||||
raise_for_status=False, # We handle status codes manually
|
||||
trust_env=True,
|
||||
) as session:
|
||||
@@ -514,13 +478,13 @@ class MistralLoader:
|
||||
|
||||
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
|
||||
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
|
||||
pages_data = ocr_response.get("pages")
|
||||
pages_data = ocr_response.get('pages')
|
||||
if not pages_data:
|
||||
log.warning("No pages found in OCR response.")
|
||||
log.warning('No pages found in OCR response.')
|
||||
return [
|
||||
Document(
|
||||
page_content="No text content found",
|
||||
metadata={"error": "no_pages", "file_name": self.file_name},
|
||||
page_content='No text content found',
|
||||
metadata={'error': 'no_pages', 'file_name': self.file_name},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -530,8 +494,8 @@ class MistralLoader:
|
||||
|
||||
# Process pages in a memory-efficient way
|
||||
for page_data in pages_data:
|
||||
page_content = page_data.get("markdown")
|
||||
page_index = page_data.get("index") # API uses 0-based index
|
||||
page_content = page_data.get('markdown')
|
||||
page_index = page_data.get('index') # API uses 0-based index
|
||||
|
||||
if page_content is None or page_index is None:
|
||||
skipped_pages += 1
|
||||
@@ -548,7 +512,7 @@ class MistralLoader:
|
||||
|
||||
if not cleaned_content:
|
||||
skipped_pages += 1
|
||||
self._debug_log(f"Skipping empty page {page_index}")
|
||||
self._debug_log(f'Skipping empty page {page_index}')
|
||||
continue
|
||||
|
||||
# Create document with optimized metadata
|
||||
@@ -556,34 +520,30 @@ class MistralLoader:
|
||||
Document(
|
||||
page_content=cleaned_content,
|
||||
metadata={
|
||||
"page": page_index, # 0-based index from API
|
||||
"page_label": page_index + 1, # 1-based label for convenience
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"processing_engine": "mistral-ocr",
|
||||
"content_length": len(cleaned_content),
|
||||
'page': page_index, # 0-based index from API
|
||||
'page_label': page_index + 1, # 1-based label for convenience
|
||||
'total_pages': total_pages,
|
||||
'file_name': self.file_name,
|
||||
'file_size': self.file_size,
|
||||
'processing_engine': 'mistral-ocr',
|
||||
'content_length': len(cleaned_content),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if skipped_pages > 0:
|
||||
log.info(
|
||||
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
|
||||
)
|
||||
log.info(f'Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages')
|
||||
|
||||
if not documents:
|
||||
# Case where pages existed but none had valid markdown/index
|
||||
log.warning(
|
||||
"OCR response contained pages, but none had valid content/index."
|
||||
)
|
||||
log.warning('OCR response contained pages, but none had valid content/index.')
|
||||
return [
|
||||
Document(
|
||||
page_content="No valid text content found in document",
|
||||
page_content='No valid text content found in document',
|
||||
metadata={
|
||||
"error": "no_valid_pages",
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
'error': 'no_valid_pages',
|
||||
'total_pages': total_pages,
|
||||
'file_name': self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -615,24 +575,20 @@ class MistralLoader:
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
log.info(f'Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
log.error(
|
||||
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
|
||||
)
|
||||
log.error(f'An error occurred during the loading process after {total_time:.2f}s: {e}')
|
||||
# Return an error document on failure
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during processing: {e}",
|
||||
page_content=f'Error during processing: {e}',
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
'error': 'processing_failed',
|
||||
'file_name': self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -643,9 +599,7 @@ class MistralLoader:
|
||||
self._delete_file(file_id)
|
||||
except Exception as del_e:
|
||||
# Log deletion error, but don't overwrite original error if one occurred
|
||||
log.error(
|
||||
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
|
||||
)
|
||||
log.error(f'Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}')
|
||||
|
||||
async def load_async(self) -> List[Document]:
|
||||
"""
|
||||
@@ -672,21 +626,19 @@ class MistralLoader:
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
log.info(f'Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
|
||||
log.error(f'Async OCR workflow failed after {total_time:.2f}s: {e}')
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during OCR processing: {e}",
|
||||
page_content=f'Error during OCR processing: {e}',
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
'error': 'processing_failed',
|
||||
'file_name': self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -697,11 +649,11 @@ class MistralLoader:
|
||||
async with self._get_session() as session:
|
||||
await self._delete_file_async(session, file_id)
|
||||
except Exception as cleanup_error:
|
||||
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
|
||||
log.error(f'Cleanup failed for file ID {file_id}: {cleanup_error}')
|
||||
|
||||
@staticmethod
|
||||
async def load_multiple_async(
|
||||
loaders: List["MistralLoader"],
|
||||
loaders: List['MistralLoader'],
|
||||
max_concurrent: int = 5, # Limit concurrent requests
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
@@ -717,15 +669,13 @@ class MistralLoader:
|
||||
if not loaders:
|
||||
return []
|
||||
|
||||
log.info(
|
||||
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
|
||||
)
|
||||
log.info(f'Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent')
|
||||
start_time = time.time()
|
||||
|
||||
# Use semaphore to control concurrency
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
|
||||
async def process_with_semaphore(loader: 'MistralLoader') -> List[Document]:
|
||||
async with semaphore:
|
||||
return await loader.load_async()
|
||||
|
||||
@@ -737,14 +687,14 @@ class MistralLoader:
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"File {i} failed: {result}")
|
||||
log.error(f'File {i} failed: {result}')
|
||||
processed_results.append(
|
||||
[
|
||||
Document(
|
||||
page_content=f"Error processing file: {result}",
|
||||
page_content=f'Error processing file: {result}',
|
||||
metadata={
|
||||
"error": "batch_processing_failed",
|
||||
"file_index": i,
|
||||
'error': 'batch_processing_failed',
|
||||
'file_index': i,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -755,15 +705,13 @@ class MistralLoader:
|
||||
# MONITORING: Log comprehensive batch processing statistics
|
||||
total_time = time.time() - start_time
|
||||
total_docs = sum(len(docs) for docs in processed_results)
|
||||
success_count = sum(
|
||||
1 for result in results if not isinstance(result, Exception)
|
||||
)
|
||||
success_count = sum(1 for result in results if not isinstance(result, Exception))
|
||||
failure_count = len(results) - success_count
|
||||
|
||||
log.info(
|
||||
f"Batch processing completed in {total_time:.2f}s: "
|
||||
f"{success_count} files succeeded, {failure_count} files failed, "
|
||||
f"produced {total_docs} total documents"
|
||||
f'Batch processing completed in {total_time:.2f}s: '
|
||||
f'{success_count} files succeeded, {failure_count} files failed, '
|
||||
f'produced {total_docs} total documents'
|
||||
)
|
||||
|
||||
return processed_results
|
||||
|
||||
@@ -25,7 +25,7 @@ class TavilyLoader(BaseLoader):
|
||||
self,
|
||||
urls: Union[str, List[str]],
|
||||
api_key: str,
|
||||
extract_depth: Literal["basic", "advanced"] = "basic",
|
||||
extract_depth: Literal['basic', 'advanced'] = 'basic',
|
||||
continue_on_failure: bool = True,
|
||||
) -> None:
|
||||
"""Initialize Tavily Extract client.
|
||||
@@ -42,13 +42,13 @@ class TavilyLoader(BaseLoader):
|
||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||
"""
|
||||
if not urls:
|
||||
raise ValueError("At least one URL must be provided.")
|
||||
raise ValueError('At least one URL must be provided.')
|
||||
|
||||
self.api_key = api_key
|
||||
self.urls = urls if isinstance(urls, list) else [urls]
|
||||
self.extract_depth = extract_depth
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.api_url = "https://api.tavily.com/extract"
|
||||
self.api_url = 'https://api.tavily.com/extract'
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Extract and yield documents from the URLs using Tavily Extract API."""
|
||||
@@ -57,35 +57,35 @@ class TavilyLoader(BaseLoader):
|
||||
batch_urls = self.urls[i : i + batch_size]
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
}
|
||||
# Use string for single URL, array for multiple URLs
|
||||
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
|
||||
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
|
||||
payload = {'urls': urls_param, 'extract_depth': self.extract_depth}
|
||||
# Make the API call
|
||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
# Process successful results
|
||||
for result in response_data.get("results", []):
|
||||
url = result.get("url", "")
|
||||
content = result.get("raw_content", "")
|
||||
for result in response_data.get('results', []):
|
||||
url = result.get('url', '')
|
||||
content = result.get('raw_content', '')
|
||||
if not content:
|
||||
log.warning(f"No content extracted from {url}")
|
||||
log.warning(f'No content extracted from {url}')
|
||||
continue
|
||||
# Add URLs as metadata
|
||||
metadata = {"source": url}
|
||||
metadata = {'source': url}
|
||||
yield Document(
|
||||
page_content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
for failed in response_data.get("failed_results", []):
|
||||
url = failed.get("url", "")
|
||||
error = failed.get("error", "Unknown error")
|
||||
log.error(f"Failed to extract content from {url}: {error}")
|
||||
for failed in response_data.get('failed_results', []):
|
||||
url = failed.get('url', '')
|
||||
error = failed.get('error', 'Unknown error')
|
||||
log.error(f'Failed to extract content from {url}: {error}')
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.error(f"Error extracting content from batch {batch_urls}: {e}")
|
||||
log.error(f'Error extracting content from batch {batch_urls}: {e}')
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -7,14 +7,14 @@ from langchain_core.documents import Document
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_SCHEMES = {"http", "https"}
|
||||
ALLOWED_SCHEMES = {'http', 'https'}
|
||||
ALLOWED_NETLOCS = {
|
||||
"youtu.be",
|
||||
"m.youtube.com",
|
||||
"youtube.com",
|
||||
"www.youtube.com",
|
||||
"www.youtube-nocookie.com",
|
||||
"vid.plus",
|
||||
'youtu.be',
|
||||
'm.youtube.com',
|
||||
'youtube.com',
|
||||
'www.youtube.com',
|
||||
'www.youtube-nocookie.com',
|
||||
'vid.plus',
|
||||
}
|
||||
|
||||
|
||||
@@ -30,17 +30,17 @@ def _parse_video_id(url: str) -> Optional[str]:
|
||||
|
||||
path = parsed_url.path
|
||||
|
||||
if path.endswith("/watch"):
|
||||
if path.endswith('/watch'):
|
||||
query = parsed_url.query
|
||||
parsed_query = parse_qs(query)
|
||||
if "v" in parsed_query:
|
||||
ids = parsed_query["v"]
|
||||
if 'v' in parsed_query:
|
||||
ids = parsed_query['v']
|
||||
video_id = ids if isinstance(ids, str) else ids[0]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
path = parsed_url.path.lstrip("/")
|
||||
video_id = path.split("/")[-1]
|
||||
path = parsed_url.path.lstrip('/')
|
||||
video_id = path.split('/')[-1]
|
||||
|
||||
if len(video_id) != 11: # Video IDs are 11 characters long
|
||||
return None
|
||||
@@ -54,13 +54,13 @@ class YoutubeLoader:
|
||||
def __init__(
|
||||
self,
|
||||
video_id: str,
|
||||
language: Union[str, Sequence[str]] = "en",
|
||||
language: Union[str, Sequence[str]] = 'en',
|
||||
proxy_url: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with YouTube video ID."""
|
||||
_video_id = _parse_video_id(video_id)
|
||||
self.video_id = _video_id if _video_id is not None else video_id
|
||||
self._metadata = {"source": video_id}
|
||||
self._metadata = {'source': video_id}
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
# Ensure language is a list
|
||||
@@ -70,8 +70,8 @@ class YoutubeLoader:
|
||||
self.language = list(language)
|
||||
|
||||
# Add English as fallback if not already in the list
|
||||
if "en" not in self.language:
|
||||
self.language.append("en")
|
||||
if 'en' not in self.language:
|
||||
self.language.append('en')
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load YouTube transcripts into `Document` objects."""
|
||||
@@ -85,14 +85,12 @@ class YoutubeLoader:
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Could not import "youtube_transcript_api" Python package. '
|
||||
"Please install it with `pip install youtube-transcript-api`."
|
||||
'Please install it with `pip install youtube-transcript-api`.'
|
||||
)
|
||||
|
||||
if self.proxy_url:
|
||||
youtube_proxies = GenericProxyConfig(
|
||||
http_url=self.proxy_url, https_url=self.proxy_url
|
||||
)
|
||||
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
||||
youtube_proxies = GenericProxyConfig(http_url=self.proxy_url, https_url=self.proxy_url)
|
||||
log.debug(f'Using proxy URL: {self.proxy_url[:14]}...')
|
||||
else:
|
||||
youtube_proxies = None
|
||||
|
||||
@@ -100,7 +98,7 @@ class YoutubeLoader:
|
||||
try:
|
||||
transcript_list = transcript_api.list(self.video_id)
|
||||
except Exception as e:
|
||||
log.exception("Loading YouTube transcript failed")
|
||||
log.exception('Loading YouTube transcript failed')
|
||||
return []
|
||||
|
||||
# Try each language in order of priority
|
||||
@@ -110,14 +108,10 @@ class YoutubeLoader:
|
||||
if transcript.is_generated:
|
||||
log.debug(f"Found generated transcript for language '{lang}'")
|
||||
try:
|
||||
transcript = transcript_list.find_manually_created_transcript(
|
||||
[lang]
|
||||
)
|
||||
transcript = transcript_list.find_manually_created_transcript([lang])
|
||||
log.debug(f"Found manual transcript for language '{lang}'")
|
||||
except NoTranscriptFound:
|
||||
log.debug(
|
||||
f"No manual transcript found for language '{lang}', using generated"
|
||||
)
|
||||
log.debug(f"No manual transcript found for language '{lang}', using generated")
|
||||
pass
|
||||
|
||||
log.debug(f"Found transcript for language '{lang}'")
|
||||
@@ -131,12 +125,10 @@ class YoutubeLoader:
|
||||
log.debug(f"Empty transcript for language '{lang}'")
|
||||
continue
|
||||
|
||||
transcript_text = " ".join(
|
||||
transcript_text = ' '.join(
|
||||
map(
|
||||
lambda transcript_piece: (
|
||||
transcript_piece.text.strip(" ")
|
||||
if hasattr(transcript_piece, "text")
|
||||
else ""
|
||||
transcript_piece.text.strip(' ') if hasattr(transcript_piece, 'text') else ''
|
||||
),
|
||||
transcript_pieces,
|
||||
)
|
||||
@@ -150,9 +142,9 @@ class YoutubeLoader:
|
||||
raise e
|
||||
|
||||
# If we get here, all languages failed
|
||||
languages_tried = ", ".join(self.language)
|
||||
languages_tried = ', '.join(self.language)
|
||||
log.warning(
|
||||
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
||||
f'No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed.'
|
||||
)
|
||||
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
||||
|
||||
|
||||
@@ -13,19 +13,17 @@ log = logging.getLogger(__name__)
|
||||
|
||||
class ColBERT(BaseReranker):
|
||||
def __init__(self, name, **kwargs) -> None:
|
||||
log.info("ColBERT: Loading model", name)
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
log.info('ColBERT: Loading model', name)
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
DOCKER = kwargs.get("env") == "docker"
|
||||
DOCKER = kwargs.get('env') == 'docker'
|
||||
if DOCKER:
|
||||
# This is a workaround for the issue with the docker container
|
||||
# where the torch extension is not loaded properly
|
||||
# and the following error is thrown:
|
||||
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
|
||||
|
||||
lock_file = (
|
||||
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
|
||||
)
|
||||
lock_file = '/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock'
|
||||
if os.path.exists(lock_file):
|
||||
os.remove(lock_file)
|
||||
|
||||
@@ -36,23 +34,16 @@ class ColBERT(BaseReranker):
|
||||
pass
|
||||
|
||||
def calculate_similarity_scores(self, query_embeddings, document_embeddings):
|
||||
|
||||
query_embeddings = query_embeddings.to(self.device)
|
||||
document_embeddings = document_embeddings.to(self.device)
|
||||
|
||||
# Validate dimensions to ensure compatibility
|
||||
if query_embeddings.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
|
||||
)
|
||||
raise ValueError(f'Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}.')
|
||||
if document_embeddings.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
|
||||
)
|
||||
raise ValueError(f'Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}.')
|
||||
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
|
||||
raise ValueError(
|
||||
"There should be either one query or queries equal to the number of documents."
|
||||
)
|
||||
raise ValueError('There should be either one query or queries equal to the number of documents.')
|
||||
|
||||
# Transpose the query embeddings to align for matrix multiplication
|
||||
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
|
||||
@@ -69,7 +60,6 @@ class ColBERT(BaseReranker):
|
||||
return normalized_scores.detach().cpu().numpy().astype(np.float32)
|
||||
|
||||
def predict(self, sentences):
|
||||
|
||||
query = sentences[0][0]
|
||||
docs = [i[1] for i in sentences]
|
||||
|
||||
@@ -80,8 +70,6 @@ class ColBERT(BaseReranker):
|
||||
embedded_query = embedded_queries[0]
|
||||
|
||||
# Calculate retrieval scores for the query against all documents
|
||||
scores = self.calculate_similarity_scores(
|
||||
embedded_query.unsqueeze(0), embedded_docs
|
||||
)
|
||||
scores = self.calculate_similarity_scores(embedded_query.unsqueeze(0), embedded_docs)
|
||||
|
||||
return scores
|
||||
|
||||
@@ -15,8 +15,8 @@ class ExternalReranker(BaseReranker):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
url: str = "http://localhost:8080/v1/rerank",
|
||||
model: str = "reranker",
|
||||
url: str = 'http://localhost:8080/v1/rerank',
|
||||
model: str = 'reranker',
|
||||
timeout: Optional[int] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
@@ -24,33 +24,31 @@ class ExternalReranker(BaseReranker):
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
|
||||
def predict(
|
||||
self, sentences: List[Tuple[str, str]], user=None
|
||||
) -> Optional[List[float]]:
|
||||
def predict(self, sentences: List[Tuple[str, str]], user=None) -> Optional[List[float]]:
|
||||
query = sentences[0][0]
|
||||
docs = [i[1] for i in sentences]
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"top_n": len(docs),
|
||||
'model': self.model,
|
||||
'query': query,
|
||||
'documents': docs,
|
||||
'top_n': len(docs),
|
||||
}
|
||||
|
||||
try:
|
||||
log.info(f"ExternalReranker:predict:model {self.model}")
|
||||
log.info(f"ExternalReranker:predict:query {query}")
|
||||
log.info(f'ExternalReranker:predict:model {self.model}')
|
||||
log.info(f'ExternalReranker:predict:query {query}')
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
}
|
||||
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
r = requests.post(
|
||||
f"{self.url}",
|
||||
f'{self.url}',
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
@@ -60,13 +58,13 @@ class ExternalReranker(BaseReranker):
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
if "results" in data:
|
||||
sorted_results = sorted(data["results"], key=lambda x: x["index"])
|
||||
return [result["relevance_score"] for result in sorted_results]
|
||||
if 'results' in data:
|
||||
sorted_results = sorted(data['results'], key=lambda x: x['index'])
|
||||
return [result['relevance_score'] for result in sorted_results]
|
||||
else:
|
||||
log.error("No results found in external reranking response")
|
||||
log.error('No results found in external reranking response')
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error in external reranking: {e}")
|
||||
log.exception(f'Error in external reranking: {e}')
|
||||
return None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -31,17 +31,15 @@ log = logging.getLogger(__name__)
|
||||
class ChromaClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
settings_dict = {
|
||||
"allow_reset": True,
|
||||
"anonymized_telemetry": False,
|
||||
'allow_reset': True,
|
||||
'anonymized_telemetry': False,
|
||||
}
|
||||
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
|
||||
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
|
||||
settings_dict['chroma_client_auth_provider'] = CHROMA_CLIENT_AUTH_PROVIDER
|
||||
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
|
||||
settings_dict["chroma_client_auth_credentials"] = (
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS
|
||||
)
|
||||
settings_dict['chroma_client_auth_credentials'] = CHROMA_CLIENT_AUTH_CREDENTIALS
|
||||
|
||||
if CHROMA_HTTP_HOST != "":
|
||||
if CHROMA_HTTP_HOST != '':
|
||||
self.client = chromadb.HttpClient(
|
||||
host=CHROMA_HTTP_HOST,
|
||||
port=CHROMA_HTTP_PORT,
|
||||
@@ -87,25 +85,23 @@ class ChromaClient(VectorDBBase):
|
||||
|
||||
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
|
||||
# https://docs.trychroma.com/docs/collections/configure cosine equation
|
||||
distances: list = result["distances"][0]
|
||||
distances: list = result['distances'][0]
|
||||
distances = [2 - dist for dist in distances]
|
||||
distances = [[dist / 2 for dist in distances]]
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result["ids"],
|
||||
"distances": distances,
|
||||
"documents": result["documents"],
|
||||
"metadatas": result["metadatas"],
|
||||
'ids': result['ids'],
|
||||
'distances': distances,
|
||||
'documents': result['documents'],
|
||||
'metadatas': result['metadatas'],
|
||||
}
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
# Query the items from the collection based on the filter.
|
||||
try:
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
@@ -117,9 +113,9 @@ class ChromaClient(VectorDBBase):
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [result["ids"]],
|
||||
"documents": [result["documents"]],
|
||||
"metadatas": [result["metadatas"]],
|
||||
'ids': [result['ids']],
|
||||
'documents': [result['documents']],
|
||||
'metadatas': [result['metadatas']],
|
||||
}
|
||||
)
|
||||
return None
|
||||
@@ -133,23 +129,21 @@ class ChromaClient(VectorDBBase):
|
||||
result = collection.get()
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [result["ids"]],
|
||||
"documents": [result["documents"]],
|
||||
"metadatas": [result["metadatas"]],
|
||||
'ids': [result['ids']],
|
||||
'documents': [result['documents']],
|
||||
'metadatas': [result['metadatas']],
|
||||
}
|
||||
)
|
||||
return None
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
embeddings = [item["vector"] for item in items]
|
||||
metadatas = [process_metadata(item["metadata"]) for item in items]
|
||||
ids = [item['id'] for item in items]
|
||||
documents = [item['text'] for item in items]
|
||||
embeddings = [item['vector'] for item in items]
|
||||
metadatas = [process_metadata(item['metadata']) for item in items]
|
||||
|
||||
for batch in create_batches(
|
||||
api=self.client,
|
||||
@@ -162,18 +156,14 @@ class ChromaClient(VectorDBBase):
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
embeddings = [item["vector"] for item in items]
|
||||
metadatas = [process_metadata(item["metadata"]) for item in items]
|
||||
ids = [item['id'] for item in items]
|
||||
documents = [item['text'] for item in items]
|
||||
embeddings = [item['vector'] for item in items]
|
||||
metadatas = [process_metadata(item['metadata']) for item in items]
|
||||
|
||||
collection.upsert(
|
||||
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
|
||||
)
|
||||
collection.upsert(ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
@@ -191,9 +181,7 @@ class ChromaClient(VectorDBBase):
|
||||
collection.delete(where=filter)
|
||||
except Exception as e:
|
||||
# If collection doesn't exist, that's fine - nothing to delete
|
||||
log.debug(
|
||||
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
|
||||
)
|
||||
log.debug(f'Attempted to delete from non-existent collection {collection_name}. Ignoring.')
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -51,7 +51,7 @@ class ElasticsearchClient(VectorDBBase):
|
||||
|
||||
# Status: works
|
||||
def _get_index_name(self, dimension: int) -> str:
|
||||
return f"{self.index_prefix}_d{str(dimension)}"
|
||||
return f'{self.index_prefix}_d{str(dimension)}'
|
||||
|
||||
# Status: works
|
||||
def _scan_result_to_get_result(self, result) -> GetResult:
|
||||
@@ -62,24 +62,24 @@ class ElasticsearchClient(VectorDBBase):
|
||||
metadatas = []
|
||||
|
||||
for hit in result:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
ids.append(hit['_id'])
|
||||
documents.append(hit['_source'].get('text'))
|
||||
metadatas.append(hit['_source'].get('metadata'))
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
# Status: works
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
if not result["hits"]["hits"]:
|
||||
if not result['hits']['hits']:
|
||||
return None
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
for hit in result['hits']['hits']:
|
||||
ids.append(hit['_id'])
|
||||
documents.append(hit['_source'].get('text'))
|
||||
metadatas.append(hit['_source'].get('metadata'))
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
@@ -90,11 +90,11 @@ class ElasticsearchClient(VectorDBBase):
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
distances.append(hit["_score"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
for hit in result['hits']['hits']:
|
||||
ids.append(hit['_id'])
|
||||
distances.append(hit['_score'])
|
||||
documents.append(hit['_source'].get('text'))
|
||||
metadatas.append(hit['_source'].get('metadata'))
|
||||
|
||||
return SearchResult(
|
||||
ids=[ids],
|
||||
@@ -106,26 +106,26 @@ class ElasticsearchClient(VectorDBBase):
|
||||
# Status: works
|
||||
def _create_index(self, dimension: int):
|
||||
body = {
|
||||
"mappings": {
|
||||
"dynamic_templates": [
|
||||
'mappings': {
|
||||
'dynamic_templates': [
|
||||
{
|
||||
"strings": {
|
||||
"match_mapping_type": "string",
|
||||
"mapping": {"type": "keyword"},
|
||||
'strings': {
|
||||
'match_mapping_type': 'string',
|
||||
'mapping': {'type': 'keyword'},
|
||||
}
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"collection": {"type": "keyword"},
|
||||
"id": {"type": "keyword"},
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": dimension, # Adjust based on your vector dimensions
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
'properties': {
|
||||
'collection': {'type': 'keyword'},
|
||||
'id': {'type': 'keyword'},
|
||||
'vector': {
|
||||
'type': 'dense_vector',
|
||||
'dims': dimension, # Adjust based on your vector dimensions
|
||||
'index': True,
|
||||
'similarity': 'cosine',
|
||||
},
|
||||
"text": {"type": "text"},
|
||||
"metadata": {"type": "object"},
|
||||
'text': {'type': 'text'},
|
||||
'metadata': {'type': 'object'},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -139,21 +139,19 @@ class ElasticsearchClient(VectorDBBase):
|
||||
|
||||
# Status: works
|
||||
def has_collection(self, collection_name) -> bool:
|
||||
query_body = {"query": {"bool": {"filter": []}}}
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"term": {"collection": collection_name}}
|
||||
)
|
||||
query_body = {'query': {'bool': {'filter': []}}}
|
||||
query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
|
||||
|
||||
try:
|
||||
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
|
||||
result = self.client.count(index=f'{self.index_prefix}*', body=query_body)
|
||||
|
||||
return result.body["count"] > 0
|
||||
return result.body['count'] > 0
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
query = {"query": {"term": {"collection": collection_name}}}
|
||||
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
|
||||
query = {'query': {'term': {'collection': collection_name}}}
|
||||
self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
|
||||
|
||||
# Status: works
|
||||
def search(
|
||||
@@ -164,51 +162,41 @@ class ElasticsearchClient(VectorDBBase):
|
||||
limit: int = 10,
|
||||
) -> Optional[SearchResult]:
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {
|
||||
"bool": {"filter": [{"term": {"collection": collection_name}}]}
|
||||
},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||
"params": {
|
||||
"vector": vectors[0]
|
||||
}, # Assuming single query vector
|
||||
'size': limit,
|
||||
'_source': ['text', 'metadata'],
|
||||
'query': {
|
||||
'script_score': {
|
||||
'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
|
||||
'script': {
|
||||
'source': "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||
'params': {'vector': vectors[0]}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = self.client.search(
|
||||
index=self._get_index_name(len(vectors[0])), body=query
|
||||
)
|
||||
result = self.client.search(index=self._get_index_name(len(vectors[0])), body=query)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
# Status: only tested halfwat
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
"_source": ["text", "metadata"],
|
||||
'query': {'bool': {'filter': []}},
|
||||
'_source': ['text', 'metadata'],
|
||||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"term": {"collection": collection_name}}
|
||||
)
|
||||
query_body['query']['bool']['filter'].append({'term': {field: value}})
|
||||
query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
|
||||
size = limit if limit else 10
|
||||
|
||||
try:
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}*",
|
||||
index=f'{self.index_prefix}*',
|
||||
body=query_body,
|
||||
size=size,
|
||||
)
|
||||
@@ -220,9 +208,7 @@ class ElasticsearchClient(VectorDBBase):
|
||||
|
||||
# Status: works
|
||||
def _has_index(self, dimension: int):
|
||||
return self.client.indices.exists(
|
||||
index=self._get_index_name(dimension=dimension)
|
||||
)
|
||||
return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
|
||||
|
||||
def get_or_create_index(self, dimension: int):
|
||||
if not self._has_index(dimension=dimension):
|
||||
@@ -232,28 +218,28 @@ class ElasticsearchClient(VectorDBBase):
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
query = {
|
||||
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
|
||||
"_source": ["text", "metadata"],
|
||||
'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
|
||||
'_source': ['text', 'metadata'],
|
||||
}
|
||||
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
|
||||
results = list(scan(self.client, index=f'{self.index_prefix}*', query=query))
|
||||
|
||||
return self._scan_result_to_get_result(results)
|
||||
|
||||
# Status: works
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||
self._create_index(dimension=len(items[0]["vector"]))
|
||||
if not self._has_index(dimension=len(items[0]['vector'])):
|
||||
self._create_index(dimension=len(items[0]['vector']))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"collection": collection_name,
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": process_metadata(item["metadata"]),
|
||||
'_index': self._get_index_name(dimension=len(items[0]['vector'])),
|
||||
'_id': item['id'],
|
||||
'_source': {
|
||||
'collection': collection_name,
|
||||
'vector': item['vector'],
|
||||
'text': item['text'],
|
||||
'metadata': process_metadata(item['metadata']),
|
||||
},
|
||||
}
|
||||
for item in batch
|
||||
@@ -262,21 +248,21 @@ class ElasticsearchClient(VectorDBBase):
|
||||
|
||||
# Upsert documents using the update API with doc_as_upsert=True.
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||
self._create_index(dimension=len(items[0]["vector"]))
|
||||
if not self._has_index(dimension=len(items[0]['vector'])):
|
||||
self._create_index(dimension=len(items[0]['vector']))
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"_op_type": "update",
|
||||
"_index": self._get_index_name(dimension=len(item["vector"])),
|
||||
"_id": item["id"],
|
||||
"doc": {
|
||||
"collection": collection_name,
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": process_metadata(item["metadata"]),
|
||||
'_op_type': 'update',
|
||||
'_index': self._get_index_name(dimension=len(item['vector'])),
|
||||
'_id': item['id'],
|
||||
'doc': {
|
||||
'collection': collection_name,
|
||||
'vector': item['vector'],
|
||||
'text': item['text'],
|
||||
'metadata': process_metadata(item['metadata']),
|
||||
},
|
||||
"doc_as_upsert": True,
|
||||
'doc_as_upsert': True,
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
@@ -289,22 +275,17 @@ class ElasticsearchClient(VectorDBBase):
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
|
||||
query = {
|
||||
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
|
||||
}
|
||||
query = {'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}}}
|
||||
# logic based on chromaDB
|
||||
if ids:
|
||||
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
|
||||
query['query']['bool']['filter'].append({'terms': {'_id': ids}})
|
||||
elif filter:
|
||||
for field, value in filter.items():
|
||||
query["query"]["bool"]["filter"].append(
|
||||
{"term": {f"metadata.{field}": value}}
|
||||
)
|
||||
query['query']['bool']['filter'].append({'term': {f'metadata.{field}': value}})
|
||||
|
||||
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
|
||||
self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
|
||||
|
||||
def reset(self):
|
||||
indices = self.client.indices.get(index=f"{self.index_prefix}*")
|
||||
indices = self.client.indices.get(index=f'{self.index_prefix}*')
|
||||
for index in indices:
|
||||
self.client.indices.delete(index=index)
|
||||
|
||||
@@ -47,8 +47,8 @@ def _embedding_to_f32_bytes(vec: List[float]) -> bytes:
|
||||
byte sequence. We use array('f') to avoid a numpy dependency and byteswap on
|
||||
big-endian platforms for portability.
|
||||
"""
|
||||
a = array.array("f", [float(x) for x in vec]) # float32
|
||||
if sys.byteorder != "little":
|
||||
a = array.array('f', [float(x) for x in vec]) # float32
|
||||
if sys.byteorder != 'little':
|
||||
a.byteswap()
|
||||
return a.tobytes()
|
||||
|
||||
@@ -68,7 +68,7 @@ def _safe_json(v: Any) -> Dict[str, Any]:
|
||||
return v
|
||||
if isinstance(v, (bytes, bytearray)):
|
||||
try:
|
||||
v = v.decode("utf-8")
|
||||
v = v.decode('utf-8')
|
||||
except Exception:
|
||||
return {}
|
||||
if isinstance(v, str):
|
||||
@@ -105,16 +105,16 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
"""
|
||||
self.db_url = (db_url or MARIADB_VECTOR_DB_URL).strip()
|
||||
self.vector_length = int(vector_length)
|
||||
self.distance_strategy = (distance_strategy or "cosine").strip().lower()
|
||||
self.distance_strategy = (distance_strategy or 'cosine').strip().lower()
|
||||
self.index_m = int(index_m)
|
||||
|
||||
if self.distance_strategy not in {"cosine", "euclidean"}:
|
||||
if self.distance_strategy not in {'cosine', 'euclidean'}:
|
||||
raise ValueError("distance_strategy must be 'cosine' or 'euclidean'")
|
||||
|
||||
if not self.db_url.lower().startswith("mariadb+mariadbconnector://"):
|
||||
if not self.db_url.lower().startswith('mariadb+mariadbconnector://'):
|
||||
raise ValueError(
|
||||
"MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) "
|
||||
"to ensure qmark paramstyle and correct VECTOR binding."
|
||||
'MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) '
|
||||
'to ensure qmark paramstyle and correct VECTOR binding.'
|
||||
)
|
||||
|
||||
if isinstance(MARIADB_VECTOR_POOL_SIZE, int):
|
||||
@@ -129,9 +129,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
poolclass=QueuePool,
|
||||
)
|
||||
else:
|
||||
self.engine = create_engine(
|
||||
self.db_url, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
self.engine = create_engine(self.db_url, pool_pre_ping=True, poolclass=NullPool)
|
||||
else:
|
||||
self.engine = create_engine(self.db_url, pool_pre_ping=True)
|
||||
self._init_schema()
|
||||
@@ -185,7 +183,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
log.exception(f"Error during database initialization: {e}")
|
||||
log.exception(f'Error during database initialization: {e}')
|
||||
raise
|
||||
|
||||
def _check_vector_length(self) -> None:
|
||||
@@ -197,19 +195,19 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
"""
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SHOW CREATE TABLE document_chunk")
|
||||
cur.execute('SHOW CREATE TABLE document_chunk')
|
||||
row = cur.fetchone()
|
||||
if not row or len(row) < 2:
|
||||
return
|
||||
ddl = row[1]
|
||||
m = re.search(r"vector\\((\\d+)\\)", ddl, flags=re.IGNORECASE)
|
||||
m = re.search(r'vector\\((\\d+)\\)', ddl, flags=re.IGNORECASE)
|
||||
if not m:
|
||||
return
|
||||
existing = int(m.group(1))
|
||||
if existing != int(self.vector_length):
|
||||
raise Exception(
|
||||
f"VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. "
|
||||
"Cannot change vector size after initialization without migrating the data."
|
||||
f'VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. '
|
||||
'Cannot change vector size after initialization without migrating the data.'
|
||||
)
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
@@ -227,11 +225,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
"""
|
||||
Return the MariaDB Vector distance function name for the configured strategy.
|
||||
"""
|
||||
return (
|
||||
"vec_distance_cosine"
|
||||
if self.distance_strategy == "cosine"
|
||||
else "vec_distance_euclidean"
|
||||
)
|
||||
return 'vec_distance_cosine' if self.distance_strategy == 'cosine' else 'vec_distance_euclidean'
|
||||
|
||||
def _score_from_dist(self, dist: float) -> float:
|
||||
"""
|
||||
@@ -240,7 +234,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
- cosine: score ~= 1 - cosine_distance, clamped to [0, 1]
|
||||
- euclidean: score = 1 / (1 + dist)
|
||||
"""
|
||||
if self.distance_strategy == "cosine":
|
||||
if self.distance_strategy == 'cosine':
|
||||
score = 1.0 - dist
|
||||
if score < 0.0:
|
||||
score = 0.0
|
||||
@@ -260,48 +254,48 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
- {"$or": [ ... ]}
|
||||
"""
|
||||
if not expr or not isinstance(expr, dict):
|
||||
return "", []
|
||||
return '', []
|
||||
|
||||
if "$and" in expr:
|
||||
if '$and' in expr:
|
||||
parts: List[str] = []
|
||||
params: List[Any] = []
|
||||
for e in expr.get("$and") or []:
|
||||
for e in expr.get('$and') or []:
|
||||
s, p = self._build_filter_sql_qmark(e)
|
||||
if s:
|
||||
parts.append(s)
|
||||
params.extend(p)
|
||||
return ("(" + " AND ".join(parts) + ")") if parts else "", params
|
||||
return ('(' + ' AND '.join(parts) + ')') if parts else '', params
|
||||
|
||||
if "$or" in expr:
|
||||
if '$or' in expr:
|
||||
parts: List[str] = []
|
||||
params: List[Any] = []
|
||||
for e in expr.get("$or") or []:
|
||||
for e in expr.get('$or') or []:
|
||||
s, p = self._build_filter_sql_qmark(e)
|
||||
if s:
|
||||
parts.append(s)
|
||||
params.extend(p)
|
||||
return ("(" + " OR ".join(parts) + ")") if parts else "", params
|
||||
return ('(' + ' OR '.join(parts) + ')') if parts else '', params
|
||||
|
||||
clauses: List[str] = []
|
||||
params: List[Any] = []
|
||||
for key, value in expr.items():
|
||||
if key.startswith("$"):
|
||||
if key.startswith('$'):
|
||||
continue
|
||||
json_expr = f"JSON_UNQUOTE(JSON_EXTRACT(vmetadata, '$.{key}'))"
|
||||
if isinstance(value, dict) and "$in" in value:
|
||||
vals = [str(v) for v in (value.get("$in") or [])]
|
||||
if isinstance(value, dict) and '$in' in value:
|
||||
vals = [str(v) for v in (value.get('$in') or [])]
|
||||
if not vals:
|
||||
clauses.append("0=1")
|
||||
clauses.append('0=1')
|
||||
continue
|
||||
ors = []
|
||||
for v in vals:
|
||||
ors.append(f"{json_expr} = ?")
|
||||
ors.append(f'{json_expr} = ?')
|
||||
params.append(v)
|
||||
clauses.append("(" + " OR ".join(ors) + ")")
|
||||
clauses.append('(' + ' OR '.join(ors) + ')')
|
||||
else:
|
||||
clauses.append(f"{json_expr} = ?")
|
||||
clauses.append(f'{json_expr} = ?')
|
||||
params.append(str(value))
|
||||
return ("(" + " AND ".join(clauses) + ")") if clauses else "", params
|
||||
return ('(' + ' AND '.join(clauses) + ')') if clauses else '', params
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""
|
||||
@@ -322,15 +316,15 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
"""
|
||||
params: List[Tuple[Any, ...]] = []
|
||||
for item in items:
|
||||
v = self.adjust_vector_length(item["vector"])
|
||||
v = self.adjust_vector_length(item['vector'])
|
||||
emb = _embedding_to_f32_bytes(v)
|
||||
meta = process_metadata(item.get("metadata") or {})
|
||||
meta = process_metadata(item.get('metadata') or {})
|
||||
params.append(
|
||||
(
|
||||
item["id"],
|
||||
item['id'],
|
||||
emb,
|
||||
collection_name,
|
||||
item.get("text"),
|
||||
item.get('text'),
|
||||
json.dumps(meta),
|
||||
)
|
||||
)
|
||||
@@ -338,7 +332,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
log.exception(f'Error during insert: {e}')
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -365,15 +359,15 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
"""
|
||||
params: List[Tuple[Any, ...]] = []
|
||||
for item in items:
|
||||
v = self.adjust_vector_length(item["vector"])
|
||||
v = self.adjust_vector_length(item['vector'])
|
||||
emb = _embedding_to_f32_bytes(v)
|
||||
meta = process_metadata(item.get("metadata") or {})
|
||||
meta = process_metadata(item.get('metadata') or {})
|
||||
params.append(
|
||||
(
|
||||
item["id"],
|
||||
item['id'],
|
||||
emb,
|
||||
collection_name,
|
||||
item.get("text"),
|
||||
item.get('text'),
|
||||
json.dumps(meta),
|
||||
)
|
||||
)
|
||||
@@ -381,7 +375,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
log.exception(f'Error during upsert: {e}')
|
||||
raise
|
||||
|
||||
def search(
|
||||
@@ -415,10 +409,10 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
fsql, fparams = self._build_filter_sql_qmark(filter or {})
|
||||
where = "collection_name = ?"
|
||||
where = 'collection_name = ?'
|
||||
base_params: List[Any] = [collection_name]
|
||||
if fsql:
|
||||
where = where + " AND " + fsql
|
||||
where = where + ' AND ' + fsql
|
||||
base_params.extend(fparams)
|
||||
|
||||
sql = f"""
|
||||
@@ -460,26 +454,24 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
metadatas=metadatas,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"[MARIADB_VECTOR] search() failed: {e}")
|
||||
log.exception(f'[MARIADB_VECTOR] search() failed: {e}')
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
"""
|
||||
Retrieve documents by metadata filter (non-vector query).
|
||||
"""
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
fsql, fparams = self._build_filter_sql_qmark(filter or {})
|
||||
where = "collection_name = ?"
|
||||
where = 'collection_name = ?'
|
||||
params: List[Any] = [collection_name]
|
||||
if fsql:
|
||||
where = where + " AND " + fsql
|
||||
where = where + ' AND ' + fsql
|
||||
params.extend(fparams)
|
||||
sql = f"SELECT id, text, vmetadata FROM document_chunk WHERE {where}"
|
||||
sql = f'SELECT id, text, vmetadata FROM document_chunk WHERE {where}'
|
||||
if limit is not None:
|
||||
sql += " LIMIT ?"
|
||||
sql += ' LIMIT ?'
|
||||
params.append(int(limit))
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
@@ -490,18 +482,16 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
metadatas = [[_safe_json(r[2]) for r in rows]]
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
def get(
|
||||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
"""
|
||||
Retrieve documents in a collection without filtering (optionally limited).
|
||||
"""
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
sql = "SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?"
|
||||
sql = 'SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?'
|
||||
params: List[Any] = [collection_name]
|
||||
if limit is not None:
|
||||
sql += " LIMIT ?"
|
||||
sql += ' LIMIT ?'
|
||||
params.append(int(limit))
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
@@ -526,12 +516,12 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
where = ["collection_name = ?"]
|
||||
where = ['collection_name = ?']
|
||||
params: List[Any] = [collection_name]
|
||||
|
||||
if ids:
|
||||
ph = ", ".join(["?"] * len(ids))
|
||||
where.append(f"id IN ({ph})")
|
||||
ph = ', '.join(['?'] * len(ids))
|
||||
where.append(f'id IN ({ph})')
|
||||
params.extend(ids)
|
||||
|
||||
if filter:
|
||||
@@ -540,12 +530,12 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
where.append(fsql)
|
||||
params.extend(fparams)
|
||||
|
||||
sql = "DELETE FROM document_chunk WHERE " + " AND ".join(where)
|
||||
sql = 'DELETE FROM document_chunk WHERE ' + ' AND '.join(where)
|
||||
cur.execute(sql, params)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
log.exception(f"Error during delete: {e}")
|
||||
log.exception(f'Error during delete: {e}')
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -555,11 +545,11 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
cur.execute("TRUNCATE TABLE document_chunk")
|
||||
cur.execute('TRUNCATE TABLE document_chunk')
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
log.exception(f"Error during reset: {e}")
|
||||
log.exception(f'Error during reset: {e}')
|
||||
raise
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
@@ -570,7 +560,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1",
|
||||
'SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1',
|
||||
(collection_name,),
|
||||
)
|
||||
return cur.fetchone() is not None
|
||||
@@ -590,4 +580,4 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
try:
|
||||
self.engine.dispose()
|
||||
except Exception as e:
|
||||
log.exception(f"Error during dispose the underlying SQLAlchemy engine: {e}")
|
||||
log.exception(f'Error during dispose the underlying SQLAlchemy engine: {e}')
|
||||
|
||||
@@ -35,7 +35,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
class MilvusClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open_webui"
|
||||
self.collection_prefix = 'open_webui'
|
||||
if MILVUS_TOKEN is None:
|
||||
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
|
||||
else:
|
||||
@@ -50,17 +50,17 @@ class MilvusClient(VectorDBBase):
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_documents.append(item.get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("metadata"))
|
||||
_ids.append(item.get('id'))
|
||||
_documents.append(item.get('data', {}).get('text'))
|
||||
_metadatas.append(item.get('metadata'))
|
||||
ids.append(_ids)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
'ids': ids,
|
||||
'documents': documents,
|
||||
'metadatas': metadatas,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -75,23 +75,23 @@ class MilvusClient(VectorDBBase):
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_ids.append(item.get('id'))
|
||||
# normalize milvus score from [-1, 1] to [0, 1] range
|
||||
# https://milvus.io/docs/de/metric.md
|
||||
_dist = (item.get("distance") + 1.0) / 2.0
|
||||
_dist = (item.get('distance') + 1.0) / 2.0
|
||||
_distances.append(_dist)
|
||||
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("entity", {}).get("metadata"))
|
||||
_documents.append(item.get('entity', {}).get('data', {}).get('text'))
|
||||
_metadatas.append(item.get('entity', {}).get('metadata'))
|
||||
ids.append(_ids)
|
||||
distances.append(_distances)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
"distances": distances,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
'ids': ids,
|
||||
'distances': distances,
|
||||
'documents': documents,
|
||||
'metadatas': metadatas,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -101,21 +101,19 @@ class MilvusClient(VectorDBBase):
|
||||
enable_dynamic_field=True,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="id",
|
||||
field_name='id',
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=65535,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
field_name='vector',
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
description="vector",
|
||||
)
|
||||
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
|
||||
schema.add_field(
|
||||
field_name="metadata", datatype=DataType.JSON, description="metadata"
|
||||
description='vector',
|
||||
)
|
||||
schema.add_field(field_name='data', datatype=DataType.JSON, description='data')
|
||||
schema.add_field(field_name='metadata', datatype=DataType.JSON, description='metadata')
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
|
||||
@@ -123,44 +121,44 @@ class MilvusClient(VectorDBBase):
|
||||
index_type = MILVUS_INDEX_TYPE.upper()
|
||||
metric_type = MILVUS_METRIC_TYPE.upper()
|
||||
|
||||
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
|
||||
log.info(f'Using Milvus index type: {index_type}, metric type: {metric_type}')
|
||||
|
||||
index_creation_params = {}
|
||||
if index_type == "HNSW":
|
||||
if index_type == 'HNSW':
|
||||
index_creation_params = {
|
||||
"M": MILVUS_HNSW_M,
|
||||
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
|
||||
'M': MILVUS_HNSW_M,
|
||||
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
|
||||
}
|
||||
log.info(f"HNSW params: {index_creation_params}")
|
||||
elif index_type == "IVF_FLAT":
|
||||
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
|
||||
log.info(f"IVF_FLAT params: {index_creation_params}")
|
||||
elif index_type == "DISKANN":
|
||||
log.info(f'HNSW params: {index_creation_params}')
|
||||
elif index_type == 'IVF_FLAT':
|
||||
index_creation_params = {'nlist': MILVUS_IVF_FLAT_NLIST}
|
||||
log.info(f'IVF_FLAT params: {index_creation_params}')
|
||||
elif index_type == 'DISKANN':
|
||||
index_creation_params = {
|
||||
"max_degree": MILVUS_DISKANN_MAX_DEGREE,
|
||||
"search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE,
|
||||
'max_degree': MILVUS_DISKANN_MAX_DEGREE,
|
||||
'search_list_size': MILVUS_DISKANN_SEARCH_LIST_SIZE,
|
||||
}
|
||||
log.info(f"DISKANN params: {index_creation_params}")
|
||||
elif index_type in ["FLAT", "AUTOINDEX"]:
|
||||
log.info(f"Using {index_type} index with no specific build-time params.")
|
||||
log.info(f'DISKANN params: {index_creation_params}')
|
||||
elif index_type in ['FLAT', 'AUTOINDEX']:
|
||||
log.info(f'Using {index_type} index with no specific build-time params.')
|
||||
else:
|
||||
log.warning(
|
||||
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
|
||||
f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. "
|
||||
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
|
||||
f'Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. '
|
||||
f'Milvus will use its default for the collection if this type is not directly supported for index creation.'
|
||||
)
|
||||
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
|
||||
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
field_name='vector',
|
||||
index_type=index_type,
|
||||
metric_type=metric_type,
|
||||
params=index_creation_params,
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
)
|
||||
@@ -170,17 +168,13 @@ class MilvusClient(VectorDBBase):
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# Check if the collection exists based on the collection name.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
return self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
return self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
# Delete the collection based on the collection name.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
return self.client.drop_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
return self.client.drop_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -190,15 +184,15 @@ class MilvusClient(VectorDBBase):
|
||||
limit: int = 10,
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
# For some index types like IVF_FLAT, search params like nprobe can be set.
|
||||
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
|
||||
# For simplicity, not adding configurable search_params here, but could be extended.
|
||||
result = self.client.search(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
data=vectors,
|
||||
limit=limit,
|
||||
output_fields=["data", "metadata"],
|
||||
output_fields=['data', 'metadata'],
|
||||
# search_params=search_params # Potentially add later if needed
|
||||
)
|
||||
return self._result_to_search_result(result)
|
||||
@@ -206,11 +200,9 @@ class MilvusClient(VectorDBBase):
|
||||
def query(self, collection_name: str, filter: dict, limit: int = -1):
|
||||
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
|
||||
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(
|
||||
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
log.warning(f'Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
|
||||
return None
|
||||
|
||||
filter_expressions = []
|
||||
@@ -220,9 +212,9 @@ class MilvusClient(VectorDBBase):
|
||||
else:
|
||||
filter_expressions.append(f'metadata["{key}"] == {value}')
|
||||
|
||||
filter_string = " && ".join(filter_expressions)
|
||||
filter_string = ' && '.join(filter_expressions)
|
||||
|
||||
collection = Collection(f"{self.collection_prefix}_{collection_name}")
|
||||
collection = Collection(f'{self.collection_prefix}_{collection_name}')
|
||||
collection.load()
|
||||
|
||||
try:
|
||||
@@ -233,9 +225,9 @@ class MilvusClient(VectorDBBase):
|
||||
iterator = collection.query_iterator(
|
||||
expr=filter_string,
|
||||
output_fields=[
|
||||
"id",
|
||||
"data",
|
||||
"metadata",
|
||||
'id',
|
||||
'data',
|
||||
'metadata',
|
||||
],
|
||||
limit=limit if limit > 0 else -1,
|
||||
)
|
||||
@@ -248,7 +240,7 @@ class MilvusClient(VectorDBBase):
|
||||
break
|
||||
all_results.extend(batch)
|
||||
|
||||
log.debug(f"Total results from query: {len(all_results)}")
|
||||
log.debug(f'Total results from query: {len(all_results)}')
|
||||
return self._result_to_get_result([all_results] if all_results else [[]])
|
||||
|
||||
except Exception as e:
|
||||
@@ -259,7 +251,7 @@ class MilvusClient(VectorDBBase):
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection. This can be very resource-intensive for large collections.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
log.warning(
|
||||
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
|
||||
)
|
||||
@@ -269,35 +261,25 @@ class MilvusClient(VectorDBBase):
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
log.info(
|
||||
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
|
||||
)
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
|
||||
log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.')
|
||||
if not items:
|
||||
log.error(
|
||||
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
|
||||
f'Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.'
|
||||
)
|
||||
raise ValueError(
|
||||
"Cannot create Milvus collection without items to determine vector dimension."
|
||||
)
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
raise ValueError('Cannot create Milvus collection without items to determine vector dimension.')
|
||||
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
|
||||
|
||||
log.info(
|
||||
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
|
||||
)
|
||||
log.info(f'Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
|
||||
return self.client.insert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": process_metadata(item["metadata"]),
|
||||
'id': item['id'],
|
||||
'vector': item['vector'],
|
||||
'data': {'text': item['text']},
|
||||
'metadata': process_metadata(item['metadata']),
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
@@ -305,35 +287,27 @@ class MilvusClient(VectorDBBase):
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
log.info(
|
||||
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
|
||||
)
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
|
||||
log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.')
|
||||
if not items:
|
||||
log.error(
|
||||
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
|
||||
f'Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.'
|
||||
)
|
||||
raise ValueError(
|
||||
"Cannot create Milvus collection for upsert without items to determine vector dimension."
|
||||
'Cannot create Milvus collection for upsert without items to determine vector dimension.'
|
||||
)
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
|
||||
|
||||
log.info(
|
||||
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
|
||||
)
|
||||
log.info(f'Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
|
||||
return self.client.upsert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": process_metadata(item["metadata"]),
|
||||
'id': item['id'],
|
||||
'vector': item['vector'],
|
||||
'data': {'text': item['text']},
|
||||
'metadata': process_metadata(item['metadata']),
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
@@ -346,46 +320,35 @@ class MilvusClient(VectorDBBase):
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids or filter.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(
|
||||
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
log.warning(f'Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
|
||||
return None
|
||||
|
||||
if ids:
|
||||
log.info(
|
||||
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
|
||||
)
|
||||
log.info(f'Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}')
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
ids=ids,
|
||||
)
|
||||
elif filter:
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f'metadata["{key}"] == {json.dumps(value)}'
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
filter_string = ' && '.join([f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items()])
|
||||
log.info(
|
||||
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
|
||||
f'Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}'
|
||||
)
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
filter=filter_string,
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
|
||||
f'Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.'
|
||||
)
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries that match the prefix.
|
||||
log.warning(
|
||||
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
|
||||
)
|
||||
log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
|
||||
collection_names = self.client.list_collections()
|
||||
deleted_collections = []
|
||||
for collection_name_full in collection_names:
|
||||
@@ -393,7 +356,7 @@ class MilvusClient(VectorDBBase):
|
||||
try:
|
||||
self.client.drop_collection(collection_name=collection_name_full)
|
||||
deleted_collections.append(collection_name_full)
|
||||
log.info(f"Deleted collection: {collection_name_full}")
|
||||
log.info(f'Deleted collection: {collection_name_full}')
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting collection {collection_name_full}: {e}")
|
||||
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")
|
||||
log.error(f'Error deleting collection {collection_name_full}: {e}')
|
||||
log.info(f'Milvus reset complete. Deleted collections: {deleted_collections}')
|
||||
|
||||
@@ -33,26 +33,26 @@ from pymilvus import (
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
RESOURCE_ID_FIELD = "resource_id"
|
||||
RESOURCE_ID_FIELD = 'resource_id'
|
||||
|
||||
|
||||
class MilvusClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
# Milvus collection names can only contain numbers, letters, and underscores.
|
||||
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
|
||||
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace('-', '_')
|
||||
connections.connect(
|
||||
alias="default",
|
||||
alias='default',
|
||||
uri=MILVUS_URI,
|
||||
token=MILVUS_TOKEN,
|
||||
db_name=MILVUS_DB,
|
||||
)
|
||||
|
||||
# Main collection types for multi-tenancy
|
||||
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
||||
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
|
||||
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
|
||||
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
|
||||
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
|
||||
self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories'
|
||||
self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge'
|
||||
self.FILE_COLLECTION = f'{self.collection_prefix}_files'
|
||||
self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web_search'
|
||||
self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash_based'
|
||||
self.shared_collections = [
|
||||
self.MEMORY_COLLECTION,
|
||||
self.KNOWLEDGE_COLLECTION,
|
||||
@@ -74,15 +74,13 @@ class MilvusClient(VectorDBBase):
|
||||
"""
|
||||
resource_id = collection_name
|
||||
|
||||
if collection_name.startswith("user-memory-"):
|
||||
if collection_name.startswith('user-memory-'):
|
||||
return self.MEMORY_COLLECTION, resource_id
|
||||
elif collection_name.startswith("file-"):
|
||||
elif collection_name.startswith('file-'):
|
||||
return self.FILE_COLLECTION, resource_id
|
||||
elif collection_name.startswith("web-search-"):
|
||||
elif collection_name.startswith('web-search-'):
|
||||
return self.WEB_SEARCH_COLLECTION, resource_id
|
||||
elif len(collection_name) == 63 and all(
|
||||
c in "0123456789abcdef" for c in collection_name
|
||||
):
|
||||
elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name):
|
||||
return self.HASH_BASED_COLLECTION, resource_id
|
||||
else:
|
||||
return self.KNOWLEDGE_COLLECTION, resource_id
|
||||
@@ -90,36 +88,36 @@ class MilvusClient(VectorDBBase):
|
||||
def _create_shared_collection(self, mt_collection_name: str, dimension: int):
|
||||
fields = [
|
||||
FieldSchema(
|
||||
name="id",
|
||||
name='id',
|
||||
dtype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
auto_id=False,
|
||||
max_length=36,
|
||||
),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="metadata", dtype=DataType.JSON),
|
||||
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
||||
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name='metadata', dtype=DataType.JSON),
|
||||
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
|
||||
]
|
||||
schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
|
||||
schema = CollectionSchema(fields, 'Shared collection for multi-tenancy')
|
||||
collection = Collection(mt_collection_name, schema)
|
||||
|
||||
index_params = {
|
||||
"metric_type": MILVUS_METRIC_TYPE,
|
||||
"index_type": MILVUS_INDEX_TYPE,
|
||||
"params": {},
|
||||
'metric_type': MILVUS_METRIC_TYPE,
|
||||
'index_type': MILVUS_INDEX_TYPE,
|
||||
'params': {},
|
||||
}
|
||||
if MILVUS_INDEX_TYPE == "HNSW":
|
||||
index_params["params"] = {
|
||||
"M": MILVUS_HNSW_M,
|
||||
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
|
||||
if MILVUS_INDEX_TYPE == 'HNSW':
|
||||
index_params['params'] = {
|
||||
'M': MILVUS_HNSW_M,
|
||||
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
|
||||
}
|
||||
elif MILVUS_INDEX_TYPE == "IVF_FLAT":
|
||||
index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
|
||||
elif MILVUS_INDEX_TYPE == 'IVF_FLAT':
|
||||
index_params['params'] = {'nlist': MILVUS_IVF_FLAT_NLIST}
|
||||
|
||||
collection.create_index("vector", index_params)
|
||||
collection.create_index('vector', index_params)
|
||||
collection.create_index(RESOURCE_ID_FIELD)
|
||||
log.info(f"Created shared collection: {mt_collection_name}")
|
||||
log.info(f'Created shared collection: {mt_collection_name}')
|
||||
return collection
|
||||
|
||||
def _ensure_collection(self, mt_collection_name: str, dimension: int):
|
||||
@@ -127,9 +125,7 @@ class MilvusClient(VectorDBBase):
|
||||
self._create_shared_collection(mt_collection_name, dimension)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
if not utility.has_collection(mt_collection):
|
||||
return False
|
||||
|
||||
@@ -141,19 +137,17 @@ class MilvusClient(VectorDBBase):
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]):
|
||||
if not items:
|
||||
return
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
dimension = len(items[0]["vector"])
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
dimension = len(items[0]['vector'])
|
||||
self._ensure_collection(mt_collection, dimension)
|
||||
collection = Collection(mt_collection)
|
||||
|
||||
entities = [
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
'id': item['id'],
|
||||
'vector': item['vector'],
|
||||
'text': item['text'],
|
||||
'metadata': item['metadata'],
|
||||
RESOURCE_ID_FIELD: resource_id,
|
||||
}
|
||||
for item in items
|
||||
@@ -170,41 +164,37 @@ class MilvusClient(VectorDBBase):
|
||||
if not vectors:
|
||||
return None
|
||||
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
if not utility.has_collection(mt_collection):
|
||||
return None
|
||||
|
||||
collection = Collection(mt_collection)
|
||||
collection.load()
|
||||
|
||||
search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
|
||||
search_params = {'metric_type': MILVUS_METRIC_TYPE, 'params': {}}
|
||||
results = collection.search(
|
||||
data=vectors,
|
||||
anns_field="vector",
|
||||
anns_field='vector',
|
||||
param=search_params,
|
||||
limit=limit,
|
||||
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
|
||||
output_fields=["id", "text", "metadata"],
|
||||
output_fields=['id', 'text', 'metadata'],
|
||||
)
|
||||
|
||||
ids, documents, metadatas, distances = [], [], [], []
|
||||
for hits in results:
|
||||
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
|
||||
for hit in hits:
|
||||
batch_ids.append(hit.entity.get("id"))
|
||||
batch_docs.append(hit.entity.get("text"))
|
||||
batch_metadatas.append(hit.entity.get("metadata"))
|
||||
batch_ids.append(hit.entity.get('id'))
|
||||
batch_docs.append(hit.entity.get('text'))
|
||||
batch_metadatas.append(hit.entity.get('metadata'))
|
||||
batch_dists.append(hit.distance)
|
||||
ids.append(batch_ids)
|
||||
documents.append(batch_docs)
|
||||
metadatas.append(batch_metadatas)
|
||||
distances.append(batch_dists)
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, documents=documents, metadatas=metadatas, distances=distances
|
||||
)
|
||||
return SearchResult(ids=ids, documents=documents, metadatas=metadatas, distances=distances)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
@@ -212,9 +202,7 @@ class MilvusClient(VectorDBBase):
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
if not utility.has_collection(mt_collection):
|
||||
return
|
||||
|
||||
@@ -224,14 +212,14 @@ class MilvusClient(VectorDBBase):
|
||||
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
|
||||
if ids:
|
||||
# Milvus expects a string list for 'in' operator
|
||||
id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
|
||||
expr.append(f"id in [{id_list_str}]")
|
||||
id_list_str = ', '.join([f"'{id_val}'" for id_val in ids])
|
||||
expr.append(f'id in [{id_list_str}]')
|
||||
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
expr.append(f"metadata['{key}'] == '{value}'")
|
||||
|
||||
collection.delete(" and ".join(expr))
|
||||
collection.delete(' and '.join(expr))
|
||||
|
||||
def reset(self):
|
||||
for collection_name in self.shared_collections:
|
||||
@@ -239,21 +227,15 @@ class MilvusClient(VectorDBBase):
|
||||
utility.drop_collection(collection_name)
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
if not utility.has_collection(mt_collection):
|
||||
return
|
||||
|
||||
collection = Collection(mt_collection)
|
||||
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
if not utility.has_collection(mt_collection):
|
||||
return None
|
||||
|
||||
@@ -269,8 +251,8 @@ class MilvusClient(VectorDBBase):
|
||||
expr.append(f"metadata['{key}'] == {value}")
|
||||
|
||||
iterator = collection.query_iterator(
|
||||
expr=" and ".join(expr),
|
||||
output_fields=["id", "text", "metadata"],
|
||||
expr=' and '.join(expr),
|
||||
output_fields=['id', 'text', 'metadata'],
|
||||
limit=limit if limit else -1,
|
||||
)
|
||||
|
||||
@@ -282,9 +264,9 @@ class MilvusClient(VectorDBBase):
|
||||
break
|
||||
all_results.extend(batch)
|
||||
|
||||
ids = [res["id"] for res in all_results]
|
||||
documents = [res["text"] for res in all_results]
|
||||
metadatas = [res["metadata"] for res in all_results]
|
||||
ids = [res['id'] for res in all_results]
|
||||
documents = [res['text'] for res in all_results]
|
||||
metadatas = [res['metadata'] for res in all_results]
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
|
||||
@@ -36,17 +36,15 @@ from sqlalchemy.dialects import registry
|
||||
|
||||
|
||||
class OpenGaussDialect(PGDialect_psycopg2):
|
||||
name = "opengauss"
|
||||
name = 'opengauss'
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
try:
|
||||
version = connection.exec_driver_sql("SELECT version()").scalar()
|
||||
version = connection.exec_driver_sql('SELECT version()').scalar()
|
||||
if not version:
|
||||
return (9, 0, 0)
|
||||
|
||||
match = re.search(
|
||||
r"openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?", version, re.IGNORECASE
|
||||
)
|
||||
match = re.search(r'openGauss\s+(\d+)\.(\d+)\.(\d+)(?:-\w+)?', version, re.IGNORECASE)
|
||||
if match:
|
||||
return (int(match.group(1)), int(match.group(2)), int(match.group(3)))
|
||||
|
||||
@@ -56,7 +54,7 @@ class OpenGaussDialect(PGDialect_psycopg2):
|
||||
|
||||
|
||||
# Register dialect
|
||||
registry.register("opengauss", __name__, "OpenGaussDialect")
|
||||
registry.register('opengauss', __name__, 'OpenGaussDialect')
|
||||
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from open_webui.retrieval.vector.main import (
|
||||
@@ -80,11 +78,11 @@ VECTOR_LENGTH = OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
Base = declarative_base()
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
log.setLevel(SRC_LOG_LEVELS['RAG'])
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
__tablename__ = 'document_chunk'
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True)
|
||||
@@ -100,26 +98,24 @@ class OpenGaussClient(VectorDBBase):
|
||||
|
||||
self.session = ScopedSession
|
||||
else:
|
||||
engine_kwargs = {"pool_pre_ping": True, "dialect": OpenGaussDialect()}
|
||||
engine_kwargs = {'pool_pre_ping': True, 'dialect': OpenGaussDialect()}
|
||||
|
||||
if isinstance(OPENGAUSS_POOL_SIZE, int) and OPENGAUSS_POOL_SIZE > 0:
|
||||
engine_kwargs.update(
|
||||
{
|
||||
"pool_size": OPENGAUSS_POOL_SIZE,
|
||||
"max_overflow": OPENGAUSS_POOL_MAX_OVERFLOW,
|
||||
"pool_timeout": OPENGAUSS_POOL_TIMEOUT,
|
||||
"pool_recycle": OPENGAUSS_POOL_RECYCLE,
|
||||
"poolclass": QueuePool,
|
||||
'pool_size': OPENGAUSS_POOL_SIZE,
|
||||
'max_overflow': OPENGAUSS_POOL_MAX_OVERFLOW,
|
||||
'pool_timeout': OPENGAUSS_POOL_TIMEOUT,
|
||||
'pool_recycle': OPENGAUSS_POOL_RECYCLE,
|
||||
'poolclass': QueuePool,
|
||||
}
|
||||
)
|
||||
else:
|
||||
engine_kwargs["poolclass"] = NullPool
|
||||
engine_kwargs['poolclass'] = NullPool
|
||||
|
||||
engine = create_engine(OPENGAUSS_DB_URL, **engine_kwargs)
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
|
||||
self.session = scoped_session(SessionLocal)
|
||||
|
||||
try:
|
||||
@@ -128,47 +124,42 @@ class OpenGaussClient(VectorDBBase):
|
||||
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_vector "
|
||||
"ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);"
|
||||
'CREATE INDEX IF NOT EXISTS idx_document_chunk_vector '
|
||||
'ON document_chunk USING ivfflat (vector vector_cosine_ops) WITH (lists = 100);'
|
||||
)
|
||||
)
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
||||
"ON document_chunk (collection_name);"
|
||||
'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);'
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
log.info("OpenGauss vector database initialization completed.")
|
||||
log.info('OpenGauss vector database initialization completed.')
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"OpenGauss Initialization failed.: {e}")
|
||||
log.exception(f'OpenGauss Initialization failed.: {e}')
|
||||
raise
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
metadata = MetaData()
|
||||
try:
|
||||
document_chunk_table = Table(
|
||||
"document_chunk", metadata, autoload_with=self.session.bind
|
||||
)
|
||||
document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind)
|
||||
except NoSuchTableError:
|
||||
return
|
||||
|
||||
if "vector" in document_chunk_table.columns:
|
||||
vector_column = document_chunk_table.columns["vector"]
|
||||
if 'vector' in document_chunk_table.columns:
|
||||
vector_column = document_chunk_table.columns['vector']
|
||||
vector_type = vector_column.type
|
||||
if isinstance(vector_type, Vector):
|
||||
db_vector_length = vector_type.dim
|
||||
if db_vector_length != VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database."
|
||||
f'Vector dimension mismatch: configured {VECTOR_LENGTH} vs. {db_vector_length} in the database.'
|
||||
)
|
||||
else:
|
||||
raise Exception("The 'vector' column type is not Vector.")
|
||||
else:
|
||||
raise Exception(
|
||||
"The 'vector' column does not exist in the 'document_chunk' table."
|
||||
)
|
||||
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
current_length = len(vector)
|
||||
@@ -182,55 +173,47 @@ class OpenGaussClient(VectorDBBase):
|
||||
try:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
vector = self.adjust_vector_length(item['vector'])
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
id=item['id'],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=process_metadata(item["metadata"]),
|
||||
text=item['text'],
|
||||
vmetadata=process_metadata(item['metadata']),
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Inserting {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
log.info(f"Inserting {len(new_items)} items into collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Failed to insert data: {e}")
|
||||
log.exception(f'Failed to insert data: {e}')
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.id == item["id"])
|
||||
.first()
|
||||
)
|
||||
vector = self.adjust_vector_length(item['vector'])
|
||||
existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first()
|
||||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = process_metadata(item["metadata"])
|
||||
existing.text = item['text']
|
||||
existing.vmetadata = process_metadata(item['metadata'])
|
||||
existing.collection_name = collection_name
|
||||
else:
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
id=item['id'],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=process_metadata(item["metadata"]),
|
||||
text=item['text'],
|
||||
vmetadata=process_metadata(item['metadata']),
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Inserting/updating {len(items)} items in collection '{collection_name}'."
|
||||
)
|
||||
log.info(f"Inserting/updating {len(items)} items in collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Failed to insert or update data.: {e}")
|
||||
log.exception(f'Failed to insert or update data.: {e}')
|
||||
raise
|
||||
|
||||
def search(
|
||||
@@ -250,35 +233,29 @@ class OpenGaussClient(VectorDBBase):
|
||||
def vector_expr(vector):
|
||||
return cast(array(vector), Vector(VECTOR_LENGTH))
|
||||
|
||||
qid_col = column("qid", Integer)
|
||||
q_vector_col = column("q_vector", Vector(VECTOR_LENGTH))
|
||||
qid_col = column('qid', Integer)
|
||||
q_vector_col = column('q_vector', Vector(VECTOR_LENGTH))
|
||||
query_vectors = (
|
||||
values(qid_col, q_vector_col)
|
||||
.data(
|
||||
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
|
||||
)
|
||||
.alias("query_vectors")
|
||||
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
|
||||
.alias('query_vectors')
|
||||
)
|
||||
|
||||
result_fields = [
|
||||
DocumentChunk.id,
|
||||
DocumentChunk.text,
|
||||
DocumentChunk.vmetadata,
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
|
||||
"distance"
|
||||
),
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance'),
|
||||
]
|
||||
|
||||
subq = (
|
||||
select(*result_fields)
|
||||
.where(DocumentChunk.collection_name == collection_name)
|
||||
.order_by(
|
||||
DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)
|
||||
)
|
||||
.order_by(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
)
|
||||
if limit is not None:
|
||||
subq = subq.limit(limit)
|
||||
subq = subq.lateral("result")
|
||||
subq = subq.lateral('result')
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
@@ -309,21 +286,15 @@ class OpenGaussClient(VectorDBBase):
|
||||
metadatas[qid].append(row.vmetadata)
|
||||
|
||||
self.session.rollback()
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Vector search failed: {e}")
|
||||
log.exception(f'Vector search failed: {e}')
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
||||
|
||||
for key, value in filter.items():
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
@@ -344,16 +315,12 @@ class OpenGaussClient(VectorDBBase):
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Conditional query failed: {e}")
|
||||
log.exception(f'Conditional query failed: {e}')
|
||||
return None
|
||||
|
||||
def get(
|
||||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
@@ -370,7 +337,7 @@ class OpenGaussClient(VectorDBBase):
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Failed to retrieve data: {e}")
|
||||
log.exception(f'Failed to retrieve data: {e}')
|
||||
return None
|
||||
|
||||
def delete(
|
||||
@@ -380,32 +347,28 @@ class OpenGaussClient(VectorDBBase):
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
||||
if ids:
|
||||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Failed to delete data: {e}")
|
||||
log.exception(f'Failed to delete data: {e}')
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
deleted = self.session.query(DocumentChunk).delete()
|
||||
self.session.commit()
|
||||
log.info(f"Reset completed. Deleted {deleted} items")
|
||||
log.info(f'Reset completed. Deleted {deleted} items')
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Reset failed: {e}")
|
||||
log.exception(f'Reset failed: {e}')
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
@@ -414,16 +377,14 @@ class OpenGaussClient(VectorDBBase):
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
try:
|
||||
exists = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.collection_name == collection_name)
|
||||
.first()
|
||||
self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first()
|
||||
is not None
|
||||
)
|
||||
self.session.rollback()
|
||||
return exists
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Failed to check collection existence: {e}")
|
||||
log.exception(f'Failed to check collection existence: {e}')
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
|
||||
@@ -24,7 +24,7 @@ from open_webui.config import (
|
||||
|
||||
class OpenSearchClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.index_prefix = "open_webui"
|
||||
self.index_prefix = 'open_webui'
|
||||
self.client = OpenSearch(
|
||||
hosts=[OPENSEARCH_URI],
|
||||
use_ssl=OPENSEARCH_SSL,
|
||||
@@ -33,25 +33,25 @@ class OpenSearchClient(VectorDBBase):
|
||||
)
|
||||
|
||||
def _get_index_name(self, collection_name: str) -> str:
|
||||
return f"{self.index_prefix}_{collection_name}"
|
||||
return f'{self.index_prefix}_{collection_name}'
|
||||
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
if not result["hits"]["hits"]:
|
||||
if not result['hits']['hits']:
|
||||
return None
|
||||
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
for hit in result['hits']['hits']:
|
||||
ids.append(hit['_id'])
|
||||
documents.append(hit['_source'].get('text'))
|
||||
metadatas.append(hit['_source'].get('metadata'))
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
def _result_to_search_result(self, result) -> SearchResult:
|
||||
if not result["hits"]["hits"]:
|
||||
if not result['hits']['hits']:
|
||||
return None
|
||||
|
||||
ids = []
|
||||
@@ -59,11 +59,11 @@ class OpenSearchClient(VectorDBBase):
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
distances.append(hit["_score"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
for hit in result['hits']['hits']:
|
||||
ids.append(hit['_id'])
|
||||
distances.append(hit['_score'])
|
||||
documents.append(hit['_source'].get('text'))
|
||||
metadatas.append(hit['_source'].get('metadata'))
|
||||
|
||||
return SearchResult(
|
||||
ids=[ids],
|
||||
@@ -74,33 +74,31 @@ class OpenSearchClient(VectorDBBase):
|
||||
|
||||
def _create_index(self, collection_name: str, dimension: int):
|
||||
body = {
|
||||
"settings": {"index": {"knn": True}},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"id": {"type": "keyword"},
|
||||
"vector": {
|
||||
"type": "knn_vector",
|
||||
"dimension": dimension, # Adjust based on your vector dimensions
|
||||
"index": True,
|
||||
"similarity": "faiss",
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "innerproduct", # Use inner product to approximate cosine similarity
|
||||
"engine": "faiss",
|
||||
"parameters": {
|
||||
"ef_construction": 128,
|
||||
"m": 16,
|
||||
'settings': {'index': {'knn': True}},
|
||||
'mappings': {
|
||||
'properties': {
|
||||
'id': {'type': 'keyword'},
|
||||
'vector': {
|
||||
'type': 'knn_vector',
|
||||
'dimension': dimension, # Adjust based on your vector dimensions
|
||||
'index': True,
|
||||
'similarity': 'faiss',
|
||||
'method': {
|
||||
'name': 'hnsw',
|
||||
'space_type': 'innerproduct', # Use inner product to approximate cosine similarity
|
||||
'engine': 'faiss',
|
||||
'parameters': {
|
||||
'ef_construction': 128,
|
||||
'm': 16,
|
||||
},
|
||||
},
|
||||
},
|
||||
"text": {"type": "text"},
|
||||
"metadata": {"type": "object"},
|
||||
'text': {'type': 'text'},
|
||||
'metadata': {'type': 'object'},
|
||||
}
|
||||
},
|
||||
}
|
||||
self.client.indices.create(
|
||||
index=self._get_index_name(collection_name), body=body
|
||||
)
|
||||
self.client.indices.create(index=self._get_index_name(collection_name), body=body)
|
||||
|
||||
def _create_batches(self, items: list[VectorItem], batch_size=100):
|
||||
for i in range(0, len(items), batch_size):
|
||||
@@ -128,46 +126,40 @@ class OpenSearchClient(VectorDBBase):
|
||||
return None
|
||||
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {"match_all": {}},
|
||||
"script": {
|
||||
"source": "(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0",
|
||||
"params": {
|
||||
"field": "vector",
|
||||
"query_value": vectors[0],
|
||||
'size': limit,
|
||||
'_source': ['text', 'metadata'],
|
||||
'query': {
|
||||
'script_score': {
|
||||
'query': {'match_all': {}},
|
||||
'script': {
|
||||
'source': '(cosineSimilarity(params.query_value, doc[params.field]) + 1.0) / 2.0',
|
||||
'params': {
|
||||
'field': 'vector',
|
||||
'query_value': vectors[0],
|
||||
}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = self.client.search(
|
||||
index=self._get_index_name(collection_name), body=query
|
||||
)
|
||||
result = self.client.search(index=self._get_index_name(collection_name), body=query)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
"_source": ["text", "metadata"],
|
||||
'query': {'bool': {'filter': []}},
|
||||
'_source': ['text', 'metadata'],
|
||||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"term": {"metadata." + str(field) + ".keyword": value}}
|
||||
)
|
||||
query_body['query']['bool']['filter'].append({'term': {'metadata.' + str(field) + '.keyword': value}})
|
||||
|
||||
size = limit if limit else 10000
|
||||
|
||||
@@ -188,28 +180,24 @@ class OpenSearchClient(VectorDBBase):
|
||||
self._create_index(collection_name, dimension)
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
query = {"query": {"match_all": {}}, "_source": ["text", "metadata"]}
|
||||
query = {'query': {'match_all': {}}, '_source': ['text', 'metadata']}
|
||||
|
||||
result = self.client.search(
|
||||
index=self._get_index_name(collection_name), body=query
|
||||
)
|
||||
result = self.client.search(index=self._get_index_name(collection_name), body=query)
|
||||
return self._result_to_get_result(result)
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
self._create_index_if_not_exists(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
self._create_index_if_not_exists(collection_name=collection_name, dimension=len(items[0]['vector']))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"_op_type": "index",
|
||||
"_index": self._get_index_name(collection_name),
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": process_metadata(item["metadata"]),
|
||||
'_op_type': 'index',
|
||||
'_index': self._get_index_name(collection_name),
|
||||
'_id': item['id'],
|
||||
'_source': {
|
||||
'vector': item['vector'],
|
||||
'text': item['text'],
|
||||
'metadata': process_metadata(item['metadata']),
|
||||
},
|
||||
}
|
||||
for item in batch
|
||||
@@ -218,22 +206,20 @@ class OpenSearchClient(VectorDBBase):
|
||||
self.client.indices.refresh(index=self._get_index_name(collection_name))
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
self._create_index_if_not_exists(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
)
|
||||
self._create_index_if_not_exists(collection_name=collection_name, dimension=len(items[0]['vector']))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"_op_type": "update",
|
||||
"_index": self._get_index_name(collection_name),
|
||||
"_id": item["id"],
|
||||
"doc": {
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": process_metadata(item["metadata"]),
|
||||
'_op_type': 'update',
|
||||
'_index': self._get_index_name(collection_name),
|
||||
'_id': item['id'],
|
||||
'doc': {
|
||||
'vector': item['vector'],
|
||||
'text': item['text'],
|
||||
'metadata': process_metadata(item['metadata']),
|
||||
},
|
||||
"doc_as_upsert": True,
|
||||
'doc_as_upsert': True,
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
@@ -249,27 +235,23 @@ class OpenSearchClient(VectorDBBase):
|
||||
if ids:
|
||||
actions = [
|
||||
{
|
||||
"_op_type": "delete",
|
||||
"_index": self._get_index_name(collection_name),
|
||||
"_id": id,
|
||||
'_op_type': 'delete',
|
||||
'_index': self._get_index_name(collection_name),
|
||||
'_id': id,
|
||||
}
|
||||
for id in ids
|
||||
]
|
||||
bulk(self.client, actions)
|
||||
elif filter:
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
'query': {'bool': {'filter': []}},
|
||||
}
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"term": {"metadata." + str(field) + ".keyword": value}}
|
||||
)
|
||||
self.client.delete_by_query(
|
||||
index=self._get_index_name(collection_name), body=query_body
|
||||
)
|
||||
query_body['query']['bool']['filter'].append({'term': {'metadata.' + str(field) + '.keyword': value}})
|
||||
self.client.delete_by_query(index=self._get_index_name(collection_name), body=query_body)
|
||||
self.client.indices.refresh(index=self._get_index_name(collection_name))
|
||||
|
||||
def reset(self):
|
||||
indices = self.client.indices.get(index=f"{self.index_prefix}_*")
|
||||
indices = self.client.indices.get(index=f'{self.index_prefix}_*')
|
||||
for index in indices:
|
||||
self.client.indices.delete(index=index)
|
||||
|
||||
@@ -94,15 +94,15 @@ class Oracle23aiClient(VectorDBBase):
|
||||
self._create_dbcs_pool()
|
||||
|
||||
dsn = ORACLE_DB_DSN
|
||||
log.info(f"Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]")
|
||||
log.info(f'Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]')
|
||||
|
||||
with self.get_connection() as connection:
|
||||
log.info(f"Connection version: {connection.version}")
|
||||
log.info(f'Connection version: {connection.version}')
|
||||
self._initialize_database(connection)
|
||||
|
||||
log.info("Oracle Vector Search initialization complete.")
|
||||
log.info('Oracle Vector Search initialization complete.')
|
||||
except Exception as e:
|
||||
log.exception(f"Error during Oracle Vector Search initialization: {e}")
|
||||
log.exception(f'Error during Oracle Vector Search initialization: {e}')
|
||||
raise
|
||||
|
||||
def _create_adb_pool(self) -> None:
|
||||
@@ -122,7 +122,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
wallet_location=ORACLE_WALLET_DIR,
|
||||
wallet_password=ORACLE_WALLET_PASSWORD,
|
||||
)
|
||||
log.info("Created ADB connection pool with wallet authentication.")
|
||||
log.info('Created ADB connection pool with wallet authentication.')
|
||||
|
||||
def _create_dbcs_pool(self) -> None:
|
||||
"""
|
||||
@@ -138,7 +138,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
max=ORACLE_DB_POOL_MAX,
|
||||
increment=ORACLE_DB_POOL_INCREMENT,
|
||||
)
|
||||
log.info("Created DB connection pool with basic authentication.")
|
||||
log.info('Created DB connection pool with basic authentication.')
|
||||
|
||||
def get_connection(self):
|
||||
"""
|
||||
@@ -155,13 +155,11 @@ class Oracle23aiClient(VectorDBBase):
|
||||
return connection
|
||||
except oracledb.DatabaseError as e:
|
||||
(error_obj,) = e.args
|
||||
log.exception(
|
||||
f"Connection attempt {attempt + 1} failed: {error_obj.message}"
|
||||
)
|
||||
log.exception(f'Connection attempt {attempt + 1} failed: {error_obj.message}')
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2**attempt
|
||||
log.info(f"Retrying in {wait_time} seconds...")
|
||||
log.info(f'Retrying in {wait_time} seconds...')
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
raise
|
||||
@@ -177,30 +175,30 @@ class Oracle23aiClient(VectorDBBase):
|
||||
def _monitor():
|
||||
while True:
|
||||
try:
|
||||
log.info("[HealthCheck] Running periodic DB health check...")
|
||||
log.info('[HealthCheck] Running periodic DB health check...')
|
||||
self.ensure_connection()
|
||||
log.info("[HealthCheck] Connection is healthy.")
|
||||
log.info('[HealthCheck] Connection is healthy.')
|
||||
except Exception as e:
|
||||
log.exception(f"[HealthCheck] Connection health check failed: {e}")
|
||||
log.exception(f'[HealthCheck] Connection health check failed: {e}')
|
||||
time.sleep(interval_seconds)
|
||||
|
||||
thread = threading.Thread(target=_monitor, daemon=True)
|
||||
thread.start()
|
||||
log.info(f"Started DB health monitor every {interval_seconds} seconds.")
|
||||
log.info(f'Started DB health monitor every {interval_seconds} seconds.')
|
||||
|
||||
def _reconnect_pool(self):
|
||||
"""
|
||||
Attempt to reinitialize the connection pool if it's been closed or broken.
|
||||
"""
|
||||
try:
|
||||
log.info("Attempting to reinitialize the Oracle connection pool...")
|
||||
log.info('Attempting to reinitialize the Oracle connection pool...')
|
||||
|
||||
# Close existing pool if it exists
|
||||
if self.pool:
|
||||
try:
|
||||
self.pool.close()
|
||||
except Exception as close_error:
|
||||
log.warning(f"Error closing existing pool: {close_error}")
|
||||
log.warning(f'Error closing existing pool: {close_error}')
|
||||
|
||||
# Re-create the appropriate connection pool based on DB type
|
||||
if ORACLE_DB_USE_WALLET:
|
||||
@@ -208,9 +206,9 @@ class Oracle23aiClient(VectorDBBase):
|
||||
else: # DBCS
|
||||
self._create_dbcs_pool()
|
||||
|
||||
log.info("Connection pool reinitialized.")
|
||||
log.info('Connection pool reinitialized.')
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to reinitialize the connection pool: {e}")
|
||||
log.exception(f'Failed to reinitialize the connection pool: {e}')
|
||||
raise
|
||||
|
||||
def ensure_connection(self):
|
||||
@@ -220,11 +218,9 @@ class Oracle23aiClient(VectorDBBase):
|
||||
try:
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("SELECT 1 FROM dual")
|
||||
cursor.execute('SELECT 1 FROM dual')
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Connection check failed: {e}, attempting to reconnect pool..."
|
||||
)
|
||||
log.exception(f'Connection check failed: {e}, attempting to reconnect pool...')
|
||||
self._reconnect_pool()
|
||||
|
||||
def _output_type_handler(self, cursor, metadata):
|
||||
@@ -239,9 +235,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
A variable with appropriate conversion for vector types
|
||||
"""
|
||||
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
|
||||
return cursor.var(
|
||||
metadata.type_code, arraysize=cursor.arraysize, outconverter=list
|
||||
)
|
||||
return cursor.var(metadata.type_code, arraysize=cursor.arraysize, outconverter=list)
|
||||
|
||||
def _initialize_database(self, connection) -> None:
|
||||
"""
|
||||
@@ -257,7 +251,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
"""
|
||||
with connection.cursor() as cursor:
|
||||
try:
|
||||
log.info("Creating Table document_chunk")
|
||||
log.info('Creating Table document_chunk')
|
||||
cursor.execute(
|
||||
"""
|
||||
BEGIN
|
||||
@@ -279,7 +273,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating Index document_chunk_collection_name_idx")
|
||||
log.info('Creating Index document_chunk_collection_name_idx')
|
||||
cursor.execute(
|
||||
"""
|
||||
BEGIN
|
||||
@@ -296,7 +290,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
"""
|
||||
)
|
||||
|
||||
log.info("Creating VECTOR INDEX document_chunk_vector_ivf_idx")
|
||||
log.info('Creating VECTOR INDEX document_chunk_vector_ivf_idx')
|
||||
cursor.execute(
|
||||
"""
|
||||
BEGIN
|
||||
@@ -318,11 +312,11 @@ class Oracle23aiClient(VectorDBBase):
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
log.info("Database initialization completed successfully.")
|
||||
log.info('Database initialization completed successfully.')
|
||||
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
log.exception(f"Error during database initialization: {e}")
|
||||
log.exception(f'Error during database initialization: {e}')
|
||||
raise
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
@@ -344,7 +338,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
Returns:
|
||||
bytes: The vector in Oracle BLOB format
|
||||
"""
|
||||
return array.array("f", vector)
|
||||
return array.array('f', vector)
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
"""
|
||||
@@ -373,7 +367,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
"""
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj)
|
||||
raise TypeError(f"{obj} is not JSON serializable")
|
||||
raise TypeError(f'{obj} is not JSON serializable')
|
||||
|
||||
def _metadata_to_json(self, metadata: Dict) -> str:
|
||||
"""
|
||||
@@ -385,7 +379,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
Returns:
|
||||
str: JSON representation of metadata
|
||||
"""
|
||||
return json.dumps(metadata, default=self._decimal_handler) if metadata else "{}"
|
||||
return json.dumps(metadata, default=self._decimal_handler) if metadata else '{}'
|
||||
|
||||
def _json_to_metadata(self, json_str: str) -> Dict:
|
||||
"""
|
||||
@@ -424,8 +418,8 @@ class Oracle23aiClient(VectorDBBase):
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
for item in items:
|
||||
vector_blob = self._vector_to_blob(item["vector"])
|
||||
metadata_json = self._metadata_to_json(item["metadata"])
|
||||
vector_blob = self._vector_to_blob(item['vector'])
|
||||
metadata_json = self._metadata_to_json(item['metadata'])
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -434,22 +428,20 @@ class Oracle23aiClient(VectorDBBase):
|
||||
VALUES (:id, :collection_name, :text, :metadata, :vector)
|
||||
""",
|
||||
{
|
||||
"id": item["id"],
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata": metadata_json,
|
||||
"vector": vector_blob,
|
||||
'id': item['id'],
|
||||
'collection_name': collection_name,
|
||||
'text': item['text'],
|
||||
'metadata': metadata_json,
|
||||
'vector': vector_blob,
|
||||
},
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
log.info(
|
||||
f"Successfully inserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
log.info(f"Successfully inserted {len(items)} items into collection '{collection_name}'.")
|
||||
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
log.exception(f'Error during insert: {e}')
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -480,8 +472,8 @@ class Oracle23aiClient(VectorDBBase):
|
||||
try:
|
||||
with connection.cursor() as cursor:
|
||||
for item in items:
|
||||
vector_blob = self._vector_to_blob(item["vector"])
|
||||
metadata_json = self._metadata_to_json(item["metadata"])
|
||||
vector_blob = self._vector_to_blob(item['vector'])
|
||||
metadata_json = self._metadata_to_json(item['metadata'])
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -499,27 +491,25 @@ class Oracle23aiClient(VectorDBBase):
|
||||
VALUES (:ins_id, :ins_collection_name, :ins_text, :ins_metadata, :ins_vector)
|
||||
""",
|
||||
{
|
||||
"merge_id": item["id"],
|
||||
"upd_collection_name": collection_name,
|
||||
"upd_text": item["text"],
|
||||
"upd_metadata": metadata_json,
|
||||
"upd_vector": vector_blob,
|
||||
"ins_id": item["id"],
|
||||
"ins_collection_name": collection_name,
|
||||
"ins_text": item["text"],
|
||||
"ins_metadata": metadata_json,
|
||||
"ins_vector": vector_blob,
|
||||
'merge_id': item['id'],
|
||||
'upd_collection_name': collection_name,
|
||||
'upd_text': item['text'],
|
||||
'upd_metadata': metadata_json,
|
||||
'upd_vector': vector_blob,
|
||||
'ins_id': item['id'],
|
||||
'ins_collection_name': collection_name,
|
||||
'ins_text': item['text'],
|
||||
'ins_metadata': metadata_json,
|
||||
'ins_vector': vector_blob,
|
||||
},
|
||||
)
|
||||
|
||||
connection.commit()
|
||||
log.info(
|
||||
f"Successfully upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
log.info(f"Successfully upserted {len(items)} items into collection '{collection_name}'.")
|
||||
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
log.exception(f'Error during upsert: {e}')
|
||||
raise
|
||||
|
||||
def search(
|
||||
@@ -551,13 +541,11 @@ class Oracle23aiClient(VectorDBBase):
|
||||
... for i, (id, dist) in enumerate(zip(results.ids[0], results.distances[0])):
|
||||
... log.info(f"Match {i+1}: id={id}, distance={dist}")
|
||||
"""
|
||||
log.info(
|
||||
f"Searching items from collection '{collection_name}' with limit {limit}."
|
||||
)
|
||||
log.info(f"Searching items from collection '{collection_name}' with limit {limit}.")
|
||||
|
||||
try:
|
||||
if not vectors:
|
||||
log.warning("No vectors provided for search.")
|
||||
log.warning('No vectors provided for search.')
|
||||
return None
|
||||
|
||||
num_queries = len(vectors)
|
||||
@@ -583,9 +571,9 @@ class Oracle23aiClient(VectorDBBase):
|
||||
FETCH APPROX FIRST :limit ROWS ONLY
|
||||
""",
|
||||
{
|
||||
"query_vector": vector_blob,
|
||||
"collection_name": collection_name,
|
||||
"limit": limit,
|
||||
'query_vector': vector_blob,
|
||||
'collection_name': collection_name,
|
||||
'limit': limit,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -593,35 +581,21 @@ class Oracle23aiClient(VectorDBBase):
|
||||
|
||||
for row in results:
|
||||
ids[qid].append(row[0])
|
||||
documents[qid].append(
|
||||
row[1].read()
|
||||
if isinstance(row[1], oracledb.LOB)
|
||||
else str(row[1])
|
||||
)
|
||||
documents[qid].append(row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]))
|
||||
# 🔧 FIXED: Parse JSON metadata properly
|
||||
metadata_str = (
|
||||
row[2].read()
|
||||
if isinstance(row[2], oracledb.LOB)
|
||||
else row[2]
|
||||
)
|
||||
metadata_str = row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
|
||||
metadatas[qid].append(self._json_to_metadata(metadata_str))
|
||||
distances[qid].append(float(row[3]))
|
||||
|
||||
log.info(
|
||||
f"Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results."
|
||||
)
|
||||
log.info(f'Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results.')
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error during search: {e}")
|
||||
log.exception(f'Error during search: {e}')
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
"""
|
||||
Query items based on metadata filters.
|
||||
|
||||
@@ -653,15 +627,15 @@ class Oracle23aiClient(VectorDBBase):
|
||||
WHERE collection_name = :collection_name
|
||||
"""
|
||||
|
||||
params = {"collection_name": collection_name}
|
||||
params = {'collection_name': collection_name}
|
||||
|
||||
for i, (key, value) in enumerate(filter.items()):
|
||||
param_name = f"value_{i}"
|
||||
param_name = f'value_{i}'
|
||||
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
|
||||
params[param_name] = str(value)
|
||||
|
||||
query += " FETCH FIRST :limit ROWS ONLY"
|
||||
params["limit"] = limit
|
||||
query += ' FETCH FIRST :limit ROWS ONLY'
|
||||
params['limit'] = limit
|
||||
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
@@ -669,32 +643,25 @@ class Oracle23aiClient(VectorDBBase):
|
||||
results = cursor.fetchall()
|
||||
|
||||
if not results:
|
||||
log.info("No results found for query.")
|
||||
log.info('No results found for query.')
|
||||
return None
|
||||
|
||||
ids = [[row[0] for row in results]]
|
||||
documents = [
|
||||
[
|
||||
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
|
||||
for row in results
|
||||
]
|
||||
]
|
||||
documents = [[row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]) for row in results]]
|
||||
# 🔧 FIXED: Parse JSON metadata properly
|
||||
metadatas = [
|
||||
[
|
||||
self._json_to_metadata(
|
||||
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
|
||||
)
|
||||
self._json_to_metadata(row[2].read() if isinstance(row[2], oracledb.LOB) else row[2])
|
||||
for row in results
|
||||
]
|
||||
]
|
||||
|
||||
log.info(f"Query completed. Found {len(results)} results.")
|
||||
log.info(f'Query completed. Found {len(results)} results.')
|
||||
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error during query: {e}")
|
||||
log.exception(f'Error during query: {e}')
|
||||
return None
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
@@ -729,28 +696,21 @@ class Oracle23aiClient(VectorDBBase):
|
||||
WHERE collection_name = :collection_name
|
||||
FETCH FIRST :limit ROWS ONLY
|
||||
""",
|
||||
{"collection_name": collection_name, "limit": limit},
|
||||
{'collection_name': collection_name, 'limit': limit},
|
||||
)
|
||||
|
||||
results = cursor.fetchall()
|
||||
|
||||
if not results:
|
||||
log.info("No results found.")
|
||||
log.info('No results found.')
|
||||
return None
|
||||
|
||||
ids = [[row[0] for row in results]]
|
||||
documents = [
|
||||
[
|
||||
row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1])
|
||||
for row in results
|
||||
]
|
||||
]
|
||||
documents = [[row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]) for row in results]]
|
||||
# 🔧 FIXED: Parse JSON metadata properly
|
||||
metadatas = [
|
||||
[
|
||||
self._json_to_metadata(
|
||||
row[2].read() if isinstance(row[2], oracledb.LOB) else row[2]
|
||||
)
|
||||
self._json_to_metadata(row[2].read() if isinstance(row[2], oracledb.LOB) else row[2])
|
||||
for row in results
|
||||
]
|
||||
]
|
||||
@@ -758,7 +718,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error during get: {e}")
|
||||
log.exception(f'Error during get: {e}')
|
||||
return None
|
||||
|
||||
def delete(
|
||||
@@ -790,21 +750,19 @@ class Oracle23aiClient(VectorDBBase):
|
||||
log.info(f"Deleting items from collection '{collection_name}'.")
|
||||
|
||||
try:
|
||||
query = (
|
||||
"DELETE FROM document_chunk WHERE collection_name = :collection_name"
|
||||
)
|
||||
params = {"collection_name": collection_name}
|
||||
query = 'DELETE FROM document_chunk WHERE collection_name = :collection_name'
|
||||
params = {'collection_name': collection_name}
|
||||
|
||||
if ids:
|
||||
# 🔧 FIXED: Use proper parameterized query to prevent SQL injection
|
||||
placeholders = ",".join([f":id_{i}" for i in range(len(ids))])
|
||||
query += f" AND id IN ({placeholders})"
|
||||
placeholders = ','.join([f':id_{i}' for i in range(len(ids))])
|
||||
query += f' AND id IN ({placeholders})'
|
||||
for i, id_val in enumerate(ids):
|
||||
params[f"id_{i}"] = id_val
|
||||
params[f'id_{i}'] = id_val
|
||||
|
||||
if filter:
|
||||
for i, (key, value) in enumerate(filter.items()):
|
||||
param_name = f"value_{i}"
|
||||
param_name = f'value_{i}'
|
||||
query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}"
|
||||
params[param_name] = str(value)
|
||||
|
||||
@@ -817,7 +775,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error during delete: {e}")
|
||||
log.exception(f'Error during delete: {e}')
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -833,21 +791,19 @@ class Oracle23aiClient(VectorDBBase):
|
||||
>>> client = Oracle23aiClient()
|
||||
>>> client.reset() # Warning: Removes all data!
|
||||
"""
|
||||
log.info("Resetting database - deleting all items.")
|
||||
log.info('Resetting database - deleting all items.')
|
||||
|
||||
try:
|
||||
with self.get_connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute("DELETE FROM document_chunk")
|
||||
cursor.execute('DELETE FROM document_chunk')
|
||||
deleted = cursor.rowcount
|
||||
connection.commit()
|
||||
|
||||
log.info(
|
||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||
)
|
||||
log.info(f"Reset complete. Deleted {deleted} items from 'document_chunk' table.")
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error during reset: {e}")
|
||||
log.exception(f'Error during reset: {e}')
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
@@ -862,11 +818,11 @@ class Oracle23aiClient(VectorDBBase):
|
||||
>>> client.close()
|
||||
"""
|
||||
try:
|
||||
if hasattr(self, "pool") and self.pool:
|
||||
if hasattr(self, 'pool') and self.pool:
|
||||
self.pool.close()
|
||||
log.info("Oracle Vector Search connection pool closed.")
|
||||
log.info('Oracle Vector Search connection pool closed.')
|
||||
except Exception as e:
|
||||
log.exception(f"Error closing connection pool: {e}")
|
||||
log.exception(f'Error closing connection pool: {e}')
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""
|
||||
@@ -895,7 +851,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
WHERE collection_name = :collection_name
|
||||
FETCH FIRST 1 ROWS ONLY
|
||||
""",
|
||||
{"collection_name": collection_name},
|
||||
{'collection_name': collection_name},
|
||||
)
|
||||
|
||||
count = cursor.fetchone()[0]
|
||||
@@ -903,7 +859,7 @@ class Oracle23aiClient(VectorDBBase):
|
||||
return count > 0
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error checking collection existence: {e}")
|
||||
log.exception(f'Error checking collection existence: {e}')
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
@@ -929,15 +885,13 @@ class Oracle23aiClient(VectorDBBase):
|
||||
DELETE FROM document_chunk
|
||||
WHERE collection_name = :collection_name
|
||||
""",
|
||||
{"collection_name": collection_name},
|
||||
{'collection_name': collection_name},
|
||||
)
|
||||
|
||||
deleted = cursor.rowcount
|
||||
connection.commit()
|
||||
|
||||
log.info(
|
||||
f"Collection '{collection_name}' deleted. Removed {deleted} items."
|
||||
)
|
||||
log.info(f"Collection '{collection_name}' deleted. Removed {deleted} items.")
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error deleting collection '{collection_name}': {e}")
|
||||
|
||||
@@ -55,7 +55,7 @@ VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
USE_HALFVEC = PGVECTOR_USE_HALFVEC
|
||||
|
||||
VECTOR_TYPE_FACTORY = HALFVEC if USE_HALFVEC else Vector
|
||||
VECTOR_OPCLASS = "halfvec_cosine_ops" if USE_HALFVEC else "vector_cosine_ops"
|
||||
VECTOR_OPCLASS = 'halfvec_cosine_ops' if USE_HALFVEC else 'vector_cosine_ops'
|
||||
Base = declarative_base()
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -65,12 +65,12 @@ def pgcrypto_encrypt(val, key):
|
||||
return func.pgp_sym_encrypt(val, literal(key))
|
||||
|
||||
|
||||
def pgcrypto_decrypt(col, key, outtype="text"):
|
||||
def pgcrypto_decrypt(col, key, outtype='text'):
|
||||
return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype)
|
||||
|
||||
|
||||
class DocumentChunk(Base):
|
||||
__tablename__ = "document_chunk"
|
||||
__tablename__ = 'document_chunk'
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
vector = Column(VECTOR_TYPE_FACTORY(dim=VECTOR_LENGTH), nullable=True)
|
||||
@@ -86,7 +86,6 @@ class DocumentChunk(Base):
|
||||
|
||||
class PgvectorClient(VectorDBBase):
|
||||
def __init__(self) -> None:
|
||||
|
||||
# if no pgvector uri, use the existing database connection
|
||||
if not PGVECTOR_DB_URL:
|
||||
from open_webui.internal.db import ScopedSession
|
||||
@@ -105,46 +104,44 @@ class PgvectorClient(VectorDBBase):
|
||||
poolclass=QueuePool,
|
||||
)
|
||||
else:
|
||||
engine = create_engine(
|
||||
PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool)
|
||||
else:
|
||||
engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True)
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
|
||||
self.session = scoped_session(SessionLocal)
|
||||
|
||||
try:
|
||||
# Ensure the pgvector extension is available
|
||||
# Use a conditional check to avoid permission issues on Azure PostgreSQL
|
||||
if PGVECTOR_CREATE_EXTENSION:
|
||||
self.session.execute(text("""
|
||||
self.session.execute(
|
||||
text("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
END IF;
|
||||
END $$;
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
# Ensure the pgcrypto extension is available for encryption
|
||||
# Use a conditional check to avoid permission issues on Azure PostgreSQL
|
||||
self.session.execute(text("""
|
||||
self.session.execute(
|
||||
text("""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN
|
||||
CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
||||
END IF;
|
||||
END $$;
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
if not PGVECTOR_PGCRYPTO_KEY:
|
||||
raise ValueError(
|
||||
"PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled."
|
||||
)
|
||||
raise ValueError('PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled.')
|
||||
|
||||
# Check vector length consistency
|
||||
self.check_vector_length()
|
||||
@@ -160,15 +157,14 @@ class PgvectorClient(VectorDBBase):
|
||||
|
||||
self.session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name "
|
||||
"ON document_chunk (collection_name);"
|
||||
'CREATE INDEX IF NOT EXISTS idx_document_chunk_collection_name ON document_chunk (collection_name);'
|
||||
)
|
||||
)
|
||||
self.session.commit()
|
||||
log.info("Initialization complete.")
|
||||
log.info('Initialization complete.')
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during initialization: {e}")
|
||||
log.exception(f'Error during initialization: {e}')
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
@@ -176,7 +172,7 @@ class PgvectorClient(VectorDBBase):
|
||||
if not index_def:
|
||||
return None
|
||||
try:
|
||||
after_using = index_def.lower().split("using ", 1)[1]
|
||||
after_using = index_def.lower().split('using ', 1)[1]
|
||||
return after_using.split()[0]
|
||||
except (IndexError, AttributeError):
|
||||
return None
|
||||
@@ -189,23 +185,23 @@ class PgvectorClient(VectorDBBase):
|
||||
index_method,
|
||||
)
|
||||
elif USE_HALFVEC:
|
||||
index_method = "hnsw"
|
||||
index_method = 'hnsw'
|
||||
log.info(
|
||||
"VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.",
|
||||
'VECTOR_LENGTH=%s exceeds 2000; using halfvec column type with hnsw index.',
|
||||
VECTOR_LENGTH,
|
||||
)
|
||||
else:
|
||||
index_method = "ivfflat"
|
||||
index_method = 'ivfflat'
|
||||
|
||||
if index_method == "hnsw":
|
||||
index_options = f"WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})"
|
||||
if index_method == 'hnsw':
|
||||
index_options = f'WITH (m = {PGVECTOR_HNSW_M}, ef_construction = {PGVECTOR_HNSW_EF_CONSTRUCTION})'
|
||||
else:
|
||||
index_options = f"WITH (lists = {PGVECTOR_IVFFLAT_LISTS})"
|
||||
index_options = f'WITH (lists = {PGVECTOR_IVFFLAT_LISTS})'
|
||||
|
||||
return index_method, index_options
|
||||
|
||||
def _ensure_vector_index(self, index_method: str, index_options: str) -> None:
|
||||
index_name = "idx_document_chunk_vector"
|
||||
index_name = 'idx_document_chunk_vector'
|
||||
existing_index_def = self.session.execute(
|
||||
text("""
|
||||
SELECT indexdef
|
||||
@@ -214,7 +210,7 @@ class PgvectorClient(VectorDBBase):
|
||||
AND tablename = 'document_chunk'
|
||||
AND indexname = :index_name
|
||||
"""),
|
||||
{"index_name": index_name},
|
||||
{'index_name': index_name},
|
||||
).scalar()
|
||||
|
||||
existing_method = self._extract_index_method(existing_index_def)
|
||||
@@ -222,23 +218,23 @@ class PgvectorClient(VectorDBBase):
|
||||
raise RuntimeError(
|
||||
f"Existing pgvector index '{index_name}' uses method '{existing_method}' but configuration now "
|
||||
f"requires '{index_method}'. Automatic rebuild is disabled to prevent long-running maintenance. "
|
||||
"Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) "
|
||||
"and recreate it with the new method before restarting Open WebUI."
|
||||
'Drop the index manually (optionally after tuning maintenance_work_mem/max_parallel_maintenance_workers) '
|
||||
'and recreate it with the new method before restarting Open WebUI.'
|
||||
)
|
||||
|
||||
if not existing_index_def:
|
||||
index_sql = (
|
||||
f"CREATE INDEX IF NOT EXISTS {index_name} "
|
||||
f"ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})"
|
||||
f'CREATE INDEX IF NOT EXISTS {index_name} '
|
||||
f'ON document_chunk USING {index_method} (vector {VECTOR_OPCLASS})'
|
||||
)
|
||||
if index_options:
|
||||
index_sql = f"{index_sql} {index_options}"
|
||||
index_sql = f'{index_sql} {index_options}'
|
||||
self.session.execute(text(index_sql))
|
||||
log.info(
|
||||
"Ensured vector index '%s' using %s%s.",
|
||||
index_name,
|
||||
index_method,
|
||||
f" {index_options}" if index_options else "",
|
||||
f' {index_options}' if index_options else '',
|
||||
)
|
||||
|
||||
def check_vector_length(self) -> None:
|
||||
@@ -249,16 +245,14 @@ class PgvectorClient(VectorDBBase):
|
||||
metadata = MetaData()
|
||||
try:
|
||||
# Attempt to reflect the 'document_chunk' table
|
||||
document_chunk_table = Table(
|
||||
"document_chunk", metadata, autoload_with=self.session.bind
|
||||
)
|
||||
document_chunk_table = Table('document_chunk', metadata, autoload_with=self.session.bind)
|
||||
except NoSuchTableError:
|
||||
# Table does not exist; no action needed
|
||||
return
|
||||
|
||||
# Proceed to check the vector column
|
||||
if "vector" in document_chunk_table.columns:
|
||||
vector_column = document_chunk_table.columns["vector"]
|
||||
if 'vector' in document_chunk_table.columns:
|
||||
vector_column = document_chunk_table.columns['vector']
|
||||
vector_type = vector_column.type
|
||||
expected_type = HALFVEC if USE_HALFVEC else Vector
|
||||
|
||||
@@ -268,16 +262,14 @@ class PgvectorClient(VectorDBBase):
|
||||
f"('{expected_type.__name__}') for VECTOR_LENGTH {VECTOR_LENGTH}."
|
||||
)
|
||||
|
||||
db_vector_length = getattr(vector_type, "dim", None)
|
||||
db_vector_length = getattr(vector_type, 'dim', None)
|
||||
if db_vector_length is not None and db_vector_length != VECTOR_LENGTH:
|
||||
raise Exception(
|
||||
f"VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. "
|
||||
"Cannot change vector size after initialization without migrating the data."
|
||||
f'VECTOR_LENGTH {VECTOR_LENGTH} does not match existing vector column dimension {db_vector_length}. '
|
||||
'Cannot change vector size after initialization without migrating the data.'
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"The 'vector' column does not exist in the 'document_chunk' table."
|
||||
)
|
||||
raise Exception("The 'vector' column does not exist in the 'document_chunk' table.")
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
# Adjust vector to have length VECTOR_LENGTH
|
||||
@@ -294,10 +286,10 @@ class PgvectorClient(VectorDBBase):
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
vector = self.adjust_vector_length(item['vector'])
|
||||
# Use raw SQL for BYTEA/pgcrypto
|
||||
# Ensure metadata is converted to its JSON text representation
|
||||
json_metadata = json.dumps(item["metadata"])
|
||||
json_metadata = json.dumps(item['metadata'])
|
||||
self.session.execute(
|
||||
text("""
|
||||
INSERT INTO document_chunk
|
||||
@@ -310,12 +302,12 @@ class PgvectorClient(VectorDBBase):
|
||||
ON CONFLICT (id) DO NOTHING
|
||||
"""),
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata_text": json_metadata,
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
'id': item['id'],
|
||||
'vector': vector,
|
||||
'collection_name': collection_name,
|
||||
'text': item['text'],
|
||||
'metadata_text': json_metadata,
|
||||
'key': PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
self.session.commit()
|
||||
@@ -324,31 +316,29 @@ class PgvectorClient(VectorDBBase):
|
||||
else:
|
||||
new_items = []
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
vector = self.adjust_vector_length(item['vector'])
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
id=item['id'],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=process_metadata(item["metadata"]),
|
||||
text=item['text'],
|
||||
vmetadata=process_metadata(item['metadata']),
|
||||
)
|
||||
new_items.append(new_chunk)
|
||||
self.session.bulk_save_objects(new_items)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Inserted {len(new_items)} items into collection '{collection_name}'."
|
||||
)
|
||||
log.info(f"Inserted {len(new_items)} items into collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
log.exception(f'Error during insert: {e}')
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
json_metadata = json.dumps(item["metadata"])
|
||||
vector = self.adjust_vector_length(item['vector'])
|
||||
json_metadata = json.dumps(item['metadata'])
|
||||
self.session.execute(
|
||||
text("""
|
||||
INSERT INTO document_chunk
|
||||
@@ -365,47 +355,39 @@ class PgvectorClient(VectorDBBase):
|
||||
vmetadata = EXCLUDED.vmetadata
|
||||
"""),
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": vector,
|
||||
"collection_name": collection_name,
|
||||
"text": item["text"],
|
||||
"metadata_text": json_metadata,
|
||||
"key": PGVECTOR_PGCRYPTO_KEY,
|
||||
'id': item['id'],
|
||||
'vector': vector,
|
||||
'collection_name': collection_name,
|
||||
'text': item['text'],
|
||||
'metadata_text': json_metadata,
|
||||
'key': PGVECTOR_PGCRYPTO_KEY,
|
||||
},
|
||||
)
|
||||
self.session.commit()
|
||||
log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'")
|
||||
else:
|
||||
for item in items:
|
||||
vector = self.adjust_vector_length(item["vector"])
|
||||
existing = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.id == item["id"])
|
||||
.first()
|
||||
)
|
||||
vector = self.adjust_vector_length(item['vector'])
|
||||
existing = self.session.query(DocumentChunk).filter(DocumentChunk.id == item['id']).first()
|
||||
if existing:
|
||||
existing.vector = vector
|
||||
existing.text = item["text"]
|
||||
existing.vmetadata = process_metadata(item["metadata"])
|
||||
existing.collection_name = (
|
||||
collection_name # Update collection_name if necessary
|
||||
)
|
||||
existing.text = item['text']
|
||||
existing.vmetadata = process_metadata(item['metadata'])
|
||||
existing.collection_name = collection_name # Update collection_name if necessary
|
||||
else:
|
||||
new_chunk = DocumentChunk(
|
||||
id=item["id"],
|
||||
id=item['id'],
|
||||
vector=vector,
|
||||
collection_name=collection_name,
|
||||
text=item["text"],
|
||||
vmetadata=process_metadata(item["metadata"]),
|
||||
text=item['text'],
|
||||
vmetadata=process_metadata(item['metadata']),
|
||||
)
|
||||
self.session.add(new_chunk)
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Upserted {len(items)} items into collection '{collection_name}'."
|
||||
)
|
||||
log.info(f"Upserted {len(items)} items into collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
log.exception(f'Error during upsert: {e}')
|
||||
raise
|
||||
|
||||
def search(
|
||||
@@ -427,38 +409,26 @@ class PgvectorClient(VectorDBBase):
|
||||
return cast(array(vector), VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||
|
||||
# Create the values for query vectors
|
||||
qid_col = column("qid", Integer)
|
||||
q_vector_col = column("q_vector", VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||
qid_col = column('qid', Integer)
|
||||
q_vector_col = column('q_vector', VECTOR_TYPE_FACTORY(VECTOR_LENGTH))
|
||||
query_vectors = (
|
||||
values(qid_col, q_vector_col)
|
||||
.data(
|
||||
[(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)]
|
||||
)
|
||||
.alias("query_vectors")
|
||||
.data([(idx, vector_expr(vector)) for idx, vector in enumerate(vectors)])
|
||||
.alias('query_vectors')
|
||||
)
|
||||
|
||||
result_fields = [
|
||||
DocumentChunk.id,
|
||||
]
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
result_fields.append(pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'))
|
||||
result_fields.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text")
|
||||
)
|
||||
result_fields.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata")
|
||||
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata')
|
||||
)
|
||||
else:
|
||||
result_fields.append(DocumentChunk.text)
|
||||
result_fields.append(DocumentChunk.vmetadata)
|
||||
result_fields.append(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label(
|
||||
"distance"
|
||||
)
|
||||
)
|
||||
result_fields.append((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label('distance'))
|
||||
|
||||
# Build the lateral subquery for each query vector
|
||||
where_clauses = [DocumentChunk.collection_name == collection_name]
|
||||
@@ -466,9 +436,9 @@ class PgvectorClient(VectorDBBase):
|
||||
# Apply metadata filter if provided
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
if isinstance(value, dict) and "$in" in value:
|
||||
if isinstance(value, dict) and '$in' in value:
|
||||
# Handle $in operator: {"field": {"$in": [values]}}
|
||||
in_values = value["$in"]
|
||||
in_values = value['$in']
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
where_clauses.append(
|
||||
pgcrypto_decrypt(
|
||||
@@ -478,11 +448,7 @@ class PgvectorClient(VectorDBBase):
|
||||
)[key].astext.in_([str(v) for v in in_values])
|
||||
)
|
||||
else:
|
||||
where_clauses.append(
|
||||
DocumentChunk.vmetadata[key].astext.in_(
|
||||
[str(v) for v in in_values]
|
||||
)
|
||||
)
|
||||
where_clauses.append(DocumentChunk.vmetadata[key].astext.in_([str(v) for v in in_values]))
|
||||
else:
|
||||
# Handle simple equality: {"field": "value"}
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
@@ -495,20 +461,16 @@ class PgvectorClient(VectorDBBase):
|
||||
== str(value)
|
||||
)
|
||||
else:
|
||||
where_clauses.append(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
where_clauses.append(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
|
||||
subq = (
|
||||
select(*result_fields)
|
||||
.where(*where_clauses)
|
||||
.order_by(
|
||||
(DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector))
|
||||
)
|
||||
.order_by((DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)))
|
||||
)
|
||||
if limit is not None:
|
||||
subq = subq.limit(limit)
|
||||
subq = subq.lateral("result")
|
||||
subq = subq.lateral('result')
|
||||
|
||||
# Build the main query by joining query_vectors and the lateral subquery
|
||||
stmt = (
|
||||
@@ -550,17 +512,13 @@ class PgvectorClient(VectorDBBase):
|
||||
metadatas[qid].append(row.vmetadata)
|
||||
|
||||
self.session.rollback() # read-only transaction
|
||||
return SearchResult(
|
||||
ids=ids, distances=distances, documents=documents, metadatas=metadatas
|
||||
)
|
||||
return SearchResult(ids=ids, distances=distances, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during search: {e}")
|
||||
log.exception(f'Error during search: {e}')
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
# Build where clause for vmetadata filter
|
||||
@@ -568,32 +526,22 @@ class PgvectorClient(VectorDBBase):
|
||||
for key, value in filter.items():
|
||||
# decrypt then check key: JSON filter after decryption
|
||||
where_clauses.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext
|
||||
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext
|
||||
== str(value)
|
||||
)
|
||||
stmt = select(
|
||||
DocumentChunk.id,
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text"),
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata"),
|
||||
pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'),
|
||||
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'),
|
||||
).where(*where_clauses)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
results = self.session.execute(stmt).all()
|
||||
else:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
||||
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
@@ -615,22 +563,16 @@ class PgvectorClient(VectorDBBase):
|
||||
)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during query: {e}")
|
||||
log.exception(f'Error during query: {e}')
|
||||
return None
|
||||
|
||||
def get(
|
||||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
try:
|
||||
if PGVECTOR_PGCRYPTO:
|
||||
stmt = select(
|
||||
DocumentChunk.id,
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text
|
||||
).label("text"),
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
).label("vmetadata"),
|
||||
pgcrypto_decrypt(DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text).label('text'),
|
||||
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB).label('vmetadata'),
|
||||
).where(DocumentChunk.collection_name == collection_name)
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
@@ -639,10 +581,7 @@ class PgvectorClient(VectorDBBase):
|
||||
documents = [[row.text for row in results]]
|
||||
metadatas = [[row.vmetadata for row in results]]
|
||||
else:
|
||||
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
@@ -659,7 +598,7 @@ class PgvectorClient(VectorDBBase):
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during get: {e}")
|
||||
log.exception(f'Error during get: {e}')
|
||||
return None
|
||||
|
||||
def delete(
|
||||
@@ -676,43 +615,35 @@ class PgvectorClient(VectorDBBase):
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
wheres.append(
|
||||
pgcrypto_decrypt(
|
||||
DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB
|
||||
)[key].astext
|
||||
pgcrypto_decrypt(DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB)[key].astext
|
||||
== str(value)
|
||||
)
|
||||
stmt = DocumentChunk.__table__.delete().where(*wheres)
|
||||
result = self.session.execute(stmt)
|
||||
deleted = result.rowcount
|
||||
else:
|
||||
query = self.session.query(DocumentChunk).filter(
|
||||
DocumentChunk.collection_name == collection_name
|
||||
)
|
||||
query = self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name)
|
||||
if ids:
|
||||
query = query.filter(DocumentChunk.id.in_(ids))
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
query = query.filter(
|
||||
DocumentChunk.vmetadata[key].astext == str(value)
|
||||
)
|
||||
query = query.filter(DocumentChunk.vmetadata[key].astext == str(value))
|
||||
deleted = query.delete(synchronize_session=False)
|
||||
self.session.commit()
|
||||
log.info(f"Deleted {deleted} items from collection '{collection_name}'.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during delete: {e}")
|
||||
log.exception(f'Error during delete: {e}')
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
deleted = self.session.query(DocumentChunk).delete()
|
||||
self.session.commit()
|
||||
log.info(
|
||||
f"Reset complete. Deleted {deleted} items from 'document_chunk' table."
|
||||
)
|
||||
log.info(f"Reset complete. Deleted {deleted} items from 'document_chunk' table.")
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error during reset: {e}")
|
||||
log.exception(f'Error during reset: {e}')
|
||||
raise
|
||||
|
||||
def close(self) -> None:
|
||||
@@ -721,16 +652,14 @@ class PgvectorClient(VectorDBBase):
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
try:
|
||||
exists = (
|
||||
self.session.query(DocumentChunk)
|
||||
.filter(DocumentChunk.collection_name == collection_name)
|
||||
.first()
|
||||
self.session.query(DocumentChunk).filter(DocumentChunk.collection_name == collection_name).first()
|
||||
is not None
|
||||
)
|
||||
self.session.rollback() # read-only transaction
|
||||
return exists
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
log.exception(f"Error checking collection existence: {e}")
|
||||
log.exception(f'Error checking collection existence: {e}')
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
|
||||
@@ -45,7 +45,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
class PineconeClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open-webui"
|
||||
self.collection_prefix = 'open-webui'
|
||||
|
||||
# Validate required configuration
|
||||
self._validate_config()
|
||||
@@ -67,7 +67,7 @@ class PineconeClient(VectorDBBase):
|
||||
timeout=30, # Reasonable timeout for operations
|
||||
)
|
||||
self.using_grpc = True
|
||||
log.info("Using Pinecone gRPC client for optimal performance")
|
||||
log.info('Using Pinecone gRPC client for optimal performance')
|
||||
else:
|
||||
# Fallback to HTTP client with enhanced connection pooling
|
||||
self.client = Pinecone(
|
||||
@@ -76,7 +76,7 @@ class PineconeClient(VectorDBBase):
|
||||
timeout=30, # Reasonable timeout for operations
|
||||
)
|
||||
self.using_grpc = False
|
||||
log.info("Using Pinecone HTTP client (gRPC not available)")
|
||||
log.info('Using Pinecone HTTP client (gRPC not available)')
|
||||
|
||||
# Persistent executor for batch operations
|
||||
self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
|
||||
@@ -88,20 +88,18 @@ class PineconeClient(VectorDBBase):
|
||||
"""Validate that all required configuration variables are set."""
|
||||
missing_vars = []
|
||||
if not PINECONE_API_KEY:
|
||||
missing_vars.append("PINECONE_API_KEY")
|
||||
missing_vars.append('PINECONE_API_KEY')
|
||||
if not PINECONE_ENVIRONMENT:
|
||||
missing_vars.append("PINECONE_ENVIRONMENT")
|
||||
missing_vars.append('PINECONE_ENVIRONMENT')
|
||||
if not PINECONE_INDEX_NAME:
|
||||
missing_vars.append("PINECONE_INDEX_NAME")
|
||||
missing_vars.append('PINECONE_INDEX_NAME')
|
||||
if not PINECONE_DIMENSION:
|
||||
missing_vars.append("PINECONE_DIMENSION")
|
||||
missing_vars.append('PINECONE_DIMENSION')
|
||||
if not PINECONE_CLOUD:
|
||||
missing_vars.append("PINECONE_CLOUD")
|
||||
missing_vars.append('PINECONE_CLOUD')
|
||||
|
||||
if missing_vars:
|
||||
raise ValueError(
|
||||
f"Required configuration missing: {', '.join(missing_vars)}"
|
||||
)
|
||||
raise ValueError(f'Required configuration missing: {", ".join(missing_vars)}')
|
||||
|
||||
def _initialize_index(self) -> None:
|
||||
"""Initialize the Pinecone index."""
|
||||
@@ -126,8 +124,8 @@ class PineconeClient(VectorDBBase):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize Pinecone index: {e}")
|
||||
raise RuntimeError(f"Failed to initialize Pinecone index: {e}")
|
||||
log.error(f'Failed to initialize Pinecone index: {e}')
|
||||
raise RuntimeError(f'Failed to initialize Pinecone index: {e}')
|
||||
|
||||
def _retry_pinecone_operation(self, operation_func, max_retries=3):
|
||||
"""Retry Pinecone operations with exponential backoff for rate limits and network issues."""
|
||||
@@ -140,18 +138,18 @@ class PineconeClient(VectorDBBase):
|
||||
is_retryable = any(
|
||||
keyword in error_str
|
||||
for keyword in [
|
||||
"rate limit",
|
||||
"quota",
|
||||
"timeout",
|
||||
"network",
|
||||
"connection",
|
||||
"unavailable",
|
||||
"internal error",
|
||||
"429",
|
||||
"500",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
'rate limit',
|
||||
'quota',
|
||||
'timeout',
|
||||
'network',
|
||||
'connection',
|
||||
'unavailable',
|
||||
'internal error',
|
||||
'429',
|
||||
'500',
|
||||
'502',
|
||||
'503',
|
||||
'504',
|
||||
]
|
||||
)
|
||||
|
||||
@@ -162,45 +160,42 @@ class PineconeClient(VectorDBBase):
|
||||
# Exponential backoff with jitter
|
||||
delay = (2**attempt) + random.uniform(0, 1)
|
||||
log.warning(
|
||||
f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), "
|
||||
f"retrying in {delay:.2f}s: {e}"
|
||||
f'Pinecone operation failed (attempt {attempt + 1}/{max_retries}), retrying in {delay:.2f}s: {e}'
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
def _create_points(
|
||||
self, items: List[VectorItem], collection_name_with_prefix: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
def _create_points(self, items: List[VectorItem], collection_name_with_prefix: str) -> List[Dict[str, Any]]:
|
||||
"""Convert VectorItem objects to Pinecone point format."""
|
||||
points = []
|
||||
for item in items:
|
||||
# Start with any existing metadata or an empty dict
|
||||
metadata = item.get("metadata", {}).copy() if item.get("metadata") else {}
|
||||
metadata = item.get('metadata', {}).copy() if item.get('metadata') else {}
|
||||
|
||||
# Add text to metadata if available
|
||||
if "text" in item:
|
||||
metadata["text"] = item["text"]
|
||||
if 'text' in item:
|
||||
metadata['text'] = item['text']
|
||||
|
||||
# Always add collection_name to metadata for filtering
|
||||
metadata["collection_name"] = collection_name_with_prefix
|
||||
metadata['collection_name'] = collection_name_with_prefix
|
||||
|
||||
point = {
|
||||
"id": item["id"],
|
||||
"values": item["vector"],
|
||||
"metadata": process_metadata(metadata),
|
||||
'id': item['id'],
|
||||
'values': item['vector'],
|
||||
'metadata': process_metadata(metadata),
|
||||
}
|
||||
points.append(point)
|
||||
return points
|
||||
|
||||
def _get_collection_name_with_prefix(self, collection_name: str) -> str:
|
||||
"""Get the collection name with prefix."""
|
||||
return f"{self.collection_prefix}_{collection_name}"
|
||||
return f'{self.collection_prefix}_{collection_name}'
|
||||
|
||||
def _normalize_distance(self, score: float) -> float:
|
||||
"""Normalize distance score based on the metric used."""
|
||||
if self.metric.lower() == "cosine":
|
||||
if self.metric.lower() == 'cosine':
|
||||
# Cosine similarity ranges from -1 to 1, normalize to 0 to 1
|
||||
return (score + 1.0) / 2.0
|
||||
elif self.metric.lower() in ["euclidean", "dotproduct"]:
|
||||
elif self.metric.lower() in ['euclidean', 'dotproduct']:
|
||||
# These are already suitable for ranking (smaller is better for Euclidean)
|
||||
return score
|
||||
else:
|
||||
@@ -214,68 +209,56 @@ class PineconeClient(VectorDBBase):
|
||||
metadatas = []
|
||||
|
||||
for match in matches:
|
||||
metadata = getattr(match, "metadata", {}) or {}
|
||||
ids.append(match.id if hasattr(match, "id") else match["id"])
|
||||
documents.append(metadata.get("text", ""))
|
||||
metadata = getattr(match, 'metadata', {}) or {}
|
||||
ids.append(match.id if hasattr(match, 'id') else match['id'])
|
||||
documents.append(metadata.get('text', ''))
|
||||
metadatas.append(metadata)
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
'ids': [ids],
|
||||
'documents': [documents],
|
||||
'metadatas': [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
"""Check if a collection exists by searching for at least one item."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
|
||||
|
||||
try:
|
||||
# Search for at least 1 item with this collection name in metadata
|
||||
response = self.index.query(
|
||||
vector=[0.0] * self.dimension, # dummy vector
|
||||
top_k=1,
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
filter={'collection_name': collection_name_with_prefix},
|
||||
include_metadata=False,
|
||||
)
|
||||
matches = getattr(response, "matches", []) or []
|
||||
matches = getattr(response, 'matches', []) or []
|
||||
return len(matches) > 0
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error checking collection '{collection_name_with_prefix}': {e}"
|
||||
)
|
||||
log.exception(f"Error checking collection '{collection_name_with_prefix}': {e}")
|
||||
return False
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""Delete a collection by removing all vectors with the collection name in metadata."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
|
||||
try:
|
||||
self.index.delete(filter={"collection_name": collection_name_with_prefix})
|
||||
log.info(
|
||||
f"Collection '{collection_name_with_prefix}' deleted (all vectors removed)."
|
||||
)
|
||||
self.index.delete(filter={'collection_name': collection_name_with_prefix})
|
||||
log.info(f"Collection '{collection_name_with_prefix}' deleted (all vectors removed).")
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Failed to delete collection '{collection_name_with_prefix}': {e}"
|
||||
)
|
||||
log.warning(f"Failed to delete collection '{collection_name_with_prefix}': {e}")
|
||||
raise
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Insert vectors into a collection."""
|
||||
if not items:
|
||||
log.warning("No items to insert")
|
||||
log.warning('No items to insert')
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Parallelize batch inserts for performance
|
||||
@@ -288,26 +271,23 @@ class PineconeClient(VectorDBBase):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
log.error(f"Error inserting batch: {e}")
|
||||
log.error(f'Error inserting batch: {e}')
|
||||
raise
|
||||
elapsed = time.time() - start_time
|
||||
log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds")
|
||||
log.debug(f'Insert of {len(points)} vectors took {elapsed:.2f} seconds')
|
||||
log.info(
|
||||
f"Successfully inserted {len(points)} vectors in parallel batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Upsert (insert or update) vectors into a collection."""
|
||||
if not items:
|
||||
log.warning("No items to upsert")
|
||||
log.warning('No items to upsert')
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Parallelize batch upserts for performance
|
||||
@@ -320,78 +300,53 @@ class PineconeClient(VectorDBBase):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
log.error(f"Error upserting batch: {e}")
|
||||
log.error(f'Error upserting batch: {e}')
|
||||
raise
|
||||
elapsed = time.time() - start_time
|
||||
log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds")
|
||||
log.debug(f'Upsert of {len(points)} vectors took {elapsed:.2f} seconds')
|
||||
log.info(
|
||||
f"Successfully upserted {len(points)} vectors in parallel batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'"
|
||||
)
|
||||
|
||||
async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Async version of insert using asyncio and run_in_executor for improved performance."""
|
||||
if not items:
|
||||
log.warning("No items to insert")
|
||||
log.warning('No items to insert')
|
||||
return
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Create batches
|
||||
batches = [
|
||||
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
|
||||
]
|
||||
batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)]
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = [
|
||||
loop.run_in_executor(
|
||||
None, functools.partial(self.index.upsert, vectors=batch)
|
||||
)
|
||||
for batch in batches
|
||||
]
|
||||
tasks = [loop.run_in_executor(None, functools.partial(self.index.upsert, vectors=batch)) for batch in batches]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"Error in async insert batch: {result}")
|
||||
log.error(f'Error in async insert batch: {result}')
|
||||
raise result
|
||||
log.info(
|
||||
f"Successfully async inserted {len(points)} vectors in batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
)
|
||||
log.info(f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'")
|
||||
|
||||
async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""Async version of upsert using asyncio and run_in_executor for improved performance."""
|
||||
if not items:
|
||||
log.warning("No items to upsert")
|
||||
log.warning('No items to upsert')
|
||||
return
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
|
||||
points = self._create_points(items, collection_name_with_prefix)
|
||||
|
||||
# Create batches
|
||||
batches = [
|
||||
points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)
|
||||
]
|
||||
batches = [points[i : i + BATCH_SIZE] for i in range(0, len(points), BATCH_SIZE)]
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = [
|
||||
loop.run_in_executor(
|
||||
None, functools.partial(self.index.upsert, vectors=batch)
|
||||
)
|
||||
for batch in batches
|
||||
]
|
||||
tasks = [loop.run_in_executor(None, functools.partial(self.index.upsert, vectors=batch)) for batch in batches]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"Error in async upsert batch: {result}")
|
||||
log.error(f'Error in async upsert batch: {result}')
|
||||
raise result
|
||||
log.info(
|
||||
f"Successfully async upserted {len(points)} vectors in batches "
|
||||
f"into '{collection_name_with_prefix}'"
|
||||
)
|
||||
log.info(f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'")
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -402,12 +357,10 @@ class PineconeClient(VectorDBBase):
|
||||
) -> Optional[SearchResult]:
|
||||
"""Search for similar vectors in a collection."""
|
||||
if not vectors or not vectors[0]:
|
||||
log.warning("No vectors provided for search")
|
||||
log.warning('No vectors provided for search')
|
||||
return None
|
||||
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
|
||||
|
||||
if limit is None or limit <= 0:
|
||||
limit = NO_LIMIT
|
||||
@@ -421,10 +374,10 @@ class PineconeClient(VectorDBBase):
|
||||
vector=query_vector,
|
||||
top_k=limit,
|
||||
include_metadata=True,
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
filter={'collection_name': collection_name_with_prefix},
|
||||
)
|
||||
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
matches = getattr(query_response, 'matches', []) or []
|
||||
if not matches:
|
||||
# Return empty result if no matches
|
||||
return SearchResult(
|
||||
@@ -438,12 +391,7 @@ class PineconeClient(VectorDBBase):
|
||||
get_result = self._result_to_get_result(matches)
|
||||
|
||||
# Calculate normalized distances based on metric
|
||||
distances = [
|
||||
[
|
||||
self._normalize_distance(getattr(match, "score", 0.0))
|
||||
for match in matches
|
||||
]
|
||||
]
|
||||
distances = [[self._normalize_distance(getattr(match, 'score', 0.0)) for match in matches]]
|
||||
|
||||
return SearchResult(
|
||||
ids=get_result.ids,
|
||||
@@ -455,13 +403,9 @@ class PineconeClient(VectorDBBase):
|
||||
log.error(f"Error searching in '{collection_name_with_prefix}': {e}")
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
"""Query vectors by metadata filter."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
|
||||
|
||||
if limit is None or limit <= 0:
|
||||
limit = NO_LIMIT
|
||||
@@ -471,7 +415,7 @@ class PineconeClient(VectorDBBase):
|
||||
zero_vector = [0.0] * self.dimension
|
||||
|
||||
# Combine user filter with collection_name
|
||||
pinecone_filter = {"collection_name": collection_name_with_prefix}
|
||||
pinecone_filter = {'collection_name': collection_name_with_prefix}
|
||||
if filter:
|
||||
pinecone_filter.update(filter)
|
||||
|
||||
@@ -483,7 +427,7 @@ class PineconeClient(VectorDBBase):
|
||||
include_metadata=True,
|
||||
)
|
||||
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
matches = getattr(query_response, 'matches', []) or []
|
||||
return self._result_to_get_result(matches)
|
||||
|
||||
except Exception as e:
|
||||
@@ -492,9 +436,7 @@ class PineconeClient(VectorDBBase):
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
"""Get all vectors in a collection."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
|
||||
|
||||
try:
|
||||
# Use a zero vector for fetching all entries
|
||||
@@ -505,10 +447,10 @@ class PineconeClient(VectorDBBase):
|
||||
vector=zero_vector,
|
||||
top_k=NO_LIMIT,
|
||||
include_metadata=True,
|
||||
filter={"collection_name": collection_name_with_prefix},
|
||||
filter={'collection_name': collection_name_with_prefix},
|
||||
)
|
||||
|
||||
matches = getattr(query_response, "matches", []) or []
|
||||
matches = getattr(query_response, 'matches', []) or []
|
||||
return self._result_to_get_result(matches)
|
||||
|
||||
except Exception as e:
|
||||
@@ -522,9 +464,7 @@ class PineconeClient(VectorDBBase):
|
||||
filter: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Delete vectors by IDs or filter."""
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(
|
||||
collection_name
|
||||
)
|
||||
collection_name_with_prefix = self._get_collection_name_with_prefix(collection_name)
|
||||
|
||||
try:
|
||||
if ids:
|
||||
@@ -534,28 +474,20 @@ class PineconeClient(VectorDBBase):
|
||||
# Note: When deleting by ID, we can't filter by collection_name
|
||||
# This is a limitation of Pinecone - be careful with ID uniqueness
|
||||
self.index.delete(ids=batch_ids)
|
||||
log.debug(
|
||||
f"Deleted batch of {len(batch_ids)} vectors by ID "
|
||||
f"from '{collection_name_with_prefix}'"
|
||||
)
|
||||
log.info(
|
||||
f"Successfully deleted {len(ids)} vectors by ID "
|
||||
f"from '{collection_name_with_prefix}'"
|
||||
)
|
||||
log.debug(f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'")
|
||||
log.info(f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'")
|
||||
|
||||
elif filter:
|
||||
# Combine user filter with collection_name
|
||||
pinecone_filter = {"collection_name": collection_name_with_prefix}
|
||||
pinecone_filter = {'collection_name': collection_name_with_prefix}
|
||||
if filter:
|
||||
pinecone_filter.update(filter)
|
||||
# Delete by metadata filter
|
||||
self.index.delete(filter=pinecone_filter)
|
||||
log.info(
|
||||
f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'"
|
||||
)
|
||||
log.info(f"Successfully deleted vectors by filter from '{collection_name_with_prefix}'")
|
||||
|
||||
else:
|
||||
log.warning("No ids or filter provided for delete operation")
|
||||
log.warning('No ids or filter provided for delete operation')
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting from collection '{collection_name}': {e}")
|
||||
@@ -565,9 +497,9 @@ class PineconeClient(VectorDBBase):
|
||||
"""Reset the database by deleting all collections."""
|
||||
try:
|
||||
self.index.delete(delete_all=True)
|
||||
log.info("All vectors successfully deleted from the index.")
|
||||
log.info('All vectors successfully deleted from the index.')
|
||||
except Exception as e:
|
||||
log.error(f"Failed to reset Pinecone index: {e}")
|
||||
log.error(f'Failed to reset Pinecone index: {e}')
|
||||
raise
|
||||
|
||||
def close(self):
|
||||
@@ -576,7 +508,7 @@ class PineconeClient(VectorDBBase):
|
||||
# The new Pinecone client doesn't need explicit closing
|
||||
pass
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to clean up Pinecone resources: {e}")
|
||||
log.warning(f'Failed to clean up Pinecone resources: {e}')
|
||||
self._executor.shutdown(wait=True)
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
@@ -76,19 +76,19 @@ class QdrantClient(VectorDBBase):
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
ids.append(point.id)
|
||||
documents.append(payload["text"])
|
||||
metadatas.append(payload["metadata"])
|
||||
documents.append(payload['text'])
|
||||
metadatas.append(payload['metadata'])
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
'ids': [ids],
|
||||
'documents': [documents],
|
||||
'metadatas': [metadatas],
|
||||
}
|
||||
)
|
||||
|
||||
def _create_collection(self, collection_name: str, dimension: int):
|
||||
collection_name_with_prefix = f"{self.collection_prefix}_{collection_name}"
|
||||
collection_name_with_prefix = f'{self.collection_prefix}_{collection_name}'
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name_with_prefix,
|
||||
vectors_config=models.VectorParams(
|
||||
@@ -104,7 +104,7 @@ class QdrantClient(VectorDBBase):
|
||||
# Create payload indexes for efficient filtering
|
||||
self.client.create_payload_index(
|
||||
collection_name=collection_name_with_prefix,
|
||||
field_name="metadata.hash",
|
||||
field_name='metadata.hash',
|
||||
field_schema=models.KeywordIndexParams(
|
||||
type=models.KeywordIndexType.KEYWORD,
|
||||
is_tenant=False,
|
||||
@@ -113,40 +113,34 @@ class QdrantClient(VectorDBBase):
|
||||
)
|
||||
self.client.create_payload_index(
|
||||
collection_name=collection_name_with_prefix,
|
||||
field_name="metadata.file_id",
|
||||
field_name='metadata.file_id',
|
||||
field_schema=models.KeywordIndexParams(
|
||||
type=models.KeywordIndexType.KEYWORD,
|
||||
is_tenant=False,
|
||||
on_disk=self.QDRANT_ON_DISK,
|
||||
),
|
||||
)
|
||||
log.info(f"collection {collection_name_with_prefix} successfully created!")
|
||||
log.info(f'collection {collection_name_with_prefix} successfully created!')
|
||||
|
||||
def _create_collection_if_not_exists(self, collection_name, dimension):
|
||||
if not self.has_collection(collection_name=collection_name):
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=dimension
|
||||
)
|
||||
self._create_collection(collection_name=collection_name, dimension=dimension)
|
||||
|
||||
def _create_points(self, items: list[VectorItem]):
|
||||
return [
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=item["vector"],
|
||||
payload={"text": item["text"], "metadata": item["metadata"]},
|
||||
id=item['id'],
|
||||
vector=item['vector'],
|
||||
payload={'text': item['text'], 'metadata': item['metadata']},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
return self.client.collection_exists(
|
||||
f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
return self.client.collection_exists(f'{self.collection_prefix}_{collection_name}')
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
return self.client.delete_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
return self.client.delete_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -160,7 +154,7 @@ class QdrantClient(VectorDBBase):
|
||||
limit = NO_LIMIT # otherwise qdrant would set limit to 10!
|
||||
|
||||
query_response = self.client.query_points(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
query=vectors[0],
|
||||
limit=limit,
|
||||
)
|
||||
@@ -184,13 +178,11 @@ class QdrantClient(VectorDBBase):
|
||||
field_conditions = []
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}", match=models.MatchValue(value=value)
|
||||
)
|
||||
models.FieldCondition(key=f'metadata.{key}', match=models.MatchValue(value=value))
|
||||
)
|
||||
|
||||
points = self.client.scroll(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
scroll_filter=models.Filter(should=field_conditions),
|
||||
limit=limit,
|
||||
)
|
||||
@@ -202,22 +194,22 @@ class QdrantClient(VectorDBBase):
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
points = self.client.scroll(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
limit=NO_LIMIT, # otherwise qdrant would set limit to 10!
|
||||
)
|
||||
return self._result_to_get_result(points[0])
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]['vector']))
|
||||
points = self._create_points(items)
|
||||
self.client.upload_points(f"{self.collection_prefix}_{collection_name}", points)
|
||||
self.client.upload_points(f'{self.collection_prefix}_{collection_name}', points)
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]["vector"]))
|
||||
self._create_collection_if_not_exists(collection_name, len(items[0]['vector']))
|
||||
points = self._create_points(items)
|
||||
return self.client.upsert(f"{self.collection_prefix}_{collection_name}", points)
|
||||
return self.client.upsert(f'{self.collection_prefix}_{collection_name}', points)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
@@ -230,26 +222,28 @@ class QdrantClient(VectorDBBase):
|
||||
|
||||
if ids:
|
||||
for id_value in ids:
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key="metadata.id",
|
||||
match=models.MatchValue(value=id_value),
|
||||
(
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key='metadata.id',
|
||||
match=models.MatchValue(value=id_value),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
elif filter:
|
||||
for key, value in filter.items():
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f"metadata.{key}",
|
||||
match=models.MatchValue(value=value),
|
||||
(
|
||||
field_conditions.append(
|
||||
models.FieldCondition(
|
||||
key=f'metadata.{key}',
|
||||
match=models.MatchValue(value=value),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=field_conditions)
|
||||
),
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
points_selector=models.FilterSelector(filter=models.Filter(must=field_conditions)),
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -29,22 +29,18 @@ from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
TENANT_ID_FIELD = "tenant_id"
|
||||
TENANT_ID_FIELD = 'tenant_id'
|
||||
DEFAULT_DIMENSION = 384
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _tenant_filter(tenant_id: str) -> models.FieldCondition:
|
||||
return models.FieldCondition(
|
||||
key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id)
|
||||
)
|
||||
return models.FieldCondition(key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id))
|
||||
|
||||
|
||||
def _metadata_filter(key: str, value: Any) -> models.FieldCondition:
|
||||
return models.FieldCondition(
|
||||
key=f"metadata.{key}", match=models.MatchValue(value=value)
|
||||
)
|
||||
return models.FieldCondition(key=f'metadata.{key}', match=models.MatchValue(value=value))
|
||||
|
||||
|
||||
class QdrantClient(VectorDBBase):
|
||||
@@ -59,9 +55,7 @@ class QdrantClient(VectorDBBase):
|
||||
self.QDRANT_HNSW_M = QDRANT_HNSW_M
|
||||
|
||||
if not self.QDRANT_URI:
|
||||
raise ValueError(
|
||||
"QDRANT_URI is not set. Please configure it in the environment variables."
|
||||
)
|
||||
raise ValueError('QDRANT_URI is not set. Please configure it in the environment variables.')
|
||||
|
||||
# Unified handling for either scheme
|
||||
parsed = urlparse(self.QDRANT_URI)
|
||||
@@ -86,19 +80,19 @@ class QdrantClient(VectorDBBase):
|
||||
)
|
||||
|
||||
# Main collection types for multi-tenancy
|
||||
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
||||
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
|
||||
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
|
||||
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search"
|
||||
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based"
|
||||
self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories'
|
||||
self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge'
|
||||
self.FILE_COLLECTION = f'{self.collection_prefix}_files'
|
||||
self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web-search'
|
||||
self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash-based'
|
||||
|
||||
def _result_to_get_result(self, points) -> GetResult:
|
||||
ids, documents, metadatas = [], [], []
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
ids.append(point.id)
|
||||
documents.append(payload["text"])
|
||||
metadatas.append(payload["metadata"])
|
||||
documents.append(payload['text'])
|
||||
metadatas.append(payload['metadata'])
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]:
|
||||
@@ -118,29 +112,25 @@ class QdrantClient(VectorDBBase):
|
||||
# Check for user memory collections
|
||||
tenant_id = collection_name
|
||||
|
||||
if collection_name.startswith("user-memory-"):
|
||||
if collection_name.startswith('user-memory-'):
|
||||
return self.MEMORY_COLLECTION, tenant_id
|
||||
|
||||
# Check for file collections
|
||||
elif collection_name.startswith("file-"):
|
||||
elif collection_name.startswith('file-'):
|
||||
return self.FILE_COLLECTION, tenant_id
|
||||
|
||||
# Check for web search collections
|
||||
elif collection_name.startswith("web-search-"):
|
||||
elif collection_name.startswith('web-search-'):
|
||||
return self.WEB_SEARCH_COLLECTION, tenant_id
|
||||
|
||||
# Handle hash-based collections (YouTube and web URLs)
|
||||
elif len(collection_name) == 63 and all(
|
||||
c in "0123456789abcdef" for c in collection_name
|
||||
):
|
||||
elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name):
|
||||
return self.HASH_BASED_COLLECTION, tenant_id
|
||||
|
||||
else:
|
||||
return self.KNOWLEDGE_COLLECTION, tenant_id
|
||||
|
||||
def _create_multi_tenant_collection(
|
||||
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
|
||||
):
|
||||
def _create_multi_tenant_collection(self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION):
|
||||
"""
|
||||
Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields.
|
||||
"""
|
||||
@@ -158,9 +148,7 @@ class QdrantClient(VectorDBBase):
|
||||
m=0,
|
||||
),
|
||||
)
|
||||
log.info(
|
||||
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!"
|
||||
)
|
||||
log.info(f'Multi-tenant collection {mt_collection_name} created with dimension {dimension}!')
|
||||
|
||||
self.client.create_payload_index(
|
||||
collection_name=mt_collection_name,
|
||||
@@ -172,7 +160,7 @@ class QdrantClient(VectorDBBase):
|
||||
),
|
||||
)
|
||||
|
||||
for field in ("metadata.hash", "metadata.file_id"):
|
||||
for field in ('metadata.hash', 'metadata.file_id'):
|
||||
self.client.create_payload_index(
|
||||
collection_name=mt_collection_name,
|
||||
field_name=field,
|
||||
@@ -182,28 +170,24 @@ class QdrantClient(VectorDBBase):
|
||||
),
|
||||
)
|
||||
|
||||
def _create_points(
|
||||
self, items: List[VectorItem], tenant_id: str
|
||||
) -> List[PointStruct]:
|
||||
def _create_points(self, items: List[VectorItem], tenant_id: str) -> List[PointStruct]:
|
||||
"""
|
||||
Create point structs from vector items with tenant ID.
|
||||
"""
|
||||
return [
|
||||
PointStruct(
|
||||
id=item["id"],
|
||||
vector=item["vector"],
|
||||
id=item['id'],
|
||||
vector=item['vector'],
|
||||
payload={
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
'text': item['text'],
|
||||
'metadata': item['metadata'],
|
||||
TENANT_ID_FIELD: tenant_id,
|
||||
},
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
def _ensure_collection(
|
||||
self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION
|
||||
):
|
||||
def _ensure_collection(self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION):
|
||||
"""
|
||||
Ensure the collection exists and payload indexes are created for tenant_id and metadata fields.
|
||||
"""
|
||||
@@ -246,15 +230,13 @@ class QdrantClient(VectorDBBase):
|
||||
must_conditions = [_tenant_filter(tenant_id)]
|
||||
should_conditions = []
|
||||
if ids:
|
||||
should_conditions = [_metadata_filter("id", id_value) for id_value in ids]
|
||||
should_conditions = [_metadata_filter('id', id_value) for id_value in ids]
|
||||
elif filter:
|
||||
must_conditions += [_metadata_filter(k, v) for k, v in filter.items()]
|
||||
|
||||
return self.client.delete(
|
||||
collection_name=mt_collection,
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=must_conditions, should=should_conditions)
|
||||
),
|
||||
points_selector=models.FilterSelector(filter=models.Filter(must=must_conditions, should=should_conditions)),
|
||||
)
|
||||
|
||||
def search(
|
||||
@@ -289,9 +271,7 @@ class QdrantClient(VectorDBBase):
|
||||
distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]],
|
||||
)
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
):
|
||||
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None):
|
||||
"""
|
||||
Query points with filters and tenant isolation.
|
||||
"""
|
||||
@@ -338,7 +318,7 @@ class QdrantClient(VectorDBBase):
|
||||
if not self.client or not items:
|
||||
return None
|
||||
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name)
|
||||
dimension = len(items[0]["vector"])
|
||||
dimension = len(items[0]['vector'])
|
||||
self._ensure_collection(mt_collection, dimension)
|
||||
points = self._create_points(items, tenant_id)
|
||||
self.client.upload_points(mt_collection, points)
|
||||
@@ -372,7 +352,5 @@ class QdrantClient(VectorDBBase):
|
||||
return None
|
||||
self.client.delete(
|
||||
collection_name=mt_collection,
|
||||
points_selector=models.FilterSelector(
|
||||
filter=models.Filter(must=[_tenant_filter(tenant_id)])
|
||||
),
|
||||
points_selector=models.FilterSelector(filter=models.Filter(must=[_tenant_filter(tenant_id)])),
|
||||
)
|
||||
|
||||
@@ -28,18 +28,16 @@ class S3VectorClient(VectorDBBase):
|
||||
|
||||
# Simple validation - log warnings instead of raising exceptions
|
||||
if not self.bucket_name:
|
||||
log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work")
|
||||
log.warning('S3_VECTOR_BUCKET_NAME not set - S3Vector will not work')
|
||||
if not self.region:
|
||||
log.warning("S3_VECTOR_REGION not set - S3Vector will not work")
|
||||
log.warning('S3_VECTOR_REGION not set - S3Vector will not work')
|
||||
|
||||
if self.bucket_name and self.region:
|
||||
try:
|
||||
self.client = boto3.client("s3vectors", region_name=self.region)
|
||||
log.info(
|
||||
f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'"
|
||||
)
|
||||
self.client = boto3.client('s3vectors', region_name=self.region)
|
||||
log.info(f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize S3Vector client: {e}")
|
||||
log.error(f'Failed to initialize S3Vector client: {e}')
|
||||
self.client = None
|
||||
else:
|
||||
self.client = None
|
||||
@@ -48,8 +46,8 @@ class S3VectorClient(VectorDBBase):
|
||||
self,
|
||||
index_name: str,
|
||||
dimension: int,
|
||||
data_type: str = "float32",
|
||||
distance_metric: str = "cosine",
|
||||
data_type: str = 'float32',
|
||||
distance_metric: str = 'cosine',
|
||||
) -> None:
|
||||
"""
|
||||
Create a new index in the S3 vector bucket for the given collection if it does not exist.
|
||||
@@ -66,21 +64,17 @@ class S3VectorClient(VectorDBBase):
|
||||
dimension=dimension,
|
||||
distanceMetric=distance_metric,
|
||||
metadataConfiguration={
|
||||
"nonFilterableMetadataKeys": [
|
||||
"text",
|
||||
'nonFilterableMetadataKeys': [
|
||||
'text',
|
||||
]
|
||||
},
|
||||
)
|
||||
log.info(
|
||||
f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})"
|
||||
)
|
||||
log.info(f'Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})')
|
||||
except Exception as e:
|
||||
log.error(f"Error creating S3 index '{index_name}': {e}")
|
||||
raise
|
||||
|
||||
def _filter_metadata(
|
||||
self, metadata: Dict[str, Any], item_id: str
|
||||
) -> Dict[str, Any]:
|
||||
def _filter_metadata(self, metadata: Dict[str, Any], item_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum.
|
||||
"""
|
||||
@@ -89,16 +83,16 @@ class S3VectorClient(VectorDBBase):
|
||||
|
||||
# Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata
|
||||
important_keys = [
|
||||
"text", # The actual document content
|
||||
"file_id", # File ID
|
||||
"source", # Document source file
|
||||
"title", # Document title
|
||||
"page", # Page number
|
||||
"total_pages", # Total pages in document
|
||||
"embedding_config", # Embedding configuration
|
||||
"created_by", # User who created it
|
||||
"name", # Document name
|
||||
"hash", # Content hash
|
||||
'text', # The actual document content
|
||||
'file_id', # File ID
|
||||
'source', # Document source file
|
||||
'title', # Document title
|
||||
'page', # Page number
|
||||
'total_pages', # Total pages in document
|
||||
'embedding_config', # Embedding configuration
|
||||
'created_by', # User who created it
|
||||
'name', # Document name
|
||||
'hash', # Content hash
|
||||
]
|
||||
filtered_metadata = {}
|
||||
|
||||
@@ -117,9 +111,7 @@ class S3VectorClient(VectorDBBase):
|
||||
if len(filtered_metadata) >= 10:
|
||||
break
|
||||
|
||||
log.warning(
|
||||
f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys"
|
||||
)
|
||||
log.warning(f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys")
|
||||
return filtered_metadata
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
@@ -128,9 +120,7 @@ class S3VectorClient(VectorDBBase):
|
||||
This avoids pagination issues with list_indexes() and is significantly faster.
|
||||
"""
|
||||
try:
|
||||
self.client.get_index(
|
||||
vectorBucketName=self.bucket_name, indexName=collection_name
|
||||
)
|
||||
self.client.get_index(vectorBucketName=self.bucket_name, indexName=collection_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error checking if index '{collection_name}' exists: {e}")
|
||||
@@ -142,16 +132,12 @@ class S3VectorClient(VectorDBBase):
|
||||
"""
|
||||
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(
|
||||
f"Collection '{collection_name}' does not exist, nothing to delete"
|
||||
)
|
||||
log.warning(f"Collection '{collection_name}' does not exist, nothing to delete")
|
||||
return
|
||||
|
||||
try:
|
||||
log.info(f"Deleting collection '{collection_name}'")
|
||||
self.client.delete_index(
|
||||
vectorBucketName=self.bucket_name, indexName=collection_name
|
||||
)
|
||||
self.client.delete_index(vectorBucketName=self.bucket_name, indexName=collection_name)
|
||||
log.info(f"Successfully deleted collection '{collection_name}'")
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting collection '{collection_name}': {e}")
|
||||
@@ -162,10 +148,10 @@ class S3VectorClient(VectorDBBase):
|
||||
Insert vector items into the S3 Vector index. Create index if it does not exist.
|
||||
"""
|
||||
if not items:
|
||||
log.warning("No items to insert")
|
||||
log.warning('No items to insert')
|
||||
return
|
||||
|
||||
dimension = len(items[0]["vector"])
|
||||
dimension = len(items[0]['vector'])
|
||||
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
@@ -173,36 +159,36 @@ class S3VectorClient(VectorDBBase):
|
||||
self._create_index(
|
||||
index_name=collection_name,
|
||||
dimension=dimension,
|
||||
data_type="float32",
|
||||
distance_metric="cosine",
|
||||
data_type='float32',
|
||||
distance_metric='cosine',
|
||||
)
|
||||
|
||||
# Prepare vectors for insertion
|
||||
vectors = []
|
||||
for item in items:
|
||||
# Ensure vector data is in the correct format for S3 Vector API
|
||||
vector_data = item["vector"]
|
||||
vector_data = item['vector']
|
||||
if isinstance(vector_data, list):
|
||||
# Convert list to float32 values as required by S3 Vector API
|
||||
vector_data = [float(x) for x in vector_data]
|
||||
|
||||
# Prepare metadata, ensuring the text field is preserved
|
||||
metadata = item.get("metadata", {}).copy()
|
||||
metadata = item.get('metadata', {}).copy()
|
||||
|
||||
# Add the text field to metadata so it's available for retrieval
|
||||
metadata["text"] = item["text"]
|
||||
metadata['text'] = item['text']
|
||||
|
||||
# Convert metadata to string format for consistency
|
||||
metadata = process_metadata(metadata)
|
||||
|
||||
# Filter metadata to comply with S3 Vector API limit of 10 keys
|
||||
metadata = self._filter_metadata(metadata, item["id"])
|
||||
metadata = self._filter_metadata(metadata, item['id'])
|
||||
|
||||
vectors.append(
|
||||
{
|
||||
"key": item["id"],
|
||||
"data": {"float32": vector_data},
|
||||
"metadata": metadata,
|
||||
'key': item['id'],
|
||||
'data': {'float32': vector_data},
|
||||
'metadata': metadata,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -215,15 +201,11 @@ class S3VectorClient(VectorDBBase):
|
||||
indexName=collection_name,
|
||||
vectors=batch,
|
||||
)
|
||||
log.info(
|
||||
f"Inserted batch {i//batch_size + 1}: {len(batch)} vectors into index '{collection_name}'."
|
||||
)
|
||||
log.info(f"Inserted batch {i // batch_size + 1}: {len(batch)} vectors into index '{collection_name}'.")
|
||||
|
||||
log.info(
|
||||
f"Completed insertion of {len(vectors)} vectors into index '{collection_name}'."
|
||||
)
|
||||
log.info(f"Completed insertion of {len(vectors)} vectors into index '{collection_name}'.")
|
||||
except Exception as e:
|
||||
log.error(f"Error inserting vectors: {e}")
|
||||
log.error(f'Error inserting vectors: {e}')
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -231,49 +213,47 @@ class S3VectorClient(VectorDBBase):
|
||||
Insert or update vector items in the S3 Vector index. Create index if it does not exist.
|
||||
"""
|
||||
if not items:
|
||||
log.warning("No items to upsert")
|
||||
log.warning('No items to upsert')
|
||||
return
|
||||
|
||||
dimension = len(items[0]["vector"])
|
||||
log.info(f"Upsert dimension: {dimension}")
|
||||
dimension = len(items[0]['vector'])
|
||||
log.info(f'Upsert dimension: {dimension}')
|
||||
|
||||
try:
|
||||
if not self.has_collection(collection_name):
|
||||
log.info(
|
||||
f"Index '{collection_name}' does not exist. Creating index for upsert."
|
||||
)
|
||||
log.info(f"Index '{collection_name}' does not exist. Creating index for upsert.")
|
||||
self._create_index(
|
||||
index_name=collection_name,
|
||||
dimension=dimension,
|
||||
data_type="float32",
|
||||
distance_metric="cosine",
|
||||
data_type='float32',
|
||||
distance_metric='cosine',
|
||||
)
|
||||
|
||||
# Prepare vectors for upsert
|
||||
vectors = []
|
||||
for item in items:
|
||||
# Ensure vector data is in the correct format for S3 Vector API
|
||||
vector_data = item["vector"]
|
||||
vector_data = item['vector']
|
||||
if isinstance(vector_data, list):
|
||||
# Convert list to float32 values as required by S3 Vector API
|
||||
vector_data = [float(x) for x in vector_data]
|
||||
|
||||
# Prepare metadata, ensuring the text field is preserved
|
||||
metadata = item.get("metadata", {}).copy()
|
||||
metadata = item.get('metadata', {}).copy()
|
||||
# Add the text field to metadata so it's available for retrieval
|
||||
metadata["text"] = item["text"]
|
||||
metadata['text'] = item['text']
|
||||
|
||||
# Convert metadata to string format for consistency
|
||||
metadata = process_metadata(metadata)
|
||||
|
||||
# Filter metadata to comply with S3 Vector API limit of 10 keys
|
||||
metadata = self._filter_metadata(metadata, item["id"])
|
||||
metadata = self._filter_metadata(metadata, item['id'])
|
||||
|
||||
vectors.append(
|
||||
{
|
||||
"key": item["id"],
|
||||
"data": {"float32": vector_data},
|
||||
"metadata": metadata,
|
||||
'key': item['id'],
|
||||
'data': {'float32': vector_data},
|
||||
'metadata': metadata,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -283,12 +263,10 @@ class S3VectorClient(VectorDBBase):
|
||||
batch = vectors[i : i + batch_size]
|
||||
if i == 0: # Log sample info for first batch only
|
||||
log.info(
|
||||
f"Upserting batch 1: {len(batch)} vectors. First vector sample: key={batch[0]['key']}, data_type={type(batch[0]['data']['float32'])}, data_len={len(batch[0]['data']['float32'])}"
|
||||
f'Upserting batch 1: {len(batch)} vectors. First vector sample: key={batch[0]["key"]}, data_type={type(batch[0]["data"]["float32"])}, data_len={len(batch[0]["data"]["float32"])}'
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"Upserting batch {i//batch_size + 1}: {len(batch)} vectors."
|
||||
)
|
||||
log.info(f'Upserting batch {i // batch_size + 1}: {len(batch)} vectors.')
|
||||
|
||||
self.client.put_vectors(
|
||||
vectorBucketName=self.bucket_name,
|
||||
@@ -296,11 +274,9 @@ class S3VectorClient(VectorDBBase):
|
||||
vectors=batch,
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Completed upsert of {len(vectors)} vectors into index '{collection_name}'."
|
||||
)
|
||||
log.info(f"Completed upsert of {len(vectors)} vectors into index '{collection_name}'.")
|
||||
except Exception as e:
|
||||
log.error(f"Error upserting vectors: {e}")
|
||||
log.error(f'Error upserting vectors: {e}')
|
||||
raise
|
||||
|
||||
def search(
|
||||
@@ -319,13 +295,11 @@ class S3VectorClient(VectorDBBase):
|
||||
return None
|
||||
|
||||
if not vectors:
|
||||
log.warning("No query vectors provided")
|
||||
log.warning('No query vectors provided')
|
||||
return None
|
||||
|
||||
try:
|
||||
log.info(
|
||||
f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}"
|
||||
)
|
||||
log.info(f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}")
|
||||
|
||||
# Initialize result lists
|
||||
all_ids = []
|
||||
@@ -335,10 +309,10 @@ class S3VectorClient(VectorDBBase):
|
||||
|
||||
# Process each query vector
|
||||
for i, query_vector in enumerate(vectors):
|
||||
log.debug(f"Processing query vector {i+1}/{len(vectors)}")
|
||||
log.debug(f'Processing query vector {i + 1}/{len(vectors)}')
|
||||
|
||||
# Prepare the query vector in S3 Vector format
|
||||
query_vector_dict = {"float32": [float(x) for x in query_vector]}
|
||||
query_vector_dict = {'float32': [float(x) for x in query_vector]}
|
||||
|
||||
# Call S3 Vector query API
|
||||
response = self.client.query_vectors(
|
||||
@@ -356,24 +330,22 @@ class S3VectorClient(VectorDBBase):
|
||||
query_metadatas = []
|
||||
query_distances = []
|
||||
|
||||
result_vectors = response.get("vectors", [])
|
||||
result_vectors = response.get('vectors', [])
|
||||
|
||||
for vector in result_vectors:
|
||||
vector_id = vector.get("key")
|
||||
vector_metadata = vector.get("metadata", {})
|
||||
vector_distance = vector.get("distance", 0.0)
|
||||
vector_id = vector.get('key')
|
||||
vector_metadata = vector.get('metadata', {})
|
||||
vector_distance = vector.get('distance', 0.0)
|
||||
|
||||
# Extract document text from metadata
|
||||
document_text = ""
|
||||
document_text = ''
|
||||
if isinstance(vector_metadata, dict):
|
||||
# Get the text field first (highest priority)
|
||||
document_text = vector_metadata.get("text")
|
||||
document_text = vector_metadata.get('text')
|
||||
if not document_text:
|
||||
# Fallback to other possible text fields
|
||||
document_text = (
|
||||
vector_metadata.get("content")
|
||||
or vector_metadata.get("document")
|
||||
or vector_id
|
||||
vector_metadata.get('content') or vector_metadata.get('document') or vector_id
|
||||
)
|
||||
else:
|
||||
document_text = vector_id
|
||||
@@ -389,7 +361,7 @@ class S3VectorClient(VectorDBBase):
|
||||
all_metadatas.append(query_metadatas)
|
||||
all_distances.append(query_distances)
|
||||
|
||||
log.info(f"Search completed. Found results for {len(all_ids)} queries")
|
||||
log.info(f'Search completed. Found results for {len(all_ids)} queries')
|
||||
|
||||
# Return SearchResult format
|
||||
return SearchResult(
|
||||
@@ -402,24 +374,20 @@ class S3VectorClient(VectorDBBase):
|
||||
except Exception as e:
|
||||
log.error(f"Error searching collection '{collection_name}': {str(e)}")
|
||||
# Handle specific AWS exceptions
|
||||
if hasattr(e, "response") and "Error" in e.response:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
if error_code == "NotFoundException":
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error']['Code']
|
||||
if error_code == 'NotFoundException':
|
||||
log.warning(f"Collection '{collection_name}' not found")
|
||||
return None
|
||||
elif error_code == "ValidationException":
|
||||
log.error(f"Invalid query vector dimensions or parameters")
|
||||
elif error_code == 'ValidationException':
|
||||
log.error(f'Invalid query vector dimensions or parameters')
|
||||
return None
|
||||
elif error_code == "AccessDeniedException":
|
||||
log.error(
|
||||
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||
)
|
||||
elif error_code == 'AccessDeniedException':
|
||||
log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
|
||||
return None
|
||||
raise
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
"""
|
||||
Query vectors from a collection using metadata filter.
|
||||
"""
|
||||
@@ -429,7 +397,7 @@ class S3VectorClient(VectorDBBase):
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
|
||||
if not filter:
|
||||
log.warning("No filter provided, returning all vectors")
|
||||
log.warning('No filter provided, returning all vectors')
|
||||
return self.get(collection_name)
|
||||
|
||||
try:
|
||||
@@ -443,17 +411,13 @@ class S3VectorClient(VectorDBBase):
|
||||
all_vectors_result = self.get(collection_name)
|
||||
|
||||
if not all_vectors_result or not all_vectors_result.ids:
|
||||
log.warning("No vectors found in collection")
|
||||
log.warning('No vectors found in collection')
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
|
||||
# Extract the lists from the result
|
||||
all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else []
|
||||
all_documents = (
|
||||
all_vectors_result.documents[0] if all_vectors_result.documents else []
|
||||
)
|
||||
all_metadatas = (
|
||||
all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
|
||||
)
|
||||
all_documents = all_vectors_result.documents[0] if all_vectors_result.documents else []
|
||||
all_metadatas = all_vectors_result.metadatas[0] if all_vectors_result.metadatas else []
|
||||
|
||||
# Apply client-side filtering
|
||||
filtered_ids = []
|
||||
@@ -472,9 +436,7 @@ class S3VectorClient(VectorDBBase):
|
||||
if limit and len(filtered_ids) >= limit:
|
||||
break
|
||||
|
||||
log.info(
|
||||
f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total"
|
||||
)
|
||||
log.info(f'Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total')
|
||||
|
||||
# Return GetResult format
|
||||
if filtered_ids:
|
||||
@@ -489,15 +451,13 @@ class S3VectorClient(VectorDBBase):
|
||||
except Exception as e:
|
||||
log.error(f"Error querying collection '{collection_name}': {str(e)}")
|
||||
# Handle specific AWS exceptions
|
||||
if hasattr(e, "response") and "Error" in e.response:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
if error_code == "NotFoundException":
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error']['Code']
|
||||
if error_code == 'NotFoundException':
|
||||
log.warning(f"Collection '{collection_name}' not found")
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
elif error_code == "AccessDeniedException":
|
||||
log.error(
|
||||
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||
)
|
||||
elif error_code == 'AccessDeniedException':
|
||||
log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
raise
|
||||
|
||||
@@ -524,47 +484,43 @@ class S3VectorClient(VectorDBBase):
|
||||
while True:
|
||||
# Prepare request parameters
|
||||
request_params = {
|
||||
"vectorBucketName": self.bucket_name,
|
||||
"indexName": collection_name,
|
||||
"returnData": False, # Don't include vector data (not needed for get)
|
||||
"returnMetadata": True, # Include metadata
|
||||
"maxResults": 500, # Use reasonable page size
|
||||
'vectorBucketName': self.bucket_name,
|
||||
'indexName': collection_name,
|
||||
'returnData': False, # Don't include vector data (not needed for get)
|
||||
'returnMetadata': True, # Include metadata
|
||||
'maxResults': 500, # Use reasonable page size
|
||||
}
|
||||
|
||||
if next_token:
|
||||
request_params["nextToken"] = next_token
|
||||
request_params['nextToken'] = next_token
|
||||
|
||||
# Call S3 Vector API
|
||||
response = self.client.list_vectors(**request_params)
|
||||
|
||||
# Process vectors in this page
|
||||
vectors = response.get("vectors", [])
|
||||
vectors = response.get('vectors', [])
|
||||
|
||||
for vector in vectors:
|
||||
vector_id = vector.get("key")
|
||||
vector_data = vector.get("data", {})
|
||||
vector_metadata = vector.get("metadata", {})
|
||||
vector_id = vector.get('key')
|
||||
vector_data = vector.get('data', {})
|
||||
vector_metadata = vector.get('metadata', {})
|
||||
|
||||
# Extract the actual vector array
|
||||
vector_array = vector_data.get("float32", [])
|
||||
vector_array = vector_data.get('float32', [])
|
||||
|
||||
# For documents, we try to extract text from metadata or use the vector ID
|
||||
document_text = ""
|
||||
document_text = ''
|
||||
if isinstance(vector_metadata, dict):
|
||||
# Get the text field first (highest priority)
|
||||
document_text = vector_metadata.get("text")
|
||||
document_text = vector_metadata.get('text')
|
||||
if not document_text:
|
||||
# Fallback to other possible text fields
|
||||
document_text = (
|
||||
vector_metadata.get("content")
|
||||
or vector_metadata.get("document")
|
||||
or vector_id
|
||||
vector_metadata.get('content') or vector_metadata.get('document') or vector_id
|
||||
)
|
||||
|
||||
# Log the actual content for debugging
|
||||
log.debug(
|
||||
f"Document text preview (first 200 chars): {str(document_text)[:200]}"
|
||||
)
|
||||
log.debug(f'Document text preview (first 200 chars): {str(document_text)[:200]}')
|
||||
else:
|
||||
document_text = vector_id
|
||||
|
||||
@@ -573,37 +529,29 @@ class S3VectorClient(VectorDBBase):
|
||||
all_metadatas.append(vector_metadata)
|
||||
|
||||
# Check if there are more pages
|
||||
next_token = response.get("nextToken")
|
||||
next_token = response.get('nextToken')
|
||||
if not next_token:
|
||||
break
|
||||
|
||||
log.info(
|
||||
f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'"
|
||||
)
|
||||
log.info(f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'")
|
||||
|
||||
# Return in GetResult format
|
||||
# The Open WebUI GetResult expects lists of lists, so we wrap each list
|
||||
if all_ids:
|
||||
return GetResult(
|
||||
ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas]
|
||||
)
|
||||
return GetResult(ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas])
|
||||
else:
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error retrieving vectors from collection '{collection_name}': {str(e)}"
|
||||
)
|
||||
log.error(f"Error retrieving vectors from collection '{collection_name}': {str(e)}")
|
||||
# Handle specific AWS exceptions
|
||||
if hasattr(e, "response") and "Error" in e.response:
|
||||
error_code = e.response["Error"]["Code"]
|
||||
if error_code == "NotFoundException":
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error']['Code']
|
||||
if error_code == 'NotFoundException':
|
||||
log.warning(f"Collection '{collection_name}' not found")
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
elif error_code == "AccessDeniedException":
|
||||
log.error(
|
||||
f"Access denied for collection '{collection_name}'. Check permissions."
|
||||
)
|
||||
elif error_code == 'AccessDeniedException':
|
||||
log.error(f"Access denied for collection '{collection_name}'. Check permissions.")
|
||||
return GetResult(ids=[[]], documents=[[]], metadatas=[[]])
|
||||
raise
|
||||
|
||||
@@ -618,20 +566,16 @@ class S3VectorClient(VectorDBBase):
|
||||
"""
|
||||
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(
|
||||
f"Collection '{collection_name}' does not exist, nothing to delete"
|
||||
)
|
||||
log.warning(f"Collection '{collection_name}' does not exist, nothing to delete")
|
||||
return
|
||||
|
||||
# Check if this is a knowledge collection (not file-specific)
|
||||
is_knowledge_collection = not collection_name.startswith("file-")
|
||||
is_knowledge_collection = not collection_name.startswith('file-')
|
||||
|
||||
try:
|
||||
if ids:
|
||||
# Delete by specific vector IDs/keys
|
||||
log.info(
|
||||
f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'"
|
||||
)
|
||||
log.info(f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'")
|
||||
self.client.delete_vectors(
|
||||
vectorBucketName=self.bucket_name,
|
||||
indexName=collection_name,
|
||||
@@ -641,15 +585,13 @@ class S3VectorClient(VectorDBBase):
|
||||
|
||||
elif filter:
|
||||
# Handle filter-based deletion
|
||||
log.info(
|
||||
f"Deleting vectors by filter from collection '{collection_name}': {filter}"
|
||||
)
|
||||
log.info(f"Deleting vectors by filter from collection '{collection_name}': {filter}")
|
||||
|
||||
# If this is a knowledge collection and we have a file_id filter,
|
||||
# also clean up the corresponding file-specific collection
|
||||
if is_knowledge_collection and "file_id" in filter:
|
||||
file_id = filter["file_id"]
|
||||
file_collection_name = f"file-{file_id}"
|
||||
if is_knowledge_collection and 'file_id' in filter:
|
||||
file_id = filter['file_id']
|
||||
file_collection_name = f'file-{file_id}'
|
||||
if self.has_collection(file_collection_name):
|
||||
log.info(
|
||||
f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates"
|
||||
@@ -661,9 +603,7 @@ class S3VectorClient(VectorDBBase):
|
||||
query_result = self.query(collection_name, filter)
|
||||
if query_result and query_result.ids and query_result.ids[0]:
|
||||
matching_ids = query_result.ids[0]
|
||||
log.info(
|
||||
f"Found {len(matching_ids)} vectors matching filter, deleting them"
|
||||
)
|
||||
log.info(f'Found {len(matching_ids)} vectors matching filter, deleting them')
|
||||
|
||||
# Delete the matching vectors by ID
|
||||
self.client.delete_vectors(
|
||||
@@ -671,17 +611,13 @@ class S3VectorClient(VectorDBBase):
|
||||
indexName=collection_name,
|
||||
keys=matching_ids,
|
||||
)
|
||||
log.info(
|
||||
f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter"
|
||||
)
|
||||
log.info(f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter")
|
||||
else:
|
||||
log.warning("No vectors found matching the filter criteria")
|
||||
log.warning('No vectors found matching the filter criteria')
|
||||
else:
|
||||
log.warning("No IDs or filter provided for deletion")
|
||||
log.warning('No IDs or filter provided for deletion')
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Error deleting vectors from collection '{collection_name}': {e}"
|
||||
)
|
||||
log.error(f"Error deleting vectors from collection '{collection_name}': {e}")
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -690,36 +626,32 @@ class S3VectorClient(VectorDBBase):
|
||||
"""
|
||||
|
||||
try:
|
||||
log.warning(
|
||||
"Reset called - this will delete all vector indexes in the S3 bucket"
|
||||
)
|
||||
log.warning('Reset called - this will delete all vector indexes in the S3 bucket')
|
||||
|
||||
# List all indexes
|
||||
response = self.client.list_indexes(vectorBucketName=self.bucket_name)
|
||||
indexes = response.get("indexes", [])
|
||||
indexes = response.get('indexes', [])
|
||||
|
||||
if not indexes:
|
||||
log.warning("No indexes found to delete")
|
||||
log.warning('No indexes found to delete')
|
||||
return
|
||||
|
||||
# Delete all indexes
|
||||
deleted_count = 0
|
||||
for index in indexes:
|
||||
index_name = index.get("indexName")
|
||||
index_name = index.get('indexName')
|
||||
if index_name:
|
||||
try:
|
||||
self.client.delete_index(
|
||||
vectorBucketName=self.bucket_name, indexName=index_name
|
||||
)
|
||||
self.client.delete_index(vectorBucketName=self.bucket_name, indexName=index_name)
|
||||
deleted_count += 1
|
||||
log.info(f"Deleted index: {index_name}")
|
||||
log.info(f'Deleted index: {index_name}')
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting index '{index_name}': {e}")
|
||||
|
||||
log.info(f"Reset completed: deleted {deleted_count} indexes")
|
||||
log.info(f'Reset completed: deleted {deleted_count} indexes')
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error during reset: {e}")
|
||||
log.error(f'Error during reset: {e}')
|
||||
raise
|
||||
|
||||
def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool:
|
||||
@@ -732,15 +664,15 @@ class S3VectorClient(VectorDBBase):
|
||||
# Check each filter condition
|
||||
for key, expected_value in filter.items():
|
||||
# Handle special operators
|
||||
if key.startswith("$"):
|
||||
if key == "$and":
|
||||
if key.startswith('$'):
|
||||
if key == '$and':
|
||||
# All conditions must match
|
||||
if not isinstance(expected_value, list):
|
||||
continue
|
||||
for condition in expected_value:
|
||||
if not self._matches_filter(metadata, condition):
|
||||
return False
|
||||
elif key == "$or":
|
||||
elif key == '$or':
|
||||
# At least one condition must match
|
||||
if not isinstance(expected_value, list):
|
||||
continue
|
||||
@@ -760,22 +692,19 @@ class S3VectorClient(VectorDBBase):
|
||||
if isinstance(expected_value, dict):
|
||||
# Handle comparison operators
|
||||
for op, op_value in expected_value.items():
|
||||
if op == "$eq":
|
||||
if op == '$eq':
|
||||
if actual_value != op_value:
|
||||
return False
|
||||
elif op == "$ne":
|
||||
elif op == '$ne':
|
||||
if actual_value == op_value:
|
||||
return False
|
||||
elif op == "$in":
|
||||
if (
|
||||
not isinstance(op_value, list)
|
||||
or actual_value not in op_value
|
||||
):
|
||||
elif op == '$in':
|
||||
if not isinstance(op_value, list) or actual_value not in op_value:
|
||||
return False
|
||||
elif op == "$nin":
|
||||
elif op == '$nin':
|
||||
if isinstance(op_value, list) and actual_value in op_value:
|
||||
return False
|
||||
elif op == "$exists":
|
||||
elif op == '$exists':
|
||||
if bool(op_value) != (key in metadata):
|
||||
return False
|
||||
# Add more operators as needed
|
||||
|
||||
@@ -60,47 +60,43 @@ class WeaviateClient(VectorDBBase):
|
||||
try:
|
||||
# Build connection parameters
|
||||
connection_params = {
|
||||
"http_host": WEAVIATE_HTTP_HOST,
|
||||
"http_port": WEAVIATE_HTTP_PORT,
|
||||
"http_secure": WEAVIATE_HTTP_SECURE,
|
||||
"grpc_host": WEAVIATE_GRPC_HOST,
|
||||
"grpc_port": WEAVIATE_GRPC_PORT,
|
||||
"grpc_secure": WEAVIATE_GRPC_SECURE,
|
||||
"skip_init_checks": WEAVIATE_SKIP_INIT_CHECKS,
|
||||
'http_host': WEAVIATE_HTTP_HOST,
|
||||
'http_port': WEAVIATE_HTTP_PORT,
|
||||
'http_secure': WEAVIATE_HTTP_SECURE,
|
||||
'grpc_host': WEAVIATE_GRPC_HOST,
|
||||
'grpc_port': WEAVIATE_GRPC_PORT,
|
||||
'grpc_secure': WEAVIATE_GRPC_SECURE,
|
||||
'skip_init_checks': WEAVIATE_SKIP_INIT_CHECKS,
|
||||
}
|
||||
|
||||
# Only add auth_credentials if WEAVIATE_API_KEY exists and is not empty
|
||||
if WEAVIATE_API_KEY:
|
||||
connection_params["auth_credentials"] = (
|
||||
weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
|
||||
)
|
||||
connection_params['auth_credentials'] = weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY)
|
||||
|
||||
self.client = weaviate.connect_to_custom(**connection_params)
|
||||
self.client.connect()
|
||||
except Exception as e:
|
||||
raise ConnectionError(f"Failed to connect to Weaviate: {e}") from e
|
||||
raise ConnectionError(f'Failed to connect to Weaviate: {e}') from e
|
||||
|
||||
def _sanitize_collection_name(self, collection_name: str) -> str:
|
||||
"""Sanitize collection name to be a valid Weaviate class name."""
|
||||
if not isinstance(collection_name, str) or not collection_name.strip():
|
||||
raise ValueError("Collection name must be a non-empty string")
|
||||
raise ValueError('Collection name must be a non-empty string')
|
||||
|
||||
# Requirements for a valid Weaviate class name:
|
||||
# The collection name must begin with a capital letter.
|
||||
# The name can only contain letters, numbers, and the underscore (_) character. Spaces are not allowed.
|
||||
|
||||
# Replace hyphens with underscores and keep only alphanumeric characters
|
||||
name = re.sub(r"[^a-zA-Z0-9_]", "", collection_name.replace("-", "_"))
|
||||
name = name.strip("_")
|
||||
name = re.sub(r'[^a-zA-Z0-9_]', '', collection_name.replace('-', '_'))
|
||||
name = name.strip('_')
|
||||
|
||||
if not name:
|
||||
raise ValueError(
|
||||
"Could not sanitize collection name to be a valid Weaviate class name"
|
||||
)
|
||||
raise ValueError('Could not sanitize collection name to be a valid Weaviate class name')
|
||||
|
||||
# Ensure it starts with a letter and is capitalized
|
||||
if not name[0].isalpha():
|
||||
name = "C" + name
|
||||
name = 'C' + name
|
||||
|
||||
return name[0].upper() + name[1:]
|
||||
|
||||
@@ -118,9 +114,7 @@ class WeaviateClient(VectorDBBase):
|
||||
name=collection_name,
|
||||
vector_config=weaviate.classes.config.Configure.Vectors.self_provided(),
|
||||
properties=[
|
||||
weaviate.classes.config.Property(
|
||||
name="text", data_type=weaviate.classes.config.DataType.TEXT
|
||||
),
|
||||
weaviate.classes.config.Property(name='text', data_type=weaviate.classes.config.DataType.TEXT),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -133,19 +127,15 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
with collection.batch.fixed_size(batch_size=100) as batch:
|
||||
for item in items:
|
||||
item_uuid = str(uuid.uuid4()) if not item["id"] else str(item["id"])
|
||||
item_uuid = str(uuid.uuid4()) if not item['id'] else str(item['id'])
|
||||
|
||||
properties = {"text": item["text"]}
|
||||
if item["metadata"]:
|
||||
clean_metadata = _convert_uuids_to_strings(
|
||||
process_metadata(item["metadata"])
|
||||
)
|
||||
clean_metadata.pop("text", None)
|
||||
properties = {'text': item['text']}
|
||||
if item['metadata']:
|
||||
clean_metadata = _convert_uuids_to_strings(process_metadata(item['metadata']))
|
||||
clean_metadata.pop('text', None)
|
||||
properties.update(clean_metadata)
|
||||
|
||||
batch.add_object(
|
||||
properties=properties, uuid=item_uuid, vector=item["vector"]
|
||||
)
|
||||
batch.add_object(properties=properties, uuid=item_uuid, vector=item['vector'])
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||
@@ -156,19 +146,15 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
with collection.batch.fixed_size(batch_size=100) as batch:
|
||||
for item in items:
|
||||
item_uuid = str(item["id"]) if item["id"] else None
|
||||
item_uuid = str(item['id']) if item['id'] else None
|
||||
|
||||
properties = {"text": item["text"]}
|
||||
if item["metadata"]:
|
||||
clean_metadata = _convert_uuids_to_strings(
|
||||
process_metadata(item["metadata"])
|
||||
)
|
||||
clean_metadata.pop("text", None)
|
||||
properties = {'text': item['text']}
|
||||
if item['metadata']:
|
||||
clean_metadata = _convert_uuids_to_strings(process_metadata(item['metadata']))
|
||||
clean_metadata.pop('text', None)
|
||||
properties.update(clean_metadata)
|
||||
|
||||
batch.add_object(
|
||||
properties=properties, uuid=item_uuid, vector=item["vector"]
|
||||
)
|
||||
batch.add_object(properties=properties, uuid=item_uuid, vector=item['vector'])
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -205,16 +191,12 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
for obj in response.objects:
|
||||
properties = dict(obj.properties) if obj.properties else {}
|
||||
documents.append(properties.pop("text", ""))
|
||||
documents.append(properties.pop('text', ''))
|
||||
metadatas.append(_convert_uuids_to_strings(properties))
|
||||
|
||||
# Weaviate has cosine distance, 2 (worst) -> 0 (best). Re-ordering to 0 -> 1
|
||||
raw_distances = [
|
||||
(
|
||||
obj.metadata.distance
|
||||
if obj.metadata and obj.metadata.distance
|
||||
else 2.0
|
||||
)
|
||||
(obj.metadata.distance if obj.metadata and obj.metadata.distance else 2.0)
|
||||
for obj in response.objects
|
||||
]
|
||||
distances = [(2 - dist) / 2 for dist in raw_distances]
|
||||
@@ -231,16 +213,14 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result_ids,
|
||||
"documents": result_documents,
|
||||
"metadatas": result_metadatas,
|
||||
"distances": result_distances,
|
||||
'ids': result_ids,
|
||||
'documents': result_documents,
|
||||
'metadatas': result_metadatas,
|
||||
'distances': result_distances,
|
||||
}
|
||||
)
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
sane_collection_name = self._sanitize_collection_name(collection_name)
|
||||
if not self.client.collections.exists(sane_collection_name):
|
||||
return None
|
||||
@@ -250,21 +230,15 @@ class WeaviateClient(VectorDBBase):
|
||||
weaviate_filter = None
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(
|
||||
value
|
||||
)
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value)
|
||||
weaviate_filter = (
|
||||
prop_filter
|
||||
if weaviate_filter is None
|
||||
else weaviate.classes.query.Filter.all_of(
|
||||
[weaviate_filter, prop_filter]
|
||||
)
|
||||
else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter])
|
||||
)
|
||||
|
||||
try:
|
||||
response = collection.query.fetch_objects(
|
||||
filters=weaviate_filter, limit=limit
|
||||
)
|
||||
response = collection.query.fetch_objects(filters=weaviate_filter, limit=limit)
|
||||
|
||||
ids = [str(obj.uuid) for obj in response.objects]
|
||||
documents = []
|
||||
@@ -272,14 +246,14 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
for obj in response.objects:
|
||||
properties = dict(obj.properties) if obj.properties else {}
|
||||
documents.append(properties.pop("text", ""))
|
||||
documents.append(properties.pop('text', ''))
|
||||
metadatas.append(_convert_uuids_to_strings(properties))
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
'ids': [ids],
|
||||
'documents': [documents],
|
||||
'metadatas': [metadatas],
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
@@ -297,7 +271,7 @@ class WeaviateClient(VectorDBBase):
|
||||
for item in collection.iterator():
|
||||
ids.append(str(item.uuid))
|
||||
properties = dict(item.properties) if item.properties else {}
|
||||
documents.append(properties.pop("text", ""))
|
||||
documents.append(properties.pop('text', ''))
|
||||
metadatas.append(_convert_uuids_to_strings(properties))
|
||||
|
||||
if not ids:
|
||||
@@ -305,9 +279,9 @@ class WeaviateClient(VectorDBBase):
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [ids],
|
||||
"documents": [documents],
|
||||
"metadatas": [metadatas],
|
||||
'ids': [ids],
|
||||
'documents': [documents],
|
||||
'metadatas': [metadatas],
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
@@ -332,15 +306,11 @@ class WeaviateClient(VectorDBBase):
|
||||
elif filter:
|
||||
weaviate_filter = None
|
||||
for key, value in filter.items():
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(
|
||||
name=key
|
||||
).equal(value)
|
||||
prop_filter = weaviate.classes.query.Filter.by_property(name=key).equal(value)
|
||||
weaviate_filter = (
|
||||
prop_filter
|
||||
if weaviate_filter is None
|
||||
else weaviate.classes.query.Filter.all_of(
|
||||
[weaviate_filter, prop_filter]
|
||||
)
|
||||
else weaviate.classes.query.Filter.all_of([weaviate_filter, prop_filter])
|
||||
)
|
||||
|
||||
if weaviate_filter:
|
||||
|
||||
@@ -8,7 +8,6 @@ from open_webui.config import (
|
||||
|
||||
|
||||
class Vector:
|
||||
|
||||
@staticmethod
|
||||
def get_vector(vector_type: str) -> VectorDBBase:
|
||||
"""
|
||||
@@ -82,7 +81,7 @@ class Vector:
|
||||
|
||||
return WeaviateClient()
|
||||
case _:
|
||||
raise ValueError(f"Unsupported vector type: {vector_type}")
|
||||
raise ValueError(f'Unsupported vector type: {vector_type}')
|
||||
|
||||
|
||||
VECTOR_DB_CLIENT = Vector.get_vector(VECTOR_DB)
|
||||
|
||||
@@ -63,9 +63,7 @@ class VectorDBBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
"""Query vectors from a collection using metadata filter."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -2,15 +2,15 @@ from enum import StrEnum
|
||||
|
||||
|
||||
class VectorType(StrEnum):
|
||||
MILVUS = "milvus"
|
||||
MARIADB_VECTOR = "mariadb-vector"
|
||||
QDRANT = "qdrant"
|
||||
CHROMA = "chroma"
|
||||
PINECONE = "pinecone"
|
||||
ELASTICSEARCH = "elasticsearch"
|
||||
OPENSEARCH = "opensearch"
|
||||
PGVECTOR = "pgvector"
|
||||
ORACLE23AI = "oracle23ai"
|
||||
S3VECTOR = "s3vector"
|
||||
WEAVIATE = "weaviate"
|
||||
OPENGAUSS = "opengauss"
|
||||
MILVUS = 'milvus'
|
||||
MARIADB_VECTOR = 'mariadb-vector'
|
||||
QDRANT = 'qdrant'
|
||||
CHROMA = 'chroma'
|
||||
PINECONE = 'pinecone'
|
||||
ELASTICSEARCH = 'elasticsearch'
|
||||
OPENSEARCH = 'opensearch'
|
||||
PGVECTOR = 'pgvector'
|
||||
ORACLE23AI = 'oracle23ai'
|
||||
S3VECTOR = 's3vector'
|
||||
WEAVIATE = 'weaviate'
|
||||
OPENGAUSS = 'opengauss'
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from datetime import datetime
|
||||
|
||||
KEYS_TO_EXCLUDE = ["content", "pages", "tables", "paragraphs", "sections", "figures"]
|
||||
KEYS_TO_EXCLUDE = ['content', 'pages', 'tables', 'paragraphs', 'sections', 'figures']
|
||||
|
||||
|
||||
def filter_metadata(metadata: dict[str, any]) -> dict[str, any]:
|
||||
# Removes large/redundant fields from metadata dict.
|
||||
metadata = {
|
||||
key: value for key, value in metadata.items() if key not in KEYS_TO_EXCLUDE
|
||||
}
|
||||
metadata = {key: value for key, value in metadata.items() if key not in KEYS_TO_EXCLUDE}
|
||||
return metadata
|
||||
|
||||
|
||||
|
||||
@@ -40,20 +40,17 @@ def search_azure(
|
||||
from azure.search.documents import SearchClient
|
||||
except ImportError:
|
||||
log.error(
|
||||
"azure-search-documents package is not installed. "
|
||||
"Install it with: pip install azure-search-documents"
|
||||
'azure-search-documents package is not installed. Install it with: pip install azure-search-documents'
|
||||
)
|
||||
raise ImportError(
|
||||
"azure-search-documents is required for Azure AI Search. "
|
||||
"Install it with: pip install azure-search-documents"
|
||||
'azure-search-documents is required for Azure AI Search. '
|
||||
'Install it with: pip install azure-search-documents'
|
||||
)
|
||||
|
||||
try:
|
||||
# Create search client with API key authentication
|
||||
credential = AzureKeyCredential(api_key)
|
||||
search_client = SearchClient(
|
||||
endpoint=endpoint, index_name=index_name, credential=credential
|
||||
)
|
||||
search_client = SearchClient(endpoint=endpoint, index_name=index_name, credential=credential)
|
||||
|
||||
# Perform the search
|
||||
results = search_client.search(search_text=query, top=count)
|
||||
@@ -68,42 +65,42 @@ def search_azure(
|
||||
|
||||
# Try to find URL field (common names)
|
||||
link = (
|
||||
result_dict.get("url")
|
||||
or result_dict.get("link")
|
||||
or result_dict.get("uri")
|
||||
or result_dict.get("metadata_storage_path")
|
||||
or ""
|
||||
result_dict.get('url')
|
||||
or result_dict.get('link')
|
||||
or result_dict.get('uri')
|
||||
or result_dict.get('metadata_storage_path')
|
||||
or ''
|
||||
)
|
||||
|
||||
# Try to find title field (common names)
|
||||
title = (
|
||||
result_dict.get("title")
|
||||
or result_dict.get("name")
|
||||
or result_dict.get("metadata_title")
|
||||
or result_dict.get("metadata_storage_name")
|
||||
result_dict.get('title')
|
||||
or result_dict.get('name')
|
||||
or result_dict.get('metadata_title')
|
||||
or result_dict.get('metadata_storage_name')
|
||||
or None
|
||||
)
|
||||
|
||||
# Try to find content/snippet field (common names)
|
||||
snippet = (
|
||||
result_dict.get("content")
|
||||
or result_dict.get("snippet")
|
||||
or result_dict.get("description")
|
||||
or result_dict.get("summary")
|
||||
or result_dict.get("text")
|
||||
result_dict.get('content')
|
||||
or result_dict.get('snippet')
|
||||
or result_dict.get('description')
|
||||
or result_dict.get('summary')
|
||||
or result_dict.get('text')
|
||||
or None
|
||||
)
|
||||
|
||||
# Truncate snippet if too long
|
||||
if snippet and len(snippet) > 500:
|
||||
snippet = snippet[:497] + "..."
|
||||
snippet = snippet[:497] + '...'
|
||||
|
||||
if link: # Only add if we found a valid link
|
||||
search_results.append(
|
||||
{
|
||||
"link": link,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
'link': link,
|
||||
'title': title,
|
||||
'snippet': snippet,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -114,13 +111,13 @@ def search_azure(
|
||||
# Convert to SearchResult objects
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
link=result['link'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('snippet'),
|
||||
)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
except Exception as ex:
|
||||
log.error(f"Azure AI Search error: {ex}")
|
||||
log.error(f'Azure AI Search error: {ex}')
|
||||
raise ex
|
||||
|
||||
@@ -21,48 +21,44 @@ def search_bing(
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
mkt = locale
|
||||
params = {"q": query, "mkt": mkt, "count": count}
|
||||
headers = {"Ocp-Apim-Subscription-Key": subscription_key}
|
||||
params = {'q': query, 'mkt': mkt, 'count': count}
|
||||
headers = {'Ocp-Apim-Subscription-Key': subscription_key}
|
||||
|
||||
try:
|
||||
response = requests.get(endpoint, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get("webPages", {}).get("value", [])
|
||||
results = json_response.get('webPages', {}).get('value', [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("name"),
|
||||
snippet=result.get("snippet"),
|
||||
link=result['url'],
|
||||
title=result.get('name'),
|
||||
snippet=result.get('snippet'),
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
except Exception as ex:
|
||||
log.error(f"Error: {ex}")
|
||||
log.error(f'Error: {ex}')
|
||||
raise ex
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Search Bing from the command line.")
|
||||
parser = argparse.ArgumentParser(description='Search Bing from the command line.')
|
||||
parser.add_argument(
|
||||
"query",
|
||||
'query',
|
||||
type=str,
|
||||
default="Top 10 international news today",
|
||||
help="The search query.",
|
||||
default='Top 10 international news today',
|
||||
help='The search query.',
|
||||
)
|
||||
parser.add_argument('--count', type=int, default=10, help='Number of search results to return.')
|
||||
parser.add_argument('--filter', nargs='*', help='List of filters to apply to the search results.')
|
||||
parser.add_argument(
|
||||
"--count", type=int, default=10, help="Number of search results to return."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filter", nargs="*", help="List of filters to apply to the search results."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--locale",
|
||||
'--locale',
|
||||
type=str,
|
||||
default="en-US",
|
||||
help="The locale to use for the search, maps to market in api",
|
||||
default='en-US',
|
||||
help='The locale to use for the search, maps to market in api',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -10,43 +10,38 @@ log = logging.getLogger(__name__)
|
||||
|
||||
def _parse_response(response):
|
||||
results = []
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
if "webPages" in data:
|
||||
webPages = data["webPages"]
|
||||
if "value" in webPages:
|
||||
if 'data' in response:
|
||||
data = response['data']
|
||||
if 'webPages' in data:
|
||||
webPages = data['webPages']
|
||||
if 'value' in webPages:
|
||||
results = [
|
||||
{
|
||||
"id": item.get("id", ""),
|
||||
"name": item.get("name", ""),
|
||||
"url": item.get("url", ""),
|
||||
"snippet": item.get("snippet", ""),
|
||||
"summary": item.get("summary", ""),
|
||||
"siteName": item.get("siteName", ""),
|
||||
"siteIcon": item.get("siteIcon", ""),
|
||||
"datePublished": item.get("datePublished", "")
|
||||
or item.get("dateLastCrawled", ""),
|
||||
'id': item.get('id', ''),
|
||||
'name': item.get('name', ''),
|
||||
'url': item.get('url', ''),
|
||||
'snippet': item.get('snippet', ''),
|
||||
'summary': item.get('summary', ''),
|
||||
'siteName': item.get('siteName', ''),
|
||||
'siteIcon': item.get('siteIcon', ''),
|
||||
'datePublished': item.get('datePublished', '') or item.get('dateLastCrawled', ''),
|
||||
}
|
||||
for item in webPages["value"]
|
||||
for item in webPages['value']
|
||||
]
|
||||
return results
|
||||
|
||||
|
||||
def search_bocha(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
def search_bocha(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]:
|
||||
"""Search using Bocha's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Bocha Search API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://api.bochaai.com/v1/web-search?utm_source=ollama"
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
url = 'https://api.bochaai.com/v1/web-search?utm_source=ollama'
|
||||
headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
|
||||
|
||||
payload = json.dumps(
|
||||
{"query": query, "summary": True, "freshness": "noLimit", "count": count}
|
||||
)
|
||||
payload = json.dumps({'query': query, 'summary': True, 'freshness': 'noLimit', 'count': count})
|
||||
|
||||
response = requests.post(url, headers=headers, data=payload, timeout=5)
|
||||
response.raise_for_status()
|
||||
@@ -56,8 +51,6 @@ def search_bocha(
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("name"), snippet=result.get("summary")
|
||||
)
|
||||
SearchResult(link=result['url'], title=result.get('name'), snippet=result.get('summary'))
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -8,44 +8,42 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def search_brave(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
def search_brave(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]:
|
||||
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Brave Search API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://api.search.brave.com/res/v1/web/search"
|
||||
url = 'https://api.search.brave.com/res/v1/web/search'
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": api_key,
|
||||
'Accept': 'application/json',
|
||||
'Accept-Encoding': 'gzip',
|
||||
'X-Subscription-Token': api_key,
|
||||
}
|
||||
params = {"q": query, "count": count}
|
||||
params = {'q': query, 'count': count}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
|
||||
# Handle 429 rate limiting - Brave free tier allows 1 request/second
|
||||
# If rate limited, wait 1 second and retry once before failing
|
||||
if response.status_code == 429:
|
||||
log.info("Brave Search API rate limited (429), retrying after 1 second...")
|
||||
log.info('Brave Search API rate limited (429), retrying after 1 second...')
|
||||
time.sleep(1)
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("web", {}).get("results", [])
|
||||
results = json_response.get('web', {}).get('results', [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
link=result['url'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('description'),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -13,7 +13,7 @@ def search_duckduckgo(
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
concurrent_requests: Optional[int] = None,
|
||||
backend: Optional[str] = "auto",
|
||||
backend: Optional[str] = 'auto',
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
|
||||
@@ -33,20 +33,18 @@ def search_duckduckgo(
|
||||
|
||||
# Use the ddgs.text() method to perform the search
|
||||
try:
|
||||
search_results = ddgs.text(
|
||||
query, safesearch="moderate", max_results=count, backend=backend
|
||||
)
|
||||
search_results = ddgs.text(query, safesearch='moderate', max_results=count, backend=backend)
|
||||
except RatelimitException as e:
|
||||
log.error(f"RatelimitException: {e}")
|
||||
log.error(f'RatelimitException: {e}')
|
||||
if filter_list:
|
||||
search_results = get_filtered_results(search_results, filter_list)
|
||||
|
||||
# Return the list of search results
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["href"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("body"),
|
||||
link=result['href'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('body'),
|
||||
)
|
||||
for result in search_results
|
||||
]
|
||||
|
||||
@@ -7,7 +7,7 @@ from open_webui.retrieval.web.main import SearchResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
EXA_API_BASE = "https://api.exa.ai"
|
||||
EXA_API_BASE = 'https://api.exa.ai'
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -31,36 +31,34 @@ def search_exa(
|
||||
count (int): Number of results to return
|
||||
filter_list (Optional[list[str]]): List of domains to filter results by
|
||||
"""
|
||||
log.info(f"Searching with Exa for query: {query}")
|
||||
log.info(f'Searching with Exa for query: {query}')
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
|
||||
|
||||
payload = {
|
||||
"query": query,
|
||||
"numResults": count or 5,
|
||||
"includeDomains": filter_list,
|
||||
"contents": {"text": True, "highlights": True},
|
||||
"type": "auto", # Use the auto search type (keyword or neural)
|
||||
'query': query,
|
||||
'numResults': count or 5,
|
||||
'includeDomains': filter_list,
|
||||
'contents': {'text': True, 'highlights': True},
|
||||
'type': 'auto', # Use the auto search type (keyword or neural)
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{EXA_API_BASE}/search", headers=headers, json=payload
|
||||
)
|
||||
response = requests.post(f'{EXA_API_BASE}/search', headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data["results"]:
|
||||
for result in data['results']:
|
||||
results.append(
|
||||
ExaResult(
|
||||
url=result["url"],
|
||||
title=result["title"],
|
||||
text=result["text"],
|
||||
url=result['url'],
|
||||
title=result['title'],
|
||||
text=result['text'],
|
||||
)
|
||||
)
|
||||
|
||||
log.info(f"Found {len(results)} results")
|
||||
log.info(f'Found {len(results)} results')
|
||||
return [
|
||||
SearchResult(
|
||||
link=result.url,
|
||||
@@ -70,5 +68,5 @@ def search_exa(
|
||||
for result in results
|
||||
]
|
||||
except Exception as e:
|
||||
log.error(f"Error searching Exa: {e}")
|
||||
log.error(f'Error searching Exa: {e}')
|
||||
return []
|
||||
|
||||
@@ -24,12 +24,12 @@ def search_external(
|
||||
) -> List[SearchResult]:
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Authorization": f"Bearer {external_api_key}",
|
||||
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot',
|
||||
'Authorization': f'Bearer {external_api_key}',
|
||||
}
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
chat_id = getattr(request.state, "chat_id", None)
|
||||
chat_id = getattr(request.state, 'chat_id', None)
|
||||
if chat_id:
|
||||
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = str(chat_id)
|
||||
|
||||
@@ -37,8 +37,8 @@ def search_external(
|
||||
external_url,
|
||||
headers=headers,
|
||||
json={
|
||||
"query": query,
|
||||
"count": count,
|
||||
'query': query,
|
||||
'count': count,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -47,14 +47,14 @@ def search_external(
|
||||
results = get_filtered_results(results, filter_list)
|
||||
results = [
|
||||
SearchResult(
|
||||
link=result.get("link"),
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
link=result.get('link'),
|
||||
title=result.get('title'),
|
||||
snippet=result.get('snippet'),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
log.info(f"External search results: {results}")
|
||||
log.info(f'External search results: {results}')
|
||||
return results
|
||||
except Exception as e:
|
||||
log.error(f"Error in External search: {e}")
|
||||
log.error(f'Error in External search: {e}')
|
||||
return []
|
||||
|
||||
@@ -17,9 +17,7 @@ def search_firecrawl(
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
firecrawl = FirecrawlApp(api_key=firecrawl_api_key, api_url=firecrawl_url)
|
||||
response = firecrawl.search(
|
||||
query=query, limit=count, ignore_invalid_urls=True, timeout=count * 3
|
||||
)
|
||||
response = firecrawl.search(query=query, limit=count, ignore_invalid_urls=True, timeout=count * 3)
|
||||
results = response.web
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
@@ -31,8 +29,8 @@ def search_firecrawl(
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
log.info(f"External search results: {results}")
|
||||
log.info(f'External search results: {results}')
|
||||
return results
|
||||
except Exception as e:
|
||||
log.error(f"Error in External search: {e}")
|
||||
log.error(f'Error in External search: {e}')
|
||||
return []
|
||||
|
||||
@@ -28,11 +28,11 @@ def search_google_pse(
|
||||
Returns:
|
||||
list[SearchResult]: A list of SearchResult objects.
|
||||
"""
|
||||
url = "https://www.googleapis.com/customsearch/v1"
|
||||
url = 'https://www.googleapis.com/customsearch/v1'
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
if referer:
|
||||
headers["Referer"] = referer
|
||||
headers['Referer'] = referer
|
||||
|
||||
all_results = []
|
||||
start_index = 1 # Google PSE start parameter is 1-based
|
||||
@@ -40,21 +40,19 @@ def search_google_pse(
|
||||
while count > 0:
|
||||
num_results_this_page = min(count, 10) # Google PSE max results per page is 10
|
||||
params = {
|
||||
"cx": search_engine_id,
|
||||
"q": query,
|
||||
"key": api_key,
|
||||
"num": num_results_this_page,
|
||||
"start": start_index,
|
||||
'cx': search_engine_id,
|
||||
'q': query,
|
||||
'key': api_key,
|
||||
'num': num_results_this_page,
|
||||
'start': start_index,
|
||||
}
|
||||
response = requests.request("GET", url, headers=headers, params=params)
|
||||
response = requests.request('GET', url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get("items", [])
|
||||
results = json_response.get('items', [])
|
||||
if results: # check if results are returned. If not, no more pages to fetch.
|
||||
all_results.extend(results)
|
||||
count -= len(
|
||||
results
|
||||
) # Decrement count by the number of results fetched in this page.
|
||||
count -= len(results) # Decrement count by the number of results fetched in this page.
|
||||
start_index += 10 # Increment start index for the next page
|
||||
else:
|
||||
break # No more results from Google PSE, break the loop
|
||||
@@ -64,9 +62,9 @@ def search_google_pse(
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
link=result['link'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('snippet'),
|
||||
)
|
||||
for result in all_results
|
||||
]
|
||||
|
||||
@@ -7,9 +7,7 @@ from yarl import URL
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def search_jina(
|
||||
api_key: str, query: str, count: int, base_url: str = ""
|
||||
) -> list[SearchResult]:
|
||||
def search_jina(api_key: str, query: str, count: int, base_url: str = '') -> list[SearchResult]:
|
||||
"""
|
||||
Search using Jina's Search API and return the results as a list of SearchResult objects.
|
||||
Args:
|
||||
@@ -21,16 +19,16 @@ def search_jina(
|
||||
Returns:
|
||||
list[SearchResult]: A list of search results
|
||||
"""
|
||||
jina_search_endpoint = base_url if base_url else "https://s.jina.ai/"
|
||||
jina_search_endpoint = base_url if base_url else 'https://s.jina.ai/'
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": api_key,
|
||||
"X-Retain-Images": "none",
|
||||
'Accept': 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': api_key,
|
||||
'X-Retain-Images': 'none',
|
||||
}
|
||||
|
||||
payload = {"q": query, "count": count if count <= 10 else 10}
|
||||
payload = {'q': query, 'count': count if count <= 10 else 10}
|
||||
|
||||
url = str(URL(jina_search_endpoint))
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
@@ -38,12 +36,12 @@ def search_jina(
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data["data"]:
|
||||
for result in data['data']:
|
||||
results.append(
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("content"),
|
||||
link=result['url'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('content'),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -7,9 +7,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def search_kagi(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
def search_kagi(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]:
|
||||
"""Search using Kagi's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
The Search API will inherit the settings in your account, including results personalization and snippet length.
|
||||
@@ -19,23 +17,21 @@ def search_kagi(
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return
|
||||
"""
|
||||
url = "https://kagi.com/api/v0/search"
|
||||
url = 'https://kagi.com/api/v0/search'
|
||||
headers = {
|
||||
"Authorization": f"Bot {api_key}",
|
||||
'Authorization': f'Bot {api_key}',
|
||||
}
|
||||
params = {"q": query, "limit": count}
|
||||
params = {'q': query, 'limit': count}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
search_results = json_response.get("data", [])
|
||||
search_results = json_response.get('data', [])
|
||||
|
||||
results = [
|
||||
SearchResult(
|
||||
link=result["url"], title=result["title"], snippet=result.get("snippet")
|
||||
)
|
||||
SearchResult(link=result['url'], title=result['title'], snippet=result.get('snippet'))
|
||||
for result in search_results
|
||||
if result["t"] == 0
|
||||
if result['t'] == 0
|
||||
]
|
||||
|
||||
print(results)
|
||||
|
||||
@@ -16,7 +16,7 @@ def get_filtered_results(results, filter_list):
|
||||
filtered_results = []
|
||||
|
||||
for result in results:
|
||||
url = result.get("url") or result.get("link", "") or result.get("href", "")
|
||||
url = result.get('url') or result.get('link', '') or result.get('href', '')
|
||||
if not validators.url(url):
|
||||
continue
|
||||
|
||||
|
||||
@@ -7,32 +7,27 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def search_mojeek(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
def search_mojeek(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]:
|
||||
"""Search using Mojeek's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Mojeek Search API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://api.mojeek.com/search"
|
||||
url = 'https://api.mojeek.com/search'
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
params = {"q": query, "api_key": api_key, "fmt": "json", "t": count}
|
||||
params = {'q': query, 'api_key': api_key, 'fmt': 'json', 't': count}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get("response", {}).get("results", [])
|
||||
results = json_response.get('response', {}).get('results', [])
|
||||
print(results)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("desc")
|
||||
)
|
||||
for result in results
|
||||
SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('desc')) for result in results
|
||||
]
|
||||
|
||||
@@ -23,30 +23,30 @@ def search_ollama_cloud(
|
||||
count (int): Number of results to return
|
||||
filter_list (Optional[list[str]]): List of domains to filter results by
|
||||
"""
|
||||
log.info(f"Searching with Ollama for query: {query}")
|
||||
log.info(f'Searching with Ollama for query: {query}')
|
||||
|
||||
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
||||
payload = {"query": query, "max_results": count}
|
||||
headers = {'Authorization': f'Bearer {api_key}', 'Content-Type': 'application/json'}
|
||||
payload = {'query': query, 'max_results': count}
|
||||
|
||||
try:
|
||||
response = requests.post(f"{url}/api/web_search", headers=headers, json=payload)
|
||||
response = requests.post(f'{url}/api/web_search', headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = data.get("results", [])
|
||||
log.info(f"Found {len(results)} results")
|
||||
results = data.get('results', [])
|
||||
log.info(f'Found {len(results)} results')
|
||||
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result.get("url", ""),
|
||||
title=result.get("title", ""),
|
||||
snippet=result.get("content", ""),
|
||||
link=result.get('url', ''),
|
||||
title=result.get('title', ''),
|
||||
snippet=result.get('content', ''),
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
except Exception as e:
|
||||
log.error(f"Error searching Ollama: {e}")
|
||||
log.error(f'Error searching Ollama: {e}')
|
||||
return []
|
||||
|
||||
@@ -5,13 +5,13 @@ import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
|
||||
MODELS = Literal[
|
||||
"sonar",
|
||||
"sonar-pro",
|
||||
"sonar-reasoning",
|
||||
"sonar-reasoning-pro",
|
||||
"sonar-deep-research",
|
||||
'sonar',
|
||||
'sonar-pro',
|
||||
'sonar-reasoning',
|
||||
'sonar-reasoning-pro',
|
||||
'sonar-deep-research',
|
||||
]
|
||||
SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"]
|
||||
SEARCH_CONTEXT_USAGE_LEVELS = Literal['low', 'medium', 'high']
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -22,8 +22,8 @@ def search_perplexity(
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
model: MODELS = "sonar",
|
||||
search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = "medium",
|
||||
model: MODELS = 'sonar',
|
||||
search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = 'medium',
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
||||
|
||||
@@ -38,66 +38,63 @@ def search_perplexity(
|
||||
"""
|
||||
|
||||
# Handle PersistentConfig object
|
||||
if hasattr(api_key, "__str__"):
|
||||
if hasattr(api_key, '__str__'):
|
||||
api_key = str(api_key)
|
||||
|
||||
try:
|
||||
url = "https://api.perplexity.ai/chat/completions"
|
||||
url = 'https://api.perplexity.ai/chat/completions'
|
||||
|
||||
# Create payload for the API call
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
'model': model,
|
||||
'messages': [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a search assistant. Provide factual information with citations.",
|
||||
'role': 'system',
|
||||
'content': 'You are a search assistant. Provide factual information with citations.',
|
||||
},
|
||||
{"role": "user", "content": query},
|
||||
{'role': 'user', 'content': query},
|
||||
],
|
||||
"temperature": 0.2, # Lower temperature for more factual responses
|
||||
"stream": False,
|
||||
"web_search_options": {
|
||||
"search_context_usage": search_context_usage,
|
||||
'temperature': 0.2, # Lower temperature for more factual responses
|
||||
'stream': False,
|
||||
'web_search_options': {
|
||||
'search_context_usage': search_context_usage,
|
||||
},
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
# Make the API request
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
response = requests.request('POST', url, json=payload, headers=headers)
|
||||
|
||||
# Parse the JSON response
|
||||
json_response = response.json()
|
||||
|
||||
# Extract citations from the response
|
||||
citations = json_response.get("citations", [])
|
||||
citations = json_response.get('citations', [])
|
||||
|
||||
# Create search results from citations
|
||||
results = []
|
||||
for i, citation in enumerate(citations[:count]):
|
||||
# Extract content from the response to use as snippet
|
||||
content = ""
|
||||
if "choices" in json_response and json_response["choices"]:
|
||||
content = ''
|
||||
if 'choices' in json_response and json_response['choices']:
|
||||
if i == 0:
|
||||
content = json_response["choices"][0]["message"]["content"]
|
||||
content = json_response['choices'][0]['message']['content']
|
||||
|
||||
result = {"link": citation, "title": f"Source {i+1}", "snippet": content}
|
||||
result = {'link': citation, 'title': f'Source {i + 1}', 'snippet': content}
|
||||
results.append(result)
|
||||
|
||||
if filter_list:
|
||||
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"], title=result["title"], snippet=result["snippet"]
|
||||
)
|
||||
SearchResult(link=result['link'], title=result['title'], snippet=result['snippet'])
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error searching with Perplexity API: {e}")
|
||||
log.error(f'Error searching with Perplexity API: {e}')
|
||||
return []
|
||||
|
||||
@@ -13,7 +13,7 @@ def search_perplexity_search(
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[list[str]] = None,
|
||||
api_url: str = "https://api.perplexity.ai/search",
|
||||
api_url: str = 'https://api.perplexity.ai/search',
|
||||
user=None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Perplexity API and return the results as a list of SearchResult objects.
|
||||
@@ -29,10 +29,10 @@ def search_perplexity_search(
|
||||
"""
|
||||
|
||||
# Handle PersistentConfig object
|
||||
if hasattr(api_key, "__str__"):
|
||||
if hasattr(api_key, '__str__'):
|
||||
api_key = str(api_key)
|
||||
|
||||
if hasattr(api_url, "__str__"):
|
||||
if hasattr(api_url, '__str__'):
|
||||
api_url = str(api_url)
|
||||
|
||||
try:
|
||||
@@ -40,13 +40,13 @@ def search_perplexity_search(
|
||||
|
||||
# Create payload for the API call
|
||||
payload = {
|
||||
"query": query,
|
||||
"max_results": count,
|
||||
'query': query,
|
||||
'max_results': count,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
# Forward user info headers if user is provided
|
||||
@@ -54,20 +54,17 @@ def search_perplexity_search(
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
# Make the API request
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
response = requests.request('POST', url, json=payload, headers=headers)
|
||||
# Parse the JSON response
|
||||
json_response = response.json()
|
||||
|
||||
# Extract citations from the response
|
||||
results = json_response.get("results", [])
|
||||
results = json_response.get('results', [])
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result["title"], snippet=result["snippet"]
|
||||
)
|
||||
for result in results
|
||||
SearchResult(link=result['url'], title=result['title'], snippet=result['snippet']) for result in results
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error searching with Perplexity Search API: {e}")
|
||||
log.error(f'Error searching with Perplexity Search API: {e}')
|
||||
return []
|
||||
|
||||
@@ -21,28 +21,26 @@ def search_searchapi(
|
||||
api_key (str): A searchapi.io API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://www.searchapi.io/api/v1/search"
|
||||
url = 'https://www.searchapi.io/api/v1/search'
|
||||
|
||||
engine = engine or "google"
|
||||
engine = engine or 'google'
|
||||
|
||||
payload = {"engine": engine, "q": query, "api_key": api_key}
|
||||
payload = {'engine': engine, 'q': query, 'api_key': api_key}
|
||||
|
||||
url = f"{url}?{urlencode(payload)}"
|
||||
response = requests.request("GET", url)
|
||||
url = f'{url}?{urlencode(payload)}'
|
||||
response = requests.request('GET', url)
|
||||
|
||||
json_response = response.json()
|
||||
log.info(f"results from searchapi search: {json_response}")
|
||||
log.info(f'results from searchapi search: {json_response}')
|
||||
|
||||
results = sorted(
|
||||
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
results = sorted(json_response.get('organic_results', []), key=lambda x: x.get('position', 0))
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
link=result['link'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('snippet'),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -38,38 +38,38 @@ def search_searxng(
|
||||
"""
|
||||
|
||||
# Default values for optional parameters are provided as empty strings or None when not specified.
|
||||
language = kwargs.get("language", "all")
|
||||
safesearch = kwargs.get("safesearch", "1")
|
||||
time_range = kwargs.get("time_range", "")
|
||||
categories = "".join(kwargs.get("categories", []))
|
||||
language = kwargs.get('language', 'all')
|
||||
safesearch = kwargs.get('safesearch', '1')
|
||||
time_range = kwargs.get('time_range', '')
|
||||
categories = ''.join(kwargs.get('categories', []))
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"pageno": 1,
|
||||
"safesearch": safesearch,
|
||||
"language": language,
|
||||
"time_range": time_range,
|
||||
"categories": categories,
|
||||
"theme": "simple",
|
||||
"image_proxy": 0,
|
||||
'q': query,
|
||||
'format': 'json',
|
||||
'pageno': 1,
|
||||
'safesearch': safesearch,
|
||||
'language': language,
|
||||
'time_range': time_range,
|
||||
'categories': categories,
|
||||
'theme': 'simple',
|
||||
'image_proxy': 0,
|
||||
}
|
||||
|
||||
# Legacy query format
|
||||
if "<query>" in query_url:
|
||||
if '<query>' in query_url:
|
||||
# Strip all query parameters from the URL
|
||||
query_url = query_url.split("?")[0]
|
||||
query_url = query_url.split('?')[0]
|
||||
|
||||
log.debug(f"searching {query_url}")
|
||||
log.debug(f'searching {query_url}')
|
||||
|
||||
response = requests.get(
|
||||
query_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Accept": "text/html",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Accept-Language": "en-US,en;q=0.5",
|
||||
"Connection": "keep-alive",
|
||||
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot',
|
||||
'Accept': 'text/html',
|
||||
'Accept-Encoding': 'gzip, deflate',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Connection': 'keep-alive',
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
@@ -77,13 +77,11 @@ def search_searxng(
|
||||
response.raise_for_status() # Raise an exception for HTTP errors.
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("results", [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
|
||||
results = json_response.get('results', [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get('score', 0), reverse=True)
|
||||
if filter_list:
|
||||
sorted_results = get_filtered_results(sorted_results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("content")
|
||||
)
|
||||
SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('content'))
|
||||
for result in sorted_results[:count]
|
||||
]
|
||||
|
||||
@@ -21,28 +21,26 @@ def search_serpapi(
|
||||
api_key (str): A serpapi.com API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://serpapi.com/search"
|
||||
url = 'https://serpapi.com/search'
|
||||
|
||||
engine = engine or "google"
|
||||
engine = engine or 'google'
|
||||
|
||||
payload = {"engine": engine, "q": query, "api_key": api_key}
|
||||
payload = {'engine': engine, 'q': query, 'api_key': api_key}
|
||||
|
||||
url = f"{url}?{urlencode(payload)}"
|
||||
response = requests.request("GET", url)
|
||||
url = f'{url}?{urlencode(payload)}'
|
||||
response = requests.request('GET', url)
|
||||
|
||||
json_response = response.json()
|
||||
log.info(f"results from serpapi search: {json_response}")
|
||||
log.info(f'results from serpapi search: {json_response}')
|
||||
|
||||
results = sorted(
|
||||
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
results = sorted(json_response.get('organic_results', []), key=lambda x: x.get('position', 0))
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
link=result['link'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('snippet'),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -8,34 +8,30 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def search_serper(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
def search_serper(api_key: str, query: str, count: int, filter_list: Optional[list[str]] = None) -> list[SearchResult]:
|
||||
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A serper.dev API key
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = "https://google.serper.dev/search"
|
||||
url = 'https://google.serper.dev/search'
|
||||
|
||||
payload = json.dumps({"q": query})
|
||||
headers = {"X-API-KEY": api_key, "Content-Type": "application/json"}
|
||||
payload = json.dumps({'q': query})
|
||||
headers = {'X-API-KEY': api_key, 'Content-Type': 'application/json'}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload)
|
||||
response = requests.request('POST', url, headers=headers, data=payload)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
results = sorted(
|
||||
json_response.get("organic", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
results = sorted(json_response.get('organic', []), key=lambda x: x.get('position', 0))
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
link=result['link'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('description'),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -12,10 +12,10 @@ def search_serply(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
hl: str = "us",
|
||||
hl: str = 'us',
|
||||
limit: int = 10,
|
||||
device_type: str = "desktop",
|
||||
proxy_location: str = "US",
|
||||
device_type: str = 'desktop',
|
||||
proxy_location: str = 'US',
|
||||
filter_list: Optional[list[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
|
||||
@@ -26,42 +26,40 @@ def search_serply(
|
||||
hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
|
||||
limit (int): The maximum number of results to return [10-100, defaults to 10]
|
||||
"""
|
||||
log.info("Searching with Serply")
|
||||
log.info('Searching with Serply')
|
||||
|
||||
url = "https://api.serply.io/v1/search/"
|
||||
url = 'https://api.serply.io/v1/search/'
|
||||
|
||||
query_payload = {
|
||||
"q": query,
|
||||
"language": "en",
|
||||
"num": limit,
|
||||
"gl": proxy_location.upper(),
|
||||
"hl": hl.lower(),
|
||||
'q': query,
|
||||
'language': 'en',
|
||||
'num': limit,
|
||||
'gl': proxy_location.upper(),
|
||||
'hl': hl.lower(),
|
||||
}
|
||||
|
||||
url = f"{url}{urlencode(query_payload)}"
|
||||
url = f'{url}{urlencode(query_payload)}'
|
||||
headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"X-User-Agent": device_type,
|
||||
"User-Agent": "open-webui",
|
||||
"X-Proxy-Location": proxy_location,
|
||||
'X-API-KEY': api_key,
|
||||
'X-User-Agent': device_type,
|
||||
'User-Agent': 'open-webui',
|
||||
'X-Proxy-Location': proxy_location,
|
||||
}
|
||||
|
||||
response = requests.request("GET", url, headers=headers)
|
||||
response = requests.request('GET', url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
log.info(f"results from serply search: {json_response}")
|
||||
log.info(f'results from serply search: {json_response}')
|
||||
|
||||
results = sorted(
|
||||
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
|
||||
)
|
||||
results = sorted(json_response.get('results', []), key=lambda x: x.get('realPosition', 0))
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
link=result['link'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('description'),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -21,26 +21,22 @@ def search_serpstack(
|
||||
query (str): The query to search for
|
||||
https_enabled (bool): Whether to use HTTPS or HTTP for the API request
|
||||
"""
|
||||
url = f"{'https' if https_enabled else 'http'}://api.serpstack.com/search"
|
||||
url = f'{"https" if https_enabled else "http"}://api.serpstack.com/search'
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
params = {
|
||||
"access_key": api_key,
|
||||
"query": query,
|
||||
'access_key': api_key,
|
||||
'query': query,
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, params=params)
|
||||
response = requests.request('POST', url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
results = sorted(
|
||||
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
results = sorted(json_response.get('organic_results', []), key=lambda x: x.get('position', 0))
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||
)
|
||||
SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('snippet'))
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
@@ -26,33 +26,26 @@ def search_sougou(
|
||||
try:
|
||||
cred = credential.Credential(sougou_api_sid, sougou_api_sk)
|
||||
http_profile = HttpProfile()
|
||||
http_profile.endpoint = "tms.tencentcloudapi.com"
|
||||
http_profile.endpoint = 'tms.tencentcloudapi.com'
|
||||
client_profile = ClientProfile()
|
||||
client_profile.http_profile = http_profile
|
||||
params = json.dumps({"Query": query, "Cnt": 20})
|
||||
common_client = CommonClient(
|
||||
"tms", "2020-12-29", cred, "", profile=client_profile
|
||||
)
|
||||
params = json.dumps({'Query': query, 'Cnt': 20})
|
||||
common_client = CommonClient('tms', '2020-12-29', cred, '', profile=client_profile)
|
||||
results = [
|
||||
json.loads(page)
|
||||
for page in common_client.call_json("SearchPro", json.loads(params))[
|
||||
"Response"
|
||||
]["Pages"]
|
||||
json.loads(page) for page in common_client.call_json('SearchPro', json.loads(params))['Response']['Pages']
|
||||
]
|
||||
sorted_results = sorted(
|
||||
results, key=lambda x: x.get("scour", 0.0), reverse=True
|
||||
)
|
||||
sorted_results = sorted(results, key=lambda x: x.get('scour', 0.0), reverse=True)
|
||||
if filter_list:
|
||||
sorted_results = get_filtered_results(sorted_results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result.get("url"),
|
||||
title=result.get("title"),
|
||||
snippet=result.get("passage"),
|
||||
link=result.get('url'),
|
||||
title=result.get('title'),
|
||||
snippet=result.get('passage'),
|
||||
)
|
||||
for result in sorted_results[:count]
|
||||
]
|
||||
except TencentCloudSDKException as err:
|
||||
log.error(f"Error in Sougou search: {err}")
|
||||
log.error(f'Error in Sougou search: {err}')
|
||||
return []
|
||||
|
||||
@@ -24,26 +24,26 @@ def search_tavily(
|
||||
Returns:
|
||||
list[SearchResult]: A list of search results
|
||||
"""
|
||||
url = "https://api.tavily.com/search"
|
||||
url = 'https://api.tavily.com/search'
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
}
|
||||
data = {"query": query, "max_results": count}
|
||||
data = {'query': query, 'max_results': count}
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
|
||||
results = json_response.get("results", [])
|
||||
results = json_response.get('results', [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("title", ""),
|
||||
snippet=result.get("content"),
|
||||
link=result['url'],
|
||||
title=result.get('title', ''),
|
||||
snippet=result.get('content'),
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
|
||||
@@ -67,16 +67,14 @@ def validate_url(url: Union[str, Sequence[str]]):
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
|
||||
# Protocol validation - only allow http/https
|
||||
if parsed_url.scheme not in ["http", "https"]:
|
||||
log.warning(
|
||||
f"Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}"
|
||||
)
|
||||
if parsed_url.scheme not in ['http', 'https']:
|
||||
log.warning(f'Blocked non-HTTP(S) protocol: {parsed_url.scheme} in URL: {url}')
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
|
||||
# Blocklist check using unified filtering logic
|
||||
if WEB_FETCH_FILTER_LIST:
|
||||
if not is_string_allowed(url, WEB_FETCH_FILTER_LIST):
|
||||
log.warning(f"URL blocked by filter list: {url}")
|
||||
log.warning(f'URL blocked by filter list: {url}')
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
|
||||
if not ENABLE_RAG_LOCAL_WEB_FETCH:
|
||||
@@ -106,29 +104,29 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]:
|
||||
if validate_url(u):
|
||||
valid_urls.append(u)
|
||||
except Exception as e:
|
||||
log.debug(f"Invalid URL {u}: {str(e)}")
|
||||
log.debug(f'Invalid URL {u}: {str(e)}')
|
||||
continue
|
||||
return valid_urls
|
||||
|
||||
|
||||
def extract_metadata(soup, url):
|
||||
metadata = {"source": url}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get("content", "No description found.")
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
metadata = {'source': url}
|
||||
if title := soup.find('title'):
|
||||
metadata['title'] = title.get_text()
|
||||
if description := soup.find('meta', attrs={'name': 'description'}):
|
||||
metadata['description'] = description.get('content', 'No description found.')
|
||||
if html := soup.find('html'):
|
||||
metadata['language'] = html.get('lang', 'No language found.')
|
||||
return metadata
|
||||
|
||||
|
||||
def verify_ssl_cert(url: str) -> bool:
|
||||
"""Verify SSL certificate for the given URL."""
|
||||
if not url.startswith("https://"):
|
||||
if not url.startswith('https://'):
|
||||
return True
|
||||
|
||||
try:
|
||||
hostname = url.split("://")[-1].split("/")[0]
|
||||
hostname = url.split('://')[-1].split('/')[0]
|
||||
context = ssl.create_default_context(cafile=certifi.where())
|
||||
with context.wrap_socket(ssl.socket(), server_hostname=hostname) as s:
|
||||
s.connect((hostname, 443))
|
||||
@@ -136,7 +134,7 @@ def verify_ssl_cert(url: str) -> bool:
|
||||
except ssl.SSLError:
|
||||
return False
|
||||
except Exception as e:
|
||||
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||
log.warning(f'SSL verification failed for {url}: {str(e)}')
|
||||
return False
|
||||
|
||||
|
||||
@@ -168,14 +166,14 @@ class URLProcessingMixin:
|
||||
async def _safe_process_url(self, url: str) -> bool:
|
||||
"""Perform safety checks before processing a URL."""
|
||||
if self.verify_ssl and not await self._verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
raise ValueError(f'SSL certificate verification failed for {url}')
|
||||
await self._wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
def _safe_process_url_sync(self, url: str) -> bool:
|
||||
"""Synchronous version of safety checks."""
|
||||
if self.verify_ssl and not verify_ssl_cert(url):
|
||||
raise ValueError(f"SSL certificate verification failed for {url}")
|
||||
raise ValueError(f'SSL certificate verification failed for {url}')
|
||||
self._sync_wait_for_rate_limit()
|
||||
return True
|
||||
|
||||
@@ -191,7 +189,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
mode: Literal["crawl", "scrape", "map"] = "scrape",
|
||||
mode: Literal['crawl', 'scrape', 'map'] = 'scrape',
|
||||
proxy: Optional[Dict[str, str]] = None,
|
||||
params: Optional[Dict] = None,
|
||||
):
|
||||
@@ -216,15 +214,15 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
params: The parameters to pass to the Firecrawl API.
|
||||
For more details, visit: https://docs.firecrawl.dev/sdks/python#batch-scrape
|
||||
"""
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
proxy_server = proxy.get('server') if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
proxy['server'] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
proxy = {'server': env_proxy_server}
|
||||
self.web_paths = web_paths
|
||||
self.verify_ssl = verify_ssl
|
||||
self.requests_per_second = requests_per_second
|
||||
@@ -240,7 +238,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Load documents using FireCrawl batch_scrape."""
|
||||
log.debug(
|
||||
"Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s",
|
||||
'Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s',
|
||||
len(self.web_paths),
|
||||
self.mode,
|
||||
self.params,
|
||||
@@ -251,7 +249,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
firecrawl = FirecrawlApp(api_key=self.api_key, api_url=self.api_url)
|
||||
result = firecrawl.batch_scrape(
|
||||
self.web_paths,
|
||||
formats=["markdown"],
|
||||
formats=['markdown'],
|
||||
skip_tls_verification=not self.verify_ssl,
|
||||
ignore_invalid_urls=True,
|
||||
remove_base64_images=True,
|
||||
@@ -260,28 +258,26 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
**self.params,
|
||||
)
|
||||
|
||||
if result.status != "completed":
|
||||
raise RuntimeError(
|
||||
f"FireCrawl batch scrape did not complete successfully. result: {result}"
|
||||
)
|
||||
if result.status != 'completed':
|
||||
raise RuntimeError(f'FireCrawl batch scrape did not complete successfully. result: {result}')
|
||||
|
||||
for data in result.data:
|
||||
metadata = data.metadata or {}
|
||||
yield Document(
|
||||
page_content=data.markdown or "",
|
||||
metadata={"source": metadata.url or metadata.source_url or ""},
|
||||
page_content=data.markdown or '',
|
||||
metadata={'source': metadata.url or metadata.source_url or ''},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error extracting content from URLs: {e}")
|
||||
log.exception(f'Error extracting content from URLs: {e}')
|
||||
else:
|
||||
raise e
|
||||
|
||||
async def alazy_load(self):
|
||||
"""Async version of lazy_load."""
|
||||
log.debug(
|
||||
"Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s",
|
||||
'Starting FireCrawl batch scrape for %d URLs, mode: %s, params: %s',
|
||||
len(self.web_paths),
|
||||
self.mode,
|
||||
self.params,
|
||||
@@ -292,7 +288,7 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
firecrawl = FirecrawlApp(api_key=self.api_key, api_url=self.api_url)
|
||||
result = firecrawl.batch_scrape(
|
||||
self.web_paths,
|
||||
formats=["markdown"],
|
||||
formats=['markdown'],
|
||||
skip_tls_verification=not self.verify_ssl,
|
||||
ignore_invalid_urls=True,
|
||||
remove_base64_images=True,
|
||||
@@ -301,21 +297,19 @@ class SafeFireCrawlLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
**self.params,
|
||||
)
|
||||
|
||||
if result.status != "completed":
|
||||
raise RuntimeError(
|
||||
f"FireCrawl batch scrape did not complete successfully. result: {result}"
|
||||
)
|
||||
if result.status != 'completed':
|
||||
raise RuntimeError(f'FireCrawl batch scrape did not complete successfully. result: {result}')
|
||||
|
||||
for data in result.data:
|
||||
metadata = data.metadata or {}
|
||||
yield Document(
|
||||
page_content=data.markdown or "",
|
||||
metadata={"source": metadata.url or metadata.source_url or ""},
|
||||
page_content=data.markdown or '',
|
||||
metadata={'source': metadata.url or metadata.source_url or ''},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error extracting content from URLs: {e}")
|
||||
log.exception(f'Error extracting content from URLs: {e}')
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -325,7 +319,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
self,
|
||||
web_paths: Union[str, List[str]],
|
||||
api_key: str,
|
||||
extract_depth: Literal["basic", "advanced"] = "basic",
|
||||
extract_depth: Literal['basic', 'advanced'] = 'basic',
|
||||
continue_on_failure: bool = True,
|
||||
requests_per_second: Optional[float] = None,
|
||||
verify_ssl: bool = True,
|
||||
@@ -345,15 +339,15 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
proxy: Optional proxy configuration.
|
||||
"""
|
||||
# Initialize proxy configuration if using environment variables
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
proxy_server = proxy.get('server') if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
proxy['server'] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
proxy = {'server': env_proxy_server}
|
||||
|
||||
# Store parameters for creating TavilyLoader instances
|
||||
self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
|
||||
@@ -376,14 +370,14 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
self._safe_process_url_sync(url)
|
||||
valid_urls.append(url)
|
||||
except Exception as e:
|
||||
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||
log.warning(f'SSL verification failed for {url}: {str(e)}')
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
if not valid_urls:
|
||||
if self.continue_on_failure:
|
||||
log.warning("No valid URLs to process after SSL verification")
|
||||
log.warning('No valid URLs to process after SSL verification')
|
||||
return
|
||||
raise ValueError("No valid URLs to process after SSL verification")
|
||||
raise ValueError('No valid URLs to process after SSL verification')
|
||||
try:
|
||||
loader = TavilyLoader(
|
||||
urls=valid_urls,
|
||||
@@ -394,7 +388,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
yield from loader.lazy_load()
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error extracting content from URLs: {e}")
|
||||
log.exception(f'Error extracting content from URLs: {e}')
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -406,15 +400,15 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
await self._safe_process_url(url)
|
||||
valid_urls.append(url)
|
||||
except Exception as e:
|
||||
log.warning(f"SSL verification failed for {url}: {str(e)}")
|
||||
log.warning(f'SSL verification failed for {url}: {str(e)}')
|
||||
if not self.continue_on_failure:
|
||||
raise e
|
||||
|
||||
if not valid_urls:
|
||||
if self.continue_on_failure:
|
||||
log.warning("No valid URLs to process after SSL verification")
|
||||
log.warning('No valid URLs to process after SSL verification')
|
||||
return
|
||||
raise ValueError("No valid URLs to process after SSL verification")
|
||||
raise ValueError('No valid URLs to process after SSL verification')
|
||||
|
||||
try:
|
||||
loader = TavilyLoader(
|
||||
@@ -427,7 +421,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
|
||||
yield document
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error loading URLs: {e}")
|
||||
log.exception(f'Error loading URLs: {e}')
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -462,15 +456,15 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing
|
||||
):
|
||||
"""Initialize with additional safety parameters and remote browser support."""
|
||||
|
||||
proxy_server = proxy.get("server") if proxy else None
|
||||
proxy_server = proxy.get('server') if proxy else None
|
||||
if trust_env and not proxy_server:
|
||||
env_proxies = urllib.request.getproxies()
|
||||
env_proxy_server = env_proxies.get("https") or env_proxies.get("http")
|
||||
env_proxy_server = env_proxies.get('https') or env_proxies.get('http')
|
||||
if env_proxy_server:
|
||||
if proxy:
|
||||
proxy["server"] = env_proxy_server
|
||||
proxy['server'] = env_proxy_server
|
||||
else:
|
||||
proxy = {"server": env_proxy_server}
|
||||
proxy = {'server': env_proxy_server}
|
||||
|
||||
# We'll set headless to False if using playwright_ws_url since it's handled by the remote browser
|
||||
super().__init__(
|
||||
@@ -504,14 +498,14 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing
|
||||
page = browser.new_page()
|
||||
response = page.goto(url, timeout=self.playwright_timeout)
|
||||
if response is None:
|
||||
raise ValueError(f"page.goto() returned None for url {url}")
|
||||
raise ValueError(f'page.goto() returned None for url {url}')
|
||||
|
||||
text = self.evaluator.evaluate(page, browser, response)
|
||||
metadata = {"source": url}
|
||||
metadata = {'source': url}
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
log.exception(f'Error loading {url}: {e}')
|
||||
continue
|
||||
raise e
|
||||
browser.close()
|
||||
@@ -525,9 +519,7 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing
|
||||
if self.playwright_ws_url:
|
||||
browser = await p.chromium.connect(self.playwright_ws_url)
|
||||
else:
|
||||
browser = await p.chromium.launch(
|
||||
headless=self.headless, proxy=self.proxy
|
||||
)
|
||||
browser = await p.chromium.launch(headless=self.headless, proxy=self.proxy)
|
||||
|
||||
for url in self.urls:
|
||||
try:
|
||||
@@ -535,14 +527,14 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing
|
||||
page = await browser.new_page()
|
||||
response = await page.goto(url, timeout=self.playwright_timeout)
|
||||
if response is None:
|
||||
raise ValueError(f"page.goto() returned None for url {url}")
|
||||
raise ValueError(f'page.goto() returned None for url {url}')
|
||||
|
||||
text = await self.evaluator.evaluate_async(page, browser, response)
|
||||
metadata = {"source": url}
|
||||
metadata = {'source': url}
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.exception(f"Error loading {url}: {e}")
|
||||
log.exception(f'Error loading {url}: {e}')
|
||||
continue
|
||||
raise e
|
||||
await browser.close()
|
||||
@@ -560,9 +552,7 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.trust_env = trust_env
|
||||
|
||||
async def _fetch(
|
||||
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
|
||||
) -> str:
|
||||
async def _fetch(self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5) -> str:
|
||||
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
|
||||
for i in range(retries):
|
||||
try:
|
||||
@@ -571,7 +561,7 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
cookies=self.session.cookies.get_dict(),
|
||||
)
|
||||
if not self.session.verify:
|
||||
kwargs["ssl"] = False
|
||||
kwargs['ssl'] = False
|
||||
|
||||
async with session.get(
|
||||
url,
|
||||
@@ -585,16 +575,11 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
if i == retries - 1:
|
||||
raise
|
||||
else:
|
||||
log.warning(
|
||||
f"Error fetching {url} with attempt "
|
||||
f"{i + 1}/{retries}: {e}. Retrying..."
|
||||
)
|
||||
log.warning(f'Error fetching {url} with attempt {i + 1}/{retries}: {e}. Retrying...')
|
||||
await asyncio.sleep(cooldown * backoff**i)
|
||||
raise ValueError("retry count exceeded")
|
||||
raise ValueError('retry count exceeded')
|
||||
|
||||
def _unpack_fetch_results(
|
||||
self, results: Any, urls: List[str], parser: Union[str, None] = None
|
||||
) -> List[Any]:
|
||||
def _unpack_fetch_results(self, results: Any, urls: List[str], parser: Union[str, None] = None) -> List[Any]:
|
||||
"""Unpack fetch results into BeautifulSoup objects."""
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
@@ -602,17 +587,15 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
for i, result in enumerate(results):
|
||||
url = urls[i]
|
||||
if parser is None:
|
||||
if url.endswith(".xml"):
|
||||
parser = "xml"
|
||||
if url.endswith('.xml'):
|
||||
parser = 'xml'
|
||||
else:
|
||||
parser = self.default_parser
|
||||
self._check_parser(parser)
|
||||
final_results.append(BeautifulSoup(result, parser, **self.bs_kwargs))
|
||||
return final_results
|
||||
|
||||
async def ascrape_all(
|
||||
self, urls: List[str], parser: Union[str, None] = None
|
||||
) -> List[Any]:
|
||||
async def ascrape_all(self, urls: List[str], parser: Union[str, None] = None) -> List[Any]:
|
||||
"""Async fetch all urls, then return soups for all results."""
|
||||
results = await self.fetch_all(urls)
|
||||
return self._unpack_fetch_results(results, urls, parser=parser)
|
||||
@@ -630,22 +613,20 @@ class SafeWebBaseLoader(WebBaseLoader):
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
# Log the error and continue with the next URL
|
||||
log.exception(f"Error loading {path}: {e}")
|
||||
log.exception(f'Error loading {path}: {e}')
|
||||
|
||||
async def alazy_load(self) -> AsyncIterator[Document]:
|
||||
"""Async lazy load text from the url(s) in web_path."""
|
||||
results = await self.ascrape_all(self.web_paths)
|
||||
for path, soup in zip(self.web_paths, results):
|
||||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
metadata = {"source": path}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get(
|
||||
"content", "No description found."
|
||||
)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
metadata = {'source': path}
|
||||
if title := soup.find('title'):
|
||||
metadata['title'] = title.get_text()
|
||||
if description := soup.find('meta', attrs={'name': 'description'}):
|
||||
metadata['description'] = description.get('content', 'No description found.')
|
||||
if html := soup.find('html'):
|
||||
metadata['language'] = html.get('lang', 'No language found.')
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
|
||||
async def aload(self) -> list[Document]:
|
||||
@@ -663,18 +644,18 @@ def get_web_loader(
|
||||
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
|
||||
|
||||
if not safe_urls:
|
||||
log.warning(f"All provided URLs were blocked or invalid: {urls}")
|
||||
log.warning(f'All provided URLs were blocked or invalid: {urls}')
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
|
||||
web_loader_args = {
|
||||
"web_paths": safe_urls,
|
||||
"verify_ssl": verify_ssl,
|
||||
"requests_per_second": requests_per_second,
|
||||
"continue_on_failure": True,
|
||||
"trust_env": trust_env,
|
||||
'web_paths': safe_urls,
|
||||
'verify_ssl': verify_ssl,
|
||||
'requests_per_second': requests_per_second,
|
||||
'continue_on_failure': True,
|
||||
'trust_env': trust_env,
|
||||
}
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "" or WEB_LOADER_ENGINE.value == "safe_web":
|
||||
if WEB_LOADER_ENGINE.value == '' or WEB_LOADER_ENGINE.value == 'safe_web':
|
||||
WebLoaderClass = SafeWebBaseLoader
|
||||
|
||||
request_kwargs = {}
|
||||
@@ -685,42 +666,42 @@ def get_web_loader(
|
||||
timeout_value = None
|
||||
|
||||
if timeout_value:
|
||||
request_kwargs["timeout"] = timeout_value
|
||||
request_kwargs['timeout'] = timeout_value
|
||||
|
||||
if request_kwargs:
|
||||
web_loader_args["requests_kwargs"] = request_kwargs
|
||||
web_loader_args['requests_kwargs'] = request_kwargs
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "playwright":
|
||||
if WEB_LOADER_ENGINE.value == 'playwright':
|
||||
WebLoaderClass = SafePlaywrightURLLoader
|
||||
web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value
|
||||
web_loader_args['playwright_timeout'] = PLAYWRIGHT_TIMEOUT.value
|
||||
if PLAYWRIGHT_WS_URL.value:
|
||||
web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URL.value
|
||||
web_loader_args['playwright_ws_url'] = PLAYWRIGHT_WS_URL.value
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "firecrawl":
|
||||
if WEB_LOADER_ENGINE.value == 'firecrawl':
|
||||
WebLoaderClass = SafeFireCrawlLoader
|
||||
web_loader_args["api_key"] = FIRECRAWL_API_KEY.value
|
||||
web_loader_args["api_url"] = FIRECRAWL_API_BASE_URL.value
|
||||
web_loader_args['api_key'] = FIRECRAWL_API_KEY.value
|
||||
web_loader_args['api_url'] = FIRECRAWL_API_BASE_URL.value
|
||||
if FIRECRAWL_TIMEOUT.value:
|
||||
try:
|
||||
web_loader_args["timeout"] = int(FIRECRAWL_TIMEOUT.value)
|
||||
web_loader_args['timeout'] = int(FIRECRAWL_TIMEOUT.value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "tavily":
|
||||
if WEB_LOADER_ENGINE.value == 'tavily':
|
||||
WebLoaderClass = SafeTavilyLoader
|
||||
web_loader_args["api_key"] = TAVILY_API_KEY.value
|
||||
web_loader_args["extract_depth"] = TAVILY_EXTRACT_DEPTH.value
|
||||
web_loader_args['api_key'] = TAVILY_API_KEY.value
|
||||
web_loader_args['extract_depth'] = TAVILY_EXTRACT_DEPTH.value
|
||||
|
||||
if WEB_LOADER_ENGINE.value == "external":
|
||||
if WEB_LOADER_ENGINE.value == 'external':
|
||||
WebLoaderClass = ExternalWebLoader
|
||||
web_loader_args["external_url"] = EXTERNAL_WEB_LOADER_URL.value
|
||||
web_loader_args["external_api_key"] = EXTERNAL_WEB_LOADER_API_KEY.value
|
||||
web_loader_args['external_url'] = EXTERNAL_WEB_LOADER_URL.value
|
||||
web_loader_args['external_api_key'] = EXTERNAL_WEB_LOADER_API_KEY.value
|
||||
|
||||
if WebLoaderClass:
|
||||
web_loader = WebLoaderClass(**web_loader_args)
|
||||
|
||||
log.debug(
|
||||
"Using WEB_LOADER_ENGINE %s for %s URLs",
|
||||
'Using WEB_LOADER_ENGINE %s for %s URLs',
|
||||
web_loader.__class__.__name__,
|
||||
len(safe_urls),
|
||||
)
|
||||
@@ -728,6 +709,6 @@ def get_web_loader(
|
||||
return web_loader
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. "
|
||||
f'Invalid WEB_LOADER_ENGINE: {WEB_LOADER_ENGINE.value}. '
|
||||
"Please set it to 'safe_web', 'playwright', 'firecrawl', or 'tavily'."
|
||||
)
|
||||
|
||||
@@ -41,29 +41,29 @@ def search_yacy(
|
||||
yacy_auth = HTTPDigestAuth(username, password)
|
||||
|
||||
params = {
|
||||
"query": query,
|
||||
"contentdom": "text",
|
||||
"resource": "global",
|
||||
"maximumRecords": count,
|
||||
"nav": "none",
|
||||
'query': query,
|
||||
'contentdom': 'text',
|
||||
'resource': 'global',
|
||||
'maximumRecords': count,
|
||||
'nav': 'none',
|
||||
}
|
||||
|
||||
# Check if provided a json API URL
|
||||
if not query_url.endswith("yacysearch.json"):
|
||||
if not query_url.endswith('yacysearch.json'):
|
||||
# Strip all query parameters from the URL
|
||||
query_url = query_url.rstrip("/") + "/yacysearch.json"
|
||||
query_url = query_url.rstrip('/') + '/yacysearch.json'
|
||||
|
||||
log.debug(f"searching {query_url}")
|
||||
log.debug(f'searching {query_url}')
|
||||
|
||||
response = requests.get(
|
||||
query_url,
|
||||
auth=yacy_auth,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Accept": "text/html",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Accept-Language": "en-US,en;q=0.5",
|
||||
"Connection": "keep-alive",
|
||||
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot',
|
||||
'Accept': 'text/html',
|
||||
'Accept-Encoding': 'gzip, deflate',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Connection': 'keep-alive',
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
@@ -71,15 +71,15 @@ def search_yacy(
|
||||
response.raise_for_status() # Raise an exception for HTTP errors.
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("channels", [{}])[0].get("items", [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get("ranking", 0), reverse=True)
|
||||
results = json_response.get('channels', [{}])[0].get('items', [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get('ranking', 0), reverse=True)
|
||||
if filter_list:
|
||||
sorted_results = get_filtered_results(sorted_results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
link=result['link'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('description'),
|
||||
)
|
||||
for result in sorted_results[:count]
|
||||
]
|
||||
|
||||
@@ -20,14 +20,14 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def xml_element_contents_to_string(element: Element) -> str:
|
||||
buffer = [element.text if element.text else ""]
|
||||
buffer = [element.text if element.text else '']
|
||||
|
||||
for child in element:
|
||||
buffer.append(xml_element_contents_to_string(child))
|
||||
|
||||
buffer.append(element.tail if element.tail else "")
|
||||
buffer.append(element.tail if element.tail else '')
|
||||
|
||||
return "".join(buffer)
|
||||
return ''.join(buffer)
|
||||
|
||||
|
||||
def search_yandex(
|
||||
@@ -42,42 +42,38 @@ def search_yandex(
|
||||
) -> List[SearchResult]:
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Authorization": f"Api-Key {yandex_search_api_key}",
|
||||
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot',
|
||||
'Authorization': f'Api-Key {yandex_search_api_key}',
|
||||
}
|
||||
|
||||
if user is not None:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
chat_id = getattr(request.state, "chat_id", None)
|
||||
chat_id = getattr(request.state, 'chat_id', None)
|
||||
if chat_id:
|
||||
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = str(chat_id)
|
||||
|
||||
payload = {} if yandex_search_config == "" else json.loads(yandex_search_config)
|
||||
payload = {} if yandex_search_config == '' else json.loads(yandex_search_config)
|
||||
|
||||
if type(payload.get("query", None)) != dict:
|
||||
payload["query"] = {}
|
||||
if type(payload.get('query', None)) != dict:
|
||||
payload['query'] = {}
|
||||
|
||||
if "searchType" not in payload["query"]:
|
||||
payload["query"]["searchType"] = "SEARCH_TYPE_RU"
|
||||
if 'searchType' not in payload['query']:
|
||||
payload['query']['searchType'] = 'SEARCH_TYPE_RU'
|
||||
|
||||
payload["query"]["queryText"] = query
|
||||
payload['query']['queryText'] = query
|
||||
|
||||
if type(payload.get("groupSpec", None)) != dict:
|
||||
payload["groupSpec"] = {}
|
||||
if type(payload.get('groupSpec', None)) != dict:
|
||||
payload['groupSpec'] = {}
|
||||
|
||||
if "groupMode" not in payload["groupSpec"]:
|
||||
payload["groupSpec"]["groupMode"] = "GROUP_MODE_DEEP"
|
||||
if 'groupMode' not in payload['groupSpec']:
|
||||
payload['groupSpec']['groupMode'] = 'GROUP_MODE_DEEP'
|
||||
|
||||
payload["groupSpec"]["groupsOnPage"] = count
|
||||
payload["groupSpec"]["docsInGroup"] = 1
|
||||
payload['groupSpec']['groupsOnPage'] = count
|
||||
payload['groupSpec']['docsInGroup'] = 1
|
||||
|
||||
response = requests.post(
|
||||
(
|
||||
"https://searchapi.api.cloud.yandex.net/v2/web/search"
|
||||
if yandex_search_url == ""
|
||||
else yandex_search_url
|
||||
),
|
||||
('https://searchapi.api.cloud.yandex.net/v2/web/search' if yandex_search_url == '' else yandex_search_url),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
@@ -85,29 +81,21 @@ def search_yandex(
|
||||
response.raise_for_status()
|
||||
|
||||
response_body = response.json()
|
||||
if "rawData" not in response_body:
|
||||
raise Exception(f"No `rawData` in response body: {response_body}")
|
||||
if 'rawData' not in response_body:
|
||||
raise Exception(f'No `rawData` in response body: {response_body}')
|
||||
|
||||
search_result_body_bytes = base64.decodebytes(
|
||||
bytes(response_body["rawData"], "utf-8")
|
||||
)
|
||||
search_result_body_bytes = base64.decodebytes(bytes(response_body['rawData'], 'utf-8'))
|
||||
|
||||
doc_root = ET.parse(io.BytesIO(search_result_body_bytes))
|
||||
|
||||
results = []
|
||||
|
||||
for group in doc_root.findall("response/results/grouping/group"):
|
||||
for group in doc_root.findall('response/results/grouping/group'):
|
||||
results.append(
|
||||
{
|
||||
"url": xml_element_contents_to_string(group.find("doc/url")).strip(
|
||||
"\n"
|
||||
),
|
||||
"title": xml_element_contents_to_string(
|
||||
group.find("doc/title")
|
||||
).strip("\n"),
|
||||
"snippet": xml_element_contents_to_string(
|
||||
group.find("doc/passages/passage")
|
||||
),
|
||||
'url': xml_element_contents_to_string(group.find('doc/url')).strip('\n'),
|
||||
'title': xml_element_contents_to_string(group.find('doc/title')).strip('\n'),
|
||||
'snippet': xml_element_contents_to_string(group.find('doc/passages/passage')),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -115,49 +103,47 @@ def search_yandex(
|
||||
|
||||
results = [
|
||||
SearchResult(
|
||||
link=result.get("url"),
|
||||
title=result.get("title"),
|
||||
snippet=result.get("snippet"),
|
||||
link=result.get('url'),
|
||||
title=result.get('title'),
|
||||
snippet=result.get('snippet'),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
log.info(f"Yandex search results: {results}")
|
||||
log.info(f'Yandex search results: {results}')
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
log.error(f"Error in search: {e}")
|
||||
log.error(f'Error in search: {e}')
|
||||
|
||||
return []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
from starlette.datastructures import Headers
|
||||
from fastapi import FastAPI
|
||||
|
||||
result = search_yandex(
|
||||
Request(
|
||||
{
|
||||
"type": "http",
|
||||
"asgi.version": "3.0",
|
||||
"asgi.spec_version": "2.0",
|
||||
"method": "GET",
|
||||
"path": "/internal",
|
||||
"query_string": b"",
|
||||
"headers": Headers({}).raw,
|
||||
"client": ("127.0.0.1", 12345),
|
||||
"server": ("127.0.0.1", 80),
|
||||
"scheme": "http",
|
||||
"app": FastAPI(),
|
||||
'type': 'http',
|
||||
'asgi.version': '3.0',
|
||||
'asgi.spec_version': '2.0',
|
||||
'method': 'GET',
|
||||
'path': '/internal',
|
||||
'query_string': b'',
|
||||
'headers': Headers({}).raw,
|
||||
'client': ('127.0.0.1', 12345),
|
||||
'server': ('127.0.0.1', 80),
|
||||
'scheme': 'http',
|
||||
'app': FastAPI(),
|
||||
},
|
||||
None,
|
||||
),
|
||||
os.environ.get("YANDEX_WEB_SEARCH_URL", ""),
|
||||
os.environ.get("YANDEX_WEB_SEARCH_API_KEY", ""),
|
||||
os.environ.get(
|
||||
"YANDEX_WEB_SEARCH_CONFIG", '{"query": {"searchType": "SEARCH_TYPE_COM"}}'
|
||||
),
|
||||
"TOP movies of the past year",
|
||||
os.environ.get('YANDEX_WEB_SEARCH_URL', ''),
|
||||
os.environ.get('YANDEX_WEB_SEARCH_API_KEY', ''),
|
||||
os.environ.get('YANDEX_WEB_SEARCH_CONFIG', '{"query": {"searchType": "SEARCH_TYPE_COM"}}'),
|
||||
'TOP movies of the past year',
|
||||
3,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ def search_youcom(
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[List[str]] = None,
|
||||
language: str = "EN",
|
||||
language: str = 'EN',
|
||||
) -> List[SearchResult]:
|
||||
"""Search using You.com's YDC Index API and return the results as a list of SearchResult objects.
|
||||
|
||||
@@ -23,30 +23,30 @@ def search_youcom(
|
||||
filter_list (list[str], optional): Domain filter list
|
||||
language (str): Language code for search results (default: "EN")
|
||||
"""
|
||||
url = "https://ydc-index.io/v1/search"
|
||||
url = 'https://ydc-index.io/v1/search'
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"X-API-KEY": api_key,
|
||||
'Accept': 'application/json',
|
||||
'X-API-KEY': api_key,
|
||||
}
|
||||
params = {
|
||||
"query": query,
|
||||
"count": count,
|
||||
"language": language,
|
||||
'query': query,
|
||||
'count': count,
|
||||
'language': language,
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("results", {}).get("web", [])
|
||||
results = json_response.get('results', {}).get('web', [])
|
||||
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("title"),
|
||||
link=result['url'],
|
||||
title=result.get('title'),
|
||||
snippet=_build_snippet(result),
|
||||
)
|
||||
for result in results[:count]
|
||||
@@ -62,12 +62,12 @@ def _build_snippet(result: dict) -> str:
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
description = result.get("description")
|
||||
description = result.get('description')
|
||||
if description:
|
||||
parts.append(description)
|
||||
|
||||
snippets = result.get("snippets")
|
||||
snippets = result.get('snippets')
|
||||
if snippets and isinstance(snippets, list):
|
||||
parts.extend(snippets)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
return '\n\n'.join(parts)
|
||||
|
||||
Reference in New Issue
Block a user