This commit is contained in:
Timothy Jaeryang Baek
2026-03-17 17:58:01 -05:00
parent fcf7208352
commit de3317e26b
220 changed files with 17200 additions and 22836 deletions

View File

@@ -10,94 +10,88 @@ from typing_extensions import Annotated
app = typer.Typer()
KEY_FILE = Path.cwd() / ".webui_secret_key"
KEY_FILE = Path.cwd() / '.webui_secret_key'
def version_callback(value: bool):
if value:
from open_webui.env import VERSION
typer.echo(f"Open WebUI version: {VERSION}")
typer.echo(f'Open WebUI version: {VERSION}')
raise typer.Exit()
@app.command()
def main(
version: Annotated[
Optional[bool], typer.Option("--version", callback=version_callback)
] = None,
version: Annotated[Optional[bool], typer.Option('--version', callback=version_callback)] = None,
):
pass
@app.command()
def serve(
host: str = "0.0.0.0",
host: str = '0.0.0.0',
port: int = 8080,
):
os.environ["FROM_INIT_PY"] = "true"
if os.getenv("WEBUI_SECRET_KEY") is None:
typer.echo(
"Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
)
os.environ['FROM_INIT_PY'] = 'true'
if os.getenv('WEBUI_SECRET_KEY') is None:
typer.echo('Loading WEBUI_SECRET_KEY from file, not provided as an environment variable.')
if not KEY_FILE.exists():
typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}")
typer.echo(f'Generating a new secret key and saving it to {KEY_FILE}')
KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}")
os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text()
typer.echo(f'Loading WEBUI_SECRET_KEY from {KEY_FILE}')
os.environ['WEBUI_SECRET_KEY'] = KEY_FILE.read_text()
if os.getenv("USE_CUDA_DOCKER", "false") == "true":
typer.echo(
"CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
)
LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
os.environ["LD_LIBRARY_PATH"] = ":".join(
if os.getenv('USE_CUDA_DOCKER', 'false') == 'true':
typer.echo('CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries.')
LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH', '').split(':')
os.environ['LD_LIBRARY_PATH'] = ':'.join(
LD_LIBRARY_PATH
+ [
"/usr/local/lib/python3.11/site-packages/torch/lib",
"/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
'/usr/local/lib/python3.11/site-packages/torch/lib',
'/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib',
]
)
try:
import torch
assert torch.cuda.is_available(), "CUDA not available"
typer.echo("CUDA seems to be working")
assert torch.cuda.is_available(), 'CUDA not available'
typer.echo('CUDA seems to be working')
except Exception as e:
typer.echo(
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
"Resetting USE_CUDA_DOCKER to false and removing "
f"LD_LIBRARY_PATH modifications: {e}"
'Error when testing CUDA but USE_CUDA_DOCKER is true. '
'Resetting USE_CUDA_DOCKER to false and removing '
f'LD_LIBRARY_PATH modifications: {e}'
)
os.environ["USE_CUDA_DOCKER"] = "false"
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
os.environ['USE_CUDA_DOCKER'] = 'false'
os.environ['LD_LIBRARY_PATH'] = ':'.join(LD_LIBRARY_PATH)
import open_webui.main # we need set environment variables before importing main
from open_webui.env import UVICORN_WORKERS # Import the workers setting
uvicorn.run(
"open_webui.main:app",
'open_webui.main:app',
host=host,
port=port,
forwarded_allow_ips="*",
forwarded_allow_ips='*',
workers=UVICORN_WORKERS,
)
@app.command()
def dev(
host: str = "0.0.0.0",
host: str = '0.0.0.0',
port: int = 8080,
reload: bool = True,
):
uvicorn.run(
"open_webui.main:app",
'open_webui.main:app',
host=host,
port=port,
reload=reload,
forwarded_allow_ips="*",
forwarded_allow_ips='*',
)
if __name__ == "__main__":
if __name__ == '__main__':
app()

File diff suppressed because it is too large Load Diff

View File

@@ -2,125 +2,107 @@ from enum import Enum
class MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}"
MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully."
MODEL_DELETED = (
lambda model="": f"The model '{model}' has been deleted successfully."
)
DEFAULT = lambda msg='': f'{msg if msg else ""}'
MODEL_ADDED = lambda model='': f"The model '{model}' has been added successfully."
MODEL_DELETED = lambda model='': f"The model '{model}' has been deleted successfully."
class WEBHOOK_MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}"
USER_SIGNUP = lambda username="": (
f"New user signed up: {username}" if username else "New user signed up"
)
DEFAULT = lambda msg='': f'{msg if msg else ""}'
USER_SIGNUP = lambda username='': (f'New user signed up: {username}' if username else 'New user signed up')
class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = (
lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
DEFAULT = lambda err='': f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
ENV_VAR_NOT_FOUND = 'Required environment variable not found. Terminating now.'
CREATE_USER_ERROR = 'Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance.'
DELETE_USER_ERROR = 'Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot.'
EMAIL_MISMATCH = 'Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again.'
EMAIL_TAKEN = 'Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew.'
USERNAME_TAKEN = 'Uh-oh! This username is already registered. Please choose another username.'
PASSWORD_TOO_LONG = (
'Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long.'
)
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again."
EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew."
USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username."
)
PASSWORD_TOO_LONG = "Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long."
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
COMMAND_TAKEN = 'Uh-oh! This command is already registered. Please choose another command string.'
FILE_EXISTS = 'Uh-oh! This file is already registered. Please choose another file.'
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
MODEL_ID_TOO_LONG = "The model id is too long. Please make sure your model id is less than 256 characters long."
ID_TAKEN = 'Uh-oh! This id is already registered. Please choose another id string.'
MODEL_ID_TAKEN = 'Uh-oh! This model id is already registered. Please choose another model id string.'
NAME_TAG_TAKEN = 'Uh-oh! This name tag is already registered. Please choose another name tag string.'
MODEL_ID_TOO_LONG = 'The model id is too long. Please make sure your model id is less than 256 characters long.'
INVALID_TOKEN = (
"Your session has expired or the token is invalid. Please sign in again."
)
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
INVALID_TOKEN = 'Your session has expired or the token is invalid. Please sign in again.'
INVALID_CRED = 'The email or password provided is incorrect. Please check for typos and try logging in again.'
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
INCORRECT_PASSWORD = (
"The password provided is incorrect. Please check for typos and try again."
INCORRECT_PASSWORD = 'The password provided is incorrect. Please check for typos and try again.'
INVALID_TRUSTED_HEADER = (
'Your provider has not provided a trusted header. Please contact your administrator for assistance.'
)
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation."
UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = (
"The requested action has been restricted as a security measure."
UNAUTHORIZED = '401 Unauthorized'
ACCESS_PROHIBITED = (
'You do not have permission to access this resource. Please contact your administrator for assistance.'
)
ACTION_PROHIBITED = 'The requested action has been restricted as a security measure.'
FILE_NOT_SENT = "FILE_NOT_SENT"
FILE_NOT_SENT = 'FILE_NOT_SENT'
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again."
NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment."
API_KEY_NOT_ALLOWED = 'Use of API key is not enabled in the environment.'
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
MALICIOUS = 'Unusual activities detected, please try again in a few minutes.'
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
INCORRECT_FORMAT = (
lambda err="": f"Invalid format. Please use the correct format{err}"
)
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
PANDOC_NOT_INSTALLED = 'Pandoc is not installed on the server. Please contact your administrator for assistance.'
INCORRECT_FORMAT = lambda err='': f'Invalid format. Please use the correct format{err}'
RATE_LIMIT_EXCEEDED = 'API rate limit exceeded'
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment."
MODEL_NOT_FOUND = lambda name='': f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name='': 'OpenAI API was not found'
OLLAMA_NOT_FOUND = 'WebUI could not connect to Ollama'
CREATE_API_KEY_ERROR = 'Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance.'
API_KEY_CREATION_NOT_ALLOWED = 'API key creation is not allowed in the environment.'
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
EMPTY_CONTENT = 'The content provided is empty. Please ensure that there is text or data present before proceeding.'
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
DB_NOT_SQLITE = 'This feature is only available when running with SQLite databases.'
INVALID_URL = (
"Oops! The URL you provided is invalid. Please double-check and try again."
)
INVALID_URL = 'Oops! The URL you provided is invalid. Please double-check and try again.'
WEB_SEARCH_ERROR = (
lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
)
WEB_SEARCH_ERROR = lambda err='': f'{err if err else "Oops! Something went wrong while searching the web."}'
OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature."
)
OLLAMA_API_DISABLED = 'The Ollama API is disabled. Please enable it to use this feature.'
FILE_TOO_LARGE = (
lambda size="": f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}."
lambda size='': f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}."
)
DUPLICATE_CONTENT = (
"Duplicate content detected. Please provide unique content to proceed."
DUPLICATE_CONTENT = 'Duplicate content detected. Please provide unique content to proceed.'
FILE_NOT_PROCESSED = (
'Extracted content is not available for this file. Please ensure that the file is processed before proceeding.'
)
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
INVALID_PASSWORD = lambda err="": (
err if err else "The password does not meet the required validation criteria."
)
INVALID_PASSWORD = lambda err='': (err if err else 'The password does not meet the required validation criteria.')
class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'generation'}"
TITLE_GENERATION = "title_generation"
FOLLOW_UP_GENERATION = "follow_up_generation"
TAGS_GENERATION = "tags_generation"
EMOJI_GENERATION = "emoji_generation"
QUERY_GENERATION = "query_generation"
IMAGE_PROMPT_GENERATION = "image_prompt_generation"
AUTOCOMPLETE_GENERATION = "autocomplete_generation"
FUNCTION_CALLING = "function_calling"
MOA_RESPONSE_GENERATION = "moa_response_generation"
DEFAULT = lambda task='': f'{task if task else "generation"}'
TITLE_GENERATION = 'title_generation'
FOLLOW_UP_GENERATION = 'follow_up_generation'
TAGS_GENERATION = 'tags_generation'
EMOJI_GENERATION = 'emoji_generation'
QUERY_GENERATION = 'query_generation'
IMAGE_PROMPT_GENERATION = 'image_prompt_generation'
AUTOCOMPLETE_GENERATION = 'autocomplete_generation'
FUNCTION_CALLING = 'function_calling'
MOA_RESPONSE_GENERATION = 'moa_response_generation'

File diff suppressed because it is too large Load Diff

View File

@@ -57,17 +57,15 @@ log = logging.getLogger(__name__)
def get_function_module_by_id(request: Request, pipe_id: str):
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
if hasattr(function_module, 'valves') and hasattr(function_module, 'Valves'):
Valves = function_module.Valves
valves = Functions.get_function_valves_by_id(pipe_id)
if valves:
try:
function_module.valves = Valves(
**{k: v for k, v in valves.items() if v is not None}
)
function_module.valves = Valves(**{k: v for k, v in valves.items() if v is not None})
except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}")
log.exception(f'Error loading valves for function {pipe_id}: {e}')
raise e
else:
function_module.valves = Valves()
@@ -76,7 +74,7 @@ def get_function_module_by_id(request: Request, pipe_id: str):
async def get_function_models(request):
pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipes = Functions.get_functions_by_type('pipe', active_only=True)
pipe_models = []
for pipe in pipes:
@@ -84,11 +82,11 @@ async def get_function_models(request):
function_module = get_function_module_by_id(request, pipe.id)
has_user_valves = False
if hasattr(function_module, "UserValves"):
if hasattr(function_module, 'UserValves'):
has_user_valves = True
# Check if function is a manifold
if hasattr(function_module, "pipes"):
if hasattr(function_module, 'pipes'):
sub_pipes = []
# Handle pipes being a list, sync function, or async function
@@ -104,32 +102,30 @@ async def get_function_models(request):
log.exception(e)
sub_pipes = []
log.debug(
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)
log.debug(f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}")
for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}'
sub_pipe_name = p["name"]
sub_pipe_name = p['name']
if hasattr(function_module, "name"):
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
if hasattr(function_module, 'name'):
sub_pipe_name = f'{function_module.name}{sub_pipe_name}'
pipe_flag = {"type": pipe.type}
pipe_flag = {'type': pipe.type}
pipe_models.append(
{
"id": sub_pipe_id,
"name": sub_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
"has_user_valves": has_user_valves,
'id': sub_pipe_id,
'name': sub_pipe_name,
'object': 'model',
'created': pipe.created_at,
'owned_by': 'openai',
'pipe': pipe_flag,
'has_user_valves': has_user_valves,
}
)
else:
pipe_flag = {"type": "pipe"}
pipe_flag = {'type': 'pipe'}
log.debug(
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
@@ -137,13 +133,13 @@ async def get_function_models(request):
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": pipe_flag,
"has_user_valves": has_user_valves,
'id': pipe.id,
'name': pipe.name,
'object': 'model',
'created': pipe.created_at,
'owned_by': 'openai',
'pipe': pipe_flag,
'has_user_valves': has_user_valves,
}
)
except Exception as e:
@@ -153,9 +149,7 @@ async def get_function_models(request):
return pipe_models
async def generate_function_chat_completion(
request, form_data, user, models: dict = {}
):
async def generate_function_chat_completion(request, form_data, user, models: dict = {}):
async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe):
return await pipe(**params)
@@ -166,32 +160,32 @@ async def generate_function_chat_completion(
if isinstance(res, str):
return res
if isinstance(res, Generator):
return "".join(map(str, res))
return ''.join(map(str, res))
if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res])
return ''.join([str(stream) async for stream in res])
def process_line(form_data: dict, line):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
line = f'data: {line}'
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
line = f'data: {json.dumps(line)}'
try:
line = line.decode("utf-8")
line = line.decode('utf-8')
except Exception:
pass
if line.startswith("data:"):
return f"{line}\n\n"
if line.startswith('data:'):
return f'{line}\n\n'
else:
line = openai_chat_chunk_message_template(form_data["model"], line)
return f"data: {json.dumps(line)}\n\n"
line = openai_chat_chunk_message_template(form_data['model'], line)
return f'data: {json.dumps(line)}\n\n'
def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, _ = pipe_id.split(".", 1)
pipe_id = form_data['model']
if '.' in pipe_id:
pipe_id, _ = pipe_id.split('.', 1)
return pipe_id
def get_function_params(function_module, form_data, user, extra_params=None):
@@ -202,27 +196,25 @@ async def generate_function_chat_completion(
# Get the signature of the function
sig = inspect.signature(function_module.pipe)
params = {"body": form_data} | {
k: v for k, v in extra_params.items() if k in sig.parameters
}
params = {'body': form_data} | {k: v for k, v in extra_params.items() if k in sig.parameters}
if "__user__" in params and hasattr(function_module, "UserValves"):
if '__user__' in params and hasattr(function_module, 'UserValves'):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
params['__user__']['valves'] = function_module.UserValves(**user_valves)
except Exception as e:
log.exception(e)
params["__user__"]["valves"] = function_module.UserValves()
params['__user__']['valves'] = function_module.UserValves()
return params
model_id = form_data.get("model")
model_id = form_data.get('model')
model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {})
metadata = form_data.pop('metadata', {})
files = metadata.get("files", [])
tool_ids = metadata.get("tool_ids", [])
files = metadata.get('files', [])
tool_ids = metadata.get('tool_ids', [])
# Check if tool_ids is None
if tool_ids is None:
tool_ids = []
@@ -233,56 +225,56 @@ async def generate_function_chat_completion(
__task_body__ = None
if metadata:
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
if all(k in metadata for k in ('session_id', 'chat_id', 'message_id')):
__event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None)
__task_body__ = metadata.get("task_body", None)
__task__ = metadata.get('task', None)
__task_body__ = metadata.get('task_body', None)
oauth_token = None
try:
if request.cookies.get("oauth_session_id", None):
if request.cookies.get('oauth_session_id', None):
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id,
request.cookies.get("oauth_session_id", None),
request.cookies.get('oauth_session_id', None),
)
except Exception as e:
log.error(f"Error getting OAuth token: {e}")
log.error(f'Error getting OAuth token: {e}')
extra_params = {
"__event_emitter__": __event_emitter__,
"__event_call__": __event_call__,
"__chat_id__": metadata.get("chat_id", None),
"__session_id__": metadata.get("session_id", None),
"__message_id__": metadata.get("message_id", None),
"__task__": __task__,
"__task_body__": __task_body__,
"__files__": files,
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata,
"__oauth_token__": oauth_token,
"__request__": request,
'__event_emitter__': __event_emitter__,
'__event_call__': __event_call__,
'__chat_id__': metadata.get('chat_id', None),
'__session_id__': metadata.get('session_id', None),
'__message_id__': metadata.get('message_id', None),
'__task__': __task__,
'__task_body__': __task_body__,
'__files__': files,
'__user__': user.model_dump() if isinstance(user, UserModel) else {},
'__metadata__': metadata,
'__oauth_token__': oauth_token,
'__request__': request,
}
extra_params["__tools__"] = await get_tools(
extra_params['__tools__'] = await get_tools(
request,
tool_ids,
user,
{
**extra_params,
"__model__": models.get(form_data["model"], None),
"__messages__": form_data["messages"],
"__files__": files,
'__model__': models.get(form_data['model'], None),
'__messages__': form_data['messages'],
'__files__': files,
},
)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
form_data['model'] = model_info.base_model_id
params = model_info.params.model_dump()
if params:
system = params.pop("system", None)
system = params.pop('system', None)
form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_system_prompt_to_body(system, form_data, metadata, user)
@@ -292,7 +284,7 @@ async def generate_function_chat_completion(
pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params)
if form_data.get("stream", False):
if form_data.get('stream', False):
async def stream_content():
try:
@@ -304,17 +296,17 @@ async def generate_function_chat_completion(
yield data
return
if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n"
yield f'data: {json.dumps(res)}\n\n'
return
except Exception as e:
log.error(f"Error: {e}")
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
log.error(f'Error: {e}')
yield f'data: {json.dumps({"error": {"detail": str(e)}})}\n\n'
return
if isinstance(res, str):
message = openai_chat_chunk_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
message = openai_chat_chunk_message_template(form_data['model'], res)
yield f'data: {json.dumps(message)}\n\n'
if isinstance(res, Iterator):
for line in res:
@@ -325,21 +317,19 @@ async def generate_function_chat_completion(
yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator):
finish_message = openai_chat_chunk_message_template(
form_data["model"], ""
)
finish_message["choices"][0]["finish_reason"] = "stop"
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
finish_message = openai_chat_chunk_message_template(form_data['model'], '')
finish_message['choices'][0]['finish_reason'] = 'stop'
yield f'data: {json.dumps(finish_message)}\n\n'
yield 'data: [DONE]'
return StreamingResponse(stream_content(), media_type="text/event-stream")
return StreamingResponse(stream_content(), media_type='text/event-stream')
else:
try:
res = await execute_pipe(pipe, params)
except Exception as e:
log.error(f"Error: {e}")
return {"error": {"detail": str(e)}}
log.error(f'Error: {e}')
return {'error': {'detail': str(e)}}
if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res
@@ -347,4 +337,4 @@ async def generate_function_chat_completion(
return res.model_dump()
message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message)
return openai_chat_completion_message_template(form_data['model'], message)

View File

@@ -56,17 +56,15 @@ def handle_peewee_migration(DATABASE_URL):
# db = None
try:
# Replace the postgresql:// with postgres:// to handle the peewee migration
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
db = register_connection(DATABASE_URL.replace('postgresql://', 'postgres://'))
migrate_dir = OPEN_WEBUI_DIR / 'internal' / 'migrations'
router = Router(db, logger=log, migrate_dir=migrate_dir)
router.run()
db.close()
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
log.warning(
"Hint: If your database password contains special characters, you may need to URL-encode it."
)
log.error(f'Failed to initialize the database connection: {e}')
log.warning('Hint: If your database password contains special characters, you may need to URL-encode it.')
raise
finally:
# Properly closing the database connection
@@ -74,7 +72,7 @@ def handle_peewee_migration(DATABASE_URL):
db.close()
# Assert if db connection has been closed
assert db.is_closed(), "Database connection is still open."
assert db.is_closed(), 'Database connection is still open.'
if ENABLE_DB_MIGRATIONS:
@@ -84,15 +82,13 @@ if ENABLE_DB_MIGRATIONS:
SQLALCHEMY_DATABASE_URL = DATABASE_URL
# Handle SQLCipher URLs
if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
database_password = os.environ.get("DATABASE_PASSWORD")
if not database_password or database_password.strip() == "":
raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'):
database_password = os.environ.get('DATABASE_PASSWORD')
if not database_password or database_password.strip() == '':
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
# Extract database path from SQLCipher URL
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '')
# Create a custom creator function that uses sqlcipher3
def create_sqlcipher_connection():
@@ -109,7 +105,7 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
# or QueuePool if DATABASE_POOL_SIZE is explicitly configured.
if isinstance(DATABASE_POOL_SIZE, int) and DATABASE_POOL_SIZE > 0:
engine = create_engine(
"sqlite://",
'sqlite://',
creator=create_sqlcipher_connection,
pool_size=DATABASE_POOL_SIZE,
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
@@ -121,28 +117,26 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
)
else:
engine = create_engine(
"sqlite://",
'sqlite://',
creator=create_sqlcipher_connection,
poolclass=NullPool,
echo=False,
)
log.info("Connected to encrypted SQLite database using SQLCipher")
log.info('Connected to encrypted SQLite database using SQLCipher')
elif "sqlite" in SQLALCHEMY_DATABASE_URL:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
elif 'sqlite' in SQLALCHEMY_DATABASE_URL:
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={'check_same_thread': False})
def on_connect(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
if DATABASE_ENABLE_SQLITE_WAL:
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute('PRAGMA journal_mode=WAL')
else:
cursor.execute("PRAGMA journal_mode=DELETE")
cursor.execute('PRAGMA journal_mode=DELETE')
cursor.close()
event.listen(engine, "connect", on_connect)
event.listen(engine, 'connect', on_connect)
else:
if isinstance(DATABASE_POOL_SIZE, int):
if DATABASE_POOL_SIZE > 0:
@@ -156,16 +150,12 @@ else:
poolclass=QueuePool,
)
else:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
)
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool)
else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
Base = declarative_base(metadata=metadata_obj)
ScopedSession = scoped_session(SessionLocal)

View File

@@ -56,7 +56,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
active = pw.BooleanField()
class Meta:
table_name = "auth"
table_name = 'auth'
@migrator.create_model
class Chat(pw.Model):
@@ -67,7 +67,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chat"
table_name = 'chat'
@migrator.create_model
class ChatIdTag(pw.Model):
@@ -78,7 +78,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chatidtag"
table_name = 'chatidtag'
@migrator.create_model
class Document(pw.Model):
@@ -92,7 +92,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "document"
table_name = 'document'
@migrator.create_model
class Modelfile(pw.Model):
@@ -103,7 +103,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "modelfile"
table_name = 'modelfile'
@migrator.create_model
class Prompt(pw.Model):
@@ -115,7 +115,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "prompt"
table_name = 'prompt'
@migrator.create_model
class Tag(pw.Model):
@@ -125,7 +125,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
table_name = 'tag'
@migrator.create_model
class User(pw.Model):
@@ -137,7 +137,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "user"
table_name = 'user'
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
@@ -149,7 +149,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
active = pw.BooleanField()
class Meta:
table_name = "auth"
table_name = 'auth'
@migrator.create_model
class Chat(pw.Model):
@@ -160,7 +160,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chat"
table_name = 'chat'
@migrator.create_model
class ChatIdTag(pw.Model):
@@ -171,7 +171,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "chatidtag"
table_name = 'chatidtag'
@migrator.create_model
class Document(pw.Model):
@@ -185,7 +185,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "document"
table_name = 'document'
@migrator.create_model
class Modelfile(pw.Model):
@@ -196,7 +196,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "modelfile"
table_name = 'modelfile'
@migrator.create_model
class Prompt(pw.Model):
@@ -208,7 +208,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "prompt"
table_name = 'prompt'
@migrator.create_model
class Tag(pw.Model):
@@ -218,7 +218,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
table_name = 'tag'
@migrator.create_model
class User(pw.Model):
@@ -230,24 +230,24 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField()
class Meta:
table_name = "user"
table_name = 'user'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("user")
migrator.remove_model('user')
migrator.remove_model("tag")
migrator.remove_model('tag')
migrator.remove_model("prompt")
migrator.remove_model('prompt')
migrator.remove_model("modelfile")
migrator.remove_model('modelfile')
migrator.remove_model("document")
migrator.remove_model('document')
migrator.remove_model("chatidtag")
migrator.remove_model('chatidtag')
migrator.remove_model("chat")
migrator.remove_model('chat')
migrator.remove_model("auth")
migrator.remove_model('auth')

View File

@@ -36,12 +36,10 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"chat", share_id=pw.CharField(max_length=255, null=True, unique=True)
)
migrator.add_fields('chat', share_id=pw.CharField(max_length=255, null=True, unique=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("chat", "share_id")
migrator.remove_fields('chat', 'share_id')

View File

@@ -36,12 +36,10 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user", api_key=pw.CharField(max_length=255, null=True, unique=True)
)
migrator.add_fields('user', api_key=pw.CharField(max_length=255, null=True, unique=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "api_key")
migrator.remove_fields('user', 'api_key')

View File

@@ -36,10 +36,10 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields("chat", archived=pw.BooleanField(default=False))
migrator.add_fields('chat', archived=pw.BooleanField(default=False))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("chat", "archived")
migrator.remove_fields('chat', 'archived')

View File

@@ -45,22 +45,20 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
'chat',
created_at=pw.DateTimeField(null=True), # Allow null for transition
updated_at=pw.DateTimeField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL')
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp")
migrator.remove_fields('chat', 'timestamp')
# Update the fields to be not null now that they are populated
migrator.change_fields(
"chat",
'chat',
created_at=pw.DateTimeField(null=False),
updated_at=pw.DateTimeField(null=False),
)
@@ -69,22 +67,20 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields(
"chat",
'chat',
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
)
# Populate the new fields from an existing 'timestamp' field
migrator.sql(
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL')
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp")
migrator.remove_fields('chat', 'timestamp')
# Update the fields to be not null now that they are populated
migrator.change_fields(
"chat",
'chat',
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
)
@@ -101,29 +97,29 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
migrator.add_fields('chat', timestamp=pw.DateTimeField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at")
migrator.sql('UPDATE chat SET timestamp = created_at')
# Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at")
migrator.remove_fields('chat', 'created_at', 'updated_at')
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
migrator.change_fields('chat', timestamp=pw.DateTimeField(null=False))
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
migrator.add_fields('chat', timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at")
migrator.sql('UPDATE chat SET timestamp = created_at')
# Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at")
migrator.remove_fields('chat', 'created_at', 'updated_at')
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))
migrator.change_fields('chat', timestamp=pw.BigIntegerField(null=False))

View File

@@ -38,45 +38,45 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
'chatidtag',
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"document",
'document',
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"modelfile",
'modelfile',
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"prompt",
'prompt',
timestamp=pw.BigIntegerField(),
)
migrator.change_fields(
"user",
'user',
timestamp=pw.BigIntegerField(),
)
# Alter the tables with varchar to text where necessary
migrator.change_fields(
"auth",
'auth',
password=pw.TextField(),
)
migrator.change_fields(
"chat",
'chat',
title=pw.TextField(),
)
migrator.change_fields(
"document",
'document',
title=pw.TextField(),
filename=pw.TextField(),
)
migrator.change_fields(
"prompt",
'prompt',
title=pw.TextField(),
)
migrator.change_fields(
"user",
'user',
profile_image_url=pw.TextField(),
)
@@ -87,43 +87,43 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
if isinstance(database, pw.SqliteDatabase):
# Alter the tables with timestamps
migrator.change_fields(
"chatidtag",
'chatidtag',
timestamp=pw.DateField(),
)
migrator.change_fields(
"document",
'document',
timestamp=pw.DateField(),
)
migrator.change_fields(
"modelfile",
'modelfile',
timestamp=pw.DateField(),
)
migrator.change_fields(
"prompt",
'prompt',
timestamp=pw.DateField(),
)
migrator.change_fields(
"user",
'user',
timestamp=pw.DateField(),
)
migrator.change_fields(
"auth",
'auth',
password=pw.CharField(max_length=255),
)
migrator.change_fields(
"chat",
'chat',
title=pw.CharField(),
)
migrator.change_fields(
"document",
'document',
title=pw.CharField(),
filename=pw.CharField(),
)
migrator.change_fields(
"prompt",
'prompt',
title=pw.CharField(),
)
migrator.change_fields(
"user",
'user',
profile_image_url=pw.CharField(),
)

View File

@@ -38,7 +38,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'user' table
migrator.add_fields(
"user",
'user',
created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
last_active_at=pw.BigIntegerField(null=True), # Allow null for transition
@@ -50,11 +50,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
)
# Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("user", "timestamp")
migrator.remove_fields('user', 'timestamp')
# Update the fields to be not null now that they are populated
migrator.change_fields(
"user",
'user',
created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False),
last_active_at=pw.BigIntegerField(null=False),
@@ -65,14 +65,14 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True))
migrator.add_fields('user', timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp
migrator.sql('UPDATE "user" SET timestamp = created_at')
# Remove the created_at and updated_at fields
migrator.remove_fields("user", "created_at", "updated_at", "last_active_at")
migrator.remove_fields('user', 'created_at', 'updated_at', 'last_active_at')
# Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False))
migrator.change_fields('user', timestamp=pw.BigIntegerField(null=False))

View File

@@ -43,10 +43,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
created_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "memory"
table_name = 'memory'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("memory")
migrator.remove_model('memory')

View File

@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "model"
table_name = 'model'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("model")
migrator.remove_model('model')

View File

@@ -42,12 +42,12 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
# Fetch data from 'modelfile' table and insert into 'model' table
migrate_modelfile_to_model(migrator, database)
# Drop the 'modelfile' table
migrator.remove_model("modelfile")
migrator.remove_model('modelfile')
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
ModelFile = migrator.orm["modelfile"]
Model = migrator.orm["model"]
ModelFile = migrator.orm['modelfile']
Model = migrator.orm['model']
modelfiles = ModelFile.select()
@@ -57,25 +57,25 @@ def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
modelfile.modelfile = json.loads(modelfile.modelfile)
meta = json.dumps(
{
"description": modelfile.modelfile.get("desc"),
"profile_image_url": modelfile.modelfile.get("imageUrl"),
"ollama": {"modelfile": modelfile.modelfile.get("content")},
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"),
"categories": modelfile.modelfile.get("categories"),
"user": {**modelfile.modelfile.get("user", {}), "community": True},
'description': modelfile.modelfile.get('desc'),
'profile_image_url': modelfile.modelfile.get('imageUrl'),
'ollama': {'modelfile': modelfile.modelfile.get('content')},
'suggestion_prompts': modelfile.modelfile.get('suggestionPrompts'),
'categories': modelfile.modelfile.get('categories'),
'user': {**modelfile.modelfile.get('user', {}), 'community': True},
}
)
info = parse_ollama_modelfile(modelfile.modelfile.get("content"))
info = parse_ollama_modelfile(modelfile.modelfile.get('content'))
# Insert the processed data into the 'model' table
Model.create(
id=f"ollama-{modelfile.tag_name}",
id=f'ollama-{modelfile.tag_name}',
user_id=modelfile.user_id,
base_model_id=info.get("base_model_id"),
name=modelfile.modelfile.get("title"),
base_model_id=info.get('base_model_id'),
name=modelfile.modelfile.get('title'),
meta=meta,
params=json.dumps(info.get("params", {})),
params=json.dumps(info.get('params', {})),
created_at=modelfile.timestamp,
updated_at=modelfile.timestamp,
)
@@ -86,7 +86,7 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
recreate_modelfile_table(migrator, database)
move_data_back_to_modelfile(migrator, database)
migrator.remove_model("model")
migrator.remove_model('model')
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
@@ -102,8 +102,8 @@ def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
Model = migrator.orm["model"]
Modelfile = migrator.orm["modelfile"]
Model = migrator.orm['model']
Modelfile = migrator.orm['modelfile']
models = Model.select()
@@ -112,13 +112,13 @@ def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
meta = json.loads(model.meta)
modelfile_data = {
"title": model.name,
"desc": meta.get("description"),
"imageUrl": meta.get("profile_image_url"),
"content": meta.get("ollama", {}).get("modelfile"),
"suggestionPrompts": meta.get("suggestion_prompts"),
"categories": meta.get("categories"),
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"},
'title': model.name,
'desc': meta.get('description'),
'imageUrl': meta.get('profile_image_url'),
'content': meta.get('ollama', {}).get('modelfile'),
'suggestionPrompts': meta.get('suggestion_prompts'),
'categories': meta.get('categories'),
'user': {k: v for k, v in meta.get('user', {}).items() if k != 'community'},
}
# Insert the processed data back into the 'modelfile' table

View File

@@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields settings to the 'user' table
migrator.add_fields("user", settings=pw.TextField(null=True))
migrator.add_fields('user', settings=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "settings")
migrator.remove_fields('user', 'settings')

View File

@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "tool"
table_name = 'tool'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("tool")
migrator.remove_model('tool')

View File

@@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields info to the 'user' table
migrator.add_fields("user", info=pw.TextField(null=True))
migrator.add_fields('user', info=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "info")
migrator.remove_fields('user', 'info')

View File

@@ -45,10 +45,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
created_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "file"
table_name = 'file'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("file")
migrator.remove_model('file')

View File

@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "function"
table_name = 'function'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("function")
migrator.remove_model('function')

View File

@@ -36,14 +36,14 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields("tool", valves=pw.TextField(null=True))
migrator.add_fields("function", valves=pw.TextField(null=True))
migrator.add_fields("function", is_active=pw.BooleanField(default=False))
migrator.add_fields('tool', valves=pw.TextField(null=True))
migrator.add_fields('function', valves=pw.TextField(null=True))
migrator.add_fields('function', is_active=pw.BooleanField(default=False))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("tool", "valves")
migrator.remove_fields("function", "valves")
migrator.remove_fields("function", "is_active")
migrator.remove_fields('tool', 'valves')
migrator.remove_fields('function', 'valves')
migrator.remove_fields('function', 'is_active')

View File

@@ -33,7 +33,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user",
'user',
oauth_sub=pw.TextField(null=True, unique=True),
)
@@ -41,4 +41,4 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "oauth_sub")
migrator.remove_fields('user', 'oauth_sub')

View File

@@ -37,7 +37,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"function",
'function',
is_global=pw.BooleanField(default=False),
)
@@ -45,4 +45,4 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("function", "is_global")
migrator.remove_fields('function', 'is_global')

View File

@@ -10,13 +10,13 @@ from playhouse.shortcuts import ReconnectMixin
log = logging.getLogger(__name__)
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())
db_state_default = {'closed': None, 'conn': None, 'ctx': None, 'transactions': None}
db_state = ContextVar('db_state', default=db_state_default.copy())
class PeeweeConnectionState(object):
def __init__(self, **kwargs):
super().__setattr__("_state", db_state)
super().__setattr__('_state', db_state)
super().__init__(**kwargs)
def __setattr__(self, name, value):
@@ -30,10 +30,10 @@ class PeeweeConnectionState(object):
class CustomReconnectMixin(ReconnectMixin):
reconnect_errors = (
# psycopg2
(OperationalError, "termin"),
(InterfaceError, "closed"),
(OperationalError, 'termin'),
(InterfaceError, 'closed'),
# peewee
(PeeWeeInterfaceError, "closed"),
(PeeWeeInterfaceError, 'closed'),
)
@@ -43,23 +43,21 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
def register_connection(db_url):
# Check if using SQLCipher protocol
if db_url.startswith("sqlite+sqlcipher://"):
database_password = os.environ.get("DATABASE_PASSWORD")
if not database_password or database_password.strip() == "":
raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
if db_url.startswith('sqlite+sqlcipher://'):
database_password = os.environ.get('DATABASE_PASSWORD')
if not database_password or database_password.strip() == '':
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
from playhouse.sqlcipher_ext import SqlCipherDatabase
# Parse the database path from SQLCipher URL
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
db_path = db_url.replace("sqlite+sqlcipher://", "")
db_path = db_url.replace('sqlite+sqlcipher://', '')
# Use Peewee's native SqlCipherDatabase with encryption
db = SqlCipherDatabase(db_path, passphrase=database_password)
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to encrypted SQLite database using SQLCipher")
log.info('Connected to encrypted SQLite database using SQLCipher')
else:
# Standard database connection (existing logic)
@@ -68,7 +66,7 @@ def register_connection(db_url):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to PostgreSQL database")
log.info('Connected to PostgreSQL database')
# Get the connection details
connection = parse(db_url, unquote_user=True, unquote_password=True)
@@ -80,7 +78,7 @@ def register_connection(db_url):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to SQLite database")
log.info('Connected to SQLite database')
else:
raise ValueError("Unsupported database connection")
raise ValueError('Unsupported database connection')
return db

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,7 @@ if config.config_file_name is not None:
fileConfig(config.config_file_name, disable_existing_loggers=False)
# Re-apply JSON formatter after fileConfig replaces handlers.
if LOG_FORMAT == "json":
if LOG_FORMAT == 'json':
from open_webui.env import JSONFormatter
for handler in logging.root.handlers:
@@ -36,7 +36,7 @@ target_metadata = Auth.metadata
DB_URL = DATABASE_URL
if DB_URL:
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%'))
def run_migrations_offline() -> None:
@@ -51,12 +51,12 @@ def run_migrations_offline() -> None:
script output.
"""
url = config.get_main_option("sqlalchemy.url")
url = config.get_main_option('sqlalchemy.url')
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
dialect_opts={'paramstyle': 'named'},
)
with context.begin_transaction():
@@ -71,15 +71,13 @@ def run_migrations_online() -> None:
"""
# Handle SQLCipher URLs
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'):
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '':
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
# Extract database path from SQLCipher URL
db_path = DB_URL.replace("sqlite+sqlcipher://", "")
if db_path.startswith("/"):
db_path = DB_URL.replace('sqlite+sqlcipher://', '')
if db_path.startswith('/'):
db_path = db_path[1:] # Remove leading slash for relative paths
# Create a custom creator function that uses sqlcipher3
@@ -91,7 +89,7 @@ def run_migrations_online() -> None:
return conn
connectable = create_engine(
"sqlite://", # Dummy URL since we're using creator
'sqlite://', # Dummy URL since we're using creator
creator=create_sqlcipher_connection,
echo=False,
)
@@ -99,7 +97,7 @@ def run_migrations_online() -> None:
# Standard database connection (existing logic)
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
prefix='sqlalchemy.',
poolclass=pool.NullPool,
)

View File

@@ -12,4 +12,4 @@ def get_existing_tables():
def get_revision_id():
import uuid
return str(uuid.uuid4()).replace("-", "")[:12]
return str(uuid.uuid4()).replace('-', '')[:12]

View File

@@ -9,38 +9,38 @@ Create Date: 2025-08-13 03:00:00.000000
from alembic import op
import sqlalchemy as sa
revision = "018012973d35"
down_revision = "d31026856c01"
revision = '018012973d35'
down_revision = 'd31026856c01'
branch_labels = None
depends_on = None
def upgrade():
# Chat table indexes
op.create_index("folder_id_idx", "chat", ["folder_id"])
op.create_index("user_id_pinned_idx", "chat", ["user_id", "pinned"])
op.create_index("user_id_archived_idx", "chat", ["user_id", "archived"])
op.create_index("updated_at_user_id_idx", "chat", ["updated_at", "user_id"])
op.create_index("folder_id_user_id_idx", "chat", ["folder_id", "user_id"])
op.create_index('folder_id_idx', 'chat', ['folder_id'])
op.create_index('user_id_pinned_idx', 'chat', ['user_id', 'pinned'])
op.create_index('user_id_archived_idx', 'chat', ['user_id', 'archived'])
op.create_index('updated_at_user_id_idx', 'chat', ['updated_at', 'user_id'])
op.create_index('folder_id_user_id_idx', 'chat', ['folder_id', 'user_id'])
# Tag table index
op.create_index("user_id_idx", "tag", ["user_id"])
op.create_index('user_id_idx', 'tag', ['user_id'])
# Function table index
op.create_index("is_global_idx", "function", ["is_global"])
op.create_index('is_global_idx', 'function', ['is_global'])
def downgrade():
# Chat table indexes
op.drop_index("folder_id_idx", table_name="chat")
op.drop_index("user_id_pinned_idx", table_name="chat")
op.drop_index("user_id_archived_idx", table_name="chat")
op.drop_index("updated_at_user_id_idx", table_name="chat")
op.drop_index("folder_id_user_id_idx", table_name="chat")
op.drop_index('folder_id_idx', table_name='chat')
op.drop_index('user_id_pinned_idx', table_name='chat')
op.drop_index('user_id_archived_idx', table_name='chat')
op.drop_index('updated_at_user_id_idx', table_name='chat')
op.drop_index('folder_id_user_id_idx', table_name='chat')
# Tag table index
op.drop_index("user_id_idx", table_name="tag")
op.drop_index('user_id_idx', table_name='tag')
# Function table index
op.drop_index("is_global_idx", table_name="function")
op.drop_index('is_global_idx', table_name='function')

View File

@@ -13,8 +13,8 @@ from sqlalchemy.engine.reflection import Inspector
import json
revision = "1af9b942657b"
down_revision = "242a2047eae0"
revision = '1af9b942657b'
down_revision = '242a2047eae0'
branch_labels = None
depends_on = None
@@ -25,43 +25,40 @@ def upgrade():
inspector = Inspector.from_engine(conn)
# Clean up potential leftover temp table from previous failures
conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag"))
conn.execute(sa.text('DROP TABLE IF EXISTS _alembic_tmp_tag'))
# Check if the 'tag' table exists
tables = inspector.get_table_names()
# Step 1: Modify Tag table using batch mode for SQLite support
if "tag" in tables:
if 'tag' in tables:
# Get the current columns in the 'tag' table
columns = [col["name"] for col in inspector.get_columns("tag")]
columns = [col['name'] for col in inspector.get_columns('tag')]
# Get any existing unique constraints on the 'tag' table
current_constraints = inspector.get_unique_constraints("tag")
current_constraints = inspector.get_unique_constraints('tag')
with op.batch_alter_table("tag", schema=None) as batch_op:
with op.batch_alter_table('tag', schema=None) as batch_op:
# Check if the unique constraint already exists
if not any(
constraint["name"] == "uq_id_user_id"
for constraint in current_constraints
):
if not any(constraint['name'] == 'uq_id_user_id' for constraint in current_constraints):
# Create unique constraint if it doesn't exist
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id'])
# Check if the 'data' column exists before trying to drop it
if "data" in columns:
batch_op.drop_column("data")
if 'data' in columns:
batch_op.drop_column('data')
# Check if the 'meta' column needs to be created
if "meta" not in columns:
if 'meta' not in columns:
# Add the 'meta' column if it doesn't already exist
batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
batch_op.add_column(sa.Column('meta', sa.JSON(), nullable=True))
tag = table(
"tag",
column("id", sa.String()),
column("name", sa.String()),
column("user_id", sa.String()),
column("meta", sa.JSON()),
'tag',
column('id', sa.String()),
column('name', sa.String()),
column('user_id', sa.String()),
column('meta', sa.JSON()),
)
# Step 2: Migrate tags
@@ -70,12 +67,12 @@ def upgrade():
tag_updates = {}
for row in result:
new_id = row.name.replace(" ", "_").lower()
new_id = row.name.replace(' ', '_').lower()
tag_updates[row.id] = new_id
for tag_id, new_tag_id in tag_updates.items():
print(f"Updating tag {tag_id} to {new_tag_id}")
if new_tag_id == "pinned":
print(f'Updating tag {tag_id} to {new_tag_id}')
if new_tag_id == 'pinned':
# delete tag
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
@@ -86,9 +83,7 @@ def upgrade():
if existing_tag_result:
# Handle duplicate case: the new_tag_id already exists
print(
f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates."
)
print(f'Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates.')
# Option 1: Delete the current tag if an update to new_tag_id would cause duplication
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt)
@@ -98,19 +93,15 @@ def upgrade():
conn.execute(update_stmt)
# Add columns `pinned` and `meta` to 'chat'
op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
op.add_column(
"chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
)
op.add_column('chat', sa.Column('pinned', sa.Boolean(), nullable=True))
op.add_column('chat', sa.Column('meta', sa.JSON(), nullable=False, server_default='{}'))
chatidtag = table(
"chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
)
chatidtag = table('chatidtag', column('chat_id', sa.String()), column('tag_name', sa.String()))
chat = table(
"chat",
column("id", sa.String()),
column("pinned", sa.Boolean()),
column("meta", sa.JSON()),
'chat',
column('id', sa.String()),
column('pinned', sa.Boolean()),
column('meta', sa.JSON()),
)
# Fetch existing tags
@@ -120,29 +111,27 @@ def upgrade():
chat_updates = {}
for row in result:
chat_id = row.chat_id
tag_name = row.tag_name.replace(" ", "_").lower()
tag_name = row.tag_name.replace(' ', '_').lower()
if tag_name == "pinned":
if tag_name == 'pinned':
# Specifically handle 'pinned' tag
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": True, "meta": {}}
chat_updates[chat_id] = {'pinned': True, 'meta': {}}
else:
chat_updates[chat_id]["pinned"] = True
chat_updates[chat_id]['pinned'] = True
else:
if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
chat_updates[chat_id] = {'pinned': False, 'meta': {'tags': [tag_name]}}
else:
tags = chat_updates[chat_id]["meta"].get("tags", [])
tags = chat_updates[chat_id]['meta'].get('tags', [])
tags.append(tag_name)
chat_updates[chat_id]["meta"]["tags"] = list(set(tags))
chat_updates[chat_id]['meta']['tags'] = list(set(tags))
# Update chats based on accumulated changes
for chat_id, updates in chat_updates.items():
update_stmt = sa.update(chat).where(chat.c.id == chat_id)
update_stmt = update_stmt.values(
meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
)
update_stmt = update_stmt.values(meta=updates.get('meta', {}), pinned=updates.get('pinned', False))
conn.execute(update_stmt)
pass

View File

@@ -12,8 +12,8 @@ from sqlalchemy.sql import table, select, update
import json
revision = "242a2047eae0"
down_revision = "6a39f3d8e55c"
revision = '242a2047eae0'
down_revision = '6a39f3d8e55c'
branch_labels = None
depends_on = None
@@ -22,39 +22,37 @@ def upgrade():
conn = op.get_bind()
inspector = sa.inspect(conn)
columns = inspector.get_columns("chat")
column_dict = {col["name"]: col for col in columns}
columns = inspector.get_columns('chat')
column_dict = {col['name']: col for col in columns}
chat_column = column_dict.get("chat")
old_chat_exists = "old_chat" in column_dict
chat_column = column_dict.get('chat')
old_chat_exists = 'old_chat' in column_dict
if chat_column:
if isinstance(chat_column["type"], sa.Text):
if isinstance(chat_column['type'], sa.Text):
print("Converting 'chat' column to JSON")
if old_chat_exists:
print("Dropping old 'old_chat' column")
op.drop_column("chat", "old_chat")
op.drop_column('chat', 'old_chat')
# Step 1: Rename current 'chat' column to 'old_chat'
print("Renaming 'chat' column to 'old_chat'")
op.alter_column(
"chat", "chat", new_column_name="old_chat", existing_type=sa.Text()
)
op.alter_column('chat', 'chat', new_column_name='old_chat', existing_type=sa.Text())
# Step 2: Add new 'chat' column of type JSON
print("Adding new 'chat' column of type JSON")
op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
op.add_column('chat', sa.Column('chat', sa.JSON(), nullable=True))
else:
# If the column is already JSON, no need to do anything
pass
# Step 3: Migrate data from 'old_chat' to 'chat'
chat_table = table(
"chat",
sa.Column("id", sa.String(), primary_key=True),
sa.Column("old_chat", sa.Text()),
sa.Column("chat", sa.JSON()),
'chat',
sa.Column('id', sa.String(), primary_key=True),
sa.Column('old_chat', sa.Text()),
sa.Column('chat', sa.JSON()),
)
# - Selecting all data from the table
@@ -67,41 +65,33 @@ def upgrade():
except json.JSONDecodeError:
json_data = None # Handle cases where the text cannot be converted to JSON
connection.execute(
sa.update(chat_table)
.where(chat_table.c.id == row.id)
.values(chat=json_data)
)
connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(chat=json_data))
# Step 4: Drop 'old_chat' column
print("Dropping 'old_chat' column")
op.drop_column("chat", "old_chat")
op.drop_column('chat', 'old_chat')
def downgrade():
# Step 1: Add 'old_chat' column back as Text
op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True))
op.add_column('chat', sa.Column('old_chat', sa.Text(), nullable=True))
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
chat_table = table(
"chat",
sa.Column("id", sa.String(), primary_key=True),
sa.Column("chat", sa.JSON()),
sa.Column("old_chat", sa.Text()),
'chat',
sa.Column('id', sa.String(), primary_key=True),
sa.Column('chat', sa.JSON()),
sa.Column('old_chat', sa.Text()),
)
connection = op.get_bind()
results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
for row in results:
text_data = json.dumps(row.chat) if row.chat is not None else None
connection.execute(
sa.update(chat_table)
.where(chat_table.c.id == row.id)
.values(old_chat=text_data)
)
connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(old_chat=text_data))
# Step 3: Remove the new 'chat' JSON column
op.drop_column("chat", "chat")
op.drop_column('chat', 'chat')
# Step 4: Rename 'old_chat' back to 'chat'
op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text())
op.alter_column('chat', 'old_chat', new_column_name='chat', existing_type=sa.Text())

View File

@@ -13,19 +13,19 @@ import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "2f1211949ecc"
down_revision: Union[str, None] = "37f288994c47"
revision: str = '2f1211949ecc'
down_revision: Union[str, None] = '37f288994c47'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# New columns to be added to channel_member table
op.add_column("channel_member", sa.Column("status", sa.Text(), nullable=True))
op.add_column('channel_member', sa.Column('status', sa.Text(), nullable=True))
op.add_column(
"channel_member",
'channel_member',
sa.Column(
"is_active",
'is_active',
sa.Boolean(),
nullable=False,
default=True,
@@ -34,9 +34,9 @@ def upgrade() -> None:
)
op.add_column(
"channel_member",
'channel_member',
sa.Column(
"is_channel_muted",
'is_channel_muted',
sa.Boolean(),
nullable=False,
default=False,
@@ -44,9 +44,9 @@ def upgrade() -> None:
),
)
op.add_column(
"channel_member",
'channel_member',
sa.Column(
"is_channel_pinned",
'is_channel_pinned',
sa.Boolean(),
nullable=False,
default=False,
@@ -54,49 +54,41 @@ def upgrade() -> None:
),
)
op.add_column("channel_member", sa.Column("data", sa.JSON(), nullable=True))
op.add_column("channel_member", sa.Column("meta", sa.JSON(), nullable=True))
op.add_column('channel_member', sa.Column('data', sa.JSON(), nullable=True))
op.add_column('channel_member', sa.Column('meta', sa.JSON(), nullable=True))
op.add_column(
"channel_member", sa.Column("joined_at", sa.BigInteger(), nullable=False)
)
op.add_column(
"channel_member", sa.Column("left_at", sa.BigInteger(), nullable=True)
)
op.add_column('channel_member', sa.Column('joined_at', sa.BigInteger(), nullable=False))
op.add_column('channel_member', sa.Column('left_at', sa.BigInteger(), nullable=True))
op.add_column(
"channel_member", sa.Column("last_read_at", sa.BigInteger(), nullable=True)
)
op.add_column('channel_member', sa.Column('last_read_at', sa.BigInteger(), nullable=True))
op.add_column(
"channel_member", sa.Column("updated_at", sa.BigInteger(), nullable=True)
)
op.add_column('channel_member', sa.Column('updated_at', sa.BigInteger(), nullable=True))
# New columns to be added to message table
op.add_column(
"message",
'message',
sa.Column(
"is_pinned",
'is_pinned',
sa.Boolean(),
nullable=False,
default=False,
server_default=sa.sql.expression.false(),
),
)
op.add_column("message", sa.Column("pinned_at", sa.BigInteger(), nullable=True))
op.add_column("message", sa.Column("pinned_by", sa.Text(), nullable=True))
op.add_column('message', sa.Column('pinned_at', sa.BigInteger(), nullable=True))
op.add_column('message', sa.Column('pinned_by', sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("channel_member", "updated_at")
op.drop_column("channel_member", "last_read_at")
op.drop_column('channel_member', 'updated_at')
op.drop_column('channel_member', 'last_read_at')
op.drop_column("channel_member", "meta")
op.drop_column("channel_member", "data")
op.drop_column('channel_member', 'meta')
op.drop_column('channel_member', 'data')
op.drop_column("channel_member", "is_channel_pinned")
op.drop_column("channel_member", "is_channel_muted")
op.drop_column('channel_member', 'is_channel_pinned')
op.drop_column('channel_member', 'is_channel_muted')
op.drop_column("message", "pinned_by")
op.drop_column("message", "pinned_at")
op.drop_column("message", "is_pinned")
op.drop_column('message', 'pinned_by')
op.drop_column('message', 'pinned_at')
op.drop_column('message', 'is_pinned')

View File

@@ -12,8 +12,8 @@ import uuid
from alembic import op
import sqlalchemy as sa
revision: str = "374d2f66af06"
down_revision: Union[str, None] = "c440947495f3"
revision: str = '374d2f66af06'
down_revision: Union[str, None] = 'c440947495f3'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -26,13 +26,13 @@ def upgrade() -> None:
# We need to assume the OLD structure.
old_prompt_table = sa.table(
"prompt",
sa.column("command", sa.Text()),
sa.column("user_id", sa.Text()),
sa.column("title", sa.Text()),
sa.column("content", sa.Text()),
sa.column("timestamp", sa.BigInteger()),
sa.column("access_control", sa.JSON()),
'prompt',
sa.column('command', sa.Text()),
sa.column('user_id', sa.Text()),
sa.column('title', sa.Text()),
sa.column('content', sa.Text()),
sa.column('timestamp', sa.BigInteger()),
sa.column('access_control', sa.JSON()),
)
# Check if table exists/read data
@@ -53,61 +53,61 @@ def upgrade() -> None:
# Step 2: Create new prompt table with 'id' as PRIMARY KEY
op.create_table(
"prompt_new",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("command", sa.String(), unique=True, index=True),
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="1"),
sa.Column("version_id", sa.Text(), nullable=True),
sa.Column("tags", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
'prompt_new',
sa.Column('id', sa.Text(), primary_key=True),
sa.Column('command', sa.String(), unique=True, index=True),
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('data', sa.JSON(), nullable=True),
sa.Column('meta', sa.JSON(), nullable=True),
sa.Column('access_control', sa.JSON(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='1'),
sa.Column('version_id', sa.Text(), nullable=True),
sa.Column('tags', sa.JSON(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column('updated_at', sa.BigInteger(), nullable=False),
)
# Step 3: Create prompt_history table
op.create_table(
"prompt_history",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("prompt_id", sa.Text(), nullable=False, index=True),
sa.Column("parent_id", sa.Text(), nullable=True),
sa.Column("snapshot", sa.JSON(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("commit_message", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
'prompt_history',
sa.Column('id', sa.Text(), primary_key=True),
sa.Column('prompt_id', sa.Text(), nullable=False, index=True),
sa.Column('parent_id', sa.Text(), nullable=True),
sa.Column('snapshot', sa.JSON(), nullable=False),
sa.Column('user_id', sa.Text(), nullable=False),
sa.Column('commit_message', sa.Text(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=False),
)
# Step 4: Migrate data
prompt_new_table = sa.table(
"prompt_new",
sa.column("id", sa.Text()),
sa.column("command", sa.String()),
sa.column("user_id", sa.String()),
sa.column("name", sa.Text()),
sa.column("content", sa.Text()),
sa.column("data", sa.JSON()),
sa.column("meta", sa.JSON()),
sa.column("access_control", sa.JSON()),
sa.column("is_active", sa.Boolean()),
sa.column("version_id", sa.Text()),
sa.column("tags", sa.JSON()),
sa.column("created_at", sa.BigInteger()),
sa.column("updated_at", sa.BigInteger()),
'prompt_new',
sa.column('id', sa.Text()),
sa.column('command', sa.String()),
sa.column('user_id', sa.String()),
sa.column('name', sa.Text()),
sa.column('content', sa.Text()),
sa.column('data', sa.JSON()),
sa.column('meta', sa.JSON()),
sa.column('access_control', sa.JSON()),
sa.column('is_active', sa.Boolean()),
sa.column('version_id', sa.Text()),
sa.column('tags', sa.JSON()),
sa.column('created_at', sa.BigInteger()),
sa.column('updated_at', sa.BigInteger()),
)
prompt_history_table = sa.table(
"prompt_history",
sa.column("id", sa.Text()),
sa.column("prompt_id", sa.Text()),
sa.column("parent_id", sa.Text()),
sa.column("snapshot", sa.JSON()),
sa.column("user_id", sa.Text()),
sa.column("commit_message", sa.Text()),
sa.column("created_at", sa.BigInteger()),
'prompt_history',
sa.column('id', sa.Text()),
sa.column('prompt_id', sa.Text()),
sa.column('parent_id', sa.Text()),
sa.column('snapshot', sa.JSON()),
sa.column('user_id', sa.Text()),
sa.column('commit_message', sa.Text()),
sa.column('created_at', sa.BigInteger()),
)
for row in existing_prompts:
@@ -120,7 +120,7 @@ def upgrade() -> None:
new_uuid = str(uuid.uuid4())
history_uuid = str(uuid.uuid4())
clean_command = command[1:] if command and command.startswith("/") else command
clean_command = command[1:] if command and command.startswith('/') else command
# Insert into prompt_new
conn.execute(
@@ -148,12 +148,12 @@ def upgrade() -> None:
prompt_id=new_uuid,
parent_id=None,
snapshot={
"name": title,
"content": content,
"command": clean_command,
"data": {},
"meta": {},
"access_control": access_control,
'name': title,
'content': content,
'command': clean_command,
'data': {},
'meta': {},
'access_control': access_control,
},
user_id=user_id,
commit_message=None,
@@ -162,8 +162,8 @@ def upgrade() -> None:
)
# Step 5: Replace old table with new one
op.drop_table("prompt")
op.rename_table("prompt_new", "prompt")
op.drop_table('prompt')
op.rename_table('prompt_new', 'prompt')
def downgrade() -> None:
@@ -171,13 +171,13 @@ def downgrade() -> None:
# Step 1: Read new data
prompt_table = sa.table(
"prompt",
sa.column("command", sa.String()),
sa.column("name", sa.Text()),
sa.column("created_at", sa.BigInteger()),
sa.column("user_id", sa.Text()),
sa.column("content", sa.Text()),
sa.column("access_control", sa.JSON()),
'prompt',
sa.column('command', sa.String()),
sa.column('name', sa.Text()),
sa.column('created_at', sa.BigInteger()),
sa.column('user_id', sa.Text()),
sa.column('content', sa.Text()),
sa.column('access_control', sa.JSON()),
)
try:
@@ -195,31 +195,31 @@ def downgrade() -> None:
current_data = []
# Step 2: Drop history and table
op.drop_table("prompt_history")
op.drop_table("prompt")
op.drop_table('prompt_history')
op.drop_table('prompt')
# Step 3: Recreate old table (command as PK?)
# Assuming old schema:
op.create_table(
"prompt",
sa.Column("command", sa.String(), primary_key=True),
sa.Column("user_id", sa.String()),
sa.Column("title", sa.Text()),
sa.Column("content", sa.Text()),
sa.Column("timestamp", sa.BigInteger()),
sa.Column("access_control", sa.JSON()),
sa.Column("id", sa.Integer(), nullable=True),
'prompt',
sa.Column('command', sa.String(), primary_key=True),
sa.Column('user_id', sa.String()),
sa.Column('title', sa.Text()),
sa.Column('content', sa.Text()),
sa.Column('timestamp', sa.BigInteger()),
sa.Column('access_control', sa.JSON()),
sa.Column('id', sa.Integer(), nullable=True),
)
# Step 4: Restore data
old_prompt_table = sa.table(
"prompt",
sa.column("command", sa.String()),
sa.column("user_id", sa.String()),
sa.column("title", sa.Text()),
sa.column("content", sa.Text()),
sa.column("timestamp", sa.BigInteger()),
sa.column("access_control", sa.JSON()),
'prompt',
sa.column('command', sa.String()),
sa.column('user_id', sa.String()),
sa.column('title', sa.Text()),
sa.column('content', sa.Text()),
sa.column('timestamp', sa.BigInteger()),
sa.column('access_control', sa.JSON()),
)
for row in current_data:
@@ -231,9 +231,7 @@ def downgrade() -> None:
access_control = row[5]
# Restore leading /
old_command = (
"/" + command if command and not command.startswith("/") else command
)
old_command = '/' + command if command and not command.startswith('/') else command
conn.execute(
sa.insert(old_prompt_table).values(

View File

@@ -9,8 +9,8 @@ Create Date: 2024-12-30 03:00:00.000000
from alembic import op
import sqlalchemy as sa
revision = "3781e22d8b01"
down_revision = "7826ab40b532"
revision = '3781e22d8b01'
down_revision = '7826ab40b532'
branch_labels = None
depends_on = None
@@ -18,9 +18,9 @@ depends_on = None
def upgrade():
# Add 'type' column to the 'channel' table
op.add_column(
"channel",
'channel',
sa.Column(
"type",
'type',
sa.Text(),
nullable=True,
),
@@ -28,43 +28,31 @@ def upgrade():
# Add 'parent_id' column to the 'message' table for threads
op.add_column(
"message",
sa.Column("parent_id", sa.Text(), nullable=True),
'message',
sa.Column('parent_id', sa.Text(), nullable=True),
)
op.create_table(
"message_reaction",
sa.Column(
"id", sa.Text(), nullable=False, primary_key=True, unique=True
), # Unique reaction ID
sa.Column("user_id", sa.Text(), nullable=False), # User who reacted
sa.Column(
"message_id", sa.Text(), nullable=False
), # Message that was reacted to
sa.Column(
"name", sa.Text(), nullable=False
), # Reaction name (e.g. "thumbs_up")
sa.Column(
"created_at", sa.BigInteger(), nullable=True
), # Timestamp of when the reaction was added
'message_reaction',
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Unique reaction ID
sa.Column('user_id', sa.Text(), nullable=False), # User who reacted
sa.Column('message_id', sa.Text(), nullable=False), # Message that was reacted to
sa.Column('name', sa.Text(), nullable=False), # Reaction name (e.g. "thumbs_up")
sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the reaction was added
)
op.create_table(
"channel_member",
sa.Column(
"id", sa.Text(), nullable=False, primary_key=True, unique=True
), # Record ID for the membership row
sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel
sa.Column("user_id", sa.Text(), nullable=False), # Associated user
sa.Column(
"created_at", sa.BigInteger(), nullable=True
), # Timestamp of when the user joined the channel
'channel_member',
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Record ID for the membership row
sa.Column('channel_id', sa.Text(), nullable=False), # Associated channel
sa.Column('user_id', sa.Text(), nullable=False), # Associated user
sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the user joined the channel
)
def downgrade():
# Revert 'type' column addition to the 'channel' table
op.drop_column("channel", "type")
op.drop_column("message", "parent_id")
op.drop_table("message_reaction")
op.drop_table("channel_member")
op.drop_column('channel', 'type')
op.drop_column('message', 'parent_id')
op.drop_table('message_reaction')
op.drop_table('channel_member')

View File

@@ -15,8 +15,8 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "37f288994c47"
down_revision: Union[str, None] = "a5c220713937"
revision: str = '37f288994c47'
down_revision: Union[str, None] = 'a5c220713937'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -24,50 +24,48 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# 1. Create new table
op.create_table(
"group_member",
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
'group_member',
sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column(
"group_id",
'group_id',
sa.Text(),
sa.ForeignKey("group.id", ondelete="CASCADE"),
sa.ForeignKey('group.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column(
"user_id",
'user_id',
sa.Text(),
sa.ForeignKey("user.id", ondelete="CASCADE"),
sa.ForeignKey('user.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.UniqueConstraint('group_id', 'user_id', name='uq_group_member_group_user'),
)
connection = op.get_bind()
# 2. Read existing group with user_ids JSON column
group_table = sa.Table(
"group",
'group',
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG
sa.Column('id', sa.Text()),
sa.Column('user_ids', sa.JSON()), # JSON stored as text in SQLite + PG
)
results = connection.execute(
sa.select(group_table.c.id, group_table.c.user_ids)
).fetchall()
results = connection.execute(sa.select(group_table.c.id, group_table.c.user_ids)).fetchall()
print(results)
# 3. Insert members into group_member table
gm_table = sa.Table(
"group_member",
'group_member',
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("group_id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
sa.Column('id', sa.Text()),
sa.Column('group_id', sa.Text()),
sa.Column('user_id', sa.Text()),
sa.Column('created_at', sa.BigInteger()),
sa.Column('updated_at', sa.BigInteger()),
)
now = int(time.time())
@@ -86,11 +84,11 @@ def upgrade() -> None:
rows = [
{
"id": str(uuid.uuid4()),
"group_id": group_id,
"user_id": uid,
"created_at": now,
"updated_at": now,
'id': str(uuid.uuid4()),
'group_id': group_id,
'user_id': uid,
'created_at': now,
'updated_at': now,
}
for uid in user_ids
]
@@ -99,47 +97,41 @@ def upgrade() -> None:
connection.execute(gm_table.insert(), rows)
# 4. Optionally drop the old column
with op.batch_alter_table("group") as batch:
batch.drop_column("user_ids")
with op.batch_alter_table('group') as batch:
batch.drop_column('user_ids')
def downgrade():
# Reverse: restore user_ids column
with op.batch_alter_table("group") as batch:
batch.add_column(sa.Column("user_ids", sa.JSON()))
with op.batch_alter_table('group') as batch:
batch.add_column(sa.Column('user_ids', sa.JSON()))
connection = op.get_bind()
gm_table = sa.Table(
"group_member",
'group_member',
sa.MetaData(),
sa.Column("group_id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
sa.Column('group_id', sa.Text()),
sa.Column('user_id', sa.Text()),
sa.Column('created_at', sa.BigInteger()),
sa.Column('updated_at', sa.BigInteger()),
)
group_table = sa.Table(
"group",
'group',
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_ids", sa.JSON()),
sa.Column('id', sa.Text()),
sa.Column('user_ids', sa.JSON()),
)
# Build JSON arrays again
results = connection.execute(sa.select(group_table.c.id)).fetchall()
for (group_id,) in results:
members = connection.execute(
sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)
).fetchall()
members = connection.execute(sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)).fetchall()
member_ids = [m[0] for m in members]
connection.execute(
group_table.update()
.where(group_table.c.id == group_id)
.values(user_ids=member_ids)
)
connection.execute(group_table.update().where(group_table.c.id == group_id).values(user_ids=member_ids))
# Drop the new table
op.drop_table("group_member")
op.drop_table('group_member')

View File

@@ -12,8 +12,8 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "38d63c18f30f"
down_revision: Union[str, None] = "3af16a1c9fb6"
revision: str = '38d63c18f30f'
down_revision: Union[str, None] = '3af16a1c9fb6'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -21,59 +21,55 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint)
inspector = sa.inspect(op.get_bind())
columns = inspector.get_columns("user")
columns = inspector.get_columns('user')
pk_columns = inspector.get_pk_constraint("user")["constrained_columns"]
id_column = next((col for col in columns if col["name"] == "id"), None)
pk_columns = inspector.get_pk_constraint('user')['constrained_columns']
id_column = next((col for col in columns if col['name'] == 'id'), None)
if id_column and not id_column.get("unique", False):
unique_constraints = inspector.get_unique_constraints("user")
unique_columns = {tuple(u["column_names"]) for u in unique_constraints}
if id_column and not id_column.get('unique', False):
unique_constraints = inspector.get_unique_constraints('user')
unique_columns = {tuple(u['column_names']) for u in unique_constraints}
with op.batch_alter_table("user") as batch_op:
with op.batch_alter_table('user') as batch_op:
# If primary key is wrong, drop it
if pk_columns and pk_columns != ["id"]:
batch_op.drop_constraint(
inspector.get_pk_constraint("user")["name"], type_="primary"
)
if pk_columns and pk_columns != ['id']:
batch_op.drop_constraint(inspector.get_pk_constraint('user')['name'], type_='primary')
# Add unique constraint if missing
if ("id",) not in unique_columns:
batch_op.create_unique_constraint("uq_user_id", ["id"])
if ('id',) not in unique_columns:
batch_op.create_unique_constraint('uq_user_id', ['id'])
# Re-create correct primary key
batch_op.create_primary_key("pk_user_id", ["id"])
batch_op.create_primary_key('pk_user_id', ['id'])
# Create oauth_session table
op.create_table(
"oauth_session",
sa.Column("id", sa.Text(), primary_key=True, nullable=False, unique=True),
'oauth_session',
sa.Column('id', sa.Text(), primary_key=True, nullable=False, unique=True),
sa.Column(
"user_id",
'user_id',
sa.Text(),
sa.ForeignKey("user.id", ondelete="CASCADE"),
sa.ForeignKey('user.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column("provider", sa.Text(), nullable=False),
sa.Column("token", sa.Text(), nullable=False),
sa.Column("expires_at", sa.BigInteger(), nullable=False),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.Column('provider', sa.Text(), nullable=False),
sa.Column('token', sa.Text(), nullable=False),
sa.Column('expires_at', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column('updated_at', sa.BigInteger(), nullable=False),
)
# Create indexes for better performance
op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"])
op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"])
op.create_index(
"idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"]
)
op.create_index('idx_oauth_session_user_id', 'oauth_session', ['user_id'])
op.create_index('idx_oauth_session_expires_at', 'oauth_session', ['expires_at'])
op.create_index('idx_oauth_session_user_provider', 'oauth_session', ['user_id', 'provider'])
def downgrade() -> None:
# Drop indexes first
op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session")
op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session")
op.drop_index("idx_oauth_session_user_id", table_name="oauth_session")
op.drop_index('idx_oauth_session_user_provider', table_name='oauth_session')
op.drop_index('idx_oauth_session_expires_at', table_name='oauth_session')
op.drop_index('idx_oauth_session_user_id', table_name='oauth_session')
# Drop the table
op.drop_table("oauth_session")
op.drop_table('oauth_session')

View File

@@ -13,8 +13,8 @@ from sqlalchemy.engine.reflection import Inspector
import json
revision = "3ab32c4b8f59"
down_revision = "1af9b942657b"
revision = '3ab32c4b8f59'
down_revision = '1af9b942657b'
branch_labels = None
depends_on = None
@@ -24,58 +24,55 @@ def upgrade():
inspector = Inspector.from_engine(conn)
# Inspecting the 'tag' table constraints and structure
existing_pk = inspector.get_pk_constraint("tag")
unique_constraints = inspector.get_unique_constraints("tag")
existing_indexes = inspector.get_indexes("tag")
existing_pk = inspector.get_pk_constraint('tag')
unique_constraints = inspector.get_unique_constraints('tag')
existing_indexes = inspector.get_indexes('tag')
print(f"Primary Key: {existing_pk}")
print(f"Unique Constraints: {unique_constraints}")
print(f"Indexes: {existing_indexes}")
print(f'Primary Key: {existing_pk}')
print(f'Unique Constraints: {unique_constraints}')
print(f'Indexes: {existing_indexes}')
with op.batch_alter_table("tag", schema=None) as batch_op:
with op.batch_alter_table('tag', schema=None) as batch_op:
# Drop existing primary key constraint if it exists
if existing_pk and existing_pk.get("constrained_columns"):
pk_name = existing_pk.get("name")
if existing_pk and existing_pk.get('constrained_columns'):
pk_name = existing_pk.get('name')
if pk_name:
print(f"Dropping primary key constraint: {pk_name}")
batch_op.drop_constraint(pk_name, type_="primary")
print(f'Dropping primary key constraint: {pk_name}')
batch_op.drop_constraint(pk_name, type_='primary')
# Now create the new primary key with the combination of 'id' and 'user_id'
print("Creating new primary key with 'id' and 'user_id'.")
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
batch_op.create_primary_key('pk_id_user_id', ['id', 'user_id'])
# Drop unique constraints that could conflict with the new primary key
for constraint in unique_constraints:
if (
constraint["name"] == "uq_id_user_id"
constraint['name'] == 'uq_id_user_id'
): # Adjust this name according to what is actually returned by the inspector
print(f"Dropping unique constraint: {constraint['name']}")
batch_op.drop_constraint(constraint["name"], type_="unique")
print(f'Dropping unique constraint: {constraint["name"]}')
batch_op.drop_constraint(constraint['name'], type_='unique')
for index in existing_indexes:
if index["unique"]:
if not any(
constraint["name"] == index["name"]
for constraint in unique_constraints
):
if index['unique']:
if not any(constraint['name'] == index['name'] for constraint in unique_constraints):
# You are attempting to drop unique indexes
print(f"Dropping unique index: {index['name']}")
batch_op.drop_index(index["name"])
print(f'Dropping unique index: {index["name"]}')
batch_op.drop_index(index['name'])
def downgrade():
conn = op.get_bind()
inspector = Inspector.from_engine(conn)
current_pk = inspector.get_pk_constraint("tag")
current_pk = inspector.get_pk_constraint('tag')
with op.batch_alter_table("tag", schema=None) as batch_op:
with op.batch_alter_table('tag', schema=None) as batch_op:
# Drop the current primary key first, if it matches the one we know we added in upgrade
if current_pk and "pk_id_user_id" == current_pk.get("name"):
batch_op.drop_constraint("pk_id_user_id", type_="primary")
if current_pk and 'pk_id_user_id' == current_pk.get('name'):
batch_op.drop_constraint('pk_id_user_id', type_='primary')
# Restore the original primary key
batch_op.create_primary_key("pk_id", ["id"])
batch_op.create_primary_key('pk_id', ['id'])
# Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id'])

View File

@@ -12,21 +12,21 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "3af16a1c9fb6"
down_revision: Union[str, None] = "018012973d35"
revision: str = '3af16a1c9fb6'
down_revision: Union[str, None] = '018012973d35'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("user", sa.Column("username", sa.String(length=50), nullable=True))
op.add_column("user", sa.Column("bio", sa.Text(), nullable=True))
op.add_column("user", sa.Column("gender", sa.Text(), nullable=True))
op.add_column("user", sa.Column("date_of_birth", sa.Date(), nullable=True))
op.add_column('user', sa.Column('username', sa.String(length=50), nullable=True))
op.add_column('user', sa.Column('bio', sa.Text(), nullable=True))
op.add_column('user', sa.Column('gender', sa.Text(), nullable=True))
op.add_column('user', sa.Column('date_of_birth', sa.Date(), nullable=True))
def downgrade() -> None:
op.drop_column("user", "username")
op.drop_column("user", "bio")
op.drop_column("user", "gender")
op.drop_column("user", "date_of_birth")
op.drop_column('user', 'username')
op.drop_column('user', 'bio')
op.drop_column('user', 'gender')
op.drop_column('user', 'date_of_birth')

View File

@@ -18,38 +18,38 @@ import json
import uuid
# revision identifiers, used by Alembic.
revision: str = "3e0e00844bb0"
down_revision: Union[str, None] = "90ef40d4714e"
revision: str = '3e0e00844bb0'
down_revision: Union[str, None] = '90ef40d4714e'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"knowledge_file",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
'knowledge_file',
sa.Column('id', sa.Text(), primary_key=True),
sa.Column('user_id', sa.Text(), nullable=False),
sa.Column(
"knowledge_id",
'knowledge_id',
sa.Text(),
sa.ForeignKey("knowledge.id", ondelete="CASCADE"),
sa.ForeignKey('knowledge.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column(
"file_id",
'file_id',
sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"),
sa.ForeignKey('file.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column('updated_at', sa.BigInteger(), nullable=False),
# indexes
sa.Index("ix_knowledge_file_knowledge_id", "knowledge_id"),
sa.Index("ix_knowledge_file_file_id", "file_id"),
sa.Index("ix_knowledge_file_user_id", "user_id"),
sa.Index('ix_knowledge_file_knowledge_id', 'knowledge_id'),
sa.Index('ix_knowledge_file_file_id', 'file_id'),
sa.Index('ix_knowledge_file_user_id', 'user_id'),
# unique constraints
sa.UniqueConstraint(
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
'knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file'
), # prevent duplicate entries
)
@@ -57,35 +57,33 @@ def upgrade() -> None:
# 2. Read existing group with user_ids JSON column
knowledge_table = sa.Table(
"knowledge",
'knowledge',
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("data", sa.JSON()), # JSON stored as text in SQLite + PG
sa.Column('id', sa.Text()),
sa.Column('user_id', sa.Text()),
sa.Column('data', sa.JSON()), # JSON stored as text in SQLite + PG
)
results = connection.execute(
sa.select(
knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data
)
sa.select(knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data)
).fetchall()
# 3. Insert members into group_member table
kf_table = sa.Table(
"knowledge_file",
'knowledge_file',
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("user_id", sa.Text()),
sa.Column("knowledge_id", sa.Text()),
sa.Column("file_id", sa.Text()),
sa.Column("created_at", sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()),
sa.Column('id', sa.Text()),
sa.Column('user_id', sa.Text()),
sa.Column('knowledge_id', sa.Text()),
sa.Column('file_id', sa.Text()),
sa.Column('created_at', sa.BigInteger()),
sa.Column('updated_at', sa.BigInteger()),
)
file_table = sa.Table(
"file",
'file',
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column('id', sa.Text()),
)
now = int(time.time())
@@ -102,50 +100,48 @@ def upgrade() -> None:
if not isinstance(data, dict):
continue
file_ids = data.get("file_ids", [])
file_ids = data.get('file_ids', [])
for file_id in file_ids:
file_exists = connection.execute(
sa.select(file_table.c.id).where(file_table.c.id == file_id)
).fetchone()
file_exists = connection.execute(sa.select(file_table.c.id).where(file_table.c.id == file_id)).fetchone()
if not file_exists:
continue # skip non-existing files
row = {
"id": str(uuid.uuid4()),
"user_id": user_id,
"knowledge_id": knowledge_id,
"file_id": file_id,
"created_at": now,
"updated_at": now,
'id': str(uuid.uuid4()),
'user_id': user_id,
'knowledge_id': knowledge_id,
'file_id': file_id,
'created_at': now,
'updated_at': now,
}
connection.execute(kf_table.insert().values(**row))
with op.batch_alter_table("knowledge") as batch:
batch.drop_column("data")
with op.batch_alter_table('knowledge') as batch:
batch.drop_column('data')
def downgrade() -> None:
# 1. Add back the old data column
op.add_column("knowledge", sa.Column("data", sa.JSON(), nullable=True))
op.add_column('knowledge', sa.Column('data', sa.JSON(), nullable=True))
connection = op.get_bind()
# 2. Read knowledge_file entries and reconstruct data JSON
knowledge_table = sa.Table(
"knowledge",
'knowledge',
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("data", sa.JSON()),
sa.Column('id', sa.Text()),
sa.Column('data', sa.JSON()),
)
kf_table = sa.Table(
"knowledge_file",
'knowledge_file',
sa.MetaData(),
sa.Column("id", sa.Text()),
sa.Column("knowledge_id", sa.Text()),
sa.Column("file_id", sa.Text()),
sa.Column('id', sa.Text()),
sa.Column('knowledge_id', sa.Text()),
sa.Column('file_id', sa.Text()),
)
results = connection.execute(sa.select(knowledge_table.c.id)).fetchall()
@@ -157,13 +153,9 @@ def downgrade() -> None:
file_ids_list = [fid for (fid,) in file_ids]
data_json = {"file_ids": file_ids_list}
data_json = {'file_ids': file_ids_list}
connection.execute(
knowledge_table.update()
.where(knowledge_table.c.id == knowledge_id)
.values(data=data_json)
)
connection.execute(knowledge_table.update().where(knowledge_table.c.id == knowledge_id).values(data=data_json))
# 3. Drop the knowledge_file table
op.drop_table("knowledge_file")
op.drop_table('knowledge_file')

View File

@@ -9,56 +9,56 @@ Create Date: 2024-10-23 03:00:00.000000
from alembic import op
import sqlalchemy as sa
revision = "4ace53fd72c8"
down_revision = "af906e964978"
revision = '4ace53fd72c8'
down_revision = 'af906e964978'
branch_labels = None
depends_on = None
def upgrade():
# Perform safe alterations using batch operation
with op.batch_alter_table("folder", schema=None) as batch_op:
with op.batch_alter_table('folder', schema=None) as batch_op:
# Step 1: Remove server defaults for created_at and updated_at
batch_op.alter_column(
"created_at",
'created_at',
server_default=None, # Removing server default
)
batch_op.alter_column(
"updated_at",
'updated_at',
server_default=None, # Removing server default
)
# Step 2: Change the column types to BigInteger for created_at
batch_op.alter_column(
"created_at",
'created_at',
type_=sa.BigInteger(),
existing_type=sa.DateTime(),
existing_nullable=False,
postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL
postgresql_using='extract(epoch from created_at)::bigint', # Conversion for PostgreSQL
)
# Change the column types to BigInteger for updated_at
batch_op.alter_column(
"updated_at",
'updated_at',
type_=sa.BigInteger(),
existing_type=sa.DateTime(),
existing_nullable=False,
postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL
postgresql_using='extract(epoch from updated_at)::bigint', # Conversion for PostgreSQL
)
def downgrade():
# Downgrade: Convert columns back to DateTime and restore defaults
with op.batch_alter_table("folder", schema=None) as batch_op:
with op.batch_alter_table('folder', schema=None) as batch_op:
batch_op.alter_column(
"created_at",
'created_at',
type_=sa.DateTime(),
existing_type=sa.BigInteger(),
existing_nullable=False,
server_default=sa.func.now(), # Restoring server default on downgrade
)
batch_op.alter_column(
"updated_at",
'updated_at',
type_=sa.DateTime(),
existing_type=sa.BigInteger(),
existing_nullable=False,

View File

@@ -9,40 +9,40 @@ Create Date: 2024-12-22 03:00:00.000000
from alembic import op
import sqlalchemy as sa
revision = "57c599a3cb57"
down_revision = "922e7a387820"
revision = '57c599a3cb57'
down_revision = '922e7a387820'
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"channel",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text()),
sa.Column("name", sa.Text()),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
'channel',
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column('user_id', sa.Text()),
sa.Column('name', sa.Text()),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('data', sa.JSON(), nullable=True),
sa.Column('meta', sa.JSON(), nullable=True),
sa.Column('access_control', sa.JSON(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
)
op.create_table(
"message",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text()),
sa.Column("channel_id", sa.Text(), nullable=True),
sa.Column("content", sa.Text()),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
'message',
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column('user_id', sa.Text()),
sa.Column('channel_id', sa.Text(), nullable=True),
sa.Column('content', sa.Text()),
sa.Column('data', sa.JSON(), nullable=True),
sa.Column('meta', sa.JSON(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
)
def downgrade():
op.drop_table("channel")
op.drop_table('channel')
op.drop_table("message")
op.drop_table('message')

View File

@@ -13,41 +13,39 @@ import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "6283dc0e4d8d"
down_revision: Union[str, None] = "3e0e00844bb0"
revision: str = '6283dc0e4d8d'
down_revision: Union[str, None] = '3e0e00844bb0'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"channel_file",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
'channel_file',
sa.Column('id', sa.Text(), primary_key=True),
sa.Column('user_id', sa.Text(), nullable=False),
sa.Column(
"channel_id",
'channel_id',
sa.Text(),
sa.ForeignKey("channel.id", ondelete="CASCADE"),
sa.ForeignKey('channel.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column(
"file_id",
'file_id',
sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"),
sa.ForeignKey('file.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column('updated_at', sa.BigInteger(), nullable=False),
# indexes
sa.Index("ix_channel_file_channel_id", "channel_id"),
sa.Index("ix_channel_file_file_id", "file_id"),
sa.Index("ix_channel_file_user_id", "user_id"),
sa.Index('ix_channel_file_channel_id', 'channel_id'),
sa.Index('ix_channel_file_file_id', 'file_id'),
sa.Index('ix_channel_file_user_id', 'user_id'),
# unique constraints
sa.UniqueConstraint(
"channel_id", "file_id", name="uq_channel_file_channel_file"
), # prevent duplicate entries
sa.UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'), # prevent duplicate entries
)
def downgrade() -> None:
op.drop_table("channel_file")
op.drop_table('channel_file')

View File

@@ -11,37 +11,37 @@ import sqlalchemy as sa
from sqlalchemy.sql import table, column, select
import json
revision = "6a39f3d8e55c"
down_revision = "c0fbf31ca0db"
revision = '6a39f3d8e55c'
down_revision = 'c0fbf31ca0db'
branch_labels = None
depends_on = None
def upgrade():
# Creating the 'knowledge' table
print("Creating knowledge table")
print('Creating knowledge table')
knowledge_table = op.create_table(
"knowledge",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
'knowledge',
sa.Column('id', sa.Text(), primary_key=True),
sa.Column('user_id', sa.Text(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('data', sa.JSON(), nullable=True),
sa.Column('meta', sa.JSON(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
)
print("Migrating data from document table to knowledge table")
print('Migrating data from document table to knowledge table')
# Representation of the existing 'document' table
document_table = table(
"document",
column("collection_name", sa.String()),
column("user_id", sa.String()),
column("name", sa.String()),
column("title", sa.Text()),
column("content", sa.Text()),
column("timestamp", sa.BigInteger()),
'document',
column('collection_name', sa.String()),
column('user_id', sa.String()),
column('name', sa.String()),
column('title', sa.Text()),
column('content', sa.Text()),
column('timestamp', sa.BigInteger()),
)
# Select all from existing document table
@@ -64,9 +64,9 @@ def upgrade():
user_id=doc.user_id,
description=doc.name,
meta={
"legacy": True,
"document": True,
"tags": json.loads(doc.content or "{}").get("tags", []),
'legacy': True,
'document': True,
'tags': json.loads(doc.content or '{}').get('tags', []),
},
name=doc.title,
created_at=doc.timestamp,
@@ -76,4 +76,4 @@ def upgrade():
def downgrade():
op.drop_table("knowledge")
op.drop_table('knowledge')

View File

@@ -9,18 +9,18 @@ Create Date: 2024-12-23 03:00:00.000000
from alembic import op
import sqlalchemy as sa
revision = "7826ab40b532"
down_revision = "57c599a3cb57"
revision = '7826ab40b532'
down_revision = '57c599a3cb57'
branch_labels = None
depends_on = None
def upgrade():
op.add_column(
"file",
sa.Column("access_control", sa.JSON(), nullable=True),
'file',
sa.Column('access_control', sa.JSON(), nullable=True),
)
def downgrade():
op.drop_column("file", "access_control")
op.drop_column('file', 'access_control')

View File

@@ -16,7 +16,7 @@ from open_webui.internal.db import JSONField
from open_webui.migrations.util import get_existing_tables
# revision identifiers, used by Alembic.
revision: str = "7e5b5dc7342b"
revision: str = '7e5b5dc7342b'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -26,179 +26,179 @@ def upgrade() -> None:
existing_tables = set(get_existing_tables())
# ### commands auto generated by Alembic - please adjust! ###
if "auth" not in existing_tables:
if 'auth' not in existing_tables:
op.create_table(
"auth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.Text(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
'auth',
sa.Column('id', sa.String(), nullable=False),
sa.Column('email', sa.String(), nullable=True),
sa.Column('password', sa.Text(), nullable=True),
sa.Column('active', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
if "chat" not in existing_tables:
if 'chat' not in existing_tables:
op.create_table(
"chat",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("chat", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.Text(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
'chat',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('chat', sa.Text(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('share_id', sa.Text(), nullable=True),
sa.Column('archived', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('share_id'),
)
if "chatidtag" not in existing_tables:
if 'chatidtag' not in existing_tables:
op.create_table(
"chatidtag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
'chatidtag',
sa.Column('id', sa.String(), nullable=False),
sa.Column('tag_name', sa.String(), nullable=True),
sa.Column('chat_id', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
if "document" not in existing_tables:
if 'document' not in existing_tables:
op.create_table(
"document",
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
'document',
sa.Column('collection_name', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('filename', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('collection_name'),
sa.UniqueConstraint('name'),
)
if "file" not in existing_tables:
if 'file' not in existing_tables:
op.create_table(
"file",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
'file',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('filename', sa.Text(), nullable=True),
sa.Column('meta', JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
if "function" not in existing_tables:
if 'function' not in existing_tables:
op.create_table(
"function",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("type", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("valves", JSONField(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column("is_global", sa.Boolean(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
'function',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('meta', JSONField(), nullable=True),
sa.Column('valves', JSONField(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('is_global', sa.Boolean(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
if "memory" not in existing_tables:
if 'memory' not in existing_tables:
op.create_table(
"memory",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
'memory',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
if "model" not in existing_tables:
if 'model' not in existing_tables:
op.create_table(
"model",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("base_model_id", sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("params", JSONField(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
'model',
sa.Column('id', sa.Text(), nullable=False),
sa.Column('user_id', sa.Text(), nullable=True),
sa.Column('base_model_id', sa.Text(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('params', JSONField(), nullable=True),
sa.Column('meta', JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
if "prompt" not in existing_tables:
if 'prompt' not in existing_tables:
op.create_table(
"prompt",
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
'prompt',
sa.Column('command', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('command'),
)
if "tag" not in existing_tables:
if 'tag' not in existing_tables:
op.create_table(
"tag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
'tag',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('data', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
if "tool" not in existing_tables:
if 'tool' not in existing_tables:
op.create_table(
"tool",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("specs", JSONField(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("valves", JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
'tool',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('specs', JSONField(), nullable=True),
sa.Column('meta', JSONField(), nullable=True),
sa.Column('valves', JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id'),
)
if "user" not in existing_tables:
if 'user' not in existing_tables:
op.create_table(
"user",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.Text(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", JSONField(), nullable=True),
sa.Column("info", JSONField(), nullable=True),
sa.Column("oauth_sub", sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
sa.UniqueConstraint("oauth_sub"),
'user',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('email', sa.String(), nullable=True),
sa.Column('role', sa.String(), nullable=True),
sa.Column('profile_image_url', sa.Text(), nullable=True),
sa.Column('last_active_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('api_key', sa.String(), nullable=True),
sa.Column('settings', JSONField(), nullable=True),
sa.Column('info', JSONField(), nullable=True),
sa.Column('oauth_sub', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('api_key'),
sa.UniqueConstraint('oauth_sub'),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user")
op.drop_table("tool")
op.drop_table("tag")
op.drop_table("prompt")
op.drop_table("model")
op.drop_table("memory")
op.drop_table("function")
op.drop_table("file")
op.drop_table("document")
op.drop_table("chatidtag")
op.drop_table("chat")
op.drop_table("auth")
op.drop_table('user')
op.drop_table('tool')
op.drop_table('tag')
op.drop_table('prompt')
op.drop_table('model')
op.drop_table('memory')
op.drop_table('function')
op.drop_table('file')
op.drop_table('document')
op.drop_table('chatidtag')
op.drop_table('chat')
op.drop_table('auth')
# ### end Alembic commands ###

View File

@@ -13,36 +13,34 @@ import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "81cc2ce44d79"
down_revision: Union[str, None] = "6283dc0e4d8d"
revision: str = '81cc2ce44d79'
down_revision: Union[str, None] = '6283dc0e4d8d'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add message_id column to channel_file table
with op.batch_alter_table("channel_file", schema=None) as batch_op:
with op.batch_alter_table('channel_file', schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"message_id",
'message_id',
sa.Text(),
sa.ForeignKey(
"message.id", ondelete="CASCADE", name="fk_channel_file_message_id"
),
sa.ForeignKey('message.id', ondelete='CASCADE', name='fk_channel_file_message_id'),
nullable=True,
)
)
# Add data column to knowledge table
with op.batch_alter_table("knowledge", schema=None) as batch_op:
batch_op.add_column(sa.Column("data", sa.JSON(), nullable=True))
with op.batch_alter_table('knowledge', schema=None) as batch_op:
batch_op.add_column(sa.Column('data', sa.JSON(), nullable=True))
def downgrade() -> None:
# Remove message_id column from channel_file table
with op.batch_alter_table("channel_file", schema=None) as batch_op:
batch_op.drop_column("message_id")
with op.batch_alter_table('channel_file', schema=None) as batch_op:
batch_op.drop_column('message_id')
# Remove data column from knowledge table
with op.batch_alter_table("knowledge", schema=None) as batch_op:
batch_op.drop_column("data")
with op.batch_alter_table('knowledge', schema=None) as batch_op:
batch_op.drop_column('data')

View File

@@ -16,8 +16,8 @@ import sqlalchemy as sa
log = logging.getLogger(__name__)
revision: str = "8452d01d26d7"
down_revision: Union[str, None] = "374d2f66af06"
revision: str = '8452d01d26d7'
down_revision: Union[str, None] = '374d2f66af06'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -51,74 +51,68 @@ def _flush_batch(conn, table, batch):
except Exception as e:
sp.rollback()
failed += 1
log.warning(f"Failed to insert message {msg['id']}: {e}")
log.warning(f'Failed to insert message {msg["id"]}: {e}')
return inserted, failed
def upgrade() -> None:
# Step 1: Create table
op.create_table(
"chat_message",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("chat_id", sa.Text(), nullable=False, index=True),
sa.Column("user_id", sa.Text(), index=True),
sa.Column("role", sa.Text(), nullable=False),
sa.Column("parent_id", sa.Text(), nullable=True),
sa.Column("content", sa.JSON(), nullable=True),
sa.Column("output", sa.JSON(), nullable=True),
sa.Column("model_id", sa.Text(), nullable=True, index=True),
sa.Column("files", sa.JSON(), nullable=True),
sa.Column("sources", sa.JSON(), nullable=True),
sa.Column("embeds", sa.JSON(), nullable=True),
sa.Column("done", sa.Boolean(), default=True),
sa.Column("status_history", sa.JSON(), nullable=True),
sa.Column("error", sa.JSON(), nullable=True),
sa.Column("usage", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), index=True),
sa.Column("updated_at", sa.BigInteger()),
sa.ForeignKeyConstraint(["chat_id"], ["chat.id"], ondelete="CASCADE"),
'chat_message',
sa.Column('id', sa.Text(), primary_key=True),
sa.Column('chat_id', sa.Text(), nullable=False, index=True),
sa.Column('user_id', sa.Text(), index=True),
sa.Column('role', sa.Text(), nullable=False),
sa.Column('parent_id', sa.Text(), nullable=True),
sa.Column('content', sa.JSON(), nullable=True),
sa.Column('output', sa.JSON(), nullable=True),
sa.Column('model_id', sa.Text(), nullable=True, index=True),
sa.Column('files', sa.JSON(), nullable=True),
sa.Column('sources', sa.JSON(), nullable=True),
sa.Column('embeds', sa.JSON(), nullable=True),
sa.Column('done', sa.Boolean(), default=True),
sa.Column('status_history', sa.JSON(), nullable=True),
sa.Column('error', sa.JSON(), nullable=True),
sa.Column('usage', sa.JSON(), nullable=True),
sa.Column('created_at', sa.BigInteger(), index=True),
sa.Column('updated_at', sa.BigInteger()),
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ondelete='CASCADE'),
)
# Create composite indexes
op.create_index(
"chat_message_chat_parent_idx", "chat_message", ["chat_id", "parent_id"]
)
op.create_index(
"chat_message_model_created_idx", "chat_message", ["model_id", "created_at"]
)
op.create_index(
"chat_message_user_created_idx", "chat_message", ["user_id", "created_at"]
)
op.create_index('chat_message_chat_parent_idx', 'chat_message', ['chat_id', 'parent_id'])
op.create_index('chat_message_model_created_idx', 'chat_message', ['model_id', 'created_at'])
op.create_index('chat_message_user_created_idx', 'chat_message', ['user_id', 'created_at'])
# Step 2: Backfill from existing chats
conn = op.get_bind()
chat_table = sa.table(
"chat",
sa.column("id", sa.Text()),
sa.column("user_id", sa.Text()),
sa.column("chat", sa.JSON()),
'chat',
sa.column('id', sa.Text()),
sa.column('user_id', sa.Text()),
sa.column('chat', sa.JSON()),
)
chat_message_table = sa.table(
"chat_message",
sa.column("id", sa.Text()),
sa.column("chat_id", sa.Text()),
sa.column("user_id", sa.Text()),
sa.column("role", sa.Text()),
sa.column("parent_id", sa.Text()),
sa.column("content", sa.JSON()),
sa.column("output", sa.JSON()),
sa.column("model_id", sa.Text()),
sa.column("files", sa.JSON()),
sa.column("sources", sa.JSON()),
sa.column("embeds", sa.JSON()),
sa.column("done", sa.Boolean()),
sa.column("status_history", sa.JSON()),
sa.column("error", sa.JSON()),
sa.column("usage", sa.JSON()),
sa.column("created_at", sa.BigInteger()),
sa.column("updated_at", sa.BigInteger()),
'chat_message',
sa.column('id', sa.Text()),
sa.column('chat_id', sa.Text()),
sa.column('user_id', sa.Text()),
sa.column('role', sa.Text()),
sa.column('parent_id', sa.Text()),
sa.column('content', sa.JSON()),
sa.column('output', sa.JSON()),
sa.column('model_id', sa.Text()),
sa.column('files', sa.JSON()),
sa.column('sources', sa.JSON()),
sa.column('embeds', sa.JSON()),
sa.column('done', sa.Boolean()),
sa.column('status_history', sa.JSON()),
sa.column('error', sa.JSON()),
sa.column('usage', sa.JSON()),
sa.column('created_at', sa.BigInteger()),
sa.column('updated_at', sa.BigInteger()),
)
# Stream rows instead of loading all into memory:
@@ -126,7 +120,7 @@ def upgrade() -> None:
# - stream_results: enables server-side cursors on PostgreSQL (no-op on SQLite)
result = conn.execute(
sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat)
.where(~chat_table.c.user_id.like("shared-%"))
.where(~chat_table.c.user_id.like('shared-%'))
.execution_options(yield_per=1000, stream_results=True)
)
@@ -150,11 +144,11 @@ def upgrade() -> None:
except Exception:
continue
history = chat_data.get("history", {})
history = chat_data.get('history', {})
if not isinstance(history, dict):
continue
messages = history.get("messages", {})
messages = history.get('messages', {})
if not isinstance(messages, dict):
continue
@@ -162,11 +156,11 @@ def upgrade() -> None:
if not isinstance(message, dict):
continue
role = message.get("role")
role = message.get('role')
if not role:
continue
timestamp = message.get("timestamp", now)
timestamp = message.get('timestamp', now)
try:
timestamp = int(float(timestamp))
@@ -182,37 +176,33 @@ def upgrade() -> None:
messages_batch.append(
{
"id": f"{chat_id}-{message_id}",
"chat_id": chat_id,
"user_id": user_id,
"role": role,
"parent_id": message.get("parentId"),
"content": message.get("content"),
"output": message.get("output"),
"model_id": message.get("model"),
"files": message.get("files"),
"sources": message.get("sources"),
"embeds": message.get("embeds"),
"done": message.get("done", True),
"status_history": message.get("statusHistory"),
"error": message.get("error"),
"usage": message.get("usage"),
"created_at": timestamp,
"updated_at": timestamp,
'id': f'{chat_id}-{message_id}',
'chat_id': chat_id,
'user_id': user_id,
'role': role,
'parent_id': message.get('parentId'),
'content': message.get('content'),
'output': message.get('output'),
'model_id': message.get('model'),
'files': message.get('files'),
'sources': message.get('sources'),
'embeds': message.get('embeds'),
'done': message.get('done', True),
'status_history': message.get('statusHistory'),
'error': message.get('error'),
'usage': message.get('usage'),
'created_at': timestamp,
'updated_at': timestamp,
}
)
# Flush batch when full
if len(messages_batch) >= BATCH_SIZE:
inserted, failed = _flush_batch(
conn, chat_message_table, messages_batch
)
inserted, failed = _flush_batch(conn, chat_message_table, messages_batch)
total_inserted += inserted
total_failed += failed
if total_inserted % 50000 < BATCH_SIZE:
log.info(
f"Migration progress: {total_inserted} messages inserted..."
)
log.info(f'Migration progress: {total_inserted} messages inserted...')
messages_batch.clear()
# Flush remaining messages
@@ -221,13 +211,11 @@ def upgrade() -> None:
total_inserted += inserted
total_failed += failed
log.info(
f"Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)"
)
log.info(f'Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)')
def downgrade() -> None:
op.drop_index("chat_message_user_created_idx", table_name="chat_message")
op.drop_index("chat_message_model_created_idx", table_name="chat_message")
op.drop_index("chat_message_chat_parent_idx", table_name="chat_message")
op.drop_table("chat_message")
op.drop_index('chat_message_user_created_idx', table_name='chat_message')
op.drop_index('chat_message_model_created_idx', table_name='chat_message')
op.drop_index('chat_message_chat_parent_idx', table_name='chat_message')
op.drop_table('chat_message')

View File

@@ -13,48 +13,46 @@ import sqlalchemy as sa
import open_webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "90ef40d4714e"
down_revision: Union[str, None] = "b10670c03dd5"
revision: str = '90ef40d4714e'
down_revision: Union[str, None] = 'b10670c03dd5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Update 'channel' table
op.add_column("channel", sa.Column("is_private", sa.Boolean(), nullable=True))
op.add_column('channel', sa.Column('is_private', sa.Boolean(), nullable=True))
op.add_column("channel", sa.Column("archived_at", sa.BigInteger(), nullable=True))
op.add_column("channel", sa.Column("archived_by", sa.Text(), nullable=True))
op.add_column('channel', sa.Column('archived_at', sa.BigInteger(), nullable=True))
op.add_column('channel', sa.Column('archived_by', sa.Text(), nullable=True))
op.add_column("channel", sa.Column("deleted_at", sa.BigInteger(), nullable=True))
op.add_column("channel", sa.Column("deleted_by", sa.Text(), nullable=True))
op.add_column('channel', sa.Column('deleted_at', sa.BigInteger(), nullable=True))
op.add_column('channel', sa.Column('deleted_by', sa.Text(), nullable=True))
op.add_column("channel", sa.Column("updated_by", sa.Text(), nullable=True))
op.add_column('channel', sa.Column('updated_by', sa.Text(), nullable=True))
# Update 'channel_member' table
op.add_column("channel_member", sa.Column("role", sa.Text(), nullable=True))
op.add_column("channel_member", sa.Column("invited_by", sa.Text(), nullable=True))
op.add_column(
"channel_member", sa.Column("invited_at", sa.BigInteger(), nullable=True)
)
op.add_column('channel_member', sa.Column('role', sa.Text(), nullable=True))
op.add_column('channel_member', sa.Column('invited_by', sa.Text(), nullable=True))
op.add_column('channel_member', sa.Column('invited_at', sa.BigInteger(), nullable=True))
# Create 'channel_webhook' table
op.create_table(
"channel_webhook",
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column("user_id", sa.Text(), nullable=False),
'channel_webhook',
sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column('user_id', sa.Text(), nullable=False),
sa.Column(
"channel_id",
'channel_id',
sa.Text(),
sa.ForeignKey("channel.id", ondelete="CASCADE"),
sa.ForeignKey('channel.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("profile_image_url", sa.Text(), nullable=True),
sa.Column("token", sa.Text(), nullable=False),
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('profile_image_url', sa.Text(), nullable=True),
sa.Column('token', sa.Text(), nullable=False),
sa.Column('last_used_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column('updated_at', sa.BigInteger(), nullable=False),
)
pass
@@ -62,19 +60,19 @@ def upgrade() -> None:
def downgrade() -> None:
# Downgrade 'channel' table
op.drop_column("channel", "is_private")
op.drop_column("channel", "archived_at")
op.drop_column("channel", "archived_by")
op.drop_column("channel", "deleted_at")
op.drop_column("channel", "deleted_by")
op.drop_column("channel", "updated_by")
op.drop_column('channel', 'is_private')
op.drop_column('channel', 'archived_at')
op.drop_column('channel', 'archived_by')
op.drop_column('channel', 'deleted_at')
op.drop_column('channel', 'deleted_by')
op.drop_column('channel', 'updated_by')
# Downgrade 'channel_member' table
op.drop_column("channel_member", "role")
op.drop_column("channel_member", "invited_by")
op.drop_column("channel_member", "invited_at")
op.drop_column('channel_member', 'role')
op.drop_column('channel_member', 'invited_by')
op.drop_column('channel_member', 'invited_at')
# Drop 'channel_webhook' table
op.drop_table("channel_webhook")
op.drop_table('channel_webhook')
pass

View File

@@ -9,38 +9,38 @@ Create Date: 2024-11-14 03:00:00.000000
from alembic import op
import sqlalchemy as sa
revision = "922e7a387820"
down_revision = "4ace53fd72c8"
revision = '922e7a387820'
down_revision = '4ace53fd72c8'
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"group",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("permissions", sa.JSON(), nullable=True),
sa.Column("user_ids", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
'group',
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column('user_id', sa.Text(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('data', sa.JSON(), nullable=True),
sa.Column('meta', sa.JSON(), nullable=True),
sa.Column('permissions', sa.JSON(), nullable=True),
sa.Column('user_ids', sa.JSON(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
)
# Add 'access_control' column to 'model' table
op.add_column(
"model",
sa.Column("access_control", sa.JSON(), nullable=True),
'model',
sa.Column('access_control', sa.JSON(), nullable=True),
)
# Add 'is_active' column to 'model' table
op.add_column(
"model",
'model',
sa.Column(
"is_active",
'is_active',
sa.Boolean(),
nullable=False,
server_default=sa.sql.expression.true(),
@@ -49,37 +49,37 @@ def upgrade():
# Add 'access_control' column to 'knowledge' table
op.add_column(
"knowledge",
sa.Column("access_control", sa.JSON(), nullable=True),
'knowledge',
sa.Column('access_control', sa.JSON(), nullable=True),
)
# Add 'access_control' column to 'prompt' table
op.add_column(
"prompt",
sa.Column("access_control", sa.JSON(), nullable=True),
'prompt',
sa.Column('access_control', sa.JSON(), nullable=True),
)
# Add 'access_control' column to 'tools' table
op.add_column(
"tool",
sa.Column("access_control", sa.JSON(), nullable=True),
'tool',
sa.Column('access_control', sa.JSON(), nullable=True),
)
def downgrade():
op.drop_table("group")
op.drop_table('group')
# Drop 'access_control' column from 'model' table
op.drop_column("model", "access_control")
op.drop_column('model', 'access_control')
# Drop 'is_active' column from 'model' table
op.drop_column("model", "is_active")
op.drop_column('model', 'is_active')
# Drop 'access_control' column from 'knowledge' table
op.drop_column("knowledge", "access_control")
op.drop_column('knowledge', 'access_control')
# Drop 'access_control' column from 'prompt' table
op.drop_column("prompt", "access_control")
op.drop_column('prompt', 'access_control')
# Drop 'access_control' column from 'tools' table
op.drop_column("tool", "access_control")
op.drop_column('tool', 'access_control')

View File

@@ -9,25 +9,25 @@ Create Date: 2025-05-03 03:00:00.000000
from alembic import op
import sqlalchemy as sa
revision = "9f0c9cd09105"
down_revision = "3781e22d8b01"
revision = '9f0c9cd09105'
down_revision = '3781e22d8b01'
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"note",
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
'note',
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column('user_id', sa.Text(), nullable=True),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('data', sa.JSON(), nullable=True),
sa.Column('meta', sa.JSON(), nullable=True),
sa.Column('access_control', sa.JSON(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
)
def downgrade():
op.drop_table("note")
op.drop_table('note')

View File

@@ -13,8 +13,8 @@ import sqlalchemy as sa
from open_webui.migrations.util import get_existing_tables
revision: str = "a1b2c3d4e5f6"
down_revision: Union[str, None] = "f1e2d3c4b5a6"
revision: str = 'a1b2c3d4e5f6'
down_revision: Union[str, None] = 'f1e2d3c4b5a6'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -22,24 +22,24 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
existing_tables = set(get_existing_tables())
if "skill" not in existing_tables:
if 'skill' not in existing_tables:
op.create_table(
"skill",
sa.Column("id", sa.String(), nullable=False, primary_key=True),
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("name", sa.Text(), nullable=False, unique=True),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=False),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.Column("created_at", sa.BigInteger(), nullable=False),
'skill',
sa.Column('id', sa.String(), nullable=False, primary_key=True),
sa.Column('user_id', sa.String(), nullable=False),
sa.Column('name', sa.Text(), nullable=False, unique=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('meta', sa.JSON(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('updated_at', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.BigInteger(), nullable=False),
)
op.create_index("idx_skill_user_id", "skill", ["user_id"])
op.create_index("idx_skill_updated_at", "skill", ["updated_at"])
op.create_index('idx_skill_user_id', 'skill', ['user_id'])
op.create_index('idx_skill_updated_at', 'skill', ['updated_at'])
def downgrade() -> None:
op.drop_index("idx_skill_updated_at", table_name="skill")
op.drop_index("idx_skill_user_id", table_name="skill")
op.drop_table("skill")
op.drop_index('idx_skill_updated_at', table_name='skill')
op.drop_index('idx_skill_user_id', table_name='skill')
op.drop_table('skill')

View File

@@ -12,8 +12,8 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "a5c220713937"
down_revision: Union[str, None] = "38d63c18f30f"
revision: str = 'a5c220713937'
down_revision: Union[str, None] = '38d63c18f30f'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -21,14 +21,14 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add 'reply_to_id' column to the 'message' table for replying to messages
op.add_column(
"message",
sa.Column("reply_to_id", sa.Text(), nullable=True),
'message',
sa.Column('reply_to_id', sa.Text(), nullable=True),
)
pass
def downgrade() -> None:
# Remove 'reply_to_id' column from the 'message' table
op.drop_column("message", "reply_to_id")
op.drop_column('message', 'reply_to_id')
pass

View File

@@ -10,8 +10,8 @@ from alembic import op
import sqlalchemy as sa
# Revision identifiers, used by Alembic.
revision = "af906e964978"
down_revision = "c29facfe716b"
revision = 'af906e964978'
down_revision = 'c29facfe716b'
branch_labels = None
depends_on = None
@@ -19,33 +19,23 @@ depends_on = None
def upgrade():
# ### Create feedback table ###
op.create_table(
"feedback",
'feedback',
sa.Column('id', sa.Text(), primary_key=True), # Unique identifier for each feedback (TEXT type)
sa.Column('user_id', sa.Text(), nullable=True), # ID of the user providing the feedback (TEXT type)
sa.Column('version', sa.BigInteger(), default=0), # Version of feedback (BIGINT type)
sa.Column('type', sa.Text(), nullable=True), # Type of feedback (TEXT type)
sa.Column('data', sa.JSON(), nullable=True), # Feedback data (JSON type)
sa.Column('meta', sa.JSON(), nullable=True), # Metadata for feedback (JSON type)
sa.Column('snapshot', sa.JSON(), nullable=True), # snapshot data for feedback (JSON type)
sa.Column(
"id", sa.Text(), primary_key=True
), # Unique identifier for each feedback (TEXT type)
sa.Column(
"user_id", sa.Text(), nullable=True
), # ID of the user providing the feedback (TEXT type)
sa.Column(
"version", sa.BigInteger(), default=0
), # Version of feedback (BIGINT type)
sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type)
sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type)
sa.Column(
"meta", sa.JSON(), nullable=True
), # Metadata for feedback (JSON type)
sa.Column(
"snapshot", sa.JSON(), nullable=True
), # snapshot data for feedback (JSON type)
sa.Column(
"created_at", sa.BigInteger(), nullable=False
'created_at', sa.BigInteger(), nullable=False
), # Feedback creation timestamp (BIGINT representing epoch)
sa.Column(
"updated_at", sa.BigInteger(), nullable=False
'updated_at', sa.BigInteger(), nullable=False
), # Feedback update timestamp (BIGINT representing epoch)
)
def downgrade():
# ### Drop feedback table ###
op.drop_table("feedback")
op.drop_table('feedback')

View File

@@ -17,8 +17,8 @@ import json
import time
# revision identifiers, used by Alembic.
revision: str = "b10670c03dd5"
down_revision: Union[str, None] = "2f1211949ecc"
revision: str = 'b10670c03dd5'
down_revision: Union[str, None] = '2f1211949ecc'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -33,13 +33,11 @@ def _drop_sqlite_indexes_for_column(table_name, column_name, conn):
for idx in indexes:
index_name = idx[1] # index name
# Get indexed columns
idx_info = conn.execute(
sa.text(f"PRAGMA index_info('{index_name}')")
).fetchall()
idx_info = conn.execute(sa.text(f"PRAGMA index_info('{index_name}')")).fetchall()
indexed_cols = [row[2] for row in idx_info] # col names
if column_name in indexed_cols:
conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}"))
conn.execute(sa.text(f'DROP INDEX IF EXISTS {index_name}'))
def _convert_column_to_json(table: str, column: str):
@@ -47,9 +45,9 @@ def _convert_column_to_json(table: str, column: str):
dialect = conn.dialect.name
# SQLite cannot ALTER COLUMN → must recreate column
if dialect == "sqlite":
if dialect == 'sqlite':
# 1. Add temporary column
op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True))
op.add_column(table, sa.Column(f'{column}_json', sa.JSON(), nullable=True))
# 2. Load old data
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
@@ -66,14 +64,14 @@ def _convert_column_to_json(table: str, column: str):
conn.execute(
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
{"val": json.dumps(parsed) if parsed else None, "id": uid},
{'val': json.dumps(parsed) if parsed else None, 'id': uid},
)
# 3. Drop old TEXT column
op.drop_column(table, column)
# 4. Rename new JSON column → original name
op.alter_column(table, f"{column}_json", new_column_name=column)
op.alter_column(table, f'{column}_json', new_column_name=column)
else:
# PostgreSQL supports direct CAST
@@ -81,7 +79,7 @@ def _convert_column_to_json(table: str, column: str):
table,
column,
type_=sa.JSON(),
postgresql_using=f"{column}::json",
postgresql_using=f'{column}::json',
)
@@ -89,85 +87,77 @@ def _convert_column_to_text(table: str, column: str):
conn = op.get_bind()
dialect = conn.dialect.name
if dialect == "sqlite":
op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True))
if dialect == 'sqlite':
op.add_column(table, sa.Column(f'{column}_text', sa.Text(), nullable=True))
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
for uid, raw in rows:
conn.execute(
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
{"val": json.dumps(raw) if raw else None, "id": uid},
{'val': json.dumps(raw) if raw else None, 'id': uid},
)
op.drop_column(table, column)
op.alter_column(table, f"{column}_text", new_column_name=column)
op.alter_column(table, f'{column}_text', new_column_name=column)
else:
op.alter_column(
table,
column,
type_=sa.Text(),
postgresql_using=f"to_json({column})::text",
postgresql_using=f'to_json({column})::text',
)
def upgrade() -> None:
op.add_column(
"user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True)
)
op.add_column("user", sa.Column("timezone", sa.String(), nullable=True))
op.add_column('user', sa.Column('profile_banner_image_url', sa.Text(), nullable=True))
op.add_column('user', sa.Column('timezone', sa.String(), nullable=True))
op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True))
op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True))
op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True))
op.add_column(
"user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True)
)
op.add_column('user', sa.Column('presence_state', sa.String(), nullable=True))
op.add_column('user', sa.Column('status_emoji', sa.String(), nullable=True))
op.add_column('user', sa.Column('status_message', sa.Text(), nullable=True))
op.add_column('user', sa.Column('status_expires_at', sa.BigInteger(), nullable=True))
op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True))
op.add_column('user', sa.Column('oauth', sa.JSON(), nullable=True))
# Convert info (TEXT/JSONField) → JSON
_convert_column_to_json("user", "info")
_convert_column_to_json('user', 'info')
# Convert settings (TEXT/JSONField) → JSON
_convert_column_to_json("user", "settings")
_convert_column_to_json('user', 'settings')
op.create_table(
"api_key",
sa.Column("id", sa.Text(), primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")),
sa.Column("key", sa.Text(), unique=True, nullable=False),
sa.Column("data", sa.JSON(), nullable=True),
sa.Column("expires_at", sa.BigInteger(), nullable=True),
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
'api_key',
sa.Column('id', sa.Text(), primary_key=True, unique=True),
sa.Column('user_id', sa.Text(), sa.ForeignKey('user.id', ondelete='CASCADE')),
sa.Column('key', sa.Text(), unique=True, nullable=False),
sa.Column('data', sa.JSON(), nullable=True),
sa.Column('expires_at', sa.BigInteger(), nullable=True),
sa.Column('last_used_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column('updated_at', sa.BigInteger(), nullable=False),
)
conn = op.get_bind()
users = conn.execute(
sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')
).fetchall()
users = conn.execute(sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')).fetchall()
for uid, oauth_sub in users:
if oauth_sub:
# Example formats supported:
# provider@sub
# plain sub (stored as {"oidc": {"sub": sub}})
if "@" in oauth_sub:
provider, sub = oauth_sub.split("@", 1)
if '@' in oauth_sub:
provider, sub = oauth_sub.split('@', 1)
else:
provider, sub = "oidc", oauth_sub
provider, sub = 'oidc', oauth_sub
oauth_json = json.dumps({provider: {"sub": sub}})
oauth_json = json.dumps({provider: {'sub': sub}})
conn.execute(
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
{"oauth": oauth_json, "id": uid},
{'oauth': oauth_json, 'id': uid},
)
users_with_keys = conn.execute(
sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')
).fetchall()
users_with_keys = conn.execute(sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')).fetchall()
now = int(time.time())
for uid, api_key in users_with_keys:
@@ -178,72 +168,70 @@ def upgrade() -> None:
VALUES (:id, :user_id, :key, :created_at, :updated_at)
"""),
{
"id": f"key_{uid}",
"user_id": uid,
"key": api_key,
"created_at": now,
"updated_at": now,
'id': f'key_{uid}',
'user_id': uid,
'key': api_key,
'created_at': now,
'updated_at': now,
},
)
if conn.dialect.name == "sqlite":
_drop_sqlite_indexes_for_column("user", "api_key", conn)
_drop_sqlite_indexes_for_column("user", "oauth_sub", conn)
if conn.dialect.name == 'sqlite':
_drop_sqlite_indexes_for_column('user', 'api_key', conn)
_drop_sqlite_indexes_for_column('user', 'oauth_sub', conn)
with op.batch_alter_table("user") as batch_op:
batch_op.drop_column("api_key")
batch_op.drop_column("oauth_sub")
with op.batch_alter_table('user') as batch_op:
batch_op.drop_column('api_key')
batch_op.drop_column('oauth_sub')
def downgrade() -> None:
# --- 1. Restore old oauth_sub column ---
op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True))
op.add_column('user', sa.Column('oauth_sub', sa.Text(), nullable=True))
conn = op.get_bind()
users = conn.execute(
sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')
).fetchall()
users = conn.execute(sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')).fetchall()
for uid, oauth in users:
try:
data = json.loads(oauth)
provider = list(data.keys())[0]
sub = data[provider].get("sub")
oauth_sub = f"{provider}@{sub}"
sub = data[provider].get('sub')
oauth_sub = f'{provider}@{sub}'
except Exception:
oauth_sub = None
conn.execute(
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
{"oauth_sub": oauth_sub, "id": uid},
{'oauth_sub': oauth_sub, 'id': uid},
)
op.drop_column("user", "oauth")
op.drop_column('user', 'oauth')
# --- 2. Restore api_key field ---
op.add_column("user", sa.Column("api_key", sa.String(), nullable=True))
op.add_column('user', sa.Column('api_key', sa.String(), nullable=True))
# Restore values from api_key
keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall()
keys = conn.execute(sa.text('SELECT user_id, key FROM api_key')).fetchall()
for uid, key in keys:
conn.execute(
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
{"key": key, "id": uid},
{'key': key, 'id': uid},
)
# Drop new table
op.drop_table("api_key")
op.drop_table('api_key')
with op.batch_alter_table("user") as batch_op:
batch_op.drop_column("profile_banner_image_url")
batch_op.drop_column("timezone")
with op.batch_alter_table('user') as batch_op:
batch_op.drop_column('profile_banner_image_url')
batch_op.drop_column('timezone')
batch_op.drop_column("presence_state")
batch_op.drop_column("status_emoji")
batch_op.drop_column("status_message")
batch_op.drop_column("status_expires_at")
batch_op.drop_column('presence_state')
batch_op.drop_column('status_emoji')
batch_op.drop_column('status_message')
batch_op.drop_column('status_expires_at')
# Convert info (JSON) → TEXT
_convert_column_to_text("user", "info")
_convert_column_to_text('user', 'info')
# Convert settings (JSON) → TEXT
_convert_column_to_text("user", "settings")
_convert_column_to_text('user', 'settings')

View File

@@ -12,15 +12,15 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "b2c3d4e5f6a7"
down_revision: Union[str, None] = "a1b2c3d4e5f6"
revision: str = 'b2c3d4e5f6a7'
down_revision: Union[str, None] = 'a1b2c3d4e5f6'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("user", sa.Column("scim", sa.JSON(), nullable=True))
op.add_column('user', sa.Column('scim', sa.JSON(), nullable=True))
def downgrade() -> None:
op.drop_column("user", "scim")
op.drop_column('user', 'scim')

View File

@@ -12,21 +12,21 @@ import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c0fbf31ca0db"
down_revision: Union[str, None] = "ca81bd47c050"
revision: str = 'c0fbf31ca0db'
down_revision: Union[str, None] = 'ca81bd47c050'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("file", sa.Column("hash", sa.Text(), nullable=True))
op.add_column("file", sa.Column("data", sa.JSON(), nullable=True))
op.add_column("file", sa.Column("updated_at", sa.BigInteger(), nullable=True))
op.add_column('file', sa.Column('hash', sa.Text(), nullable=True))
op.add_column('file', sa.Column('data', sa.JSON(), nullable=True))
op.add_column('file', sa.Column('updated_at', sa.BigInteger(), nullable=True))
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("file", "updated_at")
op.drop_column("file", "data")
op.drop_column("file", "hash")
op.drop_column('file', 'updated_at')
op.drop_column('file', 'data')
op.drop_column('file', 'hash')

View File

@@ -12,35 +12,33 @@ import json
from sqlalchemy.sql import table, column
from sqlalchemy import String, Text, JSON, and_
revision = "c29facfe716b"
down_revision = "c69f45358db4"
revision = 'c29facfe716b'
down_revision = 'c69f45358db4'
branch_labels = None
depends_on = None
def upgrade():
# 1. Add the `path` column to the "file" table.
op.add_column("file", sa.Column("path", sa.Text(), nullable=True))
op.add_column('file', sa.Column('path', sa.Text(), nullable=True))
# 2. Convert the `meta` column from Text/JSONField to `JSON()`
# Use Alembic's default batch_op for dialect compatibility.
with op.batch_alter_table("file", schema=None) as batch_op:
with op.batch_alter_table('file', schema=None) as batch_op:
batch_op.alter_column(
"meta",
'meta',
type_=sa.JSON(),
existing_type=sa.Text(),
existing_nullable=True,
nullable=True,
postgresql_using="meta::json",
postgresql_using='meta::json',
)
# 3. Migrate legacy data from `meta` JSONField
# Fetch and process `meta` data from the table, add values to the new `path` column as necessary.
# We will use SQLAlchemy core bindings to ensure safety across different databases.
file_table = table(
"file", column("id", String), column("meta", JSON), column("path", Text)
)
file_table = table('file', column('id', String), column('meta', JSON), column('path', Text))
# Create connection to the database
connection = op.get_bind()
@@ -55,24 +53,18 @@ def upgrade():
# Iterate over each row to extract and update the `path` from `meta` column
for row in results:
if "path" in row.meta:
if 'path' in row.meta:
# Extract the `path` field from the `meta` JSON
path = row.meta.get("path")
path = row.meta.get('path')
# Update the `file` table with the new `path` value
connection.execute(
file_table.update()
.where(file_table.c.id == row.id)
.values({"path": path})
)
connection.execute(file_table.update().where(file_table.c.id == row.id).values({'path': path}))
def downgrade():
# 1. Remove the `path` column
op.drop_column("file", "path")
op.drop_column('file', 'path')
# 2. Revert the `meta` column back to Text/JSONField
with op.batch_alter_table("file", schema=None) as batch_op:
batch_op.alter_column(
"meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True
)
with op.batch_alter_table('file', schema=None) as batch_op:
batch_op.alter_column('meta', type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True)

View File

@@ -12,45 +12,43 @@ from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "c440947495f3"
down_revision: Union[str, None] = "81cc2ce44d79"
revision: str = 'c440947495f3'
down_revision: Union[str, None] = '81cc2ce44d79'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"chat_file",
sa.Column("id", sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False),
'chat_file',
sa.Column('id', sa.Text(), primary_key=True),
sa.Column('user_id', sa.Text(), nullable=False),
sa.Column(
"chat_id",
'chat_id',
sa.Text(),
sa.ForeignKey("chat.id", ondelete="CASCADE"),
sa.ForeignKey('chat.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column(
"file_id",
'file_id',
sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"),
sa.ForeignKey('file.id', ondelete='CASCADE'),
nullable=False,
),
sa.Column("message_id", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False),
sa.Column('message_id', sa.Text(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column('updated_at', sa.BigInteger(), nullable=False),
# indexes
sa.Index("ix_chat_file_chat_id", "chat_id"),
sa.Index("ix_chat_file_file_id", "file_id"),
sa.Index("ix_chat_file_message_id", "message_id"),
sa.Index("ix_chat_file_user_id", "user_id"),
sa.Index('ix_chat_file_chat_id', 'chat_id'),
sa.Index('ix_chat_file_file_id', 'file_id'),
sa.Index('ix_chat_file_message_id', 'message_id'),
sa.Index('ix_chat_file_user_id', 'user_id'),
# unique constraints
sa.UniqueConstraint(
"chat_id", "file_id", name="uq_chat_file_chat_file"
), # prevent duplicate entries
sa.UniqueConstraint('chat_id', 'file_id', name='uq_chat_file_chat_file'), # prevent duplicate entries
)
pass
def downgrade() -> None:
op.drop_table("chat_file")
op.drop_table('chat_file')
pass

View File

@@ -9,42 +9,40 @@ Create Date: 2024-10-16 02:02:35.241684
from alembic import op
import sqlalchemy as sa
revision = "c69f45358db4"
down_revision = "3ab32c4b8f59"
revision = 'c69f45358db4'
down_revision = '3ab32c4b8f59'
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"folder",
sa.Column("id", sa.Text(), nullable=False),
sa.Column("parent_id", sa.Text(), nullable=True),
sa.Column("user_id", sa.Text(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("items", sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True),
sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False),
'folder',
sa.Column('id', sa.Text(), nullable=False),
sa.Column('parent_id', sa.Text(), nullable=True),
sa.Column('user_id', sa.Text(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('items', sa.JSON(), nullable=True),
sa.Column('meta', sa.JSON(), nullable=True),
sa.Column('is_expanded', sa.Boolean(), default=False, nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.Column(
"created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False
),
sa.Column(
"updated_at",
'updated_at',
sa.DateTime(),
nullable=False,
server_default=sa.func.now(),
onupdate=sa.func.now(),
),
sa.PrimaryKeyConstraint("id", "user_id"),
sa.PrimaryKeyConstraint('id', 'user_id'),
)
op.add_column(
"chat",
sa.Column("folder_id", sa.Text(), nullable=True),
'chat',
sa.Column('folder_id', sa.Text(), nullable=True),
)
def downgrade():
op.drop_column("chat", "folder_id")
op.drop_column('chat', 'folder_id')
op.drop_table("folder")
op.drop_table('folder')

View File

@@ -12,23 +12,21 @@ import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "ca81bd47c050"
down_revision: Union[str, None] = "7e5b5dc7342b"
revision: str = 'ca81bd47c050'
down_revision: Union[str, None] = '7e5b5dc7342b'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade():
op.create_table(
"config",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("data", sa.JSON(), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
'config',
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('data', sa.JSON(), nullable=False),
sa.Column('version', sa.Integer, nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.Column(
"created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()
),
sa.Column(
"updated_at",
'updated_at',
sa.DateTime(),
nullable=True,
server_default=sa.func.now(),
@@ -38,4 +36,4 @@ def upgrade():
def downgrade():
op.drop_table("config")
op.drop_table('config')

View File

@@ -9,15 +9,15 @@ Create Date: 2025-07-13 03:00:00.000000
from alembic import op
import sqlalchemy as sa
revision = "d31026856c01"
down_revision = "9f0c9cd09105"
revision = 'd31026856c01'
down_revision = '9f0c9cd09105'
branch_labels = None
depends_on = None
def upgrade():
op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True))
op.add_column('folder', sa.Column('data', sa.JSON(), nullable=True))
def downgrade():
op.drop_column("folder", "data")
op.drop_column('folder', 'data')

View File

@@ -20,8 +20,8 @@ import sqlalchemy as sa
from open_webui.migrations.util import get_existing_tables
revision: str = "f1e2d3c4b5a6"
down_revision: Union[str, None] = "8452d01d26d7"
revision: str = 'f1e2d3c4b5a6'
down_revision: Union[str, None] = '8452d01d26d7'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -30,34 +30,34 @@ def upgrade() -> None:
existing_tables = set(get_existing_tables())
# Create access_grant table
if "access_grant" not in existing_tables:
if 'access_grant' not in existing_tables:
op.create_table(
"access_grant",
sa.Column("id", sa.Text(), nullable=False, primary_key=True),
sa.Column("resource_type", sa.Text(), nullable=False),
sa.Column("resource_id", sa.Text(), nullable=False),
sa.Column("principal_type", sa.Text(), nullable=False),
sa.Column("principal_id", sa.Text(), nullable=False),
sa.Column("permission", sa.Text(), nullable=False),
sa.Column("created_at", sa.BigInteger(), nullable=False),
'access_grant',
sa.Column('id', sa.Text(), nullable=False, primary_key=True),
sa.Column('resource_type', sa.Text(), nullable=False),
sa.Column('resource_id', sa.Text(), nullable=False),
sa.Column('principal_type', sa.Text(), nullable=False),
sa.Column('principal_id', sa.Text(), nullable=False),
sa.Column('permission', sa.Text(), nullable=False),
sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.UniqueConstraint(
"resource_type",
"resource_id",
"principal_type",
"principal_id",
"permission",
name="uq_access_grant_grant",
'resource_type',
'resource_id',
'principal_type',
'principal_id',
'permission',
name='uq_access_grant_grant',
),
)
op.create_index(
"idx_access_grant_resource",
"access_grant",
["resource_type", "resource_id"],
'idx_access_grant_resource',
'access_grant',
['resource_type', 'resource_id'],
)
op.create_index(
"idx_access_grant_principal",
"access_grant",
["principal_type", "principal_id"],
'idx_access_grant_principal',
'access_grant',
['principal_type', 'principal_id'],
)
# Backfill existing access_control JSON data
@@ -65,13 +65,13 @@ def upgrade() -> None:
# Tables with access_control JSON columns: (table_name, resource_type)
resource_tables = [
("knowledge", "knowledge"),
("prompt", "prompt"),
("tool", "tool"),
("model", "model"),
("note", "note"),
("channel", "channel"),
("file", "file"),
('knowledge', 'knowledge'),
('prompt', 'prompt'),
('tool', 'tool'),
('model', 'model'),
('note', 'note'),
('channel', 'channel'),
('file', 'file'),
]
now = int(time.time())
@@ -83,9 +83,7 @@ def upgrade() -> None:
# Query all rows
try:
result = conn.execute(
sa.text(f'SELECT id, access_control FROM "{table_name}"')
)
result = conn.execute(sa.text(f'SELECT id, access_control FROM "{table_name}"'))
rows = result.fetchall()
except Exception:
continue
@@ -99,19 +97,16 @@ def upgrade() -> None:
# EXCEPTION: files with NULL are PRIVATE (owner-only), not public
is_null = (
access_control_json is None
or access_control_json == "null"
or (
isinstance(access_control_json, str)
and access_control_json.strip().lower() == "null"
)
or access_control_json == 'null'
or (isinstance(access_control_json, str) and access_control_json.strip().lower() == 'null')
)
if is_null:
# Files: NULL = private (no entry needed, owner has implicit access)
# Other resources: NULL = public (insert user:* for read)
if resource_type == "file":
if resource_type == 'file':
continue # Private - no entry needed
key = (resource_type, resource_id, "user", "*", "read")
key = (resource_type, resource_id, 'user', '*', 'read')
if key not in inserted:
try:
conn.execute(
@@ -120,13 +115,13 @@ def upgrade() -> None:
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
"""),
{
"id": str(uuid.uuid4()),
"resource_type": resource_type,
"resource_id": resource_id,
"principal_type": "user",
"principal_id": "*",
"permission": "read",
"created_at": now,
'id': str(uuid.uuid4()),
'resource_type': resource_type,
'resource_id': resource_id,
'principal_type': 'user',
'principal_id': '*',
'permission': 'read',
'created_at': now,
},
)
inserted.add(key)
@@ -149,28 +144,24 @@ def upgrade() -> None:
continue
# Check if it's effectively empty (no read/write keys with content)
read_data = access_control_json.get("read", {})
write_data = access_control_json.get("write", {})
read_data = access_control_json.get('read', {})
write_data = access_control_json.get('write', {})
has_read_grants = read_data.get("group_ids", []) or read_data.get(
"user_ids", []
)
has_write_grants = write_data.get("group_ids", []) or write_data.get(
"user_ids", []
)
has_read_grants = read_data.get('group_ids', []) or read_data.get('user_ids', [])
has_write_grants = write_data.get('group_ids', []) or write_data.get('user_ids', [])
if not has_read_grants and not has_write_grants:
# Empty permissions = private, no grants needed
continue
# Extract permissions and insert into access_grant table
for permission in ["read", "write"]:
for permission in ['read', 'write']:
perm_data = access_control_json.get(permission, {})
if not perm_data:
continue
for group_id in perm_data.get("group_ids", []):
key = (resource_type, resource_id, "group", group_id, permission)
for group_id in perm_data.get('group_ids', []):
key = (resource_type, resource_id, 'group', group_id, permission)
if key in inserted:
continue
try:
@@ -180,21 +171,21 @@ def upgrade() -> None:
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
"""),
{
"id": str(uuid.uuid4()),
"resource_type": resource_type,
"resource_id": resource_id,
"principal_type": "group",
"principal_id": group_id,
"permission": permission,
"created_at": now,
'id': str(uuid.uuid4()),
'resource_type': resource_type,
'resource_id': resource_id,
'principal_type': 'group',
'principal_id': group_id,
'permission': permission,
'created_at': now,
},
)
inserted.add(key)
except Exception:
pass
for user_id in perm_data.get("user_ids", []):
key = (resource_type, resource_id, "user", user_id, permission)
for user_id in perm_data.get('user_ids', []):
key = (resource_type, resource_id, 'user', user_id, permission)
if key in inserted:
continue
try:
@@ -204,13 +195,13 @@ def upgrade() -> None:
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
"""),
{
"id": str(uuid.uuid4()),
"resource_type": resource_type,
"resource_id": resource_id,
"principal_type": "user",
"principal_id": user_id,
"permission": permission,
"created_at": now,
'id': str(uuid.uuid4()),
'resource_type': resource_type,
'resource_id': resource_id,
'principal_type': 'user',
'principal_id': user_id,
'permission': permission,
'created_at': now,
},
)
inserted.add(key)
@@ -223,7 +214,7 @@ def upgrade() -> None:
continue
try:
with op.batch_alter_table(table_name) as batch:
batch.drop_column("access_control")
batch.drop_column('access_control')
except Exception:
pass
@@ -235,20 +226,20 @@ def downgrade() -> None:
# Resource tables mapping: (table_name, resource_type)
resource_tables = [
("knowledge", "knowledge"),
("prompt", "prompt"),
("tool", "tool"),
("model", "model"),
("note", "note"),
("channel", "channel"),
("file", "file"),
('knowledge', 'knowledge'),
('prompt', 'prompt'),
('tool', 'tool'),
('model', 'model'),
('note', 'note'),
('channel', 'channel'),
('file', 'file'),
]
# Step 1: Re-add access_control columns to resource tables
for table_name, _ in resource_tables:
try:
with op.batch_alter_table(table_name) as batch:
batch.add_column(sa.Column("access_control", sa.JSON(), nullable=True))
batch.add_column(sa.Column('access_control', sa.JSON(), nullable=True))
except Exception:
pass
@@ -262,7 +253,7 @@ def downgrade() -> None:
FROM access_grant
WHERE resource_type = :resource_type
"""),
{"resource_type": resource_type},
{'resource_type': resource_type},
)
rows = result.fetchall()
except Exception:
@@ -278,49 +269,35 @@ def downgrade() -> None:
if resource_id not in resource_grants:
resource_grants[resource_id] = {
"is_public": False,
"read": {"group_ids": [], "user_ids": []},
"write": {"group_ids": [], "user_ids": []},
'is_public': False,
'read': {'group_ids': [], 'user_ids': []},
'write': {'group_ids': [], 'user_ids': []},
}
# Handle public access (user:* for read)
if (
principal_type == "user"
and principal_id == "*"
and permission == "read"
):
resource_grants[resource_id]["is_public"] = True
if principal_type == 'user' and principal_id == '*' and permission == 'read':
resource_grants[resource_id]['is_public'] = True
continue
# Add to appropriate list
if permission in ["read", "write"]:
if principal_type == "group":
if (
principal_id
not in resource_grants[resource_id][permission]["group_ids"]
):
resource_grants[resource_id][permission]["group_ids"].append(
principal_id
)
elif principal_type == "user":
if (
principal_id
not in resource_grants[resource_id][permission]["user_ids"]
):
resource_grants[resource_id][permission]["user_ids"].append(
principal_id
)
if permission in ['read', 'write']:
if principal_type == 'group':
if principal_id not in resource_grants[resource_id][permission]['group_ids']:
resource_grants[resource_id][permission]['group_ids'].append(principal_id)
elif principal_type == 'user':
if principal_id not in resource_grants[resource_id][permission]['user_ids']:
resource_grants[resource_id][permission]['user_ids'].append(principal_id)
# Step 3: Update each resource with reconstructed JSON
for resource_id, grants in resource_grants.items():
if grants["is_public"]:
if grants['is_public']:
# Public = NULL
access_control_value = None
elif (
not grants["read"]["group_ids"]
and not grants["read"]["user_ids"]
and not grants["write"]["group_ids"]
and not grants["write"]["user_ids"]
not grants['read']['group_ids']
and not grants['read']['user_ids']
and not grants['write']['group_ids']
and not grants['write']['user_ids']
):
# No grants = should not happen (would mean no entries), default to {}
access_control_value = json.dumps({})
@@ -328,17 +305,15 @@ def downgrade() -> None:
# Custom permissions
access_control_value = json.dumps(
{
"read": grants["read"],
"write": grants["write"],
'read': grants['read'],
'write': grants['write'],
}
)
try:
conn.execute(
sa.text(
f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id'
),
{"access_control": access_control_value, "id": resource_id},
sa.text(f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id'),
{'access_control': access_control_value, 'id': resource_id},
)
except Exception:
pass
@@ -346,7 +321,7 @@ def downgrade() -> None:
# Step 4: Set all resources WITHOUT entries to private
# For files: NULL means private (owner-only), so leave as NULL
# For other resources: {} means private, so update to {}
if resource_type != "file":
if resource_type != 'file':
try:
conn.execute(
sa.text(f"""
@@ -357,13 +332,13 @@ def downgrade() -> None:
)
AND access_control IS NULL
"""),
{"private_value": json.dumps({}), "resource_type": resource_type},
{'private_value': json.dumps({}), 'resource_type': resource_type},
)
except Exception:
pass
# For files, NULL stays NULL - no action needed
# Step 5: Drop the access_grant table
op.drop_index("idx_access_grant_principal", table_name="access_grant")
op.drop_index("idx_access_grant_resource", table_name="access_grant")
op.drop_table("access_grant")
op.drop_index('idx_access_grant_principal', table_name='access_grant')
op.drop_index('idx_access_grant_resource', table_name='access_grant')
op.drop_table('access_grant')

View File

@@ -19,28 +19,24 @@ log = logging.getLogger(__name__)
class AccessGrant(Base):
__tablename__ = "access_grant"
__tablename__ = 'access_grant'
id = Column(Text, primary_key=True)
resource_type = Column(
Text, nullable=False
) # "knowledge", "model", "prompt", "tool", "note", "channel", "file"
resource_type = Column(Text, nullable=False) # "knowledge", "model", "prompt", "tool", "note", "channel", "file"
resource_id = Column(Text, nullable=False)
principal_type = Column(Text, nullable=False) # "user" or "group"
principal_id = Column(
Text, nullable=False
) # user_id, group_id, or "*" (wildcard for public)
principal_id = Column(Text, nullable=False) # user_id, group_id, or "*" (wildcard for public)
permission = Column(Text, nullable=False) # "read" or "write"
created_at = Column(BigInteger, nullable=False)
__table_args__ = (
UniqueConstraint(
"resource_type",
"resource_id",
"principal_type",
"principal_id",
"permission",
name="uq_access_grant_grant",
'resource_type',
'resource_id',
'principal_type',
'principal_id',
'permission',
name='uq_access_grant_grant',
),
)
@@ -66,7 +62,7 @@ class AccessGrantResponse(BaseModel):
permission: str
@classmethod
def from_grant(cls, grant: "AccessGrantModel") -> "AccessGrantResponse":
def from_grant(cls, grant: 'AccessGrantModel') -> 'AccessGrantResponse':
return cls(
id=grant.id,
principal_type=grant.principal_type,
@@ -100,14 +96,14 @@ def access_control_to_grants(
if access_control is None:
# NULL → public read (user:* for read)
# Exception: files with NULL are private (owner-only), no grants needed
if resource_type != "file":
if resource_type != 'file':
grants.append(
{
"resource_type": resource_type,
"resource_id": resource_id,
"principal_type": "user",
"principal_id": "*",
"permission": "read",
'resource_type': resource_type,
'resource_id': resource_id,
'principal_type': 'user',
'principal_id': '*',
'permission': 'read',
}
)
return grants
@@ -117,30 +113,30 @@ def access_control_to_grants(
return grants
# Parse structured permissions
for permission in ["read", "write"]:
for permission in ['read', 'write']:
perm_data = access_control.get(permission, {})
if not perm_data:
continue
for group_id in perm_data.get("group_ids", []):
for group_id in perm_data.get('group_ids', []):
grants.append(
{
"resource_type": resource_type,
"resource_id": resource_id,
"principal_type": "group",
"principal_id": group_id,
"permission": permission,
'resource_type': resource_type,
'resource_id': resource_id,
'principal_type': 'group',
'principal_id': group_id,
'permission': permission,
}
)
for user_id in perm_data.get("user_ids", []):
for user_id in perm_data.get('user_ids', []):
grants.append(
{
"resource_type": resource_type,
"resource_id": resource_id,
"principal_type": "user",
"principal_id": user_id,
"permission": permission,
'resource_type': resource_type,
'resource_id': resource_id,
'principal_type': 'user',
'principal_id': user_id,
'permission': permission,
}
)
@@ -164,27 +160,23 @@ def normalize_access_grants(access_grants: Optional[list]) -> list[dict]:
if not isinstance(grant, dict):
continue
principal_type = grant.get("principal_type")
principal_id = grant.get("principal_id")
permission = grant.get("permission")
principal_type = grant.get('principal_type')
principal_id = grant.get('principal_id')
permission = grant.get('permission')
if principal_type not in ("user", "group"):
if principal_type not in ('user', 'group'):
continue
if permission not in ("read", "write"):
if permission not in ('read', 'write'):
continue
if not isinstance(principal_id, str) or not principal_id:
continue
key = (principal_type, principal_id, permission)
deduped[key] = {
"id": (
grant.get("id")
if isinstance(grant.get("id"), str) and grant.get("id")
else str(uuid.uuid4())
),
"principal_type": principal_type,
"principal_id": principal_id,
"permission": permission,
'id': (grant.get('id') if isinstance(grant.get('id'), str) and grant.get('id') else str(uuid.uuid4())),
'principal_type': principal_type,
'principal_id': principal_id,
'permission': permission,
}
return list(deduped.values())
@@ -195,11 +187,7 @@ def has_public_read_access_grant(access_grants: Optional[list]) -> bool:
Returns True when a direct grant list includes wildcard public-read.
"""
for grant in normalize_access_grants(access_grants):
if (
grant["principal_type"] == "user"
and grant["principal_id"] == "*"
and grant["permission"] == "read"
):
if grant['principal_type'] == 'user' and grant['principal_id'] == '*' and grant['permission'] == 'read':
return True
return False
@@ -209,7 +197,7 @@ def has_user_access_grant(access_grants: Optional[list]) -> bool:
Returns True when a direct grant list includes any non-wildcard user grant.
"""
for grant in normalize_access_grants(access_grants):
if grant["principal_type"] == "user" and grant["principal_id"] != "*":
if grant['principal_type'] == 'user' and grant['principal_id'] != '*':
return True
return False
@@ -225,18 +213,9 @@ def strip_user_access_grants(access_grants: Optional[list]) -> list:
grant
for grant in access_grants
if not (
(
grant.get("principal_type")
if isinstance(grant, dict)
else getattr(grant, "principal_type", None)
)
== "user"
and (
grant.get("principal_id")
if isinstance(grant, dict)
else getattr(grant, "principal_id", None)
)
!= "*"
(grant.get('principal_type') if isinstance(grant, dict) else getattr(grant, 'principal_type', None))
== 'user'
and (grant.get('principal_id') if isinstance(grant, dict) else getattr(grant, 'principal_id', None)) != '*'
)
]
@@ -260,29 +239,25 @@ def grants_to_access_control(grants: list) -> Optional[dict]:
return {} # No grants = private/owner-only
result = {
"read": {"group_ids": [], "user_ids": []},
"write": {"group_ids": [], "user_ids": []},
'read': {'group_ids': [], 'user_ids': []},
'write': {'group_ids': [], 'user_ids': []},
}
is_public = False
for grant in grants:
if (
grant.principal_type == "user"
and grant.principal_id == "*"
and grant.permission == "read"
):
if grant.principal_type == 'user' and grant.principal_id == '*' and grant.permission == 'read':
is_public = True
continue # Don't add wildcard to user_ids list
if grant.permission not in ("read", "write"):
if grant.permission not in ('read', 'write'):
continue
if grant.principal_type == "group":
if grant.principal_id not in result[grant.permission]["group_ids"]:
result[grant.permission]["group_ids"].append(grant.principal_id)
elif grant.principal_type == "user":
if grant.principal_id not in result[grant.permission]["user_ids"]:
result[grant.permission]["user_ids"].append(grant.principal_id)
if grant.principal_type == 'group':
if grant.principal_id not in result[grant.permission]['group_ids']:
result[grant.permission]['group_ids'].append(grant.principal_id)
elif grant.principal_type == 'user':
if grant.principal_id not in result[grant.permission]['user_ids']:
result[grant.permission]['user_ids'].append(grant.principal_id)
if is_public:
return None # Public read access
@@ -399,9 +374,7 @@ class AccessGrantsTable:
).delete()
# Convert JSON to grant dicts
grant_dicts = access_control_to_grants(
resource_type, resource_id, access_control
)
grant_dicts = access_control_to_grants(resource_type, resource_id, access_control)
# Insert new grants
results = []
@@ -442,9 +415,9 @@ class AccessGrantsTable:
id=str(uuid.uuid4()),
resource_type=resource_type,
resource_id=resource_id,
principal_type=grant_dict["principal_type"],
principal_id=grant_dict["principal_id"],
permission=grant_dict["permission"],
principal_type=grant_dict['principal_type'],
principal_id=grant_dict['principal_id'],
permission=grant_dict['permission'],
created_at=int(time.time()),
)
db.add(grant)
@@ -511,9 +484,7 @@ class AccessGrantsTable:
)
.all()
)
result: dict[str, list[AccessGrantModel]] = {
rid: [] for rid in resource_ids
}
result: dict[str, list[AccessGrantModel]] = {rid: [] for rid in resource_ids}
for g in grants:
result[g.resource_id].append(AccessGrantModel.model_validate(g))
return result
@@ -523,7 +494,7 @@ class AccessGrantsTable:
user_id: str,
resource_type: str,
resource_id: str,
permission: str = "read",
permission: str = 'read',
user_group_ids: Optional[set[str]] = None,
db: Optional[Session] = None,
) -> bool:
@@ -540,12 +511,12 @@ class AccessGrantsTable:
conditions = [
# Public access
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_id == "*",
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == '*',
),
# Direct user access
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id,
),
]
@@ -560,7 +531,7 @@ class AccessGrantsTable:
if user_group_ids:
conditions.append(
and_(
AccessGrant.principal_type == "group",
AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(user_group_ids),
)
)
@@ -582,7 +553,7 @@ class AccessGrantsTable:
user_id: str,
resource_type: str,
resource_ids: list[str],
permission: str = "read",
permission: str = 'read',
user_group_ids: Optional[set[str]] = None,
db: Optional[Session] = None,
) -> set[str]:
@@ -597,11 +568,11 @@ class AccessGrantsTable:
with get_db_context(db) as db:
conditions = [
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_id == "*",
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == '*',
),
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id,
),
]
@@ -615,7 +586,7 @@ class AccessGrantsTable:
if user_group_ids:
conditions.append(
and_(
AccessGrant.principal_type == "group",
AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(user_group_ids),
)
)
@@ -637,7 +608,7 @@ class AccessGrantsTable:
self,
resource_type: str,
resource_id: str,
permission: str = "read",
permission: str = 'read',
db: Optional[Session] = None,
) -> list:
"""
@@ -660,19 +631,17 @@ class AccessGrantsTable:
# Check for public access
for grant in grants:
if grant.principal_type == "user" and grant.principal_id == "*":
result = Users.get_users(filter={"roles": ["!pending"]}, db=db)
return result.get("users", [])
if grant.principal_type == 'user' and grant.principal_id == '*':
result = Users.get_users(filter={'roles': ['!pending']}, db=db)
return result.get('users', [])
user_ids_with_access = set()
for grant in grants:
if grant.principal_type == "user":
if grant.principal_type == 'user':
user_ids_with_access.add(grant.principal_id)
elif grant.principal_type == "group":
group_user_ids = Groups.get_group_user_ids_by_id(
grant.principal_id, db=db
)
elif grant.principal_type == 'group':
group_user_ids = Groups.get_group_user_ids_by_id(grant.principal_id, db=db)
if group_user_ids:
user_ids_with_access.update(group_user_ids)
@@ -688,20 +657,18 @@ class AccessGrantsTable:
DocumentModel,
filter: dict,
resource_type: str,
permission: str = "read",
permission: str = 'read',
):
"""
Apply access control filtering to a SQLAlchemy query by JOINing with access_grant.
This replaces the old JSON-column-based filtering with a proper relational JOIN.
"""
group_ids = filter.get("group_ids", [])
user_id = filter.get("user_id")
group_ids = filter.get('group_ids', [])
user_id = filter.get('user_id')
if permission == "read_only":
return self._has_read_only_permission_filter(
db, query, DocumentModel, filter, resource_type
)
if permission == 'read_only':
return self._has_read_only_permission_filter(db, query, DocumentModel, filter, resource_type)
# Build principal conditions
principal_conditions = []
@@ -710,8 +677,8 @@ class AccessGrantsTable:
# Public access: user:* read
principal_conditions.append(
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_id == "*",
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == '*',
)
)
@@ -722,7 +689,7 @@ class AccessGrantsTable:
# Direct user grant
principal_conditions.append(
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id,
)
)
@@ -731,7 +698,7 @@ class AccessGrantsTable:
# Group grants
principal_conditions.append(
and_(
AccessGrant.principal_type == "group",
AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(group_ids),
)
)
@@ -751,13 +718,13 @@ class AccessGrantsTable:
AccessGrant.permission == permission,
or_(
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_id == "*",
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == '*',
),
*(
[
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id,
)
]
@@ -767,7 +734,7 @@ class AccessGrantsTable:
*(
[
and_(
AccessGrant.principal_type == "group",
AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(group_ids),
)
]
@@ -800,8 +767,8 @@ class AccessGrantsTable:
Filter for items where user has read BUT NOT write access.
Public items are NOT considered read_only.
"""
group_ids = filter.get("group_ids", [])
user_id = filter.get("user_id")
group_ids = filter.get('group_ids', [])
user_id = filter.get('user_id')
from sqlalchemy import exists as sa_exists, select
@@ -811,12 +778,12 @@ class AccessGrantsTable:
.where(
AccessGrant.resource_type == resource_type,
AccessGrant.resource_id == DocumentModel.id,
AccessGrant.permission == "read",
AccessGrant.permission == 'read',
or_(
*(
[
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id,
)
]
@@ -826,7 +793,7 @@ class AccessGrantsTable:
*(
[
and_(
AccessGrant.principal_type == "group",
AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(group_ids),
)
]
@@ -845,12 +812,12 @@ class AccessGrantsTable:
.where(
AccessGrant.resource_type == resource_type,
AccessGrant.resource_id == DocumentModel.id,
AccessGrant.permission == "write",
AccessGrant.permission == 'write',
or_(
*(
[
and_(
AccessGrant.principal_type == "user",
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id,
)
]
@@ -860,7 +827,7 @@ class AccessGrantsTable:
*(
[
and_(
AccessGrant.principal_type == "group",
AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(group_ids),
)
]
@@ -879,9 +846,9 @@ class AccessGrantsTable:
.where(
AccessGrant.resource_type == resource_type,
AccessGrant.resource_id == DocumentModel.id,
AccessGrant.permission == "read",
AccessGrant.principal_type == "user",
AccessGrant.principal_id == "*",
AccessGrant.permission == 'read',
AccessGrant.principal_type == 'user',
AccessGrant.principal_id == '*',
)
.correlate(DocumentModel)
.exists()

View File

@@ -17,7 +17,7 @@ log = logging.getLogger(__name__)
class Auth(Base):
__tablename__ = "auth"
__tablename__ = 'auth'
id = Column(String, primary_key=True, unique=True)
email = Column(String)
@@ -73,9 +73,9 @@ class SignupForm(BaseModel):
name: str
email: str
password: str
profile_image_url: Optional[str] = "/user.png"
profile_image_url: Optional[str] = '/user.png'
@field_validator("profile_image_url")
@field_validator('profile_image_url')
@classmethod
def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]:
if v is not None:
@@ -84,7 +84,7 @@ class SignupForm(BaseModel):
class AddUserForm(SignupForm):
role: Optional[str] = "pending"
role: Optional[str] = 'pending'
class AuthsTable:
@@ -93,25 +93,21 @@ class AuthsTable:
email: str,
password: str,
name: str,
profile_image_url: str = "/user.png",
role: str = "pending",
profile_image_url: str = '/user.png',
role: str = 'pending',
oauth: Optional[dict] = None,
db: Optional[Session] = None,
) -> Optional[UserModel]:
with get_db_context(db) as db:
log.info("insert_new_auth")
log.info('insert_new_auth')
id = str(uuid.uuid4())
auth = AuthModel(
**{"id": id, "email": email, "password": password, "active": True}
)
auth = AuthModel(**{'id': id, 'email': email, 'password': password, 'active': True})
result = Auth(**auth.model_dump())
db.add(result)
user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth=oauth, db=db
)
user = Users.insert_new_user(id, name, email, profile_image_url, role, oauth=oauth, db=db)
db.commit()
db.refresh(result)
@@ -124,7 +120,7 @@ class AuthsTable:
def authenticate_user(
self, email: str, verify_password: callable, db: Optional[Session] = None
) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
log.info(f'authenticate_user: {email}')
user = Users.get_user_by_email(email, db=db)
if not user:
@@ -143,10 +139,8 @@ class AuthsTable:
except Exception:
return None
def authenticate_user_by_api_key(
self, api_key: str, db: Optional[Session] = None
) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key")
def authenticate_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
log.info(f'authenticate_user_by_api_key')
# if no api_key, return None
if not api_key:
return None
@@ -157,10 +151,8 @@ class AuthsTable:
except Exception:
return False
def authenticate_user_by_email(
self, email: str, db: Optional[Session] = None
) -> Optional[UserModel]:
log.info(f"authenticate_user_by_email: {email}")
def authenticate_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
log.info(f'authenticate_user_by_email: {email}')
try:
with get_db_context(db) as db:
# Single JOIN query instead of two separate queries
@@ -177,28 +169,22 @@ class AuthsTable:
except Exception:
return None
def update_user_password_by_id(
self, id: str, new_password: str, db: Optional[Session] = None
) -> bool:
def update_user_password_by_id(self, id: str, new_password: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
result = (
db.query(Auth).filter_by(id=id).update({"password": new_password})
)
result = db.query(Auth).filter_by(id=id).update({'password': new_password})
db.commit()
return True if result == 1 else False
except Exception:
return False
def update_email_by_id(
self, id: str, email: str, db: Optional[Session] = None
) -> bool:
def update_email_by_id(self, id: str, email: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
result = db.query(Auth).filter_by(id=id).update({"email": email})
result = db.query(Auth).filter_by(id=id).update({'email': email})
db.commit()
if result == 1:
Users.update_user_by_id(id, {"email": email}, db=db)
Users.update_user_by_id(id, {'email': email}, db=db)
return True
return False
except Exception:

View File

@@ -37,7 +37,7 @@ from sqlalchemy.sql import exists
class Channel(Base):
__tablename__ = "channel"
__tablename__ = 'channel'
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text)
@@ -94,7 +94,7 @@ class ChannelModel(BaseModel):
class ChannelMember(Base):
__tablename__ = "channel_member"
__tablename__ = 'channel_member'
id = Column(Text, primary_key=True, unique=True)
channel_id = Column(Text, nullable=False)
@@ -154,25 +154,19 @@ class ChannelMemberModel(BaseModel):
class ChannelFile(Base):
__tablename__ = "channel_file"
__tablename__ = 'channel_file'
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text, nullable=False)
channel_id = Column(
Text, ForeignKey("channel.id", ondelete="CASCADE"), nullable=False
)
message_id = Column(
Text, ForeignKey("message.id", ondelete="CASCADE"), nullable=True
)
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
channel_id = Column(Text, ForeignKey('channel.id', ondelete='CASCADE'), nullable=False)
message_id = Column(Text, ForeignKey('message.id', ondelete='CASCADE'), nullable=True)
file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
__table_args__ = (
UniqueConstraint("channel_id", "file_id", name="uq_channel_file_channel_file"),
)
__table_args__ = (UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'),)
class ChannelFileModel(BaseModel):
@@ -189,7 +183,7 @@ class ChannelFileModel(BaseModel):
class ChannelWebhook(Base):
__tablename__ = "channel_webhook"
__tablename__ = 'channel_webhook'
id = Column(Text, primary_key=True, unique=True)
channel_id = Column(Text, nullable=False)
@@ -235,7 +229,7 @@ class ChannelResponse(ChannelModel):
class ChannelForm(BaseModel):
name: str = ""
name: str = ''
description: Optional[str] = None
is_private: Optional[bool] = None
data: Optional[dict] = None
@@ -255,10 +249,8 @@ class ChannelWebhookForm(BaseModel):
class ChannelTable:
def _get_access_grants(
self, channel_id: str, db: Optional[Session] = None
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("channel", channel_id, db=db)
def _get_access_grants(self, channel_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource('channel', channel_id, db=db)
def _to_channel_model(
self,
@@ -266,13 +258,9 @@ class ChannelTable:
access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None,
) -> ChannelModel:
channel_data = ChannelModel.model_validate(channel).model_dump(
exclude={"access_grants"}
)
channel_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(channel_data["id"], db=db)
channel_data = ChannelModel.model_validate(channel).model_dump(exclude={'access_grants'})
channel_data['access_grants'] = (
access_grants if access_grants is not None else self._get_access_grants(channel_data['id'], db=db)
)
return ChannelModel.model_validate(channel_data)
@@ -313,20 +301,20 @@ class ChannelTable:
for uid in user_ids:
model = ChannelMemberModel(
**{
"id": str(uuid.uuid4()),
"channel_id": channel_id,
"user_id": uid,
"status": "joined",
"is_active": True,
"is_channel_muted": False,
"is_channel_pinned": False,
"invited_at": now,
"invited_by": invited_by,
"joined_at": now,
"left_at": None,
"last_read_at": now,
"created_at": now,
"updated_at": now,
'id': str(uuid.uuid4()),
'channel_id': channel_id,
'user_id': uid,
'status': 'joined',
'is_active': True,
'is_channel_muted': False,
'is_channel_pinned': False,
'invited_at': now,
'invited_by': invited_by,
'joined_at': now,
'left_at': None,
'last_read_at': now,
'created_at': now,
'updated_at': now,
}
)
memberships.append(ChannelMember(**model.model_dump()))
@@ -339,19 +327,19 @@ class ChannelTable:
with get_db_context(db) as db:
channel = ChannelModel(
**{
**form_data.model_dump(exclude={"access_grants"}),
"type": form_data.type if form_data.type else None,
"name": form_data.name.lower(),
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
"access_grants": [],
**form_data.model_dump(exclude={'access_grants'}),
'type': form_data.type if form_data.type else None,
'name': form_data.name.lower(),
'id': str(uuid.uuid4()),
'user_id': user_id,
'created_at': int(time.time_ns()),
'updated_at': int(time.time_ns()),
'access_grants': [],
}
)
new_channel = Channel(**channel.model_dump(exclude={"access_grants"}))
new_channel = Channel(**channel.model_dump(exclude={'access_grants'}))
if form_data.type in ["group", "dm"]:
if form_data.type in ['group', 'dm']:
users = self._collect_unique_user_ids(
invited_by=user_id,
user_ids=form_data.user_ids,
@@ -366,18 +354,14 @@ class ChannelTable:
db.add_all(memberships)
db.add(new_channel)
db.commit()
AccessGrants.set_access_grants(
"channel", new_channel.id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('channel', new_channel.id, form_data.access_grants, db=db)
return self._to_channel_model(new_channel, db=db)
def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]:
with get_db_context(db) as db:
channels = db.query(Channel).all()
channel_ids = [channel.id for channel in channels]
grants_map = AccessGrants.get_grants_by_resources(
"channel", channel_ids, db=db
)
grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
return [
self._to_channel_model(
channel,
@@ -387,23 +371,19 @@ class ChannelTable:
for channel in channels
]
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
return AccessGrants.has_permission_filter(
db=db,
query=query,
DocumentModel=Channel,
filter=filter,
resource_type="channel",
resource_type='channel',
permission=permission,
)
def get_channels_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> list[ChannelModel]:
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
with get_db_context(db) as db:
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
]
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
membership_channels = (
db.query(Channel)
@@ -411,7 +391,7 @@ class ChannelTable:
.filter(
Channel.deleted_at.is_(None),
Channel.archived_at.is_(None),
Channel.type.in_(["group", "dm"]),
Channel.type.in_(['group', 'dm']),
ChannelMember.user_id == user_id,
ChannelMember.is_active.is_(True),
)
@@ -423,29 +403,20 @@ class ChannelTable:
Channel.archived_at.is_(None),
or_(
Channel.type.is_(None), # True NULL/None
Channel.type == "", # Empty string
and_(Channel.type != "group", Channel.type != "dm"),
Channel.type == '', # Empty string
and_(Channel.type != 'group', Channel.type != 'dm'),
),
)
query = self._has_permission(
db, query, {"user_id": user_id, "group_ids": user_group_ids}
)
query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids})
standard_channels = query.all()
all_channels = membership_channels + standard_channels
channel_ids = [c.id for c in all_channels]
grants_map = AccessGrants.get_grants_by_resources(
"channel", channel_ids, db=db
)
return [
self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db)
for c in all_channels
]
grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
return [self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) for c in all_channels]
def get_dm_channel_by_user_ids(
self, user_ids: list[str], db: Optional[Session] = None
) -> Optional[ChannelModel]:
def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]:
with get_db_context(db) as db:
# Ensure uniqueness in case a list with duplicates is passed
unique_user_ids = list(set(user_ids))
@@ -471,7 +442,7 @@ class ChannelTable:
db.query(Channel)
.filter(
Channel.id.in_(subquery),
Channel.type == "dm",
Channel.type == 'dm',
)
.first()
)
@@ -488,32 +459,23 @@ class ChannelTable:
) -> list[ChannelMemberModel]:
with get_db_context(db) as db:
# 1. Collect all user_ids including groups + inviter
requested_users = self._collect_unique_user_ids(
invited_by, user_ids, group_ids
)
requested_users = self._collect_unique_user_ids(invited_by, user_ids, group_ids)
existing_users = {
row.user_id
for row in db.query(ChannelMember.user_id)
.filter(ChannelMember.channel_id == channel_id)
.all()
for row in db.query(ChannelMember.user_id).filter(ChannelMember.channel_id == channel_id).all()
}
new_user_ids = requested_users - existing_users
if not new_user_ids:
return [] # Nothing to add
new_memberships = self._create_membership_models(
channel_id, invited_by, new_user_ids
)
new_memberships = self._create_membership_models(channel_id, invited_by, new_user_ids)
db.add_all(new_memberships)
db.commit()
return [
ChannelMemberModel.model_validate(membership)
for membership in new_memberships
]
return [ChannelMemberModel.model_validate(membership) for membership in new_memberships]
def remove_members_from_channel(
self,
@@ -533,9 +495,7 @@ class ChannelTable:
db.commit()
return result # number of rows deleted
def is_user_channel_manager(
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
# Check if the user is the creator of the channel
# or has a 'manager' role in ChannelMember
@@ -548,15 +508,13 @@ class ChannelTable:
.filter(
ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id,
ChannelMember.role == "manager",
ChannelMember.role == 'manager',
)
.first()
)
return membership is not None
def join_channel(
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelMemberModel]:
def join_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChannelMemberModel]:
with get_db_context(db) as db:
# Check if the membership already exists
existing_membership = (
@@ -573,18 +531,18 @@ class ChannelTable:
# Create new membership
channel_member = ChannelMemberModel(
**{
"id": str(uuid.uuid4()),
"channel_id": channel_id,
"user_id": user_id,
"status": "joined",
"is_active": True,
"is_channel_muted": False,
"is_channel_pinned": False,
"joined_at": int(time.time_ns()),
"left_at": None,
"last_read_at": int(time.time_ns()),
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
'id': str(uuid.uuid4()),
'channel_id': channel_id,
'user_id': user_id,
'status': 'joined',
'is_active': True,
'is_channel_muted': False,
'is_channel_pinned': False,
'joined_at': int(time.time_ns()),
'left_at': None,
'last_read_at': int(time.time_ns()),
'created_at': int(time.time_ns()),
'updated_at': int(time.time_ns()),
}
)
new_membership = ChannelMember(**channel_member.model_dump())
@@ -593,9 +551,7 @@ class ChannelTable:
db.commit()
return channel_member
def leave_channel(
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
@@ -608,7 +564,7 @@ class ChannelTable:
if not membership:
return False
membership.status = "left"
membership.status = 'left'
membership.is_active = False
membership.left_at = int(time.time_ns())
membership.updated_at = int(time.time_ns())
@@ -630,19 +586,10 @@ class ChannelTable:
)
return ChannelMemberModel.model_validate(membership) if membership else None
def get_members_by_channel_id(
self, channel_id: str, db: Optional[Session] = None
) -> list[ChannelMemberModel]:
def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]:
with get_db_context(db) as db:
memberships = (
db.query(ChannelMember)
.filter(ChannelMember.channel_id == channel_id)
.all()
)
return [
ChannelMemberModel.model_validate(membership)
for membership in memberships
]
memberships = db.query(ChannelMember).filter(ChannelMember.channel_id == channel_id).all()
return [ChannelMemberModel.model_validate(membership) for membership in memberships]
def pin_channel(
self,
@@ -669,9 +616,7 @@ class ChannelTable:
db.commit()
return True
def update_member_last_read_at(
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
@@ -715,9 +660,7 @@ class ChannelTable:
db.commit()
return True
def is_user_channel_member(
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
membership = (
db.query(ChannelMember)
@@ -729,9 +672,7 @@ class ChannelTable:
)
return membership is not None
def get_channel_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[ChannelModel]:
def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]:
try:
with get_db_context(db) as db:
channel = db.query(Channel).filter(Channel.id == id).first()
@@ -739,18 +680,12 @@ class ChannelTable:
except Exception:
return None
def get_channels_by_file_id(
self, file_id: str, db: Optional[Session] = None
) -> list[ChannelModel]:
def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
with get_db_context(db) as db:
channel_files = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
)
channel_files = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
channel_ids = [cf.channel_id for cf in channel_files]
channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all()
grants_map = AccessGrants.get_grants_by_resources(
"channel", channel_ids, db=db
)
grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
return [
self._to_channel_model(
channel,
@@ -765,9 +700,7 @@ class ChannelTable:
) -> list[ChannelModel]:
with get_db_context(db) as db:
# 1. Determine which channels have this file
channel_file_rows = (
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
)
channel_file_rows = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
channel_ids = [row.channel_id for row in channel_file_rows]
if not channel_ids:
@@ -787,15 +720,13 @@ class ChannelTable:
return []
# Preload user's group membership
user_group_ids = [
g.id for g in Groups.get_groups_by_member_id(user_id, db=db)
]
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)]
allowed_channels = []
for channel in channels:
# --- Case A: group or dm => user must be an active member ---
if channel.type in ["group", "dm"]:
if channel.type in ['group', 'dm']:
membership = (
db.query(ChannelMember)
.filter(
@@ -815,8 +746,8 @@ class ChannelTable:
query = self._has_permission(
db,
query,
{"user_id": user_id, "group_ids": user_group_ids},
permission="read",
{'user_id': user_id, 'group_ids': user_group_ids},
permission='read',
)
allowed = query.first()
@@ -844,7 +775,7 @@ class ChannelTable:
return None
# If the channel is a group or dm, read access requires membership (active)
if channel.type in ["group", "dm"]:
if channel.type in ['group', 'dm']:
membership = (
db.query(ChannelMember)
.filter(
@@ -863,24 +794,18 @@ class ChannelTable:
query = db.query(Channel).filter(Channel.id == id)
# Determine user groups
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
]
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
# Apply ACL rules
query = self._has_permission(
db,
query,
{"user_id": user_id, "group_ids": user_group_ids},
permission="read",
{'user_id': user_id, 'group_ids': user_group_ids},
permission='read',
)
channel_allowed = query.first()
return (
self._to_channel_model(channel_allowed, db=db)
if channel_allowed
else None
)
return self._to_channel_model(channel_allowed, db=db) if channel_allowed else None
def update_channel_by_id(
self, id: str, form_data: ChannelForm, db: Optional[Session] = None
@@ -898,9 +823,7 @@ class ChannelTable:
channel.meta = form_data.meta
if form_data.access_grants is not None:
AccessGrants.set_access_grants(
"channel", id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('channel', id, form_data.access_grants, db=db)
channel.updated_at = int(time.time_ns())
db.commit()
@@ -912,12 +835,12 @@ class ChannelTable:
with get_db_context(db) as db:
channel_file = ChannelFileModel(
**{
"id": str(uuid.uuid4()),
"channel_id": channel_id,
"file_id": file_id,
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
'id': str(uuid.uuid4()),
'channel_id': channel_id,
'file_id': file_id,
'user_id': user_id,
'created_at': int(time.time()),
'updated_at': int(time.time()),
}
)
@@ -942,11 +865,7 @@ class ChannelTable:
) -> bool:
try:
with get_db_context(db) as db:
channel_file = (
db.query(ChannelFile)
.filter_by(channel_id=channel_id, file_id=file_id)
.first()
)
channel_file = db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).first()
if not channel_file:
return False
@@ -958,14 +877,10 @@ class ChannelTable:
except Exception:
return False
def remove_file_from_channel_by_id(
self, channel_id: str, file_id: str, db: Optional[Session] = None
) -> bool:
def remove_file_from_channel_by_id(self, channel_id: str, file_id: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
db.query(ChannelFile).filter_by(
channel_id=channel_id, file_id=file_id
).delete()
db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).delete()
db.commit()
return True
except Exception:
@@ -973,7 +888,7 @@ class ChannelTable:
def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
AccessGrants.revoke_all_access("channel", id, db=db)
AccessGrants.revoke_all_access('channel', id, db=db)
db.query(Channel).filter(Channel.id == id).delete()
db.commit()
return True
@@ -1005,24 +920,14 @@ class ChannelTable:
db.commit()
return webhook
def get_webhooks_by_channel_id(
self, channel_id: str, db: Optional[Session] = None
) -> list[ChannelWebhookModel]:
def get_webhooks_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelWebhookModel]:
with get_db_context(db) as db:
webhooks = (
db.query(ChannelWebhook)
.filter(ChannelWebhook.channel_id == channel_id)
.all()
)
webhooks = db.query(ChannelWebhook).filter(ChannelWebhook.channel_id == channel_id).all()
return [ChannelWebhookModel.model_validate(w) for w in webhooks]
def get_webhook_by_id(
self, webhook_id: str, db: Optional[Session] = None
) -> Optional[ChannelWebhookModel]:
def get_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> Optional[ChannelWebhookModel]:
with get_db_context(db) as db:
webhook = (
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
)
webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
return ChannelWebhookModel.model_validate(webhook) if webhook else None
def get_webhook_by_id_and_token(
@@ -1046,9 +951,7 @@ class ChannelTable:
db: Optional[Session] = None,
) -> Optional[ChannelWebhookModel]:
with get_db_context(db) as db:
webhook = (
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
)
webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
if not webhook:
return None
webhook.name = form_data.name
@@ -1057,28 +960,18 @@ class ChannelTable:
db.commit()
return ChannelWebhookModel.model_validate(webhook)
def update_webhook_last_used_at(
self, webhook_id: str, db: Optional[Session] = None
) -> bool:
def update_webhook_last_used_at(self, webhook_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
webhook = (
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
)
webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
if not webhook:
return False
webhook.last_used_at = int(time.time_ns())
db.commit()
return True
def delete_webhook_by_id(
self, webhook_id: str, db: Optional[Session] = None
) -> bool:
def delete_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
result = (
db.query(ChannelWebhook)
.filter(ChannelWebhook.id == webhook_id)
.delete()
)
result = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).delete()
db.commit()
return result > 0

View File

@@ -47,13 +47,11 @@ def _normalize_timestamp(timestamp: int) -> float:
class ChatMessage(Base):
__tablename__ = "chat_message"
__tablename__ = 'chat_message'
# Identity
id = Column(Text, primary_key=True)
chat_id = Column(
Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False, index=True
)
chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False, index=True)
user_id = Column(Text, index=True)
# Structure
@@ -85,9 +83,9 @@ class ChatMessage(Base):
updated_at = Column(BigInteger)
__table_args__ = (
Index("chat_message_chat_parent_idx", "chat_id", "parent_id"),
Index("chat_message_model_created_idx", "model_id", "created_at"),
Index("chat_message_user_created_idx", "user_id", "created_at"),
Index('chat_message_chat_parent_idx', 'chat_id', 'parent_id'),
Index('chat_message_model_created_idx', 'model_id', 'created_at'),
Index('chat_message_user_created_idx', 'user_id', 'created_at'),
)
@@ -135,43 +133,41 @@ class ChatMessageTable:
"""Insert or update a chat message."""
with get_db_context(db) as db:
now = int(time.time())
timestamp = data.get("timestamp", now)
timestamp = data.get('timestamp', now)
# Use composite ID: {chat_id}-{message_id}
composite_id = f"{chat_id}-{message_id}"
composite_id = f'{chat_id}-{message_id}'
existing = db.get(ChatMessage, composite_id)
if existing:
# Update existing
if "role" in data:
existing.role = data["role"]
if "parent_id" in data:
existing.parent_id = data.get("parent_id") or data.get("parentId")
if "content" in data:
existing.content = data.get("content")
if "output" in data:
existing.output = data.get("output")
if "model_id" in data or "model" in data:
existing.model_id = data.get("model_id") or data.get("model")
if "files" in data:
existing.files = data.get("files")
if "sources" in data:
existing.sources = data.get("sources")
if "embeds" in data:
existing.embeds = data.get("embeds")
if "done" in data:
existing.done = data.get("done", True)
if "status_history" in data or "statusHistory" in data:
existing.status_history = data.get("status_history") or data.get(
"statusHistory"
)
if "error" in data:
existing.error = data.get("error")
if 'role' in data:
existing.role = data['role']
if 'parent_id' in data:
existing.parent_id = data.get('parent_id') or data.get('parentId')
if 'content' in data:
existing.content = data.get('content')
if 'output' in data:
existing.output = data.get('output')
if 'model_id' in data or 'model' in data:
existing.model_id = data.get('model_id') or data.get('model')
if 'files' in data:
existing.files = data.get('files')
if 'sources' in data:
existing.sources = data.get('sources')
if 'embeds' in data:
existing.embeds = data.get('embeds')
if 'done' in data:
existing.done = data.get('done', True)
if 'status_history' in data or 'statusHistory' in data:
existing.status_history = data.get('status_history') or data.get('statusHistory')
if 'error' in data:
existing.error = data.get('error')
# Extract usage - check direct field first, then info.usage
usage = data.get("usage")
usage = data.get('usage')
if not usage:
info = data.get("info", {})
usage = info.get("usage") if info else None
info = data.get('info', {})
usage = info.get('usage') if info else None
if usage:
existing.usage = usage
existing.updated_at = now
@@ -181,26 +177,25 @@ class ChatMessageTable:
else:
# Insert new
# Extract usage - check direct field first, then info.usage
usage = data.get("usage")
usage = data.get('usage')
if not usage:
info = data.get("info", {})
usage = info.get("usage") if info else None
info = data.get('info', {})
usage = info.get('usage') if info else None
message = ChatMessage(
id=composite_id,
chat_id=chat_id,
user_id=user_id,
role=data.get("role", "user"),
parent_id=data.get("parent_id") or data.get("parentId"),
content=data.get("content"),
output=data.get("output"),
model_id=data.get("model_id") or data.get("model"),
files=data.get("files"),
sources=data.get("sources"),
embeds=data.get("embeds"),
done=data.get("done", True),
status_history=data.get("status_history")
or data.get("statusHistory"),
error=data.get("error"),
role=data.get('role', 'user'),
parent_id=data.get('parent_id') or data.get('parentId'),
content=data.get('content'),
output=data.get('output'),
model_id=data.get('model_id') or data.get('model'),
files=data.get('files'),
sources=data.get('sources'),
embeds=data.get('embeds'),
done=data.get('done', True),
status_history=data.get('status_history') or data.get('statusHistory'),
error=data.get('error'),
usage=usage,
created_at=timestamp,
updated_at=now,
@@ -210,23 +205,14 @@ class ChatMessageTable:
db.refresh(message)
return ChatMessageModel.model_validate(message)
def get_message_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[ChatMessageModel]:
def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatMessageModel]:
with get_db_context(db) as db:
message = db.get(ChatMessage, id)
return ChatMessageModel.model_validate(message) if message else None
def get_messages_by_chat_id(
self, chat_id: str, db: Optional[Session] = None
) -> list[ChatMessageModel]:
def get_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[ChatMessageModel]:
with get_db_context(db) as db:
messages = (
db.query(ChatMessage)
.filter_by(chat_id=chat_id)
.order_by(ChatMessage.created_at.asc())
.all()
)
messages = db.query(ChatMessage).filter_by(chat_id=chat_id).order_by(ChatMessage.created_at.asc()).all()
return [ChatMessageModel.model_validate(message) for message in messages]
def get_messages_by_user_id(
@@ -262,12 +248,7 @@ class ChatMessageTable:
query = query.filter(ChatMessage.created_at >= start_date)
if end_date:
query = query.filter(ChatMessage.created_at <= end_date)
messages = (
query.order_by(ChatMessage.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
messages = query.order_by(ChatMessage.created_at.desc()).offset(skip).limit(limit).all()
return [ChatMessageModel.model_validate(message) for message in messages]
def get_chat_ids_by_model_id(
@@ -284,7 +265,7 @@ class ChatMessageTable:
with get_db_context(db) as db:
query = db.query(
ChatMessage.chat_id,
func.max(ChatMessage.created_at).label("last_message_at"),
func.max(ChatMessage.created_at).label('last_message_at'),
).filter(ChatMessage.model_id == model_id)
if start_date:
query = query.filter(ChatMessage.created_at >= start_date)
@@ -303,9 +284,7 @@ class ChatMessageTable:
)
return [chat_id for chat_id, _ in chat_ids]
def delete_messages_by_chat_id(
self, chat_id: str, db: Optional[Session] = None
) -> bool:
def delete_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
db.query(ChatMessage).filter_by(chat_id=chat_id).delete()
db.commit()
@@ -323,12 +302,10 @@ class ChatMessageTable:
from sqlalchemy import func
from open_webui.models.groups import GroupMember
query = db.query(
ChatMessage.model_id, func.count(ChatMessage.id).label("count")
).filter(
ChatMessage.role == "assistant",
query = db.query(ChatMessage.model_id, func.count(ChatMessage.id).label('count')).filter(
ChatMessage.role == 'assistant',
ChatMessage.model_id.isnot(None),
~ChatMessage.user_id.like("shared-%"),
~ChatMessage.user_id.like('shared-%'),
)
if start_date:
@@ -336,11 +313,7 @@ class ChatMessageTable:
if end_date:
query = query.filter(ChatMessage.created_at <= end_date)
if group_id:
group_users = (
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.group_by(ChatMessage.model_id).all()
@@ -360,36 +333,32 @@ class ChatMessageTable:
dialect = db.bind.dialect.name
if dialect == "sqlite":
input_tokens = cast(
func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer
)
output_tokens = cast(
func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer
)
elif dialect == "postgresql":
if dialect == 'sqlite':
input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer)
output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer)
elif dialect == 'postgresql':
# Use json_extract_path_text for PostgreSQL JSON columns
input_tokens = cast(
func.json_extract_path_text(ChatMessage.usage, "input_tokens"),
func.json_extract_path_text(ChatMessage.usage, 'input_tokens'),
Integer,
)
output_tokens = cast(
func.json_extract_path_text(ChatMessage.usage, "output_tokens"),
func.json_extract_path_text(ChatMessage.usage, 'output_tokens'),
Integer,
)
else:
raise NotImplementedError(f"Unsupported dialect: {dialect}")
raise NotImplementedError(f'Unsupported dialect: {dialect}')
query = db.query(
ChatMessage.model_id,
func.coalesce(func.sum(input_tokens), 0).label("input_tokens"),
func.coalesce(func.sum(output_tokens), 0).label("output_tokens"),
func.count(ChatMessage.id).label("message_count"),
func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
func.count(ChatMessage.id).label('message_count'),
).filter(
ChatMessage.role == "assistant",
ChatMessage.role == 'assistant',
ChatMessage.model_id.isnot(None),
ChatMessage.usage.isnot(None),
~ChatMessage.user_id.like("shared-%"),
~ChatMessage.user_id.like('shared-%'),
)
if start_date:
@@ -397,21 +366,17 @@ class ChatMessageTable:
if end_date:
query = query.filter(ChatMessage.created_at <= end_date)
if group_id:
group_users = (
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.group_by(ChatMessage.model_id).all()
return {
row.model_id: {
"input_tokens": row.input_tokens,
"output_tokens": row.output_tokens,
"total_tokens": row.input_tokens + row.output_tokens,
"message_count": row.message_count,
'input_tokens': row.input_tokens,
'output_tokens': row.output_tokens,
'total_tokens': row.input_tokens + row.output_tokens,
'message_count': row.message_count,
}
for row in results
}
@@ -430,36 +395,32 @@ class ChatMessageTable:
dialect = db.bind.dialect.name
if dialect == "sqlite":
input_tokens = cast(
func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer
)
output_tokens = cast(
func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer
)
elif dialect == "postgresql":
if dialect == 'sqlite':
input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer)
output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer)
elif dialect == 'postgresql':
# Use json_extract_path_text for PostgreSQL JSON columns
input_tokens = cast(
func.json_extract_path_text(ChatMessage.usage, "input_tokens"),
func.json_extract_path_text(ChatMessage.usage, 'input_tokens'),
Integer,
)
output_tokens = cast(
func.json_extract_path_text(ChatMessage.usage, "output_tokens"),
func.json_extract_path_text(ChatMessage.usage, 'output_tokens'),
Integer,
)
else:
raise NotImplementedError(f"Unsupported dialect: {dialect}")
raise NotImplementedError(f'Unsupported dialect: {dialect}')
query = db.query(
ChatMessage.user_id,
func.coalesce(func.sum(input_tokens), 0).label("input_tokens"),
func.coalesce(func.sum(output_tokens), 0).label("output_tokens"),
func.count(ChatMessage.id).label("message_count"),
func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
func.count(ChatMessage.id).label('message_count'),
).filter(
ChatMessage.role == "assistant",
ChatMessage.role == 'assistant',
ChatMessage.user_id.isnot(None),
ChatMessage.usage.isnot(None),
~ChatMessage.user_id.like("shared-%"),
~ChatMessage.user_id.like('shared-%'),
)
if start_date:
@@ -467,21 +428,17 @@ class ChatMessageTable:
if end_date:
query = query.filter(ChatMessage.created_at <= end_date)
if group_id:
group_users = (
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.group_by(ChatMessage.user_id).all()
return {
row.user_id: {
"input_tokens": row.input_tokens,
"output_tokens": row.output_tokens,
"total_tokens": row.input_tokens + row.output_tokens,
"message_count": row.message_count,
'input_tokens': row.input_tokens,
'output_tokens': row.output_tokens,
'total_tokens': row.input_tokens + row.output_tokens,
'message_count': row.message_count,
}
for row in results
}
@@ -497,20 +454,16 @@ class ChatMessageTable:
from sqlalchemy import func
from open_webui.models.groups import GroupMember
query = db.query(
ChatMessage.user_id, func.count(ChatMessage.id).label("count")
).filter(~ChatMessage.user_id.like("shared-%"))
query = db.query(ChatMessage.user_id, func.count(ChatMessage.id).label('count')).filter(
~ChatMessage.user_id.like('shared-%')
)
if start_date:
query = query.filter(ChatMessage.created_at >= start_date)
if end_date:
query = query.filter(ChatMessage.created_at <= end_date)
if group_id:
group_users = (
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.group_by(ChatMessage.user_id).all()
@@ -527,20 +480,16 @@ class ChatMessageTable:
from sqlalchemy import func
from open_webui.models.groups import GroupMember
query = db.query(
ChatMessage.chat_id, func.count(ChatMessage.id).label("count")
).filter(~ChatMessage.user_id.like("shared-%"))
query = db.query(ChatMessage.chat_id, func.count(ChatMessage.id).label('count')).filter(
~ChatMessage.user_id.like('shared-%')
)
if start_date:
query = query.filter(ChatMessage.created_at >= start_date)
if end_date:
query = query.filter(ChatMessage.created_at <= end_date)
if group_id:
group_users = (
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.group_by(ChatMessage.chat_id).all()
@@ -559,9 +508,9 @@ class ChatMessageTable:
from open_webui.models.groups import GroupMember
query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter(
ChatMessage.role == "assistant",
ChatMessage.role == 'assistant',
ChatMessage.model_id.isnot(None),
~ChatMessage.user_id.like("shared-%"),
~ChatMessage.user_id.like('shared-%'),
)
if start_date:
@@ -569,11 +518,7 @@ class ChatMessageTable:
if end_date:
query = query.filter(ChatMessage.created_at <= end_date)
if group_id:
group_users = (
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.all()
@@ -581,21 +526,17 @@ class ChatMessageTable:
# Group by date -> model -> count
daily_counts: dict[str, dict[str, int]] = {}
for timestamp, model_id in results:
date_str = datetime.fromtimestamp(
_normalize_timestamp(timestamp)
).strftime("%Y-%m-%d")
date_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d')
if date_str not in daily_counts:
daily_counts[date_str] = {}
daily_counts[date_str][model_id] = (
daily_counts[date_str].get(model_id, 0) + 1
)
daily_counts[date_str][model_id] = daily_counts[date_str].get(model_id, 0) + 1
# Fill in missing days
if start_date and end_date:
current = datetime.fromtimestamp(_normalize_timestamp(start_date))
end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date))
while current <= end_dt:
date_str = current.strftime("%Y-%m-%d")
date_str = current.strftime('%Y-%m-%d')
if date_str not in daily_counts:
daily_counts[date_str] = {}
current += timedelta(days=1)
@@ -613,9 +554,9 @@ class ChatMessageTable:
from datetime import datetime, timedelta
query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter(
ChatMessage.role == "assistant",
ChatMessage.role == 'assistant',
ChatMessage.model_id.isnot(None),
~ChatMessage.user_id.like("shared-%"),
~ChatMessage.user_id.like('shared-%'),
)
if start_date:
@@ -628,23 +569,19 @@ class ChatMessageTable:
# Group by hour -> model -> count
hourly_counts: dict[str, dict[str, int]] = {}
for timestamp, model_id in results:
hour_str = datetime.fromtimestamp(
_normalize_timestamp(timestamp)
).strftime("%Y-%m-%d %H:00")
hour_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d %H:00')
if hour_str not in hourly_counts:
hourly_counts[hour_str] = {}
hourly_counts[hour_str][model_id] = (
hourly_counts[hour_str].get(model_id, 0) + 1
)
hourly_counts[hour_str][model_id] = hourly_counts[hour_str].get(model_id, 0) + 1
# Fill in missing hours
if start_date and end_date:
current = datetime.fromtimestamp(
_normalize_timestamp(start_date)
).replace(minute=0, second=0, microsecond=0)
current = datetime.fromtimestamp(_normalize_timestamp(start_date)).replace(
minute=0, second=0, microsecond=0
)
end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date))
while current <= end_dt:
hour_str = current.strftime("%Y-%m-%d %H:00")
hour_str = current.strftime('%Y-%m-%d %H:00')
if hour_str not in hourly_counts:
hourly_counts[hour_str] = {}
current += timedelta(hours=1)

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
class Feedback(Base):
__tablename__ = "feedback"
__tablename__ = 'feedback'
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text)
version = Column(BigInteger, default=0)
@@ -81,7 +81,7 @@ class RatingData(BaseModel):
sibling_model_ids: Optional[list[str]] = None
reason: Optional[str] = None
comment: Optional[str] = None
model_config = ConfigDict(extra="allow", protected_namespaces=())
model_config = ConfigDict(extra='allow', protected_namespaces=())
class MetaData(BaseModel):
@@ -89,12 +89,12 @@ class MetaData(BaseModel):
chat_id: Optional[str] = None
message_id: Optional[str] = None
tags: Optional[list[str]] = None
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
class SnapshotData(BaseModel):
chat: Optional[dict] = None
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
class FeedbackForm(BaseModel):
@@ -102,14 +102,14 @@ class FeedbackForm(BaseModel):
data: Optional[RatingData] = None
meta: Optional[dict] = None
snapshot: Optional[SnapshotData] = None
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
class UserResponse(BaseModel):
id: str
name: str
email: str
role: str = "pending"
role: str = 'pending'
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@@ -146,12 +146,12 @@ class FeedbackTable:
id = str(uuid.uuid4())
feedback = FeedbackModel(
**{
"id": id,
"user_id": user_id,
"version": 0,
'id': id,
'user_id': user_id,
'version': 0,
**form_data.model_dump(),
"created_at": int(time.time()),
"updated_at": int(time.time()),
'created_at': int(time.time()),
'updated_at': int(time.time()),
}
)
try:
@@ -164,12 +164,10 @@ class FeedbackTable:
else:
return None
except Exception as e:
log.exception(f"Error creating a new feedback: {e}")
log.exception(f'Error creating a new feedback: {e}')
return None
def get_feedback_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[FeedbackModel]:
def get_feedback_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FeedbackModel]:
try:
with get_db_context(db) as db:
feedback = db.query(Feedback).filter_by(id=id).first()
@@ -191,16 +189,14 @@ class FeedbackTable:
except Exception:
return None
def get_feedbacks_by_chat_id(
self, chat_id: str, db: Optional[Session] = None
) -> list[FeedbackModel]:
def get_feedbacks_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[FeedbackModel]:
"""Get all feedbacks for a specific chat."""
try:
with get_db_context(db) as db:
# meta.chat_id stores the chat reference
feedbacks = (
db.query(Feedback)
.filter(Feedback.meta["chat_id"].as_string() == chat_id)
.filter(Feedback.meta['chat_id'].as_string() == chat_id)
.order_by(Feedback.created_at.desc())
.all()
)
@@ -219,36 +215,28 @@ class FeedbackTable:
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
if filter:
order_by = filter.get("order_by")
direction = filter.get("direction")
order_by = filter.get('order_by')
direction = filter.get('direction')
if order_by == "username":
if direction == "asc":
if order_by == 'username':
if direction == 'asc':
query = query.order_by(User.name.asc())
else:
query = query.order_by(User.name.desc())
elif order_by == "model_id":
elif order_by == 'model_id':
# it's stored in feedback.data['model_id']
if direction == "asc":
query = query.order_by(
Feedback.data["model_id"].as_string().asc()
)
if direction == 'asc':
query = query.order_by(Feedback.data['model_id'].as_string().asc())
else:
query = query.order_by(
Feedback.data["model_id"].as_string().desc()
)
elif order_by == "rating":
query = query.order_by(Feedback.data['model_id'].as_string().desc())
elif order_by == 'rating':
# it's stored in feedback.data['rating']
if direction == "asc":
query = query.order_by(
Feedback.data["rating"].as_string().asc()
)
if direction == 'asc':
query = query.order_by(Feedback.data['rating'].as_string().asc())
else:
query = query.order_by(
Feedback.data["rating"].as_string().desc()
)
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Feedback.data['rating'].as_string().desc())
elif order_by == 'updated_at':
if direction == 'asc':
query = query.order_by(Feedback.updated_at.asc())
else:
query = query.order_by(Feedback.updated_at.desc())
@@ -270,9 +258,7 @@ class FeedbackTable:
for feedback, user in items:
feedback_model = FeedbackModel.model_validate(feedback)
user_model = UserResponse.model_validate(user)
feedbacks.append(
FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)
)
feedbacks.append(FeedbackUserResponse(**feedback_model.model_dump(), user=user_model))
return FeedbackListResponse(items=feedbacks, total=total)
@@ -280,14 +266,10 @@ class FeedbackTable:
with get_db_context(db) as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.order_by(Feedback.updated_at.desc())
.all()
for feedback in db.query(Feedback).order_by(Feedback.updated_at.desc()).all()
]
def get_all_feedback_ids(
self, db: Optional[Session] = None
) -> list[FeedbackIdResponse]:
def get_all_feedback_ids(self, db: Optional[Session] = None) -> list[FeedbackIdResponse]:
with get_db_context(db) as db:
return [
FeedbackIdResponse(
@@ -306,14 +288,11 @@ class FeedbackTable:
.all()
]
def get_feedbacks_for_leaderboard(
self, db: Optional[Session] = None
) -> list[LeaderboardFeedbackData]:
def get_feedbacks_for_leaderboard(self, db: Optional[Session] = None) -> list[LeaderboardFeedbackData]:
"""Fetch only id and data for leaderboard computation (excludes snapshot/meta)."""
with get_db_context(db) as db:
return [
LeaderboardFeedbackData(id=row.id, data=row.data)
for row in db.query(Feedback.id, Feedback.data).all()
LeaderboardFeedbackData(id=row.id, data=row.data) for row in db.query(Feedback.id, Feedback.data).all()
]
def get_model_evaluation_history(
@@ -333,30 +312,26 @@ class FeedbackTable:
rows = db.query(Feedback.created_at, Feedback.data).all()
else:
cutoff = int(time.time()) - (days * 86400)
rows = (
db.query(Feedback.created_at, Feedback.data)
.filter(Feedback.created_at >= cutoff)
.all()
)
rows = db.query(Feedback.created_at, Feedback.data).filter(Feedback.created_at >= cutoff).all()
daily_counts = defaultdict(lambda: {"won": 0, "lost": 0})
daily_counts = defaultdict(lambda: {'won': 0, 'lost': 0})
first_date = None
for created_at, data in rows:
if not data:
continue
if data.get("model_id") != model_id:
if data.get('model_id') != model_id:
continue
rating_str = str(data.get("rating", ""))
if rating_str not in ("1", "-1"):
rating_str = str(data.get('rating', ''))
if rating_str not in ('1', '-1'):
continue
date_str = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d")
if rating_str == "1":
daily_counts[date_str]["won"] += 1
date_str = datetime.fromtimestamp(created_at).strftime('%Y-%m-%d')
if rating_str == '1':
daily_counts[date_str]['won'] += 1
else:
daily_counts[date_str]["lost"] += 1
daily_counts[date_str]['lost'] += 1
# Track first date for this model
if first_date is None or date_str < first_date:
@@ -368,7 +343,7 @@ class FeedbackTable:
if days == 0 and first_date:
# All time: start from first feedback date
start_date = datetime.strptime(first_date, "%Y-%m-%d").date()
start_date = datetime.strptime(first_date, '%Y-%m-%d').date()
num_days = (today - start_date).days + 1
else:
# Fixed range
@@ -377,36 +352,24 @@ class FeedbackTable:
for i in range(num_days):
d = start_date + timedelta(days=i)
date_str = d.strftime("%Y-%m-%d")
counts = daily_counts.get(date_str, {"won": 0, "lost": 0})
result.append(
ModelHistoryEntry(date=date_str, won=counts["won"], lost=counts["lost"])
)
date_str = d.strftime('%Y-%m-%d')
counts = daily_counts.get(date_str, {'won': 0, 'lost': 0})
result.append(ModelHistoryEntry(date=date_str, won=counts['won'], lost=counts['lost']))
return result
def get_feedbacks_by_type(
self, type: str, db: Optional[Session] = None
) -> list[FeedbackModel]:
def get_feedbacks_by_type(self, type: str, db: Optional[Session] = None) -> list[FeedbackModel]:
with get_db_context(db) as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.filter_by(type=type)
.order_by(Feedback.updated_at.desc())
.all()
for feedback in db.query(Feedback).filter_by(type=type).order_by(Feedback.updated_at.desc()).all()
]
def get_feedbacks_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> list[FeedbackModel]:
def get_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FeedbackModel]:
with get_db_context(db) as db:
return [
FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback)
.filter_by(user_id=user_id)
.order_by(Feedback.updated_at.desc())
.all()
for feedback in db.query(Feedback).filter_by(user_id=user_id).order_by(Feedback.updated_at.desc()).all()
]
def update_feedback_by_id(
@@ -462,9 +425,7 @@ class FeedbackTable:
db.commit()
return True
def delete_feedback_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[Session] = None
) -> bool:
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback:
@@ -473,9 +434,7 @@ class FeedbackTable:
db.commit()
return True
def delete_feedbacks_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> bool:
def delete_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
result = db.query(Feedback).filter_by(user_id=user_id).delete()
db.commit()

View File

@@ -16,7 +16,7 @@ log = logging.getLogger(__name__)
class File(Base):
__tablename__ = "file"
__tablename__ = 'file'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
hash = Column(Text, nullable=True)
@@ -58,9 +58,9 @@ class FileMeta(BaseModel):
content_type: Optional[str] = None
size: Optional[int] = None
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
@model_validator(mode="before")
@model_validator(mode='before')
@classmethod
def sanitize_meta(cls, data):
"""Sanitize metadata fields to handle malformed legacy data."""
@@ -68,14 +68,12 @@ class FileMeta(BaseModel):
return data
# Handle content_type that may be a list like ['application/pdf', None]
content_type = data.get("content_type")
content_type = data.get('content_type')
if isinstance(content_type, list):
# Extract first non-None string value
data["content_type"] = next(
(item for item in content_type if isinstance(item, str)), None
)
data['content_type'] = next((item for item in content_type if isinstance(item, str)), None)
elif content_type is not None and not isinstance(content_type, str):
data["content_type"] = None
data['content_type'] = None
return data
@@ -92,7 +90,7 @@ class FileModelResponse(BaseModel):
created_at: int # timestamp in epoch
updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
class FileMetadataResponse(BaseModel):
@@ -123,25 +121,22 @@ class FileUpdateForm(BaseModel):
meta: Optional[dict] = None
class FilesTable:
def insert_new_file(
self, user_id: str, form_data: FileForm, db: Optional[Session] = None
) -> Optional[FileModel]:
def insert_new_file(self, user_id: str, form_data: FileForm, db: Optional[Session] = None) -> Optional[FileModel]:
with get_db_context(db) as db:
file_data = form_data.model_dump()
# Sanitize meta to remove non-JSON-serializable objects
# (e.g. callable tool functions, MCP client instances from middleware)
if file_data.get("meta"):
file_data["meta"] = sanitize_metadata(file_data["meta"])
if file_data.get('meta'):
file_data['meta'] = sanitize_metadata(file_data['meta'])
file = FileModel(
**{
**file_data,
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
'user_id': user_id,
'created_at': int(time.time()),
'updated_at': int(time.time()),
}
)
@@ -155,12 +150,10 @@ class FilesTable:
else:
return None
except Exception as e:
log.exception(f"Error inserting a new file: {e}")
log.exception(f'Error inserting a new file: {e}')
return None
def get_file_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[FileModel]:
def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]:
try:
with get_db_context(db) as db:
try:
@@ -171,9 +164,7 @@ class FilesTable:
except Exception:
return None
def get_file_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[FileModel]:
def get_file_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[FileModel]:
with get_db_context(db) as db:
try:
file = db.query(File).filter_by(id=id, user_id=user_id).first()
@@ -184,9 +175,7 @@ class FilesTable:
except Exception:
return None
def get_file_metadata_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[FileMetadataResponse]:
def get_file_metadata_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileMetadataResponse]:
with get_db_context(db) as db:
try:
file = db.get(File, id)
@@ -204,9 +193,7 @@ class FilesTable:
with get_db_context(db) as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
def check_access_by_user_id(
self, id, user_id, permission="write", db: Optional[Session] = None
) -> bool:
def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool:
file = self.get_file_by_id(id, db=db)
if not file:
return False
@@ -215,21 +202,14 @@ class FilesTable:
# Implement additional access control logic here as needed
return False
def get_files_by_ids(
self, ids: list[str], db: Optional[Session] = None
) -> list[FileModel]:
def get_files_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileModel]:
with get_db_context(db) as db:
return [
FileModel.model_validate(file)
for file in db.query(File)
.filter(File.id.in_(ids))
.order_by(File.updated_at.desc())
.all()
for file in db.query(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc()).all()
]
def get_file_metadatas_by_ids(
self, ids: list[str], db: Optional[Session] = None
) -> list[FileMetadataResponse]:
def get_file_metadatas_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileMetadataResponse]:
with get_db_context(db) as db:
return [
FileMetadataResponse(
@@ -239,22 +219,15 @@ class FilesTable:
created_at=file.created_at,
updated_at=file.updated_at,
)
for file in db.query(
File.id, File.hash, File.meta, File.created_at, File.updated_at
)
for file in db.query(File.id, File.hash, File.meta, File.created_at, File.updated_at)
.filter(File.id.in_(ids))
.order_by(File.updated_at.desc())
.all()
]
def get_files_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> list[FileModel]:
def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]:
with get_db_context(db) as db:
return [
FileModel.model_validate(file)
for file in db.query(File).filter_by(user_id=user_id).all()
]
return [FileModel.model_validate(file) for file in db.query(File).filter_by(user_id=user_id).all()]
def get_file_list(
self,
@@ -262,7 +235,7 @@ class FilesTable:
skip: int = 0,
limit: int = 50,
db: Optional[Session] = None,
) -> "FileListResponse":
) -> 'FileListResponse':
with get_db_context(db) as db:
query = db.query(File)
if user_id:
@@ -272,10 +245,7 @@ class FilesTable:
items = [
FileModel.model_validate(file)
for file in query.order_by(File.updated_at.desc(), File.id.desc())
.offset(skip)
.limit(limit)
.all()
for file in query.order_by(File.updated_at.desc(), File.id.desc()).offset(skip).limit(limit).all()
]
return FileListResponse(items=items, total=total)
@@ -296,17 +266,17 @@ class FilesTable:
A SQL LIKE compatible pattern with proper escaping.
"""
# Escape SQL special characters first, then convert glob wildcards
pattern = glob.replace("\\", "\\\\")
pattern = pattern.replace("%", "\\%")
pattern = pattern.replace("_", "\\_")
pattern = pattern.replace("*", "%")
pattern = pattern.replace("?", "_")
pattern = glob.replace('\\', '\\\\')
pattern = pattern.replace('%', '\\%')
pattern = pattern.replace('_', '\\_')
pattern = pattern.replace('*', '%')
pattern = pattern.replace('?', '_')
return pattern
def search_files(
self,
user_id: Optional[str] = None,
filename: str = "*",
filename: str = '*',
skip: int = 0,
limit: int = 100,
db: Optional[Session] = None,
@@ -331,15 +301,12 @@ class FilesTable:
query = query.filter_by(user_id=user_id)
pattern = self._glob_to_like_pattern(filename)
if pattern != "%":
query = query.filter(File.filename.ilike(pattern, escape="\\"))
if pattern != '%':
query = query.filter(File.filename.ilike(pattern, escape='\\'))
return [
FileModel.model_validate(file)
for file in query.order_by(File.created_at.desc(), File.id.desc())
.offset(skip)
.limit(limit)
.all()
for file in query.order_by(File.created_at.desc(), File.id.desc()).offset(skip).limit(limit).all()
]
def update_file_by_id(
@@ -362,12 +329,10 @@ class FilesTable:
db.commit()
return FileModel.model_validate(file)
except Exception as e:
log.exception(f"Error updating file completely by id: {e}")
log.exception(f'Error updating file completely by id: {e}')
return None
def update_file_hash_by_id(
self, id: str, hash: Optional[str], db: Optional[Session] = None
) -> Optional[FileModel]:
def update_file_hash_by_id(self, id: str, hash: Optional[str], db: Optional[Session] = None) -> Optional[FileModel]:
with get_db_context(db) as db:
try:
file = db.query(File).filter_by(id=id).first()
@@ -379,9 +344,7 @@ class FilesTable:
except Exception:
return None
def update_file_data_by_id(
self, id: str, data: dict, db: Optional[Session] = None
) -> Optional[FileModel]:
def update_file_data_by_id(self, id: str, data: dict, db: Optional[Session] = None) -> Optional[FileModel]:
with get_db_context(db) as db:
try:
file = db.query(File).filter_by(id=id).first()
@@ -390,12 +353,9 @@ class FilesTable:
db.commit()
return FileModel.model_validate(file)
except Exception as e:
return None
def update_file_metadata_by_id(
self, id: str, meta: dict, db: Optional[Session] = None
) -> Optional[FileModel]:
def update_file_metadata_by_id(self, id: str, meta: dict, db: Optional[Session] = None) -> Optional[FileModel]:
with get_db_context(db) as db:
try:
file = db.query(File).filter_by(id=id).first()

View File

@@ -20,7 +20,7 @@ log = logging.getLogger(__name__)
class Folder(Base):
__tablename__ = "folder"
__tablename__ = 'folder'
id = Column(Text, primary_key=True, unique=True)
parent_id = Column(Text, nullable=True)
user_id = Column(Text)
@@ -72,14 +72,14 @@ class FolderForm(BaseModel):
data: Optional[dict] = None
meta: Optional[dict] = None
parent_id: Optional[str] = None
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
class FolderUpdateForm(BaseModel):
name: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
class FolderTable:
@@ -94,12 +94,12 @@ class FolderTable:
id = str(uuid.uuid4())
folder = FolderModel(
**{
"id": id,
"user_id": user_id,
'id': id,
'user_id': user_id,
**(form_data.model_dump(exclude_unset=True) or {}),
"parent_id": parent_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
'parent_id': parent_id,
'created_at': int(time.time()),
'updated_at': int(time.time()),
}
)
try:
@@ -112,7 +112,7 @@ class FolderTable:
else:
return None
except Exception as e:
log.exception(f"Error inserting a new folder: {e}")
log.exception(f'Error inserting a new folder: {e}')
return None
def get_folder_by_id_and_user_id(
@@ -137,9 +137,7 @@ class FolderTable:
folders = []
def get_children(folder):
children = self.get_folders_by_parent_id_and_user_id(
folder.id, user_id, db=db
)
children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db)
for child in children:
get_children(child)
folders.append(child)
@@ -153,14 +151,9 @@ class FolderTable:
except Exception:
return None
def get_folders_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> list[FolderModel]:
def get_folders_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FolderModel]:
with get_db_context(db) as db:
return [
FolderModel.model_validate(folder)
for folder in db.query(Folder).filter_by(user_id=user_id).all()
]
return [FolderModel.model_validate(folder) for folder in db.query(Folder).filter_by(user_id=user_id).all()]
def get_folder_by_parent_id_and_user_id_and_name(
self,
@@ -184,7 +177,7 @@ class FolderTable:
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
log.error(f'get_folder_by_parent_id_and_user_id_and_name: {e}')
return None
def get_folders_by_parent_id_and_user_id(
@@ -193,9 +186,7 @@ class FolderTable:
with get_db_context(db) as db:
return [
FolderModel.model_validate(folder)
for folder in db.query(Folder)
.filter_by(parent_id=parent_id, user_id=user_id)
.all()
for folder in db.query(Folder).filter_by(parent_id=parent_id, user_id=user_id).all()
]
def update_folder_parent_id_by_id_and_user_id(
@@ -219,7 +210,7 @@ class FolderTable:
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
log.error(f'update_folder: {e}')
return
def update_folder_by_id_and_user_id(
@@ -241,7 +232,7 @@ class FolderTable:
existing_folder = (
db.query(Folder)
.filter_by(
name=form_data.get("name"),
name=form_data.get('name'),
parent_id=folder.parent_id,
user_id=user_id,
)
@@ -251,17 +242,17 @@ class FolderTable:
if existing_folder and existing_folder.id != id:
return None
folder.name = form_data.get("name", folder.name)
if "data" in form_data:
folder.name = form_data.get('name', folder.name)
if 'data' in form_data:
folder.data = {
**(folder.data or {}),
**form_data["data"],
**form_data['data'],
}
if "meta" in form_data:
if 'meta' in form_data:
folder.meta = {
**(folder.meta or {}),
**form_data["meta"],
**form_data['meta'],
}
folder.updated_at = int(time.time())
@@ -269,7 +260,7 @@ class FolderTable:
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
log.error(f'update_folder: {e}')
return
def update_folder_is_expanded_by_id_and_user_id(
@@ -289,12 +280,10 @@ class FolderTable:
return FolderModel.model_validate(folder)
except Exception as e:
log.error(f"update_folder: {e}")
log.error(f'update_folder: {e}')
return
def delete_folder_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[Session] = None
) -> list[str]:
def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[str]:
try:
folder_ids = []
with get_db_context(db) as db:
@@ -306,11 +295,8 @@ class FolderTable:
# Delete all children folders
def delete_children(folder):
folder_children = self.get_folders_by_parent_id_and_user_id(
folder.id, user_id, db=db
)
folder_children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db)
for folder_child in folder_children:
delete_children(folder_child)
folder_ids.append(folder_child.id)
@@ -323,12 +309,12 @@ class FolderTable:
db.commit()
return folder_ids
except Exception as e:
log.error(f"delete_folder: {e}")
log.error(f'delete_folder: {e}')
return []
def normalize_folder_name(self, name: str) -> str:
# Replace _ and space with a single space, lower case, collapse multiple spaces
name = re.sub(r"[\s_]+", " ", name)
name = re.sub(r'[\s_]+', ' ', name)
return name.strip().lower()
def search_folders_by_names(
@@ -349,9 +335,7 @@ class FolderTable:
results[folder.id] = FolderModel.model_validate(folder)
# get children folders
children = self.get_children_folders_by_id_and_user_id(
folder.id, user_id, db=db
)
children = self.get_children_folders_by_id_and_user_id(folder.id, user_id, db=db)
for child in children:
results[child.id] = child

View File

@@ -16,7 +16,7 @@ log = logging.getLogger(__name__)
class Function(Base):
__tablename__ = "function"
__tablename__ = 'function'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
@@ -30,13 +30,13 @@ class Function(Base):
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
__table_args__ = (Index("is_global_idx", "is_global"),)
__table_args__ = (Index('is_global_idx', 'is_global'),)
class FunctionMeta(BaseModel):
description: Optional[str] = None
manifest: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
class FunctionModel(BaseModel):
@@ -113,10 +113,10 @@ class FunctionsTable:
function = FunctionModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"type": type,
"updated_at": int(time.time()),
"created_at": int(time.time()),
'user_id': user_id,
'type': type,
'updated_at': int(time.time()),
'created_at': int(time.time()),
}
)
@@ -131,7 +131,7 @@ class FunctionsTable:
else:
return None
except Exception as e:
log.exception(f"Error creating a new function: {e}")
log.exception(f'Error creating a new function: {e}')
return None
def sync_functions(
@@ -156,16 +156,16 @@ class FunctionsTable:
db.query(Function).filter_by(id=func.id).update(
{
**func.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
'user_id': user_id,
'updated_at': int(time.time()),
}
)
else:
new_func = Function(
**{
**func.model_dump(),
"user_id": user_id,
"updated_at": int(time.time()),
'user_id': user_id,
'updated_at': int(time.time()),
}
)
db.add(new_func)
@@ -177,17 +177,12 @@ class FunctionsTable:
db.commit()
return [
FunctionModel.model_validate(func)
for func in db.query(Function).all()
]
return [FunctionModel.model_validate(func) for func in db.query(Function).all()]
except Exception as e:
log.exception(f"Error syncing functions for user {user_id}: {e}")
log.exception(f'Error syncing functions for user {user_id}: {e}')
return []
def get_function_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[FunctionModel]:
def get_function_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FunctionModel]:
try:
with get_db_context(db) as db:
function = db.get(Function, id)
@@ -195,9 +190,7 @@ class FunctionsTable:
except Exception:
return None
def get_functions_by_ids(
self, ids: list[str], db: Optional[Session] = None
) -> list[FunctionModel]:
def get_functions_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FunctionModel]:
"""
Batch fetch multiple functions by their IDs in a single query.
Returns functions in the same order as the input IDs (None entries filtered out).
@@ -225,18 +218,11 @@ class FunctionsTable:
functions = db.query(Function).all()
if include_valves:
return [
FunctionWithValvesModel.model_validate(function)
for function in functions
]
return [FunctionWithValvesModel.model_validate(function) for function in functions]
else:
return [
FunctionModel.model_validate(function) for function in functions
]
return [FunctionModel.model_validate(function) for function in functions]
def get_function_list(
self, db: Optional[Session] = None
) -> list[FunctionUserResponse]:
def get_function_list(self, db: Optional[Session] = None) -> list[FunctionUserResponse]:
with get_db_context(db) as db:
functions = db.query(Function).order_by(Function.updated_at.desc()).all()
user_ids = list(set(func.user_id for func in functions))
@@ -248,69 +234,48 @@ class FunctionsTable:
FunctionUserResponse.model_validate(
{
**FunctionModel.model_validate(func).model_dump(),
"user": (
users_dict.get(func.user_id).model_dump()
if func.user_id in users_dict
else None
),
'user': (users_dict.get(func.user_id).model_dump() if func.user_id in users_dict else None),
}
)
for func in functions
]
def get_functions_by_type(
self, type: str, active_only=False, db: Optional[Session] = None
) -> list[FunctionModel]:
def get_functions_by_type(self, type: str, active_only=False, db: Optional[Session] = None) -> list[FunctionModel]:
with get_db_context(db) as db:
if active_only:
return [
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type=type, is_active=True)
.all()
for function in db.query(Function).filter_by(type=type, is_active=True).all()
]
else:
return [
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(type=type).all()
FunctionModel.model_validate(function) for function in db.query(Function).filter_by(type=type).all()
]
def get_global_filter_functions(
self, db: Optional[Session] = None
) -> list[FunctionModel]:
def get_global_filter_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
with get_db_context(db) as db:
return [
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type="filter", is_active=True, is_global=True)
.all()
for function in db.query(Function).filter_by(type='filter', is_active=True, is_global=True).all()
]
def get_global_action_functions(
self, db: Optional[Session] = None
) -> list[FunctionModel]:
def get_global_action_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
with get_db_context(db) as db:
return [
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type="action", is_active=True, is_global=True)
.all()
for function in db.query(Function).filter_by(type='action', is_active=True, is_global=True).all()
]
def get_function_valves_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[dict]:
def get_function_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
with get_db_context(db) as db:
try:
function = db.get(Function, id)
return function.valves if function.valves else {}
except Exception as e:
log.exception(f"Error getting function valves by id {id}: {e}")
log.exception(f'Error getting function valves by id {id}: {e}')
return None
def get_function_valves_by_ids(
self, ids: list[str], db: Optional[Session] = None
) -> dict[str, dict]:
def get_function_valves_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, dict]:
"""
Batch fetch valves for multiple functions in a single query.
Returns a dict mapping function_id -> valves dict.
@@ -320,14 +285,10 @@ class FunctionsTable:
return {}
try:
with get_db_context(db) as db:
functions = (
db.query(Function.id, Function.valves)
.filter(Function.id.in_(ids))
.all()
)
functions = db.query(Function.id, Function.valves).filter(Function.id.in_(ids)).all()
return {f.id: (f.valves if f.valves else {}) for f in functions}
except Exception as e:
log.exception(f"Error batch-fetching function valves: {e}")
log.exception(f'Error batch-fetching function valves: {e}')
return {}
def update_function_valves_by_id(
@@ -364,25 +325,23 @@ class FunctionsTable:
else:
return None
except Exception as e:
log.exception(f"Error updating function metadata by id {id}: {e}")
log.exception(f'Error updating function metadata by id {id}: {e}')
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[dict]:
def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
user_settings["functions"] = {}
if "valves" not in user_settings["functions"]:
user_settings["functions"]["valves"] = {}
if 'functions' not in user_settings:
user_settings['functions'] = {}
if 'valves' not in user_settings['functions']:
user_settings['functions']['valves'] = {}
return user_settings["functions"]["valves"].get(id, {})
return user_settings['functions']['valves'].get(id, {})
except Exception as e:
log.exception(f"Error getting user values by id {id} and user id {user_id}")
log.exception(f'Error getting user values by id {id} and user id {user_id}')
return None
def update_user_valves_by_id_and_user_id(
@@ -393,32 +352,28 @@ class FunctionsTable:
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
user_settings["functions"] = {}
if "valves" not in user_settings["functions"]:
user_settings["functions"]["valves"] = {}
if 'functions' not in user_settings:
user_settings['functions'] = {}
if 'valves' not in user_settings['functions']:
user_settings['functions']['valves'] = {}
user_settings["functions"]["valves"][id] = valves
user_settings['functions']['valves'][id] = valves
# Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings}, db=db)
Users.update_user_by_id(user_id, {'settings': user_settings}, db=db)
return user_settings["functions"]["valves"][id]
return user_settings['functions']['valves'][id]
except Exception as e:
log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
return None
def update_function_by_id(
self, id: str, updated: dict, db: Optional[Session] = None
) -> Optional[FunctionModel]:
def update_function_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[FunctionModel]:
with get_db_context(db) as db:
try:
db.query(Function).filter_by(id=id).update(
{
**updated,
"updated_at": int(time.time()),
'updated_at': int(time.time()),
}
)
db.commit()
@@ -432,8 +387,8 @@ class FunctionsTable:
try:
db.query(Function).update(
{
"is_active": False,
"updated_at": int(time.time()),
'is_active': False,
'updated_at': int(time.time()),
}
)
db.commit()

View File

@@ -34,7 +34,7 @@ log = logging.getLogger(__name__)
class Group(Base):
__tablename__ = "group"
__tablename__ = 'group'
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text)
@@ -70,12 +70,12 @@ class GroupModel(BaseModel):
class GroupMember(Base):
__tablename__ = "group_member"
__tablename__ = 'group_member'
id = Column(Text, unique=True, primary_key=True)
group_id = Column(
Text,
ForeignKey("group.id", ondelete="CASCADE"),
ForeignKey('group.id', ondelete='CASCADE'),
nullable=False,
)
user_id = Column(Text, nullable=False)
@@ -133,28 +133,26 @@ class GroupListResponse(BaseModel):
class GroupTable:
def _ensure_default_share_config(self, group_data: dict) -> dict:
"""Ensure the group data dict has a default share config if not already set."""
if "data" not in group_data or group_data["data"] is None:
group_data["data"] = {}
if "config" not in group_data["data"]:
group_data["data"]["config"] = {}
if "share" not in group_data["data"]["config"]:
group_data["data"]["config"]["share"] = DEFAULT_GROUP_SHARE_PERMISSION
if 'data' not in group_data or group_data['data'] is None:
group_data['data'] = {}
if 'config' not in group_data['data']:
group_data['data']['config'] = {}
if 'share' not in group_data['data']['config']:
group_data['data']['config']['share'] = DEFAULT_GROUP_SHARE_PERMISSION
return group_data
def insert_new_group(
self, user_id: str, form_data: GroupForm, db: Optional[Session] = None
) -> Optional[GroupModel]:
with get_db_context(db) as db:
group_data = self._ensure_default_share_config(
form_data.model_dump(exclude_none=True)
)
group_data = self._ensure_default_share_config(form_data.model_dump(exclude_none=True))
group = GroupModel(
**{
**group_data,
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
'id': str(uuid.uuid4()),
'user_id': user_id,
'created_at': int(time.time()),
'updated_at': int(time.time()),
}
)
@@ -183,19 +181,19 @@ class GroupTable:
.where(GroupMember.group_id == Group.id)
.correlate(Group)
.scalar_subquery()
.label("member_count")
.label('member_count')
)
query = db.query(Group, member_count)
if filter:
if "query" in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
if 'query' in filter:
query = query.filter(Group.name.ilike(f'%{filter["query"]}%'))
# When share filter is present, member check is handled in the share logic
if "share" in filter:
share_value = filter["share"]
member_id = filter.get("member_id")
json_share = Group.data["config"]["share"]
if 'share' in filter:
share_value = filter['share']
member_id = filter.get('member_id')
json_share = Group.data['config']['share']
json_share_str = json_share.as_string()
json_share_lower = func.lower(json_share_str)
@@ -203,37 +201,27 @@ class GroupTable:
anyone_can_share = or_(
Group.data.is_(None),
json_share_str.is_(None),
json_share_lower == "true",
json_share_lower == "1", # Handle SQLite boolean true
json_share_lower == 'true',
json_share_lower == '1', # Handle SQLite boolean true
)
if member_id:
member_groups_select = select(GroupMember.group_id).where(
GroupMember.user_id == member_id
)
member_groups_select = select(GroupMember.group_id).where(GroupMember.user_id == member_id)
members_only_and_is_member = and_(
json_share_lower == "members",
json_share_lower == 'members',
Group.id.in_(member_groups_select),
)
query = query.filter(
or_(anyone_can_share, members_only_and_is_member)
)
query = query.filter(or_(anyone_can_share, members_only_and_is_member))
else:
query = query.filter(anyone_can_share)
else:
query = query.filter(
and_(Group.data.isnot(None), json_share_lower == "false")
)
query = query.filter(and_(Group.data.isnot(None), json_share_lower == 'false'))
else:
# Only apply member_id filter when share filter is NOT present
if "member_id" in filter:
if 'member_id' in filter:
query = query.filter(
Group.id.in_(
select(GroupMember.group_id).where(
GroupMember.user_id == filter["member_id"]
)
)
Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id']))
)
results = query.order_by(Group.updated_at.desc()).all()
@@ -242,7 +230,7 @@ class GroupTable:
GroupResponse.model_validate(
{
**GroupModel.model_validate(group).model_dump(),
"member_count": count or 0,
'member_count': count or 0,
}
)
for group, count in results
@@ -259,22 +247,16 @@ class GroupTable:
query = db.query(Group)
if filter:
if "query" in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
if "member_id" in filter:
if 'query' in filter:
query = query.filter(Group.name.ilike(f'%{filter["query"]}%'))
if 'member_id' in filter:
query = query.filter(
Group.id.in_(
select(GroupMember.group_id).where(
GroupMember.user_id == filter["member_id"]
)
)
Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id']))
)
if "share" in filter:
share_value = filter["share"]
query = query.filter(
Group.data.op("->>")("share") == str(share_value)
)
if 'share' in filter:
share_value = filter['share']
query = query.filter(Group.data.op('->>')('share') == str(share_value))
total = query.count()
@@ -283,32 +265,24 @@ class GroupTable:
.where(GroupMember.group_id == Group.id)
.correlate(Group)
.scalar_subquery()
.label("member_count")
)
results = (
query.add_columns(member_count)
.order_by(Group.updated_at.desc())
.offset(skip)
.limit(limit)
.all()
.label('member_count')
)
results = query.add_columns(member_count).order_by(Group.updated_at.desc()).offset(skip).limit(limit).all()
return {
"items": [
'items': [
GroupResponse.model_validate(
{
**GroupModel.model_validate(group).model_dump(),
"member_count": count or 0,
'member_count': count or 0,
}
)
for group, count in results
],
"total": total,
'total': total,
}
def get_groups_by_member_id(
self, user_id: str, db: Optional[Session] = None
) -> list[GroupModel]:
def get_groups_by_member_id(self, user_id: str, db: Optional[Session] = None) -> list[GroupModel]:
with get_db_context(db) as db:
return [
GroupModel.model_validate(group)
@@ -340,9 +314,7 @@ class GroupTable:
return user_groups
def get_group_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[GroupModel]:
def get_group_by_id(self, id: str, db: Optional[Session] = None) -> Optional[GroupModel]:
try:
with get_db_context(db) as db:
group = db.query(Group).filter_by(id=id).first()
@@ -350,41 +322,29 @@ class GroupTable:
except Exception:
return None
def get_group_user_ids_by_id(
self, id: str, db: Optional[Session] = None
) -> list[str]:
def get_group_user_ids_by_id(self, id: str, db: Optional[Session] = None) -> list[str]:
with get_db_context(db) as db:
members = (
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
)
members = db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
if not members:
return []
return [m[0] for m in members]
def get_group_user_ids_by_ids(
self, group_ids: list[str], db: Optional[Session] = None
) -> dict[str, list[str]]:
def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[Session] = None) -> dict[str, list[str]]:
with get_db_context(db) as db:
members = (
db.query(GroupMember.group_id, GroupMember.user_id)
.filter(GroupMember.group_id.in_(group_ids))
.all()
db.query(GroupMember.group_id, GroupMember.user_id).filter(GroupMember.group_id.in_(group_ids)).all()
)
group_user_ids: dict[str, list[str]] = {
group_id: [] for group_id in group_ids
}
group_user_ids: dict[str, list[str]] = {group_id: [] for group_id in group_ids}
for group_id, user_id in members:
group_user_ids[group_id].append(user_id)
return group_user_ids
def set_group_user_ids_by_id(
self, group_id: str, user_ids: list[str], db: Optional[Session] = None
) -> None:
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[Session] = None) -> None:
with get_db_context(db) as db:
# Delete existing members
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
@@ -405,20 +365,12 @@ class GroupTable:
db.add_all(new_members)
db.commit()
def get_group_member_count_by_id(
self, id: str, db: Optional[Session] = None
) -> int:
def get_group_member_count_by_id(self, id: str, db: Optional[Session] = None) -> int:
with get_db_context(db) as db:
count = (
db.query(func.count(GroupMember.user_id))
.filter(GroupMember.group_id == id)
.scalar()
)
count = db.query(func.count(GroupMember.user_id)).filter(GroupMember.group_id == id).scalar()
return count if count else 0
def get_group_member_counts_by_ids(
self, ids: list[str], db: Optional[Session] = None
) -> dict[str, int]:
def get_group_member_counts_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, int]:
if not ids:
return {}
with get_db_context(db) as db:
@@ -442,7 +394,7 @@ class GroupTable:
db.query(Group).filter_by(id=id).update(
{
**form_data.model_dump(exclude_none=True),
"updated_at": int(time.time()),
'updated_at': int(time.time()),
}
)
db.commit()
@@ -470,9 +422,7 @@ class GroupTable:
except Exception:
return False
def remove_user_from_all_groups(
self, user_id: str, db: Optional[Session] = None
) -> bool:
def remove_user_from_all_groups(self, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
try:
# Find all groups the user belongs to
@@ -489,9 +439,7 @@ class GroupTable:
GroupMember.group_id == group.id, GroupMember.user_id == user_id
).delete()
db.query(Group).filter_by(id=group.id).update(
{"updated_at": int(time.time())}
)
db.query(Group).filter_by(id=group.id).update({'updated_at': int(time.time())})
db.commit()
return True
@@ -503,7 +451,6 @@ class GroupTable:
def create_groups_by_group_names(
self, user_id: str, group_names: list[str], db: Optional[Session] = None
) -> list[GroupModel]:
# check for existing groups
existing_groups = self.get_all_groups(db=db)
existing_group_names = {group.name for group in existing_groups}
@@ -517,10 +464,10 @@ class GroupTable:
id=str(uuid.uuid4()),
user_id=user_id,
name=group_name,
description="",
description='',
data={
"config": {
"share": DEFAULT_GROUP_SHARE_PERMISSION,
'config': {
'share': DEFAULT_GROUP_SHARE_PERMISSION,
}
},
created_at=int(time.time()),
@@ -537,17 +484,13 @@ class GroupTable:
continue
return new_groups
def sync_groups_by_group_names(
self, user_id: str, group_names: list[str], db: Optional[Session] = None
) -> bool:
def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
try:
now = int(time.time())
# 1. Groups that SHOULD contain the user
target_groups = (
db.query(Group).filter(Group.name.in_(group_names)).all()
)
target_groups = db.query(Group).filter(Group.name.in_(group_names)).all()
target_group_ids = {g.id for g in target_groups}
# 2. Groups the user is CURRENTLY in
@@ -571,7 +514,7 @@ class GroupTable:
).delete(synchronize_session=False)
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
{"updated_at": now}, synchronize_session=False
{'updated_at': now}, synchronize_session=False
)
# 5. Bulk insert missing memberships
@@ -588,7 +531,7 @@ class GroupTable:
if groups_to_add:
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
{"updated_at": now}, synchronize_session=False
{'updated_at': now}, synchronize_session=False
)
db.commit()
@@ -656,9 +599,9 @@ class GroupTable:
return GroupModel.model_validate(group)
# Remove users from group_member in batch
db.query(GroupMember).filter(
GroupMember.group_id == id, GroupMember.user_id.in_(user_ids)
).delete(synchronize_session=False)
db.query(GroupMember).filter(GroupMember.group_id == id, GroupMember.user_id.in_(user_ids)).delete(
synchronize_session=False
)
# Update group timestamp
group.updated_at = int(time.time())

View File

@@ -38,7 +38,7 @@ log = logging.getLogger(__name__)
class Knowledge(Base):
__tablename__ = "knowledge"
__tablename__ = 'knowledge'
id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text)
@@ -70,24 +70,18 @@ class KnowledgeModel(BaseModel):
class KnowledgeFile(Base):
__tablename__ = "knowledge_file"
__tablename__ = 'knowledge_file'
id = Column(Text, unique=True, primary_key=True)
knowledge_id = Column(
Text, ForeignKey("knowledge.id", ondelete="CASCADE"), nullable=False
)
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
knowledge_id = Column(Text, ForeignKey('knowledge.id', ondelete='CASCADE'), nullable=False)
file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False)
user_id = Column(Text, nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
__table_args__ = (
UniqueConstraint(
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
),
)
__table_args__ = (UniqueConstraint('knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file'),)
class KnowledgeFileModel(BaseModel):
@@ -138,10 +132,8 @@ class KnowledgeFileListResponse(BaseModel):
class KnowledgeTable:
def _get_access_grants(
self, knowledge_id: str, db: Optional[Session] = None
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("knowledge", knowledge_id, db=db)
def _get_access_grants(self, knowledge_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource('knowledge', knowledge_id, db=db)
def _to_knowledge_model(
self,
@@ -149,13 +141,9 @@ class KnowledgeTable:
access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None,
) -> KnowledgeModel:
knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump(
exclude={"access_grants"}
)
knowledge_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(knowledge_data["id"], db=db)
knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump(exclude={'access_grants'})
knowledge_data['access_grants'] = (
access_grants if access_grants is not None else self._get_access_grants(knowledge_data['id'], db=db)
)
return KnowledgeModel.model_validate(knowledge_data)
@@ -165,23 +153,21 @@ class KnowledgeTable:
with get_db_context(db) as db:
knowledge = KnowledgeModel(
**{
**form_data.model_dump(exclude={"access_grants"}),
"id": str(uuid.uuid4()),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
"access_grants": [],
**form_data.model_dump(exclude={'access_grants'}),
'id': str(uuid.uuid4()),
'user_id': user_id,
'created_at': int(time.time()),
'updated_at': int(time.time()),
'access_grants': [],
}
)
try:
result = Knowledge(**knowledge.model_dump(exclude={"access_grants"}))
result = Knowledge(**knowledge.model_dump(exclude={'access_grants'}))
db.add(result)
db.commit()
db.refresh(result)
AccessGrants.set_access_grants(
"knowledge", result.id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('knowledge', result.id, form_data.access_grants, db=db)
if result:
return self._to_knowledge_model(result, db=db)
else:
@@ -193,17 +179,13 @@ class KnowledgeTable:
self, skip: int = 0, limit: int = 30, db: Optional[Session] = None
) -> list[KnowledgeUserModel]:
with get_db_context(db) as db:
all_knowledge = (
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
)
all_knowledge = db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
knowledge_ids = [knowledge.id for knowledge in all_knowledge]
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users}
grants_map = AccessGrants.get_grants_by_resources(
"knowledge", knowledge_ids, db=db
)
grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
knowledge_bases = []
for knowledge in all_knowledge:
@@ -216,7 +198,7 @@ class KnowledgeTable:
access_grants=grants_map.get(knowledge.id, []),
db=db,
).model_dump(),
"user": user.model_dump() if user else None,
'user': user.model_dump() if user else None,
}
)
)
@@ -232,27 +214,25 @@ class KnowledgeTable:
) -> KnowledgeListResponse:
try:
with get_db_context(db) as db:
query = db.query(Knowledge, User).outerjoin(
User, User.id == Knowledge.user_id
)
query = db.query(Knowledge, User).outerjoin(User, User.id == Knowledge.user_id)
if filter:
query_key = filter.get("query")
query_key = filter.get('query')
if query_key:
query = query.filter(
or_(
Knowledge.name.ilike(f"%{query_key}%"),
Knowledge.description.ilike(f"%{query_key}%"),
User.name.ilike(f"%{query_key}%"),
User.email.ilike(f"%{query_key}%"),
User.username.ilike(f"%{query_key}%"),
Knowledge.name.ilike(f'%{query_key}%'),
Knowledge.description.ilike(f'%{query_key}%'),
User.name.ilike(f'%{query_key}%'),
User.email.ilike(f'%{query_key}%'),
User.username.ilike(f'%{query_key}%'),
)
)
view_option = filter.get("view_option")
if view_option == "created":
view_option = filter.get('view_option')
if view_option == 'created':
query = query.filter(Knowledge.user_id == user_id)
elif view_option == "shared":
elif view_option == 'shared':
query = query.filter(Knowledge.user_id != user_id)
query = AccessGrants.has_permission_filter(
@@ -260,8 +240,8 @@ class KnowledgeTable:
query=query,
DocumentModel=Knowledge,
filter=filter,
resource_type="knowledge",
permission="read",
resource_type='knowledge',
permission='read',
)
query = query.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc())
@@ -275,9 +255,7 @@ class KnowledgeTable:
items = query.all()
knowledge_ids = [kb.id for kb, _ in items]
grants_map = AccessGrants.get_grants_by_resources(
"knowledge", knowledge_ids, db=db
)
grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
knowledge_bases = []
for knowledge_base, user in items:
@@ -289,11 +267,7 @@ class KnowledgeTable:
access_grants=grants_map.get(knowledge_base.id, []),
db=db,
).model_dump(),
"user": (
UserModel.model_validate(user).model_dump()
if user
else None
),
'user': (UserModel.model_validate(user).model_dump() if user else None),
}
)
)
@@ -327,15 +301,15 @@ class KnowledgeTable:
query=query,
DocumentModel=Knowledge,
filter=filter,
resource_type="knowledge",
permission="read",
resource_type='knowledge',
permission='read',
)
# Apply filename search
if filter:
q = filter.get("query")
q = filter.get('query')
if q:
query = query.filter(File.filename.ilike(f"%{q}%"))
query = query.filter(File.filename.ilike(f'%{q}%'))
# Order by file changes
query = query.order_by(File.updated_at.desc(), File.id.asc())
@@ -355,39 +329,27 @@ class KnowledgeTable:
items.append(
FileUserResponse(
**FileModel.model_validate(file).model_dump(),
user=(
UserResponse(
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
collection=self._to_knowledge_model(
knowledge, db=db
).model_dump(),
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
collection=self._to_knowledge_model(knowledge, db=db).model_dump(),
)
)
return KnowledgeFileListResponse(items=items, total=total)
except Exception as e:
print("search_knowledge_files error:", e)
print('search_knowledge_files error:', e)
return KnowledgeFileListResponse(items=[], total=0)
def check_access_by_user_id(
self, id, user_id, permission="write", db: Optional[Session] = None
) -> bool:
def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool:
knowledge = self.get_knowledge_by_id(id, db=db)
if not knowledge:
return False
if knowledge.user_id == user_id:
return True
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return AccessGrants.has_access(
user_id=user_id,
resource_type="knowledge",
resource_type='knowledge',
resource_id=knowledge.id,
permission=permission,
user_group_ids=user_group_ids,
@@ -395,19 +357,17 @@ class KnowledgeTable:
)
def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write", db: Optional[Session] = None
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases(db=db)
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return [
knowledge_base
for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id
or AccessGrants.has_access(
user_id=user_id,
resource_type="knowledge",
resource_type='knowledge',
resource_id=knowledge_base.id,
permission=permission,
user_group_ids=user_group_ids,
@@ -415,9 +375,7 @@ class KnowledgeTable:
)
]
def get_knowledge_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[KnowledgeModel]:
def get_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
try:
with get_db_context(db) as db:
knowledge = db.query(Knowledge).filter_by(id=id).first()
@@ -435,23 +393,19 @@ class KnowledgeTable:
if knowledge.user_id == user_id:
return knowledge
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
if AccessGrants.has_access(
user_id=user_id,
resource_type="knowledge",
resource_type='knowledge',
resource_id=knowledge.id,
permission="write",
permission='write',
user_group_ids=user_group_ids,
db=db,
):
return knowledge
return None
def get_knowledges_by_file_id(
self, file_id: str, db: Optional[Session] = None
) -> list[KnowledgeModel]:
def get_knowledges_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[KnowledgeModel]:
try:
with get_db_context(db) as db:
knowledges = (
@@ -461,9 +415,7 @@ class KnowledgeTable:
.all()
)
knowledge_ids = [k.id for k in knowledges]
grants_map = AccessGrants.get_grants_by_resources(
"knowledge", knowledge_ids, db=db
)
grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
return [
self._to_knowledge_model(
knowledge,
@@ -497,32 +449,26 @@ class KnowledgeTable:
primary_sort = File.updated_at.desc()
if filter:
query_key = filter.get("query")
query_key = filter.get('query')
if query_key:
query = query.filter(or_(File.filename.ilike(f"%{query_key}%")))
query = query.filter(or_(File.filename.ilike(f'%{query_key}%')))
view_option = filter.get("view_option")
if view_option == "created":
view_option = filter.get('view_option')
if view_option == 'created':
query = query.filter(KnowledgeFile.user_id == user_id)
elif view_option == "shared":
elif view_option == 'shared':
query = query.filter(KnowledgeFile.user_id != user_id)
order_by = filter.get("order_by")
direction = filter.get("direction")
is_asc = direction == "asc"
order_by = filter.get('order_by')
direction = filter.get('direction')
is_asc = direction == 'asc'
if order_by == "name":
primary_sort = (
File.filename.asc() if is_asc else File.filename.desc()
)
elif order_by == "created_at":
primary_sort = (
File.created_at.asc() if is_asc else File.created_at.desc()
)
elif order_by == "updated_at":
primary_sort = (
File.updated_at.asc() if is_asc else File.updated_at.desc()
)
if order_by == 'name':
primary_sort = File.filename.asc() if is_asc else File.filename.desc()
elif order_by == 'created_at':
primary_sort = File.created_at.asc() if is_asc else File.created_at.desc()
elif order_by == 'updated_at':
primary_sort = File.updated_at.asc() if is_asc else File.updated_at.desc()
# Apply sort with secondary key for deterministic pagination
query = query.order_by(primary_sort, File.id.asc())
@@ -542,13 +488,7 @@ class KnowledgeTable:
files.append(
FileUserResponse(
**FileModel.model_validate(file).model_dump(),
user=(
UserResponse(
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
)
)
@@ -557,9 +497,7 @@ class KnowledgeTable:
print(e)
return KnowledgeFileListResponse(items=[], total=0)
def get_files_by_id(
self, knowledge_id: str, db: Optional[Session] = None
) -> list[FileModel]:
def get_files_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileModel]:
try:
with get_db_context(db) as db:
files = (
@@ -572,9 +510,7 @@ class KnowledgeTable:
except Exception:
return []
def get_file_metadatas_by_id(
self, knowledge_id: str, db: Optional[Session] = None
) -> list[FileMetadataResponse]:
def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileMetadataResponse]:
try:
with get_db_context(db) as db:
files = self.get_files_by_id(knowledge_id, db=db)
@@ -592,12 +528,12 @@ class KnowledgeTable:
with get_db_context(db) as db:
knowledge_file = KnowledgeFileModel(
**{
"id": str(uuid.uuid4()),
"knowledge_id": knowledge_id,
"file_id": file_id,
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
'id': str(uuid.uuid4()),
'knowledge_id': knowledge_id,
'file_id': file_id,
'user_id': user_id,
'created_at': int(time.time()),
'updated_at': int(time.time()),
}
)
@@ -613,37 +549,24 @@ class KnowledgeTable:
except Exception:
return None
def has_file(
self, knowledge_id: str, file_id: str, db: Optional[Session] = None
) -> bool:
def has_file(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool:
"""Check whether a file belongs to a knowledge base."""
try:
with get_db_context(db) as db:
return (
db.query(KnowledgeFile)
.filter_by(knowledge_id=knowledge_id, file_id=file_id)
.first()
is not None
)
return db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).first() is not None
except Exception:
return False
def remove_file_from_knowledge_by_id(
self, knowledge_id: str, file_id: str, db: Optional[Session] = None
) -> bool:
def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
db.query(KnowledgeFile).filter_by(
knowledge_id=knowledge_id, file_id=file_id
).delete()
db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).delete()
db.commit()
return True
except Exception:
return False
def reset_knowledge_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[KnowledgeModel]:
def reset_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
try:
with get_db_context(db) as db:
# Delete all knowledge_file entries for this knowledge_id
@@ -653,7 +576,7 @@ class KnowledgeTable:
# Update the knowledge entry's updated_at timestamp
db.query(Knowledge).filter_by(id=id).update(
{
"updated_at": int(time.time()),
'updated_at': int(time.time()),
}
)
db.commit()
@@ -675,15 +598,13 @@ class KnowledgeTable:
knowledge = self.get_knowledge_by_id(id=id, db=db)
db.query(Knowledge).filter_by(id=id).update(
{
**form_data.model_dump(exclude={"access_grants"}),
"updated_at": int(time.time()),
**form_data.model_dump(exclude={'access_grants'}),
'updated_at': int(time.time()),
}
)
db.commit()
if form_data.access_grants is not None:
AccessGrants.set_access_grants(
"knowledge", id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('knowledge', id, form_data.access_grants, db=db)
return self.get_knowledge_by_id(id=id, db=db)
except Exception as e:
log.exception(e)
@@ -697,8 +618,8 @@ class KnowledgeTable:
knowledge = self.get_knowledge_by_id(id=id, db=db)
db.query(Knowledge).filter_by(id=id).update(
{
"data": data,
"updated_at": int(time.time()),
'data': data,
'updated_at': int(time.time()),
}
)
db.commit()
@@ -710,7 +631,7 @@ class KnowledgeTable:
def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
AccessGrants.revoke_all_access("knowledge", id, db=db)
AccessGrants.revoke_all_access('knowledge', id, db=db)
db.query(Knowledge).filter_by(id=id).delete()
db.commit()
return True
@@ -722,7 +643,7 @@ class KnowledgeTable:
try:
knowledge_ids = [row[0] for row in db.query(Knowledge.id).all()]
for knowledge_id in knowledge_ids:
AccessGrants.revoke_all_access("knowledge", knowledge_id, db=db)
AccessGrants.revoke_all_access('knowledge', knowledge_id, db=db)
db.query(Knowledge).delete()
db.commit()

View File

@@ -13,7 +13,7 @@ from sqlalchemy import BigInteger, Column, String, Text
class Memory(Base):
__tablename__ = "memory"
__tablename__ = 'memory'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
@@ -49,11 +49,11 @@ class MemoriesTable:
memory = MemoryModel(
**{
"id": id,
"user_id": user_id,
"content": content,
"created_at": int(time.time()),
"updated_at": int(time.time()),
'id': id,
'user_id': user_id,
'content': content,
'created_at': int(time.time()),
'updated_at': int(time.time()),
}
)
result = Memory(**memory.model_dump())
@@ -95,9 +95,7 @@ class MemoriesTable:
except Exception:
return None
def get_memories_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> list[MemoryModel]:
def get_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[MemoryModel]:
with get_db_context(db) as db:
try:
memories = db.query(Memory).filter_by(user_id=user_id).all()
@@ -105,9 +103,7 @@ class MemoriesTable:
except Exception:
return None
def get_memory_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[MemoryModel]:
def get_memory_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MemoryModel]:
with get_db_context(db) as db:
try:
memory = db.get(Memory, id)
@@ -126,9 +122,7 @@ class MemoriesTable:
except Exception:
return False
def delete_memories_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> bool:
def delete_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
try:
db.query(Memory).filter_by(user_id=user_id).delete()
@@ -138,9 +132,7 @@ class MemoriesTable:
except Exception:
return False
def delete_memory_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[Session] = None
) -> bool:
def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db:
try:
memory = db.get(Memory, id)

View File

@@ -21,7 +21,7 @@ from sqlalchemy.sql import exists
class MessageReaction(Base):
__tablename__ = "message_reaction"
__tablename__ = 'message_reaction'
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text)
message_id = Column(Text)
@@ -40,7 +40,7 @@ class MessageReactionModel(BaseModel):
class Message(Base):
__tablename__ = "message"
__tablename__ = 'message'
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text)
@@ -112,7 +112,7 @@ class MessageUserResponse(MessageModel):
class MessageUserSlimResponse(MessageUserResponse):
data: bool | None = None
@field_validator("data", mode="before")
@field_validator('data', mode='before')
def convert_data_to_bool(cls, v):
# No data or not a dict → False
if not isinstance(v, dict):
@@ -152,19 +152,19 @@ class MessageTable:
message = MessageModel(
**{
"id": id,
"user_id": user_id,
"channel_id": channel_id,
"reply_to_id": form_data.reply_to_id,
"parent_id": form_data.parent_id,
"is_pinned": False,
"pinned_at": None,
"pinned_by": None,
"content": form_data.content,
"data": form_data.data,
"meta": form_data.meta,
"created_at": ts,
"updated_at": ts,
'id': id,
'user_id': user_id,
'channel_id': channel_id,
'reply_to_id': form_data.reply_to_id,
'parent_id': form_data.parent_id,
'is_pinned': False,
'pinned_at': None,
'pinned_by': None,
'content': form_data.content,
'data': form_data.data,
'meta': form_data.meta,
'created_at': ts,
'updated_at': ts,
}
)
result = Message(**message.model_dump())
@@ -186,9 +186,7 @@ class MessageTable:
return None
reply_to_message = (
self.get_message_by_id(
message.reply_to_id, include_thread_replies=False, db=db
)
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
if message.reply_to_id
else None
)
@@ -200,22 +198,22 @@ class MessageTable:
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
# Check if message was sent by webhook (webhook info in meta takes precedence)
webhook_info = message.meta.get("webhook") if message.meta else None
if webhook_info and webhook_info.get("id"):
webhook_info = message.meta.get('webhook') if message.meta else None
if webhook_info and webhook_info.get('id'):
# Look up webhook by ID to get current name
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook:
user_info = {
"id": webhook.id,
"name": webhook.name,
"role": "webhook",
'id': webhook.id,
'name': webhook.name,
'role': 'webhook',
}
else:
# Webhook was deleted, use placeholder
user_info = {
"id": webhook_info.get("id"),
"name": "Deleted Webhook",
"role": "webhook",
'id': webhook_info.get('id'),
'name': 'Deleted Webhook',
'role': 'webhook',
}
else:
user = Users.get_user_by_id(message.user_id, db=db)
@@ -224,79 +222,57 @@ class MessageTable:
return MessageResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"user": user_info,
"reply_to_message": (
reply_to_message.model_dump() if reply_to_message else None
),
"latest_reply_at": (
thread_replies[0].created_at if thread_replies else None
),
"reply_count": len(thread_replies),
"reactions": reactions,
'user': user_info,
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
'latest_reply_at': (thread_replies[0].created_at if thread_replies else None),
'reply_count': len(thread_replies),
'reactions': reactions,
}
)
def get_thread_replies_by_message_id(
self, id: str, db: Optional[Session] = None
) -> list[MessageReplyToResponse]:
def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]:
with get_db_context(db) as db:
all_messages = (
db.query(Message)
.filter_by(parent_id=id)
.order_by(Message.created_at.desc())
.all()
)
all_messages = db.query(Message).filter_by(parent_id=id).order_by(Message.created_at.desc()).all()
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(
message.reply_to_id, include_thread_replies=False, db=db
)
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
if message.reply_to_id
else None
)
webhook_info = message.meta.get("webhook") if message.meta else None
webhook_info = message.meta.get('webhook') if message.meta else None
user_info = None
if webhook_info and webhook_info.get("id"):
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
if webhook_info and webhook_info.get('id'):
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook:
user_info = {
"id": webhook.id,
"name": webhook.name,
"role": "webhook",
'id': webhook.id,
'name': webhook.name,
'role': 'webhook',
}
else:
user_info = {
"id": webhook_info.get("id"),
"name": "Deleted Webhook",
"role": "webhook",
'id': webhook_info.get('id'),
'name': 'Deleted Webhook',
'role': 'webhook',
}
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"user": user_info,
"reply_to_message": (
reply_to_message.model_dump()
if reply_to_message
else None
),
'user': user_info,
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
}
)
)
return messages
def get_reply_user_ids_by_message_id(
self, id: str, db: Optional[Session] = None
) -> list[str]:
def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]:
with get_db_context(db) as db:
return [
message.user_id
for message in db.query(Message).filter_by(parent_id=id).all()
]
return [message.user_id for message in db.query(Message).filter_by(parent_id=id).all()]
def get_messages_by_channel_id(
self,
@@ -318,40 +294,34 @@ class MessageTable:
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(
message.reply_to_id, include_thread_replies=False, db=db
)
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
if message.reply_to_id
else None
)
webhook_info = message.meta.get("webhook") if message.meta else None
webhook_info = message.meta.get('webhook') if message.meta else None
user_info = None
if webhook_info and webhook_info.get("id"):
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
if webhook_info and webhook_info.get('id'):
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook:
user_info = {
"id": webhook.id,
"name": webhook.name,
"role": "webhook",
'id': webhook.id,
'name': webhook.name,
'role': 'webhook',
}
else:
user_info = {
"id": webhook_info.get("id"),
"name": "Deleted Webhook",
"role": "webhook",
'id': webhook_info.get('id'),
'name': 'Deleted Webhook',
'role': 'webhook',
}
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"user": user_info,
"reply_to_message": (
reply_to_message.model_dump()
if reply_to_message
else None
),
'user': user_info,
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
}
)
)
@@ -387,55 +357,42 @@ class MessageTable:
messages = []
for message in all_messages:
reply_to_message = (
self.get_message_by_id(
message.reply_to_id, include_thread_replies=False, db=db
)
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
if message.reply_to_id
else None
)
webhook_info = message.meta.get("webhook") if message.meta else None
webhook_info = message.meta.get('webhook') if message.meta else None
user_info = None
if webhook_info and webhook_info.get("id"):
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
if webhook_info and webhook_info.get('id'):
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook:
user_info = {
"id": webhook.id,
"name": webhook.name,
"role": "webhook",
'id': webhook.id,
'name': webhook.name,
'role': 'webhook',
}
else:
user_info = {
"id": webhook_info.get("id"),
"name": "Deleted Webhook",
"role": "webhook",
'id': webhook_info.get('id'),
'name': 'Deleted Webhook',
'role': 'webhook',
}
messages.append(
MessageReplyToResponse.model_validate(
{
**MessageModel.model_validate(message).model_dump(),
"user": user_info,
"reply_to_message": (
reply_to_message.model_dump()
if reply_to_message
else None
),
'user': user_info,
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
}
)
)
return messages
def get_last_message_by_channel_id(
self, channel_id: str, db: Optional[Session] = None
) -> Optional[MessageModel]:
def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]:
with get_db_context(db) as db:
message = (
db.query(Message)
.filter_by(channel_id=channel_id)
.order_by(Message.created_at.desc())
.first()
)
message = db.query(Message).filter_by(channel_id=channel_id).order_by(Message.created_at.desc()).first()
return MessageModel.model_validate(message) if message else None
def get_pinned_messages_by_channel_id(
@@ -513,11 +470,7 @@ class MessageTable:
) -> Optional[MessageReactionModel]:
with get_db_context(db) as db:
# check for existing reaction
existing_reaction = (
db.query(MessageReaction)
.filter_by(message_id=id, user_id=user_id, name=name)
.first()
)
existing_reaction = db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).first()
if existing_reaction:
return MessageReactionModel.model_validate(existing_reaction)
@@ -535,9 +488,7 @@ class MessageTable:
db.refresh(result)
return MessageReactionModel.model_validate(result) if result else None
def get_reactions_by_message_id(
self, id: str, db: Optional[Session] = None
) -> list[Reactions]:
def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]:
with get_db_context(db) as db:
# JOIN User so all user info is fetched in one query
results = (
@@ -552,18 +503,18 @@ class MessageTable:
for reaction, user in results:
if reaction.name not in reactions:
reactions[reaction.name] = {
"name": reaction.name,
"users": [],
"count": 0,
'name': reaction.name,
'users': [],
'count': 0,
}
reactions[reaction.name]["users"].append(
reactions[reaction.name]['users'].append(
{
"id": user.id,
"name": user.name,
'id': user.id,
'name': user.name,
}
)
reactions[reaction.name]["count"] += 1
reactions[reaction.name]['count'] += 1
return [Reactions(**reaction) for reaction in reactions.values()]
@@ -571,9 +522,7 @@ class MessageTable:
self, id: str, user_id: str, name: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db:
db.query(MessageReaction).filter_by(
message_id=id, user_id=user_id, name=name
).delete()
db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).delete()
db.commit()
return True
@@ -612,21 +561,15 @@ class MessageTable:
with get_db_context(db) as db:
query_builder = db.query(Message).filter(
Message.channel_id.in_(channel_ids),
Message.content.ilike(f"%{query}%"),
Message.content.ilike(f'%{query}%'),
)
if start_timestamp:
query_builder = query_builder.filter(
Message.created_at >= start_timestamp
)
query_builder = query_builder.filter(Message.created_at >= start_timestamp)
if end_timestamp:
query_builder = query_builder.filter(
Message.created_at <= end_timestamp
)
query_builder = query_builder.filter(Message.created_at <= end_timestamp)
messages = (
query_builder.order_by(Message.created_at.desc()).limit(limit).all()
)
messages = query_builder.order_by(Message.created_at.desc()).limit(limit).all()
return [MessageModel.model_validate(msg) for msg in messages]

View File

@@ -28,13 +28,13 @@ log = logging.getLogger(__name__)
# ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel):
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
pass
# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/static/favicon.png"
profile_image_url: Optional[str] = '/static/favicon.png'
description: Optional[str] = None
"""
@@ -43,13 +43,13 @@ class ModelMeta(BaseModel):
capabilities: Optional[dict] = None
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
pass
class Model(Base):
__tablename__ = "model"
__tablename__ = 'model'
id = Column(Text, primary_key=True, unique=True)
"""
@@ -139,10 +139,8 @@ class ModelForm(BaseModel):
class ModelsTable:
def _get_access_grants(
self, model_id: str, db: Optional[Session] = None
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("model", model_id, db=db)
def _get_access_grants(self, model_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource('model', model_id, db=db)
def _to_model_model(
self,
@@ -150,13 +148,9 @@ class ModelsTable:
access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None,
) -> ModelModel:
model_data = ModelModel.model_validate(model).model_dump(
exclude={"access_grants"}
)
model_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(model_data["id"], db=db)
model_data = ModelModel.model_validate(model).model_dump(exclude={'access_grants'})
model_data['access_grants'] = (
access_grants if access_grants is not None else self._get_access_grants(model_data['id'], db=db)
)
return ModelModel.model_validate(model_data)
@@ -167,37 +161,32 @@ class ModelsTable:
with get_db_context(db) as db:
result = Model(
**{
**form_data.model_dump(exclude={"access_grants"}),
"user_id": user_id,
"created_at": int(time.time()),
"updated_at": int(time.time()),
**form_data.model_dump(exclude={'access_grants'}),
'user_id': user_id,
'created_at': int(time.time()),
'updated_at': int(time.time()),
}
)
db.add(result)
db.commit()
db.refresh(result)
AccessGrants.set_access_grants(
"model", result.id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('model', result.id, form_data.access_grants, db=db)
if result:
return self._to_model_model(result, db=db)
else:
return None
except Exception as e:
log.exception(f"Failed to insert a new model: {e}")
log.exception(f'Failed to insert a new model: {e}')
return None
def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]:
with get_db_context(db) as db:
all_models = db.query(Model).all()
model_ids = [model.id for model in all_models]
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
return [
self._to_model_model(
model, access_grants=grants_map.get(model.id, []), db=db
)
for model in all_models
self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models
]
def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]:
@@ -209,7 +198,7 @@ class ModelsTable:
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users}
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
models = []
for model in all_models:
@@ -222,7 +211,7 @@ class ModelsTable:
access_grants=grants_map.get(model.id, []),
db=db,
).model_dump(),
"user": user.model_dump() if user else None,
'user': user.model_dump() if user else None,
}
)
)
@@ -232,28 +221,23 @@ class ModelsTable:
with get_db_context(db) as db:
all_models = db.query(Model).filter(Model.base_model_id == None).all()
model_ids = [model.id for model in all_models]
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
return [
self._to_model_model(
model, access_grants=grants_map.get(model.id, []), db=db
)
for model in all_models
self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models
]
def get_models_by_user_id(
self, user_id: str, permission: str = "write", db: Optional[Session] = None
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
) -> list[ModelUserResponse]:
models = self.get_models(db=db)
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return [
model
for model in models
if model.user_id == user_id
or AccessGrants.has_access(
user_id=user_id,
resource_type="model",
resource_type='model',
resource_id=model.id,
permission=permission,
user_group_ids=user_group_ids,
@@ -261,13 +245,13 @@ class ModelsTable:
)
]
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
return AccessGrants.has_permission_filter(
db=db,
query=query,
DocumentModel=Model,
filter=filter,
resource_type="model",
resource_type='model',
permission=permission,
)
@@ -285,22 +269,22 @@ class ModelsTable:
query = query.filter(Model.base_model_id != None)
if filter:
query_key = filter.get("query")
query_key = filter.get('query')
if query_key:
query = query.filter(
or_(
Model.name.ilike(f"%{query_key}%"),
Model.base_model_id.ilike(f"%{query_key}%"),
User.name.ilike(f"%{query_key}%"),
User.email.ilike(f"%{query_key}%"),
User.username.ilike(f"%{query_key}%"),
Model.name.ilike(f'%{query_key}%'),
Model.base_model_id.ilike(f'%{query_key}%'),
User.name.ilike(f'%{query_key}%'),
User.email.ilike(f'%{query_key}%'),
User.username.ilike(f'%{query_key}%'),
)
)
view_option = filter.get("view_option")
if view_option == "created":
view_option = filter.get('view_option')
if view_option == 'created':
query = query.filter(Model.user_id == user_id)
elif view_option == "shared":
elif view_option == 'shared':
query = query.filter(Model.user_id != user_id)
# Apply access control filtering
@@ -308,10 +292,10 @@ class ModelsTable:
db,
query,
filter,
permission="read",
permission='read',
)
tag = filter.get("tag")
tag = filter.get('tag')
if tag:
# TODO: This is a simple implementation and should be improved for performance
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
@@ -319,21 +303,21 @@ class ModelsTable:
query = query.filter(meta_text.like(like_pattern))
order_by = filter.get("order_by")
direction = filter.get("direction")
order_by = filter.get('order_by')
direction = filter.get('direction')
if order_by == "name":
if direction == "asc":
if order_by == 'name':
if direction == 'asc':
query = query.order_by(Model.name.asc())
else:
query = query.order_by(Model.name.desc())
elif order_by == "created_at":
if direction == "asc":
elif order_by == 'created_at':
if direction == 'asc':
query = query.order_by(Model.created_at.asc())
else:
query = query.order_by(Model.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
elif order_by == 'updated_at':
if direction == 'asc':
query = query.order_by(Model.updated_at.asc())
else:
query = query.order_by(Model.updated_at.desc())
@@ -352,7 +336,7 @@ class ModelsTable:
items = query.all()
model_ids = [model.id for model, _ in items]
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
models = []
for model, user in items:
@@ -363,19 +347,13 @@ class ModelsTable:
access_grants=grants_map.get(model.id, []),
db=db,
).model_dump(),
user=(
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
)
)
return ModelListResponse(items=models, total=total)
def get_model_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[ModelModel]:
def get_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
try:
with get_db_context(db) as db:
model = db.get(Model, id)
@@ -383,16 +361,12 @@ class ModelsTable:
except Exception:
return None
def get_models_by_ids(
self, ids: list[str], db: Optional[Session] = None
) -> list[ModelModel]:
def get_models_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[ModelModel]:
try:
with get_db_context(db) as db:
models = db.query(Model).filter(Model.id.in_(ids)).all()
model_ids = [model.id for model in models]
grants_map = AccessGrants.get_grants_by_resources(
"model", model_ids, db=db
)
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
return [
self._to_model_model(
model,
@@ -404,9 +378,7 @@ class ModelsTable:
except Exception:
return []
def toggle_model_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[ModelModel]:
def toggle_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
with get_db_context(db) as db:
try:
model = db.query(Model).filter_by(id=id).first()
@@ -422,30 +394,26 @@ class ModelsTable:
except Exception:
return None
def update_model_by_id(
self, id: str, model: ModelForm, db: Optional[Session] = None
) -> Optional[ModelModel]:
def update_model_by_id(self, id: str, model: ModelForm, db: Optional[Session] = None) -> Optional[ModelModel]:
try:
with get_db_context(db) as db:
# update only the fields that are present in the model
data = model.model_dump(exclude={"id", "access_grants"})
data = model.model_dump(exclude={'id', 'access_grants'})
result = db.query(Model).filter_by(id=id).update(data)
db.commit()
if model.access_grants is not None:
AccessGrants.set_access_grants(
"model", id, model.access_grants, db=db
)
AccessGrants.set_access_grants('model', id, model.access_grants, db=db)
return self.get_model_by_id(id, db=db)
except Exception as e:
log.exception(f"Failed to update the model by id {id}: {e}")
log.exception(f'Failed to update the model by id {id}: {e}')
return None
def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
AccessGrants.revoke_all_access("model", id, db=db)
AccessGrants.revoke_all_access('model', id, db=db)
db.query(Model).filter_by(id=id).delete()
db.commit()
@@ -458,7 +426,7 @@ class ModelsTable:
with get_db_context(db) as db:
model_ids = [row[0] for row in db.query(Model.id).all()]
for model_id in model_ids:
AccessGrants.revoke_all_access("model", model_id, db=db)
AccessGrants.revoke_all_access('model', model_id, db=db)
db.query(Model).delete()
db.commit()
@@ -466,9 +434,7 @@ class ModelsTable:
except Exception:
return False
def sync_models(
self, user_id: str, models: list[ModelModel], db: Optional[Session] = None
) -> list[ModelModel]:
def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[Session] = None) -> list[ModelModel]:
try:
with get_db_context(db) as db:
# Get existing models
@@ -483,37 +449,33 @@ class ModelsTable:
if model.id in existing_ids:
db.query(Model).filter_by(id=model.id).update(
{
**model.model_dump(exclude={"access_grants"}),
"user_id": user_id,
"updated_at": int(time.time()),
**model.model_dump(exclude={'access_grants'}),
'user_id': user_id,
'updated_at': int(time.time()),
}
)
else:
new_model = Model(
**{
**model.model_dump(exclude={"access_grants"}),
"user_id": user_id,
"updated_at": int(time.time()),
**model.model_dump(exclude={'access_grants'}),
'user_id': user_id,
'updated_at': int(time.time()),
}
)
db.add(new_model)
AccessGrants.set_access_grants(
"model", model.id, model.access_grants, db=db
)
AccessGrants.set_access_grants('model', model.id, model.access_grants, db=db)
# Remove models that are no longer present
for model in existing_models:
if model.id not in new_model_ids:
AccessGrants.revoke_all_access("model", model.id, db=db)
AccessGrants.revoke_all_access('model', model.id, db=db)
db.delete(model)
db.commit()
all_models = db.query(Model).all()
model_ids = [model.id for model in all_models]
grants_map = AccessGrants.get_grants_by_resources(
"model", model_ids, db=db
)
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
return [
self._to_model_model(
model,
@@ -523,7 +485,7 @@ class ModelsTable:
for model in all_models
]
except Exception as e:
log.exception(f"Error syncing models for user {user_id}: {e}")
log.exception(f'Error syncing models for user {user_id}: {e}')
return []

View File

@@ -21,7 +21,7 @@ from sqlalchemy import or_, func, cast
class Note(Base):
__tablename__ = "note"
__tablename__ = 'note'
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text)
@@ -88,10 +88,8 @@ class NoteListResponse(BaseModel):
class NoteTable:
def _get_access_grants(
self, note_id: str, db: Optional[Session] = None
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("note", note_id, db=db)
def _get_access_grants(self, note_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource('note', note_id, db=db)
def _to_note_model(
self,
@@ -99,51 +97,43 @@ class NoteTable:
access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None,
) -> NoteModel:
note_data = NoteModel.model_validate(note).model_dump(exclude={"access_grants"})
note_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(note_data["id"], db=db)
note_data = NoteModel.model_validate(note).model_dump(exclude={'access_grants'})
note_data['access_grants'] = (
access_grants if access_grants is not None else self._get_access_grants(note_data['id'], db=db)
)
return NoteModel.model_validate(note_data)
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
return AccessGrants.has_permission_filter(
db=db,
query=query,
DocumentModel=Note,
filter=filter,
resource_type="note",
resource_type='note',
permission=permission,
)
def insert_new_note(
self, user_id: str, form_data: NoteForm, db: Optional[Session] = None
) -> Optional[NoteModel]:
def insert_new_note(self, user_id: str, form_data: NoteForm, db: Optional[Session] = None) -> Optional[NoteModel]:
with get_db_context(db) as db:
note = NoteModel(
**{
"id": str(uuid.uuid4()),
"user_id": user_id,
**form_data.model_dump(exclude={"access_grants"}),
"created_at": int(time.time_ns()),
"updated_at": int(time.time_ns()),
"access_grants": [],
'id': str(uuid.uuid4()),
'user_id': user_id,
**form_data.model_dump(exclude={'access_grants'}),
'created_at': int(time.time_ns()),
'updated_at': int(time.time_ns()),
'access_grants': [],
}
)
new_note = Note(**note.model_dump(exclude={"access_grants"}))
new_note = Note(**note.model_dump(exclude={'access_grants'}))
db.add(new_note)
db.commit()
AccessGrants.set_access_grants(
"note", note.id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('note', note.id, form_data.access_grants, db=db)
return self._to_note_model(new_note, db=db)
def get_notes(
self, skip: int = 0, limit: int = 50, db: Optional[Session] = None
) -> list[NoteModel]:
def get_notes(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[NoteModel]:
with get_db_context(db) as db:
query = db.query(Note).order_by(Note.updated_at.desc())
if skip is not None:
@@ -152,13 +142,8 @@ class NoteTable:
query = query.limit(limit)
notes = query.all()
note_ids = [note.id for note in notes]
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db)
return [
self._to_note_model(
note, access_grants=grants_map.get(note.id, []), db=db
)
for note in notes
]
grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes]
def search_notes(
self,
@@ -171,36 +156,32 @@ class NoteTable:
with get_db_context(db) as db:
query = db.query(Note, User).outerjoin(User, User.id == Note.user_id)
if filter:
query_key = filter.get("query")
query_key = filter.get('query')
if query_key:
# Normalize search by removing hyphens and spaces (e.g., "todo" matches "to-do" and "to do")
normalized_query = query_key.replace("-", "").replace(" ", "")
normalized_query = query_key.replace('-', '').replace(' ', '')
query = query.filter(
or_(
func.replace(func.replace(Note.title, '-', ''), ' ', '').ilike(f'%{normalized_query}%'),
func.replace(
func.replace(Note.title, "-", ""), " ", ""
).ilike(f"%{normalized_query}%"),
func.replace(
func.replace(
cast(Note.data["content"]["md"], Text), "-", ""
),
" ",
"",
).ilike(f"%{normalized_query}%"),
func.replace(cast(Note.data['content']['md'], Text), '-', ''),
' ',
'',
).ilike(f'%{normalized_query}%'),
)
)
view_option = filter.get("view_option")
if view_option == "created":
view_option = filter.get('view_option')
if view_option == 'created':
query = query.filter(Note.user_id == user_id)
elif view_option == "shared":
elif view_option == 'shared':
query = query.filter(Note.user_id != user_id)
# Apply access control filtering
if "permission" in filter:
permission = filter["permission"]
if 'permission' in filter:
permission = filter['permission']
else:
permission = "write"
permission = 'write'
query = self._has_permission(
db,
@@ -209,21 +190,21 @@ class NoteTable:
permission=permission,
)
order_by = filter.get("order_by")
direction = filter.get("direction")
order_by = filter.get('order_by')
direction = filter.get('direction')
if order_by == "name":
if direction == "asc":
if order_by == 'name':
if direction == 'asc':
query = query.order_by(Note.title.asc())
else:
query = query.order_by(Note.title.desc())
elif order_by == "created_at":
if direction == "asc":
elif order_by == 'created_at':
if direction == 'asc':
query = query.order_by(Note.created_at.asc())
else:
query = query.order_by(Note.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
elif order_by == 'updated_at':
if direction == 'asc':
query = query.order_by(Note.updated_at.asc())
else:
query = query.order_by(Note.updated_at.desc())
@@ -244,7 +225,7 @@ class NoteTable:
items = query.all()
note_ids = [note.id for note, _ in items]
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db)
grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
notes = []
for note, user in items:
@@ -255,11 +236,7 @@ class NoteTable:
access_grants=grants_map.get(note.id, []),
db=db,
).model_dump(),
user=(
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
)
)
@@ -268,20 +245,16 @@ class NoteTable:
def get_notes_by_user_id(
self,
user_id: str,
permission: str = "read",
permission: str = 'read',
skip: int = 0,
limit: int = 50,
db: Optional[Session] = None,
) -> list[NoteModel]:
with get_db_context(db) as db:
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
]
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
query = db.query(Note).order_by(Note.updated_at.desc())
query = self._has_permission(
db, query, {"user_id": user_id, "group_ids": user_group_ids}, permission
)
query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids}, permission)
if skip is not None:
query = query.offset(skip)
@@ -290,17 +263,10 @@ class NoteTable:
notes = query.all()
note_ids = [note.id for note in notes]
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db)
return [
self._to_note_model(
note, access_grants=grants_map.get(note.id, []), db=db
)
for note in notes
]
grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes]
def get_note_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[NoteModel]:
def get_note_by_id(self, id: str, db: Optional[Session] = None) -> Optional[NoteModel]:
with get_db_context(db) as db:
note = db.query(Note).filter(Note.id == id).first()
return self._to_note_model(note, db=db) if note else None
@@ -315,17 +281,15 @@ class NoteTable:
form_data = form_data.model_dump(exclude_unset=True)
if "title" in form_data:
note.title = form_data["title"]
if "data" in form_data:
note.data = {**note.data, **form_data["data"]}
if "meta" in form_data:
note.meta = {**note.meta, **form_data["meta"]}
if 'title' in form_data:
note.title = form_data['title']
if 'data' in form_data:
note.data = {**note.data, **form_data['data']}
if 'meta' in form_data:
note.meta = {**note.meta, **form_data['meta']}
if "access_grants" in form_data:
AccessGrants.set_access_grants(
"note", id, form_data["access_grants"], db=db
)
if 'access_grants' in form_data:
AccessGrants.set_access_grants('note', id, form_data['access_grants'], db=db)
note.updated_at = int(time.time_ns())
@@ -335,7 +299,7 @@ class NoteTable:
def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
AccessGrants.revoke_all_access("note", id, db=db)
AccessGrants.revoke_all_access('note', id, db=db)
db.query(Note).filter(Note.id == id).delete()
db.commit()
return True

View File

@@ -23,23 +23,21 @@ log = logging.getLogger(__name__)
class OAuthSession(Base):
__tablename__ = "oauth_session"
__tablename__ = 'oauth_session'
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text, nullable=False)
provider = Column(Text, nullable=False)
token = Column(
Text, nullable=False
) # JSON with access_token, id_token, refresh_token
token = Column(Text, nullable=False) # JSON with access_token, id_token, refresh_token
expires_at = Column(BigInteger, nullable=False)
created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False)
# Add indexes for better performance
__table_args__ = (
Index("idx_oauth_session_user_id", "user_id"),
Index("idx_oauth_session_expires_at", "expires_at"),
Index("idx_oauth_session_user_provider", "user_id", "provider"),
Index('idx_oauth_session_user_id', 'user_id'),
Index('idx_oauth_session_expires_at', 'expires_at'),
Index('idx_oauth_session_user_provider', 'user_id', 'provider'),
)
@@ -71,7 +69,7 @@ class OAuthSessionTable:
def __init__(self):
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
if not self.encryption_key:
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
raise Exception('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set')
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
if len(self.encryption_key) != 44:
@@ -83,7 +81,7 @@ class OAuthSessionTable:
try:
self.fernet = Fernet(self.encryption_key)
except Exception as e:
log.error(f"Error initializing Fernet with provided key: {e}")
log.error(f'Error initializing Fernet with provided key: {e}')
raise
def _encrypt_token(self, token) -> str:
@@ -93,7 +91,7 @@ class OAuthSessionTable:
encrypted = self.fernet.encrypt(token_json.encode()).decode()
return encrypted
except Exception as e:
log.error(f"Error encrypting tokens: {e}")
log.error(f'Error encrypting tokens: {e}')
raise
def _decrypt_token(self, token: str):
@@ -102,7 +100,7 @@ class OAuthSessionTable:
decrypted = self.fernet.decrypt(token.encode()).decode()
return json.loads(decrypted)
except Exception as e:
log.error(f"Error decrypting tokens: {type(e).__name__}: {e}")
log.error(f'Error decrypting tokens: {type(e).__name__}: {e}')
raise
def create_session(
@@ -120,13 +118,13 @@ class OAuthSessionTable:
result = OAuthSession(
**{
"id": id,
"user_id": user_id,
"provider": provider,
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"created_at": current_time,
"updated_at": current_time,
'id': id,
'user_id': user_id,
'provider': provider,
'token': self._encrypt_token(token),
'expires_at': token.get('expires_at'),
'created_at': current_time,
'updated_at': current_time,
}
)
@@ -141,12 +139,10 @@ class OAuthSessionTable:
else:
return None
except Exception as e:
log.error(f"Error creating OAuth session: {e}")
log.error(f'Error creating OAuth session: {e}')
return None
def get_session_by_id(
self, session_id: str, db: Optional[Session] = None
) -> Optional[OAuthSessionModel]:
def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID"""
try:
with get_db_context(db) as db:
@@ -158,7 +154,7 @@ class OAuthSessionTable:
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
log.error(f'Error getting OAuth session by ID: {e}')
return None
def get_session_by_id_and_user_id(
@@ -167,11 +163,7 @@ class OAuthSessionTable:
"""Get OAuth session by ID and user ID"""
try:
with get_db_context(db) as db:
session = (
db.query(OAuthSession)
.filter_by(id=session_id, user_id=user_id)
.first()
)
session = db.query(OAuthSession).filter_by(id=session_id, user_id=user_id).first()
if session:
db.expunge(session)
session.token = self._decrypt_token(session.token)
@@ -179,7 +171,7 @@ class OAuthSessionTable:
return None
except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}")
log.error(f'Error getting OAuth session by ID: {e}')
return None
def get_session_by_provider_and_user_id(
@@ -201,12 +193,10 @@ class OAuthSessionTable:
return None
except Exception as e:
log.error(f"Error getting OAuth session by provider and user ID: {e}")
log.error(f'Error getting OAuth session by provider and user ID: {e}')
return None
def get_sessions_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> List[OAuthSessionModel]:
def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user"""
try:
with get_db_context(db) as db:
@@ -220,7 +210,7 @@ class OAuthSessionTable:
results.append(OAuthSessionModel.model_validate(session))
except Exception as e:
log.warning(
f"Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}"
f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}'
)
db.query(OAuthSession).filter_by(id=session.id).delete()
db.commit()
@@ -228,7 +218,7 @@ class OAuthSessionTable:
return results
except Exception as e:
log.error(f"Error getting OAuth sessions by user ID: {e}")
log.error(f'Error getting OAuth sessions by user ID: {e}')
return []
def update_session_by_id(
@@ -241,9 +231,9 @@ class OAuthSessionTable:
db.query(OAuthSession).filter_by(id=session_id).update(
{
"token": self._encrypt_token(token),
"expires_at": token.get("expires_at"),
"updated_at": current_time,
'token': self._encrypt_token(token),
'expires_at': token.get('expires_at'),
'updated_at': current_time,
}
)
db.commit()
@@ -256,12 +246,10 @@ class OAuthSessionTable:
return None
except Exception as e:
log.error(f"Error updating OAuth session tokens: {e}")
log.error(f'Error updating OAuth session tokens: {e}')
return None
def delete_session_by_id(
self, session_id: str, db: Optional[Session] = None
) -> bool:
def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool:
"""Delete an OAuth session"""
try:
with get_db_context(db) as db:
@@ -269,12 +257,10 @@ class OAuthSessionTable:
db.commit()
return result > 0
except Exception as e:
log.error(f"Error deleting OAuth session: {e}")
log.error(f'Error deleting OAuth session: {e}')
return False
def delete_sessions_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> bool:
def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
"""Delete all OAuth sessions for a user"""
try:
with get_db_context(db) as db:
@@ -282,12 +268,10 @@ class OAuthSessionTable:
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by user ID: {e}")
log.error(f'Error deleting OAuth sessions by user ID: {e}')
return False
def delete_sessions_by_provider(
self, provider: str, db: Optional[Session] = None
) -> bool:
def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool:
"""Delete all OAuth sessions for a provider"""
try:
with get_db_context(db) as db:
@@ -295,7 +279,7 @@ class OAuthSessionTable:
db.commit()
return True
except Exception as e:
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
log.error(f'Error deleting OAuth sessions by provider {provider}: {e}')
return False

View File

@@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Column, Text, JSON, Index
class PromptHistory(Base):
__tablename__ = "prompt_history"
__tablename__ = 'prompt_history'
id = Column(Text, primary_key=True)
prompt_id = Column(Text, nullable=False, index=True)
@@ -100,11 +100,7 @@ class PromptHistoryTable:
return [
PromptHistoryResponse(
**PromptHistoryModel.model_validate(entry).model_dump(),
user=(
users_dict.get(entry.user_id).model_dump()
if users_dict.get(entry.user_id)
else None
),
user=(users_dict.get(entry.user_id).model_dump() if users_dict.get(entry.user_id) else None),
)
for entry in entries
]
@@ -116,9 +112,7 @@ class PromptHistoryTable:
) -> Optional[PromptHistoryModel]:
"""Get a specific history entry by ID."""
with get_db_context(db) as db:
entry = (
db.query(PromptHistory).filter(PromptHistory.id == history_id).first()
)
entry = db.query(PromptHistory).filter(PromptHistory.id == history_id).first()
if entry:
return PromptHistoryModel.model_validate(entry)
return None
@@ -147,11 +141,7 @@ class PromptHistoryTable:
) -> int:
"""Get the number of history entries for a prompt."""
with get_db_context(db) as db:
return (
db.query(PromptHistory)
.filter(PromptHistory.prompt_id == prompt_id)
.count()
)
return db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).count()
def compute_diff(
self,
@@ -161,9 +151,7 @@ class PromptHistoryTable:
) -> Optional[dict]:
"""Compute diff between two history entries."""
with get_db_context(db) as db:
from_entry = (
db.query(PromptHistory).filter(PromptHistory.id == from_id).first()
)
from_entry = db.query(PromptHistory).filter(PromptHistory.id == from_id).first()
to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first()
if not from_entry or not to_entry:
@@ -173,26 +161,26 @@ class PromptHistoryTable:
to_snapshot = to_entry.snapshot
# Compute diff for content field
from_content = from_snapshot.get("content", "")
to_content = to_snapshot.get("content", "")
from_content = from_snapshot.get('content', '')
to_content = to_snapshot.get('content', '')
diff_lines = list(
difflib.unified_diff(
from_content.splitlines(keepends=True),
to_content.splitlines(keepends=True),
fromfile=f"v{from_id[:8]}",
tofile=f"v{to_id[:8]}",
lineterm="",
fromfile=f'v{from_id[:8]}',
tofile=f'v{to_id[:8]}',
lineterm='',
)
)
return {
"from_id": from_id,
"to_id": to_id,
"from_snapshot": from_snapshot,
"to_snapshot": to_snapshot,
"content_diff": diff_lines,
"name_changed": from_snapshot.get("name") != to_snapshot.get("name"),
'from_id': from_id,
'to_id': to_id,
'from_snapshot': from_snapshot,
'to_snapshot': to_snapshot,
'content_diff': diff_lines,
'name_changed': from_snapshot.get('name') != to_snapshot.get('name'),
}
def delete_history_by_prompt_id(
@@ -202,9 +190,7 @@ class PromptHistoryTable:
) -> bool:
"""Delete all history entries for a prompt."""
with get_db_context(db) as db:
db.query(PromptHistory).filter(
PromptHistory.prompt_id == prompt_id
).delete()
db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).delete()
db.commit()
return True

View File

@@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, or_, fun
class Prompt(Base):
__tablename__ = "prompt"
__tablename__ = 'prompt'
id = Column(Text, primary_key=True)
command = Column(String, unique=True, index=True)
@@ -77,7 +77,6 @@ class PromptAccessListResponse(BaseModel):
class PromptForm(BaseModel):
command: str
name: str # Changed from title
content: str
@@ -91,10 +90,8 @@ class PromptForm(BaseModel):
class PromptsTable:
def _get_access_grants(
self, prompt_id: str, db: Optional[Session] = None
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("prompt", prompt_id, db=db)
def _get_access_grants(self, prompt_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db)
def _to_prompt_model(
self,
@@ -102,13 +99,9 @@ class PromptsTable:
access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None,
) -> PromptModel:
prompt_data = PromptModel.model_validate(prompt).model_dump(
exclude={"access_grants"}
)
prompt_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(prompt_data["id"], db=db)
prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'})
prompt_data['access_grants'] = (
access_grants if access_grants is not None else self._get_access_grants(prompt_data['id'], db=db)
)
return PromptModel.model_validate(prompt_data)
@@ -135,26 +128,22 @@ class PromptsTable:
try:
with get_db_context(db) as db:
result = Prompt(**prompt.model_dump(exclude={"access_grants"}))
result = Prompt(**prompt.model_dump(exclude={'access_grants'}))
db.add(result)
db.commit()
db.refresh(result)
AccessGrants.set_access_grants(
"prompt", prompt_id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('prompt', prompt_id, form_data.access_grants, db=db)
if result:
current_access_grants = self._get_access_grants(prompt_id, db=db)
snapshot = {
"name": form_data.name,
"content": form_data.content,
"command": form_data.command,
"data": form_data.data or {},
"meta": form_data.meta or {},
"tags": form_data.tags or [],
"access_grants": [
grant.model_dump() for grant in current_access_grants
],
'name': form_data.name,
'content': form_data.content,
'command': form_data.command,
'data': form_data.data or {},
'meta': form_data.meta or {},
'tags': form_data.tags or [],
'access_grants': [grant.model_dump() for grant in current_access_grants],
}
history_entry = PromptHistories.create_history_entry(
@@ -162,7 +151,7 @@ class PromptsTable:
snapshot=snapshot,
user_id=user_id,
parent_id=None, # Initial commit has no parent
commit_message=form_data.commit_message or "Initial version",
commit_message=form_data.commit_message or 'Initial version',
db=db,
)
@@ -178,9 +167,7 @@ class PromptsTable:
except Exception:
return None
def get_prompt_by_id(
self, prompt_id: str, db: Optional[Session] = None
) -> Optional[PromptModel]:
def get_prompt_by_id(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]:
"""Get prompt by UUID."""
try:
with get_db_context(db) as db:
@@ -191,9 +178,7 @@ class PromptsTable:
except Exception:
return None
def get_prompt_by_command(
self, command: str, db: Optional[Session] = None
) -> Optional[PromptModel]:
def get_prompt_by_command(self, command: str, db: Optional[Session] = None) -> Optional[PromptModel]:
try:
with get_db_context(db) as db:
prompt = db.query(Prompt).filter_by(command=command).first()
@@ -205,21 +190,14 @@ class PromptsTable:
def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]:
with get_db_context(db) as db:
all_prompts = (
db.query(Prompt)
.filter(Prompt.is_active == True)
.order_by(Prompt.updated_at.desc())
.all()
)
all_prompts = db.query(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc()).all()
user_ids = list(set(prompt.user_id for prompt in all_prompts))
prompt_ids = [prompt.id for prompt in all_prompts]
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users}
grants_map = AccessGrants.get_grants_by_resources(
"prompt", prompt_ids, db=db
)
grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
prompts = []
for prompt in all_prompts:
@@ -232,7 +210,7 @@ class PromptsTable:
access_grants=grants_map.get(prompt.id, []),
db=db,
).model_dump(),
"user": user.model_dump() if user else None,
'user': user.model_dump() if user else None,
}
)
)
@@ -240,12 +218,10 @@ class PromptsTable:
return prompts
def get_prompts_by_user_id(
self, user_id: str, permission: str = "write", db: Optional[Session] = None
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
) -> list[PromptUserResponse]:
prompts = self.get_prompts(db=db)
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return [
prompt
@@ -253,7 +229,7 @@ class PromptsTable:
if prompt.user_id == user_id
or AccessGrants.has_access(
user_id=user_id,
resource_type="prompt",
resource_type='prompt',
resource_id=prompt.id,
permission=permission,
user_group_ids=user_group_ids,
@@ -276,22 +252,22 @@ class PromptsTable:
query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id)
if filter:
query_key = filter.get("query")
query_key = filter.get('query')
if query_key:
query = query.filter(
or_(
Prompt.name.ilike(f"%{query_key}%"),
Prompt.command.ilike(f"%{query_key}%"),
Prompt.content.ilike(f"%{query_key}%"),
User.name.ilike(f"%{query_key}%"),
User.email.ilike(f"%{query_key}%"),
Prompt.name.ilike(f'%{query_key}%'),
Prompt.command.ilike(f'%{query_key}%'),
Prompt.content.ilike(f'%{query_key}%'),
User.name.ilike(f'%{query_key}%'),
User.email.ilike(f'%{query_key}%'),
)
)
view_option = filter.get("view_option")
if view_option == "created":
view_option = filter.get('view_option')
if view_option == 'created':
query = query.filter(Prompt.user_id == user_id)
elif view_option == "shared":
elif view_option == 'shared':
query = query.filter(Prompt.user_id != user_id)
# Apply access grant filtering
@@ -300,32 +276,32 @@ class PromptsTable:
query=query,
DocumentModel=Prompt,
filter=filter,
resource_type="prompt",
permission="read",
resource_type='prompt',
permission='read',
)
tag = filter.get("tag")
tag = filter.get('tag')
if tag:
# Search for tag in JSON array field
like_pattern = f'%"{tag.lower()}"%'
tags_text = func.lower(cast(Prompt.tags, String))
query = query.filter(tags_text.like(like_pattern))
order_by = filter.get("order_by")
direction = filter.get("direction")
order_by = filter.get('order_by')
direction = filter.get('direction')
if order_by == "name":
if direction == "asc":
if order_by == 'name':
if direction == 'asc':
query = query.order_by(Prompt.name.asc())
else:
query = query.order_by(Prompt.name.desc())
elif order_by == "created_at":
if direction == "asc":
elif order_by == 'created_at':
if direction == 'asc':
query = query.order_by(Prompt.created_at.asc())
else:
query = query.order_by(Prompt.created_at.desc())
elif order_by == "updated_at":
if direction == "asc":
elif order_by == 'updated_at':
if direction == 'asc':
query = query.order_by(Prompt.updated_at.asc())
else:
query = query.order_by(Prompt.updated_at.desc())
@@ -345,9 +321,7 @@ class PromptsTable:
items = query.all()
prompt_ids = [prompt.id for prompt, _ in items]
grants_map = AccessGrants.get_grants_by_resources(
"prompt", prompt_ids, db=db
)
grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
prompts = []
for prompt, user in items:
@@ -358,11 +332,7 @@ class PromptsTable:
access_grants=grants_map.get(prompt.id, []),
db=db,
).model_dump(),
user=(
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
)
)
@@ -381,9 +351,7 @@ class PromptsTable:
if not prompt:
return None
latest_history = PromptHistories.get_latest_history_entry(
prompt.id, db=db
)
latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db)
parent_id = latest_history.id if latest_history else None
current_access_grants = self._get_access_grants(prompt.id, db=db)
@@ -401,9 +369,7 @@ class PromptsTable:
prompt.meta = form_data.meta or prompt.meta
prompt.updated_at = int(time.time())
if form_data.access_grants is not None:
AccessGrants.set_access_grants(
"prompt", prompt.id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
current_access_grants = self._get_access_grants(prompt.id, db=db)
db.commit()
@@ -411,14 +377,12 @@ class PromptsTable:
# Create history entry only if content changed
if content_changed:
snapshot = {
"name": form_data.name,
"content": form_data.content,
"command": command,
"data": form_data.data or {},
"meta": form_data.meta or {},
"access_grants": [
grant.model_dump() for grant in current_access_grants
],
'name': form_data.name,
'content': form_data.content,
'command': command,
'data': form_data.data or {},
'meta': form_data.meta or {},
'access_grants': [grant.model_dump() for grant in current_access_grants],
}
history_entry = PromptHistories.create_history_entry(
@@ -452,9 +416,7 @@ class PromptsTable:
if not prompt:
return None
latest_history = PromptHistories.get_latest_history_entry(
prompt.id, db=db
)
latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db)
parent_id = latest_history.id if latest_history else None
current_access_grants = self._get_access_grants(prompt.id, db=db)
@@ -478,9 +440,7 @@ class PromptsTable:
prompt.tags = form_data.tags
if form_data.access_grants is not None:
AccessGrants.set_access_grants(
"prompt", prompt.id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
current_access_grants = self._get_access_grants(prompt.id, db=db)
prompt.updated_at = int(time.time())
@@ -490,15 +450,13 @@ class PromptsTable:
# Create history entry only if content changed
if content_changed:
snapshot = {
"name": form_data.name,
"content": form_data.content,
"command": prompt.command,
"data": form_data.data or {},
"meta": form_data.meta or {},
"tags": prompt.tags or [],
"access_grants": [
grant.model_dump() for grant in current_access_grants
],
'name': form_data.name,
'content': form_data.content,
'command': prompt.command,
'data': form_data.data or {},
'meta': form_data.meta or {},
'tags': prompt.tags or [],
'access_grants': [grant.model_dump() for grant in current_access_grants],
}
history_entry = PromptHistories.create_history_entry(
@@ -560,9 +518,7 @@ class PromptsTable:
if not prompt:
return None
history_entry = PromptHistories.get_history_entry_by_id(
version_id, db=db
)
history_entry = PromptHistories.get_history_entry_by_id(version_id, db=db)
if not history_entry:
return None
@@ -570,11 +526,11 @@ class PromptsTable:
# Restore prompt content from the snapshot
snapshot = history_entry.snapshot
if snapshot:
prompt.name = snapshot.get("name", prompt.name)
prompt.content = snapshot.get("content", prompt.content)
prompt.data = snapshot.get("data", prompt.data)
prompt.meta = snapshot.get("meta", prompt.meta)
prompt.tags = snapshot.get("tags", prompt.tags)
prompt.name = snapshot.get('name', prompt.name)
prompt.content = snapshot.get('content', prompt.content)
prompt.data = snapshot.get('data', prompt.data)
prompt.meta = snapshot.get('meta', prompt.meta)
prompt.tags = snapshot.get('tags', prompt.tags)
# Note: command and access_grants are not restored from snapshot
prompt.version_id = version_id
@@ -585,9 +541,7 @@ class PromptsTable:
except Exception:
return None
def toggle_prompt_active(
self, prompt_id: str, db: Optional[Session] = None
) -> Optional[PromptModel]:
def toggle_prompt_active(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]:
"""Toggle the is_active flag on a prompt."""
try:
with get_db_context(db) as db:
@@ -602,16 +556,14 @@ class PromptsTable:
except Exception:
return None
def delete_prompt_by_command(
self, command: str, db: Optional[Session] = None
) -> bool:
def delete_prompt_by_command(self, command: str, db: Optional[Session] = None) -> bool:
"""Permanently delete a prompt and its history."""
try:
with get_db_context(db) as db:
prompt = db.query(Prompt).filter_by(command=command).first()
if prompt:
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
AccessGrants.revoke_all_access("prompt", prompt.id, db=db)
AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
db.delete(prompt)
db.commit()
@@ -627,7 +579,7 @@ class PromptsTable:
prompt = db.query(Prompt).filter_by(id=prompt_id).first()
if prompt:
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
AccessGrants.revoke_all_access("prompt", prompt.id, db=db)
AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
db.delete(prompt)
db.commit()

View File

@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
class Skill(Base):
__tablename__ = "skill"
__tablename__ = 'skill'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
@@ -77,7 +77,7 @@ class SkillResponse(BaseModel):
class SkillUserResponse(SkillResponse):
user: Optional[UserResponse] = None
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
class SkillAccessResponse(SkillUserResponse):
@@ -105,10 +105,8 @@ class SkillAccessListResponse(BaseModel):
class SkillsTable:
def _get_access_grants(
self, skill_id: str, db: Optional[Session] = None
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("skill", skill_id, db=db)
def _get_access_grants(self, skill_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource('skill', skill_id, db=db)
def _to_skill_model(
self,
@@ -116,13 +114,9 @@ class SkillsTable:
access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None,
) -> SkillModel:
skill_data = SkillModel.model_validate(skill).model_dump(
exclude={"access_grants"}
)
skill_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(skill_data["id"], db=db)
skill_data = SkillModel.model_validate(skill).model_dump(exclude={'access_grants'})
skill_data['access_grants'] = (
access_grants if access_grants is not None else self._get_access_grants(skill_data['id'], db=db)
)
return SkillModel.model_validate(skill_data)
@@ -136,29 +130,25 @@ class SkillsTable:
try:
result = Skill(
**{
**form_data.model_dump(exclude={"access_grants"}),
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
**form_data.model_dump(exclude={'access_grants'}),
'user_id': user_id,
'updated_at': int(time.time()),
'created_at': int(time.time()),
}
)
db.add(result)
db.commit()
db.refresh(result)
AccessGrants.set_access_grants(
"skill", result.id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('skill', result.id, form_data.access_grants, db=db)
if result:
return self._to_skill_model(result, db=db)
else:
return None
except Exception as e:
log.exception(f"Error creating a new skill: {e}")
log.exception(f'Error creating a new skill: {e}')
return None
def get_skill_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[SkillModel]:
def get_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]:
try:
with get_db_context(db) as db:
skill = db.get(Skill, id)
@@ -166,9 +156,7 @@ class SkillsTable:
except Exception:
return None
def get_skill_by_name(
self, name: str, db: Optional[Session] = None
) -> Optional[SkillModel]:
def get_skill_by_name(self, name: str, db: Optional[Session] = None) -> Optional[SkillModel]:
try:
with get_db_context(db) as db:
skill = db.query(Skill).filter_by(name=name).first()
@@ -185,7 +173,7 @@ class SkillsTable:
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users}
grants_map = AccessGrants.get_grants_by_resources("skill", skill_ids, db=db)
grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db)
skills = []
for skill in all_skills:
@@ -198,19 +186,17 @@ class SkillsTable:
access_grants=grants_map.get(skill.id, []),
db=db,
).model_dump(),
"user": user.model_dump() if user else None,
'user': user.model_dump() if user else None,
}
)
)
return skills
def get_skills_by_user_id(
self, user_id: str, permission: str = "write", db: Optional[Session] = None
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
) -> list[SkillUserModel]:
skills = self.get_skills(db=db)
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return [
skill
@@ -218,7 +204,7 @@ class SkillsTable:
if skill.user_id == user_id
or AccessGrants.has_access(
user_id=user_id,
resource_type="skill",
resource_type='skill',
resource_id=skill.id,
permission=permission,
user_group_ids=user_group_ids,
@@ -242,22 +228,22 @@ class SkillsTable:
query = db.query(Skill, User).outerjoin(User, User.id == Skill.user_id)
if filter:
query_key = filter.get("query")
query_key = filter.get('query')
if query_key:
query = query.filter(
or_(
Skill.name.ilike(f"%{query_key}%"),
Skill.description.ilike(f"%{query_key}%"),
Skill.id.ilike(f"%{query_key}%"),
User.name.ilike(f"%{query_key}%"),
User.email.ilike(f"%{query_key}%"),
Skill.name.ilike(f'%{query_key}%'),
Skill.description.ilike(f'%{query_key}%'),
Skill.id.ilike(f'%{query_key}%'),
User.name.ilike(f'%{query_key}%'),
User.email.ilike(f'%{query_key}%'),
)
)
view_option = filter.get("view_option")
if view_option == "created":
view_option = filter.get('view_option')
if view_option == 'created':
query = query.filter(Skill.user_id == user_id)
elif view_option == "shared":
elif view_option == 'shared':
query = query.filter(Skill.user_id != user_id)
# Apply access grant filtering
@@ -266,8 +252,8 @@ class SkillsTable:
query=query,
DocumentModel=Skill,
filter=filter,
resource_type="skill",
permission="read",
resource_type='skill',
permission='read',
)
query = query.order_by(Skill.updated_at.desc())
@@ -283,9 +269,7 @@ class SkillsTable:
items = query.all()
skill_ids = [skill.id for skill, _ in items]
grants_map = AccessGrants.get_grants_by_resources(
"skill", skill_ids, db=db
)
grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db)
skills = []
for skill, user in items:
@@ -296,33 +280,23 @@ class SkillsTable:
access_grants=grants_map.get(skill.id, []),
db=db,
).model_dump(),
user=(
UserResponse(
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
)
)
return SkillListResponse(items=skills, total=total)
except Exception as e:
log.exception(f"Error searching skills: {e}")
log.exception(f'Error searching skills: {e}')
return SkillListResponse(items=[], total=0)
def update_skill_by_id(
self, id: str, updated: dict, db: Optional[Session] = None
) -> Optional[SkillModel]:
def update_skill_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[SkillModel]:
try:
with get_db_context(db) as db:
access_grants = updated.pop("access_grants", None)
db.query(Skill).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())}
)
access_grants = updated.pop('access_grants', None)
db.query(Skill).filter_by(id=id).update({**updated, 'updated_at': int(time.time())})
db.commit()
if access_grants is not None:
AccessGrants.set_access_grants("skill", id, access_grants, db=db)
AccessGrants.set_access_grants('skill', id, access_grants, db=db)
skill = db.query(Skill).get(id)
db.refresh(skill)
@@ -330,9 +304,7 @@ class SkillsTable:
except Exception:
return None
def toggle_skill_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[SkillModel]:
def toggle_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]:
with get_db_context(db) as db:
try:
skill = db.query(Skill).filter_by(id=id).first()
@@ -351,7 +323,7 @@ class SkillsTable:
def delete_skill_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
AccessGrants.revoke_all_access("skill", id, db=db)
AccessGrants.revoke_all_access('skill', id, db=db)
db.query(Skill).filter_by(id=id).delete()
db.commit()

View File

@@ -17,19 +17,19 @@ log = logging.getLogger(__name__)
# Tag DB Schema
####################
class Tag(Base):
__tablename__ = "tag"
__tablename__ = 'tag'
id = Column(String)
name = Column(String)
user_id = Column(String)
meta = Column(JSON, nullable=True)
__table_args__ = (
PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),
Index("user_id_idx", "user_id"),
PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'),
Index('user_id_idx', 'user_id'),
)
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
__table_args__ = (PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'),)
class TagModel(BaseModel):
@@ -51,12 +51,10 @@ class TagChatIdForm(BaseModel):
class TagTable:
def insert_new_tag(
self, name: str, user_id: str, db: Optional[Session] = None
) -> Optional[TagModel]:
def insert_new_tag(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]:
with get_db_context(db) as db:
id = name.replace(" ", "_").lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
id = name.replace(' ', '_').lower()
tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name})
try:
result = Tag(**tag.model_dump())
db.add(result)
@@ -67,89 +65,63 @@ class TagTable:
else:
return None
except Exception as e:
log.exception(f"Error inserting a new tag: {e}")
log.exception(f'Error inserting a new tag: {e}')
return None
def get_tag_by_name_and_user_id(
self, name: str, user_id: str, db: Optional[Session] = None
) -> Optional[TagModel]:
def get_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]:
try:
id = name.replace(" ", "_").lower()
id = name.replace(' ', '_').lower()
with get_db_context(db) as db:
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception:
return None
def get_tags_by_user_id(
self, user_id: str, db: Optional[Session] = None
) -> list[TagModel]:
def get_tags_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[TagModel]:
with get_db_context(db) as db:
return [TagModel.model_validate(tag) for tag in (db.query(Tag).filter_by(user_id=user_id).all())]
def get_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> list[TagModel]:
with get_db_context(db) as db:
return [
TagModel.model_validate(tag)
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
for tag in (db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all())
]
def get_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str, db: Optional[Session] = None
) -> list[TagModel]:
with get_db_context(db) as db:
return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
)
]
def delete_tag_by_name_and_user_id(
self, name: str, user_id: str, db: Optional[Session] = None
) -> bool:
def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
id = name.replace(" ", "_").lower()
id = name.replace(' ', '_').lower()
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
log.debug(f"res: {res}")
log.debug(f'res: {res}')
db.commit()
return True
except Exception as e:
log.error(f"delete_tag: {e}")
log.error(f'delete_tag: {e}')
return False
def delete_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str, db: Optional[Session] = None
) -> bool:
def delete_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> bool:
"""Delete all tags whose id is in *ids* for the given user, in one query."""
if not ids:
return True
try:
with get_db_context(db) as db:
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete(
synchronize_session=False
)
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete(synchronize_session=False)
db.commit()
return True
except Exception as e:
log.error(f"delete_tags_by_ids: {e}")
log.error(f'delete_tags_by_ids: {e}')
return False
def ensure_tags_exist(
self, names: list[str], user_id: str, db: Optional[Session] = None
) -> None:
def ensure_tags_exist(self, names: list[str], user_id: str, db: Optional[Session] = None) -> None:
"""Create tag rows for any *names* that don't already exist for *user_id*."""
if not names:
return
ids = [n.replace(" ", "_").lower() for n in names]
ids = [n.replace(' ', '_').lower() for n in names]
with get_db_context(db) as db:
existing = {
t.id
for t in db.query(Tag.id)
.filter(Tag.id.in_(ids), Tag.user_id == user_id)
.all()
}
existing = {t.id for t in db.query(Tag.id).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()}
new_tags = [
Tag(id=tag_id, name=name, user_id=user_id)
for tag_id, name in zip(ids, names)
if tag_id not in existing
Tag(id=tag_id, name=name, user_id=user_id) for tag_id, name in zip(ids, names) if tag_id not in existing
]
if new_tags:
db.add_all(new_tags)

View File

@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
class Tool(Base):
__tablename__ = "tool"
__tablename__ = 'tool'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
@@ -75,7 +75,7 @@ class ToolResponse(BaseModel):
class ToolUserResponse(ToolResponse):
user: Optional[UserResponse] = None
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
class ToolAccessResponse(ToolUserResponse):
@@ -95,10 +95,8 @@ class ToolValves(BaseModel):
class ToolsTable:
def _get_access_grants(
self, tool_id: str, db: Optional[Session] = None
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("tool", tool_id, db=db)
def _get_access_grants(self, tool_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource('tool', tool_id, db=db)
def _to_tool_model(
self,
@@ -106,11 +104,9 @@ class ToolsTable:
access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None,
) -> ToolModel:
tool_data = ToolModel.model_validate(tool).model_dump(exclude={"access_grants"})
tool_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(tool_data["id"], db=db)
tool_data = ToolModel.model_validate(tool).model_dump(exclude={'access_grants'})
tool_data['access_grants'] = (
access_grants if access_grants is not None else self._get_access_grants(tool_data['id'], db=db)
)
return ToolModel.model_validate(tool_data)
@@ -125,30 +121,26 @@ class ToolsTable:
try:
result = Tool(
**{
**form_data.model_dump(exclude={"access_grants"}),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
**form_data.model_dump(exclude={'access_grants'}),
'specs': specs,
'user_id': user_id,
'updated_at': int(time.time()),
'created_at': int(time.time()),
}
)
db.add(result)
db.commit()
db.refresh(result)
AccessGrants.set_access_grants(
"tool", result.id, form_data.access_grants, db=db
)
AccessGrants.set_access_grants('tool', result.id, form_data.access_grants, db=db)
if result:
return self._to_tool_model(result, db=db)
else:
return None
except Exception as e:
log.exception(f"Error creating a new tool: {e}")
log.exception(f'Error creating a new tool: {e}')
return None
def get_tool_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[ToolModel]:
def get_tool_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ToolModel]:
try:
with get_db_context(db) as db:
tool = db.get(Tool, id)
@@ -156,9 +148,7 @@ class ToolsTable:
except Exception:
return None
def get_tools(
self, defer_content: bool = False, db: Optional[Session] = None
) -> list[ToolUserModel]:
def get_tools(self, defer_content: bool = False, db: Optional[Session] = None) -> list[ToolUserModel]:
with get_db_context(db) as db:
query = db.query(Tool).order_by(Tool.updated_at.desc())
if defer_content:
@@ -170,7 +160,7 @@ class ToolsTable:
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users}
grants_map = AccessGrants.get_grants_by_resources("tool", tool_ids, db=db)
grants_map = AccessGrants.get_grants_by_resources('tool', tool_ids, db=db)
tools = []
for tool in all_tools:
@@ -183,7 +173,7 @@ class ToolsTable:
access_grants=grants_map.get(tool.id, []),
db=db,
).model_dump(),
"user": user.model_dump() if user else None,
'user': user.model_dump() if user else None,
}
)
)
@@ -192,14 +182,12 @@ class ToolsTable:
def get_tools_by_user_id(
self,
user_id: str,
permission: str = "write",
permission: str = 'write',
defer_content: bool = False,
db: Optional[Session] = None,
) -> list[ToolUserModel]:
tools = self.get_tools(defer_content=defer_content, db=db)
user_group_ids = {
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
return [
tool
@@ -207,7 +195,7 @@ class ToolsTable:
if tool.user_id == user_id
or AccessGrants.has_access(
user_id=user_id,
resource_type="tool",
resource_type='tool',
resource_id=tool.id,
permission=permission,
user_group_ids=user_group_ids,
@@ -215,48 +203,38 @@ class ToolsTable:
)
]
def get_tool_valves_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[dict]:
def get_tool_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
try:
with get_db_context(db) as db:
tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e:
log.exception(f"Error getting tool valves by id {id}")
log.exception(f'Error getting tool valves by id {id}')
return None
def update_tool_valves_by_id(
self, id: str, valves: dict, db: Optional[Session] = None
) -> Optional[ToolValves]:
def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[Session] = None) -> Optional[ToolValves]:
try:
with get_db_context(db) as db:
db.query(Tool).filter_by(id=id).update(
{"valves": valves, "updated_at": int(time.time())}
)
db.query(Tool).filter_by(id=id).update({'valves': valves, 'updated_at': int(time.time())})
db.commit()
return self.get_tool_by_id(id, db=db)
except Exception:
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[dict]:
def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
user_settings["tools"] = {}
if "valves" not in user_settings["tools"]:
user_settings["tools"]["valves"] = {}
if 'tools' not in user_settings:
user_settings['tools'] = {}
if 'valves' not in user_settings['tools']:
user_settings['tools']['valves'] = {}
return user_settings["tools"]["valves"].get(id, {})
return user_settings['tools']['valves'].get(id, {})
except Exception as e:
log.exception(
f"Error getting user values by id {id} and user_id {user_id}: {e}"
)
log.exception(f'Error getting user values by id {id} and user_id {user_id}: {e}')
return None
def update_user_valves_by_id_and_user_id(
@@ -267,35 +245,29 @@ class ToolsTable:
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings
if "tools" not in user_settings:
user_settings["tools"] = {}
if "valves" not in user_settings["tools"]:
user_settings["tools"]["valves"] = {}
if 'tools' not in user_settings:
user_settings['tools'] = {}
if 'valves' not in user_settings['tools']:
user_settings['tools']['valves'] = {}
user_settings["tools"]["valves"][id] = valves
user_settings['tools']['valves'][id] = valves
# Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings}, db=db)
Users.update_user_by_id(user_id, {'settings': user_settings}, db=db)
return user_settings["tools"]["valves"][id]
return user_settings['tools']['valves'][id]
except Exception as e:
log.exception(
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
return None
def update_tool_by_id(
self, id: str, updated: dict, db: Optional[Session] = None
) -> Optional[ToolModel]:
def update_tool_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[ToolModel]:
try:
with get_db_context(db) as db:
access_grants = updated.pop("access_grants", None)
db.query(Tool).filter_by(id=id).update(
{**updated, "updated_at": int(time.time())}
)
access_grants = updated.pop('access_grants', None)
db.query(Tool).filter_by(id=id).update({**updated, 'updated_at': int(time.time())})
db.commit()
if access_grants is not None:
AccessGrants.set_access_grants("tool", id, access_grants, db=db)
AccessGrants.set_access_grants('tool', id, access_grants, db=db)
tool = db.query(Tool).get(id)
db.refresh(tool)
@@ -306,7 +278,7 @@ class ToolsTable:
def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
AccessGrants.revoke_all_access("tool", id, db=db)
AccessGrants.revoke_all_access('tool', id, db=db)
db.query(Tool).filter_by(id=id).delete()
db.commit()

View File

@@ -40,12 +40,12 @@ import datetime
class UserSettings(BaseModel):
ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
pass
class User(Base):
__tablename__ = "user"
__tablename__ = 'user'
id = Column(String, primary_key=True, unique=True)
email = Column(String)
@@ -83,7 +83,7 @@ class UserModel(BaseModel):
email: str
username: Optional[str] = None
role: str = "pending"
role: str = 'pending'
name: str
@@ -112,10 +112,10 @@ class UserModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
@model_validator(mode="after")
@model_validator(mode='after')
def set_profile_image_url(self):
if not self.profile_image_url:
self.profile_image_url = f"/api/v1/users/{self.id}/profile/image"
self.profile_image_url = f'/api/v1/users/{self.id}/profile/image'
return self
@@ -126,7 +126,7 @@ class UserStatusModel(UserModel):
class ApiKey(Base):
__tablename__ = "api_key"
__tablename__ = 'api_key'
id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text, nullable=False)
@@ -163,7 +163,7 @@ class UpdateProfileForm(BaseModel):
gender: Optional[str] = None
date_of_birth: Optional[datetime.date] = None
@field_validator("profile_image_url")
@field_validator('profile_image_url')
@classmethod
def check_profile_image_url(cls, v: str) -> str:
return validate_profile_image_url(v)
@@ -174,7 +174,7 @@ class UserGroupIdsModel(UserModel):
class UserModelResponse(UserModel):
model_config = ConfigDict(extra="allow")
model_config = ConfigDict(extra='allow')
class UserListResponse(BaseModel):
@@ -251,7 +251,7 @@ class UserUpdateForm(BaseModel):
profile_image_url: str
password: Optional[str] = None
@field_validator("profile_image_url")
@field_validator('profile_image_url')
@classmethod
def check_profile_image_url(cls, v: str) -> str:
return validate_profile_image_url(v)
@@ -263,8 +263,8 @@ class UsersTable:
id: str,
name: str,
email: str,
profile_image_url: str = "/user.png",
role: str = "pending",
profile_image_url: str = '/user.png',
role: str = 'pending',
username: Optional[str] = None,
oauth: Optional[dict] = None,
db: Optional[Session] = None,
@@ -272,16 +272,16 @@ class UsersTable:
with get_db_context(db) as db:
user = UserModel(
**{
"id": id,
"email": email,
"name": name,
"role": role,
"profile_image_url": profile_image_url,
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
"username": username,
"oauth": oauth,
'id': id,
'email': email,
'name': name,
'role': role,
'profile_image_url': profile_image_url,
'last_active_at': int(time.time()),
'created_at': int(time.time()),
'updated_at': int(time.time()),
'username': username,
'oauth': oauth,
}
)
result = User(**user.model_dump())
@@ -293,9 +293,7 @@ class UsersTable:
else:
return None
def get_user_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[UserModel]:
def get_user_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
try:
with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first()
@@ -303,49 +301,32 @@ class UsersTable:
except Exception:
return None
def get_user_by_api_key(
self, api_key: str, db: Optional[Session] = None
) -> Optional[UserModel]:
def get_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
try:
with get_db_context(db) as db:
user = (
db.query(User)
.join(ApiKey, User.id == ApiKey.user_id)
.filter(ApiKey.key == api_key)
.first()
)
user = db.query(User).join(ApiKey, User.id == ApiKey.user_id).filter(ApiKey.key == api_key).first()
return UserModel.model_validate(user) if user else None
except Exception:
return None
def get_user_by_email(
self, email: str, db: Optional[Session] = None
) -> Optional[UserModel]:
def get_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
try:
with get_db_context(db) as db:
user = (
db.query(User)
.filter(func.lower(User.email) == email.lower())
.first()
)
user = db.query(User).filter(func.lower(User.email) == email.lower()).first()
return UserModel.model_validate(user) if user else None
except Exception:
return None
def get_user_by_oauth_sub(
self, provider: str, sub: str, db: Optional[Session] = None
) -> Optional[UserModel]:
def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[Session] = None) -> Optional[UserModel]:
try:
with get_db_context(db) as db: # type: Session
dialect_name = db.bind.dialect.name
query = db.query(User)
if dialect_name == "sqlite":
query = query.filter(User.oauth.contains({provider: {"sub": sub}}))
elif dialect_name == "postgresql":
query = query.filter(
User.oauth[provider].cast(JSONB)["sub"].astext == sub
)
if dialect_name == 'sqlite':
query = query.filter(User.oauth.contains({provider: {'sub': sub}}))
elif dialect_name == 'postgresql':
query = query.filter(User.oauth[provider].cast(JSONB)['sub'].astext == sub)
user = query.first()
return UserModel.model_validate(user) if user else None
@@ -361,15 +342,10 @@ class UsersTable:
dialect_name = db.bind.dialect.name
query = db.query(User)
if dialect_name == "sqlite":
query = query.filter(
User.scim.contains({provider: {"external_id": external_id}})
)
elif dialect_name == "postgresql":
query = query.filter(
User.scim[provider].cast(JSONB)["external_id"].astext
== external_id
)
if dialect_name == 'sqlite':
query = query.filter(User.scim.contains({provider: {'external_id': external_id}}))
elif dialect_name == 'postgresql':
query = query.filter(User.scim[provider].cast(JSONB)['external_id'].astext == external_id)
user = query.first()
return UserModel.model_validate(user) if user else None
@@ -388,16 +364,16 @@ class UsersTable:
query = db.query(User).options(defer(User.profile_image_url))
if filter:
query_key = filter.get("query")
query_key = filter.get('query')
if query_key:
query = query.filter(
or_(
User.name.ilike(f"%{query_key}%"),
User.email.ilike(f"%{query_key}%"),
User.name.ilike(f'%{query_key}%'),
User.email.ilike(f'%{query_key}%'),
)
)
channel_id = filter.get("channel_id")
channel_id = filter.get('channel_id')
if channel_id:
query = query.filter(
exists(
@@ -408,13 +384,13 @@ class UsersTable:
)
)
user_ids = filter.get("user_ids")
group_ids = filter.get("group_ids")
user_ids = filter.get('user_ids')
group_ids = filter.get('group_ids')
if isinstance(user_ids, list) and isinstance(group_ids, list):
# If both are empty lists, return no users
if not user_ids and not group_ids:
return {"users": [], "total": 0}
return {'users': [], 'total': 0}
if user_ids:
query = query.filter(User.id.in_(user_ids))
@@ -429,21 +405,21 @@ class UsersTable:
)
)
roles = filter.get("roles")
roles = filter.get('roles')
if roles:
include_roles = [role for role in roles if not role.startswith("!")]
exclude_roles = [role[1:] for role in roles if role.startswith("!")]
include_roles = [role for role in roles if not role.startswith('!')]
exclude_roles = [role[1:] for role in roles if role.startswith('!')]
if include_roles:
query = query.filter(User.role.in_(include_roles))
if exclude_roles:
query = query.filter(~User.role.in_(exclude_roles))
order_by = filter.get("order_by")
direction = filter.get("direction")
order_by = filter.get('order_by')
direction = filter.get('direction')
if order_by and order_by.startswith("group_id:"):
group_id = order_by.split(":", 1)[1]
if order_by and order_by.startswith('group_id:'):
group_id = order_by.split(':', 1)[1]
# Subquery that checks if the user belongs to the group
membership_exists = exists(
@@ -456,42 +432,42 @@ class UsersTable:
# CASE: user in group → 1, user not in group → 0
group_sort = case((membership_exists, 1), else_=0)
if direction == "asc":
if direction == 'asc':
query = query.order_by(group_sort.asc(), User.name.asc())
else:
query = query.order_by(group_sort.desc(), User.name.asc())
elif order_by == "name":
if direction == "asc":
elif order_by == 'name':
if direction == 'asc':
query = query.order_by(User.name.asc())
else:
query = query.order_by(User.name.desc())
elif order_by == "email":
if direction == "asc":
elif order_by == 'email':
if direction == 'asc':
query = query.order_by(User.email.asc())
else:
query = query.order_by(User.email.desc())
elif order_by == "created_at":
if direction == "asc":
elif order_by == 'created_at':
if direction == 'asc':
query = query.order_by(User.created_at.asc())
else:
query = query.order_by(User.created_at.desc())
elif order_by == "last_active_at":
if direction == "asc":
elif order_by == 'last_active_at':
if direction == 'asc':
query = query.order_by(User.last_active_at.asc())
else:
query = query.order_by(User.last_active_at.desc())
elif order_by == "updated_at":
if direction == "asc":
elif order_by == 'updated_at':
if direction == 'asc':
query = query.order_by(User.updated_at.asc())
else:
query = query.order_by(User.updated_at.desc())
elif order_by == "role":
if direction == "asc":
elif order_by == 'role':
if direction == 'asc':
query = query.order_by(User.role.asc())
else:
query = query.order_by(User.role.desc())
@@ -510,13 +486,11 @@ class UsersTable:
users = query.all()
return {
"users": [UserModel.model_validate(user) for user in users],
"total": total,
'users': [UserModel.model_validate(user) for user in users],
'total': total,
}
def get_users_by_group_id(
self, group_id: str, db: Optional[Session] = None
) -> list[UserModel]:
def get_users_by_group_id(self, group_id: str, db: Optional[Session] = None) -> list[UserModel]:
with get_db_context(db) as db:
users = (
db.query(User)
@@ -527,16 +501,9 @@ class UsersTable:
)
return [UserModel.model_validate(user) for user in users]
def get_users_by_user_ids(
self, user_ids: list[str], db: Optional[Session] = None
) -> list[UserStatusModel]:
def get_users_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[UserStatusModel]:
with get_db_context(db) as db:
users = (
db.query(User)
.options(defer(User.profile_image_url))
.filter(User.id.in_(user_ids))
.all()
)
users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all()
return [UserModel.model_validate(user) for user in users]
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:
@@ -555,9 +522,7 @@ class UsersTable:
except Exception:
return None
def get_user_webhook_url_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[str]:
def get_user_webhook_url_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
try:
with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first()
@@ -565,11 +530,7 @@ class UsersTable:
if user.settings is None:
return None
else:
return (
user.settings.get("ui", {})
.get("notifications", {})
.get("webhook_url", None)
)
return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None)
except Exception:
return None
@@ -577,14 +538,10 @@ class UsersTable:
with get_db_context(db) as db:
current_timestamp = int(datetime.datetime.now().timestamp())
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
query = db.query(User).filter(
User.last_active_at > today_midnight_timestamp
)
query = db.query(User).filter(User.last_active_at > today_midnight_timestamp)
return query.count()
def update_user_role_by_id(
self, id: str, role: str, db: Optional[Session] = None
) -> Optional[UserModel]:
def update_user_role_by_id(self, id: str, role: str, db: Optional[Session] = None) -> Optional[UserModel]:
try:
with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first()
@@ -629,9 +586,7 @@ class UsersTable:
return None
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
def update_last_active_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[UserModel]:
def update_last_active_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
try:
with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first()
@@ -665,10 +620,10 @@ class UsersTable:
oauth = user.oauth or {}
# Update or insert provider entry
oauth[provider] = {"sub": sub}
oauth[provider] = {'sub': sub}
# Persist updated JSON
db.query(User).filter_by(id=id).update({"oauth": oauth})
db.query(User).filter_by(id=id).update({'oauth': oauth})
db.commit()
return UserModel.model_validate(user)
@@ -698,9 +653,9 @@ class UsersTable:
return None
scim = user.scim or {}
scim[provider] = {"external_id": external_id}
scim[provider] = {'external_id': external_id}
db.query(User).filter_by(id=id).update({"scim": scim})
db.query(User).filter_by(id=id).update({'scim': scim})
db.commit()
return UserModel.model_validate(user)
@@ -708,9 +663,7 @@ class UsersTable:
except Exception:
return None
def update_user_by_id(
self, id: str, updated: dict, db: Optional[Session] = None
) -> Optional[UserModel]:
def update_user_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
try:
with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first()
@@ -725,9 +678,7 @@ class UsersTable:
print(e)
return None
def update_user_settings_by_id(
self, id: str, updated: dict, db: Optional[Session] = None
) -> Optional[UserModel]:
def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
try:
with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first()
@@ -741,7 +692,7 @@ class UsersTable:
user_settings.update(updated)
db.query(User).filter_by(id=id).update({"settings": user_settings})
db.query(User).filter_by(id=id).update({'settings': user_settings})
db.commit()
user = db.query(User).filter_by(id=id).first()
@@ -768,9 +719,7 @@ class UsersTable:
except Exception:
return False
def get_user_api_key_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[str]:
def get_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
try:
with get_db_context(db) as db:
api_key = db.query(ApiKey).filter_by(user_id=id).first()
@@ -778,9 +727,7 @@ class UsersTable:
except Exception:
return None
def update_user_api_key_by_id(
self, id: str, api_key: str, db: Optional[Session] = None
) -> bool:
def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
db.query(ApiKey).filter_by(user_id=id).delete()
@@ -788,7 +735,7 @@ class UsersTable:
now = int(time.time())
new_api_key = ApiKey(
id=f"key_{id}",
id=f'key_{id}',
user_id=id,
key=api_key,
created_at=now,
@@ -811,16 +758,14 @@ class UsersTable:
except Exception:
return False
def get_valid_user_ids(
self, user_ids: list[str], db: Optional[Session] = None
) -> list[str]:
def get_valid_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[str]:
with get_db_context(db) as db:
users = db.query(User).filter(User.id.in_(user_ids)).all()
return [user.id for user in users]
def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]:
with get_db_context(db) as db:
user = db.query(User).filter_by(role="admin").first()
user = db.query(User).filter_by(role='admin').first()
if user:
return UserModel.model_validate(user)
else:
@@ -830,9 +775,7 @@ class UsersTable:
with get_db_context(db) as db:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
count = (
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
)
count = db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
return count
@staticmethod

View File

@@ -40,78 +40,76 @@ class DatalabMarkerLoader:
self.output_format = output_format
def _get_mime_type(self, filename: str) -> str:
ext = filename.rsplit(".", 1)[-1].lower()
ext = filename.rsplit('.', 1)[-1].lower()
mime_map = {
"pdf": "application/pdf",
"xls": "application/vnd.ms-excel",
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"ods": "application/vnd.oasis.opendocument.spreadsheet",
"doc": "application/msword",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"odt": "application/vnd.oasis.opendocument.text",
"ppt": "application/vnd.ms-powerpoint",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"odp": "application/vnd.oasis.opendocument.presentation",
"html": "text/html",
"epub": "application/epub+zip",
"png": "image/png",
"jpeg": "image/jpeg",
"jpg": "image/jpeg",
"webp": "image/webp",
"gif": "image/gif",
"tiff": "image/tiff",
'pdf': 'application/pdf',
'xls': 'application/vnd.ms-excel',
'xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'ods': 'application/vnd.oasis.opendocument.spreadsheet',
'doc': 'application/msword',
'docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'odt': 'application/vnd.oasis.opendocument.text',
'ppt': 'application/vnd.ms-powerpoint',
'pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
'odp': 'application/vnd.oasis.opendocument.presentation',
'html': 'text/html',
'epub': 'application/epub+zip',
'png': 'image/png',
'jpeg': 'image/jpeg',
'jpg': 'image/jpeg',
'webp': 'image/webp',
'gif': 'image/gif',
'tiff': 'image/tiff',
}
return mime_map.get(ext, "application/octet-stream")
return mime_map.get(ext, 'application/octet-stream')
def check_marker_request_status(self, request_id: str) -> dict:
url = f"{self.api_base_url}/{request_id}"
headers = {"X-Api-Key": self.api_key}
url = f'{self.api_base_url}/{request_id}'
headers = {'X-Api-Key': self.api_key}
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
result = response.json()
log.info(f"Marker API status check for request {request_id}: {result}")
log.info(f'Marker API status check for request {request_id}: {result}')
return result
except requests.HTTPError as e:
log.error(f"Error checking Marker request status: {e}")
log.error(f'Error checking Marker request status: {e}')
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to check Marker request: {e}",
detail=f'Failed to check Marker request: {e}',
)
except ValueError as e:
log.error(f"Invalid JSON checking Marker request: {e}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
)
log.error(f'Invalid JSON checking Marker request: {e}')
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON: {e}')
def load(self) -> List[Document]:
filename = os.path.basename(self.file_path)
mime_type = self._get_mime_type(filename)
headers = {"X-Api-Key": self.api_key}
headers = {'X-Api-Key': self.api_key}
form_data = {
"use_llm": str(self.use_llm).lower(),
"skip_cache": str(self.skip_cache).lower(),
"force_ocr": str(self.force_ocr).lower(),
"paginate": str(self.paginate).lower(),
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
"disable_image_extraction": str(self.disable_image_extraction).lower(),
"format_lines": str(self.format_lines).lower(),
"output_format": self.output_format,
'use_llm': str(self.use_llm).lower(),
'skip_cache': str(self.skip_cache).lower(),
'force_ocr': str(self.force_ocr).lower(),
'paginate': str(self.paginate).lower(),
'strip_existing_ocr': str(self.strip_existing_ocr).lower(),
'disable_image_extraction': str(self.disable_image_extraction).lower(),
'format_lines': str(self.format_lines).lower(),
'output_format': self.output_format,
}
if self.additional_config and self.additional_config.strip():
form_data["additional_config"] = self.additional_config
form_data['additional_config'] = self.additional_config
log.info(
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
)
try:
with open(self.file_path, "rb") as f:
files = {"file": (filename, f, mime_type)}
with open(self.file_path, 'rb') as f:
files = {'file': (filename, f, mime_type)}
response = requests.post(
f"{self.api_base_url}",
f'{self.api_base_url}',
data=form_data,
files=files,
headers=headers,
@@ -119,29 +117,25 @@ class DatalabMarkerLoader:
response.raise_for_status()
result = response.json()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {e}",
detail=f'Datalab Marker request failed: {e}',
)
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
)
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON response: {e}')
except Exception as e:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
if not result.get("success"):
if not result.get('success'):
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
detail=f'Datalab Marker request failed: {result.get("error", "Unknown error")}',
)
check_url = result.get("request_check_url")
request_id = result.get("request_id")
check_url = result.get('request_check_url')
request_id = result.get('request_id')
# Check if this is a direct response (self-hosted) or polling response (DataLab)
if check_url:
@@ -154,54 +148,45 @@ class DatalabMarkerLoader:
poll_result = poll_response.json()
except (requests.HTTPError, ValueError) as e:
raw_body = poll_response.text
log.error(f"Polling error: {e}, response body: {raw_body}")
raise HTTPException(
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
)
log.error(f'Polling error: {e}, response body: {raw_body}')
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Polling failed: {e}')
status_val = poll_result.get("status")
success_val = poll_result.get("success")
status_val = poll_result.get('status')
success_val = poll_result.get('success')
if status_val == "complete":
if status_val == 'complete':
summary = {
k: poll_result.get(k)
for k in (
"status",
"output_format",
"success",
"error",
"page_count",
"total_cost",
'status',
'output_format',
'success',
'error',
'page_count',
'total_cost',
)
}
log.info(
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
)
log.info(f'Marker processing completed successfully: {json.dumps(summary, indent=2)}')
break
if status_val == "failed" or success_val is False:
log.error(
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
)
error_msg = (
poll_result.get("error")
or "Marker returned failure without error message"
)
if status_val == 'failed' or success_val is False:
log.error(f'Marker poll failed full response: {json.dumps(poll_result, indent=2)}')
error_msg = poll_result.get('error') or 'Marker returned failure without error message'
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Marker processing failed: {error_msg}",
detail=f'Marker processing failed: {error_msg}',
)
else:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="Marker processing timed out",
detail='Marker processing timed out',
)
if not poll_result.get("success", False):
error_msg = poll_result.get("error") or "Unknown processing error"
if not poll_result.get('success', False):
error_msg = poll_result.get('error') or 'Unknown processing error'
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Final processing failed: {error_msg}",
detail=f'Final processing failed: {error_msg}',
)
# DataLab format - content in format-specific fields
@@ -210,69 +195,65 @@ class DatalabMarkerLoader:
final_result = poll_result
else:
# Self-hosted direct response - content in "output" field
if "output" in result:
log.info("Self-hosted Marker returned direct response without polling")
raw_content = result.get("output")
if 'output' in result:
log.info('Self-hosted Marker returned direct response without polling')
raw_content = result.get('output')
final_result = result
else:
available_fields = (
list(result.keys())
if isinstance(result, dict)
else "non-dict response"
)
available_fields = list(result.keys()) if isinstance(result, dict) else 'non-dict response'
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
)
if self.output_format.lower() == "json":
if self.output_format.lower() == 'json':
full_text = json.dumps(raw_content, indent=2)
elif self.output_format.lower() in {"markdown", "html"}:
elif self.output_format.lower() in {'markdown', 'html'}:
full_text = str(raw_content).strip()
else:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported output format: {self.output_format}",
detail=f'Unsupported output format: {self.output_format}',
)
if not full_text:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Marker returned empty content",
detail='Marker returned empty content',
)
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
marker_output_dir = os.path.join('/app/backend/data/uploads', 'marker_output')
os.makedirs(marker_output_dir, exist_ok=True)
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
file_ext = file_ext_map.get(self.output_format.lower(), "txt")
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
file_ext_map = {'markdown': 'md', 'json': 'json', 'html': 'html'}
file_ext = file_ext_map.get(self.output_format.lower(), 'txt')
output_filename = f'{os.path.splitext(filename)[0]}.{file_ext}'
output_path = os.path.join(marker_output_dir, output_filename)
try:
with open(output_path, "w", encoding="utf-8") as f:
with open(output_path, 'w', encoding='utf-8') as f:
f.write(full_text)
log.info(f"Saved Marker output to: {output_path}")
log.info(f'Saved Marker output to: {output_path}')
except Exception as e:
log.warning(f"Failed to write marker output to disk: {e}")
log.warning(f'Failed to write marker output to disk: {e}')
metadata = {
"source": filename,
"output_format": final_result.get("output_format", self.output_format),
"page_count": final_result.get("page_count", 0),
"processed_with_llm": self.use_llm,
"request_id": request_id or "",
'source': filename,
'output_format': final_result.get('output_format', self.output_format),
'page_count': final_result.get('page_count', 0),
'processed_with_llm': self.use_llm,
'request_id': request_id or '',
}
images = final_result.get("images", {})
images = final_result.get('images', {})
if images:
metadata["image_count"] = len(images)
metadata["images"] = json.dumps(list(images.keys()))
metadata['image_count'] = len(images)
metadata['images'] = json.dumps(list(images.keys()))
for k, v in metadata.items():
if isinstance(v, (dict, list)):
metadata[k] = json.dumps(v)
elif v is None:
metadata[k] = ""
metadata[k] = ''
return [Document(page_content=full_text, metadata=metadata)]

View File

@@ -29,18 +29,18 @@ class ExternalDocumentLoader(BaseLoader):
self.user = user
def load(self) -> List[Document]:
with open(self.file_path, "rb") as f:
with open(self.file_path, 'rb') as f:
data = f.read()
headers = {}
if self.mime_type is not None:
headers["Content-Type"] = self.mime_type
headers['Content-Type'] = self.mime_type
if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}"
headers['Authorization'] = f'Bearer {self.api_key}'
try:
headers["X-Filename"] = quote(os.path.basename(self.file_path))
headers['X-Filename'] = quote(os.path.basename(self.file_path))
except Exception:
pass
@@ -48,24 +48,23 @@ class ExternalDocumentLoader(BaseLoader):
headers = include_user_info_headers(headers, self.user)
url = self.url
if url.endswith("/"):
if url.endswith('/'):
url = url[:-1]
try:
response = requests.put(f"{url}/process", data=data, headers=headers)
response = requests.put(f'{url}/process', data=data, headers=headers)
except Exception as e:
log.error(f"Error connecting to endpoint: {e}")
raise Exception(f"Error connecting to endpoint: {e}")
log.error(f'Error connecting to endpoint: {e}')
raise Exception(f'Error connecting to endpoint: {e}')
if response.ok:
response_data = response.json()
if response_data:
if isinstance(response_data, dict):
return [
Document(
page_content=response_data.get("page_content"),
metadata=response_data.get("metadata"),
page_content=response_data.get('page_content'),
metadata=response_data.get('metadata'),
)
]
elif isinstance(response_data, list):
@@ -73,17 +72,15 @@ class ExternalDocumentLoader(BaseLoader):
for document in response_data:
documents.append(
Document(
page_content=document.get("page_content"),
metadata=document.get("metadata"),
page_content=document.get('page_content'),
metadata=document.get('metadata'),
)
)
return documents
else:
raise Exception("Error loading document: Unable to parse content")
raise Exception('Error loading document: Unable to parse content')
else:
raise Exception("Error loading document: No content returned")
raise Exception('Error loading document: No content returned')
else:
raise Exception(
f"Error loading document: {response.status_code} {response.text}"
)
raise Exception(f'Error loading document: {response.status_code} {response.text}')

View File

@@ -30,22 +30,22 @@ class ExternalWebLoader(BaseLoader):
response = requests.post(
self.external_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
"Authorization": f"Bearer {self.external_api_key}",
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) External Web Loader',
'Authorization': f'Bearer {self.external_api_key}',
},
json={
"urls": urls,
'urls': urls,
},
)
response.raise_for_status()
results = response.json()
for result in results:
yield Document(
page_content=result.get("page_content", ""),
metadata=result.get("metadata", {}),
page_content=result.get('page_content', ''),
metadata=result.get('metadata', {}),
)
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {urls}: {e}")
log.error(f'Error extracting content from batch {urls}: {e}')
else:
raise e

View File

@@ -30,59 +30,59 @@ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
known_source_ext = [
"go",
"py",
"java",
"sh",
"bat",
"ps1",
"cmd",
"js",
"ts",
"css",
"cpp",
"hpp",
"h",
"c",
"cs",
"sql",
"log",
"ini",
"pl",
"pm",
"r",
"dart",
"dockerfile",
"env",
"php",
"hs",
"hsc",
"lua",
"nginxconf",
"conf",
"m",
"mm",
"plsql",
"perl",
"rb",
"rs",
"db2",
"scala",
"bash",
"swift",
"vue",
"svelte",
"ex",
"exs",
"erl",
"tsx",
"jsx",
"hs",
"lhs",
"json",
"yaml",
"yml",
"toml",
'go',
'py',
'java',
'sh',
'bat',
'ps1',
'cmd',
'js',
'ts',
'css',
'cpp',
'hpp',
'h',
'c',
'cs',
'sql',
'log',
'ini',
'pl',
'pm',
'r',
'dart',
'dockerfile',
'env',
'php',
'hs',
'hsc',
'lua',
'nginxconf',
'conf',
'm',
'mm',
'plsql',
'perl',
'rb',
'rs',
'db2',
'scala',
'bash',
'swift',
'vue',
'svelte',
'ex',
'exs',
'erl',
'tsx',
'jsx',
'hs',
'lhs',
'json',
'yaml',
'yml',
'toml',
]
@@ -99,11 +99,11 @@ class ExcelLoader:
xls = pd.ExcelFile(self.file_path)
for sheet_name in xls.sheet_names:
df = pd.read_excel(xls, sheet_name=sheet_name)
text_parts.append(f"Sheet: {sheet_name}\n{df.to_string(index=False)}")
text_parts.append(f'Sheet: {sheet_name}\n{df.to_string(index=False)}')
return [
Document(
page_content="\n\n".join(text_parts),
metadata={"source": self.file_path},
page_content='\n\n'.join(text_parts),
metadata={'source': self.file_path},
)
]
@@ -125,11 +125,11 @@ class PptxLoader:
if shape.has_text_frame:
slide_texts.append(shape.text_frame.text)
if slide_texts:
text_parts.append(f"Slide {i}:\n" + "\n".join(slide_texts))
text_parts.append(f'Slide {i}:\n' + '\n'.join(slide_texts))
return [
Document(
page_content="\n\n".join(text_parts),
metadata={"source": self.file_path},
page_content='\n\n'.join(text_parts),
metadata={'source': self.file_path},
)
]
@@ -143,41 +143,41 @@ class TikaLoader:
self.extract_images = extract_images
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
with open(self.file_path, 'rb') as f:
data = f.read()
if self.mime_type is not None:
headers = {"Content-Type": self.mime_type}
headers = {'Content-Type': self.mime_type}
else:
headers = {}
if self.extract_images == True:
headers["X-Tika-PDFextractInlineImages"] = "true"
headers['X-Tika-PDFextractInlineImages'] = 'true'
endpoint = self.url
if not endpoint.endswith("/"):
endpoint += "/"
endpoint += "tika/text"
if not endpoint.endswith('/'):
endpoint += '/'
endpoint += 'tika/text'
r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY)
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
text = raw_metadata.get('X-TIKA:content', '<No text content found>').strip()
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
if 'Content-Type' in raw_metadata:
headers['Content-Type'] = raw_metadata['Content-Type']
log.debug("Tika extracted text: %s", text)
log.debug('Tika extracted text: %s', text)
return [Document(page_content=text, metadata=headers)]
else:
raise Exception(f"Error calling Tika: {r.reason}")
raise Exception(f'Error calling Tika: {r.reason}')
class DoclingLoader:
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
self.url = url.rstrip("/")
self.url = url.rstrip('/')
self.api_key = api_key
self.file_path = file_path
self.mime_type = mime_type
@@ -185,199 +185,183 @@ class DoclingLoader:
self.params = params or {}
def load(self) -> list[Document]:
with open(self.file_path, "rb") as f:
with open(self.file_path, 'rb') as f:
headers = {}
if self.api_key:
headers["X-Api-Key"] = f"{self.api_key}"
headers['X-Api-Key'] = f'{self.api_key}'
r = requests.post(
f"{self.url}/v1/convert/file",
f'{self.url}/v1/convert/file',
files={
"files": (
'files': (
self.file_path,
f,
self.mime_type or "application/octet-stream",
self.mime_type or 'application/octet-stream',
)
},
data={
"image_export_mode": "placeholder",
'image_export_mode': 'placeholder',
**self.params,
},
headers=headers,
)
if r.ok:
result = r.json()
document_data = result.get("document", {})
text = document_data.get("md_content", "<No text content found>")
document_data = result.get('document', {})
text = document_data.get('md_content', '<No text content found>')
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
metadata = {'Content-Type': self.mime_type} if self.mime_type else {}
log.debug("Docling extracted text: %s", text)
log.debug('Docling extracted text: %s', text)
return [Document(page_content=text, metadata=metadata)]
else:
error_msg = f"Error calling Docling API: {r.reason}"
error_msg = f'Error calling Docling API: {r.reason}'
if r.text:
try:
error_data = r.json()
if "detail" in error_data:
error_msg += f" - {error_data['detail']}"
if 'detail' in error_data:
error_msg += f' - {error_data["detail"]}'
except Exception:
error_msg += f" - {r.text}"
raise Exception(f"Error calling Docling: {error_msg}")
error_msg += f' - {r.text}'
raise Exception(f'Error calling Docling: {error_msg}')
class Loader:
def __init__(self, engine: str = "", **kwargs):
def __init__(self, engine: str = '', **kwargs):
self.engine = engine
self.user = kwargs.get("user", None)
self.user = kwargs.get('user', None)
self.kwargs = kwargs
def load(
self, filename: str, file_content_type: str, file_path: str
) -> list[Document]:
def load(self, filename: str, file_content_type: str, file_path: str) -> list[Document]:
loader = self._get_loader(filename, file_content_type, file_path)
docs = loader.load()
return [
Document(
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
)
for doc in docs
]
return [Document(page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata) for doc in docs]
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
return file_ext in known_source_ext or (
file_content_type
and file_content_type.find("text/") >= 0
and file_content_type.find('text/') >= 0
# Avoid text/html files being detected as text
and not file_content_type.find("html") >= 0
and not file_content_type.find('html') >= 0
)
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
file_ext = filename.split('.')[-1].lower()
if (
self.engine == "external"
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
self.engine == 'external'
and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL')
and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY')
):
loader = ExternalDocumentLoader(
file_path=file_path,
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
url=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL'),
api_key=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY'),
mime_type=file_content_type,
user=self.user,
)
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
elif self.engine == 'tika' and self.kwargs.get('TIKA_SERVER_URL'):
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TikaLoader(
url=self.kwargs.get("TIKA_SERVER_URL"),
url=self.kwargs.get('TIKA_SERVER_URL'),
file_path=file_path,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
)
elif (
self.engine == "datalab_marker"
and self.kwargs.get("DATALAB_MARKER_API_KEY")
self.engine == 'datalab_marker'
and self.kwargs.get('DATALAB_MARKER_API_KEY')
and file_ext
in [
"pdf",
"xls",
"xlsx",
"ods",
"doc",
"docx",
"odt",
"ppt",
"pptx",
"odp",
"html",
"epub",
"png",
"jpeg",
"jpg",
"webp",
"gif",
"tiff",
'pdf',
'xls',
'xlsx',
'ods',
'doc',
'docx',
'odt',
'ppt',
'pptx',
'odp',
'html',
'epub',
'png',
'jpeg',
'jpg',
'webp',
'gif',
'tiff',
]
):
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "")
if not api_base_url or api_base_url.strip() == "":
api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
api_base_url = self.kwargs.get('DATALAB_MARKER_API_BASE_URL', '')
if not api_base_url or api_base_url.strip() == '':
api_base_url = 'https://www.datalab.to/api/v1/marker' # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
loader = DatalabMarkerLoader(
file_path=file_path,
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
api_key=self.kwargs['DATALAB_MARKER_API_KEY'],
api_base_url=api_base_url,
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"),
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
strip_existing_ocr=self.kwargs.get(
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
),
disable_image_extraction=self.kwargs.get(
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
),
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
output_format=self.kwargs.get(
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
),
additional_config=self.kwargs.get('DATALAB_MARKER_ADDITIONAL_CONFIG'),
use_llm=self.kwargs.get('DATALAB_MARKER_USE_LLM', False),
skip_cache=self.kwargs.get('DATALAB_MARKER_SKIP_CACHE', False),
force_ocr=self.kwargs.get('DATALAB_MARKER_FORCE_OCR', False),
paginate=self.kwargs.get('DATALAB_MARKER_PAGINATE', False),
strip_existing_ocr=self.kwargs.get('DATALAB_MARKER_STRIP_EXISTING_OCR', False),
disable_image_extraction=self.kwargs.get('DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION', False),
format_lines=self.kwargs.get('DATALAB_MARKER_FORMAT_LINES', False),
output_format=self.kwargs.get('DATALAB_MARKER_OUTPUT_FORMAT', 'markdown'),
)
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
elif self.engine == 'docling' and self.kwargs.get('DOCLING_SERVER_URL'):
if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
# Build params for DoclingLoader
params = self.kwargs.get("DOCLING_PARAMS", {})
params = self.kwargs.get('DOCLING_PARAMS', {})
if not isinstance(params, dict):
try:
params = json.loads(params)
except json.JSONDecodeError:
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
log.error('Invalid DOCLING_PARAMS format, expected JSON object')
params = {}
loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"),
api_key=self.kwargs.get("DOCLING_API_KEY", None),
url=self.kwargs.get('DOCLING_SERVER_URL'),
api_key=self.kwargs.get('DOCLING_API_KEY', None),
file_path=file_path,
mime_type=file_content_type,
params=params,
)
elif (
self.engine == "document_intelligence"
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
self.engine == 'document_intelligence'
and self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT') != ''
and (
file_ext in ["pdf", "docx", "ppt", "pptx"]
file_ext in ['pdf', 'docx', 'ppt', 'pptx']
or file_content_type
in [
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
]
)
):
if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "":
if self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY') != '':
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
api_key=self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY'),
api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
)
else:
loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
azure_credential=DefaultAzureCredential(),
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
)
elif self.engine == "mineru" and file_ext in [
"pdf"
]: # MinerU currently only supports PDF
mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300)
elif self.engine == 'mineru' and file_ext in ['pdf']: # MinerU currently only supports PDF
mineru_timeout = self.kwargs.get('MINERU_API_TIMEOUT', 300)
if mineru_timeout:
try:
mineru_timeout = int(mineru_timeout)
@@ -386,111 +370,115 @@ class Loader:
loader = MinerULoader(
file_path=file_path,
api_mode=self.kwargs.get("MINERU_API_MODE", "local"),
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"),
api_key=self.kwargs.get("MINERU_API_KEY", ""),
params=self.kwargs.get("MINERU_PARAMS", {}),
api_mode=self.kwargs.get('MINERU_API_MODE', 'local'),
api_url=self.kwargs.get('MINERU_API_URL', 'http://localhost:8000'),
api_key=self.kwargs.get('MINERU_API_KEY', ''),
params=self.kwargs.get('MINERU_PARAMS', {}),
timeout=mineru_timeout,
)
elif (
self.engine == "mistral_ocr"
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
and file_ext
in ["pdf"] # Mistral OCR currently only supports PDF and images
self.engine == 'mistral_ocr'
and self.kwargs.get('MISTRAL_OCR_API_KEY') != ''
and file_ext in ['pdf'] # Mistral OCR currently only supports PDF and images
):
loader = MistralLoader(
base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"),
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"),
base_url=self.kwargs.get('MISTRAL_OCR_API_BASE_URL'),
api_key=self.kwargs.get('MISTRAL_OCR_API_KEY'),
file_path=file_path,
)
else:
if file_ext == "pdf":
if file_ext == 'pdf':
loader = PyPDFLoader(
file_path,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
mode=self.kwargs.get("PDF_LOADER_MODE", "page"),
extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
mode=self.kwargs.get('PDF_LOADER_MODE', 'page'),
)
elif file_ext == "csv":
elif file_ext == 'csv':
loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_ext == "rst":
elif file_ext == 'rst':
try:
from langchain_community.document_loaders import UnstructuredRSTLoader
loader = UnstructuredRSTLoader(file_path, mode="elements")
loader = UnstructuredRSTLoader(file_path, mode='elements')
except ImportError:
log.warning(
"The 'unstructured' package is not installed. "
"Falling back to plain text loading for .rst file. "
"Install it with: pip install unstructured"
'Falling back to plain text loading for .rst file. '
'Install it with: pip install unstructured'
)
loader = TextLoader(file_path, autodetect_encoding=True)
elif file_ext == "xml":
elif file_ext == 'xml':
try:
from langchain_community.document_loaders import UnstructuredXMLLoader
loader = UnstructuredXMLLoader(file_path)
except ImportError:
log.warning(
"The 'unstructured' package is not installed. "
"Falling back to plain text loading for .xml file. "
"Install it with: pip install unstructured"
'Falling back to plain text loading for .xml file. '
'Install it with: pip install unstructured'
)
loader = TextLoader(file_path, autodetect_encoding=True)
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md":
elif file_ext in ['htm', 'html']:
loader = BSHTMLLoader(file_path, open_encoding='unicode_escape')
elif file_ext == 'md':
loader = TextLoader(file_path, autodetect_encoding=True)
elif file_content_type == "application/epub+zip":
elif file_content_type == 'application/epub+zip':
try:
from langchain_community.document_loaders import UnstructuredEPubLoader
loader = UnstructuredEPubLoader(file_path)
except ImportError:
raise ValueError(
"Processing .epub files requires the 'unstructured' package. "
"Install it with: pip install unstructured"
'Install it with: pip install unstructured'
)
elif (
file_content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file_ext == "docx"
file_content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
or file_ext == 'docx'
):
loader = Docx2txtLoader(file_path)
elif file_content_type in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]:
'application/vnd.ms-excel',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
] or file_ext in ['xls', 'xlsx']:
try:
from langchain_community.document_loaders import UnstructuredExcelLoader
loader = UnstructuredExcelLoader(file_path)
except ImportError:
log.warning(
"The 'unstructured' package is not installed. "
"Falling back to pandas for Excel file loading. "
"Install unstructured for better results: pip install unstructured"
'Falling back to pandas for Excel file loading. '
'Install unstructured for better results: pip install unstructured'
)
loader = ExcelLoader(file_path)
elif file_content_type in [
"application/vnd.ms-powerpoint",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
] or file_ext in ["ppt", "pptx"]:
'application/vnd.ms-powerpoint',
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
] or file_ext in ['ppt', 'pptx']:
try:
from langchain_community.document_loaders import UnstructuredPowerPointLoader
loader = UnstructuredPowerPointLoader(file_path)
except ImportError:
log.warning(
"The 'unstructured' package is not installed. "
"Falling back to python-pptx for PowerPoint file loading. "
"Install unstructured for better results: pip install unstructured"
'Falling back to python-pptx for PowerPoint file loading. '
'Install unstructured for better results: pip install unstructured'
)
loader = PptxLoader(file_path)
elif file_ext == "msg":
elif file_ext == 'msg':
loader = OutlookMessageLoader(file_path)
elif file_ext == "odt":
elif file_ext == 'odt':
try:
from langchain_community.document_loaders import UnstructuredODTLoader
loader = UnstructuredODTLoader(file_path)
except ImportError:
raise ValueError(
"Processing .odt files requires the 'unstructured' package. "
"Install it with: pip install unstructured"
'Install it with: pip install unstructured'
)
elif self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True)
@@ -498,4 +486,3 @@ class Loader:
loader = TextLoader(file_path, autodetect_encoding=True)
return loader

View File

@@ -22,37 +22,35 @@ class MinerULoader:
def __init__(
self,
file_path: str,
api_mode: str = "local",
api_url: str = "http://localhost:8000",
api_key: str = "",
api_mode: str = 'local',
api_url: str = 'http://localhost:8000',
api_key: str = '',
params: dict = None,
timeout: Optional[int] = 300,
):
self.file_path = file_path
self.api_mode = api_mode.lower()
self.api_url = api_url.rstrip("/")
self.api_url = api_url.rstrip('/')
self.api_key = api_key
self.timeout = timeout
# Parse params dict with defaults
self.params = params or {}
self.enable_ocr = params.get("enable_ocr", False)
self.enable_formula = params.get("enable_formula", True)
self.enable_table = params.get("enable_table", True)
self.language = params.get("language", "en")
self.model_version = params.get("model_version", "pipeline")
self.enable_ocr = params.get('enable_ocr', False)
self.enable_formula = params.get('enable_formula', True)
self.enable_table = params.get('enable_table', True)
self.language = params.get('language', 'en')
self.model_version = params.get('model_version', 'pipeline')
self.page_ranges = self.params.pop("page_ranges", "")
self.page_ranges = self.params.pop('page_ranges', '')
# Validate API mode
if self.api_mode not in ["local", "cloud"]:
raise ValueError(
f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'"
)
if self.api_mode not in ['local', 'cloud']:
raise ValueError(f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'")
# Validate Cloud API requirements
if self.api_mode == "cloud" and not self.api_key:
raise ValueError("API key is required for Cloud API mode")
if self.api_mode == 'cloud' and not self.api_key:
raise ValueError('API key is required for Cloud API mode')
def load(self) -> List[Document]:
"""
@@ -60,12 +58,12 @@ class MinerULoader:
Routes to Cloud or Local API based on api_mode.
"""
try:
if self.api_mode == "cloud":
if self.api_mode == 'cloud':
return self._load_cloud_api()
else:
return self._load_local_api()
except Exception as e:
log.error(f"Error loading document with MinerU: {e}")
log.error(f'Error loading document with MinerU: {e}')
raise
def _load_local_api(self) -> List[Document]:
@@ -73,14 +71,14 @@ class MinerULoader:
Load document using Local API (synchronous).
Posts file to /file_parse endpoint and gets immediate response.
"""
log.info(f"Using MinerU Local API at {self.api_url}")
log.info(f'Using MinerU Local API at {self.api_url}')
filename = os.path.basename(self.file_path)
# Build form data for Local API
form_data = {
**self.params,
"return_md": "true",
'return_md': 'true',
}
# Page ranges (Local API uses start_page_id and end_page_id)
@@ -89,18 +87,18 @@ class MinerULoader:
# Full page range parsing would require parsing the string
log.warning(
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
"Consider using start_page_id/end_page_id parameters if needed."
'Consider using start_page_id/end_page_id parameters if needed.'
)
try:
with open(self.file_path, "rb") as f:
files = {"files": (filename, f, "application/octet-stream")}
with open(self.file_path, 'rb') as f:
files = {'files': (filename, f, 'application/octet-stream')}
log.info(f"Sending file to MinerU Local API: {filename}")
log.debug(f"Local API parameters: {form_data}")
log.info(f'Sending file to MinerU Local API: {filename}')
log.debug(f'Local API parameters: {form_data}')
response = requests.post(
f"{self.api_url}/file_parse",
f'{self.api_url}/file_parse',
data=form_data,
files=files,
timeout=self.timeout,
@@ -108,27 +106,25 @@ class MinerULoader:
response.raise_for_status()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
except requests.Timeout:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="MinerU Local API request timed out",
detail='MinerU Local API request timed out',
)
except requests.HTTPError as e:
error_detail = f"MinerU Local API request failed: {e}"
error_detail = f'MinerU Local API request failed: {e}'
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data}"
error_detail += f' - {error_data}'
except Exception:
error_detail += f" - {e.response.text}"
error_detail += f' - {e.response.text}'
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error calling MinerU Local API: {str(e)}",
detail=f'Error calling MinerU Local API: {str(e)}',
)
# Parse response
@@ -137,41 +133,41 @@ class MinerULoader:
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response from MinerU Local API: {e}",
detail=f'Invalid JSON response from MinerU Local API: {e}',
)
# Extract markdown content from response
if "results" not in result:
if 'results' not in result:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail="MinerU Local API response missing 'results' field",
)
results = result["results"]
results = result['results']
if not results:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="MinerU returned empty results",
detail='MinerU returned empty results',
)
# Get the first (and typically only) result
file_result = list(results.values())[0]
markdown_content = file_result.get("md_content", "")
markdown_content = file_result.get('md_content', '')
if not markdown_content:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="MinerU returned empty markdown content",
detail='MinerU returned empty markdown content',
)
log.info(f"Successfully parsed document with MinerU Local API: {filename}")
log.info(f'Successfully parsed document with MinerU Local API: {filename}')
# Create metadata
metadata = {
"source": filename,
"api_mode": "local",
"backend": result.get("backend", "unknown"),
"version": result.get("version", "unknown"),
'source': filename,
'api_mode': 'local',
'backend': result.get('backend', 'unknown'),
'version': result.get('version', 'unknown'),
}
return [Document(page_content=markdown_content, metadata=metadata)]
@@ -181,7 +177,7 @@ class MinerULoader:
Load document using Cloud API (asynchronous).
Uses batch upload endpoint to avoid need for public file URLs.
"""
log.info(f"Using MinerU Cloud API at {self.api_url}")
log.info(f'Using MinerU Cloud API at {self.api_url}')
filename = os.path.basename(self.file_path)
@@ -195,17 +191,15 @@ class MinerULoader:
result = self._poll_batch_status(batch_id, filename)
# Step 4: Download and extract markdown from ZIP
markdown_content = self._download_and_extract_zip(
result["full_zip_url"], filename
)
markdown_content = self._download_and_extract_zip(result['full_zip_url'], filename)
log.info(f"Successfully parsed document with MinerU Cloud API: {filename}")
log.info(f'Successfully parsed document with MinerU Cloud API: {filename}')
# Create metadata
metadata = {
"source": filename,
"api_mode": "cloud",
"batch_id": batch_id,
'source': filename,
'api_mode': 'cloud',
'batch_id': batch_id,
}
return [Document(page_content=markdown_content, metadata=metadata)]
@@ -216,49 +210,49 @@ class MinerULoader:
Returns (batch_id, upload_url).
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
}
# Build request body
request_body = {
**self.params,
"files": [
'files': [
{
"name": filename,
"is_ocr": self.enable_ocr,
'name': filename,
'is_ocr': self.enable_ocr,
}
],
}
# Add page ranges if specified
if self.page_ranges:
request_body["files"][0]["page_ranges"] = self.page_ranges
request_body['files'][0]['page_ranges'] = self.page_ranges
log.info(f"Requesting upload URL for: {filename}")
log.debug(f"Cloud API request body: {request_body}")
log.info(f'Requesting upload URL for: {filename}')
log.debug(f'Cloud API request body: {request_body}')
try:
response = requests.post(
f"{self.api_url}/file-urls/batch",
f'{self.api_url}/file-urls/batch',
headers=headers,
json=request_body,
timeout=30,
)
response.raise_for_status()
except requests.HTTPError as e:
error_detail = f"Failed to request upload URL: {e}"
error_detail = f'Failed to request upload URL: {e}'
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data.get('msg', error_data)}"
error_detail += f' - {error_data.get("msg", error_data)}'
except Exception:
error_detail += f" - {e.response.text}"
error_detail += f' - {e.response.text}'
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error requesting upload URL: {str(e)}",
detail=f'Error requesting upload URL: {str(e)}',
)
try:
@@ -266,28 +260,28 @@ class MinerULoader:
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response: {e}",
detail=f'Invalid JSON response: {e}',
)
# Check for API error response
if result.get("code") != 0:
if result.get('code') != 0:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
)
data = result.get("data", {})
batch_id = data.get("batch_id")
file_urls = data.get("file_urls", [])
data = result.get('data', {})
batch_id = data.get('batch_id')
file_urls = data.get('file_urls', [])
if not batch_id or not file_urls:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail="MinerU Cloud API response missing batch_id or file_urls",
detail='MinerU Cloud API response missing batch_id or file_urls',
)
upload_url = file_urls[0]
log.info(f"Received upload URL for batch: {batch_id}")
log.info(f'Received upload URL for batch: {batch_id}')
return batch_id, upload_url
@@ -295,10 +289,10 @@ class MinerULoader:
"""
Upload file to presigned URL (no authentication needed).
"""
log.info(f"Uploading file to presigned URL")
log.info(f'Uploading file to presigned URL')
try:
with open(self.file_path, "rb") as f:
with open(self.file_path, 'rb') as f:
response = requests.put(
upload_url,
data=f,
@@ -306,26 +300,24 @@ class MinerULoader:
)
response.raise_for_status()
except FileNotFoundError:
raise HTTPException(
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
except requests.Timeout:
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="File upload to presigned URL timed out",
detail='File upload to presigned URL timed out',
)
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Failed to upload file to presigned URL: {e}",
detail=f'Failed to upload file to presigned URL: {e}',
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error uploading file: {str(e)}",
detail=f'Error uploading file: {str(e)}',
)
log.info("File uploaded successfully")
log.info('File uploaded successfully')
def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
"""
@@ -333,35 +325,35 @@ class MinerULoader:
Returns the result dict for the file.
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
'Authorization': f'Bearer {self.api_key}',
}
max_iterations = 300 # 10 minutes max (2 seconds per iteration)
poll_interval = 2 # seconds
log.info(f"Polling batch status: {batch_id}")
log.info(f'Polling batch status: {batch_id}')
for iteration in range(max_iterations):
try:
response = requests.get(
f"{self.api_url}/extract-results/batch/{batch_id}",
f'{self.api_url}/extract-results/batch/{batch_id}',
headers=headers,
timeout=30,
)
response.raise_for_status()
except requests.HTTPError as e:
error_detail = f"Failed to poll batch status: {e}"
error_detail = f'Failed to poll batch status: {e}'
if e.response is not None:
try:
error_data = e.response.json()
error_detail += f" - {error_data.get('msg', error_data)}"
error_detail += f' - {error_data.get("msg", error_data)}'
except Exception:
error_detail += f" - {e.response.text}"
error_detail += f' - {e.response.text}'
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error polling batch status: {str(e)}",
detail=f'Error polling batch status: {str(e)}',
)
try:
@@ -369,58 +361,56 @@ class MinerULoader:
except ValueError as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response while polling: {e}",
detail=f'Invalid JSON response while polling: {e}',
)
# Check for API error response
if result.get("code") != 0:
if result.get('code') != 0:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
)
data = result.get("data", {})
extract_result = data.get("extract_result", [])
data = result.get('data', {})
extract_result = data.get('extract_result', [])
# Find our file in the batch results
file_result = None
for item in extract_result:
if item.get("file_name") == filename:
if item.get('file_name') == filename:
file_result = item
break
if not file_result:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"File {filename} not found in batch results",
detail=f'File {filename} not found in batch results',
)
state = file_result.get("state")
state = file_result.get('state')
if state == "done":
log.info(f"Processing complete for {filename}")
if state == 'done':
log.info(f'Processing complete for {filename}')
return file_result
elif state == "failed":
error_msg = file_result.get("err_msg", "Unknown error")
elif state == 'failed':
error_msg = file_result.get('err_msg', 'Unknown error')
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"MinerU processing failed: {error_msg}",
detail=f'MinerU processing failed: {error_msg}',
)
elif state in ["waiting-file", "pending", "running", "converting"]:
elif state in ['waiting-file', 'pending', 'running', 'converting']:
# Still processing
if iteration % 10 == 0: # Log every 20 seconds
log.info(
f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})"
)
log.info(f'Processing status: {state} (iteration {iteration + 1}/{max_iterations})')
time.sleep(poll_interval)
else:
log.warning(f"Unknown state: {state}")
log.warning(f'Unknown state: {state}')
time.sleep(poll_interval)
# Timeout
raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT,
detail="MinerU processing timed out after 10 minutes",
detail='MinerU processing timed out after 10 minutes',
)
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
@@ -428,7 +418,7 @@ class MinerULoader:
Download ZIP file from CDN and extract markdown content.
Returns the markdown content as a string.
"""
log.info(f"Downloading results from: {zip_url}")
log.info(f'Downloading results from: {zip_url}')
try:
response = requests.get(zip_url, timeout=60)
@@ -436,23 +426,23 @@ class MinerULoader:
except requests.HTTPError as e:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail=f"Failed to download results ZIP: {e}",
detail=f'Failed to download results ZIP: {e}',
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error downloading results: {str(e)}",
detail=f'Error downloading results: {str(e)}',
)
# Save ZIP to temporary file and extract
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_zip:
tmp_zip.write(response.content)
tmp_zip_path = tmp_zip.name
with tempfile.TemporaryDirectory() as tmp_dir:
# Extract ZIP
with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref:
with zipfile.ZipFile(tmp_zip_path, 'r') as zip_ref:
zip_ref.extractall(tmp_dir)
# Find markdown file - search recursively for any .md file
@@ -466,33 +456,27 @@ class MinerULoader:
full_path = os.path.join(root, file)
all_files.append(full_path)
# Look for any .md file
if file.endswith(".md"):
if file.endswith('.md'):
found_md_path = full_path
log.info(f"Found markdown file at: {full_path}")
log.info(f'Found markdown file at: {full_path}')
try:
with open(full_path, "r", encoding="utf-8") as f:
with open(full_path, 'r', encoding='utf-8') as f:
markdown_content = f.read()
if (
markdown_content
): # Use the first non-empty markdown file
if markdown_content: # Use the first non-empty markdown file
break
except Exception as e:
log.warning(f"Failed to read {full_path}: {e}")
log.warning(f'Failed to read {full_path}: {e}')
if markdown_content:
break
if markdown_content is None:
log.error(f"Available files in ZIP: {all_files}")
log.error(f'Available files in ZIP: {all_files}')
# Try to provide more helpful error message
md_files = [f for f in all_files if f.endswith(".md")]
md_files = [f for f in all_files if f.endswith('.md')]
if md_files:
error_msg = (
f"Found .md files but couldn't read them: {md_files}"
)
error_msg = f"Found .md files but couldn't read them: {md_files}"
else:
error_msg = (
f"No .md files found in ZIP. Available files: {all_files}"
)
error_msg = f'No .md files found in ZIP. Available files: {all_files}'
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=error_msg,
@@ -504,21 +488,19 @@ class MinerULoader:
except zipfile.BadZipFile as e:
raise HTTPException(
status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid ZIP file received: {e}",
detail=f'Invalid ZIP file received: {e}',
)
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error extracting ZIP: {str(e)}",
detail=f'Error extracting ZIP: {str(e)}',
)
if not markdown_content:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
detail="Extracted markdown content is empty",
detail='Extracted markdown content is empty',
)
log.info(
f"Successfully extracted markdown content ({len(markdown_content)} characters)"
)
log.info(f'Successfully extracted markdown content ({len(markdown_content)} characters)')
return markdown_content

View File

@@ -49,13 +49,11 @@ class MistralLoader:
enable_debug_logging: Enable detailed debug logs.
"""
if not api_key:
raise ValueError("API key cannot be empty.")
raise ValueError('API key cannot be empty.')
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found at {file_path}")
raise FileNotFoundError(f'File not found at {file_path}')
self.base_url = (
base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1"
)
self.base_url = base_url.rstrip('/') if base_url else 'https://api.mistral.ai/v1'
self.api_key = api_key
self.file_path = file_path
self.timeout = timeout
@@ -65,18 +63,10 @@ class MistralLoader:
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
# This prevents long-running OCR operations from affecting quick operations
# and improves user experience by failing fast on operations that should be quick
self.upload_timeout = min(
timeout, 120
) # Cap upload at 2 minutes - prevents hanging on large files
self.url_timeout = (
30 # URL requests should be fast - fail quickly if API is slow
)
self.ocr_timeout = (
timeout # OCR can take the full timeout - this is the heavy operation
)
self.cleanup_timeout = (
30 # Cleanup should be quick - don't hang on file deletion
)
self.upload_timeout = min(timeout, 120) # Cap upload at 2 minutes - prevents hanging on large files
self.url_timeout = 30 # URL requests should be fast - fail quickly if API is slow
self.ocr_timeout = timeout # OCR can take the full timeout - this is the heavy operation
self.cleanup_timeout = 30 # Cleanup should be quick - don't hang on file deletion
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
@@ -85,8 +75,8 @@ class MistralLoader:
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
'Authorization': f'Bearer {self.api_key}',
'User-Agent': 'OpenWebUI-MistralLoader/2.0', # Helps API provider track usage
}
def _debug_log(self, message: str, *args) -> None:
@@ -108,43 +98,39 @@ class MistralLoader:
return {} # Return empty dict if no content
return response.json()
except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
log.error(f'HTTP error occurred: {http_err} - Response: {response.text}')
raise
except requests.exceptions.RequestException as req_err:
log.error(f"Request exception occurred: {req_err}")
log.error(f'Request exception occurred: {req_err}')
raise
except ValueError as json_err: # Includes JSONDecodeError
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
log.error(f'JSON decode error: {json_err} - Response: {response.text}')
raise # Re-raise after logging
async def _handle_response_async(
self, response: aiohttp.ClientResponse
) -> Dict[str, Any]:
async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]:
"""Async version of response handling with better error info."""
try:
response.raise_for_status()
# Check content type
content_type = response.headers.get("content-type", "")
if "application/json" not in content_type:
content_type = response.headers.get('content-type', '')
if 'application/json' not in content_type:
if response.status == 204:
return {}
text = await response.text()
raise ValueError(
f"Unexpected content type: {content_type}, body: {text[:200]}..."
)
raise ValueError(f'Unexpected content type: {content_type}, body: {text[:200]}...')
return await response.json()
except aiohttp.ClientResponseError as e:
error_text = await response.text() if response else "No response"
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
error_text = await response.text() if response else 'No response'
log.error(f'HTTP {e.status}: {e.message} - Response: {error_text[:500]}')
raise
except aiohttp.ClientError as e:
log.error(f"Client error: {e}")
log.error(f'Client error: {e}')
raise
except Exception as e:
log.error(f"Unexpected error processing response: {e}")
log.error(f'Unexpected error processing response: {e}')
raise
def _is_retryable_error(self, error: Exception) -> bool:
@@ -172,13 +158,11 @@ class MistralLoader:
return True # Timeouts might resolve on retry
if isinstance(error, requests.exceptions.HTTPError):
# Only retry on server errors (5xx) or rate limits (429)
if hasattr(error, "response") and error.response is not None:
if hasattr(error, 'response') and error.response is not None:
status_code = error.response.status_code
return status_code >= 500 or status_code == 429
return False
if isinstance(
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
):
if isinstance(error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)):
return True # Async network/timeout errors are retryable
if isinstance(error, aiohttp.ClientResponseError):
return error.status >= 500 or error.status == 429
@@ -204,8 +188,7 @@ class MistralLoader:
# Prevents overwhelming the server while ensuring reasonable retry delays
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
f"Retrying in {wait_time}s..."
f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
)
time.sleep(wait_time)
@@ -226,8 +209,7 @@ class MistralLoader:
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
f"Retrying in {wait_time}s..."
f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
)
await asyncio.sleep(wait_time) # Non-blocking wait
@@ -240,15 +222,15 @@ class MistralLoader:
Although streaming is not enabled for this endpoint, the file is opened
in a context manager to minimize memory usage duration.
"""
log.info("Uploading file to Mistral API")
url = f"{self.base_url}/files"
log.info('Uploading file to Mistral API')
url = f'{self.base_url}/files'
def upload_request():
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
# This ensures the file is closed immediately after reading, reducing memory usage
with open(self.file_path, "rb") as f:
files = {"file": (self.file_name, f, "application/pdf")}
data = {"purpose": "ocr"}
with open(self.file_path, 'rb') as f:
files = {'file': (self.file_name, f, 'application/pdf')}
data = {'purpose': 'ocr'}
# NOTE: stream=False is required for this endpoint
# The Mistral API doesn't support chunked uploads for this endpoint
@@ -265,42 +247,38 @@ class MistralLoader:
try:
response_data = self._retry_request_sync(upload_request)
file_id = response_data.get("id")
file_id = response_data.get('id')
if not file_id:
raise ValueError("File ID not found in upload response.")
log.info(f"File uploaded successfully. File ID: {file_id}")
raise ValueError('File ID not found in upload response.')
log.info(f'File uploaded successfully. File ID: {file_id}')
return file_id
except Exception as e:
log.error(f"Failed to upload file: {e}")
log.error(f'Failed to upload file: {e}')
raise
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
"""Async file upload with streaming for better memory efficiency."""
url = f"{self.base_url}/files"
url = f'{self.base_url}/files'
async def upload_request():
# Create multipart writer for streaming upload
writer = aiohttp.MultipartWriter("form-data")
writer = aiohttp.MultipartWriter('form-data')
# Add purpose field
purpose_part = writer.append("ocr")
purpose_part.set_content_disposition("form-data", name="purpose")
purpose_part = writer.append('ocr')
purpose_part.set_content_disposition('form-data', name='purpose')
# Add file part with streaming
file_part = writer.append_payload(
aiohttp.streams.FilePayload(
self.file_path,
filename=self.file_name,
content_type="application/pdf",
content_type='application/pdf',
)
)
file_part.set_content_disposition(
"form-data", name="file", filename=self.file_name
)
file_part.set_content_disposition('form-data', name='file', filename=self.file_name)
self._debug_log(
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
)
self._debug_log(f'Uploading file: {self.file_name} ({self.file_size:,} bytes)')
async with session.post(
url,
@@ -312,48 +290,44 @@ class MistralLoader:
response_data = await self._retry_request_async(upload_request)
file_id = response_data.get("id")
file_id = response_data.get('id')
if not file_id:
raise ValueError("File ID not found in upload response.")
raise ValueError('File ID not found in upload response.')
log.info(f"File uploaded successfully. File ID: {file_id}")
log.info(f'File uploaded successfully. File ID: {file_id}')
return file_id
def _get_signed_url(self, file_id: str) -> str:
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
log.info(f"Getting signed URL for file ID: {file_id}")
url = f"{self.base_url}/files/{file_id}/url"
params = {"expiry": 1}
signed_url_headers = {**self.headers, "Accept": "application/json"}
log.info(f'Getting signed URL for file ID: {file_id}')
url = f'{self.base_url}/files/{file_id}/url'
params = {'expiry': 1}
signed_url_headers = {**self.headers, 'Accept': 'application/json'}
def url_request():
response = requests.get(
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
)
response = requests.get(url, headers=signed_url_headers, params=params, timeout=self.url_timeout)
return self._handle_response(response)
try:
response_data = self._retry_request_sync(url_request)
signed_url = response_data.get("url")
signed_url = response_data.get('url')
if not signed_url:
raise ValueError("Signed URL not found in response.")
log.info("Signed URL received.")
raise ValueError('Signed URL not found in response.')
log.info('Signed URL received.')
return signed_url
except Exception as e:
log.error(f"Failed to get signed URL: {e}")
log.error(f'Failed to get signed URL: {e}')
raise
async def _get_signed_url_async(
self, session: aiohttp.ClientSession, file_id: str
) -> str:
async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str:
"""Async signed URL retrieval."""
url = f"{self.base_url}/files/{file_id}/url"
params = {"expiry": 1}
url = f'{self.base_url}/files/{file_id}/url'
params = {'expiry': 1}
headers = {**self.headers, "Accept": "application/json"}
headers = {**self.headers, 'Accept': 'application/json'}
async def url_request():
self._debug_log(f"Getting signed URL for file ID: {file_id}")
self._debug_log(f'Getting signed URL for file ID: {file_id}')
async with session.get(
url,
headers=headers,
@@ -364,69 +338,65 @@ class MistralLoader:
response_data = await self._retry_request_async(url_request)
signed_url = response_data.get("url")
signed_url = response_data.get('url')
if not signed_url:
raise ValueError("Signed URL not found in response.")
raise ValueError('Signed URL not found in response.')
self._debug_log("Signed URL received successfully")
self._debug_log('Signed URL received successfully')
return signed_url
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
log.info("Processing OCR via Mistral API")
url = f"{self.base_url}/ocr"
log.info('Processing OCR via Mistral API')
url = f'{self.base_url}/ocr'
ocr_headers = {
**self.headers,
"Content-Type": "application/json",
"Accept": "application/json",
'Content-Type': 'application/json',
'Accept': 'application/json',
}
payload = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url,
'model': 'mistral-ocr-latest',
'document': {
'type': 'document_url',
'document_url': signed_url,
},
"include_image_base64": False,
'include_image_base64': False,
}
def ocr_request():
response = requests.post(
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
)
response = requests.post(url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout)
return self._handle_response(response)
try:
ocr_response = self._retry_request_sync(ocr_request)
log.info("OCR processing done.")
self._debug_log("OCR response: %s", ocr_response)
log.info('OCR processing done.')
self._debug_log('OCR response: %s', ocr_response)
return ocr_response
except Exception as e:
log.error(f"Failed during OCR processing: {e}")
log.error(f'Failed during OCR processing: {e}')
raise
async def _process_ocr_async(
self, session: aiohttp.ClientSession, signed_url: str
) -> Dict[str, Any]:
async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]:
"""Async OCR processing with timing metrics."""
url = f"{self.base_url}/ocr"
url = f'{self.base_url}/ocr'
headers = {
**self.headers,
"Content-Type": "application/json",
"Accept": "application/json",
'Content-Type': 'application/json',
'Accept': 'application/json',
}
payload = {
"model": "mistral-ocr-latest",
"document": {
"type": "document_url",
"document_url": signed_url,
'model': 'mistral-ocr-latest',
'document': {
'type': 'document_url',
'document_url': signed_url,
},
"include_image_base64": False,
'include_image_base64': False,
}
async def ocr_request():
log.info("Starting OCR processing via Mistral API")
log.info('Starting OCR processing via Mistral API')
start_time = time.time()
async with session.post(
@@ -438,7 +408,7 @@ class MistralLoader:
ocr_response = await self._handle_response_async(response)
processing_time = time.time() - start_time
log.info(f"OCR processing completed in {processing_time:.2f}s")
log.info(f'OCR processing completed in {processing_time:.2f}s')
return ocr_response
@@ -446,42 +416,36 @@ class MistralLoader:
def _delete_file(self, file_id: str) -> None:
"""Deletes the file from Mistral storage (sync version)."""
log.info(f"Deleting uploaded file ID: {file_id}")
url = f"{self.base_url}/files/{file_id}"
log.info(f'Deleting uploaded file ID: {file_id}')
url = f'{self.base_url}/files/{file_id}'
try:
response = requests.delete(
url, headers=self.headers, timeout=self.cleanup_timeout
)
response = requests.delete(url, headers=self.headers, timeout=self.cleanup_timeout)
delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}")
log.info(f'File deleted successfully: {delete_response}')
except Exception as e:
# Log error but don't necessarily halt execution if deletion fails
log.error(f"Failed to delete file ID {file_id}: {e}")
log.error(f'Failed to delete file ID {file_id}: {e}')
async def _delete_file_async(
self, session: aiohttp.ClientSession, file_id: str
) -> None:
async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None:
"""Async file deletion with error tolerance."""
try:
async def delete_request():
self._debug_log(f"Deleting file ID: {file_id}")
self._debug_log(f'Deleting file ID: {file_id}')
async with session.delete(
url=f"{self.base_url}/files/{file_id}",
url=f'{self.base_url}/files/{file_id}',
headers=self.headers,
timeout=aiohttp.ClientTimeout(
total=self.cleanup_timeout
), # Shorter timeout for cleanup
timeout=aiohttp.ClientTimeout(total=self.cleanup_timeout), # Shorter timeout for cleanup
) as response:
return await self._handle_response_async(response)
await self._retry_request_async(delete_request)
self._debug_log(f"File {file_id} deleted successfully")
self._debug_log(f'File {file_id} deleted successfully')
except Exception as e:
# Don't fail the entire process if cleanup fails
log.warning(f"Failed to delete file ID {file_id}: {e}")
log.warning(f'Failed to delete file ID {file_id}: {e}')
@asynccontextmanager
async def _get_session(self):
@@ -506,7 +470,7 @@ class MistralLoader:
async with aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
headers={'User-Agent': 'OpenWebUI-MistralLoader/2.0'},
raise_for_status=False, # We handle status codes manually
trust_env=True,
) as session:
@@ -514,13 +478,13 @@ class MistralLoader:
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
pages_data = ocr_response.get("pages")
pages_data = ocr_response.get('pages')
if not pages_data:
log.warning("No pages found in OCR response.")
log.warning('No pages found in OCR response.')
return [
Document(
page_content="No text content found",
metadata={"error": "no_pages", "file_name": self.file_name},
page_content='No text content found',
metadata={'error': 'no_pages', 'file_name': self.file_name},
)
]
@@ -530,8 +494,8 @@ class MistralLoader:
# Process pages in a memory-efficient way
for page_data in pages_data:
page_content = page_data.get("markdown")
page_index = page_data.get("index") # API uses 0-based index
page_content = page_data.get('markdown')
page_index = page_data.get('index') # API uses 0-based index
if page_content is None or page_index is None:
skipped_pages += 1
@@ -548,7 +512,7 @@ class MistralLoader:
if not cleaned_content:
skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}")
self._debug_log(f'Skipping empty page {page_index}')
continue
# Create document with optimized metadata
@@ -556,34 +520,30 @@ class MistralLoader:
Document(
page_content=cleaned_content,
metadata={
"page": page_index, # 0-based index from API
"page_label": page_index + 1, # 1-based label for convenience
"total_pages": total_pages,
"file_name": self.file_name,
"file_size": self.file_size,
"processing_engine": "mistral-ocr",
"content_length": len(cleaned_content),
'page': page_index, # 0-based index from API
'page_label': page_index + 1, # 1-based label for convenience
'total_pages': total_pages,
'file_name': self.file_name,
'file_size': self.file_size,
'processing_engine': 'mistral-ocr',
'content_length': len(cleaned_content),
},
)
)
if skipped_pages > 0:
log.info(
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
)
log.info(f'Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages')
if not documents:
# Case where pages existed but none had valid markdown/index
log.warning(
"OCR response contained pages, but none had valid content/index."
)
log.warning('OCR response contained pages, but none had valid content/index.')
return [
Document(
page_content="No valid text content found in document",
page_content='No valid text content found in document',
metadata={
"error": "no_valid_pages",
"total_pages": total_pages,
"file_name": self.file_name,
'error': 'no_valid_pages',
'total_pages': total_pages,
'file_name': self.file_name,
},
)
]
@@ -615,24 +575,20 @@ class MistralLoader:
documents = self._process_results(ocr_response)
total_time = time.time() - start_time
log.info(
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
log.info(f'Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
return documents
except Exception as e:
total_time = time.time() - start_time
log.error(
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
)
log.error(f'An error occurred during the loading process after {total_time:.2f}s: {e}')
# Return an error document on failure
return [
Document(
page_content=f"Error during processing: {e}",
page_content=f'Error during processing: {e}',
metadata={
"error": "processing_failed",
"file_name": self.file_name,
'error': 'processing_failed',
'file_name': self.file_name,
},
)
]
@@ -643,9 +599,7 @@ class MistralLoader:
self._delete_file(file_id)
except Exception as del_e:
# Log deletion error, but don't overwrite original error if one occurred
log.error(
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
)
log.error(f'Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}')
async def load_async(self) -> List[Document]:
"""
@@ -672,21 +626,19 @@ class MistralLoader:
documents = self._process_results(ocr_response)
total_time = time.time() - start_time
log.info(
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
log.info(f'Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
return documents
except Exception as e:
total_time = time.time() - start_time
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
log.error(f'Async OCR workflow failed after {total_time:.2f}s: {e}')
return [
Document(
page_content=f"Error during OCR processing: {e}",
page_content=f'Error during OCR processing: {e}',
metadata={
"error": "processing_failed",
"file_name": self.file_name,
'error': 'processing_failed',
'file_name': self.file_name,
},
)
]
@@ -697,11 +649,11 @@ class MistralLoader:
async with self._get_session() as session:
await self._delete_file_async(session, file_id)
except Exception as cleanup_error:
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
log.error(f'Cleanup failed for file ID {file_id}: {cleanup_error}')
@staticmethod
async def load_multiple_async(
loaders: List["MistralLoader"],
loaders: List['MistralLoader'],
max_concurrent: int = 5, # Limit concurrent requests
) -> List[List[Document]]:
"""
@@ -717,15 +669,13 @@ class MistralLoader:
if not loaders:
return []
log.info(
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
)
log.info(f'Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent')
start_time = time.time()
# Use semaphore to control concurrency
semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
async def process_with_semaphore(loader: 'MistralLoader') -> List[Document]:
async with semaphore:
return await loader.load_async()
@@ -737,14 +687,14 @@ class MistralLoader:
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
log.error(f"File {i} failed: {result}")
log.error(f'File {i} failed: {result}')
processed_results.append(
[
Document(
page_content=f"Error processing file: {result}",
page_content=f'Error processing file: {result}',
metadata={
"error": "batch_processing_failed",
"file_index": i,
'error': 'batch_processing_failed',
'file_index': i,
},
)
]
@@ -755,15 +705,13 @@ class MistralLoader:
# MONITORING: Log comprehensive batch processing statistics
total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results)
success_count = sum(
1 for result in results if not isinstance(result, Exception)
)
success_count = sum(1 for result in results if not isinstance(result, Exception))
failure_count = len(results) - success_count
log.info(
f"Batch processing completed in {total_time:.2f}s: "
f"{success_count} files succeeded, {failure_count} files failed, "
f"produced {total_docs} total documents"
f'Batch processing completed in {total_time:.2f}s: '
f'{success_count} files succeeded, {failure_count} files failed, '
f'produced {total_docs} total documents'
)
return processed_results

View File

@@ -25,7 +25,7 @@ class TavilyLoader(BaseLoader):
self,
urls: Union[str, List[str]],
api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic",
extract_depth: Literal['basic', 'advanced'] = 'basic',
continue_on_failure: bool = True,
) -> None:
"""Initialize Tavily Extract client.
@@ -42,13 +42,13 @@ class TavilyLoader(BaseLoader):
continue_on_failure: Whether to continue if extraction of a URL fails.
"""
if not urls:
raise ValueError("At least one URL must be provided.")
raise ValueError('At least one URL must be provided.')
self.api_key = api_key
self.urls = urls if isinstance(urls, list) else [urls]
self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure
self.api_url = "https://api.tavily.com/extract"
self.api_url = 'https://api.tavily.com/extract'
def lazy_load(self) -> Iterator[Document]:
"""Extract and yield documents from the URLs using Tavily Extract API."""
@@ -57,35 +57,35 @@ class TavilyLoader(BaseLoader):
batch_urls = self.urls[i : i + batch_size]
try:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
}
# Use string for single URL, array for multiple URLs
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
payload = {'urls': urls_param, 'extract_depth': self.extract_depth}
# Make the API call
response = requests.post(self.api_url, headers=headers, json=payload)
response.raise_for_status()
response_data = response.json()
# Process successful results
for result in response_data.get("results", []):
url = result.get("url", "")
content = result.get("raw_content", "")
for result in response_data.get('results', []):
url = result.get('url', '')
content = result.get('raw_content', '')
if not content:
log.warning(f"No content extracted from {url}")
log.warning(f'No content extracted from {url}')
continue
# Add URLs as metadata
metadata = {"source": url}
metadata = {'source': url}
yield Document(
page_content=content,
metadata=metadata,
)
for failed in response_data.get("failed_results", []):
url = failed.get("url", "")
error = failed.get("error", "Unknown error")
log.error(f"Failed to extract content from {url}: {error}")
for failed in response_data.get('failed_results', []):
url = failed.get('url', '')
error = failed.get('error', 'Unknown error')
log.error(f'Failed to extract content from {url}: {error}')
except Exception as e:
if self.continue_on_failure:
log.error(f"Error extracting content from batch {batch_urls}: {e}")
log.error(f'Error extracting content from batch {batch_urls}: {e}')
else:
raise e

View File

@@ -7,14 +7,14 @@ from langchain_core.documents import Document
log = logging.getLogger(__name__)
ALLOWED_SCHEMES = {"http", "https"}
ALLOWED_SCHEMES = {'http', 'https'}
ALLOWED_NETLOCS = {
"youtu.be",
"m.youtube.com",
"youtube.com",
"www.youtube.com",
"www.youtube-nocookie.com",
"vid.plus",
'youtu.be',
'm.youtube.com',
'youtube.com',
'www.youtube.com',
'www.youtube-nocookie.com',
'vid.plus',
}
@@ -30,17 +30,17 @@ def _parse_video_id(url: str) -> Optional[str]:
path = parsed_url.path
if path.endswith("/watch"):
if path.endswith('/watch'):
query = parsed_url.query
parsed_query = parse_qs(query)
if "v" in parsed_query:
ids = parsed_query["v"]
if 'v' in parsed_query:
ids = parsed_query['v']
video_id = ids if isinstance(ids, str) else ids[0]
else:
return None
else:
path = parsed_url.path.lstrip("/")
video_id = path.split("/")[-1]
path = parsed_url.path.lstrip('/')
video_id = path.split('/')[-1]
if len(video_id) != 11: # Video IDs are 11 characters long
return None
@@ -54,13 +54,13 @@ class YoutubeLoader:
def __init__(
self,
video_id: str,
language: Union[str, Sequence[str]] = "en",
language: Union[str, Sequence[str]] = 'en',
proxy_url: Optional[str] = None,
):
"""Initialize with YouTube video ID."""
_video_id = _parse_video_id(video_id)
self.video_id = _video_id if _video_id is not None else video_id
self._metadata = {"source": video_id}
self._metadata = {'source': video_id}
self.proxy_url = proxy_url
# Ensure language is a list
@@ -70,8 +70,8 @@ class YoutubeLoader:
self.language = list(language)
# Add English as fallback if not already in the list
if "en" not in self.language:
self.language.append("en")
if 'en' not in self.language:
self.language.append('en')
def load(self) -> List[Document]:
"""Load YouTube transcripts into `Document` objects."""
@@ -85,14 +85,12 @@ class YoutubeLoader:
except ImportError:
raise ImportError(
'Could not import "youtube_transcript_api" Python package. '
"Please install it with `pip install youtube-transcript-api`."
'Please install it with `pip install youtube-transcript-api`.'
)
if self.proxy_url:
youtube_proxies = GenericProxyConfig(
http_url=self.proxy_url, https_url=self.proxy_url
)
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
youtube_proxies = GenericProxyConfig(http_url=self.proxy_url, https_url=self.proxy_url)
log.debug(f'Using proxy URL: {self.proxy_url[:14]}...')
else:
youtube_proxies = None
@@ -100,7 +98,7 @@ class YoutubeLoader:
try:
transcript_list = transcript_api.list(self.video_id)
except Exception as e:
log.exception("Loading YouTube transcript failed")
log.exception('Loading YouTube transcript failed')
return []
# Try each language in order of priority
@@ -110,14 +108,10 @@ class YoutubeLoader:
if transcript.is_generated:
log.debug(f"Found generated transcript for language '{lang}'")
try:
transcript = transcript_list.find_manually_created_transcript(
[lang]
)
transcript = transcript_list.find_manually_created_transcript([lang])
log.debug(f"Found manual transcript for language '{lang}'")
except NoTranscriptFound:
log.debug(
f"No manual transcript found for language '{lang}', using generated"
)
log.debug(f"No manual transcript found for language '{lang}', using generated")
pass
log.debug(f"Found transcript for language '{lang}'")
@@ -131,12 +125,10 @@ class YoutubeLoader:
log.debug(f"Empty transcript for language '{lang}'")
continue
transcript_text = " ".join(
transcript_text = ' '.join(
map(
lambda transcript_piece: (
transcript_piece.text.strip(" ")
if hasattr(transcript_piece, "text")
else ""
transcript_piece.text.strip(' ') if hasattr(transcript_piece, 'text') else ''
),
transcript_pieces,
)
@@ -150,9 +142,9 @@ class YoutubeLoader:
raise e
# If we get here, all languages failed
languages_tried = ", ".join(self.language)
languages_tried = ', '.join(self.language)
log.warning(
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
f'No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed.'
)
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))

View File

@@ -13,19 +13,17 @@ log = logging.getLogger(__name__)
class ColBERT(BaseReranker):
def __init__(self, name, **kwargs) -> None:
log.info("ColBERT: Loading model", name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
log.info('ColBERT: Loading model', name)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
DOCKER = kwargs.get("env") == "docker"
DOCKER = kwargs.get('env') == 'docker'
if DOCKER:
# This is a workaround for the issue with the docker container
# where the torch extension is not loaded properly
# and the following error is thrown:
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
lock_file = (
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
)
lock_file = '/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock'
if os.path.exists(lock_file):
os.remove(lock_file)
@@ -36,23 +34,16 @@ class ColBERT(BaseReranker):
pass
def calculate_similarity_scores(self, query_embeddings, document_embeddings):
query_embeddings = query_embeddings.to(self.device)
document_embeddings = document_embeddings.to(self.device)
# Validate dimensions to ensure compatibility
if query_embeddings.dim() != 3:
raise ValueError(
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
)
raise ValueError(f'Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}.')
if document_embeddings.dim() != 3:
raise ValueError(
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
)
raise ValueError(f'Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}.')
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
raise ValueError(
"There should be either one query or queries equal to the number of documents."
)
raise ValueError('There should be either one query or queries equal to the number of documents.')
# Transpose the query embeddings to align for matrix multiplication
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
@@ -69,7 +60,6 @@ class ColBERT(BaseReranker):
return normalized_scores.detach().cpu().numpy().astype(np.float32)
def predict(self, sentences):
query = sentences[0][0]
docs = [i[1] for i in sentences]
@@ -80,8 +70,6 @@ class ColBERT(BaseReranker):
embedded_query = embedded_queries[0]
# Calculate retrieval scores for the query against all documents
scores = self.calculate_similarity_scores(
embedded_query.unsqueeze(0), embedded_docs
)
scores = self.calculate_similarity_scores(embedded_query.unsqueeze(0), embedded_docs)
return scores

View File

@@ -15,8 +15,8 @@ class ExternalReranker(BaseReranker):
def __init__(
self,
api_key: str,
url: str = "http://localhost:8080/v1/rerank",
model: str = "reranker",
url: str = 'http://localhost:8080/v1/rerank',
model: str = 'reranker',
timeout: Optional[int] = None,
):
self.api_key = api_key
@@ -24,33 +24,31 @@ class ExternalReranker(BaseReranker):
self.model = model
self.timeout = timeout
def predict(
self, sentences: List[Tuple[str, str]], user=None
) -> Optional[List[float]]:
def predict(self, sentences: List[Tuple[str, str]], user=None) -> Optional[List[float]]:
query = sentences[0][0]
docs = [i[1] for i in sentences]
payload = {
"model": self.model,
"query": query,
"documents": docs,
"top_n": len(docs),
'model': self.model,
'query': query,
'documents': docs,
'top_n': len(docs),
}
try:
log.info(f"ExternalReranker:predict:model {self.model}")
log.info(f"ExternalReranker:predict:query {query}")
log.info(f'ExternalReranker:predict:model {self.model}')
log.info(f'ExternalReranker:predict:query {query}')
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
f"{self.url}",
f'{self.url}',
headers=headers,
json=payload,
timeout=self.timeout,
@@ -60,13 +58,13 @@ class ExternalReranker(BaseReranker):
r.raise_for_status()
data = r.json()
if "results" in data:
sorted_results = sorted(data["results"], key=lambda x: x["index"])
return [result["relevance_score"] for result in sorted_results]
if 'results' in data:
sorted_results = sorted(data['results'], key=lambda x: x['index'])
return [result['relevance_score'] for result in sorted_results]
else:
log.error("No results found in external reranking response")
log.error('No results found in external reranking response')
return None
except Exception as e:
log.exception(f"Error in external reranking: {e}")
log.exception(f'Error in external reranking: {e}')
return None

File diff suppressed because it is too large Load Diff

View File

@@ -31,17 +31,15 @@ log = logging.getLogger(__name__)
class ChromaClient(VectorDBBase):
def __init__(self):
settings_dict = {
"allow_reset": True,
"anonymized_telemetry": False,
'allow_reset': True,
'anonymized_telemetry': False,
}
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
settings_dict['chroma_client_auth_provider'] = CHROMA_CLIENT_AUTH_PROVIDER
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
settings_dict["chroma_client_auth_credentials"] = (
CHROMA_CLIENT_AUTH_CREDENTIALS
)
settings_dict['chroma_client_auth_credentials'] = CHROMA_CLIENT_AUTH_CREDENTIALS
if CHROMA_HTTP_HOST != "":
if CHROMA_HTTP_HOST != '':
self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
@@ -87,25 +85,23 @@ class ChromaClient(VectorDBBase):
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
# https://docs.trychroma.com/docs/collections/configure cosine equation
distances: list = result["distances"][0]
distances: list = result['distances'][0]
distances = [2 - dist for dist in distances]
distances = [[dist / 2 for dist in distances]]
return SearchResult(
**{
"ids": result["ids"],
"distances": distances,
"documents": result["documents"],
"metadatas": result["metadatas"],
'ids': result['ids'],
'distances': distances,
'documents': result['documents'],
'metadatas': result['metadatas'],
}
)
return None
except Exception as e:
return None
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
# Query the items from the collection based on the filter.
try:
collection = self.client.get_collection(name=collection_name)
@@ -117,9 +113,9 @@ class ChromaClient(VectorDBBase):
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
'ids': [result['ids']],
'documents': [result['documents']],
'metadatas': [result['metadatas']],
}
)
return None
@@ -133,23 +129,21 @@ class ChromaClient(VectorDBBase):
result = collection.get()
return GetResult(
**{
"ids": [result["ids"]],
"documents": [result["documents"]],
"metadatas": [result["metadatas"]],
'ids': [result['ids']],
'documents': [result['documents']],
'metadatas': [result['metadatas']],
}
)
return None
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [process_metadata(item["metadata"]) for item in items]
ids = [item['id'] for item in items]
documents = [item['text'] for item in items]
embeddings = [item['vector'] for item in items]
metadatas = [process_metadata(item['metadata']) for item in items]
for batch in create_batches(
api=self.client,
@@ -162,18 +156,14 @@ class ChromaClient(VectorDBBase):
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection(
name=collection_name, metadata={"hnsw:space": "cosine"}
)
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
ids = [item["id"] for item in items]
documents = [item["text"] for item in items]
embeddings = [item["vector"] for item in items]
metadatas = [process_metadata(item["metadata"]) for item in items]
ids = [item['id'] for item in items]
documents = [item['text'] for item in items]
embeddings = [item['vector'] for item in items]
metadatas = [process_metadata(item['metadata']) for item in items]
collection.upsert(
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
)
collection.upsert(ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas)
def delete(
self,
@@ -191,9 +181,7 @@ class ChromaClient(VectorDBBase):
collection.delete(where=filter)
except Exception as e:
# If collection doesn't exist, that's fine - nothing to delete
log.debug(
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
)
log.debug(f'Attempted to delete from non-existent collection {collection_name}. Ignoring.')
pass
def reset(self):

View File

@@ -51,7 +51,7 @@ class ElasticsearchClient(VectorDBBase):
# Status: works
def _get_index_name(self, dimension: int) -> str:
return f"{self.index_prefix}_d{str(dimension)}"
return f'{self.index_prefix}_d{str(dimension)}'
# Status: works
def _scan_result_to_get_result(self, result) -> GetResult:
@@ -62,24 +62,24 @@ class ElasticsearchClient(VectorDBBase):
metadatas = []
for hit in result:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
ids.append(hit['_id'])
documents.append(hit['_source'].get('text'))
metadatas.append(hit['_source'].get('metadata'))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
# Status: works
def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]:
if not result['hits']['hits']:
return None
ids = []
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
for hit in result['hits']['hits']:
ids.append(hit['_id'])
documents.append(hit['_source'].get('text'))
metadatas.append(hit['_source'].get('metadata'))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
@@ -90,11 +90,11 @@ class ElasticsearchClient(VectorDBBase):
documents = []
metadatas = []
for hit in result["hits"]["hits"]:
ids.append(hit["_id"])
distances.append(hit["_score"])
documents.append(hit["_source"].get("text"))
metadatas.append(hit["_source"].get("metadata"))
for hit in result['hits']['hits']:
ids.append(hit['_id'])
distances.append(hit['_score'])
documents.append(hit['_source'].get('text'))
metadatas.append(hit['_source'].get('metadata'))
return SearchResult(
ids=[ids],
@@ -106,26 +106,26 @@ class ElasticsearchClient(VectorDBBase):
# Status: works
def _create_index(self, dimension: int):
body = {
"mappings": {
"dynamic_templates": [
'mappings': {
'dynamic_templates': [
{
"strings": {
"match_mapping_type": "string",
"mapping": {"type": "keyword"},
'strings': {
'match_mapping_type': 'string',
'mapping': {'type': 'keyword'},
}
}
],
"properties": {
"collection": {"type": "keyword"},
"id": {"type": "keyword"},
"vector": {
"type": "dense_vector",
"dims": dimension, # Adjust based on your vector dimensions
"index": True,
"similarity": "cosine",
'properties': {
'collection': {'type': 'keyword'},
'id': {'type': 'keyword'},
'vector': {
'type': 'dense_vector',
'dims': dimension, # Adjust based on your vector dimensions
'index': True,
'similarity': 'cosine',
},
"text": {"type": "text"},
"metadata": {"type": "object"},
'text': {'type': 'text'},
'metadata': {'type': 'object'},
},
}
}
@@ -139,21 +139,19 @@ class ElasticsearchClient(VectorDBBase):
# Status: works
def has_collection(self, collection_name) -> bool:
query_body = {"query": {"bool": {"filter": []}}}
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
query_body = {'query': {'bool': {'filter': []}}}
query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
try:
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
result = self.client.count(index=f'{self.index_prefix}*', body=query_body)
return result.body["count"] > 0
return result.body['count'] > 0
except Exception as e:
return None
def delete_collection(self, collection_name: str):
query = {"query": {"term": {"collection": collection_name}}}
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
query = {'query': {'term': {'collection': collection_name}}}
self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
# Status: works
def search(
@@ -164,51 +162,41 @@ class ElasticsearchClient(VectorDBBase):
limit: int = 10,
) -> Optional[SearchResult]:
query = {
"size": limit,
"_source": ["text", "metadata"],
"query": {
"script_score": {
"query": {
"bool": {"filter": [{"term": {"collection": collection_name}}]}
},
"script": {
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
'size': limit,
'_source': ['text', 'metadata'],
'query': {
'script_score': {
'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
'script': {
'source': "cosineSimilarity(params.vector, 'vector') + 1.0",
'params': {'vector': vectors[0]}, # Assuming single query vector
},
}
},
}
result = self.client.search(
index=self._get_index_name(len(vectors[0])), body=query
)
result = self.client.search(index=self._get_index_name(len(vectors[0])), body=query)
return self._result_to_search_result(result)
# Status: only tested halfwat
def query(
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
if not self.has_collection(collection_name):
return None
query_body = {
"query": {"bool": {"filter": []}},
"_source": ["text", "metadata"],
'query': {'bool': {'filter': []}},
'_source': ['text', 'metadata'],
}
for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
query_body["query"]["bool"]["filter"].append(
{"term": {"collection": collection_name}}
)
query_body['query']['bool']['filter'].append({'term': {field: value}})
query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
size = limit if limit else 10
try:
result = self.client.search(
index=f"{self.index_prefix}*",
index=f'{self.index_prefix}*',
body=query_body,
size=size,
)
@@ -220,9 +208,7 @@ class ElasticsearchClient(VectorDBBase):
# Status: works
def _has_index(self, dimension: int):
return self.client.indices.exists(
index=self._get_index_name(dimension=dimension)
)
return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
def get_or_create_index(self, dimension: int):
if not self._has_index(dimension=dimension):
@@ -232,28 +218,28 @@ class ElasticsearchClient(VectorDBBase):
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection.
query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
"_source": ["text", "metadata"],
'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
'_source': ['text', 'metadata'],
}
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
results = list(scan(self.client, index=f'{self.index_prefix}*', query=query))
return self._scan_result_to_get_result(results)
# Status: works
def insert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(dimension=len(items[0]["vector"]))
if not self._has_index(dimension=len(items[0]['vector'])):
self._create_index(dimension=len(items[0]['vector']))
for batch in self._create_batches(items):
actions = [
{
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
"_id": item["id"],
"_source": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": process_metadata(item["metadata"]),
'_index': self._get_index_name(dimension=len(items[0]['vector'])),
'_id': item['id'],
'_source': {
'collection': collection_name,
'vector': item['vector'],
'text': item['text'],
'metadata': process_metadata(item['metadata']),
},
}
for item in batch
@@ -262,21 +248,21 @@ class ElasticsearchClient(VectorDBBase):
# Upsert documents using the update API with doc_as_upsert=True.
def upsert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])):
self._create_index(dimension=len(items[0]["vector"]))
if not self._has_index(dimension=len(items[0]['vector'])):
self._create_index(dimension=len(items[0]['vector']))
for batch in self._create_batches(items):
actions = [
{
"_op_type": "update",
"_index": self._get_index_name(dimension=len(item["vector"])),
"_id": item["id"],
"doc": {
"collection": collection_name,
"vector": item["vector"],
"text": item["text"],
"metadata": process_metadata(item["metadata"]),
'_op_type': 'update',
'_index': self._get_index_name(dimension=len(item['vector'])),
'_id': item['id'],
'doc': {
'collection': collection_name,
'vector': item['vector'],
'text': item['text'],
'metadata': process_metadata(item['metadata']),
},
"doc_as_upsert": True,
'doc_as_upsert': True,
}
for item in batch
]
@@ -289,22 +275,17 @@ class ElasticsearchClient(VectorDBBase):
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
}
query = {'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}}}
# logic based on chromaDB
if ids:
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
query['query']['bool']['filter'].append({'terms': {'_id': ids}})
elif filter:
for field, value in filter.items():
query["query"]["bool"]["filter"].append(
{"term": {f"metadata.{field}": value}}
)
query['query']['bool']['filter'].append({'term': {f'metadata.{field}': value}})
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}*")
indices = self.client.indices.get(index=f'{self.index_prefix}*')
for index in indices:
self.client.indices.delete(index=index)

View File

@@ -47,8 +47,8 @@ def _embedding_to_f32_bytes(vec: List[float]) -> bytes:
byte sequence. We use array('f') to avoid a numpy dependency and byteswap on
big-endian platforms for portability.
"""
a = array.array("f", [float(x) for x in vec]) # float32
if sys.byteorder != "little":
a = array.array('f', [float(x) for x in vec]) # float32
if sys.byteorder != 'little':
a.byteswap()
return a.tobytes()
@@ -68,7 +68,7 @@ def _safe_json(v: Any) -> Dict[str, Any]:
return v
if isinstance(v, (bytes, bytearray)):
try:
v = v.decode("utf-8")
v = v.decode('utf-8')
except Exception:
return {}
if isinstance(v, str):
@@ -105,16 +105,16 @@ class MariaDBVectorClient(VectorDBBase):
"""
self.db_url = (db_url or MARIADB_VECTOR_DB_URL).strip()
self.vector_length = int(vector_length)
self.distance_strategy = (distance_strategy or "cosine").strip().lower()
self.distance_strategy = (distance_strategy or 'cosine').strip().lower()
self.index_m = int(index_m)
if self.distance_strategy not in {"cosine", "euclidean"}:
if self.distance_strategy not in {'cosine', 'euclidean'}:
raise ValueError("distance_strategy must be 'cosine' or 'euclidean'")
if not self.db_url.lower().startswith("mariadb+mariadbconnector://"):
if not self.db_url.lower().startswith('mariadb+mariadbconnector://'):
raise ValueError(
"MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) "
"to ensure qmark paramstyle and correct VECTOR binding."
'MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) '
'to ensure qmark paramstyle and correct VECTOR binding.'
)
if isinstance(MARIADB_VECTOR_POOL_SIZE, int):
@@ -129,9 +129,7 @@ class MariaDBVectorClient(VectorDBBase):
poolclass=QueuePool,
)
else:
self.engine = create_engine(
self.db_url, pool_pre_ping=True, poolclass=NullPool
)
self.engine = create_engine(self.db_url, pool_pre_ping=True, poolclass=NullPool)
else:
self.engine = create_engine(self.db_url, pool_pre_ping=True)
self._init_schema()
@@ -185,7 +183,7 @@ class MariaDBVectorClient(VectorDBBase):
conn.commit()
except Exception as e:
conn.rollback()
log.exception(f"Error during database initialization: {e}")
log.exception(f'Error during database initialization: {e}')
raise
def _check_vector_length(self) -> None:
@@ -197,19 +195,19 @@ class MariaDBVectorClient(VectorDBBase):
"""
with self._connect() as conn:
with conn.cursor() as cur:
cur.execute("SHOW CREATE TABLE document_chunk")
cur.execute('SHOW CREATE TABLE document_chunk')
row = cur.fetchone()
if not row or len(row) < 2:
return
ddl = row[1]
m = re.search(r"vector\\((\\d+)\\)", ddl, flags=re.IGNORECASE)
m = re.search(r'vector\\((\\d+)\\)', ddl, flags=re.IGNORECASE)
if not m:
return
existing = int(m.group(1))
if existing != int(self.vector_length):
raise Exception(
f"VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. "
"Cannot change vector size after initialization without migrating the data."
f'VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. '
'Cannot change vector size after initialization without migrating the data.'
)
def adjust_vector_length(self, vector: List[float]) -> List[float]:
@@ -227,11 +225,7 @@ class MariaDBVectorClient(VectorDBBase):
"""
Return the MariaDB Vector distance function name for the configured strategy.
"""
return (
"vec_distance_cosine"
if self.distance_strategy == "cosine"
else "vec_distance_euclidean"
)
return 'vec_distance_cosine' if self.distance_strategy == 'cosine' else 'vec_distance_euclidean'
def _score_from_dist(self, dist: float) -> float:
"""
@@ -240,7 +234,7 @@ class MariaDBVectorClient(VectorDBBase):
- cosine: score ~= 1 - cosine_distance, clamped to [0, 1]
- euclidean: score = 1 / (1 + dist)
"""
if self.distance_strategy == "cosine":
if self.distance_strategy == 'cosine':
score = 1.0 - dist
if score < 0.0:
score = 0.0
@@ -260,48 +254,48 @@ class MariaDBVectorClient(VectorDBBase):
- {"$or": [ ... ]}
"""
if not expr or not isinstance(expr, dict):
return "", []
return '', []
if "$and" in expr:
if '$and' in expr:
parts: List[str] = []
params: List[Any] = []
for e in expr.get("$and") or []:
for e in expr.get('$and') or []:
s, p = self._build_filter_sql_qmark(e)
if s:
parts.append(s)
params.extend(p)
return ("(" + " AND ".join(parts) + ")") if parts else "", params
return ('(' + ' AND '.join(parts) + ')') if parts else '', params
if "$or" in expr:
if '$or' in expr:
parts: List[str] = []
params: List[Any] = []
for e in expr.get("$or") or []:
for e in expr.get('$or') or []:
s, p = self._build_filter_sql_qmark(e)
if s:
parts.append(s)
params.extend(p)
return ("(" + " OR ".join(parts) + ")") if parts else "", params
return ('(' + ' OR '.join(parts) + ')') if parts else '', params
clauses: List[str] = []
params: List[Any] = []
for key, value in expr.items():
if key.startswith("$"):
if key.startswith('$'):
continue
json_expr = f"JSON_UNQUOTE(JSON_EXTRACT(vmetadata, '$.{key}'))"
if isinstance(value, dict) and "$in" in value:
vals = [str(v) for v in (value.get("$in") or [])]
if isinstance(value, dict) and '$in' in value:
vals = [str(v) for v in (value.get('$in') or [])]
if not vals:
clauses.append("0=1")
clauses.append('0=1')
continue
ors = []
for v in vals:
ors.append(f"{json_expr} = ?")
ors.append(f'{json_expr} = ?')
params.append(v)
clauses.append("(" + " OR ".join(ors) + ")")
clauses.append('(' + ' OR '.join(ors) + ')')
else:
clauses.append(f"{json_expr} = ?")
clauses.append(f'{json_expr} = ?')
params.append(str(value))
return ("(" + " AND ".join(clauses) + ")") if clauses else "", params
return ('(' + ' AND '.join(clauses) + ')') if clauses else '', params
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
"""
@@ -322,15 +316,15 @@ class MariaDBVectorClient(VectorDBBase):
"""
params: List[Tuple[Any, ...]] = []
for item in items:
v = self.adjust_vector_length(item["vector"])
v = self.adjust_vector_length(item['vector'])
emb = _embedding_to_f32_bytes(v)
meta = process_metadata(item.get("metadata") or {})
meta = process_metadata(item.get('metadata') or {})
params.append(
(
item["id"],
item['id'],
emb,
collection_name,
item.get("text"),
item.get('text'),
json.dumps(meta),
)
)
@@ -338,7 +332,7 @@ class MariaDBVectorClient(VectorDBBase):
conn.commit()
except Exception as e:
conn.rollback()
log.exception(f"Error during insert: {e}")
log.exception(f'Error during insert: {e}')
raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -365,15 +359,15 @@ class MariaDBVectorClient(VectorDBBase):
"""
params: List[Tuple[Any, ...]] = []
for item in items:
v = self.adjust_vector_length(item["vector"])
v = self.adjust_vector_length(item['vector'])
emb = _embedding_to_f32_bytes(v)
meta = process_metadata(item.get("metadata") or {})
meta = process_metadata(item.get('metadata') or {})
params.append(
(
item["id"],
item['id'],
emb,
collection_name,
item.get("text"),
item.get('text'),
json.dumps(meta),
)
)
@@ -381,7 +375,7 @@ class MariaDBVectorClient(VectorDBBase):
conn.commit()
except Exception as e:
conn.rollback()
log.exception(f"Error during upsert: {e}")
log.exception(f'Error during upsert: {e}')
raise
def search(
@@ -415,10 +409,10 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn:
with conn.cursor() as cur:
fsql, fparams = self._build_filter_sql_qmark(filter or {})
where = "collection_name = ?"
where = 'collection_name = ?'
base_params: List[Any] = [collection_name]
if fsql:
where = where + " AND " + fsql
where = where + ' AND ' + fsql
base_params.extend(fparams)
sql = f"""
@@ -460,26 +454,24 @@ class MariaDBVectorClient(VectorDBBase):
metadatas=metadatas,
)
except Exception as e:
log.exception(f"[MARIADB_VECTOR] search() failed: {e}")
log.exception(f'[MARIADB_VECTOR] search() failed: {e}')
return None
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
"""
Retrieve documents by metadata filter (non-vector query).
"""
with self._connect() as conn:
with conn.cursor() as cur:
fsql, fparams = self._build_filter_sql_qmark(filter or {})
where = "collection_name = ?"
where = 'collection_name = ?'
params: List[Any] = [collection_name]
if fsql:
where = where + " AND " + fsql
where = where + ' AND ' + fsql
params.extend(fparams)
sql = f"SELECT id, text, vmetadata FROM document_chunk WHERE {where}"
sql = f'SELECT id, text, vmetadata FROM document_chunk WHERE {where}'
if limit is not None:
sql += " LIMIT ?"
sql += ' LIMIT ?'
params.append(int(limit))
cur.execute(sql, params)
rows = cur.fetchall()
@@ -490,18 +482,16 @@ class MariaDBVectorClient(VectorDBBase):
metadatas = [[_safe_json(r[2]) for r in rows]]
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
def get(
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
"""
Retrieve documents in a collection without filtering (optionally limited).
"""
with self._connect() as conn:
with conn.cursor() as cur:
sql = "SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?"
sql = 'SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?'
params: List[Any] = [collection_name]
if limit is not None:
sql += " LIMIT ?"
sql += ' LIMIT ?'
params.append(int(limit))
cur.execute(sql, params)
rows = cur.fetchall()
@@ -526,12 +516,12 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn:
with conn.cursor() as cur:
try:
where = ["collection_name = ?"]
where = ['collection_name = ?']
params: List[Any] = [collection_name]
if ids:
ph = ", ".join(["?"] * len(ids))
where.append(f"id IN ({ph})")
ph = ', '.join(['?'] * len(ids))
where.append(f'id IN ({ph})')
params.extend(ids)
if filter:
@@ -540,12 +530,12 @@ class MariaDBVectorClient(VectorDBBase):
where.append(fsql)
params.extend(fparams)
sql = "DELETE FROM document_chunk WHERE " + " AND ".join(where)
sql = 'DELETE FROM document_chunk WHERE ' + ' AND '.join(where)
cur.execute(sql, params)
conn.commit()
except Exception as e:
conn.rollback()
log.exception(f"Error during delete: {e}")
log.exception(f'Error during delete: {e}')
raise
def reset(self) -> None:
@@ -555,11 +545,11 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn:
with conn.cursor() as cur:
try:
cur.execute("TRUNCATE TABLE document_chunk")
cur.execute('TRUNCATE TABLE document_chunk')
conn.commit()
except Exception as e:
conn.rollback()
log.exception(f"Error during reset: {e}")
log.exception(f'Error during reset: {e}')
raise
def has_collection(self, collection_name: str) -> bool:
@@ -570,7 +560,7 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1",
'SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1',
(collection_name,),
)
return cur.fetchone() is not None
@@ -590,4 +580,4 @@ class MariaDBVectorClient(VectorDBBase):
try:
self.engine.dispose()
except Exception as e:
log.exception(f"Error during dispose the underlying SQLAlchemy engine: {e}")
log.exception(f'Error during dispose the underlying SQLAlchemy engine: {e}')

View File

@@ -35,7 +35,7 @@ log = logging.getLogger(__name__)
class MilvusClient(VectorDBBase):
def __init__(self):
self.collection_prefix = "open_webui"
self.collection_prefix = 'open_webui'
if MILVUS_TOKEN is None:
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
else:
@@ -50,17 +50,17 @@ class MilvusClient(VectorDBBase):
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_documents.append(item.get("data", {}).get("text"))
_metadatas.append(item.get("metadata"))
_ids.append(item.get('id'))
_documents.append(item.get('data', {}).get('text'))
_metadatas.append(item.get('metadata'))
ids.append(_ids)
documents.append(_documents)
metadatas.append(_metadatas)
return GetResult(
**{
"ids": ids,
"documents": documents,
"metadatas": metadatas,
'ids': ids,
'documents': documents,
'metadatas': metadatas,
}
)
@@ -75,23 +75,23 @@ class MilvusClient(VectorDBBase):
_documents = []
_metadatas = []
for item in match:
_ids.append(item.get("id"))
_ids.append(item.get('id'))
# normalize milvus score from [-1, 1] to [0, 1] range
# https://milvus.io/docs/de/metric.md
_dist = (item.get("distance") + 1.0) / 2.0
_dist = (item.get('distance') + 1.0) / 2.0
_distances.append(_dist)
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
_metadatas.append(item.get("entity", {}).get("metadata"))
_documents.append(item.get('entity', {}).get('data', {}).get('text'))
_metadatas.append(item.get('entity', {}).get('metadata'))
ids.append(_ids)
distances.append(_distances)
documents.append(_documents)
metadatas.append(_metadatas)
return SearchResult(
**{
"ids": ids,
"distances": distances,
"documents": documents,
"metadatas": metadatas,
'ids': ids,
'distances': distances,
'documents': documents,
'metadatas': metadatas,
}
)
@@ -101,21 +101,19 @@ class MilvusClient(VectorDBBase):
enable_dynamic_field=True,
)
schema.add_field(
field_name="id",
field_name='id',
datatype=DataType.VARCHAR,
is_primary=True,
max_length=65535,
)
schema.add_field(
field_name="vector",
field_name='vector',
datatype=DataType.FLOAT_VECTOR,
dim=dimension,
description="vector",
)
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
schema.add_field(
field_name="metadata", datatype=DataType.JSON, description="metadata"
description='vector',
)
schema.add_field(field_name='data', datatype=DataType.JSON, description='data')
schema.add_field(field_name='metadata', datatype=DataType.JSON, description='metadata')
index_params = self.client.prepare_index_params()
@@ -123,44 +121,44 @@ class MilvusClient(VectorDBBase):
index_type = MILVUS_INDEX_TYPE.upper()
metric_type = MILVUS_METRIC_TYPE.upper()
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
log.info(f'Using Milvus index type: {index_type}, metric type: {metric_type}')
index_creation_params = {}
if index_type == "HNSW":
if index_type == 'HNSW':
index_creation_params = {
"M": MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
'M': MILVUS_HNSW_M,
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
}
log.info(f"HNSW params: {index_creation_params}")
elif index_type == "IVF_FLAT":
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
log.info(f"IVF_FLAT params: {index_creation_params}")
elif index_type == "DISKANN":
log.info(f'HNSW params: {index_creation_params}')
elif index_type == 'IVF_FLAT':
index_creation_params = {'nlist': MILVUS_IVF_FLAT_NLIST}
log.info(f'IVF_FLAT params: {index_creation_params}')
elif index_type == 'DISKANN':
index_creation_params = {
"max_degree": MILVUS_DISKANN_MAX_DEGREE,
"search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE,
'max_degree': MILVUS_DISKANN_MAX_DEGREE,
'search_list_size': MILVUS_DISKANN_SEARCH_LIST_SIZE,
}
log.info(f"DISKANN params: {index_creation_params}")
elif index_type in ["FLAT", "AUTOINDEX"]:
log.info(f"Using {index_type} index with no specific build-time params.")
log.info(f'DISKANN params: {index_creation_params}')
elif index_type in ['FLAT', 'AUTOINDEX']:
log.info(f'Using {index_type} index with no specific build-time params.')
else:
log.warning(
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. "
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
f'Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. '
f'Milvus will use its default for the collection if this type is not directly supported for index creation.'
)
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
index_params.add_index(
field_name="vector",
field_name='vector',
index_type=index_type,
metric_type=metric_type,
params=index_creation_params,
)
self.client.create_collection(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
schema=schema,
index_params=index_params,
)
@@ -170,17 +168,13 @@ class MilvusClient(VectorDBBase):
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
collection_name = collection_name.replace("-", "_")
return self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
collection_name = collection_name.replace('-', '_')
return self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name.
collection_name = collection_name.replace("-", "_")
return self.client.drop_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
)
collection_name = collection_name.replace('-', '_')
return self.client.drop_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
def search(
self,
@@ -190,15 +184,15 @@ class MilvusClient(VectorDBBase):
limit: int = 10,
) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_")
collection_name = collection_name.replace('-', '_')
# For some index types like IVF_FLAT, search params like nprobe can be set.
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
# For simplicity, not adding configurable search_params here, but could be extended.
result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
data=vectors,
limit=limit,
output_fields=["data", "metadata"],
output_fields=['data', 'metadata'],
# search_params=search_params # Potentially add later if needed
)
return self._result_to_search_result(result)
@@ -206,11 +200,9 @@ class MilvusClient(VectorDBBase):
def query(self, collection_name: str, filter: dict, limit: int = -1):
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
collection_name = collection_name.replace("-", "_")
collection_name = collection_name.replace('-', '_')
if not self.has_collection(collection_name):
log.warning(
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
log.warning(f'Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
return None
filter_expressions = []
@@ -220,9 +212,9 @@ class MilvusClient(VectorDBBase):
else:
filter_expressions.append(f'metadata["{key}"] == {value}')
filter_string = " && ".join(filter_expressions)
filter_string = ' && '.join(filter_expressions)
collection = Collection(f"{self.collection_prefix}_{collection_name}")
collection = Collection(f'{self.collection_prefix}_{collection_name}')
collection.load()
try:
@@ -233,9 +225,9 @@ class MilvusClient(VectorDBBase):
iterator = collection.query_iterator(
expr=filter_string,
output_fields=[
"id",
"data",
"metadata",
'id',
'data',
'metadata',
],
limit=limit if limit > 0 else -1,
)
@@ -248,7 +240,7 @@ class MilvusClient(VectorDBBase):
break
all_results.extend(batch)
log.debug(f"Total results from query: {len(all_results)}")
log.debug(f'Total results from query: {len(all_results)}')
return self._result_to_get_result([all_results] if all_results else [[]])
except Exception as e:
@@ -259,7 +251,7 @@ class MilvusClient(VectorDBBase):
def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. This can be very resource-intensive for large collections.
collection_name = collection_name.replace("-", "_")
collection_name = collection_name.replace('-', '_')
log.warning(
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
)
@@ -269,35 +261,25 @@ class MilvusClient(VectorDBBase):
def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_")
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
)
collection_name = collection_name.replace('-', '_')
if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.')
if not items:
log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
f'Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.'
)
raise ValueError(
"Cannot create Milvus collection without items to determine vector dimension."
)
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
raise ValueError('Cannot create Milvus collection without items to determine vector dimension.')
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
log.info(
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
log.info(f'Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": process_metadata(item["metadata"]),
'id': item['id'],
'vector': item['vector'],
'data': {'text': item['text']},
'metadata': process_metadata(item['metadata']),
}
for item in items
],
@@ -305,35 +287,27 @@ class MilvusClient(VectorDBBase):
def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_")
if not self.client.has_collection(
collection_name=f"{self.collection_prefix}_{collection_name}"
):
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
)
collection_name = collection_name.replace('-', '_')
if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.')
if not items:
log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
f'Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.'
)
raise ValueError(
"Cannot create Milvus collection for upsert without items to determine vector dimension."
'Cannot create Milvus collection for upsert without items to determine vector dimension.'
)
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
)
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
log.info(
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
log.info(f'Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
data=[
{
"id": item["id"],
"vector": item["vector"],
"data": {"text": item["text"]},
"metadata": process_metadata(item["metadata"]),
'id': item['id'],
'vector': item['vector'],
'data': {'text': item['text']},
'metadata': process_metadata(item['metadata']),
}
for item in items
],
@@ -346,46 +320,35 @@ class MilvusClient(VectorDBBase):
filter: Optional[dict] = None,
):
# Delete the items from the collection based on the ids or filter.
collection_name = collection_name.replace("-", "_")
collection_name = collection_name.replace('-', '_')
if not self.has_collection(collection_name):
log.warning(
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
log.warning(f'Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
return None
if ids:
log.info(
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
)
log.info(f'Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}')
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
ids=ids,
)
elif filter:
filter_string = " && ".join(
[
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
filter_string = ' && '.join([f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items()])
log.info(
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
f'Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}'
)
return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}",
collection_name=f'{self.collection_prefix}_{collection_name}',
filter=filter_string,
)
else:
log.warning(
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
f'Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.'
)
return None
def reset(self):
# Resets the database. This will delete all collections and item entries that match the prefix.
log.warning(
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
)
log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
collection_names = self.client.list_collections()
deleted_collections = []
for collection_name_full in collection_names:
@@ -393,7 +356,7 @@ class MilvusClient(VectorDBBase):
try:
self.client.drop_collection(collection_name=collection_name_full)
deleted_collections.append(collection_name_full)
log.info(f"Deleted collection: {collection_name_full}")
log.info(f'Deleted collection: {collection_name_full}')
except Exception as e:
log.error(f"Error deleting collection {collection_name_full}: {e}")
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")
log.error(f'Error deleting collection {collection_name_full}: {e}')
log.info(f'Milvus reset complete. Deleted collections: {deleted_collections}')

View File

@@ -33,26 +33,26 @@ from pymilvus import (
log = logging.getLogger(__name__)
RESOURCE_ID_FIELD = "resource_id"
RESOURCE_ID_FIELD = 'resource_id'
class MilvusClient(VectorDBBase):
def __init__(self):
# Milvus collection names can only contain numbers, letters, and underscores.
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace('-', '_')
connections.connect(
alias="default",
alias='default',
uri=MILVUS_URI,
token=MILVUS_TOKEN,
db_name=MILVUS_DB,
)
# Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories'
self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge'
self.FILE_COLLECTION = f'{self.collection_prefix}_files'
self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web_search'
self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash_based'
self.shared_collections = [
self.MEMORY_COLLECTION,
self.KNOWLEDGE_COLLECTION,
@@ -74,15 +74,13 @@ class MilvusClient(VectorDBBase):
"""
resource_id = collection_name
if collection_name.startswith("user-memory-"):
if collection_name.startswith('user-memory-'):
return self.MEMORY_COLLECTION, resource_id
elif collection_name.startswith("file-"):
elif collection_name.startswith('file-'):
return self.FILE_COLLECTION, resource_id
elif collection_name.startswith("web-search-"):
elif collection_name.startswith('web-search-'):
return self.WEB_SEARCH_COLLECTION, resource_id
elif len(collection_name) == 63 and all(
c in "0123456789abcdef" for c in collection_name
):
elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name):
return self.HASH_BASED_COLLECTION, resource_id
else:
return self.KNOWLEDGE_COLLECTION, resource_id
@@ -90,36 +88,36 @@ class MilvusClient(VectorDBBase):
def _create_shared_collection(self, mt_collection_name: str, dimension: int):
fields = [
FieldSchema(
name="id",
name='id',
dtype=DataType.VARCHAR,
is_primary=True,
auto_id=False,
max_length=36,
),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="metadata", dtype=DataType.JSON),
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name='metadata', dtype=DataType.JSON),
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
]
schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
schema = CollectionSchema(fields, 'Shared collection for multi-tenancy')
collection = Collection(mt_collection_name, schema)
index_params = {
"metric_type": MILVUS_METRIC_TYPE,
"index_type": MILVUS_INDEX_TYPE,
"params": {},
'metric_type': MILVUS_METRIC_TYPE,
'index_type': MILVUS_INDEX_TYPE,
'params': {},
}
if MILVUS_INDEX_TYPE == "HNSW":
index_params["params"] = {
"M": MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
if MILVUS_INDEX_TYPE == 'HNSW':
index_params['params'] = {
'M': MILVUS_HNSW_M,
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
}
elif MILVUS_INDEX_TYPE == "IVF_FLAT":
index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
elif MILVUS_INDEX_TYPE == 'IVF_FLAT':
index_params['params'] = {'nlist': MILVUS_IVF_FLAT_NLIST}
collection.create_index("vector", index_params)
collection.create_index('vector', index_params)
collection.create_index(RESOURCE_ID_FIELD)
log.info(f"Created shared collection: {mt_collection_name}")
log.info(f'Created shared collection: {mt_collection_name}')
return collection
def _ensure_collection(self, mt_collection_name: str, dimension: int):
@@ -127,9 +125,7 @@ class MilvusClient(VectorDBBase):
self._create_shared_collection(mt_collection_name, dimension)
def has_collection(self, collection_name: str) -> bool:
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
if not utility.has_collection(mt_collection):
return False
@@ -141,19 +137,17 @@ class MilvusClient(VectorDBBase):
def upsert(self, collection_name: str, items: List[VectorItem]):
if not items:
return
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
dimension = len(items[0]["vector"])
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
dimension = len(items[0]['vector'])
self._ensure_collection(mt_collection, dimension)
collection = Collection(mt_collection)
entities = [
{
"id": item["id"],
"vector": item["vector"],
"text": item["text"],
"metadata": item["metadata"],
'id': item['id'],
'vector': item['vector'],
'text': item['text'],
'metadata': item['metadata'],
RESOURCE_ID_FIELD: resource_id,
}
for item in items
@@ -170,41 +164,37 @@ class MilvusClient(VectorDBBase):
if not vectors:
return None
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
if not utility.has_collection(mt_collection):
return None
collection = Collection(mt_collection)
collection.load()
search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
search_params = {'metric_type': MILVUS_METRIC_TYPE, 'params': {}}
results = collection.search(
data=vectors,
anns_field="vector",
anns_field='vector',
param=search_params,
limit=limit,
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
output_fields=["id", "text", "metadata"],
output_fields=['id', 'text', 'metadata'],
)
ids, documents, metadatas, distances = [], [], [], []
for hits in results:
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
for hit in hits:
batch_ids.append(hit.entity.get("id"))
batch_docs.append(hit.entity.get("text"))
batch_metadatas.append(hit.entity.get("metadata"))
batch_ids.append(hit.entity.get('id'))
batch_docs.append(hit.entity.get('text'))
batch_metadatas.append(hit.entity.get('metadata'))
batch_dists.append(hit.distance)
ids.append(batch_ids)
documents.append(batch_docs)
metadatas.append(batch_metadatas)
distances.append(batch_dists)
return SearchResult(
ids=ids, documents=documents, metadatas=metadatas, distances=distances
)
return SearchResult(ids=ids, documents=documents, metadatas=metadatas, distances=distances)
def delete(
self,
@@ -212,9 +202,7 @@ class MilvusClient(VectorDBBase):
ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None,
):
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
if not utility.has_collection(mt_collection):
return
@@ -224,14 +212,14 @@ class MilvusClient(VectorDBBase):
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
if ids:
# Milvus expects a string list for 'in' operator
id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
expr.append(f"id in [{id_list_str}]")
id_list_str = ', '.join([f"'{id_val}'" for id_val in ids])
expr.append(f'id in [{id_list_str}]')
if filter:
for key, value in filter.items():
expr.append(f"metadata['{key}'] == '{value}'")
collection.delete(" and ".join(expr))
collection.delete(' and '.join(expr))
def reset(self):
for collection_name in self.shared_collections:
@@ -239,21 +227,15 @@ class MilvusClient(VectorDBBase):
utility.drop_collection(collection_name)
def delete_collection(self, collection_name: str):
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
if not utility.has_collection(mt_collection):
return
collection = Collection(mt_collection)
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
def query(
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
if not utility.has_collection(mt_collection):
return None
@@ -269,8 +251,8 @@ class MilvusClient(VectorDBBase):
expr.append(f"metadata['{key}'] == {value}")
iterator = collection.query_iterator(
expr=" and ".join(expr),
output_fields=["id", "text", "metadata"],
expr=' and '.join(expr),
output_fields=['id', 'text', 'metadata'],
limit=limit if limit else -1,
)
@@ -282,9 +264,9 @@ class MilvusClient(VectorDBBase):
break
all_results.extend(batch)
ids = [res["id"] for res in all_results]
documents = [res["text"] for res in all_results]
metadatas = [res["metadata"] for res in all_results]
ids = [res['id'] for res in all_results]
documents = [res['text'] for res in all_results]
metadatas = [res['metadata'] for res in all_results]
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])

Some files were not shown because too many files have changed in this diff Show More