mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-25 17:15:16 +02:00
refac
This commit is contained in:
@@ -10,94 +10,88 @@ from typing_extensions import Annotated
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
KEY_FILE = Path.cwd() / ".webui_secret_key"
|
||||
KEY_FILE = Path.cwd() / '.webui_secret_key'
|
||||
|
||||
|
||||
def version_callback(value: bool):
|
||||
if value:
|
||||
from open_webui.env import VERSION
|
||||
|
||||
typer.echo(f"Open WebUI version: {VERSION}")
|
||||
typer.echo(f'Open WebUI version: {VERSION}')
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
version: Annotated[
|
||||
Optional[bool], typer.Option("--version", callback=version_callback)
|
||||
] = None,
|
||||
version: Annotated[Optional[bool], typer.Option('--version', callback=version_callback)] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
host: str = "0.0.0.0",
|
||||
host: str = '0.0.0.0',
|
||||
port: int = 8080,
|
||||
):
|
||||
os.environ["FROM_INIT_PY"] = "true"
|
||||
if os.getenv("WEBUI_SECRET_KEY") is None:
|
||||
typer.echo(
|
||||
"Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
|
||||
)
|
||||
os.environ['FROM_INIT_PY'] = 'true'
|
||||
if os.getenv('WEBUI_SECRET_KEY') is None:
|
||||
typer.echo('Loading WEBUI_SECRET_KEY from file, not provided as an environment variable.')
|
||||
if not KEY_FILE.exists():
|
||||
typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}")
|
||||
typer.echo(f'Generating a new secret key and saving it to {KEY_FILE}')
|
||||
KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
|
||||
typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}")
|
||||
os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text()
|
||||
typer.echo(f'Loading WEBUI_SECRET_KEY from {KEY_FILE}')
|
||||
os.environ['WEBUI_SECRET_KEY'] = KEY_FILE.read_text()
|
||||
|
||||
if os.getenv("USE_CUDA_DOCKER", "false") == "true":
|
||||
typer.echo(
|
||||
"CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
|
||||
)
|
||||
LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
|
||||
os.environ["LD_LIBRARY_PATH"] = ":".join(
|
||||
if os.getenv('USE_CUDA_DOCKER', 'false') == 'true':
|
||||
typer.echo('CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries.')
|
||||
LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH', '').split(':')
|
||||
os.environ['LD_LIBRARY_PATH'] = ':'.join(
|
||||
LD_LIBRARY_PATH
|
||||
+ [
|
||||
"/usr/local/lib/python3.11/site-packages/torch/lib",
|
||||
"/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
|
||||
'/usr/local/lib/python3.11/site-packages/torch/lib',
|
||||
'/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib',
|
||||
]
|
||||
)
|
||||
try:
|
||||
import torch
|
||||
|
||||
assert torch.cuda.is_available(), "CUDA not available"
|
||||
typer.echo("CUDA seems to be working")
|
||||
assert torch.cuda.is_available(), 'CUDA not available'
|
||||
typer.echo('CUDA seems to be working')
|
||||
except Exception as e:
|
||||
typer.echo(
|
||||
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
|
||||
"Resetting USE_CUDA_DOCKER to false and removing "
|
||||
f"LD_LIBRARY_PATH modifications: {e}"
|
||||
'Error when testing CUDA but USE_CUDA_DOCKER is true. '
|
||||
'Resetting USE_CUDA_DOCKER to false and removing '
|
||||
f'LD_LIBRARY_PATH modifications: {e}'
|
||||
)
|
||||
os.environ["USE_CUDA_DOCKER"] = "false"
|
||||
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
|
||||
os.environ['USE_CUDA_DOCKER'] = 'false'
|
||||
os.environ['LD_LIBRARY_PATH'] = ':'.join(LD_LIBRARY_PATH)
|
||||
|
||||
import open_webui.main # we need set environment variables before importing main
|
||||
from open_webui.env import UVICORN_WORKERS # Import the workers setting
|
||||
|
||||
uvicorn.run(
|
||||
"open_webui.main:app",
|
||||
'open_webui.main:app',
|
||||
host=host,
|
||||
port=port,
|
||||
forwarded_allow_ips="*",
|
||||
forwarded_allow_ips='*',
|
||||
workers=UVICORN_WORKERS,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def dev(
|
||||
host: str = "0.0.0.0",
|
||||
host: str = '0.0.0.0',
|
||||
port: int = 8080,
|
||||
reload: bool = True,
|
||||
):
|
||||
uvicorn.run(
|
||||
"open_webui.main:app",
|
||||
'open_webui.main:app',
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload,
|
||||
forwarded_allow_ips="*",
|
||||
forwarded_allow_ips='*',
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
app()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,125 +2,107 @@ from enum import Enum
|
||||
|
||||
|
||||
class MESSAGES(str, Enum):
|
||||
DEFAULT = lambda msg="": f"{msg if msg else ''}"
|
||||
MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully."
|
||||
MODEL_DELETED = (
|
||||
lambda model="": f"The model '{model}' has been deleted successfully."
|
||||
)
|
||||
DEFAULT = lambda msg='': f'{msg if msg else ""}'
|
||||
MODEL_ADDED = lambda model='': f"The model '{model}' has been added successfully."
|
||||
MODEL_DELETED = lambda model='': f"The model '{model}' has been deleted successfully."
|
||||
|
||||
|
||||
class WEBHOOK_MESSAGES(str, Enum):
|
||||
DEFAULT = lambda msg="": f"{msg if msg else ''}"
|
||||
USER_SIGNUP = lambda username="": (
|
||||
f"New user signed up: {username}" if username else "New user signed up"
|
||||
)
|
||||
DEFAULT = lambda msg='': f'{msg if msg else ""}'
|
||||
USER_SIGNUP = lambda username='': (f'New user signed up: {username}' if username else 'New user signed up')
|
||||
|
||||
|
||||
class ERROR_MESSAGES(str, Enum):
|
||||
def __str__(self) -> str:
|
||||
return super().__str__()
|
||||
|
||||
DEFAULT = (
|
||||
lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
|
||||
DEFAULT = lambda err='': f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
|
||||
ENV_VAR_NOT_FOUND = 'Required environment variable not found. Terminating now.'
|
||||
CREATE_USER_ERROR = 'Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance.'
|
||||
DELETE_USER_ERROR = 'Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot.'
|
||||
EMAIL_MISMATCH = 'Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again.'
|
||||
EMAIL_TAKEN = 'Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew.'
|
||||
USERNAME_TAKEN = 'Uh-oh! This username is already registered. Please choose another username.'
|
||||
PASSWORD_TOO_LONG = (
|
||||
'Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long.'
|
||||
)
|
||||
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
|
||||
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
|
||||
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
|
||||
EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again."
|
||||
EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew."
|
||||
USERNAME_TAKEN = (
|
||||
"Uh-oh! This username is already registered. Please choose another username."
|
||||
)
|
||||
PASSWORD_TOO_LONG = "Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long."
|
||||
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
|
||||
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
|
||||
COMMAND_TAKEN = 'Uh-oh! This command is already registered. Please choose another command string.'
|
||||
FILE_EXISTS = 'Uh-oh! This file is already registered. Please choose another file.'
|
||||
|
||||
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
|
||||
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
|
||||
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
|
||||
MODEL_ID_TOO_LONG = "The model id is too long. Please make sure your model id is less than 256 characters long."
|
||||
ID_TAKEN = 'Uh-oh! This id is already registered. Please choose another id string.'
|
||||
MODEL_ID_TAKEN = 'Uh-oh! This model id is already registered. Please choose another model id string.'
|
||||
NAME_TAG_TAKEN = 'Uh-oh! This name tag is already registered. Please choose another name tag string.'
|
||||
MODEL_ID_TOO_LONG = 'The model id is too long. Please make sure your model id is less than 256 characters long.'
|
||||
|
||||
INVALID_TOKEN = (
|
||||
"Your session has expired or the token is invalid. Please sign in again."
|
||||
)
|
||||
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
|
||||
INVALID_TOKEN = 'Your session has expired or the token is invalid. Please sign in again.'
|
||||
INVALID_CRED = 'The email or password provided is incorrect. Please check for typos and try logging in again.'
|
||||
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
||||
INCORRECT_PASSWORD = (
|
||||
"The password provided is incorrect. Please check for typos and try again."
|
||||
INCORRECT_PASSWORD = 'The password provided is incorrect. Please check for typos and try again.'
|
||||
INVALID_TRUSTED_HEADER = (
|
||||
'Your provider has not provided a trusted header. Please contact your administrator for assistance.'
|
||||
)
|
||||
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
|
||||
|
||||
EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation."
|
||||
|
||||
UNAUTHORIZED = "401 Unauthorized"
|
||||
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
|
||||
ACTION_PROHIBITED = (
|
||||
"The requested action has been restricted as a security measure."
|
||||
UNAUTHORIZED = '401 Unauthorized'
|
||||
ACCESS_PROHIBITED = (
|
||||
'You do not have permission to access this resource. Please contact your administrator for assistance.'
|
||||
)
|
||||
ACTION_PROHIBITED = 'The requested action has been restricted as a security measure.'
|
||||
|
||||
FILE_NOT_SENT = "FILE_NOT_SENT"
|
||||
FILE_NOT_SENT = 'FILE_NOT_SENT'
|
||||
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again."
|
||||
|
||||
NOT_FOUND = "We could not find what you're looking for :/"
|
||||
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
||||
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
|
||||
API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment."
|
||||
API_KEY_NOT_ALLOWED = 'Use of API key is not enabled in the environment.'
|
||||
|
||||
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
|
||||
MALICIOUS = 'Unusual activities detected, please try again in a few minutes.'
|
||||
|
||||
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
|
||||
INCORRECT_FORMAT = (
|
||||
lambda err="": f"Invalid format. Please use the correct format{err}"
|
||||
)
|
||||
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
|
||||
PANDOC_NOT_INSTALLED = 'Pandoc is not installed on the server. Please contact your administrator for assistance.'
|
||||
INCORRECT_FORMAT = lambda err='': f'Invalid format. Please use the correct format{err}'
|
||||
RATE_LIMIT_EXCEEDED = 'API rate limit exceeded'
|
||||
|
||||
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
|
||||
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
|
||||
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
|
||||
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
|
||||
API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment."
|
||||
MODEL_NOT_FOUND = lambda name='': f"Model '{name}' was not found"
|
||||
OPENAI_NOT_FOUND = lambda name='': 'OpenAI API was not found'
|
||||
OLLAMA_NOT_FOUND = 'WebUI could not connect to Ollama'
|
||||
CREATE_API_KEY_ERROR = 'Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance.'
|
||||
API_KEY_CREATION_NOT_ALLOWED = 'API key creation is not allowed in the environment.'
|
||||
|
||||
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
|
||||
EMPTY_CONTENT = 'The content provided is empty. Please ensure that there is text or data present before proceeding.'
|
||||
|
||||
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
|
||||
DB_NOT_SQLITE = 'This feature is only available when running with SQLite databases.'
|
||||
|
||||
INVALID_URL = (
|
||||
"Oops! The URL you provided is invalid. Please double-check and try again."
|
||||
)
|
||||
INVALID_URL = 'Oops! The URL you provided is invalid. Please double-check and try again.'
|
||||
|
||||
WEB_SEARCH_ERROR = (
|
||||
lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
|
||||
)
|
||||
WEB_SEARCH_ERROR = lambda err='': f'{err if err else "Oops! Something went wrong while searching the web."}'
|
||||
|
||||
OLLAMA_API_DISABLED = (
|
||||
"The Ollama API is disabled. Please enable it to use this feature."
|
||||
)
|
||||
OLLAMA_API_DISABLED = 'The Ollama API is disabled. Please enable it to use this feature.'
|
||||
|
||||
FILE_TOO_LARGE = (
|
||||
lambda size="": f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}."
|
||||
lambda size='': f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}."
|
||||
)
|
||||
|
||||
DUPLICATE_CONTENT = (
|
||||
"Duplicate content detected. Please provide unique content to proceed."
|
||||
DUPLICATE_CONTENT = 'Duplicate content detected. Please provide unique content to proceed.'
|
||||
FILE_NOT_PROCESSED = (
|
||||
'Extracted content is not available for this file. Please ensure that the file is processed before proceeding.'
|
||||
)
|
||||
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
|
||||
|
||||
INVALID_PASSWORD = lambda err="": (
|
||||
err if err else "The password does not meet the required validation criteria."
|
||||
)
|
||||
INVALID_PASSWORD = lambda err='': (err if err else 'The password does not meet the required validation criteria.')
|
||||
|
||||
|
||||
class TASKS(str, Enum):
|
||||
def __str__(self) -> str:
|
||||
return super().__str__()
|
||||
|
||||
DEFAULT = lambda task="": f"{task if task else 'generation'}"
|
||||
TITLE_GENERATION = "title_generation"
|
||||
FOLLOW_UP_GENERATION = "follow_up_generation"
|
||||
TAGS_GENERATION = "tags_generation"
|
||||
EMOJI_GENERATION = "emoji_generation"
|
||||
QUERY_GENERATION = "query_generation"
|
||||
IMAGE_PROMPT_GENERATION = "image_prompt_generation"
|
||||
AUTOCOMPLETE_GENERATION = "autocomplete_generation"
|
||||
FUNCTION_CALLING = "function_calling"
|
||||
MOA_RESPONSE_GENERATION = "moa_response_generation"
|
||||
DEFAULT = lambda task='': f'{task if task else "generation"}'
|
||||
TITLE_GENERATION = 'title_generation'
|
||||
FOLLOW_UP_GENERATION = 'follow_up_generation'
|
||||
TAGS_GENERATION = 'tags_generation'
|
||||
EMOJI_GENERATION = 'emoji_generation'
|
||||
QUERY_GENERATION = 'query_generation'
|
||||
IMAGE_PROMPT_GENERATION = 'image_prompt_generation'
|
||||
AUTOCOMPLETE_GENERATION = 'autocomplete_generation'
|
||||
FUNCTION_CALLING = 'function_calling'
|
||||
MOA_RESPONSE_GENERATION = 'moa_response_generation'
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -57,17 +57,15 @@ log = logging.getLogger(__name__)
|
||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
|
||||
|
||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
||||
if hasattr(function_module, 'valves') and hasattr(function_module, 'Valves'):
|
||||
Valves = function_module.Valves
|
||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||
|
||||
if valves:
|
||||
try:
|
||||
function_module.valves = Valves(
|
||||
**{k: v for k, v in valves.items() if v is not None}
|
||||
)
|
||||
function_module.valves = Valves(**{k: v for k, v in valves.items() if v is not None})
|
||||
except Exception as e:
|
||||
log.exception(f"Error loading valves for function {pipe_id}: {e}")
|
||||
log.exception(f'Error loading valves for function {pipe_id}: {e}')
|
||||
raise e
|
||||
else:
|
||||
function_module.valves = Valves()
|
||||
@@ -76,7 +74,7 @@ def get_function_module_by_id(request: Request, pipe_id: str):
|
||||
|
||||
|
||||
async def get_function_models(request):
|
||||
pipes = Functions.get_functions_by_type("pipe", active_only=True)
|
||||
pipes = Functions.get_functions_by_type('pipe', active_only=True)
|
||||
pipe_models = []
|
||||
|
||||
for pipe in pipes:
|
||||
@@ -84,11 +82,11 @@ async def get_function_models(request):
|
||||
function_module = get_function_module_by_id(request, pipe.id)
|
||||
|
||||
has_user_valves = False
|
||||
if hasattr(function_module, "UserValves"):
|
||||
if hasattr(function_module, 'UserValves'):
|
||||
has_user_valves = True
|
||||
|
||||
# Check if function is a manifold
|
||||
if hasattr(function_module, "pipes"):
|
||||
if hasattr(function_module, 'pipes'):
|
||||
sub_pipes = []
|
||||
|
||||
# Handle pipes being a list, sync function, or async function
|
||||
@@ -104,32 +102,30 @@ async def get_function_models(request):
|
||||
log.exception(e)
|
||||
sub_pipes = []
|
||||
|
||||
log.debug(
|
||||
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
|
||||
)
|
||||
log.debug(f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}")
|
||||
|
||||
for p in sub_pipes:
|
||||
sub_pipe_id = f'{pipe.id}.{p["id"]}'
|
||||
sub_pipe_name = p["name"]
|
||||
sub_pipe_name = p['name']
|
||||
|
||||
if hasattr(function_module, "name"):
|
||||
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
|
||||
if hasattr(function_module, 'name'):
|
||||
sub_pipe_name = f'{function_module.name}{sub_pipe_name}'
|
||||
|
||||
pipe_flag = {"type": pipe.type}
|
||||
pipe_flag = {'type': pipe.type}
|
||||
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": sub_pipe_id,
|
||||
"name": sub_pipe_name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": pipe_flag,
|
||||
"has_user_valves": has_user_valves,
|
||||
'id': sub_pipe_id,
|
||||
'name': sub_pipe_name,
|
||||
'object': 'model',
|
||||
'created': pipe.created_at,
|
||||
'owned_by': 'openai',
|
||||
'pipe': pipe_flag,
|
||||
'has_user_valves': has_user_valves,
|
||||
}
|
||||
)
|
||||
else:
|
||||
pipe_flag = {"type": "pipe"}
|
||||
pipe_flag = {'type': 'pipe'}
|
||||
|
||||
log.debug(
|
||||
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
|
||||
@@ -137,13 +133,13 @@ async def get_function_models(request):
|
||||
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": pipe.id,
|
||||
"name": pipe.name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": pipe_flag,
|
||||
"has_user_valves": has_user_valves,
|
||||
'id': pipe.id,
|
||||
'name': pipe.name,
|
||||
'object': 'model',
|
||||
'created': pipe.created_at,
|
||||
'owned_by': 'openai',
|
||||
'pipe': pipe_flag,
|
||||
'has_user_valves': has_user_valves,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -153,9 +149,7 @@ async def get_function_models(request):
|
||||
return pipe_models
|
||||
|
||||
|
||||
async def generate_function_chat_completion(
|
||||
request, form_data, user, models: dict = {}
|
||||
):
|
||||
async def generate_function_chat_completion(request, form_data, user, models: dict = {}):
|
||||
async def execute_pipe(pipe, params):
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
return await pipe(**params)
|
||||
@@ -166,32 +160,32 @@ async def generate_function_chat_completion(
|
||||
if isinstance(res, str):
|
||||
return res
|
||||
if isinstance(res, Generator):
|
||||
return "".join(map(str, res))
|
||||
return ''.join(map(str, res))
|
||||
if isinstance(res, AsyncGenerator):
|
||||
return "".join([str(stream) async for stream in res])
|
||||
return ''.join([str(stream) async for stream in res])
|
||||
|
||||
def process_line(form_data: dict, line):
|
||||
if isinstance(line, BaseModel):
|
||||
line = line.model_dump_json()
|
||||
line = f"data: {line}"
|
||||
line = f'data: {line}'
|
||||
if isinstance(line, dict):
|
||||
line = f"data: {json.dumps(line)}"
|
||||
line = f'data: {json.dumps(line)}'
|
||||
|
||||
try:
|
||||
line = line.decode("utf-8")
|
||||
line = line.decode('utf-8')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if line.startswith("data:"):
|
||||
return f"{line}\n\n"
|
||||
if line.startswith('data:'):
|
||||
return f'{line}\n\n'
|
||||
else:
|
||||
line = openai_chat_chunk_message_template(form_data["model"], line)
|
||||
return f"data: {json.dumps(line)}\n\n"
|
||||
line = openai_chat_chunk_message_template(form_data['model'], line)
|
||||
return f'data: {json.dumps(line)}\n\n'
|
||||
|
||||
def get_pipe_id(form_data: dict) -> str:
|
||||
pipe_id = form_data["model"]
|
||||
if "." in pipe_id:
|
||||
pipe_id, _ = pipe_id.split(".", 1)
|
||||
pipe_id = form_data['model']
|
||||
if '.' in pipe_id:
|
||||
pipe_id, _ = pipe_id.split('.', 1)
|
||||
return pipe_id
|
||||
|
||||
def get_function_params(function_module, form_data, user, extra_params=None):
|
||||
@@ -202,27 +196,25 @@ async def generate_function_chat_completion(
|
||||
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function_module.pipe)
|
||||
params = {"body": form_data} | {
|
||||
k: v for k, v in extra_params.items() if k in sig.parameters
|
||||
}
|
||||
params = {'body': form_data} | {k: v for k, v in extra_params.items() if k in sig.parameters}
|
||||
|
||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
||||
if '__user__' in params and hasattr(function_module, 'UserValves'):
|
||||
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
|
||||
try:
|
||||
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
|
||||
params['__user__']['valves'] = function_module.UserValves(**user_valves)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
params["__user__"]["valves"] = function_module.UserValves()
|
||||
params['__user__']['valves'] = function_module.UserValves()
|
||||
|
||||
return params
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_id = form_data.get('model')
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
metadata = form_data.pop("metadata", {})
|
||||
metadata = form_data.pop('metadata', {})
|
||||
|
||||
files = metadata.get("files", [])
|
||||
tool_ids = metadata.get("tool_ids", [])
|
||||
files = metadata.get('files', [])
|
||||
tool_ids = metadata.get('tool_ids', [])
|
||||
# Check if tool_ids is None
|
||||
if tool_ids is None:
|
||||
tool_ids = []
|
||||
@@ -233,56 +225,56 @@ async def generate_function_chat_completion(
|
||||
__task_body__ = None
|
||||
|
||||
if metadata:
|
||||
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
|
||||
if all(k in metadata for k in ('session_id', 'chat_id', 'message_id')):
|
||||
__event_emitter__ = get_event_emitter(metadata)
|
||||
__event_call__ = get_event_call(metadata)
|
||||
__task__ = metadata.get("task", None)
|
||||
__task_body__ = metadata.get("task_body", None)
|
||||
__task__ = metadata.get('task', None)
|
||||
__task_body__ = metadata.get('task_body', None)
|
||||
|
||||
oauth_token = None
|
||||
try:
|
||||
if request.cookies.get("oauth_session_id", None):
|
||||
if request.cookies.get('oauth_session_id', None):
|
||||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||
user.id,
|
||||
request.cookies.get("oauth_session_id", None),
|
||||
request.cookies.get('oauth_session_id', None),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth token: {e}")
|
||||
log.error(f'Error getting OAuth token: {e}')
|
||||
|
||||
extra_params = {
|
||||
"__event_emitter__": __event_emitter__,
|
||||
"__event_call__": __event_call__,
|
||||
"__chat_id__": metadata.get("chat_id", None),
|
||||
"__session_id__": metadata.get("session_id", None),
|
||||
"__message_id__": metadata.get("message_id", None),
|
||||
"__task__": __task__,
|
||||
"__task_body__": __task_body__,
|
||||
"__files__": files,
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__oauth_token__": oauth_token,
|
||||
"__request__": request,
|
||||
'__event_emitter__': __event_emitter__,
|
||||
'__event_call__': __event_call__,
|
||||
'__chat_id__': metadata.get('chat_id', None),
|
||||
'__session_id__': metadata.get('session_id', None),
|
||||
'__message_id__': metadata.get('message_id', None),
|
||||
'__task__': __task__,
|
||||
'__task_body__': __task_body__,
|
||||
'__files__': files,
|
||||
'__user__': user.model_dump() if isinstance(user, UserModel) else {},
|
||||
'__metadata__': metadata,
|
||||
'__oauth_token__': oauth_token,
|
||||
'__request__': request,
|
||||
}
|
||||
extra_params["__tools__"] = await get_tools(
|
||||
extra_params['__tools__'] = await get_tools(
|
||||
request,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": models.get(form_data["model"], None),
|
||||
"__messages__": form_data["messages"],
|
||||
"__files__": files,
|
||||
'__model__': models.get(form_data['model'], None),
|
||||
'__messages__': form_data['messages'],
|
||||
'__files__': files,
|
||||
},
|
||||
)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
form_data["model"] = model_info.base_model_id
|
||||
form_data['model'] = model_info.base_model_id
|
||||
|
||||
params = model_info.params.model_dump()
|
||||
|
||||
if params:
|
||||
system = params.pop("system", None)
|
||||
system = params.pop('system', None)
|
||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||
form_data = apply_system_prompt_to_body(system, form_data, metadata, user)
|
||||
|
||||
@@ -292,7 +284,7 @@ async def generate_function_chat_completion(
|
||||
pipe = function_module.pipe
|
||||
params = get_function_params(function_module, form_data, user, extra_params)
|
||||
|
||||
if form_data.get("stream", False):
|
||||
if form_data.get('stream', False):
|
||||
|
||||
async def stream_content():
|
||||
try:
|
||||
@@ -304,17 +296,17 @@ async def generate_function_chat_completion(
|
||||
yield data
|
||||
return
|
||||
if isinstance(res, dict):
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
yield f'data: {json.dumps(res)}\n\n'
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error: {e}")
|
||||
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
|
||||
log.error(f'Error: {e}')
|
||||
yield f'data: {json.dumps({"error": {"detail": str(e)}})}\n\n'
|
||||
return
|
||||
|
||||
if isinstance(res, str):
|
||||
message = openai_chat_chunk_message_template(form_data["model"], res)
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
message = openai_chat_chunk_message_template(form_data['model'], res)
|
||||
yield f'data: {json.dumps(message)}\n\n'
|
||||
|
||||
if isinstance(res, Iterator):
|
||||
for line in res:
|
||||
@@ -325,21 +317,19 @@ async def generate_function_chat_completion(
|
||||
yield process_line(form_data, line)
|
||||
|
||||
if isinstance(res, str) or isinstance(res, Generator):
|
||||
finish_message = openai_chat_chunk_message_template(
|
||||
form_data["model"], ""
|
||||
)
|
||||
finish_message["choices"][0]["finish_reason"] = "stop"
|
||||
yield f"data: {json.dumps(finish_message)}\n\n"
|
||||
yield "data: [DONE]"
|
||||
finish_message = openai_chat_chunk_message_template(form_data['model'], '')
|
||||
finish_message['choices'][0]['finish_reason'] = 'stop'
|
||||
yield f'data: {json.dumps(finish_message)}\n\n'
|
||||
yield 'data: [DONE]'
|
||||
|
||||
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
||||
return StreamingResponse(stream_content(), media_type='text/event-stream')
|
||||
else:
|
||||
try:
|
||||
res = await execute_pipe(pipe, params)
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error: {e}")
|
||||
return {"error": {"detail": str(e)}}
|
||||
log.error(f'Error: {e}')
|
||||
return {'error': {'detail': str(e)}}
|
||||
|
||||
if isinstance(res, StreamingResponse) or isinstance(res, dict):
|
||||
return res
|
||||
@@ -347,4 +337,4 @@ async def generate_function_chat_completion(
|
||||
return res.model_dump()
|
||||
|
||||
message = await get_message_content(res)
|
||||
return openai_chat_completion_message_template(form_data["model"], message)
|
||||
return openai_chat_completion_message_template(form_data['model'], message)
|
||||
|
||||
@@ -56,17 +56,15 @@ def handle_peewee_migration(DATABASE_URL):
|
||||
# db = None
|
||||
try:
|
||||
# Replace the postgresql:// with postgres:// to handle the peewee migration
|
||||
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
|
||||
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
|
||||
db = register_connection(DATABASE_URL.replace('postgresql://', 'postgres://'))
|
||||
migrate_dir = OPEN_WEBUI_DIR / 'internal' / 'migrations'
|
||||
router = Router(db, logger=log, migrate_dir=migrate_dir)
|
||||
router.run()
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize the database connection: {e}")
|
||||
log.warning(
|
||||
"Hint: If your database password contains special characters, you may need to URL-encode it."
|
||||
)
|
||||
log.error(f'Failed to initialize the database connection: {e}')
|
||||
log.warning('Hint: If your database password contains special characters, you may need to URL-encode it.')
|
||||
raise
|
||||
finally:
|
||||
# Properly closing the database connection
|
||||
@@ -74,7 +72,7 @@ def handle_peewee_migration(DATABASE_URL):
|
||||
db.close()
|
||||
|
||||
# Assert if db connection has been closed
|
||||
assert db.is_closed(), "Database connection is still open."
|
||||
assert db.is_closed(), 'Database connection is still open.'
|
||||
|
||||
|
||||
if ENABLE_DB_MIGRATIONS:
|
||||
@@ -84,15 +82,13 @@ if ENABLE_DB_MIGRATIONS:
|
||||
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
||||
|
||||
# Handle SQLCipher URLs
|
||||
if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
||||
database_password = os.environ.get("DATABASE_PASSWORD")
|
||||
if not database_password or database_password.strip() == "":
|
||||
raise ValueError(
|
||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||
)
|
||||
if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'):
|
||||
database_password = os.environ.get('DATABASE_PASSWORD')
|
||||
if not database_password or database_password.strip() == '':
|
||||
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
|
||||
|
||||
# Extract database path from SQLCipher URL
|
||||
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
|
||||
db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '')
|
||||
|
||||
# Create a custom creator function that uses sqlcipher3
|
||||
def create_sqlcipher_connection():
|
||||
@@ -109,7 +105,7 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
||||
# or QueuePool if DATABASE_POOL_SIZE is explicitly configured.
|
||||
if isinstance(DATABASE_POOL_SIZE, int) and DATABASE_POOL_SIZE > 0:
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
'sqlite://',
|
||||
creator=create_sqlcipher_connection,
|
||||
pool_size=DATABASE_POOL_SIZE,
|
||||
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
|
||||
@@ -121,28 +117,26 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
||||
)
|
||||
else:
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
'sqlite://',
|
||||
creator=create_sqlcipher_connection,
|
||||
poolclass=NullPool,
|
||||
echo=False,
|
||||
)
|
||||
|
||||
log.info("Connected to encrypted SQLite database using SQLCipher")
|
||||
log.info('Connected to encrypted SQLite database using SQLCipher')
|
||||
|
||||
elif "sqlite" in SQLALCHEMY_DATABASE_URL:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
||||
)
|
||||
elif 'sqlite' in SQLALCHEMY_DATABASE_URL:
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={'check_same_thread': False})
|
||||
|
||||
def on_connect(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
if DATABASE_ENABLE_SQLITE_WAL:
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute('PRAGMA journal_mode=WAL')
|
||||
else:
|
||||
cursor.execute("PRAGMA journal_mode=DELETE")
|
||||
cursor.execute('PRAGMA journal_mode=DELETE')
|
||||
cursor.close()
|
||||
|
||||
event.listen(engine, "connect", on_connect)
|
||||
event.listen(engine, 'connect', on_connect)
|
||||
else:
|
||||
if isinstance(DATABASE_POOL_SIZE, int):
|
||||
if DATABASE_POOL_SIZE > 0:
|
||||
@@ -156,16 +150,12 @@ else:
|
||||
poolclass=QueuePool,
|
||||
)
|
||||
else:
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool)
|
||||
else:
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
|
||||
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
|
||||
Base = declarative_base(metadata=metadata_obj)
|
||||
ScopedSession = scoped_session(SessionLocal)
|
||||
|
||||
@@ -56,7 +56,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
active = pw.BooleanField()
|
||||
|
||||
class Meta:
|
||||
table_name = "auth"
|
||||
table_name = 'auth'
|
||||
|
||||
@migrator.create_model
|
||||
class Chat(pw.Model):
|
||||
@@ -67,7 +67,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chat"
|
||||
table_name = 'chat'
|
||||
|
||||
@migrator.create_model
|
||||
class ChatIdTag(pw.Model):
|
||||
@@ -78,7 +78,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chatidtag"
|
||||
table_name = 'chatidtag'
|
||||
|
||||
@migrator.create_model
|
||||
class Document(pw.Model):
|
||||
@@ -92,7 +92,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "document"
|
||||
table_name = 'document'
|
||||
|
||||
@migrator.create_model
|
||||
class Modelfile(pw.Model):
|
||||
@@ -103,7 +103,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "modelfile"
|
||||
table_name = 'modelfile'
|
||||
|
||||
@migrator.create_model
|
||||
class Prompt(pw.Model):
|
||||
@@ -115,7 +115,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "prompt"
|
||||
table_name = 'prompt'
|
||||
|
||||
@migrator.create_model
|
||||
class Tag(pw.Model):
|
||||
@@ -125,7 +125,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
data = pw.TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = "tag"
|
||||
table_name = 'tag'
|
||||
|
||||
@migrator.create_model
|
||||
class User(pw.Model):
|
||||
@@ -137,7 +137,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "user"
|
||||
table_name = 'user'
|
||||
|
||||
|
||||
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
@@ -149,7 +149,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
active = pw.BooleanField()
|
||||
|
||||
class Meta:
|
||||
table_name = "auth"
|
||||
table_name = 'auth'
|
||||
|
||||
@migrator.create_model
|
||||
class Chat(pw.Model):
|
||||
@@ -160,7 +160,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chat"
|
||||
table_name = 'chat'
|
||||
|
||||
@migrator.create_model
|
||||
class ChatIdTag(pw.Model):
|
||||
@@ -171,7 +171,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "chatidtag"
|
||||
table_name = 'chatidtag'
|
||||
|
||||
@migrator.create_model
|
||||
class Document(pw.Model):
|
||||
@@ -185,7 +185,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "document"
|
||||
table_name = 'document'
|
||||
|
||||
@migrator.create_model
|
||||
class Modelfile(pw.Model):
|
||||
@@ -196,7 +196,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "modelfile"
|
||||
table_name = 'modelfile'
|
||||
|
||||
@migrator.create_model
|
||||
class Prompt(pw.Model):
|
||||
@@ -208,7 +208,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "prompt"
|
||||
table_name = 'prompt'
|
||||
|
||||
@migrator.create_model
|
||||
class Tag(pw.Model):
|
||||
@@ -218,7 +218,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
data = pw.TextField(null=True)
|
||||
|
||||
class Meta:
|
||||
table_name = "tag"
|
||||
table_name = 'tag'
|
||||
|
||||
@migrator.create_model
|
||||
class User(pw.Model):
|
||||
@@ -230,24 +230,24 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
timestamp = pw.BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
table_name = "user"
|
||||
table_name = 'user'
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("user")
|
||||
migrator.remove_model('user')
|
||||
|
||||
migrator.remove_model("tag")
|
||||
migrator.remove_model('tag')
|
||||
|
||||
migrator.remove_model("prompt")
|
||||
migrator.remove_model('prompt')
|
||||
|
||||
migrator.remove_model("modelfile")
|
||||
migrator.remove_model('modelfile')
|
||||
|
||||
migrator.remove_model("document")
|
||||
migrator.remove_model('document')
|
||||
|
||||
migrator.remove_model("chatidtag")
|
||||
migrator.remove_model('chatidtag')
|
||||
|
||||
migrator.remove_model("chat")
|
||||
migrator.remove_model('chat')
|
||||
|
||||
migrator.remove_model("auth")
|
||||
migrator.remove_model('auth')
|
||||
|
||||
@@ -36,12 +36,10 @@ with suppress(ImportError):
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"chat", share_id=pw.CharField(max_length=255, null=True, unique=True)
|
||||
)
|
||||
migrator.add_fields('chat', share_id=pw.CharField(max_length=255, null=True, unique=True))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("chat", "share_id")
|
||||
migrator.remove_fields('chat', 'share_id')
|
||||
|
||||
@@ -36,12 +36,10 @@ with suppress(ImportError):
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"user", api_key=pw.CharField(max_length=255, null=True, unique=True)
|
||||
)
|
||||
migrator.add_fields('user', api_key=pw.CharField(max_length=255, null=True, unique=True))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("user", "api_key")
|
||||
migrator.remove_fields('user', 'api_key')
|
||||
|
||||
@@ -36,10 +36,10 @@ with suppress(ImportError):
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields("chat", archived=pw.BooleanField(default=False))
|
||||
migrator.add_fields('chat', archived=pw.BooleanField(default=False))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("chat", "archived")
|
||||
migrator.remove_fields('chat', 'archived')
|
||||
|
||||
@@ -45,22 +45,20 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Adding fields created_at and updated_at to the 'chat' table
|
||||
migrator.add_fields(
|
||||
"chat",
|
||||
'chat',
|
||||
created_at=pw.DateTimeField(null=True), # Allow null for transition
|
||||
updated_at=pw.DateTimeField(null=True), # Allow null for transition
|
||||
)
|
||||
|
||||
# Populate the new fields from an existing 'timestamp' field
|
||||
migrator.sql(
|
||||
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
|
||||
)
|
||||
migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL')
|
||||
|
||||
# Now that the data has been copied, remove the original 'timestamp' field
|
||||
migrator.remove_fields("chat", "timestamp")
|
||||
migrator.remove_fields('chat', 'timestamp')
|
||||
|
||||
# Update the fields to be not null now that they are populated
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
'chat',
|
||||
created_at=pw.DateTimeField(null=False),
|
||||
updated_at=pw.DateTimeField(null=False),
|
||||
)
|
||||
@@ -69,22 +67,20 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Adding fields created_at and updated_at to the 'chat' table
|
||||
migrator.add_fields(
|
||||
"chat",
|
||||
'chat',
|
||||
created_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
)
|
||||
|
||||
# Populate the new fields from an existing 'timestamp' field
|
||||
migrator.sql(
|
||||
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
|
||||
)
|
||||
migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL')
|
||||
|
||||
# Now that the data has been copied, remove the original 'timestamp' field
|
||||
migrator.remove_fields("chat", "timestamp")
|
||||
migrator.remove_fields('chat', 'timestamp')
|
||||
|
||||
# Update the fields to be not null now that they are populated
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
'chat',
|
||||
created_at=pw.BigIntegerField(null=False),
|
||||
updated_at=pw.BigIntegerField(null=False),
|
||||
)
|
||||
@@ -101,29 +97,29 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
|
||||
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Recreate the timestamp field initially allowing null values for safe transition
|
||||
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
|
||||
migrator.add_fields('chat', timestamp=pw.DateTimeField(null=True))
|
||||
|
||||
# Copy the earliest created_at date back into the new timestamp field
|
||||
# This assumes created_at was originally a copy of timestamp
|
||||
migrator.sql("UPDATE chat SET timestamp = created_at")
|
||||
migrator.sql('UPDATE chat SET timestamp = created_at')
|
||||
|
||||
# Remove the created_at and updated_at fields
|
||||
migrator.remove_fields("chat", "created_at", "updated_at")
|
||||
migrator.remove_fields('chat', 'created_at', 'updated_at')
|
||||
|
||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
|
||||
migrator.change_fields('chat', timestamp=pw.DateTimeField(null=False))
|
||||
|
||||
|
||||
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Recreate the timestamp field initially allowing null values for safe transition
|
||||
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
|
||||
migrator.add_fields('chat', timestamp=pw.BigIntegerField(null=True))
|
||||
|
||||
# Copy the earliest created_at date back into the new timestamp field
|
||||
# This assumes created_at was originally a copy of timestamp
|
||||
migrator.sql("UPDATE chat SET timestamp = created_at")
|
||||
migrator.sql('UPDATE chat SET timestamp = created_at')
|
||||
|
||||
# Remove the created_at and updated_at fields
|
||||
migrator.remove_fields("chat", "created_at", "updated_at")
|
||||
migrator.remove_fields('chat', 'created_at', 'updated_at')
|
||||
|
||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))
|
||||
migrator.change_fields('chat', timestamp=pw.BigIntegerField(null=False))
|
||||
|
||||
@@ -38,45 +38,45 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
|
||||
# Alter the tables with timestamps
|
||||
migrator.change_fields(
|
||||
"chatidtag",
|
||||
'chatidtag',
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
'document',
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"modelfile",
|
||||
'modelfile',
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
'prompt',
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
'user',
|
||||
timestamp=pw.BigIntegerField(),
|
||||
)
|
||||
# Alter the tables with varchar to text where necessary
|
||||
migrator.change_fields(
|
||||
"auth",
|
||||
'auth',
|
||||
password=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
'chat',
|
||||
title=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
'document',
|
||||
title=pw.TextField(),
|
||||
filename=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
'prompt',
|
||||
title=pw.TextField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
'user',
|
||||
profile_image_url=pw.TextField(),
|
||||
)
|
||||
|
||||
@@ -87,43 +87,43 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
if isinstance(database, pw.SqliteDatabase):
|
||||
# Alter the tables with timestamps
|
||||
migrator.change_fields(
|
||||
"chatidtag",
|
||||
'chatidtag',
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
'document',
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"modelfile",
|
||||
'modelfile',
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
'prompt',
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
'user',
|
||||
timestamp=pw.DateField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"auth",
|
||||
'auth',
|
||||
password=pw.CharField(max_length=255),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"chat",
|
||||
'chat',
|
||||
title=pw.CharField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"document",
|
||||
'document',
|
||||
title=pw.CharField(),
|
||||
filename=pw.CharField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"prompt",
|
||||
'prompt',
|
||||
title=pw.CharField(),
|
||||
)
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
'user',
|
||||
profile_image_url=pw.CharField(),
|
||||
)
|
||||
|
||||
@@ -38,7 +38,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
|
||||
# Adding fields created_at and updated_at to the 'user' table
|
||||
migrator.add_fields(
|
||||
"user",
|
||||
'user',
|
||||
created_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
last_active_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||
@@ -50,11 +50,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
)
|
||||
|
||||
# Now that the data has been copied, remove the original 'timestamp' field
|
||||
migrator.remove_fields("user", "timestamp")
|
||||
migrator.remove_fields('user', 'timestamp')
|
||||
|
||||
# Update the fields to be not null now that they are populated
|
||||
migrator.change_fields(
|
||||
"user",
|
||||
'user',
|
||||
created_at=pw.BigIntegerField(null=False),
|
||||
updated_at=pw.BigIntegerField(null=False),
|
||||
last_active_at=pw.BigIntegerField(null=False),
|
||||
@@ -65,14 +65,14 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Recreate the timestamp field initially allowing null values for safe transition
|
||||
migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True))
|
||||
migrator.add_fields('user', timestamp=pw.BigIntegerField(null=True))
|
||||
|
||||
# Copy the earliest created_at date back into the new timestamp field
|
||||
# This assumes created_at was originally a copy of timestamp
|
||||
migrator.sql('UPDATE "user" SET timestamp = created_at')
|
||||
|
||||
# Remove the created_at and updated_at fields
|
||||
migrator.remove_fields("user", "created_at", "updated_at", "last_active_at")
|
||||
migrator.remove_fields('user', 'created_at', 'updated_at', 'last_active_at')
|
||||
|
||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||
migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False))
|
||||
migrator.change_fields('user', timestamp=pw.BigIntegerField(null=False))
|
||||
|
||||
@@ -43,10 +43,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "memory"
|
||||
table_name = 'memory'
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("memory")
|
||||
migrator.remove_model('memory')
|
||||
|
||||
@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "model"
|
||||
table_name = 'model'
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("model")
|
||||
migrator.remove_model('model')
|
||||
|
||||
@@ -42,12 +42,12 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
# Fetch data from 'modelfile' table and insert into 'model' table
|
||||
migrate_modelfile_to_model(migrator, database)
|
||||
# Drop the 'modelfile' table
|
||||
migrator.remove_model("modelfile")
|
||||
migrator.remove_model('modelfile')
|
||||
|
||||
|
||||
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
|
||||
ModelFile = migrator.orm["modelfile"]
|
||||
Model = migrator.orm["model"]
|
||||
ModelFile = migrator.orm['modelfile']
|
||||
Model = migrator.orm['model']
|
||||
|
||||
modelfiles = ModelFile.select()
|
||||
|
||||
@@ -57,25 +57,25 @@ def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
|
||||
modelfile.modelfile = json.loads(modelfile.modelfile)
|
||||
meta = json.dumps(
|
||||
{
|
||||
"description": modelfile.modelfile.get("desc"),
|
||||
"profile_image_url": modelfile.modelfile.get("imageUrl"),
|
||||
"ollama": {"modelfile": modelfile.modelfile.get("content")},
|
||||
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"),
|
||||
"categories": modelfile.modelfile.get("categories"),
|
||||
"user": {**modelfile.modelfile.get("user", {}), "community": True},
|
||||
'description': modelfile.modelfile.get('desc'),
|
||||
'profile_image_url': modelfile.modelfile.get('imageUrl'),
|
||||
'ollama': {'modelfile': modelfile.modelfile.get('content')},
|
||||
'suggestion_prompts': modelfile.modelfile.get('suggestionPrompts'),
|
||||
'categories': modelfile.modelfile.get('categories'),
|
||||
'user': {**modelfile.modelfile.get('user', {}), 'community': True},
|
||||
}
|
||||
)
|
||||
|
||||
info = parse_ollama_modelfile(modelfile.modelfile.get("content"))
|
||||
info = parse_ollama_modelfile(modelfile.modelfile.get('content'))
|
||||
|
||||
# Insert the processed data into the 'model' table
|
||||
Model.create(
|
||||
id=f"ollama-{modelfile.tag_name}",
|
||||
id=f'ollama-{modelfile.tag_name}',
|
||||
user_id=modelfile.user_id,
|
||||
base_model_id=info.get("base_model_id"),
|
||||
name=modelfile.modelfile.get("title"),
|
||||
base_model_id=info.get('base_model_id'),
|
||||
name=modelfile.modelfile.get('title'),
|
||||
meta=meta,
|
||||
params=json.dumps(info.get("params", {})),
|
||||
params=json.dumps(info.get('params', {})),
|
||||
created_at=modelfile.timestamp,
|
||||
updated_at=modelfile.timestamp,
|
||||
)
|
||||
@@ -86,7 +86,7 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
|
||||
recreate_modelfile_table(migrator, database)
|
||||
move_data_back_to_modelfile(migrator, database)
|
||||
migrator.remove_model("model")
|
||||
migrator.remove_model('model')
|
||||
|
||||
|
||||
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
|
||||
@@ -102,8 +102,8 @@ def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
|
||||
|
||||
|
||||
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
|
||||
Model = migrator.orm["model"]
|
||||
Modelfile = migrator.orm["modelfile"]
|
||||
Model = migrator.orm['model']
|
||||
Modelfile = migrator.orm['modelfile']
|
||||
|
||||
models = Model.select()
|
||||
|
||||
@@ -112,13 +112,13 @@ def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
|
||||
meta = json.loads(model.meta)
|
||||
|
||||
modelfile_data = {
|
||||
"title": model.name,
|
||||
"desc": meta.get("description"),
|
||||
"imageUrl": meta.get("profile_image_url"),
|
||||
"content": meta.get("ollama", {}).get("modelfile"),
|
||||
"suggestionPrompts": meta.get("suggestion_prompts"),
|
||||
"categories": meta.get("categories"),
|
||||
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"},
|
||||
'title': model.name,
|
||||
'desc': meta.get('description'),
|
||||
'imageUrl': meta.get('profile_image_url'),
|
||||
'content': meta.get('ollama', {}).get('modelfile'),
|
||||
'suggestionPrompts': meta.get('suggestion_prompts'),
|
||||
'categories': meta.get('categories'),
|
||||
'user': {k: v for k, v in meta.get('user', {}).items() if k != 'community'},
|
||||
}
|
||||
|
||||
# Insert the processed data back into the 'modelfile' table
|
||||
|
||||
@@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Adding fields settings to the 'user' table
|
||||
migrator.add_fields("user", settings=pw.TextField(null=True))
|
||||
migrator.add_fields('user', settings=pw.TextField(null=True))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Remove the settings field
|
||||
migrator.remove_fields("user", "settings")
|
||||
migrator.remove_fields('user', 'settings')
|
||||
|
||||
@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "tool"
|
||||
table_name = 'tool'
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("tool")
|
||||
migrator.remove_model('tool')
|
||||
|
||||
@@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Adding fields info to the 'user' table
|
||||
migrator.add_fields("user", info=pw.TextField(null=True))
|
||||
migrator.add_fields('user', info=pw.TextField(null=True))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Remove the settings field
|
||||
migrator.remove_fields("user", "info")
|
||||
migrator.remove_fields('user', 'info')
|
||||
|
||||
@@ -45,10 +45,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "file"
|
||||
table_name = 'file'
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("file")
|
||||
migrator.remove_model('file')
|
||||
|
||||
@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "function"
|
||||
table_name = 'function'
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("function")
|
||||
migrator.remove_model('function')
|
||||
|
||||
@@ -36,14 +36,14 @@ with suppress(ImportError):
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields("tool", valves=pw.TextField(null=True))
|
||||
migrator.add_fields("function", valves=pw.TextField(null=True))
|
||||
migrator.add_fields("function", is_active=pw.BooleanField(default=False))
|
||||
migrator.add_fields('tool', valves=pw.TextField(null=True))
|
||||
migrator.add_fields('function', valves=pw.TextField(null=True))
|
||||
migrator.add_fields('function', is_active=pw.BooleanField(default=False))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("tool", "valves")
|
||||
migrator.remove_fields("function", "valves")
|
||||
migrator.remove_fields("function", "is_active")
|
||||
migrator.remove_fields('tool', 'valves')
|
||||
migrator.remove_fields('function', 'valves')
|
||||
migrator.remove_fields('function', 'is_active')
|
||||
|
||||
@@ -33,7 +33,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"user",
|
||||
'user',
|
||||
oauth_sub=pw.TextField(null=True, unique=True),
|
||||
)
|
||||
|
||||
@@ -41,4 +41,4 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("user", "oauth_sub")
|
||||
migrator.remove_fields('user', 'oauth_sub')
|
||||
|
||||
@@ -37,7 +37,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
migrator.add_fields(
|
||||
"function",
|
||||
'function',
|
||||
is_global=pw.BooleanField(default=False),
|
||||
)
|
||||
|
||||
@@ -45,4 +45,4 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_fields("function", "is_global")
|
||||
migrator.remove_fields('function', 'is_global')
|
||||
|
||||
@@ -10,13 +10,13 @@ from playhouse.shortcuts import ReconnectMixin
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
||||
db_state = ContextVar("db_state", default=db_state_default.copy())
|
||||
db_state_default = {'closed': None, 'conn': None, 'ctx': None, 'transactions': None}
|
||||
db_state = ContextVar('db_state', default=db_state_default.copy())
|
||||
|
||||
|
||||
class PeeweeConnectionState(object):
|
||||
def __init__(self, **kwargs):
|
||||
super().__setattr__("_state", db_state)
|
||||
super().__setattr__('_state', db_state)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
@@ -30,10 +30,10 @@ class PeeweeConnectionState(object):
|
||||
class CustomReconnectMixin(ReconnectMixin):
|
||||
reconnect_errors = (
|
||||
# psycopg2
|
||||
(OperationalError, "termin"),
|
||||
(InterfaceError, "closed"),
|
||||
(OperationalError, 'termin'),
|
||||
(InterfaceError, 'closed'),
|
||||
# peewee
|
||||
(PeeWeeInterfaceError, "closed"),
|
||||
(PeeWeeInterfaceError, 'closed'),
|
||||
)
|
||||
|
||||
|
||||
@@ -43,23 +43,21 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
||||
|
||||
def register_connection(db_url):
|
||||
# Check if using SQLCipher protocol
|
||||
if db_url.startswith("sqlite+sqlcipher://"):
|
||||
database_password = os.environ.get("DATABASE_PASSWORD")
|
||||
if not database_password or database_password.strip() == "":
|
||||
raise ValueError(
|
||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||
)
|
||||
if db_url.startswith('sqlite+sqlcipher://'):
|
||||
database_password = os.environ.get('DATABASE_PASSWORD')
|
||||
if not database_password or database_password.strip() == '':
|
||||
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
|
||||
from playhouse.sqlcipher_ext import SqlCipherDatabase
|
||||
|
||||
# Parse the database path from SQLCipher URL
|
||||
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
|
||||
db_path = db_url.replace("sqlite+sqlcipher://", "")
|
||||
db_path = db_url.replace('sqlite+sqlcipher://', '')
|
||||
|
||||
# Use Peewee's native SqlCipherDatabase with encryption
|
||||
db = SqlCipherDatabase(db_path, passphrase=database_password)
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to encrypted SQLite database using SQLCipher")
|
||||
log.info('Connected to encrypted SQLite database using SQLCipher')
|
||||
|
||||
else:
|
||||
# Standard database connection (existing logic)
|
||||
@@ -68,7 +66,7 @@ def register_connection(db_url):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to PostgreSQL database")
|
||||
log.info('Connected to PostgreSQL database')
|
||||
|
||||
# Get the connection details
|
||||
connection = parse(db_url, unquote_user=True, unquote_password=True)
|
||||
@@ -80,7 +78,7 @@ def register_connection(db_url):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to SQLite database")
|
||||
log.info('Connected to SQLite database')
|
||||
else:
|
||||
raise ValueError("Unsupported database connection")
|
||||
raise ValueError('Unsupported database connection')
|
||||
return db
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@ if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||
|
||||
# Re-apply JSON formatter after fileConfig replaces handlers.
|
||||
if LOG_FORMAT == "json":
|
||||
if LOG_FORMAT == 'json':
|
||||
from open_webui.env import JSONFormatter
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
@@ -36,7 +36,7 @@ target_metadata = Auth.metadata
|
||||
DB_URL = DATABASE_URL
|
||||
|
||||
if DB_URL:
|
||||
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
|
||||
config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%'))
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
@@ -51,12 +51,12 @@ def run_migrations_offline() -> None:
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
url = config.get_main_option('sqlalchemy.url')
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
dialect_opts={'paramstyle': 'named'},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
@@ -71,15 +71,13 @@ def run_migrations_online() -> None:
|
||||
|
||||
"""
|
||||
# Handle SQLCipher URLs
|
||||
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
|
||||
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
|
||||
raise ValueError(
|
||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||
)
|
||||
if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'):
|
||||
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '':
|
||||
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
|
||||
|
||||
# Extract database path from SQLCipher URL
|
||||
db_path = DB_URL.replace("sqlite+sqlcipher://", "")
|
||||
if db_path.startswith("/"):
|
||||
db_path = DB_URL.replace('sqlite+sqlcipher://', '')
|
||||
if db_path.startswith('/'):
|
||||
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||
|
||||
# Create a custom creator function that uses sqlcipher3
|
||||
@@ -91,7 +89,7 @@ def run_migrations_online() -> None:
|
||||
return conn
|
||||
|
||||
connectable = create_engine(
|
||||
"sqlite://", # Dummy URL since we're using creator
|
||||
'sqlite://', # Dummy URL since we're using creator
|
||||
creator=create_sqlcipher_connection,
|
||||
echo=False,
|
||||
)
|
||||
@@ -99,7 +97,7 @@ def run_migrations_online() -> None:
|
||||
# Standard database connection (existing logic)
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
prefix='sqlalchemy.',
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,4 +12,4 @@ def get_existing_tables():
|
||||
def get_revision_id():
|
||||
import uuid
|
||||
|
||||
return str(uuid.uuid4()).replace("-", "")[:12]
|
||||
return str(uuid.uuid4()).replace('-', '')[:12]
|
||||
|
||||
@@ -9,38 +9,38 @@ Create Date: 2025-08-13 03:00:00.000000
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "018012973d35"
|
||||
down_revision = "d31026856c01"
|
||||
revision = '018012973d35'
|
||||
down_revision = 'd31026856c01'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Chat table indexes
|
||||
op.create_index("folder_id_idx", "chat", ["folder_id"])
|
||||
op.create_index("user_id_pinned_idx", "chat", ["user_id", "pinned"])
|
||||
op.create_index("user_id_archived_idx", "chat", ["user_id", "archived"])
|
||||
op.create_index("updated_at_user_id_idx", "chat", ["updated_at", "user_id"])
|
||||
op.create_index("folder_id_user_id_idx", "chat", ["folder_id", "user_id"])
|
||||
op.create_index('folder_id_idx', 'chat', ['folder_id'])
|
||||
op.create_index('user_id_pinned_idx', 'chat', ['user_id', 'pinned'])
|
||||
op.create_index('user_id_archived_idx', 'chat', ['user_id', 'archived'])
|
||||
op.create_index('updated_at_user_id_idx', 'chat', ['updated_at', 'user_id'])
|
||||
op.create_index('folder_id_user_id_idx', 'chat', ['folder_id', 'user_id'])
|
||||
|
||||
# Tag table index
|
||||
op.create_index("user_id_idx", "tag", ["user_id"])
|
||||
op.create_index('user_id_idx', 'tag', ['user_id'])
|
||||
|
||||
# Function table index
|
||||
op.create_index("is_global_idx", "function", ["is_global"])
|
||||
op.create_index('is_global_idx', 'function', ['is_global'])
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Chat table indexes
|
||||
op.drop_index("folder_id_idx", table_name="chat")
|
||||
op.drop_index("user_id_pinned_idx", table_name="chat")
|
||||
op.drop_index("user_id_archived_idx", table_name="chat")
|
||||
op.drop_index("updated_at_user_id_idx", table_name="chat")
|
||||
op.drop_index("folder_id_user_id_idx", table_name="chat")
|
||||
op.drop_index('folder_id_idx', table_name='chat')
|
||||
op.drop_index('user_id_pinned_idx', table_name='chat')
|
||||
op.drop_index('user_id_archived_idx', table_name='chat')
|
||||
op.drop_index('updated_at_user_id_idx', table_name='chat')
|
||||
op.drop_index('folder_id_user_id_idx', table_name='chat')
|
||||
|
||||
# Tag table index
|
||||
op.drop_index("user_id_idx", table_name="tag")
|
||||
op.drop_index('user_id_idx', table_name='tag')
|
||||
|
||||
# Function table index
|
||||
|
||||
op.drop_index("is_global_idx", table_name="function")
|
||||
op.drop_index('is_global_idx', table_name='function')
|
||||
|
||||
@@ -13,8 +13,8 @@ from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
import json
|
||||
|
||||
revision = "1af9b942657b"
|
||||
down_revision = "242a2047eae0"
|
||||
revision = '1af9b942657b'
|
||||
down_revision = '242a2047eae0'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -25,43 +25,40 @@ def upgrade():
|
||||
inspector = Inspector.from_engine(conn)
|
||||
|
||||
# Clean up potential leftover temp table from previous failures
|
||||
conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag"))
|
||||
conn.execute(sa.text('DROP TABLE IF EXISTS _alembic_tmp_tag'))
|
||||
|
||||
# Check if the 'tag' table exists
|
||||
tables = inspector.get_table_names()
|
||||
|
||||
# Step 1: Modify Tag table using batch mode for SQLite support
|
||||
if "tag" in tables:
|
||||
if 'tag' in tables:
|
||||
# Get the current columns in the 'tag' table
|
||||
columns = [col["name"] for col in inspector.get_columns("tag")]
|
||||
columns = [col['name'] for col in inspector.get_columns('tag')]
|
||||
|
||||
# Get any existing unique constraints on the 'tag' table
|
||||
current_constraints = inspector.get_unique_constraints("tag")
|
||||
current_constraints = inspector.get_unique_constraints('tag')
|
||||
|
||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
||||
with op.batch_alter_table('tag', schema=None) as batch_op:
|
||||
# Check if the unique constraint already exists
|
||||
if not any(
|
||||
constraint["name"] == "uq_id_user_id"
|
||||
for constraint in current_constraints
|
||||
):
|
||||
if not any(constraint['name'] == 'uq_id_user_id' for constraint in current_constraints):
|
||||
# Create unique constraint if it doesn't exist
|
||||
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
|
||||
batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id'])
|
||||
|
||||
# Check if the 'data' column exists before trying to drop it
|
||||
if "data" in columns:
|
||||
batch_op.drop_column("data")
|
||||
if 'data' in columns:
|
||||
batch_op.drop_column('data')
|
||||
|
||||
# Check if the 'meta' column needs to be created
|
||||
if "meta" not in columns:
|
||||
if 'meta' not in columns:
|
||||
# Add the 'meta' column if it doesn't already exist
|
||||
batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
|
||||
batch_op.add_column(sa.Column('meta', sa.JSON(), nullable=True))
|
||||
|
||||
tag = table(
|
||||
"tag",
|
||||
column("id", sa.String()),
|
||||
column("name", sa.String()),
|
||||
column("user_id", sa.String()),
|
||||
column("meta", sa.JSON()),
|
||||
'tag',
|
||||
column('id', sa.String()),
|
||||
column('name', sa.String()),
|
||||
column('user_id', sa.String()),
|
||||
column('meta', sa.JSON()),
|
||||
)
|
||||
|
||||
# Step 2: Migrate tags
|
||||
@@ -70,12 +67,12 @@ def upgrade():
|
||||
|
||||
tag_updates = {}
|
||||
for row in result:
|
||||
new_id = row.name.replace(" ", "_").lower()
|
||||
new_id = row.name.replace(' ', '_').lower()
|
||||
tag_updates[row.id] = new_id
|
||||
|
||||
for tag_id, new_tag_id in tag_updates.items():
|
||||
print(f"Updating tag {tag_id} to {new_tag_id}")
|
||||
if new_tag_id == "pinned":
|
||||
print(f'Updating tag {tag_id} to {new_tag_id}')
|
||||
if new_tag_id == 'pinned':
|
||||
# delete tag
|
||||
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
|
||||
conn.execute(delete_stmt)
|
||||
@@ -86,9 +83,7 @@ def upgrade():
|
||||
|
||||
if existing_tag_result:
|
||||
# Handle duplicate case: the new_tag_id already exists
|
||||
print(
|
||||
f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates."
|
||||
)
|
||||
print(f'Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates.')
|
||||
# Option 1: Delete the current tag if an update to new_tag_id would cause duplication
|
||||
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
|
||||
conn.execute(delete_stmt)
|
||||
@@ -98,19 +93,15 @@ def upgrade():
|
||||
conn.execute(update_stmt)
|
||||
|
||||
# Add columns `pinned` and `meta` to 'chat'
|
||||
op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
|
||||
op.add_column(
|
||||
"chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
|
||||
)
|
||||
op.add_column('chat', sa.Column('pinned', sa.Boolean(), nullable=True))
|
||||
op.add_column('chat', sa.Column('meta', sa.JSON(), nullable=False, server_default='{}'))
|
||||
|
||||
chatidtag = table(
|
||||
"chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
|
||||
)
|
||||
chatidtag = table('chatidtag', column('chat_id', sa.String()), column('tag_name', sa.String()))
|
||||
chat = table(
|
||||
"chat",
|
||||
column("id", sa.String()),
|
||||
column("pinned", sa.Boolean()),
|
||||
column("meta", sa.JSON()),
|
||||
'chat',
|
||||
column('id', sa.String()),
|
||||
column('pinned', sa.Boolean()),
|
||||
column('meta', sa.JSON()),
|
||||
)
|
||||
|
||||
# Fetch existing tags
|
||||
@@ -120,29 +111,27 @@ def upgrade():
|
||||
chat_updates = {}
|
||||
for row in result:
|
||||
chat_id = row.chat_id
|
||||
tag_name = row.tag_name.replace(" ", "_").lower()
|
||||
tag_name = row.tag_name.replace(' ', '_').lower()
|
||||
|
||||
if tag_name == "pinned":
|
||||
if tag_name == 'pinned':
|
||||
# Specifically handle 'pinned' tag
|
||||
if chat_id not in chat_updates:
|
||||
chat_updates[chat_id] = {"pinned": True, "meta": {}}
|
||||
chat_updates[chat_id] = {'pinned': True, 'meta': {}}
|
||||
else:
|
||||
chat_updates[chat_id]["pinned"] = True
|
||||
chat_updates[chat_id]['pinned'] = True
|
||||
else:
|
||||
if chat_id not in chat_updates:
|
||||
chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
|
||||
chat_updates[chat_id] = {'pinned': False, 'meta': {'tags': [tag_name]}}
|
||||
else:
|
||||
tags = chat_updates[chat_id]["meta"].get("tags", [])
|
||||
tags = chat_updates[chat_id]['meta'].get('tags', [])
|
||||
tags.append(tag_name)
|
||||
|
||||
chat_updates[chat_id]["meta"]["tags"] = list(set(tags))
|
||||
chat_updates[chat_id]['meta']['tags'] = list(set(tags))
|
||||
|
||||
# Update chats based on accumulated changes
|
||||
for chat_id, updates in chat_updates.items():
|
||||
update_stmt = sa.update(chat).where(chat.c.id == chat_id)
|
||||
update_stmt = update_stmt.values(
|
||||
meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
|
||||
)
|
||||
update_stmt = update_stmt.values(meta=updates.get('meta', {}), pinned=updates.get('pinned', False))
|
||||
conn.execute(update_stmt)
|
||||
pass
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ from sqlalchemy.sql import table, select, update
|
||||
|
||||
import json
|
||||
|
||||
revision = "242a2047eae0"
|
||||
down_revision = "6a39f3d8e55c"
|
||||
revision = '242a2047eae0'
|
||||
down_revision = '6a39f3d8e55c'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -22,39 +22,37 @@ def upgrade():
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
|
||||
columns = inspector.get_columns("chat")
|
||||
column_dict = {col["name"]: col for col in columns}
|
||||
columns = inspector.get_columns('chat')
|
||||
column_dict = {col['name']: col for col in columns}
|
||||
|
||||
chat_column = column_dict.get("chat")
|
||||
old_chat_exists = "old_chat" in column_dict
|
||||
chat_column = column_dict.get('chat')
|
||||
old_chat_exists = 'old_chat' in column_dict
|
||||
|
||||
if chat_column:
|
||||
if isinstance(chat_column["type"], sa.Text):
|
||||
if isinstance(chat_column['type'], sa.Text):
|
||||
print("Converting 'chat' column to JSON")
|
||||
|
||||
if old_chat_exists:
|
||||
print("Dropping old 'old_chat' column")
|
||||
op.drop_column("chat", "old_chat")
|
||||
op.drop_column('chat', 'old_chat')
|
||||
|
||||
# Step 1: Rename current 'chat' column to 'old_chat'
|
||||
print("Renaming 'chat' column to 'old_chat'")
|
||||
op.alter_column(
|
||||
"chat", "chat", new_column_name="old_chat", existing_type=sa.Text()
|
||||
)
|
||||
op.alter_column('chat', 'chat', new_column_name='old_chat', existing_type=sa.Text())
|
||||
|
||||
# Step 2: Add new 'chat' column of type JSON
|
||||
print("Adding new 'chat' column of type JSON")
|
||||
op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
|
||||
op.add_column('chat', sa.Column('chat', sa.JSON(), nullable=True))
|
||||
else:
|
||||
# If the column is already JSON, no need to do anything
|
||||
pass
|
||||
|
||||
# Step 3: Migrate data from 'old_chat' to 'chat'
|
||||
chat_table = table(
|
||||
"chat",
|
||||
sa.Column("id", sa.String(), primary_key=True),
|
||||
sa.Column("old_chat", sa.Text()),
|
||||
sa.Column("chat", sa.JSON()),
|
||||
'chat',
|
||||
sa.Column('id', sa.String(), primary_key=True),
|
||||
sa.Column('old_chat', sa.Text()),
|
||||
sa.Column('chat', sa.JSON()),
|
||||
)
|
||||
|
||||
# - Selecting all data from the table
|
||||
@@ -67,41 +65,33 @@ def upgrade():
|
||||
except json.JSONDecodeError:
|
||||
json_data = None # Handle cases where the text cannot be converted to JSON
|
||||
|
||||
connection.execute(
|
||||
sa.update(chat_table)
|
||||
.where(chat_table.c.id == row.id)
|
||||
.values(chat=json_data)
|
||||
)
|
||||
connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(chat=json_data))
|
||||
|
||||
# Step 4: Drop 'old_chat' column
|
||||
print("Dropping 'old_chat' column")
|
||||
op.drop_column("chat", "old_chat")
|
||||
op.drop_column('chat', 'old_chat')
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Step 1: Add 'old_chat' column back as Text
|
||||
op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True))
|
||||
op.add_column('chat', sa.Column('old_chat', sa.Text(), nullable=True))
|
||||
|
||||
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
|
||||
chat_table = table(
|
||||
"chat",
|
||||
sa.Column("id", sa.String(), primary_key=True),
|
||||
sa.Column("chat", sa.JSON()),
|
||||
sa.Column("old_chat", sa.Text()),
|
||||
'chat',
|
||||
sa.Column('id', sa.String(), primary_key=True),
|
||||
sa.Column('chat', sa.JSON()),
|
||||
sa.Column('old_chat', sa.Text()),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
|
||||
for row in results:
|
||||
text_data = json.dumps(row.chat) if row.chat is not None else None
|
||||
connection.execute(
|
||||
sa.update(chat_table)
|
||||
.where(chat_table.c.id == row.id)
|
||||
.values(old_chat=text_data)
|
||||
)
|
||||
connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(old_chat=text_data))
|
||||
|
||||
# Step 3: Remove the new 'chat' JSON column
|
||||
op.drop_column("chat", "chat")
|
||||
op.drop_column('chat', 'chat')
|
||||
|
||||
# Step 4: Rename 'old_chat' back to 'chat'
|
||||
op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text())
|
||||
op.alter_column('chat', 'old_chat', new_column_name='chat', existing_type=sa.Text())
|
||||
|
||||
@@ -13,19 +13,19 @@ import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "2f1211949ecc"
|
||||
down_revision: Union[str, None] = "37f288994c47"
|
||||
revision: str = '2f1211949ecc'
|
||||
down_revision: Union[str, None] = '37f288994c47'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# New columns to be added to channel_member table
|
||||
op.add_column("channel_member", sa.Column("status", sa.Text(), nullable=True))
|
||||
op.add_column('channel_member', sa.Column('status', sa.Text(), nullable=True))
|
||||
op.add_column(
|
||||
"channel_member",
|
||||
'channel_member',
|
||||
sa.Column(
|
||||
"is_active",
|
||||
'is_active',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
default=True,
|
||||
@@ -34,9 +34,9 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"channel_member",
|
||||
'channel_member',
|
||||
sa.Column(
|
||||
"is_channel_muted",
|
||||
'is_channel_muted',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
default=False,
|
||||
@@ -44,9 +44,9 @@ def upgrade() -> None:
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"channel_member",
|
||||
'channel_member',
|
||||
sa.Column(
|
||||
"is_channel_pinned",
|
||||
'is_channel_pinned',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
default=False,
|
||||
@@ -54,49 +54,41 @@ def upgrade() -> None:
|
||||
),
|
||||
)
|
||||
|
||||
op.add_column("channel_member", sa.Column("data", sa.JSON(), nullable=True))
|
||||
op.add_column("channel_member", sa.Column("meta", sa.JSON(), nullable=True))
|
||||
op.add_column('channel_member', sa.Column('data', sa.JSON(), nullable=True))
|
||||
op.add_column('channel_member', sa.Column('meta', sa.JSON(), nullable=True))
|
||||
|
||||
op.add_column(
|
||||
"channel_member", sa.Column("joined_at", sa.BigInteger(), nullable=False)
|
||||
)
|
||||
op.add_column(
|
||||
"channel_member", sa.Column("left_at", sa.BigInteger(), nullable=True)
|
||||
)
|
||||
op.add_column('channel_member', sa.Column('joined_at', sa.BigInteger(), nullable=False))
|
||||
op.add_column('channel_member', sa.Column('left_at', sa.BigInteger(), nullable=True))
|
||||
|
||||
op.add_column(
|
||||
"channel_member", sa.Column("last_read_at", sa.BigInteger(), nullable=True)
|
||||
)
|
||||
op.add_column('channel_member', sa.Column('last_read_at', sa.BigInteger(), nullable=True))
|
||||
|
||||
op.add_column(
|
||||
"channel_member", sa.Column("updated_at", sa.BigInteger(), nullable=True)
|
||||
)
|
||||
op.add_column('channel_member', sa.Column('updated_at', sa.BigInteger(), nullable=True))
|
||||
|
||||
# New columns to be added to message table
|
||||
op.add_column(
|
||||
"message",
|
||||
'message',
|
||||
sa.Column(
|
||||
"is_pinned",
|
||||
'is_pinned',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
default=False,
|
||||
server_default=sa.sql.expression.false(),
|
||||
),
|
||||
)
|
||||
op.add_column("message", sa.Column("pinned_at", sa.BigInteger(), nullable=True))
|
||||
op.add_column("message", sa.Column("pinned_by", sa.Text(), nullable=True))
|
||||
op.add_column('message', sa.Column('pinned_at', sa.BigInteger(), nullable=True))
|
||||
op.add_column('message', sa.Column('pinned_by', sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("channel_member", "updated_at")
|
||||
op.drop_column("channel_member", "last_read_at")
|
||||
op.drop_column('channel_member', 'updated_at')
|
||||
op.drop_column('channel_member', 'last_read_at')
|
||||
|
||||
op.drop_column("channel_member", "meta")
|
||||
op.drop_column("channel_member", "data")
|
||||
op.drop_column('channel_member', 'meta')
|
||||
op.drop_column('channel_member', 'data')
|
||||
|
||||
op.drop_column("channel_member", "is_channel_pinned")
|
||||
op.drop_column("channel_member", "is_channel_muted")
|
||||
op.drop_column('channel_member', 'is_channel_pinned')
|
||||
op.drop_column('channel_member', 'is_channel_muted')
|
||||
|
||||
op.drop_column("message", "pinned_by")
|
||||
op.drop_column("message", "pinned_at")
|
||||
op.drop_column("message", "is_pinned")
|
||||
op.drop_column('message', 'pinned_by')
|
||||
op.drop_column('message', 'pinned_at')
|
||||
op.drop_column('message', 'is_pinned')
|
||||
|
||||
@@ -12,8 +12,8 @@ import uuid
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision: str = "374d2f66af06"
|
||||
down_revision: Union[str, None] = "c440947495f3"
|
||||
revision: str = '374d2f66af06'
|
||||
down_revision: Union[str, None] = 'c440947495f3'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -26,13 +26,13 @@ def upgrade() -> None:
|
||||
# We need to assume the OLD structure.
|
||||
|
||||
old_prompt_table = sa.table(
|
||||
"prompt",
|
||||
sa.column("command", sa.Text()),
|
||||
sa.column("user_id", sa.Text()),
|
||||
sa.column("title", sa.Text()),
|
||||
sa.column("content", sa.Text()),
|
||||
sa.column("timestamp", sa.BigInteger()),
|
||||
sa.column("access_control", sa.JSON()),
|
||||
'prompt',
|
||||
sa.column('command', sa.Text()),
|
||||
sa.column('user_id', sa.Text()),
|
||||
sa.column('title', sa.Text()),
|
||||
sa.column('content', sa.Text()),
|
||||
sa.column('timestamp', sa.BigInteger()),
|
||||
sa.column('access_control', sa.JSON()),
|
||||
)
|
||||
|
||||
# Check if table exists/read data
|
||||
@@ -53,61 +53,61 @@ def upgrade() -> None:
|
||||
|
||||
# Step 2: Create new prompt table with 'id' as PRIMARY KEY
|
||||
op.create_table(
|
||||
"prompt_new",
|
||||
sa.Column("id", sa.Text(), primary_key=True),
|
||||
sa.Column("command", sa.String(), unique=True, index=True),
|
||||
sa.Column("user_id", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="1"),
|
||||
sa.Column("version_id", sa.Text(), nullable=True),
|
||||
sa.Column("tags", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
'prompt_new',
|
||||
sa.Column('id', sa.Text(), primary_key=True),
|
||||
sa.Column('command', sa.String(), unique=True, index=True),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.Text(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('data', sa.JSON(), nullable=True),
|
||||
sa.Column('meta', sa.JSON(), nullable=True),
|
||||
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='1'),
|
||||
sa.Column('version_id', sa.Text(), nullable=True),
|
||||
sa.Column('tags', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||
)
|
||||
|
||||
# Step 3: Create prompt_history table
|
||||
op.create_table(
|
||||
"prompt_history",
|
||||
sa.Column("id", sa.Text(), primary_key=True),
|
||||
sa.Column("prompt_id", sa.Text(), nullable=False, index=True),
|
||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
||||
sa.Column("snapshot", sa.JSON(), nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("commit_message", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
'prompt_history',
|
||||
sa.Column('id', sa.Text(), primary_key=True),
|
||||
sa.Column('prompt_id', sa.Text(), nullable=False, index=True),
|
||||
sa.Column('parent_id', sa.Text(), nullable=True),
|
||||
sa.Column('snapshot', sa.JSON(), nullable=False),
|
||||
sa.Column('user_id', sa.Text(), nullable=False),
|
||||
sa.Column('commit_message', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
)
|
||||
|
||||
# Step 4: Migrate data
|
||||
prompt_new_table = sa.table(
|
||||
"prompt_new",
|
||||
sa.column("id", sa.Text()),
|
||||
sa.column("command", sa.String()),
|
||||
sa.column("user_id", sa.String()),
|
||||
sa.column("name", sa.Text()),
|
||||
sa.column("content", sa.Text()),
|
||||
sa.column("data", sa.JSON()),
|
||||
sa.column("meta", sa.JSON()),
|
||||
sa.column("access_control", sa.JSON()),
|
||||
sa.column("is_active", sa.Boolean()),
|
||||
sa.column("version_id", sa.Text()),
|
||||
sa.column("tags", sa.JSON()),
|
||||
sa.column("created_at", sa.BigInteger()),
|
||||
sa.column("updated_at", sa.BigInteger()),
|
||||
'prompt_new',
|
||||
sa.column('id', sa.Text()),
|
||||
sa.column('command', sa.String()),
|
||||
sa.column('user_id', sa.String()),
|
||||
sa.column('name', sa.Text()),
|
||||
sa.column('content', sa.Text()),
|
||||
sa.column('data', sa.JSON()),
|
||||
sa.column('meta', sa.JSON()),
|
||||
sa.column('access_control', sa.JSON()),
|
||||
sa.column('is_active', sa.Boolean()),
|
||||
sa.column('version_id', sa.Text()),
|
||||
sa.column('tags', sa.JSON()),
|
||||
sa.column('created_at', sa.BigInteger()),
|
||||
sa.column('updated_at', sa.BigInteger()),
|
||||
)
|
||||
|
||||
prompt_history_table = sa.table(
|
||||
"prompt_history",
|
||||
sa.column("id", sa.Text()),
|
||||
sa.column("prompt_id", sa.Text()),
|
||||
sa.column("parent_id", sa.Text()),
|
||||
sa.column("snapshot", sa.JSON()),
|
||||
sa.column("user_id", sa.Text()),
|
||||
sa.column("commit_message", sa.Text()),
|
||||
sa.column("created_at", sa.BigInteger()),
|
||||
'prompt_history',
|
||||
sa.column('id', sa.Text()),
|
||||
sa.column('prompt_id', sa.Text()),
|
||||
sa.column('parent_id', sa.Text()),
|
||||
sa.column('snapshot', sa.JSON()),
|
||||
sa.column('user_id', sa.Text()),
|
||||
sa.column('commit_message', sa.Text()),
|
||||
sa.column('created_at', sa.BigInteger()),
|
||||
)
|
||||
|
||||
for row in existing_prompts:
|
||||
@@ -120,7 +120,7 @@ def upgrade() -> None:
|
||||
|
||||
new_uuid = str(uuid.uuid4())
|
||||
history_uuid = str(uuid.uuid4())
|
||||
clean_command = command[1:] if command and command.startswith("/") else command
|
||||
clean_command = command[1:] if command and command.startswith('/') else command
|
||||
|
||||
# Insert into prompt_new
|
||||
conn.execute(
|
||||
@@ -148,12 +148,12 @@ def upgrade() -> None:
|
||||
prompt_id=new_uuid,
|
||||
parent_id=None,
|
||||
snapshot={
|
||||
"name": title,
|
||||
"content": content,
|
||||
"command": clean_command,
|
||||
"data": {},
|
||||
"meta": {},
|
||||
"access_control": access_control,
|
||||
'name': title,
|
||||
'content': content,
|
||||
'command': clean_command,
|
||||
'data': {},
|
||||
'meta': {},
|
||||
'access_control': access_control,
|
||||
},
|
||||
user_id=user_id,
|
||||
commit_message=None,
|
||||
@@ -162,8 +162,8 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
# Step 5: Replace old table with new one
|
||||
op.drop_table("prompt")
|
||||
op.rename_table("prompt_new", "prompt")
|
||||
op.drop_table('prompt')
|
||||
op.rename_table('prompt_new', 'prompt')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
@@ -171,13 +171,13 @@ def downgrade() -> None:
|
||||
|
||||
# Step 1: Read new data
|
||||
prompt_table = sa.table(
|
||||
"prompt",
|
||||
sa.column("command", sa.String()),
|
||||
sa.column("name", sa.Text()),
|
||||
sa.column("created_at", sa.BigInteger()),
|
||||
sa.column("user_id", sa.Text()),
|
||||
sa.column("content", sa.Text()),
|
||||
sa.column("access_control", sa.JSON()),
|
||||
'prompt',
|
||||
sa.column('command', sa.String()),
|
||||
sa.column('name', sa.Text()),
|
||||
sa.column('created_at', sa.BigInteger()),
|
||||
sa.column('user_id', sa.Text()),
|
||||
sa.column('content', sa.Text()),
|
||||
sa.column('access_control', sa.JSON()),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -195,31 +195,31 @@ def downgrade() -> None:
|
||||
current_data = []
|
||||
|
||||
# Step 2: Drop history and table
|
||||
op.drop_table("prompt_history")
|
||||
op.drop_table("prompt")
|
||||
op.drop_table('prompt_history')
|
||||
op.drop_table('prompt')
|
||||
|
||||
# Step 3: Recreate old table (command as PK?)
|
||||
# Assuming old schema:
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("command", sa.String(), primary_key=True),
|
||||
sa.Column("user_id", sa.String()),
|
||||
sa.Column("title", sa.Text()),
|
||||
sa.Column("content", sa.Text()),
|
||||
sa.Column("timestamp", sa.BigInteger()),
|
||||
sa.Column("access_control", sa.JSON()),
|
||||
sa.Column("id", sa.Integer(), nullable=True),
|
||||
'prompt',
|
||||
sa.Column('command', sa.String(), primary_key=True),
|
||||
sa.Column('user_id', sa.String()),
|
||||
sa.Column('title', sa.Text()),
|
||||
sa.Column('content', sa.Text()),
|
||||
sa.Column('timestamp', sa.BigInteger()),
|
||||
sa.Column('access_control', sa.JSON()),
|
||||
sa.Column('id', sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
# Step 4: Restore data
|
||||
old_prompt_table = sa.table(
|
||||
"prompt",
|
||||
sa.column("command", sa.String()),
|
||||
sa.column("user_id", sa.String()),
|
||||
sa.column("title", sa.Text()),
|
||||
sa.column("content", sa.Text()),
|
||||
sa.column("timestamp", sa.BigInteger()),
|
||||
sa.column("access_control", sa.JSON()),
|
||||
'prompt',
|
||||
sa.column('command', sa.String()),
|
||||
sa.column('user_id', sa.String()),
|
||||
sa.column('title', sa.Text()),
|
||||
sa.column('content', sa.Text()),
|
||||
sa.column('timestamp', sa.BigInteger()),
|
||||
sa.column('access_control', sa.JSON()),
|
||||
)
|
||||
|
||||
for row in current_data:
|
||||
@@ -231,9 +231,7 @@ def downgrade() -> None:
|
||||
access_control = row[5]
|
||||
|
||||
# Restore leading /
|
||||
old_command = (
|
||||
"/" + command if command and not command.startswith("/") else command
|
||||
)
|
||||
old_command = '/' + command if command and not command.startswith('/') else command
|
||||
|
||||
conn.execute(
|
||||
sa.insert(old_prompt_table).values(
|
||||
|
||||
@@ -9,8 +9,8 @@ Create Date: 2024-12-30 03:00:00.000000
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "3781e22d8b01"
|
||||
down_revision = "7826ab40b532"
|
||||
revision = '3781e22d8b01'
|
||||
down_revision = '7826ab40b532'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -18,9 +18,9 @@ depends_on = None
|
||||
def upgrade():
|
||||
# Add 'type' column to the 'channel' table
|
||||
op.add_column(
|
||||
"channel",
|
||||
'channel',
|
||||
sa.Column(
|
||||
"type",
|
||||
'type',
|
||||
sa.Text(),
|
||||
nullable=True,
|
||||
),
|
||||
@@ -28,43 +28,31 @@ def upgrade():
|
||||
|
||||
# Add 'parent_id' column to the 'message' table for threads
|
||||
op.add_column(
|
||||
"message",
|
||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
||||
'message',
|
||||
sa.Column('parent_id', sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"message_reaction",
|
||||
sa.Column(
|
||||
"id", sa.Text(), nullable=False, primary_key=True, unique=True
|
||||
), # Unique reaction ID
|
||||
sa.Column("user_id", sa.Text(), nullable=False), # User who reacted
|
||||
sa.Column(
|
||||
"message_id", sa.Text(), nullable=False
|
||||
), # Message that was reacted to
|
||||
sa.Column(
|
||||
"name", sa.Text(), nullable=False
|
||||
), # Reaction name (e.g. "thumbs_up")
|
||||
sa.Column(
|
||||
"created_at", sa.BigInteger(), nullable=True
|
||||
), # Timestamp of when the reaction was added
|
||||
'message_reaction',
|
||||
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Unique reaction ID
|
||||
sa.Column('user_id', sa.Text(), nullable=False), # User who reacted
|
||||
sa.Column('message_id', sa.Text(), nullable=False), # Message that was reacted to
|
||||
sa.Column('name', sa.Text(), nullable=False), # Reaction name (e.g. "thumbs_up")
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the reaction was added
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"channel_member",
|
||||
sa.Column(
|
||||
"id", sa.Text(), nullable=False, primary_key=True, unique=True
|
||||
), # Record ID for the membership row
|
||||
sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel
|
||||
sa.Column("user_id", sa.Text(), nullable=False), # Associated user
|
||||
sa.Column(
|
||||
"created_at", sa.BigInteger(), nullable=True
|
||||
), # Timestamp of when the user joined the channel
|
||||
'channel_member',
|
||||
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Record ID for the membership row
|
||||
sa.Column('channel_id', sa.Text(), nullable=False), # Associated channel
|
||||
sa.Column('user_id', sa.Text(), nullable=False), # Associated user
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the user joined the channel
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Revert 'type' column addition to the 'channel' table
|
||||
op.drop_column("channel", "type")
|
||||
op.drop_column("message", "parent_id")
|
||||
op.drop_table("message_reaction")
|
||||
op.drop_table("channel_member")
|
||||
op.drop_column('channel', 'type')
|
||||
op.drop_column('message', 'parent_id')
|
||||
op.drop_table('message_reaction')
|
||||
op.drop_table('channel_member')
|
||||
|
||||
@@ -15,8 +15,8 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "37f288994c47"
|
||||
down_revision: Union[str, None] = "a5c220713937"
|
||||
revision: str = '37f288994c47'
|
||||
down_revision: Union[str, None] = 'a5c220713937'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -24,50 +24,48 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
def upgrade() -> None:
|
||||
# 1. Create new table
|
||||
op.create_table(
|
||||
"group_member",
|
||||
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
|
||||
'group_member',
|
||||
sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False),
|
||||
sa.Column(
|
||||
"group_id",
|
||||
'group_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey("group.id", ondelete="CASCADE"),
|
||||
sa.ForeignKey('group.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
'user_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
sa.ForeignKey('user.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.UniqueConstraint('group_id', 'user_id', name='uq_group_member_group_user'),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
# 2. Read existing group with user_ids JSON column
|
||||
group_table = sa.Table(
|
||||
"group",
|
||||
'group',
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.Text()),
|
||||
sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG
|
||||
sa.Column('id', sa.Text()),
|
||||
sa.Column('user_ids', sa.JSON()), # JSON stored as text in SQLite + PG
|
||||
)
|
||||
|
||||
results = connection.execute(
|
||||
sa.select(group_table.c.id, group_table.c.user_ids)
|
||||
).fetchall()
|
||||
results = connection.execute(sa.select(group_table.c.id, group_table.c.user_ids)).fetchall()
|
||||
|
||||
print(results)
|
||||
|
||||
# 3. Insert members into group_member table
|
||||
gm_table = sa.Table(
|
||||
"group_member",
|
||||
'group_member',
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.Text()),
|
||||
sa.Column("group_id", sa.Text()),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("created_at", sa.BigInteger()),
|
||||
sa.Column("updated_at", sa.BigInteger()),
|
||||
sa.Column('id', sa.Text()),
|
||||
sa.Column('group_id', sa.Text()),
|
||||
sa.Column('user_id', sa.Text()),
|
||||
sa.Column('created_at', sa.BigInteger()),
|
||||
sa.Column('updated_at', sa.BigInteger()),
|
||||
)
|
||||
|
||||
now = int(time.time())
|
||||
@@ -86,11 +84,11 @@ def upgrade() -> None:
|
||||
|
||||
rows = [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"group_id": group_id,
|
||||
"user_id": uid,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
'id': str(uuid.uuid4()),
|
||||
'group_id': group_id,
|
||||
'user_id': uid,
|
||||
'created_at': now,
|
||||
'updated_at': now,
|
||||
}
|
||||
for uid in user_ids
|
||||
]
|
||||
@@ -99,47 +97,41 @@ def upgrade() -> None:
|
||||
connection.execute(gm_table.insert(), rows)
|
||||
|
||||
# 4. Optionally drop the old column
|
||||
with op.batch_alter_table("group") as batch:
|
||||
batch.drop_column("user_ids")
|
||||
with op.batch_alter_table('group') as batch:
|
||||
batch.drop_column('user_ids')
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Reverse: restore user_ids column
|
||||
with op.batch_alter_table("group") as batch:
|
||||
batch.add_column(sa.Column("user_ids", sa.JSON()))
|
||||
with op.batch_alter_table('group') as batch:
|
||||
batch.add_column(sa.Column('user_ids', sa.JSON()))
|
||||
|
||||
connection = op.get_bind()
|
||||
gm_table = sa.Table(
|
||||
"group_member",
|
||||
'group_member',
|
||||
sa.MetaData(),
|
||||
sa.Column("group_id", sa.Text()),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("created_at", sa.BigInteger()),
|
||||
sa.Column("updated_at", sa.BigInteger()),
|
||||
sa.Column('group_id', sa.Text()),
|
||||
sa.Column('user_id', sa.Text()),
|
||||
sa.Column('created_at', sa.BigInteger()),
|
||||
sa.Column('updated_at', sa.BigInteger()),
|
||||
)
|
||||
|
||||
group_table = sa.Table(
|
||||
"group",
|
||||
'group',
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.Text()),
|
||||
sa.Column("user_ids", sa.JSON()),
|
||||
sa.Column('id', sa.Text()),
|
||||
sa.Column('user_ids', sa.JSON()),
|
||||
)
|
||||
|
||||
# Build JSON arrays again
|
||||
results = connection.execute(sa.select(group_table.c.id)).fetchall()
|
||||
|
||||
for (group_id,) in results:
|
||||
members = connection.execute(
|
||||
sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)
|
||||
).fetchall()
|
||||
members = connection.execute(sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)).fetchall()
|
||||
|
||||
member_ids = [m[0] for m in members]
|
||||
|
||||
connection.execute(
|
||||
group_table.update()
|
||||
.where(group_table.c.id == group_id)
|
||||
.values(user_ids=member_ids)
|
||||
)
|
||||
connection.execute(group_table.update().where(group_table.c.id == group_id).values(user_ids=member_ids))
|
||||
|
||||
# Drop the new table
|
||||
op.drop_table("group_member")
|
||||
op.drop_table('group_member')
|
||||
|
||||
@@ -12,8 +12,8 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "38d63c18f30f"
|
||||
down_revision: Union[str, None] = "3af16a1c9fb6"
|
||||
revision: str = '38d63c18f30f'
|
||||
down_revision: Union[str, None] = '3af16a1c9fb6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -21,59 +21,55 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
def upgrade() -> None:
|
||||
# Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint)
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
columns = inspector.get_columns("user")
|
||||
columns = inspector.get_columns('user')
|
||||
|
||||
pk_columns = inspector.get_pk_constraint("user")["constrained_columns"]
|
||||
id_column = next((col for col in columns if col["name"] == "id"), None)
|
||||
pk_columns = inspector.get_pk_constraint('user')['constrained_columns']
|
||||
id_column = next((col for col in columns if col['name'] == 'id'), None)
|
||||
|
||||
if id_column and not id_column.get("unique", False):
|
||||
unique_constraints = inspector.get_unique_constraints("user")
|
||||
unique_columns = {tuple(u["column_names"]) for u in unique_constraints}
|
||||
if id_column and not id_column.get('unique', False):
|
||||
unique_constraints = inspector.get_unique_constraints('user')
|
||||
unique_columns = {tuple(u['column_names']) for u in unique_constraints}
|
||||
|
||||
with op.batch_alter_table("user") as batch_op:
|
||||
with op.batch_alter_table('user') as batch_op:
|
||||
# If primary key is wrong, drop it
|
||||
if pk_columns and pk_columns != ["id"]:
|
||||
batch_op.drop_constraint(
|
||||
inspector.get_pk_constraint("user")["name"], type_="primary"
|
||||
)
|
||||
if pk_columns and pk_columns != ['id']:
|
||||
batch_op.drop_constraint(inspector.get_pk_constraint('user')['name'], type_='primary')
|
||||
|
||||
# Add unique constraint if missing
|
||||
if ("id",) not in unique_columns:
|
||||
batch_op.create_unique_constraint("uq_user_id", ["id"])
|
||||
if ('id',) not in unique_columns:
|
||||
batch_op.create_unique_constraint('uq_user_id', ['id'])
|
||||
|
||||
# Re-create correct primary key
|
||||
batch_op.create_primary_key("pk_user_id", ["id"])
|
||||
batch_op.create_primary_key('pk_user_id', ['id'])
|
||||
|
||||
# Create oauth_session table
|
||||
op.create_table(
|
||||
"oauth_session",
|
||||
sa.Column("id", sa.Text(), primary_key=True, nullable=False, unique=True),
|
||||
'oauth_session',
|
||||
sa.Column('id', sa.Text(), primary_key=True, nullable=False, unique=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
'user_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
sa.ForeignKey('user.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("provider", sa.Text(), nullable=False),
|
||||
sa.Column("token", sa.Text(), nullable=False),
|
||||
sa.Column("expires_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column('provider', sa.Text(), nullable=False),
|
||||
sa.Column('token', sa.Text(), nullable=False),
|
||||
sa.Column('expires_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||
)
|
||||
|
||||
# Create indexes for better performance
|
||||
op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"])
|
||||
op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"])
|
||||
op.create_index(
|
||||
"idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"]
|
||||
)
|
||||
op.create_index('idx_oauth_session_user_id', 'oauth_session', ['user_id'])
|
||||
op.create_index('idx_oauth_session_expires_at', 'oauth_session', ['expires_at'])
|
||||
op.create_index('idx_oauth_session_user_provider', 'oauth_session', ['user_id', 'provider'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes first
|
||||
op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session")
|
||||
op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session")
|
||||
op.drop_index("idx_oauth_session_user_id", table_name="oauth_session")
|
||||
op.drop_index('idx_oauth_session_user_provider', table_name='oauth_session')
|
||||
op.drop_index('idx_oauth_session_expires_at', table_name='oauth_session')
|
||||
op.drop_index('idx_oauth_session_user_id', table_name='oauth_session')
|
||||
|
||||
# Drop the table
|
||||
op.drop_table("oauth_session")
|
||||
op.drop_table('oauth_session')
|
||||
|
||||
@@ -13,8 +13,8 @@ from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
import json
|
||||
|
||||
revision = "3ab32c4b8f59"
|
||||
down_revision = "1af9b942657b"
|
||||
revision = '3ab32c4b8f59'
|
||||
down_revision = '1af9b942657b'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -24,58 +24,55 @@ def upgrade():
|
||||
inspector = Inspector.from_engine(conn)
|
||||
|
||||
# Inspecting the 'tag' table constraints and structure
|
||||
existing_pk = inspector.get_pk_constraint("tag")
|
||||
unique_constraints = inspector.get_unique_constraints("tag")
|
||||
existing_indexes = inspector.get_indexes("tag")
|
||||
existing_pk = inspector.get_pk_constraint('tag')
|
||||
unique_constraints = inspector.get_unique_constraints('tag')
|
||||
existing_indexes = inspector.get_indexes('tag')
|
||||
|
||||
print(f"Primary Key: {existing_pk}")
|
||||
print(f"Unique Constraints: {unique_constraints}")
|
||||
print(f"Indexes: {existing_indexes}")
|
||||
print(f'Primary Key: {existing_pk}')
|
||||
print(f'Unique Constraints: {unique_constraints}')
|
||||
print(f'Indexes: {existing_indexes}')
|
||||
|
||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
||||
with op.batch_alter_table('tag', schema=None) as batch_op:
|
||||
# Drop existing primary key constraint if it exists
|
||||
if existing_pk and existing_pk.get("constrained_columns"):
|
||||
pk_name = existing_pk.get("name")
|
||||
if existing_pk and existing_pk.get('constrained_columns'):
|
||||
pk_name = existing_pk.get('name')
|
||||
if pk_name:
|
||||
print(f"Dropping primary key constraint: {pk_name}")
|
||||
batch_op.drop_constraint(pk_name, type_="primary")
|
||||
print(f'Dropping primary key constraint: {pk_name}')
|
||||
batch_op.drop_constraint(pk_name, type_='primary')
|
||||
|
||||
# Now create the new primary key with the combination of 'id' and 'user_id'
|
||||
print("Creating new primary key with 'id' and 'user_id'.")
|
||||
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
|
||||
batch_op.create_primary_key('pk_id_user_id', ['id', 'user_id'])
|
||||
|
||||
# Drop unique constraints that could conflict with the new primary key
|
||||
for constraint in unique_constraints:
|
||||
if (
|
||||
constraint["name"] == "uq_id_user_id"
|
||||
constraint['name'] == 'uq_id_user_id'
|
||||
): # Adjust this name according to what is actually returned by the inspector
|
||||
print(f"Dropping unique constraint: {constraint['name']}")
|
||||
batch_op.drop_constraint(constraint["name"], type_="unique")
|
||||
print(f'Dropping unique constraint: {constraint["name"]}')
|
||||
batch_op.drop_constraint(constraint['name'], type_='unique')
|
||||
|
||||
for index in existing_indexes:
|
||||
if index["unique"]:
|
||||
if not any(
|
||||
constraint["name"] == index["name"]
|
||||
for constraint in unique_constraints
|
||||
):
|
||||
if index['unique']:
|
||||
if not any(constraint['name'] == index['name'] for constraint in unique_constraints):
|
||||
# You are attempting to drop unique indexes
|
||||
print(f"Dropping unique index: {index['name']}")
|
||||
batch_op.drop_index(index["name"])
|
||||
print(f'Dropping unique index: {index["name"]}')
|
||||
batch_op.drop_index(index['name'])
|
||||
|
||||
|
||||
def downgrade():
|
||||
conn = op.get_bind()
|
||||
inspector = Inspector.from_engine(conn)
|
||||
|
||||
current_pk = inspector.get_pk_constraint("tag")
|
||||
current_pk = inspector.get_pk_constraint('tag')
|
||||
|
||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
||||
with op.batch_alter_table('tag', schema=None) as batch_op:
|
||||
# Drop the current primary key first, if it matches the one we know we added in upgrade
|
||||
if current_pk and "pk_id_user_id" == current_pk.get("name"):
|
||||
batch_op.drop_constraint("pk_id_user_id", type_="primary")
|
||||
if current_pk and 'pk_id_user_id' == current_pk.get('name'):
|
||||
batch_op.drop_constraint('pk_id_user_id', type_='primary')
|
||||
|
||||
# Restore the original primary key
|
||||
batch_op.create_primary_key("pk_id", ["id"])
|
||||
batch_op.create_primary_key('pk_id', ['id'])
|
||||
|
||||
# Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
|
||||
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
|
||||
batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id'])
|
||||
|
||||
@@ -12,21 +12,21 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "3af16a1c9fb6"
|
||||
down_revision: Union[str, None] = "018012973d35"
|
||||
revision: str = '3af16a1c9fb6'
|
||||
down_revision: Union[str, None] = '018012973d35'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("user", sa.Column("username", sa.String(length=50), nullable=True))
|
||||
op.add_column("user", sa.Column("bio", sa.Text(), nullable=True))
|
||||
op.add_column("user", sa.Column("gender", sa.Text(), nullable=True))
|
||||
op.add_column("user", sa.Column("date_of_birth", sa.Date(), nullable=True))
|
||||
op.add_column('user', sa.Column('username', sa.String(length=50), nullable=True))
|
||||
op.add_column('user', sa.Column('bio', sa.Text(), nullable=True))
|
||||
op.add_column('user', sa.Column('gender', sa.Text(), nullable=True))
|
||||
op.add_column('user', sa.Column('date_of_birth', sa.Date(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "username")
|
||||
op.drop_column("user", "bio")
|
||||
op.drop_column("user", "gender")
|
||||
op.drop_column("user", "date_of_birth")
|
||||
op.drop_column('user', 'username')
|
||||
op.drop_column('user', 'bio')
|
||||
op.drop_column('user', 'gender')
|
||||
op.drop_column('user', 'date_of_birth')
|
||||
|
||||
@@ -18,38 +18,38 @@ import json
|
||||
import uuid
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "3e0e00844bb0"
|
||||
down_revision: Union[str, None] = "90ef40d4714e"
|
||||
revision: str = '3e0e00844bb0'
|
||||
down_revision: Union[str, None] = '90ef40d4714e'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"knowledge_file",
|
||||
sa.Column("id", sa.Text(), primary_key=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
'knowledge_file',
|
||||
sa.Column('id', sa.Text(), primary_key=True),
|
||||
sa.Column('user_id', sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"knowledge_id",
|
||||
'knowledge_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey("knowledge.id", ondelete="CASCADE"),
|
||||
sa.ForeignKey('knowledge.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"file_id",
|
||||
'file_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey("file.id", ondelete="CASCADE"),
|
||||
sa.ForeignKey('file.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||
# indexes
|
||||
sa.Index("ix_knowledge_file_knowledge_id", "knowledge_id"),
|
||||
sa.Index("ix_knowledge_file_file_id", "file_id"),
|
||||
sa.Index("ix_knowledge_file_user_id", "user_id"),
|
||||
sa.Index('ix_knowledge_file_knowledge_id', 'knowledge_id'),
|
||||
sa.Index('ix_knowledge_file_file_id', 'file_id'),
|
||||
sa.Index('ix_knowledge_file_user_id', 'user_id'),
|
||||
# unique constraints
|
||||
sa.UniqueConstraint(
|
||||
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
|
||||
'knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file'
|
||||
), # prevent duplicate entries
|
||||
)
|
||||
|
||||
@@ -57,35 +57,33 @@ def upgrade() -> None:
|
||||
|
||||
# 2. Read existing group with user_ids JSON column
|
||||
knowledge_table = sa.Table(
|
||||
"knowledge",
|
||||
'knowledge',
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.Text()),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("data", sa.JSON()), # JSON stored as text in SQLite + PG
|
||||
sa.Column('id', sa.Text()),
|
||||
sa.Column('user_id', sa.Text()),
|
||||
sa.Column('data', sa.JSON()), # JSON stored as text in SQLite + PG
|
||||
)
|
||||
|
||||
results = connection.execute(
|
||||
sa.select(
|
||||
knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data
|
||||
)
|
||||
sa.select(knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data)
|
||||
).fetchall()
|
||||
|
||||
# 3. Insert members into group_member table
|
||||
kf_table = sa.Table(
|
||||
"knowledge_file",
|
||||
'knowledge_file',
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.Text()),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("knowledge_id", sa.Text()),
|
||||
sa.Column("file_id", sa.Text()),
|
||||
sa.Column("created_at", sa.BigInteger()),
|
||||
sa.Column("updated_at", sa.BigInteger()),
|
||||
sa.Column('id', sa.Text()),
|
||||
sa.Column('user_id', sa.Text()),
|
||||
sa.Column('knowledge_id', sa.Text()),
|
||||
sa.Column('file_id', sa.Text()),
|
||||
sa.Column('created_at', sa.BigInteger()),
|
||||
sa.Column('updated_at', sa.BigInteger()),
|
||||
)
|
||||
|
||||
file_table = sa.Table(
|
||||
"file",
|
||||
'file',
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.Text()),
|
||||
sa.Column('id', sa.Text()),
|
||||
)
|
||||
|
||||
now = int(time.time())
|
||||
@@ -102,50 +100,48 @@ def upgrade() -> None:
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
|
||||
file_ids = data.get("file_ids", [])
|
||||
file_ids = data.get('file_ids', [])
|
||||
|
||||
for file_id in file_ids:
|
||||
file_exists = connection.execute(
|
||||
sa.select(file_table.c.id).where(file_table.c.id == file_id)
|
||||
).fetchone()
|
||||
file_exists = connection.execute(sa.select(file_table.c.id).where(file_table.c.id == file_id)).fetchone()
|
||||
|
||||
if not file_exists:
|
||||
continue # skip non-existing files
|
||||
|
||||
row = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"knowledge_id": knowledge_id,
|
||||
"file_id": file_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
'id': str(uuid.uuid4()),
|
||||
'user_id': user_id,
|
||||
'knowledge_id': knowledge_id,
|
||||
'file_id': file_id,
|
||||
'created_at': now,
|
||||
'updated_at': now,
|
||||
}
|
||||
connection.execute(kf_table.insert().values(**row))
|
||||
|
||||
with op.batch_alter_table("knowledge") as batch:
|
||||
batch.drop_column("data")
|
||||
with op.batch_alter_table('knowledge') as batch:
|
||||
batch.drop_column('data')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# 1. Add back the old data column
|
||||
op.add_column("knowledge", sa.Column("data", sa.JSON(), nullable=True))
|
||||
op.add_column('knowledge', sa.Column('data', sa.JSON(), nullable=True))
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
# 2. Read knowledge_file entries and reconstruct data JSON
|
||||
knowledge_table = sa.Table(
|
||||
"knowledge",
|
||||
'knowledge',
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.Text()),
|
||||
sa.Column("data", sa.JSON()),
|
||||
sa.Column('id', sa.Text()),
|
||||
sa.Column('data', sa.JSON()),
|
||||
)
|
||||
|
||||
kf_table = sa.Table(
|
||||
"knowledge_file",
|
||||
'knowledge_file',
|
||||
sa.MetaData(),
|
||||
sa.Column("id", sa.Text()),
|
||||
sa.Column("knowledge_id", sa.Text()),
|
||||
sa.Column("file_id", sa.Text()),
|
||||
sa.Column('id', sa.Text()),
|
||||
sa.Column('knowledge_id', sa.Text()),
|
||||
sa.Column('file_id', sa.Text()),
|
||||
)
|
||||
|
||||
results = connection.execute(sa.select(knowledge_table.c.id)).fetchall()
|
||||
@@ -157,13 +153,9 @@ def downgrade() -> None:
|
||||
|
||||
file_ids_list = [fid for (fid,) in file_ids]
|
||||
|
||||
data_json = {"file_ids": file_ids_list}
|
||||
data_json = {'file_ids': file_ids_list}
|
||||
|
||||
connection.execute(
|
||||
knowledge_table.update()
|
||||
.where(knowledge_table.c.id == knowledge_id)
|
||||
.values(data=data_json)
|
||||
)
|
||||
connection.execute(knowledge_table.update().where(knowledge_table.c.id == knowledge_id).values(data=data_json))
|
||||
|
||||
# 3. Drop the knowledge_file table
|
||||
op.drop_table("knowledge_file")
|
||||
op.drop_table('knowledge_file')
|
||||
|
||||
@@ -9,56 +9,56 @@ Create Date: 2024-10-23 03:00:00.000000
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "4ace53fd72c8"
|
||||
down_revision = "af906e964978"
|
||||
revision = '4ace53fd72c8'
|
||||
down_revision = 'af906e964978'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Perform safe alterations using batch operation
|
||||
with op.batch_alter_table("folder", schema=None) as batch_op:
|
||||
with op.batch_alter_table('folder', schema=None) as batch_op:
|
||||
# Step 1: Remove server defaults for created_at and updated_at
|
||||
batch_op.alter_column(
|
||||
"created_at",
|
||||
'created_at',
|
||||
server_default=None, # Removing server default
|
||||
)
|
||||
batch_op.alter_column(
|
||||
"updated_at",
|
||||
'updated_at',
|
||||
server_default=None, # Removing server default
|
||||
)
|
||||
|
||||
# Step 2: Change the column types to BigInteger for created_at
|
||||
batch_op.alter_column(
|
||||
"created_at",
|
||||
'created_at',
|
||||
type_=sa.BigInteger(),
|
||||
existing_type=sa.DateTime(),
|
||||
existing_nullable=False,
|
||||
postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL
|
||||
postgresql_using='extract(epoch from created_at)::bigint', # Conversion for PostgreSQL
|
||||
)
|
||||
|
||||
# Change the column types to BigInteger for updated_at
|
||||
batch_op.alter_column(
|
||||
"updated_at",
|
||||
'updated_at',
|
||||
type_=sa.BigInteger(),
|
||||
existing_type=sa.DateTime(),
|
||||
existing_nullable=False,
|
||||
postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL
|
||||
postgresql_using='extract(epoch from updated_at)::bigint', # Conversion for PostgreSQL
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Downgrade: Convert columns back to DateTime and restore defaults
|
||||
with op.batch_alter_table("folder", schema=None) as batch_op:
|
||||
with op.batch_alter_table('folder', schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"created_at",
|
||||
'created_at',
|
||||
type_=sa.DateTime(),
|
||||
existing_type=sa.BigInteger(),
|
||||
existing_nullable=False,
|
||||
server_default=sa.func.now(), # Restoring server default on downgrade
|
||||
)
|
||||
batch_op.alter_column(
|
||||
"updated_at",
|
||||
'updated_at',
|
||||
type_=sa.DateTime(),
|
||||
existing_type=sa.BigInteger(),
|
||||
existing_nullable=False,
|
||||
|
||||
@@ -9,40 +9,40 @@ Create Date: 2024-12-22 03:00:00.000000
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "57c599a3cb57"
|
||||
down_revision = "922e7a387820"
|
||||
revision = '57c599a3cb57'
|
||||
down_revision = '922e7a387820'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"channel",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("name", sa.Text()),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
'channel',
|
||||
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column('user_id', sa.Text()),
|
||||
sa.Column('name', sa.Text()),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('data', sa.JSON(), nullable=True),
|
||||
sa.Column('meta', sa.JSON(), nullable=True),
|
||||
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"message",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text()),
|
||||
sa.Column("channel_id", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text()),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
'message',
|
||||
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column('user_id', sa.Text()),
|
||||
sa.Column('channel_id', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text()),
|
||||
sa.Column('data', sa.JSON(), nullable=True),
|
||||
sa.Column('meta', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("channel")
|
||||
op.drop_table('channel')
|
||||
|
||||
op.drop_table("message")
|
||||
op.drop_table('message')
|
||||
|
||||
@@ -13,41 +13,39 @@ import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "6283dc0e4d8d"
|
||||
down_revision: Union[str, None] = "3e0e00844bb0"
|
||||
revision: str = '6283dc0e4d8d'
|
||||
down_revision: Union[str, None] = '3e0e00844bb0'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"channel_file",
|
||||
sa.Column("id", sa.Text(), primary_key=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
'channel_file',
|
||||
sa.Column('id', sa.Text(), primary_key=True),
|
||||
sa.Column('user_id', sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"channel_id",
|
||||
'channel_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey("channel.id", ondelete="CASCADE"),
|
||||
sa.ForeignKey('channel.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"file_id",
|
||||
'file_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey("file.id", ondelete="CASCADE"),
|
||||
sa.ForeignKey('file.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||
# indexes
|
||||
sa.Index("ix_channel_file_channel_id", "channel_id"),
|
||||
sa.Index("ix_channel_file_file_id", "file_id"),
|
||||
sa.Index("ix_channel_file_user_id", "user_id"),
|
||||
sa.Index('ix_channel_file_channel_id', 'channel_id'),
|
||||
sa.Index('ix_channel_file_file_id', 'file_id'),
|
||||
sa.Index('ix_channel_file_user_id', 'user_id'),
|
||||
# unique constraints
|
||||
sa.UniqueConstraint(
|
||||
"channel_id", "file_id", name="uq_channel_file_channel_file"
|
||||
), # prevent duplicate entries
|
||||
sa.UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'), # prevent duplicate entries
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("channel_file")
|
||||
op.drop_table('channel_file')
|
||||
|
||||
@@ -11,37 +11,37 @@ import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column, select
|
||||
import json
|
||||
|
||||
revision = "6a39f3d8e55c"
|
||||
down_revision = "c0fbf31ca0db"
|
||||
revision = '6a39f3d8e55c'
|
||||
down_revision = 'c0fbf31ca0db'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Creating the 'knowledge' table
|
||||
print("Creating knowledge table")
|
||||
print('Creating knowledge table')
|
||||
knowledge_table = op.create_table(
|
||||
"knowledge",
|
||||
sa.Column("id", sa.Text(), primary_key=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
'knowledge',
|
||||
sa.Column('id', sa.Text(), primary_key=True),
|
||||
sa.Column('user_id', sa.Text(), nullable=False),
|
||||
sa.Column('name', sa.Text(), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('data', sa.JSON(), nullable=True),
|
||||
sa.Column('meta', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
print("Migrating data from document table to knowledge table")
|
||||
print('Migrating data from document table to knowledge table')
|
||||
# Representation of the existing 'document' table
|
||||
document_table = table(
|
||||
"document",
|
||||
column("collection_name", sa.String()),
|
||||
column("user_id", sa.String()),
|
||||
column("name", sa.String()),
|
||||
column("title", sa.Text()),
|
||||
column("content", sa.Text()),
|
||||
column("timestamp", sa.BigInteger()),
|
||||
'document',
|
||||
column('collection_name', sa.String()),
|
||||
column('user_id', sa.String()),
|
||||
column('name', sa.String()),
|
||||
column('title', sa.Text()),
|
||||
column('content', sa.Text()),
|
||||
column('timestamp', sa.BigInteger()),
|
||||
)
|
||||
|
||||
# Select all from existing document table
|
||||
@@ -64,9 +64,9 @@ def upgrade():
|
||||
user_id=doc.user_id,
|
||||
description=doc.name,
|
||||
meta={
|
||||
"legacy": True,
|
||||
"document": True,
|
||||
"tags": json.loads(doc.content or "{}").get("tags", []),
|
||||
'legacy': True,
|
||||
'document': True,
|
||||
'tags': json.loads(doc.content or '{}').get('tags', []),
|
||||
},
|
||||
name=doc.title,
|
||||
created_at=doc.timestamp,
|
||||
@@ -76,4 +76,4 @@ def upgrade():
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("knowledge")
|
||||
op.drop_table('knowledge')
|
||||
|
||||
@@ -9,18 +9,18 @@ Create Date: 2024-12-23 03:00:00.000000
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "7826ab40b532"
|
||||
down_revision = "57c599a3cb57"
|
||||
revision = '7826ab40b532'
|
||||
down_revision = '57c599a3cb57'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column(
|
||||
"file",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
'file',
|
||||
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("file", "access_control")
|
||||
op.drop_column('file', 'access_control')
|
||||
|
||||
@@ -16,7 +16,7 @@ from open_webui.internal.db import JSONField
|
||||
from open_webui.migrations.util import get_existing_tables
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "7e5b5dc7342b"
|
||||
revision: str = '7e5b5dc7342b'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
@@ -26,179 +26,179 @@ def upgrade() -> None:
|
||||
existing_tables = set(get_existing_tables())
|
||||
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
if "auth" not in existing_tables:
|
||||
if 'auth' not in existing_tables:
|
||||
op.create_table(
|
||||
"auth",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("email", sa.String(), nullable=True),
|
||||
sa.Column("password", sa.Text(), nullable=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
'auth',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('email', sa.String(), nullable=True),
|
||||
sa.Column('password', sa.Text(), nullable=True),
|
||||
sa.Column('active', sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
if "chat" not in existing_tables:
|
||||
if 'chat' not in existing_tables:
|
||||
op.create_table(
|
||||
"chat",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("chat", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("share_id", sa.Text(), nullable=True),
|
||||
sa.Column("archived", sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("share_id"),
|
||||
'chat',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.Text(), nullable=True),
|
||||
sa.Column('chat', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('share_id', sa.Text(), nullable=True),
|
||||
sa.Column('archived', sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('share_id'),
|
||||
)
|
||||
|
||||
if "chatidtag" not in existing_tables:
|
||||
if 'chatidtag' not in existing_tables:
|
||||
op.create_table(
|
||||
"chatidtag",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("tag_name", sa.String(), nullable=True),
|
||||
sa.Column("chat_id", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
'chatidtag',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('tag_name', sa.String(), nullable=True),
|
||||
sa.Column('chat_id', sa.String(), nullable=True),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
if "document" not in existing_tables:
|
||||
if 'document' not in existing_tables:
|
||||
op.create_table(
|
||||
"document",
|
||||
sa.Column("collection_name", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("filename", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("collection_name"),
|
||||
sa.UniqueConstraint("name"),
|
||||
'document',
|
||||
sa.Column('collection_name', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.Text(), nullable=True),
|
||||
sa.Column('filename', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('collection_name'),
|
||||
sa.UniqueConstraint('name'),
|
||||
)
|
||||
|
||||
if "file" not in existing_tables:
|
||||
if 'file' not in existing_tables:
|
||||
op.create_table(
|
||||
"file",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("filename", sa.Text(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
'file',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('filename', sa.Text(), nullable=True),
|
||||
sa.Column('meta', JSONField(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
if "function" not in existing_tables:
|
||||
if 'function' not in existing_tables:
|
||||
op.create_table(
|
||||
"function",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("type", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("valves", JSONField(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=True),
|
||||
sa.Column("is_global", sa.Boolean(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
'function',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('name', sa.Text(), nullable=True),
|
||||
sa.Column('type', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('meta', JSONField(), nullable=True),
|
||||
sa.Column('valves', JSONField(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('is_global', sa.Boolean(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
if "memory" not in existing_tables:
|
||||
if 'memory' not in existing_tables:
|
||||
op.create_table(
|
||||
"memory",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
'memory',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
if "model" not in existing_tables:
|
||||
if 'model' not in existing_tables:
|
||||
op.create_table(
|
||||
"model",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=True),
|
||||
sa.Column("base_model_id", sa.Text(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("params", JSONField(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
'model',
|
||||
sa.Column('id', sa.Text(), nullable=False),
|
||||
sa.Column('user_id', sa.Text(), nullable=True),
|
||||
sa.Column('base_model_id', sa.Text(), nullable=True),
|
||||
sa.Column('name', sa.Text(), nullable=True),
|
||||
sa.Column('params', JSONField(), nullable=True),
|
||||
sa.Column('meta', JSONField(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
if "prompt" not in existing_tables:
|
||||
if 'prompt' not in existing_tables:
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("command", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("command"),
|
||||
'prompt',
|
||||
sa.Column('command', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('command'),
|
||||
)
|
||||
|
||||
if "tag" not in existing_tables:
|
||||
if 'tag' not in existing_tables:
|
||||
op.create_table(
|
||||
"tag",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("data", sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
'tag',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('data', sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
if "tool" not in existing_tables:
|
||||
if 'tool' not in existing_tables:
|
||||
op.create_table(
|
||||
"tool",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("specs", JSONField(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("valves", JSONField(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
'tool',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('name', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('specs', JSONField(), nullable=True),
|
||||
sa.Column('meta', JSONField(), nullable=True),
|
||||
sa.Column('valves', JSONField(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
if "user" not in existing_tables:
|
||||
if 'user' not in existing_tables:
|
||||
op.create_table(
|
||||
"user",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("email", sa.String(), nullable=True),
|
||||
sa.Column("role", sa.String(), nullable=True),
|
||||
sa.Column("profile_image_url", sa.Text(), nullable=True),
|
||||
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column("settings", JSONField(), nullable=True),
|
||||
sa.Column("info", JSONField(), nullable=True),
|
||||
sa.Column("oauth_sub", sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("api_key"),
|
||||
sa.UniqueConstraint("oauth_sub"),
|
||||
'user',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('email', sa.String(), nullable=True),
|
||||
sa.Column('role', sa.String(), nullable=True),
|
||||
sa.Column('profile_image_url', sa.Text(), nullable=True),
|
||||
sa.Column('last_active_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('api_key', sa.String(), nullable=True),
|
||||
sa.Column('settings', JSONField(), nullable=True),
|
||||
sa.Column('info', JSONField(), nullable=True),
|
||||
sa.Column('oauth_sub', sa.Text(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('api_key'),
|
||||
sa.UniqueConstraint('oauth_sub'),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("user")
|
||||
op.drop_table("tool")
|
||||
op.drop_table("tag")
|
||||
op.drop_table("prompt")
|
||||
op.drop_table("model")
|
||||
op.drop_table("memory")
|
||||
op.drop_table("function")
|
||||
op.drop_table("file")
|
||||
op.drop_table("document")
|
||||
op.drop_table("chatidtag")
|
||||
op.drop_table("chat")
|
||||
op.drop_table("auth")
|
||||
op.drop_table('user')
|
||||
op.drop_table('tool')
|
||||
op.drop_table('tag')
|
||||
op.drop_table('prompt')
|
||||
op.drop_table('model')
|
||||
op.drop_table('memory')
|
||||
op.drop_table('function')
|
||||
op.drop_table('file')
|
||||
op.drop_table('document')
|
||||
op.drop_table('chatidtag')
|
||||
op.drop_table('chat')
|
||||
op.drop_table('auth')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -13,36 +13,34 @@ import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "81cc2ce44d79"
|
||||
down_revision: Union[str, None] = "6283dc0e4d8d"
|
||||
revision: str = '81cc2ce44d79'
|
||||
down_revision: Union[str, None] = '6283dc0e4d8d'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add message_id column to channel_file table
|
||||
with op.batch_alter_table("channel_file", schema=None) as batch_op:
|
||||
with op.batch_alter_table('channel_file', schema=None) as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"message_id",
|
||||
'message_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey(
|
||||
"message.id", ondelete="CASCADE", name="fk_channel_file_message_id"
|
||||
),
|
||||
sa.ForeignKey('message.id', ondelete='CASCADE', name='fk_channel_file_message_id'),
|
||||
nullable=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Add data column to knowledge table
|
||||
with op.batch_alter_table("knowledge", schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column("data", sa.JSON(), nullable=True))
|
||||
with op.batch_alter_table('knowledge', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('data', sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove message_id column from channel_file table
|
||||
with op.batch_alter_table("channel_file", schema=None) as batch_op:
|
||||
batch_op.drop_column("message_id")
|
||||
with op.batch_alter_table('channel_file', schema=None) as batch_op:
|
||||
batch_op.drop_column('message_id')
|
||||
|
||||
# Remove data column from knowledge table
|
||||
with op.batch_alter_table("knowledge", schema=None) as batch_op:
|
||||
batch_op.drop_column("data")
|
||||
with op.batch_alter_table('knowledge', schema=None) as batch_op:
|
||||
batch_op.drop_column('data')
|
||||
|
||||
@@ -16,8 +16,8 @@ import sqlalchemy as sa
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
revision: str = "8452d01d26d7"
|
||||
down_revision: Union[str, None] = "374d2f66af06"
|
||||
revision: str = '8452d01d26d7'
|
||||
down_revision: Union[str, None] = '374d2f66af06'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -51,74 +51,68 @@ def _flush_batch(conn, table, batch):
|
||||
except Exception as e:
|
||||
sp.rollback()
|
||||
failed += 1
|
||||
log.warning(f"Failed to insert message {msg['id']}: {e}")
|
||||
log.warning(f'Failed to insert message {msg["id"]}: {e}')
|
||||
return inserted, failed
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Create table
|
||||
op.create_table(
|
||||
"chat_message",
|
||||
sa.Column("id", sa.Text(), primary_key=True),
|
||||
sa.Column("chat_id", sa.Text(), nullable=False, index=True),
|
||||
sa.Column("user_id", sa.Text(), index=True),
|
||||
sa.Column("role", sa.Text(), nullable=False),
|
||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.JSON(), nullable=True),
|
||||
sa.Column("output", sa.JSON(), nullable=True),
|
||||
sa.Column("model_id", sa.Text(), nullable=True, index=True),
|
||||
sa.Column("files", sa.JSON(), nullable=True),
|
||||
sa.Column("sources", sa.JSON(), nullable=True),
|
||||
sa.Column("embeds", sa.JSON(), nullable=True),
|
||||
sa.Column("done", sa.Boolean(), default=True),
|
||||
sa.Column("status_history", sa.JSON(), nullable=True),
|
||||
sa.Column("error", sa.JSON(), nullable=True),
|
||||
sa.Column("usage", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), index=True),
|
||||
sa.Column("updated_at", sa.BigInteger()),
|
||||
sa.ForeignKeyConstraint(["chat_id"], ["chat.id"], ondelete="CASCADE"),
|
||||
'chat_message',
|
||||
sa.Column('id', sa.Text(), primary_key=True),
|
||||
sa.Column('chat_id', sa.Text(), nullable=False, index=True),
|
||||
sa.Column('user_id', sa.Text(), index=True),
|
||||
sa.Column('role', sa.Text(), nullable=False),
|
||||
sa.Column('parent_id', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.JSON(), nullable=True),
|
||||
sa.Column('output', sa.JSON(), nullable=True),
|
||||
sa.Column('model_id', sa.Text(), nullable=True, index=True),
|
||||
sa.Column('files', sa.JSON(), nullable=True),
|
||||
sa.Column('sources', sa.JSON(), nullable=True),
|
||||
sa.Column('embeds', sa.JSON(), nullable=True),
|
||||
sa.Column('done', sa.Boolean(), default=True),
|
||||
sa.Column('status_history', sa.JSON(), nullable=True),
|
||||
sa.Column('error', sa.JSON(), nullable=True),
|
||||
sa.Column('usage', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), index=True),
|
||||
sa.Column('updated_at', sa.BigInteger()),
|
||||
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ondelete='CASCADE'),
|
||||
)
|
||||
|
||||
# Create composite indexes
|
||||
op.create_index(
|
||||
"chat_message_chat_parent_idx", "chat_message", ["chat_id", "parent_id"]
|
||||
)
|
||||
op.create_index(
|
||||
"chat_message_model_created_idx", "chat_message", ["model_id", "created_at"]
|
||||
)
|
||||
op.create_index(
|
||||
"chat_message_user_created_idx", "chat_message", ["user_id", "created_at"]
|
||||
)
|
||||
op.create_index('chat_message_chat_parent_idx', 'chat_message', ['chat_id', 'parent_id'])
|
||||
op.create_index('chat_message_model_created_idx', 'chat_message', ['model_id', 'created_at'])
|
||||
op.create_index('chat_message_user_created_idx', 'chat_message', ['user_id', 'created_at'])
|
||||
|
||||
# Step 2: Backfill from existing chats
|
||||
conn = op.get_bind()
|
||||
|
||||
chat_table = sa.table(
|
||||
"chat",
|
||||
sa.column("id", sa.Text()),
|
||||
sa.column("user_id", sa.Text()),
|
||||
sa.column("chat", sa.JSON()),
|
||||
'chat',
|
||||
sa.column('id', sa.Text()),
|
||||
sa.column('user_id', sa.Text()),
|
||||
sa.column('chat', sa.JSON()),
|
||||
)
|
||||
|
||||
chat_message_table = sa.table(
|
||||
"chat_message",
|
||||
sa.column("id", sa.Text()),
|
||||
sa.column("chat_id", sa.Text()),
|
||||
sa.column("user_id", sa.Text()),
|
||||
sa.column("role", sa.Text()),
|
||||
sa.column("parent_id", sa.Text()),
|
||||
sa.column("content", sa.JSON()),
|
||||
sa.column("output", sa.JSON()),
|
||||
sa.column("model_id", sa.Text()),
|
||||
sa.column("files", sa.JSON()),
|
||||
sa.column("sources", sa.JSON()),
|
||||
sa.column("embeds", sa.JSON()),
|
||||
sa.column("done", sa.Boolean()),
|
||||
sa.column("status_history", sa.JSON()),
|
||||
sa.column("error", sa.JSON()),
|
||||
sa.column("usage", sa.JSON()),
|
||||
sa.column("created_at", sa.BigInteger()),
|
||||
sa.column("updated_at", sa.BigInteger()),
|
||||
'chat_message',
|
||||
sa.column('id', sa.Text()),
|
||||
sa.column('chat_id', sa.Text()),
|
||||
sa.column('user_id', sa.Text()),
|
||||
sa.column('role', sa.Text()),
|
||||
sa.column('parent_id', sa.Text()),
|
||||
sa.column('content', sa.JSON()),
|
||||
sa.column('output', sa.JSON()),
|
||||
sa.column('model_id', sa.Text()),
|
||||
sa.column('files', sa.JSON()),
|
||||
sa.column('sources', sa.JSON()),
|
||||
sa.column('embeds', sa.JSON()),
|
||||
sa.column('done', sa.Boolean()),
|
||||
sa.column('status_history', sa.JSON()),
|
||||
sa.column('error', sa.JSON()),
|
||||
sa.column('usage', sa.JSON()),
|
||||
sa.column('created_at', sa.BigInteger()),
|
||||
sa.column('updated_at', sa.BigInteger()),
|
||||
)
|
||||
|
||||
# Stream rows instead of loading all into memory:
|
||||
@@ -126,7 +120,7 @@ def upgrade() -> None:
|
||||
# - stream_results: enables server-side cursors on PostgreSQL (no-op on SQLite)
|
||||
result = conn.execute(
|
||||
sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat)
|
||||
.where(~chat_table.c.user_id.like("shared-%"))
|
||||
.where(~chat_table.c.user_id.like('shared-%'))
|
||||
.execution_options(yield_per=1000, stream_results=True)
|
||||
)
|
||||
|
||||
@@ -150,11 +144,11 @@ def upgrade() -> None:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
history = chat_data.get("history", {})
|
||||
history = chat_data.get('history', {})
|
||||
if not isinstance(history, dict):
|
||||
continue
|
||||
|
||||
messages = history.get("messages", {})
|
||||
messages = history.get('messages', {})
|
||||
if not isinstance(messages, dict):
|
||||
continue
|
||||
|
||||
@@ -162,11 +156,11 @@ def upgrade() -> None:
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
|
||||
role = message.get("role")
|
||||
role = message.get('role')
|
||||
if not role:
|
||||
continue
|
||||
|
||||
timestamp = message.get("timestamp", now)
|
||||
timestamp = message.get('timestamp', now)
|
||||
|
||||
try:
|
||||
timestamp = int(float(timestamp))
|
||||
@@ -182,37 +176,33 @@ def upgrade() -> None:
|
||||
|
||||
messages_batch.append(
|
||||
{
|
||||
"id": f"{chat_id}-{message_id}",
|
||||
"chat_id": chat_id,
|
||||
"user_id": user_id,
|
||||
"role": role,
|
||||
"parent_id": message.get("parentId"),
|
||||
"content": message.get("content"),
|
||||
"output": message.get("output"),
|
||||
"model_id": message.get("model"),
|
||||
"files": message.get("files"),
|
||||
"sources": message.get("sources"),
|
||||
"embeds": message.get("embeds"),
|
||||
"done": message.get("done", True),
|
||||
"status_history": message.get("statusHistory"),
|
||||
"error": message.get("error"),
|
||||
"usage": message.get("usage"),
|
||||
"created_at": timestamp,
|
||||
"updated_at": timestamp,
|
||||
'id': f'{chat_id}-{message_id}',
|
||||
'chat_id': chat_id,
|
||||
'user_id': user_id,
|
||||
'role': role,
|
||||
'parent_id': message.get('parentId'),
|
||||
'content': message.get('content'),
|
||||
'output': message.get('output'),
|
||||
'model_id': message.get('model'),
|
||||
'files': message.get('files'),
|
||||
'sources': message.get('sources'),
|
||||
'embeds': message.get('embeds'),
|
||||
'done': message.get('done', True),
|
||||
'status_history': message.get('statusHistory'),
|
||||
'error': message.get('error'),
|
||||
'usage': message.get('usage'),
|
||||
'created_at': timestamp,
|
||||
'updated_at': timestamp,
|
||||
}
|
||||
)
|
||||
|
||||
# Flush batch when full
|
||||
if len(messages_batch) >= BATCH_SIZE:
|
||||
inserted, failed = _flush_batch(
|
||||
conn, chat_message_table, messages_batch
|
||||
)
|
||||
inserted, failed = _flush_batch(conn, chat_message_table, messages_batch)
|
||||
total_inserted += inserted
|
||||
total_failed += failed
|
||||
if total_inserted % 50000 < BATCH_SIZE:
|
||||
log.info(
|
||||
f"Migration progress: {total_inserted} messages inserted..."
|
||||
)
|
||||
log.info(f'Migration progress: {total_inserted} messages inserted...')
|
||||
messages_batch.clear()
|
||||
|
||||
# Flush remaining messages
|
||||
@@ -221,13 +211,11 @@ def upgrade() -> None:
|
||||
total_inserted += inserted
|
||||
total_failed += failed
|
||||
|
||||
log.info(
|
||||
f"Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)"
|
||||
)
|
||||
log.info(f'Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("chat_message_user_created_idx", table_name="chat_message")
|
||||
op.drop_index("chat_message_model_created_idx", table_name="chat_message")
|
||||
op.drop_index("chat_message_chat_parent_idx", table_name="chat_message")
|
||||
op.drop_table("chat_message")
|
||||
op.drop_index('chat_message_user_created_idx', table_name='chat_message')
|
||||
op.drop_index('chat_message_model_created_idx', table_name='chat_message')
|
||||
op.drop_index('chat_message_chat_parent_idx', table_name='chat_message')
|
||||
op.drop_table('chat_message')
|
||||
|
||||
@@ -13,48 +13,46 @@ import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "90ef40d4714e"
|
||||
down_revision: Union[str, None] = "b10670c03dd5"
|
||||
revision: str = '90ef40d4714e'
|
||||
down_revision: Union[str, None] = 'b10670c03dd5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Update 'channel' table
|
||||
op.add_column("channel", sa.Column("is_private", sa.Boolean(), nullable=True))
|
||||
op.add_column('channel', sa.Column('is_private', sa.Boolean(), nullable=True))
|
||||
|
||||
op.add_column("channel", sa.Column("archived_at", sa.BigInteger(), nullable=True))
|
||||
op.add_column("channel", sa.Column("archived_by", sa.Text(), nullable=True))
|
||||
op.add_column('channel', sa.Column('archived_at', sa.BigInteger(), nullable=True))
|
||||
op.add_column('channel', sa.Column('archived_by', sa.Text(), nullable=True))
|
||||
|
||||
op.add_column("channel", sa.Column("deleted_at", sa.BigInteger(), nullable=True))
|
||||
op.add_column("channel", sa.Column("deleted_by", sa.Text(), nullable=True))
|
||||
op.add_column('channel', sa.Column('deleted_at', sa.BigInteger(), nullable=True))
|
||||
op.add_column('channel', sa.Column('deleted_by', sa.Text(), nullable=True))
|
||||
|
||||
op.add_column("channel", sa.Column("updated_by", sa.Text(), nullable=True))
|
||||
op.add_column('channel', sa.Column('updated_by', sa.Text(), nullable=True))
|
||||
|
||||
# Update 'channel_member' table
|
||||
op.add_column("channel_member", sa.Column("role", sa.Text(), nullable=True))
|
||||
op.add_column("channel_member", sa.Column("invited_by", sa.Text(), nullable=True))
|
||||
op.add_column(
|
||||
"channel_member", sa.Column("invited_at", sa.BigInteger(), nullable=True)
|
||||
)
|
||||
op.add_column('channel_member', sa.Column('role', sa.Text(), nullable=True))
|
||||
op.add_column('channel_member', sa.Column('invited_by', sa.Text(), nullable=True))
|
||||
op.add_column('channel_member', sa.Column('invited_at', sa.BigInteger(), nullable=True))
|
||||
|
||||
# Create 'channel_webhook' table
|
||||
op.create_table(
|
||||
"channel_webhook",
|
||||
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
'channel_webhook',
|
||||
sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False),
|
||||
sa.Column('user_id', sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"channel_id",
|
||||
'channel_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey("channel.id", ondelete="CASCADE"),
|
||||
sa.ForeignKey('channel.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("profile_image_url", sa.Text(), nullable=True),
|
||||
sa.Column("token", sa.Text(), nullable=False),
|
||||
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column('name', sa.Text(), nullable=False),
|
||||
sa.Column('profile_image_url', sa.Text(), nullable=True),
|
||||
sa.Column('token', sa.Text(), nullable=False),
|
||||
sa.Column('last_used_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||
)
|
||||
|
||||
pass
|
||||
@@ -62,19 +60,19 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
# Downgrade 'channel' table
|
||||
op.drop_column("channel", "is_private")
|
||||
op.drop_column("channel", "archived_at")
|
||||
op.drop_column("channel", "archived_by")
|
||||
op.drop_column("channel", "deleted_at")
|
||||
op.drop_column("channel", "deleted_by")
|
||||
op.drop_column("channel", "updated_by")
|
||||
op.drop_column('channel', 'is_private')
|
||||
op.drop_column('channel', 'archived_at')
|
||||
op.drop_column('channel', 'archived_by')
|
||||
op.drop_column('channel', 'deleted_at')
|
||||
op.drop_column('channel', 'deleted_by')
|
||||
op.drop_column('channel', 'updated_by')
|
||||
|
||||
# Downgrade 'channel_member' table
|
||||
op.drop_column("channel_member", "role")
|
||||
op.drop_column("channel_member", "invited_by")
|
||||
op.drop_column("channel_member", "invited_at")
|
||||
op.drop_column('channel_member', 'role')
|
||||
op.drop_column('channel_member', 'invited_by')
|
||||
op.drop_column('channel_member', 'invited_at')
|
||||
|
||||
# Drop 'channel_webhook' table
|
||||
op.drop_table("channel_webhook")
|
||||
op.drop_table('channel_webhook')
|
||||
|
||||
pass
|
||||
|
||||
@@ -9,38 +9,38 @@ Create Date: 2024-11-14 03:00:00.000000
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "922e7a387820"
|
||||
down_revision = "4ace53fd72c8"
|
||||
revision = '922e7a387820'
|
||||
down_revision = '4ace53fd72c8'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"group",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("permissions", sa.JSON(), nullable=True),
|
||||
sa.Column("user_ids", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
'group',
|
||||
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column('user_id', sa.Text(), nullable=True),
|
||||
sa.Column('name', sa.Text(), nullable=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('data', sa.JSON(), nullable=True),
|
||||
sa.Column('meta', sa.JSON(), nullable=True),
|
||||
sa.Column('permissions', sa.JSON(), nullable=True),
|
||||
sa.Column('user_ids', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
# Add 'access_control' column to 'model' table
|
||||
op.add_column(
|
||||
"model",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
'model',
|
||||
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
# Add 'is_active' column to 'model' table
|
||||
op.add_column(
|
||||
"model",
|
||||
'model',
|
||||
sa.Column(
|
||||
"is_active",
|
||||
'is_active',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.sql.expression.true(),
|
||||
@@ -49,37 +49,37 @@ def upgrade():
|
||||
|
||||
# Add 'access_control' column to 'knowledge' table
|
||||
op.add_column(
|
||||
"knowledge",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
'knowledge',
|
||||
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
# Add 'access_control' column to 'prompt' table
|
||||
op.add_column(
|
||||
"prompt",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
'prompt',
|
||||
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
# Add 'access_control' column to 'tools' table
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
'tool',
|
||||
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("group")
|
||||
op.drop_table('group')
|
||||
|
||||
# Drop 'access_control' column from 'model' table
|
||||
op.drop_column("model", "access_control")
|
||||
op.drop_column('model', 'access_control')
|
||||
|
||||
# Drop 'is_active' column from 'model' table
|
||||
op.drop_column("model", "is_active")
|
||||
op.drop_column('model', 'is_active')
|
||||
|
||||
# Drop 'access_control' column from 'knowledge' table
|
||||
op.drop_column("knowledge", "access_control")
|
||||
op.drop_column('knowledge', 'access_control')
|
||||
|
||||
# Drop 'access_control' column from 'prompt' table
|
||||
op.drop_column("prompt", "access_control")
|
||||
op.drop_column('prompt', 'access_control')
|
||||
|
||||
# Drop 'access_control' column from 'tools' table
|
||||
op.drop_column("tool", "access_control")
|
||||
op.drop_column('tool', 'access_control')
|
||||
|
||||
@@ -9,25 +9,25 @@ Create Date: 2025-05-03 03:00:00.000000
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "9f0c9cd09105"
|
||||
down_revision = "3781e22d8b01"
|
||||
revision = '9f0c9cd09105'
|
||||
down_revision = '3781e22d8b01'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"note",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
'note',
|
||||
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||
sa.Column('user_id', sa.Text(), nullable=True),
|
||||
sa.Column('title', sa.Text(), nullable=True),
|
||||
sa.Column('data', sa.JSON(), nullable=True),
|
||||
sa.Column('meta', sa.JSON(), nullable=True),
|
||||
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("note")
|
||||
op.drop_table('note')
|
||||
|
||||
@@ -13,8 +13,8 @@ import sqlalchemy as sa
|
||||
|
||||
from open_webui.migrations.util import get_existing_tables
|
||||
|
||||
revision: str = "a1b2c3d4e5f6"
|
||||
down_revision: Union[str, None] = "f1e2d3c4b5a6"
|
||||
revision: str = 'a1b2c3d4e5f6'
|
||||
down_revision: Union[str, None] = 'f1e2d3c4b5a6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -22,24 +22,24 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
def upgrade() -> None:
|
||||
existing_tables = set(get_existing_tables())
|
||||
|
||||
if "skill" not in existing_tables:
|
||||
if 'skill' not in existing_tables:
|
||||
op.create_table(
|
||||
"skill",
|
||||
sa.Column("id", sa.String(), nullable=False, primary_key=True),
|
||||
sa.Column("user_id", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False, unique=True),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
'skill',
|
||||
sa.Column('id', sa.String(), nullable=False, primary_key=True),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.Text(), nullable=False, unique=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('meta', sa.JSON(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
)
|
||||
op.create_index("idx_skill_user_id", "skill", ["user_id"])
|
||||
op.create_index("idx_skill_updated_at", "skill", ["updated_at"])
|
||||
op.create_index('idx_skill_user_id', 'skill', ['user_id'])
|
||||
op.create_index('idx_skill_updated_at', 'skill', ['updated_at'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("idx_skill_updated_at", table_name="skill")
|
||||
op.drop_index("idx_skill_user_id", table_name="skill")
|
||||
op.drop_table("skill")
|
||||
op.drop_index('idx_skill_updated_at', table_name='skill')
|
||||
op.drop_index('idx_skill_user_id', table_name='skill')
|
||||
op.drop_table('skill')
|
||||
|
||||
@@ -12,8 +12,8 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a5c220713937"
|
||||
down_revision: Union[str, None] = "38d63c18f30f"
|
||||
revision: str = 'a5c220713937'
|
||||
down_revision: Union[str, None] = '38d63c18f30f'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -21,14 +21,14 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
def upgrade() -> None:
|
||||
# Add 'reply_to_id' column to the 'message' table for replying to messages
|
||||
op.add_column(
|
||||
"message",
|
||||
sa.Column("reply_to_id", sa.Text(), nullable=True),
|
||||
'message',
|
||||
sa.Column('reply_to_id', sa.Text(), nullable=True),
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove 'reply_to_id' column from the 'message' table
|
||||
op.drop_column("message", "reply_to_id")
|
||||
op.drop_column('message', 'reply_to_id')
|
||||
|
||||
pass
|
||||
|
||||
@@ -10,8 +10,8 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# Revision identifiers, used by Alembic.
|
||||
revision = "af906e964978"
|
||||
down_revision = "c29facfe716b"
|
||||
revision = 'af906e964978'
|
||||
down_revision = 'c29facfe716b'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
@@ -19,33 +19,23 @@ depends_on = None
|
||||
def upgrade():
|
||||
# ### Create feedback table ###
|
||||
op.create_table(
|
||||
"feedback",
|
||||
'feedback',
|
||||
sa.Column('id', sa.Text(), primary_key=True), # Unique identifier for each feedback (TEXT type)
|
||||
sa.Column('user_id', sa.Text(), nullable=True), # ID of the user providing the feedback (TEXT type)
|
||||
sa.Column('version', sa.BigInteger(), default=0), # Version of feedback (BIGINT type)
|
||||
sa.Column('type', sa.Text(), nullable=True), # Type of feedback (TEXT type)
|
||||
sa.Column('data', sa.JSON(), nullable=True), # Feedback data (JSON type)
|
||||
sa.Column('meta', sa.JSON(), nullable=True), # Metadata for feedback (JSON type)
|
||||
sa.Column('snapshot', sa.JSON(), nullable=True), # snapshot data for feedback (JSON type)
|
||||
sa.Column(
|
||||
"id", sa.Text(), primary_key=True
|
||||
), # Unique identifier for each feedback (TEXT type)
|
||||
sa.Column(
|
||||
"user_id", sa.Text(), nullable=True
|
||||
), # ID of the user providing the feedback (TEXT type)
|
||||
sa.Column(
|
||||
"version", sa.BigInteger(), default=0
|
||||
), # Version of feedback (BIGINT type)
|
||||
sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type)
|
||||
sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type)
|
||||
sa.Column(
|
||||
"meta", sa.JSON(), nullable=True
|
||||
), # Metadata for feedback (JSON type)
|
||||
sa.Column(
|
||||
"snapshot", sa.JSON(), nullable=True
|
||||
), # snapshot data for feedback (JSON type)
|
||||
sa.Column(
|
||||
"created_at", sa.BigInteger(), nullable=False
|
||||
'created_at', sa.BigInteger(), nullable=False
|
||||
), # Feedback creation timestamp (BIGINT representing epoch)
|
||||
sa.Column(
|
||||
"updated_at", sa.BigInteger(), nullable=False
|
||||
'updated_at', sa.BigInteger(), nullable=False
|
||||
), # Feedback update timestamp (BIGINT representing epoch)
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### Drop feedback table ###
|
||||
op.drop_table("feedback")
|
||||
op.drop_table('feedback')
|
||||
|
||||
@@ -17,8 +17,8 @@ import json
|
||||
import time
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b10670c03dd5"
|
||||
down_revision: Union[str, None] = "2f1211949ecc"
|
||||
revision: str = 'b10670c03dd5'
|
||||
down_revision: Union[str, None] = '2f1211949ecc'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -33,13 +33,11 @@ def _drop_sqlite_indexes_for_column(table_name, column_name, conn):
|
||||
for idx in indexes:
|
||||
index_name = idx[1] # index name
|
||||
# Get indexed columns
|
||||
idx_info = conn.execute(
|
||||
sa.text(f"PRAGMA index_info('{index_name}')")
|
||||
).fetchall()
|
||||
idx_info = conn.execute(sa.text(f"PRAGMA index_info('{index_name}')")).fetchall()
|
||||
|
||||
indexed_cols = [row[2] for row in idx_info] # col names
|
||||
if column_name in indexed_cols:
|
||||
conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}"))
|
||||
conn.execute(sa.text(f'DROP INDEX IF EXISTS {index_name}'))
|
||||
|
||||
|
||||
def _convert_column_to_json(table: str, column: str):
|
||||
@@ -47,9 +45,9 @@ def _convert_column_to_json(table: str, column: str):
|
||||
dialect = conn.dialect.name
|
||||
|
||||
# SQLite cannot ALTER COLUMN → must recreate column
|
||||
if dialect == "sqlite":
|
||||
if dialect == 'sqlite':
|
||||
# 1. Add temporary column
|
||||
op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True))
|
||||
op.add_column(table, sa.Column(f'{column}_json', sa.JSON(), nullable=True))
|
||||
|
||||
# 2. Load old data
|
||||
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
||||
@@ -66,14 +64,14 @@ def _convert_column_to_json(table: str, column: str):
|
||||
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
|
||||
{"val": json.dumps(parsed) if parsed else None, "id": uid},
|
||||
{'val': json.dumps(parsed) if parsed else None, 'id': uid},
|
||||
)
|
||||
|
||||
# 3. Drop old TEXT column
|
||||
op.drop_column(table, column)
|
||||
|
||||
# 4. Rename new JSON column → original name
|
||||
op.alter_column(table, f"{column}_json", new_column_name=column)
|
||||
op.alter_column(table, f'{column}_json', new_column_name=column)
|
||||
|
||||
else:
|
||||
# PostgreSQL supports direct CAST
|
||||
@@ -81,7 +79,7 @@ def _convert_column_to_json(table: str, column: str):
|
||||
table,
|
||||
column,
|
||||
type_=sa.JSON(),
|
||||
postgresql_using=f"{column}::json",
|
||||
postgresql_using=f'{column}::json',
|
||||
)
|
||||
|
||||
|
||||
@@ -89,85 +87,77 @@ def _convert_column_to_text(table: str, column: str):
|
||||
conn = op.get_bind()
|
||||
dialect = conn.dialect.name
|
||||
|
||||
if dialect == "sqlite":
|
||||
op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True))
|
||||
if dialect == 'sqlite':
|
||||
op.add_column(table, sa.Column(f'{column}_text', sa.Text(), nullable=True))
|
||||
|
||||
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
||||
|
||||
for uid, raw in rows:
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
|
||||
{"val": json.dumps(raw) if raw else None, "id": uid},
|
||||
{'val': json.dumps(raw) if raw else None, 'id': uid},
|
||||
)
|
||||
|
||||
op.drop_column(table, column)
|
||||
op.alter_column(table, f"{column}_text", new_column_name=column)
|
||||
op.alter_column(table, f'{column}_text', new_column_name=column)
|
||||
|
||||
else:
|
||||
op.alter_column(
|
||||
table,
|
||||
column,
|
||||
type_=sa.Text(),
|
||||
postgresql_using=f"to_json({column})::text",
|
||||
postgresql_using=f'to_json({column})::text',
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True)
|
||||
)
|
||||
op.add_column("user", sa.Column("timezone", sa.String(), nullable=True))
|
||||
op.add_column('user', sa.Column('profile_banner_image_url', sa.Text(), nullable=True))
|
||||
op.add_column('user', sa.Column('timezone', sa.String(), nullable=True))
|
||||
|
||||
op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True))
|
||||
op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True))
|
||||
op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True))
|
||||
op.add_column(
|
||||
"user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True)
|
||||
)
|
||||
op.add_column('user', sa.Column('presence_state', sa.String(), nullable=True))
|
||||
op.add_column('user', sa.Column('status_emoji', sa.String(), nullable=True))
|
||||
op.add_column('user', sa.Column('status_message', sa.Text(), nullable=True))
|
||||
op.add_column('user', sa.Column('status_expires_at', sa.BigInteger(), nullable=True))
|
||||
|
||||
op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True))
|
||||
op.add_column('user', sa.Column('oauth', sa.JSON(), nullable=True))
|
||||
|
||||
# Convert info (TEXT/JSONField) → JSON
|
||||
_convert_column_to_json("user", "info")
|
||||
_convert_column_to_json('user', 'info')
|
||||
# Convert settings (TEXT/JSONField) → JSON
|
||||
_convert_column_to_json("user", "settings")
|
||||
_convert_column_to_json('user', 'settings')
|
||||
|
||||
op.create_table(
|
||||
"api_key",
|
||||
sa.Column("id", sa.Text(), primary_key=True, unique=True),
|
||||
sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")),
|
||||
sa.Column("key", sa.Text(), unique=True, nullable=False),
|
||||
sa.Column("data", sa.JSON(), nullable=True),
|
||||
sa.Column("expires_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
'api_key',
|
||||
sa.Column('id', sa.Text(), primary_key=True, unique=True),
|
||||
sa.Column('user_id', sa.Text(), sa.ForeignKey('user.id', ondelete='CASCADE')),
|
||||
sa.Column('key', sa.Text(), unique=True, nullable=False),
|
||||
sa.Column('data', sa.JSON(), nullable=True),
|
||||
sa.Column('expires_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('last_used_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
users = conn.execute(
|
||||
sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')
|
||||
).fetchall()
|
||||
users = conn.execute(sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')).fetchall()
|
||||
|
||||
for uid, oauth_sub in users:
|
||||
if oauth_sub:
|
||||
# Example formats supported:
|
||||
# provider@sub
|
||||
# plain sub (stored as {"oidc": {"sub": sub}})
|
||||
if "@" in oauth_sub:
|
||||
provider, sub = oauth_sub.split("@", 1)
|
||||
if '@' in oauth_sub:
|
||||
provider, sub = oauth_sub.split('@', 1)
|
||||
else:
|
||||
provider, sub = "oidc", oauth_sub
|
||||
provider, sub = 'oidc', oauth_sub
|
||||
|
||||
oauth_json = json.dumps({provider: {"sub": sub}})
|
||||
oauth_json = json.dumps({provider: {'sub': sub}})
|
||||
conn.execute(
|
||||
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
|
||||
{"oauth": oauth_json, "id": uid},
|
||||
{'oauth': oauth_json, 'id': uid},
|
||||
)
|
||||
|
||||
users_with_keys = conn.execute(
|
||||
sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')
|
||||
).fetchall()
|
||||
users_with_keys = conn.execute(sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')).fetchall()
|
||||
now = int(time.time())
|
||||
|
||||
for uid, api_key in users_with_keys:
|
||||
@@ -178,72 +168,70 @@ def upgrade() -> None:
|
||||
VALUES (:id, :user_id, :key, :created_at, :updated_at)
|
||||
"""),
|
||||
{
|
||||
"id": f"key_{uid}",
|
||||
"user_id": uid,
|
||||
"key": api_key,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
'id': f'key_{uid}',
|
||||
'user_id': uid,
|
||||
'key': api_key,
|
||||
'created_at': now,
|
||||
'updated_at': now,
|
||||
},
|
||||
)
|
||||
|
||||
if conn.dialect.name == "sqlite":
|
||||
_drop_sqlite_indexes_for_column("user", "api_key", conn)
|
||||
_drop_sqlite_indexes_for_column("user", "oauth_sub", conn)
|
||||
if conn.dialect.name == 'sqlite':
|
||||
_drop_sqlite_indexes_for_column('user', 'api_key', conn)
|
||||
_drop_sqlite_indexes_for_column('user', 'oauth_sub', conn)
|
||||
|
||||
with op.batch_alter_table("user") as batch_op:
|
||||
batch_op.drop_column("api_key")
|
||||
batch_op.drop_column("oauth_sub")
|
||||
with op.batch_alter_table('user') as batch_op:
|
||||
batch_op.drop_column('api_key')
|
||||
batch_op.drop_column('oauth_sub')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# --- 1. Restore old oauth_sub column ---
|
||||
op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True))
|
||||
op.add_column('user', sa.Column('oauth_sub', sa.Text(), nullable=True))
|
||||
|
||||
conn = op.get_bind()
|
||||
users = conn.execute(
|
||||
sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')
|
||||
).fetchall()
|
||||
users = conn.execute(sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')).fetchall()
|
||||
|
||||
for uid, oauth in users:
|
||||
try:
|
||||
data = json.loads(oauth)
|
||||
provider = list(data.keys())[0]
|
||||
sub = data[provider].get("sub")
|
||||
oauth_sub = f"{provider}@{sub}"
|
||||
sub = data[provider].get('sub')
|
||||
oauth_sub = f'{provider}@{sub}'
|
||||
except Exception:
|
||||
oauth_sub = None
|
||||
|
||||
conn.execute(
|
||||
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
|
||||
{"oauth_sub": oauth_sub, "id": uid},
|
||||
{'oauth_sub': oauth_sub, 'id': uid},
|
||||
)
|
||||
|
||||
op.drop_column("user", "oauth")
|
||||
op.drop_column('user', 'oauth')
|
||||
|
||||
# --- 2. Restore api_key field ---
|
||||
op.add_column("user", sa.Column("api_key", sa.String(), nullable=True))
|
||||
op.add_column('user', sa.Column('api_key', sa.String(), nullable=True))
|
||||
|
||||
# Restore values from api_key
|
||||
keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall()
|
||||
keys = conn.execute(sa.text('SELECT user_id, key FROM api_key')).fetchall()
|
||||
for uid, key in keys:
|
||||
conn.execute(
|
||||
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
|
||||
{"key": key, "id": uid},
|
||||
{'key': key, 'id': uid},
|
||||
)
|
||||
|
||||
# Drop new table
|
||||
op.drop_table("api_key")
|
||||
op.drop_table('api_key')
|
||||
|
||||
with op.batch_alter_table("user") as batch_op:
|
||||
batch_op.drop_column("profile_banner_image_url")
|
||||
batch_op.drop_column("timezone")
|
||||
with op.batch_alter_table('user') as batch_op:
|
||||
batch_op.drop_column('profile_banner_image_url')
|
||||
batch_op.drop_column('timezone')
|
||||
|
||||
batch_op.drop_column("presence_state")
|
||||
batch_op.drop_column("status_emoji")
|
||||
batch_op.drop_column("status_message")
|
||||
batch_op.drop_column("status_expires_at")
|
||||
batch_op.drop_column('presence_state')
|
||||
batch_op.drop_column('status_emoji')
|
||||
batch_op.drop_column('status_message')
|
||||
batch_op.drop_column('status_expires_at')
|
||||
|
||||
# Convert info (JSON) → TEXT
|
||||
_convert_column_to_text("user", "info")
|
||||
_convert_column_to_text('user', 'info')
|
||||
# Convert settings (JSON) → TEXT
|
||||
_convert_column_to_text("user", "settings")
|
||||
_convert_column_to_text('user', 'settings')
|
||||
|
||||
@@ -12,15 +12,15 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b2c3d4e5f6a7"
|
||||
down_revision: Union[str, None] = "a1b2c3d4e5f6"
|
||||
revision: str = 'b2c3d4e5f6a7'
|
||||
down_revision: Union[str, None] = 'a1b2c3d4e5f6'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("user", sa.Column("scim", sa.JSON(), nullable=True))
|
||||
op.add_column('user', sa.Column('scim', sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "scim")
|
||||
op.drop_column('user', 'scim')
|
||||
|
||||
@@ -12,21 +12,21 @@ import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c0fbf31ca0db"
|
||||
down_revision: Union[str, None] = "ca81bd47c050"
|
||||
revision: str = 'c0fbf31ca0db'
|
||||
down_revision: Union[str, None] = 'ca81bd47c050'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("file", sa.Column("hash", sa.Text(), nullable=True))
|
||||
op.add_column("file", sa.Column("data", sa.JSON(), nullable=True))
|
||||
op.add_column("file", sa.Column("updated_at", sa.BigInteger(), nullable=True))
|
||||
op.add_column('file', sa.Column('hash', sa.Text(), nullable=True))
|
||||
op.add_column('file', sa.Column('data', sa.JSON(), nullable=True))
|
||||
op.add_column('file', sa.Column('updated_at', sa.BigInteger(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("file", "updated_at")
|
||||
op.drop_column("file", "data")
|
||||
op.drop_column("file", "hash")
|
||||
op.drop_column('file', 'updated_at')
|
||||
op.drop_column('file', 'data')
|
||||
op.drop_column('file', 'hash')
|
||||
|
||||
@@ -12,35 +12,33 @@ import json
|
||||
from sqlalchemy.sql import table, column
|
||||
from sqlalchemy import String, Text, JSON, and_
|
||||
|
||||
revision = "c29facfe716b"
|
||||
down_revision = "c69f45358db4"
|
||||
revision = 'c29facfe716b'
|
||||
down_revision = 'c69f45358db4'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# 1. Add the `path` column to the "file" table.
|
||||
op.add_column("file", sa.Column("path", sa.Text(), nullable=True))
|
||||
op.add_column('file', sa.Column('path', sa.Text(), nullable=True))
|
||||
|
||||
# 2. Convert the `meta` column from Text/JSONField to `JSON()`
|
||||
# Use Alembic's default batch_op for dialect compatibility.
|
||||
with op.batch_alter_table("file", schema=None) as batch_op:
|
||||
with op.batch_alter_table('file', schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"meta",
|
||||
'meta',
|
||||
type_=sa.JSON(),
|
||||
existing_type=sa.Text(),
|
||||
existing_nullable=True,
|
||||
nullable=True,
|
||||
postgresql_using="meta::json",
|
||||
postgresql_using='meta::json',
|
||||
)
|
||||
|
||||
# 3. Migrate legacy data from `meta` JSONField
|
||||
# Fetch and process `meta` data from the table, add values to the new `path` column as necessary.
|
||||
# We will use SQLAlchemy core bindings to ensure safety across different databases.
|
||||
|
||||
file_table = table(
|
||||
"file", column("id", String), column("meta", JSON), column("path", Text)
|
||||
)
|
||||
file_table = table('file', column('id', String), column('meta', JSON), column('path', Text))
|
||||
|
||||
# Create connection to the database
|
||||
connection = op.get_bind()
|
||||
@@ -55,24 +53,18 @@ def upgrade():
|
||||
|
||||
# Iterate over each row to extract and update the `path` from `meta` column
|
||||
for row in results:
|
||||
if "path" in row.meta:
|
||||
if 'path' in row.meta:
|
||||
# Extract the `path` field from the `meta` JSON
|
||||
path = row.meta.get("path")
|
||||
path = row.meta.get('path')
|
||||
|
||||
# Update the `file` table with the new `path` value
|
||||
connection.execute(
|
||||
file_table.update()
|
||||
.where(file_table.c.id == row.id)
|
||||
.values({"path": path})
|
||||
)
|
||||
connection.execute(file_table.update().where(file_table.c.id == row.id).values({'path': path}))
|
||||
|
||||
|
||||
def downgrade():
|
||||
# 1. Remove the `path` column
|
||||
op.drop_column("file", "path")
|
||||
op.drop_column('file', 'path')
|
||||
|
||||
# 2. Revert the `meta` column back to Text/JSONField
|
||||
with op.batch_alter_table("file", schema=None) as batch_op:
|
||||
batch_op.alter_column(
|
||||
"meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True
|
||||
)
|
||||
with op.batch_alter_table('file', schema=None) as batch_op:
|
||||
batch_op.alter_column('meta', type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True)
|
||||
|
||||
@@ -12,45 +12,43 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c440947495f3"
|
||||
down_revision: Union[str, None] = "81cc2ce44d79"
|
||||
revision: str = 'c440947495f3'
|
||||
down_revision: Union[str, None] = '81cc2ce44d79'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chat_file",
|
||||
sa.Column("id", sa.Text(), primary_key=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
'chat_file',
|
||||
sa.Column('id', sa.Text(), primary_key=True),
|
||||
sa.Column('user_id', sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"chat_id",
|
||||
'chat_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey("chat.id", ondelete="CASCADE"),
|
||||
sa.ForeignKey('chat.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"file_id",
|
||||
'file_id',
|
||||
sa.Text(),
|
||||
sa.ForeignKey("file.id", ondelete="CASCADE"),
|
||||
sa.ForeignKey('file.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("message_id", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
||||
sa.Column('message_id', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||
# indexes
|
||||
sa.Index("ix_chat_file_chat_id", "chat_id"),
|
||||
sa.Index("ix_chat_file_file_id", "file_id"),
|
||||
sa.Index("ix_chat_file_message_id", "message_id"),
|
||||
sa.Index("ix_chat_file_user_id", "user_id"),
|
||||
sa.Index('ix_chat_file_chat_id', 'chat_id'),
|
||||
sa.Index('ix_chat_file_file_id', 'file_id'),
|
||||
sa.Index('ix_chat_file_message_id', 'message_id'),
|
||||
sa.Index('ix_chat_file_user_id', 'user_id'),
|
||||
# unique constraints
|
||||
sa.UniqueConstraint(
|
||||
"chat_id", "file_id", name="uq_chat_file_chat_file"
|
||||
), # prevent duplicate entries
|
||||
sa.UniqueConstraint('chat_id', 'file_id', name='uq_chat_file_chat_file'), # prevent duplicate entries
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("chat_file")
|
||||
op.drop_table('chat_file')
|
||||
pass
|
||||
|
||||
@@ -9,42 +9,40 @@ Create Date: 2024-10-16 02:02:35.241684
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "c69f45358db4"
|
||||
down_revision = "3ab32c4b8f59"
|
||||
revision = 'c69f45358db4'
|
||||
down_revision = '3ab32c4b8f59'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"folder",
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
||||
sa.Column("user_id", sa.Text(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("items", sa.JSON(), nullable=True),
|
||||
sa.Column("meta", sa.JSON(), nullable=True),
|
||||
sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False),
|
||||
'folder',
|
||||
sa.Column('id', sa.Text(), nullable=False),
|
||||
sa.Column('parent_id', sa.Text(), nullable=True),
|
||||
sa.Column('user_id', sa.Text(), nullable=False),
|
||||
sa.Column('name', sa.Text(), nullable=False),
|
||||
sa.Column('items', sa.JSON(), nullable=True),
|
||||
sa.Column('meta', sa.JSON(), nullable=True),
|
||||
sa.Column('is_expanded', sa.Boolean(), default=False, nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
'updated_at',
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
onupdate=sa.func.now(),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "user_id"),
|
||||
sa.PrimaryKeyConstraint('id', 'user_id'),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"chat",
|
||||
sa.Column("folder_id", sa.Text(), nullable=True),
|
||||
'chat',
|
||||
sa.Column('folder_id', sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("chat", "folder_id")
|
||||
op.drop_column('chat', 'folder_id')
|
||||
|
||||
op.drop_table("folder")
|
||||
op.drop_table('folder')
|
||||
|
||||
@@ -12,23 +12,21 @@ import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "ca81bd47c050"
|
||||
down_revision: Union[str, None] = "7e5b5dc7342b"
|
||||
revision: str = 'ca81bd47c050'
|
||||
down_revision: Union[str, None] = '7e5b5dc7342b'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"config",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("data", sa.JSON(), nullable=False),
|
||||
sa.Column("version", sa.Integer, nullable=False),
|
||||
'config',
|
||||
sa.Column('id', sa.Integer, primary_key=True),
|
||||
sa.Column('data', sa.JSON(), nullable=False),
|
||||
sa.Column('version', sa.Integer, nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||
sa.Column(
|
||||
"created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
'updated_at',
|
||||
sa.DateTime(),
|
||||
nullable=True,
|
||||
server_default=sa.func.now(),
|
||||
@@ -38,4 +36,4 @@ def upgrade():
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_table("config")
|
||||
op.drop_table('config')
|
||||
|
||||
@@ -9,15 +9,15 @@ Create Date: 2025-07-13 03:00:00.000000
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "d31026856c01"
|
||||
down_revision = "9f0c9cd09105"
|
||||
revision = 'd31026856c01'
|
||||
down_revision = '9f0c9cd09105'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True))
|
||||
op.add_column('folder', sa.Column('data', sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_column("folder", "data")
|
||||
op.drop_column('folder', 'data')
|
||||
|
||||
@@ -20,8 +20,8 @@ import sqlalchemy as sa
|
||||
|
||||
from open_webui.migrations.util import get_existing_tables
|
||||
|
||||
revision: str = "f1e2d3c4b5a6"
|
||||
down_revision: Union[str, None] = "8452d01d26d7"
|
||||
revision: str = 'f1e2d3c4b5a6'
|
||||
down_revision: Union[str, None] = '8452d01d26d7'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -30,34 +30,34 @@ def upgrade() -> None:
|
||||
existing_tables = set(get_existing_tables())
|
||||
|
||||
# Create access_grant table
|
||||
if "access_grant" not in existing_tables:
|
||||
if 'access_grant' not in existing_tables:
|
||||
op.create_table(
|
||||
"access_grant",
|
||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True),
|
||||
sa.Column("resource_type", sa.Text(), nullable=False),
|
||||
sa.Column("resource_id", sa.Text(), nullable=False),
|
||||
sa.Column("principal_type", sa.Text(), nullable=False),
|
||||
sa.Column("principal_id", sa.Text(), nullable=False),
|
||||
sa.Column("permission", sa.Text(), nullable=False),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
||||
'access_grant',
|
||||
sa.Column('id', sa.Text(), nullable=False, primary_key=True),
|
||||
sa.Column('resource_type', sa.Text(), nullable=False),
|
||||
sa.Column('resource_id', sa.Text(), nullable=False),
|
||||
sa.Column('principal_type', sa.Text(), nullable=False),
|
||||
sa.Column('principal_id', sa.Text(), nullable=False),
|
||||
sa.Column('permission', sa.Text(), nullable=False),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||
sa.UniqueConstraint(
|
||||
"resource_type",
|
||||
"resource_id",
|
||||
"principal_type",
|
||||
"principal_id",
|
||||
"permission",
|
||||
name="uq_access_grant_grant",
|
||||
'resource_type',
|
||||
'resource_id',
|
||||
'principal_type',
|
||||
'principal_id',
|
||||
'permission',
|
||||
name='uq_access_grant_grant',
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_access_grant_resource",
|
||||
"access_grant",
|
||||
["resource_type", "resource_id"],
|
||||
'idx_access_grant_resource',
|
||||
'access_grant',
|
||||
['resource_type', 'resource_id'],
|
||||
)
|
||||
op.create_index(
|
||||
"idx_access_grant_principal",
|
||||
"access_grant",
|
||||
["principal_type", "principal_id"],
|
||||
'idx_access_grant_principal',
|
||||
'access_grant',
|
||||
['principal_type', 'principal_id'],
|
||||
)
|
||||
|
||||
# Backfill existing access_control JSON data
|
||||
@@ -65,13 +65,13 @@ def upgrade() -> None:
|
||||
|
||||
# Tables with access_control JSON columns: (table_name, resource_type)
|
||||
resource_tables = [
|
||||
("knowledge", "knowledge"),
|
||||
("prompt", "prompt"),
|
||||
("tool", "tool"),
|
||||
("model", "model"),
|
||||
("note", "note"),
|
||||
("channel", "channel"),
|
||||
("file", "file"),
|
||||
('knowledge', 'knowledge'),
|
||||
('prompt', 'prompt'),
|
||||
('tool', 'tool'),
|
||||
('model', 'model'),
|
||||
('note', 'note'),
|
||||
('channel', 'channel'),
|
||||
('file', 'file'),
|
||||
]
|
||||
|
||||
now = int(time.time())
|
||||
@@ -83,9 +83,7 @@ def upgrade() -> None:
|
||||
|
||||
# Query all rows
|
||||
try:
|
||||
result = conn.execute(
|
||||
sa.text(f'SELECT id, access_control FROM "{table_name}"')
|
||||
)
|
||||
result = conn.execute(sa.text(f'SELECT id, access_control FROM "{table_name}"'))
|
||||
rows = result.fetchall()
|
||||
except Exception:
|
||||
continue
|
||||
@@ -99,19 +97,16 @@ def upgrade() -> None:
|
||||
# EXCEPTION: files with NULL are PRIVATE (owner-only), not public
|
||||
is_null = (
|
||||
access_control_json is None
|
||||
or access_control_json == "null"
|
||||
or (
|
||||
isinstance(access_control_json, str)
|
||||
and access_control_json.strip().lower() == "null"
|
||||
)
|
||||
or access_control_json == 'null'
|
||||
or (isinstance(access_control_json, str) and access_control_json.strip().lower() == 'null')
|
||||
)
|
||||
if is_null:
|
||||
# Files: NULL = private (no entry needed, owner has implicit access)
|
||||
# Other resources: NULL = public (insert user:* for read)
|
||||
if resource_type == "file":
|
||||
if resource_type == 'file':
|
||||
continue # Private - no entry needed
|
||||
|
||||
key = (resource_type, resource_id, "user", "*", "read")
|
||||
key = (resource_type, resource_id, 'user', '*', 'read')
|
||||
if key not in inserted:
|
||||
try:
|
||||
conn.execute(
|
||||
@@ -120,13 +115,13 @@ def upgrade() -> None:
|
||||
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
|
||||
"""),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"resource_type": resource_type,
|
||||
"resource_id": resource_id,
|
||||
"principal_type": "user",
|
||||
"principal_id": "*",
|
||||
"permission": "read",
|
||||
"created_at": now,
|
||||
'id': str(uuid.uuid4()),
|
||||
'resource_type': resource_type,
|
||||
'resource_id': resource_id,
|
||||
'principal_type': 'user',
|
||||
'principal_id': '*',
|
||||
'permission': 'read',
|
||||
'created_at': now,
|
||||
},
|
||||
)
|
||||
inserted.add(key)
|
||||
@@ -149,28 +144,24 @@ def upgrade() -> None:
|
||||
continue
|
||||
|
||||
# Check if it's effectively empty (no read/write keys with content)
|
||||
read_data = access_control_json.get("read", {})
|
||||
write_data = access_control_json.get("write", {})
|
||||
read_data = access_control_json.get('read', {})
|
||||
write_data = access_control_json.get('write', {})
|
||||
|
||||
has_read_grants = read_data.get("group_ids", []) or read_data.get(
|
||||
"user_ids", []
|
||||
)
|
||||
has_write_grants = write_data.get("group_ids", []) or write_data.get(
|
||||
"user_ids", []
|
||||
)
|
||||
has_read_grants = read_data.get('group_ids', []) or read_data.get('user_ids', [])
|
||||
has_write_grants = write_data.get('group_ids', []) or write_data.get('user_ids', [])
|
||||
|
||||
if not has_read_grants and not has_write_grants:
|
||||
# Empty permissions = private, no grants needed
|
||||
continue
|
||||
|
||||
# Extract permissions and insert into access_grant table
|
||||
for permission in ["read", "write"]:
|
||||
for permission in ['read', 'write']:
|
||||
perm_data = access_control_json.get(permission, {})
|
||||
if not perm_data:
|
||||
continue
|
||||
|
||||
for group_id in perm_data.get("group_ids", []):
|
||||
key = (resource_type, resource_id, "group", group_id, permission)
|
||||
for group_id in perm_data.get('group_ids', []):
|
||||
key = (resource_type, resource_id, 'group', group_id, permission)
|
||||
if key in inserted:
|
||||
continue
|
||||
try:
|
||||
@@ -180,21 +171,21 @@ def upgrade() -> None:
|
||||
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
|
||||
"""),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"resource_type": resource_type,
|
||||
"resource_id": resource_id,
|
||||
"principal_type": "group",
|
||||
"principal_id": group_id,
|
||||
"permission": permission,
|
||||
"created_at": now,
|
||||
'id': str(uuid.uuid4()),
|
||||
'resource_type': resource_type,
|
||||
'resource_id': resource_id,
|
||||
'principal_type': 'group',
|
||||
'principal_id': group_id,
|
||||
'permission': permission,
|
||||
'created_at': now,
|
||||
},
|
||||
)
|
||||
inserted.add(key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for user_id in perm_data.get("user_ids", []):
|
||||
key = (resource_type, resource_id, "user", user_id, permission)
|
||||
for user_id in perm_data.get('user_ids', []):
|
||||
key = (resource_type, resource_id, 'user', user_id, permission)
|
||||
if key in inserted:
|
||||
continue
|
||||
try:
|
||||
@@ -204,13 +195,13 @@ def upgrade() -> None:
|
||||
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
|
||||
"""),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"resource_type": resource_type,
|
||||
"resource_id": resource_id,
|
||||
"principal_type": "user",
|
||||
"principal_id": user_id,
|
||||
"permission": permission,
|
||||
"created_at": now,
|
||||
'id': str(uuid.uuid4()),
|
||||
'resource_type': resource_type,
|
||||
'resource_id': resource_id,
|
||||
'principal_type': 'user',
|
||||
'principal_id': user_id,
|
||||
'permission': permission,
|
||||
'created_at': now,
|
||||
},
|
||||
)
|
||||
inserted.add(key)
|
||||
@@ -223,7 +214,7 @@ def upgrade() -> None:
|
||||
continue
|
||||
try:
|
||||
with op.batch_alter_table(table_name) as batch:
|
||||
batch.drop_column("access_control")
|
||||
batch.drop_column('access_control')
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -235,20 +226,20 @@ def downgrade() -> None:
|
||||
|
||||
# Resource tables mapping: (table_name, resource_type)
|
||||
resource_tables = [
|
||||
("knowledge", "knowledge"),
|
||||
("prompt", "prompt"),
|
||||
("tool", "tool"),
|
||||
("model", "model"),
|
||||
("note", "note"),
|
||||
("channel", "channel"),
|
||||
("file", "file"),
|
||||
('knowledge', 'knowledge'),
|
||||
('prompt', 'prompt'),
|
||||
('tool', 'tool'),
|
||||
('model', 'model'),
|
||||
('note', 'note'),
|
||||
('channel', 'channel'),
|
||||
('file', 'file'),
|
||||
]
|
||||
|
||||
# Step 1: Re-add access_control columns to resource tables
|
||||
for table_name, _ in resource_tables:
|
||||
try:
|
||||
with op.batch_alter_table(table_name) as batch:
|
||||
batch.add_column(sa.Column("access_control", sa.JSON(), nullable=True))
|
||||
batch.add_column(sa.Column('access_control', sa.JSON(), nullable=True))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -262,7 +253,7 @@ def downgrade() -> None:
|
||||
FROM access_grant
|
||||
WHERE resource_type = :resource_type
|
||||
"""),
|
||||
{"resource_type": resource_type},
|
||||
{'resource_type': resource_type},
|
||||
)
|
||||
rows = result.fetchall()
|
||||
except Exception:
|
||||
@@ -278,49 +269,35 @@ def downgrade() -> None:
|
||||
|
||||
if resource_id not in resource_grants:
|
||||
resource_grants[resource_id] = {
|
||||
"is_public": False,
|
||||
"read": {"group_ids": [], "user_ids": []},
|
||||
"write": {"group_ids": [], "user_ids": []},
|
||||
'is_public': False,
|
||||
'read': {'group_ids': [], 'user_ids': []},
|
||||
'write': {'group_ids': [], 'user_ids': []},
|
||||
}
|
||||
|
||||
# Handle public access (user:* for read)
|
||||
if (
|
||||
principal_type == "user"
|
||||
and principal_id == "*"
|
||||
and permission == "read"
|
||||
):
|
||||
resource_grants[resource_id]["is_public"] = True
|
||||
if principal_type == 'user' and principal_id == '*' and permission == 'read':
|
||||
resource_grants[resource_id]['is_public'] = True
|
||||
continue
|
||||
|
||||
# Add to appropriate list
|
||||
if permission in ["read", "write"]:
|
||||
if principal_type == "group":
|
||||
if (
|
||||
principal_id
|
||||
not in resource_grants[resource_id][permission]["group_ids"]
|
||||
):
|
||||
resource_grants[resource_id][permission]["group_ids"].append(
|
||||
principal_id
|
||||
)
|
||||
elif principal_type == "user":
|
||||
if (
|
||||
principal_id
|
||||
not in resource_grants[resource_id][permission]["user_ids"]
|
||||
):
|
||||
resource_grants[resource_id][permission]["user_ids"].append(
|
||||
principal_id
|
||||
)
|
||||
if permission in ['read', 'write']:
|
||||
if principal_type == 'group':
|
||||
if principal_id not in resource_grants[resource_id][permission]['group_ids']:
|
||||
resource_grants[resource_id][permission]['group_ids'].append(principal_id)
|
||||
elif principal_type == 'user':
|
||||
if principal_id not in resource_grants[resource_id][permission]['user_ids']:
|
||||
resource_grants[resource_id][permission]['user_ids'].append(principal_id)
|
||||
|
||||
# Step 3: Update each resource with reconstructed JSON
|
||||
for resource_id, grants in resource_grants.items():
|
||||
if grants["is_public"]:
|
||||
if grants['is_public']:
|
||||
# Public = NULL
|
||||
access_control_value = None
|
||||
elif (
|
||||
not grants["read"]["group_ids"]
|
||||
and not grants["read"]["user_ids"]
|
||||
and not grants["write"]["group_ids"]
|
||||
and not grants["write"]["user_ids"]
|
||||
not grants['read']['group_ids']
|
||||
and not grants['read']['user_ids']
|
||||
and not grants['write']['group_ids']
|
||||
and not grants['write']['user_ids']
|
||||
):
|
||||
# No grants = should not happen (would mean no entries), default to {}
|
||||
access_control_value = json.dumps({})
|
||||
@@ -328,17 +305,15 @@ def downgrade() -> None:
|
||||
# Custom permissions
|
||||
access_control_value = json.dumps(
|
||||
{
|
||||
"read": grants["read"],
|
||||
"write": grants["write"],
|
||||
'read': grants['read'],
|
||||
'write': grants['write'],
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id'
|
||||
),
|
||||
{"access_control": access_control_value, "id": resource_id},
|
||||
sa.text(f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id'),
|
||||
{'access_control': access_control_value, 'id': resource_id},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -346,7 +321,7 @@ def downgrade() -> None:
|
||||
# Step 4: Set all resources WITHOUT entries to private
|
||||
# For files: NULL means private (owner-only), so leave as NULL
|
||||
# For other resources: {} means private, so update to {}
|
||||
if resource_type != "file":
|
||||
if resource_type != 'file':
|
||||
try:
|
||||
conn.execute(
|
||||
sa.text(f"""
|
||||
@@ -357,13 +332,13 @@ def downgrade() -> None:
|
||||
)
|
||||
AND access_control IS NULL
|
||||
"""),
|
||||
{"private_value": json.dumps({}), "resource_type": resource_type},
|
||||
{'private_value': json.dumps({}), 'resource_type': resource_type},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# For files, NULL stays NULL - no action needed
|
||||
|
||||
# Step 5: Drop the access_grant table
|
||||
op.drop_index("idx_access_grant_principal", table_name="access_grant")
|
||||
op.drop_index("idx_access_grant_resource", table_name="access_grant")
|
||||
op.drop_table("access_grant")
|
||||
op.drop_index('idx_access_grant_principal', table_name='access_grant')
|
||||
op.drop_index('idx_access_grant_resource', table_name='access_grant')
|
||||
op.drop_table('access_grant')
|
||||
|
||||
@@ -19,28 +19,24 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccessGrant(Base):
|
||||
__tablename__ = "access_grant"
|
||||
__tablename__ = 'access_grant'
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
resource_type = Column(
|
||||
Text, nullable=False
|
||||
) # "knowledge", "model", "prompt", "tool", "note", "channel", "file"
|
||||
resource_type = Column(Text, nullable=False) # "knowledge", "model", "prompt", "tool", "note", "channel", "file"
|
||||
resource_id = Column(Text, nullable=False)
|
||||
principal_type = Column(Text, nullable=False) # "user" or "group"
|
||||
principal_id = Column(
|
||||
Text, nullable=False
|
||||
) # user_id, group_id, or "*" (wildcard for public)
|
||||
principal_id = Column(Text, nullable=False) # user_id, group_id, or "*" (wildcard for public)
|
||||
permission = Column(Text, nullable=False) # "read" or "write"
|
||||
created_at = Column(BigInteger, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"resource_type",
|
||||
"resource_id",
|
||||
"principal_type",
|
||||
"principal_id",
|
||||
"permission",
|
||||
name="uq_access_grant_grant",
|
||||
'resource_type',
|
||||
'resource_id',
|
||||
'principal_type',
|
||||
'principal_id',
|
||||
'permission',
|
||||
name='uq_access_grant_grant',
|
||||
),
|
||||
)
|
||||
|
||||
@@ -66,7 +62,7 @@ class AccessGrantResponse(BaseModel):
|
||||
permission: str
|
||||
|
||||
@classmethod
|
||||
def from_grant(cls, grant: "AccessGrantModel") -> "AccessGrantResponse":
|
||||
def from_grant(cls, grant: 'AccessGrantModel') -> 'AccessGrantResponse':
|
||||
return cls(
|
||||
id=grant.id,
|
||||
principal_type=grant.principal_type,
|
||||
@@ -100,14 +96,14 @@ def access_control_to_grants(
|
||||
if access_control is None:
|
||||
# NULL → public read (user:* for read)
|
||||
# Exception: files with NULL are private (owner-only), no grants needed
|
||||
if resource_type != "file":
|
||||
if resource_type != 'file':
|
||||
grants.append(
|
||||
{
|
||||
"resource_type": resource_type,
|
||||
"resource_id": resource_id,
|
||||
"principal_type": "user",
|
||||
"principal_id": "*",
|
||||
"permission": "read",
|
||||
'resource_type': resource_type,
|
||||
'resource_id': resource_id,
|
||||
'principal_type': 'user',
|
||||
'principal_id': '*',
|
||||
'permission': 'read',
|
||||
}
|
||||
)
|
||||
return grants
|
||||
@@ -117,30 +113,30 @@ def access_control_to_grants(
|
||||
return grants
|
||||
|
||||
# Parse structured permissions
|
||||
for permission in ["read", "write"]:
|
||||
for permission in ['read', 'write']:
|
||||
perm_data = access_control.get(permission, {})
|
||||
if not perm_data:
|
||||
continue
|
||||
|
||||
for group_id in perm_data.get("group_ids", []):
|
||||
for group_id in perm_data.get('group_ids', []):
|
||||
grants.append(
|
||||
{
|
||||
"resource_type": resource_type,
|
||||
"resource_id": resource_id,
|
||||
"principal_type": "group",
|
||||
"principal_id": group_id,
|
||||
"permission": permission,
|
||||
'resource_type': resource_type,
|
||||
'resource_id': resource_id,
|
||||
'principal_type': 'group',
|
||||
'principal_id': group_id,
|
||||
'permission': permission,
|
||||
}
|
||||
)
|
||||
|
||||
for user_id in perm_data.get("user_ids", []):
|
||||
for user_id in perm_data.get('user_ids', []):
|
||||
grants.append(
|
||||
{
|
||||
"resource_type": resource_type,
|
||||
"resource_id": resource_id,
|
||||
"principal_type": "user",
|
||||
"principal_id": user_id,
|
||||
"permission": permission,
|
||||
'resource_type': resource_type,
|
||||
'resource_id': resource_id,
|
||||
'principal_type': 'user',
|
||||
'principal_id': user_id,
|
||||
'permission': permission,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -164,27 +160,23 @@ def normalize_access_grants(access_grants: Optional[list]) -> list[dict]:
|
||||
if not isinstance(grant, dict):
|
||||
continue
|
||||
|
||||
principal_type = grant.get("principal_type")
|
||||
principal_id = grant.get("principal_id")
|
||||
permission = grant.get("permission")
|
||||
principal_type = grant.get('principal_type')
|
||||
principal_id = grant.get('principal_id')
|
||||
permission = grant.get('permission')
|
||||
|
||||
if principal_type not in ("user", "group"):
|
||||
if principal_type not in ('user', 'group'):
|
||||
continue
|
||||
if permission not in ("read", "write"):
|
||||
if permission not in ('read', 'write'):
|
||||
continue
|
||||
if not isinstance(principal_id, str) or not principal_id:
|
||||
continue
|
||||
|
||||
key = (principal_type, principal_id, permission)
|
||||
deduped[key] = {
|
||||
"id": (
|
||||
grant.get("id")
|
||||
if isinstance(grant.get("id"), str) and grant.get("id")
|
||||
else str(uuid.uuid4())
|
||||
),
|
||||
"principal_type": principal_type,
|
||||
"principal_id": principal_id,
|
||||
"permission": permission,
|
||||
'id': (grant.get('id') if isinstance(grant.get('id'), str) and grant.get('id') else str(uuid.uuid4())),
|
||||
'principal_type': principal_type,
|
||||
'principal_id': principal_id,
|
||||
'permission': permission,
|
||||
}
|
||||
|
||||
return list(deduped.values())
|
||||
@@ -195,11 +187,7 @@ def has_public_read_access_grant(access_grants: Optional[list]) -> bool:
|
||||
Returns True when a direct grant list includes wildcard public-read.
|
||||
"""
|
||||
for grant in normalize_access_grants(access_grants):
|
||||
if (
|
||||
grant["principal_type"] == "user"
|
||||
and grant["principal_id"] == "*"
|
||||
and grant["permission"] == "read"
|
||||
):
|
||||
if grant['principal_type'] == 'user' and grant['principal_id'] == '*' and grant['permission'] == 'read':
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -209,7 +197,7 @@ def has_user_access_grant(access_grants: Optional[list]) -> bool:
|
||||
Returns True when a direct grant list includes any non-wildcard user grant.
|
||||
"""
|
||||
for grant in normalize_access_grants(access_grants):
|
||||
if grant["principal_type"] == "user" and grant["principal_id"] != "*":
|
||||
if grant['principal_type'] == 'user' and grant['principal_id'] != '*':
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -225,18 +213,9 @@ def strip_user_access_grants(access_grants: Optional[list]) -> list:
|
||||
grant
|
||||
for grant in access_grants
|
||||
if not (
|
||||
(
|
||||
grant.get("principal_type")
|
||||
if isinstance(grant, dict)
|
||||
else getattr(grant, "principal_type", None)
|
||||
)
|
||||
== "user"
|
||||
and (
|
||||
grant.get("principal_id")
|
||||
if isinstance(grant, dict)
|
||||
else getattr(grant, "principal_id", None)
|
||||
)
|
||||
!= "*"
|
||||
(grant.get('principal_type') if isinstance(grant, dict) else getattr(grant, 'principal_type', None))
|
||||
== 'user'
|
||||
and (grant.get('principal_id') if isinstance(grant, dict) else getattr(grant, 'principal_id', None)) != '*'
|
||||
)
|
||||
]
|
||||
|
||||
@@ -260,29 +239,25 @@ def grants_to_access_control(grants: list) -> Optional[dict]:
|
||||
return {} # No grants = private/owner-only
|
||||
|
||||
result = {
|
||||
"read": {"group_ids": [], "user_ids": []},
|
||||
"write": {"group_ids": [], "user_ids": []},
|
||||
'read': {'group_ids': [], 'user_ids': []},
|
||||
'write': {'group_ids': [], 'user_ids': []},
|
||||
}
|
||||
|
||||
is_public = False
|
||||
for grant in grants:
|
||||
if (
|
||||
grant.principal_type == "user"
|
||||
and grant.principal_id == "*"
|
||||
and grant.permission == "read"
|
||||
):
|
||||
if grant.principal_type == 'user' and grant.principal_id == '*' and grant.permission == 'read':
|
||||
is_public = True
|
||||
continue # Don't add wildcard to user_ids list
|
||||
|
||||
if grant.permission not in ("read", "write"):
|
||||
if grant.permission not in ('read', 'write'):
|
||||
continue
|
||||
|
||||
if grant.principal_type == "group":
|
||||
if grant.principal_id not in result[grant.permission]["group_ids"]:
|
||||
result[grant.permission]["group_ids"].append(grant.principal_id)
|
||||
elif grant.principal_type == "user":
|
||||
if grant.principal_id not in result[grant.permission]["user_ids"]:
|
||||
result[grant.permission]["user_ids"].append(grant.principal_id)
|
||||
if grant.principal_type == 'group':
|
||||
if grant.principal_id not in result[grant.permission]['group_ids']:
|
||||
result[grant.permission]['group_ids'].append(grant.principal_id)
|
||||
elif grant.principal_type == 'user':
|
||||
if grant.principal_id not in result[grant.permission]['user_ids']:
|
||||
result[grant.permission]['user_ids'].append(grant.principal_id)
|
||||
|
||||
if is_public:
|
||||
return None # Public read access
|
||||
@@ -399,9 +374,7 @@ class AccessGrantsTable:
|
||||
).delete()
|
||||
|
||||
# Convert JSON to grant dicts
|
||||
grant_dicts = access_control_to_grants(
|
||||
resource_type, resource_id, access_control
|
||||
)
|
||||
grant_dicts = access_control_to_grants(resource_type, resource_id, access_control)
|
||||
|
||||
# Insert new grants
|
||||
results = []
|
||||
@@ -442,9 +415,9 @@ class AccessGrantsTable:
|
||||
id=str(uuid.uuid4()),
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
principal_type=grant_dict["principal_type"],
|
||||
principal_id=grant_dict["principal_id"],
|
||||
permission=grant_dict["permission"],
|
||||
principal_type=grant_dict['principal_type'],
|
||||
principal_id=grant_dict['principal_id'],
|
||||
permission=grant_dict['permission'],
|
||||
created_at=int(time.time()),
|
||||
)
|
||||
db.add(grant)
|
||||
@@ -511,9 +484,7 @@ class AccessGrantsTable:
|
||||
)
|
||||
.all()
|
||||
)
|
||||
result: dict[str, list[AccessGrantModel]] = {
|
||||
rid: [] for rid in resource_ids
|
||||
}
|
||||
result: dict[str, list[AccessGrantModel]] = {rid: [] for rid in resource_ids}
|
||||
for g in grants:
|
||||
result[g.resource_id].append(AccessGrantModel.model_validate(g))
|
||||
return result
|
||||
@@ -523,7 +494,7 @@ class AccessGrantsTable:
|
||||
user_id: str,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
permission: str = "read",
|
||||
permission: str = 'read',
|
||||
user_group_ids: Optional[set[str]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> bool:
|
||||
@@ -540,12 +511,12 @@ class AccessGrantsTable:
|
||||
conditions = [
|
||||
# Public access
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_id == "*",
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == '*',
|
||||
),
|
||||
# Direct user access
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == user_id,
|
||||
),
|
||||
]
|
||||
@@ -560,7 +531,7 @@ class AccessGrantsTable:
|
||||
if user_group_ids:
|
||||
conditions.append(
|
||||
and_(
|
||||
AccessGrant.principal_type == "group",
|
||||
AccessGrant.principal_type == 'group',
|
||||
AccessGrant.principal_id.in_(user_group_ids),
|
||||
)
|
||||
)
|
||||
@@ -582,7 +553,7 @@ class AccessGrantsTable:
|
||||
user_id: str,
|
||||
resource_type: str,
|
||||
resource_ids: list[str],
|
||||
permission: str = "read",
|
||||
permission: str = 'read',
|
||||
user_group_ids: Optional[set[str]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> set[str]:
|
||||
@@ -597,11 +568,11 @@ class AccessGrantsTable:
|
||||
with get_db_context(db) as db:
|
||||
conditions = [
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_id == "*",
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == '*',
|
||||
),
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == user_id,
|
||||
),
|
||||
]
|
||||
@@ -615,7 +586,7 @@ class AccessGrantsTable:
|
||||
if user_group_ids:
|
||||
conditions.append(
|
||||
and_(
|
||||
AccessGrant.principal_type == "group",
|
||||
AccessGrant.principal_type == 'group',
|
||||
AccessGrant.principal_id.in_(user_group_ids),
|
||||
)
|
||||
)
|
||||
@@ -637,7 +608,7 @@ class AccessGrantsTable:
|
||||
self,
|
||||
resource_type: str,
|
||||
resource_id: str,
|
||||
permission: str = "read",
|
||||
permission: str = 'read',
|
||||
db: Optional[Session] = None,
|
||||
) -> list:
|
||||
"""
|
||||
@@ -660,19 +631,17 @@ class AccessGrantsTable:
|
||||
|
||||
# Check for public access
|
||||
for grant in grants:
|
||||
if grant.principal_type == "user" and grant.principal_id == "*":
|
||||
result = Users.get_users(filter={"roles": ["!pending"]}, db=db)
|
||||
return result.get("users", [])
|
||||
if grant.principal_type == 'user' and grant.principal_id == '*':
|
||||
result = Users.get_users(filter={'roles': ['!pending']}, db=db)
|
||||
return result.get('users', [])
|
||||
|
||||
user_ids_with_access = set()
|
||||
|
||||
for grant in grants:
|
||||
if grant.principal_type == "user":
|
||||
if grant.principal_type == 'user':
|
||||
user_ids_with_access.add(grant.principal_id)
|
||||
elif grant.principal_type == "group":
|
||||
group_user_ids = Groups.get_group_user_ids_by_id(
|
||||
grant.principal_id, db=db
|
||||
)
|
||||
elif grant.principal_type == 'group':
|
||||
group_user_ids = Groups.get_group_user_ids_by_id(grant.principal_id, db=db)
|
||||
if group_user_ids:
|
||||
user_ids_with_access.update(group_user_ids)
|
||||
|
||||
@@ -688,20 +657,18 @@ class AccessGrantsTable:
|
||||
DocumentModel,
|
||||
filter: dict,
|
||||
resource_type: str,
|
||||
permission: str = "read",
|
||||
permission: str = 'read',
|
||||
):
|
||||
"""
|
||||
Apply access control filtering to a SQLAlchemy query by JOINing with access_grant.
|
||||
|
||||
This replaces the old JSON-column-based filtering with a proper relational JOIN.
|
||||
"""
|
||||
group_ids = filter.get("group_ids", [])
|
||||
user_id = filter.get("user_id")
|
||||
group_ids = filter.get('group_ids', [])
|
||||
user_id = filter.get('user_id')
|
||||
|
||||
if permission == "read_only":
|
||||
return self._has_read_only_permission_filter(
|
||||
db, query, DocumentModel, filter, resource_type
|
||||
)
|
||||
if permission == 'read_only':
|
||||
return self._has_read_only_permission_filter(db, query, DocumentModel, filter, resource_type)
|
||||
|
||||
# Build principal conditions
|
||||
principal_conditions = []
|
||||
@@ -710,8 +677,8 @@ class AccessGrantsTable:
|
||||
# Public access: user:* read
|
||||
principal_conditions.append(
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_id == "*",
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == '*',
|
||||
)
|
||||
)
|
||||
|
||||
@@ -722,7 +689,7 @@ class AccessGrantsTable:
|
||||
# Direct user grant
|
||||
principal_conditions.append(
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == user_id,
|
||||
)
|
||||
)
|
||||
@@ -731,7 +698,7 @@ class AccessGrantsTable:
|
||||
# Group grants
|
||||
principal_conditions.append(
|
||||
and_(
|
||||
AccessGrant.principal_type == "group",
|
||||
AccessGrant.principal_type == 'group',
|
||||
AccessGrant.principal_id.in_(group_ids),
|
||||
)
|
||||
)
|
||||
@@ -751,13 +718,13 @@ class AccessGrantsTable:
|
||||
AccessGrant.permission == permission,
|
||||
or_(
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_id == "*",
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == '*',
|
||||
),
|
||||
*(
|
||||
[
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == user_id,
|
||||
)
|
||||
]
|
||||
@@ -767,7 +734,7 @@ class AccessGrantsTable:
|
||||
*(
|
||||
[
|
||||
and_(
|
||||
AccessGrant.principal_type == "group",
|
||||
AccessGrant.principal_type == 'group',
|
||||
AccessGrant.principal_id.in_(group_ids),
|
||||
)
|
||||
]
|
||||
@@ -800,8 +767,8 @@ class AccessGrantsTable:
|
||||
Filter for items where user has read BUT NOT write access.
|
||||
Public items are NOT considered read_only.
|
||||
"""
|
||||
group_ids = filter.get("group_ids", [])
|
||||
user_id = filter.get("user_id")
|
||||
group_ids = filter.get('group_ids', [])
|
||||
user_id = filter.get('user_id')
|
||||
|
||||
from sqlalchemy import exists as sa_exists, select
|
||||
|
||||
@@ -811,12 +778,12 @@ class AccessGrantsTable:
|
||||
.where(
|
||||
AccessGrant.resource_type == resource_type,
|
||||
AccessGrant.resource_id == DocumentModel.id,
|
||||
AccessGrant.permission == "read",
|
||||
AccessGrant.permission == 'read',
|
||||
or_(
|
||||
*(
|
||||
[
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == user_id,
|
||||
)
|
||||
]
|
||||
@@ -826,7 +793,7 @@ class AccessGrantsTable:
|
||||
*(
|
||||
[
|
||||
and_(
|
||||
AccessGrant.principal_type == "group",
|
||||
AccessGrant.principal_type == 'group',
|
||||
AccessGrant.principal_id.in_(group_ids),
|
||||
)
|
||||
]
|
||||
@@ -845,12 +812,12 @@ class AccessGrantsTable:
|
||||
.where(
|
||||
AccessGrant.resource_type == resource_type,
|
||||
AccessGrant.resource_id == DocumentModel.id,
|
||||
AccessGrant.permission == "write",
|
||||
AccessGrant.permission == 'write',
|
||||
or_(
|
||||
*(
|
||||
[
|
||||
and_(
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == user_id,
|
||||
)
|
||||
]
|
||||
@@ -860,7 +827,7 @@ class AccessGrantsTable:
|
||||
*(
|
||||
[
|
||||
and_(
|
||||
AccessGrant.principal_type == "group",
|
||||
AccessGrant.principal_type == 'group',
|
||||
AccessGrant.principal_id.in_(group_ids),
|
||||
)
|
||||
]
|
||||
@@ -879,9 +846,9 @@ class AccessGrantsTable:
|
||||
.where(
|
||||
AccessGrant.resource_type == resource_type,
|
||||
AccessGrant.resource_id == DocumentModel.id,
|
||||
AccessGrant.permission == "read",
|
||||
AccessGrant.principal_type == "user",
|
||||
AccessGrant.principal_id == "*",
|
||||
AccessGrant.permission == 'read',
|
||||
AccessGrant.principal_type == 'user',
|
||||
AccessGrant.principal_id == '*',
|
||||
)
|
||||
.correlate(DocumentModel)
|
||||
.exists()
|
||||
|
||||
@@ -17,7 +17,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Auth(Base):
|
||||
__tablename__ = "auth"
|
||||
__tablename__ = 'auth'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
email = Column(String)
|
||||
@@ -73,9 +73,9 @@ class SignupForm(BaseModel):
|
||||
name: str
|
||||
email: str
|
||||
password: str
|
||||
profile_image_url: Optional[str] = "/user.png"
|
||||
profile_image_url: Optional[str] = '/user.png'
|
||||
|
||||
@field_validator("profile_image_url")
|
||||
@field_validator('profile_image_url')
|
||||
@classmethod
|
||||
def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None:
|
||||
@@ -84,7 +84,7 @@ class SignupForm(BaseModel):
|
||||
|
||||
|
||||
class AddUserForm(SignupForm):
|
||||
role: Optional[str] = "pending"
|
||||
role: Optional[str] = 'pending'
|
||||
|
||||
|
||||
class AuthsTable:
|
||||
@@ -93,25 +93,21 @@ class AuthsTable:
|
||||
email: str,
|
||||
password: str,
|
||||
name: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
profile_image_url: str = '/user.png',
|
||||
role: str = 'pending',
|
||||
oauth: Optional[dict] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[UserModel]:
|
||||
with get_db_context(db) as db:
|
||||
log.info("insert_new_auth")
|
||||
log.info('insert_new_auth')
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
auth = AuthModel(
|
||||
**{"id": id, "email": email, "password": password, "active": True}
|
||||
)
|
||||
auth = AuthModel(**{'id': id, 'email': email, 'password': password, 'active': True})
|
||||
result = Auth(**auth.model_dump())
|
||||
db.add(result)
|
||||
|
||||
user = Users.insert_new_user(
|
||||
id, name, email, profile_image_url, role, oauth=oauth, db=db
|
||||
)
|
||||
user = Users.insert_new_user(id, name, email, profile_image_url, role, oauth=oauth, db=db)
|
||||
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
@@ -124,7 +120,7 @@ class AuthsTable:
|
||||
def authenticate_user(
|
||||
self, email: str, verify_password: callable, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user: {email}")
|
||||
log.info(f'authenticate_user: {email}')
|
||||
|
||||
user = Users.get_user_by_email(email, db=db)
|
||||
if not user:
|
||||
@@ -143,10 +139,8 @@ class AuthsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def authenticate_user_by_api_key(
|
||||
self, api_key: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_api_key")
|
||||
def authenticate_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
log.info(f'authenticate_user_by_api_key')
|
||||
# if no api_key, return None
|
||||
if not api_key:
|
||||
return None
|
||||
@@ -157,10 +151,8 @@ class AuthsTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def authenticate_user_by_email(
|
||||
self, email: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_email: {email}")
|
||||
def authenticate_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
log.info(f'authenticate_user_by_email: {email}')
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
# Single JOIN query instead of two separate queries
|
||||
@@ -177,28 +169,22 @@ class AuthsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_password_by_id(
|
||||
self, id: str, new_password: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def update_user_password_by_id(self, id: str, new_password: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
result = (
|
||||
db.query(Auth).filter_by(id=id).update({"password": new_password})
|
||||
)
|
||||
result = db.query(Auth).filter_by(id=id).update({'password': new_password})
|
||||
db.commit()
|
||||
return True if result == 1 else False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def update_email_by_id(
|
||||
self, id: str, email: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def update_email_by_id(self, id: str, email: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
result = db.query(Auth).filter_by(id=id).update({"email": email})
|
||||
result = db.query(Auth).filter_by(id=id).update({'email': email})
|
||||
db.commit()
|
||||
if result == 1:
|
||||
Users.update_user_by_id(id, {"email": email}, db=db)
|
||||
Users.update_user_by_id(id, {'email': email}, db=db)
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
|
||||
@@ -37,7 +37,7 @@ from sqlalchemy.sql import exists
|
||||
|
||||
|
||||
class Channel(Base):
|
||||
__tablename__ = "channel"
|
||||
__tablename__ = 'channel'
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
user_id = Column(Text)
|
||||
@@ -94,7 +94,7 @@ class ChannelModel(BaseModel):
|
||||
|
||||
|
||||
class ChannelMember(Base):
|
||||
__tablename__ = "channel_member"
|
||||
__tablename__ = 'channel_member'
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
channel_id = Column(Text, nullable=False)
|
||||
@@ -154,25 +154,19 @@ class ChannelMemberModel(BaseModel):
|
||||
|
||||
|
||||
class ChannelFile(Base):
|
||||
__tablename__ = "channel_file"
|
||||
__tablename__ = 'channel_file'
|
||||
|
||||
id = Column(Text, unique=True, primary_key=True)
|
||||
user_id = Column(Text, nullable=False)
|
||||
|
||||
channel_id = Column(
|
||||
Text, ForeignKey("channel.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
message_id = Column(
|
||||
Text, ForeignKey("message.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
|
||||
channel_id = Column(Text, ForeignKey('channel.id', ondelete='CASCADE'), nullable=False)
|
||||
message_id = Column(Text, ForeignKey('message.id', ondelete='CASCADE'), nullable=True)
|
||||
file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False)
|
||||
|
||||
created_at = Column(BigInteger, nullable=False)
|
||||
updated_at = Column(BigInteger, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("channel_id", "file_id", name="uq_channel_file_channel_file"),
|
||||
)
|
||||
__table_args__ = (UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'),)
|
||||
|
||||
|
||||
class ChannelFileModel(BaseModel):
|
||||
@@ -189,7 +183,7 @@ class ChannelFileModel(BaseModel):
|
||||
|
||||
|
||||
class ChannelWebhook(Base):
|
||||
__tablename__ = "channel_webhook"
|
||||
__tablename__ = 'channel_webhook'
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
channel_id = Column(Text, nullable=False)
|
||||
@@ -235,7 +229,7 @@ class ChannelResponse(ChannelModel):
|
||||
|
||||
|
||||
class ChannelForm(BaseModel):
|
||||
name: str = ""
|
||||
name: str = ''
|
||||
description: Optional[str] = None
|
||||
is_private: Optional[bool] = None
|
||||
data: Optional[dict] = None
|
||||
@@ -255,10 +249,8 @@ class ChannelWebhookForm(BaseModel):
|
||||
|
||||
|
||||
class ChannelTable:
|
||||
def _get_access_grants(
|
||||
self, channel_id: str, db: Optional[Session] = None
|
||||
) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource("channel", channel_id, db=db)
|
||||
def _get_access_grants(self, channel_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource('channel', channel_id, db=db)
|
||||
|
||||
def _to_channel_model(
|
||||
self,
|
||||
@@ -266,13 +258,9 @@ class ChannelTable:
|
||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> ChannelModel:
|
||||
channel_data = ChannelModel.model_validate(channel).model_dump(
|
||||
exclude={"access_grants"}
|
||||
)
|
||||
channel_data["access_grants"] = (
|
||||
access_grants
|
||||
if access_grants is not None
|
||||
else self._get_access_grants(channel_data["id"], db=db)
|
||||
channel_data = ChannelModel.model_validate(channel).model_dump(exclude={'access_grants'})
|
||||
channel_data['access_grants'] = (
|
||||
access_grants if access_grants is not None else self._get_access_grants(channel_data['id'], db=db)
|
||||
)
|
||||
return ChannelModel.model_validate(channel_data)
|
||||
|
||||
@@ -313,20 +301,20 @@ class ChannelTable:
|
||||
for uid in user_ids:
|
||||
model = ChannelMemberModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
"channel_id": channel_id,
|
||||
"user_id": uid,
|
||||
"status": "joined",
|
||||
"is_active": True,
|
||||
"is_channel_muted": False,
|
||||
"is_channel_pinned": False,
|
||||
"invited_at": now,
|
||||
"invited_by": invited_by,
|
||||
"joined_at": now,
|
||||
"left_at": None,
|
||||
"last_read_at": now,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
'id': str(uuid.uuid4()),
|
||||
'channel_id': channel_id,
|
||||
'user_id': uid,
|
||||
'status': 'joined',
|
||||
'is_active': True,
|
||||
'is_channel_muted': False,
|
||||
'is_channel_pinned': False,
|
||||
'invited_at': now,
|
||||
'invited_by': invited_by,
|
||||
'joined_at': now,
|
||||
'left_at': None,
|
||||
'last_read_at': now,
|
||||
'created_at': now,
|
||||
'updated_at': now,
|
||||
}
|
||||
)
|
||||
memberships.append(ChannelMember(**model.model_dump()))
|
||||
@@ -339,19 +327,19 @@ class ChannelTable:
|
||||
with get_db_context(db) as db:
|
||||
channel = ChannelModel(
|
||||
**{
|
||||
**form_data.model_dump(exclude={"access_grants"}),
|
||||
"type": form_data.type if form_data.type else None,
|
||||
"name": form_data.name.lower(),
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time_ns()),
|
||||
"updated_at": int(time.time_ns()),
|
||||
"access_grants": [],
|
||||
**form_data.model_dump(exclude={'access_grants'}),
|
||||
'type': form_data.type if form_data.type else None,
|
||||
'name': form_data.name.lower(),
|
||||
'id': str(uuid.uuid4()),
|
||||
'user_id': user_id,
|
||||
'created_at': int(time.time_ns()),
|
||||
'updated_at': int(time.time_ns()),
|
||||
'access_grants': [],
|
||||
}
|
||||
)
|
||||
new_channel = Channel(**channel.model_dump(exclude={"access_grants"}))
|
||||
new_channel = Channel(**channel.model_dump(exclude={'access_grants'}))
|
||||
|
||||
if form_data.type in ["group", "dm"]:
|
||||
if form_data.type in ['group', 'dm']:
|
||||
users = self._collect_unique_user_ids(
|
||||
invited_by=user_id,
|
||||
user_ids=form_data.user_ids,
|
||||
@@ -366,18 +354,14 @@ class ChannelTable:
|
||||
db.add_all(memberships)
|
||||
db.add(new_channel)
|
||||
db.commit()
|
||||
AccessGrants.set_access_grants(
|
||||
"channel", new_channel.id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('channel', new_channel.id, form_data.access_grants, db=db)
|
||||
return self._to_channel_model(new_channel, db=db)
|
||||
|
||||
def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
channels = db.query(Channel).all()
|
||||
channel_ids = [channel.id for channel in channels]
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"channel", channel_ids, db=db
|
||||
)
|
||||
grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
|
||||
return [
|
||||
self._to_channel_model(
|
||||
channel,
|
||||
@@ -387,23 +371,19 @@ class ChannelTable:
|
||||
for channel in channels
|
||||
]
|
||||
|
||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
||||
def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
|
||||
return AccessGrants.has_permission_filter(
|
||||
db=db,
|
||||
query=query,
|
||||
DocumentModel=Channel,
|
||||
filter=filter,
|
||||
resource_type="channel",
|
||||
resource_type='channel',
|
||||
permission=permission,
|
||||
)
|
||||
|
||||
def get_channels_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> list[ChannelModel]:
|
||||
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
user_group_ids = [
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
]
|
||||
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||
|
||||
membership_channels = (
|
||||
db.query(Channel)
|
||||
@@ -411,7 +391,7 @@ class ChannelTable:
|
||||
.filter(
|
||||
Channel.deleted_at.is_(None),
|
||||
Channel.archived_at.is_(None),
|
||||
Channel.type.in_(["group", "dm"]),
|
||||
Channel.type.in_(['group', 'dm']),
|
||||
ChannelMember.user_id == user_id,
|
||||
ChannelMember.is_active.is_(True),
|
||||
)
|
||||
@@ -423,29 +403,20 @@ class ChannelTable:
|
||||
Channel.archived_at.is_(None),
|
||||
or_(
|
||||
Channel.type.is_(None), # True NULL/None
|
||||
Channel.type == "", # Empty string
|
||||
and_(Channel.type != "group", Channel.type != "dm"),
|
||||
Channel.type == '', # Empty string
|
||||
and_(Channel.type != 'group', Channel.type != 'dm'),
|
||||
),
|
||||
)
|
||||
query = self._has_permission(
|
||||
db, query, {"user_id": user_id, "group_ids": user_group_ids}
|
||||
)
|
||||
query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids})
|
||||
|
||||
standard_channels = query.all()
|
||||
|
||||
all_channels = membership_channels + standard_channels
|
||||
channel_ids = [c.id for c in all_channels]
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"channel", channel_ids, db=db
|
||||
)
|
||||
return [
|
||||
self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db)
|
||||
for c in all_channels
|
||||
]
|
||||
grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
|
||||
return [self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) for c in all_channels]
|
||||
|
||||
def get_dm_channel_by_user_ids(
|
||||
self, user_ids: list[str], db: Optional[Session] = None
|
||||
) -> Optional[ChannelModel]:
|
||||
def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
# Ensure uniqueness in case a list with duplicates is passed
|
||||
unique_user_ids = list(set(user_ids))
|
||||
@@ -471,7 +442,7 @@ class ChannelTable:
|
||||
db.query(Channel)
|
||||
.filter(
|
||||
Channel.id.in_(subquery),
|
||||
Channel.type == "dm",
|
||||
Channel.type == 'dm',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -488,32 +459,23 @@ class ChannelTable:
|
||||
) -> list[ChannelMemberModel]:
|
||||
with get_db_context(db) as db:
|
||||
# 1. Collect all user_ids including groups + inviter
|
||||
requested_users = self._collect_unique_user_ids(
|
||||
invited_by, user_ids, group_ids
|
||||
)
|
||||
requested_users = self._collect_unique_user_ids(invited_by, user_ids, group_ids)
|
||||
|
||||
existing_users = {
|
||||
row.user_id
|
||||
for row in db.query(ChannelMember.user_id)
|
||||
.filter(ChannelMember.channel_id == channel_id)
|
||||
.all()
|
||||
for row in db.query(ChannelMember.user_id).filter(ChannelMember.channel_id == channel_id).all()
|
||||
}
|
||||
|
||||
new_user_ids = requested_users - existing_users
|
||||
if not new_user_ids:
|
||||
return [] # Nothing to add
|
||||
|
||||
new_memberships = self._create_membership_models(
|
||||
channel_id, invited_by, new_user_ids
|
||||
)
|
||||
new_memberships = self._create_membership_models(channel_id, invited_by, new_user_ids)
|
||||
|
||||
db.add_all(new_memberships)
|
||||
db.commit()
|
||||
|
||||
return [
|
||||
ChannelMemberModel.model_validate(membership)
|
||||
for membership in new_memberships
|
||||
]
|
||||
return [ChannelMemberModel.model_validate(membership) for membership in new_memberships]
|
||||
|
||||
def remove_members_from_channel(
|
||||
self,
|
||||
@@ -533,9 +495,7 @@ class ChannelTable:
|
||||
db.commit()
|
||||
return result # number of rows deleted
|
||||
|
||||
def is_user_channel_manager(
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
# Check if the user is the creator of the channel
|
||||
# or has a 'manager' role in ChannelMember
|
||||
@@ -548,15 +508,13 @@ class ChannelTable:
|
||||
.filter(
|
||||
ChannelMember.channel_id == channel_id,
|
||||
ChannelMember.user_id == user_id,
|
||||
ChannelMember.role == "manager",
|
||||
ChannelMember.role == 'manager',
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return membership is not None
|
||||
|
||||
def join_channel(
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChannelMemberModel]:
|
||||
def join_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChannelMemberModel]:
|
||||
with get_db_context(db) as db:
|
||||
# Check if the membership already exists
|
||||
existing_membership = (
|
||||
@@ -573,18 +531,18 @@ class ChannelTable:
|
||||
# Create new membership
|
||||
channel_member = ChannelMemberModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
"channel_id": channel_id,
|
||||
"user_id": user_id,
|
||||
"status": "joined",
|
||||
"is_active": True,
|
||||
"is_channel_muted": False,
|
||||
"is_channel_pinned": False,
|
||||
"joined_at": int(time.time_ns()),
|
||||
"left_at": None,
|
||||
"last_read_at": int(time.time_ns()),
|
||||
"created_at": int(time.time_ns()),
|
||||
"updated_at": int(time.time_ns()),
|
||||
'id': str(uuid.uuid4()),
|
||||
'channel_id': channel_id,
|
||||
'user_id': user_id,
|
||||
'status': 'joined',
|
||||
'is_active': True,
|
||||
'is_channel_muted': False,
|
||||
'is_channel_pinned': False,
|
||||
'joined_at': int(time.time_ns()),
|
||||
'left_at': None,
|
||||
'last_read_at': int(time.time_ns()),
|
||||
'created_at': int(time.time_ns()),
|
||||
'updated_at': int(time.time_ns()),
|
||||
}
|
||||
)
|
||||
new_membership = ChannelMember(**channel_member.model_dump())
|
||||
@@ -593,9 +551,7 @@ class ChannelTable:
|
||||
db.commit()
|
||||
return channel_member
|
||||
|
||||
def leave_channel(
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
@@ -608,7 +564,7 @@ class ChannelTable:
|
||||
if not membership:
|
||||
return False
|
||||
|
||||
membership.status = "left"
|
||||
membership.status = 'left'
|
||||
membership.is_active = False
|
||||
membership.left_at = int(time.time_ns())
|
||||
membership.updated_at = int(time.time_ns())
|
||||
@@ -630,19 +586,10 @@ class ChannelTable:
|
||||
)
|
||||
return ChannelMemberModel.model_validate(membership) if membership else None
|
||||
|
||||
def get_members_by_channel_id(
|
||||
self, channel_id: str, db: Optional[Session] = None
|
||||
) -> list[ChannelMemberModel]:
|
||||
def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]:
|
||||
with get_db_context(db) as db:
|
||||
memberships = (
|
||||
db.query(ChannelMember)
|
||||
.filter(ChannelMember.channel_id == channel_id)
|
||||
.all()
|
||||
)
|
||||
return [
|
||||
ChannelMemberModel.model_validate(membership)
|
||||
for membership in memberships
|
||||
]
|
||||
memberships = db.query(ChannelMember).filter(ChannelMember.channel_id == channel_id).all()
|
||||
return [ChannelMemberModel.model_validate(membership) for membership in memberships]
|
||||
|
||||
def pin_channel(
|
||||
self,
|
||||
@@ -669,9 +616,7 @@ class ChannelTable:
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def update_member_last_read_at(
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
@@ -715,9 +660,7 @@ class ChannelTable:
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def is_user_channel_member(
|
||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
@@ -729,9 +672,7 @@ class ChannelTable:
|
||||
)
|
||||
return membership is not None
|
||||
|
||||
def get_channel_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChannelModel]:
|
||||
def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||
@@ -739,18 +680,12 @@ class ChannelTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_channels_by_file_id(
|
||||
self, file_id: str, db: Optional[Session] = None
|
||||
) -> list[ChannelModel]:
|
||||
def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
channel_files = (
|
||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||
)
|
||||
channel_files = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||
channel_ids = [cf.channel_id for cf in channel_files]
|
||||
channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all()
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"channel", channel_ids, db=db
|
||||
)
|
||||
grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
|
||||
return [
|
||||
self._to_channel_model(
|
||||
channel,
|
||||
@@ -765,9 +700,7 @@ class ChannelTable:
|
||||
) -> list[ChannelModel]:
|
||||
with get_db_context(db) as db:
|
||||
# 1. Determine which channels have this file
|
||||
channel_file_rows = (
|
||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||
)
|
||||
channel_file_rows = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||
channel_ids = [row.channel_id for row in channel_file_rows]
|
||||
|
||||
if not channel_ids:
|
||||
@@ -787,15 +720,13 @@ class ChannelTable:
|
||||
return []
|
||||
|
||||
# Preload user's group membership
|
||||
user_group_ids = [
|
||||
g.id for g in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
]
|
||||
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||
|
||||
allowed_channels = []
|
||||
|
||||
for channel in channels:
|
||||
# --- Case A: group or dm => user must be an active member ---
|
||||
if channel.type in ["group", "dm"]:
|
||||
if channel.type in ['group', 'dm']:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
.filter(
|
||||
@@ -815,8 +746,8 @@ class ChannelTable:
|
||||
query = self._has_permission(
|
||||
db,
|
||||
query,
|
||||
{"user_id": user_id, "group_ids": user_group_ids},
|
||||
permission="read",
|
||||
{'user_id': user_id, 'group_ids': user_group_ids},
|
||||
permission='read',
|
||||
)
|
||||
|
||||
allowed = query.first()
|
||||
@@ -844,7 +775,7 @@ class ChannelTable:
|
||||
return None
|
||||
|
||||
# If the channel is a group or dm, read access requires membership (active)
|
||||
if channel.type in ["group", "dm"]:
|
||||
if channel.type in ['group', 'dm']:
|
||||
membership = (
|
||||
db.query(ChannelMember)
|
||||
.filter(
|
||||
@@ -863,24 +794,18 @@ class ChannelTable:
|
||||
query = db.query(Channel).filter(Channel.id == id)
|
||||
|
||||
# Determine user groups
|
||||
user_group_ids = [
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
]
|
||||
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||
|
||||
# Apply ACL rules
|
||||
query = self._has_permission(
|
||||
db,
|
||||
query,
|
||||
{"user_id": user_id, "group_ids": user_group_ids},
|
||||
permission="read",
|
||||
{'user_id': user_id, 'group_ids': user_group_ids},
|
||||
permission='read',
|
||||
)
|
||||
|
||||
channel_allowed = query.first()
|
||||
return (
|
||||
self._to_channel_model(channel_allowed, db=db)
|
||||
if channel_allowed
|
||||
else None
|
||||
)
|
||||
return self._to_channel_model(channel_allowed, db=db) if channel_allowed else None
|
||||
|
||||
def update_channel_by_id(
|
||||
self, id: str, form_data: ChannelForm, db: Optional[Session] = None
|
||||
@@ -898,9 +823,7 @@ class ChannelTable:
|
||||
channel.meta = form_data.meta
|
||||
|
||||
if form_data.access_grants is not None:
|
||||
AccessGrants.set_access_grants(
|
||||
"channel", id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('channel', id, form_data.access_grants, db=db)
|
||||
channel.updated_at = int(time.time_ns())
|
||||
|
||||
db.commit()
|
||||
@@ -912,12 +835,12 @@ class ChannelTable:
|
||||
with get_db_context(db) as db:
|
||||
channel_file = ChannelFileModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
"channel_id": channel_id,
|
||||
"file_id": file_id,
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
'id': str(uuid.uuid4()),
|
||||
'channel_id': channel_id,
|
||||
'file_id': file_id,
|
||||
'user_id': user_id,
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -942,11 +865,7 @@ class ChannelTable:
|
||||
) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
channel_file = (
|
||||
db.query(ChannelFile)
|
||||
.filter_by(channel_id=channel_id, file_id=file_id)
|
||||
.first()
|
||||
)
|
||||
channel_file = db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).first()
|
||||
if not channel_file:
|
||||
return False
|
||||
|
||||
@@ -958,14 +877,10 @@ class ChannelTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def remove_file_from_channel_by_id(
|
||||
self, channel_id: str, file_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def remove_file_from_channel_by_id(self, channel_id: str, file_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
db.query(ChannelFile).filter_by(
|
||||
channel_id=channel_id, file_id=file_id
|
||||
).delete()
|
||||
db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
@@ -973,7 +888,7 @@ class ChannelTable:
|
||||
|
||||
def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
AccessGrants.revoke_all_access("channel", id, db=db)
|
||||
AccessGrants.revoke_all_access('channel', id, db=db)
|
||||
db.query(Channel).filter(Channel.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
@@ -1005,24 +920,14 @@ class ChannelTable:
|
||||
db.commit()
|
||||
return webhook
|
||||
|
||||
def get_webhooks_by_channel_id(
|
||||
self, channel_id: str, db: Optional[Session] = None
|
||||
) -> list[ChannelWebhookModel]:
|
||||
def get_webhooks_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelWebhookModel]:
|
||||
with get_db_context(db) as db:
|
||||
webhooks = (
|
||||
db.query(ChannelWebhook)
|
||||
.filter(ChannelWebhook.channel_id == channel_id)
|
||||
.all()
|
||||
)
|
||||
webhooks = db.query(ChannelWebhook).filter(ChannelWebhook.channel_id == channel_id).all()
|
||||
return [ChannelWebhookModel.model_validate(w) for w in webhooks]
|
||||
|
||||
def get_webhook_by_id(
|
||||
self, webhook_id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChannelWebhookModel]:
|
||||
def get_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> Optional[ChannelWebhookModel]:
|
||||
with get_db_context(db) as db:
|
||||
webhook = (
|
||||
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
||||
)
|
||||
webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
||||
return ChannelWebhookModel.model_validate(webhook) if webhook else None
|
||||
|
||||
def get_webhook_by_id_and_token(
|
||||
@@ -1046,9 +951,7 @@ class ChannelTable:
|
||||
db: Optional[Session] = None,
|
||||
) -> Optional[ChannelWebhookModel]:
|
||||
with get_db_context(db) as db:
|
||||
webhook = (
|
||||
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
||||
)
|
||||
webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
||||
if not webhook:
|
||||
return None
|
||||
webhook.name = form_data.name
|
||||
@@ -1057,28 +960,18 @@ class ChannelTable:
|
||||
db.commit()
|
||||
return ChannelWebhookModel.model_validate(webhook)
|
||||
|
||||
def update_webhook_last_used_at(
|
||||
self, webhook_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def update_webhook_last_used_at(self, webhook_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
webhook = (
|
||||
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
||||
)
|
||||
webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
||||
if not webhook:
|
||||
return False
|
||||
webhook.last_used_at = int(time.time_ns())
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_webhook_by_id(
|
||||
self, webhook_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
result = (
|
||||
db.query(ChannelWebhook)
|
||||
.filter(ChannelWebhook.id == webhook_id)
|
||||
.delete()
|
||||
)
|
||||
result = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).delete()
|
||||
db.commit()
|
||||
return result > 0
|
||||
|
||||
|
||||
@@ -47,13 +47,11 @@ def _normalize_timestamp(timestamp: int) -> float:
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = "chat_message"
|
||||
__tablename__ = 'chat_message'
|
||||
|
||||
# Identity
|
||||
id = Column(Text, primary_key=True)
|
||||
chat_id = Column(
|
||||
Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||
user_id = Column(Text, index=True)
|
||||
|
||||
# Structure
|
||||
@@ -85,9 +83,9 @@ class ChatMessage(Base):
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
__table_args__ = (
|
||||
Index("chat_message_chat_parent_idx", "chat_id", "parent_id"),
|
||||
Index("chat_message_model_created_idx", "model_id", "created_at"),
|
||||
Index("chat_message_user_created_idx", "user_id", "created_at"),
|
||||
Index('chat_message_chat_parent_idx', 'chat_id', 'parent_id'),
|
||||
Index('chat_message_model_created_idx', 'model_id', 'created_at'),
|
||||
Index('chat_message_user_created_idx', 'user_id', 'created_at'),
|
||||
)
|
||||
|
||||
|
||||
@@ -135,43 +133,41 @@ class ChatMessageTable:
|
||||
"""Insert or update a chat message."""
|
||||
with get_db_context(db) as db:
|
||||
now = int(time.time())
|
||||
timestamp = data.get("timestamp", now)
|
||||
timestamp = data.get('timestamp', now)
|
||||
|
||||
# Use composite ID: {chat_id}-{message_id}
|
||||
composite_id = f"{chat_id}-{message_id}"
|
||||
composite_id = f'{chat_id}-{message_id}'
|
||||
|
||||
existing = db.get(ChatMessage, composite_id)
|
||||
if existing:
|
||||
# Update existing
|
||||
if "role" in data:
|
||||
existing.role = data["role"]
|
||||
if "parent_id" in data:
|
||||
existing.parent_id = data.get("parent_id") or data.get("parentId")
|
||||
if "content" in data:
|
||||
existing.content = data.get("content")
|
||||
if "output" in data:
|
||||
existing.output = data.get("output")
|
||||
if "model_id" in data or "model" in data:
|
||||
existing.model_id = data.get("model_id") or data.get("model")
|
||||
if "files" in data:
|
||||
existing.files = data.get("files")
|
||||
if "sources" in data:
|
||||
existing.sources = data.get("sources")
|
||||
if "embeds" in data:
|
||||
existing.embeds = data.get("embeds")
|
||||
if "done" in data:
|
||||
existing.done = data.get("done", True)
|
||||
if "status_history" in data or "statusHistory" in data:
|
||||
existing.status_history = data.get("status_history") or data.get(
|
||||
"statusHistory"
|
||||
)
|
||||
if "error" in data:
|
||||
existing.error = data.get("error")
|
||||
if 'role' in data:
|
||||
existing.role = data['role']
|
||||
if 'parent_id' in data:
|
||||
existing.parent_id = data.get('parent_id') or data.get('parentId')
|
||||
if 'content' in data:
|
||||
existing.content = data.get('content')
|
||||
if 'output' in data:
|
||||
existing.output = data.get('output')
|
||||
if 'model_id' in data or 'model' in data:
|
||||
existing.model_id = data.get('model_id') or data.get('model')
|
||||
if 'files' in data:
|
||||
existing.files = data.get('files')
|
||||
if 'sources' in data:
|
||||
existing.sources = data.get('sources')
|
||||
if 'embeds' in data:
|
||||
existing.embeds = data.get('embeds')
|
||||
if 'done' in data:
|
||||
existing.done = data.get('done', True)
|
||||
if 'status_history' in data or 'statusHistory' in data:
|
||||
existing.status_history = data.get('status_history') or data.get('statusHistory')
|
||||
if 'error' in data:
|
||||
existing.error = data.get('error')
|
||||
# Extract usage - check direct field first, then info.usage
|
||||
usage = data.get("usage")
|
||||
usage = data.get('usage')
|
||||
if not usage:
|
||||
info = data.get("info", {})
|
||||
usage = info.get("usage") if info else None
|
||||
info = data.get('info', {})
|
||||
usage = info.get('usage') if info else None
|
||||
if usage:
|
||||
existing.usage = usage
|
||||
existing.updated_at = now
|
||||
@@ -181,26 +177,25 @@ class ChatMessageTable:
|
||||
else:
|
||||
# Insert new
|
||||
# Extract usage - check direct field first, then info.usage
|
||||
usage = data.get("usage")
|
||||
usage = data.get('usage')
|
||||
if not usage:
|
||||
info = data.get("info", {})
|
||||
usage = info.get("usage") if info else None
|
||||
info = data.get('info', {})
|
||||
usage = info.get('usage') if info else None
|
||||
message = ChatMessage(
|
||||
id=composite_id,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
role=data.get("role", "user"),
|
||||
parent_id=data.get("parent_id") or data.get("parentId"),
|
||||
content=data.get("content"),
|
||||
output=data.get("output"),
|
||||
model_id=data.get("model_id") or data.get("model"),
|
||||
files=data.get("files"),
|
||||
sources=data.get("sources"),
|
||||
embeds=data.get("embeds"),
|
||||
done=data.get("done", True),
|
||||
status_history=data.get("status_history")
|
||||
or data.get("statusHistory"),
|
||||
error=data.get("error"),
|
||||
role=data.get('role', 'user'),
|
||||
parent_id=data.get('parent_id') or data.get('parentId'),
|
||||
content=data.get('content'),
|
||||
output=data.get('output'),
|
||||
model_id=data.get('model_id') or data.get('model'),
|
||||
files=data.get('files'),
|
||||
sources=data.get('sources'),
|
||||
embeds=data.get('embeds'),
|
||||
done=data.get('done', True),
|
||||
status_history=data.get('status_history') or data.get('statusHistory'),
|
||||
error=data.get('error'),
|
||||
usage=usage,
|
||||
created_at=timestamp,
|
||||
updated_at=now,
|
||||
@@ -210,23 +205,14 @@ class ChatMessageTable:
|
||||
db.refresh(message)
|
||||
return ChatMessageModel.model_validate(message)
|
||||
|
||||
def get_message_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[ChatMessageModel]:
|
||||
def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatMessageModel]:
|
||||
with get_db_context(db) as db:
|
||||
message = db.get(ChatMessage, id)
|
||||
return ChatMessageModel.model_validate(message) if message else None
|
||||
|
||||
def get_messages_by_chat_id(
|
||||
self, chat_id: str, db: Optional[Session] = None
|
||||
) -> list[ChatMessageModel]:
|
||||
def get_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[ChatMessageModel]:
|
||||
with get_db_context(db) as db:
|
||||
messages = (
|
||||
db.query(ChatMessage)
|
||||
.filter_by(chat_id=chat_id)
|
||||
.order_by(ChatMessage.created_at.asc())
|
||||
.all()
|
||||
)
|
||||
messages = db.query(ChatMessage).filter_by(chat_id=chat_id).order_by(ChatMessage.created_at.asc()).all()
|
||||
return [ChatMessageModel.model_validate(message) for message in messages]
|
||||
|
||||
def get_messages_by_user_id(
|
||||
@@ -262,12 +248,7 @@ class ChatMessageTable:
|
||||
query = query.filter(ChatMessage.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(ChatMessage.created_at <= end_date)
|
||||
messages = (
|
||||
query.order_by(ChatMessage.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
messages = query.order_by(ChatMessage.created_at.desc()).offset(skip).limit(limit).all()
|
||||
return [ChatMessageModel.model_validate(message) for message in messages]
|
||||
|
||||
def get_chat_ids_by_model_id(
|
||||
@@ -284,7 +265,7 @@ class ChatMessageTable:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(
|
||||
ChatMessage.chat_id,
|
||||
func.max(ChatMessage.created_at).label("last_message_at"),
|
||||
func.max(ChatMessage.created_at).label('last_message_at'),
|
||||
).filter(ChatMessage.model_id == model_id)
|
||||
if start_date:
|
||||
query = query.filter(ChatMessage.created_at >= start_date)
|
||||
@@ -303,9 +284,7 @@ class ChatMessageTable:
|
||||
)
|
||||
return [chat_id for chat_id, _ in chat_ids]
|
||||
|
||||
def delete_messages_by_chat_id(
|
||||
self, chat_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
db.query(ChatMessage).filter_by(chat_id=chat_id).delete()
|
||||
db.commit()
|
||||
@@ -323,12 +302,10 @@ class ChatMessageTable:
|
||||
from sqlalchemy import func
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
query = db.query(
|
||||
ChatMessage.model_id, func.count(ChatMessage.id).label("count")
|
||||
).filter(
|
||||
ChatMessage.role == "assistant",
|
||||
query = db.query(ChatMessage.model_id, func.count(ChatMessage.id).label('count')).filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
~ChatMessage.user_id.like("shared-%"),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -336,11 +313,7 @@ class ChatMessageTable:
|
||||
if end_date:
|
||||
query = query.filter(ChatMessage.created_at <= end_date)
|
||||
if group_id:
|
||||
group_users = (
|
||||
db.query(GroupMember.user_id)
|
||||
.filter(GroupMember.group_id == group_id)
|
||||
.subquery()
|
||||
)
|
||||
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||
|
||||
results = query.group_by(ChatMessage.model_id).all()
|
||||
@@ -360,36 +333,32 @@ class ChatMessageTable:
|
||||
|
||||
dialect = db.bind.dialect.name
|
||||
|
||||
if dialect == "sqlite":
|
||||
input_tokens = cast(
|
||||
func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer
|
||||
)
|
||||
output_tokens = cast(
|
||||
func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
if dialect == 'sqlite':
|
||||
input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer)
|
||||
output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer)
|
||||
elif dialect == 'postgresql':
|
||||
# Use json_extract_path_text for PostgreSQL JSON columns
|
||||
input_tokens = cast(
|
||||
func.json_extract_path_text(ChatMessage.usage, "input_tokens"),
|
||||
func.json_extract_path_text(ChatMessage.usage, 'input_tokens'),
|
||||
Integer,
|
||||
)
|
||||
output_tokens = cast(
|
||||
func.json_extract_path_text(ChatMessage.usage, "output_tokens"),
|
||||
func.json_extract_path_text(ChatMessage.usage, 'output_tokens'),
|
||||
Integer,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dialect: {dialect}")
|
||||
raise NotImplementedError(f'Unsupported dialect: {dialect}')
|
||||
|
||||
query = db.query(
|
||||
ChatMessage.model_id,
|
||||
func.coalesce(func.sum(input_tokens), 0).label("input_tokens"),
|
||||
func.coalesce(func.sum(output_tokens), 0).label("output_tokens"),
|
||||
func.count(ChatMessage.id).label("message_count"),
|
||||
func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
|
||||
func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
|
||||
func.count(ChatMessage.id).label('message_count'),
|
||||
).filter(
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
ChatMessage.usage.isnot(None),
|
||||
~ChatMessage.user_id.like("shared-%"),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -397,21 +366,17 @@ class ChatMessageTable:
|
||||
if end_date:
|
||||
query = query.filter(ChatMessage.created_at <= end_date)
|
||||
if group_id:
|
||||
group_users = (
|
||||
db.query(GroupMember.user_id)
|
||||
.filter(GroupMember.group_id == group_id)
|
||||
.subquery()
|
||||
)
|
||||
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||
|
||||
results = query.group_by(ChatMessage.model_id).all()
|
||||
|
||||
return {
|
||||
row.model_id: {
|
||||
"input_tokens": row.input_tokens,
|
||||
"output_tokens": row.output_tokens,
|
||||
"total_tokens": row.input_tokens + row.output_tokens,
|
||||
"message_count": row.message_count,
|
||||
'input_tokens': row.input_tokens,
|
||||
'output_tokens': row.output_tokens,
|
||||
'total_tokens': row.input_tokens + row.output_tokens,
|
||||
'message_count': row.message_count,
|
||||
}
|
||||
for row in results
|
||||
}
|
||||
@@ -430,36 +395,32 @@ class ChatMessageTable:
|
||||
|
||||
dialect = db.bind.dialect.name
|
||||
|
||||
if dialect == "sqlite":
|
||||
input_tokens = cast(
|
||||
func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer
|
||||
)
|
||||
output_tokens = cast(
|
||||
func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
if dialect == 'sqlite':
|
||||
input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer)
|
||||
output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer)
|
||||
elif dialect == 'postgresql':
|
||||
# Use json_extract_path_text for PostgreSQL JSON columns
|
||||
input_tokens = cast(
|
||||
func.json_extract_path_text(ChatMessage.usage, "input_tokens"),
|
||||
func.json_extract_path_text(ChatMessage.usage, 'input_tokens'),
|
||||
Integer,
|
||||
)
|
||||
output_tokens = cast(
|
||||
func.json_extract_path_text(ChatMessage.usage, "output_tokens"),
|
||||
func.json_extract_path_text(ChatMessage.usage, 'output_tokens'),
|
||||
Integer,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dialect: {dialect}")
|
||||
raise NotImplementedError(f'Unsupported dialect: {dialect}')
|
||||
|
||||
query = db.query(
|
||||
ChatMessage.user_id,
|
||||
func.coalesce(func.sum(input_tokens), 0).label("input_tokens"),
|
||||
func.coalesce(func.sum(output_tokens), 0).label("output_tokens"),
|
||||
func.count(ChatMessage.id).label("message_count"),
|
||||
func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
|
||||
func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
|
||||
func.count(ChatMessage.id).label('message_count'),
|
||||
).filter(
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.user_id.isnot(None),
|
||||
ChatMessage.usage.isnot(None),
|
||||
~ChatMessage.user_id.like("shared-%"),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -467,21 +428,17 @@ class ChatMessageTable:
|
||||
if end_date:
|
||||
query = query.filter(ChatMessage.created_at <= end_date)
|
||||
if group_id:
|
||||
group_users = (
|
||||
db.query(GroupMember.user_id)
|
||||
.filter(GroupMember.group_id == group_id)
|
||||
.subquery()
|
||||
)
|
||||
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||
|
||||
results = query.group_by(ChatMessage.user_id).all()
|
||||
|
||||
return {
|
||||
row.user_id: {
|
||||
"input_tokens": row.input_tokens,
|
||||
"output_tokens": row.output_tokens,
|
||||
"total_tokens": row.input_tokens + row.output_tokens,
|
||||
"message_count": row.message_count,
|
||||
'input_tokens': row.input_tokens,
|
||||
'output_tokens': row.output_tokens,
|
||||
'total_tokens': row.input_tokens + row.output_tokens,
|
||||
'message_count': row.message_count,
|
||||
}
|
||||
for row in results
|
||||
}
|
||||
@@ -497,20 +454,16 @@ class ChatMessageTable:
|
||||
from sqlalchemy import func
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
query = db.query(
|
||||
ChatMessage.user_id, func.count(ChatMessage.id).label("count")
|
||||
).filter(~ChatMessage.user_id.like("shared-%"))
|
||||
query = db.query(ChatMessage.user_id, func.count(ChatMessage.id).label('count')).filter(
|
||||
~ChatMessage.user_id.like('shared-%')
|
||||
)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(ChatMessage.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(ChatMessage.created_at <= end_date)
|
||||
if group_id:
|
||||
group_users = (
|
||||
db.query(GroupMember.user_id)
|
||||
.filter(GroupMember.group_id == group_id)
|
||||
.subquery()
|
||||
)
|
||||
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||
|
||||
results = query.group_by(ChatMessage.user_id).all()
|
||||
@@ -527,20 +480,16 @@ class ChatMessageTable:
|
||||
from sqlalchemy import func
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
query = db.query(
|
||||
ChatMessage.chat_id, func.count(ChatMessage.id).label("count")
|
||||
).filter(~ChatMessage.user_id.like("shared-%"))
|
||||
query = db.query(ChatMessage.chat_id, func.count(ChatMessage.id).label('count')).filter(
|
||||
~ChatMessage.user_id.like('shared-%')
|
||||
)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(ChatMessage.created_at >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(ChatMessage.created_at <= end_date)
|
||||
if group_id:
|
||||
group_users = (
|
||||
db.query(GroupMember.user_id)
|
||||
.filter(GroupMember.group_id == group_id)
|
||||
.subquery()
|
||||
)
|
||||
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||
|
||||
results = query.group_by(ChatMessage.chat_id).all()
|
||||
@@ -559,9 +508,9 @@ class ChatMessageTable:
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter(
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
~ChatMessage.user_id.like("shared-%"),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -569,11 +518,7 @@ class ChatMessageTable:
|
||||
if end_date:
|
||||
query = query.filter(ChatMessage.created_at <= end_date)
|
||||
if group_id:
|
||||
group_users = (
|
||||
db.query(GroupMember.user_id)
|
||||
.filter(GroupMember.group_id == group_id)
|
||||
.subquery()
|
||||
)
|
||||
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||
|
||||
results = query.all()
|
||||
@@ -581,21 +526,17 @@ class ChatMessageTable:
|
||||
# Group by date -> model -> count
|
||||
daily_counts: dict[str, dict[str, int]] = {}
|
||||
for timestamp, model_id in results:
|
||||
date_str = datetime.fromtimestamp(
|
||||
_normalize_timestamp(timestamp)
|
||||
).strftime("%Y-%m-%d")
|
||||
date_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d')
|
||||
if date_str not in daily_counts:
|
||||
daily_counts[date_str] = {}
|
||||
daily_counts[date_str][model_id] = (
|
||||
daily_counts[date_str].get(model_id, 0) + 1
|
||||
)
|
||||
daily_counts[date_str][model_id] = daily_counts[date_str].get(model_id, 0) + 1
|
||||
|
||||
# Fill in missing days
|
||||
if start_date and end_date:
|
||||
current = datetime.fromtimestamp(_normalize_timestamp(start_date))
|
||||
end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date))
|
||||
while current <= end_dt:
|
||||
date_str = current.strftime("%Y-%m-%d")
|
||||
date_str = current.strftime('%Y-%m-%d')
|
||||
if date_str not in daily_counts:
|
||||
daily_counts[date_str] = {}
|
||||
current += timedelta(days=1)
|
||||
@@ -613,9 +554,9 @@ class ChatMessageTable:
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter(
|
||||
ChatMessage.role == "assistant",
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
~ChatMessage.user_id.like("shared-%"),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -628,23 +569,19 @@ class ChatMessageTable:
|
||||
# Group by hour -> model -> count
|
||||
hourly_counts: dict[str, dict[str, int]] = {}
|
||||
for timestamp, model_id in results:
|
||||
hour_str = datetime.fromtimestamp(
|
||||
_normalize_timestamp(timestamp)
|
||||
).strftime("%Y-%m-%d %H:00")
|
||||
hour_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d %H:00')
|
||||
if hour_str not in hourly_counts:
|
||||
hourly_counts[hour_str] = {}
|
||||
hourly_counts[hour_str][model_id] = (
|
||||
hourly_counts[hour_str].get(model_id, 0) + 1
|
||||
)
|
||||
hourly_counts[hour_str][model_id] = hourly_counts[hour_str].get(model_id, 0) + 1
|
||||
|
||||
# Fill in missing hours
|
||||
if start_date and end_date:
|
||||
current = datetime.fromtimestamp(
|
||||
_normalize_timestamp(start_date)
|
||||
).replace(minute=0, second=0, microsecond=0)
|
||||
current = datetime.fromtimestamp(_normalize_timestamp(start_date)).replace(
|
||||
minute=0, second=0, microsecond=0
|
||||
)
|
||||
end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date))
|
||||
while current <= end_dt:
|
||||
hour_str = current.strftime("%Y-%m-%d %H:00")
|
||||
hour_str = current.strftime('%Y-%m-%d %H:00')
|
||||
if hour_str not in hourly_counts:
|
||||
hourly_counts[hour_str] = {}
|
||||
current += timedelta(hours=1)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Feedback(Base):
|
||||
__tablename__ = "feedback"
|
||||
__tablename__ = 'feedback'
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
user_id = Column(Text)
|
||||
version = Column(BigInteger, default=0)
|
||||
@@ -81,7 +81,7 @@ class RatingData(BaseModel):
|
||||
sibling_model_ids: Optional[list[str]] = None
|
||||
reason: Optional[str] = None
|
||||
comment: Optional[str] = None
|
||||
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
||||
model_config = ConfigDict(extra='allow', protected_namespaces=())
|
||||
|
||||
|
||||
class MetaData(BaseModel):
|
||||
@@ -89,12 +89,12 @@ class MetaData(BaseModel):
|
||||
chat_id: Optional[str] = None
|
||||
message_id: Optional[str] = None
|
||||
tags: Optional[list[str]] = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class SnapshotData(BaseModel):
|
||||
chat: Optional[dict] = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class FeedbackForm(BaseModel):
|
||||
@@ -102,14 +102,14 @@ class FeedbackForm(BaseModel):
|
||||
data: Optional[RatingData] = None
|
||||
meta: Optional[dict] = None
|
||||
snapshot: Optional[SnapshotData] = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
role: str = "pending"
|
||||
role: str = 'pending'
|
||||
|
||||
last_active_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
@@ -146,12 +146,12 @@ class FeedbackTable:
|
||||
id = str(uuid.uuid4())
|
||||
feedback = FeedbackModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"version": 0,
|
||||
'id': id,
|
||||
'user_id': user_id,
|
||||
'version': 0,
|
||||
**form_data.model_dump(),
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
try:
|
||||
@@ -164,12 +164,10 @@ class FeedbackTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating a new feedback: {e}")
|
||||
log.exception(f'Error creating a new feedback: {e}')
|
||||
return None
|
||||
|
||||
def get_feedback_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[FeedbackModel]:
|
||||
def get_feedback_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FeedbackModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id).first()
|
||||
@@ -191,16 +189,14 @@ class FeedbackTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_feedbacks_by_chat_id(
|
||||
self, chat_id: str, db: Optional[Session] = None
|
||||
) -> list[FeedbackModel]:
|
||||
def get_feedbacks_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[FeedbackModel]:
|
||||
"""Get all feedbacks for a specific chat."""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
# meta.chat_id stores the chat reference
|
||||
feedbacks = (
|
||||
db.query(Feedback)
|
||||
.filter(Feedback.meta["chat_id"].as_string() == chat_id)
|
||||
.filter(Feedback.meta['chat_id'].as_string() == chat_id)
|
||||
.order_by(Feedback.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
@@ -219,36 +215,28 @@ class FeedbackTable:
|
||||
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
|
||||
|
||||
if filter:
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
order_by = filter.get('order_by')
|
||||
direction = filter.get('direction')
|
||||
|
||||
if order_by == "username":
|
||||
if direction == "asc":
|
||||
if order_by == 'username':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.name.asc())
|
||||
else:
|
||||
query = query.order_by(User.name.desc())
|
||||
elif order_by == "model_id":
|
||||
elif order_by == 'model_id':
|
||||
# it's stored in feedback.data['model_id']
|
||||
if direction == "asc":
|
||||
query = query.order_by(
|
||||
Feedback.data["model_id"].as_string().asc()
|
||||
)
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Feedback.data['model_id'].as_string().asc())
|
||||
else:
|
||||
query = query.order_by(
|
||||
Feedback.data["model_id"].as_string().desc()
|
||||
)
|
||||
elif order_by == "rating":
|
||||
query = query.order_by(Feedback.data['model_id'].as_string().desc())
|
||||
elif order_by == 'rating':
|
||||
# it's stored in feedback.data['rating']
|
||||
if direction == "asc":
|
||||
query = query.order_by(
|
||||
Feedback.data["rating"].as_string().asc()
|
||||
)
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Feedback.data['rating'].as_string().asc())
|
||||
else:
|
||||
query = query.order_by(
|
||||
Feedback.data["rating"].as_string().desc()
|
||||
)
|
||||
elif order_by == "updated_at":
|
||||
if direction == "asc":
|
||||
query = query.order_by(Feedback.data['rating'].as_string().desc())
|
||||
elif order_by == 'updated_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Feedback.updated_at.asc())
|
||||
else:
|
||||
query = query.order_by(Feedback.updated_at.desc())
|
||||
@@ -270,9 +258,7 @@ class FeedbackTable:
|
||||
for feedback, user in items:
|
||||
feedback_model = FeedbackModel.model_validate(feedback)
|
||||
user_model = UserResponse.model_validate(user)
|
||||
feedbacks.append(
|
||||
FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)
|
||||
)
|
||||
feedbacks.append(FeedbackUserResponse(**feedback_model.model_dump(), user=user_model))
|
||||
|
||||
return FeedbackListResponse(items=feedbacks, total=total)
|
||||
|
||||
@@ -280,14 +266,10 @@ class FeedbackTable:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FeedbackModel.model_validate(feedback)
|
||||
for feedback in db.query(Feedback)
|
||||
.order_by(Feedback.updated_at.desc())
|
||||
.all()
|
||||
for feedback in db.query(Feedback).order_by(Feedback.updated_at.desc()).all()
|
||||
]
|
||||
|
||||
def get_all_feedback_ids(
|
||||
self, db: Optional[Session] = None
|
||||
) -> list[FeedbackIdResponse]:
|
||||
def get_all_feedback_ids(self, db: Optional[Session] = None) -> list[FeedbackIdResponse]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FeedbackIdResponse(
|
||||
@@ -306,14 +288,11 @@ class FeedbackTable:
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_feedbacks_for_leaderboard(
|
||||
self, db: Optional[Session] = None
|
||||
) -> list[LeaderboardFeedbackData]:
|
||||
def get_feedbacks_for_leaderboard(self, db: Optional[Session] = None) -> list[LeaderboardFeedbackData]:
|
||||
"""Fetch only id and data for leaderboard computation (excludes snapshot/meta)."""
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
LeaderboardFeedbackData(id=row.id, data=row.data)
|
||||
for row in db.query(Feedback.id, Feedback.data).all()
|
||||
LeaderboardFeedbackData(id=row.id, data=row.data) for row in db.query(Feedback.id, Feedback.data).all()
|
||||
]
|
||||
|
||||
def get_model_evaluation_history(
|
||||
@@ -333,30 +312,26 @@ class FeedbackTable:
|
||||
rows = db.query(Feedback.created_at, Feedback.data).all()
|
||||
else:
|
||||
cutoff = int(time.time()) - (days * 86400)
|
||||
rows = (
|
||||
db.query(Feedback.created_at, Feedback.data)
|
||||
.filter(Feedback.created_at >= cutoff)
|
||||
.all()
|
||||
)
|
||||
rows = db.query(Feedback.created_at, Feedback.data).filter(Feedback.created_at >= cutoff).all()
|
||||
|
||||
daily_counts = defaultdict(lambda: {"won": 0, "lost": 0})
|
||||
daily_counts = defaultdict(lambda: {'won': 0, 'lost': 0})
|
||||
first_date = None
|
||||
|
||||
for created_at, data in rows:
|
||||
if not data:
|
||||
continue
|
||||
if data.get("model_id") != model_id:
|
||||
if data.get('model_id') != model_id:
|
||||
continue
|
||||
|
||||
rating_str = str(data.get("rating", ""))
|
||||
if rating_str not in ("1", "-1"):
|
||||
rating_str = str(data.get('rating', ''))
|
||||
if rating_str not in ('1', '-1'):
|
||||
continue
|
||||
|
||||
date_str = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d")
|
||||
if rating_str == "1":
|
||||
daily_counts[date_str]["won"] += 1
|
||||
date_str = datetime.fromtimestamp(created_at).strftime('%Y-%m-%d')
|
||||
if rating_str == '1':
|
||||
daily_counts[date_str]['won'] += 1
|
||||
else:
|
||||
daily_counts[date_str]["lost"] += 1
|
||||
daily_counts[date_str]['lost'] += 1
|
||||
|
||||
# Track first date for this model
|
||||
if first_date is None or date_str < first_date:
|
||||
@@ -368,7 +343,7 @@ class FeedbackTable:
|
||||
|
||||
if days == 0 and first_date:
|
||||
# All time: start from first feedback date
|
||||
start_date = datetime.strptime(first_date, "%Y-%m-%d").date()
|
||||
start_date = datetime.strptime(first_date, '%Y-%m-%d').date()
|
||||
num_days = (today - start_date).days + 1
|
||||
else:
|
||||
# Fixed range
|
||||
@@ -377,36 +352,24 @@ class FeedbackTable:
|
||||
|
||||
for i in range(num_days):
|
||||
d = start_date + timedelta(days=i)
|
||||
date_str = d.strftime("%Y-%m-%d")
|
||||
counts = daily_counts.get(date_str, {"won": 0, "lost": 0})
|
||||
result.append(
|
||||
ModelHistoryEntry(date=date_str, won=counts["won"], lost=counts["lost"])
|
||||
)
|
||||
date_str = d.strftime('%Y-%m-%d')
|
||||
counts = daily_counts.get(date_str, {'won': 0, 'lost': 0})
|
||||
result.append(ModelHistoryEntry(date=date_str, won=counts['won'], lost=counts['lost']))
|
||||
|
||||
return result
|
||||
|
||||
def get_feedbacks_by_type(
|
||||
self, type: str, db: Optional[Session] = None
|
||||
) -> list[FeedbackModel]:
|
||||
def get_feedbacks_by_type(self, type: str, db: Optional[Session] = None) -> list[FeedbackModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FeedbackModel.model_validate(feedback)
|
||||
for feedback in db.query(Feedback)
|
||||
.filter_by(type=type)
|
||||
.order_by(Feedback.updated_at.desc())
|
||||
.all()
|
||||
for feedback in db.query(Feedback).filter_by(type=type).order_by(Feedback.updated_at.desc()).all()
|
||||
]
|
||||
|
||||
def get_feedbacks_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> list[FeedbackModel]:
|
||||
def get_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FeedbackModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FeedbackModel.model_validate(feedback)
|
||||
for feedback in db.query(Feedback)
|
||||
.filter_by(user_id=user_id)
|
||||
.order_by(Feedback.updated_at.desc())
|
||||
.all()
|
||||
for feedback in db.query(Feedback).filter_by(user_id=user_id).order_by(Feedback.updated_at.desc()).all()
|
||||
]
|
||||
|
||||
def update_feedback_by_id(
|
||||
@@ -462,9 +425,7 @@ class FeedbackTable:
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_feedback_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
|
||||
if not feedback:
|
||||
@@ -473,9 +434,7 @@ class FeedbackTable:
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
def delete_feedbacks_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
result = db.query(Feedback).filter_by(user_id=user_id).delete()
|
||||
db.commit()
|
||||
|
||||
@@ -16,7 +16,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = "file"
|
||||
__tablename__ = 'file'
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
hash = Column(Text, nullable=True)
|
||||
@@ -58,9 +58,9 @@ class FileMeta(BaseModel):
|
||||
content_type: Optional[str] = None
|
||||
size: Optional[int] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
@model_validator(mode="before")
|
||||
@model_validator(mode='before')
|
||||
@classmethod
|
||||
def sanitize_meta(cls, data):
|
||||
"""Sanitize metadata fields to handle malformed legacy data."""
|
||||
@@ -68,14 +68,12 @@ class FileMeta(BaseModel):
|
||||
return data
|
||||
|
||||
# Handle content_type that may be a list like ['application/pdf', None]
|
||||
content_type = data.get("content_type")
|
||||
content_type = data.get('content_type')
|
||||
if isinstance(content_type, list):
|
||||
# Extract first non-None string value
|
||||
data["content_type"] = next(
|
||||
(item for item in content_type if isinstance(item, str)), None
|
||||
)
|
||||
data['content_type'] = next((item for item in content_type if isinstance(item, str)), None)
|
||||
elif content_type is not None and not isinstance(content_type, str):
|
||||
data["content_type"] = None
|
||||
data['content_type'] = None
|
||||
|
||||
return data
|
||||
|
||||
@@ -92,7 +90,7 @@ class FileModelResponse(BaseModel):
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class FileMetadataResponse(BaseModel):
|
||||
@@ -123,25 +121,22 @@ class FileUpdateForm(BaseModel):
|
||||
meta: Optional[dict] = None
|
||||
|
||||
|
||||
|
||||
class FilesTable:
|
||||
def insert_new_file(
|
||||
self, user_id: str, form_data: FileForm, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
def insert_new_file(self, user_id: str, form_data: FileForm, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
file_data = form_data.model_dump()
|
||||
|
||||
# Sanitize meta to remove non-JSON-serializable objects
|
||||
# (e.g. callable tool functions, MCP client instances from middleware)
|
||||
if file_data.get("meta"):
|
||||
file_data["meta"] = sanitize_metadata(file_data["meta"])
|
||||
if file_data.get('meta'):
|
||||
file_data['meta'] = sanitize_metadata(file_data['meta'])
|
||||
|
||||
file = FileModel(
|
||||
**{
|
||||
**file_data,
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
'user_id': user_id,
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -155,12 +150,10 @@ class FilesTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error inserting a new file: {e}")
|
||||
log.exception(f'Error inserting a new file: {e}')
|
||||
return None
|
||||
|
||||
def get_file_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
@@ -171,9 +164,7 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_file_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
def get_file_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id, user_id=user_id).first()
|
||||
@@ -184,9 +175,7 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_file_metadata_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[FileMetadataResponse]:
|
||||
def get_file_metadata_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileMetadataResponse]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.get(File, id)
|
||||
@@ -204,9 +193,7 @@ class FilesTable:
|
||||
with get_db_context(db) as db:
|
||||
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
||||
|
||||
def check_access_by_user_id(
|
||||
self, id, user_id, permission="write", db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool:
|
||||
file = self.get_file_by_id(id, db=db)
|
||||
if not file:
|
||||
return False
|
||||
@@ -215,21 +202,14 @@ class FilesTable:
|
||||
# Implement additional access control logic here as needed
|
||||
return False
|
||||
|
||||
def get_files_by_ids(
|
||||
self, ids: list[str], db: Optional[Session] = None
|
||||
) -> list[FileModel]:
|
||||
def get_files_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FileModel.model_validate(file)
|
||||
for file in db.query(File)
|
||||
.filter(File.id.in_(ids))
|
||||
.order_by(File.updated_at.desc())
|
||||
.all()
|
||||
for file in db.query(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc()).all()
|
||||
]
|
||||
|
||||
def get_file_metadatas_by_ids(
|
||||
self, ids: list[str], db: Optional[Session] = None
|
||||
) -> list[FileMetadataResponse]:
|
||||
def get_file_metadatas_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileMetadataResponse]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FileMetadataResponse(
|
||||
@@ -239,22 +219,15 @@ class FilesTable:
|
||||
created_at=file.created_at,
|
||||
updated_at=file.updated_at,
|
||||
)
|
||||
for file in db.query(
|
||||
File.id, File.hash, File.meta, File.created_at, File.updated_at
|
||||
)
|
||||
for file in db.query(File.id, File.hash, File.meta, File.created_at, File.updated_at)
|
||||
.filter(File.id.in_(ids))
|
||||
.order_by(File.updated_at.desc())
|
||||
.all()
|
||||
]
|
||||
|
||||
def get_files_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> list[FileModel]:
|
||||
def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FileModel.model_validate(file)
|
||||
for file in db.query(File).filter_by(user_id=user_id).all()
|
||||
]
|
||||
return [FileModel.model_validate(file) for file in db.query(File).filter_by(user_id=user_id).all()]
|
||||
|
||||
def get_file_list(
|
||||
self,
|
||||
@@ -262,7 +235,7 @@ class FilesTable:
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[Session] = None,
|
||||
) -> "FileListResponse":
|
||||
) -> 'FileListResponse':
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(File)
|
||||
if user_id:
|
||||
@@ -272,10 +245,7 @@ class FilesTable:
|
||||
|
||||
items = [
|
||||
FileModel.model_validate(file)
|
||||
for file in query.order_by(File.updated_at.desc(), File.id.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
for file in query.order_by(File.updated_at.desc(), File.id.desc()).offset(skip).limit(limit).all()
|
||||
]
|
||||
|
||||
return FileListResponse(items=items, total=total)
|
||||
@@ -296,17 +266,17 @@ class FilesTable:
|
||||
A SQL LIKE compatible pattern with proper escaping.
|
||||
"""
|
||||
# Escape SQL special characters first, then convert glob wildcards
|
||||
pattern = glob.replace("\\", "\\\\")
|
||||
pattern = pattern.replace("%", "\\%")
|
||||
pattern = pattern.replace("_", "\\_")
|
||||
pattern = pattern.replace("*", "%")
|
||||
pattern = pattern.replace("?", "_")
|
||||
pattern = glob.replace('\\', '\\\\')
|
||||
pattern = pattern.replace('%', '\\%')
|
||||
pattern = pattern.replace('_', '\\_')
|
||||
pattern = pattern.replace('*', '%')
|
||||
pattern = pattern.replace('?', '_')
|
||||
return pattern
|
||||
|
||||
def search_files(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
filename: str = "*",
|
||||
filename: str = '*',
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Optional[Session] = None,
|
||||
@@ -331,15 +301,12 @@ class FilesTable:
|
||||
query = query.filter_by(user_id=user_id)
|
||||
|
||||
pattern = self._glob_to_like_pattern(filename)
|
||||
if pattern != "%":
|
||||
query = query.filter(File.filename.ilike(pattern, escape="\\"))
|
||||
if pattern != '%':
|
||||
query = query.filter(File.filename.ilike(pattern, escape='\\'))
|
||||
|
||||
return [
|
||||
FileModel.model_validate(file)
|
||||
for file in query.order_by(File.created_at.desc(), File.id.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
for file in query.order_by(File.created_at.desc(), File.id.desc()).offset(skip).limit(limit).all()
|
||||
]
|
||||
|
||||
def update_file_by_id(
|
||||
@@ -362,12 +329,10 @@ class FilesTable:
|
||||
db.commit()
|
||||
return FileModel.model_validate(file)
|
||||
except Exception as e:
|
||||
log.exception(f"Error updating file completely by id: {e}")
|
||||
log.exception(f'Error updating file completely by id: {e}')
|
||||
return None
|
||||
|
||||
def update_file_hash_by_id(
|
||||
self, id: str, hash: Optional[str], db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
def update_file_hash_by_id(self, id: str, hash: Optional[str], db: Optional[Session] = None) -> Optional[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
@@ -379,9 +344,7 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_file_data_by_id(
|
||||
self, id: str, data: dict, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
def update_file_data_by_id(self, id: str, data: dict, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
@@ -390,12 +353,9 @@ class FilesTable:
|
||||
db.commit()
|
||||
return FileModel.model_validate(file)
|
||||
except Exception as e:
|
||||
|
||||
return None
|
||||
|
||||
def update_file_metadata_by_id(
|
||||
self, id: str, meta: dict, db: Optional[Session] = None
|
||||
) -> Optional[FileModel]:
|
||||
def update_file_metadata_by_id(self, id: str, meta: dict, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
file = db.query(File).filter_by(id=id).first()
|
||||
|
||||
@@ -20,7 +20,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Folder(Base):
|
||||
__tablename__ = "folder"
|
||||
__tablename__ = 'folder'
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
parent_id = Column(Text, nullable=True)
|
||||
user_id = Column(Text)
|
||||
@@ -72,14 +72,14 @@ class FolderForm(BaseModel):
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
parent_id: Optional[str] = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class FolderUpdateForm(BaseModel):
|
||||
name: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class FolderTable:
|
||||
@@ -94,12 +94,12 @@ class FolderTable:
|
||||
id = str(uuid.uuid4())
|
||||
folder = FolderModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
'id': id,
|
||||
'user_id': user_id,
|
||||
**(form_data.model_dump(exclude_unset=True) or {}),
|
||||
"parent_id": parent_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
'parent_id': parent_id,
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
try:
|
||||
@@ -112,7 +112,7 @@ class FolderTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error inserting a new folder: {e}")
|
||||
log.exception(f'Error inserting a new folder: {e}')
|
||||
return None
|
||||
|
||||
def get_folder_by_id_and_user_id(
|
||||
@@ -137,9 +137,7 @@ class FolderTable:
|
||||
folders = []
|
||||
|
||||
def get_children(folder):
|
||||
children = self.get_folders_by_parent_id_and_user_id(
|
||||
folder.id, user_id, db=db
|
||||
)
|
||||
children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db)
|
||||
for child in children:
|
||||
get_children(child)
|
||||
folders.append(child)
|
||||
@@ -153,14 +151,9 @@ class FolderTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_folders_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> list[FolderModel]:
|
||||
def get_folders_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FolderModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FolderModel.model_validate(folder)
|
||||
for folder in db.query(Folder).filter_by(user_id=user_id).all()
|
||||
]
|
||||
return [FolderModel.model_validate(folder) for folder in db.query(Folder).filter_by(user_id=user_id).all()]
|
||||
|
||||
def get_folder_by_parent_id_and_user_id_and_name(
|
||||
self,
|
||||
@@ -184,7 +177,7 @@ class FolderTable:
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
|
||||
log.error(f'get_folder_by_parent_id_and_user_id_and_name: {e}')
|
||||
return None
|
||||
|
||||
def get_folders_by_parent_id_and_user_id(
|
||||
@@ -193,9 +186,7 @@ class FolderTable:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FolderModel.model_validate(folder)
|
||||
for folder in db.query(Folder)
|
||||
.filter_by(parent_id=parent_id, user_id=user_id)
|
||||
.all()
|
||||
for folder in db.query(Folder).filter_by(parent_id=parent_id, user_id=user_id).all()
|
||||
]
|
||||
|
||||
def update_folder_parent_id_by_id_and_user_id(
|
||||
@@ -219,7 +210,7 @@ class FolderTable:
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
log.error(f'update_folder: {e}')
|
||||
return
|
||||
|
||||
def update_folder_by_id_and_user_id(
|
||||
@@ -241,7 +232,7 @@ class FolderTable:
|
||||
existing_folder = (
|
||||
db.query(Folder)
|
||||
.filter_by(
|
||||
name=form_data.get("name"),
|
||||
name=form_data.get('name'),
|
||||
parent_id=folder.parent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -251,17 +242,17 @@ class FolderTable:
|
||||
if existing_folder and existing_folder.id != id:
|
||||
return None
|
||||
|
||||
folder.name = form_data.get("name", folder.name)
|
||||
if "data" in form_data:
|
||||
folder.name = form_data.get('name', folder.name)
|
||||
if 'data' in form_data:
|
||||
folder.data = {
|
||||
**(folder.data or {}),
|
||||
**form_data["data"],
|
||||
**form_data['data'],
|
||||
}
|
||||
|
||||
if "meta" in form_data:
|
||||
if 'meta' in form_data:
|
||||
folder.meta = {
|
||||
**(folder.meta or {}),
|
||||
**form_data["meta"],
|
||||
**form_data['meta'],
|
||||
}
|
||||
|
||||
folder.updated_at = int(time.time())
|
||||
@@ -269,7 +260,7 @@ class FolderTable:
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
log.error(f'update_folder: {e}')
|
||||
return
|
||||
|
||||
def update_folder_is_expanded_by_id_and_user_id(
|
||||
@@ -289,12 +280,10 @@ class FolderTable:
|
||||
|
||||
return FolderModel.model_validate(folder)
|
||||
except Exception as e:
|
||||
log.error(f"update_folder: {e}")
|
||||
log.error(f'update_folder: {e}')
|
||||
return
|
||||
|
||||
def delete_folder_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> list[str]:
|
||||
def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[str]:
|
||||
try:
|
||||
folder_ids = []
|
||||
with get_db_context(db) as db:
|
||||
@@ -306,11 +295,8 @@ class FolderTable:
|
||||
|
||||
# Delete all children folders
|
||||
def delete_children(folder):
|
||||
folder_children = self.get_folders_by_parent_id_and_user_id(
|
||||
folder.id, user_id, db=db
|
||||
)
|
||||
folder_children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db)
|
||||
for folder_child in folder_children:
|
||||
|
||||
delete_children(folder_child)
|
||||
folder_ids.append(folder_child.id)
|
||||
|
||||
@@ -323,12 +309,12 @@ class FolderTable:
|
||||
db.commit()
|
||||
return folder_ids
|
||||
except Exception as e:
|
||||
log.error(f"delete_folder: {e}")
|
||||
log.error(f'delete_folder: {e}')
|
||||
return []
|
||||
|
||||
def normalize_folder_name(self, name: str) -> str:
|
||||
# Replace _ and space with a single space, lower case, collapse multiple spaces
|
||||
name = re.sub(r"[\s_]+", " ", name)
|
||||
name = re.sub(r'[\s_]+', ' ', name)
|
||||
return name.strip().lower()
|
||||
|
||||
def search_folders_by_names(
|
||||
@@ -349,9 +335,7 @@ class FolderTable:
|
||||
results[folder.id] = FolderModel.model_validate(folder)
|
||||
|
||||
# get children folders
|
||||
children = self.get_children_folders_by_id_and_user_id(
|
||||
folder.id, user_id, db=db
|
||||
)
|
||||
children = self.get_children_folders_by_id_and_user_id(folder.id, user_id, db=db)
|
||||
for child in children:
|
||||
results[child.id] = child
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Function(Base):
|
||||
__tablename__ = "function"
|
||||
__tablename__ = 'function'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
@@ -30,13 +30,13 @@ class Function(Base):
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
__table_args__ = (Index("is_global_idx", "is_global"),)
|
||||
__table_args__ = (Index('is_global_idx', 'is_global'),)
|
||||
|
||||
|
||||
class FunctionMeta(BaseModel):
|
||||
description: Optional[str] = None
|
||||
manifest: Optional[dict] = {}
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class FunctionModel(BaseModel):
|
||||
@@ -113,10 +113,10 @@ class FunctionsTable:
|
||||
function = FunctionModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"user_id": user_id,
|
||||
"type": type,
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
'user_id': user_id,
|
||||
'type': type,
|
||||
'updated_at': int(time.time()),
|
||||
'created_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -131,7 +131,7 @@ class FunctionsTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating a new function: {e}")
|
||||
log.exception(f'Error creating a new function: {e}')
|
||||
return None
|
||||
|
||||
def sync_functions(
|
||||
@@ -156,16 +156,16 @@ class FunctionsTable:
|
||||
db.query(Function).filter_by(id=func.id).update(
|
||||
{
|
||||
**func.model_dump(),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
'user_id': user_id,
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
else:
|
||||
new_func = Function(
|
||||
**{
|
||||
**func.model_dump(),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
'user_id': user_id,
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.add(new_func)
|
||||
@@ -177,17 +177,12 @@ class FunctionsTable:
|
||||
|
||||
db.commit()
|
||||
|
||||
return [
|
||||
FunctionModel.model_validate(func)
|
||||
for func in db.query(Function).all()
|
||||
]
|
||||
return [FunctionModel.model_validate(func) for func in db.query(Function).all()]
|
||||
except Exception as e:
|
||||
log.exception(f"Error syncing functions for user {user_id}: {e}")
|
||||
log.exception(f'Error syncing functions for user {user_id}: {e}')
|
||||
return []
|
||||
|
||||
def get_function_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[FunctionModel]:
|
||||
def get_function_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FunctionModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
function = db.get(Function, id)
|
||||
@@ -195,9 +190,7 @@ class FunctionsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_functions_by_ids(
|
||||
self, ids: list[str], db: Optional[Session] = None
|
||||
) -> list[FunctionModel]:
|
||||
def get_functions_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FunctionModel]:
|
||||
"""
|
||||
Batch fetch multiple functions by their IDs in a single query.
|
||||
Returns functions in the same order as the input IDs (None entries filtered out).
|
||||
@@ -225,18 +218,11 @@ class FunctionsTable:
|
||||
functions = db.query(Function).all()
|
||||
|
||||
if include_valves:
|
||||
return [
|
||||
FunctionWithValvesModel.model_validate(function)
|
||||
for function in functions
|
||||
]
|
||||
return [FunctionWithValvesModel.model_validate(function) for function in functions]
|
||||
else:
|
||||
return [
|
||||
FunctionModel.model_validate(function) for function in functions
|
||||
]
|
||||
return [FunctionModel.model_validate(function) for function in functions]
|
||||
|
||||
def get_function_list(
|
||||
self, db: Optional[Session] = None
|
||||
) -> list[FunctionUserResponse]:
|
||||
def get_function_list(self, db: Optional[Session] = None) -> list[FunctionUserResponse]:
|
||||
with get_db_context(db) as db:
|
||||
functions = db.query(Function).order_by(Function.updated_at.desc()).all()
|
||||
user_ids = list(set(func.user_id for func in functions))
|
||||
@@ -248,69 +234,48 @@ class FunctionsTable:
|
||||
FunctionUserResponse.model_validate(
|
||||
{
|
||||
**FunctionModel.model_validate(func).model_dump(),
|
||||
"user": (
|
||||
users_dict.get(func.user_id).model_dump()
|
||||
if func.user_id in users_dict
|
||||
else None
|
||||
),
|
||||
'user': (users_dict.get(func.user_id).model_dump() if func.user_id in users_dict else None),
|
||||
}
|
||||
)
|
||||
for func in functions
|
||||
]
|
||||
|
||||
def get_functions_by_type(
|
||||
self, type: str, active_only=False, db: Optional[Session] = None
|
||||
) -> list[FunctionModel]:
|
||||
def get_functions_by_type(self, type: str, active_only=False, db: Optional[Session] = None) -> list[FunctionModel]:
|
||||
with get_db_context(db) as db:
|
||||
if active_only:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function)
|
||||
.filter_by(type=type, is_active=True)
|
||||
.all()
|
||||
for function in db.query(Function).filter_by(type=type, is_active=True).all()
|
||||
]
|
||||
else:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function).filter_by(type=type).all()
|
||||
FunctionModel.model_validate(function) for function in db.query(Function).filter_by(type=type).all()
|
||||
]
|
||||
|
||||
def get_global_filter_functions(
|
||||
self, db: Optional[Session] = None
|
||||
) -> list[FunctionModel]:
|
||||
def get_global_filter_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function)
|
||||
.filter_by(type="filter", is_active=True, is_global=True)
|
||||
.all()
|
||||
for function in db.query(Function).filter_by(type='filter', is_active=True, is_global=True).all()
|
||||
]
|
||||
|
||||
def get_global_action_functions(
|
||||
self, db: Optional[Session] = None
|
||||
) -> list[FunctionModel]:
|
||||
def get_global_action_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function)
|
||||
.filter_by(type="action", is_active=True, is_global=True)
|
||||
.all()
|
||||
for function in db.query(Function).filter_by(type='action', is_active=True, is_global=True).all()
|
||||
]
|
||||
|
||||
def get_function_valves_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[dict]:
|
||||
def get_function_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
function = db.get(Function, id)
|
||||
return function.valves if function.valves else {}
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting function valves by id {id}: {e}")
|
||||
log.exception(f'Error getting function valves by id {id}: {e}')
|
||||
return None
|
||||
|
||||
def get_function_valves_by_ids(
|
||||
self, ids: list[str], db: Optional[Session] = None
|
||||
) -> dict[str, dict]:
|
||||
def get_function_valves_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, dict]:
|
||||
"""
|
||||
Batch fetch valves for multiple functions in a single query.
|
||||
Returns a dict mapping function_id -> valves dict.
|
||||
@@ -320,14 +285,10 @@ class FunctionsTable:
|
||||
return {}
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
functions = (
|
||||
db.query(Function.id, Function.valves)
|
||||
.filter(Function.id.in_(ids))
|
||||
.all()
|
||||
)
|
||||
functions = db.query(Function.id, Function.valves).filter(Function.id.in_(ids)).all()
|
||||
return {f.id: (f.valves if f.valves else {}) for f in functions}
|
||||
except Exception as e:
|
||||
log.exception(f"Error batch-fetching function valves: {e}")
|
||||
log.exception(f'Error batch-fetching function valves: {e}')
|
||||
return {}
|
||||
|
||||
def update_function_valves_by_id(
|
||||
@@ -364,25 +325,23 @@ class FunctionsTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error updating function metadata by id {id}: {e}")
|
||||
log.exception(f'Error updating function metadata by id {id}: {e}')
|
||||
return None
|
||||
|
||||
def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[dict]:
|
||||
def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "functions" and "valves" settings
|
||||
if "functions" not in user_settings:
|
||||
user_settings["functions"] = {}
|
||||
if "valves" not in user_settings["functions"]:
|
||||
user_settings["functions"]["valves"] = {}
|
||||
if 'functions' not in user_settings:
|
||||
user_settings['functions'] = {}
|
||||
if 'valves' not in user_settings['functions']:
|
||||
user_settings['functions']['valves'] = {}
|
||||
|
||||
return user_settings["functions"]["valves"].get(id, {})
|
||||
return user_settings['functions']['valves'].get(id, {})
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting user values by id {id} and user id {user_id}")
|
||||
log.exception(f'Error getting user values by id {id} and user id {user_id}')
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
@@ -393,32 +352,28 @@ class FunctionsTable:
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "functions" and "valves" settings
|
||||
if "functions" not in user_settings:
|
||||
user_settings["functions"] = {}
|
||||
if "valves" not in user_settings["functions"]:
|
||||
user_settings["functions"]["valves"] = {}
|
||||
if 'functions' not in user_settings:
|
||||
user_settings['functions'] = {}
|
||||
if 'valves' not in user_settings['functions']:
|
||||
user_settings['functions']['valves'] = {}
|
||||
|
||||
user_settings["functions"]["valves"][id] = valves
|
||||
user_settings['functions']['valves'][id] = valves
|
||||
|
||||
# Update the user settings in the database
|
||||
Users.update_user_by_id(user_id, {"settings": user_settings}, db=db)
|
||||
Users.update_user_by_id(user_id, {'settings': user_settings}, db=db)
|
||||
|
||||
return user_settings["functions"]["valves"][id]
|
||||
return user_settings['functions']['valves'][id]
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
|
||||
return None
|
||||
|
||||
def update_function_by_id(
|
||||
self, id: str, updated: dict, db: Optional[Session] = None
|
||||
) -> Optional[FunctionModel]:
|
||||
def update_function_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[FunctionModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(Function).filter_by(id=id).update(
|
||||
{
|
||||
**updated,
|
||||
"updated_at": int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
@@ -432,8 +387,8 @@ class FunctionsTable:
|
||||
try:
|
||||
db.query(Function).update(
|
||||
{
|
||||
"is_active": False,
|
||||
"updated_at": int(time.time()),
|
||||
'is_active': False,
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
|
||||
@@ -34,7 +34,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
__tablename__ = "group"
|
||||
__tablename__ = 'group'
|
||||
|
||||
id = Column(Text, unique=True, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
@@ -70,12 +70,12 @@ class GroupModel(BaseModel):
|
||||
|
||||
|
||||
class GroupMember(Base):
|
||||
__tablename__ = "group_member"
|
||||
__tablename__ = 'group_member'
|
||||
|
||||
id = Column(Text, unique=True, primary_key=True)
|
||||
group_id = Column(
|
||||
Text,
|
||||
ForeignKey("group.id", ondelete="CASCADE"),
|
||||
ForeignKey('group.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
)
|
||||
user_id = Column(Text, nullable=False)
|
||||
@@ -133,28 +133,26 @@ class GroupListResponse(BaseModel):
|
||||
class GroupTable:
|
||||
def _ensure_default_share_config(self, group_data: dict) -> dict:
|
||||
"""Ensure the group data dict has a default share config if not already set."""
|
||||
if "data" not in group_data or group_data["data"] is None:
|
||||
group_data["data"] = {}
|
||||
if "config" not in group_data["data"]:
|
||||
group_data["data"]["config"] = {}
|
||||
if "share" not in group_data["data"]["config"]:
|
||||
group_data["data"]["config"]["share"] = DEFAULT_GROUP_SHARE_PERMISSION
|
||||
if 'data' not in group_data or group_data['data'] is None:
|
||||
group_data['data'] = {}
|
||||
if 'config' not in group_data['data']:
|
||||
group_data['data']['config'] = {}
|
||||
if 'share' not in group_data['data']['config']:
|
||||
group_data['data']['config']['share'] = DEFAULT_GROUP_SHARE_PERMISSION
|
||||
return group_data
|
||||
|
||||
def insert_new_group(
|
||||
self, user_id: str, form_data: GroupForm, db: Optional[Session] = None
|
||||
) -> Optional[GroupModel]:
|
||||
with get_db_context(db) as db:
|
||||
group_data = self._ensure_default_share_config(
|
||||
form_data.model_dump(exclude_none=True)
|
||||
)
|
||||
group_data = self._ensure_default_share_config(form_data.model_dump(exclude_none=True))
|
||||
group = GroupModel(
|
||||
**{
|
||||
**group_data,
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
'id': str(uuid.uuid4()),
|
||||
'user_id': user_id,
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -183,19 +181,19 @@ class GroupTable:
|
||||
.where(GroupMember.group_id == Group.id)
|
||||
.correlate(Group)
|
||||
.scalar_subquery()
|
||||
.label("member_count")
|
||||
.label('member_count')
|
||||
)
|
||||
query = db.query(Group, member_count)
|
||||
|
||||
if filter:
|
||||
if "query" in filter:
|
||||
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
||||
if 'query' in filter:
|
||||
query = query.filter(Group.name.ilike(f'%{filter["query"]}%'))
|
||||
|
||||
# When share filter is present, member check is handled in the share logic
|
||||
if "share" in filter:
|
||||
share_value = filter["share"]
|
||||
member_id = filter.get("member_id")
|
||||
json_share = Group.data["config"]["share"]
|
||||
if 'share' in filter:
|
||||
share_value = filter['share']
|
||||
member_id = filter.get('member_id')
|
||||
json_share = Group.data['config']['share']
|
||||
json_share_str = json_share.as_string()
|
||||
json_share_lower = func.lower(json_share_str)
|
||||
|
||||
@@ -203,37 +201,27 @@ class GroupTable:
|
||||
anyone_can_share = or_(
|
||||
Group.data.is_(None),
|
||||
json_share_str.is_(None),
|
||||
json_share_lower == "true",
|
||||
json_share_lower == "1", # Handle SQLite boolean true
|
||||
json_share_lower == 'true',
|
||||
json_share_lower == '1', # Handle SQLite boolean true
|
||||
)
|
||||
|
||||
if member_id:
|
||||
member_groups_select = select(GroupMember.group_id).where(
|
||||
GroupMember.user_id == member_id
|
||||
)
|
||||
member_groups_select = select(GroupMember.group_id).where(GroupMember.user_id == member_id)
|
||||
members_only_and_is_member = and_(
|
||||
json_share_lower == "members",
|
||||
json_share_lower == 'members',
|
||||
Group.id.in_(member_groups_select),
|
||||
)
|
||||
query = query.filter(
|
||||
or_(anyone_can_share, members_only_and_is_member)
|
||||
)
|
||||
query = query.filter(or_(anyone_can_share, members_only_and_is_member))
|
||||
else:
|
||||
query = query.filter(anyone_can_share)
|
||||
else:
|
||||
query = query.filter(
|
||||
and_(Group.data.isnot(None), json_share_lower == "false")
|
||||
)
|
||||
query = query.filter(and_(Group.data.isnot(None), json_share_lower == 'false'))
|
||||
|
||||
else:
|
||||
# Only apply member_id filter when share filter is NOT present
|
||||
if "member_id" in filter:
|
||||
if 'member_id' in filter:
|
||||
query = query.filter(
|
||||
Group.id.in_(
|
||||
select(GroupMember.group_id).where(
|
||||
GroupMember.user_id == filter["member_id"]
|
||||
)
|
||||
)
|
||||
Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id']))
|
||||
)
|
||||
|
||||
results = query.order_by(Group.updated_at.desc()).all()
|
||||
@@ -242,7 +230,7 @@ class GroupTable:
|
||||
GroupResponse.model_validate(
|
||||
{
|
||||
**GroupModel.model_validate(group).model_dump(),
|
||||
"member_count": count or 0,
|
||||
'member_count': count or 0,
|
||||
}
|
||||
)
|
||||
for group, count in results
|
||||
@@ -259,22 +247,16 @@ class GroupTable:
|
||||
query = db.query(Group)
|
||||
|
||||
if filter:
|
||||
if "query" in filter:
|
||||
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
||||
if "member_id" in filter:
|
||||
if 'query' in filter:
|
||||
query = query.filter(Group.name.ilike(f'%{filter["query"]}%'))
|
||||
if 'member_id' in filter:
|
||||
query = query.filter(
|
||||
Group.id.in_(
|
||||
select(GroupMember.group_id).where(
|
||||
GroupMember.user_id == filter["member_id"]
|
||||
)
|
||||
)
|
||||
Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id']))
|
||||
)
|
||||
|
||||
if "share" in filter:
|
||||
share_value = filter["share"]
|
||||
query = query.filter(
|
||||
Group.data.op("->>")("share") == str(share_value)
|
||||
)
|
||||
if 'share' in filter:
|
||||
share_value = filter['share']
|
||||
query = query.filter(Group.data.op('->>')('share') == str(share_value))
|
||||
|
||||
total = query.count()
|
||||
|
||||
@@ -283,32 +265,24 @@ class GroupTable:
|
||||
.where(GroupMember.group_id == Group.id)
|
||||
.correlate(Group)
|
||||
.scalar_subquery()
|
||||
.label("member_count")
|
||||
)
|
||||
results = (
|
||||
query.add_columns(member_count)
|
||||
.order_by(Group.updated_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
.label('member_count')
|
||||
)
|
||||
results = query.add_columns(member_count).order_by(Group.updated_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
return {
|
||||
"items": [
|
||||
'items': [
|
||||
GroupResponse.model_validate(
|
||||
{
|
||||
**GroupModel.model_validate(group).model_dump(),
|
||||
"member_count": count or 0,
|
||||
'member_count': count or 0,
|
||||
}
|
||||
)
|
||||
for group, count in results
|
||||
],
|
||||
"total": total,
|
||||
'total': total,
|
||||
}
|
||||
|
||||
def get_groups_by_member_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> list[GroupModel]:
|
||||
def get_groups_by_member_id(self, user_id: str, db: Optional[Session] = None) -> list[GroupModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
GroupModel.model_validate(group)
|
||||
@@ -340,9 +314,7 @@ class GroupTable:
|
||||
|
||||
return user_groups
|
||||
|
||||
def get_group_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[GroupModel]:
|
||||
def get_group_by_id(self, id: str, db: Optional[Session] = None) -> Optional[GroupModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
group = db.query(Group).filter_by(id=id).first()
|
||||
@@ -350,41 +322,29 @@ class GroupTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_group_user_ids_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> list[str]:
|
||||
def get_group_user_ids_by_id(self, id: str, db: Optional[Session] = None) -> list[str]:
|
||||
with get_db_context(db) as db:
|
||||
members = (
|
||||
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
|
||||
)
|
||||
members = db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
|
||||
|
||||
if not members:
|
||||
return []
|
||||
|
||||
return [m[0] for m in members]
|
||||
|
||||
def get_group_user_ids_by_ids(
|
||||
self, group_ids: list[str], db: Optional[Session] = None
|
||||
) -> dict[str, list[str]]:
|
||||
def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[Session] = None) -> dict[str, list[str]]:
|
||||
with get_db_context(db) as db:
|
||||
members = (
|
||||
db.query(GroupMember.group_id, GroupMember.user_id)
|
||||
.filter(GroupMember.group_id.in_(group_ids))
|
||||
.all()
|
||||
db.query(GroupMember.group_id, GroupMember.user_id).filter(GroupMember.group_id.in_(group_ids)).all()
|
||||
)
|
||||
|
||||
group_user_ids: dict[str, list[str]] = {
|
||||
group_id: [] for group_id in group_ids
|
||||
}
|
||||
group_user_ids: dict[str, list[str]] = {group_id: [] for group_id in group_ids}
|
||||
|
||||
for group_id, user_id in members:
|
||||
group_user_ids[group_id].append(user_id)
|
||||
|
||||
return group_user_ids
|
||||
|
||||
def set_group_user_ids_by_id(
|
||||
self, group_id: str, user_ids: list[str], db: Optional[Session] = None
|
||||
) -> None:
|
||||
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[Session] = None) -> None:
|
||||
with get_db_context(db) as db:
|
||||
# Delete existing members
|
||||
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
|
||||
@@ -405,20 +365,12 @@ class GroupTable:
|
||||
db.add_all(new_members)
|
||||
db.commit()
|
||||
|
||||
def get_group_member_count_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> int:
|
||||
def get_group_member_count_by_id(self, id: str, db: Optional[Session] = None) -> int:
|
||||
with get_db_context(db) as db:
|
||||
count = (
|
||||
db.query(func.count(GroupMember.user_id))
|
||||
.filter(GroupMember.group_id == id)
|
||||
.scalar()
|
||||
)
|
||||
count = db.query(func.count(GroupMember.user_id)).filter(GroupMember.group_id == id).scalar()
|
||||
return count if count else 0
|
||||
|
||||
def get_group_member_counts_by_ids(
|
||||
self, ids: list[str], db: Optional[Session] = None
|
||||
) -> dict[str, int]:
|
||||
def get_group_member_counts_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, int]:
|
||||
if not ids:
|
||||
return {}
|
||||
with get_db_context(db) as db:
|
||||
@@ -442,7 +394,7 @@ class GroupTable:
|
||||
db.query(Group).filter_by(id=id).update(
|
||||
{
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
"updated_at": int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
@@ -470,9 +422,7 @@ class GroupTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def remove_user_from_all_groups(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def remove_user_from_all_groups(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
# Find all groups the user belongs to
|
||||
@@ -489,9 +439,7 @@ class GroupTable:
|
||||
GroupMember.group_id == group.id, GroupMember.user_id == user_id
|
||||
).delete()
|
||||
|
||||
db.query(Group).filter_by(id=group.id).update(
|
||||
{"updated_at": int(time.time())}
|
||||
)
|
||||
db.query(Group).filter_by(id=group.id).update({'updated_at': int(time.time())})
|
||||
|
||||
db.commit()
|
||||
return True
|
||||
@@ -503,7 +451,6 @@ class GroupTable:
|
||||
def create_groups_by_group_names(
|
||||
self, user_id: str, group_names: list[str], db: Optional[Session] = None
|
||||
) -> list[GroupModel]:
|
||||
|
||||
# check for existing groups
|
||||
existing_groups = self.get_all_groups(db=db)
|
||||
existing_group_names = {group.name for group in existing_groups}
|
||||
@@ -517,10 +464,10 @@ class GroupTable:
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
name=group_name,
|
||||
description="",
|
||||
description='',
|
||||
data={
|
||||
"config": {
|
||||
"share": DEFAULT_GROUP_SHARE_PERMISSION,
|
||||
'config': {
|
||||
'share': DEFAULT_GROUP_SHARE_PERMISSION,
|
||||
}
|
||||
},
|
||||
created_at=int(time.time()),
|
||||
@@ -537,17 +484,13 @@ class GroupTable:
|
||||
continue
|
||||
return new_groups
|
||||
|
||||
def sync_groups_by_group_names(
|
||||
self, user_id: str, group_names: list[str], db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
now = int(time.time())
|
||||
|
||||
# 1. Groups that SHOULD contain the user
|
||||
target_groups = (
|
||||
db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||
)
|
||||
target_groups = db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||
target_group_ids = {g.id for g in target_groups}
|
||||
|
||||
# 2. Groups the user is CURRENTLY in
|
||||
@@ -571,7 +514,7 @@ class GroupTable:
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
|
||||
{"updated_at": now}, synchronize_session=False
|
||||
{'updated_at': now}, synchronize_session=False
|
||||
)
|
||||
|
||||
# 5. Bulk insert missing memberships
|
||||
@@ -588,7 +531,7 @@ class GroupTable:
|
||||
|
||||
if groups_to_add:
|
||||
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
|
||||
{"updated_at": now}, synchronize_session=False
|
||||
{'updated_at': now}, synchronize_session=False
|
||||
)
|
||||
|
||||
db.commit()
|
||||
@@ -656,9 +599,9 @@ class GroupTable:
|
||||
return GroupModel.model_validate(group)
|
||||
|
||||
# Remove users from group_member in batch
|
||||
db.query(GroupMember).filter(
|
||||
GroupMember.group_id == id, GroupMember.user_id.in_(user_ids)
|
||||
).delete(synchronize_session=False)
|
||||
db.query(GroupMember).filter(GroupMember.group_id == id, GroupMember.user_id.in_(user_ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
# Update group timestamp
|
||||
group.updated_at = int(time.time())
|
||||
|
||||
@@ -38,7 +38,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Knowledge(Base):
|
||||
__tablename__ = "knowledge"
|
||||
__tablename__ = 'knowledge'
|
||||
|
||||
id = Column(Text, unique=True, primary_key=True)
|
||||
user_id = Column(Text)
|
||||
@@ -70,24 +70,18 @@ class KnowledgeModel(BaseModel):
|
||||
|
||||
|
||||
class KnowledgeFile(Base):
|
||||
__tablename__ = "knowledge_file"
|
||||
__tablename__ = 'knowledge_file'
|
||||
|
||||
id = Column(Text, unique=True, primary_key=True)
|
||||
|
||||
knowledge_id = Column(
|
||||
Text, ForeignKey("knowledge.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
|
||||
knowledge_id = Column(Text, ForeignKey('knowledge.id', ondelete='CASCADE'), nullable=False)
|
||||
file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False)
|
||||
user_id = Column(Text, nullable=False)
|
||||
|
||||
created_at = Column(BigInteger, nullable=False)
|
||||
updated_at = Column(BigInteger, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
|
||||
),
|
||||
)
|
||||
__table_args__ = (UniqueConstraint('knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file'),)
|
||||
|
||||
|
||||
class KnowledgeFileModel(BaseModel):
|
||||
@@ -138,10 +132,8 @@ class KnowledgeFileListResponse(BaseModel):
|
||||
|
||||
|
||||
class KnowledgeTable:
|
||||
def _get_access_grants(
|
||||
self, knowledge_id: str, db: Optional[Session] = None
|
||||
) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource("knowledge", knowledge_id, db=db)
|
||||
def _get_access_grants(self, knowledge_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource('knowledge', knowledge_id, db=db)
|
||||
|
||||
def _to_knowledge_model(
|
||||
self,
|
||||
@@ -149,13 +141,9 @@ class KnowledgeTable:
|
||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> KnowledgeModel:
|
||||
knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump(
|
||||
exclude={"access_grants"}
|
||||
)
|
||||
knowledge_data["access_grants"] = (
|
||||
access_grants
|
||||
if access_grants is not None
|
||||
else self._get_access_grants(knowledge_data["id"], db=db)
|
||||
knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump(exclude={'access_grants'})
|
||||
knowledge_data['access_grants'] = (
|
||||
access_grants if access_grants is not None else self._get_access_grants(knowledge_data['id'], db=db)
|
||||
)
|
||||
return KnowledgeModel.model_validate(knowledge_data)
|
||||
|
||||
@@ -165,23 +153,21 @@ class KnowledgeTable:
|
||||
with get_db_context(db) as db:
|
||||
knowledge = KnowledgeModel(
|
||||
**{
|
||||
**form_data.model_dump(exclude={"access_grants"}),
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
"access_grants": [],
|
||||
**form_data.model_dump(exclude={'access_grants'}),
|
||||
'id': str(uuid.uuid4()),
|
||||
'user_id': user_id,
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
'access_grants': [],
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = Knowledge(**knowledge.model_dump(exclude={"access_grants"}))
|
||||
result = Knowledge(**knowledge.model_dump(exclude={'access_grants'}))
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
AccessGrants.set_access_grants(
|
||||
"knowledge", result.id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('knowledge', result.id, form_data.access_grants, db=db)
|
||||
if result:
|
||||
return self._to_knowledge_model(result, db=db)
|
||||
else:
|
||||
@@ -193,17 +179,13 @@ class KnowledgeTable:
|
||||
self, skip: int = 0, limit: int = 30, db: Optional[Session] = None
|
||||
) -> list[KnowledgeUserModel]:
|
||||
with get_db_context(db) as db:
|
||||
all_knowledge = (
|
||||
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
|
||||
)
|
||||
all_knowledge = db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
|
||||
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
|
||||
knowledge_ids = [knowledge.id for knowledge in all_knowledge]
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"knowledge", knowledge_ids, db=db
|
||||
)
|
||||
grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
|
||||
|
||||
knowledge_bases = []
|
||||
for knowledge in all_knowledge:
|
||||
@@ -216,7 +198,7 @@ class KnowledgeTable:
|
||||
access_grants=grants_map.get(knowledge.id, []),
|
||||
db=db,
|
||||
).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
'user': user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -232,27 +214,25 @@ class KnowledgeTable:
|
||||
) -> KnowledgeListResponse:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Knowledge, User).outerjoin(
|
||||
User, User.id == Knowledge.user_id
|
||||
)
|
||||
query = db.query(Knowledge, User).outerjoin(User, User.id == Knowledge.user_id)
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
query_key = filter.get('query')
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Knowledge.name.ilike(f"%{query_key}%"),
|
||||
Knowledge.description.ilike(f"%{query_key}%"),
|
||||
User.name.ilike(f"%{query_key}%"),
|
||||
User.email.ilike(f"%{query_key}%"),
|
||||
User.username.ilike(f"%{query_key}%"),
|
||||
Knowledge.name.ilike(f'%{query_key}%'),
|
||||
Knowledge.description.ilike(f'%{query_key}%'),
|
||||
User.name.ilike(f'%{query_key}%'),
|
||||
User.email.ilike(f'%{query_key}%'),
|
||||
User.username.ilike(f'%{query_key}%'),
|
||||
)
|
||||
)
|
||||
|
||||
view_option = filter.get("view_option")
|
||||
if view_option == "created":
|
||||
view_option = filter.get('view_option')
|
||||
if view_option == 'created':
|
||||
query = query.filter(Knowledge.user_id == user_id)
|
||||
elif view_option == "shared":
|
||||
elif view_option == 'shared':
|
||||
query = query.filter(Knowledge.user_id != user_id)
|
||||
|
||||
query = AccessGrants.has_permission_filter(
|
||||
@@ -260,8 +240,8 @@ class KnowledgeTable:
|
||||
query=query,
|
||||
DocumentModel=Knowledge,
|
||||
filter=filter,
|
||||
resource_type="knowledge",
|
||||
permission="read",
|
||||
resource_type='knowledge',
|
||||
permission='read',
|
||||
)
|
||||
|
||||
query = query.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc())
|
||||
@@ -275,9 +255,7 @@ class KnowledgeTable:
|
||||
items = query.all()
|
||||
|
||||
knowledge_ids = [kb.id for kb, _ in items]
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"knowledge", knowledge_ids, db=db
|
||||
)
|
||||
grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
|
||||
|
||||
knowledge_bases = []
|
||||
for knowledge_base, user in items:
|
||||
@@ -289,11 +267,7 @@ class KnowledgeTable:
|
||||
access_grants=grants_map.get(knowledge_base.id, []),
|
||||
db=db,
|
||||
).model_dump(),
|
||||
"user": (
|
||||
UserModel.model_validate(user).model_dump()
|
||||
if user
|
||||
else None
|
||||
),
|
||||
'user': (UserModel.model_validate(user).model_dump() if user else None),
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -327,15 +301,15 @@ class KnowledgeTable:
|
||||
query=query,
|
||||
DocumentModel=Knowledge,
|
||||
filter=filter,
|
||||
resource_type="knowledge",
|
||||
permission="read",
|
||||
resource_type='knowledge',
|
||||
permission='read',
|
||||
)
|
||||
|
||||
# Apply filename search
|
||||
if filter:
|
||||
q = filter.get("query")
|
||||
q = filter.get('query')
|
||||
if q:
|
||||
query = query.filter(File.filename.ilike(f"%{q}%"))
|
||||
query = query.filter(File.filename.ilike(f'%{q}%'))
|
||||
|
||||
# Order by file changes
|
||||
query = query.order_by(File.updated_at.desc(), File.id.asc())
|
||||
@@ -355,39 +329,27 @@ class KnowledgeTable:
|
||||
items.append(
|
||||
FileUserResponse(
|
||||
**FileModel.model_validate(file).model_dump(),
|
||||
user=(
|
||||
UserResponse(
|
||||
**UserModel.model_validate(user).model_dump()
|
||||
)
|
||||
if user
|
||||
else None
|
||||
),
|
||||
collection=self._to_knowledge_model(
|
||||
knowledge, db=db
|
||||
).model_dump(),
|
||||
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||
collection=self._to_knowledge_model(knowledge, db=db).model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
return KnowledgeFileListResponse(items=items, total=total)
|
||||
|
||||
except Exception as e:
|
||||
print("search_knowledge_files error:", e)
|
||||
print('search_knowledge_files error:', e)
|
||||
return KnowledgeFileListResponse(items=[], total=0)
|
||||
|
||||
def check_access_by_user_id(
|
||||
self, id, user_id, permission="write", db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool:
|
||||
knowledge = self.get_knowledge_by_id(id, db=db)
|
||||
if not knowledge:
|
||||
return False
|
||||
if knowledge.user_id == user_id:
|
||||
return True
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
return AccessGrants.has_access(
|
||||
user_id=user_id,
|
||||
resource_type="knowledge",
|
||||
resource_type='knowledge',
|
||||
resource_id=knowledge.id,
|
||||
permission=permission,
|
||||
user_group_ids=user_group_ids,
|
||||
@@ -395,19 +357,17 @@ class KnowledgeTable:
|
||||
)
|
||||
|
||||
def get_knowledge_bases_by_user_id(
|
||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
||||
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
|
||||
) -> list[KnowledgeUserModel]:
|
||||
knowledge_bases = self.get_knowledge_bases(db=db)
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
return [
|
||||
knowledge_base
|
||||
for knowledge_base in knowledge_bases
|
||||
if knowledge_base.user_id == user_id
|
||||
or AccessGrants.has_access(
|
||||
user_id=user_id,
|
||||
resource_type="knowledge",
|
||||
resource_type='knowledge',
|
||||
resource_id=knowledge_base.id,
|
||||
permission=permission,
|
||||
user_group_ids=user_group_ids,
|
||||
@@ -415,9 +375,7 @@ class KnowledgeTable:
|
||||
)
|
||||
]
|
||||
|
||||
def get_knowledge_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[KnowledgeModel]:
|
||||
def get_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
knowledge = db.query(Knowledge).filter_by(id=id).first()
|
||||
@@ -435,23 +393,19 @@ class KnowledgeTable:
|
||||
if knowledge.user_id == user_id:
|
||||
return knowledge
|
||||
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
if AccessGrants.has_access(
|
||||
user_id=user_id,
|
||||
resource_type="knowledge",
|
||||
resource_type='knowledge',
|
||||
resource_id=knowledge.id,
|
||||
permission="write",
|
||||
permission='write',
|
||||
user_group_ids=user_group_ids,
|
||||
db=db,
|
||||
):
|
||||
return knowledge
|
||||
return None
|
||||
|
||||
def get_knowledges_by_file_id(
|
||||
self, file_id: str, db: Optional[Session] = None
|
||||
) -> list[KnowledgeModel]:
|
||||
def get_knowledges_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[KnowledgeModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
knowledges = (
|
||||
@@ -461,9 +415,7 @@ class KnowledgeTable:
|
||||
.all()
|
||||
)
|
||||
knowledge_ids = [k.id for k in knowledges]
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"knowledge", knowledge_ids, db=db
|
||||
)
|
||||
grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
|
||||
return [
|
||||
self._to_knowledge_model(
|
||||
knowledge,
|
||||
@@ -497,32 +449,26 @@ class KnowledgeTable:
|
||||
primary_sort = File.updated_at.desc()
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
query_key = filter.get('query')
|
||||
if query_key:
|
||||
query = query.filter(or_(File.filename.ilike(f"%{query_key}%")))
|
||||
query = query.filter(or_(File.filename.ilike(f'%{query_key}%')))
|
||||
|
||||
view_option = filter.get("view_option")
|
||||
if view_option == "created":
|
||||
view_option = filter.get('view_option')
|
||||
if view_option == 'created':
|
||||
query = query.filter(KnowledgeFile.user_id == user_id)
|
||||
elif view_option == "shared":
|
||||
elif view_option == 'shared':
|
||||
query = query.filter(KnowledgeFile.user_id != user_id)
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
is_asc = direction == "asc"
|
||||
order_by = filter.get('order_by')
|
||||
direction = filter.get('direction')
|
||||
is_asc = direction == 'asc'
|
||||
|
||||
if order_by == "name":
|
||||
primary_sort = (
|
||||
File.filename.asc() if is_asc else File.filename.desc()
|
||||
)
|
||||
elif order_by == "created_at":
|
||||
primary_sort = (
|
||||
File.created_at.asc() if is_asc else File.created_at.desc()
|
||||
)
|
||||
elif order_by == "updated_at":
|
||||
primary_sort = (
|
||||
File.updated_at.asc() if is_asc else File.updated_at.desc()
|
||||
)
|
||||
if order_by == 'name':
|
||||
primary_sort = File.filename.asc() if is_asc else File.filename.desc()
|
||||
elif order_by == 'created_at':
|
||||
primary_sort = File.created_at.asc() if is_asc else File.created_at.desc()
|
||||
elif order_by == 'updated_at':
|
||||
primary_sort = File.updated_at.asc() if is_asc else File.updated_at.desc()
|
||||
|
||||
# Apply sort with secondary key for deterministic pagination
|
||||
query = query.order_by(primary_sort, File.id.asc())
|
||||
@@ -542,13 +488,7 @@ class KnowledgeTable:
|
||||
files.append(
|
||||
FileUserResponse(
|
||||
**FileModel.model_validate(file).model_dump(),
|
||||
user=(
|
||||
UserResponse(
|
||||
**UserModel.model_validate(user).model_dump()
|
||||
)
|
||||
if user
|
||||
else None
|
||||
),
|
||||
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -557,9 +497,7 @@ class KnowledgeTable:
|
||||
print(e)
|
||||
return KnowledgeFileListResponse(items=[], total=0)
|
||||
|
||||
def get_files_by_id(
|
||||
self, knowledge_id: str, db: Optional[Session] = None
|
||||
) -> list[FileModel]:
|
||||
def get_files_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
files = (
|
||||
@@ -572,9 +510,7 @@ class KnowledgeTable:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def get_file_metadatas_by_id(
|
||||
self, knowledge_id: str, db: Optional[Session] = None
|
||||
) -> list[FileMetadataResponse]:
|
||||
def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileMetadataResponse]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
files = self.get_files_by_id(knowledge_id, db=db)
|
||||
@@ -592,12 +528,12 @@ class KnowledgeTable:
|
||||
with get_db_context(db) as db:
|
||||
knowledge_file = KnowledgeFileModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
"knowledge_id": knowledge_id,
|
||||
"file_id": file_id,
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
'id': str(uuid.uuid4()),
|
||||
'knowledge_id': knowledge_id,
|
||||
'file_id': file_id,
|
||||
'user_id': user_id,
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -613,37 +549,24 @@ class KnowledgeTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def has_file(
|
||||
self, knowledge_id: str, file_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def has_file(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool:
|
||||
"""Check whether a file belongs to a knowledge base."""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
return (
|
||||
db.query(KnowledgeFile)
|
||||
.filter_by(knowledge_id=knowledge_id, file_id=file_id)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
return db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).first() is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def remove_file_from_knowledge_by_id(
|
||||
self, knowledge_id: str, file_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
db.query(KnowledgeFile).filter_by(
|
||||
knowledge_id=knowledge_id, file_id=file_id
|
||||
).delete()
|
||||
db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def reset_knowledge_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[KnowledgeModel]:
|
||||
def reset_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
# Delete all knowledge_file entries for this knowledge_id
|
||||
@@ -653,7 +576,7 @@ class KnowledgeTable:
|
||||
# Update the knowledge entry's updated_at timestamp
|
||||
db.query(Knowledge).filter_by(id=id).update(
|
||||
{
|
||||
"updated_at": int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
@@ -675,15 +598,13 @@ class KnowledgeTable:
|
||||
knowledge = self.get_knowledge_by_id(id=id, db=db)
|
||||
db.query(Knowledge).filter_by(id=id).update(
|
||||
{
|
||||
**form_data.model_dump(exclude={"access_grants"}),
|
||||
"updated_at": int(time.time()),
|
||||
**form_data.model_dump(exclude={'access_grants'}),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
if form_data.access_grants is not None:
|
||||
AccessGrants.set_access_grants(
|
||||
"knowledge", id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('knowledge', id, form_data.access_grants, db=db)
|
||||
return self.get_knowledge_by_id(id=id, db=db)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
@@ -697,8 +618,8 @@ class KnowledgeTable:
|
||||
knowledge = self.get_knowledge_by_id(id=id, db=db)
|
||||
db.query(Knowledge).filter_by(id=id).update(
|
||||
{
|
||||
"data": data,
|
||||
"updated_at": int(time.time()),
|
||||
'data': data,
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
@@ -710,7 +631,7 @@ class KnowledgeTable:
|
||||
def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
AccessGrants.revoke_all_access("knowledge", id, db=db)
|
||||
AccessGrants.revoke_all_access('knowledge', id, db=db)
|
||||
db.query(Knowledge).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
@@ -722,7 +643,7 @@ class KnowledgeTable:
|
||||
try:
|
||||
knowledge_ids = [row[0] for row in db.query(Knowledge.id).all()]
|
||||
for knowledge_id in knowledge_ids:
|
||||
AccessGrants.revoke_all_access("knowledge", knowledge_id, db=db)
|
||||
AccessGrants.revoke_all_access('knowledge', knowledge_id, db=db)
|
||||
db.query(Knowledge).delete()
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy import BigInteger, Column, String, Text
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
__tablename__ = "memory"
|
||||
__tablename__ = 'memory'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
@@ -49,11 +49,11 @@ class MemoriesTable:
|
||||
|
||||
memory = MemoryModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"content": content,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
'id': id,
|
||||
'user_id': user_id,
|
||||
'content': content,
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
result = Memory(**memory.model_dump())
|
||||
@@ -95,9 +95,7 @@ class MemoriesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_memories_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> list[MemoryModel]:
|
||||
def get_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[MemoryModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
memories = db.query(Memory).filter_by(user_id=user_id).all()
|
||||
@@ -105,9 +103,7 @@ class MemoriesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_memory_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[MemoryModel]:
|
||||
def get_memory_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MemoryModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
memory = db.get(Memory, id)
|
||||
@@ -126,9 +122,7 @@ class MemoriesTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_memories_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
db.query(Memory).filter_by(user_id=user_id).delete()
|
||||
@@ -138,9 +132,7 @@ class MemoriesTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def delete_memory_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
memory = db.get(Memory, id)
|
||||
|
||||
@@ -21,7 +21,7 @@ from sqlalchemy.sql import exists
|
||||
|
||||
|
||||
class MessageReaction(Base):
|
||||
__tablename__ = "message_reaction"
|
||||
__tablename__ = 'message_reaction'
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
user_id = Column(Text)
|
||||
message_id = Column(Text)
|
||||
@@ -40,7 +40,7 @@ class MessageReactionModel(BaseModel):
|
||||
|
||||
|
||||
class Message(Base):
|
||||
__tablename__ = "message"
|
||||
__tablename__ = 'message'
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
|
||||
user_id = Column(Text)
|
||||
@@ -112,7 +112,7 @@ class MessageUserResponse(MessageModel):
|
||||
class MessageUserSlimResponse(MessageUserResponse):
|
||||
data: bool | None = None
|
||||
|
||||
@field_validator("data", mode="before")
|
||||
@field_validator('data', mode='before')
|
||||
def convert_data_to_bool(cls, v):
|
||||
# No data or not a dict → False
|
||||
if not isinstance(v, dict):
|
||||
@@ -152,19 +152,19 @@ class MessageTable:
|
||||
|
||||
message = MessageModel(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"channel_id": channel_id,
|
||||
"reply_to_id": form_data.reply_to_id,
|
||||
"parent_id": form_data.parent_id,
|
||||
"is_pinned": False,
|
||||
"pinned_at": None,
|
||||
"pinned_by": None,
|
||||
"content": form_data.content,
|
||||
"data": form_data.data,
|
||||
"meta": form_data.meta,
|
||||
"created_at": ts,
|
||||
"updated_at": ts,
|
||||
'id': id,
|
||||
'user_id': user_id,
|
||||
'channel_id': channel_id,
|
||||
'reply_to_id': form_data.reply_to_id,
|
||||
'parent_id': form_data.parent_id,
|
||||
'is_pinned': False,
|
||||
'pinned_at': None,
|
||||
'pinned_by': None,
|
||||
'content': form_data.content,
|
||||
'data': form_data.data,
|
||||
'meta': form_data.meta,
|
||||
'created_at': ts,
|
||||
'updated_at': ts,
|
||||
}
|
||||
)
|
||||
result = Message(**message.model_dump())
|
||||
@@ -186,9 +186,7 @@ class MessageTable:
|
||||
return None
|
||||
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(
|
||||
message.reply_to_id, include_thread_replies=False, db=db
|
||||
)
|
||||
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
@@ -200,22 +198,22 @@ class MessageTable:
|
||||
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
|
||||
|
||||
# Check if message was sent by webhook (webhook info in meta takes precedence)
|
||||
webhook_info = message.meta.get("webhook") if message.meta else None
|
||||
if webhook_info and webhook_info.get("id"):
|
||||
webhook_info = message.meta.get('webhook') if message.meta else None
|
||||
if webhook_info and webhook_info.get('id'):
|
||||
# Look up webhook by ID to get current name
|
||||
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
|
||||
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
||||
if webhook:
|
||||
user_info = {
|
||||
"id": webhook.id,
|
||||
"name": webhook.name,
|
||||
"role": "webhook",
|
||||
'id': webhook.id,
|
||||
'name': webhook.name,
|
||||
'role': 'webhook',
|
||||
}
|
||||
else:
|
||||
# Webhook was deleted, use placeholder
|
||||
user_info = {
|
||||
"id": webhook_info.get("id"),
|
||||
"name": "Deleted Webhook",
|
||||
"role": "webhook",
|
||||
'id': webhook_info.get('id'),
|
||||
'name': 'Deleted Webhook',
|
||||
'role': 'webhook',
|
||||
}
|
||||
else:
|
||||
user = Users.get_user_by_id(message.user_id, db=db)
|
||||
@@ -224,79 +222,57 @@ class MessageTable:
|
||||
return MessageResponse.model_validate(
|
||||
{
|
||||
**MessageModel.model_validate(message).model_dump(),
|
||||
"user": user_info,
|
||||
"reply_to_message": (
|
||||
reply_to_message.model_dump() if reply_to_message else None
|
||||
),
|
||||
"latest_reply_at": (
|
||||
thread_replies[0].created_at if thread_replies else None
|
||||
),
|
||||
"reply_count": len(thread_replies),
|
||||
"reactions": reactions,
|
||||
'user': user_info,
|
||||
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
||||
'latest_reply_at': (thread_replies[0].created_at if thread_replies else None),
|
||||
'reply_count': len(thread_replies),
|
||||
'reactions': reactions,
|
||||
}
|
||||
)
|
||||
|
||||
def get_thread_replies_by_message_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> list[MessageReplyToResponse]:
|
||||
def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]:
|
||||
with get_db_context(db) as db:
|
||||
all_messages = (
|
||||
db.query(Message)
|
||||
.filter_by(parent_id=id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
all_messages = db.query(Message).filter_by(parent_id=id).order_by(Message.created_at.desc()).all()
|
||||
|
||||
messages = []
|
||||
for message in all_messages:
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(
|
||||
message.reply_to_id, include_thread_replies=False, db=db
|
||||
)
|
||||
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
|
||||
webhook_info = message.meta.get("webhook") if message.meta else None
|
||||
webhook_info = message.meta.get('webhook') if message.meta else None
|
||||
user_info = None
|
||||
if webhook_info and webhook_info.get("id"):
|
||||
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
|
||||
if webhook_info and webhook_info.get('id'):
|
||||
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
||||
if webhook:
|
||||
user_info = {
|
||||
"id": webhook.id,
|
||||
"name": webhook.name,
|
||||
"role": "webhook",
|
||||
'id': webhook.id,
|
||||
'name': webhook.name,
|
||||
'role': 'webhook',
|
||||
}
|
||||
else:
|
||||
user_info = {
|
||||
"id": webhook_info.get("id"),
|
||||
"name": "Deleted Webhook",
|
||||
"role": "webhook",
|
||||
'id': webhook_info.get('id'),
|
||||
'name': 'Deleted Webhook',
|
||||
'role': 'webhook',
|
||||
}
|
||||
|
||||
messages.append(
|
||||
MessageReplyToResponse.model_validate(
|
||||
{
|
||||
**MessageModel.model_validate(message).model_dump(),
|
||||
"user": user_info,
|
||||
"reply_to_message": (
|
||||
reply_to_message.model_dump()
|
||||
if reply_to_message
|
||||
else None
|
||||
),
|
||||
'user': user_info,
|
||||
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
||||
}
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
def get_reply_user_ids_by_message_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> list[str]:
|
||||
def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
message.user_id
|
||||
for message in db.query(Message).filter_by(parent_id=id).all()
|
||||
]
|
||||
return [message.user_id for message in db.query(Message).filter_by(parent_id=id).all()]
|
||||
|
||||
def get_messages_by_channel_id(
|
||||
self,
|
||||
@@ -318,40 +294,34 @@ class MessageTable:
|
||||
messages = []
|
||||
for message in all_messages:
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(
|
||||
message.reply_to_id, include_thread_replies=False, db=db
|
||||
)
|
||||
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
|
||||
webhook_info = message.meta.get("webhook") if message.meta else None
|
||||
webhook_info = message.meta.get('webhook') if message.meta else None
|
||||
user_info = None
|
||||
if webhook_info and webhook_info.get("id"):
|
||||
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
|
||||
if webhook_info and webhook_info.get('id'):
|
||||
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
||||
if webhook:
|
||||
user_info = {
|
||||
"id": webhook.id,
|
||||
"name": webhook.name,
|
||||
"role": "webhook",
|
||||
'id': webhook.id,
|
||||
'name': webhook.name,
|
||||
'role': 'webhook',
|
||||
}
|
||||
else:
|
||||
user_info = {
|
||||
"id": webhook_info.get("id"),
|
||||
"name": "Deleted Webhook",
|
||||
"role": "webhook",
|
||||
'id': webhook_info.get('id'),
|
||||
'name': 'Deleted Webhook',
|
||||
'role': 'webhook',
|
||||
}
|
||||
|
||||
messages.append(
|
||||
MessageReplyToResponse.model_validate(
|
||||
{
|
||||
**MessageModel.model_validate(message).model_dump(),
|
||||
"user": user_info,
|
||||
"reply_to_message": (
|
||||
reply_to_message.model_dump()
|
||||
if reply_to_message
|
||||
else None
|
||||
),
|
||||
'user': user_info,
|
||||
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -387,55 +357,42 @@ class MessageTable:
|
||||
messages = []
|
||||
for message in all_messages:
|
||||
reply_to_message = (
|
||||
self.get_message_by_id(
|
||||
message.reply_to_id, include_thread_replies=False, db=db
|
||||
)
|
||||
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
||||
if message.reply_to_id
|
||||
else None
|
||||
)
|
||||
|
||||
webhook_info = message.meta.get("webhook") if message.meta else None
|
||||
webhook_info = message.meta.get('webhook') if message.meta else None
|
||||
user_info = None
|
||||
if webhook_info and webhook_info.get("id"):
|
||||
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
|
||||
if webhook_info and webhook_info.get('id'):
|
||||
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
||||
if webhook:
|
||||
user_info = {
|
||||
"id": webhook.id,
|
||||
"name": webhook.name,
|
||||
"role": "webhook",
|
||||
'id': webhook.id,
|
||||
'name': webhook.name,
|
||||
'role': 'webhook',
|
||||
}
|
||||
else:
|
||||
user_info = {
|
||||
"id": webhook_info.get("id"),
|
||||
"name": "Deleted Webhook",
|
||||
"role": "webhook",
|
||||
'id': webhook_info.get('id'),
|
||||
'name': 'Deleted Webhook',
|
||||
'role': 'webhook',
|
||||
}
|
||||
|
||||
messages.append(
|
||||
MessageReplyToResponse.model_validate(
|
||||
{
|
||||
**MessageModel.model_validate(message).model_dump(),
|
||||
"user": user_info,
|
||||
"reply_to_message": (
|
||||
reply_to_message.model_dump()
|
||||
if reply_to_message
|
||||
else None
|
||||
),
|
||||
'user': user_info,
|
||||
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
||||
}
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
def get_last_message_by_channel_id(
|
||||
self, channel_id: str, db: Optional[Session] = None
|
||||
) -> Optional[MessageModel]:
|
||||
def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]:
|
||||
with get_db_context(db) as db:
|
||||
message = (
|
||||
db.query(Message)
|
||||
.filter_by(channel_id=channel_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
message = db.query(Message).filter_by(channel_id=channel_id).order_by(Message.created_at.desc()).first()
|
||||
return MessageModel.model_validate(message) if message else None
|
||||
|
||||
def get_pinned_messages_by_channel_id(
|
||||
@@ -513,11 +470,7 @@ class MessageTable:
|
||||
) -> Optional[MessageReactionModel]:
|
||||
with get_db_context(db) as db:
|
||||
# check for existing reaction
|
||||
existing_reaction = (
|
||||
db.query(MessageReaction)
|
||||
.filter_by(message_id=id, user_id=user_id, name=name)
|
||||
.first()
|
||||
)
|
||||
existing_reaction = db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).first()
|
||||
if existing_reaction:
|
||||
return MessageReactionModel.model_validate(existing_reaction)
|
||||
|
||||
@@ -535,9 +488,7 @@ class MessageTable:
|
||||
db.refresh(result)
|
||||
return MessageReactionModel.model_validate(result) if result else None
|
||||
|
||||
def get_reactions_by_message_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> list[Reactions]:
|
||||
def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]:
|
||||
with get_db_context(db) as db:
|
||||
# JOIN User so all user info is fetched in one query
|
||||
results = (
|
||||
@@ -552,18 +503,18 @@ class MessageTable:
|
||||
for reaction, user in results:
|
||||
if reaction.name not in reactions:
|
||||
reactions[reaction.name] = {
|
||||
"name": reaction.name,
|
||||
"users": [],
|
||||
"count": 0,
|
||||
'name': reaction.name,
|
||||
'users': [],
|
||||
'count': 0,
|
||||
}
|
||||
|
||||
reactions[reaction.name]["users"].append(
|
||||
reactions[reaction.name]['users'].append(
|
||||
{
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
'id': user.id,
|
||||
'name': user.name,
|
||||
}
|
||||
)
|
||||
reactions[reaction.name]["count"] += 1
|
||||
reactions[reaction.name]['count'] += 1
|
||||
|
||||
return [Reactions(**reaction) for reaction in reactions.values()]
|
||||
|
||||
@@ -571,9 +522,7 @@ class MessageTable:
|
||||
self, id: str, user_id: str, name: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
with get_db_context(db) as db:
|
||||
db.query(MessageReaction).filter_by(
|
||||
message_id=id, user_id=user_id, name=name
|
||||
).delete()
|
||||
db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
@@ -612,21 +561,15 @@ class MessageTable:
|
||||
with get_db_context(db) as db:
|
||||
query_builder = db.query(Message).filter(
|
||||
Message.channel_id.in_(channel_ids),
|
||||
Message.content.ilike(f"%{query}%"),
|
||||
Message.content.ilike(f'%{query}%'),
|
||||
)
|
||||
|
||||
if start_timestamp:
|
||||
query_builder = query_builder.filter(
|
||||
Message.created_at >= start_timestamp
|
||||
)
|
||||
query_builder = query_builder.filter(Message.created_at >= start_timestamp)
|
||||
if end_timestamp:
|
||||
query_builder = query_builder.filter(
|
||||
Message.created_at <= end_timestamp
|
||||
)
|
||||
query_builder = query_builder.filter(Message.created_at <= end_timestamp)
|
||||
|
||||
messages = (
|
||||
query_builder.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
)
|
||||
messages = query_builder.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
return [MessageModel.model_validate(msg) for msg in messages]
|
||||
|
||||
|
||||
|
||||
@@ -28,13 +28,13 @@ log = logging.getLogger(__name__)
|
||||
|
||||
# ModelParams is a model for the data stored in the params field of the Model table
|
||||
class ModelParams(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
pass
|
||||
|
||||
|
||||
# ModelMeta is a model for the data stored in the meta field of the Model table
|
||||
class ModelMeta(BaseModel):
|
||||
profile_image_url: Optional[str] = "/static/favicon.png"
|
||||
profile_image_url: Optional[str] = '/static/favicon.png'
|
||||
|
||||
description: Optional[str] = None
|
||||
"""
|
||||
@@ -43,13 +43,13 @@ class ModelMeta(BaseModel):
|
||||
|
||||
capabilities: Optional[dict] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Model(Base):
|
||||
__tablename__ = "model"
|
||||
__tablename__ = 'model'
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
"""
|
||||
@@ -139,10 +139,8 @@ class ModelForm(BaseModel):
|
||||
|
||||
|
||||
class ModelsTable:
|
||||
def _get_access_grants(
|
||||
self, model_id: str, db: Optional[Session] = None
|
||||
) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource("model", model_id, db=db)
|
||||
def _get_access_grants(self, model_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource('model', model_id, db=db)
|
||||
|
||||
def _to_model_model(
|
||||
self,
|
||||
@@ -150,13 +148,9 @@ class ModelsTable:
|
||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> ModelModel:
|
||||
model_data = ModelModel.model_validate(model).model_dump(
|
||||
exclude={"access_grants"}
|
||||
)
|
||||
model_data["access_grants"] = (
|
||||
access_grants
|
||||
if access_grants is not None
|
||||
else self._get_access_grants(model_data["id"], db=db)
|
||||
model_data = ModelModel.model_validate(model).model_dump(exclude={'access_grants'})
|
||||
model_data['access_grants'] = (
|
||||
access_grants if access_grants is not None else self._get_access_grants(model_data['id'], db=db)
|
||||
)
|
||||
return ModelModel.model_validate(model_data)
|
||||
|
||||
@@ -167,37 +161,32 @@ class ModelsTable:
|
||||
with get_db_context(db) as db:
|
||||
result = Model(
|
||||
**{
|
||||
**form_data.model_dump(exclude={"access_grants"}),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
**form_data.model_dump(exclude={'access_grants'}),
|
||||
'user_id': user_id,
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
AccessGrants.set_access_grants(
|
||||
"model", result.id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('model', result.id, form_data.access_grants, db=db)
|
||||
|
||||
if result:
|
||||
return self._to_model_model(result, db=db)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to insert a new model: {e}")
|
||||
log.exception(f'Failed to insert a new model: {e}')
|
||||
return None
|
||||
|
||||
def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]:
|
||||
with get_db_context(db) as db:
|
||||
all_models = db.query(Model).all()
|
||||
model_ids = [model.id for model in all_models]
|
||||
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
|
||||
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||
return [
|
||||
self._to_model_model(
|
||||
model, access_grants=grants_map.get(model.id, []), db=db
|
||||
)
|
||||
for model in all_models
|
||||
self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models
|
||||
]
|
||||
|
||||
def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]:
|
||||
@@ -209,7 +198,7 @@ class ModelsTable:
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
|
||||
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||
|
||||
models = []
|
||||
for model in all_models:
|
||||
@@ -222,7 +211,7 @@ class ModelsTable:
|
||||
access_grants=grants_map.get(model.id, []),
|
||||
db=db,
|
||||
).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
'user': user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -232,28 +221,23 @@ class ModelsTable:
|
||||
with get_db_context(db) as db:
|
||||
all_models = db.query(Model).filter(Model.base_model_id == None).all()
|
||||
model_ids = [model.id for model in all_models]
|
||||
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
|
||||
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||
return [
|
||||
self._to_model_model(
|
||||
model, access_grants=grants_map.get(model.id, []), db=db
|
||||
)
|
||||
for model in all_models
|
||||
self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models
|
||||
]
|
||||
|
||||
def get_models_by_user_id(
|
||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
||||
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
|
||||
) -> list[ModelUserResponse]:
|
||||
models = self.get_models(db=db)
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
return [
|
||||
model
|
||||
for model in models
|
||||
if model.user_id == user_id
|
||||
or AccessGrants.has_access(
|
||||
user_id=user_id,
|
||||
resource_type="model",
|
||||
resource_type='model',
|
||||
resource_id=model.id,
|
||||
permission=permission,
|
||||
user_group_ids=user_group_ids,
|
||||
@@ -261,13 +245,13 @@ class ModelsTable:
|
||||
)
|
||||
]
|
||||
|
||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
||||
def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
|
||||
return AccessGrants.has_permission_filter(
|
||||
db=db,
|
||||
query=query,
|
||||
DocumentModel=Model,
|
||||
filter=filter,
|
||||
resource_type="model",
|
||||
resource_type='model',
|
||||
permission=permission,
|
||||
)
|
||||
|
||||
@@ -285,22 +269,22 @@ class ModelsTable:
|
||||
query = query.filter(Model.base_model_id != None)
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
query_key = filter.get('query')
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Model.name.ilike(f"%{query_key}%"),
|
||||
Model.base_model_id.ilike(f"%{query_key}%"),
|
||||
User.name.ilike(f"%{query_key}%"),
|
||||
User.email.ilike(f"%{query_key}%"),
|
||||
User.username.ilike(f"%{query_key}%"),
|
||||
Model.name.ilike(f'%{query_key}%'),
|
||||
Model.base_model_id.ilike(f'%{query_key}%'),
|
||||
User.name.ilike(f'%{query_key}%'),
|
||||
User.email.ilike(f'%{query_key}%'),
|
||||
User.username.ilike(f'%{query_key}%'),
|
||||
)
|
||||
)
|
||||
|
||||
view_option = filter.get("view_option")
|
||||
if view_option == "created":
|
||||
view_option = filter.get('view_option')
|
||||
if view_option == 'created':
|
||||
query = query.filter(Model.user_id == user_id)
|
||||
elif view_option == "shared":
|
||||
elif view_option == 'shared':
|
||||
query = query.filter(Model.user_id != user_id)
|
||||
|
||||
# Apply access control filtering
|
||||
@@ -308,10 +292,10 @@ class ModelsTable:
|
||||
db,
|
||||
query,
|
||||
filter,
|
||||
permission="read",
|
||||
permission='read',
|
||||
)
|
||||
|
||||
tag = filter.get("tag")
|
||||
tag = filter.get('tag')
|
||||
if tag:
|
||||
# TODO: This is a simple implementation and should be improved for performance
|
||||
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
|
||||
@@ -319,21 +303,21 @@ class ModelsTable:
|
||||
|
||||
query = query.filter(meta_text.like(like_pattern))
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
order_by = filter.get('order_by')
|
||||
direction = filter.get('direction')
|
||||
|
||||
if order_by == "name":
|
||||
if direction == "asc":
|
||||
if order_by == 'name':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Model.name.asc())
|
||||
else:
|
||||
query = query.order_by(Model.name.desc())
|
||||
elif order_by == "created_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'created_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Model.created_at.asc())
|
||||
else:
|
||||
query = query.order_by(Model.created_at.desc())
|
||||
elif order_by == "updated_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'updated_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Model.updated_at.asc())
|
||||
else:
|
||||
query = query.order_by(Model.updated_at.desc())
|
||||
@@ -352,7 +336,7 @@ class ModelsTable:
|
||||
items = query.all()
|
||||
|
||||
model_ids = [model.id for model, _ in items]
|
||||
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
|
||||
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||
|
||||
models = []
|
||||
for model, user in items:
|
||||
@@ -363,19 +347,13 @@ class ModelsTable:
|
||||
access_grants=grants_map.get(model.id, []),
|
||||
db=db,
|
||||
).model_dump(),
|
||||
user=(
|
||||
UserResponse(**UserModel.model_validate(user).model_dump())
|
||||
if user
|
||||
else None
|
||||
),
|
||||
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelListResponse(items=models, total=total)
|
||||
|
||||
def get_model_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[ModelModel]:
|
||||
def get_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
model = db.get(Model, id)
|
||||
@@ -383,16 +361,12 @@ class ModelsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_models_by_ids(
|
||||
self, ids: list[str], db: Optional[Session] = None
|
||||
) -> list[ModelModel]:
|
||||
def get_models_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[ModelModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
models = db.query(Model).filter(Model.id.in_(ids)).all()
|
||||
model_ids = [model.id for model in models]
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"model", model_ids, db=db
|
||||
)
|
||||
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||
return [
|
||||
self._to_model_model(
|
||||
model,
|
||||
@@ -404,9 +378,7 @@ class ModelsTable:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def toggle_model_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[ModelModel]:
|
||||
def toggle_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
model = db.query(Model).filter_by(id=id).first()
|
||||
@@ -422,30 +394,26 @@ class ModelsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_model_by_id(
|
||||
self, id: str, model: ModelForm, db: Optional[Session] = None
|
||||
) -> Optional[ModelModel]:
|
||||
def update_model_by_id(self, id: str, model: ModelForm, db: Optional[Session] = None) -> Optional[ModelModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
# update only the fields that are present in the model
|
||||
data = model.model_dump(exclude={"id", "access_grants"})
|
||||
data = model.model_dump(exclude={'id', 'access_grants'})
|
||||
result = db.query(Model).filter_by(id=id).update(data)
|
||||
|
||||
db.commit()
|
||||
if model.access_grants is not None:
|
||||
AccessGrants.set_access_grants(
|
||||
"model", id, model.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('model', id, model.access_grants, db=db)
|
||||
|
||||
return self.get_model_by_id(id, db=db)
|
||||
except Exception as e:
|
||||
log.exception(f"Failed to update the model by id {id}: {e}")
|
||||
log.exception(f'Failed to update the model by id {id}: {e}')
|
||||
return None
|
||||
|
||||
def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
AccessGrants.revoke_all_access("model", id, db=db)
|
||||
AccessGrants.revoke_all_access('model', id, db=db)
|
||||
db.query(Model).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
@@ -458,7 +426,7 @@ class ModelsTable:
|
||||
with get_db_context(db) as db:
|
||||
model_ids = [row[0] for row in db.query(Model.id).all()]
|
||||
for model_id in model_ids:
|
||||
AccessGrants.revoke_all_access("model", model_id, db=db)
|
||||
AccessGrants.revoke_all_access('model', model_id, db=db)
|
||||
db.query(Model).delete()
|
||||
db.commit()
|
||||
|
||||
@@ -466,9 +434,7 @@ class ModelsTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def sync_models(
|
||||
self, user_id: str, models: list[ModelModel], db: Optional[Session] = None
|
||||
) -> list[ModelModel]:
|
||||
def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[Session] = None) -> list[ModelModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
# Get existing models
|
||||
@@ -483,37 +449,33 @@ class ModelsTable:
|
||||
if model.id in existing_ids:
|
||||
db.query(Model).filter_by(id=model.id).update(
|
||||
{
|
||||
**model.model_dump(exclude={"access_grants"}),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
**model.model_dump(exclude={'access_grants'}),
|
||||
'user_id': user_id,
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
else:
|
||||
new_model = Model(
|
||||
**{
|
||||
**model.model_dump(exclude={"access_grants"}),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
**model.model_dump(exclude={'access_grants'}),
|
||||
'user_id': user_id,
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.add(new_model)
|
||||
AccessGrants.set_access_grants(
|
||||
"model", model.id, model.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('model', model.id, model.access_grants, db=db)
|
||||
|
||||
# Remove models that are no longer present
|
||||
for model in existing_models:
|
||||
if model.id not in new_model_ids:
|
||||
AccessGrants.revoke_all_access("model", model.id, db=db)
|
||||
AccessGrants.revoke_all_access('model', model.id, db=db)
|
||||
db.delete(model)
|
||||
|
||||
db.commit()
|
||||
|
||||
all_models = db.query(Model).all()
|
||||
model_ids = [model.id for model in all_models]
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"model", model_ids, db=db
|
||||
)
|
||||
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||
return [
|
||||
self._to_model_model(
|
||||
model,
|
||||
@@ -523,7 +485,7 @@ class ModelsTable:
|
||||
for model in all_models
|
||||
]
|
||||
except Exception as e:
|
||||
log.exception(f"Error syncing models for user {user_id}: {e}")
|
||||
log.exception(f'Error syncing models for user {user_id}: {e}')
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from sqlalchemy import or_, func, cast
|
||||
|
||||
|
||||
class Note(Base):
|
||||
__tablename__ = "note"
|
||||
__tablename__ = 'note'
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
user_id = Column(Text)
|
||||
@@ -88,10 +88,8 @@ class NoteListResponse(BaseModel):
|
||||
|
||||
|
||||
class NoteTable:
|
||||
def _get_access_grants(
|
||||
self, note_id: str, db: Optional[Session] = None
|
||||
) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource("note", note_id, db=db)
|
||||
def _get_access_grants(self, note_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource('note', note_id, db=db)
|
||||
|
||||
def _to_note_model(
|
||||
self,
|
||||
@@ -99,51 +97,43 @@ class NoteTable:
|
||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> NoteModel:
|
||||
note_data = NoteModel.model_validate(note).model_dump(exclude={"access_grants"})
|
||||
note_data["access_grants"] = (
|
||||
access_grants
|
||||
if access_grants is not None
|
||||
else self._get_access_grants(note_data["id"], db=db)
|
||||
note_data = NoteModel.model_validate(note).model_dump(exclude={'access_grants'})
|
||||
note_data['access_grants'] = (
|
||||
access_grants if access_grants is not None else self._get_access_grants(note_data['id'], db=db)
|
||||
)
|
||||
return NoteModel.model_validate(note_data)
|
||||
|
||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
||||
def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
|
||||
return AccessGrants.has_permission_filter(
|
||||
db=db,
|
||||
query=query,
|
||||
DocumentModel=Note,
|
||||
filter=filter,
|
||||
resource_type="note",
|
||||
resource_type='note',
|
||||
permission=permission,
|
||||
)
|
||||
|
||||
def insert_new_note(
|
||||
self, user_id: str, form_data: NoteForm, db: Optional[Session] = None
|
||||
) -> Optional[NoteModel]:
|
||||
def insert_new_note(self, user_id: str, form_data: NoteForm, db: Optional[Session] = None) -> Optional[NoteModel]:
|
||||
with get_db_context(db) as db:
|
||||
note = NoteModel(
|
||||
**{
|
||||
"id": str(uuid.uuid4()),
|
||||
"user_id": user_id,
|
||||
**form_data.model_dump(exclude={"access_grants"}),
|
||||
"created_at": int(time.time_ns()),
|
||||
"updated_at": int(time.time_ns()),
|
||||
"access_grants": [],
|
||||
'id': str(uuid.uuid4()),
|
||||
'user_id': user_id,
|
||||
**form_data.model_dump(exclude={'access_grants'}),
|
||||
'created_at': int(time.time_ns()),
|
||||
'updated_at': int(time.time_ns()),
|
||||
'access_grants': [],
|
||||
}
|
||||
)
|
||||
|
||||
new_note = Note(**note.model_dump(exclude={"access_grants"}))
|
||||
new_note = Note(**note.model_dump(exclude={'access_grants'}))
|
||||
|
||||
db.add(new_note)
|
||||
db.commit()
|
||||
AccessGrants.set_access_grants(
|
||||
"note", note.id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('note', note.id, form_data.access_grants, db=db)
|
||||
return self._to_note_model(new_note, db=db)
|
||||
|
||||
def get_notes(
|
||||
self, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
||||
) -> list[NoteModel]:
|
||||
def get_notes(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[NoteModel]:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Note).order_by(Note.updated_at.desc())
|
||||
if skip is not None:
|
||||
@@ -152,13 +142,8 @@ class NoteTable:
|
||||
query = query.limit(limit)
|
||||
notes = query.all()
|
||||
note_ids = [note.id for note in notes]
|
||||
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db)
|
||||
return [
|
||||
self._to_note_model(
|
||||
note, access_grants=grants_map.get(note.id, []), db=db
|
||||
)
|
||||
for note in notes
|
||||
]
|
||||
grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
|
||||
return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes]
|
||||
|
||||
def search_notes(
|
||||
self,
|
||||
@@ -171,36 +156,32 @@ class NoteTable:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Note, User).outerjoin(User, User.id == Note.user_id)
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
query_key = filter.get('query')
|
||||
if query_key:
|
||||
# Normalize search by removing hyphens and spaces (e.g., "todo" matches "to-do" and "to do")
|
||||
normalized_query = query_key.replace("-", "").replace(" ", "")
|
||||
normalized_query = query_key.replace('-', '').replace(' ', '')
|
||||
query = query.filter(
|
||||
or_(
|
||||
func.replace(func.replace(Note.title, '-', ''), ' ', '').ilike(f'%{normalized_query}%'),
|
||||
func.replace(
|
||||
func.replace(Note.title, "-", ""), " ", ""
|
||||
).ilike(f"%{normalized_query}%"),
|
||||
func.replace(
|
||||
func.replace(
|
||||
cast(Note.data["content"]["md"], Text), "-", ""
|
||||
),
|
||||
" ",
|
||||
"",
|
||||
).ilike(f"%{normalized_query}%"),
|
||||
func.replace(cast(Note.data['content']['md'], Text), '-', ''),
|
||||
' ',
|
||||
'',
|
||||
).ilike(f'%{normalized_query}%'),
|
||||
)
|
||||
)
|
||||
|
||||
view_option = filter.get("view_option")
|
||||
if view_option == "created":
|
||||
view_option = filter.get('view_option')
|
||||
if view_option == 'created':
|
||||
query = query.filter(Note.user_id == user_id)
|
||||
elif view_option == "shared":
|
||||
elif view_option == 'shared':
|
||||
query = query.filter(Note.user_id != user_id)
|
||||
|
||||
# Apply access control filtering
|
||||
if "permission" in filter:
|
||||
permission = filter["permission"]
|
||||
if 'permission' in filter:
|
||||
permission = filter['permission']
|
||||
else:
|
||||
permission = "write"
|
||||
permission = 'write'
|
||||
|
||||
query = self._has_permission(
|
||||
db,
|
||||
@@ -209,21 +190,21 @@ class NoteTable:
|
||||
permission=permission,
|
||||
)
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
order_by = filter.get('order_by')
|
||||
direction = filter.get('direction')
|
||||
|
||||
if order_by == "name":
|
||||
if direction == "asc":
|
||||
if order_by == 'name':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Note.title.asc())
|
||||
else:
|
||||
query = query.order_by(Note.title.desc())
|
||||
elif order_by == "created_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'created_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Note.created_at.asc())
|
||||
else:
|
||||
query = query.order_by(Note.created_at.desc())
|
||||
elif order_by == "updated_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'updated_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Note.updated_at.asc())
|
||||
else:
|
||||
query = query.order_by(Note.updated_at.desc())
|
||||
@@ -244,7 +225,7 @@ class NoteTable:
|
||||
items = query.all()
|
||||
|
||||
note_ids = [note.id for note, _ in items]
|
||||
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db)
|
||||
grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
|
||||
|
||||
notes = []
|
||||
for note, user in items:
|
||||
@@ -255,11 +236,7 @@ class NoteTable:
|
||||
access_grants=grants_map.get(note.id, []),
|
||||
db=db,
|
||||
).model_dump(),
|
||||
user=(
|
||||
UserResponse(**UserModel.model_validate(user).model_dump())
|
||||
if user
|
||||
else None
|
||||
),
|
||||
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -268,20 +245,16 @@ class NoteTable:
|
||||
def get_notes_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
permission: str = "read",
|
||||
permission: str = 'read',
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[NoteModel]:
|
||||
with get_db_context(db) as db:
|
||||
user_group_ids = [
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
]
|
||||
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||
|
||||
query = db.query(Note).order_by(Note.updated_at.desc())
|
||||
query = self._has_permission(
|
||||
db, query, {"user_id": user_id, "group_ids": user_group_ids}, permission
|
||||
)
|
||||
query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids}, permission)
|
||||
|
||||
if skip is not None:
|
||||
query = query.offset(skip)
|
||||
@@ -290,17 +263,10 @@ class NoteTable:
|
||||
|
||||
notes = query.all()
|
||||
note_ids = [note.id for note in notes]
|
||||
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db)
|
||||
return [
|
||||
self._to_note_model(
|
||||
note, access_grants=grants_map.get(note.id, []), db=db
|
||||
)
|
||||
for note in notes
|
||||
]
|
||||
grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
|
||||
return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes]
|
||||
|
||||
def get_note_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[NoteModel]:
|
||||
def get_note_by_id(self, id: str, db: Optional[Session] = None) -> Optional[NoteModel]:
|
||||
with get_db_context(db) as db:
|
||||
note = db.query(Note).filter(Note.id == id).first()
|
||||
return self._to_note_model(note, db=db) if note else None
|
||||
@@ -315,17 +281,15 @@ class NoteTable:
|
||||
|
||||
form_data = form_data.model_dump(exclude_unset=True)
|
||||
|
||||
if "title" in form_data:
|
||||
note.title = form_data["title"]
|
||||
if "data" in form_data:
|
||||
note.data = {**note.data, **form_data["data"]}
|
||||
if "meta" in form_data:
|
||||
note.meta = {**note.meta, **form_data["meta"]}
|
||||
if 'title' in form_data:
|
||||
note.title = form_data['title']
|
||||
if 'data' in form_data:
|
||||
note.data = {**note.data, **form_data['data']}
|
||||
if 'meta' in form_data:
|
||||
note.meta = {**note.meta, **form_data['meta']}
|
||||
|
||||
if "access_grants" in form_data:
|
||||
AccessGrants.set_access_grants(
|
||||
"note", id, form_data["access_grants"], db=db
|
||||
)
|
||||
if 'access_grants' in form_data:
|
||||
AccessGrants.set_access_grants('note', id, form_data['access_grants'], db=db)
|
||||
|
||||
note.updated_at = int(time.time_ns())
|
||||
|
||||
@@ -335,7 +299,7 @@ class NoteTable:
|
||||
def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
AccessGrants.revoke_all_access("note", id, db=db)
|
||||
AccessGrants.revoke_all_access('note', id, db=db)
|
||||
db.query(Note).filter(Note.id == id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
@@ -23,23 +23,21 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthSession(Base):
|
||||
__tablename__ = "oauth_session"
|
||||
__tablename__ = 'oauth_session'
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
user_id = Column(Text, nullable=False)
|
||||
provider = Column(Text, nullable=False)
|
||||
token = Column(
|
||||
Text, nullable=False
|
||||
) # JSON with access_token, id_token, refresh_token
|
||||
token = Column(Text, nullable=False) # JSON with access_token, id_token, refresh_token
|
||||
expires_at = Column(BigInteger, nullable=False)
|
||||
created_at = Column(BigInteger, nullable=False)
|
||||
updated_at = Column(BigInteger, nullable=False)
|
||||
|
||||
# Add indexes for better performance
|
||||
__table_args__ = (
|
||||
Index("idx_oauth_session_user_id", "user_id"),
|
||||
Index("idx_oauth_session_expires_at", "expires_at"),
|
||||
Index("idx_oauth_session_user_provider", "user_id", "provider"),
|
||||
Index('idx_oauth_session_user_id', 'user_id'),
|
||||
Index('idx_oauth_session_expires_at', 'expires_at'),
|
||||
Index('idx_oauth_session_user_provider', 'user_id', 'provider'),
|
||||
)
|
||||
|
||||
|
||||
@@ -71,7 +69,7 @@ class OAuthSessionTable:
|
||||
def __init__(self):
|
||||
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||
if not self.encryption_key:
|
||||
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
|
||||
raise Exception('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set')
|
||||
|
||||
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
|
||||
if len(self.encryption_key) != 44:
|
||||
@@ -83,7 +81,7 @@ class OAuthSessionTable:
|
||||
try:
|
||||
self.fernet = Fernet(self.encryption_key)
|
||||
except Exception as e:
|
||||
log.error(f"Error initializing Fernet with provided key: {e}")
|
||||
log.error(f'Error initializing Fernet with provided key: {e}')
|
||||
raise
|
||||
|
||||
def _encrypt_token(self, token) -> str:
|
||||
@@ -93,7 +91,7 @@ class OAuthSessionTable:
|
||||
encrypted = self.fernet.encrypt(token_json.encode()).decode()
|
||||
return encrypted
|
||||
except Exception as e:
|
||||
log.error(f"Error encrypting tokens: {e}")
|
||||
log.error(f'Error encrypting tokens: {e}')
|
||||
raise
|
||||
|
||||
def _decrypt_token(self, token: str):
|
||||
@@ -102,7 +100,7 @@ class OAuthSessionTable:
|
||||
decrypted = self.fernet.decrypt(token.encode()).decode()
|
||||
return json.loads(decrypted)
|
||||
except Exception as e:
|
||||
log.error(f"Error decrypting tokens: {type(e).__name__}: {e}")
|
||||
log.error(f'Error decrypting tokens: {type(e).__name__}: {e}')
|
||||
raise
|
||||
|
||||
def create_session(
|
||||
@@ -120,13 +118,13 @@ class OAuthSessionTable:
|
||||
|
||||
result = OAuthSession(
|
||||
**{
|
||||
"id": id,
|
||||
"user_id": user_id,
|
||||
"provider": provider,
|
||||
"token": self._encrypt_token(token),
|
||||
"expires_at": token.get("expires_at"),
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
'id': id,
|
||||
'user_id': user_id,
|
||||
'provider': provider,
|
||||
'token': self._encrypt_token(token),
|
||||
'expires_at': token.get('expires_at'),
|
||||
'created_at': current_time,
|
||||
'updated_at': current_time,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -141,12 +139,10 @@ class OAuthSessionTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error creating OAuth session: {e}")
|
||||
log.error(f'Error creating OAuth session: {e}')
|
||||
return None
|
||||
|
||||
def get_session_by_id(
|
||||
self, session_id: str, db: Optional[Session] = None
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]:
|
||||
"""Get OAuth session by ID"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -158,7 +154,7 @@ class OAuthSessionTable:
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth session by ID: {e}")
|
||||
log.error(f'Error getting OAuth session by ID: {e}')
|
||||
return None
|
||||
|
||||
def get_session_by_id_and_user_id(
|
||||
@@ -167,11 +163,7 @@ class OAuthSessionTable:
|
||||
"""Get OAuth session by ID and user ID"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
session = (
|
||||
db.query(OAuthSession)
|
||||
.filter_by(id=session_id, user_id=user_id)
|
||||
.first()
|
||||
)
|
||||
session = db.query(OAuthSession).filter_by(id=session_id, user_id=user_id).first()
|
||||
if session:
|
||||
db.expunge(session)
|
||||
session.token = self._decrypt_token(session.token)
|
||||
@@ -179,7 +171,7 @@ class OAuthSessionTable:
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth session by ID: {e}")
|
||||
log.error(f'Error getting OAuth session by ID: {e}')
|
||||
return None
|
||||
|
||||
def get_session_by_provider_and_user_id(
|
||||
@@ -201,12 +193,10 @@ class OAuthSessionTable:
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth session by provider and user ID: {e}")
|
||||
log.error(f'Error getting OAuth session by provider and user ID: {e}')
|
||||
return None
|
||||
|
||||
def get_sessions_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> List[OAuthSessionModel]:
|
||||
def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]:
|
||||
"""Get all OAuth sessions for a user"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -220,7 +210,7 @@ class OAuthSessionTable:
|
||||
results.append(OAuthSessionModel.model_validate(session))
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f"Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}"
|
||||
f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}'
|
||||
)
|
||||
db.query(OAuthSession).filter_by(id=session.id).delete()
|
||||
db.commit()
|
||||
@@ -228,7 +218,7 @@ class OAuthSessionTable:
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Error getting OAuth sessions by user ID: {e}")
|
||||
log.error(f'Error getting OAuth sessions by user ID: {e}')
|
||||
return []
|
||||
|
||||
def update_session_by_id(
|
||||
@@ -241,9 +231,9 @@ class OAuthSessionTable:
|
||||
|
||||
db.query(OAuthSession).filter_by(id=session_id).update(
|
||||
{
|
||||
"token": self._encrypt_token(token),
|
||||
"expires_at": token.get("expires_at"),
|
||||
"updated_at": current_time,
|
||||
'token': self._encrypt_token(token),
|
||||
'expires_at': token.get('expires_at'),
|
||||
'updated_at': current_time,
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
@@ -256,12 +246,10 @@ class OAuthSessionTable:
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
log.error(f"Error updating OAuth session tokens: {e}")
|
||||
log.error(f'Error updating OAuth session tokens: {e}')
|
||||
return None
|
||||
|
||||
def delete_session_by_id(
|
||||
self, session_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool:
|
||||
"""Delete an OAuth session"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -269,12 +257,10 @@ class OAuthSessionTable:
|
||||
db.commit()
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth session: {e}")
|
||||
log.error(f'Error deleting OAuth session: {e}')
|
||||
return False
|
||||
|
||||
def delete_sessions_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
"""Delete all OAuth sessions for a user"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -282,12 +268,10 @@ class OAuthSessionTable:
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
||||
log.error(f'Error deleting OAuth sessions by user ID: {e}')
|
||||
return False
|
||||
|
||||
def delete_sessions_by_provider(
|
||||
self, provider: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool:
|
||||
"""Delete all OAuth sessions for a provider"""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -295,7 +279,7 @@ class OAuthSessionTable:
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
|
||||
log.error(f'Error deleting OAuth sessions by provider {provider}: {e}')
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Column, Text, JSON, Index
|
||||
|
||||
|
||||
class PromptHistory(Base):
|
||||
__tablename__ = "prompt_history"
|
||||
__tablename__ = 'prompt_history'
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
prompt_id = Column(Text, nullable=False, index=True)
|
||||
@@ -100,11 +100,7 @@ class PromptHistoryTable:
|
||||
return [
|
||||
PromptHistoryResponse(
|
||||
**PromptHistoryModel.model_validate(entry).model_dump(),
|
||||
user=(
|
||||
users_dict.get(entry.user_id).model_dump()
|
||||
if users_dict.get(entry.user_id)
|
||||
else None
|
||||
),
|
||||
user=(users_dict.get(entry.user_id).model_dump() if users_dict.get(entry.user_id) else None),
|
||||
)
|
||||
for entry in entries
|
||||
]
|
||||
@@ -116,9 +112,7 @@ class PromptHistoryTable:
|
||||
) -> Optional[PromptHistoryModel]:
|
||||
"""Get a specific history entry by ID."""
|
||||
with get_db_context(db) as db:
|
||||
entry = (
|
||||
db.query(PromptHistory).filter(PromptHistory.id == history_id).first()
|
||||
)
|
||||
entry = db.query(PromptHistory).filter(PromptHistory.id == history_id).first()
|
||||
if entry:
|
||||
return PromptHistoryModel.model_validate(entry)
|
||||
return None
|
||||
@@ -147,11 +141,7 @@ class PromptHistoryTable:
|
||||
) -> int:
|
||||
"""Get the number of history entries for a prompt."""
|
||||
with get_db_context(db) as db:
|
||||
return (
|
||||
db.query(PromptHistory)
|
||||
.filter(PromptHistory.prompt_id == prompt_id)
|
||||
.count()
|
||||
)
|
||||
return db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).count()
|
||||
|
||||
def compute_diff(
|
||||
self,
|
||||
@@ -161,9 +151,7 @@ class PromptHistoryTable:
|
||||
) -> Optional[dict]:
|
||||
"""Compute diff between two history entries."""
|
||||
with get_db_context(db) as db:
|
||||
from_entry = (
|
||||
db.query(PromptHistory).filter(PromptHistory.id == from_id).first()
|
||||
)
|
||||
from_entry = db.query(PromptHistory).filter(PromptHistory.id == from_id).first()
|
||||
to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first()
|
||||
|
||||
if not from_entry or not to_entry:
|
||||
@@ -173,26 +161,26 @@ class PromptHistoryTable:
|
||||
to_snapshot = to_entry.snapshot
|
||||
|
||||
# Compute diff for content field
|
||||
from_content = from_snapshot.get("content", "")
|
||||
to_content = to_snapshot.get("content", "")
|
||||
from_content = from_snapshot.get('content', '')
|
||||
to_content = to_snapshot.get('content', '')
|
||||
|
||||
diff_lines = list(
|
||||
difflib.unified_diff(
|
||||
from_content.splitlines(keepends=True),
|
||||
to_content.splitlines(keepends=True),
|
||||
fromfile=f"v{from_id[:8]}",
|
||||
tofile=f"v{to_id[:8]}",
|
||||
lineterm="",
|
||||
fromfile=f'v{from_id[:8]}',
|
||||
tofile=f'v{to_id[:8]}',
|
||||
lineterm='',
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"from_id": from_id,
|
||||
"to_id": to_id,
|
||||
"from_snapshot": from_snapshot,
|
||||
"to_snapshot": to_snapshot,
|
||||
"content_diff": diff_lines,
|
||||
"name_changed": from_snapshot.get("name") != to_snapshot.get("name"),
|
||||
'from_id': from_id,
|
||||
'to_id': to_id,
|
||||
'from_snapshot': from_snapshot,
|
||||
'to_snapshot': to_snapshot,
|
||||
'content_diff': diff_lines,
|
||||
'name_changed': from_snapshot.get('name') != to_snapshot.get('name'),
|
||||
}
|
||||
|
||||
def delete_history_by_prompt_id(
|
||||
@@ -202,9 +190,7 @@ class PromptHistoryTable:
|
||||
) -> bool:
|
||||
"""Delete all history entries for a prompt."""
|
||||
with get_db_context(db) as db:
|
||||
db.query(PromptHistory).filter(
|
||||
PromptHistory.prompt_id == prompt_id
|
||||
).delete()
|
||||
db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).delete()
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, or_, fun
|
||||
|
||||
|
||||
class Prompt(Base):
|
||||
__tablename__ = "prompt"
|
||||
__tablename__ = 'prompt'
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
command = Column(String, unique=True, index=True)
|
||||
@@ -77,7 +77,6 @@ class PromptAccessListResponse(BaseModel):
|
||||
|
||||
|
||||
class PromptForm(BaseModel):
|
||||
|
||||
command: str
|
||||
name: str # Changed from title
|
||||
content: str
|
||||
@@ -91,10 +90,8 @@ class PromptForm(BaseModel):
|
||||
|
||||
|
||||
class PromptsTable:
|
||||
def _get_access_grants(
|
||||
self, prompt_id: str, db: Optional[Session] = None
|
||||
) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource("prompt", prompt_id, db=db)
|
||||
def _get_access_grants(self, prompt_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db)
|
||||
|
||||
def _to_prompt_model(
|
||||
self,
|
||||
@@ -102,13 +99,9 @@ class PromptsTable:
|
||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> PromptModel:
|
||||
prompt_data = PromptModel.model_validate(prompt).model_dump(
|
||||
exclude={"access_grants"}
|
||||
)
|
||||
prompt_data["access_grants"] = (
|
||||
access_grants
|
||||
if access_grants is not None
|
||||
else self._get_access_grants(prompt_data["id"], db=db)
|
||||
prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'})
|
||||
prompt_data['access_grants'] = (
|
||||
access_grants if access_grants is not None else self._get_access_grants(prompt_data['id'], db=db)
|
||||
)
|
||||
return PromptModel.model_validate(prompt_data)
|
||||
|
||||
@@ -135,26 +128,22 @@ class PromptsTable:
|
||||
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
result = Prompt(**prompt.model_dump(exclude={"access_grants"}))
|
||||
result = Prompt(**prompt.model_dump(exclude={'access_grants'}))
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
AccessGrants.set_access_grants(
|
||||
"prompt", prompt_id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('prompt', prompt_id, form_data.access_grants, db=db)
|
||||
|
||||
if result:
|
||||
current_access_grants = self._get_access_grants(prompt_id, db=db)
|
||||
snapshot = {
|
||||
"name": form_data.name,
|
||||
"content": form_data.content,
|
||||
"command": form_data.command,
|
||||
"data": form_data.data or {},
|
||||
"meta": form_data.meta or {},
|
||||
"tags": form_data.tags or [],
|
||||
"access_grants": [
|
||||
grant.model_dump() for grant in current_access_grants
|
||||
],
|
||||
'name': form_data.name,
|
||||
'content': form_data.content,
|
||||
'command': form_data.command,
|
||||
'data': form_data.data or {},
|
||||
'meta': form_data.meta or {},
|
||||
'tags': form_data.tags or [],
|
||||
'access_grants': [grant.model_dump() for grant in current_access_grants],
|
||||
}
|
||||
|
||||
history_entry = PromptHistories.create_history_entry(
|
||||
@@ -162,7 +151,7 @@ class PromptsTable:
|
||||
snapshot=snapshot,
|
||||
user_id=user_id,
|
||||
parent_id=None, # Initial commit has no parent
|
||||
commit_message=form_data.commit_message or "Initial version",
|
||||
commit_message=form_data.commit_message or 'Initial version',
|
||||
db=db,
|
||||
)
|
||||
|
||||
@@ -178,9 +167,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_prompt_by_id(
|
||||
self, prompt_id: str, db: Optional[Session] = None
|
||||
) -> Optional[PromptModel]:
|
||||
def get_prompt_by_id(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]:
|
||||
"""Get prompt by UUID."""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -191,9 +178,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_prompt_by_command(
|
||||
self, command: str, db: Optional[Session] = None
|
||||
) -> Optional[PromptModel]:
|
||||
def get_prompt_by_command(self, command: str, db: Optional[Session] = None) -> Optional[PromptModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
@@ -205,21 +190,14 @@ class PromptsTable:
|
||||
|
||||
def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]:
|
||||
with get_db_context(db) as db:
|
||||
all_prompts = (
|
||||
db.query(Prompt)
|
||||
.filter(Prompt.is_active == True)
|
||||
.order_by(Prompt.updated_at.desc())
|
||||
.all()
|
||||
)
|
||||
all_prompts = db.query(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc()).all()
|
||||
|
||||
user_ids = list(set(prompt.user_id for prompt in all_prompts))
|
||||
prompt_ids = [prompt.id for prompt in all_prompts]
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"prompt", prompt_ids, db=db
|
||||
)
|
||||
grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
|
||||
|
||||
prompts = []
|
||||
for prompt in all_prompts:
|
||||
@@ -232,7 +210,7 @@ class PromptsTable:
|
||||
access_grants=grants_map.get(prompt.id, []),
|
||||
db=db,
|
||||
).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
'user': user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -240,12 +218,10 @@ class PromptsTable:
|
||||
return prompts
|
||||
|
||||
def get_prompts_by_user_id(
|
||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
||||
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
|
||||
) -> list[PromptUserResponse]:
|
||||
prompts = self.get_prompts(db=db)
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
|
||||
return [
|
||||
prompt
|
||||
@@ -253,7 +229,7 @@ class PromptsTable:
|
||||
if prompt.user_id == user_id
|
||||
or AccessGrants.has_access(
|
||||
user_id=user_id,
|
||||
resource_type="prompt",
|
||||
resource_type='prompt',
|
||||
resource_id=prompt.id,
|
||||
permission=permission,
|
||||
user_group_ids=user_group_ids,
|
||||
@@ -276,22 +252,22 @@ class PromptsTable:
|
||||
query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id)
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
query_key = filter.get('query')
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Prompt.name.ilike(f"%{query_key}%"),
|
||||
Prompt.command.ilike(f"%{query_key}%"),
|
||||
Prompt.content.ilike(f"%{query_key}%"),
|
||||
User.name.ilike(f"%{query_key}%"),
|
||||
User.email.ilike(f"%{query_key}%"),
|
||||
Prompt.name.ilike(f'%{query_key}%'),
|
||||
Prompt.command.ilike(f'%{query_key}%'),
|
||||
Prompt.content.ilike(f'%{query_key}%'),
|
||||
User.name.ilike(f'%{query_key}%'),
|
||||
User.email.ilike(f'%{query_key}%'),
|
||||
)
|
||||
)
|
||||
|
||||
view_option = filter.get("view_option")
|
||||
if view_option == "created":
|
||||
view_option = filter.get('view_option')
|
||||
if view_option == 'created':
|
||||
query = query.filter(Prompt.user_id == user_id)
|
||||
elif view_option == "shared":
|
||||
elif view_option == 'shared':
|
||||
query = query.filter(Prompt.user_id != user_id)
|
||||
|
||||
# Apply access grant filtering
|
||||
@@ -300,32 +276,32 @@ class PromptsTable:
|
||||
query=query,
|
||||
DocumentModel=Prompt,
|
||||
filter=filter,
|
||||
resource_type="prompt",
|
||||
permission="read",
|
||||
resource_type='prompt',
|
||||
permission='read',
|
||||
)
|
||||
|
||||
tag = filter.get("tag")
|
||||
tag = filter.get('tag')
|
||||
if tag:
|
||||
# Search for tag in JSON array field
|
||||
like_pattern = f'%"{tag.lower()}"%'
|
||||
tags_text = func.lower(cast(Prompt.tags, String))
|
||||
query = query.filter(tags_text.like(like_pattern))
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
order_by = filter.get('order_by')
|
||||
direction = filter.get('direction')
|
||||
|
||||
if order_by == "name":
|
||||
if direction == "asc":
|
||||
if order_by == 'name':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Prompt.name.asc())
|
||||
else:
|
||||
query = query.order_by(Prompt.name.desc())
|
||||
elif order_by == "created_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'created_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Prompt.created_at.asc())
|
||||
else:
|
||||
query = query.order_by(Prompt.created_at.desc())
|
||||
elif order_by == "updated_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'updated_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(Prompt.updated_at.asc())
|
||||
else:
|
||||
query = query.order_by(Prompt.updated_at.desc())
|
||||
@@ -345,9 +321,7 @@ class PromptsTable:
|
||||
items = query.all()
|
||||
|
||||
prompt_ids = [prompt.id for prompt, _ in items]
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"prompt", prompt_ids, db=db
|
||||
)
|
||||
grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
|
||||
|
||||
prompts = []
|
||||
for prompt, user in items:
|
||||
@@ -358,11 +332,7 @@ class PromptsTable:
|
||||
access_grants=grants_map.get(prompt.id, []),
|
||||
db=db,
|
||||
).model_dump(),
|
||||
user=(
|
||||
UserResponse(**UserModel.model_validate(user).model_dump())
|
||||
if user
|
||||
else None
|
||||
),
|
||||
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -381,9 +351,7 @@ class PromptsTable:
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
latest_history = PromptHistories.get_latest_history_entry(
|
||||
prompt.id, db=db
|
||||
)
|
||||
latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db)
|
||||
parent_id = latest_history.id if latest_history else None
|
||||
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
||||
|
||||
@@ -401,9 +369,7 @@ class PromptsTable:
|
||||
prompt.meta = form_data.meta or prompt.meta
|
||||
prompt.updated_at = int(time.time())
|
||||
if form_data.access_grants is not None:
|
||||
AccessGrants.set_access_grants(
|
||||
"prompt", prompt.id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
|
||||
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
||||
|
||||
db.commit()
|
||||
@@ -411,14 +377,12 @@ class PromptsTable:
|
||||
# Create history entry only if content changed
|
||||
if content_changed:
|
||||
snapshot = {
|
||||
"name": form_data.name,
|
||||
"content": form_data.content,
|
||||
"command": command,
|
||||
"data": form_data.data or {},
|
||||
"meta": form_data.meta or {},
|
||||
"access_grants": [
|
||||
grant.model_dump() for grant in current_access_grants
|
||||
],
|
||||
'name': form_data.name,
|
||||
'content': form_data.content,
|
||||
'command': command,
|
||||
'data': form_data.data or {},
|
||||
'meta': form_data.meta or {},
|
||||
'access_grants': [grant.model_dump() for grant in current_access_grants],
|
||||
}
|
||||
|
||||
history_entry = PromptHistories.create_history_entry(
|
||||
@@ -452,9 +416,7 @@ class PromptsTable:
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
latest_history = PromptHistories.get_latest_history_entry(
|
||||
prompt.id, db=db
|
||||
)
|
||||
latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db)
|
||||
parent_id = latest_history.id if latest_history else None
|
||||
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
||||
|
||||
@@ -478,9 +440,7 @@ class PromptsTable:
|
||||
prompt.tags = form_data.tags
|
||||
|
||||
if form_data.access_grants is not None:
|
||||
AccessGrants.set_access_grants(
|
||||
"prompt", prompt.id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
|
||||
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
||||
|
||||
prompt.updated_at = int(time.time())
|
||||
@@ -490,15 +450,13 @@ class PromptsTable:
|
||||
# Create history entry only if content changed
|
||||
if content_changed:
|
||||
snapshot = {
|
||||
"name": form_data.name,
|
||||
"content": form_data.content,
|
||||
"command": prompt.command,
|
||||
"data": form_data.data or {},
|
||||
"meta": form_data.meta or {},
|
||||
"tags": prompt.tags or [],
|
||||
"access_grants": [
|
||||
grant.model_dump() for grant in current_access_grants
|
||||
],
|
||||
'name': form_data.name,
|
||||
'content': form_data.content,
|
||||
'command': prompt.command,
|
||||
'data': form_data.data or {},
|
||||
'meta': form_data.meta or {},
|
||||
'tags': prompt.tags or [],
|
||||
'access_grants': [grant.model_dump() for grant in current_access_grants],
|
||||
}
|
||||
|
||||
history_entry = PromptHistories.create_history_entry(
|
||||
@@ -560,9 +518,7 @@ class PromptsTable:
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
history_entry = PromptHistories.get_history_entry_by_id(
|
||||
version_id, db=db
|
||||
)
|
||||
history_entry = PromptHistories.get_history_entry_by_id(version_id, db=db)
|
||||
|
||||
if not history_entry:
|
||||
return None
|
||||
@@ -570,11 +526,11 @@ class PromptsTable:
|
||||
# Restore prompt content from the snapshot
|
||||
snapshot = history_entry.snapshot
|
||||
if snapshot:
|
||||
prompt.name = snapshot.get("name", prompt.name)
|
||||
prompt.content = snapshot.get("content", prompt.content)
|
||||
prompt.data = snapshot.get("data", prompt.data)
|
||||
prompt.meta = snapshot.get("meta", prompt.meta)
|
||||
prompt.tags = snapshot.get("tags", prompt.tags)
|
||||
prompt.name = snapshot.get('name', prompt.name)
|
||||
prompt.content = snapshot.get('content', prompt.content)
|
||||
prompt.data = snapshot.get('data', prompt.data)
|
||||
prompt.meta = snapshot.get('meta', prompt.meta)
|
||||
prompt.tags = snapshot.get('tags', prompt.tags)
|
||||
# Note: command and access_grants are not restored from snapshot
|
||||
|
||||
prompt.version_id = version_id
|
||||
@@ -585,9 +541,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def toggle_prompt_active(
|
||||
self, prompt_id: str, db: Optional[Session] = None
|
||||
) -> Optional[PromptModel]:
|
||||
def toggle_prompt_active(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]:
|
||||
"""Toggle the is_active flag on a prompt."""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
@@ -602,16 +556,14 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def delete_prompt_by_command(
|
||||
self, command: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_prompt_by_command(self, command: str, db: Optional[Session] = None) -> bool:
|
||||
"""Permanently delete a prompt and its history."""
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||
if prompt:
|
||||
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
|
||||
AccessGrants.revoke_all_access("prompt", prompt.id, db=db)
|
||||
AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
|
||||
|
||||
db.delete(prompt)
|
||||
db.commit()
|
||||
@@ -627,7 +579,7 @@ class PromptsTable:
|
||||
prompt = db.query(Prompt).filter_by(id=prompt_id).first()
|
||||
if prompt:
|
||||
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
|
||||
AccessGrants.revoke_all_access("prompt", prompt.id, db=db)
|
||||
AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
|
||||
|
||||
db.delete(prompt)
|
||||
db.commit()
|
||||
|
||||
@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Skill(Base):
|
||||
__tablename__ = "skill"
|
||||
__tablename__ = 'skill'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
@@ -77,7 +77,7 @@ class SkillResponse(BaseModel):
|
||||
class SkillUserResponse(SkillResponse):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class SkillAccessResponse(SkillUserResponse):
|
||||
@@ -105,10 +105,8 @@ class SkillAccessListResponse(BaseModel):
|
||||
|
||||
|
||||
class SkillsTable:
|
||||
def _get_access_grants(
|
||||
self, skill_id: str, db: Optional[Session] = None
|
||||
) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource("skill", skill_id, db=db)
|
||||
def _get_access_grants(self, skill_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource('skill', skill_id, db=db)
|
||||
|
||||
def _to_skill_model(
|
||||
self,
|
||||
@@ -116,13 +114,9 @@ class SkillsTable:
|
||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> SkillModel:
|
||||
skill_data = SkillModel.model_validate(skill).model_dump(
|
||||
exclude={"access_grants"}
|
||||
)
|
||||
skill_data["access_grants"] = (
|
||||
access_grants
|
||||
if access_grants is not None
|
||||
else self._get_access_grants(skill_data["id"], db=db)
|
||||
skill_data = SkillModel.model_validate(skill).model_dump(exclude={'access_grants'})
|
||||
skill_data['access_grants'] = (
|
||||
access_grants if access_grants is not None else self._get_access_grants(skill_data['id'], db=db)
|
||||
)
|
||||
return SkillModel.model_validate(skill_data)
|
||||
|
||||
@@ -136,29 +130,25 @@ class SkillsTable:
|
||||
try:
|
||||
result = Skill(
|
||||
**{
|
||||
**form_data.model_dump(exclude={"access_grants"}),
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
**form_data.model_dump(exclude={'access_grants'}),
|
||||
'user_id': user_id,
|
||||
'updated_at': int(time.time()),
|
||||
'created_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
AccessGrants.set_access_grants(
|
||||
"skill", result.id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('skill', result.id, form_data.access_grants, db=db)
|
||||
if result:
|
||||
return self._to_skill_model(result, db=db)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating a new skill: {e}")
|
||||
log.exception(f'Error creating a new skill: {e}')
|
||||
return None
|
||||
|
||||
def get_skill_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[SkillModel]:
|
||||
def get_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
skill = db.get(Skill, id)
|
||||
@@ -166,9 +156,7 @@ class SkillsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_skill_by_name(
|
||||
self, name: str, db: Optional[Session] = None
|
||||
) -> Optional[SkillModel]:
|
||||
def get_skill_by_name(self, name: str, db: Optional[Session] = None) -> Optional[SkillModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
skill = db.query(Skill).filter_by(name=name).first()
|
||||
@@ -185,7 +173,7 @@ class SkillsTable:
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
grants_map = AccessGrants.get_grants_by_resources("skill", skill_ids, db=db)
|
||||
grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db)
|
||||
|
||||
skills = []
|
||||
for skill in all_skills:
|
||||
@@ -198,19 +186,17 @@ class SkillsTable:
|
||||
access_grants=grants_map.get(skill.id, []),
|
||||
db=db,
|
||||
).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
'user': user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
return skills
|
||||
|
||||
def get_skills_by_user_id(
|
||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
||||
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
|
||||
) -> list[SkillUserModel]:
|
||||
skills = self.get_skills(db=db)
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
|
||||
return [
|
||||
skill
|
||||
@@ -218,7 +204,7 @@ class SkillsTable:
|
||||
if skill.user_id == user_id
|
||||
or AccessGrants.has_access(
|
||||
user_id=user_id,
|
||||
resource_type="skill",
|
||||
resource_type='skill',
|
||||
resource_id=skill.id,
|
||||
permission=permission,
|
||||
user_group_ids=user_group_ids,
|
||||
@@ -242,22 +228,22 @@ class SkillsTable:
|
||||
query = db.query(Skill, User).outerjoin(User, User.id == Skill.user_id)
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
query_key = filter.get('query')
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
Skill.name.ilike(f"%{query_key}%"),
|
||||
Skill.description.ilike(f"%{query_key}%"),
|
||||
Skill.id.ilike(f"%{query_key}%"),
|
||||
User.name.ilike(f"%{query_key}%"),
|
||||
User.email.ilike(f"%{query_key}%"),
|
||||
Skill.name.ilike(f'%{query_key}%'),
|
||||
Skill.description.ilike(f'%{query_key}%'),
|
||||
Skill.id.ilike(f'%{query_key}%'),
|
||||
User.name.ilike(f'%{query_key}%'),
|
||||
User.email.ilike(f'%{query_key}%'),
|
||||
)
|
||||
)
|
||||
|
||||
view_option = filter.get("view_option")
|
||||
if view_option == "created":
|
||||
view_option = filter.get('view_option')
|
||||
if view_option == 'created':
|
||||
query = query.filter(Skill.user_id == user_id)
|
||||
elif view_option == "shared":
|
||||
elif view_option == 'shared':
|
||||
query = query.filter(Skill.user_id != user_id)
|
||||
|
||||
# Apply access grant filtering
|
||||
@@ -266,8 +252,8 @@ class SkillsTable:
|
||||
query=query,
|
||||
DocumentModel=Skill,
|
||||
filter=filter,
|
||||
resource_type="skill",
|
||||
permission="read",
|
||||
resource_type='skill',
|
||||
permission='read',
|
||||
)
|
||||
|
||||
query = query.order_by(Skill.updated_at.desc())
|
||||
@@ -283,9 +269,7 @@ class SkillsTable:
|
||||
items = query.all()
|
||||
|
||||
skill_ids = [skill.id for skill, _ in items]
|
||||
grants_map = AccessGrants.get_grants_by_resources(
|
||||
"skill", skill_ids, db=db
|
||||
)
|
||||
grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db)
|
||||
|
||||
skills = []
|
||||
for skill, user in items:
|
||||
@@ -296,33 +280,23 @@ class SkillsTable:
|
||||
access_grants=grants_map.get(skill.id, []),
|
||||
db=db,
|
||||
).model_dump(),
|
||||
user=(
|
||||
UserResponse(
|
||||
**UserModel.model_validate(user).model_dump()
|
||||
)
|
||||
if user
|
||||
else None
|
||||
),
|
||||
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||
)
|
||||
)
|
||||
|
||||
return SkillListResponse(items=skills, total=total)
|
||||
except Exception as e:
|
||||
log.exception(f"Error searching skills: {e}")
|
||||
log.exception(f'Error searching skills: {e}')
|
||||
return SkillListResponse(items=[], total=0)
|
||||
|
||||
def update_skill_by_id(
|
||||
self, id: str, updated: dict, db: Optional[Session] = None
|
||||
) -> Optional[SkillModel]:
|
||||
def update_skill_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[SkillModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
access_grants = updated.pop("access_grants", None)
|
||||
db.query(Skill).filter_by(id=id).update(
|
||||
{**updated, "updated_at": int(time.time())}
|
||||
)
|
||||
access_grants = updated.pop('access_grants', None)
|
||||
db.query(Skill).filter_by(id=id).update({**updated, 'updated_at': int(time.time())})
|
||||
db.commit()
|
||||
if access_grants is not None:
|
||||
AccessGrants.set_access_grants("skill", id, access_grants, db=db)
|
||||
AccessGrants.set_access_grants('skill', id, access_grants, db=db)
|
||||
|
||||
skill = db.query(Skill).get(id)
|
||||
db.refresh(skill)
|
||||
@@ -330,9 +304,7 @@ class SkillsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def toggle_skill_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[SkillModel]:
|
||||
def toggle_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]:
|
||||
with get_db_context(db) as db:
|
||||
try:
|
||||
skill = db.query(Skill).filter_by(id=id).first()
|
||||
@@ -351,7 +323,7 @@ class SkillsTable:
|
||||
def delete_skill_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
AccessGrants.revoke_all_access("skill", id, db=db)
|
||||
AccessGrants.revoke_all_access('skill', id, db=db)
|
||||
db.query(Skill).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -17,19 +17,19 @@ log = logging.getLogger(__name__)
|
||||
# Tag DB Schema
|
||||
####################
|
||||
class Tag(Base):
|
||||
__tablename__ = "tag"
|
||||
__tablename__ = 'tag'
|
||||
id = Column(String)
|
||||
name = Column(String)
|
||||
user_id = Column(String)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),
|
||||
Index("user_id_idx", "user_id"),
|
||||
PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'),
|
||||
Index('user_id_idx', 'user_id'),
|
||||
)
|
||||
|
||||
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
|
||||
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
|
||||
__table_args__ = (PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'),)
|
||||
|
||||
|
||||
class TagModel(BaseModel):
|
||||
@@ -51,12 +51,10 @@ class TagChatIdForm(BaseModel):
|
||||
|
||||
|
||||
class TagTable:
|
||||
def insert_new_tag(
|
||||
self, name: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[TagModel]:
|
||||
def insert_new_tag(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]:
|
||||
with get_db_context(db) as db:
|
||||
id = name.replace(" ", "_").lower()
|
||||
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
||||
id = name.replace(' ', '_').lower()
|
||||
tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name})
|
||||
try:
|
||||
result = Tag(**tag.model_dump())
|
||||
db.add(result)
|
||||
@@ -67,89 +65,63 @@ class TagTable:
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error inserting a new tag: {e}")
|
||||
log.exception(f'Error inserting a new tag: {e}')
|
||||
return None
|
||||
|
||||
def get_tag_by_name_and_user_id(
|
||||
self, name: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[TagModel]:
|
||||
def get_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]:
|
||||
try:
|
||||
id = name.replace(" ", "_").lower()
|
||||
id = name.replace(' ', '_').lower()
|
||||
with get_db_context(db) as db:
|
||||
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
|
||||
return TagModel.model_validate(tag)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_tags_by_user_id(
|
||||
self, user_id: str, db: Optional[Session] = None
|
||||
) -> list[TagModel]:
|
||||
def get_tags_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[TagModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [TagModel.model_validate(tag) for tag in (db.query(Tag).filter_by(user_id=user_id).all())]
|
||||
|
||||
def get_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> list[TagModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
|
||||
for tag in (db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all())
|
||||
]
|
||||
|
||||
def get_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str, db: Optional[Session] = None
|
||||
) -> list[TagModel]:
|
||||
with get_db_context(db) as db:
|
||||
return [
|
||||
TagModel.model_validate(tag)
|
||||
for tag in (
|
||||
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
|
||||
)
|
||||
]
|
||||
|
||||
def delete_tag_by_name_and_user_id(
|
||||
self, name: str, user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
id = name.replace(" ", "_").lower()
|
||||
id = name.replace(' ', '_').lower()
|
||||
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
|
||||
log.debug(f"res: {res}")
|
||||
log.debug(f'res: {res}')
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"delete_tag: {e}")
|
||||
log.error(f'delete_tag: {e}')
|
||||
return False
|
||||
|
||||
def delete_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def delete_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> bool:
|
||||
"""Delete all tags whose id is in *ids* for the given user, in one query."""
|
||||
if not ids:
|
||||
return True
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete(synchronize_session=False)
|
||||
db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
log.error(f"delete_tags_by_ids: {e}")
|
||||
log.error(f'delete_tags_by_ids: {e}')
|
||||
return False
|
||||
|
||||
def ensure_tags_exist(
|
||||
self, names: list[str], user_id: str, db: Optional[Session] = None
|
||||
) -> None:
|
||||
def ensure_tags_exist(self, names: list[str], user_id: str, db: Optional[Session] = None) -> None:
|
||||
"""Create tag rows for any *names* that don't already exist for *user_id*."""
|
||||
if not names:
|
||||
return
|
||||
ids = [n.replace(" ", "_").lower() for n in names]
|
||||
ids = [n.replace(' ', '_').lower() for n in names]
|
||||
with get_db_context(db) as db:
|
||||
existing = {
|
||||
t.id
|
||||
for t in db.query(Tag.id)
|
||||
.filter(Tag.id.in_(ids), Tag.user_id == user_id)
|
||||
.all()
|
||||
}
|
||||
existing = {t.id for t in db.query(Tag.id).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()}
|
||||
new_tags = [
|
||||
Tag(id=tag_id, name=name, user_id=user_id)
|
||||
for tag_id, name in zip(ids, names)
|
||||
if tag_id not in existing
|
||||
Tag(id=tag_id, name=name, user_id=user_id) for tag_id, name in zip(ids, names) if tag_id not in existing
|
||||
]
|
||||
if new_tags:
|
||||
db.add_all(new_tags)
|
||||
|
||||
@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Tool(Base):
|
||||
__tablename__ = "tool"
|
||||
__tablename__ = 'tool'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
@@ -75,7 +75,7 @@ class ToolResponse(BaseModel):
|
||||
class ToolUserResponse(ToolResponse):
|
||||
user: Optional[UserResponse] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class ToolAccessResponse(ToolUserResponse):
|
||||
@@ -95,10 +95,8 @@ class ToolValves(BaseModel):
|
||||
|
||||
|
||||
class ToolsTable:
|
||||
def _get_access_grants(
|
||||
self, tool_id: str, db: Optional[Session] = None
|
||||
) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource("tool", tool_id, db=db)
|
||||
def _get_access_grants(self, tool_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||
return AccessGrants.get_grants_by_resource('tool', tool_id, db=db)
|
||||
|
||||
def _to_tool_model(
|
||||
self,
|
||||
@@ -106,11 +104,9 @@ class ToolsTable:
|
||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||
db: Optional[Session] = None,
|
||||
) -> ToolModel:
|
||||
tool_data = ToolModel.model_validate(tool).model_dump(exclude={"access_grants"})
|
||||
tool_data["access_grants"] = (
|
||||
access_grants
|
||||
if access_grants is not None
|
||||
else self._get_access_grants(tool_data["id"], db=db)
|
||||
tool_data = ToolModel.model_validate(tool).model_dump(exclude={'access_grants'})
|
||||
tool_data['access_grants'] = (
|
||||
access_grants if access_grants is not None else self._get_access_grants(tool_data['id'], db=db)
|
||||
)
|
||||
return ToolModel.model_validate(tool_data)
|
||||
|
||||
@@ -125,30 +121,26 @@ class ToolsTable:
|
||||
try:
|
||||
result = Tool(
|
||||
**{
|
||||
**form_data.model_dump(exclude={"access_grants"}),
|
||||
"specs": specs,
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
**form_data.model_dump(exclude={'access_grants'}),
|
||||
'specs': specs,
|
||||
'user_id': user_id,
|
||||
'updated_at': int(time.time()),
|
||||
'created_at': int(time.time()),
|
||||
}
|
||||
)
|
||||
db.add(result)
|
||||
db.commit()
|
||||
db.refresh(result)
|
||||
AccessGrants.set_access_grants(
|
||||
"tool", result.id, form_data.access_grants, db=db
|
||||
)
|
||||
AccessGrants.set_access_grants('tool', result.id, form_data.access_grants, db=db)
|
||||
if result:
|
||||
return self._to_tool_model(result, db=db)
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating a new tool: {e}")
|
||||
log.exception(f'Error creating a new tool: {e}')
|
||||
return None
|
||||
|
||||
def get_tool_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[ToolModel]:
|
||||
def get_tool_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ToolModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
tool = db.get(Tool, id)
|
||||
@@ -156,9 +148,7 @@ class ToolsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_tools(
|
||||
self, defer_content: bool = False, db: Optional[Session] = None
|
||||
) -> list[ToolUserModel]:
|
||||
def get_tools(self, defer_content: bool = False, db: Optional[Session] = None) -> list[ToolUserModel]:
|
||||
with get_db_context(db) as db:
|
||||
query = db.query(Tool).order_by(Tool.updated_at.desc())
|
||||
if defer_content:
|
||||
@@ -170,7 +160,7 @@ class ToolsTable:
|
||||
|
||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
grants_map = AccessGrants.get_grants_by_resources("tool", tool_ids, db=db)
|
||||
grants_map = AccessGrants.get_grants_by_resources('tool', tool_ids, db=db)
|
||||
|
||||
tools = []
|
||||
for tool in all_tools:
|
||||
@@ -183,7 +173,7 @@ class ToolsTable:
|
||||
access_grants=grants_map.get(tool.id, []),
|
||||
db=db,
|
||||
).model_dump(),
|
||||
"user": user.model_dump() if user else None,
|
||||
'user': user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -192,14 +182,12 @@ class ToolsTable:
|
||||
def get_tools_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
permission: str = "write",
|
||||
permission: str = 'write',
|
||||
defer_content: bool = False,
|
||||
db: Optional[Session] = None,
|
||||
) -> list[ToolUserModel]:
|
||||
tools = self.get_tools(defer_content=defer_content, db=db)
|
||||
user_group_ids = {
|
||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
||||
}
|
||||
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||
|
||||
return [
|
||||
tool
|
||||
@@ -207,7 +195,7 @@ class ToolsTable:
|
||||
if tool.user_id == user_id
|
||||
or AccessGrants.has_access(
|
||||
user_id=user_id,
|
||||
resource_type="tool",
|
||||
resource_type='tool',
|
||||
resource_id=tool.id,
|
||||
permission=permission,
|
||||
user_group_ids=user_group_ids,
|
||||
@@ -215,48 +203,38 @@ class ToolsTable:
|
||||
)
|
||||
]
|
||||
|
||||
def get_tool_valves_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[dict]:
|
||||
def get_tool_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
tool = db.get(Tool, id)
|
||||
return tool.valves if tool.valves else {}
|
||||
except Exception as e:
|
||||
log.exception(f"Error getting tool valves by id {id}")
|
||||
log.exception(f'Error getting tool valves by id {id}')
|
||||
return None
|
||||
|
||||
def update_tool_valves_by_id(
|
||||
self, id: str, valves: dict, db: Optional[Session] = None
|
||||
) -> Optional[ToolValves]:
|
||||
def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[Session] = None) -> Optional[ToolValves]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
db.query(Tool).filter_by(id=id).update(
|
||||
{"valves": valves, "updated_at": int(time.time())}
|
||||
)
|
||||
db.query(Tool).filter_by(id=id).update({'valves': valves, 'updated_at': int(time.time())})
|
||||
db.commit()
|
||||
return self.get_tool_by_id(id, db=db)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[Session] = None
|
||||
) -> Optional[dict]:
|
||||
def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]:
|
||||
try:
|
||||
user = Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "tools" and "valves" settings
|
||||
if "tools" not in user_settings:
|
||||
user_settings["tools"] = {}
|
||||
if "valves" not in user_settings["tools"]:
|
||||
user_settings["tools"]["valves"] = {}
|
||||
if 'tools' not in user_settings:
|
||||
user_settings['tools'] = {}
|
||||
if 'valves' not in user_settings['tools']:
|
||||
user_settings['tools']['valves'] = {}
|
||||
|
||||
return user_settings["tools"]["valves"].get(id, {})
|
||||
return user_settings['tools']['valves'].get(id, {})
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error getting user values by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
log.exception(f'Error getting user values by id {id} and user_id {user_id}: {e}')
|
||||
return None
|
||||
|
||||
def update_user_valves_by_id_and_user_id(
|
||||
@@ -267,35 +245,29 @@ class ToolsTable:
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
|
||||
# Check if user has "tools" and "valves" settings
|
||||
if "tools" not in user_settings:
|
||||
user_settings["tools"] = {}
|
||||
if "valves" not in user_settings["tools"]:
|
||||
user_settings["tools"]["valves"] = {}
|
||||
if 'tools' not in user_settings:
|
||||
user_settings['tools'] = {}
|
||||
if 'valves' not in user_settings['tools']:
|
||||
user_settings['tools']['valves'] = {}
|
||||
|
||||
user_settings["tools"]["valves"][id] = valves
|
||||
user_settings['tools']['valves'][id] = valves
|
||||
|
||||
# Update the user settings in the database
|
||||
Users.update_user_by_id(user_id, {"settings": user_settings}, db=db)
|
||||
Users.update_user_by_id(user_id, {'settings': user_settings}, db=db)
|
||||
|
||||
return user_settings["tools"]["valves"][id]
|
||||
return user_settings['tools']['valves'][id]
|
||||
except Exception as e:
|
||||
log.exception(
|
||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
||||
)
|
||||
log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
|
||||
return None
|
||||
|
||||
def update_tool_by_id(
|
||||
self, id: str, updated: dict, db: Optional[Session] = None
|
||||
) -> Optional[ToolModel]:
|
||||
def update_tool_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[ToolModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
access_grants = updated.pop("access_grants", None)
|
||||
db.query(Tool).filter_by(id=id).update(
|
||||
{**updated, "updated_at": int(time.time())}
|
||||
)
|
||||
access_grants = updated.pop('access_grants', None)
|
||||
db.query(Tool).filter_by(id=id).update({**updated, 'updated_at': int(time.time())})
|
||||
db.commit()
|
||||
if access_grants is not None:
|
||||
AccessGrants.set_access_grants("tool", id, access_grants, db=db)
|
||||
AccessGrants.set_access_grants('tool', id, access_grants, db=db)
|
||||
|
||||
tool = db.query(Tool).get(id)
|
||||
db.refresh(tool)
|
||||
@@ -306,7 +278,7 @@ class ToolsTable:
|
||||
def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
AccessGrants.revoke_all_access("tool", id, db=db)
|
||||
AccessGrants.revoke_all_access('tool', id, db=db)
|
||||
db.query(Tool).filter_by(id=id).delete()
|
||||
db.commit()
|
||||
|
||||
|
||||
@@ -40,12 +40,12 @@ import datetime
|
||||
|
||||
class UserSettings(BaseModel):
|
||||
ui: Optional[dict] = {}
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
pass
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "user"
|
||||
__tablename__ = 'user'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
email = Column(String)
|
||||
@@ -83,7 +83,7 @@ class UserModel(BaseModel):
|
||||
|
||||
email: str
|
||||
username: Optional[str] = None
|
||||
role: str = "pending"
|
||||
role: str = 'pending'
|
||||
|
||||
name: str
|
||||
|
||||
@@ -112,10 +112,10 @@ class UserModel(BaseModel):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
@model_validator(mode='after')
|
||||
def set_profile_image_url(self):
|
||||
if not self.profile_image_url:
|
||||
self.profile_image_url = f"/api/v1/users/{self.id}/profile/image"
|
||||
self.profile_image_url = f'/api/v1/users/{self.id}/profile/image'
|
||||
return self
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ class UserStatusModel(UserModel):
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
__tablename__ = "api_key"
|
||||
__tablename__ = 'api_key'
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
user_id = Column(Text, nullable=False)
|
||||
@@ -163,7 +163,7 @@ class UpdateProfileForm(BaseModel):
|
||||
gender: Optional[str] = None
|
||||
date_of_birth: Optional[datetime.date] = None
|
||||
|
||||
@field_validator("profile_image_url")
|
||||
@field_validator('profile_image_url')
|
||||
@classmethod
|
||||
def check_profile_image_url(cls, v: str) -> str:
|
||||
return validate_profile_image_url(v)
|
||||
@@ -174,7 +174,7 @@ class UserGroupIdsModel(UserModel):
|
||||
|
||||
|
||||
class UserModelResponse(UserModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
@@ -251,7 +251,7 @@ class UserUpdateForm(BaseModel):
|
||||
profile_image_url: str
|
||||
password: Optional[str] = None
|
||||
|
||||
@field_validator("profile_image_url")
|
||||
@field_validator('profile_image_url')
|
||||
@classmethod
|
||||
def check_profile_image_url(cls, v: str) -> str:
|
||||
return validate_profile_image_url(v)
|
||||
@@ -263,8 +263,8 @@ class UsersTable:
|
||||
id: str,
|
||||
name: str,
|
||||
email: str,
|
||||
profile_image_url: str = "/user.png",
|
||||
role: str = "pending",
|
||||
profile_image_url: str = '/user.png',
|
||||
role: str = 'pending',
|
||||
username: Optional[str] = None,
|
||||
oauth: Optional[dict] = None,
|
||||
db: Optional[Session] = None,
|
||||
@@ -272,16 +272,16 @@ class UsersTable:
|
||||
with get_db_context(db) as db:
|
||||
user = UserModel(
|
||||
**{
|
||||
"id": id,
|
||||
"email": email,
|
||||
"name": name,
|
||||
"role": role,
|
||||
"profile_image_url": profile_image_url,
|
||||
"last_active_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
"username": username,
|
||||
"oauth": oauth,
|
||||
'id': id,
|
||||
'email': email,
|
||||
'name': name,
|
||||
'role': role,
|
||||
'profile_image_url': profile_image_url,
|
||||
'last_active_at': int(time.time()),
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
'username': username,
|
||||
'oauth': oauth,
|
||||
}
|
||||
)
|
||||
result = User(**user.model_dump())
|
||||
@@ -293,9 +293,7 @@ class UsersTable:
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_user_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def get_user_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -303,49 +301,32 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_api_key(
|
||||
self, api_key: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def get_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = (
|
||||
db.query(User)
|
||||
.join(ApiKey, User.id == ApiKey.user_id)
|
||||
.filter(ApiKey.key == api_key)
|
||||
.first()
|
||||
)
|
||||
user = db.query(User).join(ApiKey, User.id == ApiKey.user_id).filter(ApiKey.key == api_key).first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_email(
|
||||
self, email: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def get_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = (
|
||||
db.query(User)
|
||||
.filter(func.lower(User.email) == email.lower())
|
||||
.first()
|
||||
)
|
||||
user = db.query(User).filter(func.lower(User.email) == email.lower()).first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_by_oauth_sub(
|
||||
self, provider: str, sub: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db: # type: Session
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
query = db.query(User)
|
||||
if dialect_name == "sqlite":
|
||||
query = query.filter(User.oauth.contains({provider: {"sub": sub}}))
|
||||
elif dialect_name == "postgresql":
|
||||
query = query.filter(
|
||||
User.oauth[provider].cast(JSONB)["sub"].astext == sub
|
||||
)
|
||||
if dialect_name == 'sqlite':
|
||||
query = query.filter(User.oauth.contains({provider: {'sub': sub}}))
|
||||
elif dialect_name == 'postgresql':
|
||||
query = query.filter(User.oauth[provider].cast(JSONB)['sub'].astext == sub)
|
||||
|
||||
user = query.first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
@@ -361,15 +342,10 @@ class UsersTable:
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
query = db.query(User)
|
||||
if dialect_name == "sqlite":
|
||||
query = query.filter(
|
||||
User.scim.contains({provider: {"external_id": external_id}})
|
||||
)
|
||||
elif dialect_name == "postgresql":
|
||||
query = query.filter(
|
||||
User.scim[provider].cast(JSONB)["external_id"].astext
|
||||
== external_id
|
||||
)
|
||||
if dialect_name == 'sqlite':
|
||||
query = query.filter(User.scim.contains({provider: {'external_id': external_id}}))
|
||||
elif dialect_name == 'postgresql':
|
||||
query = query.filter(User.scim[provider].cast(JSONB)['external_id'].astext == external_id)
|
||||
|
||||
user = query.first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
@@ -388,16 +364,16 @@ class UsersTable:
|
||||
query = db.query(User).options(defer(User.profile_image_url))
|
||||
|
||||
if filter:
|
||||
query_key = filter.get("query")
|
||||
query_key = filter.get('query')
|
||||
if query_key:
|
||||
query = query.filter(
|
||||
or_(
|
||||
User.name.ilike(f"%{query_key}%"),
|
||||
User.email.ilike(f"%{query_key}%"),
|
||||
User.name.ilike(f'%{query_key}%'),
|
||||
User.email.ilike(f'%{query_key}%'),
|
||||
)
|
||||
)
|
||||
|
||||
channel_id = filter.get("channel_id")
|
||||
channel_id = filter.get('channel_id')
|
||||
if channel_id:
|
||||
query = query.filter(
|
||||
exists(
|
||||
@@ -408,13 +384,13 @@ class UsersTable:
|
||||
)
|
||||
)
|
||||
|
||||
user_ids = filter.get("user_ids")
|
||||
group_ids = filter.get("group_ids")
|
||||
user_ids = filter.get('user_ids')
|
||||
group_ids = filter.get('group_ids')
|
||||
|
||||
if isinstance(user_ids, list) and isinstance(group_ids, list):
|
||||
# If both are empty lists, return no users
|
||||
if not user_ids and not group_ids:
|
||||
return {"users": [], "total": 0}
|
||||
return {'users': [], 'total': 0}
|
||||
|
||||
if user_ids:
|
||||
query = query.filter(User.id.in_(user_ids))
|
||||
@@ -429,21 +405,21 @@ class UsersTable:
|
||||
)
|
||||
)
|
||||
|
||||
roles = filter.get("roles")
|
||||
roles = filter.get('roles')
|
||||
if roles:
|
||||
include_roles = [role for role in roles if not role.startswith("!")]
|
||||
exclude_roles = [role[1:] for role in roles if role.startswith("!")]
|
||||
include_roles = [role for role in roles if not role.startswith('!')]
|
||||
exclude_roles = [role[1:] for role in roles if role.startswith('!')]
|
||||
|
||||
if include_roles:
|
||||
query = query.filter(User.role.in_(include_roles))
|
||||
if exclude_roles:
|
||||
query = query.filter(~User.role.in_(exclude_roles))
|
||||
|
||||
order_by = filter.get("order_by")
|
||||
direction = filter.get("direction")
|
||||
order_by = filter.get('order_by')
|
||||
direction = filter.get('direction')
|
||||
|
||||
if order_by and order_by.startswith("group_id:"):
|
||||
group_id = order_by.split(":", 1)[1]
|
||||
if order_by and order_by.startswith('group_id:'):
|
||||
group_id = order_by.split(':', 1)[1]
|
||||
|
||||
# Subquery that checks if the user belongs to the group
|
||||
membership_exists = exists(
|
||||
@@ -456,42 +432,42 @@ class UsersTable:
|
||||
# CASE: user in group → 1, user not in group → 0
|
||||
group_sort = case((membership_exists, 1), else_=0)
|
||||
|
||||
if direction == "asc":
|
||||
if direction == 'asc':
|
||||
query = query.order_by(group_sort.asc(), User.name.asc())
|
||||
else:
|
||||
query = query.order_by(group_sort.desc(), User.name.asc())
|
||||
|
||||
elif order_by == "name":
|
||||
if direction == "asc":
|
||||
elif order_by == 'name':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.name.asc())
|
||||
else:
|
||||
query = query.order_by(User.name.desc())
|
||||
|
||||
elif order_by == "email":
|
||||
if direction == "asc":
|
||||
elif order_by == 'email':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.email.asc())
|
||||
else:
|
||||
query = query.order_by(User.email.desc())
|
||||
|
||||
elif order_by == "created_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'created_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.created_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.created_at.desc())
|
||||
|
||||
elif order_by == "last_active_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'last_active_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.last_active_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.last_active_at.desc())
|
||||
|
||||
elif order_by == "updated_at":
|
||||
if direction == "asc":
|
||||
elif order_by == 'updated_at':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.updated_at.asc())
|
||||
else:
|
||||
query = query.order_by(User.updated_at.desc())
|
||||
elif order_by == "role":
|
||||
if direction == "asc":
|
||||
elif order_by == 'role':
|
||||
if direction == 'asc':
|
||||
query = query.order_by(User.role.asc())
|
||||
else:
|
||||
query = query.order_by(User.role.desc())
|
||||
@@ -510,13 +486,11 @@ class UsersTable:
|
||||
|
||||
users = query.all()
|
||||
return {
|
||||
"users": [UserModel.model_validate(user) for user in users],
|
||||
"total": total,
|
||||
'users': [UserModel.model_validate(user) for user in users],
|
||||
'total': total,
|
||||
}
|
||||
|
||||
def get_users_by_group_id(
|
||||
self, group_id: str, db: Optional[Session] = None
|
||||
) -> list[UserModel]:
|
||||
def get_users_by_group_id(self, group_id: str, db: Optional[Session] = None) -> list[UserModel]:
|
||||
with get_db_context(db) as db:
|
||||
users = (
|
||||
db.query(User)
|
||||
@@ -527,16 +501,9 @@ class UsersTable:
|
||||
)
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_users_by_user_ids(
|
||||
self, user_ids: list[str], db: Optional[Session] = None
|
||||
) -> list[UserStatusModel]:
|
||||
def get_users_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[UserStatusModel]:
|
||||
with get_db_context(db) as db:
|
||||
users = (
|
||||
db.query(User)
|
||||
.options(defer(User.profile_image_url))
|
||||
.filter(User.id.in_(user_ids))
|
||||
.all()
|
||||
)
|
||||
users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:
|
||||
@@ -555,9 +522,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_user_webhook_url_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[str]:
|
||||
def get_user_webhook_url_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -565,11 +530,7 @@ class UsersTable:
|
||||
if user.settings is None:
|
||||
return None
|
||||
else:
|
||||
return (
|
||||
user.settings.get("ui", {})
|
||||
.get("notifications", {})
|
||||
.get("webhook_url", None)
|
||||
)
|
||||
return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -577,14 +538,10 @@ class UsersTable:
|
||||
with get_db_context(db) as db:
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
|
||||
query = db.query(User).filter(
|
||||
User.last_active_at > today_midnight_timestamp
|
||||
)
|
||||
query = db.query(User).filter(User.last_active_at > today_midnight_timestamp)
|
||||
return query.count()
|
||||
|
||||
def update_user_role_by_id(
|
||||
self, id: str, role: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_role_by_id(self, id: str, role: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -629,9 +586,7 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||
def update_last_active_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def update_last_active_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -665,10 +620,10 @@ class UsersTable:
|
||||
oauth = user.oauth or {}
|
||||
|
||||
# Update or insert provider entry
|
||||
oauth[provider] = {"sub": sub}
|
||||
oauth[provider] = {'sub': sub}
|
||||
|
||||
# Persist updated JSON
|
||||
db.query(User).filter_by(id=id).update({"oauth": oauth})
|
||||
db.query(User).filter_by(id=id).update({'oauth': oauth})
|
||||
db.commit()
|
||||
|
||||
return UserModel.model_validate(user)
|
||||
@@ -698,9 +653,9 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
scim = user.scim or {}
|
||||
scim[provider] = {"external_id": external_id}
|
||||
scim[provider] = {'external_id': external_id}
|
||||
|
||||
db.query(User).filter_by(id=id).update({"scim": scim})
|
||||
db.query(User).filter_by(id=id).update({'scim': scim})
|
||||
db.commit()
|
||||
|
||||
return UserModel.model_validate(user)
|
||||
@@ -708,9 +663,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_by_id(
|
||||
self, id: str, updated: dict, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -725,9 +678,7 @@ class UsersTable:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def update_user_settings_by_id(
|
||||
self, id: str, updated: dict, db: Optional[Session] = None
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -741,7 +692,7 @@ class UsersTable:
|
||||
|
||||
user_settings.update(updated)
|
||||
|
||||
db.query(User).filter_by(id=id).update({"settings": user_settings})
|
||||
db.query(User).filter_by(id=id).update({'settings': user_settings})
|
||||
db.commit()
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
@@ -768,9 +719,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_user_api_key_by_id(
|
||||
self, id: str, db: Optional[Session] = None
|
||||
) -> Optional[str]:
|
||||
def get_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
api_key = db.query(ApiKey).filter_by(user_id=id).first()
|
||||
@@ -778,9 +727,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def update_user_api_key_by_id(
|
||||
self, id: str, api_key: str, db: Optional[Session] = None
|
||||
) -> bool:
|
||||
def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[Session] = None) -> bool:
|
||||
try:
|
||||
with get_db_context(db) as db:
|
||||
db.query(ApiKey).filter_by(user_id=id).delete()
|
||||
@@ -788,7 +735,7 @@ class UsersTable:
|
||||
|
||||
now = int(time.time())
|
||||
new_api_key = ApiKey(
|
||||
id=f"key_{id}",
|
||||
id=f'key_{id}',
|
||||
user_id=id,
|
||||
key=api_key,
|
||||
created_at=now,
|
||||
@@ -811,16 +758,14 @@ class UsersTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_valid_user_ids(
|
||||
self, user_ids: list[str], db: Optional[Session] = None
|
||||
) -> list[str]:
|
||||
def get_valid_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[str]:
|
||||
with get_db_context(db) as db:
|
||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||
return [user.id for user in users]
|
||||
|
||||
def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||
with get_db_context(db) as db:
|
||||
user = db.query(User).filter_by(role="admin").first()
|
||||
user = db.query(User).filter_by(role='admin').first()
|
||||
if user:
|
||||
return UserModel.model_validate(user)
|
||||
else:
|
||||
@@ -830,9 +775,7 @@ class UsersTable:
|
||||
with get_db_context(db) as db:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
three_minutes_ago = int(time.time()) - 180
|
||||
count = (
|
||||
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
|
||||
)
|
||||
count = db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
|
||||
return count
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -40,78 +40,76 @@ class DatalabMarkerLoader:
|
||||
self.output_format = output_format
|
||||
|
||||
def _get_mime_type(self, filename: str) -> str:
|
||||
ext = filename.rsplit(".", 1)[-1].lower()
|
||||
ext = filename.rsplit('.', 1)[-1].lower()
|
||||
mime_map = {
|
||||
"pdf": "application/pdf",
|
||||
"xls": "application/vnd.ms-excel",
|
||||
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"ods": "application/vnd.oasis.opendocument.spreadsheet",
|
||||
"doc": "application/msword",
|
||||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"odt": "application/vnd.oasis.opendocument.text",
|
||||
"ppt": "application/vnd.ms-powerpoint",
|
||||
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
"odp": "application/vnd.oasis.opendocument.presentation",
|
||||
"html": "text/html",
|
||||
"epub": "application/epub+zip",
|
||||
"png": "image/png",
|
||||
"jpeg": "image/jpeg",
|
||||
"jpg": "image/jpeg",
|
||||
"webp": "image/webp",
|
||||
"gif": "image/gif",
|
||||
"tiff": "image/tiff",
|
||||
'pdf': 'application/pdf',
|
||||
'xls': 'application/vnd.ms-excel',
|
||||
'xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'ods': 'application/vnd.oasis.opendocument.spreadsheet',
|
||||
'doc': 'application/msword',
|
||||
'docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'odt': 'application/vnd.oasis.opendocument.text',
|
||||
'ppt': 'application/vnd.ms-powerpoint',
|
||||
'pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'odp': 'application/vnd.oasis.opendocument.presentation',
|
||||
'html': 'text/html',
|
||||
'epub': 'application/epub+zip',
|
||||
'png': 'image/png',
|
||||
'jpeg': 'image/jpeg',
|
||||
'jpg': 'image/jpeg',
|
||||
'webp': 'image/webp',
|
||||
'gif': 'image/gif',
|
||||
'tiff': 'image/tiff',
|
||||
}
|
||||
return mime_map.get(ext, "application/octet-stream")
|
||||
return mime_map.get(ext, 'application/octet-stream')
|
||||
|
||||
def check_marker_request_status(self, request_id: str) -> dict:
|
||||
url = f"{self.api_base_url}/{request_id}"
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
url = f'{self.api_base_url}/{request_id}'
|
||||
headers = {'X-Api-Key': self.api_key}
|
||||
try:
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
log.info(f"Marker API status check for request {request_id}: {result}")
|
||||
log.info(f'Marker API status check for request {request_id}: {result}')
|
||||
return result
|
||||
except requests.HTTPError as e:
|
||||
log.error(f"Error checking Marker request status: {e}")
|
||||
log.error(f'Error checking Marker request status: {e}')
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to check Marker request: {e}",
|
||||
detail=f'Failed to check Marker request: {e}',
|
||||
)
|
||||
except ValueError as e:
|
||||
log.error(f"Invalid JSON checking Marker request: {e}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
|
||||
)
|
||||
log.error(f'Invalid JSON checking Marker request: {e}')
|
||||
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON: {e}')
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
filename = os.path.basename(self.file_path)
|
||||
mime_type = self._get_mime_type(filename)
|
||||
headers = {"X-Api-Key": self.api_key}
|
||||
headers = {'X-Api-Key': self.api_key}
|
||||
|
||||
form_data = {
|
||||
"use_llm": str(self.use_llm).lower(),
|
||||
"skip_cache": str(self.skip_cache).lower(),
|
||||
"force_ocr": str(self.force_ocr).lower(),
|
||||
"paginate": str(self.paginate).lower(),
|
||||
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
|
||||
"disable_image_extraction": str(self.disable_image_extraction).lower(),
|
||||
"format_lines": str(self.format_lines).lower(),
|
||||
"output_format": self.output_format,
|
||||
'use_llm': str(self.use_llm).lower(),
|
||||
'skip_cache': str(self.skip_cache).lower(),
|
||||
'force_ocr': str(self.force_ocr).lower(),
|
||||
'paginate': str(self.paginate).lower(),
|
||||
'strip_existing_ocr': str(self.strip_existing_ocr).lower(),
|
||||
'disable_image_extraction': str(self.disable_image_extraction).lower(),
|
||||
'format_lines': str(self.format_lines).lower(),
|
||||
'output_format': self.output_format,
|
||||
}
|
||||
|
||||
if self.additional_config and self.additional_config.strip():
|
||||
form_data["additional_config"] = self.additional_config
|
||||
form_data['additional_config'] = self.additional_config
|
||||
|
||||
log.info(
|
||||
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (filename, f, mime_type)}
|
||||
with open(self.file_path, 'rb') as f:
|
||||
files = {'file': (filename, f, mime_type)}
|
||||
response = requests.post(
|
||||
f"{self.api_base_url}",
|
||||
f'{self.api_base_url}',
|
||||
data=form_data,
|
||||
files=files,
|
||||
headers=headers,
|
||||
@@ -119,29 +117,25 @@ class DatalabMarkerLoader:
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Datalab Marker request failed: {e}",
|
||||
detail=f'Datalab Marker request failed: {e}',
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
|
||||
)
|
||||
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON response: {e}')
|
||||
except Exception as e:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
if not result.get("success"):
|
||||
if not result.get('success'):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
|
||||
detail=f'Datalab Marker request failed: {result.get("error", "Unknown error")}',
|
||||
)
|
||||
|
||||
check_url = result.get("request_check_url")
|
||||
request_id = result.get("request_id")
|
||||
check_url = result.get('request_check_url')
|
||||
request_id = result.get('request_id')
|
||||
|
||||
# Check if this is a direct response (self-hosted) or polling response (DataLab)
|
||||
if check_url:
|
||||
@@ -154,54 +148,45 @@ class DatalabMarkerLoader:
|
||||
poll_result = poll_response.json()
|
||||
except (requests.HTTPError, ValueError) as e:
|
||||
raw_body = poll_response.text
|
||||
log.error(f"Polling error: {e}, response body: {raw_body}")
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
|
||||
)
|
||||
log.error(f'Polling error: {e}, response body: {raw_body}')
|
||||
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Polling failed: {e}')
|
||||
|
||||
status_val = poll_result.get("status")
|
||||
success_val = poll_result.get("success")
|
||||
status_val = poll_result.get('status')
|
||||
success_val = poll_result.get('success')
|
||||
|
||||
if status_val == "complete":
|
||||
if status_val == 'complete':
|
||||
summary = {
|
||||
k: poll_result.get(k)
|
||||
for k in (
|
||||
"status",
|
||||
"output_format",
|
||||
"success",
|
||||
"error",
|
||||
"page_count",
|
||||
"total_cost",
|
||||
'status',
|
||||
'output_format',
|
||||
'success',
|
||||
'error',
|
||||
'page_count',
|
||||
'total_cost',
|
||||
)
|
||||
}
|
||||
log.info(
|
||||
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
|
||||
)
|
||||
log.info(f'Marker processing completed successfully: {json.dumps(summary, indent=2)}')
|
||||
break
|
||||
|
||||
if status_val == "failed" or success_val is False:
|
||||
log.error(
|
||||
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
|
||||
)
|
||||
error_msg = (
|
||||
poll_result.get("error")
|
||||
or "Marker returned failure without error message"
|
||||
)
|
||||
if status_val == 'failed' or success_val is False:
|
||||
log.error(f'Marker poll failed full response: {json.dumps(poll_result, indent=2)}')
|
||||
error_msg = poll_result.get('error') or 'Marker returned failure without error message'
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Marker processing failed: {error_msg}",
|
||||
detail=f'Marker processing failed: {error_msg}',
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="Marker processing timed out",
|
||||
detail='Marker processing timed out',
|
||||
)
|
||||
|
||||
if not poll_result.get("success", False):
|
||||
error_msg = poll_result.get("error") or "Unknown processing error"
|
||||
if not poll_result.get('success', False):
|
||||
error_msg = poll_result.get('error') or 'Unknown processing error'
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Final processing failed: {error_msg}",
|
||||
detail=f'Final processing failed: {error_msg}',
|
||||
)
|
||||
|
||||
# DataLab format - content in format-specific fields
|
||||
@@ -210,69 +195,65 @@ class DatalabMarkerLoader:
|
||||
final_result = poll_result
|
||||
else:
|
||||
# Self-hosted direct response - content in "output" field
|
||||
if "output" in result:
|
||||
log.info("Self-hosted Marker returned direct response without polling")
|
||||
raw_content = result.get("output")
|
||||
if 'output' in result:
|
||||
log.info('Self-hosted Marker returned direct response without polling')
|
||||
raw_content = result.get('output')
|
||||
final_result = result
|
||||
else:
|
||||
available_fields = (
|
||||
list(result.keys())
|
||||
if isinstance(result, dict)
|
||||
else "non-dict response"
|
||||
)
|
||||
available_fields = list(result.keys()) if isinstance(result, dict) else 'non-dict response'
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
|
||||
)
|
||||
|
||||
if self.output_format.lower() == "json":
|
||||
if self.output_format.lower() == 'json':
|
||||
full_text = json.dumps(raw_content, indent=2)
|
||||
elif self.output_format.lower() in {"markdown", "html"}:
|
||||
elif self.output_format.lower() in {'markdown', 'html'}:
|
||||
full_text = str(raw_content).strip()
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported output format: {self.output_format}",
|
||||
detail=f'Unsupported output format: {self.output_format}',
|
||||
)
|
||||
|
||||
if not full_text:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="Marker returned empty content",
|
||||
detail='Marker returned empty content',
|
||||
)
|
||||
|
||||
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
|
||||
marker_output_dir = os.path.join('/app/backend/data/uploads', 'marker_output')
|
||||
os.makedirs(marker_output_dir, exist_ok=True)
|
||||
|
||||
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
|
||||
file_ext = file_ext_map.get(self.output_format.lower(), "txt")
|
||||
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
|
||||
file_ext_map = {'markdown': 'md', 'json': 'json', 'html': 'html'}
|
||||
file_ext = file_ext_map.get(self.output_format.lower(), 'txt')
|
||||
output_filename = f'{os.path.splitext(filename)[0]}.{file_ext}'
|
||||
output_path = os.path.join(marker_output_dir, output_filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(full_text)
|
||||
log.info(f"Saved Marker output to: {output_path}")
|
||||
log.info(f'Saved Marker output to: {output_path}')
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to write marker output to disk: {e}")
|
||||
log.warning(f'Failed to write marker output to disk: {e}')
|
||||
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"output_format": final_result.get("output_format", self.output_format),
|
||||
"page_count": final_result.get("page_count", 0),
|
||||
"processed_with_llm": self.use_llm,
|
||||
"request_id": request_id or "",
|
||||
'source': filename,
|
||||
'output_format': final_result.get('output_format', self.output_format),
|
||||
'page_count': final_result.get('page_count', 0),
|
||||
'processed_with_llm': self.use_llm,
|
||||
'request_id': request_id or '',
|
||||
}
|
||||
|
||||
images = final_result.get("images", {})
|
||||
images = final_result.get('images', {})
|
||||
if images:
|
||||
metadata["image_count"] = len(images)
|
||||
metadata["images"] = json.dumps(list(images.keys()))
|
||||
metadata['image_count'] = len(images)
|
||||
metadata['images'] = json.dumps(list(images.keys()))
|
||||
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, (dict, list)):
|
||||
metadata[k] = json.dumps(v)
|
||||
elif v is None:
|
||||
metadata[k] = ""
|
||||
metadata[k] = ''
|
||||
|
||||
return [Document(page_content=full_text, metadata=metadata)]
|
||||
|
||||
@@ -29,18 +29,18 @@ class ExternalDocumentLoader(BaseLoader):
|
||||
self.user = user
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
with open(self.file_path, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
headers = {}
|
||||
if self.mime_type is not None:
|
||||
headers["Content-Type"] = self.mime_type
|
||||
headers['Content-Type'] = self.mime_type
|
||||
|
||||
if self.api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
headers['Authorization'] = f'Bearer {self.api_key}'
|
||||
|
||||
try:
|
||||
headers["X-Filename"] = quote(os.path.basename(self.file_path))
|
||||
headers['X-Filename'] = quote(os.path.basename(self.file_path))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -48,24 +48,23 @@ class ExternalDocumentLoader(BaseLoader):
|
||||
headers = include_user_info_headers(headers, self.user)
|
||||
|
||||
url = self.url
|
||||
if url.endswith("/"):
|
||||
if url.endswith('/'):
|
||||
url = url[:-1]
|
||||
|
||||
try:
|
||||
response = requests.put(f"{url}/process", data=data, headers=headers)
|
||||
response = requests.put(f'{url}/process', data=data, headers=headers)
|
||||
except Exception as e:
|
||||
log.error(f"Error connecting to endpoint: {e}")
|
||||
raise Exception(f"Error connecting to endpoint: {e}")
|
||||
log.error(f'Error connecting to endpoint: {e}')
|
||||
raise Exception(f'Error connecting to endpoint: {e}')
|
||||
|
||||
if response.ok:
|
||||
|
||||
response_data = response.json()
|
||||
if response_data:
|
||||
if isinstance(response_data, dict):
|
||||
return [
|
||||
Document(
|
||||
page_content=response_data.get("page_content"),
|
||||
metadata=response_data.get("metadata"),
|
||||
page_content=response_data.get('page_content'),
|
||||
metadata=response_data.get('metadata'),
|
||||
)
|
||||
]
|
||||
elif isinstance(response_data, list):
|
||||
@@ -73,17 +72,15 @@ class ExternalDocumentLoader(BaseLoader):
|
||||
for document in response_data:
|
||||
documents.append(
|
||||
Document(
|
||||
page_content=document.get("page_content"),
|
||||
metadata=document.get("metadata"),
|
||||
page_content=document.get('page_content'),
|
||||
metadata=document.get('metadata'),
|
||||
)
|
||||
)
|
||||
return documents
|
||||
else:
|
||||
raise Exception("Error loading document: Unable to parse content")
|
||||
raise Exception('Error loading document: Unable to parse content')
|
||||
|
||||
else:
|
||||
raise Exception("Error loading document: No content returned")
|
||||
raise Exception('Error loading document: No content returned')
|
||||
else:
|
||||
raise Exception(
|
||||
f"Error loading document: {response.status_code} {response.text}"
|
||||
)
|
||||
raise Exception(f'Error loading document: {response.status_code} {response.text}')
|
||||
|
||||
@@ -30,22 +30,22 @@ class ExternalWebLoader(BaseLoader):
|
||||
response = requests.post(
|
||||
self.external_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
|
||||
"Authorization": f"Bearer {self.external_api_key}",
|
||||
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) External Web Loader',
|
||||
'Authorization': f'Bearer {self.external_api_key}',
|
||||
},
|
||||
json={
|
||||
"urls": urls,
|
||||
'urls': urls,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
for result in results:
|
||||
yield Document(
|
||||
page_content=result.get("page_content", ""),
|
||||
metadata=result.get("metadata", {}),
|
||||
page_content=result.get('page_content', ''),
|
||||
metadata=result.get('metadata', {}),
|
||||
)
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.error(f"Error extracting content from batch {urls}: {e}")
|
||||
log.error(f'Error extracting content from batch {urls}: {e}')
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -30,59 +30,59 @@ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
known_source_ext = [
|
||||
"go",
|
||||
"py",
|
||||
"java",
|
||||
"sh",
|
||||
"bat",
|
||||
"ps1",
|
||||
"cmd",
|
||||
"js",
|
||||
"ts",
|
||||
"css",
|
||||
"cpp",
|
||||
"hpp",
|
||||
"h",
|
||||
"c",
|
||||
"cs",
|
||||
"sql",
|
||||
"log",
|
||||
"ini",
|
||||
"pl",
|
||||
"pm",
|
||||
"r",
|
||||
"dart",
|
||||
"dockerfile",
|
||||
"env",
|
||||
"php",
|
||||
"hs",
|
||||
"hsc",
|
||||
"lua",
|
||||
"nginxconf",
|
||||
"conf",
|
||||
"m",
|
||||
"mm",
|
||||
"plsql",
|
||||
"perl",
|
||||
"rb",
|
||||
"rs",
|
||||
"db2",
|
||||
"scala",
|
||||
"bash",
|
||||
"swift",
|
||||
"vue",
|
||||
"svelte",
|
||||
"ex",
|
||||
"exs",
|
||||
"erl",
|
||||
"tsx",
|
||||
"jsx",
|
||||
"hs",
|
||||
"lhs",
|
||||
"json",
|
||||
"yaml",
|
||||
"yml",
|
||||
"toml",
|
||||
'go',
|
||||
'py',
|
||||
'java',
|
||||
'sh',
|
||||
'bat',
|
||||
'ps1',
|
||||
'cmd',
|
||||
'js',
|
||||
'ts',
|
||||
'css',
|
||||
'cpp',
|
||||
'hpp',
|
||||
'h',
|
||||
'c',
|
||||
'cs',
|
||||
'sql',
|
||||
'log',
|
||||
'ini',
|
||||
'pl',
|
||||
'pm',
|
||||
'r',
|
||||
'dart',
|
||||
'dockerfile',
|
||||
'env',
|
||||
'php',
|
||||
'hs',
|
||||
'hsc',
|
||||
'lua',
|
||||
'nginxconf',
|
||||
'conf',
|
||||
'm',
|
||||
'mm',
|
||||
'plsql',
|
||||
'perl',
|
||||
'rb',
|
||||
'rs',
|
||||
'db2',
|
||||
'scala',
|
||||
'bash',
|
||||
'swift',
|
||||
'vue',
|
||||
'svelte',
|
||||
'ex',
|
||||
'exs',
|
||||
'erl',
|
||||
'tsx',
|
||||
'jsx',
|
||||
'hs',
|
||||
'lhs',
|
||||
'json',
|
||||
'yaml',
|
||||
'yml',
|
||||
'toml',
|
||||
]
|
||||
|
||||
|
||||
@@ -99,11 +99,11 @@ class ExcelLoader:
|
||||
xls = pd.ExcelFile(self.file_path)
|
||||
for sheet_name in xls.sheet_names:
|
||||
df = pd.read_excel(xls, sheet_name=sheet_name)
|
||||
text_parts.append(f"Sheet: {sheet_name}\n{df.to_string(index=False)}")
|
||||
text_parts.append(f'Sheet: {sheet_name}\n{df.to_string(index=False)}')
|
||||
return [
|
||||
Document(
|
||||
page_content="\n\n".join(text_parts),
|
||||
metadata={"source": self.file_path},
|
||||
page_content='\n\n'.join(text_parts),
|
||||
metadata={'source': self.file_path},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -125,11 +125,11 @@ class PptxLoader:
|
||||
if shape.has_text_frame:
|
||||
slide_texts.append(shape.text_frame.text)
|
||||
if slide_texts:
|
||||
text_parts.append(f"Slide {i}:\n" + "\n".join(slide_texts))
|
||||
text_parts.append(f'Slide {i}:\n' + '\n'.join(slide_texts))
|
||||
return [
|
||||
Document(
|
||||
page_content="\n\n".join(text_parts),
|
||||
metadata={"source": self.file_path},
|
||||
page_content='\n\n'.join(text_parts),
|
||||
metadata={'source': self.file_path},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -143,41 +143,41 @@ class TikaLoader:
|
||||
self.extract_images = extract_images
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
with open(self.file_path, 'rb') as f:
|
||||
data = f.read()
|
||||
|
||||
if self.mime_type is not None:
|
||||
headers = {"Content-Type": self.mime_type}
|
||||
headers = {'Content-Type': self.mime_type}
|
||||
else:
|
||||
headers = {}
|
||||
|
||||
if self.extract_images == True:
|
||||
headers["X-Tika-PDFextractInlineImages"] = "true"
|
||||
headers['X-Tika-PDFextractInlineImages'] = 'true'
|
||||
|
||||
endpoint = self.url
|
||||
if not endpoint.endswith("/"):
|
||||
endpoint += "/"
|
||||
endpoint += "tika/text"
|
||||
if not endpoint.endswith('/'):
|
||||
endpoint += '/'
|
||||
endpoint += 'tika/text'
|
||||
|
||||
r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY)
|
||||
|
||||
if r.ok:
|
||||
raw_metadata = r.json()
|
||||
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
|
||||
text = raw_metadata.get('X-TIKA:content', '<No text content found>').strip()
|
||||
|
||||
if "Content-Type" in raw_metadata:
|
||||
headers["Content-Type"] = raw_metadata["Content-Type"]
|
||||
if 'Content-Type' in raw_metadata:
|
||||
headers['Content-Type'] = raw_metadata['Content-Type']
|
||||
|
||||
log.debug("Tika extracted text: %s", text)
|
||||
log.debug('Tika extracted text: %s', text)
|
||||
|
||||
return [Document(page_content=text, metadata=headers)]
|
||||
else:
|
||||
raise Exception(f"Error calling Tika: {r.reason}")
|
||||
raise Exception(f'Error calling Tika: {r.reason}')
|
||||
|
||||
|
||||
class DoclingLoader:
|
||||
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
|
||||
self.url = url.rstrip("/")
|
||||
self.url = url.rstrip('/')
|
||||
self.api_key = api_key
|
||||
self.file_path = file_path
|
||||
self.mime_type = mime_type
|
||||
@@ -185,199 +185,183 @@ class DoclingLoader:
|
||||
self.params = params or {}
|
||||
|
||||
def load(self) -> list[Document]:
|
||||
with open(self.file_path, "rb") as f:
|
||||
with open(self.file_path, 'rb') as f:
|
||||
headers = {}
|
||||
if self.api_key:
|
||||
headers["X-Api-Key"] = f"{self.api_key}"
|
||||
headers['X-Api-Key'] = f'{self.api_key}'
|
||||
|
||||
r = requests.post(
|
||||
f"{self.url}/v1/convert/file",
|
||||
f'{self.url}/v1/convert/file',
|
||||
files={
|
||||
"files": (
|
||||
'files': (
|
||||
self.file_path,
|
||||
f,
|
||||
self.mime_type or "application/octet-stream",
|
||||
self.mime_type or 'application/octet-stream',
|
||||
)
|
||||
},
|
||||
data={
|
||||
"image_export_mode": "placeholder",
|
||||
'image_export_mode': 'placeholder',
|
||||
**self.params,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
if r.ok:
|
||||
result = r.json()
|
||||
document_data = result.get("document", {})
|
||||
text = document_data.get("md_content", "<No text content found>")
|
||||
document_data = result.get('document', {})
|
||||
text = document_data.get('md_content', '<No text content found>')
|
||||
|
||||
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
|
||||
metadata = {'Content-Type': self.mime_type} if self.mime_type else {}
|
||||
|
||||
log.debug("Docling extracted text: %s", text)
|
||||
log.debug('Docling extracted text: %s', text)
|
||||
return [Document(page_content=text, metadata=metadata)]
|
||||
else:
|
||||
error_msg = f"Error calling Docling API: {r.reason}"
|
||||
error_msg = f'Error calling Docling API: {r.reason}'
|
||||
if r.text:
|
||||
try:
|
||||
error_data = r.json()
|
||||
if "detail" in error_data:
|
||||
error_msg += f" - {error_data['detail']}"
|
||||
if 'detail' in error_data:
|
||||
error_msg += f' - {error_data["detail"]}'
|
||||
except Exception:
|
||||
error_msg += f" - {r.text}"
|
||||
raise Exception(f"Error calling Docling: {error_msg}")
|
||||
error_msg += f' - {r.text}'
|
||||
raise Exception(f'Error calling Docling: {error_msg}')
|
||||
|
||||
|
||||
class Loader:
|
||||
def __init__(self, engine: str = "", **kwargs):
|
||||
def __init__(self, engine: str = '', **kwargs):
|
||||
self.engine = engine
|
||||
self.user = kwargs.get("user", None)
|
||||
self.user = kwargs.get('user', None)
|
||||
self.kwargs = kwargs
|
||||
|
||||
def load(
|
||||
self, filename: str, file_content_type: str, file_path: str
|
||||
) -> list[Document]:
|
||||
def load(self, filename: str, file_content_type: str, file_path: str) -> list[Document]:
|
||||
loader = self._get_loader(filename, file_content_type, file_path)
|
||||
docs = loader.load()
|
||||
|
||||
return [
|
||||
Document(
|
||||
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
return [Document(page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata) for doc in docs]
|
||||
|
||||
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
|
||||
return file_ext in known_source_ext or (
|
||||
file_content_type
|
||||
and file_content_type.find("text/") >= 0
|
||||
and file_content_type.find('text/') >= 0
|
||||
# Avoid text/html files being detected as text
|
||||
and not file_content_type.find("html") >= 0
|
||||
and not file_content_type.find('html') >= 0
|
||||
)
|
||||
|
||||
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
|
||||
file_ext = filename.split(".")[-1].lower()
|
||||
file_ext = filename.split('.')[-1].lower()
|
||||
|
||||
if (
|
||||
self.engine == "external"
|
||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
|
||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
|
||||
self.engine == 'external'
|
||||
and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL')
|
||||
and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY')
|
||||
):
|
||||
loader = ExternalDocumentLoader(
|
||||
file_path=file_path,
|
||||
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
|
||||
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
|
||||
url=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL'),
|
||||
api_key=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY'),
|
||||
mime_type=file_content_type,
|
||||
user=self.user,
|
||||
)
|
||||
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
||||
elif self.engine == 'tika' and self.kwargs.get('TIKA_SERVER_URL'):
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
loader = TikaLoader(
|
||||
url=self.kwargs.get("TIKA_SERVER_URL"),
|
||||
url=self.kwargs.get('TIKA_SERVER_URL'),
|
||||
file_path=file_path,
|
||||
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
||||
extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
|
||||
)
|
||||
elif (
|
||||
self.engine == "datalab_marker"
|
||||
and self.kwargs.get("DATALAB_MARKER_API_KEY")
|
||||
self.engine == 'datalab_marker'
|
||||
and self.kwargs.get('DATALAB_MARKER_API_KEY')
|
||||
and file_ext
|
||||
in [
|
||||
"pdf",
|
||||
"xls",
|
||||
"xlsx",
|
||||
"ods",
|
||||
"doc",
|
||||
"docx",
|
||||
"odt",
|
||||
"ppt",
|
||||
"pptx",
|
||||
"odp",
|
||||
"html",
|
||||
"epub",
|
||||
"png",
|
||||
"jpeg",
|
||||
"jpg",
|
||||
"webp",
|
||||
"gif",
|
||||
"tiff",
|
||||
'pdf',
|
||||
'xls',
|
||||
'xlsx',
|
||||
'ods',
|
||||
'doc',
|
||||
'docx',
|
||||
'odt',
|
||||
'ppt',
|
||||
'pptx',
|
||||
'odp',
|
||||
'html',
|
||||
'epub',
|
||||
'png',
|
||||
'jpeg',
|
||||
'jpg',
|
||||
'webp',
|
||||
'gif',
|
||||
'tiff',
|
||||
]
|
||||
):
|
||||
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "")
|
||||
if not api_base_url or api_base_url.strip() == "":
|
||||
api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
|
||||
api_base_url = self.kwargs.get('DATALAB_MARKER_API_BASE_URL', '')
|
||||
if not api_base_url or api_base_url.strip() == '':
|
||||
api_base_url = 'https://www.datalab.to/api/v1/marker' # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
|
||||
|
||||
loader = DatalabMarkerLoader(
|
||||
file_path=file_path,
|
||||
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
|
||||
api_key=self.kwargs['DATALAB_MARKER_API_KEY'],
|
||||
api_base_url=api_base_url,
|
||||
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"),
|
||||
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
|
||||
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
|
||||
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
|
||||
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
|
||||
strip_existing_ocr=self.kwargs.get(
|
||||
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
|
||||
),
|
||||
disable_image_extraction=self.kwargs.get(
|
||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
|
||||
),
|
||||
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
|
||||
output_format=self.kwargs.get(
|
||||
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
|
||||
),
|
||||
additional_config=self.kwargs.get('DATALAB_MARKER_ADDITIONAL_CONFIG'),
|
||||
use_llm=self.kwargs.get('DATALAB_MARKER_USE_LLM', False),
|
||||
skip_cache=self.kwargs.get('DATALAB_MARKER_SKIP_CACHE', False),
|
||||
force_ocr=self.kwargs.get('DATALAB_MARKER_FORCE_OCR', False),
|
||||
paginate=self.kwargs.get('DATALAB_MARKER_PAGINATE', False),
|
||||
strip_existing_ocr=self.kwargs.get('DATALAB_MARKER_STRIP_EXISTING_OCR', False),
|
||||
disable_image_extraction=self.kwargs.get('DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION', False),
|
||||
format_lines=self.kwargs.get('DATALAB_MARKER_FORMAT_LINES', False),
|
||||
output_format=self.kwargs.get('DATALAB_MARKER_OUTPUT_FORMAT', 'markdown'),
|
||||
)
|
||||
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
|
||||
elif self.engine == 'docling' and self.kwargs.get('DOCLING_SERVER_URL'):
|
||||
if self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
else:
|
||||
# Build params for DoclingLoader
|
||||
params = self.kwargs.get("DOCLING_PARAMS", {})
|
||||
params = self.kwargs.get('DOCLING_PARAMS', {})
|
||||
if not isinstance(params, dict):
|
||||
try:
|
||||
params = json.loads(params)
|
||||
except json.JSONDecodeError:
|
||||
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
|
||||
log.error('Invalid DOCLING_PARAMS format, expected JSON object')
|
||||
params = {}
|
||||
|
||||
loader = DoclingLoader(
|
||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
||||
api_key=self.kwargs.get("DOCLING_API_KEY", None),
|
||||
url=self.kwargs.get('DOCLING_SERVER_URL'),
|
||||
api_key=self.kwargs.get('DOCLING_API_KEY', None),
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
params=params,
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||
self.engine == 'document_intelligence'
|
||||
and self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT') != ''
|
||||
and (
|
||||
file_ext in ["pdf", "docx", "ppt", "pptx"]
|
||||
file_ext in ['pdf', 'docx', 'ppt', 'pptx']
|
||||
or file_content_type
|
||||
in [
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
]
|
||||
)
|
||||
):
|
||||
if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "":
|
||||
if self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY') != '':
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
||||
api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
|
||||
api_key=self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY'),
|
||||
api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
|
||||
)
|
||||
else:
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
|
||||
azure_credential=DefaultAzureCredential(),
|
||||
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
||||
api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
|
||||
)
|
||||
elif self.engine == "mineru" and file_ext in [
|
||||
"pdf"
|
||||
]: # MinerU currently only supports PDF
|
||||
|
||||
mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300)
|
||||
elif self.engine == 'mineru' and file_ext in ['pdf']: # MinerU currently only supports PDF
|
||||
mineru_timeout = self.kwargs.get('MINERU_API_TIMEOUT', 300)
|
||||
if mineru_timeout:
|
||||
try:
|
||||
mineru_timeout = int(mineru_timeout)
|
||||
@@ -386,111 +370,115 @@ class Loader:
|
||||
|
||||
loader = MinerULoader(
|
||||
file_path=file_path,
|
||||
api_mode=self.kwargs.get("MINERU_API_MODE", "local"),
|
||||
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"),
|
||||
api_key=self.kwargs.get("MINERU_API_KEY", ""),
|
||||
params=self.kwargs.get("MINERU_PARAMS", {}),
|
||||
api_mode=self.kwargs.get('MINERU_API_MODE', 'local'),
|
||||
api_url=self.kwargs.get('MINERU_API_URL', 'http://localhost:8000'),
|
||||
api_key=self.kwargs.get('MINERU_API_KEY', ''),
|
||||
params=self.kwargs.get('MINERU_PARAMS', {}),
|
||||
timeout=mineru_timeout,
|
||||
)
|
||||
elif (
|
||||
self.engine == "mistral_ocr"
|
||||
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
|
||||
and file_ext
|
||||
in ["pdf"] # Mistral OCR currently only supports PDF and images
|
||||
self.engine == 'mistral_ocr'
|
||||
and self.kwargs.get('MISTRAL_OCR_API_KEY') != ''
|
||||
and file_ext in ['pdf'] # Mistral OCR currently only supports PDF and images
|
||||
):
|
||||
loader = MistralLoader(
|
||||
base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"),
|
||||
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"),
|
||||
base_url=self.kwargs.get('MISTRAL_OCR_API_BASE_URL'),
|
||||
api_key=self.kwargs.get('MISTRAL_OCR_API_KEY'),
|
||||
file_path=file_path,
|
||||
)
|
||||
else:
|
||||
if file_ext == "pdf":
|
||||
if file_ext == 'pdf':
|
||||
loader = PyPDFLoader(
|
||||
file_path,
|
||||
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
||||
mode=self.kwargs.get("PDF_LOADER_MODE", "page"),
|
||||
extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
|
||||
mode=self.kwargs.get('PDF_LOADER_MODE', 'page'),
|
||||
)
|
||||
elif file_ext == "csv":
|
||||
elif file_ext == 'csv':
|
||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||
elif file_ext == "rst":
|
||||
elif file_ext == 'rst':
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredRSTLoader
|
||||
loader = UnstructuredRSTLoader(file_path, mode="elements")
|
||||
|
||||
loader = UnstructuredRSTLoader(file_path, mode='elements')
|
||||
except ImportError:
|
||||
log.warning(
|
||||
"The 'unstructured' package is not installed. "
|
||||
"Falling back to plain text loading for .rst file. "
|
||||
"Install it with: pip install unstructured"
|
||||
'Falling back to plain text loading for .rst file. '
|
||||
'Install it with: pip install unstructured'
|
||||
)
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
elif file_ext == "xml":
|
||||
elif file_ext == 'xml':
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredXMLLoader
|
||||
|
||||
loader = UnstructuredXMLLoader(file_path)
|
||||
except ImportError:
|
||||
log.warning(
|
||||
"The 'unstructured' package is not installed. "
|
||||
"Falling back to plain text loading for .xml file. "
|
||||
"Install it with: pip install unstructured"
|
||||
'Falling back to plain text loading for .xml file. '
|
||||
'Install it with: pip install unstructured'
|
||||
)
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
elif file_ext in ["htm", "html"]:
|
||||
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
|
||||
elif file_ext == "md":
|
||||
elif file_ext in ['htm', 'html']:
|
||||
loader = BSHTMLLoader(file_path, open_encoding='unicode_escape')
|
||||
elif file_ext == 'md':
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
elif file_content_type == "application/epub+zip":
|
||||
elif file_content_type == 'application/epub+zip':
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredEPubLoader
|
||||
|
||||
loader = UnstructuredEPubLoader(file_path)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Processing .epub files requires the 'unstructured' package. "
|
||||
"Install it with: pip install unstructured"
|
||||
'Install it with: pip install unstructured'
|
||||
)
|
||||
elif (
|
||||
file_content_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
or file_ext == "docx"
|
||||
file_content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
||||
or file_ext == 'docx'
|
||||
):
|
||||
loader = Docx2txtLoader(file_path)
|
||||
elif file_content_type in [
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
] or file_ext in ["xls", "xlsx"]:
|
||||
'application/vnd.ms-excel',
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
] or file_ext in ['xls', 'xlsx']:
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredExcelLoader
|
||||
|
||||
loader = UnstructuredExcelLoader(file_path)
|
||||
except ImportError:
|
||||
log.warning(
|
||||
"The 'unstructured' package is not installed. "
|
||||
"Falling back to pandas for Excel file loading. "
|
||||
"Install unstructured for better results: pip install unstructured"
|
||||
'Falling back to pandas for Excel file loading. '
|
||||
'Install unstructured for better results: pip install unstructured'
|
||||
)
|
||||
loader = ExcelLoader(file_path)
|
||||
elif file_content_type in [
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
] or file_ext in ["ppt", "pptx"]:
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
] or file_ext in ['ppt', 'pptx']:
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredPowerPointLoader
|
||||
|
||||
loader = UnstructuredPowerPointLoader(file_path)
|
||||
except ImportError:
|
||||
log.warning(
|
||||
"The 'unstructured' package is not installed. "
|
||||
"Falling back to python-pptx for PowerPoint file loading. "
|
||||
"Install unstructured for better results: pip install unstructured"
|
||||
'Falling back to python-pptx for PowerPoint file loading. '
|
||||
'Install unstructured for better results: pip install unstructured'
|
||||
)
|
||||
loader = PptxLoader(file_path)
|
||||
elif file_ext == "msg":
|
||||
elif file_ext == 'msg':
|
||||
loader = OutlookMessageLoader(file_path)
|
||||
elif file_ext == "odt":
|
||||
elif file_ext == 'odt':
|
||||
try:
|
||||
from langchain_community.document_loaders import UnstructuredODTLoader
|
||||
|
||||
loader = UnstructuredODTLoader(file_path)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Processing .odt files requires the 'unstructured' package. "
|
||||
"Install it with: pip install unstructured"
|
||||
'Install it with: pip install unstructured'
|
||||
)
|
||||
elif self._is_text_file(file_ext, file_content_type):
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
@@ -498,4 +486,3 @@ class Loader:
|
||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
@@ -22,37 +22,35 @@ class MinerULoader:
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
api_mode: str = "local",
|
||||
api_url: str = "http://localhost:8000",
|
||||
api_key: str = "",
|
||||
api_mode: str = 'local',
|
||||
api_url: str = 'http://localhost:8000',
|
||||
api_key: str = '',
|
||||
params: dict = None,
|
||||
timeout: Optional[int] = 300,
|
||||
):
|
||||
self.file_path = file_path
|
||||
self.api_mode = api_mode.lower()
|
||||
self.api_url = api_url.rstrip("/")
|
||||
self.api_url = api_url.rstrip('/')
|
||||
self.api_key = api_key
|
||||
self.timeout = timeout
|
||||
|
||||
# Parse params dict with defaults
|
||||
self.params = params or {}
|
||||
self.enable_ocr = params.get("enable_ocr", False)
|
||||
self.enable_formula = params.get("enable_formula", True)
|
||||
self.enable_table = params.get("enable_table", True)
|
||||
self.language = params.get("language", "en")
|
||||
self.model_version = params.get("model_version", "pipeline")
|
||||
self.enable_ocr = params.get('enable_ocr', False)
|
||||
self.enable_formula = params.get('enable_formula', True)
|
||||
self.enable_table = params.get('enable_table', True)
|
||||
self.language = params.get('language', 'en')
|
||||
self.model_version = params.get('model_version', 'pipeline')
|
||||
|
||||
self.page_ranges = self.params.pop("page_ranges", "")
|
||||
self.page_ranges = self.params.pop('page_ranges', '')
|
||||
|
||||
# Validate API mode
|
||||
if self.api_mode not in ["local", "cloud"]:
|
||||
raise ValueError(
|
||||
f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'"
|
||||
)
|
||||
if self.api_mode not in ['local', 'cloud']:
|
||||
raise ValueError(f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'")
|
||||
|
||||
# Validate Cloud API requirements
|
||||
if self.api_mode == "cloud" and not self.api_key:
|
||||
raise ValueError("API key is required for Cloud API mode")
|
||||
if self.api_mode == 'cloud' and not self.api_key:
|
||||
raise ValueError('API key is required for Cloud API mode')
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
@@ -60,12 +58,12 @@ class MinerULoader:
|
||||
Routes to Cloud or Local API based on api_mode.
|
||||
"""
|
||||
try:
|
||||
if self.api_mode == "cloud":
|
||||
if self.api_mode == 'cloud':
|
||||
return self._load_cloud_api()
|
||||
else:
|
||||
return self._load_local_api()
|
||||
except Exception as e:
|
||||
log.error(f"Error loading document with MinerU: {e}")
|
||||
log.error(f'Error loading document with MinerU: {e}')
|
||||
raise
|
||||
|
||||
def _load_local_api(self) -> List[Document]:
|
||||
@@ -73,14 +71,14 @@ class MinerULoader:
|
||||
Load document using Local API (synchronous).
|
||||
Posts file to /file_parse endpoint and gets immediate response.
|
||||
"""
|
||||
log.info(f"Using MinerU Local API at {self.api_url}")
|
||||
log.info(f'Using MinerU Local API at {self.api_url}')
|
||||
|
||||
filename = os.path.basename(self.file_path)
|
||||
|
||||
# Build form data for Local API
|
||||
form_data = {
|
||||
**self.params,
|
||||
"return_md": "true",
|
||||
'return_md': 'true',
|
||||
}
|
||||
|
||||
# Page ranges (Local API uses start_page_id and end_page_id)
|
||||
@@ -89,18 +87,18 @@ class MinerULoader:
|
||||
# Full page range parsing would require parsing the string
|
||||
log.warning(
|
||||
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
|
||||
"Consider using start_page_id/end_page_id parameters if needed."
|
||||
'Consider using start_page_id/end_page_id parameters if needed.'
|
||||
)
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"files": (filename, f, "application/octet-stream")}
|
||||
with open(self.file_path, 'rb') as f:
|
||||
files = {'files': (filename, f, 'application/octet-stream')}
|
||||
|
||||
log.info(f"Sending file to MinerU Local API: {filename}")
|
||||
log.debug(f"Local API parameters: {form_data}")
|
||||
log.info(f'Sending file to MinerU Local API: {filename}')
|
||||
log.debug(f'Local API parameters: {form_data}')
|
||||
|
||||
response = requests.post(
|
||||
f"{self.api_url}/file_parse",
|
||||
f'{self.api_url}/file_parse',
|
||||
data=form_data,
|
||||
files=files,
|
||||
timeout=self.timeout,
|
||||
@@ -108,27 +106,25 @@ class MinerULoader:
|
||||
response.raise_for_status()
|
||||
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
|
||||
except requests.Timeout:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="MinerU Local API request timed out",
|
||||
detail='MinerU Local API request timed out',
|
||||
)
|
||||
except requests.HTTPError as e:
|
||||
error_detail = f"MinerU Local API request failed: {e}"
|
||||
error_detail = f'MinerU Local API request failed: {e}'
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_detail += f" - {error_data}"
|
||||
error_detail += f' - {error_data}'
|
||||
except Exception:
|
||||
error_detail += f" - {e.response.text}"
|
||||
error_detail += f' - {e.response.text}'
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error calling MinerU Local API: {str(e)}",
|
||||
detail=f'Error calling MinerU Local API: {str(e)}',
|
||||
)
|
||||
|
||||
# Parse response
|
||||
@@ -137,41 +133,41 @@ class MinerULoader:
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid JSON response from MinerU Local API: {e}",
|
||||
detail=f'Invalid JSON response from MinerU Local API: {e}',
|
||||
)
|
||||
|
||||
# Extract markdown content from response
|
||||
if "results" not in result:
|
||||
if 'results' not in result:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail="MinerU Local API response missing 'results' field",
|
||||
)
|
||||
|
||||
results = result["results"]
|
||||
results = result['results']
|
||||
if not results:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="MinerU returned empty results",
|
||||
detail='MinerU returned empty results',
|
||||
)
|
||||
|
||||
# Get the first (and typically only) result
|
||||
file_result = list(results.values())[0]
|
||||
markdown_content = file_result.get("md_content", "")
|
||||
markdown_content = file_result.get('md_content', '')
|
||||
|
||||
if not markdown_content:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="MinerU returned empty markdown content",
|
||||
detail='MinerU returned empty markdown content',
|
||||
)
|
||||
|
||||
log.info(f"Successfully parsed document with MinerU Local API: {filename}")
|
||||
log.info(f'Successfully parsed document with MinerU Local API: {filename}')
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"api_mode": "local",
|
||||
"backend": result.get("backend", "unknown"),
|
||||
"version": result.get("version", "unknown"),
|
||||
'source': filename,
|
||||
'api_mode': 'local',
|
||||
'backend': result.get('backend', 'unknown'),
|
||||
'version': result.get('version', 'unknown'),
|
||||
}
|
||||
|
||||
return [Document(page_content=markdown_content, metadata=metadata)]
|
||||
@@ -181,7 +177,7 @@ class MinerULoader:
|
||||
Load document using Cloud API (asynchronous).
|
||||
Uses batch upload endpoint to avoid need for public file URLs.
|
||||
"""
|
||||
log.info(f"Using MinerU Cloud API at {self.api_url}")
|
||||
log.info(f'Using MinerU Cloud API at {self.api_url}')
|
||||
|
||||
filename = os.path.basename(self.file_path)
|
||||
|
||||
@@ -195,17 +191,15 @@ class MinerULoader:
|
||||
result = self._poll_batch_status(batch_id, filename)
|
||||
|
||||
# Step 4: Download and extract markdown from ZIP
|
||||
markdown_content = self._download_and_extract_zip(
|
||||
result["full_zip_url"], filename
|
||||
)
|
||||
markdown_content = self._download_and_extract_zip(result['full_zip_url'], filename)
|
||||
|
||||
log.info(f"Successfully parsed document with MinerU Cloud API: {filename}")
|
||||
log.info(f'Successfully parsed document with MinerU Cloud API: {filename}')
|
||||
|
||||
# Create metadata
|
||||
metadata = {
|
||||
"source": filename,
|
||||
"api_mode": "cloud",
|
||||
"batch_id": batch_id,
|
||||
'source': filename,
|
||||
'api_mode': 'cloud',
|
||||
'batch_id': batch_id,
|
||||
}
|
||||
|
||||
return [Document(page_content=markdown_content, metadata=metadata)]
|
||||
@@ -216,49 +210,49 @@ class MinerULoader:
|
||||
Returns (batch_id, upload_url).
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
# Build request body
|
||||
request_body = {
|
||||
**self.params,
|
||||
"files": [
|
||||
'files': [
|
||||
{
|
||||
"name": filename,
|
||||
"is_ocr": self.enable_ocr,
|
||||
'name': filename,
|
||||
'is_ocr': self.enable_ocr,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Add page ranges if specified
|
||||
if self.page_ranges:
|
||||
request_body["files"][0]["page_ranges"] = self.page_ranges
|
||||
request_body['files'][0]['page_ranges'] = self.page_ranges
|
||||
|
||||
log.info(f"Requesting upload URL for: {filename}")
|
||||
log.debug(f"Cloud API request body: {request_body}")
|
||||
log.info(f'Requesting upload URL for: {filename}')
|
||||
log.debug(f'Cloud API request body: {request_body}')
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.api_url}/file-urls/batch",
|
||||
f'{self.api_url}/file-urls/batch',
|
||||
headers=headers,
|
||||
json=request_body,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = f"Failed to request upload URL: {e}"
|
||||
error_detail = f'Failed to request upload URL: {e}'
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_detail += f" - {error_data.get('msg', error_data)}"
|
||||
error_detail += f' - {error_data.get("msg", error_data)}'
|
||||
except Exception:
|
||||
error_detail += f" - {e.response.text}"
|
||||
error_detail += f' - {e.response.text}'
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error requesting upload URL: {str(e)}",
|
||||
detail=f'Error requesting upload URL: {str(e)}',
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -266,28 +260,28 @@ class MinerULoader:
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid JSON response: {e}",
|
||||
detail=f'Invalid JSON response: {e}',
|
||||
)
|
||||
|
||||
# Check for API error response
|
||||
if result.get("code") != 0:
|
||||
if result.get('code') != 0:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
|
||||
detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
|
||||
)
|
||||
|
||||
data = result.get("data", {})
|
||||
batch_id = data.get("batch_id")
|
||||
file_urls = data.get("file_urls", [])
|
||||
data = result.get('data', {})
|
||||
batch_id = data.get('batch_id')
|
||||
file_urls = data.get('file_urls', [])
|
||||
|
||||
if not batch_id or not file_urls:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail="MinerU Cloud API response missing batch_id or file_urls",
|
||||
detail='MinerU Cloud API response missing batch_id or file_urls',
|
||||
)
|
||||
|
||||
upload_url = file_urls[0]
|
||||
log.info(f"Received upload URL for batch: {batch_id}")
|
||||
log.info(f'Received upload URL for batch: {batch_id}')
|
||||
|
||||
return batch_id, upload_url
|
||||
|
||||
@@ -295,10 +289,10 @@ class MinerULoader:
|
||||
"""
|
||||
Upload file to presigned URL (no authentication needed).
|
||||
"""
|
||||
log.info(f"Uploading file to presigned URL")
|
||||
log.info(f'Uploading file to presigned URL')
|
||||
|
||||
try:
|
||||
with open(self.file_path, "rb") as f:
|
||||
with open(self.file_path, 'rb') as f:
|
||||
response = requests.put(
|
||||
upload_url,
|
||||
data=f,
|
||||
@@ -306,26 +300,24 @@ class MinerULoader:
|
||||
)
|
||||
response.raise_for_status()
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(
|
||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
||||
)
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
|
||||
except requests.Timeout:
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="File upload to presigned URL timed out",
|
||||
detail='File upload to presigned URL timed out',
|
||||
)
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Failed to upload file to presigned URL: {e}",
|
||||
detail=f'Failed to upload file to presigned URL: {e}',
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error uploading file: {str(e)}",
|
||||
detail=f'Error uploading file: {str(e)}',
|
||||
)
|
||||
|
||||
log.info("File uploaded successfully")
|
||||
log.info('File uploaded successfully')
|
||||
|
||||
def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
|
||||
"""
|
||||
@@ -333,35 +325,35 @@ class MinerULoader:
|
||||
Returns the result dict for the file.
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
}
|
||||
|
||||
max_iterations = 300 # 10 minutes max (2 seconds per iteration)
|
||||
poll_interval = 2 # seconds
|
||||
|
||||
log.info(f"Polling batch status: {batch_id}")
|
||||
log.info(f'Polling batch status: {batch_id}')
|
||||
|
||||
for iteration in range(max_iterations):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{self.api_url}/extract-results/batch/{batch_id}",
|
||||
f'{self.api_url}/extract-results/batch/{batch_id}',
|
||||
headers=headers,
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = f"Failed to poll batch status: {e}"
|
||||
error_detail = f'Failed to poll batch status: {e}'
|
||||
if e.response is not None:
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
error_detail += f" - {error_data.get('msg', error_data)}"
|
||||
error_detail += f' - {error_data.get("msg", error_data)}'
|
||||
except Exception:
|
||||
error_detail += f" - {e.response.text}"
|
||||
error_detail += f' - {e.response.text}'
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error polling batch status: {str(e)}",
|
||||
detail=f'Error polling batch status: {str(e)}',
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -369,58 +361,56 @@ class MinerULoader:
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid JSON response while polling: {e}",
|
||||
detail=f'Invalid JSON response while polling: {e}',
|
||||
)
|
||||
|
||||
# Check for API error response
|
||||
if result.get("code") != 0:
|
||||
if result.get('code') != 0:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
|
||||
detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
|
||||
)
|
||||
|
||||
data = result.get("data", {})
|
||||
extract_result = data.get("extract_result", [])
|
||||
data = result.get('data', {})
|
||||
extract_result = data.get('extract_result', [])
|
||||
|
||||
# Find our file in the batch results
|
||||
file_result = None
|
||||
for item in extract_result:
|
||||
if item.get("file_name") == filename:
|
||||
if item.get('file_name') == filename:
|
||||
file_result = item
|
||||
break
|
||||
|
||||
if not file_result:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"File {filename} not found in batch results",
|
||||
detail=f'File {filename} not found in batch results',
|
||||
)
|
||||
|
||||
state = file_result.get("state")
|
||||
state = file_result.get('state')
|
||||
|
||||
if state == "done":
|
||||
log.info(f"Processing complete for {filename}")
|
||||
if state == 'done':
|
||||
log.info(f'Processing complete for {filename}')
|
||||
return file_result
|
||||
elif state == "failed":
|
||||
error_msg = file_result.get("err_msg", "Unknown error")
|
||||
elif state == 'failed':
|
||||
error_msg = file_result.get('err_msg', 'Unknown error')
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"MinerU processing failed: {error_msg}",
|
||||
detail=f'MinerU processing failed: {error_msg}',
|
||||
)
|
||||
elif state in ["waiting-file", "pending", "running", "converting"]:
|
||||
elif state in ['waiting-file', 'pending', 'running', 'converting']:
|
||||
# Still processing
|
||||
if iteration % 10 == 0: # Log every 20 seconds
|
||||
log.info(
|
||||
f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})"
|
||||
)
|
||||
log.info(f'Processing status: {state} (iteration {iteration + 1}/{max_iterations})')
|
||||
time.sleep(poll_interval)
|
||||
else:
|
||||
log.warning(f"Unknown state: {state}")
|
||||
log.warning(f'Unknown state: {state}')
|
||||
time.sleep(poll_interval)
|
||||
|
||||
# Timeout
|
||||
raise HTTPException(
|
||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||
detail="MinerU processing timed out after 10 minutes",
|
||||
detail='MinerU processing timed out after 10 minutes',
|
||||
)
|
||||
|
||||
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
|
||||
@@ -428,7 +418,7 @@ class MinerULoader:
|
||||
Download ZIP file from CDN and extract markdown content.
|
||||
Returns the markdown content as a string.
|
||||
"""
|
||||
log.info(f"Downloading results from: {zip_url}")
|
||||
log.info(f'Downloading results from: {zip_url}')
|
||||
|
||||
try:
|
||||
response = requests.get(zip_url, timeout=60)
|
||||
@@ -436,23 +426,23 @@ class MinerULoader:
|
||||
except requests.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Failed to download results ZIP: {e}",
|
||||
detail=f'Failed to download results ZIP: {e}',
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error downloading results: {str(e)}",
|
||||
detail=f'Error downloading results: {str(e)}',
|
||||
)
|
||||
|
||||
# Save ZIP to temporary file and extract
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_zip:
|
||||
tmp_zip.write(response.content)
|
||||
tmp_zip_path = tmp_zip.name
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Extract ZIP
|
||||
with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref:
|
||||
with zipfile.ZipFile(tmp_zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(tmp_dir)
|
||||
|
||||
# Find markdown file - search recursively for any .md file
|
||||
@@ -466,33 +456,27 @@ class MinerULoader:
|
||||
full_path = os.path.join(root, file)
|
||||
all_files.append(full_path)
|
||||
# Look for any .md file
|
||||
if file.endswith(".md"):
|
||||
if file.endswith('.md'):
|
||||
found_md_path = full_path
|
||||
log.info(f"Found markdown file at: {full_path}")
|
||||
log.info(f'Found markdown file at: {full_path}')
|
||||
try:
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
with open(full_path, 'r', encoding='utf-8') as f:
|
||||
markdown_content = f.read()
|
||||
if (
|
||||
markdown_content
|
||||
): # Use the first non-empty markdown file
|
||||
if markdown_content: # Use the first non-empty markdown file
|
||||
break
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to read {full_path}: {e}")
|
||||
log.warning(f'Failed to read {full_path}: {e}')
|
||||
if markdown_content:
|
||||
break
|
||||
|
||||
if markdown_content is None:
|
||||
log.error(f"Available files in ZIP: {all_files}")
|
||||
log.error(f'Available files in ZIP: {all_files}')
|
||||
# Try to provide more helpful error message
|
||||
md_files = [f for f in all_files if f.endswith(".md")]
|
||||
md_files = [f for f in all_files if f.endswith('.md')]
|
||||
if md_files:
|
||||
error_msg = (
|
||||
f"Found .md files but couldn't read them: {md_files}"
|
||||
)
|
||||
error_msg = f"Found .md files but couldn't read them: {md_files}"
|
||||
else:
|
||||
error_msg = (
|
||||
f"No .md files found in ZIP. Available files: {all_files}"
|
||||
)
|
||||
error_msg = f'No .md files found in ZIP. Available files: {all_files}'
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=error_msg,
|
||||
@@ -504,21 +488,19 @@ class MinerULoader:
|
||||
except zipfile.BadZipFile as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Invalid ZIP file received: {e}",
|
||||
detail=f'Invalid ZIP file received: {e}',
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Error extracting ZIP: {str(e)}",
|
||||
detail=f'Error extracting ZIP: {str(e)}',
|
||||
)
|
||||
|
||||
if not markdown_content:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
detail="Extracted markdown content is empty",
|
||||
detail='Extracted markdown content is empty',
|
||||
)
|
||||
|
||||
log.info(
|
||||
f"Successfully extracted markdown content ({len(markdown_content)} characters)"
|
||||
)
|
||||
log.info(f'Successfully extracted markdown content ({len(markdown_content)} characters)')
|
||||
return markdown_content
|
||||
|
||||
@@ -49,13 +49,11 @@ class MistralLoader:
|
||||
enable_debug_logging: Enable detailed debug logs.
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("API key cannot be empty.")
|
||||
raise ValueError('API key cannot be empty.')
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found at {file_path}")
|
||||
raise FileNotFoundError(f'File not found at {file_path}')
|
||||
|
||||
self.base_url = (
|
||||
base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1"
|
||||
)
|
||||
self.base_url = base_url.rstrip('/') if base_url else 'https://api.mistral.ai/v1'
|
||||
self.api_key = api_key
|
||||
self.file_path = file_path
|
||||
self.timeout = timeout
|
||||
@@ -65,18 +63,10 @@ class MistralLoader:
|
||||
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
|
||||
# This prevents long-running OCR operations from affecting quick operations
|
||||
# and improves user experience by failing fast on operations that should be quick
|
||||
self.upload_timeout = min(
|
||||
timeout, 120
|
||||
) # Cap upload at 2 minutes - prevents hanging on large files
|
||||
self.url_timeout = (
|
||||
30 # URL requests should be fast - fail quickly if API is slow
|
||||
)
|
||||
self.ocr_timeout = (
|
||||
timeout # OCR can take the full timeout - this is the heavy operation
|
||||
)
|
||||
self.cleanup_timeout = (
|
||||
30 # Cleanup should be quick - don't hang on file deletion
|
||||
)
|
||||
self.upload_timeout = min(timeout, 120) # Cap upload at 2 minutes - prevents hanging on large files
|
||||
self.url_timeout = 30 # URL requests should be fast - fail quickly if API is slow
|
||||
self.ocr_timeout = timeout # OCR can take the full timeout - this is the heavy operation
|
||||
self.cleanup_timeout = 30 # Cleanup should be quick - don't hang on file deletion
|
||||
|
||||
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
|
||||
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
|
||||
@@ -85,8 +75,8 @@ class MistralLoader:
|
||||
|
||||
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'User-Agent': 'OpenWebUI-MistralLoader/2.0', # Helps API provider track usage
|
||||
}
|
||||
|
||||
def _debug_log(self, message: str, *args) -> None:
|
||||
@@ -108,43 +98,39 @@ class MistralLoader:
|
||||
return {} # Return empty dict if no content
|
||||
return response.json()
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
|
||||
log.error(f'HTTP error occurred: {http_err} - Response: {response.text}')
|
||||
raise
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
log.error(f"Request exception occurred: {req_err}")
|
||||
log.error(f'Request exception occurred: {req_err}')
|
||||
raise
|
||||
except ValueError as json_err: # Includes JSONDecodeError
|
||||
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
|
||||
log.error(f'JSON decode error: {json_err} - Response: {response.text}')
|
||||
raise # Re-raise after logging
|
||||
|
||||
async def _handle_response_async(
|
||||
self, response: aiohttp.ClientResponse
|
||||
) -> Dict[str, Any]:
|
||||
async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]:
|
||||
"""Async version of response handling with better error info."""
|
||||
try:
|
||||
response.raise_for_status()
|
||||
|
||||
# Check content type
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "application/json" not in content_type:
|
||||
content_type = response.headers.get('content-type', '')
|
||||
if 'application/json' not in content_type:
|
||||
if response.status == 204:
|
||||
return {}
|
||||
text = await response.text()
|
||||
raise ValueError(
|
||||
f"Unexpected content type: {content_type}, body: {text[:200]}..."
|
||||
)
|
||||
raise ValueError(f'Unexpected content type: {content_type}, body: {text[:200]}...')
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientResponseError as e:
|
||||
error_text = await response.text() if response else "No response"
|
||||
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
|
||||
error_text = await response.text() if response else 'No response'
|
||||
log.error(f'HTTP {e.status}: {e.message} - Response: {error_text[:500]}')
|
||||
raise
|
||||
except aiohttp.ClientError as e:
|
||||
log.error(f"Client error: {e}")
|
||||
log.error(f'Client error: {e}')
|
||||
raise
|
||||
except Exception as e:
|
||||
log.error(f"Unexpected error processing response: {e}")
|
||||
log.error(f'Unexpected error processing response: {e}')
|
||||
raise
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
@@ -172,13 +158,11 @@ class MistralLoader:
|
||||
return True # Timeouts might resolve on retry
|
||||
if isinstance(error, requests.exceptions.HTTPError):
|
||||
# Only retry on server errors (5xx) or rate limits (429)
|
||||
if hasattr(error, "response") and error.response is not None:
|
||||
if hasattr(error, 'response') and error.response is not None:
|
||||
status_code = error.response.status_code
|
||||
return status_code >= 500 or status_code == 429
|
||||
return False
|
||||
if isinstance(
|
||||
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
|
||||
):
|
||||
if isinstance(error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)):
|
||||
return True # Async network/timeout errors are retryable
|
||||
if isinstance(error, aiohttp.ClientResponseError):
|
||||
return error.status >= 500 or error.status == 429
|
||||
@@ -204,8 +188,7 @@ class MistralLoader:
|
||||
# Prevents overwhelming the server while ensuring reasonable retry delays
|
||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||
log.warning(
|
||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
||||
f"Retrying in {wait_time}s..."
|
||||
f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
@@ -226,8 +209,7 @@ class MistralLoader:
|
||||
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
|
||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||
log.warning(
|
||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
||||
f"Retrying in {wait_time}s..."
|
||||
f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
|
||||
)
|
||||
await asyncio.sleep(wait_time) # Non-blocking wait
|
||||
|
||||
@@ -240,15 +222,15 @@ class MistralLoader:
|
||||
Although streaming is not enabled for this endpoint, the file is opened
|
||||
in a context manager to minimize memory usage duration.
|
||||
"""
|
||||
log.info("Uploading file to Mistral API")
|
||||
url = f"{self.base_url}/files"
|
||||
log.info('Uploading file to Mistral API')
|
||||
url = f'{self.base_url}/files'
|
||||
|
||||
def upload_request():
|
||||
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
|
||||
# This ensures the file is closed immediately after reading, reducing memory usage
|
||||
with open(self.file_path, "rb") as f:
|
||||
files = {"file": (self.file_name, f, "application/pdf")}
|
||||
data = {"purpose": "ocr"}
|
||||
with open(self.file_path, 'rb') as f:
|
||||
files = {'file': (self.file_name, f, 'application/pdf')}
|
||||
data = {'purpose': 'ocr'}
|
||||
|
||||
# NOTE: stream=False is required for this endpoint
|
||||
# The Mistral API doesn't support chunked uploads for this endpoint
|
||||
@@ -265,42 +247,38 @@ class MistralLoader:
|
||||
|
||||
try:
|
||||
response_data = self._retry_request_sync(upload_request)
|
||||
file_id = response_data.get("id")
|
||||
file_id = response_data.get('id')
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
||||
raise ValueError('File ID not found in upload response.')
|
||||
log.info(f'File uploaded successfully. File ID: {file_id}')
|
||||
return file_id
|
||||
except Exception as e:
|
||||
log.error(f"Failed to upload file: {e}")
|
||||
log.error(f'Failed to upload file: {e}')
|
||||
raise
|
||||
|
||||
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
|
||||
"""Async file upload with streaming for better memory efficiency."""
|
||||
url = f"{self.base_url}/files"
|
||||
url = f'{self.base_url}/files'
|
||||
|
||||
async def upload_request():
|
||||
# Create multipart writer for streaming upload
|
||||
writer = aiohttp.MultipartWriter("form-data")
|
||||
writer = aiohttp.MultipartWriter('form-data')
|
||||
|
||||
# Add purpose field
|
||||
purpose_part = writer.append("ocr")
|
||||
purpose_part.set_content_disposition("form-data", name="purpose")
|
||||
purpose_part = writer.append('ocr')
|
||||
purpose_part.set_content_disposition('form-data', name='purpose')
|
||||
|
||||
# Add file part with streaming
|
||||
file_part = writer.append_payload(
|
||||
aiohttp.streams.FilePayload(
|
||||
self.file_path,
|
||||
filename=self.file_name,
|
||||
content_type="application/pdf",
|
||||
content_type='application/pdf',
|
||||
)
|
||||
)
|
||||
file_part.set_content_disposition(
|
||||
"form-data", name="file", filename=self.file_name
|
||||
)
|
||||
file_part.set_content_disposition('form-data', name='file', filename=self.file_name)
|
||||
|
||||
self._debug_log(
|
||||
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
|
||||
)
|
||||
self._debug_log(f'Uploading file: {self.file_name} ({self.file_size:,} bytes)')
|
||||
|
||||
async with session.post(
|
||||
url,
|
||||
@@ -312,48 +290,44 @@ class MistralLoader:
|
||||
|
||||
response_data = await self._retry_request_async(upload_request)
|
||||
|
||||
file_id = response_data.get("id")
|
||||
file_id = response_data.get('id')
|
||||
if not file_id:
|
||||
raise ValueError("File ID not found in upload response.")
|
||||
raise ValueError('File ID not found in upload response.')
|
||||
|
||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
||||
log.info(f'File uploaded successfully. File ID: {file_id}')
|
||||
return file_id
|
||||
|
||||
def _get_signed_url(self, file_id: str) -> str:
|
||||
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
|
||||
log.info(f"Getting signed URL for file ID: {file_id}")
|
||||
url = f"{self.base_url}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
signed_url_headers = {**self.headers, "Accept": "application/json"}
|
||||
log.info(f'Getting signed URL for file ID: {file_id}')
|
||||
url = f'{self.base_url}/files/{file_id}/url'
|
||||
params = {'expiry': 1}
|
||||
signed_url_headers = {**self.headers, 'Accept': 'application/json'}
|
||||
|
||||
def url_request():
|
||||
response = requests.get(
|
||||
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
|
||||
)
|
||||
response = requests.get(url, headers=signed_url_headers, params=params, timeout=self.url_timeout)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
response_data = self._retry_request_sync(url_request)
|
||||
signed_url = response_data.get("url")
|
||||
signed_url = response_data.get('url')
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
log.info("Signed URL received.")
|
||||
raise ValueError('Signed URL not found in response.')
|
||||
log.info('Signed URL received.')
|
||||
return signed_url
|
||||
except Exception as e:
|
||||
log.error(f"Failed to get signed URL: {e}")
|
||||
log.error(f'Failed to get signed URL: {e}')
|
||||
raise
|
||||
|
||||
async def _get_signed_url_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> str:
|
||||
async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str:
|
||||
"""Async signed URL retrieval."""
|
||||
url = f"{self.base_url}/files/{file_id}/url"
|
||||
params = {"expiry": 1}
|
||||
url = f'{self.base_url}/files/{file_id}/url'
|
||||
params = {'expiry': 1}
|
||||
|
||||
headers = {**self.headers, "Accept": "application/json"}
|
||||
headers = {**self.headers, 'Accept': 'application/json'}
|
||||
|
||||
async def url_request():
|
||||
self._debug_log(f"Getting signed URL for file ID: {file_id}")
|
||||
self._debug_log(f'Getting signed URL for file ID: {file_id}')
|
||||
async with session.get(
|
||||
url,
|
||||
headers=headers,
|
||||
@@ -364,69 +338,65 @@ class MistralLoader:
|
||||
|
||||
response_data = await self._retry_request_async(url_request)
|
||||
|
||||
signed_url = response_data.get("url")
|
||||
signed_url = response_data.get('url')
|
||||
if not signed_url:
|
||||
raise ValueError("Signed URL not found in response.")
|
||||
raise ValueError('Signed URL not found in response.')
|
||||
|
||||
self._debug_log("Signed URL received successfully")
|
||||
self._debug_log('Signed URL received successfully')
|
||||
return signed_url
|
||||
|
||||
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
|
||||
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
|
||||
log.info("Processing OCR via Mistral API")
|
||||
url = f"{self.base_url}/ocr"
|
||||
log.info('Processing OCR via Mistral API')
|
||||
url = f'{self.base_url}/ocr'
|
||||
ocr_headers = {
|
||||
**self.headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
payload = {
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": signed_url,
|
||||
'model': 'mistral-ocr-latest',
|
||||
'document': {
|
||||
'type': 'document_url',
|
||||
'document_url': signed_url,
|
||||
},
|
||||
"include_image_base64": False,
|
||||
'include_image_base64': False,
|
||||
}
|
||||
|
||||
def ocr_request():
|
||||
response = requests.post(
|
||||
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
|
||||
)
|
||||
response = requests.post(url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout)
|
||||
return self._handle_response(response)
|
||||
|
||||
try:
|
||||
ocr_response = self._retry_request_sync(ocr_request)
|
||||
log.info("OCR processing done.")
|
||||
self._debug_log("OCR response: %s", ocr_response)
|
||||
log.info('OCR processing done.')
|
||||
self._debug_log('OCR response: %s', ocr_response)
|
||||
return ocr_response
|
||||
except Exception as e:
|
||||
log.error(f"Failed during OCR processing: {e}")
|
||||
log.error(f'Failed during OCR processing: {e}')
|
||||
raise
|
||||
|
||||
async def _process_ocr_async(
|
||||
self, session: aiohttp.ClientSession, signed_url: str
|
||||
) -> Dict[str, Any]:
|
||||
async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]:
|
||||
"""Async OCR processing with timing metrics."""
|
||||
url = f"{self.base_url}/ocr"
|
||||
url = f'{self.base_url}/ocr'
|
||||
|
||||
headers = {
|
||||
**self.headers,
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json',
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": "mistral-ocr-latest",
|
||||
"document": {
|
||||
"type": "document_url",
|
||||
"document_url": signed_url,
|
||||
'model': 'mistral-ocr-latest',
|
||||
'document': {
|
||||
'type': 'document_url',
|
||||
'document_url': signed_url,
|
||||
},
|
||||
"include_image_base64": False,
|
||||
'include_image_base64': False,
|
||||
}
|
||||
|
||||
async def ocr_request():
|
||||
log.info("Starting OCR processing via Mistral API")
|
||||
log.info('Starting OCR processing via Mistral API')
|
||||
start_time = time.time()
|
||||
|
||||
async with session.post(
|
||||
@@ -438,7 +408,7 @@ class MistralLoader:
|
||||
ocr_response = await self._handle_response_async(response)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
log.info(f"OCR processing completed in {processing_time:.2f}s")
|
||||
log.info(f'OCR processing completed in {processing_time:.2f}s')
|
||||
|
||||
return ocr_response
|
||||
|
||||
@@ -446,42 +416,36 @@ class MistralLoader:
|
||||
|
||||
def _delete_file(self, file_id: str) -> None:
|
||||
"""Deletes the file from Mistral storage (sync version)."""
|
||||
log.info(f"Deleting uploaded file ID: {file_id}")
|
||||
url = f"{self.base_url}/files/{file_id}"
|
||||
log.info(f'Deleting uploaded file ID: {file_id}')
|
||||
url = f'{self.base_url}/files/{file_id}'
|
||||
|
||||
try:
|
||||
response = requests.delete(
|
||||
url, headers=self.headers, timeout=self.cleanup_timeout
|
||||
)
|
||||
response = requests.delete(url, headers=self.headers, timeout=self.cleanup_timeout)
|
||||
delete_response = self._handle_response(response)
|
||||
log.info(f"File deleted successfully: {delete_response}")
|
||||
log.info(f'File deleted successfully: {delete_response}')
|
||||
except Exception as e:
|
||||
# Log error but don't necessarily halt execution if deletion fails
|
||||
log.error(f"Failed to delete file ID {file_id}: {e}")
|
||||
log.error(f'Failed to delete file ID {file_id}: {e}')
|
||||
|
||||
async def _delete_file_async(
|
||||
self, session: aiohttp.ClientSession, file_id: str
|
||||
) -> None:
|
||||
async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None:
|
||||
"""Async file deletion with error tolerance."""
|
||||
try:
|
||||
|
||||
async def delete_request():
|
||||
self._debug_log(f"Deleting file ID: {file_id}")
|
||||
self._debug_log(f'Deleting file ID: {file_id}')
|
||||
async with session.delete(
|
||||
url=f"{self.base_url}/files/{file_id}",
|
||||
url=f'{self.base_url}/files/{file_id}',
|
||||
headers=self.headers,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.cleanup_timeout
|
||||
), # Shorter timeout for cleanup
|
||||
timeout=aiohttp.ClientTimeout(total=self.cleanup_timeout), # Shorter timeout for cleanup
|
||||
) as response:
|
||||
return await self._handle_response_async(response)
|
||||
|
||||
await self._retry_request_async(delete_request)
|
||||
self._debug_log(f"File {file_id} deleted successfully")
|
||||
self._debug_log(f'File {file_id} deleted successfully')
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail the entire process if cleanup fails
|
||||
log.warning(f"Failed to delete file ID {file_id}: {e}")
|
||||
log.warning(f'Failed to delete file ID {file_id}: {e}')
|
||||
|
||||
@asynccontextmanager
|
||||
async def _get_session(self):
|
||||
@@ -506,7 +470,7 @@ class MistralLoader:
|
||||
async with aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
|
||||
headers={'User-Agent': 'OpenWebUI-MistralLoader/2.0'},
|
||||
raise_for_status=False, # We handle status codes manually
|
||||
trust_env=True,
|
||||
) as session:
|
||||
@@ -514,13 +478,13 @@ class MistralLoader:
|
||||
|
||||
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
|
||||
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
|
||||
pages_data = ocr_response.get("pages")
|
||||
pages_data = ocr_response.get('pages')
|
||||
if not pages_data:
|
||||
log.warning("No pages found in OCR response.")
|
||||
log.warning('No pages found in OCR response.')
|
||||
return [
|
||||
Document(
|
||||
page_content="No text content found",
|
||||
metadata={"error": "no_pages", "file_name": self.file_name},
|
||||
page_content='No text content found',
|
||||
metadata={'error': 'no_pages', 'file_name': self.file_name},
|
||||
)
|
||||
]
|
||||
|
||||
@@ -530,8 +494,8 @@ class MistralLoader:
|
||||
|
||||
# Process pages in a memory-efficient way
|
||||
for page_data in pages_data:
|
||||
page_content = page_data.get("markdown")
|
||||
page_index = page_data.get("index") # API uses 0-based index
|
||||
page_content = page_data.get('markdown')
|
||||
page_index = page_data.get('index') # API uses 0-based index
|
||||
|
||||
if page_content is None or page_index is None:
|
||||
skipped_pages += 1
|
||||
@@ -548,7 +512,7 @@ class MistralLoader:
|
||||
|
||||
if not cleaned_content:
|
||||
skipped_pages += 1
|
||||
self._debug_log(f"Skipping empty page {page_index}")
|
||||
self._debug_log(f'Skipping empty page {page_index}')
|
||||
continue
|
||||
|
||||
# Create document with optimized metadata
|
||||
@@ -556,34 +520,30 @@ class MistralLoader:
|
||||
Document(
|
||||
page_content=cleaned_content,
|
||||
metadata={
|
||||
"page": page_index, # 0-based index from API
|
||||
"page_label": page_index + 1, # 1-based label for convenience
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
"file_size": self.file_size,
|
||||
"processing_engine": "mistral-ocr",
|
||||
"content_length": len(cleaned_content),
|
||||
'page': page_index, # 0-based index from API
|
||||
'page_label': page_index + 1, # 1-based label for convenience
|
||||
'total_pages': total_pages,
|
||||
'file_name': self.file_name,
|
||||
'file_size': self.file_size,
|
||||
'processing_engine': 'mistral-ocr',
|
||||
'content_length': len(cleaned_content),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if skipped_pages > 0:
|
||||
log.info(
|
||||
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
|
||||
)
|
||||
log.info(f'Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages')
|
||||
|
||||
if not documents:
|
||||
# Case where pages existed but none had valid markdown/index
|
||||
log.warning(
|
||||
"OCR response contained pages, but none had valid content/index."
|
||||
)
|
||||
log.warning('OCR response contained pages, but none had valid content/index.')
|
||||
return [
|
||||
Document(
|
||||
page_content="No valid text content found in document",
|
||||
page_content='No valid text content found in document',
|
||||
metadata={
|
||||
"error": "no_valid_pages",
|
||||
"total_pages": total_pages,
|
||||
"file_name": self.file_name,
|
||||
'error': 'no_valid_pages',
|
||||
'total_pages': total_pages,
|
||||
'file_name': self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -615,24 +575,20 @@ class MistralLoader:
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
log.info(f'Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
log.error(
|
||||
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
|
||||
)
|
||||
log.error(f'An error occurred during the loading process after {total_time:.2f}s: {e}')
|
||||
# Return an error document on failure
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during processing: {e}",
|
||||
page_content=f'Error during processing: {e}',
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
'error': 'processing_failed',
|
||||
'file_name': self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -643,9 +599,7 @@ class MistralLoader:
|
||||
self._delete_file(file_id)
|
||||
except Exception as del_e:
|
||||
# Log deletion error, but don't overwrite original error if one occurred
|
||||
log.error(
|
||||
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
|
||||
)
|
||||
log.error(f'Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}')
|
||||
|
||||
async def load_async(self) -> List[Document]:
|
||||
"""
|
||||
@@ -672,21 +626,19 @@ class MistralLoader:
|
||||
documents = self._process_results(ocr_response)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
log.info(
|
||||
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
||||
)
|
||||
log.info(f'Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
|
||||
|
||||
return documents
|
||||
|
||||
except Exception as e:
|
||||
total_time = time.time() - start_time
|
||||
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
|
||||
log.error(f'Async OCR workflow failed after {total_time:.2f}s: {e}')
|
||||
return [
|
||||
Document(
|
||||
page_content=f"Error during OCR processing: {e}",
|
||||
page_content=f'Error during OCR processing: {e}',
|
||||
metadata={
|
||||
"error": "processing_failed",
|
||||
"file_name": self.file_name,
|
||||
'error': 'processing_failed',
|
||||
'file_name': self.file_name,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -697,11 +649,11 @@ class MistralLoader:
|
||||
async with self._get_session() as session:
|
||||
await self._delete_file_async(session, file_id)
|
||||
except Exception as cleanup_error:
|
||||
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
|
||||
log.error(f'Cleanup failed for file ID {file_id}: {cleanup_error}')
|
||||
|
||||
@staticmethod
|
||||
async def load_multiple_async(
|
||||
loaders: List["MistralLoader"],
|
||||
loaders: List['MistralLoader'],
|
||||
max_concurrent: int = 5, # Limit concurrent requests
|
||||
) -> List[List[Document]]:
|
||||
"""
|
||||
@@ -717,15 +669,13 @@ class MistralLoader:
|
||||
if not loaders:
|
||||
return []
|
||||
|
||||
log.info(
|
||||
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
|
||||
)
|
||||
log.info(f'Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent')
|
||||
start_time = time.time()
|
||||
|
||||
# Use semaphore to control concurrency
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
|
||||
async def process_with_semaphore(loader: 'MistralLoader') -> List[Document]:
|
||||
async with semaphore:
|
||||
return await loader.load_async()
|
||||
|
||||
@@ -737,14 +687,14 @@ class MistralLoader:
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
log.error(f"File {i} failed: {result}")
|
||||
log.error(f'File {i} failed: {result}')
|
||||
processed_results.append(
|
||||
[
|
||||
Document(
|
||||
page_content=f"Error processing file: {result}",
|
||||
page_content=f'Error processing file: {result}',
|
||||
metadata={
|
||||
"error": "batch_processing_failed",
|
||||
"file_index": i,
|
||||
'error': 'batch_processing_failed',
|
||||
'file_index': i,
|
||||
},
|
||||
)
|
||||
]
|
||||
@@ -755,15 +705,13 @@ class MistralLoader:
|
||||
# MONITORING: Log comprehensive batch processing statistics
|
||||
total_time = time.time() - start_time
|
||||
total_docs = sum(len(docs) for docs in processed_results)
|
||||
success_count = sum(
|
||||
1 for result in results if not isinstance(result, Exception)
|
||||
)
|
||||
success_count = sum(1 for result in results if not isinstance(result, Exception))
|
||||
failure_count = len(results) - success_count
|
||||
|
||||
log.info(
|
||||
f"Batch processing completed in {total_time:.2f}s: "
|
||||
f"{success_count} files succeeded, {failure_count} files failed, "
|
||||
f"produced {total_docs} total documents"
|
||||
f'Batch processing completed in {total_time:.2f}s: '
|
||||
f'{success_count} files succeeded, {failure_count} files failed, '
|
||||
f'produced {total_docs} total documents'
|
||||
)
|
||||
|
||||
return processed_results
|
||||
|
||||
@@ -25,7 +25,7 @@ class TavilyLoader(BaseLoader):
|
||||
self,
|
||||
urls: Union[str, List[str]],
|
||||
api_key: str,
|
||||
extract_depth: Literal["basic", "advanced"] = "basic",
|
||||
extract_depth: Literal['basic', 'advanced'] = 'basic',
|
||||
continue_on_failure: bool = True,
|
||||
) -> None:
|
||||
"""Initialize Tavily Extract client.
|
||||
@@ -42,13 +42,13 @@ class TavilyLoader(BaseLoader):
|
||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||
"""
|
||||
if not urls:
|
||||
raise ValueError("At least one URL must be provided.")
|
||||
raise ValueError('At least one URL must be provided.')
|
||||
|
||||
self.api_key = api_key
|
||||
self.urls = urls if isinstance(urls, list) else [urls]
|
||||
self.extract_depth = extract_depth
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.api_url = "https://api.tavily.com/extract"
|
||||
self.api_url = 'https://api.tavily.com/extract'
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Extract and yield documents from the URLs using Tavily Extract API."""
|
||||
@@ -57,35 +57,35 @@ class TavilyLoader(BaseLoader):
|
||||
batch_urls = self.urls[i : i + batch_size]
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
}
|
||||
# Use string for single URL, array for multiple URLs
|
||||
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
|
||||
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
|
||||
payload = {'urls': urls_param, 'extract_depth': self.extract_depth}
|
||||
# Make the API call
|
||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
# Process successful results
|
||||
for result in response_data.get("results", []):
|
||||
url = result.get("url", "")
|
||||
content = result.get("raw_content", "")
|
||||
for result in response_data.get('results', []):
|
||||
url = result.get('url', '')
|
||||
content = result.get('raw_content', '')
|
||||
if not content:
|
||||
log.warning(f"No content extracted from {url}")
|
||||
log.warning(f'No content extracted from {url}')
|
||||
continue
|
||||
# Add URLs as metadata
|
||||
metadata = {"source": url}
|
||||
metadata = {'source': url}
|
||||
yield Document(
|
||||
page_content=content,
|
||||
metadata=metadata,
|
||||
)
|
||||
for failed in response_data.get("failed_results", []):
|
||||
url = failed.get("url", "")
|
||||
error = failed.get("error", "Unknown error")
|
||||
log.error(f"Failed to extract content from {url}: {error}")
|
||||
for failed in response_data.get('failed_results', []):
|
||||
url = failed.get('url', '')
|
||||
error = failed.get('error', 'Unknown error')
|
||||
log.error(f'Failed to extract content from {url}: {error}')
|
||||
except Exception as e:
|
||||
if self.continue_on_failure:
|
||||
log.error(f"Error extracting content from batch {batch_urls}: {e}")
|
||||
log.error(f'Error extracting content from batch {batch_urls}: {e}')
|
||||
else:
|
||||
raise e
|
||||
|
||||
@@ -7,14 +7,14 @@ from langchain_core.documents import Document
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_SCHEMES = {"http", "https"}
|
||||
ALLOWED_SCHEMES = {'http', 'https'}
|
||||
ALLOWED_NETLOCS = {
|
||||
"youtu.be",
|
||||
"m.youtube.com",
|
||||
"youtube.com",
|
||||
"www.youtube.com",
|
||||
"www.youtube-nocookie.com",
|
||||
"vid.plus",
|
||||
'youtu.be',
|
||||
'm.youtube.com',
|
||||
'youtube.com',
|
||||
'www.youtube.com',
|
||||
'www.youtube-nocookie.com',
|
||||
'vid.plus',
|
||||
}
|
||||
|
||||
|
||||
@@ -30,17 +30,17 @@ def _parse_video_id(url: str) -> Optional[str]:
|
||||
|
||||
path = parsed_url.path
|
||||
|
||||
if path.endswith("/watch"):
|
||||
if path.endswith('/watch'):
|
||||
query = parsed_url.query
|
||||
parsed_query = parse_qs(query)
|
||||
if "v" in parsed_query:
|
||||
ids = parsed_query["v"]
|
||||
if 'v' in parsed_query:
|
||||
ids = parsed_query['v']
|
||||
video_id = ids if isinstance(ids, str) else ids[0]
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
path = parsed_url.path.lstrip("/")
|
||||
video_id = path.split("/")[-1]
|
||||
path = parsed_url.path.lstrip('/')
|
||||
video_id = path.split('/')[-1]
|
||||
|
||||
if len(video_id) != 11: # Video IDs are 11 characters long
|
||||
return None
|
||||
@@ -54,13 +54,13 @@ class YoutubeLoader:
|
||||
def __init__(
|
||||
self,
|
||||
video_id: str,
|
||||
language: Union[str, Sequence[str]] = "en",
|
||||
language: Union[str, Sequence[str]] = 'en',
|
||||
proxy_url: Optional[str] = None,
|
||||
):
|
||||
"""Initialize with YouTube video ID."""
|
||||
_video_id = _parse_video_id(video_id)
|
||||
self.video_id = _video_id if _video_id is not None else video_id
|
||||
self._metadata = {"source": video_id}
|
||||
self._metadata = {'source': video_id}
|
||||
self.proxy_url = proxy_url
|
||||
|
||||
# Ensure language is a list
|
||||
@@ -70,8 +70,8 @@ class YoutubeLoader:
|
||||
self.language = list(language)
|
||||
|
||||
# Add English as fallback if not already in the list
|
||||
if "en" not in self.language:
|
||||
self.language.append("en")
|
||||
if 'en' not in self.language:
|
||||
self.language.append('en')
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Load YouTube transcripts into `Document` objects."""
|
||||
@@ -85,14 +85,12 @@ class YoutubeLoader:
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
'Could not import "youtube_transcript_api" Python package. '
|
||||
"Please install it with `pip install youtube-transcript-api`."
|
||||
'Please install it with `pip install youtube-transcript-api`.'
|
||||
)
|
||||
|
||||
if self.proxy_url:
|
||||
youtube_proxies = GenericProxyConfig(
|
||||
http_url=self.proxy_url, https_url=self.proxy_url
|
||||
)
|
||||
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
||||
youtube_proxies = GenericProxyConfig(http_url=self.proxy_url, https_url=self.proxy_url)
|
||||
log.debug(f'Using proxy URL: {self.proxy_url[:14]}...')
|
||||
else:
|
||||
youtube_proxies = None
|
||||
|
||||
@@ -100,7 +98,7 @@ class YoutubeLoader:
|
||||
try:
|
||||
transcript_list = transcript_api.list(self.video_id)
|
||||
except Exception as e:
|
||||
log.exception("Loading YouTube transcript failed")
|
||||
log.exception('Loading YouTube transcript failed')
|
||||
return []
|
||||
|
||||
# Try each language in order of priority
|
||||
@@ -110,14 +108,10 @@ class YoutubeLoader:
|
||||
if transcript.is_generated:
|
||||
log.debug(f"Found generated transcript for language '{lang}'")
|
||||
try:
|
||||
transcript = transcript_list.find_manually_created_transcript(
|
||||
[lang]
|
||||
)
|
||||
transcript = transcript_list.find_manually_created_transcript([lang])
|
||||
log.debug(f"Found manual transcript for language '{lang}'")
|
||||
except NoTranscriptFound:
|
||||
log.debug(
|
||||
f"No manual transcript found for language '{lang}', using generated"
|
||||
)
|
||||
log.debug(f"No manual transcript found for language '{lang}', using generated")
|
||||
pass
|
||||
|
||||
log.debug(f"Found transcript for language '{lang}'")
|
||||
@@ -131,12 +125,10 @@ class YoutubeLoader:
|
||||
log.debug(f"Empty transcript for language '{lang}'")
|
||||
continue
|
||||
|
||||
transcript_text = " ".join(
|
||||
transcript_text = ' '.join(
|
||||
map(
|
||||
lambda transcript_piece: (
|
||||
transcript_piece.text.strip(" ")
|
||||
if hasattr(transcript_piece, "text")
|
||||
else ""
|
||||
transcript_piece.text.strip(' ') if hasattr(transcript_piece, 'text') else ''
|
||||
),
|
||||
transcript_pieces,
|
||||
)
|
||||
@@ -150,9 +142,9 @@ class YoutubeLoader:
|
||||
raise e
|
||||
|
||||
# If we get here, all languages failed
|
||||
languages_tried = ", ".join(self.language)
|
||||
languages_tried = ', '.join(self.language)
|
||||
log.warning(
|
||||
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
||||
f'No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed.'
|
||||
)
|
||||
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
||||
|
||||
|
||||
@@ -13,19 +13,17 @@ log = logging.getLogger(__name__)
|
||||
|
||||
class ColBERT(BaseReranker):
|
||||
def __init__(self, name, **kwargs) -> None:
|
||||
log.info("ColBERT: Loading model", name)
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
log.info('ColBERT: Loading model', name)
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
DOCKER = kwargs.get("env") == "docker"
|
||||
DOCKER = kwargs.get('env') == 'docker'
|
||||
if DOCKER:
|
||||
# This is a workaround for the issue with the docker container
|
||||
# where the torch extension is not loaded properly
|
||||
# and the following error is thrown:
|
||||
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
|
||||
|
||||
lock_file = (
|
||||
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
|
||||
)
|
||||
lock_file = '/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock'
|
||||
if os.path.exists(lock_file):
|
||||
os.remove(lock_file)
|
||||
|
||||
@@ -36,23 +34,16 @@ class ColBERT(BaseReranker):
|
||||
pass
|
||||
|
||||
def calculate_similarity_scores(self, query_embeddings, document_embeddings):
|
||||
|
||||
query_embeddings = query_embeddings.to(self.device)
|
||||
document_embeddings = document_embeddings.to(self.device)
|
||||
|
||||
# Validate dimensions to ensure compatibility
|
||||
if query_embeddings.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
|
||||
)
|
||||
raise ValueError(f'Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}.')
|
||||
if document_embeddings.dim() != 3:
|
||||
raise ValueError(
|
||||
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
|
||||
)
|
||||
raise ValueError(f'Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}.')
|
||||
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
|
||||
raise ValueError(
|
||||
"There should be either one query or queries equal to the number of documents."
|
||||
)
|
||||
raise ValueError('There should be either one query or queries equal to the number of documents.')
|
||||
|
||||
# Transpose the query embeddings to align for matrix multiplication
|
||||
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
|
||||
@@ -69,7 +60,6 @@ class ColBERT(BaseReranker):
|
||||
return normalized_scores.detach().cpu().numpy().astype(np.float32)
|
||||
|
||||
def predict(self, sentences):
|
||||
|
||||
query = sentences[0][0]
|
||||
docs = [i[1] for i in sentences]
|
||||
|
||||
@@ -80,8 +70,6 @@ class ColBERT(BaseReranker):
|
||||
embedded_query = embedded_queries[0]
|
||||
|
||||
# Calculate retrieval scores for the query against all documents
|
||||
scores = self.calculate_similarity_scores(
|
||||
embedded_query.unsqueeze(0), embedded_docs
|
||||
)
|
||||
scores = self.calculate_similarity_scores(embedded_query.unsqueeze(0), embedded_docs)
|
||||
|
||||
return scores
|
||||
|
||||
@@ -15,8 +15,8 @@ class ExternalReranker(BaseReranker):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
url: str = "http://localhost:8080/v1/rerank",
|
||||
model: str = "reranker",
|
||||
url: str = 'http://localhost:8080/v1/rerank',
|
||||
model: str = 'reranker',
|
||||
timeout: Optional[int] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
@@ -24,33 +24,31 @@ class ExternalReranker(BaseReranker):
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
|
||||
def predict(
|
||||
self, sentences: List[Tuple[str, str]], user=None
|
||||
) -> Optional[List[float]]:
|
||||
def predict(self, sentences: List[Tuple[str, str]], user=None) -> Optional[List[float]]:
|
||||
query = sentences[0][0]
|
||||
docs = [i[1] for i in sentences]
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"top_n": len(docs),
|
||||
'model': self.model,
|
||||
'query': query,
|
||||
'documents': docs,
|
||||
'top_n': len(docs),
|
||||
}
|
||||
|
||||
try:
|
||||
log.info(f"ExternalReranker:predict:model {self.model}")
|
||||
log.info(f"ExternalReranker:predict:query {query}")
|
||||
log.info(f'ExternalReranker:predict:model {self.model}')
|
||||
log.info(f'ExternalReranker:predict:query {query}')
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
}
|
||||
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||
headers = include_user_info_headers(headers, user)
|
||||
|
||||
r = requests.post(
|
||||
f"{self.url}",
|
||||
f'{self.url}',
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
@@ -60,13 +58,13 @@ class ExternalReranker(BaseReranker):
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
if "results" in data:
|
||||
sorted_results = sorted(data["results"], key=lambda x: x["index"])
|
||||
return [result["relevance_score"] for result in sorted_results]
|
||||
if 'results' in data:
|
||||
sorted_results = sorted(data['results'], key=lambda x: x['index'])
|
||||
return [result['relevance_score'] for result in sorted_results]
|
||||
else:
|
||||
log.error("No results found in external reranking response")
|
||||
log.error('No results found in external reranking response')
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
log.exception(f"Error in external reranking: {e}")
|
||||
log.exception(f'Error in external reranking: {e}')
|
||||
return None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -31,17 +31,15 @@ log = logging.getLogger(__name__)
|
||||
class ChromaClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
settings_dict = {
|
||||
"allow_reset": True,
|
||||
"anonymized_telemetry": False,
|
||||
'allow_reset': True,
|
||||
'anonymized_telemetry': False,
|
||||
}
|
||||
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
|
||||
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
|
||||
settings_dict['chroma_client_auth_provider'] = CHROMA_CLIENT_AUTH_PROVIDER
|
||||
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
|
||||
settings_dict["chroma_client_auth_credentials"] = (
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS
|
||||
)
|
||||
settings_dict['chroma_client_auth_credentials'] = CHROMA_CLIENT_AUTH_CREDENTIALS
|
||||
|
||||
if CHROMA_HTTP_HOST != "":
|
||||
if CHROMA_HTTP_HOST != '':
|
||||
self.client = chromadb.HttpClient(
|
||||
host=CHROMA_HTTP_HOST,
|
||||
port=CHROMA_HTTP_PORT,
|
||||
@@ -87,25 +85,23 @@ class ChromaClient(VectorDBBase):
|
||||
|
||||
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
|
||||
# https://docs.trychroma.com/docs/collections/configure cosine equation
|
||||
distances: list = result["distances"][0]
|
||||
distances: list = result['distances'][0]
|
||||
distances = [2 - dist for dist in distances]
|
||||
distances = [[dist / 2 for dist in distances]]
|
||||
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": result["ids"],
|
||||
"distances": distances,
|
||||
"documents": result["documents"],
|
||||
"metadatas": result["metadatas"],
|
||||
'ids': result['ids'],
|
||||
'distances': distances,
|
||||
'documents': result['documents'],
|
||||
'metadatas': result['metadatas'],
|
||||
}
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
# Query the items from the collection based on the filter.
|
||||
try:
|
||||
collection = self.client.get_collection(name=collection_name)
|
||||
@@ -117,9 +113,9 @@ class ChromaClient(VectorDBBase):
|
||||
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [result["ids"]],
|
||||
"documents": [result["documents"]],
|
||||
"metadatas": [result["metadatas"]],
|
||||
'ids': [result['ids']],
|
||||
'documents': [result['documents']],
|
||||
'metadatas': [result['metadatas']],
|
||||
}
|
||||
)
|
||||
return None
|
||||
@@ -133,23 +129,21 @@ class ChromaClient(VectorDBBase):
|
||||
result = collection.get()
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": [result["ids"]],
|
||||
"documents": [result["documents"]],
|
||||
"metadatas": [result["metadatas"]],
|
||||
'ids': [result['ids']],
|
||||
'documents': [result['documents']],
|
||||
'metadatas': [result['metadatas']],
|
||||
}
|
||||
)
|
||||
return None
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
embeddings = [item["vector"] for item in items]
|
||||
metadatas = [process_metadata(item["metadata"]) for item in items]
|
||||
ids = [item['id'] for item in items]
|
||||
documents = [item['text'] for item in items]
|
||||
embeddings = [item['vector'] for item in items]
|
||||
metadatas = [process_metadata(item['metadata']) for item in items]
|
||||
|
||||
for batch in create_batches(
|
||||
api=self.client,
|
||||
@@ -162,18 +156,14 @@ class ChromaClient(VectorDBBase):
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
||||
)
|
||||
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
|
||||
|
||||
ids = [item["id"] for item in items]
|
||||
documents = [item["text"] for item in items]
|
||||
embeddings = [item["vector"] for item in items]
|
||||
metadatas = [process_metadata(item["metadata"]) for item in items]
|
||||
ids = [item['id'] for item in items]
|
||||
documents = [item['text'] for item in items]
|
||||
embeddings = [item['vector'] for item in items]
|
||||
metadatas = [process_metadata(item['metadata']) for item in items]
|
||||
|
||||
collection.upsert(
|
||||
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
|
||||
)
|
||||
collection.upsert(ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
@@ -191,9 +181,7 @@ class ChromaClient(VectorDBBase):
|
||||
collection.delete(where=filter)
|
||||
except Exception as e:
|
||||
# If collection doesn't exist, that's fine - nothing to delete
|
||||
log.debug(
|
||||
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
|
||||
)
|
||||
log.debug(f'Attempted to delete from non-existent collection {collection_name}. Ignoring.')
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
|
||||
@@ -51,7 +51,7 @@ class ElasticsearchClient(VectorDBBase):
|
||||
|
||||
# Status: works
|
||||
def _get_index_name(self, dimension: int) -> str:
|
||||
return f"{self.index_prefix}_d{str(dimension)}"
|
||||
return f'{self.index_prefix}_d{str(dimension)}'
|
||||
|
||||
# Status: works
|
||||
def _scan_result_to_get_result(self, result) -> GetResult:
|
||||
@@ -62,24 +62,24 @@ class ElasticsearchClient(VectorDBBase):
|
||||
metadatas = []
|
||||
|
||||
for hit in result:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
ids.append(hit['_id'])
|
||||
documents.append(hit['_source'].get('text'))
|
||||
metadatas.append(hit['_source'].get('metadata'))
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
# Status: works
|
||||
def _result_to_get_result(self, result) -> GetResult:
|
||||
if not result["hits"]["hits"]:
|
||||
if not result['hits']['hits']:
|
||||
return None
|
||||
ids = []
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
for hit in result['hits']['hits']:
|
||||
ids.append(hit['_id'])
|
||||
documents.append(hit['_source'].get('text'))
|
||||
metadatas.append(hit['_source'].get('metadata'))
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
@@ -90,11 +90,11 @@ class ElasticsearchClient(VectorDBBase):
|
||||
documents = []
|
||||
metadatas = []
|
||||
|
||||
for hit in result["hits"]["hits"]:
|
||||
ids.append(hit["_id"])
|
||||
distances.append(hit["_score"])
|
||||
documents.append(hit["_source"].get("text"))
|
||||
metadatas.append(hit["_source"].get("metadata"))
|
||||
for hit in result['hits']['hits']:
|
||||
ids.append(hit['_id'])
|
||||
distances.append(hit['_score'])
|
||||
documents.append(hit['_source'].get('text'))
|
||||
metadatas.append(hit['_source'].get('metadata'))
|
||||
|
||||
return SearchResult(
|
||||
ids=[ids],
|
||||
@@ -106,26 +106,26 @@ class ElasticsearchClient(VectorDBBase):
|
||||
# Status: works
|
||||
def _create_index(self, dimension: int):
|
||||
body = {
|
||||
"mappings": {
|
||||
"dynamic_templates": [
|
||||
'mappings': {
|
||||
'dynamic_templates': [
|
||||
{
|
||||
"strings": {
|
||||
"match_mapping_type": "string",
|
||||
"mapping": {"type": "keyword"},
|
||||
'strings': {
|
||||
'match_mapping_type': 'string',
|
||||
'mapping': {'type': 'keyword'},
|
||||
}
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"collection": {"type": "keyword"},
|
||||
"id": {"type": "keyword"},
|
||||
"vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": dimension, # Adjust based on your vector dimensions
|
||||
"index": True,
|
||||
"similarity": "cosine",
|
||||
'properties': {
|
||||
'collection': {'type': 'keyword'},
|
||||
'id': {'type': 'keyword'},
|
||||
'vector': {
|
||||
'type': 'dense_vector',
|
||||
'dims': dimension, # Adjust based on your vector dimensions
|
||||
'index': True,
|
||||
'similarity': 'cosine',
|
||||
},
|
||||
"text": {"type": "text"},
|
||||
"metadata": {"type": "object"},
|
||||
'text': {'type': 'text'},
|
||||
'metadata': {'type': 'object'},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -139,21 +139,19 @@ class ElasticsearchClient(VectorDBBase):
|
||||
|
||||
# Status: works
|
||||
def has_collection(self, collection_name) -> bool:
|
||||
query_body = {"query": {"bool": {"filter": []}}}
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"term": {"collection": collection_name}}
|
||||
)
|
||||
query_body = {'query': {'bool': {'filter': []}}}
|
||||
query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
|
||||
|
||||
try:
|
||||
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
|
||||
result = self.client.count(index=f'{self.index_prefix}*', body=query_body)
|
||||
|
||||
return result.body["count"] > 0
|
||||
return result.body['count'] > 0
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
query = {"query": {"term": {"collection": collection_name}}}
|
||||
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
|
||||
query = {'query': {'term': {'collection': collection_name}}}
|
||||
self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
|
||||
|
||||
# Status: works
|
||||
def search(
|
||||
@@ -164,51 +162,41 @@ class ElasticsearchClient(VectorDBBase):
|
||||
limit: int = 10,
|
||||
) -> Optional[SearchResult]:
|
||||
query = {
|
||||
"size": limit,
|
||||
"_source": ["text", "metadata"],
|
||||
"query": {
|
||||
"script_score": {
|
||||
"query": {
|
||||
"bool": {"filter": [{"term": {"collection": collection_name}}]}
|
||||
},
|
||||
"script": {
|
||||
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||
"params": {
|
||||
"vector": vectors[0]
|
||||
}, # Assuming single query vector
|
||||
'size': limit,
|
||||
'_source': ['text', 'metadata'],
|
||||
'query': {
|
||||
'script_score': {
|
||||
'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
|
||||
'script': {
|
||||
'source': "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||
'params': {'vector': vectors[0]}, # Assuming single query vector
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = self.client.search(
|
||||
index=self._get_index_name(len(vectors[0])), body=query
|
||||
)
|
||||
result = self.client.search(index=self._get_index_name(len(vectors[0])), body=query)
|
||||
|
||||
return self._result_to_search_result(result)
|
||||
|
||||
# Status: only tested halfwat
|
||||
def query(
|
||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
if not self.has_collection(collection_name):
|
||||
return None
|
||||
|
||||
query_body = {
|
||||
"query": {"bool": {"filter": []}},
|
||||
"_source": ["text", "metadata"],
|
||||
'query': {'bool': {'filter': []}},
|
||||
'_source': ['text', 'metadata'],
|
||||
}
|
||||
|
||||
for field, value in filter.items():
|
||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
||||
query_body["query"]["bool"]["filter"].append(
|
||||
{"term": {"collection": collection_name}}
|
||||
)
|
||||
query_body['query']['bool']['filter'].append({'term': {field: value}})
|
||||
query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
|
||||
size = limit if limit else 10
|
||||
|
||||
try:
|
||||
result = self.client.search(
|
||||
index=f"{self.index_prefix}*",
|
||||
index=f'{self.index_prefix}*',
|
||||
body=query_body,
|
||||
size=size,
|
||||
)
|
||||
@@ -220,9 +208,7 @@ class ElasticsearchClient(VectorDBBase):
|
||||
|
||||
# Status: works
|
||||
def _has_index(self, dimension: int):
|
||||
return self.client.indices.exists(
|
||||
index=self._get_index_name(dimension=dimension)
|
||||
)
|
||||
return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
|
||||
|
||||
def get_or_create_index(self, dimension: int):
|
||||
if not self._has_index(dimension=dimension):
|
||||
@@ -232,28 +218,28 @@ class ElasticsearchClient(VectorDBBase):
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection.
|
||||
query = {
|
||||
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
|
||||
"_source": ["text", "metadata"],
|
||||
'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
|
||||
'_source': ['text', 'metadata'],
|
||||
}
|
||||
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
|
||||
results = list(scan(self.client, index=f'{self.index_prefix}*', query=query))
|
||||
|
||||
return self._scan_result_to_get_result(results)
|
||||
|
||||
# Status: works
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||
self._create_index(dimension=len(items[0]["vector"]))
|
||||
if not self._has_index(dimension=len(items[0]['vector'])):
|
||||
self._create_index(dimension=len(items[0]['vector']))
|
||||
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
|
||||
"_id": item["id"],
|
||||
"_source": {
|
||||
"collection": collection_name,
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": process_metadata(item["metadata"]),
|
||||
'_index': self._get_index_name(dimension=len(items[0]['vector'])),
|
||||
'_id': item['id'],
|
||||
'_source': {
|
||||
'collection': collection_name,
|
||||
'vector': item['vector'],
|
||||
'text': item['text'],
|
||||
'metadata': process_metadata(item['metadata']),
|
||||
},
|
||||
}
|
||||
for item in batch
|
||||
@@ -262,21 +248,21 @@ class ElasticsearchClient(VectorDBBase):
|
||||
|
||||
# Upsert documents using the update API with doc_as_upsert=True.
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
if not self._has_index(dimension=len(items[0]["vector"])):
|
||||
self._create_index(dimension=len(items[0]["vector"]))
|
||||
if not self._has_index(dimension=len(items[0]['vector'])):
|
||||
self._create_index(dimension=len(items[0]['vector']))
|
||||
for batch in self._create_batches(items):
|
||||
actions = [
|
||||
{
|
||||
"_op_type": "update",
|
||||
"_index": self._get_index_name(dimension=len(item["vector"])),
|
||||
"_id": item["id"],
|
||||
"doc": {
|
||||
"collection": collection_name,
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": process_metadata(item["metadata"]),
|
||||
'_op_type': 'update',
|
||||
'_index': self._get_index_name(dimension=len(item['vector'])),
|
||||
'_id': item['id'],
|
||||
'doc': {
|
||||
'collection': collection_name,
|
||||
'vector': item['vector'],
|
||||
'text': item['text'],
|
||||
'metadata': process_metadata(item['metadata']),
|
||||
},
|
||||
"doc_as_upsert": True,
|
||||
'doc_as_upsert': True,
|
||||
}
|
||||
for item in batch
|
||||
]
|
||||
@@ -289,22 +275,17 @@ class ElasticsearchClient(VectorDBBase):
|
||||
ids: Optional[list[str]] = None,
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
|
||||
query = {
|
||||
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
|
||||
}
|
||||
query = {'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}}}
|
||||
# logic based on chromaDB
|
||||
if ids:
|
||||
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
|
||||
query['query']['bool']['filter'].append({'terms': {'_id': ids}})
|
||||
elif filter:
|
||||
for field, value in filter.items():
|
||||
query["query"]["bool"]["filter"].append(
|
||||
{"term": {f"metadata.{field}": value}}
|
||||
)
|
||||
query['query']['bool']['filter'].append({'term': {f'metadata.{field}': value}})
|
||||
|
||||
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
|
||||
self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
|
||||
|
||||
def reset(self):
|
||||
indices = self.client.indices.get(index=f"{self.index_prefix}*")
|
||||
indices = self.client.indices.get(index=f'{self.index_prefix}*')
|
||||
for index in indices:
|
||||
self.client.indices.delete(index=index)
|
||||
|
||||
@@ -47,8 +47,8 @@ def _embedding_to_f32_bytes(vec: List[float]) -> bytes:
|
||||
byte sequence. We use array('f') to avoid a numpy dependency and byteswap on
|
||||
big-endian platforms for portability.
|
||||
"""
|
||||
a = array.array("f", [float(x) for x in vec]) # float32
|
||||
if sys.byteorder != "little":
|
||||
a = array.array('f', [float(x) for x in vec]) # float32
|
||||
if sys.byteorder != 'little':
|
||||
a.byteswap()
|
||||
return a.tobytes()
|
||||
|
||||
@@ -68,7 +68,7 @@ def _safe_json(v: Any) -> Dict[str, Any]:
|
||||
return v
|
||||
if isinstance(v, (bytes, bytearray)):
|
||||
try:
|
||||
v = v.decode("utf-8")
|
||||
v = v.decode('utf-8')
|
||||
except Exception:
|
||||
return {}
|
||||
if isinstance(v, str):
|
||||
@@ -105,16 +105,16 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
"""
|
||||
self.db_url = (db_url or MARIADB_VECTOR_DB_URL).strip()
|
||||
self.vector_length = int(vector_length)
|
||||
self.distance_strategy = (distance_strategy or "cosine").strip().lower()
|
||||
self.distance_strategy = (distance_strategy or 'cosine').strip().lower()
|
||||
self.index_m = int(index_m)
|
||||
|
||||
if self.distance_strategy not in {"cosine", "euclidean"}:
|
||||
if self.distance_strategy not in {'cosine', 'euclidean'}:
|
||||
raise ValueError("distance_strategy must be 'cosine' or 'euclidean'")
|
||||
|
||||
if not self.db_url.lower().startswith("mariadb+mariadbconnector://"):
|
||||
if not self.db_url.lower().startswith('mariadb+mariadbconnector://'):
|
||||
raise ValueError(
|
||||
"MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) "
|
||||
"to ensure qmark paramstyle and correct VECTOR binding."
|
||||
'MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) '
|
||||
'to ensure qmark paramstyle and correct VECTOR binding.'
|
||||
)
|
||||
|
||||
if isinstance(MARIADB_VECTOR_POOL_SIZE, int):
|
||||
@@ -129,9 +129,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
poolclass=QueuePool,
|
||||
)
|
||||
else:
|
||||
self.engine = create_engine(
|
||||
self.db_url, pool_pre_ping=True, poolclass=NullPool
|
||||
)
|
||||
self.engine = create_engine(self.db_url, pool_pre_ping=True, poolclass=NullPool)
|
||||
else:
|
||||
self.engine = create_engine(self.db_url, pool_pre_ping=True)
|
||||
self._init_schema()
|
||||
@@ -185,7 +183,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
log.exception(f"Error during database initialization: {e}")
|
||||
log.exception(f'Error during database initialization: {e}')
|
||||
raise
|
||||
|
||||
def _check_vector_length(self) -> None:
|
||||
@@ -197,19 +195,19 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
"""
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SHOW CREATE TABLE document_chunk")
|
||||
cur.execute('SHOW CREATE TABLE document_chunk')
|
||||
row = cur.fetchone()
|
||||
if not row or len(row) < 2:
|
||||
return
|
||||
ddl = row[1]
|
||||
m = re.search(r"vector\\((\\d+)\\)", ddl, flags=re.IGNORECASE)
|
||||
m = re.search(r'vector\\((\\d+)\\)', ddl, flags=re.IGNORECASE)
|
||||
if not m:
|
||||
return
|
||||
existing = int(m.group(1))
|
||||
if existing != int(self.vector_length):
|
||||
raise Exception(
|
||||
f"VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. "
|
||||
"Cannot change vector size after initialization without migrating the data."
|
||||
f'VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. '
|
||||
'Cannot change vector size after initialization without migrating the data.'
|
||||
)
|
||||
|
||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||
@@ -227,11 +225,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
"""
|
||||
Return the MariaDB Vector distance function name for the configured strategy.
|
||||
"""
|
||||
return (
|
||||
"vec_distance_cosine"
|
||||
if self.distance_strategy == "cosine"
|
||||
else "vec_distance_euclidean"
|
||||
)
|
||||
return 'vec_distance_cosine' if self.distance_strategy == 'cosine' else 'vec_distance_euclidean'
|
||||
|
||||
def _score_from_dist(self, dist: float) -> float:
|
||||
"""
|
||||
@@ -240,7 +234,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
- cosine: score ~= 1 - cosine_distance, clamped to [0, 1]
|
||||
- euclidean: score = 1 / (1 + dist)
|
||||
"""
|
||||
if self.distance_strategy == "cosine":
|
||||
if self.distance_strategy == 'cosine':
|
||||
score = 1.0 - dist
|
||||
if score < 0.0:
|
||||
score = 0.0
|
||||
@@ -260,48 +254,48 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
- {"$or": [ ... ]}
|
||||
"""
|
||||
if not expr or not isinstance(expr, dict):
|
||||
return "", []
|
||||
return '', []
|
||||
|
||||
if "$and" in expr:
|
||||
if '$and' in expr:
|
||||
parts: List[str] = []
|
||||
params: List[Any] = []
|
||||
for e in expr.get("$and") or []:
|
||||
for e in expr.get('$and') or []:
|
||||
s, p = self._build_filter_sql_qmark(e)
|
||||
if s:
|
||||
parts.append(s)
|
||||
params.extend(p)
|
||||
return ("(" + " AND ".join(parts) + ")") if parts else "", params
|
||||
return ('(' + ' AND '.join(parts) + ')') if parts else '', params
|
||||
|
||||
if "$or" in expr:
|
||||
if '$or' in expr:
|
||||
parts: List[str] = []
|
||||
params: List[Any] = []
|
||||
for e in expr.get("$or") or []:
|
||||
for e in expr.get('$or') or []:
|
||||
s, p = self._build_filter_sql_qmark(e)
|
||||
if s:
|
||||
parts.append(s)
|
||||
params.extend(p)
|
||||
return ("(" + " OR ".join(parts) + ")") if parts else "", params
|
||||
return ('(' + ' OR '.join(parts) + ')') if parts else '', params
|
||||
|
||||
clauses: List[str] = []
|
||||
params: List[Any] = []
|
||||
for key, value in expr.items():
|
||||
if key.startswith("$"):
|
||||
if key.startswith('$'):
|
||||
continue
|
||||
json_expr = f"JSON_UNQUOTE(JSON_EXTRACT(vmetadata, '$.{key}'))"
|
||||
if isinstance(value, dict) and "$in" in value:
|
||||
vals = [str(v) for v in (value.get("$in") or [])]
|
||||
if isinstance(value, dict) and '$in' in value:
|
||||
vals = [str(v) for v in (value.get('$in') or [])]
|
||||
if not vals:
|
||||
clauses.append("0=1")
|
||||
clauses.append('0=1')
|
||||
continue
|
||||
ors = []
|
||||
for v in vals:
|
||||
ors.append(f"{json_expr} = ?")
|
||||
ors.append(f'{json_expr} = ?')
|
||||
params.append(v)
|
||||
clauses.append("(" + " OR ".join(ors) + ")")
|
||||
clauses.append('(' + ' OR '.join(ors) + ')')
|
||||
else:
|
||||
clauses.append(f"{json_expr} = ?")
|
||||
clauses.append(f'{json_expr} = ?')
|
||||
params.append(str(value))
|
||||
return ("(" + " AND ".join(clauses) + ")") if clauses else "", params
|
||||
return ('(' + ' AND '.join(clauses) + ')') if clauses else '', params
|
||||
|
||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
"""
|
||||
@@ -322,15 +316,15 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
"""
|
||||
params: List[Tuple[Any, ...]] = []
|
||||
for item in items:
|
||||
v = self.adjust_vector_length(item["vector"])
|
||||
v = self.adjust_vector_length(item['vector'])
|
||||
emb = _embedding_to_f32_bytes(v)
|
||||
meta = process_metadata(item.get("metadata") or {})
|
||||
meta = process_metadata(item.get('metadata') or {})
|
||||
params.append(
|
||||
(
|
||||
item["id"],
|
||||
item['id'],
|
||||
emb,
|
||||
collection_name,
|
||||
item.get("text"),
|
||||
item.get('text'),
|
||||
json.dumps(meta),
|
||||
)
|
||||
)
|
||||
@@ -338,7 +332,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
log.exception(f"Error during insert: {e}")
|
||||
log.exception(f'Error during insert: {e}')
|
||||
raise
|
||||
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||
@@ -365,15 +359,15 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
"""
|
||||
params: List[Tuple[Any, ...]] = []
|
||||
for item in items:
|
||||
v = self.adjust_vector_length(item["vector"])
|
||||
v = self.adjust_vector_length(item['vector'])
|
||||
emb = _embedding_to_f32_bytes(v)
|
||||
meta = process_metadata(item.get("metadata") or {})
|
||||
meta = process_metadata(item.get('metadata') or {})
|
||||
params.append(
|
||||
(
|
||||
item["id"],
|
||||
item['id'],
|
||||
emb,
|
||||
collection_name,
|
||||
item.get("text"),
|
||||
item.get('text'),
|
||||
json.dumps(meta),
|
||||
)
|
||||
)
|
||||
@@ -381,7 +375,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
log.exception(f"Error during upsert: {e}")
|
||||
log.exception(f'Error during upsert: {e}')
|
||||
raise
|
||||
|
||||
def search(
|
||||
@@ -415,10 +409,10 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
fsql, fparams = self._build_filter_sql_qmark(filter or {})
|
||||
where = "collection_name = ?"
|
||||
where = 'collection_name = ?'
|
||||
base_params: List[Any] = [collection_name]
|
||||
if fsql:
|
||||
where = where + " AND " + fsql
|
||||
where = where + ' AND ' + fsql
|
||||
base_params.extend(fparams)
|
||||
|
||||
sql = f"""
|
||||
@@ -460,26 +454,24 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
metadatas=metadatas,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"[MARIADB_VECTOR] search() failed: {e}")
|
||||
log.exception(f'[MARIADB_VECTOR] search() failed: {e}')
|
||||
return None
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
"""
|
||||
Retrieve documents by metadata filter (non-vector query).
|
||||
"""
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
fsql, fparams = self._build_filter_sql_qmark(filter or {})
|
||||
where = "collection_name = ?"
|
||||
where = 'collection_name = ?'
|
||||
params: List[Any] = [collection_name]
|
||||
if fsql:
|
||||
where = where + " AND " + fsql
|
||||
where = where + ' AND ' + fsql
|
||||
params.extend(fparams)
|
||||
sql = f"SELECT id, text, vmetadata FROM document_chunk WHERE {where}"
|
||||
sql = f'SELECT id, text, vmetadata FROM document_chunk WHERE {where}'
|
||||
if limit is not None:
|
||||
sql += " LIMIT ?"
|
||||
sql += ' LIMIT ?'
|
||||
params.append(int(limit))
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
@@ -490,18 +482,16 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
metadatas = [[_safe_json(r[2]) for r in rows]]
|
||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||
|
||||
def get(
|
||||
self, collection_name: str, limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
"""
|
||||
Retrieve documents in a collection without filtering (optionally limited).
|
||||
"""
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
sql = "SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?"
|
||||
sql = 'SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?'
|
||||
params: List[Any] = [collection_name]
|
||||
if limit is not None:
|
||||
sql += " LIMIT ?"
|
||||
sql += ' LIMIT ?'
|
||||
params.append(int(limit))
|
||||
cur.execute(sql, params)
|
||||
rows = cur.fetchall()
|
||||
@@ -526,12 +516,12 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
where = ["collection_name = ?"]
|
||||
where = ['collection_name = ?']
|
||||
params: List[Any] = [collection_name]
|
||||
|
||||
if ids:
|
||||
ph = ", ".join(["?"] * len(ids))
|
||||
where.append(f"id IN ({ph})")
|
||||
ph = ', '.join(['?'] * len(ids))
|
||||
where.append(f'id IN ({ph})')
|
||||
params.extend(ids)
|
||||
|
||||
if filter:
|
||||
@@ -540,12 +530,12 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
where.append(fsql)
|
||||
params.extend(fparams)
|
||||
|
||||
sql = "DELETE FROM document_chunk WHERE " + " AND ".join(where)
|
||||
sql = 'DELETE FROM document_chunk WHERE ' + ' AND '.join(where)
|
||||
cur.execute(sql, params)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
log.exception(f"Error during delete: {e}")
|
||||
log.exception(f'Error during delete: {e}')
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
@@ -555,11 +545,11 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
cur.execute("TRUNCATE TABLE document_chunk")
|
||||
cur.execute('TRUNCATE TABLE document_chunk')
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
log.exception(f"Error during reset: {e}")
|
||||
log.exception(f'Error during reset: {e}')
|
||||
raise
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
@@ -570,7 +560,7 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
with self._connect() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1",
|
||||
'SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1',
|
||||
(collection_name,),
|
||||
)
|
||||
return cur.fetchone() is not None
|
||||
@@ -590,4 +580,4 @@ class MariaDBVectorClient(VectorDBBase):
|
||||
try:
|
||||
self.engine.dispose()
|
||||
except Exception as e:
|
||||
log.exception(f"Error during dispose the underlying SQLAlchemy engine: {e}")
|
||||
log.exception(f'Error during dispose the underlying SQLAlchemy engine: {e}')
|
||||
|
||||
@@ -35,7 +35,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
class MilvusClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
self.collection_prefix = "open_webui"
|
||||
self.collection_prefix = 'open_webui'
|
||||
if MILVUS_TOKEN is None:
|
||||
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
|
||||
else:
|
||||
@@ -50,17 +50,17 @@ class MilvusClient(VectorDBBase):
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_documents.append(item.get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("metadata"))
|
||||
_ids.append(item.get('id'))
|
||||
_documents.append(item.get('data', {}).get('text'))
|
||||
_metadatas.append(item.get('metadata'))
|
||||
ids.append(_ids)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
return GetResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
'ids': ids,
|
||||
'documents': documents,
|
||||
'metadatas': metadatas,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -75,23 +75,23 @@ class MilvusClient(VectorDBBase):
|
||||
_documents = []
|
||||
_metadatas = []
|
||||
for item in match:
|
||||
_ids.append(item.get("id"))
|
||||
_ids.append(item.get('id'))
|
||||
# normalize milvus score from [-1, 1] to [0, 1] range
|
||||
# https://milvus.io/docs/de/metric.md
|
||||
_dist = (item.get("distance") + 1.0) / 2.0
|
||||
_dist = (item.get('distance') + 1.0) / 2.0
|
||||
_distances.append(_dist)
|
||||
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
|
||||
_metadatas.append(item.get("entity", {}).get("metadata"))
|
||||
_documents.append(item.get('entity', {}).get('data', {}).get('text'))
|
||||
_metadatas.append(item.get('entity', {}).get('metadata'))
|
||||
ids.append(_ids)
|
||||
distances.append(_distances)
|
||||
documents.append(_documents)
|
||||
metadatas.append(_metadatas)
|
||||
return SearchResult(
|
||||
**{
|
||||
"ids": ids,
|
||||
"distances": distances,
|
||||
"documents": documents,
|
||||
"metadatas": metadatas,
|
||||
'ids': ids,
|
||||
'distances': distances,
|
||||
'documents': documents,
|
||||
'metadatas': metadatas,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -101,21 +101,19 @@ class MilvusClient(VectorDBBase):
|
||||
enable_dynamic_field=True,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="id",
|
||||
field_name='id',
|
||||
datatype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
max_length=65535,
|
||||
)
|
||||
schema.add_field(
|
||||
field_name="vector",
|
||||
field_name='vector',
|
||||
datatype=DataType.FLOAT_VECTOR,
|
||||
dim=dimension,
|
||||
description="vector",
|
||||
)
|
||||
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
|
||||
schema.add_field(
|
||||
field_name="metadata", datatype=DataType.JSON, description="metadata"
|
||||
description='vector',
|
||||
)
|
||||
schema.add_field(field_name='data', datatype=DataType.JSON, description='data')
|
||||
schema.add_field(field_name='metadata', datatype=DataType.JSON, description='metadata')
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
|
||||
@@ -123,44 +121,44 @@ class MilvusClient(VectorDBBase):
|
||||
index_type = MILVUS_INDEX_TYPE.upper()
|
||||
metric_type = MILVUS_METRIC_TYPE.upper()
|
||||
|
||||
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
|
||||
log.info(f'Using Milvus index type: {index_type}, metric type: {metric_type}')
|
||||
|
||||
index_creation_params = {}
|
||||
if index_type == "HNSW":
|
||||
if index_type == 'HNSW':
|
||||
index_creation_params = {
|
||||
"M": MILVUS_HNSW_M,
|
||||
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
|
||||
'M': MILVUS_HNSW_M,
|
||||
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
|
||||
}
|
||||
log.info(f"HNSW params: {index_creation_params}")
|
||||
elif index_type == "IVF_FLAT":
|
||||
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
|
||||
log.info(f"IVF_FLAT params: {index_creation_params}")
|
||||
elif index_type == "DISKANN":
|
||||
log.info(f'HNSW params: {index_creation_params}')
|
||||
elif index_type == 'IVF_FLAT':
|
||||
index_creation_params = {'nlist': MILVUS_IVF_FLAT_NLIST}
|
||||
log.info(f'IVF_FLAT params: {index_creation_params}')
|
||||
elif index_type == 'DISKANN':
|
||||
index_creation_params = {
|
||||
"max_degree": MILVUS_DISKANN_MAX_DEGREE,
|
||||
"search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE,
|
||||
'max_degree': MILVUS_DISKANN_MAX_DEGREE,
|
||||
'search_list_size': MILVUS_DISKANN_SEARCH_LIST_SIZE,
|
||||
}
|
||||
log.info(f"DISKANN params: {index_creation_params}")
|
||||
elif index_type in ["FLAT", "AUTOINDEX"]:
|
||||
log.info(f"Using {index_type} index with no specific build-time params.")
|
||||
log.info(f'DISKANN params: {index_creation_params}')
|
||||
elif index_type in ['FLAT', 'AUTOINDEX']:
|
||||
log.info(f'Using {index_type} index with no specific build-time params.')
|
||||
else:
|
||||
log.warning(
|
||||
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
|
||||
f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. "
|
||||
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
|
||||
f'Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. '
|
||||
f'Milvus will use its default for the collection if this type is not directly supported for index creation.'
|
||||
)
|
||||
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
|
||||
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
|
||||
|
||||
index_params.add_index(
|
||||
field_name="vector",
|
||||
field_name='vector',
|
||||
index_type=index_type,
|
||||
metric_type=metric_type,
|
||||
params=index_creation_params,
|
||||
)
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
)
|
||||
@@ -170,17 +168,13 @@ class MilvusClient(VectorDBBase):
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
# Check if the collection exists based on the collection name.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
return self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
return self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
# Delete the collection based on the collection name.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
return self.client.drop_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
return self.client.drop_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -190,15 +184,15 @@ class MilvusClient(VectorDBBase):
|
||||
limit: int = 10,
|
||||
) -> Optional[SearchResult]:
|
||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
# For some index types like IVF_FLAT, search params like nprobe can be set.
|
||||
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
|
||||
# For simplicity, not adding configurable search_params here, but could be extended.
|
||||
result = self.client.search(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
data=vectors,
|
||||
limit=limit,
|
||||
output_fields=["data", "metadata"],
|
||||
output_fields=['data', 'metadata'],
|
||||
# search_params=search_params # Potentially add later if needed
|
||||
)
|
||||
return self._result_to_search_result(result)
|
||||
@@ -206,11 +200,9 @@ class MilvusClient(VectorDBBase):
|
||||
def query(self, collection_name: str, filter: dict, limit: int = -1):
|
||||
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
|
||||
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(
|
||||
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
log.warning(f'Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
|
||||
return None
|
||||
|
||||
filter_expressions = []
|
||||
@@ -220,9 +212,9 @@ class MilvusClient(VectorDBBase):
|
||||
else:
|
||||
filter_expressions.append(f'metadata["{key}"] == {value}')
|
||||
|
||||
filter_string = " && ".join(filter_expressions)
|
||||
filter_string = ' && '.join(filter_expressions)
|
||||
|
||||
collection = Collection(f"{self.collection_prefix}_{collection_name}")
|
||||
collection = Collection(f'{self.collection_prefix}_{collection_name}')
|
||||
collection.load()
|
||||
|
||||
try:
|
||||
@@ -233,9 +225,9 @@ class MilvusClient(VectorDBBase):
|
||||
iterator = collection.query_iterator(
|
||||
expr=filter_string,
|
||||
output_fields=[
|
||||
"id",
|
||||
"data",
|
||||
"metadata",
|
||||
'id',
|
||||
'data',
|
||||
'metadata',
|
||||
],
|
||||
limit=limit if limit > 0 else -1,
|
||||
)
|
||||
@@ -248,7 +240,7 @@ class MilvusClient(VectorDBBase):
|
||||
break
|
||||
all_results.extend(batch)
|
||||
|
||||
log.debug(f"Total results from query: {len(all_results)}")
|
||||
log.debug(f'Total results from query: {len(all_results)}')
|
||||
return self._result_to_get_result([all_results] if all_results else [[]])
|
||||
|
||||
except Exception as e:
|
||||
@@ -259,7 +251,7 @@ class MilvusClient(VectorDBBase):
|
||||
|
||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||
# Get all the items in the collection. This can be very resource-intensive for large collections.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
log.warning(
|
||||
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
|
||||
)
|
||||
@@ -269,35 +261,25 @@ class MilvusClient(VectorDBBase):
|
||||
|
||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
log.info(
|
||||
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
|
||||
)
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
|
||||
log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.')
|
||||
if not items:
|
||||
log.error(
|
||||
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
|
||||
)
|
||||
raise ValueError(
|
||||
"Cannot create Milvus collection without items to determine vector dimension."
|
||||
)
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
f'Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.'
|
||||
)
|
||||
raise ValueError('Cannot create Milvus collection without items to determine vector dimension.')
|
||||
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
|
||||
|
||||
log.info(
|
||||
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
|
||||
)
|
||||
log.info(f'Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
|
||||
return self.client.insert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": process_metadata(item["metadata"]),
|
||||
'id': item['id'],
|
||||
'vector': item['vector'],
|
||||
'data': {'text': item['text']},
|
||||
'metadata': process_metadata(item['metadata']),
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
@@ -305,35 +287,27 @@ class MilvusClient(VectorDBBase):
|
||||
|
||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
if not self.client.has_collection(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
||||
):
|
||||
log.info(
|
||||
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
|
||||
)
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
|
||||
log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.')
|
||||
if not items:
|
||||
log.error(
|
||||
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
|
||||
f'Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.'
|
||||
)
|
||||
raise ValueError(
|
||||
"Cannot create Milvus collection for upsert without items to determine vector dimension."
|
||||
)
|
||||
self._create_collection(
|
||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
||||
'Cannot create Milvus collection for upsert without items to determine vector dimension.'
|
||||
)
|
||||
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
|
||||
|
||||
log.info(
|
||||
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
|
||||
)
|
||||
log.info(f'Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
|
||||
return self.client.upsert(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"data": {"text": item["text"]},
|
||||
"metadata": process_metadata(item["metadata"]),
|
||||
'id': item['id'],
|
||||
'vector': item['vector'],
|
||||
'data': {'text': item['text']},
|
||||
'metadata': process_metadata(item['metadata']),
|
||||
}
|
||||
for item in items
|
||||
],
|
||||
@@ -346,46 +320,35 @@ class MilvusClient(VectorDBBase):
|
||||
filter: Optional[dict] = None,
|
||||
):
|
||||
# Delete the items from the collection based on the ids or filter.
|
||||
collection_name = collection_name.replace("-", "_")
|
||||
collection_name = collection_name.replace('-', '_')
|
||||
if not self.has_collection(collection_name):
|
||||
log.warning(
|
||||
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
||||
)
|
||||
log.warning(f'Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
|
||||
return None
|
||||
|
||||
if ids:
|
||||
log.info(
|
||||
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
|
||||
)
|
||||
log.info(f'Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}')
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
ids=ids,
|
||||
)
|
||||
elif filter:
|
||||
filter_string = " && ".join(
|
||||
[
|
||||
f'metadata["{key}"] == {json.dumps(value)}'
|
||||
for key, value in filter.items()
|
||||
]
|
||||
)
|
||||
filter_string = ' && '.join([f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items()])
|
||||
log.info(
|
||||
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
|
||||
f'Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}'
|
||||
)
|
||||
return self.client.delete(
|
||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
||||
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||
filter=filter_string,
|
||||
)
|
||||
else:
|
||||
log.warning(
|
||||
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
|
||||
f'Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.'
|
||||
)
|
||||
return None
|
||||
|
||||
def reset(self):
|
||||
# Resets the database. This will delete all collections and item entries that match the prefix.
|
||||
log.warning(
|
||||
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
|
||||
)
|
||||
log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
|
||||
collection_names = self.client.list_collections()
|
||||
deleted_collections = []
|
||||
for collection_name_full in collection_names:
|
||||
@@ -393,7 +356,7 @@ class MilvusClient(VectorDBBase):
|
||||
try:
|
||||
self.client.drop_collection(collection_name=collection_name_full)
|
||||
deleted_collections.append(collection_name_full)
|
||||
log.info(f"Deleted collection: {collection_name_full}")
|
||||
log.info(f'Deleted collection: {collection_name_full}')
|
||||
except Exception as e:
|
||||
log.error(f"Error deleting collection {collection_name_full}: {e}")
|
||||
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")
|
||||
log.error(f'Error deleting collection {collection_name_full}: {e}')
|
||||
log.info(f'Milvus reset complete. Deleted collections: {deleted_collections}')
|
||||
|
||||
@@ -33,26 +33,26 @@ from pymilvus import (
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
RESOURCE_ID_FIELD = "resource_id"
|
||||
RESOURCE_ID_FIELD = 'resource_id'
|
||||
|
||||
|
||||
class MilvusClient(VectorDBBase):
|
||||
def __init__(self):
|
||||
# Milvus collection names can only contain numbers, letters, and underscores.
|
||||
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
|
||||
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace('-', '_')
|
||||
connections.connect(
|
||||
alias="default",
|
||||
alias='default',
|
||||
uri=MILVUS_URI,
|
||||
token=MILVUS_TOKEN,
|
||||
db_name=MILVUS_DB,
|
||||
)
|
||||
|
||||
# Main collection types for multi-tenancy
|
||||
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
||||
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
|
||||
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
|
||||
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
|
||||
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
|
||||
self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories'
|
||||
self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge'
|
||||
self.FILE_COLLECTION = f'{self.collection_prefix}_files'
|
||||
self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web_search'
|
||||
self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash_based'
|
||||
self.shared_collections = [
|
||||
self.MEMORY_COLLECTION,
|
||||
self.KNOWLEDGE_COLLECTION,
|
||||
@@ -74,15 +74,13 @@ class MilvusClient(VectorDBBase):
|
||||
"""
|
||||
resource_id = collection_name
|
||||
|
||||
if collection_name.startswith("user-memory-"):
|
||||
if collection_name.startswith('user-memory-'):
|
||||
return self.MEMORY_COLLECTION, resource_id
|
||||
elif collection_name.startswith("file-"):
|
||||
elif collection_name.startswith('file-'):
|
||||
return self.FILE_COLLECTION, resource_id
|
||||
elif collection_name.startswith("web-search-"):
|
||||
elif collection_name.startswith('web-search-'):
|
||||
return self.WEB_SEARCH_COLLECTION, resource_id
|
||||
elif len(collection_name) == 63 and all(
|
||||
c in "0123456789abcdef" for c in collection_name
|
||||
):
|
||||
elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name):
|
||||
return self.HASH_BASED_COLLECTION, resource_id
|
||||
else:
|
||||
return self.KNOWLEDGE_COLLECTION, resource_id
|
||||
@@ -90,36 +88,36 @@ class MilvusClient(VectorDBBase):
|
||||
def _create_shared_collection(self, mt_collection_name: str, dimension: int):
|
||||
fields = [
|
||||
FieldSchema(
|
||||
name="id",
|
||||
name='id',
|
||||
dtype=DataType.VARCHAR,
|
||||
is_primary=True,
|
||||
auto_id=False,
|
||||
max_length=36,
|
||||
),
|
||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="metadata", dtype=DataType.JSON),
|
||||
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
||||
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name='metadata', dtype=DataType.JSON),
|
||||
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
|
||||
]
|
||||
schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
|
||||
schema = CollectionSchema(fields, 'Shared collection for multi-tenancy')
|
||||
collection = Collection(mt_collection_name, schema)
|
||||
|
||||
index_params = {
|
||||
"metric_type": MILVUS_METRIC_TYPE,
|
||||
"index_type": MILVUS_INDEX_TYPE,
|
||||
"params": {},
|
||||
'metric_type': MILVUS_METRIC_TYPE,
|
||||
'index_type': MILVUS_INDEX_TYPE,
|
||||
'params': {},
|
||||
}
|
||||
if MILVUS_INDEX_TYPE == "HNSW":
|
||||
index_params["params"] = {
|
||||
"M": MILVUS_HNSW_M,
|
||||
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
|
||||
if MILVUS_INDEX_TYPE == 'HNSW':
|
||||
index_params['params'] = {
|
||||
'M': MILVUS_HNSW_M,
|
||||
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
|
||||
}
|
||||
elif MILVUS_INDEX_TYPE == "IVF_FLAT":
|
||||
index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
|
||||
elif MILVUS_INDEX_TYPE == 'IVF_FLAT':
|
||||
index_params['params'] = {'nlist': MILVUS_IVF_FLAT_NLIST}
|
||||
|
||||
collection.create_index("vector", index_params)
|
||||
collection.create_index('vector', index_params)
|
||||
collection.create_index(RESOURCE_ID_FIELD)
|
||||
log.info(f"Created shared collection: {mt_collection_name}")
|
||||
log.info(f'Created shared collection: {mt_collection_name}')
|
||||
return collection
|
||||
|
||||
def _ensure_collection(self, mt_collection_name: str, dimension: int):
|
||||
@@ -127,9 +125,7 @@ class MilvusClient(VectorDBBase):
|
||||
self._create_shared_collection(mt_collection_name, dimension)
|
||||
|
||||
def has_collection(self, collection_name: str) -> bool:
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
if not utility.has_collection(mt_collection):
|
||||
return False
|
||||
|
||||
@@ -141,19 +137,17 @@ class MilvusClient(VectorDBBase):
|
||||
def upsert(self, collection_name: str, items: List[VectorItem]):
|
||||
if not items:
|
||||
return
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
dimension = len(items[0]["vector"])
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
dimension = len(items[0]['vector'])
|
||||
self._ensure_collection(mt_collection, dimension)
|
||||
collection = Collection(mt_collection)
|
||||
|
||||
entities = [
|
||||
{
|
||||
"id": item["id"],
|
||||
"vector": item["vector"],
|
||||
"text": item["text"],
|
||||
"metadata": item["metadata"],
|
||||
'id': item['id'],
|
||||
'vector': item['vector'],
|
||||
'text': item['text'],
|
||||
'metadata': item['metadata'],
|
||||
RESOURCE_ID_FIELD: resource_id,
|
||||
}
|
||||
for item in items
|
||||
@@ -170,41 +164,37 @@ class MilvusClient(VectorDBBase):
|
||||
if not vectors:
|
||||
return None
|
||||
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
if not utility.has_collection(mt_collection):
|
||||
return None
|
||||
|
||||
collection = Collection(mt_collection)
|
||||
collection.load()
|
||||
|
||||
search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
|
||||
search_params = {'metric_type': MILVUS_METRIC_TYPE, 'params': {}}
|
||||
results = collection.search(
|
||||
data=vectors,
|
||||
anns_field="vector",
|
||||
anns_field='vector',
|
||||
param=search_params,
|
||||
limit=limit,
|
||||
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
|
||||
output_fields=["id", "text", "metadata"],
|
||||
output_fields=['id', 'text', 'metadata'],
|
||||
)
|
||||
|
||||
ids, documents, metadatas, distances = [], [], [], []
|
||||
for hits in results:
|
||||
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
|
||||
for hit in hits:
|
||||
batch_ids.append(hit.entity.get("id"))
|
||||
batch_docs.append(hit.entity.get("text"))
|
||||
batch_metadatas.append(hit.entity.get("metadata"))
|
||||
batch_ids.append(hit.entity.get('id'))
|
||||
batch_docs.append(hit.entity.get('text'))
|
||||
batch_metadatas.append(hit.entity.get('metadata'))
|
||||
batch_dists.append(hit.distance)
|
||||
ids.append(batch_ids)
|
||||
documents.append(batch_docs)
|
||||
metadatas.append(batch_metadatas)
|
||||
distances.append(batch_dists)
|
||||
|
||||
return SearchResult(
|
||||
ids=ids, documents=documents, metadatas=metadatas, distances=distances
|
||||
)
|
||||
return SearchResult(ids=ids, documents=documents, metadatas=metadatas, distances=distances)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
@@ -212,9 +202,7 @@ class MilvusClient(VectorDBBase):
|
||||
ids: Optional[List[str]] = None,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
if not utility.has_collection(mt_collection):
|
||||
return
|
||||
|
||||
@@ -224,14 +212,14 @@ class MilvusClient(VectorDBBase):
|
||||
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
|
||||
if ids:
|
||||
# Milvus expects a string list for 'in' operator
|
||||
id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
|
||||
expr.append(f"id in [{id_list_str}]")
|
||||
id_list_str = ', '.join([f"'{id_val}'" for id_val in ids])
|
||||
expr.append(f'id in [{id_list_str}]')
|
||||
|
||||
if filter:
|
||||
for key, value in filter.items():
|
||||
expr.append(f"metadata['{key}'] == '{value}'")
|
||||
|
||||
collection.delete(" and ".join(expr))
|
||||
collection.delete(' and '.join(expr))
|
||||
|
||||
def reset(self):
|
||||
for collection_name in self.shared_collections:
|
||||
@@ -239,21 +227,15 @@ class MilvusClient(VectorDBBase):
|
||||
utility.drop_collection(collection_name)
|
||||
|
||||
def delete_collection(self, collection_name: str):
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
if not utility.has_collection(mt_collection):
|
||||
return
|
||||
|
||||
collection = Collection(mt_collection)
|
||||
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
|
||||
|
||||
def query(
|
||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
||||
) -> Optional[GetResult]:
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
||||
collection_name
|
||||
)
|
||||
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
||||
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||
if not utility.has_collection(mt_collection):
|
||||
return None
|
||||
|
||||
@@ -269,8 +251,8 @@ class MilvusClient(VectorDBBase):
|
||||
expr.append(f"metadata['{key}'] == {value}")
|
||||
|
||||
iterator = collection.query_iterator(
|
||||
expr=" and ".join(expr),
|
||||
output_fields=["id", "text", "metadata"],
|
||||
expr=' and '.join(expr),
|
||||
output_fields=['id', 'text', 'metadata'],
|
||||
limit=limit if limit else -1,
|
||||
)
|
||||
|
||||
@@ -282,9 +264,9 @@ class MilvusClient(VectorDBBase):
|
||||
break
|
||||
all_results.extend(batch)
|
||||
|
||||
ids = [res["id"] for res in all_results]
|
||||
documents = [res["text"] for res in all_results]
|
||||
metadatas = [res["metadata"] for res in all_results]
|
||||
ids = [res['id'] for res in all_results]
|
||||
documents = [res['text'] for res in all_results]
|
||||
metadatas = [res['metadata'] for res in all_results]
|
||||
|
||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user