This commit is contained in:
Timothy Jaeryang Baek
2026-03-17 17:58:01 -05:00
parent fcf7208352
commit de3317e26b
220 changed files with 17200 additions and 22836 deletions

View File

@@ -10,94 +10,88 @@ from typing_extensions import Annotated
app = typer.Typer() app = typer.Typer()
KEY_FILE = Path.cwd() / ".webui_secret_key" KEY_FILE = Path.cwd() / '.webui_secret_key'
def version_callback(value: bool): def version_callback(value: bool):
if value: if value:
from open_webui.env import VERSION from open_webui.env import VERSION
typer.echo(f"Open WebUI version: {VERSION}") typer.echo(f'Open WebUI version: {VERSION}')
raise typer.Exit() raise typer.Exit()
@app.command() @app.command()
def main( def main(
version: Annotated[ version: Annotated[Optional[bool], typer.Option('--version', callback=version_callback)] = None,
Optional[bool], typer.Option("--version", callback=version_callback)
] = None,
): ):
pass pass
@app.command() @app.command()
def serve( def serve(
host: str = "0.0.0.0", host: str = '0.0.0.0',
port: int = 8080, port: int = 8080,
): ):
os.environ["FROM_INIT_PY"] = "true" os.environ['FROM_INIT_PY'] = 'true'
if os.getenv("WEBUI_SECRET_KEY") is None: if os.getenv('WEBUI_SECRET_KEY') is None:
typer.echo( typer.echo('Loading WEBUI_SECRET_KEY from file, not provided as an environment variable.')
"Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
)
if not KEY_FILE.exists(): if not KEY_FILE.exists():
typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}") typer.echo(f'Generating a new secret key and saving it to {KEY_FILE}')
KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12))) KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}") typer.echo(f'Loading WEBUI_SECRET_KEY from {KEY_FILE}')
os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text() os.environ['WEBUI_SECRET_KEY'] = KEY_FILE.read_text()
if os.getenv("USE_CUDA_DOCKER", "false") == "true": if os.getenv('USE_CUDA_DOCKER', 'false') == 'true':
typer.echo( typer.echo('CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries.')
"CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries." LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH', '').split(':')
) os.environ['LD_LIBRARY_PATH'] = ':'.join(
LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
os.environ["LD_LIBRARY_PATH"] = ":".join(
LD_LIBRARY_PATH LD_LIBRARY_PATH
+ [ + [
"/usr/local/lib/python3.11/site-packages/torch/lib", '/usr/local/lib/python3.11/site-packages/torch/lib',
"/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib", '/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib',
] ]
) )
try: try:
import torch import torch
assert torch.cuda.is_available(), "CUDA not available" assert torch.cuda.is_available(), 'CUDA not available'
typer.echo("CUDA seems to be working") typer.echo('CUDA seems to be working')
except Exception as e: except Exception as e:
typer.echo( typer.echo(
"Error when testing CUDA but USE_CUDA_DOCKER is true. " 'Error when testing CUDA but USE_CUDA_DOCKER is true. '
"Resetting USE_CUDA_DOCKER to false and removing " 'Resetting USE_CUDA_DOCKER to false and removing '
f"LD_LIBRARY_PATH modifications: {e}" f'LD_LIBRARY_PATH modifications: {e}'
) )
os.environ["USE_CUDA_DOCKER"] = "false" os.environ['USE_CUDA_DOCKER'] = 'false'
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH) os.environ['LD_LIBRARY_PATH'] = ':'.join(LD_LIBRARY_PATH)
import open_webui.main # we need set environment variables before importing main import open_webui.main # we need set environment variables before importing main
from open_webui.env import UVICORN_WORKERS # Import the workers setting from open_webui.env import UVICORN_WORKERS # Import the workers setting
uvicorn.run( uvicorn.run(
"open_webui.main:app", 'open_webui.main:app',
host=host, host=host,
port=port, port=port,
forwarded_allow_ips="*", forwarded_allow_ips='*',
workers=UVICORN_WORKERS, workers=UVICORN_WORKERS,
) )
@app.command() @app.command()
def dev( def dev(
host: str = "0.0.0.0", host: str = '0.0.0.0',
port: int = 8080, port: int = 8080,
reload: bool = True, reload: bool = True,
): ):
uvicorn.run( uvicorn.run(
"open_webui.main:app", 'open_webui.main:app',
host=host, host=host,
port=port, port=port,
reload=reload, reload=reload,
forwarded_allow_ips="*", forwarded_allow_ips='*',
) )
if __name__ == "__main__": if __name__ == '__main__':
app() app()

File diff suppressed because it is too large Load Diff

View File

@@ -2,125 +2,107 @@ from enum import Enum
class MESSAGES(str, Enum): class MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}" DEFAULT = lambda msg='': f'{msg if msg else ""}'
MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully." MODEL_ADDED = lambda model='': f"The model '{model}' has been added successfully."
MODEL_DELETED = ( MODEL_DELETED = lambda model='': f"The model '{model}' has been deleted successfully."
lambda model="": f"The model '{model}' has been deleted successfully."
)
class WEBHOOK_MESSAGES(str, Enum): class WEBHOOK_MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}" DEFAULT = lambda msg='': f'{msg if msg else ""}'
USER_SIGNUP = lambda username="": ( USER_SIGNUP = lambda username='': (f'New user signed up: {username}' if username else 'New user signed up')
f"New user signed up: {username}" if username else "New user signed up"
)
class ERROR_MESSAGES(str, Enum): class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str: def __str__(self) -> str:
return super().__str__() return super().__str__()
DEFAULT = ( DEFAULT = lambda err='': f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}' ENV_VAR_NOT_FOUND = 'Required environment variable not found. Terminating now.'
CREATE_USER_ERROR = 'Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance.'
DELETE_USER_ERROR = 'Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot.'
EMAIL_MISMATCH = 'Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again.'
EMAIL_TAKEN = 'Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew.'
USERNAME_TAKEN = 'Uh-oh! This username is already registered. Please choose another username.'
PASSWORD_TOO_LONG = (
'Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long.'
) )
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." COMMAND_TAKEN = 'Uh-oh! This command is already registered. Please choose another command string.'
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance." FILE_EXISTS = 'Uh-oh! This file is already registered. Please choose another file.'
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again."
EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew."
USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username."
)
PASSWORD_TOO_LONG = "Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long."
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string." ID_TAKEN = 'Uh-oh! This id is already registered. Please choose another id string.'
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." MODEL_ID_TAKEN = 'Uh-oh! This model id is already registered. Please choose another model id string.'
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." NAME_TAG_TAKEN = 'Uh-oh! This name tag is already registered. Please choose another name tag string.'
MODEL_ID_TOO_LONG = "The model id is too long. Please make sure your model id is less than 256 characters long." MODEL_ID_TOO_LONG = 'The model id is too long. Please make sure your model id is less than 256 characters long.'
INVALID_TOKEN = ( INVALID_TOKEN = 'Your session has expired or the token is invalid. Please sign in again.'
"Your session has expired or the token is invalid. Please sign in again." INVALID_CRED = 'The email or password provided is incorrect. Please check for typos and try logging in again.'
)
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)." INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
INCORRECT_PASSWORD = ( INCORRECT_PASSWORD = 'The password provided is incorrect. Please check for typos and try again.'
"The password provided is incorrect. Please check for typos and try again." INVALID_TRUSTED_HEADER = (
'Your provider has not provided a trusted header. Please contact your administrator for assistance.'
) )
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation." EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation."
UNAUTHORIZED = "401 Unauthorized" UNAUTHORIZED = '401 Unauthorized'
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." ACCESS_PROHIBITED = (
ACTION_PROHIBITED = ( 'You do not have permission to access this resource. Please contact your administrator for assistance.'
"The requested action has been restricted as a security measure."
) )
ACTION_PROHIBITED = 'The requested action has been restricted as a security measure.'
FILE_NOT_SENT = "FILE_NOT_SENT" FILE_NOT_SENT = 'FILE_NOT_SENT'
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again." FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again."
NOT_FOUND = "We could not find what you're looking for :/" NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment." API_KEY_NOT_ALLOWED = 'Use of API key is not enabled in the environment.'
MALICIOUS = "Unusual activities detected, please try again in a few minutes." MALICIOUS = 'Unusual activities detected, please try again in a few minutes.'
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance." PANDOC_NOT_INSTALLED = 'Pandoc is not installed on the server. Please contact your administrator for assistance.'
INCORRECT_FORMAT = ( INCORRECT_FORMAT = lambda err='': f'Invalid format. Please use the correct format{err}'
lambda err="": f"Invalid format. Please use the correct format{err}" RATE_LIMIT_EXCEEDED = 'API rate limit exceeded'
)
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" MODEL_NOT_FOUND = lambda name='': f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found" OPENAI_NOT_FOUND = lambda name='': 'OpenAI API was not found'
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" OLLAMA_NOT_FOUND = 'WebUI could not connect to Ollama'
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." CREATE_API_KEY_ERROR = 'Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance.'
API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment." API_KEY_CREATION_NOT_ALLOWED = 'API key creation is not allowed in the environment.'
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." EMPTY_CONTENT = 'The content provided is empty. Please ensure that there is text or data present before proceeding.'
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases." DB_NOT_SQLITE = 'This feature is only available when running with SQLite databases.'
INVALID_URL = ( INVALID_URL = 'Oops! The URL you provided is invalid. Please double-check and try again.'
"Oops! The URL you provided is invalid. Please double-check and try again."
)
WEB_SEARCH_ERROR = ( WEB_SEARCH_ERROR = lambda err='': f'{err if err else "Oops! Something went wrong while searching the web."}'
lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
)
OLLAMA_API_DISABLED = ( OLLAMA_API_DISABLED = 'The Ollama API is disabled. Please enable it to use this feature.'
"The Ollama API is disabled. Please enable it to use this feature."
)
FILE_TOO_LARGE = ( FILE_TOO_LARGE = (
lambda size="": f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}." lambda size='': f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}."
) )
DUPLICATE_CONTENT = ( DUPLICATE_CONTENT = 'Duplicate content detected. Please provide unique content to proceed.'
"Duplicate content detected. Please provide unique content to proceed." FILE_NOT_PROCESSED = (
'Extracted content is not available for this file. Please ensure that the file is processed before proceeding.'
) )
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
INVALID_PASSWORD = lambda err="": ( INVALID_PASSWORD = lambda err='': (err if err else 'The password does not meet the required validation criteria.')
err if err else "The password does not meet the required validation criteria."
)
class TASKS(str, Enum): class TASKS(str, Enum):
def __str__(self) -> str: def __str__(self) -> str:
return super().__str__() return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'generation'}" DEFAULT = lambda task='': f'{task if task else "generation"}'
TITLE_GENERATION = "title_generation" TITLE_GENERATION = 'title_generation'
FOLLOW_UP_GENERATION = "follow_up_generation" FOLLOW_UP_GENERATION = 'follow_up_generation'
TAGS_GENERATION = "tags_generation" TAGS_GENERATION = 'tags_generation'
EMOJI_GENERATION = "emoji_generation" EMOJI_GENERATION = 'emoji_generation'
QUERY_GENERATION = "query_generation" QUERY_GENERATION = 'query_generation'
IMAGE_PROMPT_GENERATION = "image_prompt_generation" IMAGE_PROMPT_GENERATION = 'image_prompt_generation'
AUTOCOMPLETE_GENERATION = "autocomplete_generation" AUTOCOMPLETE_GENERATION = 'autocomplete_generation'
FUNCTION_CALLING = "function_calling" FUNCTION_CALLING = 'function_calling'
MOA_RESPONSE_GENERATION = "moa_response_generation" MOA_RESPONSE_GENERATION = 'moa_response_generation'

File diff suppressed because it is too large Load Diff

View File

@@ -57,17 +57,15 @@ log = logging.getLogger(__name__)
def get_function_module_by_id(request: Request, pipe_id: str): def get_function_module_by_id(request: Request, pipe_id: str):
function_module, _, _ = get_function_module_from_cache(request, pipe_id) function_module, _, _ = get_function_module_from_cache(request, pipe_id)
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): if hasattr(function_module, 'valves') and hasattr(function_module, 'Valves'):
Valves = function_module.Valves Valves = function_module.Valves
valves = Functions.get_function_valves_by_id(pipe_id) valves = Functions.get_function_valves_by_id(pipe_id)
if valves: if valves:
try: try:
function_module.valves = Valves( function_module.valves = Valves(**{k: v for k, v in valves.items() if v is not None})
**{k: v for k, v in valves.items() if v is not None}
)
except Exception as e: except Exception as e:
log.exception(f"Error loading valves for function {pipe_id}: {e}") log.exception(f'Error loading valves for function {pipe_id}: {e}')
raise e raise e
else: else:
function_module.valves = Valves() function_module.valves = Valves()
@@ -76,7 +74,7 @@ def get_function_module_by_id(request: Request, pipe_id: str):
async def get_function_models(request): async def get_function_models(request):
pipes = Functions.get_functions_by_type("pipe", active_only=True) pipes = Functions.get_functions_by_type('pipe', active_only=True)
pipe_models = [] pipe_models = []
for pipe in pipes: for pipe in pipes:
@@ -84,11 +82,11 @@ async def get_function_models(request):
function_module = get_function_module_by_id(request, pipe.id) function_module = get_function_module_by_id(request, pipe.id)
has_user_valves = False has_user_valves = False
if hasattr(function_module, "UserValves"): if hasattr(function_module, 'UserValves'):
has_user_valves = True has_user_valves = True
# Check if function is a manifold # Check if function is a manifold
if hasattr(function_module, "pipes"): if hasattr(function_module, 'pipes'):
sub_pipes = [] sub_pipes = []
# Handle pipes being a list, sync function, or async function # Handle pipes being a list, sync function, or async function
@@ -104,32 +102,30 @@ async def get_function_models(request):
log.exception(e) log.exception(e)
sub_pipes = [] sub_pipes = []
log.debug( log.debug(f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}")
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
)
for p in sub_pipes: for p in sub_pipes:
sub_pipe_id = f'{pipe.id}.{p["id"]}' sub_pipe_id = f'{pipe.id}.{p["id"]}'
sub_pipe_name = p["name"] sub_pipe_name = p['name']
if hasattr(function_module, "name"): if hasattr(function_module, 'name'):
sub_pipe_name = f"{function_module.name}{sub_pipe_name}" sub_pipe_name = f'{function_module.name}{sub_pipe_name}'
pipe_flag = {"type": pipe.type} pipe_flag = {'type': pipe.type}
pipe_models.append( pipe_models.append(
{ {
"id": sub_pipe_id, 'id': sub_pipe_id,
"name": sub_pipe_name, 'name': sub_pipe_name,
"object": "model", 'object': 'model',
"created": pipe.created_at, 'created': pipe.created_at,
"owned_by": "openai", 'owned_by': 'openai',
"pipe": pipe_flag, 'pipe': pipe_flag,
"has_user_valves": has_user_valves, 'has_user_valves': has_user_valves,
} }
) )
else: else:
pipe_flag = {"type": "pipe"} pipe_flag = {'type': 'pipe'}
log.debug( log.debug(
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
@@ -137,13 +133,13 @@ async def get_function_models(request):
pipe_models.append( pipe_models.append(
{ {
"id": pipe.id, 'id': pipe.id,
"name": pipe.name, 'name': pipe.name,
"object": "model", 'object': 'model',
"created": pipe.created_at, 'created': pipe.created_at,
"owned_by": "openai", 'owned_by': 'openai',
"pipe": pipe_flag, 'pipe': pipe_flag,
"has_user_valves": has_user_valves, 'has_user_valves': has_user_valves,
} }
) )
except Exception as e: except Exception as e:
@@ -153,9 +149,7 @@ async def get_function_models(request):
return pipe_models return pipe_models
async def generate_function_chat_completion( async def generate_function_chat_completion(request, form_data, user, models: dict = {}):
request, form_data, user, models: dict = {}
):
async def execute_pipe(pipe, params): async def execute_pipe(pipe, params):
if inspect.iscoroutinefunction(pipe): if inspect.iscoroutinefunction(pipe):
return await pipe(**params) return await pipe(**params)
@@ -166,32 +160,32 @@ async def generate_function_chat_completion(
if isinstance(res, str): if isinstance(res, str):
return res return res
if isinstance(res, Generator): if isinstance(res, Generator):
return "".join(map(str, res)) return ''.join(map(str, res))
if isinstance(res, AsyncGenerator): if isinstance(res, AsyncGenerator):
return "".join([str(stream) async for stream in res]) return ''.join([str(stream) async for stream in res])
def process_line(form_data: dict, line): def process_line(form_data: dict, line):
if isinstance(line, BaseModel): if isinstance(line, BaseModel):
line = line.model_dump_json() line = line.model_dump_json()
line = f"data: {line}" line = f'data: {line}'
if isinstance(line, dict): if isinstance(line, dict):
line = f"data: {json.dumps(line)}" line = f'data: {json.dumps(line)}'
try: try:
line = line.decode("utf-8") line = line.decode('utf-8')
except Exception: except Exception:
pass pass
if line.startswith("data:"): if line.startswith('data:'):
return f"{line}\n\n" return f'{line}\n\n'
else: else:
line = openai_chat_chunk_message_template(form_data["model"], line) line = openai_chat_chunk_message_template(form_data['model'], line)
return f"data: {json.dumps(line)}\n\n" return f'data: {json.dumps(line)}\n\n'
def get_pipe_id(form_data: dict) -> str: def get_pipe_id(form_data: dict) -> str:
pipe_id = form_data["model"] pipe_id = form_data['model']
if "." in pipe_id: if '.' in pipe_id:
pipe_id, _ = pipe_id.split(".", 1) pipe_id, _ = pipe_id.split('.', 1)
return pipe_id return pipe_id
def get_function_params(function_module, form_data, user, extra_params=None): def get_function_params(function_module, form_data, user, extra_params=None):
@@ -202,27 +196,25 @@ async def generate_function_chat_completion(
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(function_module.pipe) sig = inspect.signature(function_module.pipe)
params = {"body": form_data} | { params = {'body': form_data} | {k: v for k, v in extra_params.items() if k in sig.parameters}
k: v for k, v in extra_params.items() if k in sig.parameters
}
if "__user__" in params and hasattr(function_module, "UserValves"): if '__user__' in params and hasattr(function_module, 'UserValves'):
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id) user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
try: try:
params["__user__"]["valves"] = function_module.UserValves(**user_valves) params['__user__']['valves'] = function_module.UserValves(**user_valves)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
params["__user__"]["valves"] = function_module.UserValves() params['__user__']['valves'] = function_module.UserValves()
return params return params
model_id = form_data.get("model") model_id = form_data.get('model')
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
metadata = form_data.pop("metadata", {}) metadata = form_data.pop('metadata', {})
files = metadata.get("files", []) files = metadata.get('files', [])
tool_ids = metadata.get("tool_ids", []) tool_ids = metadata.get('tool_ids', [])
# Check if tool_ids is None # Check if tool_ids is None
if tool_ids is None: if tool_ids is None:
tool_ids = [] tool_ids = []
@@ -233,56 +225,56 @@ async def generate_function_chat_completion(
__task_body__ = None __task_body__ = None
if metadata: if metadata:
if all(k in metadata for k in ("session_id", "chat_id", "message_id")): if all(k in metadata for k in ('session_id', 'chat_id', 'message_id')):
__event_emitter__ = get_event_emitter(metadata) __event_emitter__ = get_event_emitter(metadata)
__event_call__ = get_event_call(metadata) __event_call__ = get_event_call(metadata)
__task__ = metadata.get("task", None) __task__ = metadata.get('task', None)
__task_body__ = metadata.get("task_body", None) __task_body__ = metadata.get('task_body', None)
oauth_token = None oauth_token = None
try: try:
if request.cookies.get("oauth_session_id", None): if request.cookies.get('oauth_session_id', None):
oauth_token = await request.app.state.oauth_manager.get_oauth_token( oauth_token = await request.app.state.oauth_manager.get_oauth_token(
user.id, user.id,
request.cookies.get("oauth_session_id", None), request.cookies.get('oauth_session_id', None),
) )
except Exception as e: except Exception as e:
log.error(f"Error getting OAuth token: {e}") log.error(f'Error getting OAuth token: {e}')
extra_params = { extra_params = {
"__event_emitter__": __event_emitter__, '__event_emitter__': __event_emitter__,
"__event_call__": __event_call__, '__event_call__': __event_call__,
"__chat_id__": metadata.get("chat_id", None), '__chat_id__': metadata.get('chat_id', None),
"__session_id__": metadata.get("session_id", None), '__session_id__': metadata.get('session_id', None),
"__message_id__": metadata.get("message_id", None), '__message_id__': metadata.get('message_id', None),
"__task__": __task__, '__task__': __task__,
"__task_body__": __task_body__, '__task_body__': __task_body__,
"__files__": files, '__files__': files,
"__user__": user.model_dump() if isinstance(user, UserModel) else {}, '__user__': user.model_dump() if isinstance(user, UserModel) else {},
"__metadata__": metadata, '__metadata__': metadata,
"__oauth_token__": oauth_token, '__oauth_token__': oauth_token,
"__request__": request, '__request__': request,
} }
extra_params["__tools__"] = await get_tools( extra_params['__tools__'] = await get_tools(
request, request,
tool_ids, tool_ids,
user, user,
{ {
**extra_params, **extra_params,
"__model__": models.get(form_data["model"], None), '__model__': models.get(form_data['model'], None),
"__messages__": form_data["messages"], '__messages__': form_data['messages'],
"__files__": files, '__files__': files,
}, },
) )
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
form_data["model"] = model_info.base_model_id form_data['model'] = model_info.base_model_id
params = model_info.params.model_dump() params = model_info.params.model_dump()
if params: if params:
system = params.pop("system", None) system = params.pop('system', None)
form_data = apply_model_params_to_body_openai(params, form_data) form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_system_prompt_to_body(system, form_data, metadata, user) form_data = apply_system_prompt_to_body(system, form_data, metadata, user)
@@ -292,7 +284,7 @@ async def generate_function_chat_completion(
pipe = function_module.pipe pipe = function_module.pipe
params = get_function_params(function_module, form_data, user, extra_params) params = get_function_params(function_module, form_data, user, extra_params)
if form_data.get("stream", False): if form_data.get('stream', False):
async def stream_content(): async def stream_content():
try: try:
@@ -304,17 +296,17 @@ async def generate_function_chat_completion(
yield data yield data
return return
if isinstance(res, dict): if isinstance(res, dict):
yield f"data: {json.dumps(res)}\n\n" yield f'data: {json.dumps(res)}\n\n'
return return
except Exception as e: except Exception as e:
log.error(f"Error: {e}") log.error(f'Error: {e}')
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n" yield f'data: {json.dumps({"error": {"detail": str(e)}})}\n\n'
return return
if isinstance(res, str): if isinstance(res, str):
message = openai_chat_chunk_message_template(form_data["model"], res) message = openai_chat_chunk_message_template(form_data['model'], res)
yield f"data: {json.dumps(message)}\n\n" yield f'data: {json.dumps(message)}\n\n'
if isinstance(res, Iterator): if isinstance(res, Iterator):
for line in res: for line in res:
@@ -325,21 +317,19 @@ async def generate_function_chat_completion(
yield process_line(form_data, line) yield process_line(form_data, line)
if isinstance(res, str) or isinstance(res, Generator): if isinstance(res, str) or isinstance(res, Generator):
finish_message = openai_chat_chunk_message_template( finish_message = openai_chat_chunk_message_template(form_data['model'], '')
form_data["model"], "" finish_message['choices'][0]['finish_reason'] = 'stop'
) yield f'data: {json.dumps(finish_message)}\n\n'
finish_message["choices"][0]["finish_reason"] = "stop" yield 'data: [DONE]'
yield f"data: {json.dumps(finish_message)}\n\n"
yield "data: [DONE]"
return StreamingResponse(stream_content(), media_type="text/event-stream") return StreamingResponse(stream_content(), media_type='text/event-stream')
else: else:
try: try:
res = await execute_pipe(pipe, params) res = await execute_pipe(pipe, params)
except Exception as e: except Exception as e:
log.error(f"Error: {e}") log.error(f'Error: {e}')
return {"error": {"detail": str(e)}} return {'error': {'detail': str(e)}}
if isinstance(res, StreamingResponse) or isinstance(res, dict): if isinstance(res, StreamingResponse) or isinstance(res, dict):
return res return res
@@ -347,4 +337,4 @@ async def generate_function_chat_completion(
return res.model_dump() return res.model_dump()
message = await get_message_content(res) message = await get_message_content(res)
return openai_chat_completion_message_template(form_data["model"], message) return openai_chat_completion_message_template(form_data['model'], message)

View File

@@ -56,17 +56,15 @@ def handle_peewee_migration(DATABASE_URL):
# db = None # db = None
try: try:
# Replace the postgresql:// with postgres:// to handle the peewee migration # Replace the postgresql:// with postgres:// to handle the peewee migration
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://")) db = register_connection(DATABASE_URL.replace('postgresql://', 'postgres://'))
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations" migrate_dir = OPEN_WEBUI_DIR / 'internal' / 'migrations'
router = Router(db, logger=log, migrate_dir=migrate_dir) router = Router(db, logger=log, migrate_dir=migrate_dir)
router.run() router.run()
db.close() db.close()
except Exception as e: except Exception as e:
log.error(f"Failed to initialize the database connection: {e}") log.error(f'Failed to initialize the database connection: {e}')
log.warning( log.warning('Hint: If your database password contains special characters, you may need to URL-encode it.')
"Hint: If your database password contains special characters, you may need to URL-encode it."
)
raise raise
finally: finally:
# Properly closing the database connection # Properly closing the database connection
@@ -74,7 +72,7 @@ def handle_peewee_migration(DATABASE_URL):
db.close() db.close()
# Assert if db connection has been closed # Assert if db connection has been closed
assert db.is_closed(), "Database connection is still open." assert db.is_closed(), 'Database connection is still open.'
if ENABLE_DB_MIGRATIONS: if ENABLE_DB_MIGRATIONS:
@@ -84,15 +82,13 @@ if ENABLE_DB_MIGRATIONS:
SQLALCHEMY_DATABASE_URL = DATABASE_URL SQLALCHEMY_DATABASE_URL = DATABASE_URL
# Handle SQLCipher URLs # Handle SQLCipher URLs
if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"): if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'):
database_password = os.environ.get("DATABASE_PASSWORD") database_password = os.environ.get('DATABASE_PASSWORD')
if not database_password or database_password.strip() == "": if not database_password or database_password.strip() == '':
raise ValueError( raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
# Extract database path from SQLCipher URL # Extract database path from SQLCipher URL
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "") db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '')
# Create a custom creator function that uses sqlcipher3 # Create a custom creator function that uses sqlcipher3
def create_sqlcipher_connection(): def create_sqlcipher_connection():
@@ -109,7 +105,7 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
# or QueuePool if DATABASE_POOL_SIZE is explicitly configured. # or QueuePool if DATABASE_POOL_SIZE is explicitly configured.
if isinstance(DATABASE_POOL_SIZE, int) and DATABASE_POOL_SIZE > 0: if isinstance(DATABASE_POOL_SIZE, int) and DATABASE_POOL_SIZE > 0:
engine = create_engine( engine = create_engine(
"sqlite://", 'sqlite://',
creator=create_sqlcipher_connection, creator=create_sqlcipher_connection,
pool_size=DATABASE_POOL_SIZE, pool_size=DATABASE_POOL_SIZE,
max_overflow=DATABASE_POOL_MAX_OVERFLOW, max_overflow=DATABASE_POOL_MAX_OVERFLOW,
@@ -121,28 +117,26 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
) )
else: else:
engine = create_engine( engine = create_engine(
"sqlite://", 'sqlite://',
creator=create_sqlcipher_connection, creator=create_sqlcipher_connection,
poolclass=NullPool, poolclass=NullPool,
echo=False, echo=False,
) )
log.info("Connected to encrypted SQLite database using SQLCipher") log.info('Connected to encrypted SQLite database using SQLCipher')
elif "sqlite" in SQLALCHEMY_DATABASE_URL: elif 'sqlite' in SQLALCHEMY_DATABASE_URL:
engine = create_engine( engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={'check_same_thread': False})
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
def on_connect(dbapi_connection, connection_record): def on_connect(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor() cursor = dbapi_connection.cursor()
if DATABASE_ENABLE_SQLITE_WAL: if DATABASE_ENABLE_SQLITE_WAL:
cursor.execute("PRAGMA journal_mode=WAL") cursor.execute('PRAGMA journal_mode=WAL')
else: else:
cursor.execute("PRAGMA journal_mode=DELETE") cursor.execute('PRAGMA journal_mode=DELETE')
cursor.close() cursor.close()
event.listen(engine, "connect", on_connect) event.listen(engine, 'connect', on_connect)
else: else:
if isinstance(DATABASE_POOL_SIZE, int): if isinstance(DATABASE_POOL_SIZE, int):
if DATABASE_POOL_SIZE > 0: if DATABASE_POOL_SIZE > 0:
@@ -156,16 +150,12 @@ else:
poolclass=QueuePool, poolclass=QueuePool,
) )
else: else:
engine = create_engine( engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool)
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
)
else: else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker( SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
metadata_obj = MetaData(schema=DATABASE_SCHEMA) metadata_obj = MetaData(schema=DATABASE_SCHEMA)
Base = declarative_base(metadata=metadata_obj) Base = declarative_base(metadata=metadata_obj)
ScopedSession = scoped_session(SessionLocal) ScopedSession = scoped_session(SessionLocal)

View File

@@ -56,7 +56,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
active = pw.BooleanField() active = pw.BooleanField()
class Meta: class Meta:
table_name = "auth" table_name = 'auth'
@migrator.create_model @migrator.create_model
class Chat(pw.Model): class Chat(pw.Model):
@@ -67,7 +67,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "chat" table_name = 'chat'
@migrator.create_model @migrator.create_model
class ChatIdTag(pw.Model): class ChatIdTag(pw.Model):
@@ -78,7 +78,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "chatidtag" table_name = 'chatidtag'
@migrator.create_model @migrator.create_model
class Document(pw.Model): class Document(pw.Model):
@@ -92,7 +92,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "document" table_name = 'document'
@migrator.create_model @migrator.create_model
class Modelfile(pw.Model): class Modelfile(pw.Model):
@@ -103,7 +103,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "modelfile" table_name = 'modelfile'
@migrator.create_model @migrator.create_model
class Prompt(pw.Model): class Prompt(pw.Model):
@@ -115,7 +115,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "prompt" table_name = 'prompt'
@migrator.create_model @migrator.create_model
class Tag(pw.Model): class Tag(pw.Model):
@@ -125,7 +125,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
data = pw.TextField(null=True) data = pw.TextField(null=True)
class Meta: class Meta:
table_name = "tag" table_name = 'tag'
@migrator.create_model @migrator.create_model
class User(pw.Model): class User(pw.Model):
@@ -137,7 +137,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "user" table_name = 'user'
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
@@ -149,7 +149,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
active = pw.BooleanField() active = pw.BooleanField()
class Meta: class Meta:
table_name = "auth" table_name = 'auth'
@migrator.create_model @migrator.create_model
class Chat(pw.Model): class Chat(pw.Model):
@@ -160,7 +160,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "chat" table_name = 'chat'
@migrator.create_model @migrator.create_model
class ChatIdTag(pw.Model): class ChatIdTag(pw.Model):
@@ -171,7 +171,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "chatidtag" table_name = 'chatidtag'
@migrator.create_model @migrator.create_model
class Document(pw.Model): class Document(pw.Model):
@@ -185,7 +185,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "document" table_name = 'document'
@migrator.create_model @migrator.create_model
class Modelfile(pw.Model): class Modelfile(pw.Model):
@@ -196,7 +196,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "modelfile" table_name = 'modelfile'
@migrator.create_model @migrator.create_model
class Prompt(pw.Model): class Prompt(pw.Model):
@@ -208,7 +208,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "prompt" table_name = 'prompt'
@migrator.create_model @migrator.create_model
class Tag(pw.Model): class Tag(pw.Model):
@@ -218,7 +218,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
data = pw.TextField(null=True) data = pw.TextField(null=True)
class Meta: class Meta:
table_name = "tag" table_name = 'tag'
@migrator.create_model @migrator.create_model
class User(pw.Model): class User(pw.Model):
@@ -230,24 +230,24 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
timestamp = pw.BigIntegerField() timestamp = pw.BigIntegerField()
class Meta: class Meta:
table_name = "user" table_name = 'user'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_model("user") migrator.remove_model('user')
migrator.remove_model("tag") migrator.remove_model('tag')
migrator.remove_model("prompt") migrator.remove_model('prompt')
migrator.remove_model("modelfile") migrator.remove_model('modelfile')
migrator.remove_model("document") migrator.remove_model('document')
migrator.remove_model("chatidtag") migrator.remove_model('chatidtag')
migrator.remove_model("chat") migrator.remove_model('chat')
migrator.remove_model("auth") migrator.remove_model('auth')

View File

@@ -36,12 +36,10 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
migrator.add_fields( migrator.add_fields('chat', share_id=pw.CharField(max_length=255, null=True, unique=True))
"chat", share_id=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_fields("chat", "share_id") migrator.remove_fields('chat', 'share_id')

View File

@@ -36,12 +36,10 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
migrator.add_fields( migrator.add_fields('user', api_key=pw.CharField(max_length=255, null=True, unique=True))
"user", api_key=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_fields("user", "api_key") migrator.remove_fields('user', 'api_key')

View File

@@ -36,10 +36,10 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
migrator.add_fields("chat", archived=pw.BooleanField(default=False)) migrator.add_fields('chat', archived=pw.BooleanField(default=False))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_fields("chat", "archived") migrator.remove_fields('chat', 'archived')

View File

@@ -45,22 +45,20 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table # Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields( migrator.add_fields(
"chat", 'chat',
created_at=pw.DateTimeField(null=True), # Allow null for transition created_at=pw.DateTimeField(null=True), # Allow null for transition
updated_at=pw.DateTimeField(null=True), # Allow null for transition updated_at=pw.DateTimeField(null=True), # Allow null for transition
) )
# Populate the new fields from an existing 'timestamp' field # Populate the new fields from an existing 'timestamp' field
migrator.sql( migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL')
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
# Now that the data has been copied, remove the original 'timestamp' field # Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp") migrator.remove_fields('chat', 'timestamp')
# Update the fields to be not null now that they are populated # Update the fields to be not null now that they are populated
migrator.change_fields( migrator.change_fields(
"chat", 'chat',
created_at=pw.DateTimeField(null=False), created_at=pw.DateTimeField(null=False),
updated_at=pw.DateTimeField(null=False), updated_at=pw.DateTimeField(null=False),
) )
@@ -69,22 +67,20 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'chat' table # Adding fields created_at and updated_at to the 'chat' table
migrator.add_fields( migrator.add_fields(
"chat", 'chat',
created_at=pw.BigIntegerField(null=True), # Allow null for transition created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition updated_at=pw.BigIntegerField(null=True), # Allow null for transition
) )
# Populate the new fields from an existing 'timestamp' field # Populate the new fields from an existing 'timestamp' field
migrator.sql( migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL')
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
)
# Now that the data has been copied, remove the original 'timestamp' field # Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("chat", "timestamp") migrator.remove_fields('chat', 'timestamp')
# Update the fields to be not null now that they are populated # Update the fields to be not null now that they are populated
migrator.change_fields( migrator.change_fields(
"chat", 'chat',
created_at=pw.BigIntegerField(null=False), created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False), updated_at=pw.BigIntegerField(null=False),
) )
@@ -101,29 +97,29 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition # Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True)) migrator.add_fields('chat', timestamp=pw.DateTimeField(null=True))
# Copy the earliest created_at date back into the new timestamp field # Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp # This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at") migrator.sql('UPDATE chat SET timestamp = created_at')
# Remove the created_at and updated_at fields # Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at") migrator.remove_fields('chat', 'created_at', 'updated_at')
# Finally, alter the timestamp field to not allow nulls if that was the original setting # Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False)) migrator.change_fields('chat', timestamp=pw.DateTimeField(null=False))
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False): def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
# Recreate the timestamp field initially allowing null values for safe transition # Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True)) migrator.add_fields('chat', timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field # Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp # This assumes created_at was originally a copy of timestamp
migrator.sql("UPDATE chat SET timestamp = created_at") migrator.sql('UPDATE chat SET timestamp = created_at')
# Remove the created_at and updated_at fields # Remove the created_at and updated_at fields
migrator.remove_fields("chat", "created_at", "updated_at") migrator.remove_fields('chat', 'created_at', 'updated_at')
# Finally, alter the timestamp field to not allow nulls if that was the original setting # Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False)) migrator.change_fields('chat', timestamp=pw.BigIntegerField(null=False))

View File

@@ -38,45 +38,45 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
# Alter the tables with timestamps # Alter the tables with timestamps
migrator.change_fields( migrator.change_fields(
"chatidtag", 'chatidtag',
timestamp=pw.BigIntegerField(), timestamp=pw.BigIntegerField(),
) )
migrator.change_fields( migrator.change_fields(
"document", 'document',
timestamp=pw.BigIntegerField(), timestamp=pw.BigIntegerField(),
) )
migrator.change_fields( migrator.change_fields(
"modelfile", 'modelfile',
timestamp=pw.BigIntegerField(), timestamp=pw.BigIntegerField(),
) )
migrator.change_fields( migrator.change_fields(
"prompt", 'prompt',
timestamp=pw.BigIntegerField(), timestamp=pw.BigIntegerField(),
) )
migrator.change_fields( migrator.change_fields(
"user", 'user',
timestamp=pw.BigIntegerField(), timestamp=pw.BigIntegerField(),
) )
# Alter the tables with varchar to text where necessary # Alter the tables with varchar to text where necessary
migrator.change_fields( migrator.change_fields(
"auth", 'auth',
password=pw.TextField(), password=pw.TextField(),
) )
migrator.change_fields( migrator.change_fields(
"chat", 'chat',
title=pw.TextField(), title=pw.TextField(),
) )
migrator.change_fields( migrator.change_fields(
"document", 'document',
title=pw.TextField(), title=pw.TextField(),
filename=pw.TextField(), filename=pw.TextField(),
) )
migrator.change_fields( migrator.change_fields(
"prompt", 'prompt',
title=pw.TextField(), title=pw.TextField(),
) )
migrator.change_fields( migrator.change_fields(
"user", 'user',
profile_image_url=pw.TextField(), profile_image_url=pw.TextField(),
) )
@@ -87,43 +87,43 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
if isinstance(database, pw.SqliteDatabase): if isinstance(database, pw.SqliteDatabase):
# Alter the tables with timestamps # Alter the tables with timestamps
migrator.change_fields( migrator.change_fields(
"chatidtag", 'chatidtag',
timestamp=pw.DateField(), timestamp=pw.DateField(),
) )
migrator.change_fields( migrator.change_fields(
"document", 'document',
timestamp=pw.DateField(), timestamp=pw.DateField(),
) )
migrator.change_fields( migrator.change_fields(
"modelfile", 'modelfile',
timestamp=pw.DateField(), timestamp=pw.DateField(),
) )
migrator.change_fields( migrator.change_fields(
"prompt", 'prompt',
timestamp=pw.DateField(), timestamp=pw.DateField(),
) )
migrator.change_fields( migrator.change_fields(
"user", 'user',
timestamp=pw.DateField(), timestamp=pw.DateField(),
) )
migrator.change_fields( migrator.change_fields(
"auth", 'auth',
password=pw.CharField(max_length=255), password=pw.CharField(max_length=255),
) )
migrator.change_fields( migrator.change_fields(
"chat", 'chat',
title=pw.CharField(), title=pw.CharField(),
) )
migrator.change_fields( migrator.change_fields(
"document", 'document',
title=pw.CharField(), title=pw.CharField(),
filename=pw.CharField(), filename=pw.CharField(),
) )
migrator.change_fields( migrator.change_fields(
"prompt", 'prompt',
title=pw.CharField(), title=pw.CharField(),
) )
migrator.change_fields( migrator.change_fields(
"user", 'user',
profile_image_url=pw.CharField(), profile_image_url=pw.CharField(),
) )

View File

@@ -38,7 +38,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
# Adding fields created_at and updated_at to the 'user' table # Adding fields created_at and updated_at to the 'user' table
migrator.add_fields( migrator.add_fields(
"user", 'user',
created_at=pw.BigIntegerField(null=True), # Allow null for transition created_at=pw.BigIntegerField(null=True), # Allow null for transition
updated_at=pw.BigIntegerField(null=True), # Allow null for transition updated_at=pw.BigIntegerField(null=True), # Allow null for transition
last_active_at=pw.BigIntegerField(null=True), # Allow null for transition last_active_at=pw.BigIntegerField(null=True), # Allow null for transition
@@ -50,11 +50,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
) )
# Now that the data has been copied, remove the original 'timestamp' field # Now that the data has been copied, remove the original 'timestamp' field
migrator.remove_fields("user", "timestamp") migrator.remove_fields('user', 'timestamp')
# Update the fields to be not null now that they are populated # Update the fields to be not null now that they are populated
migrator.change_fields( migrator.change_fields(
"user", 'user',
created_at=pw.BigIntegerField(null=False), created_at=pw.BigIntegerField(null=False),
updated_at=pw.BigIntegerField(null=False), updated_at=pw.BigIntegerField(null=False),
last_active_at=pw.BigIntegerField(null=False), last_active_at=pw.BigIntegerField(null=False),
@@ -65,14 +65,14 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
# Recreate the timestamp field initially allowing null values for safe transition # Recreate the timestamp field initially allowing null values for safe transition
migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True)) migrator.add_fields('user', timestamp=pw.BigIntegerField(null=True))
# Copy the earliest created_at date back into the new timestamp field # Copy the earliest created_at date back into the new timestamp field
# This assumes created_at was originally a copy of timestamp # This assumes created_at was originally a copy of timestamp
migrator.sql('UPDATE "user" SET timestamp = created_at') migrator.sql('UPDATE "user" SET timestamp = created_at')
# Remove the created_at and updated_at fields # Remove the created_at and updated_at fields
migrator.remove_fields("user", "created_at", "updated_at", "last_active_at") migrator.remove_fields('user', 'created_at', 'updated_at', 'last_active_at')
# Finally, alter the timestamp field to not allow nulls if that was the original setting # Finally, alter the timestamp field to not allow nulls if that was the original setting
migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False)) migrator.change_fields('user', timestamp=pw.BigIntegerField(null=False))

View File

@@ -43,10 +43,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
created_at = pw.BigIntegerField(null=False) created_at = pw.BigIntegerField(null=False)
class Meta: class Meta:
table_name = "memory" table_name = 'memory'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_model("memory") migrator.remove_model('memory')

View File

@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
updated_at = pw.BigIntegerField(null=False) updated_at = pw.BigIntegerField(null=False)
class Meta: class Meta:
table_name = "model" table_name = 'model'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_model("model") migrator.remove_model('model')

View File

@@ -42,12 +42,12 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
# Fetch data from 'modelfile' table and insert into 'model' table # Fetch data from 'modelfile' table and insert into 'model' table
migrate_modelfile_to_model(migrator, database) migrate_modelfile_to_model(migrator, database)
# Drop the 'modelfile' table # Drop the 'modelfile' table
migrator.remove_model("modelfile") migrator.remove_model('modelfile')
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database): def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
ModelFile = migrator.orm["modelfile"] ModelFile = migrator.orm['modelfile']
Model = migrator.orm["model"] Model = migrator.orm['model']
modelfiles = ModelFile.select() modelfiles = ModelFile.select()
@@ -57,25 +57,25 @@ def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
modelfile.modelfile = json.loads(modelfile.modelfile) modelfile.modelfile = json.loads(modelfile.modelfile)
meta = json.dumps( meta = json.dumps(
{ {
"description": modelfile.modelfile.get("desc"), 'description': modelfile.modelfile.get('desc'),
"profile_image_url": modelfile.modelfile.get("imageUrl"), 'profile_image_url': modelfile.modelfile.get('imageUrl'),
"ollama": {"modelfile": modelfile.modelfile.get("content")}, 'ollama': {'modelfile': modelfile.modelfile.get('content')},
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"), 'suggestion_prompts': modelfile.modelfile.get('suggestionPrompts'),
"categories": modelfile.modelfile.get("categories"), 'categories': modelfile.modelfile.get('categories'),
"user": {**modelfile.modelfile.get("user", {}), "community": True}, 'user': {**modelfile.modelfile.get('user', {}), 'community': True},
} }
) )
info = parse_ollama_modelfile(modelfile.modelfile.get("content")) info = parse_ollama_modelfile(modelfile.modelfile.get('content'))
# Insert the processed data into the 'model' table # Insert the processed data into the 'model' table
Model.create( Model.create(
id=f"ollama-{modelfile.tag_name}", id=f'ollama-{modelfile.tag_name}',
user_id=modelfile.user_id, user_id=modelfile.user_id,
base_model_id=info.get("base_model_id"), base_model_id=info.get('base_model_id'),
name=modelfile.modelfile.get("title"), name=modelfile.modelfile.get('title'),
meta=meta, meta=meta,
params=json.dumps(info.get("params", {})), params=json.dumps(info.get('params', {})),
created_at=modelfile.timestamp, created_at=modelfile.timestamp,
updated_at=modelfile.timestamp, updated_at=modelfile.timestamp,
) )
@@ -86,7 +86,7 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
recreate_modelfile_table(migrator, database) recreate_modelfile_table(migrator, database)
move_data_back_to_modelfile(migrator, database) move_data_back_to_modelfile(migrator, database)
migrator.remove_model("model") migrator.remove_model('model')
def recreate_modelfile_table(migrator: Migrator, database: pw.Database): def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
@@ -102,8 +102,8 @@ def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database): def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
Model = migrator.orm["model"] Model = migrator.orm['model']
Modelfile = migrator.orm["modelfile"] Modelfile = migrator.orm['modelfile']
models = Model.select() models = Model.select()
@@ -112,13 +112,13 @@ def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
meta = json.loads(model.meta) meta = json.loads(model.meta)
modelfile_data = { modelfile_data = {
"title": model.name, 'title': model.name,
"desc": meta.get("description"), 'desc': meta.get('description'),
"imageUrl": meta.get("profile_image_url"), 'imageUrl': meta.get('profile_image_url'),
"content": meta.get("ollama", {}).get("modelfile"), 'content': meta.get('ollama', {}).get('modelfile'),
"suggestionPrompts": meta.get("suggestion_prompts"), 'suggestionPrompts': meta.get('suggestion_prompts'),
"categories": meta.get("categories"), 'categories': meta.get('categories'),
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"}, 'user': {k: v for k, v in meta.get('user', {}).items() if k != 'community'},
} }
# Insert the processed data back into the 'modelfile' table # Insert the processed data back into the 'modelfile' table

View File

@@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
# Adding fields settings to the 'user' table # Adding fields settings to the 'user' table
migrator.add_fields("user", settings=pw.TextField(null=True)) migrator.add_fields('user', settings=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
# Remove the settings field # Remove the settings field
migrator.remove_fields("user", "settings") migrator.remove_fields('user', 'settings')

View File

@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
updated_at = pw.BigIntegerField(null=False) updated_at = pw.BigIntegerField(null=False)
class Meta: class Meta:
table_name = "tool" table_name = 'tool'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_model("tool") migrator.remove_model('tool')

View File

@@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
# Adding fields info to the 'user' table # Adding fields info to the 'user' table
migrator.add_fields("user", info=pw.TextField(null=True)) migrator.add_fields('user', info=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
# Remove the settings field # Remove the settings field
migrator.remove_fields("user", "info") migrator.remove_fields('user', 'info')

View File

@@ -45,10 +45,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
created_at = pw.BigIntegerField(null=False) created_at = pw.BigIntegerField(null=False)
class Meta: class Meta:
table_name = "file" table_name = 'file'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_model("file") migrator.remove_model('file')

View File

@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
updated_at = pw.BigIntegerField(null=False) updated_at = pw.BigIntegerField(null=False)
class Meta: class Meta:
table_name = "function" table_name = 'function'
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_model("function") migrator.remove_model('function')

View File

@@ -36,14 +36,14 @@ with suppress(ImportError):
def migrate(migrator: Migrator, database: pw.Database, *, fake=False): def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
migrator.add_fields("tool", valves=pw.TextField(null=True)) migrator.add_fields('tool', valves=pw.TextField(null=True))
migrator.add_fields("function", valves=pw.TextField(null=True)) migrator.add_fields('function', valves=pw.TextField(null=True))
migrator.add_fields("function", is_active=pw.BooleanField(default=False)) migrator.add_fields('function', is_active=pw.BooleanField(default=False))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_fields("tool", "valves") migrator.remove_fields('tool', 'valves')
migrator.remove_fields("function", "valves") migrator.remove_fields('function', 'valves')
migrator.remove_fields("function", "is_active") migrator.remove_fields('function', 'is_active')

View File

@@ -33,7 +33,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
migrator.add_fields( migrator.add_fields(
"user", 'user',
oauth_sub=pw.TextField(null=True, unique=True), oauth_sub=pw.TextField(null=True, unique=True),
) )
@@ -41,4 +41,4 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_fields("user", "oauth_sub") migrator.remove_fields('user', 'oauth_sub')

View File

@@ -37,7 +37,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here.""" """Write your migrations here."""
migrator.add_fields( migrator.add_fields(
"function", 'function',
is_global=pw.BooleanField(default=False), is_global=pw.BooleanField(default=False),
) )
@@ -45,4 +45,4 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
def rollback(migrator: Migrator, database: pw.Database, *, fake=False): def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here.""" """Write your rollback migrations here."""
migrator.remove_fields("function", "is_global") migrator.remove_fields('function', 'is_global')

View File

@@ -10,13 +10,13 @@ from playhouse.shortcuts import ReconnectMixin
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None} db_state_default = {'closed': None, 'conn': None, 'ctx': None, 'transactions': None}
db_state = ContextVar("db_state", default=db_state_default.copy()) db_state = ContextVar('db_state', default=db_state_default.copy())
class PeeweeConnectionState(object): class PeeweeConnectionState(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__setattr__("_state", db_state) super().__setattr__('_state', db_state)
super().__init__(**kwargs) super().__init__(**kwargs)
def __setattr__(self, name, value): def __setattr__(self, name, value):
@@ -30,10 +30,10 @@ class PeeweeConnectionState(object):
class CustomReconnectMixin(ReconnectMixin): class CustomReconnectMixin(ReconnectMixin):
reconnect_errors = ( reconnect_errors = (
# psycopg2 # psycopg2
(OperationalError, "termin"), (OperationalError, 'termin'),
(InterfaceError, "closed"), (InterfaceError, 'closed'),
# peewee # peewee
(PeeWeeInterfaceError, "closed"), (PeeWeeInterfaceError, 'closed'),
) )
@@ -43,23 +43,21 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
def register_connection(db_url): def register_connection(db_url):
# Check if using SQLCipher protocol # Check if using SQLCipher protocol
if db_url.startswith("sqlite+sqlcipher://"): if db_url.startswith('sqlite+sqlcipher://'):
database_password = os.environ.get("DATABASE_PASSWORD") database_password = os.environ.get('DATABASE_PASSWORD')
if not database_password or database_password.strip() == "": if not database_password or database_password.strip() == '':
raise ValueError( raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
from playhouse.sqlcipher_ext import SqlCipherDatabase from playhouse.sqlcipher_ext import SqlCipherDatabase
# Parse the database path from SQLCipher URL # Parse the database path from SQLCipher URL
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite # Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
db_path = db_url.replace("sqlite+sqlcipher://", "") db_path = db_url.replace('sqlite+sqlcipher://', '')
# Use Peewee's native SqlCipherDatabase with encryption # Use Peewee's native SqlCipherDatabase with encryption
db = SqlCipherDatabase(db_path, passphrase=database_password) db = SqlCipherDatabase(db_path, passphrase=database_password)
db.autoconnect = True db.autoconnect = True
db.reuse_if_open = True db.reuse_if_open = True
log.info("Connected to encrypted SQLite database using SQLCipher") log.info('Connected to encrypted SQLite database using SQLCipher')
else: else:
# Standard database connection (existing logic) # Standard database connection (existing logic)
@@ -68,7 +66,7 @@ def register_connection(db_url):
# Enable autoconnect for SQLite databases, managed by Peewee # Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True db.autoconnect = True
db.reuse_if_open = True db.reuse_if_open = True
log.info("Connected to PostgreSQL database") log.info('Connected to PostgreSQL database')
# Get the connection details # Get the connection details
connection = parse(db_url, unquote_user=True, unquote_password=True) connection = parse(db_url, unquote_user=True, unquote_password=True)
@@ -80,7 +78,7 @@ def register_connection(db_url):
# Enable autoconnect for SQLite databases, managed by Peewee # Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True db.autoconnect = True
db.reuse_if_open = True db.reuse_if_open = True
log.info("Connected to SQLite database") log.info('Connected to SQLite database')
else: else:
raise ValueError("Unsupported database connection") raise ValueError('Unsupported database connection')
return db return db

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,7 @@ if config.config_file_name is not None:
fileConfig(config.config_file_name, disable_existing_loggers=False) fileConfig(config.config_file_name, disable_existing_loggers=False)
# Re-apply JSON formatter after fileConfig replaces handlers. # Re-apply JSON formatter after fileConfig replaces handlers.
if LOG_FORMAT == "json": if LOG_FORMAT == 'json':
from open_webui.env import JSONFormatter from open_webui.env import JSONFormatter
for handler in logging.root.handlers: for handler in logging.root.handlers:
@@ -36,7 +36,7 @@ target_metadata = Auth.metadata
DB_URL = DATABASE_URL DB_URL = DATABASE_URL
if DB_URL: if DB_URL:
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%")) config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%'))
def run_migrations_offline() -> None: def run_migrations_offline() -> None:
@@ -51,12 +51,12 @@ def run_migrations_offline() -> None:
script output. script output.
""" """
url = config.get_main_option("sqlalchemy.url") url = config.get_main_option('sqlalchemy.url')
context.configure( context.configure(
url=url, url=url,
target_metadata=target_metadata, target_metadata=target_metadata,
literal_binds=True, literal_binds=True,
dialect_opts={"paramstyle": "named"}, dialect_opts={'paramstyle': 'named'},
) )
with context.begin_transaction(): with context.begin_transaction():
@@ -71,15 +71,13 @@ def run_migrations_online() -> None:
""" """
# Handle SQLCipher URLs # Handle SQLCipher URLs
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"): if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'):
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "": if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '':
raise ValueError( raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
# Extract database path from SQLCipher URL # Extract database path from SQLCipher URL
db_path = DB_URL.replace("sqlite+sqlcipher://", "") db_path = DB_URL.replace('sqlite+sqlcipher://', '')
if db_path.startswith("/"): if db_path.startswith('/'):
db_path = db_path[1:] # Remove leading slash for relative paths db_path = db_path[1:] # Remove leading slash for relative paths
# Create a custom creator function that uses sqlcipher3 # Create a custom creator function that uses sqlcipher3
@@ -91,7 +89,7 @@ def run_migrations_online() -> None:
return conn return conn
connectable = create_engine( connectable = create_engine(
"sqlite://", # Dummy URL since we're using creator 'sqlite://', # Dummy URL since we're using creator
creator=create_sqlcipher_connection, creator=create_sqlcipher_connection,
echo=False, echo=False,
) )
@@ -99,7 +97,7 @@ def run_migrations_online() -> None:
# Standard database connection (existing logic) # Standard database connection (existing logic)
connectable = engine_from_config( connectable = engine_from_config(
config.get_section(config.config_ini_section, {}), config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.", prefix='sqlalchemy.',
poolclass=pool.NullPool, poolclass=pool.NullPool,
) )

View File

@@ -12,4 +12,4 @@ def get_existing_tables():
def get_revision_id(): def get_revision_id():
import uuid import uuid
return str(uuid.uuid4()).replace("-", "")[:12] return str(uuid.uuid4()).replace('-', '')[:12]

View File

@@ -9,38 +9,38 @@ Create Date: 2025-08-13 03:00:00.000000
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
revision = "018012973d35" revision = '018012973d35'
down_revision = "d31026856c01" down_revision = 'd31026856c01'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
# Chat table indexes # Chat table indexes
op.create_index("folder_id_idx", "chat", ["folder_id"]) op.create_index('folder_id_idx', 'chat', ['folder_id'])
op.create_index("user_id_pinned_idx", "chat", ["user_id", "pinned"]) op.create_index('user_id_pinned_idx', 'chat', ['user_id', 'pinned'])
op.create_index("user_id_archived_idx", "chat", ["user_id", "archived"]) op.create_index('user_id_archived_idx', 'chat', ['user_id', 'archived'])
op.create_index("updated_at_user_id_idx", "chat", ["updated_at", "user_id"]) op.create_index('updated_at_user_id_idx', 'chat', ['updated_at', 'user_id'])
op.create_index("folder_id_user_id_idx", "chat", ["folder_id", "user_id"]) op.create_index('folder_id_user_id_idx', 'chat', ['folder_id', 'user_id'])
# Tag table index # Tag table index
op.create_index("user_id_idx", "tag", ["user_id"]) op.create_index('user_id_idx', 'tag', ['user_id'])
# Function table index # Function table index
op.create_index("is_global_idx", "function", ["is_global"]) op.create_index('is_global_idx', 'function', ['is_global'])
def downgrade(): def downgrade():
# Chat table indexes # Chat table indexes
op.drop_index("folder_id_idx", table_name="chat") op.drop_index('folder_id_idx', table_name='chat')
op.drop_index("user_id_pinned_idx", table_name="chat") op.drop_index('user_id_pinned_idx', table_name='chat')
op.drop_index("user_id_archived_idx", table_name="chat") op.drop_index('user_id_archived_idx', table_name='chat')
op.drop_index("updated_at_user_id_idx", table_name="chat") op.drop_index('updated_at_user_id_idx', table_name='chat')
op.drop_index("folder_id_user_id_idx", table_name="chat") op.drop_index('folder_id_user_id_idx', table_name='chat')
# Tag table index # Tag table index
op.drop_index("user_id_idx", table_name="tag") op.drop_index('user_id_idx', table_name='tag')
# Function table index # Function table index
op.drop_index("is_global_idx", table_name="function") op.drop_index('is_global_idx', table_name='function')

View File

@@ -13,8 +13,8 @@ from sqlalchemy.engine.reflection import Inspector
import json import json
revision = "1af9b942657b" revision = '1af9b942657b'
down_revision = "242a2047eae0" down_revision = '242a2047eae0'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -25,43 +25,40 @@ def upgrade():
inspector = Inspector.from_engine(conn) inspector = Inspector.from_engine(conn)
# Clean up potential leftover temp table from previous failures # Clean up potential leftover temp table from previous failures
conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag")) conn.execute(sa.text('DROP TABLE IF EXISTS _alembic_tmp_tag'))
# Check if the 'tag' table exists # Check if the 'tag' table exists
tables = inspector.get_table_names() tables = inspector.get_table_names()
# Step 1: Modify Tag table using batch mode for SQLite support # Step 1: Modify Tag table using batch mode for SQLite support
if "tag" in tables: if 'tag' in tables:
# Get the current columns in the 'tag' table # Get the current columns in the 'tag' table
columns = [col["name"] for col in inspector.get_columns("tag")] columns = [col['name'] for col in inspector.get_columns('tag')]
# Get any existing unique constraints on the 'tag' table # Get any existing unique constraints on the 'tag' table
current_constraints = inspector.get_unique_constraints("tag") current_constraints = inspector.get_unique_constraints('tag')
with op.batch_alter_table("tag", schema=None) as batch_op: with op.batch_alter_table('tag', schema=None) as batch_op:
# Check if the unique constraint already exists # Check if the unique constraint already exists
if not any( if not any(constraint['name'] == 'uq_id_user_id' for constraint in current_constraints):
constraint["name"] == "uq_id_user_id"
for constraint in current_constraints
):
# Create unique constraint if it doesn't exist # Create unique constraint if it doesn't exist
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"]) batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id'])
# Check if the 'data' column exists before trying to drop it # Check if the 'data' column exists before trying to drop it
if "data" in columns: if 'data' in columns:
batch_op.drop_column("data") batch_op.drop_column('data')
# Check if the 'meta' column needs to be created # Check if the 'meta' column needs to be created
if "meta" not in columns: if 'meta' not in columns:
# Add the 'meta' column if it doesn't already exist # Add the 'meta' column if it doesn't already exist
batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True)) batch_op.add_column(sa.Column('meta', sa.JSON(), nullable=True))
tag = table( tag = table(
"tag", 'tag',
column("id", sa.String()), column('id', sa.String()),
column("name", sa.String()), column('name', sa.String()),
column("user_id", sa.String()), column('user_id', sa.String()),
column("meta", sa.JSON()), column('meta', sa.JSON()),
) )
# Step 2: Migrate tags # Step 2: Migrate tags
@@ -70,12 +67,12 @@ def upgrade():
tag_updates = {} tag_updates = {}
for row in result: for row in result:
new_id = row.name.replace(" ", "_").lower() new_id = row.name.replace(' ', '_').lower()
tag_updates[row.id] = new_id tag_updates[row.id] = new_id
for tag_id, new_tag_id in tag_updates.items(): for tag_id, new_tag_id in tag_updates.items():
print(f"Updating tag {tag_id} to {new_tag_id}") print(f'Updating tag {tag_id} to {new_tag_id}')
if new_tag_id == "pinned": if new_tag_id == 'pinned':
# delete tag # delete tag
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt) conn.execute(delete_stmt)
@@ -86,9 +83,7 @@ def upgrade():
if existing_tag_result: if existing_tag_result:
# Handle duplicate case: the new_tag_id already exists # Handle duplicate case: the new_tag_id already exists
print( print(f'Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates.')
f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates."
)
# Option 1: Delete the current tag if an update to new_tag_id would cause duplication # Option 1: Delete the current tag if an update to new_tag_id would cause duplication
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id) delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
conn.execute(delete_stmt) conn.execute(delete_stmt)
@@ -98,19 +93,15 @@ def upgrade():
conn.execute(update_stmt) conn.execute(update_stmt)
# Add columns `pinned` and `meta` to 'chat' # Add columns `pinned` and `meta` to 'chat'
op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True)) op.add_column('chat', sa.Column('pinned', sa.Boolean(), nullable=True))
op.add_column( op.add_column('chat', sa.Column('meta', sa.JSON(), nullable=False, server_default='{}'))
"chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
)
chatidtag = table( chatidtag = table('chatidtag', column('chat_id', sa.String()), column('tag_name', sa.String()))
"chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
)
chat = table( chat = table(
"chat", 'chat',
column("id", sa.String()), column('id', sa.String()),
column("pinned", sa.Boolean()), column('pinned', sa.Boolean()),
column("meta", sa.JSON()), column('meta', sa.JSON()),
) )
# Fetch existing tags # Fetch existing tags
@@ -120,29 +111,27 @@ def upgrade():
chat_updates = {} chat_updates = {}
for row in result: for row in result:
chat_id = row.chat_id chat_id = row.chat_id
tag_name = row.tag_name.replace(" ", "_").lower() tag_name = row.tag_name.replace(' ', '_').lower()
if tag_name == "pinned": if tag_name == 'pinned':
# Specifically handle 'pinned' tag # Specifically handle 'pinned' tag
if chat_id not in chat_updates: if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": True, "meta": {}} chat_updates[chat_id] = {'pinned': True, 'meta': {}}
else: else:
chat_updates[chat_id]["pinned"] = True chat_updates[chat_id]['pinned'] = True
else: else:
if chat_id not in chat_updates: if chat_id not in chat_updates:
chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}} chat_updates[chat_id] = {'pinned': False, 'meta': {'tags': [tag_name]}}
else: else:
tags = chat_updates[chat_id]["meta"].get("tags", []) tags = chat_updates[chat_id]['meta'].get('tags', [])
tags.append(tag_name) tags.append(tag_name)
chat_updates[chat_id]["meta"]["tags"] = list(set(tags)) chat_updates[chat_id]['meta']['tags'] = list(set(tags))
# Update chats based on accumulated changes # Update chats based on accumulated changes
for chat_id, updates in chat_updates.items(): for chat_id, updates in chat_updates.items():
update_stmt = sa.update(chat).where(chat.c.id == chat_id) update_stmt = sa.update(chat).where(chat.c.id == chat_id)
update_stmt = update_stmt.values( update_stmt = update_stmt.values(meta=updates.get('meta', {}), pinned=updates.get('pinned', False))
meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
)
conn.execute(update_stmt) conn.execute(update_stmt)
pass pass

View File

@@ -12,8 +12,8 @@ from sqlalchemy.sql import table, select, update
import json import json
revision = "242a2047eae0" revision = '242a2047eae0'
down_revision = "6a39f3d8e55c" down_revision = '6a39f3d8e55c'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -22,39 +22,37 @@ def upgrade():
conn = op.get_bind() conn = op.get_bind()
inspector = sa.inspect(conn) inspector = sa.inspect(conn)
columns = inspector.get_columns("chat") columns = inspector.get_columns('chat')
column_dict = {col["name"]: col for col in columns} column_dict = {col['name']: col for col in columns}
chat_column = column_dict.get("chat") chat_column = column_dict.get('chat')
old_chat_exists = "old_chat" in column_dict old_chat_exists = 'old_chat' in column_dict
if chat_column: if chat_column:
if isinstance(chat_column["type"], sa.Text): if isinstance(chat_column['type'], sa.Text):
print("Converting 'chat' column to JSON") print("Converting 'chat' column to JSON")
if old_chat_exists: if old_chat_exists:
print("Dropping old 'old_chat' column") print("Dropping old 'old_chat' column")
op.drop_column("chat", "old_chat") op.drop_column('chat', 'old_chat')
# Step 1: Rename current 'chat' column to 'old_chat' # Step 1: Rename current 'chat' column to 'old_chat'
print("Renaming 'chat' column to 'old_chat'") print("Renaming 'chat' column to 'old_chat'")
op.alter_column( op.alter_column('chat', 'chat', new_column_name='old_chat', existing_type=sa.Text())
"chat", "chat", new_column_name="old_chat", existing_type=sa.Text()
)
# Step 2: Add new 'chat' column of type JSON # Step 2: Add new 'chat' column of type JSON
print("Adding new 'chat' column of type JSON") print("Adding new 'chat' column of type JSON")
op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True)) op.add_column('chat', sa.Column('chat', sa.JSON(), nullable=True))
else: else:
# If the column is already JSON, no need to do anything # If the column is already JSON, no need to do anything
pass pass
# Step 3: Migrate data from 'old_chat' to 'chat' # Step 3: Migrate data from 'old_chat' to 'chat'
chat_table = table( chat_table = table(
"chat", 'chat',
sa.Column("id", sa.String(), primary_key=True), sa.Column('id', sa.String(), primary_key=True),
sa.Column("old_chat", sa.Text()), sa.Column('old_chat', sa.Text()),
sa.Column("chat", sa.JSON()), sa.Column('chat', sa.JSON()),
) )
# - Selecting all data from the table # - Selecting all data from the table
@@ -67,41 +65,33 @@ def upgrade():
except json.JSONDecodeError: except json.JSONDecodeError:
json_data = None # Handle cases where the text cannot be converted to JSON json_data = None # Handle cases where the text cannot be converted to JSON
connection.execute( connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(chat=json_data))
sa.update(chat_table)
.where(chat_table.c.id == row.id)
.values(chat=json_data)
)
# Step 4: Drop 'old_chat' column # Step 4: Drop 'old_chat' column
print("Dropping 'old_chat' column") print("Dropping 'old_chat' column")
op.drop_column("chat", "old_chat") op.drop_column('chat', 'old_chat')
def downgrade(): def downgrade():
# Step 1: Add 'old_chat' column back as Text # Step 1: Add 'old_chat' column back as Text
op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True)) op.add_column('chat', sa.Column('old_chat', sa.Text(), nullable=True))
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat' # Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
chat_table = table( chat_table = table(
"chat", 'chat',
sa.Column("id", sa.String(), primary_key=True), sa.Column('id', sa.String(), primary_key=True),
sa.Column("chat", sa.JSON()), sa.Column('chat', sa.JSON()),
sa.Column("old_chat", sa.Text()), sa.Column('old_chat', sa.Text()),
) )
connection = op.get_bind() connection = op.get_bind()
results = connection.execute(select(chat_table.c.id, chat_table.c.chat)) results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
for row in results: for row in results:
text_data = json.dumps(row.chat) if row.chat is not None else None text_data = json.dumps(row.chat) if row.chat is not None else None
connection.execute( connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(old_chat=text_data))
sa.update(chat_table)
.where(chat_table.c.id == row.id)
.values(old_chat=text_data)
)
# Step 3: Remove the new 'chat' JSON column # Step 3: Remove the new 'chat' JSON column
op.drop_column("chat", "chat") op.drop_column('chat', 'chat')
# Step 4: Rename 'old_chat' back to 'chat' # Step 4: Rename 'old_chat' back to 'chat'
op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text()) op.alter_column('chat', 'old_chat', new_column_name='chat', existing_type=sa.Text())

View File

@@ -13,19 +13,19 @@ import sqlalchemy as sa
import open_webui.internal.db import open_webui.internal.db
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "2f1211949ecc" revision: str = '2f1211949ecc'
down_revision: Union[str, None] = "37f288994c47" down_revision: Union[str, None] = '37f288994c47'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# New columns to be added to channel_member table # New columns to be added to channel_member table
op.add_column("channel_member", sa.Column("status", sa.Text(), nullable=True)) op.add_column('channel_member', sa.Column('status', sa.Text(), nullable=True))
op.add_column( op.add_column(
"channel_member", 'channel_member',
sa.Column( sa.Column(
"is_active", 'is_active',
sa.Boolean(), sa.Boolean(),
nullable=False, nullable=False,
default=True, default=True,
@@ -34,9 +34,9 @@ def upgrade() -> None:
) )
op.add_column( op.add_column(
"channel_member", 'channel_member',
sa.Column( sa.Column(
"is_channel_muted", 'is_channel_muted',
sa.Boolean(), sa.Boolean(),
nullable=False, nullable=False,
default=False, default=False,
@@ -44,9 +44,9 @@ def upgrade() -> None:
), ),
) )
op.add_column( op.add_column(
"channel_member", 'channel_member',
sa.Column( sa.Column(
"is_channel_pinned", 'is_channel_pinned',
sa.Boolean(), sa.Boolean(),
nullable=False, nullable=False,
default=False, default=False,
@@ -54,49 +54,41 @@ def upgrade() -> None:
), ),
) )
op.add_column("channel_member", sa.Column("data", sa.JSON(), nullable=True)) op.add_column('channel_member', sa.Column('data', sa.JSON(), nullable=True))
op.add_column("channel_member", sa.Column("meta", sa.JSON(), nullable=True)) op.add_column('channel_member', sa.Column('meta', sa.JSON(), nullable=True))
op.add_column( op.add_column('channel_member', sa.Column('joined_at', sa.BigInteger(), nullable=False))
"channel_member", sa.Column("joined_at", sa.BigInteger(), nullable=False) op.add_column('channel_member', sa.Column('left_at', sa.BigInteger(), nullable=True))
)
op.add_column(
"channel_member", sa.Column("left_at", sa.BigInteger(), nullable=True)
)
op.add_column( op.add_column('channel_member', sa.Column('last_read_at', sa.BigInteger(), nullable=True))
"channel_member", sa.Column("last_read_at", sa.BigInteger(), nullable=True)
)
op.add_column( op.add_column('channel_member', sa.Column('updated_at', sa.BigInteger(), nullable=True))
"channel_member", sa.Column("updated_at", sa.BigInteger(), nullable=True)
)
# New columns to be added to message table # New columns to be added to message table
op.add_column( op.add_column(
"message", 'message',
sa.Column( sa.Column(
"is_pinned", 'is_pinned',
sa.Boolean(), sa.Boolean(),
nullable=False, nullable=False,
default=False, default=False,
server_default=sa.sql.expression.false(), server_default=sa.sql.expression.false(),
), ),
) )
op.add_column("message", sa.Column("pinned_at", sa.BigInteger(), nullable=True)) op.add_column('message', sa.Column('pinned_at', sa.BigInteger(), nullable=True))
op.add_column("message", sa.Column("pinned_by", sa.Text(), nullable=True)) op.add_column('message', sa.Column('pinned_by', sa.Text(), nullable=True))
def downgrade() -> None: def downgrade() -> None:
op.drop_column("channel_member", "updated_at") op.drop_column('channel_member', 'updated_at')
op.drop_column("channel_member", "last_read_at") op.drop_column('channel_member', 'last_read_at')
op.drop_column("channel_member", "meta") op.drop_column('channel_member', 'meta')
op.drop_column("channel_member", "data") op.drop_column('channel_member', 'data')
op.drop_column("channel_member", "is_channel_pinned") op.drop_column('channel_member', 'is_channel_pinned')
op.drop_column("channel_member", "is_channel_muted") op.drop_column('channel_member', 'is_channel_muted')
op.drop_column("message", "pinned_by") op.drop_column('message', 'pinned_by')
op.drop_column("message", "pinned_at") op.drop_column('message', 'pinned_at')
op.drop_column("message", "is_pinned") op.drop_column('message', 'is_pinned')

View File

@@ -12,8 +12,8 @@ import uuid
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
revision: str = "374d2f66af06" revision: str = '374d2f66af06'
down_revision: Union[str, None] = "c440947495f3" down_revision: Union[str, None] = 'c440947495f3'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -26,13 +26,13 @@ def upgrade() -> None:
# We need to assume the OLD structure. # We need to assume the OLD structure.
old_prompt_table = sa.table( old_prompt_table = sa.table(
"prompt", 'prompt',
sa.column("command", sa.Text()), sa.column('command', sa.Text()),
sa.column("user_id", sa.Text()), sa.column('user_id', sa.Text()),
sa.column("title", sa.Text()), sa.column('title', sa.Text()),
sa.column("content", sa.Text()), sa.column('content', sa.Text()),
sa.column("timestamp", sa.BigInteger()), sa.column('timestamp', sa.BigInteger()),
sa.column("access_control", sa.JSON()), sa.column('access_control', sa.JSON()),
) )
# Check if table exists/read data # Check if table exists/read data
@@ -53,61 +53,61 @@ def upgrade() -> None:
# Step 2: Create new prompt table with 'id' as PRIMARY KEY # Step 2: Create new prompt table with 'id' as PRIMARY KEY
op.create_table( op.create_table(
"prompt_new", 'prompt_new',
sa.Column("id", sa.Text(), primary_key=True), sa.Column('id', sa.Text(), primary_key=True),
sa.Column("command", sa.String(), unique=True, index=True), sa.Column('command', sa.String(), unique=True, index=True),
sa.Column("user_id", sa.String(), nullable=False), sa.Column('user_id', sa.String(), nullable=False),
sa.Column("name", sa.Text(), nullable=False), sa.Column('name', sa.Text(), nullable=False),
sa.Column("content", sa.Text(), nullable=False), sa.Column('content', sa.Text(), nullable=False),
sa.Column("data", sa.JSON(), nullable=True), sa.Column('data', sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True), sa.Column('access_control', sa.JSON(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="1"), sa.Column('is_active', sa.Boolean(), nullable=False, server_default='1'),
sa.Column("version_id", sa.Text(), nullable=True), sa.Column('version_id', sa.Text(), nullable=True),
sa.Column("tags", sa.JSON(), nullable=True), sa.Column('tags', sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False), sa.Column('updated_at', sa.BigInteger(), nullable=False),
) )
# Step 3: Create prompt_history table # Step 3: Create prompt_history table
op.create_table( op.create_table(
"prompt_history", 'prompt_history',
sa.Column("id", sa.Text(), primary_key=True), sa.Column('id', sa.Text(), primary_key=True),
sa.Column("prompt_id", sa.Text(), nullable=False, index=True), sa.Column('prompt_id', sa.Text(), nullable=False, index=True),
sa.Column("parent_id", sa.Text(), nullable=True), sa.Column('parent_id', sa.Text(), nullable=True),
sa.Column("snapshot", sa.JSON(), nullable=False), sa.Column('snapshot', sa.JSON(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=False), sa.Column('user_id', sa.Text(), nullable=False),
sa.Column("commit_message", sa.Text(), nullable=True), sa.Column('commit_message', sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
) )
# Step 4: Migrate data # Step 4: Migrate data
prompt_new_table = sa.table( prompt_new_table = sa.table(
"prompt_new", 'prompt_new',
sa.column("id", sa.Text()), sa.column('id', sa.Text()),
sa.column("command", sa.String()), sa.column('command', sa.String()),
sa.column("user_id", sa.String()), sa.column('user_id', sa.String()),
sa.column("name", sa.Text()), sa.column('name', sa.Text()),
sa.column("content", sa.Text()), sa.column('content', sa.Text()),
sa.column("data", sa.JSON()), sa.column('data', sa.JSON()),
sa.column("meta", sa.JSON()), sa.column('meta', sa.JSON()),
sa.column("access_control", sa.JSON()), sa.column('access_control', sa.JSON()),
sa.column("is_active", sa.Boolean()), sa.column('is_active', sa.Boolean()),
sa.column("version_id", sa.Text()), sa.column('version_id', sa.Text()),
sa.column("tags", sa.JSON()), sa.column('tags', sa.JSON()),
sa.column("created_at", sa.BigInteger()), sa.column('created_at', sa.BigInteger()),
sa.column("updated_at", sa.BigInteger()), sa.column('updated_at', sa.BigInteger()),
) )
prompt_history_table = sa.table( prompt_history_table = sa.table(
"prompt_history", 'prompt_history',
sa.column("id", sa.Text()), sa.column('id', sa.Text()),
sa.column("prompt_id", sa.Text()), sa.column('prompt_id', sa.Text()),
sa.column("parent_id", sa.Text()), sa.column('parent_id', sa.Text()),
sa.column("snapshot", sa.JSON()), sa.column('snapshot', sa.JSON()),
sa.column("user_id", sa.Text()), sa.column('user_id', sa.Text()),
sa.column("commit_message", sa.Text()), sa.column('commit_message', sa.Text()),
sa.column("created_at", sa.BigInteger()), sa.column('created_at', sa.BigInteger()),
) )
for row in existing_prompts: for row in existing_prompts:
@@ -120,7 +120,7 @@ def upgrade() -> None:
new_uuid = str(uuid.uuid4()) new_uuid = str(uuid.uuid4())
history_uuid = str(uuid.uuid4()) history_uuid = str(uuid.uuid4())
clean_command = command[1:] if command and command.startswith("/") else command clean_command = command[1:] if command and command.startswith('/') else command
# Insert into prompt_new # Insert into prompt_new
conn.execute( conn.execute(
@@ -148,12 +148,12 @@ def upgrade() -> None:
prompt_id=new_uuid, prompt_id=new_uuid,
parent_id=None, parent_id=None,
snapshot={ snapshot={
"name": title, 'name': title,
"content": content, 'content': content,
"command": clean_command, 'command': clean_command,
"data": {}, 'data': {},
"meta": {}, 'meta': {},
"access_control": access_control, 'access_control': access_control,
}, },
user_id=user_id, user_id=user_id,
commit_message=None, commit_message=None,
@@ -162,8 +162,8 @@ def upgrade() -> None:
) )
# Step 5: Replace old table with new one # Step 5: Replace old table with new one
op.drop_table("prompt") op.drop_table('prompt')
op.rename_table("prompt_new", "prompt") op.rename_table('prompt_new', 'prompt')
def downgrade() -> None: def downgrade() -> None:
@@ -171,13 +171,13 @@ def downgrade() -> None:
# Step 1: Read new data # Step 1: Read new data
prompt_table = sa.table( prompt_table = sa.table(
"prompt", 'prompt',
sa.column("command", sa.String()), sa.column('command', sa.String()),
sa.column("name", sa.Text()), sa.column('name', sa.Text()),
sa.column("created_at", sa.BigInteger()), sa.column('created_at', sa.BigInteger()),
sa.column("user_id", sa.Text()), sa.column('user_id', sa.Text()),
sa.column("content", sa.Text()), sa.column('content', sa.Text()),
sa.column("access_control", sa.JSON()), sa.column('access_control', sa.JSON()),
) )
try: try:
@@ -195,31 +195,31 @@ def downgrade() -> None:
current_data = [] current_data = []
# Step 2: Drop history and table # Step 2: Drop history and table
op.drop_table("prompt_history") op.drop_table('prompt_history')
op.drop_table("prompt") op.drop_table('prompt')
# Step 3: Recreate old table (command as PK?) # Step 3: Recreate old table (command as PK?)
# Assuming old schema: # Assuming old schema:
op.create_table( op.create_table(
"prompt", 'prompt',
sa.Column("command", sa.String(), primary_key=True), sa.Column('command', sa.String(), primary_key=True),
sa.Column("user_id", sa.String()), sa.Column('user_id', sa.String()),
sa.Column("title", sa.Text()), sa.Column('title', sa.Text()),
sa.Column("content", sa.Text()), sa.Column('content', sa.Text()),
sa.Column("timestamp", sa.BigInteger()), sa.Column('timestamp', sa.BigInteger()),
sa.Column("access_control", sa.JSON()), sa.Column('access_control', sa.JSON()),
sa.Column("id", sa.Integer(), nullable=True), sa.Column('id', sa.Integer(), nullable=True),
) )
# Step 4: Restore data # Step 4: Restore data
old_prompt_table = sa.table( old_prompt_table = sa.table(
"prompt", 'prompt',
sa.column("command", sa.String()), sa.column('command', sa.String()),
sa.column("user_id", sa.String()), sa.column('user_id', sa.String()),
sa.column("title", sa.Text()), sa.column('title', sa.Text()),
sa.column("content", sa.Text()), sa.column('content', sa.Text()),
sa.column("timestamp", sa.BigInteger()), sa.column('timestamp', sa.BigInteger()),
sa.column("access_control", sa.JSON()), sa.column('access_control', sa.JSON()),
) )
for row in current_data: for row in current_data:
@@ -231,9 +231,7 @@ def downgrade() -> None:
access_control = row[5] access_control = row[5]
# Restore leading / # Restore leading /
old_command = ( old_command = '/' + command if command and not command.startswith('/') else command
"/" + command if command and not command.startswith("/") else command
)
conn.execute( conn.execute(
sa.insert(old_prompt_table).values( sa.insert(old_prompt_table).values(

View File

@@ -9,8 +9,8 @@ Create Date: 2024-12-30 03:00:00.000000
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
revision = "3781e22d8b01" revision = '3781e22d8b01'
down_revision = "7826ab40b532" down_revision = '7826ab40b532'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -18,9 +18,9 @@ depends_on = None
def upgrade(): def upgrade():
# Add 'type' column to the 'channel' table # Add 'type' column to the 'channel' table
op.add_column( op.add_column(
"channel", 'channel',
sa.Column( sa.Column(
"type", 'type',
sa.Text(), sa.Text(),
nullable=True, nullable=True,
), ),
@@ -28,43 +28,31 @@ def upgrade():
# Add 'parent_id' column to the 'message' table for threads # Add 'parent_id' column to the 'message' table for threads
op.add_column( op.add_column(
"message", 'message',
sa.Column("parent_id", sa.Text(), nullable=True), sa.Column('parent_id', sa.Text(), nullable=True),
) )
op.create_table( op.create_table(
"message_reaction", 'message_reaction',
sa.Column( sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Unique reaction ID
"id", sa.Text(), nullable=False, primary_key=True, unique=True sa.Column('user_id', sa.Text(), nullable=False), # User who reacted
), # Unique reaction ID sa.Column('message_id', sa.Text(), nullable=False), # Message that was reacted to
sa.Column("user_id", sa.Text(), nullable=False), # User who reacted sa.Column('name', sa.Text(), nullable=False), # Reaction name (e.g. "thumbs_up")
sa.Column( sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the reaction was added
"message_id", sa.Text(), nullable=False
), # Message that was reacted to
sa.Column(
"name", sa.Text(), nullable=False
), # Reaction name (e.g. "thumbs_up")
sa.Column(
"created_at", sa.BigInteger(), nullable=True
), # Timestamp of when the reaction was added
) )
op.create_table( op.create_table(
"channel_member", 'channel_member',
sa.Column( sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Record ID for the membership row
"id", sa.Text(), nullable=False, primary_key=True, unique=True sa.Column('channel_id', sa.Text(), nullable=False), # Associated channel
), # Record ID for the membership row sa.Column('user_id', sa.Text(), nullable=False), # Associated user
sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the user joined the channel
sa.Column("user_id", sa.Text(), nullable=False), # Associated user
sa.Column(
"created_at", sa.BigInteger(), nullable=True
), # Timestamp of when the user joined the channel
) )
def downgrade(): def downgrade():
# Revert 'type' column addition to the 'channel' table # Revert 'type' column addition to the 'channel' table
op.drop_column("channel", "type") op.drop_column('channel', 'type')
op.drop_column("message", "parent_id") op.drop_column('message', 'parent_id')
op.drop_table("message_reaction") op.drop_table('message_reaction')
op.drop_table("channel_member") op.drop_table('channel_member')

View File

@@ -15,8 +15,8 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "37f288994c47" revision: str = '37f288994c47'
down_revision: Union[str, None] = "a5c220713937" down_revision: Union[str, None] = 'a5c220713937'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -24,50 +24,48 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# 1. Create new table # 1. Create new table
op.create_table( op.create_table(
"group_member", 'group_member',
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False), sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column( sa.Column(
"group_id", 'group_id',
sa.Text(), sa.Text(),
sa.ForeignKey("group.id", ondelete="CASCADE"), sa.ForeignKey('group.id', ondelete='CASCADE'),
nullable=False, nullable=False,
), ),
sa.Column( sa.Column(
"user_id", 'user_id',
sa.Text(), sa.Text(),
sa.ForeignKey("user.id", ondelete="CASCADE"), sa.ForeignKey('user.id', ondelete='CASCADE'),
nullable=False, nullable=False,
), ),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"), sa.UniqueConstraint('group_id', 'user_id', name='uq_group_member_group_user'),
) )
connection = op.get_bind() connection = op.get_bind()
# 2. Read existing group with user_ids JSON column # 2. Read existing group with user_ids JSON column
group_table = sa.Table( group_table = sa.Table(
"group", 'group',
sa.MetaData(), sa.MetaData(),
sa.Column("id", sa.Text()), sa.Column('id', sa.Text()),
sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG sa.Column('user_ids', sa.JSON()), # JSON stored as text in SQLite + PG
) )
results = connection.execute( results = connection.execute(sa.select(group_table.c.id, group_table.c.user_ids)).fetchall()
sa.select(group_table.c.id, group_table.c.user_ids)
).fetchall()
print(results) print(results)
# 3. Insert members into group_member table # 3. Insert members into group_member table
gm_table = sa.Table( gm_table = sa.Table(
"group_member", 'group_member',
sa.MetaData(), sa.MetaData(),
sa.Column("id", sa.Text()), sa.Column('id', sa.Text()),
sa.Column("group_id", sa.Text()), sa.Column('group_id', sa.Text()),
sa.Column("user_id", sa.Text()), sa.Column('user_id', sa.Text()),
sa.Column("created_at", sa.BigInteger()), sa.Column('created_at', sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()), sa.Column('updated_at', sa.BigInteger()),
) )
now = int(time.time()) now = int(time.time())
@@ -86,11 +84,11 @@ def upgrade() -> None:
rows = [ rows = [
{ {
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"group_id": group_id, 'group_id': group_id,
"user_id": uid, 'user_id': uid,
"created_at": now, 'created_at': now,
"updated_at": now, 'updated_at': now,
} }
for uid in user_ids for uid in user_ids
] ]
@@ -99,47 +97,41 @@ def upgrade() -> None:
connection.execute(gm_table.insert(), rows) connection.execute(gm_table.insert(), rows)
# 4. Optionally drop the old column # 4. Optionally drop the old column
with op.batch_alter_table("group") as batch: with op.batch_alter_table('group') as batch:
batch.drop_column("user_ids") batch.drop_column('user_ids')
def downgrade(): def downgrade():
# Reverse: restore user_ids column # Reverse: restore user_ids column
with op.batch_alter_table("group") as batch: with op.batch_alter_table('group') as batch:
batch.add_column(sa.Column("user_ids", sa.JSON())) batch.add_column(sa.Column('user_ids', sa.JSON()))
connection = op.get_bind() connection = op.get_bind()
gm_table = sa.Table( gm_table = sa.Table(
"group_member", 'group_member',
sa.MetaData(), sa.MetaData(),
sa.Column("group_id", sa.Text()), sa.Column('group_id', sa.Text()),
sa.Column("user_id", sa.Text()), sa.Column('user_id', sa.Text()),
sa.Column("created_at", sa.BigInteger()), sa.Column('created_at', sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()), sa.Column('updated_at', sa.BigInteger()),
) )
group_table = sa.Table( group_table = sa.Table(
"group", 'group',
sa.MetaData(), sa.MetaData(),
sa.Column("id", sa.Text()), sa.Column('id', sa.Text()),
sa.Column("user_ids", sa.JSON()), sa.Column('user_ids', sa.JSON()),
) )
# Build JSON arrays again # Build JSON arrays again
results = connection.execute(sa.select(group_table.c.id)).fetchall() results = connection.execute(sa.select(group_table.c.id)).fetchall()
for (group_id,) in results: for (group_id,) in results:
members = connection.execute( members = connection.execute(sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)).fetchall()
sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)
).fetchall()
member_ids = [m[0] for m in members] member_ids = [m[0] for m in members]
connection.execute( connection.execute(group_table.update().where(group_table.c.id == group_id).values(user_ids=member_ids))
group_table.update()
.where(group_table.c.id == group_id)
.values(user_ids=member_ids)
)
# Drop the new table # Drop the new table
op.drop_table("group_member") op.drop_table('group_member')

View File

@@ -12,8 +12,8 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "38d63c18f30f" revision: str = '38d63c18f30f'
down_revision: Union[str, None] = "3af16a1c9fb6" down_revision: Union[str, None] = '3af16a1c9fb6'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -21,59 +21,55 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint) # Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint)
inspector = sa.inspect(op.get_bind()) inspector = sa.inspect(op.get_bind())
columns = inspector.get_columns("user") columns = inspector.get_columns('user')
pk_columns = inspector.get_pk_constraint("user")["constrained_columns"] pk_columns = inspector.get_pk_constraint('user')['constrained_columns']
id_column = next((col for col in columns if col["name"] == "id"), None) id_column = next((col for col in columns if col['name'] == 'id'), None)
if id_column and not id_column.get("unique", False): if id_column and not id_column.get('unique', False):
unique_constraints = inspector.get_unique_constraints("user") unique_constraints = inspector.get_unique_constraints('user')
unique_columns = {tuple(u["column_names"]) for u in unique_constraints} unique_columns = {tuple(u['column_names']) for u in unique_constraints}
with op.batch_alter_table("user") as batch_op: with op.batch_alter_table('user') as batch_op:
# If primary key is wrong, drop it # If primary key is wrong, drop it
if pk_columns and pk_columns != ["id"]: if pk_columns and pk_columns != ['id']:
batch_op.drop_constraint( batch_op.drop_constraint(inspector.get_pk_constraint('user')['name'], type_='primary')
inspector.get_pk_constraint("user")["name"], type_="primary"
)
# Add unique constraint if missing # Add unique constraint if missing
if ("id",) not in unique_columns: if ('id',) not in unique_columns:
batch_op.create_unique_constraint("uq_user_id", ["id"]) batch_op.create_unique_constraint('uq_user_id', ['id'])
# Re-create correct primary key # Re-create correct primary key
batch_op.create_primary_key("pk_user_id", ["id"]) batch_op.create_primary_key('pk_user_id', ['id'])
# Create oauth_session table # Create oauth_session table
op.create_table( op.create_table(
"oauth_session", 'oauth_session',
sa.Column("id", sa.Text(), primary_key=True, nullable=False, unique=True), sa.Column('id', sa.Text(), primary_key=True, nullable=False, unique=True),
sa.Column( sa.Column(
"user_id", 'user_id',
sa.Text(), sa.Text(),
sa.ForeignKey("user.id", ondelete="CASCADE"), sa.ForeignKey('user.id', ondelete='CASCADE'),
nullable=False, nullable=False,
), ),
sa.Column("provider", sa.Text(), nullable=False), sa.Column('provider', sa.Text(), nullable=False),
sa.Column("token", sa.Text(), nullable=False), sa.Column('token', sa.Text(), nullable=False),
sa.Column("expires_at", sa.BigInteger(), nullable=False), sa.Column('expires_at', sa.BigInteger(), nullable=False),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False), sa.Column('updated_at', sa.BigInteger(), nullable=False),
) )
# Create indexes for better performance # Create indexes for better performance
op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"]) op.create_index('idx_oauth_session_user_id', 'oauth_session', ['user_id'])
op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"]) op.create_index('idx_oauth_session_expires_at', 'oauth_session', ['expires_at'])
op.create_index( op.create_index('idx_oauth_session_user_provider', 'oauth_session', ['user_id', 'provider'])
"idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"]
)
def downgrade() -> None: def downgrade() -> None:
# Drop indexes first # Drop indexes first
op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session") op.drop_index('idx_oauth_session_user_provider', table_name='oauth_session')
op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session") op.drop_index('idx_oauth_session_expires_at', table_name='oauth_session')
op.drop_index("idx_oauth_session_user_id", table_name="oauth_session") op.drop_index('idx_oauth_session_user_id', table_name='oauth_session')
# Drop the table # Drop the table
op.drop_table("oauth_session") op.drop_table('oauth_session')

View File

@@ -13,8 +13,8 @@ from sqlalchemy.engine.reflection import Inspector
import json import json
revision = "3ab32c4b8f59" revision = '3ab32c4b8f59'
down_revision = "1af9b942657b" down_revision = '1af9b942657b'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -24,58 +24,55 @@ def upgrade():
inspector = Inspector.from_engine(conn) inspector = Inspector.from_engine(conn)
# Inspecting the 'tag' table constraints and structure # Inspecting the 'tag' table constraints and structure
existing_pk = inspector.get_pk_constraint("tag") existing_pk = inspector.get_pk_constraint('tag')
unique_constraints = inspector.get_unique_constraints("tag") unique_constraints = inspector.get_unique_constraints('tag')
existing_indexes = inspector.get_indexes("tag") existing_indexes = inspector.get_indexes('tag')
print(f"Primary Key: {existing_pk}") print(f'Primary Key: {existing_pk}')
print(f"Unique Constraints: {unique_constraints}") print(f'Unique Constraints: {unique_constraints}')
print(f"Indexes: {existing_indexes}") print(f'Indexes: {existing_indexes}')
with op.batch_alter_table("tag", schema=None) as batch_op: with op.batch_alter_table('tag', schema=None) as batch_op:
# Drop existing primary key constraint if it exists # Drop existing primary key constraint if it exists
if existing_pk and existing_pk.get("constrained_columns"): if existing_pk and existing_pk.get('constrained_columns'):
pk_name = existing_pk.get("name") pk_name = existing_pk.get('name')
if pk_name: if pk_name:
print(f"Dropping primary key constraint: {pk_name}") print(f'Dropping primary key constraint: {pk_name}')
batch_op.drop_constraint(pk_name, type_="primary") batch_op.drop_constraint(pk_name, type_='primary')
# Now create the new primary key with the combination of 'id' and 'user_id' # Now create the new primary key with the combination of 'id' and 'user_id'
print("Creating new primary key with 'id' and 'user_id'.") print("Creating new primary key with 'id' and 'user_id'.")
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"]) batch_op.create_primary_key('pk_id_user_id', ['id', 'user_id'])
# Drop unique constraints that could conflict with the new primary key # Drop unique constraints that could conflict with the new primary key
for constraint in unique_constraints: for constraint in unique_constraints:
if ( if (
constraint["name"] == "uq_id_user_id" constraint['name'] == 'uq_id_user_id'
): # Adjust this name according to what is actually returned by the inspector ): # Adjust this name according to what is actually returned by the inspector
print(f"Dropping unique constraint: {constraint['name']}") print(f'Dropping unique constraint: {constraint["name"]}')
batch_op.drop_constraint(constraint["name"], type_="unique") batch_op.drop_constraint(constraint['name'], type_='unique')
for index in existing_indexes: for index in existing_indexes:
if index["unique"]: if index['unique']:
if not any( if not any(constraint['name'] == index['name'] for constraint in unique_constraints):
constraint["name"] == index["name"]
for constraint in unique_constraints
):
# You are attempting to drop unique indexes # You are attempting to drop unique indexes
print(f"Dropping unique index: {index['name']}") print(f'Dropping unique index: {index["name"]}')
batch_op.drop_index(index["name"]) batch_op.drop_index(index['name'])
def downgrade(): def downgrade():
conn = op.get_bind() conn = op.get_bind()
inspector = Inspector.from_engine(conn) inspector = Inspector.from_engine(conn)
current_pk = inspector.get_pk_constraint("tag") current_pk = inspector.get_pk_constraint('tag')
with op.batch_alter_table("tag", schema=None) as batch_op: with op.batch_alter_table('tag', schema=None) as batch_op:
# Drop the current primary key first, if it matches the one we know we added in upgrade # Drop the current primary key first, if it matches the one we know we added in upgrade
if current_pk and "pk_id_user_id" == current_pk.get("name"): if current_pk and 'pk_id_user_id' == current_pk.get('name'):
batch_op.drop_constraint("pk_id_user_id", type_="primary") batch_op.drop_constraint('pk_id_user_id', type_='primary')
# Restore the original primary key # Restore the original primary key
batch_op.create_primary_key("pk_id", ["id"]) batch_op.create_primary_key('pk_id', ['id'])
# Since primary key on just 'id' is restored, we now add back any unique constraints if necessary # Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"]) batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id'])

View File

@@ -12,21 +12,21 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "3af16a1c9fb6" revision: str = '3af16a1c9fb6'
down_revision: Union[str, None] = "018012973d35" down_revision: Union[str, None] = '018012973d35'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
op.add_column("user", sa.Column("username", sa.String(length=50), nullable=True)) op.add_column('user', sa.Column('username', sa.String(length=50), nullable=True))
op.add_column("user", sa.Column("bio", sa.Text(), nullable=True)) op.add_column('user', sa.Column('bio', sa.Text(), nullable=True))
op.add_column("user", sa.Column("gender", sa.Text(), nullable=True)) op.add_column('user', sa.Column('gender', sa.Text(), nullable=True))
op.add_column("user", sa.Column("date_of_birth", sa.Date(), nullable=True)) op.add_column('user', sa.Column('date_of_birth', sa.Date(), nullable=True))
def downgrade() -> None: def downgrade() -> None:
op.drop_column("user", "username") op.drop_column('user', 'username')
op.drop_column("user", "bio") op.drop_column('user', 'bio')
op.drop_column("user", "gender") op.drop_column('user', 'gender')
op.drop_column("user", "date_of_birth") op.drop_column('user', 'date_of_birth')

View File

@@ -18,38 +18,38 @@ import json
import uuid import uuid
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "3e0e00844bb0" revision: str = '3e0e00844bb0'
down_revision: Union[str, None] = "90ef40d4714e" down_revision: Union[str, None] = '90ef40d4714e'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
op.create_table( op.create_table(
"knowledge_file", 'knowledge_file',
sa.Column("id", sa.Text(), primary_key=True), sa.Column('id', sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False), sa.Column('user_id', sa.Text(), nullable=False),
sa.Column( sa.Column(
"knowledge_id", 'knowledge_id',
sa.Text(), sa.Text(),
sa.ForeignKey("knowledge.id", ondelete="CASCADE"), sa.ForeignKey('knowledge.id', ondelete='CASCADE'),
nullable=False, nullable=False,
), ),
sa.Column( sa.Column(
"file_id", 'file_id',
sa.Text(), sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"), sa.ForeignKey('file.id', ondelete='CASCADE'),
nullable=False, nullable=False,
), ),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False), sa.Column('updated_at', sa.BigInteger(), nullable=False),
# indexes # indexes
sa.Index("ix_knowledge_file_knowledge_id", "knowledge_id"), sa.Index('ix_knowledge_file_knowledge_id', 'knowledge_id'),
sa.Index("ix_knowledge_file_file_id", "file_id"), sa.Index('ix_knowledge_file_file_id', 'file_id'),
sa.Index("ix_knowledge_file_user_id", "user_id"), sa.Index('ix_knowledge_file_user_id', 'user_id'),
# unique constraints # unique constraints
sa.UniqueConstraint( sa.UniqueConstraint(
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file" 'knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file'
), # prevent duplicate entries ), # prevent duplicate entries
) )
@@ -57,35 +57,33 @@ def upgrade() -> None:
# 2. Read existing group with user_ids JSON column # 2. Read existing group with user_ids JSON column
knowledge_table = sa.Table( knowledge_table = sa.Table(
"knowledge", 'knowledge',
sa.MetaData(), sa.MetaData(),
sa.Column("id", sa.Text()), sa.Column('id', sa.Text()),
sa.Column("user_id", sa.Text()), sa.Column('user_id', sa.Text()),
sa.Column("data", sa.JSON()), # JSON stored as text in SQLite + PG sa.Column('data', sa.JSON()), # JSON stored as text in SQLite + PG
) )
results = connection.execute( results = connection.execute(
sa.select( sa.select(knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data)
knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data
)
).fetchall() ).fetchall()
# 3. Insert members into group_member table # 3. Insert members into group_member table
kf_table = sa.Table( kf_table = sa.Table(
"knowledge_file", 'knowledge_file',
sa.MetaData(), sa.MetaData(),
sa.Column("id", sa.Text()), sa.Column('id', sa.Text()),
sa.Column("user_id", sa.Text()), sa.Column('user_id', sa.Text()),
sa.Column("knowledge_id", sa.Text()), sa.Column('knowledge_id', sa.Text()),
sa.Column("file_id", sa.Text()), sa.Column('file_id', sa.Text()),
sa.Column("created_at", sa.BigInteger()), sa.Column('created_at', sa.BigInteger()),
sa.Column("updated_at", sa.BigInteger()), sa.Column('updated_at', sa.BigInteger()),
) )
file_table = sa.Table( file_table = sa.Table(
"file", 'file',
sa.MetaData(), sa.MetaData(),
sa.Column("id", sa.Text()), sa.Column('id', sa.Text()),
) )
now = int(time.time()) now = int(time.time())
@@ -102,50 +100,48 @@ def upgrade() -> None:
if not isinstance(data, dict): if not isinstance(data, dict):
continue continue
file_ids = data.get("file_ids", []) file_ids = data.get('file_ids', [])
for file_id in file_ids: for file_id in file_ids:
file_exists = connection.execute( file_exists = connection.execute(sa.select(file_table.c.id).where(file_table.c.id == file_id)).fetchone()
sa.select(file_table.c.id).where(file_table.c.id == file_id)
).fetchone()
if not file_exists: if not file_exists:
continue # skip non-existing files continue # skip non-existing files
row = { row = {
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"user_id": user_id, 'user_id': user_id,
"knowledge_id": knowledge_id, 'knowledge_id': knowledge_id,
"file_id": file_id, 'file_id': file_id,
"created_at": now, 'created_at': now,
"updated_at": now, 'updated_at': now,
} }
connection.execute(kf_table.insert().values(**row)) connection.execute(kf_table.insert().values(**row))
with op.batch_alter_table("knowledge") as batch: with op.batch_alter_table('knowledge') as batch:
batch.drop_column("data") batch.drop_column('data')
def downgrade() -> None: def downgrade() -> None:
# 1. Add back the old data column # 1. Add back the old data column
op.add_column("knowledge", sa.Column("data", sa.JSON(), nullable=True)) op.add_column('knowledge', sa.Column('data', sa.JSON(), nullable=True))
connection = op.get_bind() connection = op.get_bind()
# 2. Read knowledge_file entries and reconstruct data JSON # 2. Read knowledge_file entries and reconstruct data JSON
knowledge_table = sa.Table( knowledge_table = sa.Table(
"knowledge", 'knowledge',
sa.MetaData(), sa.MetaData(),
sa.Column("id", sa.Text()), sa.Column('id', sa.Text()),
sa.Column("data", sa.JSON()), sa.Column('data', sa.JSON()),
) )
kf_table = sa.Table( kf_table = sa.Table(
"knowledge_file", 'knowledge_file',
sa.MetaData(), sa.MetaData(),
sa.Column("id", sa.Text()), sa.Column('id', sa.Text()),
sa.Column("knowledge_id", sa.Text()), sa.Column('knowledge_id', sa.Text()),
sa.Column("file_id", sa.Text()), sa.Column('file_id', sa.Text()),
) )
results = connection.execute(sa.select(knowledge_table.c.id)).fetchall() results = connection.execute(sa.select(knowledge_table.c.id)).fetchall()
@@ -157,13 +153,9 @@ def downgrade() -> None:
file_ids_list = [fid for (fid,) in file_ids] file_ids_list = [fid for (fid,) in file_ids]
data_json = {"file_ids": file_ids_list} data_json = {'file_ids': file_ids_list}
connection.execute( connection.execute(knowledge_table.update().where(knowledge_table.c.id == knowledge_id).values(data=data_json))
knowledge_table.update()
.where(knowledge_table.c.id == knowledge_id)
.values(data=data_json)
)
# 3. Drop the knowledge_file table # 3. Drop the knowledge_file table
op.drop_table("knowledge_file") op.drop_table('knowledge_file')

View File

@@ -9,56 +9,56 @@ Create Date: 2024-10-23 03:00:00.000000
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
revision = "4ace53fd72c8" revision = '4ace53fd72c8'
down_revision = "af906e964978" down_revision = 'af906e964978'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
# Perform safe alterations using batch operation # Perform safe alterations using batch operation
with op.batch_alter_table("folder", schema=None) as batch_op: with op.batch_alter_table('folder', schema=None) as batch_op:
# Step 1: Remove server defaults for created_at and updated_at # Step 1: Remove server defaults for created_at and updated_at
batch_op.alter_column( batch_op.alter_column(
"created_at", 'created_at',
server_default=None, # Removing server default server_default=None, # Removing server default
) )
batch_op.alter_column( batch_op.alter_column(
"updated_at", 'updated_at',
server_default=None, # Removing server default server_default=None, # Removing server default
) )
# Step 2: Change the column types to BigInteger for created_at # Step 2: Change the column types to BigInteger for created_at
batch_op.alter_column( batch_op.alter_column(
"created_at", 'created_at',
type_=sa.BigInteger(), type_=sa.BigInteger(),
existing_type=sa.DateTime(), existing_type=sa.DateTime(),
existing_nullable=False, existing_nullable=False,
postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL postgresql_using='extract(epoch from created_at)::bigint', # Conversion for PostgreSQL
) )
# Change the column types to BigInteger for updated_at # Change the column types to BigInteger for updated_at
batch_op.alter_column( batch_op.alter_column(
"updated_at", 'updated_at',
type_=sa.BigInteger(), type_=sa.BigInteger(),
existing_type=sa.DateTime(), existing_type=sa.DateTime(),
existing_nullable=False, existing_nullable=False,
postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL postgresql_using='extract(epoch from updated_at)::bigint', # Conversion for PostgreSQL
) )
def downgrade(): def downgrade():
# Downgrade: Convert columns back to DateTime and restore defaults # Downgrade: Convert columns back to DateTime and restore defaults
with op.batch_alter_table("folder", schema=None) as batch_op: with op.batch_alter_table('folder', schema=None) as batch_op:
batch_op.alter_column( batch_op.alter_column(
"created_at", 'created_at',
type_=sa.DateTime(), type_=sa.DateTime(),
existing_type=sa.BigInteger(), existing_type=sa.BigInteger(),
existing_nullable=False, existing_nullable=False,
server_default=sa.func.now(), # Restoring server default on downgrade server_default=sa.func.now(), # Restoring server default on downgrade
) )
batch_op.alter_column( batch_op.alter_column(
"updated_at", 'updated_at',
type_=sa.DateTime(), type_=sa.DateTime(),
existing_type=sa.BigInteger(), existing_type=sa.BigInteger(),
existing_nullable=False, existing_nullable=False,

View File

@@ -9,40 +9,40 @@ Create Date: 2024-12-22 03:00:00.000000
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
revision = "57c599a3cb57" revision = '57c599a3cb57'
down_revision = "922e7a387820" down_revision = '922e7a387820'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
"channel", 'channel',
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text()), sa.Column('user_id', sa.Text()),
sa.Column("name", sa.Text()), sa.Column('name', sa.Text()),
sa.Column("description", sa.Text(), nullable=True), sa.Column('description', sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True), sa.Column('data', sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True), sa.Column('access_control', sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
) )
op.create_table( op.create_table(
"message", 'message',
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text()), sa.Column('user_id', sa.Text()),
sa.Column("channel_id", sa.Text(), nullable=True), sa.Column('channel_id', sa.Text(), nullable=True),
sa.Column("content", sa.Text()), sa.Column('content', sa.Text()),
sa.Column("data", sa.JSON(), nullable=True), sa.Column('data', sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
) )
def downgrade(): def downgrade():
op.drop_table("channel") op.drop_table('channel')
op.drop_table("message") op.drop_table('message')

View File

@@ -13,41 +13,39 @@ import sqlalchemy as sa
import open_webui.internal.db import open_webui.internal.db
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "6283dc0e4d8d" revision: str = '6283dc0e4d8d'
down_revision: Union[str, None] = "3e0e00844bb0" down_revision: Union[str, None] = '3e0e00844bb0'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
op.create_table( op.create_table(
"channel_file", 'channel_file',
sa.Column("id", sa.Text(), primary_key=True), sa.Column('id', sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False), sa.Column('user_id', sa.Text(), nullable=False),
sa.Column( sa.Column(
"channel_id", 'channel_id',
sa.Text(), sa.Text(),
sa.ForeignKey("channel.id", ondelete="CASCADE"), sa.ForeignKey('channel.id', ondelete='CASCADE'),
nullable=False, nullable=False,
), ),
sa.Column( sa.Column(
"file_id", 'file_id',
sa.Text(), sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"), sa.ForeignKey('file.id', ondelete='CASCADE'),
nullable=False, nullable=False,
), ),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False), sa.Column('updated_at', sa.BigInteger(), nullable=False),
# indexes # indexes
sa.Index("ix_channel_file_channel_id", "channel_id"), sa.Index('ix_channel_file_channel_id', 'channel_id'),
sa.Index("ix_channel_file_file_id", "file_id"), sa.Index('ix_channel_file_file_id', 'file_id'),
sa.Index("ix_channel_file_user_id", "user_id"), sa.Index('ix_channel_file_user_id', 'user_id'),
# unique constraints # unique constraints
sa.UniqueConstraint( sa.UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'), # prevent duplicate entries
"channel_id", "file_id", name="uq_channel_file_channel_file"
), # prevent duplicate entries
) )
def downgrade() -> None: def downgrade() -> None:
op.drop_table("channel_file") op.drop_table('channel_file')

View File

@@ -11,37 +11,37 @@ import sqlalchemy as sa
from sqlalchemy.sql import table, column, select from sqlalchemy.sql import table, column, select
import json import json
revision = "6a39f3d8e55c" revision = '6a39f3d8e55c'
down_revision = "c0fbf31ca0db" down_revision = 'c0fbf31ca0db'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
# Creating the 'knowledge' table # Creating the 'knowledge' table
print("Creating knowledge table") print('Creating knowledge table')
knowledge_table = op.create_table( knowledge_table = op.create_table(
"knowledge", 'knowledge',
sa.Column("id", sa.Text(), primary_key=True), sa.Column('id', sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False), sa.Column('user_id', sa.Text(), nullable=False),
sa.Column("name", sa.Text(), nullable=False), sa.Column('name', sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True), sa.Column('description', sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True), sa.Column('data', sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
) )
print("Migrating data from document table to knowledge table") print('Migrating data from document table to knowledge table')
# Representation of the existing 'document' table # Representation of the existing 'document' table
document_table = table( document_table = table(
"document", 'document',
column("collection_name", sa.String()), column('collection_name', sa.String()),
column("user_id", sa.String()), column('user_id', sa.String()),
column("name", sa.String()), column('name', sa.String()),
column("title", sa.Text()), column('title', sa.Text()),
column("content", sa.Text()), column('content', sa.Text()),
column("timestamp", sa.BigInteger()), column('timestamp', sa.BigInteger()),
) )
# Select all from existing document table # Select all from existing document table
@@ -64,9 +64,9 @@ def upgrade():
user_id=doc.user_id, user_id=doc.user_id,
description=doc.name, description=doc.name,
meta={ meta={
"legacy": True, 'legacy': True,
"document": True, 'document': True,
"tags": json.loads(doc.content or "{}").get("tags", []), 'tags': json.loads(doc.content or '{}').get('tags', []),
}, },
name=doc.title, name=doc.title,
created_at=doc.timestamp, created_at=doc.timestamp,
@@ -76,4 +76,4 @@ def upgrade():
def downgrade(): def downgrade():
op.drop_table("knowledge") op.drop_table('knowledge')

View File

@@ -9,18 +9,18 @@ Create Date: 2024-12-23 03:00:00.000000
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
revision = "7826ab40b532" revision = '7826ab40b532'
down_revision = "57c599a3cb57" down_revision = '57c599a3cb57'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.add_column( op.add_column(
"file", 'file',
sa.Column("access_control", sa.JSON(), nullable=True), sa.Column('access_control', sa.JSON(), nullable=True),
) )
def downgrade(): def downgrade():
op.drop_column("file", "access_control") op.drop_column('file', 'access_control')

View File

@@ -16,7 +16,7 @@ from open_webui.internal.db import JSONField
from open_webui.migrations.util import get_existing_tables from open_webui.migrations.util import get_existing_tables
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "7e5b5dc7342b" revision: str = '7e5b5dc7342b'
down_revision: Union[str, None] = None down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -26,179 +26,179 @@ def upgrade() -> None:
existing_tables = set(get_existing_tables()) existing_tables = set(get_existing_tables())
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
if "auth" not in existing_tables: if 'auth' not in existing_tables:
op.create_table( op.create_table(
"auth", 'auth',
sa.Column("id", sa.String(), nullable=False), sa.Column('id', sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True), sa.Column('email', sa.String(), nullable=True),
sa.Column("password", sa.Text(), nullable=True), sa.Column('password', sa.Text(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True), sa.Column('active', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint('id'),
) )
if "chat" not in existing_tables: if 'chat' not in existing_tables:
op.create_table( op.create_table(
"chat", 'chat',
sa.Column("id", sa.String(), nullable=False), sa.Column('id', sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True), sa.Column('user_id', sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True), sa.Column('title', sa.Text(), nullable=True),
sa.Column("chat", sa.Text(), nullable=True), sa.Column('chat', sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.Text(), nullable=True), sa.Column('share_id', sa.Text(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True), sa.Column('archived', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint("share_id"), sa.UniqueConstraint('share_id'),
) )
if "chatidtag" not in existing_tables: if 'chatidtag' not in existing_tables:
op.create_table( op.create_table(
"chatidtag", 'chatidtag',
sa.Column("id", sa.String(), nullable=False), sa.Column('id', sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True), sa.Column('tag_name', sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True), sa.Column('chat_id', sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True), sa.Column('user_id', sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True), sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint('id'),
) )
if "document" not in existing_tables: if 'document' not in existing_tables:
op.create_table( op.create_table(
"document", 'document',
sa.Column("collection_name", sa.String(), nullable=False), sa.Column('collection_name', sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True), sa.Column('name', sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True), sa.Column('title', sa.Text(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True), sa.Column('filename', sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True), sa.Column('content', sa.Text(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True), sa.Column('user_id', sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True), sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"), sa.PrimaryKeyConstraint('collection_name'),
sa.UniqueConstraint("name"), sa.UniqueConstraint('name'),
) )
if "file" not in existing_tables: if 'file' not in existing_tables:
op.create_table( op.create_table(
"file", 'file',
sa.Column("id", sa.String(), nullable=False), sa.Column('id', sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True), sa.Column('user_id', sa.String(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True), sa.Column('filename', sa.Text(), nullable=True),
sa.Column("meta", JSONField(), nullable=True), sa.Column('meta', JSONField(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint('id'),
) )
if "function" not in existing_tables: if 'function' not in existing_tables:
op.create_table( op.create_table(
"function", 'function',
sa.Column("id", sa.String(), nullable=False), sa.Column('id', sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True), sa.Column('user_id', sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True), sa.Column('name', sa.Text(), nullable=True),
sa.Column("type", sa.Text(), nullable=True), sa.Column('type', sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True), sa.Column('content', sa.Text(), nullable=True),
sa.Column("meta", JSONField(), nullable=True), sa.Column('meta', JSONField(), nullable=True),
sa.Column("valves", JSONField(), nullable=True), sa.Column('valves', JSONField(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True), sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column("is_global", sa.Boolean(), nullable=True), sa.Column('is_global', sa.Boolean(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint('id'),
) )
if "memory" not in existing_tables: if 'memory' not in existing_tables:
op.create_table( op.create_table(
"memory", 'memory',
sa.Column("id", sa.String(), nullable=False), sa.Column('id', sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True), sa.Column('user_id', sa.String(), nullable=True),
sa.Column("content", sa.Text(), nullable=True), sa.Column('content', sa.Text(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint('id'),
) )
if "model" not in existing_tables: if 'model' not in existing_tables:
op.create_table( op.create_table(
"model", 'model',
sa.Column("id", sa.Text(), nullable=False), sa.Column('id', sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=True), sa.Column('user_id', sa.Text(), nullable=True),
sa.Column("base_model_id", sa.Text(), nullable=True), sa.Column('base_model_id', sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True), sa.Column('name', sa.Text(), nullable=True),
sa.Column("params", JSONField(), nullable=True), sa.Column('params', JSONField(), nullable=True),
sa.Column("meta", JSONField(), nullable=True), sa.Column('meta', JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint('id'),
) )
if "prompt" not in existing_tables: if 'prompt' not in existing_tables:
op.create_table( op.create_table(
"prompt", 'prompt',
sa.Column("command", sa.String(), nullable=False), sa.Column('command', sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True), sa.Column('user_id', sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True), sa.Column('title', sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True), sa.Column('content', sa.Text(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True), sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"), sa.PrimaryKeyConstraint('command'),
) )
if "tag" not in existing_tables: if 'tag' not in existing_tables:
op.create_table( op.create_table(
"tag", 'tag',
sa.Column("id", sa.String(), nullable=False), sa.Column('id', sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True), sa.Column('name', sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True), sa.Column('user_id', sa.String(), nullable=True),
sa.Column("data", sa.Text(), nullable=True), sa.Column('data', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint('id'),
) )
if "tool" not in existing_tables: if 'tool' not in existing_tables:
op.create_table( op.create_table(
"tool", 'tool',
sa.Column("id", sa.String(), nullable=False), sa.Column('id', sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True), sa.Column('user_id', sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True), sa.Column('name', sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True), sa.Column('content', sa.Text(), nullable=True),
sa.Column("specs", JSONField(), nullable=True), sa.Column('specs', JSONField(), nullable=True),
sa.Column("meta", JSONField(), nullable=True), sa.Column('meta', JSONField(), nullable=True),
sa.Column("valves", JSONField(), nullable=True), sa.Column('valves', JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint('id'),
) )
if "user" not in existing_tables: if 'user' not in existing_tables:
op.create_table( op.create_table(
"user", 'user',
sa.Column("id", sa.String(), nullable=False), sa.Column('id', sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True), sa.Column('name', sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True), sa.Column('email', sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True), sa.Column('role', sa.String(), nullable=True),
sa.Column("profile_image_url", sa.Text(), nullable=True), sa.Column('profile_image_url', sa.Text(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True), sa.Column('last_active_at', sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True), sa.Column('api_key', sa.String(), nullable=True),
sa.Column("settings", JSONField(), nullable=True), sa.Column('settings', JSONField(), nullable=True),
sa.Column("info", JSONField(), nullable=True), sa.Column('info', JSONField(), nullable=True),
sa.Column("oauth_sub", sa.Text(), nullable=True), sa.Column('oauth_sub', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint("id"), sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint("api_key"), sa.UniqueConstraint('api_key'),
sa.UniqueConstraint("oauth_sub"), sa.UniqueConstraint('oauth_sub'),
) )
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_table("user") op.drop_table('user')
op.drop_table("tool") op.drop_table('tool')
op.drop_table("tag") op.drop_table('tag')
op.drop_table("prompt") op.drop_table('prompt')
op.drop_table("model") op.drop_table('model')
op.drop_table("memory") op.drop_table('memory')
op.drop_table("function") op.drop_table('function')
op.drop_table("file") op.drop_table('file')
op.drop_table("document") op.drop_table('document')
op.drop_table("chatidtag") op.drop_table('chatidtag')
op.drop_table("chat") op.drop_table('chat')
op.drop_table("auth") op.drop_table('auth')
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@@ -13,36 +13,34 @@ import sqlalchemy as sa
import open_webui.internal.db import open_webui.internal.db
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "81cc2ce44d79" revision: str = '81cc2ce44d79'
down_revision: Union[str, None] = "6283dc0e4d8d" down_revision: Union[str, None] = '6283dc0e4d8d'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# Add message_id column to channel_file table # Add message_id column to channel_file table
with op.batch_alter_table("channel_file", schema=None) as batch_op: with op.batch_alter_table('channel_file', schema=None) as batch_op:
batch_op.add_column( batch_op.add_column(
sa.Column( sa.Column(
"message_id", 'message_id',
sa.Text(), sa.Text(),
sa.ForeignKey( sa.ForeignKey('message.id', ondelete='CASCADE', name='fk_channel_file_message_id'),
"message.id", ondelete="CASCADE", name="fk_channel_file_message_id"
),
nullable=True, nullable=True,
) )
) )
# Add data column to knowledge table # Add data column to knowledge table
with op.batch_alter_table("knowledge", schema=None) as batch_op: with op.batch_alter_table('knowledge', schema=None) as batch_op:
batch_op.add_column(sa.Column("data", sa.JSON(), nullable=True)) batch_op.add_column(sa.Column('data', sa.JSON(), nullable=True))
def downgrade() -> None: def downgrade() -> None:
# Remove message_id column from channel_file table # Remove message_id column from channel_file table
with op.batch_alter_table("channel_file", schema=None) as batch_op: with op.batch_alter_table('channel_file', schema=None) as batch_op:
batch_op.drop_column("message_id") batch_op.drop_column('message_id')
# Remove data column from knowledge table # Remove data column from knowledge table
with op.batch_alter_table("knowledge", schema=None) as batch_op: with op.batch_alter_table('knowledge', schema=None) as batch_op:
batch_op.drop_column("data") batch_op.drop_column('data')

View File

@@ -16,8 +16,8 @@ import sqlalchemy as sa
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
revision: str = "8452d01d26d7" revision: str = '8452d01d26d7'
down_revision: Union[str, None] = "374d2f66af06" down_revision: Union[str, None] = '374d2f66af06'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -51,74 +51,68 @@ def _flush_batch(conn, table, batch):
except Exception as e: except Exception as e:
sp.rollback() sp.rollback()
failed += 1 failed += 1
log.warning(f"Failed to insert message {msg['id']}: {e}") log.warning(f'Failed to insert message {msg["id"]}: {e}')
return inserted, failed return inserted, failed
def upgrade() -> None: def upgrade() -> None:
# Step 1: Create table # Step 1: Create table
op.create_table( op.create_table(
"chat_message", 'chat_message',
sa.Column("id", sa.Text(), primary_key=True), sa.Column('id', sa.Text(), primary_key=True),
sa.Column("chat_id", sa.Text(), nullable=False, index=True), sa.Column('chat_id', sa.Text(), nullable=False, index=True),
sa.Column("user_id", sa.Text(), index=True), sa.Column('user_id', sa.Text(), index=True),
sa.Column("role", sa.Text(), nullable=False), sa.Column('role', sa.Text(), nullable=False),
sa.Column("parent_id", sa.Text(), nullable=True), sa.Column('parent_id', sa.Text(), nullable=True),
sa.Column("content", sa.JSON(), nullable=True), sa.Column('content', sa.JSON(), nullable=True),
sa.Column("output", sa.JSON(), nullable=True), sa.Column('output', sa.JSON(), nullable=True),
sa.Column("model_id", sa.Text(), nullable=True, index=True), sa.Column('model_id', sa.Text(), nullable=True, index=True),
sa.Column("files", sa.JSON(), nullable=True), sa.Column('files', sa.JSON(), nullable=True),
sa.Column("sources", sa.JSON(), nullable=True), sa.Column('sources', sa.JSON(), nullable=True),
sa.Column("embeds", sa.JSON(), nullable=True), sa.Column('embeds', sa.JSON(), nullable=True),
sa.Column("done", sa.Boolean(), default=True), sa.Column('done', sa.Boolean(), default=True),
sa.Column("status_history", sa.JSON(), nullable=True), sa.Column('status_history', sa.JSON(), nullable=True),
sa.Column("error", sa.JSON(), nullable=True), sa.Column('error', sa.JSON(), nullable=True),
sa.Column("usage", sa.JSON(), nullable=True), sa.Column('usage', sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), index=True), sa.Column('created_at', sa.BigInteger(), index=True),
sa.Column("updated_at", sa.BigInteger()), sa.Column('updated_at', sa.BigInteger()),
sa.ForeignKeyConstraint(["chat_id"], ["chat.id"], ondelete="CASCADE"), sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ondelete='CASCADE'),
) )
# Create composite indexes # Create composite indexes
op.create_index( op.create_index('chat_message_chat_parent_idx', 'chat_message', ['chat_id', 'parent_id'])
"chat_message_chat_parent_idx", "chat_message", ["chat_id", "parent_id"] op.create_index('chat_message_model_created_idx', 'chat_message', ['model_id', 'created_at'])
) op.create_index('chat_message_user_created_idx', 'chat_message', ['user_id', 'created_at'])
op.create_index(
"chat_message_model_created_idx", "chat_message", ["model_id", "created_at"]
)
op.create_index(
"chat_message_user_created_idx", "chat_message", ["user_id", "created_at"]
)
# Step 2: Backfill from existing chats # Step 2: Backfill from existing chats
conn = op.get_bind() conn = op.get_bind()
chat_table = sa.table( chat_table = sa.table(
"chat", 'chat',
sa.column("id", sa.Text()), sa.column('id', sa.Text()),
sa.column("user_id", sa.Text()), sa.column('user_id', sa.Text()),
sa.column("chat", sa.JSON()), sa.column('chat', sa.JSON()),
) )
chat_message_table = sa.table( chat_message_table = sa.table(
"chat_message", 'chat_message',
sa.column("id", sa.Text()), sa.column('id', sa.Text()),
sa.column("chat_id", sa.Text()), sa.column('chat_id', sa.Text()),
sa.column("user_id", sa.Text()), sa.column('user_id', sa.Text()),
sa.column("role", sa.Text()), sa.column('role', sa.Text()),
sa.column("parent_id", sa.Text()), sa.column('parent_id', sa.Text()),
sa.column("content", sa.JSON()), sa.column('content', sa.JSON()),
sa.column("output", sa.JSON()), sa.column('output', sa.JSON()),
sa.column("model_id", sa.Text()), sa.column('model_id', sa.Text()),
sa.column("files", sa.JSON()), sa.column('files', sa.JSON()),
sa.column("sources", sa.JSON()), sa.column('sources', sa.JSON()),
sa.column("embeds", sa.JSON()), sa.column('embeds', sa.JSON()),
sa.column("done", sa.Boolean()), sa.column('done', sa.Boolean()),
sa.column("status_history", sa.JSON()), sa.column('status_history', sa.JSON()),
sa.column("error", sa.JSON()), sa.column('error', sa.JSON()),
sa.column("usage", sa.JSON()), sa.column('usage', sa.JSON()),
sa.column("created_at", sa.BigInteger()), sa.column('created_at', sa.BigInteger()),
sa.column("updated_at", sa.BigInteger()), sa.column('updated_at', sa.BigInteger()),
) )
# Stream rows instead of loading all into memory: # Stream rows instead of loading all into memory:
@@ -126,7 +120,7 @@ def upgrade() -> None:
# - stream_results: enables server-side cursors on PostgreSQL (no-op on SQLite) # - stream_results: enables server-side cursors on PostgreSQL (no-op on SQLite)
result = conn.execute( result = conn.execute(
sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat) sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat)
.where(~chat_table.c.user_id.like("shared-%")) .where(~chat_table.c.user_id.like('shared-%'))
.execution_options(yield_per=1000, stream_results=True) .execution_options(yield_per=1000, stream_results=True)
) )
@@ -150,11 +144,11 @@ def upgrade() -> None:
except Exception: except Exception:
continue continue
history = chat_data.get("history", {}) history = chat_data.get('history', {})
if not isinstance(history, dict): if not isinstance(history, dict):
continue continue
messages = history.get("messages", {}) messages = history.get('messages', {})
if not isinstance(messages, dict): if not isinstance(messages, dict):
continue continue
@@ -162,11 +156,11 @@ def upgrade() -> None:
if not isinstance(message, dict): if not isinstance(message, dict):
continue continue
role = message.get("role") role = message.get('role')
if not role: if not role:
continue continue
timestamp = message.get("timestamp", now) timestamp = message.get('timestamp', now)
try: try:
timestamp = int(float(timestamp)) timestamp = int(float(timestamp))
@@ -182,37 +176,33 @@ def upgrade() -> None:
messages_batch.append( messages_batch.append(
{ {
"id": f"{chat_id}-{message_id}", 'id': f'{chat_id}-{message_id}',
"chat_id": chat_id, 'chat_id': chat_id,
"user_id": user_id, 'user_id': user_id,
"role": role, 'role': role,
"parent_id": message.get("parentId"), 'parent_id': message.get('parentId'),
"content": message.get("content"), 'content': message.get('content'),
"output": message.get("output"), 'output': message.get('output'),
"model_id": message.get("model"), 'model_id': message.get('model'),
"files": message.get("files"), 'files': message.get('files'),
"sources": message.get("sources"), 'sources': message.get('sources'),
"embeds": message.get("embeds"), 'embeds': message.get('embeds'),
"done": message.get("done", True), 'done': message.get('done', True),
"status_history": message.get("statusHistory"), 'status_history': message.get('statusHistory'),
"error": message.get("error"), 'error': message.get('error'),
"usage": message.get("usage"), 'usage': message.get('usage'),
"created_at": timestamp, 'created_at': timestamp,
"updated_at": timestamp, 'updated_at': timestamp,
} }
) )
# Flush batch when full # Flush batch when full
if len(messages_batch) >= BATCH_SIZE: if len(messages_batch) >= BATCH_SIZE:
inserted, failed = _flush_batch( inserted, failed = _flush_batch(conn, chat_message_table, messages_batch)
conn, chat_message_table, messages_batch
)
total_inserted += inserted total_inserted += inserted
total_failed += failed total_failed += failed
if total_inserted % 50000 < BATCH_SIZE: if total_inserted % 50000 < BATCH_SIZE:
log.info( log.info(f'Migration progress: {total_inserted} messages inserted...')
f"Migration progress: {total_inserted} messages inserted..."
)
messages_batch.clear() messages_batch.clear()
# Flush remaining messages # Flush remaining messages
@@ -221,13 +211,11 @@ def upgrade() -> None:
total_inserted += inserted total_inserted += inserted
total_failed += failed total_failed += failed
log.info( log.info(f'Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)')
f"Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)"
)
def downgrade() -> None: def downgrade() -> None:
op.drop_index("chat_message_user_created_idx", table_name="chat_message") op.drop_index('chat_message_user_created_idx', table_name='chat_message')
op.drop_index("chat_message_model_created_idx", table_name="chat_message") op.drop_index('chat_message_model_created_idx', table_name='chat_message')
op.drop_index("chat_message_chat_parent_idx", table_name="chat_message") op.drop_index('chat_message_chat_parent_idx', table_name='chat_message')
op.drop_table("chat_message") op.drop_table('chat_message')

View File

@@ -13,48 +13,46 @@ import sqlalchemy as sa
import open_webui.internal.db import open_webui.internal.db
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "90ef40d4714e" revision: str = '90ef40d4714e'
down_revision: Union[str, None] = "b10670c03dd5" down_revision: Union[str, None] = 'b10670c03dd5'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# Update 'channel' table # Update 'channel' table
op.add_column("channel", sa.Column("is_private", sa.Boolean(), nullable=True)) op.add_column('channel', sa.Column('is_private', sa.Boolean(), nullable=True))
op.add_column("channel", sa.Column("archived_at", sa.BigInteger(), nullable=True)) op.add_column('channel', sa.Column('archived_at', sa.BigInteger(), nullable=True))
op.add_column("channel", sa.Column("archived_by", sa.Text(), nullable=True)) op.add_column('channel', sa.Column('archived_by', sa.Text(), nullable=True))
op.add_column("channel", sa.Column("deleted_at", sa.BigInteger(), nullable=True)) op.add_column('channel', sa.Column('deleted_at', sa.BigInteger(), nullable=True))
op.add_column("channel", sa.Column("deleted_by", sa.Text(), nullable=True)) op.add_column('channel', sa.Column('deleted_by', sa.Text(), nullable=True))
op.add_column("channel", sa.Column("updated_by", sa.Text(), nullable=True)) op.add_column('channel', sa.Column('updated_by', sa.Text(), nullable=True))
# Update 'channel_member' table # Update 'channel_member' table
op.add_column("channel_member", sa.Column("role", sa.Text(), nullable=True)) op.add_column('channel_member', sa.Column('role', sa.Text(), nullable=True))
op.add_column("channel_member", sa.Column("invited_by", sa.Text(), nullable=True)) op.add_column('channel_member', sa.Column('invited_by', sa.Text(), nullable=True))
op.add_column( op.add_column('channel_member', sa.Column('invited_at', sa.BigInteger(), nullable=True))
"channel_member", sa.Column("invited_at", sa.BigInteger(), nullable=True)
)
# Create 'channel_webhook' table # Create 'channel_webhook' table
op.create_table( op.create_table(
"channel_webhook", 'channel_webhook',
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False), sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False),
sa.Column("user_id", sa.Text(), nullable=False), sa.Column('user_id', sa.Text(), nullable=False),
sa.Column( sa.Column(
"channel_id", 'channel_id',
sa.Text(), sa.Text(),
sa.ForeignKey("channel.id", ondelete="CASCADE"), sa.ForeignKey('channel.id', ondelete='CASCADE'),
nullable=False, nullable=False,
), ),
sa.Column("name", sa.Text(), nullable=False), sa.Column('name', sa.Text(), nullable=False),
sa.Column("profile_image_url", sa.Text(), nullable=True), sa.Column('profile_image_url', sa.Text(), nullable=True),
sa.Column("token", sa.Text(), nullable=False), sa.Column('token', sa.Text(), nullable=False),
sa.Column("last_used_at", sa.BigInteger(), nullable=True), sa.Column('last_used_at', sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False), sa.Column('updated_at', sa.BigInteger(), nullable=False),
) )
pass pass
@@ -62,19 +60,19 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
# Downgrade 'channel' table # Downgrade 'channel' table
op.drop_column("channel", "is_private") op.drop_column('channel', 'is_private')
op.drop_column("channel", "archived_at") op.drop_column('channel', 'archived_at')
op.drop_column("channel", "archived_by") op.drop_column('channel', 'archived_by')
op.drop_column("channel", "deleted_at") op.drop_column('channel', 'deleted_at')
op.drop_column("channel", "deleted_by") op.drop_column('channel', 'deleted_by')
op.drop_column("channel", "updated_by") op.drop_column('channel', 'updated_by')
# Downgrade 'channel_member' table # Downgrade 'channel_member' table
op.drop_column("channel_member", "role") op.drop_column('channel_member', 'role')
op.drop_column("channel_member", "invited_by") op.drop_column('channel_member', 'invited_by')
op.drop_column("channel_member", "invited_at") op.drop_column('channel_member', 'invited_at')
# Drop 'channel_webhook' table # Drop 'channel_webhook' table
op.drop_table("channel_webhook") op.drop_table('channel_webhook')
pass pass

View File

@@ -9,38 +9,38 @@ Create Date: 2024-11-14 03:00:00.000000
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
revision = "922e7a387820" revision = '922e7a387820'
down_revision = "4ace53fd72c8" down_revision = '4ace53fd72c8'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
"group", 'group',
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), nullable=True), sa.Column('user_id', sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True), sa.Column('name', sa.Text(), nullable=True),
sa.Column("description", sa.Text(), nullable=True), sa.Column('description', sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True), sa.Column('data', sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True),
sa.Column("permissions", sa.JSON(), nullable=True), sa.Column('permissions', sa.JSON(), nullable=True),
sa.Column("user_ids", sa.JSON(), nullable=True), sa.Column('user_ids', sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
) )
# Add 'access_control' column to 'model' table # Add 'access_control' column to 'model' table
op.add_column( op.add_column(
"model", 'model',
sa.Column("access_control", sa.JSON(), nullable=True), sa.Column('access_control', sa.JSON(), nullable=True),
) )
# Add 'is_active' column to 'model' table # Add 'is_active' column to 'model' table
op.add_column( op.add_column(
"model", 'model',
sa.Column( sa.Column(
"is_active", 'is_active',
sa.Boolean(), sa.Boolean(),
nullable=False, nullable=False,
server_default=sa.sql.expression.true(), server_default=sa.sql.expression.true(),
@@ -49,37 +49,37 @@ def upgrade():
# Add 'access_control' column to 'knowledge' table # Add 'access_control' column to 'knowledge' table
op.add_column( op.add_column(
"knowledge", 'knowledge',
sa.Column("access_control", sa.JSON(), nullable=True), sa.Column('access_control', sa.JSON(), nullable=True),
) )
# Add 'access_control' column to 'prompt' table # Add 'access_control' column to 'prompt' table
op.add_column( op.add_column(
"prompt", 'prompt',
sa.Column("access_control", sa.JSON(), nullable=True), sa.Column('access_control', sa.JSON(), nullable=True),
) )
# Add 'access_control' column to 'tools' table # Add 'access_control' column to 'tools' table
op.add_column( op.add_column(
"tool", 'tool',
sa.Column("access_control", sa.JSON(), nullable=True), sa.Column('access_control', sa.JSON(), nullable=True),
) )
def downgrade(): def downgrade():
op.drop_table("group") op.drop_table('group')
# Drop 'access_control' column from 'model' table # Drop 'access_control' column from 'model' table
op.drop_column("model", "access_control") op.drop_column('model', 'access_control')
# Drop 'is_active' column from 'model' table # Drop 'is_active' column from 'model' table
op.drop_column("model", "is_active") op.drop_column('model', 'is_active')
# Drop 'access_control' column from 'knowledge' table # Drop 'access_control' column from 'knowledge' table
op.drop_column("knowledge", "access_control") op.drop_column('knowledge', 'access_control')
# Drop 'access_control' column from 'prompt' table # Drop 'access_control' column from 'prompt' table
op.drop_column("prompt", "access_control") op.drop_column('prompt', 'access_control')
# Drop 'access_control' column from 'tools' table # Drop 'access_control' column from 'tools' table
op.drop_column("tool", "access_control") op.drop_column('tool', 'access_control')

View File

@@ -9,25 +9,25 @@ Create Date: 2025-05-03 03:00:00.000000
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
revision = "9f0c9cd09105" revision = '9f0c9cd09105'
down_revision = "3781e22d8b01" down_revision = '3781e22d8b01'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
"note", 'note',
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True), sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), nullable=True), sa.Column('user_id', sa.Text(), nullable=True),
sa.Column("title", sa.Text(), nullable=True), sa.Column('title', sa.Text(), nullable=True),
sa.Column("data", sa.JSON(), nullable=True), sa.Column('data', sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True),
sa.Column("access_control", sa.JSON(), nullable=True), sa.Column('access_control', sa.JSON(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True), sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True), sa.Column('updated_at', sa.BigInteger(), nullable=True),
) )
def downgrade(): def downgrade():
op.drop_table("note") op.drop_table('note')

View File

@@ -13,8 +13,8 @@ import sqlalchemy as sa
from open_webui.migrations.util import get_existing_tables from open_webui.migrations.util import get_existing_tables
revision: str = "a1b2c3d4e5f6" revision: str = 'a1b2c3d4e5f6'
down_revision: Union[str, None] = "f1e2d3c4b5a6" down_revision: Union[str, None] = 'f1e2d3c4b5a6'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -22,24 +22,24 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
existing_tables = set(get_existing_tables()) existing_tables = set(get_existing_tables())
if "skill" not in existing_tables: if 'skill' not in existing_tables:
op.create_table( op.create_table(
"skill", 'skill',
sa.Column("id", sa.String(), nullable=False, primary_key=True), sa.Column('id', sa.String(), nullable=False, primary_key=True),
sa.Column("user_id", sa.String(), nullable=False), sa.Column('user_id', sa.String(), nullable=False),
sa.Column("name", sa.Text(), nullable=False, unique=True), sa.Column('name', sa.Text(), nullable=False, unique=True),
sa.Column("description", sa.Text(), nullable=True), sa.Column('description', sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=False), sa.Column('content', sa.Text(), nullable=False),
sa.Column("meta", sa.JSON(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False), sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False), sa.Column('updated_at', sa.BigInteger(), nullable=False),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
) )
op.create_index("idx_skill_user_id", "skill", ["user_id"]) op.create_index('idx_skill_user_id', 'skill', ['user_id'])
op.create_index("idx_skill_updated_at", "skill", ["updated_at"]) op.create_index('idx_skill_updated_at', 'skill', ['updated_at'])
def downgrade() -> None: def downgrade() -> None:
op.drop_index("idx_skill_updated_at", table_name="skill") op.drop_index('idx_skill_updated_at', table_name='skill')
op.drop_index("idx_skill_user_id", table_name="skill") op.drop_index('idx_skill_user_id', table_name='skill')
op.drop_table("skill") op.drop_table('skill')

View File

@@ -12,8 +12,8 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "a5c220713937" revision: str = 'a5c220713937'
down_revision: Union[str, None] = "38d63c18f30f" down_revision: Union[str, None] = '38d63c18f30f'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -21,14 +21,14 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
# Add 'reply_to_id' column to the 'message' table for replying to messages # Add 'reply_to_id' column to the 'message' table for replying to messages
op.add_column( op.add_column(
"message", 'message',
sa.Column("reply_to_id", sa.Text(), nullable=True), sa.Column('reply_to_id', sa.Text(), nullable=True),
) )
pass pass
def downgrade() -> None: def downgrade() -> None:
# Remove 'reply_to_id' column from the 'message' table # Remove 'reply_to_id' column from the 'message' table
op.drop_column("message", "reply_to_id") op.drop_column('message', 'reply_to_id')
pass pass

View File

@@ -10,8 +10,8 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# Revision identifiers, used by Alembic. # Revision identifiers, used by Alembic.
revision = "af906e964978" revision = 'af906e964978'
down_revision = "c29facfe716b" down_revision = 'c29facfe716b'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@@ -19,33 +19,23 @@ depends_on = None
def upgrade(): def upgrade():
# ### Create feedback table ### # ### Create feedback table ###
op.create_table( op.create_table(
"feedback", 'feedback',
sa.Column('id', sa.Text(), primary_key=True), # Unique identifier for each feedback (TEXT type)
sa.Column('user_id', sa.Text(), nullable=True), # ID of the user providing the feedback (TEXT type)
sa.Column('version', sa.BigInteger(), default=0), # Version of feedback (BIGINT type)
sa.Column('type', sa.Text(), nullable=True), # Type of feedback (TEXT type)
sa.Column('data', sa.JSON(), nullable=True), # Feedback data (JSON type)
sa.Column('meta', sa.JSON(), nullable=True), # Metadata for feedback (JSON type)
sa.Column('snapshot', sa.JSON(), nullable=True), # snapshot data for feedback (JSON type)
sa.Column( sa.Column(
"id", sa.Text(), primary_key=True 'created_at', sa.BigInteger(), nullable=False
), # Unique identifier for each feedback (TEXT type)
sa.Column(
"user_id", sa.Text(), nullable=True
), # ID of the user providing the feedback (TEXT type)
sa.Column(
"version", sa.BigInteger(), default=0
), # Version of feedback (BIGINT type)
sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type)
sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type)
sa.Column(
"meta", sa.JSON(), nullable=True
), # Metadata for feedback (JSON type)
sa.Column(
"snapshot", sa.JSON(), nullable=True
), # snapshot data for feedback (JSON type)
sa.Column(
"created_at", sa.BigInteger(), nullable=False
), # Feedback creation timestamp (BIGINT representing epoch) ), # Feedback creation timestamp (BIGINT representing epoch)
sa.Column( sa.Column(
"updated_at", sa.BigInteger(), nullable=False 'updated_at', sa.BigInteger(), nullable=False
), # Feedback update timestamp (BIGINT representing epoch) ), # Feedback update timestamp (BIGINT representing epoch)
) )
def downgrade(): def downgrade():
# ### Drop feedback table ### # ### Drop feedback table ###
op.drop_table("feedback") op.drop_table('feedback')

View File

@@ -17,8 +17,8 @@ import json
import time import time
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "b10670c03dd5" revision: str = 'b10670c03dd5'
down_revision: Union[str, None] = "2f1211949ecc" down_revision: Union[str, None] = '2f1211949ecc'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -33,13 +33,11 @@ def _drop_sqlite_indexes_for_column(table_name, column_name, conn):
for idx in indexes: for idx in indexes:
index_name = idx[1] # index name index_name = idx[1] # index name
# Get indexed columns # Get indexed columns
idx_info = conn.execute( idx_info = conn.execute(sa.text(f"PRAGMA index_info('{index_name}')")).fetchall()
sa.text(f"PRAGMA index_info('{index_name}')")
).fetchall()
indexed_cols = [row[2] for row in idx_info] # col names indexed_cols = [row[2] for row in idx_info] # col names
if column_name in indexed_cols: if column_name in indexed_cols:
conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}")) conn.execute(sa.text(f'DROP INDEX IF EXISTS {index_name}'))
def _convert_column_to_json(table: str, column: str): def _convert_column_to_json(table: str, column: str):
@@ -47,9 +45,9 @@ def _convert_column_to_json(table: str, column: str):
dialect = conn.dialect.name dialect = conn.dialect.name
# SQLite cannot ALTER COLUMN → must recreate column # SQLite cannot ALTER COLUMN → must recreate column
if dialect == "sqlite": if dialect == 'sqlite':
# 1. Add temporary column # 1. Add temporary column
op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True)) op.add_column(table, sa.Column(f'{column}_json', sa.JSON(), nullable=True))
# 2. Load old data # 2. Load old data
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall() rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
@@ -66,14 +64,14 @@ def _convert_column_to_json(table: str, column: str):
conn.execute( conn.execute(
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'), sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
{"val": json.dumps(parsed) if parsed else None, "id": uid}, {'val': json.dumps(parsed) if parsed else None, 'id': uid},
) )
# 3. Drop old TEXT column # 3. Drop old TEXT column
op.drop_column(table, column) op.drop_column(table, column)
# 4. Rename new JSON column → original name # 4. Rename new JSON column → original name
op.alter_column(table, f"{column}_json", new_column_name=column) op.alter_column(table, f'{column}_json', new_column_name=column)
else: else:
# PostgreSQL supports direct CAST # PostgreSQL supports direct CAST
@@ -81,7 +79,7 @@ def _convert_column_to_json(table: str, column: str):
table, table,
column, column,
type_=sa.JSON(), type_=sa.JSON(),
postgresql_using=f"{column}::json", postgresql_using=f'{column}::json',
) )
@@ -89,85 +87,77 @@ def _convert_column_to_text(table: str, column: str):
conn = op.get_bind() conn = op.get_bind()
dialect = conn.dialect.name dialect = conn.dialect.name
if dialect == "sqlite": if dialect == 'sqlite':
op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True)) op.add_column(table, sa.Column(f'{column}_text', sa.Text(), nullable=True))
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall() rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
for uid, raw in rows: for uid, raw in rows:
conn.execute( conn.execute(
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'), sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
{"val": json.dumps(raw) if raw else None, "id": uid}, {'val': json.dumps(raw) if raw else None, 'id': uid},
) )
op.drop_column(table, column) op.drop_column(table, column)
op.alter_column(table, f"{column}_text", new_column_name=column) op.alter_column(table, f'{column}_text', new_column_name=column)
else: else:
op.alter_column( op.alter_column(
table, table,
column, column,
type_=sa.Text(), type_=sa.Text(),
postgresql_using=f"to_json({column})::text", postgresql_using=f'to_json({column})::text',
) )
def upgrade() -> None: def upgrade() -> None:
op.add_column( op.add_column('user', sa.Column('profile_banner_image_url', sa.Text(), nullable=True))
"user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True) op.add_column('user', sa.Column('timezone', sa.String(), nullable=True))
)
op.add_column("user", sa.Column("timezone", sa.String(), nullable=True))
op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True)) op.add_column('user', sa.Column('presence_state', sa.String(), nullable=True))
op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True)) op.add_column('user', sa.Column('status_emoji', sa.String(), nullable=True))
op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True)) op.add_column('user', sa.Column('status_message', sa.Text(), nullable=True))
op.add_column( op.add_column('user', sa.Column('status_expires_at', sa.BigInteger(), nullable=True))
"user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True)
)
op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True)) op.add_column('user', sa.Column('oauth', sa.JSON(), nullable=True))
# Convert info (TEXT/JSONField) → JSON # Convert info (TEXT/JSONField) → JSON
_convert_column_to_json("user", "info") _convert_column_to_json('user', 'info')
# Convert settings (TEXT/JSONField) → JSON # Convert settings (TEXT/JSONField) → JSON
_convert_column_to_json("user", "settings") _convert_column_to_json('user', 'settings')
op.create_table( op.create_table(
"api_key", 'api_key',
sa.Column("id", sa.Text(), primary_key=True, unique=True), sa.Column('id', sa.Text(), primary_key=True, unique=True),
sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")), sa.Column('user_id', sa.Text(), sa.ForeignKey('user.id', ondelete='CASCADE')),
sa.Column("key", sa.Text(), unique=True, nullable=False), sa.Column('key', sa.Text(), unique=True, nullable=False),
sa.Column("data", sa.JSON(), nullable=True), sa.Column('data', sa.JSON(), nullable=True),
sa.Column("expires_at", sa.BigInteger(), nullable=True), sa.Column('expires_at', sa.BigInteger(), nullable=True),
sa.Column("last_used_at", sa.BigInteger(), nullable=True), sa.Column('last_used_at', sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False), sa.Column('updated_at', sa.BigInteger(), nullable=False),
) )
conn = op.get_bind() conn = op.get_bind()
users = conn.execute( users = conn.execute(sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')).fetchall()
sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')
).fetchall()
for uid, oauth_sub in users: for uid, oauth_sub in users:
if oauth_sub: if oauth_sub:
# Example formats supported: # Example formats supported:
# provider@sub # provider@sub
# plain sub (stored as {"oidc": {"sub": sub}}) # plain sub (stored as {"oidc": {"sub": sub}})
if "@" in oauth_sub: if '@' in oauth_sub:
provider, sub = oauth_sub.split("@", 1) provider, sub = oauth_sub.split('@', 1)
else: else:
provider, sub = "oidc", oauth_sub provider, sub = 'oidc', oauth_sub
oauth_json = json.dumps({provider: {"sub": sub}}) oauth_json = json.dumps({provider: {'sub': sub}})
conn.execute( conn.execute(
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'), sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
{"oauth": oauth_json, "id": uid}, {'oauth': oauth_json, 'id': uid},
) )
users_with_keys = conn.execute( users_with_keys = conn.execute(sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')).fetchall()
sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')
).fetchall()
now = int(time.time()) now = int(time.time())
for uid, api_key in users_with_keys: for uid, api_key in users_with_keys:
@@ -178,72 +168,70 @@ def upgrade() -> None:
VALUES (:id, :user_id, :key, :created_at, :updated_at) VALUES (:id, :user_id, :key, :created_at, :updated_at)
"""), """),
{ {
"id": f"key_{uid}", 'id': f'key_{uid}',
"user_id": uid, 'user_id': uid,
"key": api_key, 'key': api_key,
"created_at": now, 'created_at': now,
"updated_at": now, 'updated_at': now,
}, },
) )
if conn.dialect.name == "sqlite": if conn.dialect.name == 'sqlite':
_drop_sqlite_indexes_for_column("user", "api_key", conn) _drop_sqlite_indexes_for_column('user', 'api_key', conn)
_drop_sqlite_indexes_for_column("user", "oauth_sub", conn) _drop_sqlite_indexes_for_column('user', 'oauth_sub', conn)
with op.batch_alter_table("user") as batch_op: with op.batch_alter_table('user') as batch_op:
batch_op.drop_column("api_key") batch_op.drop_column('api_key')
batch_op.drop_column("oauth_sub") batch_op.drop_column('oauth_sub')
def downgrade() -> None: def downgrade() -> None:
# --- 1. Restore old oauth_sub column --- # --- 1. Restore old oauth_sub column ---
op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True)) op.add_column('user', sa.Column('oauth_sub', sa.Text(), nullable=True))
conn = op.get_bind() conn = op.get_bind()
users = conn.execute( users = conn.execute(sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')).fetchall()
sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')
).fetchall()
for uid, oauth in users: for uid, oauth in users:
try: try:
data = json.loads(oauth) data = json.loads(oauth)
provider = list(data.keys())[0] provider = list(data.keys())[0]
sub = data[provider].get("sub") sub = data[provider].get('sub')
oauth_sub = f"{provider}@{sub}" oauth_sub = f'{provider}@{sub}'
except Exception: except Exception:
oauth_sub = None oauth_sub = None
conn.execute( conn.execute(
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'), sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
{"oauth_sub": oauth_sub, "id": uid}, {'oauth_sub': oauth_sub, 'id': uid},
) )
op.drop_column("user", "oauth") op.drop_column('user', 'oauth')
# --- 2. Restore api_key field --- # --- 2. Restore api_key field ---
op.add_column("user", sa.Column("api_key", sa.String(), nullable=True)) op.add_column('user', sa.Column('api_key', sa.String(), nullable=True))
# Restore values from api_key # Restore values from api_key
keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall() keys = conn.execute(sa.text('SELECT user_id, key FROM api_key')).fetchall()
for uid, key in keys: for uid, key in keys:
conn.execute( conn.execute(
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'), sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
{"key": key, "id": uid}, {'key': key, 'id': uid},
) )
# Drop new table # Drop new table
op.drop_table("api_key") op.drop_table('api_key')
with op.batch_alter_table("user") as batch_op: with op.batch_alter_table('user') as batch_op:
batch_op.drop_column("profile_banner_image_url") batch_op.drop_column('profile_banner_image_url')
batch_op.drop_column("timezone") batch_op.drop_column('timezone')
batch_op.drop_column("presence_state") batch_op.drop_column('presence_state')
batch_op.drop_column("status_emoji") batch_op.drop_column('status_emoji')
batch_op.drop_column("status_message") batch_op.drop_column('status_message')
batch_op.drop_column("status_expires_at") batch_op.drop_column('status_expires_at')
# Convert info (JSON) → TEXT # Convert info (JSON) → TEXT
_convert_column_to_text("user", "info") _convert_column_to_text('user', 'info')
# Convert settings (JSON) → TEXT # Convert settings (JSON) → TEXT
_convert_column_to_text("user", "settings") _convert_column_to_text('user', 'settings')

View File

@@ -12,15 +12,15 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "b2c3d4e5f6a7" revision: str = 'b2c3d4e5f6a7'
down_revision: Union[str, None] = "a1b2c3d4e5f6" down_revision: Union[str, None] = 'a1b2c3d4e5f6'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
op.add_column("user", sa.Column("scim", sa.JSON(), nullable=True)) op.add_column('user', sa.Column('scim', sa.JSON(), nullable=True))
def downgrade() -> None: def downgrade() -> None:
op.drop_column("user", "scim") op.drop_column('user', 'scim')

View File

@@ -12,21 +12,21 @@ import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "c0fbf31ca0db" revision: str = 'c0fbf31ca0db'
down_revision: Union[str, None] = "ca81bd47c050" down_revision: Union[str, None] = 'ca81bd47c050'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.add_column("file", sa.Column("hash", sa.Text(), nullable=True)) op.add_column('file', sa.Column('hash', sa.Text(), nullable=True))
op.add_column("file", sa.Column("data", sa.JSON(), nullable=True)) op.add_column('file', sa.Column('data', sa.JSON(), nullable=True))
op.add_column("file", sa.Column("updated_at", sa.BigInteger(), nullable=True)) op.add_column('file', sa.Column('updated_at', sa.BigInteger(), nullable=True))
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_column("file", "updated_at") op.drop_column('file', 'updated_at')
op.drop_column("file", "data") op.drop_column('file', 'data')
op.drop_column("file", "hash") op.drop_column('file', 'hash')

View File

@@ -12,35 +12,33 @@ import json
from sqlalchemy.sql import table, column from sqlalchemy.sql import table, column
from sqlalchemy import String, Text, JSON, and_ from sqlalchemy import String, Text, JSON, and_
revision = "c29facfe716b" revision = 'c29facfe716b'
down_revision = "c69f45358db4" down_revision = 'c69f45358db4'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
# 1. Add the `path` column to the "file" table. # 1. Add the `path` column to the "file" table.
op.add_column("file", sa.Column("path", sa.Text(), nullable=True)) op.add_column('file', sa.Column('path', sa.Text(), nullable=True))
# 2. Convert the `meta` column from Text/JSONField to `JSON()` # 2. Convert the `meta` column from Text/JSONField to `JSON()`
# Use Alembic's default batch_op for dialect compatibility. # Use Alembic's default batch_op for dialect compatibility.
with op.batch_alter_table("file", schema=None) as batch_op: with op.batch_alter_table('file', schema=None) as batch_op:
batch_op.alter_column( batch_op.alter_column(
"meta", 'meta',
type_=sa.JSON(), type_=sa.JSON(),
existing_type=sa.Text(), existing_type=sa.Text(),
existing_nullable=True, existing_nullable=True,
nullable=True, nullable=True,
postgresql_using="meta::json", postgresql_using='meta::json',
) )
# 3. Migrate legacy data from `meta` JSONField # 3. Migrate legacy data from `meta` JSONField
# Fetch and process `meta` data from the table, add values to the new `path` column as necessary. # Fetch and process `meta` data from the table, add values to the new `path` column as necessary.
# We will use SQLAlchemy core bindings to ensure safety across different databases. # We will use SQLAlchemy core bindings to ensure safety across different databases.
file_table = table( file_table = table('file', column('id', String), column('meta', JSON), column('path', Text))
"file", column("id", String), column("meta", JSON), column("path", Text)
)
# Create connection to the database # Create connection to the database
connection = op.get_bind() connection = op.get_bind()
@@ -55,24 +53,18 @@ def upgrade():
# Iterate over each row to extract and update the `path` from `meta` column # Iterate over each row to extract and update the `path` from `meta` column
for row in results: for row in results:
if "path" in row.meta: if 'path' in row.meta:
# Extract the `path` field from the `meta` JSON # Extract the `path` field from the `meta` JSON
path = row.meta.get("path") path = row.meta.get('path')
# Update the `file` table with the new `path` value # Update the `file` table with the new `path` value
connection.execute( connection.execute(file_table.update().where(file_table.c.id == row.id).values({'path': path}))
file_table.update()
.where(file_table.c.id == row.id)
.values({"path": path})
)
def downgrade(): def downgrade():
# 1. Remove the `path` column # 1. Remove the `path` column
op.drop_column("file", "path") op.drop_column('file', 'path')
# 2. Revert the `meta` column back to Text/JSONField # 2. Revert the `meta` column back to Text/JSONField
with op.batch_alter_table("file", schema=None) as batch_op: with op.batch_alter_table('file', schema=None) as batch_op:
batch_op.alter_column( batch_op.alter_column('meta', type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True)
"meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True
)

View File

@@ -12,45 +12,43 @@ from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "c440947495f3" revision: str = 'c440947495f3'
down_revision: Union[str, None] = "81cc2ce44d79" down_revision: Union[str, None] = '81cc2ce44d79'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
op.create_table( op.create_table(
"chat_file", 'chat_file',
sa.Column("id", sa.Text(), primary_key=True), sa.Column('id', sa.Text(), primary_key=True),
sa.Column("user_id", sa.Text(), nullable=False), sa.Column('user_id', sa.Text(), nullable=False),
sa.Column( sa.Column(
"chat_id", 'chat_id',
sa.Text(), sa.Text(),
sa.ForeignKey("chat.id", ondelete="CASCADE"), sa.ForeignKey('chat.id', ondelete='CASCADE'),
nullable=False, nullable=False,
), ),
sa.Column( sa.Column(
"file_id", 'file_id',
sa.Text(), sa.Text(),
sa.ForeignKey("file.id", ondelete="CASCADE"), sa.ForeignKey('file.id', ondelete='CASCADE'),
nullable=False, nullable=False,
), ),
sa.Column("message_id", sa.Text(), nullable=True), sa.Column('message_id', sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.Column("updated_at", sa.BigInteger(), nullable=False), sa.Column('updated_at', sa.BigInteger(), nullable=False),
# indexes # indexes
sa.Index("ix_chat_file_chat_id", "chat_id"), sa.Index('ix_chat_file_chat_id', 'chat_id'),
sa.Index("ix_chat_file_file_id", "file_id"), sa.Index('ix_chat_file_file_id', 'file_id'),
sa.Index("ix_chat_file_message_id", "message_id"), sa.Index('ix_chat_file_message_id', 'message_id'),
sa.Index("ix_chat_file_user_id", "user_id"), sa.Index('ix_chat_file_user_id', 'user_id'),
# unique constraints # unique constraints
sa.UniqueConstraint( sa.UniqueConstraint('chat_id', 'file_id', name='uq_chat_file_chat_file'), # prevent duplicate entries
"chat_id", "file_id", name="uq_chat_file_chat_file"
), # prevent duplicate entries
) )
pass pass
def downgrade() -> None: def downgrade() -> None:
op.drop_table("chat_file") op.drop_table('chat_file')
pass pass

View File

@@ -9,42 +9,40 @@ Create Date: 2024-10-16 02:02:35.241684
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
revision = "c69f45358db4" revision = 'c69f45358db4'
down_revision = "3ab32c4b8f59" down_revision = '3ab32c4b8f59'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
"folder", 'folder',
sa.Column("id", sa.Text(), nullable=False), sa.Column('id', sa.Text(), nullable=False),
sa.Column("parent_id", sa.Text(), nullable=True), sa.Column('parent_id', sa.Text(), nullable=True),
sa.Column("user_id", sa.Text(), nullable=False), sa.Column('user_id', sa.Text(), nullable=False),
sa.Column("name", sa.Text(), nullable=False), sa.Column('name', sa.Text(), nullable=False),
sa.Column("items", sa.JSON(), nullable=True), sa.Column('items', sa.JSON(), nullable=True),
sa.Column("meta", sa.JSON(), nullable=True), sa.Column('meta', sa.JSON(), nullable=True),
sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False), sa.Column('is_expanded', sa.Boolean(), default=False, nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.Column( sa.Column(
"created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False 'updated_at',
),
sa.Column(
"updated_at",
sa.DateTime(), sa.DateTime(),
nullable=False, nullable=False,
server_default=sa.func.now(), server_default=sa.func.now(),
onupdate=sa.func.now(), onupdate=sa.func.now(),
), ),
sa.PrimaryKeyConstraint("id", "user_id"), sa.PrimaryKeyConstraint('id', 'user_id'),
) )
op.add_column( op.add_column(
"chat", 'chat',
sa.Column("folder_id", sa.Text(), nullable=True), sa.Column('folder_id', sa.Text(), nullable=True),
) )
def downgrade(): def downgrade():
op.drop_column("chat", "folder_id") op.drop_column('chat', 'folder_id')
op.drop_table("folder") op.drop_table('folder')

View File

@@ -12,23 +12,21 @@ import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision: str = "ca81bd47c050" revision: str = 'ca81bd47c050'
down_revision: Union[str, None] = "7e5b5dc7342b" down_revision: Union[str, None] = '7e5b5dc7342b'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
"config", 'config',
sa.Column("id", sa.Integer, primary_key=True), sa.Column('id', sa.Integer, primary_key=True),
sa.Column("data", sa.JSON(), nullable=False), sa.Column('data', sa.JSON(), nullable=False),
sa.Column("version", sa.Integer, nullable=False), sa.Column('version', sa.Integer, nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
sa.Column( sa.Column(
"created_at", sa.DateTime(), nullable=False, server_default=sa.func.now() 'updated_at',
),
sa.Column(
"updated_at",
sa.DateTime(), sa.DateTime(),
nullable=True, nullable=True,
server_default=sa.func.now(), server_default=sa.func.now(),
@@ -38,4 +36,4 @@ def upgrade():
def downgrade(): def downgrade():
op.drop_table("config") op.drop_table('config')

View File

@@ -9,15 +9,15 @@ Create Date: 2025-07-13 03:00:00.000000
from alembic import op from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
revision = "d31026856c01" revision = 'd31026856c01'
down_revision = "9f0c9cd09105" down_revision = '9f0c9cd09105'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True)) op.add_column('folder', sa.Column('data', sa.JSON(), nullable=True))
def downgrade(): def downgrade():
op.drop_column("folder", "data") op.drop_column('folder', 'data')

View File

@@ -20,8 +20,8 @@ import sqlalchemy as sa
from open_webui.migrations.util import get_existing_tables from open_webui.migrations.util import get_existing_tables
revision: str = "f1e2d3c4b5a6" revision: str = 'f1e2d3c4b5a6'
down_revision: Union[str, None] = "8452d01d26d7" down_revision: Union[str, None] = '8452d01d26d7'
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None
@@ -30,34 +30,34 @@ def upgrade() -> None:
existing_tables = set(get_existing_tables()) existing_tables = set(get_existing_tables())
# Create access_grant table # Create access_grant table
if "access_grant" not in existing_tables: if 'access_grant' not in existing_tables:
op.create_table( op.create_table(
"access_grant", 'access_grant',
sa.Column("id", sa.Text(), nullable=False, primary_key=True), sa.Column('id', sa.Text(), nullable=False, primary_key=True),
sa.Column("resource_type", sa.Text(), nullable=False), sa.Column('resource_type', sa.Text(), nullable=False),
sa.Column("resource_id", sa.Text(), nullable=False), sa.Column('resource_id', sa.Text(), nullable=False),
sa.Column("principal_type", sa.Text(), nullable=False), sa.Column('principal_type', sa.Text(), nullable=False),
sa.Column("principal_id", sa.Text(), nullable=False), sa.Column('principal_id', sa.Text(), nullable=False),
sa.Column("permission", sa.Text(), nullable=False), sa.Column('permission', sa.Text(), nullable=False),
sa.Column("created_at", sa.BigInteger(), nullable=False), sa.Column('created_at', sa.BigInteger(), nullable=False),
sa.UniqueConstraint( sa.UniqueConstraint(
"resource_type", 'resource_type',
"resource_id", 'resource_id',
"principal_type", 'principal_type',
"principal_id", 'principal_id',
"permission", 'permission',
name="uq_access_grant_grant", name='uq_access_grant_grant',
), ),
) )
op.create_index( op.create_index(
"idx_access_grant_resource", 'idx_access_grant_resource',
"access_grant", 'access_grant',
["resource_type", "resource_id"], ['resource_type', 'resource_id'],
) )
op.create_index( op.create_index(
"idx_access_grant_principal", 'idx_access_grant_principal',
"access_grant", 'access_grant',
["principal_type", "principal_id"], ['principal_type', 'principal_id'],
) )
# Backfill existing access_control JSON data # Backfill existing access_control JSON data
@@ -65,13 +65,13 @@ def upgrade() -> None:
# Tables with access_control JSON columns: (table_name, resource_type) # Tables with access_control JSON columns: (table_name, resource_type)
resource_tables = [ resource_tables = [
("knowledge", "knowledge"), ('knowledge', 'knowledge'),
("prompt", "prompt"), ('prompt', 'prompt'),
("tool", "tool"), ('tool', 'tool'),
("model", "model"), ('model', 'model'),
("note", "note"), ('note', 'note'),
("channel", "channel"), ('channel', 'channel'),
("file", "file"), ('file', 'file'),
] ]
now = int(time.time()) now = int(time.time())
@@ -83,9 +83,7 @@ def upgrade() -> None:
# Query all rows # Query all rows
try: try:
result = conn.execute( result = conn.execute(sa.text(f'SELECT id, access_control FROM "{table_name}"'))
sa.text(f'SELECT id, access_control FROM "{table_name}"')
)
rows = result.fetchall() rows = result.fetchall()
except Exception: except Exception:
continue continue
@@ -99,19 +97,16 @@ def upgrade() -> None:
# EXCEPTION: files with NULL are PRIVATE (owner-only), not public # EXCEPTION: files with NULL are PRIVATE (owner-only), not public
is_null = ( is_null = (
access_control_json is None access_control_json is None
or access_control_json == "null" or access_control_json == 'null'
or ( or (isinstance(access_control_json, str) and access_control_json.strip().lower() == 'null')
isinstance(access_control_json, str)
and access_control_json.strip().lower() == "null"
)
) )
if is_null: if is_null:
# Files: NULL = private (no entry needed, owner has implicit access) # Files: NULL = private (no entry needed, owner has implicit access)
# Other resources: NULL = public (insert user:* for read) # Other resources: NULL = public (insert user:* for read)
if resource_type == "file": if resource_type == 'file':
continue # Private - no entry needed continue # Private - no entry needed
key = (resource_type, resource_id, "user", "*", "read") key = (resource_type, resource_id, 'user', '*', 'read')
if key not in inserted: if key not in inserted:
try: try:
conn.execute( conn.execute(
@@ -120,13 +115,13 @@ def upgrade() -> None:
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
"""), """),
{ {
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"resource_type": resource_type, 'resource_type': resource_type,
"resource_id": resource_id, 'resource_id': resource_id,
"principal_type": "user", 'principal_type': 'user',
"principal_id": "*", 'principal_id': '*',
"permission": "read", 'permission': 'read',
"created_at": now, 'created_at': now,
}, },
) )
inserted.add(key) inserted.add(key)
@@ -149,28 +144,24 @@ def upgrade() -> None:
continue continue
# Check if it's effectively empty (no read/write keys with content) # Check if it's effectively empty (no read/write keys with content)
read_data = access_control_json.get("read", {}) read_data = access_control_json.get('read', {})
write_data = access_control_json.get("write", {}) write_data = access_control_json.get('write', {})
has_read_grants = read_data.get("group_ids", []) or read_data.get( has_read_grants = read_data.get('group_ids', []) or read_data.get('user_ids', [])
"user_ids", [] has_write_grants = write_data.get('group_ids', []) or write_data.get('user_ids', [])
)
has_write_grants = write_data.get("group_ids", []) or write_data.get(
"user_ids", []
)
if not has_read_grants and not has_write_grants: if not has_read_grants and not has_write_grants:
# Empty permissions = private, no grants needed # Empty permissions = private, no grants needed
continue continue
# Extract permissions and insert into access_grant table # Extract permissions and insert into access_grant table
for permission in ["read", "write"]: for permission in ['read', 'write']:
perm_data = access_control_json.get(permission, {}) perm_data = access_control_json.get(permission, {})
if not perm_data: if not perm_data:
continue continue
for group_id in perm_data.get("group_ids", []): for group_id in perm_data.get('group_ids', []):
key = (resource_type, resource_id, "group", group_id, permission) key = (resource_type, resource_id, 'group', group_id, permission)
if key in inserted: if key in inserted:
continue continue
try: try:
@@ -180,21 +171,21 @@ def upgrade() -> None:
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
"""), """),
{ {
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"resource_type": resource_type, 'resource_type': resource_type,
"resource_id": resource_id, 'resource_id': resource_id,
"principal_type": "group", 'principal_type': 'group',
"principal_id": group_id, 'principal_id': group_id,
"permission": permission, 'permission': permission,
"created_at": now, 'created_at': now,
}, },
) )
inserted.add(key) inserted.add(key)
except Exception: except Exception:
pass pass
for user_id in perm_data.get("user_ids", []): for user_id in perm_data.get('user_ids', []):
key = (resource_type, resource_id, "user", user_id, permission) key = (resource_type, resource_id, 'user', user_id, permission)
if key in inserted: if key in inserted:
continue continue
try: try:
@@ -204,13 +195,13 @@ def upgrade() -> None:
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at) VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
"""), """),
{ {
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"resource_type": resource_type, 'resource_type': resource_type,
"resource_id": resource_id, 'resource_id': resource_id,
"principal_type": "user", 'principal_type': 'user',
"principal_id": user_id, 'principal_id': user_id,
"permission": permission, 'permission': permission,
"created_at": now, 'created_at': now,
}, },
) )
inserted.add(key) inserted.add(key)
@@ -223,7 +214,7 @@ def upgrade() -> None:
continue continue
try: try:
with op.batch_alter_table(table_name) as batch: with op.batch_alter_table(table_name) as batch:
batch.drop_column("access_control") batch.drop_column('access_control')
except Exception: except Exception:
pass pass
@@ -235,20 +226,20 @@ def downgrade() -> None:
# Resource tables mapping: (table_name, resource_type) # Resource tables mapping: (table_name, resource_type)
resource_tables = [ resource_tables = [
("knowledge", "knowledge"), ('knowledge', 'knowledge'),
("prompt", "prompt"), ('prompt', 'prompt'),
("tool", "tool"), ('tool', 'tool'),
("model", "model"), ('model', 'model'),
("note", "note"), ('note', 'note'),
("channel", "channel"), ('channel', 'channel'),
("file", "file"), ('file', 'file'),
] ]
# Step 1: Re-add access_control columns to resource tables # Step 1: Re-add access_control columns to resource tables
for table_name, _ in resource_tables: for table_name, _ in resource_tables:
try: try:
with op.batch_alter_table(table_name) as batch: with op.batch_alter_table(table_name) as batch:
batch.add_column(sa.Column("access_control", sa.JSON(), nullable=True)) batch.add_column(sa.Column('access_control', sa.JSON(), nullable=True))
except Exception: except Exception:
pass pass
@@ -262,7 +253,7 @@ def downgrade() -> None:
FROM access_grant FROM access_grant
WHERE resource_type = :resource_type WHERE resource_type = :resource_type
"""), """),
{"resource_type": resource_type}, {'resource_type': resource_type},
) )
rows = result.fetchall() rows = result.fetchall()
except Exception: except Exception:
@@ -278,49 +269,35 @@ def downgrade() -> None:
if resource_id not in resource_grants: if resource_id not in resource_grants:
resource_grants[resource_id] = { resource_grants[resource_id] = {
"is_public": False, 'is_public': False,
"read": {"group_ids": [], "user_ids": []}, 'read': {'group_ids': [], 'user_ids': []},
"write": {"group_ids": [], "user_ids": []}, 'write': {'group_ids': [], 'user_ids': []},
} }
# Handle public access (user:* for read) # Handle public access (user:* for read)
if ( if principal_type == 'user' and principal_id == '*' and permission == 'read':
principal_type == "user" resource_grants[resource_id]['is_public'] = True
and principal_id == "*"
and permission == "read"
):
resource_grants[resource_id]["is_public"] = True
continue continue
# Add to appropriate list # Add to appropriate list
if permission in ["read", "write"]: if permission in ['read', 'write']:
if principal_type == "group": if principal_type == 'group':
if ( if principal_id not in resource_grants[resource_id][permission]['group_ids']:
principal_id resource_grants[resource_id][permission]['group_ids'].append(principal_id)
not in resource_grants[resource_id][permission]["group_ids"] elif principal_type == 'user':
): if principal_id not in resource_grants[resource_id][permission]['user_ids']:
resource_grants[resource_id][permission]["group_ids"].append( resource_grants[resource_id][permission]['user_ids'].append(principal_id)
principal_id
)
elif principal_type == "user":
if (
principal_id
not in resource_grants[resource_id][permission]["user_ids"]
):
resource_grants[resource_id][permission]["user_ids"].append(
principal_id
)
# Step 3: Update each resource with reconstructed JSON # Step 3: Update each resource with reconstructed JSON
for resource_id, grants in resource_grants.items(): for resource_id, grants in resource_grants.items():
if grants["is_public"]: if grants['is_public']:
# Public = NULL # Public = NULL
access_control_value = None access_control_value = None
elif ( elif (
not grants["read"]["group_ids"] not grants['read']['group_ids']
and not grants["read"]["user_ids"] and not grants['read']['user_ids']
and not grants["write"]["group_ids"] and not grants['write']['group_ids']
and not grants["write"]["user_ids"] and not grants['write']['user_ids']
): ):
# No grants = should not happen (would mean no entries), default to {} # No grants = should not happen (would mean no entries), default to {}
access_control_value = json.dumps({}) access_control_value = json.dumps({})
@@ -328,17 +305,15 @@ def downgrade() -> None:
# Custom permissions # Custom permissions
access_control_value = json.dumps( access_control_value = json.dumps(
{ {
"read": grants["read"], 'read': grants['read'],
"write": grants["write"], 'write': grants['write'],
} }
) )
try: try:
conn.execute( conn.execute(
sa.text( sa.text(f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id'),
f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id' {'access_control': access_control_value, 'id': resource_id},
),
{"access_control": access_control_value, "id": resource_id},
) )
except Exception: except Exception:
pass pass
@@ -346,7 +321,7 @@ def downgrade() -> None:
# Step 4: Set all resources WITHOUT entries to private # Step 4: Set all resources WITHOUT entries to private
# For files: NULL means private (owner-only), so leave as NULL # For files: NULL means private (owner-only), so leave as NULL
# For other resources: {} means private, so update to {} # For other resources: {} means private, so update to {}
if resource_type != "file": if resource_type != 'file':
try: try:
conn.execute( conn.execute(
sa.text(f""" sa.text(f"""
@@ -357,13 +332,13 @@ def downgrade() -> None:
) )
AND access_control IS NULL AND access_control IS NULL
"""), """),
{"private_value": json.dumps({}), "resource_type": resource_type}, {'private_value': json.dumps({}), 'resource_type': resource_type},
) )
except Exception: except Exception:
pass pass
# For files, NULL stays NULL - no action needed # For files, NULL stays NULL - no action needed
# Step 5: Drop the access_grant table # Step 5: Drop the access_grant table
op.drop_index("idx_access_grant_principal", table_name="access_grant") op.drop_index('idx_access_grant_principal', table_name='access_grant')
op.drop_index("idx_access_grant_resource", table_name="access_grant") op.drop_index('idx_access_grant_resource', table_name='access_grant')
op.drop_table("access_grant") op.drop_table('access_grant')

View File

@@ -19,28 +19,24 @@ log = logging.getLogger(__name__)
class AccessGrant(Base): class AccessGrant(Base):
__tablename__ = "access_grant" __tablename__ = 'access_grant'
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True)
resource_type = Column( resource_type = Column(Text, nullable=False) # "knowledge", "model", "prompt", "tool", "note", "channel", "file"
Text, nullable=False
) # "knowledge", "model", "prompt", "tool", "note", "channel", "file"
resource_id = Column(Text, nullable=False) resource_id = Column(Text, nullable=False)
principal_type = Column(Text, nullable=False) # "user" or "group" principal_type = Column(Text, nullable=False) # "user" or "group"
principal_id = Column( principal_id = Column(Text, nullable=False) # user_id, group_id, or "*" (wildcard for public)
Text, nullable=False
) # user_id, group_id, or "*" (wildcard for public)
permission = Column(Text, nullable=False) # "read" or "write" permission = Column(Text, nullable=False) # "read" or "write"
created_at = Column(BigInteger, nullable=False) created_at = Column(BigInteger, nullable=False)
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
"resource_type", 'resource_type',
"resource_id", 'resource_id',
"principal_type", 'principal_type',
"principal_id", 'principal_id',
"permission", 'permission',
name="uq_access_grant_grant", name='uq_access_grant_grant',
), ),
) )
@@ -66,7 +62,7 @@ class AccessGrantResponse(BaseModel):
permission: str permission: str
@classmethod @classmethod
def from_grant(cls, grant: "AccessGrantModel") -> "AccessGrantResponse": def from_grant(cls, grant: 'AccessGrantModel') -> 'AccessGrantResponse':
return cls( return cls(
id=grant.id, id=grant.id,
principal_type=grant.principal_type, principal_type=grant.principal_type,
@@ -100,14 +96,14 @@ def access_control_to_grants(
if access_control is None: if access_control is None:
# NULL → public read (user:* for read) # NULL → public read (user:* for read)
# Exception: files with NULL are private (owner-only), no grants needed # Exception: files with NULL are private (owner-only), no grants needed
if resource_type != "file": if resource_type != 'file':
grants.append( grants.append(
{ {
"resource_type": resource_type, 'resource_type': resource_type,
"resource_id": resource_id, 'resource_id': resource_id,
"principal_type": "user", 'principal_type': 'user',
"principal_id": "*", 'principal_id': '*',
"permission": "read", 'permission': 'read',
} }
) )
return grants return grants
@@ -117,30 +113,30 @@ def access_control_to_grants(
return grants return grants
# Parse structured permissions # Parse structured permissions
for permission in ["read", "write"]: for permission in ['read', 'write']:
perm_data = access_control.get(permission, {}) perm_data = access_control.get(permission, {})
if not perm_data: if not perm_data:
continue continue
for group_id in perm_data.get("group_ids", []): for group_id in perm_data.get('group_ids', []):
grants.append( grants.append(
{ {
"resource_type": resource_type, 'resource_type': resource_type,
"resource_id": resource_id, 'resource_id': resource_id,
"principal_type": "group", 'principal_type': 'group',
"principal_id": group_id, 'principal_id': group_id,
"permission": permission, 'permission': permission,
} }
) )
for user_id in perm_data.get("user_ids", []): for user_id in perm_data.get('user_ids', []):
grants.append( grants.append(
{ {
"resource_type": resource_type, 'resource_type': resource_type,
"resource_id": resource_id, 'resource_id': resource_id,
"principal_type": "user", 'principal_type': 'user',
"principal_id": user_id, 'principal_id': user_id,
"permission": permission, 'permission': permission,
} }
) )
@@ -164,27 +160,23 @@ def normalize_access_grants(access_grants: Optional[list]) -> list[dict]:
if not isinstance(grant, dict): if not isinstance(grant, dict):
continue continue
principal_type = grant.get("principal_type") principal_type = grant.get('principal_type')
principal_id = grant.get("principal_id") principal_id = grant.get('principal_id')
permission = grant.get("permission") permission = grant.get('permission')
if principal_type not in ("user", "group"): if principal_type not in ('user', 'group'):
continue continue
if permission not in ("read", "write"): if permission not in ('read', 'write'):
continue continue
if not isinstance(principal_id, str) or not principal_id: if not isinstance(principal_id, str) or not principal_id:
continue continue
key = (principal_type, principal_id, permission) key = (principal_type, principal_id, permission)
deduped[key] = { deduped[key] = {
"id": ( 'id': (grant.get('id') if isinstance(grant.get('id'), str) and grant.get('id') else str(uuid.uuid4())),
grant.get("id") 'principal_type': principal_type,
if isinstance(grant.get("id"), str) and grant.get("id") 'principal_id': principal_id,
else str(uuid.uuid4()) 'permission': permission,
),
"principal_type": principal_type,
"principal_id": principal_id,
"permission": permission,
} }
return list(deduped.values()) return list(deduped.values())
@@ -195,11 +187,7 @@ def has_public_read_access_grant(access_grants: Optional[list]) -> bool:
Returns True when a direct grant list includes wildcard public-read. Returns True when a direct grant list includes wildcard public-read.
""" """
for grant in normalize_access_grants(access_grants): for grant in normalize_access_grants(access_grants):
if ( if grant['principal_type'] == 'user' and grant['principal_id'] == '*' and grant['permission'] == 'read':
grant["principal_type"] == "user"
and grant["principal_id"] == "*"
and grant["permission"] == "read"
):
return True return True
return False return False
@@ -209,7 +197,7 @@ def has_user_access_grant(access_grants: Optional[list]) -> bool:
Returns True when a direct grant list includes any non-wildcard user grant. Returns True when a direct grant list includes any non-wildcard user grant.
""" """
for grant in normalize_access_grants(access_grants): for grant in normalize_access_grants(access_grants):
if grant["principal_type"] == "user" and grant["principal_id"] != "*": if grant['principal_type'] == 'user' and grant['principal_id'] != '*':
return True return True
return False return False
@@ -225,18 +213,9 @@ def strip_user_access_grants(access_grants: Optional[list]) -> list:
grant grant
for grant in access_grants for grant in access_grants
if not ( if not (
( (grant.get('principal_type') if isinstance(grant, dict) else getattr(grant, 'principal_type', None))
grant.get("principal_type") == 'user'
if isinstance(grant, dict) and (grant.get('principal_id') if isinstance(grant, dict) else getattr(grant, 'principal_id', None)) != '*'
else getattr(grant, "principal_type", None)
)
== "user"
and (
grant.get("principal_id")
if isinstance(grant, dict)
else getattr(grant, "principal_id", None)
)
!= "*"
) )
] ]
@@ -260,29 +239,25 @@ def grants_to_access_control(grants: list) -> Optional[dict]:
return {} # No grants = private/owner-only return {} # No grants = private/owner-only
result = { result = {
"read": {"group_ids": [], "user_ids": []}, 'read': {'group_ids': [], 'user_ids': []},
"write": {"group_ids": [], "user_ids": []}, 'write': {'group_ids': [], 'user_ids': []},
} }
is_public = False is_public = False
for grant in grants: for grant in grants:
if ( if grant.principal_type == 'user' and grant.principal_id == '*' and grant.permission == 'read':
grant.principal_type == "user"
and grant.principal_id == "*"
and grant.permission == "read"
):
is_public = True is_public = True
continue # Don't add wildcard to user_ids list continue # Don't add wildcard to user_ids list
if grant.permission not in ("read", "write"): if grant.permission not in ('read', 'write'):
continue continue
if grant.principal_type == "group": if grant.principal_type == 'group':
if grant.principal_id not in result[grant.permission]["group_ids"]: if grant.principal_id not in result[grant.permission]['group_ids']:
result[grant.permission]["group_ids"].append(grant.principal_id) result[grant.permission]['group_ids'].append(grant.principal_id)
elif grant.principal_type == "user": elif grant.principal_type == 'user':
if grant.principal_id not in result[grant.permission]["user_ids"]: if grant.principal_id not in result[grant.permission]['user_ids']:
result[grant.permission]["user_ids"].append(grant.principal_id) result[grant.permission]['user_ids'].append(grant.principal_id)
if is_public: if is_public:
return None # Public read access return None # Public read access
@@ -399,9 +374,7 @@ class AccessGrantsTable:
).delete() ).delete()
# Convert JSON to grant dicts # Convert JSON to grant dicts
grant_dicts = access_control_to_grants( grant_dicts = access_control_to_grants(resource_type, resource_id, access_control)
resource_type, resource_id, access_control
)
# Insert new grants # Insert new grants
results = [] results = []
@@ -442,9 +415,9 @@ class AccessGrantsTable:
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
resource_type=resource_type, resource_type=resource_type,
resource_id=resource_id, resource_id=resource_id,
principal_type=grant_dict["principal_type"], principal_type=grant_dict['principal_type'],
principal_id=grant_dict["principal_id"], principal_id=grant_dict['principal_id'],
permission=grant_dict["permission"], permission=grant_dict['permission'],
created_at=int(time.time()), created_at=int(time.time()),
) )
db.add(grant) db.add(grant)
@@ -511,9 +484,7 @@ class AccessGrantsTable:
) )
.all() .all()
) )
result: dict[str, list[AccessGrantModel]] = { result: dict[str, list[AccessGrantModel]] = {rid: [] for rid in resource_ids}
rid: [] for rid in resource_ids
}
for g in grants: for g in grants:
result[g.resource_id].append(AccessGrantModel.model_validate(g)) result[g.resource_id].append(AccessGrantModel.model_validate(g))
return result return result
@@ -523,7 +494,7 @@ class AccessGrantsTable:
user_id: str, user_id: str,
resource_type: str, resource_type: str,
resource_id: str, resource_id: str,
permission: str = "read", permission: str = 'read',
user_group_ids: Optional[set[str]] = None, user_group_ids: Optional[set[str]] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> bool: ) -> bool:
@@ -540,12 +511,12 @@ class AccessGrantsTable:
conditions = [ conditions = [
# Public access # Public access
and_( and_(
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == "*", AccessGrant.principal_id == '*',
), ),
# Direct user access # Direct user access
and_( and_(
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id, AccessGrant.principal_id == user_id,
), ),
] ]
@@ -560,7 +531,7 @@ class AccessGrantsTable:
if user_group_ids: if user_group_ids:
conditions.append( conditions.append(
and_( and_(
AccessGrant.principal_type == "group", AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(user_group_ids), AccessGrant.principal_id.in_(user_group_ids),
) )
) )
@@ -582,7 +553,7 @@ class AccessGrantsTable:
user_id: str, user_id: str,
resource_type: str, resource_type: str,
resource_ids: list[str], resource_ids: list[str],
permission: str = "read", permission: str = 'read',
user_group_ids: Optional[set[str]] = None, user_group_ids: Optional[set[str]] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> set[str]: ) -> set[str]:
@@ -597,11 +568,11 @@ class AccessGrantsTable:
with get_db_context(db) as db: with get_db_context(db) as db:
conditions = [ conditions = [
and_( and_(
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == "*", AccessGrant.principal_id == '*',
), ),
and_( and_(
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id, AccessGrant.principal_id == user_id,
), ),
] ]
@@ -615,7 +586,7 @@ class AccessGrantsTable:
if user_group_ids: if user_group_ids:
conditions.append( conditions.append(
and_( and_(
AccessGrant.principal_type == "group", AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(user_group_ids), AccessGrant.principal_id.in_(user_group_ids),
) )
) )
@@ -637,7 +608,7 @@ class AccessGrantsTable:
self, self,
resource_type: str, resource_type: str,
resource_id: str, resource_id: str,
permission: str = "read", permission: str = 'read',
db: Optional[Session] = None, db: Optional[Session] = None,
) -> list: ) -> list:
""" """
@@ -660,19 +631,17 @@ class AccessGrantsTable:
# Check for public access # Check for public access
for grant in grants: for grant in grants:
if grant.principal_type == "user" and grant.principal_id == "*": if grant.principal_type == 'user' and grant.principal_id == '*':
result = Users.get_users(filter={"roles": ["!pending"]}, db=db) result = Users.get_users(filter={'roles': ['!pending']}, db=db)
return result.get("users", []) return result.get('users', [])
user_ids_with_access = set() user_ids_with_access = set()
for grant in grants: for grant in grants:
if grant.principal_type == "user": if grant.principal_type == 'user':
user_ids_with_access.add(grant.principal_id) user_ids_with_access.add(grant.principal_id)
elif grant.principal_type == "group": elif grant.principal_type == 'group':
group_user_ids = Groups.get_group_user_ids_by_id( group_user_ids = Groups.get_group_user_ids_by_id(grant.principal_id, db=db)
grant.principal_id, db=db
)
if group_user_ids: if group_user_ids:
user_ids_with_access.update(group_user_ids) user_ids_with_access.update(group_user_ids)
@@ -688,20 +657,18 @@ class AccessGrantsTable:
DocumentModel, DocumentModel,
filter: dict, filter: dict,
resource_type: str, resource_type: str,
permission: str = "read", permission: str = 'read',
): ):
""" """
Apply access control filtering to a SQLAlchemy query by JOINing with access_grant. Apply access control filtering to a SQLAlchemy query by JOINing with access_grant.
This replaces the old JSON-column-based filtering with a proper relational JOIN. This replaces the old JSON-column-based filtering with a proper relational JOIN.
""" """
group_ids = filter.get("group_ids", []) group_ids = filter.get('group_ids', [])
user_id = filter.get("user_id") user_id = filter.get('user_id')
if permission == "read_only": if permission == 'read_only':
return self._has_read_only_permission_filter( return self._has_read_only_permission_filter(db, query, DocumentModel, filter, resource_type)
db, query, DocumentModel, filter, resource_type
)
# Build principal conditions # Build principal conditions
principal_conditions = [] principal_conditions = []
@@ -710,8 +677,8 @@ class AccessGrantsTable:
# Public access: user:* read # Public access: user:* read
principal_conditions.append( principal_conditions.append(
and_( and_(
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == "*", AccessGrant.principal_id == '*',
) )
) )
@@ -722,7 +689,7 @@ class AccessGrantsTable:
# Direct user grant # Direct user grant
principal_conditions.append( principal_conditions.append(
and_( and_(
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id, AccessGrant.principal_id == user_id,
) )
) )
@@ -731,7 +698,7 @@ class AccessGrantsTable:
# Group grants # Group grants
principal_conditions.append( principal_conditions.append(
and_( and_(
AccessGrant.principal_type == "group", AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(group_ids), AccessGrant.principal_id.in_(group_ids),
) )
) )
@@ -751,13 +718,13 @@ class AccessGrantsTable:
AccessGrant.permission == permission, AccessGrant.permission == permission,
or_( or_(
and_( and_(
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == "*", AccessGrant.principal_id == '*',
), ),
*( *(
[ [
and_( and_(
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id, AccessGrant.principal_id == user_id,
) )
] ]
@@ -767,7 +734,7 @@ class AccessGrantsTable:
*( *(
[ [
and_( and_(
AccessGrant.principal_type == "group", AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(group_ids), AccessGrant.principal_id.in_(group_ids),
) )
] ]
@@ -800,8 +767,8 @@ class AccessGrantsTable:
Filter for items where user has read BUT NOT write access. Filter for items where user has read BUT NOT write access.
Public items are NOT considered read_only. Public items are NOT considered read_only.
""" """
group_ids = filter.get("group_ids", []) group_ids = filter.get('group_ids', [])
user_id = filter.get("user_id") user_id = filter.get('user_id')
from sqlalchemy import exists as sa_exists, select from sqlalchemy import exists as sa_exists, select
@@ -811,12 +778,12 @@ class AccessGrantsTable:
.where( .where(
AccessGrant.resource_type == resource_type, AccessGrant.resource_type == resource_type,
AccessGrant.resource_id == DocumentModel.id, AccessGrant.resource_id == DocumentModel.id,
AccessGrant.permission == "read", AccessGrant.permission == 'read',
or_( or_(
*( *(
[ [
and_( and_(
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id, AccessGrant.principal_id == user_id,
) )
] ]
@@ -826,7 +793,7 @@ class AccessGrantsTable:
*( *(
[ [
and_( and_(
AccessGrant.principal_type == "group", AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(group_ids), AccessGrant.principal_id.in_(group_ids),
) )
] ]
@@ -845,12 +812,12 @@ class AccessGrantsTable:
.where( .where(
AccessGrant.resource_type == resource_type, AccessGrant.resource_type == resource_type,
AccessGrant.resource_id == DocumentModel.id, AccessGrant.resource_id == DocumentModel.id,
AccessGrant.permission == "write", AccessGrant.permission == 'write',
or_( or_(
*( *(
[ [
and_( and_(
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == user_id, AccessGrant.principal_id == user_id,
) )
] ]
@@ -860,7 +827,7 @@ class AccessGrantsTable:
*( *(
[ [
and_( and_(
AccessGrant.principal_type == "group", AccessGrant.principal_type == 'group',
AccessGrant.principal_id.in_(group_ids), AccessGrant.principal_id.in_(group_ids),
) )
] ]
@@ -879,9 +846,9 @@ class AccessGrantsTable:
.where( .where(
AccessGrant.resource_type == resource_type, AccessGrant.resource_type == resource_type,
AccessGrant.resource_id == DocumentModel.id, AccessGrant.resource_id == DocumentModel.id,
AccessGrant.permission == "read", AccessGrant.permission == 'read',
AccessGrant.principal_type == "user", AccessGrant.principal_type == 'user',
AccessGrant.principal_id == "*", AccessGrant.principal_id == '*',
) )
.correlate(DocumentModel) .correlate(DocumentModel)
.exists() .exists()

View File

@@ -17,7 +17,7 @@ log = logging.getLogger(__name__)
class Auth(Base): class Auth(Base):
__tablename__ = "auth" __tablename__ = 'auth'
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True, unique=True)
email = Column(String) email = Column(String)
@@ -73,9 +73,9 @@ class SignupForm(BaseModel):
name: str name: str
email: str email: str
password: str password: str
profile_image_url: Optional[str] = "/user.png" profile_image_url: Optional[str] = '/user.png'
@field_validator("profile_image_url") @field_validator('profile_image_url')
@classmethod @classmethod
def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]: def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]:
if v is not None: if v is not None:
@@ -84,7 +84,7 @@ class SignupForm(BaseModel):
class AddUserForm(SignupForm): class AddUserForm(SignupForm):
role: Optional[str] = "pending" role: Optional[str] = 'pending'
class AuthsTable: class AuthsTable:
@@ -93,25 +93,21 @@ class AuthsTable:
email: str, email: str,
password: str, password: str,
name: str, name: str,
profile_image_url: str = "/user.png", profile_image_url: str = '/user.png',
role: str = "pending", role: str = 'pending',
oauth: Optional[dict] = None, oauth: Optional[dict] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
log.info("insert_new_auth") log.info('insert_new_auth')
id = str(uuid.uuid4()) id = str(uuid.uuid4())
auth = AuthModel( auth = AuthModel(**{'id': id, 'email': email, 'password': password, 'active': True})
**{"id": id, "email": email, "password": password, "active": True}
)
result = Auth(**auth.model_dump()) result = Auth(**auth.model_dump())
db.add(result) db.add(result)
user = Users.insert_new_user( user = Users.insert_new_user(id, name, email, profile_image_url, role, oauth=oauth, db=db)
id, name, email, profile_image_url, role, oauth=oauth, db=db
)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
@@ -124,7 +120,7 @@ class AuthsTable:
def authenticate_user( def authenticate_user(
self, email: str, verify_password: callable, db: Optional[Session] = None self, email: str, verify_password: callable, db: Optional[Session] = None
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f'authenticate_user: {email}')
user = Users.get_user_by_email(email, db=db) user = Users.get_user_by_email(email, db=db)
if not user: if not user:
@@ -143,10 +139,8 @@ class AuthsTable:
except Exception: except Exception:
return None return None
def authenticate_user_by_api_key( def authenticate_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
self, api_key: str, db: Optional[Session] = None log.info(f'authenticate_user_by_api_key')
) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key")
# if no api_key, return None # if no api_key, return None
if not api_key: if not api_key:
return None return None
@@ -157,10 +151,8 @@ class AuthsTable:
except Exception: except Exception:
return False return False
def authenticate_user_by_email( def authenticate_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
self, email: str, db: Optional[Session] = None log.info(f'authenticate_user_by_email: {email}')
) -> Optional[UserModel]:
log.info(f"authenticate_user_by_email: {email}")
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
# Single JOIN query instead of two separate queries # Single JOIN query instead of two separate queries
@@ -177,28 +169,22 @@ class AuthsTable:
except Exception: except Exception:
return None return None
def update_user_password_by_id( def update_user_password_by_id(self, id: str, new_password: str, db: Optional[Session] = None) -> bool:
self, id: str, new_password: str, db: Optional[Session] = None
) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
result = ( result = db.query(Auth).filter_by(id=id).update({'password': new_password})
db.query(Auth).filter_by(id=id).update({"password": new_password})
)
db.commit() db.commit()
return True if result == 1 else False return True if result == 1 else False
except Exception: except Exception:
return False return False
def update_email_by_id( def update_email_by_id(self, id: str, email: str, db: Optional[Session] = None) -> bool:
self, id: str, email: str, db: Optional[Session] = None
) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
result = db.query(Auth).filter_by(id=id).update({"email": email}) result = db.query(Auth).filter_by(id=id).update({'email': email})
db.commit() db.commit()
if result == 1: if result == 1:
Users.update_user_by_id(id, {"email": email}, db=db) Users.update_user_by_id(id, {'email': email}, db=db)
return True return True
return False return False
except Exception: except Exception:

View File

@@ -37,7 +37,7 @@ from sqlalchemy.sql import exists
class Channel(Base): class Channel(Base):
__tablename__ = "channel" __tablename__ = 'channel'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text) user_id = Column(Text)
@@ -94,7 +94,7 @@ class ChannelModel(BaseModel):
class ChannelMember(Base): class ChannelMember(Base):
__tablename__ = "channel_member" __tablename__ = 'channel_member'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
channel_id = Column(Text, nullable=False) channel_id = Column(Text, nullable=False)
@@ -154,25 +154,19 @@ class ChannelMemberModel(BaseModel):
class ChannelFile(Base): class ChannelFile(Base):
__tablename__ = "channel_file" __tablename__ = 'channel_file'
id = Column(Text, unique=True, primary_key=True) id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text, nullable=False) user_id = Column(Text, nullable=False)
channel_id = Column( channel_id = Column(Text, ForeignKey('channel.id', ondelete='CASCADE'), nullable=False)
Text, ForeignKey("channel.id", ondelete="CASCADE"), nullable=False message_id = Column(Text, ForeignKey('message.id', ondelete='CASCADE'), nullable=True)
) file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False)
message_id = Column(
Text, ForeignKey("message.id", ondelete="CASCADE"), nullable=True
)
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
created_at = Column(BigInteger, nullable=False) created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False)
__table_args__ = ( __table_args__ = (UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'),)
UniqueConstraint("channel_id", "file_id", name="uq_channel_file_channel_file"),
)
class ChannelFileModel(BaseModel): class ChannelFileModel(BaseModel):
@@ -189,7 +183,7 @@ class ChannelFileModel(BaseModel):
class ChannelWebhook(Base): class ChannelWebhook(Base):
__tablename__ = "channel_webhook" __tablename__ = 'channel_webhook'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
channel_id = Column(Text, nullable=False) channel_id = Column(Text, nullable=False)
@@ -235,7 +229,7 @@ class ChannelResponse(ChannelModel):
class ChannelForm(BaseModel): class ChannelForm(BaseModel):
name: str = "" name: str = ''
description: Optional[str] = None description: Optional[str] = None
is_private: Optional[bool] = None is_private: Optional[bool] = None
data: Optional[dict] = None data: Optional[dict] = None
@@ -255,10 +249,8 @@ class ChannelWebhookForm(BaseModel):
class ChannelTable: class ChannelTable:
def _get_access_grants( def _get_access_grants(self, channel_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
self, channel_id: str, db: Optional[Session] = None return AccessGrants.get_grants_by_resource('channel', channel_id, db=db)
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("channel", channel_id, db=db)
def _to_channel_model( def _to_channel_model(
self, self,
@@ -266,13 +258,9 @@ class ChannelTable:
access_grants: Optional[list[AccessGrantModel]] = None, access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> ChannelModel: ) -> ChannelModel:
channel_data = ChannelModel.model_validate(channel).model_dump( channel_data = ChannelModel.model_validate(channel).model_dump(exclude={'access_grants'})
exclude={"access_grants"} channel_data['access_grants'] = (
) access_grants if access_grants is not None else self._get_access_grants(channel_data['id'], db=db)
channel_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(channel_data["id"], db=db)
) )
return ChannelModel.model_validate(channel_data) return ChannelModel.model_validate(channel_data)
@@ -313,20 +301,20 @@ class ChannelTable:
for uid in user_ids: for uid in user_ids:
model = ChannelMemberModel( model = ChannelMemberModel(
**{ **{
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"channel_id": channel_id, 'channel_id': channel_id,
"user_id": uid, 'user_id': uid,
"status": "joined", 'status': 'joined',
"is_active": True, 'is_active': True,
"is_channel_muted": False, 'is_channel_muted': False,
"is_channel_pinned": False, 'is_channel_pinned': False,
"invited_at": now, 'invited_at': now,
"invited_by": invited_by, 'invited_by': invited_by,
"joined_at": now, 'joined_at': now,
"left_at": None, 'left_at': None,
"last_read_at": now, 'last_read_at': now,
"created_at": now, 'created_at': now,
"updated_at": now, 'updated_at': now,
} }
) )
memberships.append(ChannelMember(**model.model_dump())) memberships.append(ChannelMember(**model.model_dump()))
@@ -339,19 +327,19 @@ class ChannelTable:
with get_db_context(db) as db: with get_db_context(db) as db:
channel = ChannelModel( channel = ChannelModel(
**{ **{
**form_data.model_dump(exclude={"access_grants"}), **form_data.model_dump(exclude={'access_grants'}),
"type": form_data.type if form_data.type else None, 'type': form_data.type if form_data.type else None,
"name": form_data.name.lower(), 'name': form_data.name.lower(),
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"user_id": user_id, 'user_id': user_id,
"created_at": int(time.time_ns()), 'created_at': int(time.time_ns()),
"updated_at": int(time.time_ns()), 'updated_at': int(time.time_ns()),
"access_grants": [], 'access_grants': [],
} }
) )
new_channel = Channel(**channel.model_dump(exclude={"access_grants"})) new_channel = Channel(**channel.model_dump(exclude={'access_grants'}))
if form_data.type in ["group", "dm"]: if form_data.type in ['group', 'dm']:
users = self._collect_unique_user_ids( users = self._collect_unique_user_ids(
invited_by=user_id, invited_by=user_id,
user_ids=form_data.user_ids, user_ids=form_data.user_ids,
@@ -366,18 +354,14 @@ class ChannelTable:
db.add_all(memberships) db.add_all(memberships)
db.add(new_channel) db.add(new_channel)
db.commit() db.commit()
AccessGrants.set_access_grants( AccessGrants.set_access_grants('channel', new_channel.id, form_data.access_grants, db=db)
"channel", new_channel.id, form_data.access_grants, db=db
)
return self._to_channel_model(new_channel, db=db) return self._to_channel_model(new_channel, db=db)
def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]: def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
channels = db.query(Channel).all() channels = db.query(Channel).all()
channel_ids = [channel.id for channel in channels] channel_ids = [channel.id for channel in channels]
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
"channel", channel_ids, db=db
)
return [ return [
self._to_channel_model( self._to_channel_model(
channel, channel,
@@ -387,23 +371,19 @@ class ChannelTable:
for channel in channels for channel in channels
] ]
def _has_permission(self, db, query, filter: dict, permission: str = "read"): def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
return AccessGrants.has_permission_filter( return AccessGrants.has_permission_filter(
db=db, db=db,
query=query, query=query,
DocumentModel=Channel, DocumentModel=Channel,
filter=filter, filter=filter,
resource_type="channel", resource_type='channel',
permission=permission, permission=permission,
) )
def get_channels_by_user_id( def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
self, user_id: str, db: Optional[Session] = None
) -> list[ChannelModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
user_group_ids = [ user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
]
membership_channels = ( membership_channels = (
db.query(Channel) db.query(Channel)
@@ -411,7 +391,7 @@ class ChannelTable:
.filter( .filter(
Channel.deleted_at.is_(None), Channel.deleted_at.is_(None),
Channel.archived_at.is_(None), Channel.archived_at.is_(None),
Channel.type.in_(["group", "dm"]), Channel.type.in_(['group', 'dm']),
ChannelMember.user_id == user_id, ChannelMember.user_id == user_id,
ChannelMember.is_active.is_(True), ChannelMember.is_active.is_(True),
) )
@@ -423,29 +403,20 @@ class ChannelTable:
Channel.archived_at.is_(None), Channel.archived_at.is_(None),
or_( or_(
Channel.type.is_(None), # True NULL/None Channel.type.is_(None), # True NULL/None
Channel.type == "", # Empty string Channel.type == '', # Empty string
and_(Channel.type != "group", Channel.type != "dm"), and_(Channel.type != 'group', Channel.type != 'dm'),
), ),
) )
query = self._has_permission( query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids})
db, query, {"user_id": user_id, "group_ids": user_group_ids}
)
standard_channels = query.all() standard_channels = query.all()
all_channels = membership_channels + standard_channels all_channels = membership_channels + standard_channels
channel_ids = [c.id for c in all_channels] channel_ids = [c.id for c in all_channels]
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
"channel", channel_ids, db=db return [self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) for c in all_channels]
)
return [
self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db)
for c in all_channels
]
def get_dm_channel_by_user_ids( def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]:
self, user_ids: list[str], db: Optional[Session] = None
) -> Optional[ChannelModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
# Ensure uniqueness in case a list with duplicates is passed # Ensure uniqueness in case a list with duplicates is passed
unique_user_ids = list(set(user_ids)) unique_user_ids = list(set(user_ids))
@@ -471,7 +442,7 @@ class ChannelTable:
db.query(Channel) db.query(Channel)
.filter( .filter(
Channel.id.in_(subquery), Channel.id.in_(subquery),
Channel.type == "dm", Channel.type == 'dm',
) )
.first() .first()
) )
@@ -488,32 +459,23 @@ class ChannelTable:
) -> list[ChannelMemberModel]: ) -> list[ChannelMemberModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
# 1. Collect all user_ids including groups + inviter # 1. Collect all user_ids including groups + inviter
requested_users = self._collect_unique_user_ids( requested_users = self._collect_unique_user_ids(invited_by, user_ids, group_ids)
invited_by, user_ids, group_ids
)
existing_users = { existing_users = {
row.user_id row.user_id
for row in db.query(ChannelMember.user_id) for row in db.query(ChannelMember.user_id).filter(ChannelMember.channel_id == channel_id).all()
.filter(ChannelMember.channel_id == channel_id)
.all()
} }
new_user_ids = requested_users - existing_users new_user_ids = requested_users - existing_users
if not new_user_ids: if not new_user_ids:
return [] # Nothing to add return [] # Nothing to add
new_memberships = self._create_membership_models( new_memberships = self._create_membership_models(channel_id, invited_by, new_user_ids)
channel_id, invited_by, new_user_ids
)
db.add_all(new_memberships) db.add_all(new_memberships)
db.commit() db.commit()
return [ return [ChannelMemberModel.model_validate(membership) for membership in new_memberships]
ChannelMemberModel.model_validate(membership)
for membership in new_memberships
]
def remove_members_from_channel( def remove_members_from_channel(
self, self,
@@ -533,9 +495,7 @@ class ChannelTable:
db.commit() db.commit()
return result # number of rows deleted return result # number of rows deleted
def is_user_channel_manager( def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
# Check if the user is the creator of the channel # Check if the user is the creator of the channel
# or has a 'manager' role in ChannelMember # or has a 'manager' role in ChannelMember
@@ -548,15 +508,13 @@ class ChannelTable:
.filter( .filter(
ChannelMember.channel_id == channel_id, ChannelMember.channel_id == channel_id,
ChannelMember.user_id == user_id, ChannelMember.user_id == user_id,
ChannelMember.role == "manager", ChannelMember.role == 'manager',
) )
.first() .first()
) )
return membership is not None return membership is not None
def join_channel( def join_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChannelMemberModel]:
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> Optional[ChannelMemberModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
# Check if the membership already exists # Check if the membership already exists
existing_membership = ( existing_membership = (
@@ -573,18 +531,18 @@ class ChannelTable:
# Create new membership # Create new membership
channel_member = ChannelMemberModel( channel_member = ChannelMemberModel(
**{ **{
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"channel_id": channel_id, 'channel_id': channel_id,
"user_id": user_id, 'user_id': user_id,
"status": "joined", 'status': 'joined',
"is_active": True, 'is_active': True,
"is_channel_muted": False, 'is_channel_muted': False,
"is_channel_pinned": False, 'is_channel_pinned': False,
"joined_at": int(time.time_ns()), 'joined_at': int(time.time_ns()),
"left_at": None, 'left_at': None,
"last_read_at": int(time.time_ns()), 'last_read_at': int(time.time_ns()),
"created_at": int(time.time_ns()), 'created_at': int(time.time_ns()),
"updated_at": int(time.time_ns()), 'updated_at': int(time.time_ns()),
} }
) )
new_membership = ChannelMember(**channel_member.model_dump()) new_membership = ChannelMember(**channel_member.model_dump())
@@ -593,9 +551,7 @@ class ChannelTable:
db.commit() db.commit()
return channel_member return channel_member
def leave_channel( def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
@@ -608,7 +564,7 @@ class ChannelTable:
if not membership: if not membership:
return False return False
membership.status = "left" membership.status = 'left'
membership.is_active = False membership.is_active = False
membership.left_at = int(time.time_ns()) membership.left_at = int(time.time_ns())
membership.updated_at = int(time.time_ns()) membership.updated_at = int(time.time_ns())
@@ -630,19 +586,10 @@ class ChannelTable:
) )
return ChannelMemberModel.model_validate(membership) if membership else None return ChannelMemberModel.model_validate(membership) if membership else None
def get_members_by_channel_id( def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]:
self, channel_id: str, db: Optional[Session] = None
) -> list[ChannelMemberModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
memberships = ( memberships = db.query(ChannelMember).filter(ChannelMember.channel_id == channel_id).all()
db.query(ChannelMember) return [ChannelMemberModel.model_validate(membership) for membership in memberships]
.filter(ChannelMember.channel_id == channel_id)
.all()
)
return [
ChannelMemberModel.model_validate(membership)
for membership in memberships
]
def pin_channel( def pin_channel(
self, self,
@@ -669,9 +616,7 @@ class ChannelTable:
db.commit() db.commit()
return True return True
def update_member_last_read_at( def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
@@ -715,9 +660,7 @@ class ChannelTable:
db.commit() db.commit()
return True return True
def is_user_channel_member( def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
self, channel_id: str, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
@@ -729,9 +672,7 @@ class ChannelTable:
) )
return membership is not None return membership is not None
def get_channel_by_id( def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[ChannelModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
channel = db.query(Channel).filter(Channel.id == id).first() channel = db.query(Channel).filter(Channel.id == id).first()
@@ -739,18 +680,12 @@ class ChannelTable:
except Exception: except Exception:
return None return None
def get_channels_by_file_id( def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
self, file_id: str, db: Optional[Session] = None
) -> list[ChannelModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
channel_files = ( channel_files = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
)
channel_ids = [cf.channel_id for cf in channel_files] channel_ids = [cf.channel_id for cf in channel_files]
channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all() channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all()
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
"channel", channel_ids, db=db
)
return [ return [
self._to_channel_model( self._to_channel_model(
channel, channel,
@@ -765,9 +700,7 @@ class ChannelTable:
) -> list[ChannelModel]: ) -> list[ChannelModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
# 1. Determine which channels have this file # 1. Determine which channels have this file
channel_file_rows = ( channel_file_rows = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
)
channel_ids = [row.channel_id for row in channel_file_rows] channel_ids = [row.channel_id for row in channel_file_rows]
if not channel_ids: if not channel_ids:
@@ -787,15 +720,13 @@ class ChannelTable:
return [] return []
# Preload user's group membership # Preload user's group membership
user_group_ids = [ user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)]
g.id for g in Groups.get_groups_by_member_id(user_id, db=db)
]
allowed_channels = [] allowed_channels = []
for channel in channels: for channel in channels:
# --- Case A: group or dm => user must be an active member --- # --- Case A: group or dm => user must be an active member ---
if channel.type in ["group", "dm"]: if channel.type in ['group', 'dm']:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
.filter( .filter(
@@ -815,8 +746,8 @@ class ChannelTable:
query = self._has_permission( query = self._has_permission(
db, db,
query, query,
{"user_id": user_id, "group_ids": user_group_ids}, {'user_id': user_id, 'group_ids': user_group_ids},
permission="read", permission='read',
) )
allowed = query.first() allowed = query.first()
@@ -844,7 +775,7 @@ class ChannelTable:
return None return None
# If the channel is a group or dm, read access requires membership (active) # If the channel is a group or dm, read access requires membership (active)
if channel.type in ["group", "dm"]: if channel.type in ['group', 'dm']:
membership = ( membership = (
db.query(ChannelMember) db.query(ChannelMember)
.filter( .filter(
@@ -863,24 +794,18 @@ class ChannelTable:
query = db.query(Channel).filter(Channel.id == id) query = db.query(Channel).filter(Channel.id == id)
# Determine user groups # Determine user groups
user_group_ids = [ user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
]
# Apply ACL rules # Apply ACL rules
query = self._has_permission( query = self._has_permission(
db, db,
query, query,
{"user_id": user_id, "group_ids": user_group_ids}, {'user_id': user_id, 'group_ids': user_group_ids},
permission="read", permission='read',
) )
channel_allowed = query.first() channel_allowed = query.first()
return ( return self._to_channel_model(channel_allowed, db=db) if channel_allowed else None
self._to_channel_model(channel_allowed, db=db)
if channel_allowed
else None
)
def update_channel_by_id( def update_channel_by_id(
self, id: str, form_data: ChannelForm, db: Optional[Session] = None self, id: str, form_data: ChannelForm, db: Optional[Session] = None
@@ -898,9 +823,7 @@ class ChannelTable:
channel.meta = form_data.meta channel.meta = form_data.meta
if form_data.access_grants is not None: if form_data.access_grants is not None:
AccessGrants.set_access_grants( AccessGrants.set_access_grants('channel', id, form_data.access_grants, db=db)
"channel", id, form_data.access_grants, db=db
)
channel.updated_at = int(time.time_ns()) channel.updated_at = int(time.time_ns())
db.commit() db.commit()
@@ -912,12 +835,12 @@ class ChannelTable:
with get_db_context(db) as db: with get_db_context(db) as db:
channel_file = ChannelFileModel( channel_file = ChannelFileModel(
**{ **{
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"channel_id": channel_id, 'channel_id': channel_id,
"file_id": file_id, 'file_id': file_id,
"user_id": user_id, 'user_id': user_id,
"created_at": int(time.time()), 'created_at': int(time.time()),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
@@ -942,11 +865,7 @@ class ChannelTable:
) -> bool: ) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
channel_file = ( channel_file = db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).first()
db.query(ChannelFile)
.filter_by(channel_id=channel_id, file_id=file_id)
.first()
)
if not channel_file: if not channel_file:
return False return False
@@ -958,14 +877,10 @@ class ChannelTable:
except Exception: except Exception:
return False return False
def remove_file_from_channel_by_id( def remove_file_from_channel_by_id(self, channel_id: str, file_id: str, db: Optional[Session] = None) -> bool:
self, channel_id: str, file_id: str, db: Optional[Session] = None
) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
db.query(ChannelFile).filter_by( db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).delete()
channel_id=channel_id, file_id=file_id
).delete()
db.commit() db.commit()
return True return True
except Exception: except Exception:
@@ -973,7 +888,7 @@ class ChannelTable:
def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool: def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
AccessGrants.revoke_all_access("channel", id, db=db) AccessGrants.revoke_all_access('channel', id, db=db)
db.query(Channel).filter(Channel.id == id).delete() db.query(Channel).filter(Channel.id == id).delete()
db.commit() db.commit()
return True return True
@@ -1005,24 +920,14 @@ class ChannelTable:
db.commit() db.commit()
return webhook return webhook
def get_webhooks_by_channel_id( def get_webhooks_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelWebhookModel]:
self, channel_id: str, db: Optional[Session] = None
) -> list[ChannelWebhookModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
webhooks = ( webhooks = db.query(ChannelWebhook).filter(ChannelWebhook.channel_id == channel_id).all()
db.query(ChannelWebhook)
.filter(ChannelWebhook.channel_id == channel_id)
.all()
)
return [ChannelWebhookModel.model_validate(w) for w in webhooks] return [ChannelWebhookModel.model_validate(w) for w in webhooks]
def get_webhook_by_id( def get_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> Optional[ChannelWebhookModel]:
self, webhook_id: str, db: Optional[Session] = None
) -> Optional[ChannelWebhookModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
webhook = ( webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
)
return ChannelWebhookModel.model_validate(webhook) if webhook else None return ChannelWebhookModel.model_validate(webhook) if webhook else None
def get_webhook_by_id_and_token( def get_webhook_by_id_and_token(
@@ -1046,9 +951,7 @@ class ChannelTable:
db: Optional[Session] = None, db: Optional[Session] = None,
) -> Optional[ChannelWebhookModel]: ) -> Optional[ChannelWebhookModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
webhook = ( webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
)
if not webhook: if not webhook:
return None return None
webhook.name = form_data.name webhook.name = form_data.name
@@ -1057,28 +960,18 @@ class ChannelTable:
db.commit() db.commit()
return ChannelWebhookModel.model_validate(webhook) return ChannelWebhookModel.model_validate(webhook)
def update_webhook_last_used_at( def update_webhook_last_used_at(self, webhook_id: str, db: Optional[Session] = None) -> bool:
self, webhook_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
webhook = ( webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
)
if not webhook: if not webhook:
return False return False
webhook.last_used_at = int(time.time_ns()) webhook.last_used_at = int(time.time_ns())
db.commit() db.commit()
return True return True
def delete_webhook_by_id( def delete_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> bool:
self, webhook_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
result = ( result = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).delete()
db.query(ChannelWebhook)
.filter(ChannelWebhook.id == webhook_id)
.delete()
)
db.commit() db.commit()
return result > 0 return result > 0

View File

@@ -47,13 +47,11 @@ def _normalize_timestamp(timestamp: int) -> float:
class ChatMessage(Base): class ChatMessage(Base):
__tablename__ = "chat_message" __tablename__ = 'chat_message'
# Identity # Identity
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True)
chat_id = Column( chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False, index=True)
Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id = Column(Text, index=True) user_id = Column(Text, index=True)
# Structure # Structure
@@ -85,9 +83,9 @@ class ChatMessage(Base):
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
__table_args__ = ( __table_args__ = (
Index("chat_message_chat_parent_idx", "chat_id", "parent_id"), Index('chat_message_chat_parent_idx', 'chat_id', 'parent_id'),
Index("chat_message_model_created_idx", "model_id", "created_at"), Index('chat_message_model_created_idx', 'model_id', 'created_at'),
Index("chat_message_user_created_idx", "user_id", "created_at"), Index('chat_message_user_created_idx', 'user_id', 'created_at'),
) )
@@ -135,43 +133,41 @@ class ChatMessageTable:
"""Insert or update a chat message.""" """Insert or update a chat message."""
with get_db_context(db) as db: with get_db_context(db) as db:
now = int(time.time()) now = int(time.time())
timestamp = data.get("timestamp", now) timestamp = data.get('timestamp', now)
# Use composite ID: {chat_id}-{message_id} # Use composite ID: {chat_id}-{message_id}
composite_id = f"{chat_id}-{message_id}" composite_id = f'{chat_id}-{message_id}'
existing = db.get(ChatMessage, composite_id) existing = db.get(ChatMessage, composite_id)
if existing: if existing:
# Update existing # Update existing
if "role" in data: if 'role' in data:
existing.role = data["role"] existing.role = data['role']
if "parent_id" in data: if 'parent_id' in data:
existing.parent_id = data.get("parent_id") or data.get("parentId") existing.parent_id = data.get('parent_id') or data.get('parentId')
if "content" in data: if 'content' in data:
existing.content = data.get("content") existing.content = data.get('content')
if "output" in data: if 'output' in data:
existing.output = data.get("output") existing.output = data.get('output')
if "model_id" in data or "model" in data: if 'model_id' in data or 'model' in data:
existing.model_id = data.get("model_id") or data.get("model") existing.model_id = data.get('model_id') or data.get('model')
if "files" in data: if 'files' in data:
existing.files = data.get("files") existing.files = data.get('files')
if "sources" in data: if 'sources' in data:
existing.sources = data.get("sources") existing.sources = data.get('sources')
if "embeds" in data: if 'embeds' in data:
existing.embeds = data.get("embeds") existing.embeds = data.get('embeds')
if "done" in data: if 'done' in data:
existing.done = data.get("done", True) existing.done = data.get('done', True)
if "status_history" in data or "statusHistory" in data: if 'status_history' in data or 'statusHistory' in data:
existing.status_history = data.get("status_history") or data.get( existing.status_history = data.get('status_history') or data.get('statusHistory')
"statusHistory" if 'error' in data:
) existing.error = data.get('error')
if "error" in data:
existing.error = data.get("error")
# Extract usage - check direct field first, then info.usage # Extract usage - check direct field first, then info.usage
usage = data.get("usage") usage = data.get('usage')
if not usage: if not usage:
info = data.get("info", {}) info = data.get('info', {})
usage = info.get("usage") if info else None usage = info.get('usage') if info else None
if usage: if usage:
existing.usage = usage existing.usage = usage
existing.updated_at = now existing.updated_at = now
@@ -181,26 +177,25 @@ class ChatMessageTable:
else: else:
# Insert new # Insert new
# Extract usage - check direct field first, then info.usage # Extract usage - check direct field first, then info.usage
usage = data.get("usage") usage = data.get('usage')
if not usage: if not usage:
info = data.get("info", {}) info = data.get('info', {})
usage = info.get("usage") if info else None usage = info.get('usage') if info else None
message = ChatMessage( message = ChatMessage(
id=composite_id, id=composite_id,
chat_id=chat_id, chat_id=chat_id,
user_id=user_id, user_id=user_id,
role=data.get("role", "user"), role=data.get('role', 'user'),
parent_id=data.get("parent_id") or data.get("parentId"), parent_id=data.get('parent_id') or data.get('parentId'),
content=data.get("content"), content=data.get('content'),
output=data.get("output"), output=data.get('output'),
model_id=data.get("model_id") or data.get("model"), model_id=data.get('model_id') or data.get('model'),
files=data.get("files"), files=data.get('files'),
sources=data.get("sources"), sources=data.get('sources'),
embeds=data.get("embeds"), embeds=data.get('embeds'),
done=data.get("done", True), done=data.get('done', True),
status_history=data.get("status_history") status_history=data.get('status_history') or data.get('statusHistory'),
or data.get("statusHistory"), error=data.get('error'),
error=data.get("error"),
usage=usage, usage=usage,
created_at=timestamp, created_at=timestamp,
updated_at=now, updated_at=now,
@@ -210,23 +205,14 @@ class ChatMessageTable:
db.refresh(message) db.refresh(message)
return ChatMessageModel.model_validate(message) return ChatMessageModel.model_validate(message)
def get_message_by_id( def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatMessageModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[ChatMessageModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
message = db.get(ChatMessage, id) message = db.get(ChatMessage, id)
return ChatMessageModel.model_validate(message) if message else None return ChatMessageModel.model_validate(message) if message else None
def get_messages_by_chat_id( def get_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[ChatMessageModel]:
self, chat_id: str, db: Optional[Session] = None
) -> list[ChatMessageModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
messages = ( messages = db.query(ChatMessage).filter_by(chat_id=chat_id).order_by(ChatMessage.created_at.asc()).all()
db.query(ChatMessage)
.filter_by(chat_id=chat_id)
.order_by(ChatMessage.created_at.asc())
.all()
)
return [ChatMessageModel.model_validate(message) for message in messages] return [ChatMessageModel.model_validate(message) for message in messages]
def get_messages_by_user_id( def get_messages_by_user_id(
@@ -262,12 +248,7 @@ class ChatMessageTable:
query = query.filter(ChatMessage.created_at >= start_date) query = query.filter(ChatMessage.created_at >= start_date)
if end_date: if end_date:
query = query.filter(ChatMessage.created_at <= end_date) query = query.filter(ChatMessage.created_at <= end_date)
messages = ( messages = query.order_by(ChatMessage.created_at.desc()).offset(skip).limit(limit).all()
query.order_by(ChatMessage.created_at.desc())
.offset(skip)
.limit(limit)
.all()
)
return [ChatMessageModel.model_validate(message) for message in messages] return [ChatMessageModel.model_validate(message) for message in messages]
def get_chat_ids_by_model_id( def get_chat_ids_by_model_id(
@@ -284,7 +265,7 @@ class ChatMessageTable:
with get_db_context(db) as db: with get_db_context(db) as db:
query = db.query( query = db.query(
ChatMessage.chat_id, ChatMessage.chat_id,
func.max(ChatMessage.created_at).label("last_message_at"), func.max(ChatMessage.created_at).label('last_message_at'),
).filter(ChatMessage.model_id == model_id) ).filter(ChatMessage.model_id == model_id)
if start_date: if start_date:
query = query.filter(ChatMessage.created_at >= start_date) query = query.filter(ChatMessage.created_at >= start_date)
@@ -303,9 +284,7 @@ class ChatMessageTable:
) )
return [chat_id for chat_id, _ in chat_ids] return [chat_id for chat_id, _ in chat_ids]
def delete_messages_by_chat_id( def delete_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool:
self, chat_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
db.query(ChatMessage).filter_by(chat_id=chat_id).delete() db.query(ChatMessage).filter_by(chat_id=chat_id).delete()
db.commit() db.commit()
@@ -323,12 +302,10 @@ class ChatMessageTable:
from sqlalchemy import func from sqlalchemy import func
from open_webui.models.groups import GroupMember from open_webui.models.groups import GroupMember
query = db.query( query = db.query(ChatMessage.model_id, func.count(ChatMessage.id).label('count')).filter(
ChatMessage.model_id, func.count(ChatMessage.id).label("count") ChatMessage.role == 'assistant',
).filter(
ChatMessage.role == "assistant",
ChatMessage.model_id.isnot(None), ChatMessage.model_id.isnot(None),
~ChatMessage.user_id.like("shared-%"), ~ChatMessage.user_id.like('shared-%'),
) )
if start_date: if start_date:
@@ -336,11 +313,7 @@ class ChatMessageTable:
if end_date: if end_date:
query = query.filter(ChatMessage.created_at <= end_date) query = query.filter(ChatMessage.created_at <= end_date)
if group_id: if group_id:
group_users = ( group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
query = query.filter(ChatMessage.user_id.in_(group_users)) query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.group_by(ChatMessage.model_id).all() results = query.group_by(ChatMessage.model_id).all()
@@ -360,36 +333,32 @@ class ChatMessageTable:
dialect = db.bind.dialect.name dialect = db.bind.dialect.name
if dialect == "sqlite": if dialect == 'sqlite':
input_tokens = cast( input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer)
func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer)
) elif dialect == 'postgresql':
output_tokens = cast(
func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer
)
elif dialect == "postgresql":
# Use json_extract_path_text for PostgreSQL JSON columns # Use json_extract_path_text for PostgreSQL JSON columns
input_tokens = cast( input_tokens = cast(
func.json_extract_path_text(ChatMessage.usage, "input_tokens"), func.json_extract_path_text(ChatMessage.usage, 'input_tokens'),
Integer, Integer,
) )
output_tokens = cast( output_tokens = cast(
func.json_extract_path_text(ChatMessage.usage, "output_tokens"), func.json_extract_path_text(ChatMessage.usage, 'output_tokens'),
Integer, Integer,
) )
else: else:
raise NotImplementedError(f"Unsupported dialect: {dialect}") raise NotImplementedError(f'Unsupported dialect: {dialect}')
query = db.query( query = db.query(
ChatMessage.model_id, ChatMessage.model_id,
func.coalesce(func.sum(input_tokens), 0).label("input_tokens"), func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
func.coalesce(func.sum(output_tokens), 0).label("output_tokens"), func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
func.count(ChatMessage.id).label("message_count"), func.count(ChatMessage.id).label('message_count'),
).filter( ).filter(
ChatMessage.role == "assistant", ChatMessage.role == 'assistant',
ChatMessage.model_id.isnot(None), ChatMessage.model_id.isnot(None),
ChatMessage.usage.isnot(None), ChatMessage.usage.isnot(None),
~ChatMessage.user_id.like("shared-%"), ~ChatMessage.user_id.like('shared-%'),
) )
if start_date: if start_date:
@@ -397,21 +366,17 @@ class ChatMessageTable:
if end_date: if end_date:
query = query.filter(ChatMessage.created_at <= end_date) query = query.filter(ChatMessage.created_at <= end_date)
if group_id: if group_id:
group_users = ( group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
query = query.filter(ChatMessage.user_id.in_(group_users)) query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.group_by(ChatMessage.model_id).all() results = query.group_by(ChatMessage.model_id).all()
return { return {
row.model_id: { row.model_id: {
"input_tokens": row.input_tokens, 'input_tokens': row.input_tokens,
"output_tokens": row.output_tokens, 'output_tokens': row.output_tokens,
"total_tokens": row.input_tokens + row.output_tokens, 'total_tokens': row.input_tokens + row.output_tokens,
"message_count": row.message_count, 'message_count': row.message_count,
} }
for row in results for row in results
} }
@@ -430,36 +395,32 @@ class ChatMessageTable:
dialect = db.bind.dialect.name dialect = db.bind.dialect.name
if dialect == "sqlite": if dialect == 'sqlite':
input_tokens = cast( input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer)
func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer)
) elif dialect == 'postgresql':
output_tokens = cast(
func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer
)
elif dialect == "postgresql":
# Use json_extract_path_text for PostgreSQL JSON columns # Use json_extract_path_text for PostgreSQL JSON columns
input_tokens = cast( input_tokens = cast(
func.json_extract_path_text(ChatMessage.usage, "input_tokens"), func.json_extract_path_text(ChatMessage.usage, 'input_tokens'),
Integer, Integer,
) )
output_tokens = cast( output_tokens = cast(
func.json_extract_path_text(ChatMessage.usage, "output_tokens"), func.json_extract_path_text(ChatMessage.usage, 'output_tokens'),
Integer, Integer,
) )
else: else:
raise NotImplementedError(f"Unsupported dialect: {dialect}") raise NotImplementedError(f'Unsupported dialect: {dialect}')
query = db.query( query = db.query(
ChatMessage.user_id, ChatMessage.user_id,
func.coalesce(func.sum(input_tokens), 0).label("input_tokens"), func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
func.coalesce(func.sum(output_tokens), 0).label("output_tokens"), func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
func.count(ChatMessage.id).label("message_count"), func.count(ChatMessage.id).label('message_count'),
).filter( ).filter(
ChatMessage.role == "assistant", ChatMessage.role == 'assistant',
ChatMessage.user_id.isnot(None), ChatMessage.user_id.isnot(None),
ChatMessage.usage.isnot(None), ChatMessage.usage.isnot(None),
~ChatMessage.user_id.like("shared-%"), ~ChatMessage.user_id.like('shared-%'),
) )
if start_date: if start_date:
@@ -467,21 +428,17 @@ class ChatMessageTable:
if end_date: if end_date:
query = query.filter(ChatMessage.created_at <= end_date) query = query.filter(ChatMessage.created_at <= end_date)
if group_id: if group_id:
group_users = ( group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
query = query.filter(ChatMessage.user_id.in_(group_users)) query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.group_by(ChatMessage.user_id).all() results = query.group_by(ChatMessage.user_id).all()
return { return {
row.user_id: { row.user_id: {
"input_tokens": row.input_tokens, 'input_tokens': row.input_tokens,
"output_tokens": row.output_tokens, 'output_tokens': row.output_tokens,
"total_tokens": row.input_tokens + row.output_tokens, 'total_tokens': row.input_tokens + row.output_tokens,
"message_count": row.message_count, 'message_count': row.message_count,
} }
for row in results for row in results
} }
@@ -497,20 +454,16 @@ class ChatMessageTable:
from sqlalchemy import func from sqlalchemy import func
from open_webui.models.groups import GroupMember from open_webui.models.groups import GroupMember
query = db.query( query = db.query(ChatMessage.user_id, func.count(ChatMessage.id).label('count')).filter(
ChatMessage.user_id, func.count(ChatMessage.id).label("count") ~ChatMessage.user_id.like('shared-%')
).filter(~ChatMessage.user_id.like("shared-%")) )
if start_date: if start_date:
query = query.filter(ChatMessage.created_at >= start_date) query = query.filter(ChatMessage.created_at >= start_date)
if end_date: if end_date:
query = query.filter(ChatMessage.created_at <= end_date) query = query.filter(ChatMessage.created_at <= end_date)
if group_id: if group_id:
group_users = ( group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
query = query.filter(ChatMessage.user_id.in_(group_users)) query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.group_by(ChatMessage.user_id).all() results = query.group_by(ChatMessage.user_id).all()
@@ -527,20 +480,16 @@ class ChatMessageTable:
from sqlalchemy import func from sqlalchemy import func
from open_webui.models.groups import GroupMember from open_webui.models.groups import GroupMember
query = db.query( query = db.query(ChatMessage.chat_id, func.count(ChatMessage.id).label('count')).filter(
ChatMessage.chat_id, func.count(ChatMessage.id).label("count") ~ChatMessage.user_id.like('shared-%')
).filter(~ChatMessage.user_id.like("shared-%")) )
if start_date: if start_date:
query = query.filter(ChatMessage.created_at >= start_date) query = query.filter(ChatMessage.created_at >= start_date)
if end_date: if end_date:
query = query.filter(ChatMessage.created_at <= end_date) query = query.filter(ChatMessage.created_at <= end_date)
if group_id: if group_id:
group_users = ( group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
query = query.filter(ChatMessage.user_id.in_(group_users)) query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.group_by(ChatMessage.chat_id).all() results = query.group_by(ChatMessage.chat_id).all()
@@ -559,9 +508,9 @@ class ChatMessageTable:
from open_webui.models.groups import GroupMember from open_webui.models.groups import GroupMember
query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter( query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter(
ChatMessage.role == "assistant", ChatMessage.role == 'assistant',
ChatMessage.model_id.isnot(None), ChatMessage.model_id.isnot(None),
~ChatMessage.user_id.like("shared-%"), ~ChatMessage.user_id.like('shared-%'),
) )
if start_date: if start_date:
@@ -569,11 +518,7 @@ class ChatMessageTable:
if end_date: if end_date:
query = query.filter(ChatMessage.created_at <= end_date) query = query.filter(ChatMessage.created_at <= end_date)
if group_id: if group_id:
group_users = ( group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
db.query(GroupMember.user_id)
.filter(GroupMember.group_id == group_id)
.subquery()
)
query = query.filter(ChatMessage.user_id.in_(group_users)) query = query.filter(ChatMessage.user_id.in_(group_users))
results = query.all() results = query.all()
@@ -581,21 +526,17 @@ class ChatMessageTable:
# Group by date -> model -> count # Group by date -> model -> count
daily_counts: dict[str, dict[str, int]] = {} daily_counts: dict[str, dict[str, int]] = {}
for timestamp, model_id in results: for timestamp, model_id in results:
date_str = datetime.fromtimestamp( date_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d')
_normalize_timestamp(timestamp)
).strftime("%Y-%m-%d")
if date_str not in daily_counts: if date_str not in daily_counts:
daily_counts[date_str] = {} daily_counts[date_str] = {}
daily_counts[date_str][model_id] = ( daily_counts[date_str][model_id] = daily_counts[date_str].get(model_id, 0) + 1
daily_counts[date_str].get(model_id, 0) + 1
)
# Fill in missing days # Fill in missing days
if start_date and end_date: if start_date and end_date:
current = datetime.fromtimestamp(_normalize_timestamp(start_date)) current = datetime.fromtimestamp(_normalize_timestamp(start_date))
end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date)) end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date))
while current <= end_dt: while current <= end_dt:
date_str = current.strftime("%Y-%m-%d") date_str = current.strftime('%Y-%m-%d')
if date_str not in daily_counts: if date_str not in daily_counts:
daily_counts[date_str] = {} daily_counts[date_str] = {}
current += timedelta(days=1) current += timedelta(days=1)
@@ -613,9 +554,9 @@ class ChatMessageTable:
from datetime import datetime, timedelta from datetime import datetime, timedelta
query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter( query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter(
ChatMessage.role == "assistant", ChatMessage.role == 'assistant',
ChatMessage.model_id.isnot(None), ChatMessage.model_id.isnot(None),
~ChatMessage.user_id.like("shared-%"), ~ChatMessage.user_id.like('shared-%'),
) )
if start_date: if start_date:
@@ -628,23 +569,19 @@ class ChatMessageTable:
# Group by hour -> model -> count # Group by hour -> model -> count
hourly_counts: dict[str, dict[str, int]] = {} hourly_counts: dict[str, dict[str, int]] = {}
for timestamp, model_id in results: for timestamp, model_id in results:
hour_str = datetime.fromtimestamp( hour_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d %H:00')
_normalize_timestamp(timestamp)
).strftime("%Y-%m-%d %H:00")
if hour_str not in hourly_counts: if hour_str not in hourly_counts:
hourly_counts[hour_str] = {} hourly_counts[hour_str] = {}
hourly_counts[hour_str][model_id] = ( hourly_counts[hour_str][model_id] = hourly_counts[hour_str].get(model_id, 0) + 1
hourly_counts[hour_str].get(model_id, 0) + 1
)
# Fill in missing hours # Fill in missing hours
if start_date and end_date: if start_date and end_date:
current = datetime.fromtimestamp( current = datetime.fromtimestamp(_normalize_timestamp(start_date)).replace(
_normalize_timestamp(start_date) minute=0, second=0, microsecond=0
).replace(minute=0, second=0, microsecond=0) )
end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date)) end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date))
while current <= end_dt: while current <= end_dt:
hour_str = current.strftime("%Y-%m-%d %H:00") hour_str = current.strftime('%Y-%m-%d %H:00')
if hour_str not in hourly_counts: if hour_str not in hourly_counts:
hourly_counts[hour_str] = {} hourly_counts[hour_str] = {}
current += timedelta(hours=1) current += timedelta(hours=1)

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
class Feedback(Base): class Feedback(Base):
__tablename__ = "feedback" __tablename__ = 'feedback'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text) user_id = Column(Text)
version = Column(BigInteger, default=0) version = Column(BigInteger, default=0)
@@ -81,7 +81,7 @@ class RatingData(BaseModel):
sibling_model_ids: Optional[list[str]] = None sibling_model_ids: Optional[list[str]] = None
reason: Optional[str] = None reason: Optional[str] = None
comment: Optional[str] = None comment: Optional[str] = None
model_config = ConfigDict(extra="allow", protected_namespaces=()) model_config = ConfigDict(extra='allow', protected_namespaces=())
class MetaData(BaseModel): class MetaData(BaseModel):
@@ -89,12 +89,12 @@ class MetaData(BaseModel):
chat_id: Optional[str] = None chat_id: Optional[str] = None
message_id: Optional[str] = None message_id: Optional[str] = None
tags: Optional[list[str]] = None tags: Optional[list[str]] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
class SnapshotData(BaseModel): class SnapshotData(BaseModel):
chat: Optional[dict] = None chat: Optional[dict] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
class FeedbackForm(BaseModel): class FeedbackForm(BaseModel):
@@ -102,14 +102,14 @@ class FeedbackForm(BaseModel):
data: Optional[RatingData] = None data: Optional[RatingData] = None
meta: Optional[dict] = None meta: Optional[dict] = None
snapshot: Optional[SnapshotData] = None snapshot: Optional[SnapshotData] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str id: str
name: str name: str
email: str email: str
role: str = "pending" role: str = 'pending'
last_active_at: int # timestamp in epoch last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
@@ -146,12 +146,12 @@ class FeedbackTable:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
feedback = FeedbackModel( feedback = FeedbackModel(
**{ **{
"id": id, 'id': id,
"user_id": user_id, 'user_id': user_id,
"version": 0, 'version': 0,
**form_data.model_dump(), **form_data.model_dump(),
"created_at": int(time.time()), 'created_at': int(time.time()),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
try: try:
@@ -164,12 +164,10 @@ class FeedbackTable:
else: else:
return None return None
except Exception as e: except Exception as e:
log.exception(f"Error creating a new feedback: {e}") log.exception(f'Error creating a new feedback: {e}')
return None return None
def get_feedback_by_id( def get_feedback_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FeedbackModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[FeedbackModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
feedback = db.query(Feedback).filter_by(id=id).first() feedback = db.query(Feedback).filter_by(id=id).first()
@@ -191,16 +189,14 @@ class FeedbackTable:
except Exception: except Exception:
return None return None
def get_feedbacks_by_chat_id( def get_feedbacks_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[FeedbackModel]:
self, chat_id: str, db: Optional[Session] = None
) -> list[FeedbackModel]:
"""Get all feedbacks for a specific chat.""" """Get all feedbacks for a specific chat."""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
# meta.chat_id stores the chat reference # meta.chat_id stores the chat reference
feedbacks = ( feedbacks = (
db.query(Feedback) db.query(Feedback)
.filter(Feedback.meta["chat_id"].as_string() == chat_id) .filter(Feedback.meta['chat_id'].as_string() == chat_id)
.order_by(Feedback.created_at.desc()) .order_by(Feedback.created_at.desc())
.all() .all()
) )
@@ -219,36 +215,28 @@ class FeedbackTable:
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id) query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
if filter: if filter:
order_by = filter.get("order_by") order_by = filter.get('order_by')
direction = filter.get("direction") direction = filter.get('direction')
if order_by == "username": if order_by == 'username':
if direction == "asc": if direction == 'asc':
query = query.order_by(User.name.asc()) query = query.order_by(User.name.asc())
else: else:
query = query.order_by(User.name.desc()) query = query.order_by(User.name.desc())
elif order_by == "model_id": elif order_by == 'model_id':
# it's stored in feedback.data['model_id'] # it's stored in feedback.data['model_id']
if direction == "asc": if direction == 'asc':
query = query.order_by( query = query.order_by(Feedback.data['model_id'].as_string().asc())
Feedback.data["model_id"].as_string().asc()
)
else: else:
query = query.order_by( query = query.order_by(Feedback.data['model_id'].as_string().desc())
Feedback.data["model_id"].as_string().desc() elif order_by == 'rating':
)
elif order_by == "rating":
# it's stored in feedback.data['rating'] # it's stored in feedback.data['rating']
if direction == "asc": if direction == 'asc':
query = query.order_by( query = query.order_by(Feedback.data['rating'].as_string().asc())
Feedback.data["rating"].as_string().asc()
)
else: else:
query = query.order_by( query = query.order_by(Feedback.data['rating'].as_string().desc())
Feedback.data["rating"].as_string().desc() elif order_by == 'updated_at':
) if direction == 'asc':
elif order_by == "updated_at":
if direction == "asc":
query = query.order_by(Feedback.updated_at.asc()) query = query.order_by(Feedback.updated_at.asc())
else: else:
query = query.order_by(Feedback.updated_at.desc()) query = query.order_by(Feedback.updated_at.desc())
@@ -270,9 +258,7 @@ class FeedbackTable:
for feedback, user in items: for feedback, user in items:
feedback_model = FeedbackModel.model_validate(feedback) feedback_model = FeedbackModel.model_validate(feedback)
user_model = UserResponse.model_validate(user) user_model = UserResponse.model_validate(user)
feedbacks.append( feedbacks.append(FeedbackUserResponse(**feedback_model.model_dump(), user=user_model))
FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)
)
return FeedbackListResponse(items=feedbacks, total=total) return FeedbackListResponse(items=feedbacks, total=total)
@@ -280,14 +266,10 @@ class FeedbackTable:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
FeedbackModel.model_validate(feedback) FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback) for feedback in db.query(Feedback).order_by(Feedback.updated_at.desc()).all()
.order_by(Feedback.updated_at.desc())
.all()
] ]
def get_all_feedback_ids( def get_all_feedback_ids(self, db: Optional[Session] = None) -> list[FeedbackIdResponse]:
self, db: Optional[Session] = None
) -> list[FeedbackIdResponse]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
FeedbackIdResponse( FeedbackIdResponse(
@@ -306,14 +288,11 @@ class FeedbackTable:
.all() .all()
] ]
def get_feedbacks_for_leaderboard( def get_feedbacks_for_leaderboard(self, db: Optional[Session] = None) -> list[LeaderboardFeedbackData]:
self, db: Optional[Session] = None
) -> list[LeaderboardFeedbackData]:
"""Fetch only id and data for leaderboard computation (excludes snapshot/meta).""" """Fetch only id and data for leaderboard computation (excludes snapshot/meta)."""
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
LeaderboardFeedbackData(id=row.id, data=row.data) LeaderboardFeedbackData(id=row.id, data=row.data) for row in db.query(Feedback.id, Feedback.data).all()
for row in db.query(Feedback.id, Feedback.data).all()
] ]
def get_model_evaluation_history( def get_model_evaluation_history(
@@ -333,30 +312,26 @@ class FeedbackTable:
rows = db.query(Feedback.created_at, Feedback.data).all() rows = db.query(Feedback.created_at, Feedback.data).all()
else: else:
cutoff = int(time.time()) - (days * 86400) cutoff = int(time.time()) - (days * 86400)
rows = ( rows = db.query(Feedback.created_at, Feedback.data).filter(Feedback.created_at >= cutoff).all()
db.query(Feedback.created_at, Feedback.data)
.filter(Feedback.created_at >= cutoff)
.all()
)
daily_counts = defaultdict(lambda: {"won": 0, "lost": 0}) daily_counts = defaultdict(lambda: {'won': 0, 'lost': 0})
first_date = None first_date = None
for created_at, data in rows: for created_at, data in rows:
if not data: if not data:
continue continue
if data.get("model_id") != model_id: if data.get('model_id') != model_id:
continue continue
rating_str = str(data.get("rating", "")) rating_str = str(data.get('rating', ''))
if rating_str not in ("1", "-1"): if rating_str not in ('1', '-1'):
continue continue
date_str = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d") date_str = datetime.fromtimestamp(created_at).strftime('%Y-%m-%d')
if rating_str == "1": if rating_str == '1':
daily_counts[date_str]["won"] += 1 daily_counts[date_str]['won'] += 1
else: else:
daily_counts[date_str]["lost"] += 1 daily_counts[date_str]['lost'] += 1
# Track first date for this model # Track first date for this model
if first_date is None or date_str < first_date: if first_date is None or date_str < first_date:
@@ -368,7 +343,7 @@ class FeedbackTable:
if days == 0 and first_date: if days == 0 and first_date:
# All time: start from first feedback date # All time: start from first feedback date
start_date = datetime.strptime(first_date, "%Y-%m-%d").date() start_date = datetime.strptime(first_date, '%Y-%m-%d').date()
num_days = (today - start_date).days + 1 num_days = (today - start_date).days + 1
else: else:
# Fixed range # Fixed range
@@ -377,36 +352,24 @@ class FeedbackTable:
for i in range(num_days): for i in range(num_days):
d = start_date + timedelta(days=i) d = start_date + timedelta(days=i)
date_str = d.strftime("%Y-%m-%d") date_str = d.strftime('%Y-%m-%d')
counts = daily_counts.get(date_str, {"won": 0, "lost": 0}) counts = daily_counts.get(date_str, {'won': 0, 'lost': 0})
result.append( result.append(ModelHistoryEntry(date=date_str, won=counts['won'], lost=counts['lost']))
ModelHistoryEntry(date=date_str, won=counts["won"], lost=counts["lost"])
)
return result return result
def get_feedbacks_by_type( def get_feedbacks_by_type(self, type: str, db: Optional[Session] = None) -> list[FeedbackModel]:
self, type: str, db: Optional[Session] = None
) -> list[FeedbackModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
FeedbackModel.model_validate(feedback) FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback) for feedback in db.query(Feedback).filter_by(type=type).order_by(Feedback.updated_at.desc()).all()
.filter_by(type=type)
.order_by(Feedback.updated_at.desc())
.all()
] ]
def get_feedbacks_by_user_id( def get_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FeedbackModel]:
self, user_id: str, db: Optional[Session] = None
) -> list[FeedbackModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
FeedbackModel.model_validate(feedback) FeedbackModel.model_validate(feedback)
for feedback in db.query(Feedback) for feedback in db.query(Feedback).filter_by(user_id=user_id).order_by(Feedback.updated_at.desc()).all()
.filter_by(user_id=user_id)
.order_by(Feedback.updated_at.desc())
.all()
] ]
def update_feedback_by_id( def update_feedback_by_id(
@@ -462,9 +425,7 @@ class FeedbackTable:
db.commit() db.commit()
return True return True
def delete_feedback_by_id_and_user_id( def delete_feedback_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
self, id: str, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first() feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
if not feedback: if not feedback:
@@ -473,9 +434,7 @@ class FeedbackTable:
db.commit() db.commit()
return True return True
def delete_feedbacks_by_user_id( def delete_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
self, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
result = db.query(Feedback).filter_by(user_id=user_id).delete() result = db.query(Feedback).filter_by(user_id=user_id).delete()
db.commit() db.commit()

View File

@@ -16,7 +16,7 @@ log = logging.getLogger(__name__)
class File(Base): class File(Base):
__tablename__ = "file" __tablename__ = 'file'
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True, unique=True)
user_id = Column(String) user_id = Column(String)
hash = Column(Text, nullable=True) hash = Column(Text, nullable=True)
@@ -58,9 +58,9 @@ class FileMeta(BaseModel):
content_type: Optional[str] = None content_type: Optional[str] = None
size: Optional[int] = None size: Optional[int] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
@model_validator(mode="before") @model_validator(mode='before')
@classmethod @classmethod
def sanitize_meta(cls, data): def sanitize_meta(cls, data):
"""Sanitize metadata fields to handle malformed legacy data.""" """Sanitize metadata fields to handle malformed legacy data."""
@@ -68,14 +68,12 @@ class FileMeta(BaseModel):
return data return data
# Handle content_type that may be a list like ['application/pdf', None] # Handle content_type that may be a list like ['application/pdf', None]
content_type = data.get("content_type") content_type = data.get('content_type')
if isinstance(content_type, list): if isinstance(content_type, list):
# Extract first non-None string value # Extract first non-None string value
data["content_type"] = next( data['content_type'] = next((item for item in content_type if isinstance(item, str)), None)
(item for item in content_type if isinstance(item, str)), None
)
elif content_type is not None and not isinstance(content_type, str): elif content_type is not None and not isinstance(content_type, str):
data["content_type"] = None data['content_type'] = None
return data return data
@@ -92,7 +90,7 @@ class FileModelResponse(BaseModel):
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
class FileMetadataResponse(BaseModel): class FileMetadataResponse(BaseModel):
@@ -123,25 +121,22 @@ class FileUpdateForm(BaseModel):
meta: Optional[dict] = None meta: Optional[dict] = None
class FilesTable: class FilesTable:
def insert_new_file( def insert_new_file(self, user_id: str, form_data: FileForm, db: Optional[Session] = None) -> Optional[FileModel]:
self, user_id: str, form_data: FileForm, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
file_data = form_data.model_dump() file_data = form_data.model_dump()
# Sanitize meta to remove non-JSON-serializable objects # Sanitize meta to remove non-JSON-serializable objects
# (e.g. callable tool functions, MCP client instances from middleware) # (e.g. callable tool functions, MCP client instances from middleware)
if file_data.get("meta"): if file_data.get('meta'):
file_data["meta"] = sanitize_metadata(file_data["meta"]) file_data['meta'] = sanitize_metadata(file_data['meta'])
file = FileModel( file = FileModel(
**{ **{
**file_data, **file_data,
"user_id": user_id, 'user_id': user_id,
"created_at": int(time.time()), 'created_at': int(time.time()),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
@@ -155,12 +150,10 @@ class FilesTable:
else: else:
return None return None
except Exception as e: except Exception as e:
log.exception(f"Error inserting a new file: {e}") log.exception(f'Error inserting a new file: {e}')
return None return None
def get_file_by_id( def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[FileModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
@@ -171,9 +164,7 @@ class FilesTable:
except Exception: except Exception:
return None return None
def get_file_by_id_and_user_id( def get_file_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[FileModel]:
self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
file = db.query(File).filter_by(id=id, user_id=user_id).first() file = db.query(File).filter_by(id=id, user_id=user_id).first()
@@ -184,9 +175,7 @@ class FilesTable:
except Exception: except Exception:
return None return None
def get_file_metadata_by_id( def get_file_metadata_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileMetadataResponse]:
self, id: str, db: Optional[Session] = None
) -> Optional[FileMetadataResponse]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
file = db.get(File, id) file = db.get(File, id)
@@ -204,9 +193,7 @@ class FilesTable:
with get_db_context(db) as db: with get_db_context(db) as db:
return [FileModel.model_validate(file) for file in db.query(File).all()] return [FileModel.model_validate(file) for file in db.query(File).all()]
def check_access_by_user_id( def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool:
self, id, user_id, permission="write", db: Optional[Session] = None
) -> bool:
file = self.get_file_by_id(id, db=db) file = self.get_file_by_id(id, db=db)
if not file: if not file:
return False return False
@@ -215,21 +202,14 @@ class FilesTable:
# Implement additional access control logic here as needed # Implement additional access control logic here as needed
return False return False
def get_files_by_ids( def get_files_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileModel]:
self, ids: list[str], db: Optional[Session] = None
) -> list[FileModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
FileModel.model_validate(file) FileModel.model_validate(file)
for file in db.query(File) for file in db.query(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc()).all()
.filter(File.id.in_(ids))
.order_by(File.updated_at.desc())
.all()
] ]
def get_file_metadatas_by_ids( def get_file_metadatas_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileMetadataResponse]:
self, ids: list[str], db: Optional[Session] = None
) -> list[FileMetadataResponse]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
FileMetadataResponse( FileMetadataResponse(
@@ -239,22 +219,15 @@ class FilesTable:
created_at=file.created_at, created_at=file.created_at,
updated_at=file.updated_at, updated_at=file.updated_at,
) )
for file in db.query( for file in db.query(File.id, File.hash, File.meta, File.created_at, File.updated_at)
File.id, File.hash, File.meta, File.created_at, File.updated_at
)
.filter(File.id.in_(ids)) .filter(File.id.in_(ids))
.order_by(File.updated_at.desc()) .order_by(File.updated_at.desc())
.all() .all()
] ]
def get_files_by_user_id( def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]:
self, user_id: str, db: Optional[Session] = None
) -> list[FileModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [FileModel.model_validate(file) for file in db.query(File).filter_by(user_id=user_id).all()]
FileModel.model_validate(file)
for file in db.query(File).filter_by(user_id=user_id).all()
]
def get_file_list( def get_file_list(
self, self,
@@ -262,7 +235,7 @@ class FilesTable:
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> "FileListResponse": ) -> 'FileListResponse':
with get_db_context(db) as db: with get_db_context(db) as db:
query = db.query(File) query = db.query(File)
if user_id: if user_id:
@@ -272,10 +245,7 @@ class FilesTable:
items = [ items = [
FileModel.model_validate(file) FileModel.model_validate(file)
for file in query.order_by(File.updated_at.desc(), File.id.desc()) for file in query.order_by(File.updated_at.desc(), File.id.desc()).offset(skip).limit(limit).all()
.offset(skip)
.limit(limit)
.all()
] ]
return FileListResponse(items=items, total=total) return FileListResponse(items=items, total=total)
@@ -296,17 +266,17 @@ class FilesTable:
A SQL LIKE compatible pattern with proper escaping. A SQL LIKE compatible pattern with proper escaping.
""" """
# Escape SQL special characters first, then convert glob wildcards # Escape SQL special characters first, then convert glob wildcards
pattern = glob.replace("\\", "\\\\") pattern = glob.replace('\\', '\\\\')
pattern = pattern.replace("%", "\\%") pattern = pattern.replace('%', '\\%')
pattern = pattern.replace("_", "\\_") pattern = pattern.replace('_', '\\_')
pattern = pattern.replace("*", "%") pattern = pattern.replace('*', '%')
pattern = pattern.replace("?", "_") pattern = pattern.replace('?', '_')
return pattern return pattern
def search_files( def search_files(
self, self,
user_id: Optional[str] = None, user_id: Optional[str] = None,
filename: str = "*", filename: str = '*',
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
db: Optional[Session] = None, db: Optional[Session] = None,
@@ -331,15 +301,12 @@ class FilesTable:
query = query.filter_by(user_id=user_id) query = query.filter_by(user_id=user_id)
pattern = self._glob_to_like_pattern(filename) pattern = self._glob_to_like_pattern(filename)
if pattern != "%": if pattern != '%':
query = query.filter(File.filename.ilike(pattern, escape="\\")) query = query.filter(File.filename.ilike(pattern, escape='\\'))
return [ return [
FileModel.model_validate(file) FileModel.model_validate(file)
for file in query.order_by(File.created_at.desc(), File.id.desc()) for file in query.order_by(File.created_at.desc(), File.id.desc()).offset(skip).limit(limit).all()
.offset(skip)
.limit(limit)
.all()
] ]
def update_file_by_id( def update_file_by_id(
@@ -362,12 +329,10 @@ class FilesTable:
db.commit() db.commit()
return FileModel.model_validate(file) return FileModel.model_validate(file)
except Exception as e: except Exception as e:
log.exception(f"Error updating file completely by id: {e}") log.exception(f'Error updating file completely by id: {e}')
return None return None
def update_file_hash_by_id( def update_file_hash_by_id(self, id: str, hash: Optional[str], db: Optional[Session] = None) -> Optional[FileModel]:
self, id: str, hash: Optional[str], db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = db.query(File).filter_by(id=id).first()
@@ -379,9 +344,7 @@ class FilesTable:
except Exception: except Exception:
return None return None
def update_file_data_by_id( def update_file_data_by_id(self, id: str, data: dict, db: Optional[Session] = None) -> Optional[FileModel]:
self, id: str, data: dict, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = db.query(File).filter_by(id=id).first()
@@ -390,12 +353,9 @@ class FilesTable:
db.commit() db.commit()
return FileModel.model_validate(file) return FileModel.model_validate(file)
except Exception as e: except Exception as e:
return None return None
def update_file_metadata_by_id( def update_file_metadata_by_id(self, id: str, meta: dict, db: Optional[Session] = None) -> Optional[FileModel]:
self, id: str, meta: dict, db: Optional[Session] = None
) -> Optional[FileModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
file = db.query(File).filter_by(id=id).first() file = db.query(File).filter_by(id=id).first()

View File

@@ -20,7 +20,7 @@ log = logging.getLogger(__name__)
class Folder(Base): class Folder(Base):
__tablename__ = "folder" __tablename__ = 'folder'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
parent_id = Column(Text, nullable=True) parent_id = Column(Text, nullable=True)
user_id = Column(Text) user_id = Column(Text)
@@ -72,14 +72,14 @@ class FolderForm(BaseModel):
data: Optional[dict] = None data: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
parent_id: Optional[str] = None parent_id: Optional[str] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
class FolderUpdateForm(BaseModel): class FolderUpdateForm(BaseModel):
name: Optional[str] = None name: Optional[str] = None
data: Optional[dict] = None data: Optional[dict] = None
meta: Optional[dict] = None meta: Optional[dict] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
class FolderTable: class FolderTable:
@@ -94,12 +94,12 @@ class FolderTable:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
folder = FolderModel( folder = FolderModel(
**{ **{
"id": id, 'id': id,
"user_id": user_id, 'user_id': user_id,
**(form_data.model_dump(exclude_unset=True) or {}), **(form_data.model_dump(exclude_unset=True) or {}),
"parent_id": parent_id, 'parent_id': parent_id,
"created_at": int(time.time()), 'created_at': int(time.time()),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
try: try:
@@ -112,7 +112,7 @@ class FolderTable:
else: else:
return None return None
except Exception as e: except Exception as e:
log.exception(f"Error inserting a new folder: {e}") log.exception(f'Error inserting a new folder: {e}')
return None return None
def get_folder_by_id_and_user_id( def get_folder_by_id_and_user_id(
@@ -137,9 +137,7 @@ class FolderTable:
folders = [] folders = []
def get_children(folder): def get_children(folder):
children = self.get_folders_by_parent_id_and_user_id( children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db)
folder.id, user_id, db=db
)
for child in children: for child in children:
get_children(child) get_children(child)
folders.append(child) folders.append(child)
@@ -153,14 +151,9 @@ class FolderTable:
except Exception: except Exception:
return None return None
def get_folders_by_user_id( def get_folders_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FolderModel]:
self, user_id: str, db: Optional[Session] = None
) -> list[FolderModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [FolderModel.model_validate(folder) for folder in db.query(Folder).filter_by(user_id=user_id).all()]
FolderModel.model_validate(folder)
for folder in db.query(Folder).filter_by(user_id=user_id).all()
]
def get_folder_by_parent_id_and_user_id_and_name( def get_folder_by_parent_id_and_user_id_and_name(
self, self,
@@ -184,7 +177,7 @@ class FolderTable:
return FolderModel.model_validate(folder) return FolderModel.model_validate(folder)
except Exception as e: except Exception as e:
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}") log.error(f'get_folder_by_parent_id_and_user_id_and_name: {e}')
return None return None
def get_folders_by_parent_id_and_user_id( def get_folders_by_parent_id_and_user_id(
@@ -193,9 +186,7 @@ class FolderTable:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
FolderModel.model_validate(folder) FolderModel.model_validate(folder)
for folder in db.query(Folder) for folder in db.query(Folder).filter_by(parent_id=parent_id, user_id=user_id).all()
.filter_by(parent_id=parent_id, user_id=user_id)
.all()
] ]
def update_folder_parent_id_by_id_and_user_id( def update_folder_parent_id_by_id_and_user_id(
@@ -219,7 +210,7 @@ class FolderTable:
return FolderModel.model_validate(folder) return FolderModel.model_validate(folder)
except Exception as e: except Exception as e:
log.error(f"update_folder: {e}") log.error(f'update_folder: {e}')
return return
def update_folder_by_id_and_user_id( def update_folder_by_id_and_user_id(
@@ -241,7 +232,7 @@ class FolderTable:
existing_folder = ( existing_folder = (
db.query(Folder) db.query(Folder)
.filter_by( .filter_by(
name=form_data.get("name"), name=form_data.get('name'),
parent_id=folder.parent_id, parent_id=folder.parent_id,
user_id=user_id, user_id=user_id,
) )
@@ -251,17 +242,17 @@ class FolderTable:
if existing_folder and existing_folder.id != id: if existing_folder and existing_folder.id != id:
return None return None
folder.name = form_data.get("name", folder.name) folder.name = form_data.get('name', folder.name)
if "data" in form_data: if 'data' in form_data:
folder.data = { folder.data = {
**(folder.data or {}), **(folder.data or {}),
**form_data["data"], **form_data['data'],
} }
if "meta" in form_data: if 'meta' in form_data:
folder.meta = { folder.meta = {
**(folder.meta or {}), **(folder.meta or {}),
**form_data["meta"], **form_data['meta'],
} }
folder.updated_at = int(time.time()) folder.updated_at = int(time.time())
@@ -269,7 +260,7 @@ class FolderTable:
return FolderModel.model_validate(folder) return FolderModel.model_validate(folder)
except Exception as e: except Exception as e:
log.error(f"update_folder: {e}") log.error(f'update_folder: {e}')
return return
def update_folder_is_expanded_by_id_and_user_id( def update_folder_is_expanded_by_id_and_user_id(
@@ -289,12 +280,10 @@ class FolderTable:
return FolderModel.model_validate(folder) return FolderModel.model_validate(folder)
except Exception as e: except Exception as e:
log.error(f"update_folder: {e}") log.error(f'update_folder: {e}')
return return
def delete_folder_by_id_and_user_id( def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[str]:
self, id: str, user_id: str, db: Optional[Session] = None
) -> list[str]:
try: try:
folder_ids = [] folder_ids = []
with get_db_context(db) as db: with get_db_context(db) as db:
@@ -306,11 +295,8 @@ class FolderTable:
# Delete all children folders # Delete all children folders
def delete_children(folder): def delete_children(folder):
folder_children = self.get_folders_by_parent_id_and_user_id( folder_children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db)
folder.id, user_id, db=db
)
for folder_child in folder_children: for folder_child in folder_children:
delete_children(folder_child) delete_children(folder_child)
folder_ids.append(folder_child.id) folder_ids.append(folder_child.id)
@@ -323,12 +309,12 @@ class FolderTable:
db.commit() db.commit()
return folder_ids return folder_ids
except Exception as e: except Exception as e:
log.error(f"delete_folder: {e}") log.error(f'delete_folder: {e}')
return [] return []
def normalize_folder_name(self, name: str) -> str: def normalize_folder_name(self, name: str) -> str:
# Replace _ and space with a single space, lower case, collapse multiple spaces # Replace _ and space with a single space, lower case, collapse multiple spaces
name = re.sub(r"[\s_]+", " ", name) name = re.sub(r'[\s_]+', ' ', name)
return name.strip().lower() return name.strip().lower()
def search_folders_by_names( def search_folders_by_names(
@@ -349,9 +335,7 @@ class FolderTable:
results[folder.id] = FolderModel.model_validate(folder) results[folder.id] = FolderModel.model_validate(folder)
# get children folders # get children folders
children = self.get_children_folders_by_id_and_user_id( children = self.get_children_folders_by_id_and_user_id(folder.id, user_id, db=db)
folder.id, user_id, db=db
)
for child in children: for child in children:
results[child.id] = child results[child.id] = child

View File

@@ -16,7 +16,7 @@ log = logging.getLogger(__name__)
class Function(Base): class Function(Base):
__tablename__ = "function" __tablename__ = 'function'
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True, unique=True)
user_id = Column(String) user_id = Column(String)
@@ -30,13 +30,13 @@ class Function(Base):
updated_at = Column(BigInteger) updated_at = Column(BigInteger)
created_at = Column(BigInteger) created_at = Column(BigInteger)
__table_args__ = (Index("is_global_idx", "is_global"),) __table_args__ = (Index('is_global_idx', 'is_global'),)
class FunctionMeta(BaseModel): class FunctionMeta(BaseModel):
description: Optional[str] = None description: Optional[str] = None
manifest: Optional[dict] = {} manifest: Optional[dict] = {}
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
class FunctionModel(BaseModel): class FunctionModel(BaseModel):
@@ -113,10 +113,10 @@ class FunctionsTable:
function = FunctionModel( function = FunctionModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
"user_id": user_id, 'user_id': user_id,
"type": type, 'type': type,
"updated_at": int(time.time()), 'updated_at': int(time.time()),
"created_at": int(time.time()), 'created_at': int(time.time()),
} }
) )
@@ -131,7 +131,7 @@ class FunctionsTable:
else: else:
return None return None
except Exception as e: except Exception as e:
log.exception(f"Error creating a new function: {e}") log.exception(f'Error creating a new function: {e}')
return None return None
def sync_functions( def sync_functions(
@@ -156,16 +156,16 @@ class FunctionsTable:
db.query(Function).filter_by(id=func.id).update( db.query(Function).filter_by(id=func.id).update(
{ {
**func.model_dump(), **func.model_dump(),
"user_id": user_id, 'user_id': user_id,
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
else: else:
new_func = Function( new_func = Function(
**{ **{
**func.model_dump(), **func.model_dump(),
"user_id": user_id, 'user_id': user_id,
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
db.add(new_func) db.add(new_func)
@@ -177,17 +177,12 @@ class FunctionsTable:
db.commit() db.commit()
return [ return [FunctionModel.model_validate(func) for func in db.query(Function).all()]
FunctionModel.model_validate(func)
for func in db.query(Function).all()
]
except Exception as e: except Exception as e:
log.exception(f"Error syncing functions for user {user_id}: {e}") log.exception(f'Error syncing functions for user {user_id}: {e}')
return [] return []
def get_function_by_id( def get_function_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FunctionModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[FunctionModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
function = db.get(Function, id) function = db.get(Function, id)
@@ -195,9 +190,7 @@ class FunctionsTable:
except Exception: except Exception:
return None return None
def get_functions_by_ids( def get_functions_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FunctionModel]:
self, ids: list[str], db: Optional[Session] = None
) -> list[FunctionModel]:
""" """
Batch fetch multiple functions by their IDs in a single query. Batch fetch multiple functions by their IDs in a single query.
Returns functions in the same order as the input IDs (None entries filtered out). Returns functions in the same order as the input IDs (None entries filtered out).
@@ -225,18 +218,11 @@ class FunctionsTable:
functions = db.query(Function).all() functions = db.query(Function).all()
if include_valves: if include_valves:
return [ return [FunctionWithValvesModel.model_validate(function) for function in functions]
FunctionWithValvesModel.model_validate(function)
for function in functions
]
else: else:
return [ return [FunctionModel.model_validate(function) for function in functions]
FunctionModel.model_validate(function) for function in functions
]
def get_function_list( def get_function_list(self, db: Optional[Session] = None) -> list[FunctionUserResponse]:
self, db: Optional[Session] = None
) -> list[FunctionUserResponse]:
with get_db_context(db) as db: with get_db_context(db) as db:
functions = db.query(Function).order_by(Function.updated_at.desc()).all() functions = db.query(Function).order_by(Function.updated_at.desc()).all()
user_ids = list(set(func.user_id for func in functions)) user_ids = list(set(func.user_id for func in functions))
@@ -248,69 +234,48 @@ class FunctionsTable:
FunctionUserResponse.model_validate( FunctionUserResponse.model_validate(
{ {
**FunctionModel.model_validate(func).model_dump(), **FunctionModel.model_validate(func).model_dump(),
"user": ( 'user': (users_dict.get(func.user_id).model_dump() if func.user_id in users_dict else None),
users_dict.get(func.user_id).model_dump()
if func.user_id in users_dict
else None
),
} }
) )
for func in functions for func in functions
] ]
def get_functions_by_type( def get_functions_by_type(self, type: str, active_only=False, db: Optional[Session] = None) -> list[FunctionModel]:
self, type: str, active_only=False, db: Optional[Session] = None
) -> list[FunctionModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
if active_only: if active_only:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function) for function in db.query(Function).filter_by(type=type, is_active=True).all()
.filter_by(type=type, is_active=True)
.all()
] ]
else: else:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function) for function in db.query(Function).filter_by(type=type).all()
for function in db.query(Function).filter_by(type=type).all()
] ]
def get_global_filter_functions( def get_global_filter_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
self, db: Optional[Session] = None
) -> list[FunctionModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function) for function in db.query(Function).filter_by(type='filter', is_active=True, is_global=True).all()
.filter_by(type="filter", is_active=True, is_global=True)
.all()
] ]
def get_global_action_functions( def get_global_action_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
self, db: Optional[Session] = None
) -> list[FunctionModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
FunctionModel.model_validate(function) FunctionModel.model_validate(function)
for function in db.query(Function) for function in db.query(Function).filter_by(type='action', is_active=True, is_global=True).all()
.filter_by(type="action", is_active=True, is_global=True)
.all()
] ]
def get_function_valves_by_id( def get_function_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
self, id: str, db: Optional[Session] = None
) -> Optional[dict]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
function = db.get(Function, id) function = db.get(Function, id)
return function.valves if function.valves else {} return function.valves if function.valves else {}
except Exception as e: except Exception as e:
log.exception(f"Error getting function valves by id {id}: {e}") log.exception(f'Error getting function valves by id {id}: {e}')
return None return None
def get_function_valves_by_ids( def get_function_valves_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, dict]:
self, ids: list[str], db: Optional[Session] = None
) -> dict[str, dict]:
""" """
Batch fetch valves for multiple functions in a single query. Batch fetch valves for multiple functions in a single query.
Returns a dict mapping function_id -> valves dict. Returns a dict mapping function_id -> valves dict.
@@ -320,14 +285,10 @@ class FunctionsTable:
return {} return {}
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
functions = ( functions = db.query(Function.id, Function.valves).filter(Function.id.in_(ids)).all()
db.query(Function.id, Function.valves)
.filter(Function.id.in_(ids))
.all()
)
return {f.id: (f.valves if f.valves else {}) for f in functions} return {f.id: (f.valves if f.valves else {}) for f in functions}
except Exception as e: except Exception as e:
log.exception(f"Error batch-fetching function valves: {e}") log.exception(f'Error batch-fetching function valves: {e}')
return {} return {}
def update_function_valves_by_id( def update_function_valves_by_id(
@@ -364,25 +325,23 @@ class FunctionsTable:
else: else:
return None return None
except Exception as e: except Exception as e:
log.exception(f"Error updating function metadata by id {id}: {e}") log.exception(f'Error updating function metadata by id {id}: {e}')
return None return None
def get_user_valves_by_id_and_user_id( def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]:
self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id, db=db) user = Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings # Check if user has "functions" and "valves" settings
if "functions" not in user_settings: if 'functions' not in user_settings:
user_settings["functions"] = {} user_settings['functions'] = {}
if "valves" not in user_settings["functions"]: if 'valves' not in user_settings['functions']:
user_settings["functions"]["valves"] = {} user_settings['functions']['valves'] = {}
return user_settings["functions"]["valves"].get(id, {}) return user_settings['functions']['valves'].get(id, {})
except Exception as e: except Exception as e:
log.exception(f"Error getting user values by id {id} and user id {user_id}") log.exception(f'Error getting user values by id {id} and user id {user_id}')
return None return None
def update_user_valves_by_id_and_user_id( def update_user_valves_by_id_and_user_id(
@@ -393,32 +352,28 @@ class FunctionsTable:
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings # Check if user has "functions" and "valves" settings
if "functions" not in user_settings: if 'functions' not in user_settings:
user_settings["functions"] = {} user_settings['functions'] = {}
if "valves" not in user_settings["functions"]: if 'valves' not in user_settings['functions']:
user_settings["functions"]["valves"] = {} user_settings['functions']['valves'] = {}
user_settings["functions"]["valves"][id] = valves user_settings['functions']['valves'][id] = valves
# Update the user settings in the database # Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings}, db=db) Users.update_user_by_id(user_id, {'settings': user_settings}, db=db)
return user_settings["functions"]["valves"][id] return user_settings['functions']['valves'][id]
except Exception as e: except Exception as e:
log.exception( log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None return None
def update_function_by_id( def update_function_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[FunctionModel]:
self, id: str, updated: dict, db: Optional[Session] = None
) -> Optional[FunctionModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
db.query(Function).filter_by(id=id).update( db.query(Function).filter_by(id=id).update(
{ {
**updated, **updated,
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
db.commit() db.commit()
@@ -432,8 +387,8 @@ class FunctionsTable:
try: try:
db.query(Function).update( db.query(Function).update(
{ {
"is_active": False, 'is_active': False,
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
db.commit() db.commit()

View File

@@ -34,7 +34,7 @@ log = logging.getLogger(__name__)
class Group(Base): class Group(Base):
__tablename__ = "group" __tablename__ = 'group'
id = Column(Text, unique=True, primary_key=True) id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text) user_id = Column(Text)
@@ -70,12 +70,12 @@ class GroupModel(BaseModel):
class GroupMember(Base): class GroupMember(Base):
__tablename__ = "group_member" __tablename__ = 'group_member'
id = Column(Text, unique=True, primary_key=True) id = Column(Text, unique=True, primary_key=True)
group_id = Column( group_id = Column(
Text, Text,
ForeignKey("group.id", ondelete="CASCADE"), ForeignKey('group.id', ondelete='CASCADE'),
nullable=False, nullable=False,
) )
user_id = Column(Text, nullable=False) user_id = Column(Text, nullable=False)
@@ -133,28 +133,26 @@ class GroupListResponse(BaseModel):
class GroupTable: class GroupTable:
def _ensure_default_share_config(self, group_data: dict) -> dict: def _ensure_default_share_config(self, group_data: dict) -> dict:
"""Ensure the group data dict has a default share config if not already set.""" """Ensure the group data dict has a default share config if not already set."""
if "data" not in group_data or group_data["data"] is None: if 'data' not in group_data or group_data['data'] is None:
group_data["data"] = {} group_data['data'] = {}
if "config" not in group_data["data"]: if 'config' not in group_data['data']:
group_data["data"]["config"] = {} group_data['data']['config'] = {}
if "share" not in group_data["data"]["config"]: if 'share' not in group_data['data']['config']:
group_data["data"]["config"]["share"] = DEFAULT_GROUP_SHARE_PERMISSION group_data['data']['config']['share'] = DEFAULT_GROUP_SHARE_PERMISSION
return group_data return group_data
def insert_new_group( def insert_new_group(
self, user_id: str, form_data: GroupForm, db: Optional[Session] = None self, user_id: str, form_data: GroupForm, db: Optional[Session] = None
) -> Optional[GroupModel]: ) -> Optional[GroupModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
group_data = self._ensure_default_share_config( group_data = self._ensure_default_share_config(form_data.model_dump(exclude_none=True))
form_data.model_dump(exclude_none=True)
)
group = GroupModel( group = GroupModel(
**{ **{
**group_data, **group_data,
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"user_id": user_id, 'user_id': user_id,
"created_at": int(time.time()), 'created_at': int(time.time()),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
@@ -183,19 +181,19 @@ class GroupTable:
.where(GroupMember.group_id == Group.id) .where(GroupMember.group_id == Group.id)
.correlate(Group) .correlate(Group)
.scalar_subquery() .scalar_subquery()
.label("member_count") .label('member_count')
) )
query = db.query(Group, member_count) query = db.query(Group, member_count)
if filter: if filter:
if "query" in filter: if 'query' in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%")) query = query.filter(Group.name.ilike(f'%{filter["query"]}%'))
# When share filter is present, member check is handled in the share logic # When share filter is present, member check is handled in the share logic
if "share" in filter: if 'share' in filter:
share_value = filter["share"] share_value = filter['share']
member_id = filter.get("member_id") member_id = filter.get('member_id')
json_share = Group.data["config"]["share"] json_share = Group.data['config']['share']
json_share_str = json_share.as_string() json_share_str = json_share.as_string()
json_share_lower = func.lower(json_share_str) json_share_lower = func.lower(json_share_str)
@@ -203,37 +201,27 @@ class GroupTable:
anyone_can_share = or_( anyone_can_share = or_(
Group.data.is_(None), Group.data.is_(None),
json_share_str.is_(None), json_share_str.is_(None),
json_share_lower == "true", json_share_lower == 'true',
json_share_lower == "1", # Handle SQLite boolean true json_share_lower == '1', # Handle SQLite boolean true
) )
if member_id: if member_id:
member_groups_select = select(GroupMember.group_id).where( member_groups_select = select(GroupMember.group_id).where(GroupMember.user_id == member_id)
GroupMember.user_id == member_id
)
members_only_and_is_member = and_( members_only_and_is_member = and_(
json_share_lower == "members", json_share_lower == 'members',
Group.id.in_(member_groups_select), Group.id.in_(member_groups_select),
) )
query = query.filter( query = query.filter(or_(anyone_can_share, members_only_and_is_member))
or_(anyone_can_share, members_only_and_is_member)
)
else: else:
query = query.filter(anyone_can_share) query = query.filter(anyone_can_share)
else: else:
query = query.filter( query = query.filter(and_(Group.data.isnot(None), json_share_lower == 'false'))
and_(Group.data.isnot(None), json_share_lower == "false")
)
else: else:
# Only apply member_id filter when share filter is NOT present # Only apply member_id filter when share filter is NOT present
if "member_id" in filter: if 'member_id' in filter:
query = query.filter( query = query.filter(
Group.id.in_( Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id']))
select(GroupMember.group_id).where(
GroupMember.user_id == filter["member_id"]
)
)
) )
results = query.order_by(Group.updated_at.desc()).all() results = query.order_by(Group.updated_at.desc()).all()
@@ -242,7 +230,7 @@ class GroupTable:
GroupResponse.model_validate( GroupResponse.model_validate(
{ {
**GroupModel.model_validate(group).model_dump(), **GroupModel.model_validate(group).model_dump(),
"member_count": count or 0, 'member_count': count or 0,
} }
) )
for group, count in results for group, count in results
@@ -259,22 +247,16 @@ class GroupTable:
query = db.query(Group) query = db.query(Group)
if filter: if filter:
if "query" in filter: if 'query' in filter:
query = query.filter(Group.name.ilike(f"%{filter['query']}%")) query = query.filter(Group.name.ilike(f'%{filter["query"]}%'))
if "member_id" in filter: if 'member_id' in filter:
query = query.filter( query = query.filter(
Group.id.in_( Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id']))
select(GroupMember.group_id).where(
GroupMember.user_id == filter["member_id"]
)
)
) )
if "share" in filter: if 'share' in filter:
share_value = filter["share"] share_value = filter['share']
query = query.filter( query = query.filter(Group.data.op('->>')('share') == str(share_value))
Group.data.op("->>")("share") == str(share_value)
)
total = query.count() total = query.count()
@@ -283,32 +265,24 @@ class GroupTable:
.where(GroupMember.group_id == Group.id) .where(GroupMember.group_id == Group.id)
.correlate(Group) .correlate(Group)
.scalar_subquery() .scalar_subquery()
.label("member_count") .label('member_count')
)
results = (
query.add_columns(member_count)
.order_by(Group.updated_at.desc())
.offset(skip)
.limit(limit)
.all()
) )
results = query.add_columns(member_count).order_by(Group.updated_at.desc()).offset(skip).limit(limit).all()
return { return {
"items": [ 'items': [
GroupResponse.model_validate( GroupResponse.model_validate(
{ {
**GroupModel.model_validate(group).model_dump(), **GroupModel.model_validate(group).model_dump(),
"member_count": count or 0, 'member_count': count or 0,
} }
) )
for group, count in results for group, count in results
], ],
"total": total, 'total': total,
} }
def get_groups_by_member_id( def get_groups_by_member_id(self, user_id: str, db: Optional[Session] = None) -> list[GroupModel]:
self, user_id: str, db: Optional[Session] = None
) -> list[GroupModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
GroupModel.model_validate(group) GroupModel.model_validate(group)
@@ -340,9 +314,7 @@ class GroupTable:
return user_groups return user_groups
def get_group_by_id( def get_group_by_id(self, id: str, db: Optional[Session] = None) -> Optional[GroupModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[GroupModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
group = db.query(Group).filter_by(id=id).first() group = db.query(Group).filter_by(id=id).first()
@@ -350,41 +322,29 @@ class GroupTable:
except Exception: except Exception:
return None return None
def get_group_user_ids_by_id( def get_group_user_ids_by_id(self, id: str, db: Optional[Session] = None) -> list[str]:
self, id: str, db: Optional[Session] = None
) -> list[str]:
with get_db_context(db) as db: with get_db_context(db) as db:
members = ( members = db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
)
if not members: if not members:
return [] return []
return [m[0] for m in members] return [m[0] for m in members]
def get_group_user_ids_by_ids( def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[Session] = None) -> dict[str, list[str]]:
self, group_ids: list[str], db: Optional[Session] = None
) -> dict[str, list[str]]:
with get_db_context(db) as db: with get_db_context(db) as db:
members = ( members = (
db.query(GroupMember.group_id, GroupMember.user_id) db.query(GroupMember.group_id, GroupMember.user_id).filter(GroupMember.group_id.in_(group_ids)).all()
.filter(GroupMember.group_id.in_(group_ids))
.all()
) )
group_user_ids: dict[str, list[str]] = { group_user_ids: dict[str, list[str]] = {group_id: [] for group_id in group_ids}
group_id: [] for group_id in group_ids
}
for group_id, user_id in members: for group_id, user_id in members:
group_user_ids[group_id].append(user_id) group_user_ids[group_id].append(user_id)
return group_user_ids return group_user_ids
def set_group_user_ids_by_id( def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[Session] = None) -> None:
self, group_id: str, user_ids: list[str], db: Optional[Session] = None
) -> None:
with get_db_context(db) as db: with get_db_context(db) as db:
# Delete existing members # Delete existing members
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete() db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
@@ -405,20 +365,12 @@ class GroupTable:
db.add_all(new_members) db.add_all(new_members)
db.commit() db.commit()
def get_group_member_count_by_id( def get_group_member_count_by_id(self, id: str, db: Optional[Session] = None) -> int:
self, id: str, db: Optional[Session] = None
) -> int:
with get_db_context(db) as db: with get_db_context(db) as db:
count = ( count = db.query(func.count(GroupMember.user_id)).filter(GroupMember.group_id == id).scalar()
db.query(func.count(GroupMember.user_id))
.filter(GroupMember.group_id == id)
.scalar()
)
return count if count else 0 return count if count else 0
def get_group_member_counts_by_ids( def get_group_member_counts_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, int]:
self, ids: list[str], db: Optional[Session] = None
) -> dict[str, int]:
if not ids: if not ids:
return {} return {}
with get_db_context(db) as db: with get_db_context(db) as db:
@@ -442,7 +394,7 @@ class GroupTable:
db.query(Group).filter_by(id=id).update( db.query(Group).filter_by(id=id).update(
{ {
**form_data.model_dump(exclude_none=True), **form_data.model_dump(exclude_none=True),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
db.commit() db.commit()
@@ -470,9 +422,7 @@ class GroupTable:
except Exception: except Exception:
return False return False
def remove_user_from_all_groups( def remove_user_from_all_groups(self, user_id: str, db: Optional[Session] = None) -> bool:
self, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
# Find all groups the user belongs to # Find all groups the user belongs to
@@ -489,9 +439,7 @@ class GroupTable:
GroupMember.group_id == group.id, GroupMember.user_id == user_id GroupMember.group_id == group.id, GroupMember.user_id == user_id
).delete() ).delete()
db.query(Group).filter_by(id=group.id).update( db.query(Group).filter_by(id=group.id).update({'updated_at': int(time.time())})
{"updated_at": int(time.time())}
)
db.commit() db.commit()
return True return True
@@ -503,7 +451,6 @@ class GroupTable:
def create_groups_by_group_names( def create_groups_by_group_names(
self, user_id: str, group_names: list[str], db: Optional[Session] = None self, user_id: str, group_names: list[str], db: Optional[Session] = None
) -> list[GroupModel]: ) -> list[GroupModel]:
# check for existing groups # check for existing groups
existing_groups = self.get_all_groups(db=db) existing_groups = self.get_all_groups(db=db)
existing_group_names = {group.name for group in existing_groups} existing_group_names = {group.name for group in existing_groups}
@@ -517,10 +464,10 @@ class GroupTable:
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
user_id=user_id, user_id=user_id,
name=group_name, name=group_name,
description="", description='',
data={ data={
"config": { 'config': {
"share": DEFAULT_GROUP_SHARE_PERMISSION, 'share': DEFAULT_GROUP_SHARE_PERMISSION,
} }
}, },
created_at=int(time.time()), created_at=int(time.time()),
@@ -537,17 +484,13 @@ class GroupTable:
continue continue
return new_groups return new_groups
def sync_groups_by_group_names( def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[Session] = None) -> bool:
self, user_id: str, group_names: list[str], db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
now = int(time.time()) now = int(time.time())
# 1. Groups that SHOULD contain the user # 1. Groups that SHOULD contain the user
target_groups = ( target_groups = db.query(Group).filter(Group.name.in_(group_names)).all()
db.query(Group).filter(Group.name.in_(group_names)).all()
)
target_group_ids = {g.id for g in target_groups} target_group_ids = {g.id for g in target_groups}
# 2. Groups the user is CURRENTLY in # 2. Groups the user is CURRENTLY in
@@ -571,7 +514,7 @@ class GroupTable:
).delete(synchronize_session=False) ).delete(synchronize_session=False)
db.query(Group).filter(Group.id.in_(groups_to_remove)).update( db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
{"updated_at": now}, synchronize_session=False {'updated_at': now}, synchronize_session=False
) )
# 5. Bulk insert missing memberships # 5. Bulk insert missing memberships
@@ -588,7 +531,7 @@ class GroupTable:
if groups_to_add: if groups_to_add:
db.query(Group).filter(Group.id.in_(groups_to_add)).update( db.query(Group).filter(Group.id.in_(groups_to_add)).update(
{"updated_at": now}, synchronize_session=False {'updated_at': now}, synchronize_session=False
) )
db.commit() db.commit()
@@ -656,9 +599,9 @@ class GroupTable:
return GroupModel.model_validate(group) return GroupModel.model_validate(group)
# Remove users from group_member in batch # Remove users from group_member in batch
db.query(GroupMember).filter( db.query(GroupMember).filter(GroupMember.group_id == id, GroupMember.user_id.in_(user_ids)).delete(
GroupMember.group_id == id, GroupMember.user_id.in_(user_ids) synchronize_session=False
).delete(synchronize_session=False) )
# Update group timestamp # Update group timestamp
group.updated_at = int(time.time()) group.updated_at = int(time.time())

View File

@@ -38,7 +38,7 @@ log = logging.getLogger(__name__)
class Knowledge(Base): class Knowledge(Base):
__tablename__ = "knowledge" __tablename__ = 'knowledge'
id = Column(Text, unique=True, primary_key=True) id = Column(Text, unique=True, primary_key=True)
user_id = Column(Text) user_id = Column(Text)
@@ -70,24 +70,18 @@ class KnowledgeModel(BaseModel):
class KnowledgeFile(Base): class KnowledgeFile(Base):
__tablename__ = "knowledge_file" __tablename__ = 'knowledge_file'
id = Column(Text, unique=True, primary_key=True) id = Column(Text, unique=True, primary_key=True)
knowledge_id = Column( knowledge_id = Column(Text, ForeignKey('knowledge.id', ondelete='CASCADE'), nullable=False)
Text, ForeignKey("knowledge.id", ondelete="CASCADE"), nullable=False file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False)
)
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
user_id = Column(Text, nullable=False) user_id = Column(Text, nullable=False)
created_at = Column(BigInteger, nullable=False) created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False)
__table_args__ = ( __table_args__ = (UniqueConstraint('knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file'),)
UniqueConstraint(
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
),
)
class KnowledgeFileModel(BaseModel): class KnowledgeFileModel(BaseModel):
@@ -138,10 +132,8 @@ class KnowledgeFileListResponse(BaseModel):
class KnowledgeTable: class KnowledgeTable:
def _get_access_grants( def _get_access_grants(self, knowledge_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
self, knowledge_id: str, db: Optional[Session] = None return AccessGrants.get_grants_by_resource('knowledge', knowledge_id, db=db)
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("knowledge", knowledge_id, db=db)
def _to_knowledge_model( def _to_knowledge_model(
self, self,
@@ -149,13 +141,9 @@ class KnowledgeTable:
access_grants: Optional[list[AccessGrantModel]] = None, access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> KnowledgeModel: ) -> KnowledgeModel:
knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump( knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump(exclude={'access_grants'})
exclude={"access_grants"} knowledge_data['access_grants'] = (
) access_grants if access_grants is not None else self._get_access_grants(knowledge_data['id'], db=db)
knowledge_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(knowledge_data["id"], db=db)
) )
return KnowledgeModel.model_validate(knowledge_data) return KnowledgeModel.model_validate(knowledge_data)
@@ -165,23 +153,21 @@ class KnowledgeTable:
with get_db_context(db) as db: with get_db_context(db) as db:
knowledge = KnowledgeModel( knowledge = KnowledgeModel(
**{ **{
**form_data.model_dump(exclude={"access_grants"}), **form_data.model_dump(exclude={'access_grants'}),
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"user_id": user_id, 'user_id': user_id,
"created_at": int(time.time()), 'created_at': int(time.time()),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
"access_grants": [], 'access_grants': [],
} }
) )
try: try:
result = Knowledge(**knowledge.model_dump(exclude={"access_grants"})) result = Knowledge(**knowledge.model_dump(exclude={'access_grants'}))
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
AccessGrants.set_access_grants( AccessGrants.set_access_grants('knowledge', result.id, form_data.access_grants, db=db)
"knowledge", result.id, form_data.access_grants, db=db
)
if result: if result:
return self._to_knowledge_model(result, db=db) return self._to_knowledge_model(result, db=db)
else: else:
@@ -193,17 +179,13 @@ class KnowledgeTable:
self, skip: int = 0, limit: int = 30, db: Optional[Session] = None self, skip: int = 0, limit: int = 30, db: Optional[Session] = None
) -> list[KnowledgeUserModel]: ) -> list[KnowledgeUserModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
all_knowledge = ( all_knowledge = db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
)
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge)) user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
knowledge_ids = [knowledge.id for knowledge in all_knowledge] knowledge_ids = [knowledge.id for knowledge in all_knowledge]
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users} users_dict = {user.id: user for user in users}
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
"knowledge", knowledge_ids, db=db
)
knowledge_bases = [] knowledge_bases = []
for knowledge in all_knowledge: for knowledge in all_knowledge:
@@ -216,7 +198,7 @@ class KnowledgeTable:
access_grants=grants_map.get(knowledge.id, []), access_grants=grants_map.get(knowledge.id, []),
db=db, db=db,
).model_dump(), ).model_dump(),
"user": user.model_dump() if user else None, 'user': user.model_dump() if user else None,
} }
) )
) )
@@ -232,27 +214,25 @@ class KnowledgeTable:
) -> KnowledgeListResponse: ) -> KnowledgeListResponse:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
query = db.query(Knowledge, User).outerjoin( query = db.query(Knowledge, User).outerjoin(User, User.id == Knowledge.user_id)
User, User.id == Knowledge.user_id
)
if filter: if filter:
query_key = filter.get("query") query_key = filter.get('query')
if query_key: if query_key:
query = query.filter( query = query.filter(
or_( or_(
Knowledge.name.ilike(f"%{query_key}%"), Knowledge.name.ilike(f'%{query_key}%'),
Knowledge.description.ilike(f"%{query_key}%"), Knowledge.description.ilike(f'%{query_key}%'),
User.name.ilike(f"%{query_key}%"), User.name.ilike(f'%{query_key}%'),
User.email.ilike(f"%{query_key}%"), User.email.ilike(f'%{query_key}%'),
User.username.ilike(f"%{query_key}%"), User.username.ilike(f'%{query_key}%'),
) )
) )
view_option = filter.get("view_option") view_option = filter.get('view_option')
if view_option == "created": if view_option == 'created':
query = query.filter(Knowledge.user_id == user_id) query = query.filter(Knowledge.user_id == user_id)
elif view_option == "shared": elif view_option == 'shared':
query = query.filter(Knowledge.user_id != user_id) query = query.filter(Knowledge.user_id != user_id)
query = AccessGrants.has_permission_filter( query = AccessGrants.has_permission_filter(
@@ -260,8 +240,8 @@ class KnowledgeTable:
query=query, query=query,
DocumentModel=Knowledge, DocumentModel=Knowledge,
filter=filter, filter=filter,
resource_type="knowledge", resource_type='knowledge',
permission="read", permission='read',
) )
query = query.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc()) query = query.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc())
@@ -275,9 +255,7 @@ class KnowledgeTable:
items = query.all() items = query.all()
knowledge_ids = [kb.id for kb, _ in items] knowledge_ids = [kb.id for kb, _ in items]
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
"knowledge", knowledge_ids, db=db
)
knowledge_bases = [] knowledge_bases = []
for knowledge_base, user in items: for knowledge_base, user in items:
@@ -289,11 +267,7 @@ class KnowledgeTable:
access_grants=grants_map.get(knowledge_base.id, []), access_grants=grants_map.get(knowledge_base.id, []),
db=db, db=db,
).model_dump(), ).model_dump(),
"user": ( 'user': (UserModel.model_validate(user).model_dump() if user else None),
UserModel.model_validate(user).model_dump()
if user
else None
),
} }
) )
) )
@@ -327,15 +301,15 @@ class KnowledgeTable:
query=query, query=query,
DocumentModel=Knowledge, DocumentModel=Knowledge,
filter=filter, filter=filter,
resource_type="knowledge", resource_type='knowledge',
permission="read", permission='read',
) )
# Apply filename search # Apply filename search
if filter: if filter:
q = filter.get("query") q = filter.get('query')
if q: if q:
query = query.filter(File.filename.ilike(f"%{q}%")) query = query.filter(File.filename.ilike(f'%{q}%'))
# Order by file changes # Order by file changes
query = query.order_by(File.updated_at.desc(), File.id.asc()) query = query.order_by(File.updated_at.desc(), File.id.asc())
@@ -355,39 +329,27 @@ class KnowledgeTable:
items.append( items.append(
FileUserResponse( FileUserResponse(
**FileModel.model_validate(file).model_dump(), **FileModel.model_validate(file).model_dump(),
user=( user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
UserResponse( collection=self._to_knowledge_model(knowledge, db=db).model_dump(),
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
collection=self._to_knowledge_model(
knowledge, db=db
).model_dump(),
) )
) )
return KnowledgeFileListResponse(items=items, total=total) return KnowledgeFileListResponse(items=items, total=total)
except Exception as e: except Exception as e:
print("search_knowledge_files error:", e) print('search_knowledge_files error:', e)
return KnowledgeFileListResponse(items=[], total=0) return KnowledgeFileListResponse(items=[], total=0)
def check_access_by_user_id( def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool:
self, id, user_id, permission="write", db: Optional[Session] = None
) -> bool:
knowledge = self.get_knowledge_by_id(id, db=db) knowledge = self.get_knowledge_by_id(id, db=db)
if not knowledge: if not knowledge:
return False return False
if knowledge.user_id == user_id: if knowledge.user_id == user_id:
return True return True
user_group_ids = { user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
return AccessGrants.has_access( return AccessGrants.has_access(
user_id=user_id, user_id=user_id,
resource_type="knowledge", resource_type='knowledge',
resource_id=knowledge.id, resource_id=knowledge.id,
permission=permission, permission=permission,
user_group_ids=user_group_ids, user_group_ids=user_group_ids,
@@ -395,19 +357,17 @@ class KnowledgeTable:
) )
def get_knowledge_bases_by_user_id( def get_knowledge_bases_by_user_id(
self, user_id: str, permission: str = "write", db: Optional[Session] = None self, user_id: str, permission: str = 'write', db: Optional[Session] = None
) -> list[KnowledgeUserModel]: ) -> list[KnowledgeUserModel]:
knowledge_bases = self.get_knowledge_bases(db=db) knowledge_bases = self.get_knowledge_bases(db=db)
user_group_ids = { user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
return [ return [
knowledge_base knowledge_base
for knowledge_base in knowledge_bases for knowledge_base in knowledge_bases
if knowledge_base.user_id == user_id if knowledge_base.user_id == user_id
or AccessGrants.has_access( or AccessGrants.has_access(
user_id=user_id, user_id=user_id,
resource_type="knowledge", resource_type='knowledge',
resource_id=knowledge_base.id, resource_id=knowledge_base.id,
permission=permission, permission=permission,
user_group_ids=user_group_ids, user_group_ids=user_group_ids,
@@ -415,9 +375,7 @@ class KnowledgeTable:
) )
] ]
def get_knowledge_by_id( def get_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[KnowledgeModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
knowledge = db.query(Knowledge).filter_by(id=id).first() knowledge = db.query(Knowledge).filter_by(id=id).first()
@@ -435,23 +393,19 @@ class KnowledgeTable:
if knowledge.user_id == user_id: if knowledge.user_id == user_id:
return knowledge return knowledge
user_group_ids = { user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
if AccessGrants.has_access( if AccessGrants.has_access(
user_id=user_id, user_id=user_id,
resource_type="knowledge", resource_type='knowledge',
resource_id=knowledge.id, resource_id=knowledge.id,
permission="write", permission='write',
user_group_ids=user_group_ids, user_group_ids=user_group_ids,
db=db, db=db,
): ):
return knowledge return knowledge
return None return None
def get_knowledges_by_file_id( def get_knowledges_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[KnowledgeModel]:
self, file_id: str, db: Optional[Session] = None
) -> list[KnowledgeModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
knowledges = ( knowledges = (
@@ -461,9 +415,7 @@ class KnowledgeTable:
.all() .all()
) )
knowledge_ids = [k.id for k in knowledges] knowledge_ids = [k.id for k in knowledges]
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
"knowledge", knowledge_ids, db=db
)
return [ return [
self._to_knowledge_model( self._to_knowledge_model(
knowledge, knowledge,
@@ -497,32 +449,26 @@ class KnowledgeTable:
primary_sort = File.updated_at.desc() primary_sort = File.updated_at.desc()
if filter: if filter:
query_key = filter.get("query") query_key = filter.get('query')
if query_key: if query_key:
query = query.filter(or_(File.filename.ilike(f"%{query_key}%"))) query = query.filter(or_(File.filename.ilike(f'%{query_key}%')))
view_option = filter.get("view_option") view_option = filter.get('view_option')
if view_option == "created": if view_option == 'created':
query = query.filter(KnowledgeFile.user_id == user_id) query = query.filter(KnowledgeFile.user_id == user_id)
elif view_option == "shared": elif view_option == 'shared':
query = query.filter(KnowledgeFile.user_id != user_id) query = query.filter(KnowledgeFile.user_id != user_id)
order_by = filter.get("order_by") order_by = filter.get('order_by')
direction = filter.get("direction") direction = filter.get('direction')
is_asc = direction == "asc" is_asc = direction == 'asc'
if order_by == "name": if order_by == 'name':
primary_sort = ( primary_sort = File.filename.asc() if is_asc else File.filename.desc()
File.filename.asc() if is_asc else File.filename.desc() elif order_by == 'created_at':
) primary_sort = File.created_at.asc() if is_asc else File.created_at.desc()
elif order_by == "created_at": elif order_by == 'updated_at':
primary_sort = ( primary_sort = File.updated_at.asc() if is_asc else File.updated_at.desc()
File.created_at.asc() if is_asc else File.created_at.desc()
)
elif order_by == "updated_at":
primary_sort = (
File.updated_at.asc() if is_asc else File.updated_at.desc()
)
# Apply sort with secondary key for deterministic pagination # Apply sort with secondary key for deterministic pagination
query = query.order_by(primary_sort, File.id.asc()) query = query.order_by(primary_sort, File.id.asc())
@@ -542,13 +488,7 @@ class KnowledgeTable:
files.append( files.append(
FileUserResponse( FileUserResponse(
**FileModel.model_validate(file).model_dump(), **FileModel.model_validate(file).model_dump(),
user=( user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
UserResponse(
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
) )
) )
@@ -557,9 +497,7 @@ class KnowledgeTable:
print(e) print(e)
return KnowledgeFileListResponse(items=[], total=0) return KnowledgeFileListResponse(items=[], total=0)
def get_files_by_id( def get_files_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileModel]:
self, knowledge_id: str, db: Optional[Session] = None
) -> list[FileModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
files = ( files = (
@@ -572,9 +510,7 @@ class KnowledgeTable:
except Exception: except Exception:
return [] return []
def get_file_metadatas_by_id( def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileMetadataResponse]:
self, knowledge_id: str, db: Optional[Session] = None
) -> list[FileMetadataResponse]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
files = self.get_files_by_id(knowledge_id, db=db) files = self.get_files_by_id(knowledge_id, db=db)
@@ -592,12 +528,12 @@ class KnowledgeTable:
with get_db_context(db) as db: with get_db_context(db) as db:
knowledge_file = KnowledgeFileModel( knowledge_file = KnowledgeFileModel(
**{ **{
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"knowledge_id": knowledge_id, 'knowledge_id': knowledge_id,
"file_id": file_id, 'file_id': file_id,
"user_id": user_id, 'user_id': user_id,
"created_at": int(time.time()), 'created_at': int(time.time()),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
@@ -613,37 +549,24 @@ class KnowledgeTable:
except Exception: except Exception:
return None return None
def has_file( def has_file(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool:
self, knowledge_id: str, file_id: str, db: Optional[Session] = None
) -> bool:
"""Check whether a file belongs to a knowledge base.""" """Check whether a file belongs to a knowledge base."""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
return ( return db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).first() is not None
db.query(KnowledgeFile)
.filter_by(knowledge_id=knowledge_id, file_id=file_id)
.first()
is not None
)
except Exception: except Exception:
return False return False
def remove_file_from_knowledge_by_id( def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool:
self, knowledge_id: str, file_id: str, db: Optional[Session] = None
) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
db.query(KnowledgeFile).filter_by( db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).delete()
knowledge_id=knowledge_id, file_id=file_id
).delete()
db.commit() db.commit()
return True return True
except Exception: except Exception:
return False return False
def reset_knowledge_by_id( def reset_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[KnowledgeModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
# Delete all knowledge_file entries for this knowledge_id # Delete all knowledge_file entries for this knowledge_id
@@ -653,7 +576,7 @@ class KnowledgeTable:
# Update the knowledge entry's updated_at timestamp # Update the knowledge entry's updated_at timestamp
db.query(Knowledge).filter_by(id=id).update( db.query(Knowledge).filter_by(id=id).update(
{ {
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
db.commit() db.commit()
@@ -675,15 +598,13 @@ class KnowledgeTable:
knowledge = self.get_knowledge_by_id(id=id, db=db) knowledge = self.get_knowledge_by_id(id=id, db=db)
db.query(Knowledge).filter_by(id=id).update( db.query(Knowledge).filter_by(id=id).update(
{ {
**form_data.model_dump(exclude={"access_grants"}), **form_data.model_dump(exclude={'access_grants'}),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
db.commit() db.commit()
if form_data.access_grants is not None: if form_data.access_grants is not None:
AccessGrants.set_access_grants( AccessGrants.set_access_grants('knowledge', id, form_data.access_grants, db=db)
"knowledge", id, form_data.access_grants, db=db
)
return self.get_knowledge_by_id(id=id, db=db) return self.get_knowledge_by_id(id=id, db=db)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
@@ -697,8 +618,8 @@ class KnowledgeTable:
knowledge = self.get_knowledge_by_id(id=id, db=db) knowledge = self.get_knowledge_by_id(id=id, db=db)
db.query(Knowledge).filter_by(id=id).update( db.query(Knowledge).filter_by(id=id).update(
{ {
"data": data, 'data': data,
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
db.commit() db.commit()
@@ -710,7 +631,7 @@ class KnowledgeTable:
def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool: def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
AccessGrants.revoke_all_access("knowledge", id, db=db) AccessGrants.revoke_all_access('knowledge', id, db=db)
db.query(Knowledge).filter_by(id=id).delete() db.query(Knowledge).filter_by(id=id).delete()
db.commit() db.commit()
return True return True
@@ -722,7 +643,7 @@ class KnowledgeTable:
try: try:
knowledge_ids = [row[0] for row in db.query(Knowledge.id).all()] knowledge_ids = [row[0] for row in db.query(Knowledge.id).all()]
for knowledge_id in knowledge_ids: for knowledge_id in knowledge_ids:
AccessGrants.revoke_all_access("knowledge", knowledge_id, db=db) AccessGrants.revoke_all_access('knowledge', knowledge_id, db=db)
db.query(Knowledge).delete() db.query(Knowledge).delete()
db.commit() db.commit()

View File

@@ -13,7 +13,7 @@ from sqlalchemy import BigInteger, Column, String, Text
class Memory(Base): class Memory(Base):
__tablename__ = "memory" __tablename__ = 'memory'
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True, unique=True)
user_id = Column(String) user_id = Column(String)
@@ -49,11 +49,11 @@ class MemoriesTable:
memory = MemoryModel( memory = MemoryModel(
**{ **{
"id": id, 'id': id,
"user_id": user_id, 'user_id': user_id,
"content": content, 'content': content,
"created_at": int(time.time()), 'created_at': int(time.time()),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
result = Memory(**memory.model_dump()) result = Memory(**memory.model_dump())
@@ -95,9 +95,7 @@ class MemoriesTable:
except Exception: except Exception:
return None return None
def get_memories_by_user_id( def get_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[MemoryModel]:
self, user_id: str, db: Optional[Session] = None
) -> list[MemoryModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
memories = db.query(Memory).filter_by(user_id=user_id).all() memories = db.query(Memory).filter_by(user_id=user_id).all()
@@ -105,9 +103,7 @@ class MemoriesTable:
except Exception: except Exception:
return None return None
def get_memory_by_id( def get_memory_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MemoryModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[MemoryModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
memory = db.get(Memory, id) memory = db.get(Memory, id)
@@ -126,9 +122,7 @@ class MemoriesTable:
except Exception: except Exception:
return False return False
def delete_memories_by_user_id( def delete_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
self, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
db.query(Memory).filter_by(user_id=user_id).delete() db.query(Memory).filter_by(user_id=user_id).delete()
@@ -138,9 +132,7 @@ class MemoriesTable:
except Exception: except Exception:
return False return False
def delete_memory_by_id_and_user_id( def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
self, id: str, user_id: str, db: Optional[Session] = None
) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
memory = db.get(Memory, id) memory = db.get(Memory, id)

View File

@@ -21,7 +21,7 @@ from sqlalchemy.sql import exists
class MessageReaction(Base): class MessageReaction(Base):
__tablename__ = "message_reaction" __tablename__ = 'message_reaction'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text) user_id = Column(Text)
message_id = Column(Text) message_id = Column(Text)
@@ -40,7 +40,7 @@ class MessageReactionModel(BaseModel):
class Message(Base): class Message(Base):
__tablename__ = "message" __tablename__ = 'message'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text) user_id = Column(Text)
@@ -112,7 +112,7 @@ class MessageUserResponse(MessageModel):
class MessageUserSlimResponse(MessageUserResponse): class MessageUserSlimResponse(MessageUserResponse):
data: bool | None = None data: bool | None = None
@field_validator("data", mode="before") @field_validator('data', mode='before')
def convert_data_to_bool(cls, v): def convert_data_to_bool(cls, v):
# No data or not a dict → False # No data or not a dict → False
if not isinstance(v, dict): if not isinstance(v, dict):
@@ -152,19 +152,19 @@ class MessageTable:
message = MessageModel( message = MessageModel(
**{ **{
"id": id, 'id': id,
"user_id": user_id, 'user_id': user_id,
"channel_id": channel_id, 'channel_id': channel_id,
"reply_to_id": form_data.reply_to_id, 'reply_to_id': form_data.reply_to_id,
"parent_id": form_data.parent_id, 'parent_id': form_data.parent_id,
"is_pinned": False, 'is_pinned': False,
"pinned_at": None, 'pinned_at': None,
"pinned_by": None, 'pinned_by': None,
"content": form_data.content, 'content': form_data.content,
"data": form_data.data, 'data': form_data.data,
"meta": form_data.meta, 'meta': form_data.meta,
"created_at": ts, 'created_at': ts,
"updated_at": ts, 'updated_at': ts,
} }
) )
result = Message(**message.model_dump()) result = Message(**message.model_dump())
@@ -186,9 +186,7 @@ class MessageTable:
return None return None
reply_to_message = ( reply_to_message = (
self.get_message_by_id( self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
message.reply_to_id, include_thread_replies=False, db=db
)
if message.reply_to_id if message.reply_to_id
else None else None
) )
@@ -200,22 +198,22 @@ class MessageTable:
thread_replies = self.get_thread_replies_by_message_id(id, db=db) thread_replies = self.get_thread_replies_by_message_id(id, db=db)
# Check if message was sent by webhook (webhook info in meta takes precedence) # Check if message was sent by webhook (webhook info in meta takes precedence)
webhook_info = message.meta.get("webhook") if message.meta else None webhook_info = message.meta.get('webhook') if message.meta else None
if webhook_info and webhook_info.get("id"): if webhook_info and webhook_info.get('id'):
# Look up webhook by ID to get current name # Look up webhook by ID to get current name
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook: if webhook:
user_info = { user_info = {
"id": webhook.id, 'id': webhook.id,
"name": webhook.name, 'name': webhook.name,
"role": "webhook", 'role': 'webhook',
} }
else: else:
# Webhook was deleted, use placeholder # Webhook was deleted, use placeholder
user_info = { user_info = {
"id": webhook_info.get("id"), 'id': webhook_info.get('id'),
"name": "Deleted Webhook", 'name': 'Deleted Webhook',
"role": "webhook", 'role': 'webhook',
} }
else: else:
user = Users.get_user_by_id(message.user_id, db=db) user = Users.get_user_by_id(message.user_id, db=db)
@@ -224,79 +222,57 @@ class MessageTable:
return MessageResponse.model_validate( return MessageResponse.model_validate(
{ {
**MessageModel.model_validate(message).model_dump(), **MessageModel.model_validate(message).model_dump(),
"user": user_info, 'user': user_info,
"reply_to_message": ( 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
reply_to_message.model_dump() if reply_to_message else None 'latest_reply_at': (thread_replies[0].created_at if thread_replies else None),
), 'reply_count': len(thread_replies),
"latest_reply_at": ( 'reactions': reactions,
thread_replies[0].created_at if thread_replies else None
),
"reply_count": len(thread_replies),
"reactions": reactions,
} }
) )
def get_thread_replies_by_message_id( def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]:
self, id: str, db: Optional[Session] = None
) -> list[MessageReplyToResponse]:
with get_db_context(db) as db: with get_db_context(db) as db:
all_messages = ( all_messages = db.query(Message).filter_by(parent_id=id).order_by(Message.created_at.desc()).all()
db.query(Message)
.filter_by(parent_id=id)
.order_by(Message.created_at.desc())
.all()
)
messages = [] messages = []
for message in all_messages: for message in all_messages:
reply_to_message = ( reply_to_message = (
self.get_message_by_id( self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
message.reply_to_id, include_thread_replies=False, db=db
)
if message.reply_to_id if message.reply_to_id
else None else None
) )
webhook_info = message.meta.get("webhook") if message.meta else None webhook_info = message.meta.get('webhook') if message.meta else None
user_info = None user_info = None
if webhook_info and webhook_info.get("id"): if webhook_info and webhook_info.get('id'):
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook: if webhook:
user_info = { user_info = {
"id": webhook.id, 'id': webhook.id,
"name": webhook.name, 'name': webhook.name,
"role": "webhook", 'role': 'webhook',
} }
else: else:
user_info = { user_info = {
"id": webhook_info.get("id"), 'id': webhook_info.get('id'),
"name": "Deleted Webhook", 'name': 'Deleted Webhook',
"role": "webhook", 'role': 'webhook',
} }
messages.append( messages.append(
MessageReplyToResponse.model_validate( MessageReplyToResponse.model_validate(
{ {
**MessageModel.model_validate(message).model_dump(), **MessageModel.model_validate(message).model_dump(),
"user": user_info, 'user': user_info,
"reply_to_message": ( 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
reply_to_message.model_dump()
if reply_to_message
else None
),
} }
) )
) )
return messages return messages
def get_reply_user_ids_by_message_id( def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]:
self, id: str, db: Optional[Session] = None
) -> list[str]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [message.user_id for message in db.query(Message).filter_by(parent_id=id).all()]
message.user_id
for message in db.query(Message).filter_by(parent_id=id).all()
]
def get_messages_by_channel_id( def get_messages_by_channel_id(
self, self,
@@ -318,40 +294,34 @@ class MessageTable:
messages = [] messages = []
for message in all_messages: for message in all_messages:
reply_to_message = ( reply_to_message = (
self.get_message_by_id( self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
message.reply_to_id, include_thread_replies=False, db=db
)
if message.reply_to_id if message.reply_to_id
else None else None
) )
webhook_info = message.meta.get("webhook") if message.meta else None webhook_info = message.meta.get('webhook') if message.meta else None
user_info = None user_info = None
if webhook_info and webhook_info.get("id"): if webhook_info and webhook_info.get('id'):
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook: if webhook:
user_info = { user_info = {
"id": webhook.id, 'id': webhook.id,
"name": webhook.name, 'name': webhook.name,
"role": "webhook", 'role': 'webhook',
} }
else: else:
user_info = { user_info = {
"id": webhook_info.get("id"), 'id': webhook_info.get('id'),
"name": "Deleted Webhook", 'name': 'Deleted Webhook',
"role": "webhook", 'role': 'webhook',
} }
messages.append( messages.append(
MessageReplyToResponse.model_validate( MessageReplyToResponse.model_validate(
{ {
**MessageModel.model_validate(message).model_dump(), **MessageModel.model_validate(message).model_dump(),
"user": user_info, 'user': user_info,
"reply_to_message": ( 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
reply_to_message.model_dump()
if reply_to_message
else None
),
} }
) )
) )
@@ -387,55 +357,42 @@ class MessageTable:
messages = [] messages = []
for message in all_messages: for message in all_messages:
reply_to_message = ( reply_to_message = (
self.get_message_by_id( self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
message.reply_to_id, include_thread_replies=False, db=db
)
if message.reply_to_id if message.reply_to_id
else None else None
) )
webhook_info = message.meta.get("webhook") if message.meta else None webhook_info = message.meta.get('webhook') if message.meta else None
user_info = None user_info = None
if webhook_info and webhook_info.get("id"): if webhook_info and webhook_info.get('id'):
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db) webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
if webhook: if webhook:
user_info = { user_info = {
"id": webhook.id, 'id': webhook.id,
"name": webhook.name, 'name': webhook.name,
"role": "webhook", 'role': 'webhook',
} }
else: else:
user_info = { user_info = {
"id": webhook_info.get("id"), 'id': webhook_info.get('id'),
"name": "Deleted Webhook", 'name': 'Deleted Webhook',
"role": "webhook", 'role': 'webhook',
} }
messages.append( messages.append(
MessageReplyToResponse.model_validate( MessageReplyToResponse.model_validate(
{ {
**MessageModel.model_validate(message).model_dump(), **MessageModel.model_validate(message).model_dump(),
"user": user_info, 'user': user_info,
"reply_to_message": ( 'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
reply_to_message.model_dump()
if reply_to_message
else None
),
} }
) )
) )
return messages return messages
def get_last_message_by_channel_id( def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]:
self, channel_id: str, db: Optional[Session] = None
) -> Optional[MessageModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
message = ( message = db.query(Message).filter_by(channel_id=channel_id).order_by(Message.created_at.desc()).first()
db.query(Message)
.filter_by(channel_id=channel_id)
.order_by(Message.created_at.desc())
.first()
)
return MessageModel.model_validate(message) if message else None return MessageModel.model_validate(message) if message else None
def get_pinned_messages_by_channel_id( def get_pinned_messages_by_channel_id(
@@ -513,11 +470,7 @@ class MessageTable:
) -> Optional[MessageReactionModel]: ) -> Optional[MessageReactionModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
# check for existing reaction # check for existing reaction
existing_reaction = ( existing_reaction = db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).first()
db.query(MessageReaction)
.filter_by(message_id=id, user_id=user_id, name=name)
.first()
)
if existing_reaction: if existing_reaction:
return MessageReactionModel.model_validate(existing_reaction) return MessageReactionModel.model_validate(existing_reaction)
@@ -535,9 +488,7 @@ class MessageTable:
db.refresh(result) db.refresh(result)
return MessageReactionModel.model_validate(result) if result else None return MessageReactionModel.model_validate(result) if result else None
def get_reactions_by_message_id( def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]:
self, id: str, db: Optional[Session] = None
) -> list[Reactions]:
with get_db_context(db) as db: with get_db_context(db) as db:
# JOIN User so all user info is fetched in one query # JOIN User so all user info is fetched in one query
results = ( results = (
@@ -552,18 +503,18 @@ class MessageTable:
for reaction, user in results: for reaction, user in results:
if reaction.name not in reactions: if reaction.name not in reactions:
reactions[reaction.name] = { reactions[reaction.name] = {
"name": reaction.name, 'name': reaction.name,
"users": [], 'users': [],
"count": 0, 'count': 0,
} }
reactions[reaction.name]["users"].append( reactions[reaction.name]['users'].append(
{ {
"id": user.id, 'id': user.id,
"name": user.name, 'name': user.name,
} }
) )
reactions[reaction.name]["count"] += 1 reactions[reaction.name]['count'] += 1
return [Reactions(**reaction) for reaction in reactions.values()] return [Reactions(**reaction) for reaction in reactions.values()]
@@ -571,9 +522,7 @@ class MessageTable:
self, id: str, user_id: str, name: str, db: Optional[Session] = None self, id: str, user_id: str, name: str, db: Optional[Session] = None
) -> bool: ) -> bool:
with get_db_context(db) as db: with get_db_context(db) as db:
db.query(MessageReaction).filter_by( db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).delete()
message_id=id, user_id=user_id, name=name
).delete()
db.commit() db.commit()
return True return True
@@ -612,21 +561,15 @@ class MessageTable:
with get_db_context(db) as db: with get_db_context(db) as db:
query_builder = db.query(Message).filter( query_builder = db.query(Message).filter(
Message.channel_id.in_(channel_ids), Message.channel_id.in_(channel_ids),
Message.content.ilike(f"%{query}%"), Message.content.ilike(f'%{query}%'),
) )
if start_timestamp: if start_timestamp:
query_builder = query_builder.filter( query_builder = query_builder.filter(Message.created_at >= start_timestamp)
Message.created_at >= start_timestamp
)
if end_timestamp: if end_timestamp:
query_builder = query_builder.filter( query_builder = query_builder.filter(Message.created_at <= end_timestamp)
Message.created_at <= end_timestamp
)
messages = ( messages = query_builder.order_by(Message.created_at.desc()).limit(limit).all()
query_builder.order_by(Message.created_at.desc()).limit(limit).all()
)
return [MessageModel.model_validate(msg) for msg in messages] return [MessageModel.model_validate(msg) for msg in messages]

View File

@@ -28,13 +28,13 @@ log = logging.getLogger(__name__)
# ModelParams is a model for the data stored in the params field of the Model table # ModelParams is a model for the data stored in the params field of the Model table
class ModelParams(BaseModel): class ModelParams(BaseModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
pass pass
# ModelMeta is a model for the data stored in the meta field of the Model table # ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel): class ModelMeta(BaseModel):
profile_image_url: Optional[str] = "/static/favicon.png" profile_image_url: Optional[str] = '/static/favicon.png'
description: Optional[str] = None description: Optional[str] = None
""" """
@@ -43,13 +43,13 @@ class ModelMeta(BaseModel):
capabilities: Optional[dict] = None capabilities: Optional[dict] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
pass pass
class Model(Base): class Model(Base):
__tablename__ = "model" __tablename__ = 'model'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
""" """
@@ -139,10 +139,8 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def _get_access_grants( def _get_access_grants(self, model_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
self, model_id: str, db: Optional[Session] = None return AccessGrants.get_grants_by_resource('model', model_id, db=db)
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("model", model_id, db=db)
def _to_model_model( def _to_model_model(
self, self,
@@ -150,13 +148,9 @@ class ModelsTable:
access_grants: Optional[list[AccessGrantModel]] = None, access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> ModelModel: ) -> ModelModel:
model_data = ModelModel.model_validate(model).model_dump( model_data = ModelModel.model_validate(model).model_dump(exclude={'access_grants'})
exclude={"access_grants"} model_data['access_grants'] = (
) access_grants if access_grants is not None else self._get_access_grants(model_data['id'], db=db)
model_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(model_data["id"], db=db)
) )
return ModelModel.model_validate(model_data) return ModelModel.model_validate(model_data)
@@ -167,37 +161,32 @@ class ModelsTable:
with get_db_context(db) as db: with get_db_context(db) as db:
result = Model( result = Model(
**{ **{
**form_data.model_dump(exclude={"access_grants"}), **form_data.model_dump(exclude={'access_grants'}),
"user_id": user_id, 'user_id': user_id,
"created_at": int(time.time()), 'created_at': int(time.time()),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
AccessGrants.set_access_grants( AccessGrants.set_access_grants('model', result.id, form_data.access_grants, db=db)
"model", result.id, form_data.access_grants, db=db
)
if result: if result:
return self._to_model_model(result, db=db) return self._to_model_model(result, db=db)
else: else:
return None return None
except Exception as e: except Exception as e:
log.exception(f"Failed to insert a new model: {e}") log.exception(f'Failed to insert a new model: {e}')
return None return None
def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]: def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
all_models = db.query(Model).all() all_models = db.query(Model).all()
model_ids = [model.id for model in all_models] model_ids = [model.id for model in all_models]
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
return [ return [
self._to_model_model( self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models
model, access_grants=grants_map.get(model.id, []), db=db
)
for model in all_models
] ]
def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]: def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]:
@@ -209,7 +198,7 @@ class ModelsTable:
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users} users_dict = {user.id: user for user in users}
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
models = [] models = []
for model in all_models: for model in all_models:
@@ -222,7 +211,7 @@ class ModelsTable:
access_grants=grants_map.get(model.id, []), access_grants=grants_map.get(model.id, []),
db=db, db=db,
).model_dump(), ).model_dump(),
"user": user.model_dump() if user else None, 'user': user.model_dump() if user else None,
} }
) )
) )
@@ -232,28 +221,23 @@ class ModelsTable:
with get_db_context(db) as db: with get_db_context(db) as db:
all_models = db.query(Model).filter(Model.base_model_id == None).all() all_models = db.query(Model).filter(Model.base_model_id == None).all()
model_ids = [model.id for model in all_models] model_ids = [model.id for model in all_models]
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
return [ return [
self._to_model_model( self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models
model, access_grants=grants_map.get(model.id, []), db=db
)
for model in all_models
] ]
def get_models_by_user_id( def get_models_by_user_id(
self, user_id: str, permission: str = "write", db: Optional[Session] = None self, user_id: str, permission: str = 'write', db: Optional[Session] = None
) -> list[ModelUserResponse]: ) -> list[ModelUserResponse]:
models = self.get_models(db=db) models = self.get_models(db=db)
user_group_ids = { user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
return [ return [
model model
for model in models for model in models
if model.user_id == user_id if model.user_id == user_id
or AccessGrants.has_access( or AccessGrants.has_access(
user_id=user_id, user_id=user_id,
resource_type="model", resource_type='model',
resource_id=model.id, resource_id=model.id,
permission=permission, permission=permission,
user_group_ids=user_group_ids, user_group_ids=user_group_ids,
@@ -261,13 +245,13 @@ class ModelsTable:
) )
] ]
def _has_permission(self, db, query, filter: dict, permission: str = "read"): def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
return AccessGrants.has_permission_filter( return AccessGrants.has_permission_filter(
db=db, db=db,
query=query, query=query,
DocumentModel=Model, DocumentModel=Model,
filter=filter, filter=filter,
resource_type="model", resource_type='model',
permission=permission, permission=permission,
) )
@@ -285,22 +269,22 @@ class ModelsTable:
query = query.filter(Model.base_model_id != None) query = query.filter(Model.base_model_id != None)
if filter: if filter:
query_key = filter.get("query") query_key = filter.get('query')
if query_key: if query_key:
query = query.filter( query = query.filter(
or_( or_(
Model.name.ilike(f"%{query_key}%"), Model.name.ilike(f'%{query_key}%'),
Model.base_model_id.ilike(f"%{query_key}%"), Model.base_model_id.ilike(f'%{query_key}%'),
User.name.ilike(f"%{query_key}%"), User.name.ilike(f'%{query_key}%'),
User.email.ilike(f"%{query_key}%"), User.email.ilike(f'%{query_key}%'),
User.username.ilike(f"%{query_key}%"), User.username.ilike(f'%{query_key}%'),
) )
) )
view_option = filter.get("view_option") view_option = filter.get('view_option')
if view_option == "created": if view_option == 'created':
query = query.filter(Model.user_id == user_id) query = query.filter(Model.user_id == user_id)
elif view_option == "shared": elif view_option == 'shared':
query = query.filter(Model.user_id != user_id) query = query.filter(Model.user_id != user_id)
# Apply access control filtering # Apply access control filtering
@@ -308,10 +292,10 @@ class ModelsTable:
db, db,
query, query,
filter, filter,
permission="read", permission='read',
) )
tag = filter.get("tag") tag = filter.get('tag')
if tag: if tag:
# TODO: This is a simple implementation and should be improved for performance # TODO: This is a simple implementation and should be improved for performance
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
@@ -319,21 +303,21 @@ class ModelsTable:
query = query.filter(meta_text.like(like_pattern)) query = query.filter(meta_text.like(like_pattern))
order_by = filter.get("order_by") order_by = filter.get('order_by')
direction = filter.get("direction") direction = filter.get('direction')
if order_by == "name": if order_by == 'name':
if direction == "asc": if direction == 'asc':
query = query.order_by(Model.name.asc()) query = query.order_by(Model.name.asc())
else: else:
query = query.order_by(Model.name.desc()) query = query.order_by(Model.name.desc())
elif order_by == "created_at": elif order_by == 'created_at':
if direction == "asc": if direction == 'asc':
query = query.order_by(Model.created_at.asc()) query = query.order_by(Model.created_at.asc())
else: else:
query = query.order_by(Model.created_at.desc()) query = query.order_by(Model.created_at.desc())
elif order_by == "updated_at": elif order_by == 'updated_at':
if direction == "asc": if direction == 'asc':
query = query.order_by(Model.updated_at.asc()) query = query.order_by(Model.updated_at.asc())
else: else:
query = query.order_by(Model.updated_at.desc()) query = query.order_by(Model.updated_at.desc())
@@ -352,7 +336,7 @@ class ModelsTable:
items = query.all() items = query.all()
model_ids = [model.id for model, _ in items] model_ids = [model.id for model, _ in items]
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db) grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
models = [] models = []
for model, user in items: for model, user in items:
@@ -363,19 +347,13 @@ class ModelsTable:
access_grants=grants_map.get(model.id, []), access_grants=grants_map.get(model.id, []),
db=db, db=db,
).model_dump(), ).model_dump(),
user=( user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
) )
) )
return ModelListResponse(items=models, total=total) return ModelListResponse(items=models, total=total)
def get_model_by_id( def get_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[ModelModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
model = db.get(Model, id) model = db.get(Model, id)
@@ -383,16 +361,12 @@ class ModelsTable:
except Exception: except Exception:
return None return None
def get_models_by_ids( def get_models_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[ModelModel]:
self, ids: list[str], db: Optional[Session] = None
) -> list[ModelModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
models = db.query(Model).filter(Model.id.in_(ids)).all() models = db.query(Model).filter(Model.id.in_(ids)).all()
model_ids = [model.id for model in models] model_ids = [model.id for model in models]
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
"model", model_ids, db=db
)
return [ return [
self._to_model_model( self._to_model_model(
model, model,
@@ -404,9 +378,7 @@ class ModelsTable:
except Exception: except Exception:
return [] return []
def toggle_model_by_id( def toggle_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[ModelModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
model = db.query(Model).filter_by(id=id).first() model = db.query(Model).filter_by(id=id).first()
@@ -422,30 +394,26 @@ class ModelsTable:
except Exception: except Exception:
return None return None
def update_model_by_id( def update_model_by_id(self, id: str, model: ModelForm, db: Optional[Session] = None) -> Optional[ModelModel]:
self, id: str, model: ModelForm, db: Optional[Session] = None
) -> Optional[ModelModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
# update only the fields that are present in the model # update only the fields that are present in the model
data = model.model_dump(exclude={"id", "access_grants"}) data = model.model_dump(exclude={'id', 'access_grants'})
result = db.query(Model).filter_by(id=id).update(data) result = db.query(Model).filter_by(id=id).update(data)
db.commit() db.commit()
if model.access_grants is not None: if model.access_grants is not None:
AccessGrants.set_access_grants( AccessGrants.set_access_grants('model', id, model.access_grants, db=db)
"model", id, model.access_grants, db=db
)
return self.get_model_by_id(id, db=db) return self.get_model_by_id(id, db=db)
except Exception as e: except Exception as e:
log.exception(f"Failed to update the model by id {id}: {e}") log.exception(f'Failed to update the model by id {id}: {e}')
return None return None
def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool: def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
AccessGrants.revoke_all_access("model", id, db=db) AccessGrants.revoke_all_access('model', id, db=db)
db.query(Model).filter_by(id=id).delete() db.query(Model).filter_by(id=id).delete()
db.commit() db.commit()
@@ -458,7 +426,7 @@ class ModelsTable:
with get_db_context(db) as db: with get_db_context(db) as db:
model_ids = [row[0] for row in db.query(Model.id).all()] model_ids = [row[0] for row in db.query(Model.id).all()]
for model_id in model_ids: for model_id in model_ids:
AccessGrants.revoke_all_access("model", model_id, db=db) AccessGrants.revoke_all_access('model', model_id, db=db)
db.query(Model).delete() db.query(Model).delete()
db.commit() db.commit()
@@ -466,9 +434,7 @@ class ModelsTable:
except Exception: except Exception:
return False return False
def sync_models( def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[Session] = None) -> list[ModelModel]:
self, user_id: str, models: list[ModelModel], db: Optional[Session] = None
) -> list[ModelModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
# Get existing models # Get existing models
@@ -483,37 +449,33 @@ class ModelsTable:
if model.id in existing_ids: if model.id in existing_ids:
db.query(Model).filter_by(id=model.id).update( db.query(Model).filter_by(id=model.id).update(
{ {
**model.model_dump(exclude={"access_grants"}), **model.model_dump(exclude={'access_grants'}),
"user_id": user_id, 'user_id': user_id,
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
else: else:
new_model = Model( new_model = Model(
**{ **{
**model.model_dump(exclude={"access_grants"}), **model.model_dump(exclude={'access_grants'}),
"user_id": user_id, 'user_id': user_id,
"updated_at": int(time.time()), 'updated_at': int(time.time()),
} }
) )
db.add(new_model) db.add(new_model)
AccessGrants.set_access_grants( AccessGrants.set_access_grants('model', model.id, model.access_grants, db=db)
"model", model.id, model.access_grants, db=db
)
# Remove models that are no longer present # Remove models that are no longer present
for model in existing_models: for model in existing_models:
if model.id not in new_model_ids: if model.id not in new_model_ids:
AccessGrants.revoke_all_access("model", model.id, db=db) AccessGrants.revoke_all_access('model', model.id, db=db)
db.delete(model) db.delete(model)
db.commit() db.commit()
all_models = db.query(Model).all() all_models = db.query(Model).all()
model_ids = [model.id for model in all_models] model_ids = [model.id for model in all_models]
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
"model", model_ids, db=db
)
return [ return [
self._to_model_model( self._to_model_model(
model, model,
@@ -523,7 +485,7 @@ class ModelsTable:
for model in all_models for model in all_models
] ]
except Exception as e: except Exception as e:
log.exception(f"Error syncing models for user {user_id}: {e}") log.exception(f'Error syncing models for user {user_id}: {e}')
return [] return []

View File

@@ -21,7 +21,7 @@ from sqlalchemy import or_, func, cast
class Note(Base): class Note(Base):
__tablename__ = "note" __tablename__ = 'note'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text) user_id = Column(Text)
@@ -88,10 +88,8 @@ class NoteListResponse(BaseModel):
class NoteTable: class NoteTable:
def _get_access_grants( def _get_access_grants(self, note_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
self, note_id: str, db: Optional[Session] = None return AccessGrants.get_grants_by_resource('note', note_id, db=db)
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("note", note_id, db=db)
def _to_note_model( def _to_note_model(
self, self,
@@ -99,51 +97,43 @@ class NoteTable:
access_grants: Optional[list[AccessGrantModel]] = None, access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> NoteModel: ) -> NoteModel:
note_data = NoteModel.model_validate(note).model_dump(exclude={"access_grants"}) note_data = NoteModel.model_validate(note).model_dump(exclude={'access_grants'})
note_data["access_grants"] = ( note_data['access_grants'] = (
access_grants access_grants if access_grants is not None else self._get_access_grants(note_data['id'], db=db)
if access_grants is not None
else self._get_access_grants(note_data["id"], db=db)
) )
return NoteModel.model_validate(note_data) return NoteModel.model_validate(note_data)
def _has_permission(self, db, query, filter: dict, permission: str = "read"): def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
return AccessGrants.has_permission_filter( return AccessGrants.has_permission_filter(
db=db, db=db,
query=query, query=query,
DocumentModel=Note, DocumentModel=Note,
filter=filter, filter=filter,
resource_type="note", resource_type='note',
permission=permission, permission=permission,
) )
def insert_new_note( def insert_new_note(self, user_id: str, form_data: NoteForm, db: Optional[Session] = None) -> Optional[NoteModel]:
self, user_id: str, form_data: NoteForm, db: Optional[Session] = None
) -> Optional[NoteModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
note = NoteModel( note = NoteModel(
**{ **{
"id": str(uuid.uuid4()), 'id': str(uuid.uuid4()),
"user_id": user_id, 'user_id': user_id,
**form_data.model_dump(exclude={"access_grants"}), **form_data.model_dump(exclude={'access_grants'}),
"created_at": int(time.time_ns()), 'created_at': int(time.time_ns()),
"updated_at": int(time.time_ns()), 'updated_at': int(time.time_ns()),
"access_grants": [], 'access_grants': [],
} }
) )
new_note = Note(**note.model_dump(exclude={"access_grants"})) new_note = Note(**note.model_dump(exclude={'access_grants'}))
db.add(new_note) db.add(new_note)
db.commit() db.commit()
AccessGrants.set_access_grants( AccessGrants.set_access_grants('note', note.id, form_data.access_grants, db=db)
"note", note.id, form_data.access_grants, db=db
)
return self._to_note_model(new_note, db=db) return self._to_note_model(new_note, db=db)
def get_notes( def get_notes(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[NoteModel]:
self, skip: int = 0, limit: int = 50, db: Optional[Session] = None
) -> list[NoteModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
query = db.query(Note).order_by(Note.updated_at.desc()) query = db.query(Note).order_by(Note.updated_at.desc())
if skip is not None: if skip is not None:
@@ -152,13 +142,8 @@ class NoteTable:
query = query.limit(limit) query = query.limit(limit)
notes = query.all() notes = query.all()
note_ids = [note.id for note in notes] note_ids = [note.id for note in notes]
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db) grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
return [ return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes]
self._to_note_model(
note, access_grants=grants_map.get(note.id, []), db=db
)
for note in notes
]
def search_notes( def search_notes(
self, self,
@@ -171,36 +156,32 @@ class NoteTable:
with get_db_context(db) as db: with get_db_context(db) as db:
query = db.query(Note, User).outerjoin(User, User.id == Note.user_id) query = db.query(Note, User).outerjoin(User, User.id == Note.user_id)
if filter: if filter:
query_key = filter.get("query") query_key = filter.get('query')
if query_key: if query_key:
# Normalize search by removing hyphens and spaces (e.g., "todo" matches "to-do" and "to do") # Normalize search by removing hyphens and spaces (e.g., "todo" matches "to-do" and "to do")
normalized_query = query_key.replace("-", "").replace(" ", "") normalized_query = query_key.replace('-', '').replace(' ', '')
query = query.filter( query = query.filter(
or_( or_(
func.replace(func.replace(Note.title, '-', ''), ' ', '').ilike(f'%{normalized_query}%'),
func.replace( func.replace(
func.replace(Note.title, "-", ""), " ", "" func.replace(cast(Note.data['content']['md'], Text), '-', ''),
).ilike(f"%{normalized_query}%"), ' ',
func.replace( '',
func.replace( ).ilike(f'%{normalized_query}%'),
cast(Note.data["content"]["md"], Text), "-", ""
),
" ",
"",
).ilike(f"%{normalized_query}%"),
) )
) )
view_option = filter.get("view_option") view_option = filter.get('view_option')
if view_option == "created": if view_option == 'created':
query = query.filter(Note.user_id == user_id) query = query.filter(Note.user_id == user_id)
elif view_option == "shared": elif view_option == 'shared':
query = query.filter(Note.user_id != user_id) query = query.filter(Note.user_id != user_id)
# Apply access control filtering # Apply access control filtering
if "permission" in filter: if 'permission' in filter:
permission = filter["permission"] permission = filter['permission']
else: else:
permission = "write" permission = 'write'
query = self._has_permission( query = self._has_permission(
db, db,
@@ -209,21 +190,21 @@ class NoteTable:
permission=permission, permission=permission,
) )
order_by = filter.get("order_by") order_by = filter.get('order_by')
direction = filter.get("direction") direction = filter.get('direction')
if order_by == "name": if order_by == 'name':
if direction == "asc": if direction == 'asc':
query = query.order_by(Note.title.asc()) query = query.order_by(Note.title.asc())
else: else:
query = query.order_by(Note.title.desc()) query = query.order_by(Note.title.desc())
elif order_by == "created_at": elif order_by == 'created_at':
if direction == "asc": if direction == 'asc':
query = query.order_by(Note.created_at.asc()) query = query.order_by(Note.created_at.asc())
else: else:
query = query.order_by(Note.created_at.desc()) query = query.order_by(Note.created_at.desc())
elif order_by == "updated_at": elif order_by == 'updated_at':
if direction == "asc": if direction == 'asc':
query = query.order_by(Note.updated_at.asc()) query = query.order_by(Note.updated_at.asc())
else: else:
query = query.order_by(Note.updated_at.desc()) query = query.order_by(Note.updated_at.desc())
@@ -244,7 +225,7 @@ class NoteTable:
items = query.all() items = query.all()
note_ids = [note.id for note, _ in items] note_ids = [note.id for note, _ in items]
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db) grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
notes = [] notes = []
for note, user in items: for note, user in items:
@@ -255,11 +236,7 @@ class NoteTable:
access_grants=grants_map.get(note.id, []), access_grants=grants_map.get(note.id, []),
db=db, db=db,
).model_dump(), ).model_dump(),
user=( user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
) )
) )
@@ -268,20 +245,16 @@ class NoteTable:
def get_notes_by_user_id( def get_notes_by_user_id(
self, self,
user_id: str, user_id: str,
permission: str = "read", permission: str = 'read',
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> list[NoteModel]: ) -> list[NoteModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
user_group_ids = [ user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
]
query = db.query(Note).order_by(Note.updated_at.desc()) query = db.query(Note).order_by(Note.updated_at.desc())
query = self._has_permission( query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids}, permission)
db, query, {"user_id": user_id, "group_ids": user_group_ids}, permission
)
if skip is not None: if skip is not None:
query = query.offset(skip) query = query.offset(skip)
@@ -290,17 +263,10 @@ class NoteTable:
notes = query.all() notes = query.all()
note_ids = [note.id for note in notes] note_ids = [note.id for note in notes]
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db) grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
return [ return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes]
self._to_note_model(
note, access_grants=grants_map.get(note.id, []), db=db
)
for note in notes
]
def get_note_by_id( def get_note_by_id(self, id: str, db: Optional[Session] = None) -> Optional[NoteModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[NoteModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
note = db.query(Note).filter(Note.id == id).first() note = db.query(Note).filter(Note.id == id).first()
return self._to_note_model(note, db=db) if note else None return self._to_note_model(note, db=db) if note else None
@@ -315,17 +281,15 @@ class NoteTable:
form_data = form_data.model_dump(exclude_unset=True) form_data = form_data.model_dump(exclude_unset=True)
if "title" in form_data: if 'title' in form_data:
note.title = form_data["title"] note.title = form_data['title']
if "data" in form_data: if 'data' in form_data:
note.data = {**note.data, **form_data["data"]} note.data = {**note.data, **form_data['data']}
if "meta" in form_data: if 'meta' in form_data:
note.meta = {**note.meta, **form_data["meta"]} note.meta = {**note.meta, **form_data['meta']}
if "access_grants" in form_data: if 'access_grants' in form_data:
AccessGrants.set_access_grants( AccessGrants.set_access_grants('note', id, form_data['access_grants'], db=db)
"note", id, form_data["access_grants"], db=db
)
note.updated_at = int(time.time_ns()) note.updated_at = int(time.time_ns())
@@ -335,7 +299,7 @@ class NoteTable:
def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool: def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
AccessGrants.revoke_all_access("note", id, db=db) AccessGrants.revoke_all_access('note', id, db=db)
db.query(Note).filter(Note.id == id).delete() db.query(Note).filter(Note.id == id).delete()
db.commit() db.commit()
return True return True

View File

@@ -23,23 +23,21 @@ log = logging.getLogger(__name__)
class OAuthSession(Base): class OAuthSession(Base):
__tablename__ = "oauth_session" __tablename__ = 'oauth_session'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text, nullable=False) user_id = Column(Text, nullable=False)
provider = Column(Text, nullable=False) provider = Column(Text, nullable=False)
token = Column( token = Column(Text, nullable=False) # JSON with access_token, id_token, refresh_token
Text, nullable=False
) # JSON with access_token, id_token, refresh_token
expires_at = Column(BigInteger, nullable=False) expires_at = Column(BigInteger, nullable=False)
created_at = Column(BigInteger, nullable=False) created_at = Column(BigInteger, nullable=False)
updated_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False)
# Add indexes for better performance # Add indexes for better performance
__table_args__ = ( __table_args__ = (
Index("idx_oauth_session_user_id", "user_id"), Index('idx_oauth_session_user_id', 'user_id'),
Index("idx_oauth_session_expires_at", "expires_at"), Index('idx_oauth_session_expires_at', 'expires_at'),
Index("idx_oauth_session_user_provider", "user_id", "provider"), Index('idx_oauth_session_user_provider', 'user_id', 'provider'),
) )
@@ -71,7 +69,7 @@ class OAuthSessionTable:
def __init__(self): def __init__(self):
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
if not self.encryption_key: if not self.encryption_key:
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set") raise Exception('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set')
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes) # check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
if len(self.encryption_key) != 44: if len(self.encryption_key) != 44:
@@ -83,7 +81,7 @@ class OAuthSessionTable:
try: try:
self.fernet = Fernet(self.encryption_key) self.fernet = Fernet(self.encryption_key)
except Exception as e: except Exception as e:
log.error(f"Error initializing Fernet with provided key: {e}") log.error(f'Error initializing Fernet with provided key: {e}')
raise raise
def _encrypt_token(self, token) -> str: def _encrypt_token(self, token) -> str:
@@ -93,7 +91,7 @@ class OAuthSessionTable:
encrypted = self.fernet.encrypt(token_json.encode()).decode() encrypted = self.fernet.encrypt(token_json.encode()).decode()
return encrypted return encrypted
except Exception as e: except Exception as e:
log.error(f"Error encrypting tokens: {e}") log.error(f'Error encrypting tokens: {e}')
raise raise
def _decrypt_token(self, token: str): def _decrypt_token(self, token: str):
@@ -102,7 +100,7 @@ class OAuthSessionTable:
decrypted = self.fernet.decrypt(token.encode()).decode() decrypted = self.fernet.decrypt(token.encode()).decode()
return json.loads(decrypted) return json.loads(decrypted)
except Exception as e: except Exception as e:
log.error(f"Error decrypting tokens: {type(e).__name__}: {e}") log.error(f'Error decrypting tokens: {type(e).__name__}: {e}')
raise raise
def create_session( def create_session(
@@ -120,13 +118,13 @@ class OAuthSessionTable:
result = OAuthSession( result = OAuthSession(
**{ **{
"id": id, 'id': id,
"user_id": user_id, 'user_id': user_id,
"provider": provider, 'provider': provider,
"token": self._encrypt_token(token), 'token': self._encrypt_token(token),
"expires_at": token.get("expires_at"), 'expires_at': token.get('expires_at'),
"created_at": current_time, 'created_at': current_time,
"updated_at": current_time, 'updated_at': current_time,
} }
) )
@@ -141,12 +139,10 @@ class OAuthSessionTable:
else: else:
return None return None
except Exception as e: except Exception as e:
log.error(f"Error creating OAuth session: {e}") log.error(f'Error creating OAuth session: {e}')
return None return None
def get_session_by_id( def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]:
self, session_id: str, db: Optional[Session] = None
) -> Optional[OAuthSessionModel]:
"""Get OAuth session by ID""" """Get OAuth session by ID"""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
@@ -158,7 +154,7 @@ class OAuthSessionTable:
return None return None
except Exception as e: except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}") log.error(f'Error getting OAuth session by ID: {e}')
return None return None
def get_session_by_id_and_user_id( def get_session_by_id_and_user_id(
@@ -167,11 +163,7 @@ class OAuthSessionTable:
"""Get OAuth session by ID and user ID""" """Get OAuth session by ID and user ID"""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
session = ( session = db.query(OAuthSession).filter_by(id=session_id, user_id=user_id).first()
db.query(OAuthSession)
.filter_by(id=session_id, user_id=user_id)
.first()
)
if session: if session:
db.expunge(session) db.expunge(session)
session.token = self._decrypt_token(session.token) session.token = self._decrypt_token(session.token)
@@ -179,7 +171,7 @@ class OAuthSessionTable:
return None return None
except Exception as e: except Exception as e:
log.error(f"Error getting OAuth session by ID: {e}") log.error(f'Error getting OAuth session by ID: {e}')
return None return None
def get_session_by_provider_and_user_id( def get_session_by_provider_and_user_id(
@@ -201,12 +193,10 @@ class OAuthSessionTable:
return None return None
except Exception as e: except Exception as e:
log.error(f"Error getting OAuth session by provider and user ID: {e}") log.error(f'Error getting OAuth session by provider and user ID: {e}')
return None return None
def get_sessions_by_user_id( def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]:
self, user_id: str, db: Optional[Session] = None
) -> List[OAuthSessionModel]:
"""Get all OAuth sessions for a user""" """Get all OAuth sessions for a user"""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
@@ -220,7 +210,7 @@ class OAuthSessionTable:
results.append(OAuthSessionModel.model_validate(session)) results.append(OAuthSessionModel.model_validate(session))
except Exception as e: except Exception as e:
log.warning( log.warning(
f"Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}" f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}'
) )
db.query(OAuthSession).filter_by(id=session.id).delete() db.query(OAuthSession).filter_by(id=session.id).delete()
db.commit() db.commit()
@@ -228,7 +218,7 @@ class OAuthSessionTable:
return results return results
except Exception as e: except Exception as e:
log.error(f"Error getting OAuth sessions by user ID: {e}") log.error(f'Error getting OAuth sessions by user ID: {e}')
return [] return []
def update_session_by_id( def update_session_by_id(
@@ -241,9 +231,9 @@ class OAuthSessionTable:
db.query(OAuthSession).filter_by(id=session_id).update( db.query(OAuthSession).filter_by(id=session_id).update(
{ {
"token": self._encrypt_token(token), 'token': self._encrypt_token(token),
"expires_at": token.get("expires_at"), 'expires_at': token.get('expires_at'),
"updated_at": current_time, 'updated_at': current_time,
} }
) )
db.commit() db.commit()
@@ -256,12 +246,10 @@ class OAuthSessionTable:
return None return None
except Exception as e: except Exception as e:
log.error(f"Error updating OAuth session tokens: {e}") log.error(f'Error updating OAuth session tokens: {e}')
return None return None
def delete_session_by_id( def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool:
self, session_id: str, db: Optional[Session] = None
) -> bool:
"""Delete an OAuth session""" """Delete an OAuth session"""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
@@ -269,12 +257,10 @@ class OAuthSessionTable:
db.commit() db.commit()
return result > 0 return result > 0
except Exception as e: except Exception as e:
log.error(f"Error deleting OAuth session: {e}") log.error(f'Error deleting OAuth session: {e}')
return False return False
def delete_sessions_by_user_id( def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
self, user_id: str, db: Optional[Session] = None
) -> bool:
"""Delete all OAuth sessions for a user""" """Delete all OAuth sessions for a user"""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
@@ -282,12 +268,10 @@ class OAuthSessionTable:
db.commit() db.commit()
return True return True
except Exception as e: except Exception as e:
log.error(f"Error deleting OAuth sessions by user ID: {e}") log.error(f'Error deleting OAuth sessions by user ID: {e}')
return False return False
def delete_sessions_by_provider( def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool:
self, provider: str, db: Optional[Session] = None
) -> bool:
"""Delete all OAuth sessions for a provider""" """Delete all OAuth sessions for a provider"""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
@@ -295,7 +279,7 @@ class OAuthSessionTable:
db.commit() db.commit()
return True return True
except Exception as e: except Exception as e:
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}") log.error(f'Error deleting OAuth sessions by provider {provider}: {e}')
return False return False

View File

@@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Column, Text, JSON, Index
class PromptHistory(Base): class PromptHistory(Base):
__tablename__ = "prompt_history" __tablename__ = 'prompt_history'
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True)
prompt_id = Column(Text, nullable=False, index=True) prompt_id = Column(Text, nullable=False, index=True)
@@ -100,11 +100,7 @@ class PromptHistoryTable:
return [ return [
PromptHistoryResponse( PromptHistoryResponse(
**PromptHistoryModel.model_validate(entry).model_dump(), **PromptHistoryModel.model_validate(entry).model_dump(),
user=( user=(users_dict.get(entry.user_id).model_dump() if users_dict.get(entry.user_id) else None),
users_dict.get(entry.user_id).model_dump()
if users_dict.get(entry.user_id)
else None
),
) )
for entry in entries for entry in entries
] ]
@@ -116,9 +112,7 @@ class PromptHistoryTable:
) -> Optional[PromptHistoryModel]: ) -> Optional[PromptHistoryModel]:
"""Get a specific history entry by ID.""" """Get a specific history entry by ID."""
with get_db_context(db) as db: with get_db_context(db) as db:
entry = ( entry = db.query(PromptHistory).filter(PromptHistory.id == history_id).first()
db.query(PromptHistory).filter(PromptHistory.id == history_id).first()
)
if entry: if entry:
return PromptHistoryModel.model_validate(entry) return PromptHistoryModel.model_validate(entry)
return None return None
@@ -147,11 +141,7 @@ class PromptHistoryTable:
) -> int: ) -> int:
"""Get the number of history entries for a prompt.""" """Get the number of history entries for a prompt."""
with get_db_context(db) as db: with get_db_context(db) as db:
return ( return db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).count()
db.query(PromptHistory)
.filter(PromptHistory.prompt_id == prompt_id)
.count()
)
def compute_diff( def compute_diff(
self, self,
@@ -161,9 +151,7 @@ class PromptHistoryTable:
) -> Optional[dict]: ) -> Optional[dict]:
"""Compute diff between two history entries.""" """Compute diff between two history entries."""
with get_db_context(db) as db: with get_db_context(db) as db:
from_entry = ( from_entry = db.query(PromptHistory).filter(PromptHistory.id == from_id).first()
db.query(PromptHistory).filter(PromptHistory.id == from_id).first()
)
to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first() to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first()
if not from_entry or not to_entry: if not from_entry or not to_entry:
@@ -173,26 +161,26 @@ class PromptHistoryTable:
to_snapshot = to_entry.snapshot to_snapshot = to_entry.snapshot
# Compute diff for content field # Compute diff for content field
from_content = from_snapshot.get("content", "") from_content = from_snapshot.get('content', '')
to_content = to_snapshot.get("content", "") to_content = to_snapshot.get('content', '')
diff_lines = list( diff_lines = list(
difflib.unified_diff( difflib.unified_diff(
from_content.splitlines(keepends=True), from_content.splitlines(keepends=True),
to_content.splitlines(keepends=True), to_content.splitlines(keepends=True),
fromfile=f"v{from_id[:8]}", fromfile=f'v{from_id[:8]}',
tofile=f"v{to_id[:8]}", tofile=f'v{to_id[:8]}',
lineterm="", lineterm='',
) )
) )
return { return {
"from_id": from_id, 'from_id': from_id,
"to_id": to_id, 'to_id': to_id,
"from_snapshot": from_snapshot, 'from_snapshot': from_snapshot,
"to_snapshot": to_snapshot, 'to_snapshot': to_snapshot,
"content_diff": diff_lines, 'content_diff': diff_lines,
"name_changed": from_snapshot.get("name") != to_snapshot.get("name"), 'name_changed': from_snapshot.get('name') != to_snapshot.get('name'),
} }
def delete_history_by_prompt_id( def delete_history_by_prompt_id(
@@ -202,9 +190,7 @@ class PromptHistoryTable:
) -> bool: ) -> bool:
"""Delete all history entries for a prompt.""" """Delete all history entries for a prompt."""
with get_db_context(db) as db: with get_db_context(db) as db:
db.query(PromptHistory).filter( db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).delete()
PromptHistory.prompt_id == prompt_id
).delete()
db.commit() db.commit()
return True return True

View File

@@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, or_, fun
class Prompt(Base): class Prompt(Base):
__tablename__ = "prompt" __tablename__ = 'prompt'
id = Column(Text, primary_key=True) id = Column(Text, primary_key=True)
command = Column(String, unique=True, index=True) command = Column(String, unique=True, index=True)
@@ -77,7 +77,6 @@ class PromptAccessListResponse(BaseModel):
class PromptForm(BaseModel): class PromptForm(BaseModel):
command: str command: str
name: str # Changed from title name: str # Changed from title
content: str content: str
@@ -91,10 +90,8 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def _get_access_grants( def _get_access_grants(self, prompt_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
self, prompt_id: str, db: Optional[Session] = None return AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db)
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("prompt", prompt_id, db=db)
def _to_prompt_model( def _to_prompt_model(
self, self,
@@ -102,13 +99,9 @@ class PromptsTable:
access_grants: Optional[list[AccessGrantModel]] = None, access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> PromptModel: ) -> PromptModel:
prompt_data = PromptModel.model_validate(prompt).model_dump( prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'})
exclude={"access_grants"} prompt_data['access_grants'] = (
) access_grants if access_grants is not None else self._get_access_grants(prompt_data['id'], db=db)
prompt_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(prompt_data["id"], db=db)
) )
return PromptModel.model_validate(prompt_data) return PromptModel.model_validate(prompt_data)
@@ -135,26 +128,22 @@ class PromptsTable:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
result = Prompt(**prompt.model_dump(exclude={"access_grants"})) result = Prompt(**prompt.model_dump(exclude={'access_grants'}))
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
AccessGrants.set_access_grants( AccessGrants.set_access_grants('prompt', prompt_id, form_data.access_grants, db=db)
"prompt", prompt_id, form_data.access_grants, db=db
)
if result: if result:
current_access_grants = self._get_access_grants(prompt_id, db=db) current_access_grants = self._get_access_grants(prompt_id, db=db)
snapshot = { snapshot = {
"name": form_data.name, 'name': form_data.name,
"content": form_data.content, 'content': form_data.content,
"command": form_data.command, 'command': form_data.command,
"data": form_data.data or {}, 'data': form_data.data or {},
"meta": form_data.meta or {}, 'meta': form_data.meta or {},
"tags": form_data.tags or [], 'tags': form_data.tags or [],
"access_grants": [ 'access_grants': [grant.model_dump() for grant in current_access_grants],
grant.model_dump() for grant in current_access_grants
],
} }
history_entry = PromptHistories.create_history_entry( history_entry = PromptHistories.create_history_entry(
@@ -162,7 +151,7 @@ class PromptsTable:
snapshot=snapshot, snapshot=snapshot,
user_id=user_id, user_id=user_id,
parent_id=None, # Initial commit has no parent parent_id=None, # Initial commit has no parent
commit_message=form_data.commit_message or "Initial version", commit_message=form_data.commit_message or 'Initial version',
db=db, db=db,
) )
@@ -178,9 +167,7 @@ class PromptsTable:
except Exception: except Exception:
return None return None
def get_prompt_by_id( def get_prompt_by_id(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]:
self, prompt_id: str, db: Optional[Session] = None
) -> Optional[PromptModel]:
"""Get prompt by UUID.""" """Get prompt by UUID."""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
@@ -191,9 +178,7 @@ class PromptsTable:
except Exception: except Exception:
return None return None
def get_prompt_by_command( def get_prompt_by_command(self, command: str, db: Optional[Session] = None) -> Optional[PromptModel]:
self, command: str, db: Optional[Session] = None
) -> Optional[PromptModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = db.query(Prompt).filter_by(command=command).first()
@@ -205,21 +190,14 @@ class PromptsTable:
def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]: def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]:
with get_db_context(db) as db: with get_db_context(db) as db:
all_prompts = ( all_prompts = db.query(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc()).all()
db.query(Prompt)
.filter(Prompt.is_active == True)
.order_by(Prompt.updated_at.desc())
.all()
)
user_ids = list(set(prompt.user_id for prompt in all_prompts)) user_ids = list(set(prompt.user_id for prompt in all_prompts))
prompt_ids = [prompt.id for prompt in all_prompts] prompt_ids = [prompt.id for prompt in all_prompts]
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users} users_dict = {user.id: user for user in users}
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
"prompt", prompt_ids, db=db
)
prompts = [] prompts = []
for prompt in all_prompts: for prompt in all_prompts:
@@ -232,7 +210,7 @@ class PromptsTable:
access_grants=grants_map.get(prompt.id, []), access_grants=grants_map.get(prompt.id, []),
db=db, db=db,
).model_dump(), ).model_dump(),
"user": user.model_dump() if user else None, 'user': user.model_dump() if user else None,
} }
) )
) )
@@ -240,12 +218,10 @@ class PromptsTable:
return prompts return prompts
def get_prompts_by_user_id( def get_prompts_by_user_id(
self, user_id: str, permission: str = "write", db: Optional[Session] = None self, user_id: str, permission: str = 'write', db: Optional[Session] = None
) -> list[PromptUserResponse]: ) -> list[PromptUserResponse]:
prompts = self.get_prompts(db=db) prompts = self.get_prompts(db=db)
user_group_ids = { user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
return [ return [
prompt prompt
@@ -253,7 +229,7 @@ class PromptsTable:
if prompt.user_id == user_id if prompt.user_id == user_id
or AccessGrants.has_access( or AccessGrants.has_access(
user_id=user_id, user_id=user_id,
resource_type="prompt", resource_type='prompt',
resource_id=prompt.id, resource_id=prompt.id,
permission=permission, permission=permission,
user_group_ids=user_group_ids, user_group_ids=user_group_ids,
@@ -276,22 +252,22 @@ class PromptsTable:
query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id) query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id)
if filter: if filter:
query_key = filter.get("query") query_key = filter.get('query')
if query_key: if query_key:
query = query.filter( query = query.filter(
or_( or_(
Prompt.name.ilike(f"%{query_key}%"), Prompt.name.ilike(f'%{query_key}%'),
Prompt.command.ilike(f"%{query_key}%"), Prompt.command.ilike(f'%{query_key}%'),
Prompt.content.ilike(f"%{query_key}%"), Prompt.content.ilike(f'%{query_key}%'),
User.name.ilike(f"%{query_key}%"), User.name.ilike(f'%{query_key}%'),
User.email.ilike(f"%{query_key}%"), User.email.ilike(f'%{query_key}%'),
) )
) )
view_option = filter.get("view_option") view_option = filter.get('view_option')
if view_option == "created": if view_option == 'created':
query = query.filter(Prompt.user_id == user_id) query = query.filter(Prompt.user_id == user_id)
elif view_option == "shared": elif view_option == 'shared':
query = query.filter(Prompt.user_id != user_id) query = query.filter(Prompt.user_id != user_id)
# Apply access grant filtering # Apply access grant filtering
@@ -300,32 +276,32 @@ class PromptsTable:
query=query, query=query,
DocumentModel=Prompt, DocumentModel=Prompt,
filter=filter, filter=filter,
resource_type="prompt", resource_type='prompt',
permission="read", permission='read',
) )
tag = filter.get("tag") tag = filter.get('tag')
if tag: if tag:
# Search for tag in JSON array field # Search for tag in JSON array field
like_pattern = f'%"{tag.lower()}"%' like_pattern = f'%"{tag.lower()}"%'
tags_text = func.lower(cast(Prompt.tags, String)) tags_text = func.lower(cast(Prompt.tags, String))
query = query.filter(tags_text.like(like_pattern)) query = query.filter(tags_text.like(like_pattern))
order_by = filter.get("order_by") order_by = filter.get('order_by')
direction = filter.get("direction") direction = filter.get('direction')
if order_by == "name": if order_by == 'name':
if direction == "asc": if direction == 'asc':
query = query.order_by(Prompt.name.asc()) query = query.order_by(Prompt.name.asc())
else: else:
query = query.order_by(Prompt.name.desc()) query = query.order_by(Prompt.name.desc())
elif order_by == "created_at": elif order_by == 'created_at':
if direction == "asc": if direction == 'asc':
query = query.order_by(Prompt.created_at.asc()) query = query.order_by(Prompt.created_at.asc())
else: else:
query = query.order_by(Prompt.created_at.desc()) query = query.order_by(Prompt.created_at.desc())
elif order_by == "updated_at": elif order_by == 'updated_at':
if direction == "asc": if direction == 'asc':
query = query.order_by(Prompt.updated_at.asc()) query = query.order_by(Prompt.updated_at.asc())
else: else:
query = query.order_by(Prompt.updated_at.desc()) query = query.order_by(Prompt.updated_at.desc())
@@ -345,9 +321,7 @@ class PromptsTable:
items = query.all() items = query.all()
prompt_ids = [prompt.id for prompt, _ in items] prompt_ids = [prompt.id for prompt, _ in items]
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
"prompt", prompt_ids, db=db
)
prompts = [] prompts = []
for prompt, user in items: for prompt, user in items:
@@ -358,11 +332,7 @@ class PromptsTable:
access_grants=grants_map.get(prompt.id, []), access_grants=grants_map.get(prompt.id, []),
db=db, db=db,
).model_dump(), ).model_dump(),
user=( user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
UserResponse(**UserModel.model_validate(user).model_dump())
if user
else None
),
) )
) )
@@ -381,9 +351,7 @@ class PromptsTable:
if not prompt: if not prompt:
return None return None
latest_history = PromptHistories.get_latest_history_entry( latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db)
prompt.id, db=db
)
parent_id = latest_history.id if latest_history else None parent_id = latest_history.id if latest_history else None
current_access_grants = self._get_access_grants(prompt.id, db=db) current_access_grants = self._get_access_grants(prompt.id, db=db)
@@ -401,9 +369,7 @@ class PromptsTable:
prompt.meta = form_data.meta or prompt.meta prompt.meta = form_data.meta or prompt.meta
prompt.updated_at = int(time.time()) prompt.updated_at = int(time.time())
if form_data.access_grants is not None: if form_data.access_grants is not None:
AccessGrants.set_access_grants( AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
"prompt", prompt.id, form_data.access_grants, db=db
)
current_access_grants = self._get_access_grants(prompt.id, db=db) current_access_grants = self._get_access_grants(prompt.id, db=db)
db.commit() db.commit()
@@ -411,14 +377,12 @@ class PromptsTable:
# Create history entry only if content changed # Create history entry only if content changed
if content_changed: if content_changed:
snapshot = { snapshot = {
"name": form_data.name, 'name': form_data.name,
"content": form_data.content, 'content': form_data.content,
"command": command, 'command': command,
"data": form_data.data or {}, 'data': form_data.data or {},
"meta": form_data.meta or {}, 'meta': form_data.meta or {},
"access_grants": [ 'access_grants': [grant.model_dump() for grant in current_access_grants],
grant.model_dump() for grant in current_access_grants
],
} }
history_entry = PromptHistories.create_history_entry( history_entry = PromptHistories.create_history_entry(
@@ -452,9 +416,7 @@ class PromptsTable:
if not prompt: if not prompt:
return None return None
latest_history = PromptHistories.get_latest_history_entry( latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db)
prompt.id, db=db
)
parent_id = latest_history.id if latest_history else None parent_id = latest_history.id if latest_history else None
current_access_grants = self._get_access_grants(prompt.id, db=db) current_access_grants = self._get_access_grants(prompt.id, db=db)
@@ -478,9 +440,7 @@ class PromptsTable:
prompt.tags = form_data.tags prompt.tags = form_data.tags
if form_data.access_grants is not None: if form_data.access_grants is not None:
AccessGrants.set_access_grants( AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
"prompt", prompt.id, form_data.access_grants, db=db
)
current_access_grants = self._get_access_grants(prompt.id, db=db) current_access_grants = self._get_access_grants(prompt.id, db=db)
prompt.updated_at = int(time.time()) prompt.updated_at = int(time.time())
@@ -490,15 +450,13 @@ class PromptsTable:
# Create history entry only if content changed # Create history entry only if content changed
if content_changed: if content_changed:
snapshot = { snapshot = {
"name": form_data.name, 'name': form_data.name,
"content": form_data.content, 'content': form_data.content,
"command": prompt.command, 'command': prompt.command,
"data": form_data.data or {}, 'data': form_data.data or {},
"meta": form_data.meta or {}, 'meta': form_data.meta or {},
"tags": prompt.tags or [], 'tags': prompt.tags or [],
"access_grants": [ 'access_grants': [grant.model_dump() for grant in current_access_grants],
grant.model_dump() for grant in current_access_grants
],
} }
history_entry = PromptHistories.create_history_entry( history_entry = PromptHistories.create_history_entry(
@@ -560,9 +518,7 @@ class PromptsTable:
if not prompt: if not prompt:
return None return None
history_entry = PromptHistories.get_history_entry_by_id( history_entry = PromptHistories.get_history_entry_by_id(version_id, db=db)
version_id, db=db
)
if not history_entry: if not history_entry:
return None return None
@@ -570,11 +526,11 @@ class PromptsTable:
# Restore prompt content from the snapshot # Restore prompt content from the snapshot
snapshot = history_entry.snapshot snapshot = history_entry.snapshot
if snapshot: if snapshot:
prompt.name = snapshot.get("name", prompt.name) prompt.name = snapshot.get('name', prompt.name)
prompt.content = snapshot.get("content", prompt.content) prompt.content = snapshot.get('content', prompt.content)
prompt.data = snapshot.get("data", prompt.data) prompt.data = snapshot.get('data', prompt.data)
prompt.meta = snapshot.get("meta", prompt.meta) prompt.meta = snapshot.get('meta', prompt.meta)
prompt.tags = snapshot.get("tags", prompt.tags) prompt.tags = snapshot.get('tags', prompt.tags)
# Note: command and access_grants are not restored from snapshot # Note: command and access_grants are not restored from snapshot
prompt.version_id = version_id prompt.version_id = version_id
@@ -585,9 +541,7 @@ class PromptsTable:
except Exception: except Exception:
return None return None
def toggle_prompt_active( def toggle_prompt_active(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]:
self, prompt_id: str, db: Optional[Session] = None
) -> Optional[PromptModel]:
"""Toggle the is_active flag on a prompt.""" """Toggle the is_active flag on a prompt."""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
@@ -602,16 +556,14 @@ class PromptsTable:
except Exception: except Exception:
return None return None
def delete_prompt_by_command( def delete_prompt_by_command(self, command: str, db: Optional[Session] = None) -> bool:
self, command: str, db: Optional[Session] = None
) -> bool:
"""Permanently delete a prompt and its history.""" """Permanently delete a prompt and its history."""
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
prompt = db.query(Prompt).filter_by(command=command).first() prompt = db.query(Prompt).filter_by(command=command).first()
if prompt: if prompt:
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
AccessGrants.revoke_all_access("prompt", prompt.id, db=db) AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
db.delete(prompt) db.delete(prompt)
db.commit() db.commit()
@@ -627,7 +579,7 @@ class PromptsTable:
prompt = db.query(Prompt).filter_by(id=prompt_id).first() prompt = db.query(Prompt).filter_by(id=prompt_id).first()
if prompt: if prompt:
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
AccessGrants.revoke_all_access("prompt", prompt.id, db=db) AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
db.delete(prompt) db.delete(prompt)
db.commit() db.commit()

View File

@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
class Skill(Base): class Skill(Base):
__tablename__ = "skill" __tablename__ = 'skill'
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True, unique=True)
user_id = Column(String) user_id = Column(String)
@@ -77,7 +77,7 @@ class SkillResponse(BaseModel):
class SkillUserResponse(SkillResponse): class SkillUserResponse(SkillResponse):
user: Optional[UserResponse] = None user: Optional[UserResponse] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
class SkillAccessResponse(SkillUserResponse): class SkillAccessResponse(SkillUserResponse):
@@ -105,10 +105,8 @@ class SkillAccessListResponse(BaseModel):
class SkillsTable: class SkillsTable:
def _get_access_grants( def _get_access_grants(self, skill_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
self, skill_id: str, db: Optional[Session] = None return AccessGrants.get_grants_by_resource('skill', skill_id, db=db)
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("skill", skill_id, db=db)
def _to_skill_model( def _to_skill_model(
self, self,
@@ -116,13 +114,9 @@ class SkillsTable:
access_grants: Optional[list[AccessGrantModel]] = None, access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> SkillModel: ) -> SkillModel:
skill_data = SkillModel.model_validate(skill).model_dump( skill_data = SkillModel.model_validate(skill).model_dump(exclude={'access_grants'})
exclude={"access_grants"} skill_data['access_grants'] = (
) access_grants if access_grants is not None else self._get_access_grants(skill_data['id'], db=db)
skill_data["access_grants"] = (
access_grants
if access_grants is not None
else self._get_access_grants(skill_data["id"], db=db)
) )
return SkillModel.model_validate(skill_data) return SkillModel.model_validate(skill_data)
@@ -136,29 +130,25 @@ class SkillsTable:
try: try:
result = Skill( result = Skill(
**{ **{
**form_data.model_dump(exclude={"access_grants"}), **form_data.model_dump(exclude={'access_grants'}),
"user_id": user_id, 'user_id': user_id,
"updated_at": int(time.time()), 'updated_at': int(time.time()),
"created_at": int(time.time()), 'created_at': int(time.time()),
} }
) )
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
AccessGrants.set_access_grants( AccessGrants.set_access_grants('skill', result.id, form_data.access_grants, db=db)
"skill", result.id, form_data.access_grants, db=db
)
if result: if result:
return self._to_skill_model(result, db=db) return self._to_skill_model(result, db=db)
else: else:
return None return None
except Exception as e: except Exception as e:
log.exception(f"Error creating a new skill: {e}") log.exception(f'Error creating a new skill: {e}')
return None return None
def get_skill_by_id( def get_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[SkillModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
skill = db.get(Skill, id) skill = db.get(Skill, id)
@@ -166,9 +156,7 @@ class SkillsTable:
except Exception: except Exception:
return None return None
def get_skill_by_name( def get_skill_by_name(self, name: str, db: Optional[Session] = None) -> Optional[SkillModel]:
self, name: str, db: Optional[Session] = None
) -> Optional[SkillModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
skill = db.query(Skill).filter_by(name=name).first() skill = db.query(Skill).filter_by(name=name).first()
@@ -185,7 +173,7 @@ class SkillsTable:
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users} users_dict = {user.id: user for user in users}
grants_map = AccessGrants.get_grants_by_resources("skill", skill_ids, db=db) grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db)
skills = [] skills = []
for skill in all_skills: for skill in all_skills:
@@ -198,19 +186,17 @@ class SkillsTable:
access_grants=grants_map.get(skill.id, []), access_grants=grants_map.get(skill.id, []),
db=db, db=db,
).model_dump(), ).model_dump(),
"user": user.model_dump() if user else None, 'user': user.model_dump() if user else None,
} }
) )
) )
return skills return skills
def get_skills_by_user_id( def get_skills_by_user_id(
self, user_id: str, permission: str = "write", db: Optional[Session] = None self, user_id: str, permission: str = 'write', db: Optional[Session] = None
) -> list[SkillUserModel]: ) -> list[SkillUserModel]:
skills = self.get_skills(db=db) skills = self.get_skills(db=db)
user_group_ids = { user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
return [ return [
skill skill
@@ -218,7 +204,7 @@ class SkillsTable:
if skill.user_id == user_id if skill.user_id == user_id
or AccessGrants.has_access( or AccessGrants.has_access(
user_id=user_id, user_id=user_id,
resource_type="skill", resource_type='skill',
resource_id=skill.id, resource_id=skill.id,
permission=permission, permission=permission,
user_group_ids=user_group_ids, user_group_ids=user_group_ids,
@@ -242,22 +228,22 @@ class SkillsTable:
query = db.query(Skill, User).outerjoin(User, User.id == Skill.user_id) query = db.query(Skill, User).outerjoin(User, User.id == Skill.user_id)
if filter: if filter:
query_key = filter.get("query") query_key = filter.get('query')
if query_key: if query_key:
query = query.filter( query = query.filter(
or_( or_(
Skill.name.ilike(f"%{query_key}%"), Skill.name.ilike(f'%{query_key}%'),
Skill.description.ilike(f"%{query_key}%"), Skill.description.ilike(f'%{query_key}%'),
Skill.id.ilike(f"%{query_key}%"), Skill.id.ilike(f'%{query_key}%'),
User.name.ilike(f"%{query_key}%"), User.name.ilike(f'%{query_key}%'),
User.email.ilike(f"%{query_key}%"), User.email.ilike(f'%{query_key}%'),
) )
) )
view_option = filter.get("view_option") view_option = filter.get('view_option')
if view_option == "created": if view_option == 'created':
query = query.filter(Skill.user_id == user_id) query = query.filter(Skill.user_id == user_id)
elif view_option == "shared": elif view_option == 'shared':
query = query.filter(Skill.user_id != user_id) query = query.filter(Skill.user_id != user_id)
# Apply access grant filtering # Apply access grant filtering
@@ -266,8 +252,8 @@ class SkillsTable:
query=query, query=query,
DocumentModel=Skill, DocumentModel=Skill,
filter=filter, filter=filter,
resource_type="skill", resource_type='skill',
permission="read", permission='read',
) )
query = query.order_by(Skill.updated_at.desc()) query = query.order_by(Skill.updated_at.desc())
@@ -283,9 +269,7 @@ class SkillsTable:
items = query.all() items = query.all()
skill_ids = [skill.id for skill, _ in items] skill_ids = [skill.id for skill, _ in items]
grants_map = AccessGrants.get_grants_by_resources( grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db)
"skill", skill_ids, db=db
)
skills = [] skills = []
for skill, user in items: for skill, user in items:
@@ -296,33 +280,23 @@ class SkillsTable:
access_grants=grants_map.get(skill.id, []), access_grants=grants_map.get(skill.id, []),
db=db, db=db,
).model_dump(), ).model_dump(),
user=( user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
UserResponse(
**UserModel.model_validate(user).model_dump()
)
if user
else None
),
) )
) )
return SkillListResponse(items=skills, total=total) return SkillListResponse(items=skills, total=total)
except Exception as e: except Exception as e:
log.exception(f"Error searching skills: {e}") log.exception(f'Error searching skills: {e}')
return SkillListResponse(items=[], total=0) return SkillListResponse(items=[], total=0)
def update_skill_by_id( def update_skill_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[SkillModel]:
self, id: str, updated: dict, db: Optional[Session] = None
) -> Optional[SkillModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
access_grants = updated.pop("access_grants", None) access_grants = updated.pop('access_grants', None)
db.query(Skill).filter_by(id=id).update( db.query(Skill).filter_by(id=id).update({**updated, 'updated_at': int(time.time())})
{**updated, "updated_at": int(time.time())}
)
db.commit() db.commit()
if access_grants is not None: if access_grants is not None:
AccessGrants.set_access_grants("skill", id, access_grants, db=db) AccessGrants.set_access_grants('skill', id, access_grants, db=db)
skill = db.query(Skill).get(id) skill = db.query(Skill).get(id)
db.refresh(skill) db.refresh(skill)
@@ -330,9 +304,7 @@ class SkillsTable:
except Exception: except Exception:
return None return None
def toggle_skill_by_id( def toggle_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[SkillModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
try: try:
skill = db.query(Skill).filter_by(id=id).first() skill = db.query(Skill).filter_by(id=id).first()
@@ -351,7 +323,7 @@ class SkillsTable:
def delete_skill_by_id(self, id: str, db: Optional[Session] = None) -> bool: def delete_skill_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
AccessGrants.revoke_all_access("skill", id, db=db) AccessGrants.revoke_all_access('skill', id, db=db)
db.query(Skill).filter_by(id=id).delete() db.query(Skill).filter_by(id=id).delete()
db.commit() db.commit()

View File

@@ -17,19 +17,19 @@ log = logging.getLogger(__name__)
# Tag DB Schema # Tag DB Schema
#################### ####################
class Tag(Base): class Tag(Base):
__tablename__ = "tag" __tablename__ = 'tag'
id = Column(String) id = Column(String)
name = Column(String) name = Column(String)
user_id = Column(String) user_id = Column(String)
meta = Column(JSON, nullable=True) meta = Column(JSON, nullable=True)
__table_args__ = ( __table_args__ = (
PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"), PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'),
Index("user_id_idx", "user_id"), Index('user_id_idx', 'user_id'),
) )
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column # Unique constraint ensuring (id, user_id) is unique, not just the `id` column
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),) __table_args__ = (PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'),)
class TagModel(BaseModel): class TagModel(BaseModel):
@@ -51,12 +51,10 @@ class TagChatIdForm(BaseModel):
class TagTable: class TagTable:
def insert_new_tag( def insert_new_tag(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]:
self, name: str, user_id: str, db: Optional[Session] = None
) -> Optional[TagModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
id = name.replace(" ", "_").lower() id = name.replace(' ', '_').lower()
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name})
try: try:
result = Tag(**tag.model_dump()) result = Tag(**tag.model_dump())
db.add(result) db.add(result)
@@ -67,89 +65,63 @@ class TagTable:
else: else:
return None return None
except Exception as e: except Exception as e:
log.exception(f"Error inserting a new tag: {e}") log.exception(f'Error inserting a new tag: {e}')
return None return None
def get_tag_by_name_and_user_id( def get_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]:
self, name: str, user_id: str, db: Optional[Session] = None
) -> Optional[TagModel]:
try: try:
id = name.replace(" ", "_").lower() id = name.replace(' ', '_').lower()
with get_db_context(db) as db: with get_db_context(db) as db:
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first() tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
return TagModel.model_validate(tag) return TagModel.model_validate(tag)
except Exception: except Exception:
return None return None
def get_tags_by_user_id( def get_tags_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[TagModel]:
self, user_id: str, db: Optional[Session] = None with get_db_context(db) as db:
) -> list[TagModel]: return [TagModel.model_validate(tag) for tag in (db.query(Tag).filter_by(user_id=user_id).all())]
def get_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> list[TagModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
return [ return [
TagModel.model_validate(tag) TagModel.model_validate(tag)
for tag in (db.query(Tag).filter_by(user_id=user_id).all()) for tag in (db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all())
] ]
def get_tags_by_ids_and_user_id( def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> bool:
self, ids: list[str], user_id: str, db: Optional[Session] = None
) -> list[TagModel]:
with get_db_context(db) as db:
return [
TagModel.model_validate(tag)
for tag in (
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
)
]
def delete_tag_by_name_and_user_id(
self, name: str, user_id: str, db: Optional[Session] = None
) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
id = name.replace(" ", "_").lower() id = name.replace(' ', '_').lower()
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete() res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
log.debug(f"res: {res}") log.debug(f'res: {res}')
db.commit() db.commit()
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f'delete_tag: {e}')
return False return False
def delete_tags_by_ids_and_user_id( def delete_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> bool:
self, ids: list[str], user_id: str, db: Optional[Session] = None
) -> bool:
"""Delete all tags whose id is in *ids* for the given user, in one query.""" """Delete all tags whose id is in *ids* for the given user, in one query."""
if not ids: if not ids:
return True return True
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete( db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete(synchronize_session=False)
synchronize_session=False
)
db.commit() db.commit()
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tags_by_ids: {e}") log.error(f'delete_tags_by_ids: {e}')
return False return False
def ensure_tags_exist( def ensure_tags_exist(self, names: list[str], user_id: str, db: Optional[Session] = None) -> None:
self, names: list[str], user_id: str, db: Optional[Session] = None
) -> None:
"""Create tag rows for any *names* that don't already exist for *user_id*.""" """Create tag rows for any *names* that don't already exist for *user_id*."""
if not names: if not names:
return return
ids = [n.replace(" ", "_").lower() for n in names] ids = [n.replace(' ', '_').lower() for n in names]
with get_db_context(db) as db: with get_db_context(db) as db:
existing = { existing = {t.id for t in db.query(Tag.id).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()}
t.id
for t in db.query(Tag.id)
.filter(Tag.id.in_(ids), Tag.user_id == user_id)
.all()
}
new_tags = [ new_tags = [
Tag(id=tag_id, name=name, user_id=user_id) Tag(id=tag_id, name=name, user_id=user_id) for tag_id, name in zip(ids, names) if tag_id not in existing
for tag_id, name in zip(ids, names)
if tag_id not in existing
] ]
if new_tags: if new_tags:
db.add_all(new_tags) db.add_all(new_tags)

View File

@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
class Tool(Base): class Tool(Base):
__tablename__ = "tool" __tablename__ = 'tool'
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True, unique=True)
user_id = Column(String) user_id = Column(String)
@@ -75,7 +75,7 @@ class ToolResponse(BaseModel):
class ToolUserResponse(ToolResponse): class ToolUserResponse(ToolResponse):
user: Optional[UserResponse] = None user: Optional[UserResponse] = None
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
class ToolAccessResponse(ToolUserResponse): class ToolAccessResponse(ToolUserResponse):
@@ -95,10 +95,8 @@ class ToolValves(BaseModel):
class ToolsTable: class ToolsTable:
def _get_access_grants( def _get_access_grants(self, tool_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
self, tool_id: str, db: Optional[Session] = None return AccessGrants.get_grants_by_resource('tool', tool_id, db=db)
) -> list[AccessGrantModel]:
return AccessGrants.get_grants_by_resource("tool", tool_id, db=db)
def _to_tool_model( def _to_tool_model(
self, self,
@@ -106,11 +104,9 @@ class ToolsTable:
access_grants: Optional[list[AccessGrantModel]] = None, access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> ToolModel: ) -> ToolModel:
tool_data = ToolModel.model_validate(tool).model_dump(exclude={"access_grants"}) tool_data = ToolModel.model_validate(tool).model_dump(exclude={'access_grants'})
tool_data["access_grants"] = ( tool_data['access_grants'] = (
access_grants access_grants if access_grants is not None else self._get_access_grants(tool_data['id'], db=db)
if access_grants is not None
else self._get_access_grants(tool_data["id"], db=db)
) )
return ToolModel.model_validate(tool_data) return ToolModel.model_validate(tool_data)
@@ -125,30 +121,26 @@ class ToolsTable:
try: try:
result = Tool( result = Tool(
**{ **{
**form_data.model_dump(exclude={"access_grants"}), **form_data.model_dump(exclude={'access_grants'}),
"specs": specs, 'specs': specs,
"user_id": user_id, 'user_id': user_id,
"updated_at": int(time.time()), 'updated_at': int(time.time()),
"created_at": int(time.time()), 'created_at': int(time.time()),
} }
) )
db.add(result) db.add(result)
db.commit() db.commit()
db.refresh(result) db.refresh(result)
AccessGrants.set_access_grants( AccessGrants.set_access_grants('tool', result.id, form_data.access_grants, db=db)
"tool", result.id, form_data.access_grants, db=db
)
if result: if result:
return self._to_tool_model(result, db=db) return self._to_tool_model(result, db=db)
else: else:
return None return None
except Exception as e: except Exception as e:
log.exception(f"Error creating a new tool: {e}") log.exception(f'Error creating a new tool: {e}')
return None return None
def get_tool_by_id( def get_tool_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ToolModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[ToolModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
tool = db.get(Tool, id) tool = db.get(Tool, id)
@@ -156,9 +148,7 @@ class ToolsTable:
except Exception: except Exception:
return None return None
def get_tools( def get_tools(self, defer_content: bool = False, db: Optional[Session] = None) -> list[ToolUserModel]:
self, defer_content: bool = False, db: Optional[Session] = None
) -> list[ToolUserModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
query = db.query(Tool).order_by(Tool.updated_at.desc()) query = db.query(Tool).order_by(Tool.updated_at.desc())
if defer_content: if defer_content:
@@ -170,7 +160,7 @@ class ToolsTable:
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users_dict = {user.id: user for user in users} users_dict = {user.id: user for user in users}
grants_map = AccessGrants.get_grants_by_resources("tool", tool_ids, db=db) grants_map = AccessGrants.get_grants_by_resources('tool', tool_ids, db=db)
tools = [] tools = []
for tool in all_tools: for tool in all_tools:
@@ -183,7 +173,7 @@ class ToolsTable:
access_grants=grants_map.get(tool.id, []), access_grants=grants_map.get(tool.id, []),
db=db, db=db,
).model_dump(), ).model_dump(),
"user": user.model_dump() if user else None, 'user': user.model_dump() if user else None,
} }
) )
) )
@@ -192,14 +182,12 @@ class ToolsTable:
def get_tools_by_user_id( def get_tools_by_user_id(
self, self,
user_id: str, user_id: str,
permission: str = "write", permission: str = 'write',
defer_content: bool = False, defer_content: bool = False,
db: Optional[Session] = None, db: Optional[Session] = None,
) -> list[ToolUserModel]: ) -> list[ToolUserModel]:
tools = self.get_tools(defer_content=defer_content, db=db) tools = self.get_tools(defer_content=defer_content, db=db)
user_group_ids = { user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
}
return [ return [
tool tool
@@ -207,7 +195,7 @@ class ToolsTable:
if tool.user_id == user_id if tool.user_id == user_id
or AccessGrants.has_access( or AccessGrants.has_access(
user_id=user_id, user_id=user_id,
resource_type="tool", resource_type='tool',
resource_id=tool.id, resource_id=tool.id,
permission=permission, permission=permission,
user_group_ids=user_group_ids, user_group_ids=user_group_ids,
@@ -215,48 +203,38 @@ class ToolsTable:
) )
] ]
def get_tool_valves_by_id( def get_tool_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
self, id: str, db: Optional[Session] = None
) -> Optional[dict]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
tool = db.get(Tool, id) tool = db.get(Tool, id)
return tool.valves if tool.valves else {} return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
log.exception(f"Error getting tool valves by id {id}") log.exception(f'Error getting tool valves by id {id}')
return None return None
def update_tool_valves_by_id( def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[Session] = None) -> Optional[ToolValves]:
self, id: str, valves: dict, db: Optional[Session] = None
) -> Optional[ToolValves]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
db.query(Tool).filter_by(id=id).update( db.query(Tool).filter_by(id=id).update({'valves': valves, 'updated_at': int(time.time())})
{"valves": valves, "updated_at": int(time.time())}
)
db.commit() db.commit()
return self.get_tool_by_id(id, db=db) return self.get_tool_by_id(id, db=db)
except Exception: except Exception:
return None return None
def get_user_valves_by_id_and_user_id( def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]:
self, id: str, user_id: str, db: Optional[Session] = None
) -> Optional[dict]:
try: try:
user = Users.get_user_by_id(user_id, db=db) user = Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings # Check if user has "tools" and "valves" settings
if "tools" not in user_settings: if 'tools' not in user_settings:
user_settings["tools"] = {} user_settings['tools'] = {}
if "valves" not in user_settings["tools"]: if 'valves' not in user_settings['tools']:
user_settings["tools"]["valves"] = {} user_settings['tools']['valves'] = {}
return user_settings["tools"]["valves"].get(id, {}) return user_settings['tools']['valves'].get(id, {})
except Exception as e: except Exception as e:
log.exception( log.exception(f'Error getting user values by id {id} and user_id {user_id}: {e}')
f"Error getting user values by id {id} and user_id {user_id}: {e}"
)
return None return None
def update_user_valves_by_id_and_user_id( def update_user_valves_by_id_and_user_id(
@@ -267,35 +245,29 @@ class ToolsTable:
user_settings = user.settings.model_dump() if user.settings else {} user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "tools" and "valves" settings # Check if user has "tools" and "valves" settings
if "tools" not in user_settings: if 'tools' not in user_settings:
user_settings["tools"] = {} user_settings['tools'] = {}
if "valves" not in user_settings["tools"]: if 'valves' not in user_settings['tools']:
user_settings["tools"]["valves"] = {} user_settings['tools']['valves'] = {}
user_settings["tools"]["valves"][id] = valves user_settings['tools']['valves'][id] = valves
# Update the user settings in the database # Update the user settings in the database
Users.update_user_by_id(user_id, {"settings": user_settings}, db=db) Users.update_user_by_id(user_id, {'settings': user_settings}, db=db)
return user_settings["tools"]["valves"][id] return user_settings['tools']['valves'][id]
except Exception as e: except Exception as e:
log.exception( log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
)
return None return None
def update_tool_by_id( def update_tool_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[ToolModel]:
self, id: str, updated: dict, db: Optional[Session] = None
) -> Optional[ToolModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
access_grants = updated.pop("access_grants", None) access_grants = updated.pop('access_grants', None)
db.query(Tool).filter_by(id=id).update( db.query(Tool).filter_by(id=id).update({**updated, 'updated_at': int(time.time())})
{**updated, "updated_at": int(time.time())}
)
db.commit() db.commit()
if access_grants is not None: if access_grants is not None:
AccessGrants.set_access_grants("tool", id, access_grants, db=db) AccessGrants.set_access_grants('tool', id, access_grants, db=db)
tool = db.query(Tool).get(id) tool = db.query(Tool).get(id)
db.refresh(tool) db.refresh(tool)
@@ -306,7 +278,7 @@ class ToolsTable:
def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool: def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
AccessGrants.revoke_all_access("tool", id, db=db) AccessGrants.revoke_all_access('tool', id, db=db)
db.query(Tool).filter_by(id=id).delete() db.query(Tool).filter_by(id=id).delete()
db.commit() db.commit()

View File

@@ -40,12 +40,12 @@ import datetime
class UserSettings(BaseModel): class UserSettings(BaseModel):
ui: Optional[dict] = {} ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
pass pass
class User(Base): class User(Base):
__tablename__ = "user" __tablename__ = 'user'
id = Column(String, primary_key=True, unique=True) id = Column(String, primary_key=True, unique=True)
email = Column(String) email = Column(String)
@@ -83,7 +83,7 @@ class UserModel(BaseModel):
email: str email: str
username: Optional[str] = None username: Optional[str] = None
role: str = "pending" role: str = 'pending'
name: str name: str
@@ -112,10 +112,10 @@ class UserModel(BaseModel):
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@model_validator(mode="after") @model_validator(mode='after')
def set_profile_image_url(self): def set_profile_image_url(self):
if not self.profile_image_url: if not self.profile_image_url:
self.profile_image_url = f"/api/v1/users/{self.id}/profile/image" self.profile_image_url = f'/api/v1/users/{self.id}/profile/image'
return self return self
@@ -126,7 +126,7 @@ class UserStatusModel(UserModel):
class ApiKey(Base): class ApiKey(Base):
__tablename__ = "api_key" __tablename__ = 'api_key'
id = Column(Text, primary_key=True, unique=True) id = Column(Text, primary_key=True, unique=True)
user_id = Column(Text, nullable=False) user_id = Column(Text, nullable=False)
@@ -163,7 +163,7 @@ class UpdateProfileForm(BaseModel):
gender: Optional[str] = None gender: Optional[str] = None
date_of_birth: Optional[datetime.date] = None date_of_birth: Optional[datetime.date] = None
@field_validator("profile_image_url") @field_validator('profile_image_url')
@classmethod @classmethod
def check_profile_image_url(cls, v: str) -> str: def check_profile_image_url(cls, v: str) -> str:
return validate_profile_image_url(v) return validate_profile_image_url(v)
@@ -174,7 +174,7 @@ class UserGroupIdsModel(UserModel):
class UserModelResponse(UserModel): class UserModelResponse(UserModel):
model_config = ConfigDict(extra="allow") model_config = ConfigDict(extra='allow')
class UserListResponse(BaseModel): class UserListResponse(BaseModel):
@@ -251,7 +251,7 @@ class UserUpdateForm(BaseModel):
profile_image_url: str profile_image_url: str
password: Optional[str] = None password: Optional[str] = None
@field_validator("profile_image_url") @field_validator('profile_image_url')
@classmethod @classmethod
def check_profile_image_url(cls, v: str) -> str: def check_profile_image_url(cls, v: str) -> str:
return validate_profile_image_url(v) return validate_profile_image_url(v)
@@ -263,8 +263,8 @@ class UsersTable:
id: str, id: str,
name: str, name: str,
email: str, email: str,
profile_image_url: str = "/user.png", profile_image_url: str = '/user.png',
role: str = "pending", role: str = 'pending',
username: Optional[str] = None, username: Optional[str] = None,
oauth: Optional[dict] = None, oauth: Optional[dict] = None,
db: Optional[Session] = None, db: Optional[Session] = None,
@@ -272,16 +272,16 @@ class UsersTable:
with get_db_context(db) as db: with get_db_context(db) as db:
user = UserModel( user = UserModel(
**{ **{
"id": id, 'id': id,
"email": email, 'email': email,
"name": name, 'name': name,
"role": role, 'role': role,
"profile_image_url": profile_image_url, 'profile_image_url': profile_image_url,
"last_active_at": int(time.time()), 'last_active_at': int(time.time()),
"created_at": int(time.time()), 'created_at': int(time.time()),
"updated_at": int(time.time()), 'updated_at': int(time.time()),
"username": username, 'username': username,
"oauth": oauth, 'oauth': oauth,
} }
) )
result = User(**user.model_dump()) result = User(**user.model_dump())
@@ -293,9 +293,7 @@ class UsersTable:
else: else:
return None return None
def get_user_by_id( def get_user_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[UserModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
@@ -303,49 +301,32 @@ class UsersTable:
except Exception: except Exception:
return None return None
def get_user_by_api_key( def get_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
self, api_key: str, db: Optional[Session] = None
) -> Optional[UserModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
user = ( user = db.query(User).join(ApiKey, User.id == ApiKey.user_id).filter(ApiKey.key == api_key).first()
db.query(User)
.join(ApiKey, User.id == ApiKey.user_id)
.filter(ApiKey.key == api_key)
.first()
)
return UserModel.model_validate(user) if user else None return UserModel.model_validate(user) if user else None
except Exception: except Exception:
return None return None
def get_user_by_email( def get_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
self, email: str, db: Optional[Session] = None
) -> Optional[UserModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
user = ( user = db.query(User).filter(func.lower(User.email) == email.lower()).first()
db.query(User)
.filter(func.lower(User.email) == email.lower())
.first()
)
return UserModel.model_validate(user) if user else None return UserModel.model_validate(user) if user else None
except Exception: except Exception:
return None return None
def get_user_by_oauth_sub( def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[Session] = None) -> Optional[UserModel]:
self, provider: str, sub: str, db: Optional[Session] = None
) -> Optional[UserModel]:
try: try:
with get_db_context(db) as db: # type: Session with get_db_context(db) as db: # type: Session
dialect_name = db.bind.dialect.name dialect_name = db.bind.dialect.name
query = db.query(User) query = db.query(User)
if dialect_name == "sqlite": if dialect_name == 'sqlite':
query = query.filter(User.oauth.contains({provider: {"sub": sub}})) query = query.filter(User.oauth.contains({provider: {'sub': sub}}))
elif dialect_name == "postgresql": elif dialect_name == 'postgresql':
query = query.filter( query = query.filter(User.oauth[provider].cast(JSONB)['sub'].astext == sub)
User.oauth[provider].cast(JSONB)["sub"].astext == sub
)
user = query.first() user = query.first()
return UserModel.model_validate(user) if user else None return UserModel.model_validate(user) if user else None
@@ -361,15 +342,10 @@ class UsersTable:
dialect_name = db.bind.dialect.name dialect_name = db.bind.dialect.name
query = db.query(User) query = db.query(User)
if dialect_name == "sqlite": if dialect_name == 'sqlite':
query = query.filter( query = query.filter(User.scim.contains({provider: {'external_id': external_id}}))
User.scim.contains({provider: {"external_id": external_id}}) elif dialect_name == 'postgresql':
) query = query.filter(User.scim[provider].cast(JSONB)['external_id'].astext == external_id)
elif dialect_name == "postgresql":
query = query.filter(
User.scim[provider].cast(JSONB)["external_id"].astext
== external_id
)
user = query.first() user = query.first()
return UserModel.model_validate(user) if user else None return UserModel.model_validate(user) if user else None
@@ -388,16 +364,16 @@ class UsersTable:
query = db.query(User).options(defer(User.profile_image_url)) query = db.query(User).options(defer(User.profile_image_url))
if filter: if filter:
query_key = filter.get("query") query_key = filter.get('query')
if query_key: if query_key:
query = query.filter( query = query.filter(
or_( or_(
User.name.ilike(f"%{query_key}%"), User.name.ilike(f'%{query_key}%'),
User.email.ilike(f"%{query_key}%"), User.email.ilike(f'%{query_key}%'),
) )
) )
channel_id = filter.get("channel_id") channel_id = filter.get('channel_id')
if channel_id: if channel_id:
query = query.filter( query = query.filter(
exists( exists(
@@ -408,13 +384,13 @@ class UsersTable:
) )
) )
user_ids = filter.get("user_ids") user_ids = filter.get('user_ids')
group_ids = filter.get("group_ids") group_ids = filter.get('group_ids')
if isinstance(user_ids, list) and isinstance(group_ids, list): if isinstance(user_ids, list) and isinstance(group_ids, list):
# If both are empty lists, return no users # If both are empty lists, return no users
if not user_ids and not group_ids: if not user_ids and not group_ids:
return {"users": [], "total": 0} return {'users': [], 'total': 0}
if user_ids: if user_ids:
query = query.filter(User.id.in_(user_ids)) query = query.filter(User.id.in_(user_ids))
@@ -429,21 +405,21 @@ class UsersTable:
) )
) )
roles = filter.get("roles") roles = filter.get('roles')
if roles: if roles:
include_roles = [role for role in roles if not role.startswith("!")] include_roles = [role for role in roles if not role.startswith('!')]
exclude_roles = [role[1:] for role in roles if role.startswith("!")] exclude_roles = [role[1:] for role in roles if role.startswith('!')]
if include_roles: if include_roles:
query = query.filter(User.role.in_(include_roles)) query = query.filter(User.role.in_(include_roles))
if exclude_roles: if exclude_roles:
query = query.filter(~User.role.in_(exclude_roles)) query = query.filter(~User.role.in_(exclude_roles))
order_by = filter.get("order_by") order_by = filter.get('order_by')
direction = filter.get("direction") direction = filter.get('direction')
if order_by and order_by.startswith("group_id:"): if order_by and order_by.startswith('group_id:'):
group_id = order_by.split(":", 1)[1] group_id = order_by.split(':', 1)[1]
# Subquery that checks if the user belongs to the group # Subquery that checks if the user belongs to the group
membership_exists = exists( membership_exists = exists(
@@ -456,42 +432,42 @@ class UsersTable:
# CASE: user in group → 1, user not in group → 0 # CASE: user in group → 1, user not in group → 0
group_sort = case((membership_exists, 1), else_=0) group_sort = case((membership_exists, 1), else_=0)
if direction == "asc": if direction == 'asc':
query = query.order_by(group_sort.asc(), User.name.asc()) query = query.order_by(group_sort.asc(), User.name.asc())
else: else:
query = query.order_by(group_sort.desc(), User.name.asc()) query = query.order_by(group_sort.desc(), User.name.asc())
elif order_by == "name": elif order_by == 'name':
if direction == "asc": if direction == 'asc':
query = query.order_by(User.name.asc()) query = query.order_by(User.name.asc())
else: else:
query = query.order_by(User.name.desc()) query = query.order_by(User.name.desc())
elif order_by == "email": elif order_by == 'email':
if direction == "asc": if direction == 'asc':
query = query.order_by(User.email.asc()) query = query.order_by(User.email.asc())
else: else:
query = query.order_by(User.email.desc()) query = query.order_by(User.email.desc())
elif order_by == "created_at": elif order_by == 'created_at':
if direction == "asc": if direction == 'asc':
query = query.order_by(User.created_at.asc()) query = query.order_by(User.created_at.asc())
else: else:
query = query.order_by(User.created_at.desc()) query = query.order_by(User.created_at.desc())
elif order_by == "last_active_at": elif order_by == 'last_active_at':
if direction == "asc": if direction == 'asc':
query = query.order_by(User.last_active_at.asc()) query = query.order_by(User.last_active_at.asc())
else: else:
query = query.order_by(User.last_active_at.desc()) query = query.order_by(User.last_active_at.desc())
elif order_by == "updated_at": elif order_by == 'updated_at':
if direction == "asc": if direction == 'asc':
query = query.order_by(User.updated_at.asc()) query = query.order_by(User.updated_at.asc())
else: else:
query = query.order_by(User.updated_at.desc()) query = query.order_by(User.updated_at.desc())
elif order_by == "role": elif order_by == 'role':
if direction == "asc": if direction == 'asc':
query = query.order_by(User.role.asc()) query = query.order_by(User.role.asc())
else: else:
query = query.order_by(User.role.desc()) query = query.order_by(User.role.desc())
@@ -510,13 +486,11 @@ class UsersTable:
users = query.all() users = query.all()
return { return {
"users": [UserModel.model_validate(user) for user in users], 'users': [UserModel.model_validate(user) for user in users],
"total": total, 'total': total,
} }
def get_users_by_group_id( def get_users_by_group_id(self, group_id: str, db: Optional[Session] = None) -> list[UserModel]:
self, group_id: str, db: Optional[Session] = None
) -> list[UserModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
users = ( users = (
db.query(User) db.query(User)
@@ -527,16 +501,9 @@ class UsersTable:
) )
return [UserModel.model_validate(user) for user in users] return [UserModel.model_validate(user) for user in users]
def get_users_by_user_ids( def get_users_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[UserStatusModel]:
self, user_ids: list[str], db: Optional[Session] = None
) -> list[UserStatusModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
users = ( users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all()
db.query(User)
.options(defer(User.profile_image_url))
.filter(User.id.in_(user_ids))
.all()
)
return [UserModel.model_validate(user) for user in users] return [UserModel.model_validate(user) for user in users]
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]: def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:
@@ -555,9 +522,7 @@ class UsersTable:
except Exception: except Exception:
return None return None
def get_user_webhook_url_by_id( def get_user_webhook_url_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
self, id: str, db: Optional[Session] = None
) -> Optional[str]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
@@ -565,11 +530,7 @@ class UsersTable:
if user.settings is None: if user.settings is None:
return None return None
else: else:
return ( return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None)
user.settings.get("ui", {})
.get("notifications", {})
.get("webhook_url", None)
)
except Exception: except Exception:
return None return None
@@ -577,14 +538,10 @@ class UsersTable:
with get_db_context(db) as db: with get_db_context(db) as db:
current_timestamp = int(datetime.datetime.now().timestamp()) current_timestamp = int(datetime.datetime.now().timestamp())
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400) today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
query = db.query(User).filter( query = db.query(User).filter(User.last_active_at > today_midnight_timestamp)
User.last_active_at > today_midnight_timestamp
)
return query.count() return query.count()
def update_user_role_by_id( def update_user_role_by_id(self, id: str, role: str, db: Optional[Session] = None) -> Optional[UserModel]:
self, id: str, role: str, db: Optional[Session] = None
) -> Optional[UserModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
@@ -629,9 +586,7 @@ class UsersTable:
return None return None
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
def update_last_active_by_id( def update_last_active_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
self, id: str, db: Optional[Session] = None
) -> Optional[UserModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
@@ -665,10 +620,10 @@ class UsersTable:
oauth = user.oauth or {} oauth = user.oauth or {}
# Update or insert provider entry # Update or insert provider entry
oauth[provider] = {"sub": sub} oauth[provider] = {'sub': sub}
# Persist updated JSON # Persist updated JSON
db.query(User).filter_by(id=id).update({"oauth": oauth}) db.query(User).filter_by(id=id).update({'oauth': oauth})
db.commit() db.commit()
return UserModel.model_validate(user) return UserModel.model_validate(user)
@@ -698,9 +653,9 @@ class UsersTable:
return None return None
scim = user.scim or {} scim = user.scim or {}
scim[provider] = {"external_id": external_id} scim[provider] = {'external_id': external_id}
db.query(User).filter_by(id=id).update({"scim": scim}) db.query(User).filter_by(id=id).update({'scim': scim})
db.commit() db.commit()
return UserModel.model_validate(user) return UserModel.model_validate(user)
@@ -708,9 +663,7 @@ class UsersTable:
except Exception: except Exception:
return None return None
def update_user_by_id( def update_user_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
self, id: str, updated: dict, db: Optional[Session] = None
) -> Optional[UserModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
@@ -725,9 +678,7 @@ class UsersTable:
print(e) print(e)
return None return None
def update_user_settings_by_id( def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
self, id: str, updated: dict, db: Optional[Session] = None
) -> Optional[UserModel]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
@@ -741,7 +692,7 @@ class UsersTable:
user_settings.update(updated) user_settings.update(updated)
db.query(User).filter_by(id=id).update({"settings": user_settings}) db.query(User).filter_by(id=id).update({'settings': user_settings})
db.commit() db.commit()
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
@@ -768,9 +719,7 @@ class UsersTable:
except Exception: except Exception:
return False return False
def get_user_api_key_by_id( def get_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
self, id: str, db: Optional[Session] = None
) -> Optional[str]:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
api_key = db.query(ApiKey).filter_by(user_id=id).first() api_key = db.query(ApiKey).filter_by(user_id=id).first()
@@ -778,9 +727,7 @@ class UsersTable:
except Exception: except Exception:
return None return None
def update_user_api_key_by_id( def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[Session] = None) -> bool:
self, id: str, api_key: str, db: Optional[Session] = None
) -> bool:
try: try:
with get_db_context(db) as db: with get_db_context(db) as db:
db.query(ApiKey).filter_by(user_id=id).delete() db.query(ApiKey).filter_by(user_id=id).delete()
@@ -788,7 +735,7 @@ class UsersTable:
now = int(time.time()) now = int(time.time())
new_api_key = ApiKey( new_api_key = ApiKey(
id=f"key_{id}", id=f'key_{id}',
user_id=id, user_id=id,
key=api_key, key=api_key,
created_at=now, created_at=now,
@@ -811,16 +758,14 @@ class UsersTable:
except Exception: except Exception:
return False return False
def get_valid_user_ids( def get_valid_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[str]:
self, user_ids: list[str], db: Optional[Session] = None
) -> list[str]:
with get_db_context(db) as db: with get_db_context(db) as db:
users = db.query(User).filter(User.id.in_(user_ids)).all() users = db.query(User).filter(User.id.in_(user_ids)).all()
return [user.id for user in users] return [user.id for user in users]
def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]: def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]:
with get_db_context(db) as db: with get_db_context(db) as db:
user = db.query(User).filter_by(role="admin").first() user = db.query(User).filter_by(role='admin').first()
if user: if user:
return UserModel.model_validate(user) return UserModel.model_validate(user)
else: else:
@@ -830,9 +775,7 @@ class UsersTable:
with get_db_context(db) as db: with get_db_context(db) as db:
# Consider user active if last_active_at within the last 3 minutes # Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180 three_minutes_ago = int(time.time()) - 180
count = ( count = db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
)
return count return count
@staticmethod @staticmethod

View File

@@ -40,78 +40,76 @@ class DatalabMarkerLoader:
self.output_format = output_format self.output_format = output_format
def _get_mime_type(self, filename: str) -> str: def _get_mime_type(self, filename: str) -> str:
ext = filename.rsplit(".", 1)[-1].lower() ext = filename.rsplit('.', 1)[-1].lower()
mime_map = { mime_map = {
"pdf": "application/pdf", 'pdf': 'application/pdf',
"xls": "application/vnd.ms-excel", 'xls': 'application/vnd.ms-excel',
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", 'xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
"ods": "application/vnd.oasis.opendocument.spreadsheet", 'ods': 'application/vnd.oasis.opendocument.spreadsheet',
"doc": "application/msword", 'doc': 'application/msword',
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", 'docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
"odt": "application/vnd.oasis.opendocument.text", 'odt': 'application/vnd.oasis.opendocument.text',
"ppt": "application/vnd.ms-powerpoint", 'ppt': 'application/vnd.ms-powerpoint',
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", 'pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
"odp": "application/vnd.oasis.opendocument.presentation", 'odp': 'application/vnd.oasis.opendocument.presentation',
"html": "text/html", 'html': 'text/html',
"epub": "application/epub+zip", 'epub': 'application/epub+zip',
"png": "image/png", 'png': 'image/png',
"jpeg": "image/jpeg", 'jpeg': 'image/jpeg',
"jpg": "image/jpeg", 'jpg': 'image/jpeg',
"webp": "image/webp", 'webp': 'image/webp',
"gif": "image/gif", 'gif': 'image/gif',
"tiff": "image/tiff", 'tiff': 'image/tiff',
} }
return mime_map.get(ext, "application/octet-stream") return mime_map.get(ext, 'application/octet-stream')
def check_marker_request_status(self, request_id: str) -> dict: def check_marker_request_status(self, request_id: str) -> dict:
url = f"{self.api_base_url}/{request_id}" url = f'{self.api_base_url}/{request_id}'
headers = {"X-Api-Key": self.api_key} headers = {'X-Api-Key': self.api_key}
try: try:
response = requests.get(url, headers=headers) response = requests.get(url, headers=headers)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
log.info(f"Marker API status check for request {request_id}: {result}") log.info(f'Marker API status check for request {request_id}: {result}')
return result return result
except requests.HTTPError as e: except requests.HTTPError as e:
log.error(f"Error checking Marker request status: {e}") log.error(f'Error checking Marker request status: {e}')
raise HTTPException( raise HTTPException(
status.HTTP_502_BAD_GATEWAY, status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to check Marker request: {e}", detail=f'Failed to check Marker request: {e}',
) )
except ValueError as e: except ValueError as e:
log.error(f"Invalid JSON checking Marker request: {e}") log.error(f'Invalid JSON checking Marker request: {e}')
raise HTTPException( raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON: {e}')
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
)
def load(self) -> List[Document]: def load(self) -> List[Document]:
filename = os.path.basename(self.file_path) filename = os.path.basename(self.file_path)
mime_type = self._get_mime_type(filename) mime_type = self._get_mime_type(filename)
headers = {"X-Api-Key": self.api_key} headers = {'X-Api-Key': self.api_key}
form_data = { form_data = {
"use_llm": str(self.use_llm).lower(), 'use_llm': str(self.use_llm).lower(),
"skip_cache": str(self.skip_cache).lower(), 'skip_cache': str(self.skip_cache).lower(),
"force_ocr": str(self.force_ocr).lower(), 'force_ocr': str(self.force_ocr).lower(),
"paginate": str(self.paginate).lower(), 'paginate': str(self.paginate).lower(),
"strip_existing_ocr": str(self.strip_existing_ocr).lower(), 'strip_existing_ocr': str(self.strip_existing_ocr).lower(),
"disable_image_extraction": str(self.disable_image_extraction).lower(), 'disable_image_extraction': str(self.disable_image_extraction).lower(),
"format_lines": str(self.format_lines).lower(), 'format_lines': str(self.format_lines).lower(),
"output_format": self.output_format, 'output_format': self.output_format,
} }
if self.additional_config and self.additional_config.strip(): if self.additional_config and self.additional_config.strip():
form_data["additional_config"] = self.additional_config form_data['additional_config'] = self.additional_config
log.info( log.info(
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}" f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
) )
try: try:
with open(self.file_path, "rb") as f: with open(self.file_path, 'rb') as f:
files = {"file": (filename, f, mime_type)} files = {'file': (filename, f, mime_type)}
response = requests.post( response = requests.post(
f"{self.api_base_url}", f'{self.api_base_url}',
data=form_data, data=form_data,
files=files, files=files,
headers=headers, headers=headers,
@@ -119,29 +117,25 @@ class DatalabMarkerLoader:
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
except FileNotFoundError: except FileNotFoundError:
raise HTTPException( raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
except requests.HTTPError as e: except requests.HTTPError as e:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {e}", detail=f'Datalab Marker request failed: {e}',
) )
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON response: {e}')
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
)
except Exception as e: except Exception as e:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
if not result.get("success"): if not result.get('success'):
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}", detail=f'Datalab Marker request failed: {result.get("error", "Unknown error")}',
) )
check_url = result.get("request_check_url") check_url = result.get('request_check_url')
request_id = result.get("request_id") request_id = result.get('request_id')
# Check if this is a direct response (self-hosted) or polling response (DataLab) # Check if this is a direct response (self-hosted) or polling response (DataLab)
if check_url: if check_url:
@@ -154,54 +148,45 @@ class DatalabMarkerLoader:
poll_result = poll_response.json() poll_result = poll_response.json()
except (requests.HTTPError, ValueError) as e: except (requests.HTTPError, ValueError) as e:
raw_body = poll_response.text raw_body = poll_response.text
log.error(f"Polling error: {e}, response body: {raw_body}") log.error(f'Polling error: {e}, response body: {raw_body}')
raise HTTPException( raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Polling failed: {e}')
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
)
status_val = poll_result.get("status") status_val = poll_result.get('status')
success_val = poll_result.get("success") success_val = poll_result.get('success')
if status_val == "complete": if status_val == 'complete':
summary = { summary = {
k: poll_result.get(k) k: poll_result.get(k)
for k in ( for k in (
"status", 'status',
"output_format", 'output_format',
"success", 'success',
"error", 'error',
"page_count", 'page_count',
"total_cost", 'total_cost',
) )
} }
log.info( log.info(f'Marker processing completed successfully: {json.dumps(summary, indent=2)}')
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
)
break break
if status_val == "failed" or success_val is False: if status_val == 'failed' or success_val is False:
log.error( log.error(f'Marker poll failed full response: {json.dumps(poll_result, indent=2)}')
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}" error_msg = poll_result.get('error') or 'Marker returned failure without error message'
)
error_msg = (
poll_result.get("error")
or "Marker returned failure without error message"
)
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"Marker processing failed: {error_msg}", detail=f'Marker processing failed: {error_msg}',
) )
else: else:
raise HTTPException( raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT, status.HTTP_504_GATEWAY_TIMEOUT,
detail="Marker processing timed out", detail='Marker processing timed out',
) )
if not poll_result.get("success", False): if not poll_result.get('success', False):
error_msg = poll_result.get("error") or "Unknown processing error" error_msg = poll_result.get('error') or 'Unknown processing error'
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"Final processing failed: {error_msg}", detail=f'Final processing failed: {error_msg}',
) )
# DataLab format - content in format-specific fields # DataLab format - content in format-specific fields
@@ -210,69 +195,65 @@ class DatalabMarkerLoader:
final_result = poll_result final_result = poll_result
else: else:
# Self-hosted direct response - content in "output" field # Self-hosted direct response - content in "output" field
if "output" in result: if 'output' in result:
log.info("Self-hosted Marker returned direct response without polling") log.info('Self-hosted Marker returned direct response without polling')
raw_content = result.get("output") raw_content = result.get('output')
final_result = result final_result = result
else: else:
available_fields = ( available_fields = list(result.keys()) if isinstance(result, dict) else 'non-dict response'
list(result.keys())
if isinstance(result, dict)
else "non-dict response"
)
raise HTTPException( raise HTTPException(
status.HTTP_502_BAD_GATEWAY, status.HTTP_502_BAD_GATEWAY,
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.", detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
) )
if self.output_format.lower() == "json": if self.output_format.lower() == 'json':
full_text = json.dumps(raw_content, indent=2) full_text = json.dumps(raw_content, indent=2)
elif self.output_format.lower() in {"markdown", "html"}: elif self.output_format.lower() in {'markdown', 'html'}:
full_text = str(raw_content).strip() full_text = str(raw_content).strip()
else: else:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported output format: {self.output_format}", detail=f'Unsupported output format: {self.output_format}',
) )
if not full_text: if not full_text:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail="Marker returned empty content", detail='Marker returned empty content',
) )
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output") marker_output_dir = os.path.join('/app/backend/data/uploads', 'marker_output')
os.makedirs(marker_output_dir, exist_ok=True) os.makedirs(marker_output_dir, exist_ok=True)
file_ext_map = {"markdown": "md", "json": "json", "html": "html"} file_ext_map = {'markdown': 'md', 'json': 'json', 'html': 'html'}
file_ext = file_ext_map.get(self.output_format.lower(), "txt") file_ext = file_ext_map.get(self.output_format.lower(), 'txt')
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}" output_filename = f'{os.path.splitext(filename)[0]}.{file_ext}'
output_path = os.path.join(marker_output_dir, output_filename) output_path = os.path.join(marker_output_dir, output_filename)
try: try:
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, 'w', encoding='utf-8') as f:
f.write(full_text) f.write(full_text)
log.info(f"Saved Marker output to: {output_path}") log.info(f'Saved Marker output to: {output_path}')
except Exception as e: except Exception as e:
log.warning(f"Failed to write marker output to disk: {e}") log.warning(f'Failed to write marker output to disk: {e}')
metadata = { metadata = {
"source": filename, 'source': filename,
"output_format": final_result.get("output_format", self.output_format), 'output_format': final_result.get('output_format', self.output_format),
"page_count": final_result.get("page_count", 0), 'page_count': final_result.get('page_count', 0),
"processed_with_llm": self.use_llm, 'processed_with_llm': self.use_llm,
"request_id": request_id or "", 'request_id': request_id or '',
} }
images = final_result.get("images", {}) images = final_result.get('images', {})
if images: if images:
metadata["image_count"] = len(images) metadata['image_count'] = len(images)
metadata["images"] = json.dumps(list(images.keys())) metadata['images'] = json.dumps(list(images.keys()))
for k, v in metadata.items(): for k, v in metadata.items():
if isinstance(v, (dict, list)): if isinstance(v, (dict, list)):
metadata[k] = json.dumps(v) metadata[k] = json.dumps(v)
elif v is None: elif v is None:
metadata[k] = "" metadata[k] = ''
return [Document(page_content=full_text, metadata=metadata)] return [Document(page_content=full_text, metadata=metadata)]

View File

@@ -29,18 +29,18 @@ class ExternalDocumentLoader(BaseLoader):
self.user = user self.user = user
def load(self) -> List[Document]: def load(self) -> List[Document]:
with open(self.file_path, "rb") as f: with open(self.file_path, 'rb') as f:
data = f.read() data = f.read()
headers = {} headers = {}
if self.mime_type is not None: if self.mime_type is not None:
headers["Content-Type"] = self.mime_type headers['Content-Type'] = self.mime_type
if self.api_key is not None: if self.api_key is not None:
headers["Authorization"] = f"Bearer {self.api_key}" headers['Authorization'] = f'Bearer {self.api_key}'
try: try:
headers["X-Filename"] = quote(os.path.basename(self.file_path)) headers['X-Filename'] = quote(os.path.basename(self.file_path))
except Exception: except Exception:
pass pass
@@ -48,24 +48,23 @@ class ExternalDocumentLoader(BaseLoader):
headers = include_user_info_headers(headers, self.user) headers = include_user_info_headers(headers, self.user)
url = self.url url = self.url
if url.endswith("/"): if url.endswith('/'):
url = url[:-1] url = url[:-1]
try: try:
response = requests.put(f"{url}/process", data=data, headers=headers) response = requests.put(f'{url}/process', data=data, headers=headers)
except Exception as e: except Exception as e:
log.error(f"Error connecting to endpoint: {e}") log.error(f'Error connecting to endpoint: {e}')
raise Exception(f"Error connecting to endpoint: {e}") raise Exception(f'Error connecting to endpoint: {e}')
if response.ok: if response.ok:
response_data = response.json() response_data = response.json()
if response_data: if response_data:
if isinstance(response_data, dict): if isinstance(response_data, dict):
return [ return [
Document( Document(
page_content=response_data.get("page_content"), page_content=response_data.get('page_content'),
metadata=response_data.get("metadata"), metadata=response_data.get('metadata'),
) )
] ]
elif isinstance(response_data, list): elif isinstance(response_data, list):
@@ -73,17 +72,15 @@ class ExternalDocumentLoader(BaseLoader):
for document in response_data: for document in response_data:
documents.append( documents.append(
Document( Document(
page_content=document.get("page_content"), page_content=document.get('page_content'),
metadata=document.get("metadata"), metadata=document.get('metadata'),
) )
) )
return documents return documents
else: else:
raise Exception("Error loading document: Unable to parse content") raise Exception('Error loading document: Unable to parse content')
else: else:
raise Exception("Error loading document: No content returned") raise Exception('Error loading document: No content returned')
else: else:
raise Exception( raise Exception(f'Error loading document: {response.status_code} {response.text}')
f"Error loading document: {response.status_code} {response.text}"
)

View File

@@ -30,22 +30,22 @@ class ExternalWebLoader(BaseLoader):
response = requests.post( response = requests.post(
self.external_url, self.external_url,
headers={ headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader", 'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) External Web Loader',
"Authorization": f"Bearer {self.external_api_key}", 'Authorization': f'Bearer {self.external_api_key}',
}, },
json={ json={
"urls": urls, 'urls': urls,
}, },
) )
response.raise_for_status() response.raise_for_status()
results = response.json() results = response.json()
for result in results: for result in results:
yield Document( yield Document(
page_content=result.get("page_content", ""), page_content=result.get('page_content', ''),
metadata=result.get("metadata", {}), metadata=result.get('metadata', {}),
) )
except Exception as e: except Exception as e:
if self.continue_on_failure: if self.continue_on_failure:
log.error(f"Error extracting content from batch {urls}: {e}") log.error(f'Error extracting content from batch {urls}: {e}')
else: else:
raise e raise e

View File

@@ -30,59 +30,59 @@ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
known_source_ext = [ known_source_ext = [
"go", 'go',
"py", 'py',
"java", 'java',
"sh", 'sh',
"bat", 'bat',
"ps1", 'ps1',
"cmd", 'cmd',
"js", 'js',
"ts", 'ts',
"css", 'css',
"cpp", 'cpp',
"hpp", 'hpp',
"h", 'h',
"c", 'c',
"cs", 'cs',
"sql", 'sql',
"log", 'log',
"ini", 'ini',
"pl", 'pl',
"pm", 'pm',
"r", 'r',
"dart", 'dart',
"dockerfile", 'dockerfile',
"env", 'env',
"php", 'php',
"hs", 'hs',
"hsc", 'hsc',
"lua", 'lua',
"nginxconf", 'nginxconf',
"conf", 'conf',
"m", 'm',
"mm", 'mm',
"plsql", 'plsql',
"perl", 'perl',
"rb", 'rb',
"rs", 'rs',
"db2", 'db2',
"scala", 'scala',
"bash", 'bash',
"swift", 'swift',
"vue", 'vue',
"svelte", 'svelte',
"ex", 'ex',
"exs", 'exs',
"erl", 'erl',
"tsx", 'tsx',
"jsx", 'jsx',
"hs", 'hs',
"lhs", 'lhs',
"json", 'json',
"yaml", 'yaml',
"yml", 'yml',
"toml", 'toml',
] ]
@@ -99,11 +99,11 @@ class ExcelLoader:
xls = pd.ExcelFile(self.file_path) xls = pd.ExcelFile(self.file_path)
for sheet_name in xls.sheet_names: for sheet_name in xls.sheet_names:
df = pd.read_excel(xls, sheet_name=sheet_name) df = pd.read_excel(xls, sheet_name=sheet_name)
text_parts.append(f"Sheet: {sheet_name}\n{df.to_string(index=False)}") text_parts.append(f'Sheet: {sheet_name}\n{df.to_string(index=False)}')
return [ return [
Document( Document(
page_content="\n\n".join(text_parts), page_content='\n\n'.join(text_parts),
metadata={"source": self.file_path}, metadata={'source': self.file_path},
) )
] ]
@@ -125,11 +125,11 @@ class PptxLoader:
if shape.has_text_frame: if shape.has_text_frame:
slide_texts.append(shape.text_frame.text) slide_texts.append(shape.text_frame.text)
if slide_texts: if slide_texts:
text_parts.append(f"Slide {i}:\n" + "\n".join(slide_texts)) text_parts.append(f'Slide {i}:\n' + '\n'.join(slide_texts))
return [ return [
Document( Document(
page_content="\n\n".join(text_parts), page_content='\n\n'.join(text_parts),
metadata={"source": self.file_path}, metadata={'source': self.file_path},
) )
] ]
@@ -143,41 +143,41 @@ class TikaLoader:
self.extract_images = extract_images self.extract_images = extract_images
def load(self) -> list[Document]: def load(self) -> list[Document]:
with open(self.file_path, "rb") as f: with open(self.file_path, 'rb') as f:
data = f.read() data = f.read()
if self.mime_type is not None: if self.mime_type is not None:
headers = {"Content-Type": self.mime_type} headers = {'Content-Type': self.mime_type}
else: else:
headers = {} headers = {}
if self.extract_images == True: if self.extract_images == True:
headers["X-Tika-PDFextractInlineImages"] = "true" headers['X-Tika-PDFextractInlineImages'] = 'true'
endpoint = self.url endpoint = self.url
if not endpoint.endswith("/"): if not endpoint.endswith('/'):
endpoint += "/" endpoint += '/'
endpoint += "tika/text" endpoint += 'tika/text'
r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY) r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY)
if r.ok: if r.ok:
raw_metadata = r.json() raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip() text = raw_metadata.get('X-TIKA:content', '<No text content found>').strip()
if "Content-Type" in raw_metadata: if 'Content-Type' in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"] headers['Content-Type'] = raw_metadata['Content-Type']
log.debug("Tika extracted text: %s", text) log.debug('Tika extracted text: %s', text)
return [Document(page_content=text, metadata=headers)] return [Document(page_content=text, metadata=headers)]
else: else:
raise Exception(f"Error calling Tika: {r.reason}") raise Exception(f'Error calling Tika: {r.reason}')
class DoclingLoader: class DoclingLoader:
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None): def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
self.url = url.rstrip("/") self.url = url.rstrip('/')
self.api_key = api_key self.api_key = api_key
self.file_path = file_path self.file_path = file_path
self.mime_type = mime_type self.mime_type = mime_type
@@ -185,199 +185,183 @@ class DoclingLoader:
self.params = params or {} self.params = params or {}
def load(self) -> list[Document]: def load(self) -> list[Document]:
with open(self.file_path, "rb") as f: with open(self.file_path, 'rb') as f:
headers = {} headers = {}
if self.api_key: if self.api_key:
headers["X-Api-Key"] = f"{self.api_key}" headers['X-Api-Key'] = f'{self.api_key}'
r = requests.post( r = requests.post(
f"{self.url}/v1/convert/file", f'{self.url}/v1/convert/file',
files={ files={
"files": ( 'files': (
self.file_path, self.file_path,
f, f,
self.mime_type or "application/octet-stream", self.mime_type or 'application/octet-stream',
) )
}, },
data={ data={
"image_export_mode": "placeholder", 'image_export_mode': 'placeholder',
**self.params, **self.params,
}, },
headers=headers, headers=headers,
) )
if r.ok: if r.ok:
result = r.json() result = r.json()
document_data = result.get("document", {}) document_data = result.get('document', {})
text = document_data.get("md_content", "<No text content found>") text = document_data.get('md_content', '<No text content found>')
metadata = {"Content-Type": self.mime_type} if self.mime_type else {} metadata = {'Content-Type': self.mime_type} if self.mime_type else {}
log.debug("Docling extracted text: %s", text) log.debug('Docling extracted text: %s', text)
return [Document(page_content=text, metadata=metadata)] return [Document(page_content=text, metadata=metadata)]
else: else:
error_msg = f"Error calling Docling API: {r.reason}" error_msg = f'Error calling Docling API: {r.reason}'
if r.text: if r.text:
try: try:
error_data = r.json() error_data = r.json()
if "detail" in error_data: if 'detail' in error_data:
error_msg += f" - {error_data['detail']}" error_msg += f' - {error_data["detail"]}'
except Exception: except Exception:
error_msg += f" - {r.text}" error_msg += f' - {r.text}'
raise Exception(f"Error calling Docling: {error_msg}") raise Exception(f'Error calling Docling: {error_msg}')
class Loader: class Loader:
def __init__(self, engine: str = "", **kwargs): def __init__(self, engine: str = '', **kwargs):
self.engine = engine self.engine = engine
self.user = kwargs.get("user", None) self.user = kwargs.get('user', None)
self.kwargs = kwargs self.kwargs = kwargs
def load( def load(self, filename: str, file_content_type: str, file_path: str) -> list[Document]:
self, filename: str, file_content_type: str, file_path: str
) -> list[Document]:
loader = self._get_loader(filename, file_content_type, file_path) loader = self._get_loader(filename, file_content_type, file_path)
docs = loader.load() docs = loader.load()
return [ return [Document(page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata) for doc in docs]
Document(
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
)
for doc in docs
]
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool: def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
return file_ext in known_source_ext or ( return file_ext in known_source_ext or (
file_content_type file_content_type
and file_content_type.find("text/") >= 0 and file_content_type.find('text/') >= 0
# Avoid text/html files being detected as text # Avoid text/html files being detected as text
and not file_content_type.find("html") >= 0 and not file_content_type.find('html') >= 0
) )
def _get_loader(self, filename: str, file_content_type: str, file_path: str): def _get_loader(self, filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower() file_ext = filename.split('.')[-1].lower()
if ( if (
self.engine == "external" self.engine == 'external'
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL") and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL')
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY") and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY')
): ):
loader = ExternalDocumentLoader( loader = ExternalDocumentLoader(
file_path=file_path, file_path=file_path,
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"), url=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL'),
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"), api_key=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY'),
mime_type=file_content_type, mime_type=file_content_type,
user=self.user, user=self.user,
) )
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): elif self.engine == 'tika' and self.kwargs.get('TIKA_SERVER_URL'):
if self._is_text_file(file_ext, file_content_type): if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
else: else:
loader = TikaLoader( loader = TikaLoader(
url=self.kwargs.get("TIKA_SERVER_URL"), url=self.kwargs.get('TIKA_SERVER_URL'),
file_path=file_path, file_path=file_path,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"), extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
) )
elif ( elif (
self.engine == "datalab_marker" self.engine == 'datalab_marker'
and self.kwargs.get("DATALAB_MARKER_API_KEY") and self.kwargs.get('DATALAB_MARKER_API_KEY')
and file_ext and file_ext
in [ in [
"pdf", 'pdf',
"xls", 'xls',
"xlsx", 'xlsx',
"ods", 'ods',
"doc", 'doc',
"docx", 'docx',
"odt", 'odt',
"ppt", 'ppt',
"pptx", 'pptx',
"odp", 'odp',
"html", 'html',
"epub", 'epub',
"png", 'png',
"jpeg", 'jpeg',
"jpg", 'jpg',
"webp", 'webp',
"gif", 'gif',
"tiff", 'tiff',
] ]
): ):
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "") api_base_url = self.kwargs.get('DATALAB_MARKER_API_BASE_URL', '')
if not api_base_url or api_base_url.strip() == "": if not api_base_url or api_base_url.strip() == '':
api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349 api_base_url = 'https://www.datalab.to/api/v1/marker' # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
loader = DatalabMarkerLoader( loader = DatalabMarkerLoader(
file_path=file_path, file_path=file_path,
api_key=self.kwargs["DATALAB_MARKER_API_KEY"], api_key=self.kwargs['DATALAB_MARKER_API_KEY'],
api_base_url=api_base_url, api_base_url=api_base_url,
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"), additional_config=self.kwargs.get('DATALAB_MARKER_ADDITIONAL_CONFIG'),
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False), use_llm=self.kwargs.get('DATALAB_MARKER_USE_LLM', False),
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False), skip_cache=self.kwargs.get('DATALAB_MARKER_SKIP_CACHE', False),
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False), force_ocr=self.kwargs.get('DATALAB_MARKER_FORCE_OCR', False),
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False), paginate=self.kwargs.get('DATALAB_MARKER_PAGINATE', False),
strip_existing_ocr=self.kwargs.get( strip_existing_ocr=self.kwargs.get('DATALAB_MARKER_STRIP_EXISTING_OCR', False),
"DATALAB_MARKER_STRIP_EXISTING_OCR", False disable_image_extraction=self.kwargs.get('DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION', False),
), format_lines=self.kwargs.get('DATALAB_MARKER_FORMAT_LINES', False),
disable_image_extraction=self.kwargs.get( output_format=self.kwargs.get('DATALAB_MARKER_OUTPUT_FORMAT', 'markdown'),
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
),
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
output_format=self.kwargs.get(
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
),
) )
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"): elif self.engine == 'docling' and self.kwargs.get('DOCLING_SERVER_URL'):
if self._is_text_file(file_ext, file_content_type): if self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
else: else:
# Build params for DoclingLoader # Build params for DoclingLoader
params = self.kwargs.get("DOCLING_PARAMS", {}) params = self.kwargs.get('DOCLING_PARAMS', {})
if not isinstance(params, dict): if not isinstance(params, dict):
try: try:
params = json.loads(params) params = json.loads(params)
except json.JSONDecodeError: except json.JSONDecodeError:
log.error("Invalid DOCLING_PARAMS format, expected JSON object") log.error('Invalid DOCLING_PARAMS format, expected JSON object')
params = {} params = {}
loader = DoclingLoader( loader = DoclingLoader(
url=self.kwargs.get("DOCLING_SERVER_URL"), url=self.kwargs.get('DOCLING_SERVER_URL'),
api_key=self.kwargs.get("DOCLING_API_KEY", None), api_key=self.kwargs.get('DOCLING_API_KEY', None),
file_path=file_path, file_path=file_path,
mime_type=file_content_type, mime_type=file_content_type,
params=params, params=params,
) )
elif ( elif (
self.engine == "document_intelligence" self.engine == 'document_intelligence'
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != "" and self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT') != ''
and ( and (
file_ext in ["pdf", "docx", "ppt", "pptx"] file_ext in ['pdf', 'docx', 'ppt', 'pptx']
or file_content_type or file_content_type
in [ in [
"application/vnd.openxmlformats-officedocument.wordprocessingml.document", 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
"application/vnd.ms-powerpoint", 'application/vnd.ms-powerpoint',
"application/vnd.openxmlformats-officedocument.presentationml.presentation", 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
] ]
) )
): ):
if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "": if self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY') != '':
loader = AzureAIDocumentIntelligenceLoader( loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path, file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"), api_key=self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY'),
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"), api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
) )
else: else:
loader = AzureAIDocumentIntelligenceLoader( loader = AzureAIDocumentIntelligenceLoader(
file_path=file_path, file_path=file_path,
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
azure_credential=DefaultAzureCredential(), azure_credential=DefaultAzureCredential(),
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"), api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
) )
elif self.engine == "mineru" and file_ext in [ elif self.engine == 'mineru' and file_ext in ['pdf']: # MinerU currently only supports PDF
"pdf" mineru_timeout = self.kwargs.get('MINERU_API_TIMEOUT', 300)
]: # MinerU currently only supports PDF
mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300)
if mineru_timeout: if mineru_timeout:
try: try:
mineru_timeout = int(mineru_timeout) mineru_timeout = int(mineru_timeout)
@@ -386,111 +370,115 @@ class Loader:
loader = MinerULoader( loader = MinerULoader(
file_path=file_path, file_path=file_path,
api_mode=self.kwargs.get("MINERU_API_MODE", "local"), api_mode=self.kwargs.get('MINERU_API_MODE', 'local'),
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"), api_url=self.kwargs.get('MINERU_API_URL', 'http://localhost:8000'),
api_key=self.kwargs.get("MINERU_API_KEY", ""), api_key=self.kwargs.get('MINERU_API_KEY', ''),
params=self.kwargs.get("MINERU_PARAMS", {}), params=self.kwargs.get('MINERU_PARAMS', {}),
timeout=mineru_timeout, timeout=mineru_timeout,
) )
elif ( elif (
self.engine == "mistral_ocr" self.engine == 'mistral_ocr'
and self.kwargs.get("MISTRAL_OCR_API_KEY") != "" and self.kwargs.get('MISTRAL_OCR_API_KEY') != ''
and file_ext and file_ext in ['pdf'] # Mistral OCR currently only supports PDF and images
in ["pdf"] # Mistral OCR currently only supports PDF and images
): ):
loader = MistralLoader( loader = MistralLoader(
base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"), base_url=self.kwargs.get('MISTRAL_OCR_API_BASE_URL'),
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"), api_key=self.kwargs.get('MISTRAL_OCR_API_KEY'),
file_path=file_path, file_path=file_path,
) )
else: else:
if file_ext == "pdf": if file_ext == 'pdf':
loader = PyPDFLoader( loader = PyPDFLoader(
file_path, file_path,
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"), extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
mode=self.kwargs.get("PDF_LOADER_MODE", "page"), mode=self.kwargs.get('PDF_LOADER_MODE', 'page'),
) )
elif file_ext == "csv": elif file_ext == 'csv':
loader = CSVLoader(file_path, autodetect_encoding=True) loader = CSVLoader(file_path, autodetect_encoding=True)
elif file_ext == "rst": elif file_ext == 'rst':
try: try:
from langchain_community.document_loaders import UnstructuredRSTLoader from langchain_community.document_loaders import UnstructuredRSTLoader
loader = UnstructuredRSTLoader(file_path, mode="elements")
loader = UnstructuredRSTLoader(file_path, mode='elements')
except ImportError: except ImportError:
log.warning( log.warning(
"The 'unstructured' package is not installed. " "The 'unstructured' package is not installed. "
"Falling back to plain text loading for .rst file. " 'Falling back to plain text loading for .rst file. '
"Install it with: pip install unstructured" 'Install it with: pip install unstructured'
) )
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
elif file_ext == "xml": elif file_ext == 'xml':
try: try:
from langchain_community.document_loaders import UnstructuredXMLLoader from langchain_community.document_loaders import UnstructuredXMLLoader
loader = UnstructuredXMLLoader(file_path) loader = UnstructuredXMLLoader(file_path)
except ImportError: except ImportError:
log.warning( log.warning(
"The 'unstructured' package is not installed. " "The 'unstructured' package is not installed. "
"Falling back to plain text loading for .xml file. " 'Falling back to plain text loading for .xml file. '
"Install it with: pip install unstructured" 'Install it with: pip install unstructured'
) )
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
elif file_ext in ["htm", "html"]: elif file_ext in ['htm', 'html']:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape") loader = BSHTMLLoader(file_path, open_encoding='unicode_escape')
elif file_ext == "md": elif file_ext == 'md':
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
elif file_content_type == "application/epub+zip": elif file_content_type == 'application/epub+zip':
try: try:
from langchain_community.document_loaders import UnstructuredEPubLoader from langchain_community.document_loaders import UnstructuredEPubLoader
loader = UnstructuredEPubLoader(file_path) loader = UnstructuredEPubLoader(file_path)
except ImportError: except ImportError:
raise ValueError( raise ValueError(
"Processing .epub files requires the 'unstructured' package. " "Processing .epub files requires the 'unstructured' package. "
"Install it with: pip install unstructured" 'Install it with: pip install unstructured'
) )
elif ( elif (
file_content_type file_content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document" or file_ext == 'docx'
or file_ext == "docx"
): ):
loader = Docx2txtLoader(file_path) loader = Docx2txtLoader(file_path)
elif file_content_type in [ elif file_content_type in [
"application/vnd.ms-excel", 'application/vnd.ms-excel',
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
] or file_ext in ["xls", "xlsx"]: ] or file_ext in ['xls', 'xlsx']:
try: try:
from langchain_community.document_loaders import UnstructuredExcelLoader from langchain_community.document_loaders import UnstructuredExcelLoader
loader = UnstructuredExcelLoader(file_path) loader = UnstructuredExcelLoader(file_path)
except ImportError: except ImportError:
log.warning( log.warning(
"The 'unstructured' package is not installed. " "The 'unstructured' package is not installed. "
"Falling back to pandas for Excel file loading. " 'Falling back to pandas for Excel file loading. '
"Install unstructured for better results: pip install unstructured" 'Install unstructured for better results: pip install unstructured'
) )
loader = ExcelLoader(file_path) loader = ExcelLoader(file_path)
elif file_content_type in [ elif file_content_type in [
"application/vnd.ms-powerpoint", 'application/vnd.ms-powerpoint',
"application/vnd.openxmlformats-officedocument.presentationml.presentation", 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
] or file_ext in ["ppt", "pptx"]: ] or file_ext in ['ppt', 'pptx']:
try: try:
from langchain_community.document_loaders import UnstructuredPowerPointLoader from langchain_community.document_loaders import UnstructuredPowerPointLoader
loader = UnstructuredPowerPointLoader(file_path) loader = UnstructuredPowerPointLoader(file_path)
except ImportError: except ImportError:
log.warning( log.warning(
"The 'unstructured' package is not installed. " "The 'unstructured' package is not installed. "
"Falling back to python-pptx for PowerPoint file loading. " 'Falling back to python-pptx for PowerPoint file loading. '
"Install unstructured for better results: pip install unstructured" 'Install unstructured for better results: pip install unstructured'
) )
loader = PptxLoader(file_path) loader = PptxLoader(file_path)
elif file_ext == "msg": elif file_ext == 'msg':
loader = OutlookMessageLoader(file_path) loader = OutlookMessageLoader(file_path)
elif file_ext == "odt": elif file_ext == 'odt':
try: try:
from langchain_community.document_loaders import UnstructuredODTLoader from langchain_community.document_loaders import UnstructuredODTLoader
loader = UnstructuredODTLoader(file_path) loader = UnstructuredODTLoader(file_path)
except ImportError: except ImportError:
raise ValueError( raise ValueError(
"Processing .odt files requires the 'unstructured' package. " "Processing .odt files requires the 'unstructured' package. "
"Install it with: pip install unstructured" 'Install it with: pip install unstructured'
) )
elif self._is_text_file(file_ext, file_content_type): elif self._is_text_file(file_ext, file_content_type):
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
@@ -498,4 +486,3 @@ class Loader:
loader = TextLoader(file_path, autodetect_encoding=True) loader = TextLoader(file_path, autodetect_encoding=True)
return loader return loader

View File

@@ -22,37 +22,35 @@ class MinerULoader:
def __init__( def __init__(
self, self,
file_path: str, file_path: str,
api_mode: str = "local", api_mode: str = 'local',
api_url: str = "http://localhost:8000", api_url: str = 'http://localhost:8000',
api_key: str = "", api_key: str = '',
params: dict = None, params: dict = None,
timeout: Optional[int] = 300, timeout: Optional[int] = 300,
): ):
self.file_path = file_path self.file_path = file_path
self.api_mode = api_mode.lower() self.api_mode = api_mode.lower()
self.api_url = api_url.rstrip("/") self.api_url = api_url.rstrip('/')
self.api_key = api_key self.api_key = api_key
self.timeout = timeout self.timeout = timeout
# Parse params dict with defaults # Parse params dict with defaults
self.params = params or {} self.params = params or {}
self.enable_ocr = params.get("enable_ocr", False) self.enable_ocr = params.get('enable_ocr', False)
self.enable_formula = params.get("enable_formula", True) self.enable_formula = params.get('enable_formula', True)
self.enable_table = params.get("enable_table", True) self.enable_table = params.get('enable_table', True)
self.language = params.get("language", "en") self.language = params.get('language', 'en')
self.model_version = params.get("model_version", "pipeline") self.model_version = params.get('model_version', 'pipeline')
self.page_ranges = self.params.pop("page_ranges", "") self.page_ranges = self.params.pop('page_ranges', '')
# Validate API mode # Validate API mode
if self.api_mode not in ["local", "cloud"]: if self.api_mode not in ['local', 'cloud']:
raise ValueError( raise ValueError(f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'")
f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'"
)
# Validate Cloud API requirements # Validate Cloud API requirements
if self.api_mode == "cloud" and not self.api_key: if self.api_mode == 'cloud' and not self.api_key:
raise ValueError("API key is required for Cloud API mode") raise ValueError('API key is required for Cloud API mode')
def load(self) -> List[Document]: def load(self) -> List[Document]:
""" """
@@ -60,12 +58,12 @@ class MinerULoader:
Routes to Cloud or Local API based on api_mode. Routes to Cloud or Local API based on api_mode.
""" """
try: try:
if self.api_mode == "cloud": if self.api_mode == 'cloud':
return self._load_cloud_api() return self._load_cloud_api()
else: else:
return self._load_local_api() return self._load_local_api()
except Exception as e: except Exception as e:
log.error(f"Error loading document with MinerU: {e}") log.error(f'Error loading document with MinerU: {e}')
raise raise
def _load_local_api(self) -> List[Document]: def _load_local_api(self) -> List[Document]:
@@ -73,14 +71,14 @@ class MinerULoader:
Load document using Local API (synchronous). Load document using Local API (synchronous).
Posts file to /file_parse endpoint and gets immediate response. Posts file to /file_parse endpoint and gets immediate response.
""" """
log.info(f"Using MinerU Local API at {self.api_url}") log.info(f'Using MinerU Local API at {self.api_url}')
filename = os.path.basename(self.file_path) filename = os.path.basename(self.file_path)
# Build form data for Local API # Build form data for Local API
form_data = { form_data = {
**self.params, **self.params,
"return_md": "true", 'return_md': 'true',
} }
# Page ranges (Local API uses start_page_id and end_page_id) # Page ranges (Local API uses start_page_id and end_page_id)
@@ -89,18 +87,18 @@ class MinerULoader:
# Full page range parsing would require parsing the string # Full page range parsing would require parsing the string
log.warning( log.warning(
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. " f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
"Consider using start_page_id/end_page_id parameters if needed." 'Consider using start_page_id/end_page_id parameters if needed.'
) )
try: try:
with open(self.file_path, "rb") as f: with open(self.file_path, 'rb') as f:
files = {"files": (filename, f, "application/octet-stream")} files = {'files': (filename, f, 'application/octet-stream')}
log.info(f"Sending file to MinerU Local API: {filename}") log.info(f'Sending file to MinerU Local API: {filename}')
log.debug(f"Local API parameters: {form_data}") log.debug(f'Local API parameters: {form_data}')
response = requests.post( response = requests.post(
f"{self.api_url}/file_parse", f'{self.api_url}/file_parse',
data=form_data, data=form_data,
files=files, files=files,
timeout=self.timeout, timeout=self.timeout,
@@ -108,27 +106,25 @@ class MinerULoader:
response.raise_for_status() response.raise_for_status()
except FileNotFoundError: except FileNotFoundError:
raise HTTPException( raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
except requests.Timeout: except requests.Timeout:
raise HTTPException( raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT, status.HTTP_504_GATEWAY_TIMEOUT,
detail="MinerU Local API request timed out", detail='MinerU Local API request timed out',
) )
except requests.HTTPError as e: except requests.HTTPError as e:
error_detail = f"MinerU Local API request failed: {e}" error_detail = f'MinerU Local API request failed: {e}'
if e.response is not None: if e.response is not None:
try: try:
error_data = e.response.json() error_data = e.response.json()
error_detail += f" - {error_data}" error_detail += f' - {error_data}'
except Exception: except Exception:
error_detail += f" - {e.response.text}" error_detail += f' - {e.response.text}'
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error calling MinerU Local API: {str(e)}", detail=f'Error calling MinerU Local API: {str(e)}',
) )
# Parse response # Parse response
@@ -137,41 +133,41 @@ class MinerULoader:
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(
status.HTTP_502_BAD_GATEWAY, status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response from MinerU Local API: {e}", detail=f'Invalid JSON response from MinerU Local API: {e}',
) )
# Extract markdown content from response # Extract markdown content from response
if "results" not in result: if 'results' not in result:
raise HTTPException( raise HTTPException(
status.HTTP_502_BAD_GATEWAY, status.HTTP_502_BAD_GATEWAY,
detail="MinerU Local API response missing 'results' field", detail="MinerU Local API response missing 'results' field",
) )
results = result["results"] results = result['results']
if not results: if not results:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail="MinerU returned empty results", detail='MinerU returned empty results',
) )
# Get the first (and typically only) result # Get the first (and typically only) result
file_result = list(results.values())[0] file_result = list(results.values())[0]
markdown_content = file_result.get("md_content", "") markdown_content = file_result.get('md_content', '')
if not markdown_content: if not markdown_content:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail="MinerU returned empty markdown content", detail='MinerU returned empty markdown content',
) )
log.info(f"Successfully parsed document with MinerU Local API: {filename}") log.info(f'Successfully parsed document with MinerU Local API: {filename}')
# Create metadata # Create metadata
metadata = { metadata = {
"source": filename, 'source': filename,
"api_mode": "local", 'api_mode': 'local',
"backend": result.get("backend", "unknown"), 'backend': result.get('backend', 'unknown'),
"version": result.get("version", "unknown"), 'version': result.get('version', 'unknown'),
} }
return [Document(page_content=markdown_content, metadata=metadata)] return [Document(page_content=markdown_content, metadata=metadata)]
@@ -181,7 +177,7 @@ class MinerULoader:
Load document using Cloud API (asynchronous). Load document using Cloud API (asynchronous).
Uses batch upload endpoint to avoid need for public file URLs. Uses batch upload endpoint to avoid need for public file URLs.
""" """
log.info(f"Using MinerU Cloud API at {self.api_url}") log.info(f'Using MinerU Cloud API at {self.api_url}')
filename = os.path.basename(self.file_path) filename = os.path.basename(self.file_path)
@@ -195,17 +191,15 @@ class MinerULoader:
result = self._poll_batch_status(batch_id, filename) result = self._poll_batch_status(batch_id, filename)
# Step 4: Download and extract markdown from ZIP # Step 4: Download and extract markdown from ZIP
markdown_content = self._download_and_extract_zip( markdown_content = self._download_and_extract_zip(result['full_zip_url'], filename)
result["full_zip_url"], filename
)
log.info(f"Successfully parsed document with MinerU Cloud API: {filename}") log.info(f'Successfully parsed document with MinerU Cloud API: {filename}')
# Create metadata # Create metadata
metadata = { metadata = {
"source": filename, 'source': filename,
"api_mode": "cloud", 'api_mode': 'cloud',
"batch_id": batch_id, 'batch_id': batch_id,
} }
return [Document(page_content=markdown_content, metadata=metadata)] return [Document(page_content=markdown_content, metadata=metadata)]
@@ -216,49 +210,49 @@ class MinerULoader:
Returns (batch_id, upload_url). Returns (batch_id, upload_url).
""" """
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", 'Authorization': f'Bearer {self.api_key}',
"Content-Type": "application/json", 'Content-Type': 'application/json',
} }
# Build request body # Build request body
request_body = { request_body = {
**self.params, **self.params,
"files": [ 'files': [
{ {
"name": filename, 'name': filename,
"is_ocr": self.enable_ocr, 'is_ocr': self.enable_ocr,
} }
], ],
} }
# Add page ranges if specified # Add page ranges if specified
if self.page_ranges: if self.page_ranges:
request_body["files"][0]["page_ranges"] = self.page_ranges request_body['files'][0]['page_ranges'] = self.page_ranges
log.info(f"Requesting upload URL for: {filename}") log.info(f'Requesting upload URL for: {filename}')
log.debug(f"Cloud API request body: {request_body}") log.debug(f'Cloud API request body: {request_body}')
try: try:
response = requests.post( response = requests.post(
f"{self.api_url}/file-urls/batch", f'{self.api_url}/file-urls/batch',
headers=headers, headers=headers,
json=request_body, json=request_body,
timeout=30, timeout=30,
) )
response.raise_for_status() response.raise_for_status()
except requests.HTTPError as e: except requests.HTTPError as e:
error_detail = f"Failed to request upload URL: {e}" error_detail = f'Failed to request upload URL: {e}'
if e.response is not None: if e.response is not None:
try: try:
error_data = e.response.json() error_data = e.response.json()
error_detail += f" - {error_data.get('msg', error_data)}" error_detail += f' - {error_data.get("msg", error_data)}'
except Exception: except Exception:
error_detail += f" - {e.response.text}" error_detail += f' - {e.response.text}'
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error requesting upload URL: {str(e)}", detail=f'Error requesting upload URL: {str(e)}',
) )
try: try:
@@ -266,28 +260,28 @@ class MinerULoader:
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(
status.HTTP_502_BAD_GATEWAY, status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response: {e}", detail=f'Invalid JSON response: {e}',
) )
# Check for API error response # Check for API error response
if result.get("code") != 0: if result.get('code') != 0:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}", detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
) )
data = result.get("data", {}) data = result.get('data', {})
batch_id = data.get("batch_id") batch_id = data.get('batch_id')
file_urls = data.get("file_urls", []) file_urls = data.get('file_urls', [])
if not batch_id or not file_urls: if not batch_id or not file_urls:
raise HTTPException( raise HTTPException(
status.HTTP_502_BAD_GATEWAY, status.HTTP_502_BAD_GATEWAY,
detail="MinerU Cloud API response missing batch_id or file_urls", detail='MinerU Cloud API response missing batch_id or file_urls',
) )
upload_url = file_urls[0] upload_url = file_urls[0]
log.info(f"Received upload URL for batch: {batch_id}") log.info(f'Received upload URL for batch: {batch_id}')
return batch_id, upload_url return batch_id, upload_url
@@ -295,10 +289,10 @@ class MinerULoader:
""" """
Upload file to presigned URL (no authentication needed). Upload file to presigned URL (no authentication needed).
""" """
log.info(f"Uploading file to presigned URL") log.info(f'Uploading file to presigned URL')
try: try:
with open(self.file_path, "rb") as f: with open(self.file_path, 'rb') as f:
response = requests.put( response = requests.put(
upload_url, upload_url,
data=f, data=f,
@@ -306,26 +300,24 @@ class MinerULoader:
) )
response.raise_for_status() response.raise_for_status()
except FileNotFoundError: except FileNotFoundError:
raise HTTPException( raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
)
except requests.Timeout: except requests.Timeout:
raise HTTPException( raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT, status.HTTP_504_GATEWAY_TIMEOUT,
detail="File upload to presigned URL timed out", detail='File upload to presigned URL timed out',
) )
except requests.HTTPError as e: except requests.HTTPError as e:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"Failed to upload file to presigned URL: {e}", detail=f'Failed to upload file to presigned URL: {e}',
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error uploading file: {str(e)}", detail=f'Error uploading file: {str(e)}',
) )
log.info("File uploaded successfully") log.info('File uploaded successfully')
def _poll_batch_status(self, batch_id: str, filename: str) -> dict: def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
""" """
@@ -333,35 +325,35 @@ class MinerULoader:
Returns the result dict for the file. Returns the result dict for the file.
""" """
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", 'Authorization': f'Bearer {self.api_key}',
} }
max_iterations = 300 # 10 minutes max (2 seconds per iteration) max_iterations = 300 # 10 minutes max (2 seconds per iteration)
poll_interval = 2 # seconds poll_interval = 2 # seconds
log.info(f"Polling batch status: {batch_id}") log.info(f'Polling batch status: {batch_id}')
for iteration in range(max_iterations): for iteration in range(max_iterations):
try: try:
response = requests.get( response = requests.get(
f"{self.api_url}/extract-results/batch/{batch_id}", f'{self.api_url}/extract-results/batch/{batch_id}',
headers=headers, headers=headers,
timeout=30, timeout=30,
) )
response.raise_for_status() response.raise_for_status()
except requests.HTTPError as e: except requests.HTTPError as e:
error_detail = f"Failed to poll batch status: {e}" error_detail = f'Failed to poll batch status: {e}'
if e.response is not None: if e.response is not None:
try: try:
error_data = e.response.json() error_data = e.response.json()
error_detail += f" - {error_data.get('msg', error_data)}" error_detail += f' - {error_data.get("msg", error_data)}'
except Exception: except Exception:
error_detail += f" - {e.response.text}" error_detail += f' - {e.response.text}'
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error polling batch status: {str(e)}", detail=f'Error polling batch status: {str(e)}',
) )
try: try:
@@ -369,58 +361,56 @@ class MinerULoader:
except ValueError as e: except ValueError as e:
raise HTTPException( raise HTTPException(
status.HTTP_502_BAD_GATEWAY, status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid JSON response while polling: {e}", detail=f'Invalid JSON response while polling: {e}',
) )
# Check for API error response # Check for API error response
if result.get("code") != 0: if result.get('code') != 0:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}", detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
) )
data = result.get("data", {}) data = result.get('data', {})
extract_result = data.get("extract_result", []) extract_result = data.get('extract_result', [])
# Find our file in the batch results # Find our file in the batch results
file_result = None file_result = None
for item in extract_result: for item in extract_result:
if item.get("file_name") == filename: if item.get('file_name') == filename:
file_result = item file_result = item
break break
if not file_result: if not file_result:
raise HTTPException( raise HTTPException(
status.HTTP_502_BAD_GATEWAY, status.HTTP_502_BAD_GATEWAY,
detail=f"File {filename} not found in batch results", detail=f'File {filename} not found in batch results',
) )
state = file_result.get("state") state = file_result.get('state')
if state == "done": if state == 'done':
log.info(f"Processing complete for {filename}") log.info(f'Processing complete for {filename}')
return file_result return file_result
elif state == "failed": elif state == 'failed':
error_msg = file_result.get("err_msg", "Unknown error") error_msg = file_result.get('err_msg', 'Unknown error')
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"MinerU processing failed: {error_msg}", detail=f'MinerU processing failed: {error_msg}',
) )
elif state in ["waiting-file", "pending", "running", "converting"]: elif state in ['waiting-file', 'pending', 'running', 'converting']:
# Still processing # Still processing
if iteration % 10 == 0: # Log every 20 seconds if iteration % 10 == 0: # Log every 20 seconds
log.info( log.info(f'Processing status: {state} (iteration {iteration + 1}/{max_iterations})')
f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})"
)
time.sleep(poll_interval) time.sleep(poll_interval)
else: else:
log.warning(f"Unknown state: {state}") log.warning(f'Unknown state: {state}')
time.sleep(poll_interval) time.sleep(poll_interval)
# Timeout # Timeout
raise HTTPException( raise HTTPException(
status.HTTP_504_GATEWAY_TIMEOUT, status.HTTP_504_GATEWAY_TIMEOUT,
detail="MinerU processing timed out after 10 minutes", detail='MinerU processing timed out after 10 minutes',
) )
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str: def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
@@ -428,7 +418,7 @@ class MinerULoader:
Download ZIP file from CDN and extract markdown content. Download ZIP file from CDN and extract markdown content.
Returns the markdown content as a string. Returns the markdown content as a string.
""" """
log.info(f"Downloading results from: {zip_url}") log.info(f'Downloading results from: {zip_url}')
try: try:
response = requests.get(zip_url, timeout=60) response = requests.get(zip_url, timeout=60)
@@ -436,23 +426,23 @@ class MinerULoader:
except requests.HTTPError as e: except requests.HTTPError as e:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail=f"Failed to download results ZIP: {e}", detail=f'Failed to download results ZIP: {e}',
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error downloading results: {str(e)}", detail=f'Error downloading results: {str(e)}',
) )
# Save ZIP to temporary file and extract # Save ZIP to temporary file and extract
try: try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip: with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_zip:
tmp_zip.write(response.content) tmp_zip.write(response.content)
tmp_zip_path = tmp_zip.name tmp_zip_path = tmp_zip.name
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
# Extract ZIP # Extract ZIP
with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref: with zipfile.ZipFile(tmp_zip_path, 'r') as zip_ref:
zip_ref.extractall(tmp_dir) zip_ref.extractall(tmp_dir)
# Find markdown file - search recursively for any .md file # Find markdown file - search recursively for any .md file
@@ -466,33 +456,27 @@ class MinerULoader:
full_path = os.path.join(root, file) full_path = os.path.join(root, file)
all_files.append(full_path) all_files.append(full_path)
# Look for any .md file # Look for any .md file
if file.endswith(".md"): if file.endswith('.md'):
found_md_path = full_path found_md_path = full_path
log.info(f"Found markdown file at: {full_path}") log.info(f'Found markdown file at: {full_path}')
try: try:
with open(full_path, "r", encoding="utf-8") as f: with open(full_path, 'r', encoding='utf-8') as f:
markdown_content = f.read() markdown_content = f.read()
if ( if markdown_content: # Use the first non-empty markdown file
markdown_content
): # Use the first non-empty markdown file
break break
except Exception as e: except Exception as e:
log.warning(f"Failed to read {full_path}: {e}") log.warning(f'Failed to read {full_path}: {e}')
if markdown_content: if markdown_content:
break break
if markdown_content is None: if markdown_content is None:
log.error(f"Available files in ZIP: {all_files}") log.error(f'Available files in ZIP: {all_files}')
# Try to provide more helpful error message # Try to provide more helpful error message
md_files = [f for f in all_files if f.endswith(".md")] md_files = [f for f in all_files if f.endswith('.md')]
if md_files: if md_files:
error_msg = ( error_msg = f"Found .md files but couldn't read them: {md_files}"
f"Found .md files but couldn't read them: {md_files}"
)
else: else:
error_msg = ( error_msg = f'No .md files found in ZIP. Available files: {all_files}'
f"No .md files found in ZIP. Available files: {all_files}"
)
raise HTTPException( raise HTTPException(
status.HTTP_502_BAD_GATEWAY, status.HTTP_502_BAD_GATEWAY,
detail=error_msg, detail=error_msg,
@@ -504,21 +488,19 @@ class MinerULoader:
except zipfile.BadZipFile as e: except zipfile.BadZipFile as e:
raise HTTPException( raise HTTPException(
status.HTTP_502_BAD_GATEWAY, status.HTTP_502_BAD_GATEWAY,
detail=f"Invalid ZIP file received: {e}", detail=f'Invalid ZIP file received: {e}',
) )
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error extracting ZIP: {str(e)}", detail=f'Error extracting ZIP: {str(e)}',
) )
if not markdown_content: if not markdown_content:
raise HTTPException( raise HTTPException(
status.HTTP_400_BAD_REQUEST, status.HTTP_400_BAD_REQUEST,
detail="Extracted markdown content is empty", detail='Extracted markdown content is empty',
) )
log.info( log.info(f'Successfully extracted markdown content ({len(markdown_content)} characters)')
f"Successfully extracted markdown content ({len(markdown_content)} characters)"
)
return markdown_content return markdown_content

View File

@@ -49,13 +49,11 @@ class MistralLoader:
enable_debug_logging: Enable detailed debug logs. enable_debug_logging: Enable detailed debug logs.
""" """
if not api_key: if not api_key:
raise ValueError("API key cannot be empty.") raise ValueError('API key cannot be empty.')
if not os.path.exists(file_path): if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found at {file_path}") raise FileNotFoundError(f'File not found at {file_path}')
self.base_url = ( self.base_url = base_url.rstrip('/') if base_url else 'https://api.mistral.ai/v1'
base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1"
)
self.api_key = api_key self.api_key = api_key
self.file_path = file_path self.file_path = file_path
self.timeout = timeout self.timeout = timeout
@@ -65,18 +63,10 @@ class MistralLoader:
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations # PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
# This prevents long-running OCR operations from affecting quick operations # This prevents long-running OCR operations from affecting quick operations
# and improves user experience by failing fast on operations that should be quick # and improves user experience by failing fast on operations that should be quick
self.upload_timeout = min( self.upload_timeout = min(timeout, 120) # Cap upload at 2 minutes - prevents hanging on large files
timeout, 120 self.url_timeout = 30 # URL requests should be fast - fail quickly if API is slow
) # Cap upload at 2 minutes - prevents hanging on large files self.ocr_timeout = timeout # OCR can take the full timeout - this is the heavy operation
self.url_timeout = ( self.cleanup_timeout = 30 # Cleanup should be quick - don't hang on file deletion
30 # URL requests should be fast - fail quickly if API is slow
)
self.ocr_timeout = (
timeout # OCR can take the full timeout - this is the heavy operation
)
self.cleanup_timeout = (
30 # Cleanup should be quick - don't hang on file deletion
)
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls # PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing # This avoids multiple os.path.basename() and os.path.getsize() calls during processing
@@ -85,8 +75,8 @@ class MistralLoader:
# ENHANCEMENT: Added User-Agent for better API tracking and debugging # ENHANCEMENT: Added User-Agent for better API tracking and debugging
self.headers = { self.headers = {
"Authorization": f"Bearer {self.api_key}", 'Authorization': f'Bearer {self.api_key}',
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage 'User-Agent': 'OpenWebUI-MistralLoader/2.0', # Helps API provider track usage
} }
def _debug_log(self, message: str, *args) -> None: def _debug_log(self, message: str, *args) -> None:
@@ -108,43 +98,39 @@ class MistralLoader:
return {} # Return empty dict if no content return {} # Return empty dict if no content
return response.json() return response.json()
except requests.exceptions.HTTPError as http_err: except requests.exceptions.HTTPError as http_err:
log.error(f"HTTP error occurred: {http_err} - Response: {response.text}") log.error(f'HTTP error occurred: {http_err} - Response: {response.text}')
raise raise
except requests.exceptions.RequestException as req_err: except requests.exceptions.RequestException as req_err:
log.error(f"Request exception occurred: {req_err}") log.error(f'Request exception occurred: {req_err}')
raise raise
except ValueError as json_err: # Includes JSONDecodeError except ValueError as json_err: # Includes JSONDecodeError
log.error(f"JSON decode error: {json_err} - Response: {response.text}") log.error(f'JSON decode error: {json_err} - Response: {response.text}')
raise # Re-raise after logging raise # Re-raise after logging
async def _handle_response_async( async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]:
self, response: aiohttp.ClientResponse
) -> Dict[str, Any]:
"""Async version of response handling with better error info.""" """Async version of response handling with better error info."""
try: try:
response.raise_for_status() response.raise_for_status()
# Check content type # Check content type
content_type = response.headers.get("content-type", "") content_type = response.headers.get('content-type', '')
if "application/json" not in content_type: if 'application/json' not in content_type:
if response.status == 204: if response.status == 204:
return {} return {}
text = await response.text() text = await response.text()
raise ValueError( raise ValueError(f'Unexpected content type: {content_type}, body: {text[:200]}...')
f"Unexpected content type: {content_type}, body: {text[:200]}..."
)
return await response.json() return await response.json()
except aiohttp.ClientResponseError as e: except aiohttp.ClientResponseError as e:
error_text = await response.text() if response else "No response" error_text = await response.text() if response else 'No response'
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}") log.error(f'HTTP {e.status}: {e.message} - Response: {error_text[:500]}')
raise raise
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
log.error(f"Client error: {e}") log.error(f'Client error: {e}')
raise raise
except Exception as e: except Exception as e:
log.error(f"Unexpected error processing response: {e}") log.error(f'Unexpected error processing response: {e}')
raise raise
def _is_retryable_error(self, error: Exception) -> bool: def _is_retryable_error(self, error: Exception) -> bool:
@@ -172,13 +158,11 @@ class MistralLoader:
return True # Timeouts might resolve on retry return True # Timeouts might resolve on retry
if isinstance(error, requests.exceptions.HTTPError): if isinstance(error, requests.exceptions.HTTPError):
# Only retry on server errors (5xx) or rate limits (429) # Only retry on server errors (5xx) or rate limits (429)
if hasattr(error, "response") and error.response is not None: if hasattr(error, 'response') and error.response is not None:
status_code = error.response.status_code status_code = error.response.status_code
return status_code >= 500 or status_code == 429 return status_code >= 500 or status_code == 429
return False return False
if isinstance( if isinstance(error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)):
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
):
return True # Async network/timeout errors are retryable return True # Async network/timeout errors are retryable
if isinstance(error, aiohttp.ClientResponseError): if isinstance(error, aiohttp.ClientResponseError):
return error.status >= 500 or error.status == 429 return error.status >= 500 or error.status == 429
@@ -204,8 +188,7 @@ class MistralLoader:
# Prevents overwhelming the server while ensuring reasonable retry delays # Prevents overwhelming the server while ensuring reasonable retry delays
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning( log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
f"Retrying in {wait_time}s..."
) )
time.sleep(wait_time) time.sleep(wait_time)
@@ -226,8 +209,7 @@ class MistralLoader:
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff # PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
log.warning( log.warning(
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
f"Retrying in {wait_time}s..."
) )
await asyncio.sleep(wait_time) # Non-blocking wait await asyncio.sleep(wait_time) # Non-blocking wait
@@ -240,15 +222,15 @@ class MistralLoader:
Although streaming is not enabled for this endpoint, the file is opened Although streaming is not enabled for this endpoint, the file is opened
in a context manager to minimize memory usage duration. in a context manager to minimize memory usage duration.
""" """
log.info("Uploading file to Mistral API") log.info('Uploading file to Mistral API')
url = f"{self.base_url}/files" url = f'{self.base_url}/files'
def upload_request(): def upload_request():
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime # MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
# This ensures the file is closed immediately after reading, reducing memory usage # This ensures the file is closed immediately after reading, reducing memory usage
with open(self.file_path, "rb") as f: with open(self.file_path, 'rb') as f:
files = {"file": (self.file_name, f, "application/pdf")} files = {'file': (self.file_name, f, 'application/pdf')}
data = {"purpose": "ocr"} data = {'purpose': 'ocr'}
# NOTE: stream=False is required for this endpoint # NOTE: stream=False is required for this endpoint
# The Mistral API doesn't support chunked uploads for this endpoint # The Mistral API doesn't support chunked uploads for this endpoint
@@ -265,42 +247,38 @@ class MistralLoader:
try: try:
response_data = self._retry_request_sync(upload_request) response_data = self._retry_request_sync(upload_request)
file_id = response_data.get("id") file_id = response_data.get('id')
if not file_id: if not file_id:
raise ValueError("File ID not found in upload response.") raise ValueError('File ID not found in upload response.')
log.info(f"File uploaded successfully. File ID: {file_id}") log.info(f'File uploaded successfully. File ID: {file_id}')
return file_id return file_id
except Exception as e: except Exception as e:
log.error(f"Failed to upload file: {e}") log.error(f'Failed to upload file: {e}')
raise raise
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str: async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
"""Async file upload with streaming for better memory efficiency.""" """Async file upload with streaming for better memory efficiency."""
url = f"{self.base_url}/files" url = f'{self.base_url}/files'
async def upload_request(): async def upload_request():
# Create multipart writer for streaming upload # Create multipart writer for streaming upload
writer = aiohttp.MultipartWriter("form-data") writer = aiohttp.MultipartWriter('form-data')
# Add purpose field # Add purpose field
purpose_part = writer.append("ocr") purpose_part = writer.append('ocr')
purpose_part.set_content_disposition("form-data", name="purpose") purpose_part.set_content_disposition('form-data', name='purpose')
# Add file part with streaming # Add file part with streaming
file_part = writer.append_payload( file_part = writer.append_payload(
aiohttp.streams.FilePayload( aiohttp.streams.FilePayload(
self.file_path, self.file_path,
filename=self.file_name, filename=self.file_name,
content_type="application/pdf", content_type='application/pdf',
) )
) )
file_part.set_content_disposition( file_part.set_content_disposition('form-data', name='file', filename=self.file_name)
"form-data", name="file", filename=self.file_name
)
self._debug_log( self._debug_log(f'Uploading file: {self.file_name} ({self.file_size:,} bytes)')
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
)
async with session.post( async with session.post(
url, url,
@@ -312,48 +290,44 @@ class MistralLoader:
response_data = await self._retry_request_async(upload_request) response_data = await self._retry_request_async(upload_request)
file_id = response_data.get("id") file_id = response_data.get('id')
if not file_id: if not file_id:
raise ValueError("File ID not found in upload response.") raise ValueError('File ID not found in upload response.')
log.info(f"File uploaded successfully. File ID: {file_id}") log.info(f'File uploaded successfully. File ID: {file_id}')
return file_id return file_id
def _get_signed_url(self, file_id: str) -> str: def _get_signed_url(self, file_id: str) -> str:
"""Retrieves a temporary signed URL for the uploaded file (sync version).""" """Retrieves a temporary signed URL for the uploaded file (sync version)."""
log.info(f"Getting signed URL for file ID: {file_id}") log.info(f'Getting signed URL for file ID: {file_id}')
url = f"{self.base_url}/files/{file_id}/url" url = f'{self.base_url}/files/{file_id}/url'
params = {"expiry": 1} params = {'expiry': 1}
signed_url_headers = {**self.headers, "Accept": "application/json"} signed_url_headers = {**self.headers, 'Accept': 'application/json'}
def url_request(): def url_request():
response = requests.get( response = requests.get(url, headers=signed_url_headers, params=params, timeout=self.url_timeout)
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
)
return self._handle_response(response) return self._handle_response(response)
try: try:
response_data = self._retry_request_sync(url_request) response_data = self._retry_request_sync(url_request)
signed_url = response_data.get("url") signed_url = response_data.get('url')
if not signed_url: if not signed_url:
raise ValueError("Signed URL not found in response.") raise ValueError('Signed URL not found in response.')
log.info("Signed URL received.") log.info('Signed URL received.')
return signed_url return signed_url
except Exception as e: except Exception as e:
log.error(f"Failed to get signed URL: {e}") log.error(f'Failed to get signed URL: {e}')
raise raise
async def _get_signed_url_async( async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str:
self, session: aiohttp.ClientSession, file_id: str
) -> str:
"""Async signed URL retrieval.""" """Async signed URL retrieval."""
url = f"{self.base_url}/files/{file_id}/url" url = f'{self.base_url}/files/{file_id}/url'
params = {"expiry": 1} params = {'expiry': 1}
headers = {**self.headers, "Accept": "application/json"} headers = {**self.headers, 'Accept': 'application/json'}
async def url_request(): async def url_request():
self._debug_log(f"Getting signed URL for file ID: {file_id}") self._debug_log(f'Getting signed URL for file ID: {file_id}')
async with session.get( async with session.get(
url, url,
headers=headers, headers=headers,
@@ -364,69 +338,65 @@ class MistralLoader:
response_data = await self._retry_request_async(url_request) response_data = await self._retry_request_async(url_request)
signed_url = response_data.get("url") signed_url = response_data.get('url')
if not signed_url: if not signed_url:
raise ValueError("Signed URL not found in response.") raise ValueError('Signed URL not found in response.')
self._debug_log("Signed URL received successfully") self._debug_log('Signed URL received successfully')
return signed_url return signed_url
def _process_ocr(self, signed_url: str) -> Dict[str, Any]: def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
"""Sends the signed URL to the OCR endpoint for processing (sync version).""" """Sends the signed URL to the OCR endpoint for processing (sync version)."""
log.info("Processing OCR via Mistral API") log.info('Processing OCR via Mistral API')
url = f"{self.base_url}/ocr" url = f'{self.base_url}/ocr'
ocr_headers = { ocr_headers = {
**self.headers, **self.headers,
"Content-Type": "application/json", 'Content-Type': 'application/json',
"Accept": "application/json", 'Accept': 'application/json',
} }
payload = { payload = {
"model": "mistral-ocr-latest", 'model': 'mistral-ocr-latest',
"document": { 'document': {
"type": "document_url", 'type': 'document_url',
"document_url": signed_url, 'document_url': signed_url,
}, },
"include_image_base64": False, 'include_image_base64': False,
} }
def ocr_request(): def ocr_request():
response = requests.post( response = requests.post(url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout)
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
)
return self._handle_response(response) return self._handle_response(response)
try: try:
ocr_response = self._retry_request_sync(ocr_request) ocr_response = self._retry_request_sync(ocr_request)
log.info("OCR processing done.") log.info('OCR processing done.')
self._debug_log("OCR response: %s", ocr_response) self._debug_log('OCR response: %s', ocr_response)
return ocr_response return ocr_response
except Exception as e: except Exception as e:
log.error(f"Failed during OCR processing: {e}") log.error(f'Failed during OCR processing: {e}')
raise raise
async def _process_ocr_async( async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]:
self, session: aiohttp.ClientSession, signed_url: str
) -> Dict[str, Any]:
"""Async OCR processing with timing metrics.""" """Async OCR processing with timing metrics."""
url = f"{self.base_url}/ocr" url = f'{self.base_url}/ocr'
headers = { headers = {
**self.headers, **self.headers,
"Content-Type": "application/json", 'Content-Type': 'application/json',
"Accept": "application/json", 'Accept': 'application/json',
} }
payload = { payload = {
"model": "mistral-ocr-latest", 'model': 'mistral-ocr-latest',
"document": { 'document': {
"type": "document_url", 'type': 'document_url',
"document_url": signed_url, 'document_url': signed_url,
}, },
"include_image_base64": False, 'include_image_base64': False,
} }
async def ocr_request(): async def ocr_request():
log.info("Starting OCR processing via Mistral API") log.info('Starting OCR processing via Mistral API')
start_time = time.time() start_time = time.time()
async with session.post( async with session.post(
@@ -438,7 +408,7 @@ class MistralLoader:
ocr_response = await self._handle_response_async(response) ocr_response = await self._handle_response_async(response)
processing_time = time.time() - start_time processing_time = time.time() - start_time
log.info(f"OCR processing completed in {processing_time:.2f}s") log.info(f'OCR processing completed in {processing_time:.2f}s')
return ocr_response return ocr_response
@@ -446,42 +416,36 @@ class MistralLoader:
def _delete_file(self, file_id: str) -> None: def _delete_file(self, file_id: str) -> None:
"""Deletes the file from Mistral storage (sync version).""" """Deletes the file from Mistral storage (sync version)."""
log.info(f"Deleting uploaded file ID: {file_id}") log.info(f'Deleting uploaded file ID: {file_id}')
url = f"{self.base_url}/files/{file_id}" url = f'{self.base_url}/files/{file_id}'
try: try:
response = requests.delete( response = requests.delete(url, headers=self.headers, timeout=self.cleanup_timeout)
url, headers=self.headers, timeout=self.cleanup_timeout
)
delete_response = self._handle_response(response) delete_response = self._handle_response(response)
log.info(f"File deleted successfully: {delete_response}") log.info(f'File deleted successfully: {delete_response}')
except Exception as e: except Exception as e:
# Log error but don't necessarily halt execution if deletion fails # Log error but don't necessarily halt execution if deletion fails
log.error(f"Failed to delete file ID {file_id}: {e}") log.error(f'Failed to delete file ID {file_id}: {e}')
async def _delete_file_async( async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None:
self, session: aiohttp.ClientSession, file_id: str
) -> None:
"""Async file deletion with error tolerance.""" """Async file deletion with error tolerance."""
try: try:
async def delete_request(): async def delete_request():
self._debug_log(f"Deleting file ID: {file_id}") self._debug_log(f'Deleting file ID: {file_id}')
async with session.delete( async with session.delete(
url=f"{self.base_url}/files/{file_id}", url=f'{self.base_url}/files/{file_id}',
headers=self.headers, headers=self.headers,
timeout=aiohttp.ClientTimeout( timeout=aiohttp.ClientTimeout(total=self.cleanup_timeout), # Shorter timeout for cleanup
total=self.cleanup_timeout
), # Shorter timeout for cleanup
) as response: ) as response:
return await self._handle_response_async(response) return await self._handle_response_async(response)
await self._retry_request_async(delete_request) await self._retry_request_async(delete_request)
self._debug_log(f"File {file_id} deleted successfully") self._debug_log(f'File {file_id} deleted successfully')
except Exception as e: except Exception as e:
# Don't fail the entire process if cleanup fails # Don't fail the entire process if cleanup fails
log.warning(f"Failed to delete file ID {file_id}: {e}") log.warning(f'Failed to delete file ID {file_id}: {e}')
@asynccontextmanager @asynccontextmanager
async def _get_session(self): async def _get_session(self):
@@ -506,7 +470,7 @@ class MistralLoader:
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
connector=connector, connector=connector,
timeout=timeout, timeout=timeout,
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}, headers={'User-Agent': 'OpenWebUI-MistralLoader/2.0'},
raise_for_status=False, # We handle status codes manually raise_for_status=False, # We handle status codes manually
trust_env=True, trust_env=True,
) as session: ) as session:
@@ -514,13 +478,13 @@ class MistralLoader:
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
"""Process OCR results into Document objects with enhanced metadata and memory efficiency.""" """Process OCR results into Document objects with enhanced metadata and memory efficiency."""
pages_data = ocr_response.get("pages") pages_data = ocr_response.get('pages')
if not pages_data: if not pages_data:
log.warning("No pages found in OCR response.") log.warning('No pages found in OCR response.')
return [ return [
Document( Document(
page_content="No text content found", page_content='No text content found',
metadata={"error": "no_pages", "file_name": self.file_name}, metadata={'error': 'no_pages', 'file_name': self.file_name},
) )
] ]
@@ -530,8 +494,8 @@ class MistralLoader:
# Process pages in a memory-efficient way # Process pages in a memory-efficient way
for page_data in pages_data: for page_data in pages_data:
page_content = page_data.get("markdown") page_content = page_data.get('markdown')
page_index = page_data.get("index") # API uses 0-based index page_index = page_data.get('index') # API uses 0-based index
if page_content is None or page_index is None: if page_content is None or page_index is None:
skipped_pages += 1 skipped_pages += 1
@@ -548,7 +512,7 @@ class MistralLoader:
if not cleaned_content: if not cleaned_content:
skipped_pages += 1 skipped_pages += 1
self._debug_log(f"Skipping empty page {page_index}") self._debug_log(f'Skipping empty page {page_index}')
continue continue
# Create document with optimized metadata # Create document with optimized metadata
@@ -556,34 +520,30 @@ class MistralLoader:
Document( Document(
page_content=cleaned_content, page_content=cleaned_content,
metadata={ metadata={
"page": page_index, # 0-based index from API 'page': page_index, # 0-based index from API
"page_label": page_index + 1, # 1-based label for convenience 'page_label': page_index + 1, # 1-based label for convenience
"total_pages": total_pages, 'total_pages': total_pages,
"file_name": self.file_name, 'file_name': self.file_name,
"file_size": self.file_size, 'file_size': self.file_size,
"processing_engine": "mistral-ocr", 'processing_engine': 'mistral-ocr',
"content_length": len(cleaned_content), 'content_length': len(cleaned_content),
}, },
) )
) )
if skipped_pages > 0: if skipped_pages > 0:
log.info( log.info(f'Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages')
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
)
if not documents: if not documents:
# Case where pages existed but none had valid markdown/index # Case where pages existed but none had valid markdown/index
log.warning( log.warning('OCR response contained pages, but none had valid content/index.')
"OCR response contained pages, but none had valid content/index."
)
return [ return [
Document( Document(
page_content="No valid text content found in document", page_content='No valid text content found in document',
metadata={ metadata={
"error": "no_valid_pages", 'error': 'no_valid_pages',
"total_pages": total_pages, 'total_pages': total_pages,
"file_name": self.file_name, 'file_name': self.file_name,
}, },
) )
] ]
@@ -615,24 +575,20 @@ class MistralLoader:
documents = self._process_results(ocr_response) documents = self._process_results(ocr_response)
total_time = time.time() - start_time total_time = time.time() - start_time
log.info( log.info(f'Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
return documents return documents
except Exception as e: except Exception as e:
total_time = time.time() - start_time total_time = time.time() - start_time
log.error( log.error(f'An error occurred during the loading process after {total_time:.2f}s: {e}')
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
)
# Return an error document on failure # Return an error document on failure
return [ return [
Document( Document(
page_content=f"Error during processing: {e}", page_content=f'Error during processing: {e}',
metadata={ metadata={
"error": "processing_failed", 'error': 'processing_failed',
"file_name": self.file_name, 'file_name': self.file_name,
}, },
) )
] ]
@@ -643,9 +599,7 @@ class MistralLoader:
self._delete_file(file_id) self._delete_file(file_id)
except Exception as del_e: except Exception as del_e:
# Log deletion error, but don't overwrite original error if one occurred # Log deletion error, but don't overwrite original error if one occurred
log.error( log.error(f'Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}')
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
)
async def load_async(self) -> List[Document]: async def load_async(self) -> List[Document]:
""" """
@@ -672,21 +626,19 @@ class MistralLoader:
documents = self._process_results(ocr_response) documents = self._process_results(ocr_response)
total_time = time.time() - start_time total_time = time.time() - start_time
log.info( log.info(f'Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
)
return documents return documents
except Exception as e: except Exception as e:
total_time = time.time() - start_time total_time = time.time() - start_time
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}") log.error(f'Async OCR workflow failed after {total_time:.2f}s: {e}')
return [ return [
Document( Document(
page_content=f"Error during OCR processing: {e}", page_content=f'Error during OCR processing: {e}',
metadata={ metadata={
"error": "processing_failed", 'error': 'processing_failed',
"file_name": self.file_name, 'file_name': self.file_name,
}, },
) )
] ]
@@ -697,11 +649,11 @@ class MistralLoader:
async with self._get_session() as session: async with self._get_session() as session:
await self._delete_file_async(session, file_id) await self._delete_file_async(session, file_id)
except Exception as cleanup_error: except Exception as cleanup_error:
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}") log.error(f'Cleanup failed for file ID {file_id}: {cleanup_error}')
@staticmethod @staticmethod
async def load_multiple_async( async def load_multiple_async(
loaders: List["MistralLoader"], loaders: List['MistralLoader'],
max_concurrent: int = 5, # Limit concurrent requests max_concurrent: int = 5, # Limit concurrent requests
) -> List[List[Document]]: ) -> List[List[Document]]:
""" """
@@ -717,15 +669,13 @@ class MistralLoader:
if not loaders: if not loaders:
return [] return []
log.info( log.info(f'Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent')
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
)
start_time = time.time() start_time = time.time()
# Use semaphore to control concurrency # Use semaphore to control concurrency
semaphore = asyncio.Semaphore(max_concurrent) semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]: async def process_with_semaphore(loader: 'MistralLoader') -> List[Document]:
async with semaphore: async with semaphore:
return await loader.load_async() return await loader.load_async()
@@ -737,14 +687,14 @@ class MistralLoader:
processed_results = [] processed_results = []
for i, result in enumerate(results): for i, result in enumerate(results):
if isinstance(result, Exception): if isinstance(result, Exception):
log.error(f"File {i} failed: {result}") log.error(f'File {i} failed: {result}')
processed_results.append( processed_results.append(
[ [
Document( Document(
page_content=f"Error processing file: {result}", page_content=f'Error processing file: {result}',
metadata={ metadata={
"error": "batch_processing_failed", 'error': 'batch_processing_failed',
"file_index": i, 'file_index': i,
}, },
) )
] ]
@@ -755,15 +705,13 @@ class MistralLoader:
# MONITORING: Log comprehensive batch processing statistics # MONITORING: Log comprehensive batch processing statistics
total_time = time.time() - start_time total_time = time.time() - start_time
total_docs = sum(len(docs) for docs in processed_results) total_docs = sum(len(docs) for docs in processed_results)
success_count = sum( success_count = sum(1 for result in results if not isinstance(result, Exception))
1 for result in results if not isinstance(result, Exception)
)
failure_count = len(results) - success_count failure_count = len(results) - success_count
log.info( log.info(
f"Batch processing completed in {total_time:.2f}s: " f'Batch processing completed in {total_time:.2f}s: '
f"{success_count} files succeeded, {failure_count} files failed, " f'{success_count} files succeeded, {failure_count} files failed, '
f"produced {total_docs} total documents" f'produced {total_docs} total documents'
) )
return processed_results return processed_results

View File

@@ -25,7 +25,7 @@ class TavilyLoader(BaseLoader):
self, self,
urls: Union[str, List[str]], urls: Union[str, List[str]],
api_key: str, api_key: str,
extract_depth: Literal["basic", "advanced"] = "basic", extract_depth: Literal['basic', 'advanced'] = 'basic',
continue_on_failure: bool = True, continue_on_failure: bool = True,
) -> None: ) -> None:
"""Initialize Tavily Extract client. """Initialize Tavily Extract client.
@@ -42,13 +42,13 @@ class TavilyLoader(BaseLoader):
continue_on_failure: Whether to continue if extraction of a URL fails. continue_on_failure: Whether to continue if extraction of a URL fails.
""" """
if not urls: if not urls:
raise ValueError("At least one URL must be provided.") raise ValueError('At least one URL must be provided.')
self.api_key = api_key self.api_key = api_key
self.urls = urls if isinstance(urls, list) else [urls] self.urls = urls if isinstance(urls, list) else [urls]
self.extract_depth = extract_depth self.extract_depth = extract_depth
self.continue_on_failure = continue_on_failure self.continue_on_failure = continue_on_failure
self.api_url = "https://api.tavily.com/extract" self.api_url = 'https://api.tavily.com/extract'
def lazy_load(self) -> Iterator[Document]: def lazy_load(self) -> Iterator[Document]:
"""Extract and yield documents from the URLs using Tavily Extract API.""" """Extract and yield documents from the URLs using Tavily Extract API."""
@@ -57,35 +57,35 @@ class TavilyLoader(BaseLoader):
batch_urls = self.urls[i : i + batch_size] batch_urls = self.urls[i : i + batch_size]
try: try:
headers = { headers = {
"Content-Type": "application/json", 'Content-Type': 'application/json',
"Authorization": f"Bearer {self.api_key}", 'Authorization': f'Bearer {self.api_key}',
} }
# Use string for single URL, array for multiple URLs # Use string for single URL, array for multiple URLs
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
payload = {"urls": urls_param, "extract_depth": self.extract_depth} payload = {'urls': urls_param, 'extract_depth': self.extract_depth}
# Make the API call # Make the API call
response = requests.post(self.api_url, headers=headers, json=payload) response = requests.post(self.api_url, headers=headers, json=payload)
response.raise_for_status() response.raise_for_status()
response_data = response.json() response_data = response.json()
# Process successful results # Process successful results
for result in response_data.get("results", []): for result in response_data.get('results', []):
url = result.get("url", "") url = result.get('url', '')
content = result.get("raw_content", "") content = result.get('raw_content', '')
if not content: if not content:
log.warning(f"No content extracted from {url}") log.warning(f'No content extracted from {url}')
continue continue
# Add URLs as metadata # Add URLs as metadata
metadata = {"source": url} metadata = {'source': url}
yield Document( yield Document(
page_content=content, page_content=content,
metadata=metadata, metadata=metadata,
) )
for failed in response_data.get("failed_results", []): for failed in response_data.get('failed_results', []):
url = failed.get("url", "") url = failed.get('url', '')
error = failed.get("error", "Unknown error") error = failed.get('error', 'Unknown error')
log.error(f"Failed to extract content from {url}: {error}") log.error(f'Failed to extract content from {url}: {error}')
except Exception as e: except Exception as e:
if self.continue_on_failure: if self.continue_on_failure:
log.error(f"Error extracting content from batch {batch_urls}: {e}") log.error(f'Error extracting content from batch {batch_urls}: {e}')
else: else:
raise e raise e

View File

@@ -7,14 +7,14 @@ from langchain_core.documents import Document
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ALLOWED_SCHEMES = {"http", "https"} ALLOWED_SCHEMES = {'http', 'https'}
ALLOWED_NETLOCS = { ALLOWED_NETLOCS = {
"youtu.be", 'youtu.be',
"m.youtube.com", 'm.youtube.com',
"youtube.com", 'youtube.com',
"www.youtube.com", 'www.youtube.com',
"www.youtube-nocookie.com", 'www.youtube-nocookie.com',
"vid.plus", 'vid.plus',
} }
@@ -30,17 +30,17 @@ def _parse_video_id(url: str) -> Optional[str]:
path = parsed_url.path path = parsed_url.path
if path.endswith("/watch"): if path.endswith('/watch'):
query = parsed_url.query query = parsed_url.query
parsed_query = parse_qs(query) parsed_query = parse_qs(query)
if "v" in parsed_query: if 'v' in parsed_query:
ids = parsed_query["v"] ids = parsed_query['v']
video_id = ids if isinstance(ids, str) else ids[0] video_id = ids if isinstance(ids, str) else ids[0]
else: else:
return None return None
else: else:
path = parsed_url.path.lstrip("/") path = parsed_url.path.lstrip('/')
video_id = path.split("/")[-1] video_id = path.split('/')[-1]
if len(video_id) != 11: # Video IDs are 11 characters long if len(video_id) != 11: # Video IDs are 11 characters long
return None return None
@@ -54,13 +54,13 @@ class YoutubeLoader:
def __init__( def __init__(
self, self,
video_id: str, video_id: str,
language: Union[str, Sequence[str]] = "en", language: Union[str, Sequence[str]] = 'en',
proxy_url: Optional[str] = None, proxy_url: Optional[str] = None,
): ):
"""Initialize with YouTube video ID.""" """Initialize with YouTube video ID."""
_video_id = _parse_video_id(video_id) _video_id = _parse_video_id(video_id)
self.video_id = _video_id if _video_id is not None else video_id self.video_id = _video_id if _video_id is not None else video_id
self._metadata = {"source": video_id} self._metadata = {'source': video_id}
self.proxy_url = proxy_url self.proxy_url = proxy_url
# Ensure language is a list # Ensure language is a list
@@ -70,8 +70,8 @@ class YoutubeLoader:
self.language = list(language) self.language = list(language)
# Add English as fallback if not already in the list # Add English as fallback if not already in the list
if "en" not in self.language: if 'en' not in self.language:
self.language.append("en") self.language.append('en')
def load(self) -> List[Document]: def load(self) -> List[Document]:
"""Load YouTube transcripts into `Document` objects.""" """Load YouTube transcripts into `Document` objects."""
@@ -85,14 +85,12 @@ class YoutubeLoader:
except ImportError: except ImportError:
raise ImportError( raise ImportError(
'Could not import "youtube_transcript_api" Python package. ' 'Could not import "youtube_transcript_api" Python package. '
"Please install it with `pip install youtube-transcript-api`." 'Please install it with `pip install youtube-transcript-api`.'
) )
if self.proxy_url: if self.proxy_url:
youtube_proxies = GenericProxyConfig( youtube_proxies = GenericProxyConfig(http_url=self.proxy_url, https_url=self.proxy_url)
http_url=self.proxy_url, https_url=self.proxy_url log.debug(f'Using proxy URL: {self.proxy_url[:14]}...')
)
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
else: else:
youtube_proxies = None youtube_proxies = None
@@ -100,7 +98,7 @@ class YoutubeLoader:
try: try:
transcript_list = transcript_api.list(self.video_id) transcript_list = transcript_api.list(self.video_id)
except Exception as e: except Exception as e:
log.exception("Loading YouTube transcript failed") log.exception('Loading YouTube transcript failed')
return [] return []
# Try each language in order of priority # Try each language in order of priority
@@ -110,14 +108,10 @@ class YoutubeLoader:
if transcript.is_generated: if transcript.is_generated:
log.debug(f"Found generated transcript for language '{lang}'") log.debug(f"Found generated transcript for language '{lang}'")
try: try:
transcript = transcript_list.find_manually_created_transcript( transcript = transcript_list.find_manually_created_transcript([lang])
[lang]
)
log.debug(f"Found manual transcript for language '{lang}'") log.debug(f"Found manual transcript for language '{lang}'")
except NoTranscriptFound: except NoTranscriptFound:
log.debug( log.debug(f"No manual transcript found for language '{lang}', using generated")
f"No manual transcript found for language '{lang}', using generated"
)
pass pass
log.debug(f"Found transcript for language '{lang}'") log.debug(f"Found transcript for language '{lang}'")
@@ -131,12 +125,10 @@ class YoutubeLoader:
log.debug(f"Empty transcript for language '{lang}'") log.debug(f"Empty transcript for language '{lang}'")
continue continue
transcript_text = " ".join( transcript_text = ' '.join(
map( map(
lambda transcript_piece: ( lambda transcript_piece: (
transcript_piece.text.strip(" ") transcript_piece.text.strip(' ') if hasattr(transcript_piece, 'text') else ''
if hasattr(transcript_piece, "text")
else ""
), ),
transcript_pieces, transcript_pieces,
) )
@@ -150,9 +142,9 @@ class YoutubeLoader:
raise e raise e
# If we get here, all languages failed # If we get here, all languages failed
languages_tried = ", ".join(self.language) languages_tried = ', '.join(self.language)
log.warning( log.warning(
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed." f'No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed.'
) )
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list)) raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))

View File

@@ -13,19 +13,17 @@ log = logging.getLogger(__name__)
class ColBERT(BaseReranker): class ColBERT(BaseReranker):
def __init__(self, name, **kwargs) -> None: def __init__(self, name, **kwargs) -> None:
log.info("ColBERT: Loading model", name) log.info('ColBERT: Loading model', name)
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
DOCKER = kwargs.get("env") == "docker" DOCKER = kwargs.get('env') == 'docker'
if DOCKER: if DOCKER:
# This is a workaround for the issue with the docker container # This is a workaround for the issue with the docker container
# where the torch extension is not loaded properly # where the torch extension is not loaded properly
# and the following error is thrown: # and the following error is thrown:
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory # /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
lock_file = ( lock_file = '/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock'
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
)
if os.path.exists(lock_file): if os.path.exists(lock_file):
os.remove(lock_file) os.remove(lock_file)
@@ -36,23 +34,16 @@ class ColBERT(BaseReranker):
pass pass
def calculate_similarity_scores(self, query_embeddings, document_embeddings): def calculate_similarity_scores(self, query_embeddings, document_embeddings):
query_embeddings = query_embeddings.to(self.device) query_embeddings = query_embeddings.to(self.device)
document_embeddings = document_embeddings.to(self.device) document_embeddings = document_embeddings.to(self.device)
# Validate dimensions to ensure compatibility # Validate dimensions to ensure compatibility
if query_embeddings.dim() != 3: if query_embeddings.dim() != 3:
raise ValueError( raise ValueError(f'Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}.')
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
)
if document_embeddings.dim() != 3: if document_embeddings.dim() != 3:
raise ValueError( raise ValueError(f'Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}.')
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
)
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]: if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
raise ValueError( raise ValueError('There should be either one query or queries equal to the number of documents.')
"There should be either one query or queries equal to the number of documents."
)
# Transpose the query embeddings to align for matrix multiplication # Transpose the query embeddings to align for matrix multiplication
transposed_query_embeddings = query_embeddings.permute(0, 2, 1) transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
@@ -69,7 +60,6 @@ class ColBERT(BaseReranker):
return normalized_scores.detach().cpu().numpy().astype(np.float32) return normalized_scores.detach().cpu().numpy().astype(np.float32)
def predict(self, sentences): def predict(self, sentences):
query = sentences[0][0] query = sentences[0][0]
docs = [i[1] for i in sentences] docs = [i[1] for i in sentences]
@@ -80,8 +70,6 @@ class ColBERT(BaseReranker):
embedded_query = embedded_queries[0] embedded_query = embedded_queries[0]
# Calculate retrieval scores for the query against all documents # Calculate retrieval scores for the query against all documents
scores = self.calculate_similarity_scores( scores = self.calculate_similarity_scores(embedded_query.unsqueeze(0), embedded_docs)
embedded_query.unsqueeze(0), embedded_docs
)
return scores return scores

View File

@@ -15,8 +15,8 @@ class ExternalReranker(BaseReranker):
def __init__( def __init__(
self, self,
api_key: str, api_key: str,
url: str = "http://localhost:8080/v1/rerank", url: str = 'http://localhost:8080/v1/rerank',
model: str = "reranker", model: str = 'reranker',
timeout: Optional[int] = None, timeout: Optional[int] = None,
): ):
self.api_key = api_key self.api_key = api_key
@@ -24,33 +24,31 @@ class ExternalReranker(BaseReranker):
self.model = model self.model = model
self.timeout = timeout self.timeout = timeout
def predict( def predict(self, sentences: List[Tuple[str, str]], user=None) -> Optional[List[float]]:
self, sentences: List[Tuple[str, str]], user=None
) -> Optional[List[float]]:
query = sentences[0][0] query = sentences[0][0]
docs = [i[1] for i in sentences] docs = [i[1] for i in sentences]
payload = { payload = {
"model": self.model, 'model': self.model,
"query": query, 'query': query,
"documents": docs, 'documents': docs,
"top_n": len(docs), 'top_n': len(docs),
} }
try: try:
log.info(f"ExternalReranker:predict:model {self.model}") log.info(f'ExternalReranker:predict:model {self.model}')
log.info(f"ExternalReranker:predict:query {query}") log.info(f'ExternalReranker:predict:query {query}')
headers = { headers = {
"Content-Type": "application/json", 'Content-Type': 'application/json',
"Authorization": f"Bearer {self.api_key}", 'Authorization': f'Bearer {self.api_key}',
} }
if ENABLE_FORWARD_USER_INFO_HEADERS and user: if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user) headers = include_user_info_headers(headers, user)
r = requests.post( r = requests.post(
f"{self.url}", f'{self.url}',
headers=headers, headers=headers,
json=payload, json=payload,
timeout=self.timeout, timeout=self.timeout,
@@ -60,13 +58,13 @@ class ExternalReranker(BaseReranker):
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
if "results" in data: if 'results' in data:
sorted_results = sorted(data["results"], key=lambda x: x["index"]) sorted_results = sorted(data['results'], key=lambda x: x['index'])
return [result["relevance_score"] for result in sorted_results] return [result['relevance_score'] for result in sorted_results]
else: else:
log.error("No results found in external reranking response") log.error('No results found in external reranking response')
return None return None
except Exception as e: except Exception as e:
log.exception(f"Error in external reranking: {e}") log.exception(f'Error in external reranking: {e}')
return None return None

File diff suppressed because it is too large Load Diff

View File

@@ -31,17 +31,15 @@ log = logging.getLogger(__name__)
class ChromaClient(VectorDBBase): class ChromaClient(VectorDBBase):
def __init__(self): def __init__(self):
settings_dict = { settings_dict = {
"allow_reset": True, 'allow_reset': True,
"anonymized_telemetry": False, 'anonymized_telemetry': False,
} }
if CHROMA_CLIENT_AUTH_PROVIDER is not None: if CHROMA_CLIENT_AUTH_PROVIDER is not None:
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER settings_dict['chroma_client_auth_provider'] = CHROMA_CLIENT_AUTH_PROVIDER
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None: if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
settings_dict["chroma_client_auth_credentials"] = ( settings_dict['chroma_client_auth_credentials'] = CHROMA_CLIENT_AUTH_CREDENTIALS
CHROMA_CLIENT_AUTH_CREDENTIALS
)
if CHROMA_HTTP_HOST != "": if CHROMA_HTTP_HOST != '':
self.client = chromadb.HttpClient( self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST, host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT, port=CHROMA_HTTP_PORT,
@@ -87,25 +85,23 @@ class ChromaClient(VectorDBBase):
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1 # chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
# https://docs.trychroma.com/docs/collections/configure cosine equation # https://docs.trychroma.com/docs/collections/configure cosine equation
distances: list = result["distances"][0] distances: list = result['distances'][0]
distances = [2 - dist for dist in distances] distances = [2 - dist for dist in distances]
distances = [[dist / 2 for dist in distances]] distances = [[dist / 2 for dist in distances]]
return SearchResult( return SearchResult(
**{ **{
"ids": result["ids"], 'ids': result['ids'],
"distances": distances, 'distances': distances,
"documents": result["documents"], 'documents': result['documents'],
"metadatas": result["metadatas"], 'metadatas': result['metadatas'],
} }
) )
return None return None
except Exception as e: except Exception as e:
return None return None
def query( def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
# Query the items from the collection based on the filter. # Query the items from the collection based on the filter.
try: try:
collection = self.client.get_collection(name=collection_name) collection = self.client.get_collection(name=collection_name)
@@ -117,9 +113,9 @@ class ChromaClient(VectorDBBase):
return GetResult( return GetResult(
**{ **{
"ids": [result["ids"]], 'ids': [result['ids']],
"documents": [result["documents"]], 'documents': [result['documents']],
"metadatas": [result["metadatas"]], 'metadatas': [result['metadatas']],
} }
) )
return None return None
@@ -133,23 +129,21 @@ class ChromaClient(VectorDBBase):
result = collection.get() result = collection.get()
return GetResult( return GetResult(
**{ **{
"ids": [result["ids"]], 'ids': [result['ids']],
"documents": [result["documents"]], 'documents': [result['documents']],
"metadatas": [result["metadatas"]], 'metadatas': [result['metadatas']],
} }
) )
return None return None
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.
collection = self.client.get_or_create_collection( collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items] ids = [item['id'] for item in items]
documents = [item["text"] for item in items] documents = [item['text'] for item in items]
embeddings = [item["vector"] for item in items] embeddings = [item['vector'] for item in items]
metadatas = [process_metadata(item["metadata"]) for item in items] metadatas = [process_metadata(item['metadata']) for item in items]
for batch in create_batches( for batch in create_batches(
api=self.client, api=self.client,
@@ -162,18 +156,14 @@ class ChromaClient(VectorDBBase):
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection = self.client.get_or_create_collection( collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
name=collection_name, metadata={"hnsw:space": "cosine"}
)
ids = [item["id"] for item in items] ids = [item['id'] for item in items]
documents = [item["text"] for item in items] documents = [item['text'] for item in items]
embeddings = [item["vector"] for item in items] embeddings = [item['vector'] for item in items]
metadatas = [process_metadata(item["metadata"]) for item in items] metadatas = [process_metadata(item['metadata']) for item in items]
collection.upsert( collection.upsert(ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas)
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
)
def delete( def delete(
self, self,
@@ -191,9 +181,7 @@ class ChromaClient(VectorDBBase):
collection.delete(where=filter) collection.delete(where=filter)
except Exception as e: except Exception as e:
# If collection doesn't exist, that's fine - nothing to delete # If collection doesn't exist, that's fine - nothing to delete
log.debug( log.debug(f'Attempted to delete from non-existent collection {collection_name}. Ignoring.')
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
)
pass pass
def reset(self): def reset(self):

View File

@@ -51,7 +51,7 @@ class ElasticsearchClient(VectorDBBase):
# Status: works # Status: works
def _get_index_name(self, dimension: int) -> str: def _get_index_name(self, dimension: int) -> str:
return f"{self.index_prefix}_d{str(dimension)}" return f'{self.index_prefix}_d{str(dimension)}'
# Status: works # Status: works
def _scan_result_to_get_result(self, result) -> GetResult: def _scan_result_to_get_result(self, result) -> GetResult:
@@ -62,24 +62,24 @@ class ElasticsearchClient(VectorDBBase):
metadatas = [] metadatas = []
for hit in result: for hit in result:
ids.append(hit["_id"]) ids.append(hit['_id'])
documents.append(hit["_source"].get("text")) documents.append(hit['_source'].get('text'))
metadatas.append(hit["_source"].get("metadata")) metadatas.append(hit['_source'].get('metadata'))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
# Status: works # Status: works
def _result_to_get_result(self, result) -> GetResult: def _result_to_get_result(self, result) -> GetResult:
if not result["hits"]["hits"]: if not result['hits']['hits']:
return None return None
ids = [] ids = []
documents = [] documents = []
metadatas = [] metadatas = []
for hit in result["hits"]["hits"]: for hit in result['hits']['hits']:
ids.append(hit["_id"]) ids.append(hit['_id'])
documents.append(hit["_source"].get("text")) documents.append(hit['_source'].get('text'))
metadatas.append(hit["_source"].get("metadata")) metadatas.append(hit['_source'].get('metadata'))
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
@@ -90,11 +90,11 @@ class ElasticsearchClient(VectorDBBase):
documents = [] documents = []
metadatas = [] metadatas = []
for hit in result["hits"]["hits"]: for hit in result['hits']['hits']:
ids.append(hit["_id"]) ids.append(hit['_id'])
distances.append(hit["_score"]) distances.append(hit['_score'])
documents.append(hit["_source"].get("text")) documents.append(hit['_source'].get('text'))
metadatas.append(hit["_source"].get("metadata")) metadatas.append(hit['_source'].get('metadata'))
return SearchResult( return SearchResult(
ids=[ids], ids=[ids],
@@ -106,26 +106,26 @@ class ElasticsearchClient(VectorDBBase):
# Status: works # Status: works
def _create_index(self, dimension: int): def _create_index(self, dimension: int):
body = { body = {
"mappings": { 'mappings': {
"dynamic_templates": [ 'dynamic_templates': [
{ {
"strings": { 'strings': {
"match_mapping_type": "string", 'match_mapping_type': 'string',
"mapping": {"type": "keyword"}, 'mapping': {'type': 'keyword'},
} }
} }
], ],
"properties": { 'properties': {
"collection": {"type": "keyword"}, 'collection': {'type': 'keyword'},
"id": {"type": "keyword"}, 'id': {'type': 'keyword'},
"vector": { 'vector': {
"type": "dense_vector", 'type': 'dense_vector',
"dims": dimension, # Adjust based on your vector dimensions 'dims': dimension, # Adjust based on your vector dimensions
"index": True, 'index': True,
"similarity": "cosine", 'similarity': 'cosine',
}, },
"text": {"type": "text"}, 'text': {'type': 'text'},
"metadata": {"type": "object"}, 'metadata': {'type': 'object'},
}, },
} }
} }
@@ -139,21 +139,19 @@ class ElasticsearchClient(VectorDBBase):
# Status: works # Status: works
def has_collection(self, collection_name) -> bool: def has_collection(self, collection_name) -> bool:
query_body = {"query": {"bool": {"filter": []}}} query_body = {'query': {'bool': {'filter': []}}}
query_body["query"]["bool"]["filter"].append( query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
{"term": {"collection": collection_name}}
)
try: try:
result = self.client.count(index=f"{self.index_prefix}*", body=query_body) result = self.client.count(index=f'{self.index_prefix}*', body=query_body)
return result.body["count"] > 0 return result.body['count'] > 0
except Exception as e: except Exception as e:
return None return None
def delete_collection(self, collection_name: str): def delete_collection(self, collection_name: str):
query = {"query": {"term": {"collection": collection_name}}} query = {'query': {'term': {'collection': collection_name}}}
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query) self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
# Status: works # Status: works
def search( def search(
@@ -164,51 +162,41 @@ class ElasticsearchClient(VectorDBBase):
limit: int = 10, limit: int = 10,
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
query = { query = {
"size": limit, 'size': limit,
"_source": ["text", "metadata"], '_source': ['text', 'metadata'],
"query": { 'query': {
"script_score": { 'script_score': {
"query": { 'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
"bool": {"filter": [{"term": {"collection": collection_name}}]} 'script': {
}, 'source': "cosineSimilarity(params.vector, 'vector') + 1.0",
"script": { 'params': {'vector': vectors[0]}, # Assuming single query vector
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
"params": {
"vector": vectors[0]
}, # Assuming single query vector
}, },
} }
}, },
} }
result = self.client.search( result = self.client.search(index=self._get_index_name(len(vectors[0])), body=query)
index=self._get_index_name(len(vectors[0])), body=query
)
return self._result_to_search_result(result) return self._result_to_search_result(result)
# Status: only tested halfwat # Status: only tested halfwat
def query( def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
self, collection_name: str, filter: dict, limit: Optional[int] = None
) -> Optional[GetResult]:
if not self.has_collection(collection_name): if not self.has_collection(collection_name):
return None return None
query_body = { query_body = {
"query": {"bool": {"filter": []}}, 'query': {'bool': {'filter': []}},
"_source": ["text", "metadata"], '_source': ['text', 'metadata'],
} }
for field, value in filter.items(): for field, value in filter.items():
query_body["query"]["bool"]["filter"].append({"term": {field: value}}) query_body['query']['bool']['filter'].append({'term': {field: value}})
query_body["query"]["bool"]["filter"].append( query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
{"term": {"collection": collection_name}}
)
size = limit if limit else 10 size = limit if limit else 10
try: try:
result = self.client.search( result = self.client.search(
index=f"{self.index_prefix}*", index=f'{self.index_prefix}*',
body=query_body, body=query_body,
size=size, size=size,
) )
@@ -220,9 +208,7 @@ class ElasticsearchClient(VectorDBBase):
# Status: works # Status: works
def _has_index(self, dimension: int): def _has_index(self, dimension: int):
return self.client.indices.exists( return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
index=self._get_index_name(dimension=dimension)
)
def get_or_create_index(self, dimension: int): def get_or_create_index(self, dimension: int):
if not self._has_index(dimension=dimension): if not self._has_index(dimension=dimension):
@@ -232,28 +218,28 @@ class ElasticsearchClient(VectorDBBase):
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. # Get all the items in the collection.
query = { query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}, 'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
"_source": ["text", "metadata"], '_source': ['text', 'metadata'],
} }
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query)) results = list(scan(self.client, index=f'{self.index_prefix}*', query=query))
return self._scan_result_to_get_result(results) return self._scan_result_to_get_result(results)
# Status: works # Status: works
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])): if not self._has_index(dimension=len(items[0]['vector'])):
self._create_index(dimension=len(items[0]["vector"])) self._create_index(dimension=len(items[0]['vector']))
for batch in self._create_batches(items): for batch in self._create_batches(items):
actions = [ actions = [
{ {
"_index": self._get_index_name(dimension=len(items[0]["vector"])), '_index': self._get_index_name(dimension=len(items[0]['vector'])),
"_id": item["id"], '_id': item['id'],
"_source": { '_source': {
"collection": collection_name, 'collection': collection_name,
"vector": item["vector"], 'vector': item['vector'],
"text": item["text"], 'text': item['text'],
"metadata": process_metadata(item["metadata"]), 'metadata': process_metadata(item['metadata']),
}, },
} }
for item in batch for item in batch
@@ -262,21 +248,21 @@ class ElasticsearchClient(VectorDBBase):
# Upsert documents using the update API with doc_as_upsert=True. # Upsert documents using the update API with doc_as_upsert=True.
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
if not self._has_index(dimension=len(items[0]["vector"])): if not self._has_index(dimension=len(items[0]['vector'])):
self._create_index(dimension=len(items[0]["vector"])) self._create_index(dimension=len(items[0]['vector']))
for batch in self._create_batches(items): for batch in self._create_batches(items):
actions = [ actions = [
{ {
"_op_type": "update", '_op_type': 'update',
"_index": self._get_index_name(dimension=len(item["vector"])), '_index': self._get_index_name(dimension=len(item['vector'])),
"_id": item["id"], '_id': item['id'],
"doc": { 'doc': {
"collection": collection_name, 'collection': collection_name,
"vector": item["vector"], 'vector': item['vector'],
"text": item["text"], 'text': item['text'],
"metadata": process_metadata(item["metadata"]), 'metadata': process_metadata(item['metadata']),
}, },
"doc_as_upsert": True, 'doc_as_upsert': True,
} }
for item in batch for item in batch
] ]
@@ -289,22 +275,17 @@ class ElasticsearchClient(VectorDBBase):
ids: Optional[list[str]] = None, ids: Optional[list[str]] = None,
filter: Optional[dict] = None, filter: Optional[dict] = None,
): ):
query = {'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}}}
query = {
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
}
# logic based on chromaDB # logic based on chromaDB
if ids: if ids:
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}}) query['query']['bool']['filter'].append({'terms': {'_id': ids}})
elif filter: elif filter:
for field, value in filter.items(): for field, value in filter.items():
query["query"]["bool"]["filter"].append( query['query']['bool']['filter'].append({'term': {f'metadata.{field}': value}})
{"term": {f"metadata.{field}": value}}
)
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query) self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
def reset(self): def reset(self):
indices = self.client.indices.get(index=f"{self.index_prefix}*") indices = self.client.indices.get(index=f'{self.index_prefix}*')
for index in indices: for index in indices:
self.client.indices.delete(index=index) self.client.indices.delete(index=index)

View File

@@ -47,8 +47,8 @@ def _embedding_to_f32_bytes(vec: List[float]) -> bytes:
byte sequence. We use array('f') to avoid a numpy dependency and byteswap on byte sequence. We use array('f') to avoid a numpy dependency and byteswap on
big-endian platforms for portability. big-endian platforms for portability.
""" """
a = array.array("f", [float(x) for x in vec]) # float32 a = array.array('f', [float(x) for x in vec]) # float32
if sys.byteorder != "little": if sys.byteorder != 'little':
a.byteswap() a.byteswap()
return a.tobytes() return a.tobytes()
@@ -68,7 +68,7 @@ def _safe_json(v: Any) -> Dict[str, Any]:
return v return v
if isinstance(v, (bytes, bytearray)): if isinstance(v, (bytes, bytearray)):
try: try:
v = v.decode("utf-8") v = v.decode('utf-8')
except Exception: except Exception:
return {} return {}
if isinstance(v, str): if isinstance(v, str):
@@ -105,16 +105,16 @@ class MariaDBVectorClient(VectorDBBase):
""" """
self.db_url = (db_url or MARIADB_VECTOR_DB_URL).strip() self.db_url = (db_url or MARIADB_VECTOR_DB_URL).strip()
self.vector_length = int(vector_length) self.vector_length = int(vector_length)
self.distance_strategy = (distance_strategy or "cosine").strip().lower() self.distance_strategy = (distance_strategy or 'cosine').strip().lower()
self.index_m = int(index_m) self.index_m = int(index_m)
if self.distance_strategy not in {"cosine", "euclidean"}: if self.distance_strategy not in {'cosine', 'euclidean'}:
raise ValueError("distance_strategy must be 'cosine' or 'euclidean'") raise ValueError("distance_strategy must be 'cosine' or 'euclidean'")
if not self.db_url.lower().startswith("mariadb+mariadbconnector://"): if not self.db_url.lower().startswith('mariadb+mariadbconnector://'):
raise ValueError( raise ValueError(
"MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) " 'MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) '
"to ensure qmark paramstyle and correct VECTOR binding." 'to ensure qmark paramstyle and correct VECTOR binding.'
) )
if isinstance(MARIADB_VECTOR_POOL_SIZE, int): if isinstance(MARIADB_VECTOR_POOL_SIZE, int):
@@ -129,9 +129,7 @@ class MariaDBVectorClient(VectorDBBase):
poolclass=QueuePool, poolclass=QueuePool,
) )
else: else:
self.engine = create_engine( self.engine = create_engine(self.db_url, pool_pre_ping=True, poolclass=NullPool)
self.db_url, pool_pre_ping=True, poolclass=NullPool
)
else: else:
self.engine = create_engine(self.db_url, pool_pre_ping=True) self.engine = create_engine(self.db_url, pool_pre_ping=True)
self._init_schema() self._init_schema()
@@ -185,7 +183,7 @@ class MariaDBVectorClient(VectorDBBase):
conn.commit() conn.commit()
except Exception as e: except Exception as e:
conn.rollback() conn.rollback()
log.exception(f"Error during database initialization: {e}") log.exception(f'Error during database initialization: {e}')
raise raise
def _check_vector_length(self) -> None: def _check_vector_length(self) -> None:
@@ -197,19 +195,19 @@ class MariaDBVectorClient(VectorDBBase):
""" """
with self._connect() as conn: with self._connect() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute("SHOW CREATE TABLE document_chunk") cur.execute('SHOW CREATE TABLE document_chunk')
row = cur.fetchone() row = cur.fetchone()
if not row or len(row) < 2: if not row or len(row) < 2:
return return
ddl = row[1] ddl = row[1]
m = re.search(r"vector\\((\\d+)\\)", ddl, flags=re.IGNORECASE) m = re.search(r'vector\\((\\d+)\\)', ddl, flags=re.IGNORECASE)
if not m: if not m:
return return
existing = int(m.group(1)) existing = int(m.group(1))
if existing != int(self.vector_length): if existing != int(self.vector_length):
raise Exception( raise Exception(
f"VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. " f'VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. '
"Cannot change vector size after initialization without migrating the data." 'Cannot change vector size after initialization without migrating the data.'
) )
def adjust_vector_length(self, vector: List[float]) -> List[float]: def adjust_vector_length(self, vector: List[float]) -> List[float]:
@@ -227,11 +225,7 @@ class MariaDBVectorClient(VectorDBBase):
""" """
Return the MariaDB Vector distance function name for the configured strategy. Return the MariaDB Vector distance function name for the configured strategy.
""" """
return ( return 'vec_distance_cosine' if self.distance_strategy == 'cosine' else 'vec_distance_euclidean'
"vec_distance_cosine"
if self.distance_strategy == "cosine"
else "vec_distance_euclidean"
)
def _score_from_dist(self, dist: float) -> float: def _score_from_dist(self, dist: float) -> float:
""" """
@@ -240,7 +234,7 @@ class MariaDBVectorClient(VectorDBBase):
- cosine: score ~= 1 - cosine_distance, clamped to [0, 1] - cosine: score ~= 1 - cosine_distance, clamped to [0, 1]
- euclidean: score = 1 / (1 + dist) - euclidean: score = 1 / (1 + dist)
""" """
if self.distance_strategy == "cosine": if self.distance_strategy == 'cosine':
score = 1.0 - dist score = 1.0 - dist
if score < 0.0: if score < 0.0:
score = 0.0 score = 0.0
@@ -260,48 +254,48 @@ class MariaDBVectorClient(VectorDBBase):
- {"$or": [ ... ]} - {"$or": [ ... ]}
""" """
if not expr or not isinstance(expr, dict): if not expr or not isinstance(expr, dict):
return "", [] return '', []
if "$and" in expr: if '$and' in expr:
parts: List[str] = [] parts: List[str] = []
params: List[Any] = [] params: List[Any] = []
for e in expr.get("$and") or []: for e in expr.get('$and') or []:
s, p = self._build_filter_sql_qmark(e) s, p = self._build_filter_sql_qmark(e)
if s: if s:
parts.append(s) parts.append(s)
params.extend(p) params.extend(p)
return ("(" + " AND ".join(parts) + ")") if parts else "", params return ('(' + ' AND '.join(parts) + ')') if parts else '', params
if "$or" in expr: if '$or' in expr:
parts: List[str] = [] parts: List[str] = []
params: List[Any] = [] params: List[Any] = []
for e in expr.get("$or") or []: for e in expr.get('$or') or []:
s, p = self._build_filter_sql_qmark(e) s, p = self._build_filter_sql_qmark(e)
if s: if s:
parts.append(s) parts.append(s)
params.extend(p) params.extend(p)
return ("(" + " OR ".join(parts) + ")") if parts else "", params return ('(' + ' OR '.join(parts) + ')') if parts else '', params
clauses: List[str] = [] clauses: List[str] = []
params: List[Any] = [] params: List[Any] = []
for key, value in expr.items(): for key, value in expr.items():
if key.startswith("$"): if key.startswith('$'):
continue continue
json_expr = f"JSON_UNQUOTE(JSON_EXTRACT(vmetadata, '$.{key}'))" json_expr = f"JSON_UNQUOTE(JSON_EXTRACT(vmetadata, '$.{key}'))"
if isinstance(value, dict) and "$in" in value: if isinstance(value, dict) and '$in' in value:
vals = [str(v) for v in (value.get("$in") or [])] vals = [str(v) for v in (value.get('$in') or [])]
if not vals: if not vals:
clauses.append("0=1") clauses.append('0=1')
continue continue
ors = [] ors = []
for v in vals: for v in vals:
ors.append(f"{json_expr} = ?") ors.append(f'{json_expr} = ?')
params.append(v) params.append(v)
clauses.append("(" + " OR ".join(ors) + ")") clauses.append('(' + ' OR '.join(ors) + ')')
else: else:
clauses.append(f"{json_expr} = ?") clauses.append(f'{json_expr} = ?')
params.append(str(value)) params.append(str(value))
return ("(" + " AND ".join(clauses) + ")") if clauses else "", params return ('(' + ' AND '.join(clauses) + ')') if clauses else '', params
def insert(self, collection_name: str, items: List[VectorItem]) -> None: def insert(self, collection_name: str, items: List[VectorItem]) -> None:
""" """
@@ -322,15 +316,15 @@ class MariaDBVectorClient(VectorDBBase):
""" """
params: List[Tuple[Any, ...]] = [] params: List[Tuple[Any, ...]] = []
for item in items: for item in items:
v = self.adjust_vector_length(item["vector"]) v = self.adjust_vector_length(item['vector'])
emb = _embedding_to_f32_bytes(v) emb = _embedding_to_f32_bytes(v)
meta = process_metadata(item.get("metadata") or {}) meta = process_metadata(item.get('metadata') or {})
params.append( params.append(
( (
item["id"], item['id'],
emb, emb,
collection_name, collection_name,
item.get("text"), item.get('text'),
json.dumps(meta), json.dumps(meta),
) )
) )
@@ -338,7 +332,7 @@ class MariaDBVectorClient(VectorDBBase):
conn.commit() conn.commit()
except Exception as e: except Exception as e:
conn.rollback() conn.rollback()
log.exception(f"Error during insert: {e}") log.exception(f'Error during insert: {e}')
raise raise
def upsert(self, collection_name: str, items: List[VectorItem]) -> None: def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
@@ -365,15 +359,15 @@ class MariaDBVectorClient(VectorDBBase):
""" """
params: List[Tuple[Any, ...]] = [] params: List[Tuple[Any, ...]] = []
for item in items: for item in items:
v = self.adjust_vector_length(item["vector"]) v = self.adjust_vector_length(item['vector'])
emb = _embedding_to_f32_bytes(v) emb = _embedding_to_f32_bytes(v)
meta = process_metadata(item.get("metadata") or {}) meta = process_metadata(item.get('metadata') or {})
params.append( params.append(
( (
item["id"], item['id'],
emb, emb,
collection_name, collection_name,
item.get("text"), item.get('text'),
json.dumps(meta), json.dumps(meta),
) )
) )
@@ -381,7 +375,7 @@ class MariaDBVectorClient(VectorDBBase):
conn.commit() conn.commit()
except Exception as e: except Exception as e:
conn.rollback() conn.rollback()
log.exception(f"Error during upsert: {e}") log.exception(f'Error during upsert: {e}')
raise raise
def search( def search(
@@ -415,10 +409,10 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn: with self._connect() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
fsql, fparams = self._build_filter_sql_qmark(filter or {}) fsql, fparams = self._build_filter_sql_qmark(filter or {})
where = "collection_name = ?" where = 'collection_name = ?'
base_params: List[Any] = [collection_name] base_params: List[Any] = [collection_name]
if fsql: if fsql:
where = where + " AND " + fsql where = where + ' AND ' + fsql
base_params.extend(fparams) base_params.extend(fparams)
sql = f""" sql = f"""
@@ -460,26 +454,24 @@ class MariaDBVectorClient(VectorDBBase):
metadatas=metadatas, metadatas=metadatas,
) )
except Exception as e: except Exception as e:
log.exception(f"[MARIADB_VECTOR] search() failed: {e}") log.exception(f'[MARIADB_VECTOR] search() failed: {e}')
return None return None
def query( def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
) -> Optional[GetResult]:
""" """
Retrieve documents by metadata filter (non-vector query). Retrieve documents by metadata filter (non-vector query).
""" """
with self._connect() as conn: with self._connect() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
fsql, fparams = self._build_filter_sql_qmark(filter or {}) fsql, fparams = self._build_filter_sql_qmark(filter or {})
where = "collection_name = ?" where = 'collection_name = ?'
params: List[Any] = [collection_name] params: List[Any] = [collection_name]
if fsql: if fsql:
where = where + " AND " + fsql where = where + ' AND ' + fsql
params.extend(fparams) params.extend(fparams)
sql = f"SELECT id, text, vmetadata FROM document_chunk WHERE {where}" sql = f'SELECT id, text, vmetadata FROM document_chunk WHERE {where}'
if limit is not None: if limit is not None:
sql += " LIMIT ?" sql += ' LIMIT ?'
params.append(int(limit)) params.append(int(limit))
cur.execute(sql, params) cur.execute(sql, params)
rows = cur.fetchall() rows = cur.fetchall()
@@ -490,18 +482,16 @@ class MariaDBVectorClient(VectorDBBase):
metadatas = [[_safe_json(r[2]) for r in rows]] metadatas = [[_safe_json(r[2]) for r in rows]]
return GetResult(ids=ids, documents=documents, metadatas=metadatas) return GetResult(ids=ids, documents=documents, metadatas=metadatas)
def get( def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
self, collection_name: str, limit: Optional[int] = None
) -> Optional[GetResult]:
""" """
Retrieve documents in a collection without filtering (optionally limited). Retrieve documents in a collection without filtering (optionally limited).
""" """
with self._connect() as conn: with self._connect() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
sql = "SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?" sql = 'SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?'
params: List[Any] = [collection_name] params: List[Any] = [collection_name]
if limit is not None: if limit is not None:
sql += " LIMIT ?" sql += ' LIMIT ?'
params.append(int(limit)) params.append(int(limit))
cur.execute(sql, params) cur.execute(sql, params)
rows = cur.fetchall() rows = cur.fetchall()
@@ -526,12 +516,12 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn: with self._connect() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
try: try:
where = ["collection_name = ?"] where = ['collection_name = ?']
params: List[Any] = [collection_name] params: List[Any] = [collection_name]
if ids: if ids:
ph = ", ".join(["?"] * len(ids)) ph = ', '.join(['?'] * len(ids))
where.append(f"id IN ({ph})") where.append(f'id IN ({ph})')
params.extend(ids) params.extend(ids)
if filter: if filter:
@@ -540,12 +530,12 @@ class MariaDBVectorClient(VectorDBBase):
where.append(fsql) where.append(fsql)
params.extend(fparams) params.extend(fparams)
sql = "DELETE FROM document_chunk WHERE " + " AND ".join(where) sql = 'DELETE FROM document_chunk WHERE ' + ' AND '.join(where)
cur.execute(sql, params) cur.execute(sql, params)
conn.commit() conn.commit()
except Exception as e: except Exception as e:
conn.rollback() conn.rollback()
log.exception(f"Error during delete: {e}") log.exception(f'Error during delete: {e}')
raise raise
def reset(self) -> None: def reset(self) -> None:
@@ -555,11 +545,11 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn: with self._connect() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
try: try:
cur.execute("TRUNCATE TABLE document_chunk") cur.execute('TRUNCATE TABLE document_chunk')
conn.commit() conn.commit()
except Exception as e: except Exception as e:
conn.rollback() conn.rollback()
log.exception(f"Error during reset: {e}") log.exception(f'Error during reset: {e}')
raise raise
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
@@ -570,7 +560,7 @@ class MariaDBVectorClient(VectorDBBase):
with self._connect() as conn: with self._connect() as conn:
with conn.cursor() as cur: with conn.cursor() as cur:
cur.execute( cur.execute(
"SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1", 'SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1',
(collection_name,), (collection_name,),
) )
return cur.fetchone() is not None return cur.fetchone() is not None
@@ -590,4 +580,4 @@ class MariaDBVectorClient(VectorDBBase):
try: try:
self.engine.dispose() self.engine.dispose()
except Exception as e: except Exception as e:
log.exception(f"Error during dispose the underlying SQLAlchemy engine: {e}") log.exception(f'Error during dispose the underlying SQLAlchemy engine: {e}')

View File

@@ -35,7 +35,7 @@ log = logging.getLogger(__name__)
class MilvusClient(VectorDBBase): class MilvusClient(VectorDBBase):
def __init__(self): def __init__(self):
self.collection_prefix = "open_webui" self.collection_prefix = 'open_webui'
if MILVUS_TOKEN is None: if MILVUS_TOKEN is None:
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB) self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
else: else:
@@ -50,17 +50,17 @@ class MilvusClient(VectorDBBase):
_documents = [] _documents = []
_metadatas = [] _metadatas = []
for item in match: for item in match:
_ids.append(item.get("id")) _ids.append(item.get('id'))
_documents.append(item.get("data", {}).get("text")) _documents.append(item.get('data', {}).get('text'))
_metadatas.append(item.get("metadata")) _metadatas.append(item.get('metadata'))
ids.append(_ids) ids.append(_ids)
documents.append(_documents) documents.append(_documents)
metadatas.append(_metadatas) metadatas.append(_metadatas)
return GetResult( return GetResult(
**{ **{
"ids": ids, 'ids': ids,
"documents": documents, 'documents': documents,
"metadatas": metadatas, 'metadatas': metadatas,
} }
) )
@@ -75,23 +75,23 @@ class MilvusClient(VectorDBBase):
_documents = [] _documents = []
_metadatas = [] _metadatas = []
for item in match: for item in match:
_ids.append(item.get("id")) _ids.append(item.get('id'))
# normalize milvus score from [-1, 1] to [0, 1] range # normalize milvus score from [-1, 1] to [0, 1] range
# https://milvus.io/docs/de/metric.md # https://milvus.io/docs/de/metric.md
_dist = (item.get("distance") + 1.0) / 2.0 _dist = (item.get('distance') + 1.0) / 2.0
_distances.append(_dist) _distances.append(_dist)
_documents.append(item.get("entity", {}).get("data", {}).get("text")) _documents.append(item.get('entity', {}).get('data', {}).get('text'))
_metadatas.append(item.get("entity", {}).get("metadata")) _metadatas.append(item.get('entity', {}).get('metadata'))
ids.append(_ids) ids.append(_ids)
distances.append(_distances) distances.append(_distances)
documents.append(_documents) documents.append(_documents)
metadatas.append(_metadatas) metadatas.append(_metadatas)
return SearchResult( return SearchResult(
**{ **{
"ids": ids, 'ids': ids,
"distances": distances, 'distances': distances,
"documents": documents, 'documents': documents,
"metadatas": metadatas, 'metadatas': metadatas,
} }
) )
@@ -101,21 +101,19 @@ class MilvusClient(VectorDBBase):
enable_dynamic_field=True, enable_dynamic_field=True,
) )
schema.add_field( schema.add_field(
field_name="id", field_name='id',
datatype=DataType.VARCHAR, datatype=DataType.VARCHAR,
is_primary=True, is_primary=True,
max_length=65535, max_length=65535,
) )
schema.add_field( schema.add_field(
field_name="vector", field_name='vector',
datatype=DataType.FLOAT_VECTOR, datatype=DataType.FLOAT_VECTOR,
dim=dimension, dim=dimension,
description="vector", description='vector',
)
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
schema.add_field(
field_name="metadata", datatype=DataType.JSON, description="metadata"
) )
schema.add_field(field_name='data', datatype=DataType.JSON, description='data')
schema.add_field(field_name='metadata', datatype=DataType.JSON, description='metadata')
index_params = self.client.prepare_index_params() index_params = self.client.prepare_index_params()
@@ -123,44 +121,44 @@ class MilvusClient(VectorDBBase):
index_type = MILVUS_INDEX_TYPE.upper() index_type = MILVUS_INDEX_TYPE.upper()
metric_type = MILVUS_METRIC_TYPE.upper() metric_type = MILVUS_METRIC_TYPE.upper()
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}") log.info(f'Using Milvus index type: {index_type}, metric type: {metric_type}')
index_creation_params = {} index_creation_params = {}
if index_type == "HNSW": if index_type == 'HNSW':
index_creation_params = { index_creation_params = {
"M": MILVUS_HNSW_M, 'M': MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION, 'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
} }
log.info(f"HNSW params: {index_creation_params}") log.info(f'HNSW params: {index_creation_params}')
elif index_type == "IVF_FLAT": elif index_type == 'IVF_FLAT':
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST} index_creation_params = {'nlist': MILVUS_IVF_FLAT_NLIST}
log.info(f"IVF_FLAT params: {index_creation_params}") log.info(f'IVF_FLAT params: {index_creation_params}')
elif index_type == "DISKANN": elif index_type == 'DISKANN':
index_creation_params = { index_creation_params = {
"max_degree": MILVUS_DISKANN_MAX_DEGREE, 'max_degree': MILVUS_DISKANN_MAX_DEGREE,
"search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE, 'search_list_size': MILVUS_DISKANN_SEARCH_LIST_SIZE,
} }
log.info(f"DISKANN params: {index_creation_params}") log.info(f'DISKANN params: {index_creation_params}')
elif index_type in ["FLAT", "AUTOINDEX"]: elif index_type in ['FLAT', 'AUTOINDEX']:
log.info(f"Using {index_type} index with no specific build-time params.") log.info(f'Using {index_type} index with no specific build-time params.')
else: else:
log.warning( log.warning(
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. " f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. " f'Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. '
f"Milvus will use its default for the collection if this type is not directly supported for index creation." f'Milvus will use its default for the collection if this type is not directly supported for index creation.'
) )
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default. # For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var. # If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
index_params.add_index( index_params.add_index(
field_name="vector", field_name='vector',
index_type=index_type, index_type=index_type,
metric_type=metric_type, metric_type=metric_type,
params=index_creation_params, params=index_creation_params,
) )
self.client.create_collection( self.client.create_collection(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f'{self.collection_prefix}_{collection_name}',
schema=schema, schema=schema,
index_params=index_params, index_params=index_params,
) )
@@ -170,17 +168,13 @@ class MilvusClient(VectorDBBase):
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name. # Check if the collection exists based on the collection name.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace('-', '_')
return self.client.has_collection( return self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def delete_collection(self, collection_name: str): def delete_collection(self, collection_name: str):
# Delete the collection based on the collection name. # Delete the collection based on the collection name.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace('-', '_')
return self.client.drop_collection( return self.client.drop_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
collection_name=f"{self.collection_prefix}_{collection_name}"
)
def search( def search(
self, self,
@@ -190,15 +184,15 @@ class MilvusClient(VectorDBBase):
limit: int = 10, limit: int = 10,
) -> Optional[SearchResult]: ) -> Optional[SearchResult]:
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results. # Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace('-', '_')
# For some index types like IVF_FLAT, search params like nprobe can be set. # For some index types like IVF_FLAT, search params like nprobe can be set.
# Example: search_params = {"nprobe": 10} if using IVF_FLAT # Example: search_params = {"nprobe": 10} if using IVF_FLAT
# For simplicity, not adding configurable search_params here, but could be extended. # For simplicity, not adding configurable search_params here, but could be extended.
result = self.client.search( result = self.client.search(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f'{self.collection_prefix}_{collection_name}',
data=vectors, data=vectors,
limit=limit, limit=limit,
output_fields=["data", "metadata"], output_fields=['data', 'metadata'],
# search_params=search_params # Potentially add later if needed # search_params=search_params # Potentially add later if needed
) )
return self._result_to_search_result(result) return self._result_to_search_result(result)
@@ -206,11 +200,9 @@ class MilvusClient(VectorDBBase):
def query(self, collection_name: str, filter: dict, limit: int = -1): def query(self, collection_name: str, filter: dict, limit: int = -1):
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB) connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace('-', '_')
if not self.has_collection(collection_name): if not self.has_collection(collection_name):
log.warning( log.warning(f'Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
return None return None
filter_expressions = [] filter_expressions = []
@@ -220,9 +212,9 @@ class MilvusClient(VectorDBBase):
else: else:
filter_expressions.append(f'metadata["{key}"] == {value}') filter_expressions.append(f'metadata["{key}"] == {value}')
filter_string = " && ".join(filter_expressions) filter_string = ' && '.join(filter_expressions)
collection = Collection(f"{self.collection_prefix}_{collection_name}") collection = Collection(f'{self.collection_prefix}_{collection_name}')
collection.load() collection.load()
try: try:
@@ -233,9 +225,9 @@ class MilvusClient(VectorDBBase):
iterator = collection.query_iterator( iterator = collection.query_iterator(
expr=filter_string, expr=filter_string,
output_fields=[ output_fields=[
"id", 'id',
"data", 'data',
"metadata", 'metadata',
], ],
limit=limit if limit > 0 else -1, limit=limit if limit > 0 else -1,
) )
@@ -248,7 +240,7 @@ class MilvusClient(VectorDBBase):
break break
all_results.extend(batch) all_results.extend(batch)
log.debug(f"Total results from query: {len(all_results)}") log.debug(f'Total results from query: {len(all_results)}')
return self._result_to_get_result([all_results] if all_results else [[]]) return self._result_to_get_result([all_results] if all_results else [[]])
except Exception as e: except Exception as e:
@@ -259,7 +251,7 @@ class MilvusClient(VectorDBBase):
def get(self, collection_name: str) -> Optional[GetResult]: def get(self, collection_name: str) -> Optional[GetResult]:
# Get all the items in the collection. This can be very resource-intensive for large collections. # Get all the items in the collection. This can be very resource-intensive for large collections.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace('-', '_')
log.warning( log.warning(
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections." f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
) )
@@ -269,35 +261,25 @@ class MilvusClient(VectorDBBase):
def insert(self, collection_name: str, items: list[VectorItem]): def insert(self, collection_name: str, items: list[VectorItem]):
# Insert the items into the collection, if the collection does not exist, it will be created. # Insert the items into the collection, if the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace('-', '_')
if not self.client.has_collection( if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
collection_name=f"{self.collection_prefix}_{collection_name}" log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.')
):
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
)
if not items: if not items:
log.error( log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension." f'Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.'
)
raise ValueError(
"Cannot create Milvus collection without items to determine vector dimension."
)
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
) )
raise ValueError('Cannot create Milvus collection without items to determine vector dimension.')
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
log.info( log.info(f'Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
return self.client.insert( return self.client.insert(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f'{self.collection_prefix}_{collection_name}',
data=[ data=[
{ {
"id": item["id"], 'id': item['id'],
"vector": item["vector"], 'vector': item['vector'],
"data": {"text": item["text"]}, 'data': {'text': item['text']},
"metadata": process_metadata(item["metadata"]), 'metadata': process_metadata(item['metadata']),
} }
for item in items for item in items
], ],
@@ -305,35 +287,27 @@ class MilvusClient(VectorDBBase):
def upsert(self, collection_name: str, items: list[VectorItem]): def upsert(self, collection_name: str, items: list[VectorItem]):
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created. # Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace('-', '_')
if not self.client.has_collection( if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
collection_name=f"{self.collection_prefix}_{collection_name}" log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.')
):
log.info(
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
)
if not items: if not items:
log.error( log.error(
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension." f'Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.'
) )
raise ValueError( raise ValueError(
"Cannot create Milvus collection for upsert without items to determine vector dimension." 'Cannot create Milvus collection for upsert without items to determine vector dimension.'
)
self._create_collection(
collection_name=collection_name, dimension=len(items[0]["vector"])
) )
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
log.info( log.info(f'Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
)
return self.client.upsert( return self.client.upsert(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f'{self.collection_prefix}_{collection_name}',
data=[ data=[
{ {
"id": item["id"], 'id': item['id'],
"vector": item["vector"], 'vector': item['vector'],
"data": {"text": item["text"]}, 'data': {'text': item['text']},
"metadata": process_metadata(item["metadata"]), 'metadata': process_metadata(item['metadata']),
} }
for item in items for item in items
], ],
@@ -346,46 +320,35 @@ class MilvusClient(VectorDBBase):
filter: Optional[dict] = None, filter: Optional[dict] = None,
): ):
# Delete the items from the collection based on the ids or filter. # Delete the items from the collection based on the ids or filter.
collection_name = collection_name.replace("-", "_") collection_name = collection_name.replace('-', '_')
if not self.has_collection(collection_name): if not self.has_collection(collection_name):
log.warning( log.warning(f'Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
)
return None return None
if ids: if ids:
log.info( log.info(f'Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}')
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
)
return self.client.delete( return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f'{self.collection_prefix}_{collection_name}',
ids=ids, ids=ids,
) )
elif filter: elif filter:
filter_string = " && ".join( filter_string = ' && '.join([f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items()])
[
f'metadata["{key}"] == {json.dumps(value)}'
for key, value in filter.items()
]
)
log.info( log.info(
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}" f'Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}'
) )
return self.client.delete( return self.client.delete(
collection_name=f"{self.collection_prefix}_{collection_name}", collection_name=f'{self.collection_prefix}_{collection_name}',
filter=filter_string, filter=filter_string,
) )
else: else:
log.warning( log.warning(
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken." f'Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.'
) )
return None return None
def reset(self): def reset(self):
# Resets the database. This will delete all collections and item entries that match the prefix. # Resets the database. This will delete all collections and item entries that match the prefix.
log.warning( log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
)
collection_names = self.client.list_collections() collection_names = self.client.list_collections()
deleted_collections = [] deleted_collections = []
for collection_name_full in collection_names: for collection_name_full in collection_names:
@@ -393,7 +356,7 @@ class MilvusClient(VectorDBBase):
try: try:
self.client.drop_collection(collection_name=collection_name_full) self.client.drop_collection(collection_name=collection_name_full)
deleted_collections.append(collection_name_full) deleted_collections.append(collection_name_full)
log.info(f"Deleted collection: {collection_name_full}") log.info(f'Deleted collection: {collection_name_full}')
except Exception as e: except Exception as e:
log.error(f"Error deleting collection {collection_name_full}: {e}") log.error(f'Error deleting collection {collection_name_full}: {e}')
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}") log.info(f'Milvus reset complete. Deleted collections: {deleted_collections}')

View File

@@ -33,26 +33,26 @@ from pymilvus import (
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
RESOURCE_ID_FIELD = "resource_id" RESOURCE_ID_FIELD = 'resource_id'
class MilvusClient(VectorDBBase): class MilvusClient(VectorDBBase):
def __init__(self): def __init__(self):
# Milvus collection names can only contain numbers, letters, and underscores. # Milvus collection names can only contain numbers, letters, and underscores.
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_") self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace('-', '_')
connections.connect( connections.connect(
alias="default", alias='default',
uri=MILVUS_URI, uri=MILVUS_URI,
token=MILVUS_TOKEN, token=MILVUS_TOKEN,
db_name=MILVUS_DB, db_name=MILVUS_DB,
) )
# Main collection types for multi-tenancy # Main collection types for multi-tenancy
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories'
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge" self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge'
self.FILE_COLLECTION = f"{self.collection_prefix}_files" self.FILE_COLLECTION = f'{self.collection_prefix}_files'
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search" self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web_search'
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based" self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash_based'
self.shared_collections = [ self.shared_collections = [
self.MEMORY_COLLECTION, self.MEMORY_COLLECTION,
self.KNOWLEDGE_COLLECTION, self.KNOWLEDGE_COLLECTION,
@@ -74,15 +74,13 @@ class MilvusClient(VectorDBBase):
""" """
resource_id = collection_name resource_id = collection_name
if collection_name.startswith("user-memory-"): if collection_name.startswith('user-memory-'):
return self.MEMORY_COLLECTION, resource_id return self.MEMORY_COLLECTION, resource_id
elif collection_name.startswith("file-"): elif collection_name.startswith('file-'):
return self.FILE_COLLECTION, resource_id return self.FILE_COLLECTION, resource_id
elif collection_name.startswith("web-search-"): elif collection_name.startswith('web-search-'):
return self.WEB_SEARCH_COLLECTION, resource_id return self.WEB_SEARCH_COLLECTION, resource_id
elif len(collection_name) == 63 and all( elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name):
c in "0123456789abcdef" for c in collection_name
):
return self.HASH_BASED_COLLECTION, resource_id return self.HASH_BASED_COLLECTION, resource_id
else: else:
return self.KNOWLEDGE_COLLECTION, resource_id return self.KNOWLEDGE_COLLECTION, resource_id
@@ -90,36 +88,36 @@ class MilvusClient(VectorDBBase):
def _create_shared_collection(self, mt_collection_name: str, dimension: int): def _create_shared_collection(self, mt_collection_name: str, dimension: int):
fields = [ fields = [
FieldSchema( FieldSchema(
name="id", name='id',
dtype=DataType.VARCHAR, dtype=DataType.VARCHAR,
is_primary=True, is_primary=True,
auto_id=False, auto_id=False,
max_length=36, max_length=36,
), ),
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=dimension),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="metadata", dtype=DataType.JSON), FieldSchema(name='metadata', dtype=DataType.JSON),
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255), FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
] ]
schema = CollectionSchema(fields, "Shared collection for multi-tenancy") schema = CollectionSchema(fields, 'Shared collection for multi-tenancy')
collection = Collection(mt_collection_name, schema) collection = Collection(mt_collection_name, schema)
index_params = { index_params = {
"metric_type": MILVUS_METRIC_TYPE, 'metric_type': MILVUS_METRIC_TYPE,
"index_type": MILVUS_INDEX_TYPE, 'index_type': MILVUS_INDEX_TYPE,
"params": {}, 'params': {},
} }
if MILVUS_INDEX_TYPE == "HNSW": if MILVUS_INDEX_TYPE == 'HNSW':
index_params["params"] = { index_params['params'] = {
"M": MILVUS_HNSW_M, 'M': MILVUS_HNSW_M,
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION, 'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
} }
elif MILVUS_INDEX_TYPE == "IVF_FLAT": elif MILVUS_INDEX_TYPE == 'IVF_FLAT':
index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST} index_params['params'] = {'nlist': MILVUS_IVF_FLAT_NLIST}
collection.create_index("vector", index_params) collection.create_index('vector', index_params)
collection.create_index(RESOURCE_ID_FIELD) collection.create_index(RESOURCE_ID_FIELD)
log.info(f"Created shared collection: {mt_collection_name}") log.info(f'Created shared collection: {mt_collection_name}')
return collection return collection
def _ensure_collection(self, mt_collection_name: str, dimension: int): def _ensure_collection(self, mt_collection_name: str, dimension: int):
@@ -127,9 +125,7 @@ class MilvusClient(VectorDBBase):
self._create_shared_collection(mt_collection_name, dimension) self._create_shared_collection(mt_collection_name, dimension)
def has_collection(self, collection_name: str) -> bool: def has_collection(self, collection_name: str) -> bool:
mt_collection, resource_id = self._get_collection_and_resource_id( mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
collection_name
)
if not utility.has_collection(mt_collection): if not utility.has_collection(mt_collection):
return False return False
@@ -141,19 +137,17 @@ class MilvusClient(VectorDBBase):
def upsert(self, collection_name: str, items: List[VectorItem]): def upsert(self, collection_name: str, items: List[VectorItem]):
if not items: if not items:
return return
mt_collection, resource_id = self._get_collection_and_resource_id( mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
collection_name dimension = len(items[0]['vector'])
)
dimension = len(items[0]["vector"])
self._ensure_collection(mt_collection, dimension) self._ensure_collection(mt_collection, dimension)
collection = Collection(mt_collection) collection = Collection(mt_collection)
entities = [ entities = [
{ {
"id": item["id"], 'id': item['id'],
"vector": item["vector"], 'vector': item['vector'],
"text": item["text"], 'text': item['text'],
"metadata": item["metadata"], 'metadata': item['metadata'],
RESOURCE_ID_FIELD: resource_id, RESOURCE_ID_FIELD: resource_id,
} }
for item in items for item in items
@@ -170,41 +164,37 @@ class MilvusClient(VectorDBBase):
if not vectors: if not vectors:
return None return None
mt_collection, resource_id = self._get_collection_and_resource_id( mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
collection_name
)
if not utility.has_collection(mt_collection): if not utility.has_collection(mt_collection):
return None return None
collection = Collection(mt_collection) collection = Collection(mt_collection)
collection.load() collection.load()
search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}} search_params = {'metric_type': MILVUS_METRIC_TYPE, 'params': {}}
results = collection.search( results = collection.search(
data=vectors, data=vectors,
anns_field="vector", anns_field='vector',
param=search_params, param=search_params,
limit=limit, limit=limit,
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
output_fields=["id", "text", "metadata"], output_fields=['id', 'text', 'metadata'],
) )
ids, documents, metadatas, distances = [], [], [], [] ids, documents, metadatas, distances = [], [], [], []
for hits in results: for hits in results:
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], [] batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
for hit in hits: for hit in hits:
batch_ids.append(hit.entity.get("id")) batch_ids.append(hit.entity.get('id'))
batch_docs.append(hit.entity.get("text")) batch_docs.append(hit.entity.get('text'))
batch_metadatas.append(hit.entity.get("metadata")) batch_metadatas.append(hit.entity.get('metadata'))
batch_dists.append(hit.distance) batch_dists.append(hit.distance)
ids.append(batch_ids) ids.append(batch_ids)
documents.append(batch_docs) documents.append(batch_docs)
metadatas.append(batch_metadatas) metadatas.append(batch_metadatas)
distances.append(batch_dists) distances.append(batch_dists)
return SearchResult( return SearchResult(ids=ids, documents=documents, metadatas=metadatas, distances=distances)
ids=ids, documents=documents, metadatas=metadatas, distances=distances
)
def delete( def delete(
self, self,
@@ -212,9 +202,7 @@ class MilvusClient(VectorDBBase):
ids: Optional[List[str]] = None, ids: Optional[List[str]] = None,
filter: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None,
): ):
mt_collection, resource_id = self._get_collection_and_resource_id( mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
collection_name
)
if not utility.has_collection(mt_collection): if not utility.has_collection(mt_collection):
return return
@@ -224,14 +212,14 @@ class MilvusClient(VectorDBBase):
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"] expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
if ids: if ids:
# Milvus expects a string list for 'in' operator # Milvus expects a string list for 'in' operator
id_list_str = ", ".join([f"'{id_val}'" for id_val in ids]) id_list_str = ', '.join([f"'{id_val}'" for id_val in ids])
expr.append(f"id in [{id_list_str}]") expr.append(f'id in [{id_list_str}]')
if filter: if filter:
for key, value in filter.items(): for key, value in filter.items():
expr.append(f"metadata['{key}'] == '{value}'") expr.append(f"metadata['{key}'] == '{value}'")
collection.delete(" and ".join(expr)) collection.delete(' and '.join(expr))
def reset(self): def reset(self):
for collection_name in self.shared_collections: for collection_name in self.shared_collections:
@@ -239,21 +227,15 @@ class MilvusClient(VectorDBBase):
utility.drop_collection(collection_name) utility.drop_collection(collection_name)
def delete_collection(self, collection_name: str): def delete_collection(self, collection_name: str):
mt_collection, resource_id = self._get_collection_and_resource_id( mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
collection_name
)
if not utility.has_collection(mt_collection): if not utility.has_collection(mt_collection):
return return
collection = Collection(mt_collection) collection = Collection(mt_collection)
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'") collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
def query( def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
) -> Optional[GetResult]:
mt_collection, resource_id = self._get_collection_and_resource_id(
collection_name
)
if not utility.has_collection(mt_collection): if not utility.has_collection(mt_collection):
return None return None
@@ -269,8 +251,8 @@ class MilvusClient(VectorDBBase):
expr.append(f"metadata['{key}'] == {value}") expr.append(f"metadata['{key}'] == {value}")
iterator = collection.query_iterator( iterator = collection.query_iterator(
expr=" and ".join(expr), expr=' and '.join(expr),
output_fields=["id", "text", "metadata"], output_fields=['id', 'text', 'metadata'],
limit=limit if limit else -1, limit=limit if limit else -1,
) )
@@ -282,9 +264,9 @@ class MilvusClient(VectorDBBase):
break break
all_results.extend(batch) all_results.extend(batch)
ids = [res["id"] for res in all_results] ids = [res['id'] for res in all_results]
documents = [res["text"] for res in all_results] documents = [res['text'] for res in all_results]
metadatas = [res["metadata"] for res in all_results] metadatas = [res['metadata'] for res in all_results]
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])

Some files were not shown because too many files have changed in this diff Show More