mirror of
https://github.com/open-webui/open-webui.git
synced 2026-04-25 17:15:16 +02:00
refac
This commit is contained in:
@@ -10,94 +10,88 @@ from typing_extensions import Annotated
|
|||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
KEY_FILE = Path.cwd() / ".webui_secret_key"
|
KEY_FILE = Path.cwd() / '.webui_secret_key'
|
||||||
|
|
||||||
|
|
||||||
def version_callback(value: bool):
|
def version_callback(value: bool):
|
||||||
if value:
|
if value:
|
||||||
from open_webui.env import VERSION
|
from open_webui.env import VERSION
|
||||||
|
|
||||||
typer.echo(f"Open WebUI version: {VERSION}")
|
typer.echo(f'Open WebUI version: {VERSION}')
|
||||||
raise typer.Exit()
|
raise typer.Exit()
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def main(
|
def main(
|
||||||
version: Annotated[
|
version: Annotated[Optional[bool], typer.Option('--version', callback=version_callback)] = None,
|
||||||
Optional[bool], typer.Option("--version", callback=version_callback)
|
|
||||||
] = None,
|
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def serve(
|
def serve(
|
||||||
host: str = "0.0.0.0",
|
host: str = '0.0.0.0',
|
||||||
port: int = 8080,
|
port: int = 8080,
|
||||||
):
|
):
|
||||||
os.environ["FROM_INIT_PY"] = "true"
|
os.environ['FROM_INIT_PY'] = 'true'
|
||||||
if os.getenv("WEBUI_SECRET_KEY") is None:
|
if os.getenv('WEBUI_SECRET_KEY') is None:
|
||||||
typer.echo(
|
typer.echo('Loading WEBUI_SECRET_KEY from file, not provided as an environment variable.')
|
||||||
"Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
|
|
||||||
)
|
|
||||||
if not KEY_FILE.exists():
|
if not KEY_FILE.exists():
|
||||||
typer.echo(f"Generating a new secret key and saving it to {KEY_FILE}")
|
typer.echo(f'Generating a new secret key and saving it to {KEY_FILE}')
|
||||||
KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
|
KEY_FILE.write_bytes(base64.b64encode(random.randbytes(12)))
|
||||||
typer.echo(f"Loading WEBUI_SECRET_KEY from {KEY_FILE}")
|
typer.echo(f'Loading WEBUI_SECRET_KEY from {KEY_FILE}')
|
||||||
os.environ["WEBUI_SECRET_KEY"] = KEY_FILE.read_text()
|
os.environ['WEBUI_SECRET_KEY'] = KEY_FILE.read_text()
|
||||||
|
|
||||||
if os.getenv("USE_CUDA_DOCKER", "false") == "true":
|
if os.getenv('USE_CUDA_DOCKER', 'false') == 'true':
|
||||||
typer.echo(
|
typer.echo('CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries.')
|
||||||
"CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
|
LD_LIBRARY_PATH = os.getenv('LD_LIBRARY_PATH', '').split(':')
|
||||||
)
|
os.environ['LD_LIBRARY_PATH'] = ':'.join(
|
||||||
LD_LIBRARY_PATH = os.getenv("LD_LIBRARY_PATH", "").split(":")
|
|
||||||
os.environ["LD_LIBRARY_PATH"] = ":".join(
|
|
||||||
LD_LIBRARY_PATH
|
LD_LIBRARY_PATH
|
||||||
+ [
|
+ [
|
||||||
"/usr/local/lib/python3.11/site-packages/torch/lib",
|
'/usr/local/lib/python3.11/site-packages/torch/lib',
|
||||||
"/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib",
|
'/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib',
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
assert torch.cuda.is_available(), "CUDA not available"
|
assert torch.cuda.is_available(), 'CUDA not available'
|
||||||
typer.echo("CUDA seems to be working")
|
typer.echo('CUDA seems to be working')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
typer.echo(
|
typer.echo(
|
||||||
"Error when testing CUDA but USE_CUDA_DOCKER is true. "
|
'Error when testing CUDA but USE_CUDA_DOCKER is true. '
|
||||||
"Resetting USE_CUDA_DOCKER to false and removing "
|
'Resetting USE_CUDA_DOCKER to false and removing '
|
||||||
f"LD_LIBRARY_PATH modifications: {e}"
|
f'LD_LIBRARY_PATH modifications: {e}'
|
||||||
)
|
)
|
||||||
os.environ["USE_CUDA_DOCKER"] = "false"
|
os.environ['USE_CUDA_DOCKER'] = 'false'
|
||||||
os.environ["LD_LIBRARY_PATH"] = ":".join(LD_LIBRARY_PATH)
|
os.environ['LD_LIBRARY_PATH'] = ':'.join(LD_LIBRARY_PATH)
|
||||||
|
|
||||||
import open_webui.main # we need set environment variables before importing main
|
import open_webui.main # we need set environment variables before importing main
|
||||||
from open_webui.env import UVICORN_WORKERS # Import the workers setting
|
from open_webui.env import UVICORN_WORKERS # Import the workers setting
|
||||||
|
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"open_webui.main:app",
|
'open_webui.main:app',
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
forwarded_allow_ips="*",
|
forwarded_allow_ips='*',
|
||||||
workers=UVICORN_WORKERS,
|
workers=UVICORN_WORKERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def dev(
|
def dev(
|
||||||
host: str = "0.0.0.0",
|
host: str = '0.0.0.0',
|
||||||
port: int = 8080,
|
port: int = 8080,
|
||||||
reload: bool = True,
|
reload: bool = True,
|
||||||
):
|
):
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"open_webui.main:app",
|
'open_webui.main:app',
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
reload=reload,
|
reload=reload,
|
||||||
forwarded_allow_ips="*",
|
forwarded_allow_ips='*',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == '__main__':
|
||||||
app()
|
app()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -2,125 +2,107 @@ from enum import Enum
|
|||||||
|
|
||||||
|
|
||||||
class MESSAGES(str, Enum):
|
class MESSAGES(str, Enum):
|
||||||
DEFAULT = lambda msg="": f"{msg if msg else ''}"
|
DEFAULT = lambda msg='': f'{msg if msg else ""}'
|
||||||
MODEL_ADDED = lambda model="": f"The model '{model}' has been added successfully."
|
MODEL_ADDED = lambda model='': f"The model '{model}' has been added successfully."
|
||||||
MODEL_DELETED = (
|
MODEL_DELETED = lambda model='': f"The model '{model}' has been deleted successfully."
|
||||||
lambda model="": f"The model '{model}' has been deleted successfully."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WEBHOOK_MESSAGES(str, Enum):
|
class WEBHOOK_MESSAGES(str, Enum):
|
||||||
DEFAULT = lambda msg="": f"{msg if msg else ''}"
|
DEFAULT = lambda msg='': f'{msg if msg else ""}'
|
||||||
USER_SIGNUP = lambda username="": (
|
USER_SIGNUP = lambda username='': (f'New user signed up: {username}' if username else 'New user signed up')
|
||||||
f"New user signed up: {username}" if username else "New user signed up"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ERROR_MESSAGES(str, Enum):
|
class ERROR_MESSAGES(str, Enum):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return super().__str__()
|
return super().__str__()
|
||||||
|
|
||||||
DEFAULT = (
|
DEFAULT = lambda err='': f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
|
||||||
lambda err="": f'{"Something went wrong :/" if err == "" else "[ERROR: " + str(err) + "]"}'
|
ENV_VAR_NOT_FOUND = 'Required environment variable not found. Terminating now.'
|
||||||
|
CREATE_USER_ERROR = 'Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance.'
|
||||||
|
DELETE_USER_ERROR = 'Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot.'
|
||||||
|
EMAIL_MISMATCH = 'Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again.'
|
||||||
|
EMAIL_TAKEN = 'Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew.'
|
||||||
|
USERNAME_TAKEN = 'Uh-oh! This username is already registered. Please choose another username.'
|
||||||
|
PASSWORD_TOO_LONG = (
|
||||||
|
'Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long.'
|
||||||
)
|
)
|
||||||
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
|
COMMAND_TAKEN = 'Uh-oh! This command is already registered. Please choose another command string.'
|
||||||
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
|
FILE_EXISTS = 'Uh-oh! This file is already registered. Please choose another file.'
|
||||||
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
|
|
||||||
EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again."
|
|
||||||
EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew."
|
|
||||||
USERNAME_TAKEN = (
|
|
||||||
"Uh-oh! This username is already registered. Please choose another username."
|
|
||||||
)
|
|
||||||
PASSWORD_TOO_LONG = "Uh-oh! The password you entered is too long. Please make sure your password is less than 72 bytes long."
|
|
||||||
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
|
|
||||||
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
|
|
||||||
|
|
||||||
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
|
ID_TAKEN = 'Uh-oh! This id is already registered. Please choose another id string.'
|
||||||
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
|
MODEL_ID_TAKEN = 'Uh-oh! This model id is already registered. Please choose another model id string.'
|
||||||
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
|
NAME_TAG_TAKEN = 'Uh-oh! This name tag is already registered. Please choose another name tag string.'
|
||||||
MODEL_ID_TOO_LONG = "The model id is too long. Please make sure your model id is less than 256 characters long."
|
MODEL_ID_TOO_LONG = 'The model id is too long. Please make sure your model id is less than 256 characters long.'
|
||||||
|
|
||||||
INVALID_TOKEN = (
|
INVALID_TOKEN = 'Your session has expired or the token is invalid. Please sign in again.'
|
||||||
"Your session has expired or the token is invalid. Please sign in again."
|
INVALID_CRED = 'The email or password provided is incorrect. Please check for typos and try logging in again.'
|
||||||
)
|
|
||||||
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
|
|
||||||
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
|
||||||
INCORRECT_PASSWORD = (
|
INCORRECT_PASSWORD = 'The password provided is incorrect. Please check for typos and try again.'
|
||||||
"The password provided is incorrect. Please check for typos and try again."
|
INVALID_TRUSTED_HEADER = (
|
||||||
|
'Your provider has not provided a trusted header. Please contact your administrator for assistance.'
|
||||||
)
|
)
|
||||||
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
|
|
||||||
|
|
||||||
EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation."
|
EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation."
|
||||||
|
|
||||||
UNAUTHORIZED = "401 Unauthorized"
|
UNAUTHORIZED = '401 Unauthorized'
|
||||||
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
|
ACCESS_PROHIBITED = (
|
||||||
ACTION_PROHIBITED = (
|
'You do not have permission to access this resource. Please contact your administrator for assistance.'
|
||||||
"The requested action has been restricted as a security measure."
|
|
||||||
)
|
)
|
||||||
|
ACTION_PROHIBITED = 'The requested action has been restricted as a security measure.'
|
||||||
|
|
||||||
FILE_NOT_SENT = "FILE_NOT_SENT"
|
FILE_NOT_SENT = 'FILE_NOT_SENT'
|
||||||
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again."
|
FILE_NOT_SUPPORTED = "Oops! It seems like the file format you're trying to upload is not supported. Please upload a file with a supported format and try again."
|
||||||
|
|
||||||
NOT_FOUND = "We could not find what you're looking for :/"
|
NOT_FOUND = "We could not find what you're looking for :/"
|
||||||
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
USER_NOT_FOUND = "We could not find what you're looking for :/"
|
||||||
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
|
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
|
||||||
API_KEY_NOT_ALLOWED = "Use of API key is not enabled in the environment."
|
API_KEY_NOT_ALLOWED = 'Use of API key is not enabled in the environment.'
|
||||||
|
|
||||||
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
|
MALICIOUS = 'Unusual activities detected, please try again in a few minutes.'
|
||||||
|
|
||||||
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
|
PANDOC_NOT_INSTALLED = 'Pandoc is not installed on the server. Please contact your administrator for assistance.'
|
||||||
INCORRECT_FORMAT = (
|
INCORRECT_FORMAT = lambda err='': f'Invalid format. Please use the correct format{err}'
|
||||||
lambda err="": f"Invalid format. Please use the correct format{err}"
|
RATE_LIMIT_EXCEEDED = 'API rate limit exceeded'
|
||||||
)
|
|
||||||
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
|
|
||||||
|
|
||||||
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
|
MODEL_NOT_FOUND = lambda name='': f"Model '{name}' was not found"
|
||||||
OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
|
OPENAI_NOT_FOUND = lambda name='': 'OpenAI API was not found'
|
||||||
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
|
OLLAMA_NOT_FOUND = 'WebUI could not connect to Ollama'
|
||||||
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
|
CREATE_API_KEY_ERROR = 'Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance.'
|
||||||
API_KEY_CREATION_NOT_ALLOWED = "API key creation is not allowed in the environment."
|
API_KEY_CREATION_NOT_ALLOWED = 'API key creation is not allowed in the environment.'
|
||||||
|
|
||||||
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
|
EMPTY_CONTENT = 'The content provided is empty. Please ensure that there is text or data present before proceeding.'
|
||||||
|
|
||||||
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
|
DB_NOT_SQLITE = 'This feature is only available when running with SQLite databases.'
|
||||||
|
|
||||||
INVALID_URL = (
|
INVALID_URL = 'Oops! The URL you provided is invalid. Please double-check and try again.'
|
||||||
"Oops! The URL you provided is invalid. Please double-check and try again."
|
|
||||||
)
|
|
||||||
|
|
||||||
WEB_SEARCH_ERROR = (
|
WEB_SEARCH_ERROR = lambda err='': f'{err if err else "Oops! Something went wrong while searching the web."}'
|
||||||
lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
OLLAMA_API_DISABLED = (
|
OLLAMA_API_DISABLED = 'The Ollama API is disabled. Please enable it to use this feature.'
|
||||||
"The Ollama API is disabled. Please enable it to use this feature."
|
|
||||||
)
|
|
||||||
|
|
||||||
FILE_TOO_LARGE = (
|
FILE_TOO_LARGE = (
|
||||||
lambda size="": f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}."
|
lambda size='': f"Oops! The file you're trying to upload is too large. Please upload a file that is less than {size}."
|
||||||
)
|
)
|
||||||
|
|
||||||
DUPLICATE_CONTENT = (
|
DUPLICATE_CONTENT = 'Duplicate content detected. Please provide unique content to proceed.'
|
||||||
"Duplicate content detected. Please provide unique content to proceed."
|
FILE_NOT_PROCESSED = (
|
||||||
|
'Extracted content is not available for this file. Please ensure that the file is processed before proceeding.'
|
||||||
)
|
)
|
||||||
FILE_NOT_PROCESSED = "Extracted content is not available for this file. Please ensure that the file is processed before proceeding."
|
|
||||||
|
|
||||||
INVALID_PASSWORD = lambda err="": (
|
INVALID_PASSWORD = lambda err='': (err if err else 'The password does not meet the required validation criteria.')
|
||||||
err if err else "The password does not meet the required validation criteria."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TASKS(str, Enum):
|
class TASKS(str, Enum):
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return super().__str__()
|
return super().__str__()
|
||||||
|
|
||||||
DEFAULT = lambda task="": f"{task if task else 'generation'}"
|
DEFAULT = lambda task='': f'{task if task else "generation"}'
|
||||||
TITLE_GENERATION = "title_generation"
|
TITLE_GENERATION = 'title_generation'
|
||||||
FOLLOW_UP_GENERATION = "follow_up_generation"
|
FOLLOW_UP_GENERATION = 'follow_up_generation'
|
||||||
TAGS_GENERATION = "tags_generation"
|
TAGS_GENERATION = 'tags_generation'
|
||||||
EMOJI_GENERATION = "emoji_generation"
|
EMOJI_GENERATION = 'emoji_generation'
|
||||||
QUERY_GENERATION = "query_generation"
|
QUERY_GENERATION = 'query_generation'
|
||||||
IMAGE_PROMPT_GENERATION = "image_prompt_generation"
|
IMAGE_PROMPT_GENERATION = 'image_prompt_generation'
|
||||||
AUTOCOMPLETE_GENERATION = "autocomplete_generation"
|
AUTOCOMPLETE_GENERATION = 'autocomplete_generation'
|
||||||
FUNCTION_CALLING = "function_calling"
|
FUNCTION_CALLING = 'function_calling'
|
||||||
MOA_RESPONSE_GENERATION = "moa_response_generation"
|
MOA_RESPONSE_GENERATION = 'moa_response_generation'
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -57,17 +57,15 @@ log = logging.getLogger(__name__)
|
|||||||
def get_function_module_by_id(request: Request, pipe_id: str):
|
def get_function_module_by_id(request: Request, pipe_id: str):
|
||||||
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
|
function_module, _, _ = get_function_module_from_cache(request, pipe_id)
|
||||||
|
|
||||||
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
|
if hasattr(function_module, 'valves') and hasattr(function_module, 'Valves'):
|
||||||
Valves = function_module.Valves
|
Valves = function_module.Valves
|
||||||
valves = Functions.get_function_valves_by_id(pipe_id)
|
valves = Functions.get_function_valves_by_id(pipe_id)
|
||||||
|
|
||||||
if valves:
|
if valves:
|
||||||
try:
|
try:
|
||||||
function_module.valves = Valves(
|
function_module.valves = Valves(**{k: v for k, v in valves.items() if v is not None})
|
||||||
**{k: v for k, v in valves.items() if v is not None}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error loading valves for function {pipe_id}: {e}")
|
log.exception(f'Error loading valves for function {pipe_id}: {e}')
|
||||||
raise e
|
raise e
|
||||||
else:
|
else:
|
||||||
function_module.valves = Valves()
|
function_module.valves = Valves()
|
||||||
@@ -76,7 +74,7 @@ def get_function_module_by_id(request: Request, pipe_id: str):
|
|||||||
|
|
||||||
|
|
||||||
async def get_function_models(request):
|
async def get_function_models(request):
|
||||||
pipes = Functions.get_functions_by_type("pipe", active_only=True)
|
pipes = Functions.get_functions_by_type('pipe', active_only=True)
|
||||||
pipe_models = []
|
pipe_models = []
|
||||||
|
|
||||||
for pipe in pipes:
|
for pipe in pipes:
|
||||||
@@ -84,11 +82,11 @@ async def get_function_models(request):
|
|||||||
function_module = get_function_module_by_id(request, pipe.id)
|
function_module = get_function_module_by_id(request, pipe.id)
|
||||||
|
|
||||||
has_user_valves = False
|
has_user_valves = False
|
||||||
if hasattr(function_module, "UserValves"):
|
if hasattr(function_module, 'UserValves'):
|
||||||
has_user_valves = True
|
has_user_valves = True
|
||||||
|
|
||||||
# Check if function is a manifold
|
# Check if function is a manifold
|
||||||
if hasattr(function_module, "pipes"):
|
if hasattr(function_module, 'pipes'):
|
||||||
sub_pipes = []
|
sub_pipes = []
|
||||||
|
|
||||||
# Handle pipes being a list, sync function, or async function
|
# Handle pipes being a list, sync function, or async function
|
||||||
@@ -104,32 +102,30 @@ async def get_function_models(request):
|
|||||||
log.exception(e)
|
log.exception(e)
|
||||||
sub_pipes = []
|
sub_pipes = []
|
||||||
|
|
||||||
log.debug(
|
log.debug(f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}")
|
||||||
f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for p in sub_pipes:
|
for p in sub_pipes:
|
||||||
sub_pipe_id = f'{pipe.id}.{p["id"]}'
|
sub_pipe_id = f'{pipe.id}.{p["id"]}'
|
||||||
sub_pipe_name = p["name"]
|
sub_pipe_name = p['name']
|
||||||
|
|
||||||
if hasattr(function_module, "name"):
|
if hasattr(function_module, 'name'):
|
||||||
sub_pipe_name = f"{function_module.name}{sub_pipe_name}"
|
sub_pipe_name = f'{function_module.name}{sub_pipe_name}'
|
||||||
|
|
||||||
pipe_flag = {"type": pipe.type}
|
pipe_flag = {'type': pipe.type}
|
||||||
|
|
||||||
pipe_models.append(
|
pipe_models.append(
|
||||||
{
|
{
|
||||||
"id": sub_pipe_id,
|
'id': sub_pipe_id,
|
||||||
"name": sub_pipe_name,
|
'name': sub_pipe_name,
|
||||||
"object": "model",
|
'object': 'model',
|
||||||
"created": pipe.created_at,
|
'created': pipe.created_at,
|
||||||
"owned_by": "openai",
|
'owned_by': 'openai',
|
||||||
"pipe": pipe_flag,
|
'pipe': pipe_flag,
|
||||||
"has_user_valves": has_user_valves,
|
'has_user_valves': has_user_valves,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
pipe_flag = {"type": "pipe"}
|
pipe_flag = {'type': 'pipe'}
|
||||||
|
|
||||||
log.debug(
|
log.debug(
|
||||||
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
|
f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}"
|
||||||
@@ -137,13 +133,13 @@ async def get_function_models(request):
|
|||||||
|
|
||||||
pipe_models.append(
|
pipe_models.append(
|
||||||
{
|
{
|
||||||
"id": pipe.id,
|
'id': pipe.id,
|
||||||
"name": pipe.name,
|
'name': pipe.name,
|
||||||
"object": "model",
|
'object': 'model',
|
||||||
"created": pipe.created_at,
|
'created': pipe.created_at,
|
||||||
"owned_by": "openai",
|
'owned_by': 'openai',
|
||||||
"pipe": pipe_flag,
|
'pipe': pipe_flag,
|
||||||
"has_user_valves": has_user_valves,
|
'has_user_valves': has_user_valves,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -153,9 +149,7 @@ async def get_function_models(request):
|
|||||||
return pipe_models
|
return pipe_models
|
||||||
|
|
||||||
|
|
||||||
async def generate_function_chat_completion(
|
async def generate_function_chat_completion(request, form_data, user, models: dict = {}):
|
||||||
request, form_data, user, models: dict = {}
|
|
||||||
):
|
|
||||||
async def execute_pipe(pipe, params):
|
async def execute_pipe(pipe, params):
|
||||||
if inspect.iscoroutinefunction(pipe):
|
if inspect.iscoroutinefunction(pipe):
|
||||||
return await pipe(**params)
|
return await pipe(**params)
|
||||||
@@ -166,32 +160,32 @@ async def generate_function_chat_completion(
|
|||||||
if isinstance(res, str):
|
if isinstance(res, str):
|
||||||
return res
|
return res
|
||||||
if isinstance(res, Generator):
|
if isinstance(res, Generator):
|
||||||
return "".join(map(str, res))
|
return ''.join(map(str, res))
|
||||||
if isinstance(res, AsyncGenerator):
|
if isinstance(res, AsyncGenerator):
|
||||||
return "".join([str(stream) async for stream in res])
|
return ''.join([str(stream) async for stream in res])
|
||||||
|
|
||||||
def process_line(form_data: dict, line):
|
def process_line(form_data: dict, line):
|
||||||
if isinstance(line, BaseModel):
|
if isinstance(line, BaseModel):
|
||||||
line = line.model_dump_json()
|
line = line.model_dump_json()
|
||||||
line = f"data: {line}"
|
line = f'data: {line}'
|
||||||
if isinstance(line, dict):
|
if isinstance(line, dict):
|
||||||
line = f"data: {json.dumps(line)}"
|
line = f'data: {json.dumps(line)}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
line = line.decode("utf-8")
|
line = line.decode('utf-8')
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if line.startswith("data:"):
|
if line.startswith('data:'):
|
||||||
return f"{line}\n\n"
|
return f'{line}\n\n'
|
||||||
else:
|
else:
|
||||||
line = openai_chat_chunk_message_template(form_data["model"], line)
|
line = openai_chat_chunk_message_template(form_data['model'], line)
|
||||||
return f"data: {json.dumps(line)}\n\n"
|
return f'data: {json.dumps(line)}\n\n'
|
||||||
|
|
||||||
def get_pipe_id(form_data: dict) -> str:
|
def get_pipe_id(form_data: dict) -> str:
|
||||||
pipe_id = form_data["model"]
|
pipe_id = form_data['model']
|
||||||
if "." in pipe_id:
|
if '.' in pipe_id:
|
||||||
pipe_id, _ = pipe_id.split(".", 1)
|
pipe_id, _ = pipe_id.split('.', 1)
|
||||||
return pipe_id
|
return pipe_id
|
||||||
|
|
||||||
def get_function_params(function_module, form_data, user, extra_params=None):
|
def get_function_params(function_module, form_data, user, extra_params=None):
|
||||||
@@ -202,27 +196,25 @@ async def generate_function_chat_completion(
|
|||||||
|
|
||||||
# Get the signature of the function
|
# Get the signature of the function
|
||||||
sig = inspect.signature(function_module.pipe)
|
sig = inspect.signature(function_module.pipe)
|
||||||
params = {"body": form_data} | {
|
params = {'body': form_data} | {k: v for k, v in extra_params.items() if k in sig.parameters}
|
||||||
k: v for k, v in extra_params.items() if k in sig.parameters
|
|
||||||
}
|
|
||||||
|
|
||||||
if "__user__" in params and hasattr(function_module, "UserValves"):
|
if '__user__' in params and hasattr(function_module, 'UserValves'):
|
||||||
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
|
user_valves = Functions.get_user_valves_by_id_and_user_id(pipe_id, user.id)
|
||||||
try:
|
try:
|
||||||
params["__user__"]["valves"] = function_module.UserValves(**user_valves)
|
params['__user__']['valves'] = function_module.UserValves(**user_valves)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
params["__user__"]["valves"] = function_module.UserValves()
|
params['__user__']['valves'] = function_module.UserValves()
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
model_id = form_data.get("model")
|
model_id = form_data.get('model')
|
||||||
model_info = Models.get_model_by_id(model_id)
|
model_info = Models.get_model_by_id(model_id)
|
||||||
|
|
||||||
metadata = form_data.pop("metadata", {})
|
metadata = form_data.pop('metadata', {})
|
||||||
|
|
||||||
files = metadata.get("files", [])
|
files = metadata.get('files', [])
|
||||||
tool_ids = metadata.get("tool_ids", [])
|
tool_ids = metadata.get('tool_ids', [])
|
||||||
# Check if tool_ids is None
|
# Check if tool_ids is None
|
||||||
if tool_ids is None:
|
if tool_ids is None:
|
||||||
tool_ids = []
|
tool_ids = []
|
||||||
@@ -233,56 +225,56 @@ async def generate_function_chat_completion(
|
|||||||
__task_body__ = None
|
__task_body__ = None
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
if all(k in metadata for k in ("session_id", "chat_id", "message_id")):
|
if all(k in metadata for k in ('session_id', 'chat_id', 'message_id')):
|
||||||
__event_emitter__ = get_event_emitter(metadata)
|
__event_emitter__ = get_event_emitter(metadata)
|
||||||
__event_call__ = get_event_call(metadata)
|
__event_call__ = get_event_call(metadata)
|
||||||
__task__ = metadata.get("task", None)
|
__task__ = metadata.get('task', None)
|
||||||
__task_body__ = metadata.get("task_body", None)
|
__task_body__ = metadata.get('task_body', None)
|
||||||
|
|
||||||
oauth_token = None
|
oauth_token = None
|
||||||
try:
|
try:
|
||||||
if request.cookies.get("oauth_session_id", None):
|
if request.cookies.get('oauth_session_id', None):
|
||||||
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
oauth_token = await request.app.state.oauth_manager.get_oauth_token(
|
||||||
user.id,
|
user.id,
|
||||||
request.cookies.get("oauth_session_id", None),
|
request.cookies.get('oauth_session_id', None),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error getting OAuth token: {e}")
|
log.error(f'Error getting OAuth token: {e}')
|
||||||
|
|
||||||
extra_params = {
|
extra_params = {
|
||||||
"__event_emitter__": __event_emitter__,
|
'__event_emitter__': __event_emitter__,
|
||||||
"__event_call__": __event_call__,
|
'__event_call__': __event_call__,
|
||||||
"__chat_id__": metadata.get("chat_id", None),
|
'__chat_id__': metadata.get('chat_id', None),
|
||||||
"__session_id__": metadata.get("session_id", None),
|
'__session_id__': metadata.get('session_id', None),
|
||||||
"__message_id__": metadata.get("message_id", None),
|
'__message_id__': metadata.get('message_id', None),
|
||||||
"__task__": __task__,
|
'__task__': __task__,
|
||||||
"__task_body__": __task_body__,
|
'__task_body__': __task_body__,
|
||||||
"__files__": files,
|
'__files__': files,
|
||||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
'__user__': user.model_dump() if isinstance(user, UserModel) else {},
|
||||||
"__metadata__": metadata,
|
'__metadata__': metadata,
|
||||||
"__oauth_token__": oauth_token,
|
'__oauth_token__': oauth_token,
|
||||||
"__request__": request,
|
'__request__': request,
|
||||||
}
|
}
|
||||||
extra_params["__tools__"] = await get_tools(
|
extra_params['__tools__'] = await get_tools(
|
||||||
request,
|
request,
|
||||||
tool_ids,
|
tool_ids,
|
||||||
user,
|
user,
|
||||||
{
|
{
|
||||||
**extra_params,
|
**extra_params,
|
||||||
"__model__": models.get(form_data["model"], None),
|
'__model__': models.get(form_data['model'], None),
|
||||||
"__messages__": form_data["messages"],
|
'__messages__': form_data['messages'],
|
||||||
"__files__": files,
|
'__files__': files,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_info:
|
if model_info:
|
||||||
if model_info.base_model_id:
|
if model_info.base_model_id:
|
||||||
form_data["model"] = model_info.base_model_id
|
form_data['model'] = model_info.base_model_id
|
||||||
|
|
||||||
params = model_info.params.model_dump()
|
params = model_info.params.model_dump()
|
||||||
|
|
||||||
if params:
|
if params:
|
||||||
system = params.pop("system", None)
|
system = params.pop('system', None)
|
||||||
form_data = apply_model_params_to_body_openai(params, form_data)
|
form_data = apply_model_params_to_body_openai(params, form_data)
|
||||||
form_data = apply_system_prompt_to_body(system, form_data, metadata, user)
|
form_data = apply_system_prompt_to_body(system, form_data, metadata, user)
|
||||||
|
|
||||||
@@ -292,7 +284,7 @@ async def generate_function_chat_completion(
|
|||||||
pipe = function_module.pipe
|
pipe = function_module.pipe
|
||||||
params = get_function_params(function_module, form_data, user, extra_params)
|
params = get_function_params(function_module, form_data, user, extra_params)
|
||||||
|
|
||||||
if form_data.get("stream", False):
|
if form_data.get('stream', False):
|
||||||
|
|
||||||
async def stream_content():
|
async def stream_content():
|
||||||
try:
|
try:
|
||||||
@@ -304,17 +296,17 @@ async def generate_function_chat_completion(
|
|||||||
yield data
|
yield data
|
||||||
return
|
return
|
||||||
if isinstance(res, dict):
|
if isinstance(res, dict):
|
||||||
yield f"data: {json.dumps(res)}\n\n"
|
yield f'data: {json.dumps(res)}\n\n'
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error: {e}")
|
log.error(f'Error: {e}')
|
||||||
yield f"data: {json.dumps({'error': {'detail':str(e)}})}\n\n"
|
yield f'data: {json.dumps({"error": {"detail": str(e)}})}\n\n'
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(res, str):
|
if isinstance(res, str):
|
||||||
message = openai_chat_chunk_message_template(form_data["model"], res)
|
message = openai_chat_chunk_message_template(form_data['model'], res)
|
||||||
yield f"data: {json.dumps(message)}\n\n"
|
yield f'data: {json.dumps(message)}\n\n'
|
||||||
|
|
||||||
if isinstance(res, Iterator):
|
if isinstance(res, Iterator):
|
||||||
for line in res:
|
for line in res:
|
||||||
@@ -325,21 +317,19 @@ async def generate_function_chat_completion(
|
|||||||
yield process_line(form_data, line)
|
yield process_line(form_data, line)
|
||||||
|
|
||||||
if isinstance(res, str) or isinstance(res, Generator):
|
if isinstance(res, str) or isinstance(res, Generator):
|
||||||
finish_message = openai_chat_chunk_message_template(
|
finish_message = openai_chat_chunk_message_template(form_data['model'], '')
|
||||||
form_data["model"], ""
|
finish_message['choices'][0]['finish_reason'] = 'stop'
|
||||||
)
|
yield f'data: {json.dumps(finish_message)}\n\n'
|
||||||
finish_message["choices"][0]["finish_reason"] = "stop"
|
yield 'data: [DONE]'
|
||||||
yield f"data: {json.dumps(finish_message)}\n\n"
|
|
||||||
yield "data: [DONE]"
|
|
||||||
|
|
||||||
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
return StreamingResponse(stream_content(), media_type='text/event-stream')
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
res = await execute_pipe(pipe, params)
|
res = await execute_pipe(pipe, params)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error: {e}")
|
log.error(f'Error: {e}')
|
||||||
return {"error": {"detail": str(e)}}
|
return {'error': {'detail': str(e)}}
|
||||||
|
|
||||||
if isinstance(res, StreamingResponse) or isinstance(res, dict):
|
if isinstance(res, StreamingResponse) or isinstance(res, dict):
|
||||||
return res
|
return res
|
||||||
@@ -347,4 +337,4 @@ async def generate_function_chat_completion(
|
|||||||
return res.model_dump()
|
return res.model_dump()
|
||||||
|
|
||||||
message = await get_message_content(res)
|
message = await get_message_content(res)
|
||||||
return openai_chat_completion_message_template(form_data["model"], message)
|
return openai_chat_completion_message_template(form_data['model'], message)
|
||||||
|
|||||||
@@ -56,17 +56,15 @@ def handle_peewee_migration(DATABASE_URL):
|
|||||||
# db = None
|
# db = None
|
||||||
try:
|
try:
|
||||||
# Replace the postgresql:// with postgres:// to handle the peewee migration
|
# Replace the postgresql:// with postgres:// to handle the peewee migration
|
||||||
db = register_connection(DATABASE_URL.replace("postgresql://", "postgres://"))
|
db = register_connection(DATABASE_URL.replace('postgresql://', 'postgres://'))
|
||||||
migrate_dir = OPEN_WEBUI_DIR / "internal" / "migrations"
|
migrate_dir = OPEN_WEBUI_DIR / 'internal' / 'migrations'
|
||||||
router = Router(db, logger=log, migrate_dir=migrate_dir)
|
router = Router(db, logger=log, migrate_dir=migrate_dir)
|
||||||
router.run()
|
router.run()
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Failed to initialize the database connection: {e}")
|
log.error(f'Failed to initialize the database connection: {e}')
|
||||||
log.warning(
|
log.warning('Hint: If your database password contains special characters, you may need to URL-encode it.')
|
||||||
"Hint: If your database password contains special characters, you may need to URL-encode it."
|
|
||||||
)
|
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Properly closing the database connection
|
# Properly closing the database connection
|
||||||
@@ -74,7 +72,7 @@ def handle_peewee_migration(DATABASE_URL):
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
# Assert if db connection has been closed
|
# Assert if db connection has been closed
|
||||||
assert db.is_closed(), "Database connection is still open."
|
assert db.is_closed(), 'Database connection is still open.'
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_DB_MIGRATIONS:
|
if ENABLE_DB_MIGRATIONS:
|
||||||
@@ -84,15 +82,13 @@ if ENABLE_DB_MIGRATIONS:
|
|||||||
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
SQLALCHEMY_DATABASE_URL = DATABASE_URL
|
||||||
|
|
||||||
# Handle SQLCipher URLs
|
# Handle SQLCipher URLs
|
||||||
if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
if SQLALCHEMY_DATABASE_URL.startswith('sqlite+sqlcipher://'):
|
||||||
database_password = os.environ.get("DATABASE_PASSWORD")
|
database_password = os.environ.get('DATABASE_PASSWORD')
|
||||||
if not database_password or database_password.strip() == "":
|
if not database_password or database_password.strip() == '':
|
||||||
raise ValueError(
|
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
|
||||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract database path from SQLCipher URL
|
# Extract database path from SQLCipher URL
|
||||||
db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "")
|
db_path = SQLALCHEMY_DATABASE_URL.replace('sqlite+sqlcipher://', '')
|
||||||
|
|
||||||
# Create a custom creator function that uses sqlcipher3
|
# Create a custom creator function that uses sqlcipher3
|
||||||
def create_sqlcipher_connection():
|
def create_sqlcipher_connection():
|
||||||
@@ -109,7 +105,7 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
|||||||
# or QueuePool if DATABASE_POOL_SIZE is explicitly configured.
|
# or QueuePool if DATABASE_POOL_SIZE is explicitly configured.
|
||||||
if isinstance(DATABASE_POOL_SIZE, int) and DATABASE_POOL_SIZE > 0:
|
if isinstance(DATABASE_POOL_SIZE, int) and DATABASE_POOL_SIZE > 0:
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
"sqlite://",
|
'sqlite://',
|
||||||
creator=create_sqlcipher_connection,
|
creator=create_sqlcipher_connection,
|
||||||
pool_size=DATABASE_POOL_SIZE,
|
pool_size=DATABASE_POOL_SIZE,
|
||||||
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
|
max_overflow=DATABASE_POOL_MAX_OVERFLOW,
|
||||||
@@ -121,28 +117,26 @@ if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
"sqlite://",
|
'sqlite://',
|
||||||
creator=create_sqlcipher_connection,
|
creator=create_sqlcipher_connection,
|
||||||
poolclass=NullPool,
|
poolclass=NullPool,
|
||||||
echo=False,
|
echo=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info("Connected to encrypted SQLite database using SQLCipher")
|
log.info('Connected to encrypted SQLite database using SQLCipher')
|
||||||
|
|
||||||
elif "sqlite" in SQLALCHEMY_DATABASE_URL:
|
elif 'sqlite' in SQLALCHEMY_DATABASE_URL:
|
||||||
engine = create_engine(
|
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={'check_same_thread': False})
|
||||||
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_connect(dbapi_connection, connection_record):
|
def on_connect(dbapi_connection, connection_record):
|
||||||
cursor = dbapi_connection.cursor()
|
cursor = dbapi_connection.cursor()
|
||||||
if DATABASE_ENABLE_SQLITE_WAL:
|
if DATABASE_ENABLE_SQLITE_WAL:
|
||||||
cursor.execute("PRAGMA journal_mode=WAL")
|
cursor.execute('PRAGMA journal_mode=WAL')
|
||||||
else:
|
else:
|
||||||
cursor.execute("PRAGMA journal_mode=DELETE")
|
cursor.execute('PRAGMA journal_mode=DELETE')
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
event.listen(engine, "connect", on_connect)
|
event.listen(engine, 'connect', on_connect)
|
||||||
else:
|
else:
|
||||||
if isinstance(DATABASE_POOL_SIZE, int):
|
if isinstance(DATABASE_POOL_SIZE, int):
|
||||||
if DATABASE_POOL_SIZE > 0:
|
if DATABASE_POOL_SIZE > 0:
|
||||||
@@ -156,16 +150,12 @@ else:
|
|||||||
poolclass=QueuePool,
|
poolclass=QueuePool,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
engine = create_engine(
|
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool)
|
||||||
SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
||||||
|
|
||||||
|
|
||||||
SessionLocal = sessionmaker(
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
|
||||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
|
||||||
)
|
|
||||||
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
|
metadata_obj = MetaData(schema=DATABASE_SCHEMA)
|
||||||
Base = declarative_base(metadata=metadata_obj)
|
Base = declarative_base(metadata=metadata_obj)
|
||||||
ScopedSession = scoped_session(SessionLocal)
|
ScopedSession = scoped_session(SessionLocal)
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
active = pw.BooleanField()
|
active = pw.BooleanField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "auth"
|
table_name = 'auth'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class Chat(pw.Model):
|
class Chat(pw.Model):
|
||||||
@@ -67,7 +67,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "chat"
|
table_name = 'chat'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class ChatIdTag(pw.Model):
|
class ChatIdTag(pw.Model):
|
||||||
@@ -78,7 +78,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "chatidtag"
|
table_name = 'chatidtag'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class Document(pw.Model):
|
class Document(pw.Model):
|
||||||
@@ -92,7 +92,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "document"
|
table_name = 'document'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class Modelfile(pw.Model):
|
class Modelfile(pw.Model):
|
||||||
@@ -103,7 +103,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "modelfile"
|
table_name = 'modelfile'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class Prompt(pw.Model):
|
class Prompt(pw.Model):
|
||||||
@@ -115,7 +115,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "prompt"
|
table_name = 'prompt'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class Tag(pw.Model):
|
class Tag(pw.Model):
|
||||||
@@ -125,7 +125,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
data = pw.TextField(null=True)
|
data = pw.TextField(null=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "tag"
|
table_name = 'tag'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class User(pw.Model):
|
class User(pw.Model):
|
||||||
@@ -137,7 +137,7 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "user"
|
table_name = 'user'
|
||||||
|
|
||||||
|
|
||||||
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
@@ -149,7 +149,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
active = pw.BooleanField()
|
active = pw.BooleanField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "auth"
|
table_name = 'auth'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class Chat(pw.Model):
|
class Chat(pw.Model):
|
||||||
@@ -160,7 +160,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "chat"
|
table_name = 'chat'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class ChatIdTag(pw.Model):
|
class ChatIdTag(pw.Model):
|
||||||
@@ -171,7 +171,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "chatidtag"
|
table_name = 'chatidtag'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class Document(pw.Model):
|
class Document(pw.Model):
|
||||||
@@ -185,7 +185,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "document"
|
table_name = 'document'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class Modelfile(pw.Model):
|
class Modelfile(pw.Model):
|
||||||
@@ -196,7 +196,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "modelfile"
|
table_name = 'modelfile'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class Prompt(pw.Model):
|
class Prompt(pw.Model):
|
||||||
@@ -208,7 +208,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "prompt"
|
table_name = 'prompt'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class Tag(pw.Model):
|
class Tag(pw.Model):
|
||||||
@@ -218,7 +218,7 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
data = pw.TextField(null=True)
|
data = pw.TextField(null=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "tag"
|
table_name = 'tag'
|
||||||
|
|
||||||
@migrator.create_model
|
@migrator.create_model
|
||||||
class User(pw.Model):
|
class User(pw.Model):
|
||||||
@@ -230,24 +230,24 @@ def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
timestamp = pw.BigIntegerField()
|
timestamp = pw.BigIntegerField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "user"
|
table_name = 'user'
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_model("user")
|
migrator.remove_model('user')
|
||||||
|
|
||||||
migrator.remove_model("tag")
|
migrator.remove_model('tag')
|
||||||
|
|
||||||
migrator.remove_model("prompt")
|
migrator.remove_model('prompt')
|
||||||
|
|
||||||
migrator.remove_model("modelfile")
|
migrator.remove_model('modelfile')
|
||||||
|
|
||||||
migrator.remove_model("document")
|
migrator.remove_model('document')
|
||||||
|
|
||||||
migrator.remove_model("chatidtag")
|
migrator.remove_model('chatidtag')
|
||||||
|
|
||||||
migrator.remove_model("chat")
|
migrator.remove_model('chat')
|
||||||
|
|
||||||
migrator.remove_model("auth")
|
migrator.remove_model('auth')
|
||||||
|
|||||||
@@ -36,12 +36,10 @@ with suppress(ImportError):
|
|||||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your migrations here."""
|
"""Write your migrations here."""
|
||||||
|
|
||||||
migrator.add_fields(
|
migrator.add_fields('chat', share_id=pw.CharField(max_length=255, null=True, unique=True))
|
||||||
"chat", share_id=pw.CharField(max_length=255, null=True, unique=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_fields("chat", "share_id")
|
migrator.remove_fields('chat', 'share_id')
|
||||||
|
|||||||
@@ -36,12 +36,10 @@ with suppress(ImportError):
|
|||||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your migrations here."""
|
"""Write your migrations here."""
|
||||||
|
|
||||||
migrator.add_fields(
|
migrator.add_fields('user', api_key=pw.CharField(max_length=255, null=True, unique=True))
|
||||||
"user", api_key=pw.CharField(max_length=255, null=True, unique=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_fields("user", "api_key")
|
migrator.remove_fields('user', 'api_key')
|
||||||
|
|||||||
@@ -36,10 +36,10 @@ with suppress(ImportError):
|
|||||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your migrations here."""
|
"""Write your migrations here."""
|
||||||
|
|
||||||
migrator.add_fields("chat", archived=pw.BooleanField(default=False))
|
migrator.add_fields('chat', archived=pw.BooleanField(default=False))
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_fields("chat", "archived")
|
migrator.remove_fields('chat', 'archived')
|
||||||
|
|||||||
@@ -45,22 +45,20 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
# Adding fields created_at and updated_at to the 'chat' table
|
# Adding fields created_at and updated_at to the 'chat' table
|
||||||
migrator.add_fields(
|
migrator.add_fields(
|
||||||
"chat",
|
'chat',
|
||||||
created_at=pw.DateTimeField(null=True), # Allow null for transition
|
created_at=pw.DateTimeField(null=True), # Allow null for transition
|
||||||
updated_at=pw.DateTimeField(null=True), # Allow null for transition
|
updated_at=pw.DateTimeField(null=True), # Allow null for transition
|
||||||
)
|
)
|
||||||
|
|
||||||
# Populate the new fields from an existing 'timestamp' field
|
# Populate the new fields from an existing 'timestamp' field
|
||||||
migrator.sql(
|
migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL')
|
||||||
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now that the data has been copied, remove the original 'timestamp' field
|
# Now that the data has been copied, remove the original 'timestamp' field
|
||||||
migrator.remove_fields("chat", "timestamp")
|
migrator.remove_fields('chat', 'timestamp')
|
||||||
|
|
||||||
# Update the fields to be not null now that they are populated
|
# Update the fields to be not null now that they are populated
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"chat",
|
'chat',
|
||||||
created_at=pw.DateTimeField(null=False),
|
created_at=pw.DateTimeField(null=False),
|
||||||
updated_at=pw.DateTimeField(null=False),
|
updated_at=pw.DateTimeField(null=False),
|
||||||
)
|
)
|
||||||
@@ -69,22 +67,20 @@ def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
# Adding fields created_at and updated_at to the 'chat' table
|
# Adding fields created_at and updated_at to the 'chat' table
|
||||||
migrator.add_fields(
|
migrator.add_fields(
|
||||||
"chat",
|
'chat',
|
||||||
created_at=pw.BigIntegerField(null=True), # Allow null for transition
|
created_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||||
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
|
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||||
)
|
)
|
||||||
|
|
||||||
# Populate the new fields from an existing 'timestamp' field
|
# Populate the new fields from an existing 'timestamp' field
|
||||||
migrator.sql(
|
migrator.sql('UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL')
|
||||||
"UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now that the data has been copied, remove the original 'timestamp' field
|
# Now that the data has been copied, remove the original 'timestamp' field
|
||||||
migrator.remove_fields("chat", "timestamp")
|
migrator.remove_fields('chat', 'timestamp')
|
||||||
|
|
||||||
# Update the fields to be not null now that they are populated
|
# Update the fields to be not null now that they are populated
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"chat",
|
'chat',
|
||||||
created_at=pw.BigIntegerField(null=False),
|
created_at=pw.BigIntegerField(null=False),
|
||||||
updated_at=pw.BigIntegerField(null=False),
|
updated_at=pw.BigIntegerField(null=False),
|
||||||
)
|
)
|
||||||
@@ -101,29 +97,29 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
|
|
||||||
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
# Recreate the timestamp field initially allowing null values for safe transition
|
# Recreate the timestamp field initially allowing null values for safe transition
|
||||||
migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True))
|
migrator.add_fields('chat', timestamp=pw.DateTimeField(null=True))
|
||||||
|
|
||||||
# Copy the earliest created_at date back into the new timestamp field
|
# Copy the earliest created_at date back into the new timestamp field
|
||||||
# This assumes created_at was originally a copy of timestamp
|
# This assumes created_at was originally a copy of timestamp
|
||||||
migrator.sql("UPDATE chat SET timestamp = created_at")
|
migrator.sql('UPDATE chat SET timestamp = created_at')
|
||||||
|
|
||||||
# Remove the created_at and updated_at fields
|
# Remove the created_at and updated_at fields
|
||||||
migrator.remove_fields("chat", "created_at", "updated_at")
|
migrator.remove_fields('chat', 'created_at', 'updated_at')
|
||||||
|
|
||||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||||
migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False))
|
migrator.change_fields('chat', timestamp=pw.DateTimeField(null=False))
|
||||||
|
|
||||||
|
|
||||||
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
# Recreate the timestamp field initially allowing null values for safe transition
|
# Recreate the timestamp field initially allowing null values for safe transition
|
||||||
migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True))
|
migrator.add_fields('chat', timestamp=pw.BigIntegerField(null=True))
|
||||||
|
|
||||||
# Copy the earliest created_at date back into the new timestamp field
|
# Copy the earliest created_at date back into the new timestamp field
|
||||||
# This assumes created_at was originally a copy of timestamp
|
# This assumes created_at was originally a copy of timestamp
|
||||||
migrator.sql("UPDATE chat SET timestamp = created_at")
|
migrator.sql('UPDATE chat SET timestamp = created_at')
|
||||||
|
|
||||||
# Remove the created_at and updated_at fields
|
# Remove the created_at and updated_at fields
|
||||||
migrator.remove_fields("chat", "created_at", "updated_at")
|
migrator.remove_fields('chat', 'created_at', 'updated_at')
|
||||||
|
|
||||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||||
migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False))
|
migrator.change_fields('chat', timestamp=pw.BigIntegerField(null=False))
|
||||||
|
|||||||
@@ -38,45 +38,45 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
|
|
||||||
# Alter the tables with timestamps
|
# Alter the tables with timestamps
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"chatidtag",
|
'chatidtag',
|
||||||
timestamp=pw.BigIntegerField(),
|
timestamp=pw.BigIntegerField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"document",
|
'document',
|
||||||
timestamp=pw.BigIntegerField(),
|
timestamp=pw.BigIntegerField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"modelfile",
|
'modelfile',
|
||||||
timestamp=pw.BigIntegerField(),
|
timestamp=pw.BigIntegerField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"prompt",
|
'prompt',
|
||||||
timestamp=pw.BigIntegerField(),
|
timestamp=pw.BigIntegerField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"user",
|
'user',
|
||||||
timestamp=pw.BigIntegerField(),
|
timestamp=pw.BigIntegerField(),
|
||||||
)
|
)
|
||||||
# Alter the tables with varchar to text where necessary
|
# Alter the tables with varchar to text where necessary
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"auth",
|
'auth',
|
||||||
password=pw.TextField(),
|
password=pw.TextField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"chat",
|
'chat',
|
||||||
title=pw.TextField(),
|
title=pw.TextField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"document",
|
'document',
|
||||||
title=pw.TextField(),
|
title=pw.TextField(),
|
||||||
filename=pw.TextField(),
|
filename=pw.TextField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"prompt",
|
'prompt',
|
||||||
title=pw.TextField(),
|
title=pw.TextField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"user",
|
'user',
|
||||||
profile_image_url=pw.TextField(),
|
profile_image_url=pw.TextField(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -87,43 +87,43 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
if isinstance(database, pw.SqliteDatabase):
|
if isinstance(database, pw.SqliteDatabase):
|
||||||
# Alter the tables with timestamps
|
# Alter the tables with timestamps
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"chatidtag",
|
'chatidtag',
|
||||||
timestamp=pw.DateField(),
|
timestamp=pw.DateField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"document",
|
'document',
|
||||||
timestamp=pw.DateField(),
|
timestamp=pw.DateField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"modelfile",
|
'modelfile',
|
||||||
timestamp=pw.DateField(),
|
timestamp=pw.DateField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"prompt",
|
'prompt',
|
||||||
timestamp=pw.DateField(),
|
timestamp=pw.DateField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"user",
|
'user',
|
||||||
timestamp=pw.DateField(),
|
timestamp=pw.DateField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"auth",
|
'auth',
|
||||||
password=pw.CharField(max_length=255),
|
password=pw.CharField(max_length=255),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"chat",
|
'chat',
|
||||||
title=pw.CharField(),
|
title=pw.CharField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"document",
|
'document',
|
||||||
title=pw.CharField(),
|
title=pw.CharField(),
|
||||||
filename=pw.CharField(),
|
filename=pw.CharField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"prompt",
|
'prompt',
|
||||||
title=pw.CharField(),
|
title=pw.CharField(),
|
||||||
)
|
)
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"user",
|
'user',
|
||||||
profile_image_url=pw.CharField(),
|
profile_image_url=pw.CharField(),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
|
|
||||||
# Adding fields created_at and updated_at to the 'user' table
|
# Adding fields created_at and updated_at to the 'user' table
|
||||||
migrator.add_fields(
|
migrator.add_fields(
|
||||||
"user",
|
'user',
|
||||||
created_at=pw.BigIntegerField(null=True), # Allow null for transition
|
created_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||||
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
|
updated_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||||
last_active_at=pw.BigIntegerField(null=True), # Allow null for transition
|
last_active_at=pw.BigIntegerField(null=True), # Allow null for transition
|
||||||
@@ -50,11 +50,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Now that the data has been copied, remove the original 'timestamp' field
|
# Now that the data has been copied, remove the original 'timestamp' field
|
||||||
migrator.remove_fields("user", "timestamp")
|
migrator.remove_fields('user', 'timestamp')
|
||||||
|
|
||||||
# Update the fields to be not null now that they are populated
|
# Update the fields to be not null now that they are populated
|
||||||
migrator.change_fields(
|
migrator.change_fields(
|
||||||
"user",
|
'user',
|
||||||
created_at=pw.BigIntegerField(null=False),
|
created_at=pw.BigIntegerField(null=False),
|
||||||
updated_at=pw.BigIntegerField(null=False),
|
updated_at=pw.BigIntegerField(null=False),
|
||||||
last_active_at=pw.BigIntegerField(null=False),
|
last_active_at=pw.BigIntegerField(null=False),
|
||||||
@@ -65,14 +65,14 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
# Recreate the timestamp field initially allowing null values for safe transition
|
# Recreate the timestamp field initially allowing null values for safe transition
|
||||||
migrator.add_fields("user", timestamp=pw.BigIntegerField(null=True))
|
migrator.add_fields('user', timestamp=pw.BigIntegerField(null=True))
|
||||||
|
|
||||||
# Copy the earliest created_at date back into the new timestamp field
|
# Copy the earliest created_at date back into the new timestamp field
|
||||||
# This assumes created_at was originally a copy of timestamp
|
# This assumes created_at was originally a copy of timestamp
|
||||||
migrator.sql('UPDATE "user" SET timestamp = created_at')
|
migrator.sql('UPDATE "user" SET timestamp = created_at')
|
||||||
|
|
||||||
# Remove the created_at and updated_at fields
|
# Remove the created_at and updated_at fields
|
||||||
migrator.remove_fields("user", "created_at", "updated_at", "last_active_at")
|
migrator.remove_fields('user', 'created_at', 'updated_at', 'last_active_at')
|
||||||
|
|
||||||
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
# Finally, alter the timestamp field to not allow nulls if that was the original setting
|
||||||
migrator.change_fields("user", timestamp=pw.BigIntegerField(null=False))
|
migrator.change_fields('user', timestamp=pw.BigIntegerField(null=False))
|
||||||
|
|||||||
@@ -43,10 +43,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
created_at = pw.BigIntegerField(null=False)
|
created_at = pw.BigIntegerField(null=False)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "memory"
|
table_name = 'memory'
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_model("memory")
|
migrator.remove_model('memory')
|
||||||
|
|||||||
@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
updated_at = pw.BigIntegerField(null=False)
|
updated_at = pw.BigIntegerField(null=False)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "model"
|
table_name = 'model'
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_model("model")
|
migrator.remove_model('model')
|
||||||
|
|||||||
@@ -42,12 +42,12 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
# Fetch data from 'modelfile' table and insert into 'model' table
|
# Fetch data from 'modelfile' table and insert into 'model' table
|
||||||
migrate_modelfile_to_model(migrator, database)
|
migrate_modelfile_to_model(migrator, database)
|
||||||
# Drop the 'modelfile' table
|
# Drop the 'modelfile' table
|
||||||
migrator.remove_model("modelfile")
|
migrator.remove_model('modelfile')
|
||||||
|
|
||||||
|
|
||||||
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
|
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
|
||||||
ModelFile = migrator.orm["modelfile"]
|
ModelFile = migrator.orm['modelfile']
|
||||||
Model = migrator.orm["model"]
|
Model = migrator.orm['model']
|
||||||
|
|
||||||
modelfiles = ModelFile.select()
|
modelfiles = ModelFile.select()
|
||||||
|
|
||||||
@@ -57,25 +57,25 @@ def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
|
|||||||
modelfile.modelfile = json.loads(modelfile.modelfile)
|
modelfile.modelfile = json.loads(modelfile.modelfile)
|
||||||
meta = json.dumps(
|
meta = json.dumps(
|
||||||
{
|
{
|
||||||
"description": modelfile.modelfile.get("desc"),
|
'description': modelfile.modelfile.get('desc'),
|
||||||
"profile_image_url": modelfile.modelfile.get("imageUrl"),
|
'profile_image_url': modelfile.modelfile.get('imageUrl'),
|
||||||
"ollama": {"modelfile": modelfile.modelfile.get("content")},
|
'ollama': {'modelfile': modelfile.modelfile.get('content')},
|
||||||
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"),
|
'suggestion_prompts': modelfile.modelfile.get('suggestionPrompts'),
|
||||||
"categories": modelfile.modelfile.get("categories"),
|
'categories': modelfile.modelfile.get('categories'),
|
||||||
"user": {**modelfile.modelfile.get("user", {}), "community": True},
|
'user': {**modelfile.modelfile.get('user', {}), 'community': True},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
info = parse_ollama_modelfile(modelfile.modelfile.get("content"))
|
info = parse_ollama_modelfile(modelfile.modelfile.get('content'))
|
||||||
|
|
||||||
# Insert the processed data into the 'model' table
|
# Insert the processed data into the 'model' table
|
||||||
Model.create(
|
Model.create(
|
||||||
id=f"ollama-{modelfile.tag_name}",
|
id=f'ollama-{modelfile.tag_name}',
|
||||||
user_id=modelfile.user_id,
|
user_id=modelfile.user_id,
|
||||||
base_model_id=info.get("base_model_id"),
|
base_model_id=info.get('base_model_id'),
|
||||||
name=modelfile.modelfile.get("title"),
|
name=modelfile.modelfile.get('title'),
|
||||||
meta=meta,
|
meta=meta,
|
||||||
params=json.dumps(info.get("params", {})),
|
params=json.dumps(info.get('params', {})),
|
||||||
created_at=modelfile.timestamp,
|
created_at=modelfile.timestamp,
|
||||||
updated_at=modelfile.timestamp,
|
updated_at=modelfile.timestamp,
|
||||||
)
|
)
|
||||||
@@ -86,7 +86,7 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
|
|
||||||
recreate_modelfile_table(migrator, database)
|
recreate_modelfile_table(migrator, database)
|
||||||
move_data_back_to_modelfile(migrator, database)
|
move_data_back_to_modelfile(migrator, database)
|
||||||
migrator.remove_model("model")
|
migrator.remove_model('model')
|
||||||
|
|
||||||
|
|
||||||
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
|
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
|
||||||
@@ -102,8 +102,8 @@ def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
|
|||||||
|
|
||||||
|
|
||||||
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
|
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
|
||||||
Model = migrator.orm["model"]
|
Model = migrator.orm['model']
|
||||||
Modelfile = migrator.orm["modelfile"]
|
Modelfile = migrator.orm['modelfile']
|
||||||
|
|
||||||
models = Model.select()
|
models = Model.select()
|
||||||
|
|
||||||
@@ -112,13 +112,13 @@ def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
|
|||||||
meta = json.loads(model.meta)
|
meta = json.loads(model.meta)
|
||||||
|
|
||||||
modelfile_data = {
|
modelfile_data = {
|
||||||
"title": model.name,
|
'title': model.name,
|
||||||
"desc": meta.get("description"),
|
'desc': meta.get('description'),
|
||||||
"imageUrl": meta.get("profile_image_url"),
|
'imageUrl': meta.get('profile_image_url'),
|
||||||
"content": meta.get("ollama", {}).get("modelfile"),
|
'content': meta.get('ollama', {}).get('modelfile'),
|
||||||
"suggestionPrompts": meta.get("suggestion_prompts"),
|
'suggestionPrompts': meta.get('suggestion_prompts'),
|
||||||
"categories": meta.get("categories"),
|
'categories': meta.get('categories'),
|
||||||
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"},
|
'user': {k: v for k, v in meta.get('user', {}).items() if k != 'community'},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Insert the processed data back into the 'modelfile' table
|
# Insert the processed data back into the 'modelfile' table
|
||||||
|
|||||||
@@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
"""Write your migrations here."""
|
"""Write your migrations here."""
|
||||||
|
|
||||||
# Adding fields settings to the 'user' table
|
# Adding fields settings to the 'user' table
|
||||||
migrator.add_fields("user", settings=pw.TextField(null=True))
|
migrator.add_fields('user', settings=pw.TextField(null=True))
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
# Remove the settings field
|
# Remove the settings field
|
||||||
migrator.remove_fields("user", "settings")
|
migrator.remove_fields('user', 'settings')
|
||||||
|
|||||||
@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
updated_at = pw.BigIntegerField(null=False)
|
updated_at = pw.BigIntegerField(null=False)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "tool"
|
table_name = 'tool'
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_model("tool")
|
migrator.remove_model('tool')
|
||||||
|
|||||||
@@ -37,11 +37,11 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
"""Write your migrations here."""
|
"""Write your migrations here."""
|
||||||
|
|
||||||
# Adding fields info to the 'user' table
|
# Adding fields info to the 'user' table
|
||||||
migrator.add_fields("user", info=pw.TextField(null=True))
|
migrator.add_fields('user', info=pw.TextField(null=True))
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
# Remove the settings field
|
# Remove the settings field
|
||||||
migrator.remove_fields("user", "info")
|
migrator.remove_fields('user', 'info')
|
||||||
|
|||||||
@@ -45,10 +45,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
created_at = pw.BigIntegerField(null=False)
|
created_at = pw.BigIntegerField(null=False)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "file"
|
table_name = 'file'
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_model("file")
|
migrator.remove_model('file')
|
||||||
|
|||||||
@@ -51,10 +51,10 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
updated_at = pw.BigIntegerField(null=False)
|
updated_at = pw.BigIntegerField(null=False)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "function"
|
table_name = 'function'
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_model("function")
|
migrator.remove_model('function')
|
||||||
|
|||||||
@@ -36,14 +36,14 @@ with suppress(ImportError):
|
|||||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your migrations here."""
|
"""Write your migrations here."""
|
||||||
|
|
||||||
migrator.add_fields("tool", valves=pw.TextField(null=True))
|
migrator.add_fields('tool', valves=pw.TextField(null=True))
|
||||||
migrator.add_fields("function", valves=pw.TextField(null=True))
|
migrator.add_fields('function', valves=pw.TextField(null=True))
|
||||||
migrator.add_fields("function", is_active=pw.BooleanField(default=False))
|
migrator.add_fields('function', is_active=pw.BooleanField(default=False))
|
||||||
|
|
||||||
|
|
||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_fields("tool", "valves")
|
migrator.remove_fields('tool', 'valves')
|
||||||
migrator.remove_fields("function", "valves")
|
migrator.remove_fields('function', 'valves')
|
||||||
migrator.remove_fields("function", "is_active")
|
migrator.remove_fields('function', 'is_active')
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
"""Write your migrations here."""
|
"""Write your migrations here."""
|
||||||
|
|
||||||
migrator.add_fields(
|
migrator.add_fields(
|
||||||
"user",
|
'user',
|
||||||
oauth_sub=pw.TextField(null=True, unique=True),
|
oauth_sub=pw.TextField(null=True, unique=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,4 +41,4 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_fields("user", "oauth_sub")
|
migrator.remove_fields('user', 'oauth_sub')
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
"""Write your migrations here."""
|
"""Write your migrations here."""
|
||||||
|
|
||||||
migrator.add_fields(
|
migrator.add_fields(
|
||||||
"function",
|
'function',
|
||||||
is_global=pw.BooleanField(default=False),
|
is_global=pw.BooleanField(default=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,4 +45,4 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
|||||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||||
"""Write your rollback migrations here."""
|
"""Write your rollback migrations here."""
|
||||||
|
|
||||||
migrator.remove_fields("function", "is_global")
|
migrator.remove_fields('function', 'is_global')
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ from playhouse.shortcuts import ReconnectMixin
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
db_state_default = {'closed': None, 'conn': None, 'ctx': None, 'transactions': None}
|
||||||
db_state = ContextVar("db_state", default=db_state_default.copy())
|
db_state = ContextVar('db_state', default=db_state_default.copy())
|
||||||
|
|
||||||
|
|
||||||
class PeeweeConnectionState(object):
|
class PeeweeConnectionState(object):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__setattr__("_state", db_state)
|
super().__setattr__('_state', db_state)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
@@ -30,10 +30,10 @@ class PeeweeConnectionState(object):
|
|||||||
class CustomReconnectMixin(ReconnectMixin):
|
class CustomReconnectMixin(ReconnectMixin):
|
||||||
reconnect_errors = (
|
reconnect_errors = (
|
||||||
# psycopg2
|
# psycopg2
|
||||||
(OperationalError, "termin"),
|
(OperationalError, 'termin'),
|
||||||
(InterfaceError, "closed"),
|
(InterfaceError, 'closed'),
|
||||||
# peewee
|
# peewee
|
||||||
(PeeWeeInterfaceError, "closed"),
|
(PeeWeeInterfaceError, 'closed'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -43,23 +43,21 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
|||||||
|
|
||||||
def register_connection(db_url):
|
def register_connection(db_url):
|
||||||
# Check if using SQLCipher protocol
|
# Check if using SQLCipher protocol
|
||||||
if db_url.startswith("sqlite+sqlcipher://"):
|
if db_url.startswith('sqlite+sqlcipher://'):
|
||||||
database_password = os.environ.get("DATABASE_PASSWORD")
|
database_password = os.environ.get('DATABASE_PASSWORD')
|
||||||
if not database_password or database_password.strip() == "":
|
if not database_password or database_password.strip() == '':
|
||||||
raise ValueError(
|
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
|
||||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
|
||||||
)
|
|
||||||
from playhouse.sqlcipher_ext import SqlCipherDatabase
|
from playhouse.sqlcipher_ext import SqlCipherDatabase
|
||||||
|
|
||||||
# Parse the database path from SQLCipher URL
|
# Parse the database path from SQLCipher URL
|
||||||
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
|
# Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite
|
||||||
db_path = db_url.replace("sqlite+sqlcipher://", "")
|
db_path = db_url.replace('sqlite+sqlcipher://', '')
|
||||||
|
|
||||||
# Use Peewee's native SqlCipherDatabase with encryption
|
# Use Peewee's native SqlCipherDatabase with encryption
|
||||||
db = SqlCipherDatabase(db_path, passphrase=database_password)
|
db = SqlCipherDatabase(db_path, passphrase=database_password)
|
||||||
db.autoconnect = True
|
db.autoconnect = True
|
||||||
db.reuse_if_open = True
|
db.reuse_if_open = True
|
||||||
log.info("Connected to encrypted SQLite database using SQLCipher")
|
log.info('Connected to encrypted SQLite database using SQLCipher')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Standard database connection (existing logic)
|
# Standard database connection (existing logic)
|
||||||
@@ -68,7 +66,7 @@ def register_connection(db_url):
|
|||||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||||
db.autoconnect = True
|
db.autoconnect = True
|
||||||
db.reuse_if_open = True
|
db.reuse_if_open = True
|
||||||
log.info("Connected to PostgreSQL database")
|
log.info('Connected to PostgreSQL database')
|
||||||
|
|
||||||
# Get the connection details
|
# Get the connection details
|
||||||
connection = parse(db_url, unquote_user=True, unquote_password=True)
|
connection = parse(db_url, unquote_user=True, unquote_password=True)
|
||||||
@@ -80,7 +78,7 @@ def register_connection(db_url):
|
|||||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||||
db.autoconnect = True
|
db.autoconnect = True
|
||||||
db.reuse_if_open = True
|
db.reuse_if_open = True
|
||||||
log.info("Connected to SQLite database")
|
log.info('Connected to SQLite database')
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported database connection")
|
raise ValueError('Unsupported database connection')
|
||||||
return db
|
return db
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -16,7 +16,7 @@ if config.config_file_name is not None:
|
|||||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||||
|
|
||||||
# Re-apply JSON formatter after fileConfig replaces handlers.
|
# Re-apply JSON formatter after fileConfig replaces handlers.
|
||||||
if LOG_FORMAT == "json":
|
if LOG_FORMAT == 'json':
|
||||||
from open_webui.env import JSONFormatter
|
from open_webui.env import JSONFormatter
|
||||||
|
|
||||||
for handler in logging.root.handlers:
|
for handler in logging.root.handlers:
|
||||||
@@ -36,7 +36,7 @@ target_metadata = Auth.metadata
|
|||||||
DB_URL = DATABASE_URL
|
DB_URL = DATABASE_URL
|
||||||
|
|
||||||
if DB_URL:
|
if DB_URL:
|
||||||
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
|
config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%'))
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_offline() -> None:
|
def run_migrations_offline() -> None:
|
||||||
@@ -51,12 +51,12 @@ def run_migrations_offline() -> None:
|
|||||||
script output.
|
script output.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
url = config.get_main_option("sqlalchemy.url")
|
url = config.get_main_option('sqlalchemy.url')
|
||||||
context.configure(
|
context.configure(
|
||||||
url=url,
|
url=url,
|
||||||
target_metadata=target_metadata,
|
target_metadata=target_metadata,
|
||||||
literal_binds=True,
|
literal_binds=True,
|
||||||
dialect_opts={"paramstyle": "named"},
|
dialect_opts={'paramstyle': 'named'},
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
@@ -71,15 +71,13 @@ def run_migrations_online() -> None:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# Handle SQLCipher URLs
|
# Handle SQLCipher URLs
|
||||||
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
|
if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'):
|
||||||
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
|
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '':
|
||||||
raise ValueError(
|
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
|
||||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract database path from SQLCipher URL
|
# Extract database path from SQLCipher URL
|
||||||
db_path = DB_URL.replace("sqlite+sqlcipher://", "")
|
db_path = DB_URL.replace('sqlite+sqlcipher://', '')
|
||||||
if db_path.startswith("/"):
|
if db_path.startswith('/'):
|
||||||
db_path = db_path[1:] # Remove leading slash for relative paths
|
db_path = db_path[1:] # Remove leading slash for relative paths
|
||||||
|
|
||||||
# Create a custom creator function that uses sqlcipher3
|
# Create a custom creator function that uses sqlcipher3
|
||||||
@@ -91,7 +89,7 @@ def run_migrations_online() -> None:
|
|||||||
return conn
|
return conn
|
||||||
|
|
||||||
connectable = create_engine(
|
connectable = create_engine(
|
||||||
"sqlite://", # Dummy URL since we're using creator
|
'sqlite://', # Dummy URL since we're using creator
|
||||||
creator=create_sqlcipher_connection,
|
creator=create_sqlcipher_connection,
|
||||||
echo=False,
|
echo=False,
|
||||||
)
|
)
|
||||||
@@ -99,7 +97,7 @@ def run_migrations_online() -> None:
|
|||||||
# Standard database connection (existing logic)
|
# Standard database connection (existing logic)
|
||||||
connectable = engine_from_config(
|
connectable = engine_from_config(
|
||||||
config.get_section(config.config_ini_section, {}),
|
config.get_section(config.config_ini_section, {}),
|
||||||
prefix="sqlalchemy.",
|
prefix='sqlalchemy.',
|
||||||
poolclass=pool.NullPool,
|
poolclass=pool.NullPool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -12,4 +12,4 @@ def get_existing_tables():
|
|||||||
def get_revision_id():
|
def get_revision_id():
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
return str(uuid.uuid4()).replace("-", "")[:12]
|
return str(uuid.uuid4()).replace('-', '')[:12]
|
||||||
|
|||||||
@@ -9,38 +9,38 @@ Create Date: 2025-08-13 03:00:00.000000
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
revision = "018012973d35"
|
revision = '018012973d35'
|
||||||
down_revision = "d31026856c01"
|
down_revision = 'd31026856c01'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
# Chat table indexes
|
# Chat table indexes
|
||||||
op.create_index("folder_id_idx", "chat", ["folder_id"])
|
op.create_index('folder_id_idx', 'chat', ['folder_id'])
|
||||||
op.create_index("user_id_pinned_idx", "chat", ["user_id", "pinned"])
|
op.create_index('user_id_pinned_idx', 'chat', ['user_id', 'pinned'])
|
||||||
op.create_index("user_id_archived_idx", "chat", ["user_id", "archived"])
|
op.create_index('user_id_archived_idx', 'chat', ['user_id', 'archived'])
|
||||||
op.create_index("updated_at_user_id_idx", "chat", ["updated_at", "user_id"])
|
op.create_index('updated_at_user_id_idx', 'chat', ['updated_at', 'user_id'])
|
||||||
op.create_index("folder_id_user_id_idx", "chat", ["folder_id", "user_id"])
|
op.create_index('folder_id_user_id_idx', 'chat', ['folder_id', 'user_id'])
|
||||||
|
|
||||||
# Tag table index
|
# Tag table index
|
||||||
op.create_index("user_id_idx", "tag", ["user_id"])
|
op.create_index('user_id_idx', 'tag', ['user_id'])
|
||||||
|
|
||||||
# Function table index
|
# Function table index
|
||||||
op.create_index("is_global_idx", "function", ["is_global"])
|
op.create_index('is_global_idx', 'function', ['is_global'])
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# Chat table indexes
|
# Chat table indexes
|
||||||
op.drop_index("folder_id_idx", table_name="chat")
|
op.drop_index('folder_id_idx', table_name='chat')
|
||||||
op.drop_index("user_id_pinned_idx", table_name="chat")
|
op.drop_index('user_id_pinned_idx', table_name='chat')
|
||||||
op.drop_index("user_id_archived_idx", table_name="chat")
|
op.drop_index('user_id_archived_idx', table_name='chat')
|
||||||
op.drop_index("updated_at_user_id_idx", table_name="chat")
|
op.drop_index('updated_at_user_id_idx', table_name='chat')
|
||||||
op.drop_index("folder_id_user_id_idx", table_name="chat")
|
op.drop_index('folder_id_user_id_idx', table_name='chat')
|
||||||
|
|
||||||
# Tag table index
|
# Tag table index
|
||||||
op.drop_index("user_id_idx", table_name="tag")
|
op.drop_index('user_id_idx', table_name='tag')
|
||||||
|
|
||||||
# Function table index
|
# Function table index
|
||||||
|
|
||||||
op.drop_index("is_global_idx", table_name="function")
|
op.drop_index('is_global_idx', table_name='function')
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from sqlalchemy.engine.reflection import Inspector
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
revision = "1af9b942657b"
|
revision = '1af9b942657b'
|
||||||
down_revision = "242a2047eae0"
|
down_revision = '242a2047eae0'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
@@ -25,43 +25,40 @@ def upgrade():
|
|||||||
inspector = Inspector.from_engine(conn)
|
inspector = Inspector.from_engine(conn)
|
||||||
|
|
||||||
# Clean up potential leftover temp table from previous failures
|
# Clean up potential leftover temp table from previous failures
|
||||||
conn.execute(sa.text("DROP TABLE IF EXISTS _alembic_tmp_tag"))
|
conn.execute(sa.text('DROP TABLE IF EXISTS _alembic_tmp_tag'))
|
||||||
|
|
||||||
# Check if the 'tag' table exists
|
# Check if the 'tag' table exists
|
||||||
tables = inspector.get_table_names()
|
tables = inspector.get_table_names()
|
||||||
|
|
||||||
# Step 1: Modify Tag table using batch mode for SQLite support
|
# Step 1: Modify Tag table using batch mode for SQLite support
|
||||||
if "tag" in tables:
|
if 'tag' in tables:
|
||||||
# Get the current columns in the 'tag' table
|
# Get the current columns in the 'tag' table
|
||||||
columns = [col["name"] for col in inspector.get_columns("tag")]
|
columns = [col['name'] for col in inspector.get_columns('tag')]
|
||||||
|
|
||||||
# Get any existing unique constraints on the 'tag' table
|
# Get any existing unique constraints on the 'tag' table
|
||||||
current_constraints = inspector.get_unique_constraints("tag")
|
current_constraints = inspector.get_unique_constraints('tag')
|
||||||
|
|
||||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
with op.batch_alter_table('tag', schema=None) as batch_op:
|
||||||
# Check if the unique constraint already exists
|
# Check if the unique constraint already exists
|
||||||
if not any(
|
if not any(constraint['name'] == 'uq_id_user_id' for constraint in current_constraints):
|
||||||
constraint["name"] == "uq_id_user_id"
|
|
||||||
for constraint in current_constraints
|
|
||||||
):
|
|
||||||
# Create unique constraint if it doesn't exist
|
# Create unique constraint if it doesn't exist
|
||||||
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
|
batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id'])
|
||||||
|
|
||||||
# Check if the 'data' column exists before trying to drop it
|
# Check if the 'data' column exists before trying to drop it
|
||||||
if "data" in columns:
|
if 'data' in columns:
|
||||||
batch_op.drop_column("data")
|
batch_op.drop_column('data')
|
||||||
|
|
||||||
# Check if the 'meta' column needs to be created
|
# Check if the 'meta' column needs to be created
|
||||||
if "meta" not in columns:
|
if 'meta' not in columns:
|
||||||
# Add the 'meta' column if it doesn't already exist
|
# Add the 'meta' column if it doesn't already exist
|
||||||
batch_op.add_column(sa.Column("meta", sa.JSON(), nullable=True))
|
batch_op.add_column(sa.Column('meta', sa.JSON(), nullable=True))
|
||||||
|
|
||||||
tag = table(
|
tag = table(
|
||||||
"tag",
|
'tag',
|
||||||
column("id", sa.String()),
|
column('id', sa.String()),
|
||||||
column("name", sa.String()),
|
column('name', sa.String()),
|
||||||
column("user_id", sa.String()),
|
column('user_id', sa.String()),
|
||||||
column("meta", sa.JSON()),
|
column('meta', sa.JSON()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 2: Migrate tags
|
# Step 2: Migrate tags
|
||||||
@@ -70,12 +67,12 @@ def upgrade():
|
|||||||
|
|
||||||
tag_updates = {}
|
tag_updates = {}
|
||||||
for row in result:
|
for row in result:
|
||||||
new_id = row.name.replace(" ", "_").lower()
|
new_id = row.name.replace(' ', '_').lower()
|
||||||
tag_updates[row.id] = new_id
|
tag_updates[row.id] = new_id
|
||||||
|
|
||||||
for tag_id, new_tag_id in tag_updates.items():
|
for tag_id, new_tag_id in tag_updates.items():
|
||||||
print(f"Updating tag {tag_id} to {new_tag_id}")
|
print(f'Updating tag {tag_id} to {new_tag_id}')
|
||||||
if new_tag_id == "pinned":
|
if new_tag_id == 'pinned':
|
||||||
# delete tag
|
# delete tag
|
||||||
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
|
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
|
||||||
conn.execute(delete_stmt)
|
conn.execute(delete_stmt)
|
||||||
@@ -86,9 +83,7 @@ def upgrade():
|
|||||||
|
|
||||||
if existing_tag_result:
|
if existing_tag_result:
|
||||||
# Handle duplicate case: the new_tag_id already exists
|
# Handle duplicate case: the new_tag_id already exists
|
||||||
print(
|
print(f'Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates.')
|
||||||
f"Tag {new_tag_id} already exists. Removing current tag with ID {tag_id} to avoid duplicates."
|
|
||||||
)
|
|
||||||
# Option 1: Delete the current tag if an update to new_tag_id would cause duplication
|
# Option 1: Delete the current tag if an update to new_tag_id would cause duplication
|
||||||
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
|
delete_stmt = sa.delete(tag).where(tag.c.id == tag_id)
|
||||||
conn.execute(delete_stmt)
|
conn.execute(delete_stmt)
|
||||||
@@ -98,19 +93,15 @@ def upgrade():
|
|||||||
conn.execute(update_stmt)
|
conn.execute(update_stmt)
|
||||||
|
|
||||||
# Add columns `pinned` and `meta` to 'chat'
|
# Add columns `pinned` and `meta` to 'chat'
|
||||||
op.add_column("chat", sa.Column("pinned", sa.Boolean(), nullable=True))
|
op.add_column('chat', sa.Column('pinned', sa.Boolean(), nullable=True))
|
||||||
op.add_column(
|
op.add_column('chat', sa.Column('meta', sa.JSON(), nullable=False, server_default='{}'))
|
||||||
"chat", sa.Column("meta", sa.JSON(), nullable=False, server_default="{}")
|
|
||||||
)
|
|
||||||
|
|
||||||
chatidtag = table(
|
chatidtag = table('chatidtag', column('chat_id', sa.String()), column('tag_name', sa.String()))
|
||||||
"chatidtag", column("chat_id", sa.String()), column("tag_name", sa.String())
|
|
||||||
)
|
|
||||||
chat = table(
|
chat = table(
|
||||||
"chat",
|
'chat',
|
||||||
column("id", sa.String()),
|
column('id', sa.String()),
|
||||||
column("pinned", sa.Boolean()),
|
column('pinned', sa.Boolean()),
|
||||||
column("meta", sa.JSON()),
|
column('meta', sa.JSON()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fetch existing tags
|
# Fetch existing tags
|
||||||
@@ -120,29 +111,27 @@ def upgrade():
|
|||||||
chat_updates = {}
|
chat_updates = {}
|
||||||
for row in result:
|
for row in result:
|
||||||
chat_id = row.chat_id
|
chat_id = row.chat_id
|
||||||
tag_name = row.tag_name.replace(" ", "_").lower()
|
tag_name = row.tag_name.replace(' ', '_').lower()
|
||||||
|
|
||||||
if tag_name == "pinned":
|
if tag_name == 'pinned':
|
||||||
# Specifically handle 'pinned' tag
|
# Specifically handle 'pinned' tag
|
||||||
if chat_id not in chat_updates:
|
if chat_id not in chat_updates:
|
||||||
chat_updates[chat_id] = {"pinned": True, "meta": {}}
|
chat_updates[chat_id] = {'pinned': True, 'meta': {}}
|
||||||
else:
|
else:
|
||||||
chat_updates[chat_id]["pinned"] = True
|
chat_updates[chat_id]['pinned'] = True
|
||||||
else:
|
else:
|
||||||
if chat_id not in chat_updates:
|
if chat_id not in chat_updates:
|
||||||
chat_updates[chat_id] = {"pinned": False, "meta": {"tags": [tag_name]}}
|
chat_updates[chat_id] = {'pinned': False, 'meta': {'tags': [tag_name]}}
|
||||||
else:
|
else:
|
||||||
tags = chat_updates[chat_id]["meta"].get("tags", [])
|
tags = chat_updates[chat_id]['meta'].get('tags', [])
|
||||||
tags.append(tag_name)
|
tags.append(tag_name)
|
||||||
|
|
||||||
chat_updates[chat_id]["meta"]["tags"] = list(set(tags))
|
chat_updates[chat_id]['meta']['tags'] = list(set(tags))
|
||||||
|
|
||||||
# Update chats based on accumulated changes
|
# Update chats based on accumulated changes
|
||||||
for chat_id, updates in chat_updates.items():
|
for chat_id, updates in chat_updates.items():
|
||||||
update_stmt = sa.update(chat).where(chat.c.id == chat_id)
|
update_stmt = sa.update(chat).where(chat.c.id == chat_id)
|
||||||
update_stmt = update_stmt.values(
|
update_stmt = update_stmt.values(meta=updates.get('meta', {}), pinned=updates.get('pinned', False))
|
||||||
meta=updates.get("meta", {}), pinned=updates.get("pinned", False)
|
|
||||||
)
|
|
||||||
conn.execute(update_stmt)
|
conn.execute(update_stmt)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from sqlalchemy.sql import table, select, update
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
revision = "242a2047eae0"
|
revision = '242a2047eae0'
|
||||||
down_revision = "6a39f3d8e55c"
|
down_revision = '6a39f3d8e55c'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
@@ -22,39 +22,37 @@ def upgrade():
|
|||||||
conn = op.get_bind()
|
conn = op.get_bind()
|
||||||
inspector = sa.inspect(conn)
|
inspector = sa.inspect(conn)
|
||||||
|
|
||||||
columns = inspector.get_columns("chat")
|
columns = inspector.get_columns('chat')
|
||||||
column_dict = {col["name"]: col for col in columns}
|
column_dict = {col['name']: col for col in columns}
|
||||||
|
|
||||||
chat_column = column_dict.get("chat")
|
chat_column = column_dict.get('chat')
|
||||||
old_chat_exists = "old_chat" in column_dict
|
old_chat_exists = 'old_chat' in column_dict
|
||||||
|
|
||||||
if chat_column:
|
if chat_column:
|
||||||
if isinstance(chat_column["type"], sa.Text):
|
if isinstance(chat_column['type'], sa.Text):
|
||||||
print("Converting 'chat' column to JSON")
|
print("Converting 'chat' column to JSON")
|
||||||
|
|
||||||
if old_chat_exists:
|
if old_chat_exists:
|
||||||
print("Dropping old 'old_chat' column")
|
print("Dropping old 'old_chat' column")
|
||||||
op.drop_column("chat", "old_chat")
|
op.drop_column('chat', 'old_chat')
|
||||||
|
|
||||||
# Step 1: Rename current 'chat' column to 'old_chat'
|
# Step 1: Rename current 'chat' column to 'old_chat'
|
||||||
print("Renaming 'chat' column to 'old_chat'")
|
print("Renaming 'chat' column to 'old_chat'")
|
||||||
op.alter_column(
|
op.alter_column('chat', 'chat', new_column_name='old_chat', existing_type=sa.Text())
|
||||||
"chat", "chat", new_column_name="old_chat", existing_type=sa.Text()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 2: Add new 'chat' column of type JSON
|
# Step 2: Add new 'chat' column of type JSON
|
||||||
print("Adding new 'chat' column of type JSON")
|
print("Adding new 'chat' column of type JSON")
|
||||||
op.add_column("chat", sa.Column("chat", sa.JSON(), nullable=True))
|
op.add_column('chat', sa.Column('chat', sa.JSON(), nullable=True))
|
||||||
else:
|
else:
|
||||||
# If the column is already JSON, no need to do anything
|
# If the column is already JSON, no need to do anything
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Step 3: Migrate data from 'old_chat' to 'chat'
|
# Step 3: Migrate data from 'old_chat' to 'chat'
|
||||||
chat_table = table(
|
chat_table = table(
|
||||||
"chat",
|
'chat',
|
||||||
sa.Column("id", sa.String(), primary_key=True),
|
sa.Column('id', sa.String(), primary_key=True),
|
||||||
sa.Column("old_chat", sa.Text()),
|
sa.Column('old_chat', sa.Text()),
|
||||||
sa.Column("chat", sa.JSON()),
|
sa.Column('chat', sa.JSON()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# - Selecting all data from the table
|
# - Selecting all data from the table
|
||||||
@@ -67,41 +65,33 @@ def upgrade():
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
json_data = None # Handle cases where the text cannot be converted to JSON
|
json_data = None # Handle cases where the text cannot be converted to JSON
|
||||||
|
|
||||||
connection.execute(
|
connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(chat=json_data))
|
||||||
sa.update(chat_table)
|
|
||||||
.where(chat_table.c.id == row.id)
|
|
||||||
.values(chat=json_data)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 4: Drop 'old_chat' column
|
# Step 4: Drop 'old_chat' column
|
||||||
print("Dropping 'old_chat' column")
|
print("Dropping 'old_chat' column")
|
||||||
op.drop_column("chat", "old_chat")
|
op.drop_column('chat', 'old_chat')
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# Step 1: Add 'old_chat' column back as Text
|
# Step 1: Add 'old_chat' column back as Text
|
||||||
op.add_column("chat", sa.Column("old_chat", sa.Text(), nullable=True))
|
op.add_column('chat', sa.Column('old_chat', sa.Text(), nullable=True))
|
||||||
|
|
||||||
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
|
# Step 2: Convert 'chat' JSON data back to text and store in 'old_chat'
|
||||||
chat_table = table(
|
chat_table = table(
|
||||||
"chat",
|
'chat',
|
||||||
sa.Column("id", sa.String(), primary_key=True),
|
sa.Column('id', sa.String(), primary_key=True),
|
||||||
sa.Column("chat", sa.JSON()),
|
sa.Column('chat', sa.JSON()),
|
||||||
sa.Column("old_chat", sa.Text()),
|
sa.Column('old_chat', sa.Text()),
|
||||||
)
|
)
|
||||||
|
|
||||||
connection = op.get_bind()
|
connection = op.get_bind()
|
||||||
results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
|
results = connection.execute(select(chat_table.c.id, chat_table.c.chat))
|
||||||
for row in results:
|
for row in results:
|
||||||
text_data = json.dumps(row.chat) if row.chat is not None else None
|
text_data = json.dumps(row.chat) if row.chat is not None else None
|
||||||
connection.execute(
|
connection.execute(sa.update(chat_table).where(chat_table.c.id == row.id).values(old_chat=text_data))
|
||||||
sa.update(chat_table)
|
|
||||||
.where(chat_table.c.id == row.id)
|
|
||||||
.values(old_chat=text_data)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 3: Remove the new 'chat' JSON column
|
# Step 3: Remove the new 'chat' JSON column
|
||||||
op.drop_column("chat", "chat")
|
op.drop_column('chat', 'chat')
|
||||||
|
|
||||||
# Step 4: Rename 'old_chat' back to 'chat'
|
# Step 4: Rename 'old_chat' back to 'chat'
|
||||||
op.alter_column("chat", "old_chat", new_column_name="chat", existing_type=sa.Text())
|
op.alter_column('chat', 'old_chat', new_column_name='chat', existing_type=sa.Text())
|
||||||
|
|||||||
@@ -13,19 +13,19 @@ import sqlalchemy as sa
|
|||||||
import open_webui.internal.db
|
import open_webui.internal.db
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "2f1211949ecc"
|
revision: str = '2f1211949ecc'
|
||||||
down_revision: Union[str, None] = "37f288994c47"
|
down_revision: Union[str, None] = '37f288994c47'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# New columns to be added to channel_member table
|
# New columns to be added to channel_member table
|
||||||
op.add_column("channel_member", sa.Column("status", sa.Text(), nullable=True))
|
op.add_column('channel_member', sa.Column('status', sa.Text(), nullable=True))
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"channel_member",
|
'channel_member',
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"is_active",
|
'is_active',
|
||||||
sa.Boolean(),
|
sa.Boolean(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=True,
|
default=True,
|
||||||
@@ -34,9 +34,9 @@ def upgrade() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"channel_member",
|
'channel_member',
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"is_channel_muted",
|
'is_channel_muted',
|
||||||
sa.Boolean(),
|
sa.Boolean(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=False,
|
default=False,
|
||||||
@@ -44,9 +44,9 @@ def upgrade() -> None:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"channel_member",
|
'channel_member',
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"is_channel_pinned",
|
'is_channel_pinned',
|
||||||
sa.Boolean(),
|
sa.Boolean(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=False,
|
default=False,
|
||||||
@@ -54,49 +54,41 @@ def upgrade() -> None:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
op.add_column("channel_member", sa.Column("data", sa.JSON(), nullable=True))
|
op.add_column('channel_member', sa.Column('data', sa.JSON(), nullable=True))
|
||||||
op.add_column("channel_member", sa.Column("meta", sa.JSON(), nullable=True))
|
op.add_column('channel_member', sa.Column('meta', sa.JSON(), nullable=True))
|
||||||
|
|
||||||
op.add_column(
|
op.add_column('channel_member', sa.Column('joined_at', sa.BigInteger(), nullable=False))
|
||||||
"channel_member", sa.Column("joined_at", sa.BigInteger(), nullable=False)
|
op.add_column('channel_member', sa.Column('left_at', sa.BigInteger(), nullable=True))
|
||||||
)
|
|
||||||
op.add_column(
|
|
||||||
"channel_member", sa.Column("left_at", sa.BigInteger(), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
op.add_column(
|
op.add_column('channel_member', sa.Column('last_read_at', sa.BigInteger(), nullable=True))
|
||||||
"channel_member", sa.Column("last_read_at", sa.BigInteger(), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
op.add_column(
|
op.add_column('channel_member', sa.Column('updated_at', sa.BigInteger(), nullable=True))
|
||||||
"channel_member", sa.Column("updated_at", sa.BigInteger(), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# New columns to be added to message table
|
# New columns to be added to message table
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"message",
|
'message',
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"is_pinned",
|
'is_pinned',
|
||||||
sa.Boolean(),
|
sa.Boolean(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=False,
|
default=False,
|
||||||
server_default=sa.sql.expression.false(),
|
server_default=sa.sql.expression.false(),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
op.add_column("message", sa.Column("pinned_at", sa.BigInteger(), nullable=True))
|
op.add_column('message', sa.Column('pinned_at', sa.BigInteger(), nullable=True))
|
||||||
op.add_column("message", sa.Column("pinned_by", sa.Text(), nullable=True))
|
op.add_column('message', sa.Column('pinned_by', sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_column("channel_member", "updated_at")
|
op.drop_column('channel_member', 'updated_at')
|
||||||
op.drop_column("channel_member", "last_read_at")
|
op.drop_column('channel_member', 'last_read_at')
|
||||||
|
|
||||||
op.drop_column("channel_member", "meta")
|
op.drop_column('channel_member', 'meta')
|
||||||
op.drop_column("channel_member", "data")
|
op.drop_column('channel_member', 'data')
|
||||||
|
|
||||||
op.drop_column("channel_member", "is_channel_pinned")
|
op.drop_column('channel_member', 'is_channel_pinned')
|
||||||
op.drop_column("channel_member", "is_channel_muted")
|
op.drop_column('channel_member', 'is_channel_muted')
|
||||||
|
|
||||||
op.drop_column("message", "pinned_by")
|
op.drop_column('message', 'pinned_by')
|
||||||
op.drop_column("message", "pinned_at")
|
op.drop_column('message', 'pinned_at')
|
||||||
op.drop_column("message", "is_pinned")
|
op.drop_column('message', 'is_pinned')
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ import uuid
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
revision: str = "374d2f66af06"
|
revision: str = '374d2f66af06'
|
||||||
down_revision: Union[str, None] = "c440947495f3"
|
down_revision: Union[str, None] = 'c440947495f3'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
@@ -26,13 +26,13 @@ def upgrade() -> None:
|
|||||||
# We need to assume the OLD structure.
|
# We need to assume the OLD structure.
|
||||||
|
|
||||||
old_prompt_table = sa.table(
|
old_prompt_table = sa.table(
|
||||||
"prompt",
|
'prompt',
|
||||||
sa.column("command", sa.Text()),
|
sa.column('command', sa.Text()),
|
||||||
sa.column("user_id", sa.Text()),
|
sa.column('user_id', sa.Text()),
|
||||||
sa.column("title", sa.Text()),
|
sa.column('title', sa.Text()),
|
||||||
sa.column("content", sa.Text()),
|
sa.column('content', sa.Text()),
|
||||||
sa.column("timestamp", sa.BigInteger()),
|
sa.column('timestamp', sa.BigInteger()),
|
||||||
sa.column("access_control", sa.JSON()),
|
sa.column('access_control', sa.JSON()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if table exists/read data
|
# Check if table exists/read data
|
||||||
@@ -53,61 +53,61 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
# Step 2: Create new prompt table with 'id' as PRIMARY KEY
|
# Step 2: Create new prompt table with 'id' as PRIMARY KEY
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"prompt_new",
|
'prompt_new',
|
||||||
sa.Column("id", sa.Text(), primary_key=True),
|
sa.Column('id', sa.Text(), primary_key=True),
|
||||||
sa.Column("command", sa.String(), unique=True, index=True),
|
sa.Column('command', sa.String(), unique=True, index=True),
|
||||||
sa.Column("user_id", sa.String(), nullable=False),
|
sa.Column('user_id', sa.String(), nullable=False),
|
||||||
sa.Column("name", sa.Text(), nullable=False),
|
sa.Column('name', sa.Text(), nullable=False),
|
||||||
sa.Column("content", sa.Text(), nullable=False),
|
sa.Column('content', sa.Text(), nullable=False),
|
||||||
sa.Column("data", sa.JSON(), nullable=True),
|
sa.Column('data', sa.JSON(), nullable=True),
|
||||||
sa.Column("meta", sa.JSON(), nullable=True),
|
sa.Column('meta', sa.JSON(), nullable=True),
|
||||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||||
sa.Column("is_active", sa.Boolean(), nullable=False, server_default="1"),
|
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='1'),
|
||||||
sa.Column("version_id", sa.Text(), nullable=True),
|
sa.Column('version_id', sa.Text(), nullable=True),
|
||||||
sa.Column("tags", sa.JSON(), nullable=True),
|
sa.Column('tags', sa.JSON(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3: Create prompt_history table
|
# Step 3: Create prompt_history table
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"prompt_history",
|
'prompt_history',
|
||||||
sa.Column("id", sa.Text(), primary_key=True),
|
sa.Column('id', sa.Text(), primary_key=True),
|
||||||
sa.Column("prompt_id", sa.Text(), nullable=False, index=True),
|
sa.Column('prompt_id', sa.Text(), nullable=False, index=True),
|
||||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
sa.Column('parent_id', sa.Text(), nullable=True),
|
||||||
sa.Column("snapshot", sa.JSON(), nullable=False),
|
sa.Column('snapshot', sa.JSON(), nullable=False),
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
sa.Column('user_id', sa.Text(), nullable=False),
|
||||||
sa.Column("commit_message", sa.Text(), nullable=True),
|
sa.Column('commit_message', sa.Text(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 4: Migrate data
|
# Step 4: Migrate data
|
||||||
prompt_new_table = sa.table(
|
prompt_new_table = sa.table(
|
||||||
"prompt_new",
|
'prompt_new',
|
||||||
sa.column("id", sa.Text()),
|
sa.column('id', sa.Text()),
|
||||||
sa.column("command", sa.String()),
|
sa.column('command', sa.String()),
|
||||||
sa.column("user_id", sa.String()),
|
sa.column('user_id', sa.String()),
|
||||||
sa.column("name", sa.Text()),
|
sa.column('name', sa.Text()),
|
||||||
sa.column("content", sa.Text()),
|
sa.column('content', sa.Text()),
|
||||||
sa.column("data", sa.JSON()),
|
sa.column('data', sa.JSON()),
|
||||||
sa.column("meta", sa.JSON()),
|
sa.column('meta', sa.JSON()),
|
||||||
sa.column("access_control", sa.JSON()),
|
sa.column('access_control', sa.JSON()),
|
||||||
sa.column("is_active", sa.Boolean()),
|
sa.column('is_active', sa.Boolean()),
|
||||||
sa.column("version_id", sa.Text()),
|
sa.column('version_id', sa.Text()),
|
||||||
sa.column("tags", sa.JSON()),
|
sa.column('tags', sa.JSON()),
|
||||||
sa.column("created_at", sa.BigInteger()),
|
sa.column('created_at', sa.BigInteger()),
|
||||||
sa.column("updated_at", sa.BigInteger()),
|
sa.column('updated_at', sa.BigInteger()),
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_history_table = sa.table(
|
prompt_history_table = sa.table(
|
||||||
"prompt_history",
|
'prompt_history',
|
||||||
sa.column("id", sa.Text()),
|
sa.column('id', sa.Text()),
|
||||||
sa.column("prompt_id", sa.Text()),
|
sa.column('prompt_id', sa.Text()),
|
||||||
sa.column("parent_id", sa.Text()),
|
sa.column('parent_id', sa.Text()),
|
||||||
sa.column("snapshot", sa.JSON()),
|
sa.column('snapshot', sa.JSON()),
|
||||||
sa.column("user_id", sa.Text()),
|
sa.column('user_id', sa.Text()),
|
||||||
sa.column("commit_message", sa.Text()),
|
sa.column('commit_message', sa.Text()),
|
||||||
sa.column("created_at", sa.BigInteger()),
|
sa.column('created_at', sa.BigInteger()),
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in existing_prompts:
|
for row in existing_prompts:
|
||||||
@@ -120,7 +120,7 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
new_uuid = str(uuid.uuid4())
|
new_uuid = str(uuid.uuid4())
|
||||||
history_uuid = str(uuid.uuid4())
|
history_uuid = str(uuid.uuid4())
|
||||||
clean_command = command[1:] if command and command.startswith("/") else command
|
clean_command = command[1:] if command and command.startswith('/') else command
|
||||||
|
|
||||||
# Insert into prompt_new
|
# Insert into prompt_new
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@@ -148,12 +148,12 @@ def upgrade() -> None:
|
|||||||
prompt_id=new_uuid,
|
prompt_id=new_uuid,
|
||||||
parent_id=None,
|
parent_id=None,
|
||||||
snapshot={
|
snapshot={
|
||||||
"name": title,
|
'name': title,
|
||||||
"content": content,
|
'content': content,
|
||||||
"command": clean_command,
|
'command': clean_command,
|
||||||
"data": {},
|
'data': {},
|
||||||
"meta": {},
|
'meta': {},
|
||||||
"access_control": access_control,
|
'access_control': access_control,
|
||||||
},
|
},
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
commit_message=None,
|
commit_message=None,
|
||||||
@@ -162,8 +162,8 @@ def upgrade() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Step 5: Replace old table with new one
|
# Step 5: Replace old table with new one
|
||||||
op.drop_table("prompt")
|
op.drop_table('prompt')
|
||||||
op.rename_table("prompt_new", "prompt")
|
op.rename_table('prompt_new', 'prompt')
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
@@ -171,13 +171,13 @@ def downgrade() -> None:
|
|||||||
|
|
||||||
# Step 1: Read new data
|
# Step 1: Read new data
|
||||||
prompt_table = sa.table(
|
prompt_table = sa.table(
|
||||||
"prompt",
|
'prompt',
|
||||||
sa.column("command", sa.String()),
|
sa.column('command', sa.String()),
|
||||||
sa.column("name", sa.Text()),
|
sa.column('name', sa.Text()),
|
||||||
sa.column("created_at", sa.BigInteger()),
|
sa.column('created_at', sa.BigInteger()),
|
||||||
sa.column("user_id", sa.Text()),
|
sa.column('user_id', sa.Text()),
|
||||||
sa.column("content", sa.Text()),
|
sa.column('content', sa.Text()),
|
||||||
sa.column("access_control", sa.JSON()),
|
sa.column('access_control', sa.JSON()),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -195,31 +195,31 @@ def downgrade() -> None:
|
|||||||
current_data = []
|
current_data = []
|
||||||
|
|
||||||
# Step 2: Drop history and table
|
# Step 2: Drop history and table
|
||||||
op.drop_table("prompt_history")
|
op.drop_table('prompt_history')
|
||||||
op.drop_table("prompt")
|
op.drop_table('prompt')
|
||||||
|
|
||||||
# Step 3: Recreate old table (command as PK?)
|
# Step 3: Recreate old table (command as PK?)
|
||||||
# Assuming old schema:
|
# Assuming old schema:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"prompt",
|
'prompt',
|
||||||
sa.Column("command", sa.String(), primary_key=True),
|
sa.Column('command', sa.String(), primary_key=True),
|
||||||
sa.Column("user_id", sa.String()),
|
sa.Column('user_id', sa.String()),
|
||||||
sa.Column("title", sa.Text()),
|
sa.Column('title', sa.Text()),
|
||||||
sa.Column("content", sa.Text()),
|
sa.Column('content', sa.Text()),
|
||||||
sa.Column("timestamp", sa.BigInteger()),
|
sa.Column('timestamp', sa.BigInteger()),
|
||||||
sa.Column("access_control", sa.JSON()),
|
sa.Column('access_control', sa.JSON()),
|
||||||
sa.Column("id", sa.Integer(), nullable=True),
|
sa.Column('id', sa.Integer(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 4: Restore data
|
# Step 4: Restore data
|
||||||
old_prompt_table = sa.table(
|
old_prompt_table = sa.table(
|
||||||
"prompt",
|
'prompt',
|
||||||
sa.column("command", sa.String()),
|
sa.column('command', sa.String()),
|
||||||
sa.column("user_id", sa.String()),
|
sa.column('user_id', sa.String()),
|
||||||
sa.column("title", sa.Text()),
|
sa.column('title', sa.Text()),
|
||||||
sa.column("content", sa.Text()),
|
sa.column('content', sa.Text()),
|
||||||
sa.column("timestamp", sa.BigInteger()),
|
sa.column('timestamp', sa.BigInteger()),
|
||||||
sa.column("access_control", sa.JSON()),
|
sa.column('access_control', sa.JSON()),
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in current_data:
|
for row in current_data:
|
||||||
@@ -231,9 +231,7 @@ def downgrade() -> None:
|
|||||||
access_control = row[5]
|
access_control = row[5]
|
||||||
|
|
||||||
# Restore leading /
|
# Restore leading /
|
||||||
old_command = (
|
old_command = '/' + command if command and not command.startswith('/') else command
|
||||||
"/" + command if command and not command.startswith("/") else command
|
|
||||||
)
|
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
sa.insert(old_prompt_table).values(
|
sa.insert(old_prompt_table).values(
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ Create Date: 2024-12-30 03:00:00.000000
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
revision = "3781e22d8b01"
|
revision = '3781e22d8b01'
|
||||||
down_revision = "7826ab40b532"
|
down_revision = '7826ab40b532'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
@@ -18,9 +18,9 @@ depends_on = None
|
|||||||
def upgrade():
|
def upgrade():
|
||||||
# Add 'type' column to the 'channel' table
|
# Add 'type' column to the 'channel' table
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"channel",
|
'channel',
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"type",
|
'type',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
),
|
),
|
||||||
@@ -28,43 +28,31 @@ def upgrade():
|
|||||||
|
|
||||||
# Add 'parent_id' column to the 'message' table for threads
|
# Add 'parent_id' column to the 'message' table for threads
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"message",
|
'message',
|
||||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
sa.Column('parent_id', sa.Text(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"message_reaction",
|
'message_reaction',
|
||||||
sa.Column(
|
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Unique reaction ID
|
||||||
"id", sa.Text(), nullable=False, primary_key=True, unique=True
|
sa.Column('user_id', sa.Text(), nullable=False), # User who reacted
|
||||||
), # Unique reaction ID
|
sa.Column('message_id', sa.Text(), nullable=False), # Message that was reacted to
|
||||||
sa.Column("user_id", sa.Text(), nullable=False), # User who reacted
|
sa.Column('name', sa.Text(), nullable=False), # Reaction name (e.g. "thumbs_up")
|
||||||
sa.Column(
|
sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the reaction was added
|
||||||
"message_id", sa.Text(), nullable=False
|
|
||||||
), # Message that was reacted to
|
|
||||||
sa.Column(
|
|
||||||
"name", sa.Text(), nullable=False
|
|
||||||
), # Reaction name (e.g. "thumbs_up")
|
|
||||||
sa.Column(
|
|
||||||
"created_at", sa.BigInteger(), nullable=True
|
|
||||||
), # Timestamp of when the reaction was added
|
|
||||||
)
|
)
|
||||||
|
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"channel_member",
|
'channel_member',
|
||||||
sa.Column(
|
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True), # Record ID for the membership row
|
||||||
"id", sa.Text(), nullable=False, primary_key=True, unique=True
|
sa.Column('channel_id', sa.Text(), nullable=False), # Associated channel
|
||||||
), # Record ID for the membership row
|
sa.Column('user_id', sa.Text(), nullable=False), # Associated user
|
||||||
sa.Column("channel_id", sa.Text(), nullable=False), # Associated channel
|
sa.Column('created_at', sa.BigInteger(), nullable=True), # Timestamp of when the user joined the channel
|
||||||
sa.Column("user_id", sa.Text(), nullable=False), # Associated user
|
|
||||||
sa.Column(
|
|
||||||
"created_at", sa.BigInteger(), nullable=True
|
|
||||||
), # Timestamp of when the user joined the channel
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# Revert 'type' column addition to the 'channel' table
|
# Revert 'type' column addition to the 'channel' table
|
||||||
op.drop_column("channel", "type")
|
op.drop_column('channel', 'type')
|
||||||
op.drop_column("message", "parent_id")
|
op.drop_column('message', 'parent_id')
|
||||||
op.drop_table("message_reaction")
|
op.drop_table('message_reaction')
|
||||||
op.drop_table("channel_member")
|
op.drop_table('channel_member')
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ from alembic import op
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "37f288994c47"
|
revision: str = '37f288994c47'
|
||||||
down_revision: Union[str, None] = "a5c220713937"
|
down_revision: Union[str, None] = 'a5c220713937'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
@@ -24,50 +24,48 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# 1. Create new table
|
# 1. Create new table
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"group_member",
|
'group_member',
|
||||||
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
|
sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"group_id",
|
'group_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey("group.id", ondelete="CASCADE"),
|
sa.ForeignKey('group.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"user_id",
|
'user_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
sa.ForeignKey('user.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
sa.UniqueConstraint("group_id", "user_id", name="uq_group_member_group_user"),
|
sa.UniqueConstraint('group_id', 'user_id', name='uq_group_member_group_user'),
|
||||||
)
|
)
|
||||||
|
|
||||||
connection = op.get_bind()
|
connection = op.get_bind()
|
||||||
|
|
||||||
# 2. Read existing group with user_ids JSON column
|
# 2. Read existing group with user_ids JSON column
|
||||||
group_table = sa.Table(
|
group_table = sa.Table(
|
||||||
"group",
|
'group',
|
||||||
sa.MetaData(),
|
sa.MetaData(),
|
||||||
sa.Column("id", sa.Text()),
|
sa.Column('id', sa.Text()),
|
||||||
sa.Column("user_ids", sa.JSON()), # JSON stored as text in SQLite + PG
|
sa.Column('user_ids', sa.JSON()), # JSON stored as text in SQLite + PG
|
||||||
)
|
)
|
||||||
|
|
||||||
results = connection.execute(
|
results = connection.execute(sa.select(group_table.c.id, group_table.c.user_ids)).fetchall()
|
||||||
sa.select(group_table.c.id, group_table.c.user_ids)
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
print(results)
|
print(results)
|
||||||
|
|
||||||
# 3. Insert members into group_member table
|
# 3. Insert members into group_member table
|
||||||
gm_table = sa.Table(
|
gm_table = sa.Table(
|
||||||
"group_member",
|
'group_member',
|
||||||
sa.MetaData(),
|
sa.MetaData(),
|
||||||
sa.Column("id", sa.Text()),
|
sa.Column('id', sa.Text()),
|
||||||
sa.Column("group_id", sa.Text()),
|
sa.Column('group_id', sa.Text()),
|
||||||
sa.Column("user_id", sa.Text()),
|
sa.Column('user_id', sa.Text()),
|
||||||
sa.Column("created_at", sa.BigInteger()),
|
sa.Column('created_at', sa.BigInteger()),
|
||||||
sa.Column("updated_at", sa.BigInteger()),
|
sa.Column('updated_at', sa.BigInteger()),
|
||||||
)
|
)
|
||||||
|
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
@@ -86,11 +84,11 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
rows = [
|
rows = [
|
||||||
{
|
{
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"group_id": group_id,
|
'group_id': group_id,
|
||||||
"user_id": uid,
|
'user_id': uid,
|
||||||
"created_at": now,
|
'created_at': now,
|
||||||
"updated_at": now,
|
'updated_at': now,
|
||||||
}
|
}
|
||||||
for uid in user_ids
|
for uid in user_ids
|
||||||
]
|
]
|
||||||
@@ -99,47 +97,41 @@ def upgrade() -> None:
|
|||||||
connection.execute(gm_table.insert(), rows)
|
connection.execute(gm_table.insert(), rows)
|
||||||
|
|
||||||
# 4. Optionally drop the old column
|
# 4. Optionally drop the old column
|
||||||
with op.batch_alter_table("group") as batch:
|
with op.batch_alter_table('group') as batch:
|
||||||
batch.drop_column("user_ids")
|
batch.drop_column('user_ids')
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# Reverse: restore user_ids column
|
# Reverse: restore user_ids column
|
||||||
with op.batch_alter_table("group") as batch:
|
with op.batch_alter_table('group') as batch:
|
||||||
batch.add_column(sa.Column("user_ids", sa.JSON()))
|
batch.add_column(sa.Column('user_ids', sa.JSON()))
|
||||||
|
|
||||||
connection = op.get_bind()
|
connection = op.get_bind()
|
||||||
gm_table = sa.Table(
|
gm_table = sa.Table(
|
||||||
"group_member",
|
'group_member',
|
||||||
sa.MetaData(),
|
sa.MetaData(),
|
||||||
sa.Column("group_id", sa.Text()),
|
sa.Column('group_id', sa.Text()),
|
||||||
sa.Column("user_id", sa.Text()),
|
sa.Column('user_id', sa.Text()),
|
||||||
sa.Column("created_at", sa.BigInteger()),
|
sa.Column('created_at', sa.BigInteger()),
|
||||||
sa.Column("updated_at", sa.BigInteger()),
|
sa.Column('updated_at', sa.BigInteger()),
|
||||||
)
|
)
|
||||||
|
|
||||||
group_table = sa.Table(
|
group_table = sa.Table(
|
||||||
"group",
|
'group',
|
||||||
sa.MetaData(),
|
sa.MetaData(),
|
||||||
sa.Column("id", sa.Text()),
|
sa.Column('id', sa.Text()),
|
||||||
sa.Column("user_ids", sa.JSON()),
|
sa.Column('user_ids', sa.JSON()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build JSON arrays again
|
# Build JSON arrays again
|
||||||
results = connection.execute(sa.select(group_table.c.id)).fetchall()
|
results = connection.execute(sa.select(group_table.c.id)).fetchall()
|
||||||
|
|
||||||
for (group_id,) in results:
|
for (group_id,) in results:
|
||||||
members = connection.execute(
|
members = connection.execute(sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)).fetchall()
|
||||||
sa.select(gm_table.c.user_id).where(gm_table.c.group_id == group_id)
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
member_ids = [m[0] for m in members]
|
member_ids = [m[0] for m in members]
|
||||||
|
|
||||||
connection.execute(
|
connection.execute(group_table.update().where(group_table.c.id == group_id).values(user_ids=member_ids))
|
||||||
group_table.update()
|
|
||||||
.where(group_table.c.id == group_id)
|
|
||||||
.values(user_ids=member_ids)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Drop the new table
|
# Drop the new table
|
||||||
op.drop_table("group_member")
|
op.drop_table('group_member')
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from alembic import op
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "38d63c18f30f"
|
revision: str = '38d63c18f30f'
|
||||||
down_revision: Union[str, None] = "3af16a1c9fb6"
|
down_revision: Union[str, None] = '3af16a1c9fb6'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
@@ -21,59 +21,55 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint)
|
# Ensure 'id' column in 'user' table is unique and primary key (ForeignKey constraint)
|
||||||
inspector = sa.inspect(op.get_bind())
|
inspector = sa.inspect(op.get_bind())
|
||||||
columns = inspector.get_columns("user")
|
columns = inspector.get_columns('user')
|
||||||
|
|
||||||
pk_columns = inspector.get_pk_constraint("user")["constrained_columns"]
|
pk_columns = inspector.get_pk_constraint('user')['constrained_columns']
|
||||||
id_column = next((col for col in columns if col["name"] == "id"), None)
|
id_column = next((col for col in columns if col['name'] == 'id'), None)
|
||||||
|
|
||||||
if id_column and not id_column.get("unique", False):
|
if id_column and not id_column.get('unique', False):
|
||||||
unique_constraints = inspector.get_unique_constraints("user")
|
unique_constraints = inspector.get_unique_constraints('user')
|
||||||
unique_columns = {tuple(u["column_names"]) for u in unique_constraints}
|
unique_columns = {tuple(u['column_names']) for u in unique_constraints}
|
||||||
|
|
||||||
with op.batch_alter_table("user") as batch_op:
|
with op.batch_alter_table('user') as batch_op:
|
||||||
# If primary key is wrong, drop it
|
# If primary key is wrong, drop it
|
||||||
if pk_columns and pk_columns != ["id"]:
|
if pk_columns and pk_columns != ['id']:
|
||||||
batch_op.drop_constraint(
|
batch_op.drop_constraint(inspector.get_pk_constraint('user')['name'], type_='primary')
|
||||||
inspector.get_pk_constraint("user")["name"], type_="primary"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add unique constraint if missing
|
# Add unique constraint if missing
|
||||||
if ("id",) not in unique_columns:
|
if ('id',) not in unique_columns:
|
||||||
batch_op.create_unique_constraint("uq_user_id", ["id"])
|
batch_op.create_unique_constraint('uq_user_id', ['id'])
|
||||||
|
|
||||||
# Re-create correct primary key
|
# Re-create correct primary key
|
||||||
batch_op.create_primary_key("pk_user_id", ["id"])
|
batch_op.create_primary_key('pk_user_id', ['id'])
|
||||||
|
|
||||||
# Create oauth_session table
|
# Create oauth_session table
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"oauth_session",
|
'oauth_session',
|
||||||
sa.Column("id", sa.Text(), primary_key=True, nullable=False, unique=True),
|
sa.Column('id', sa.Text(), primary_key=True, nullable=False, unique=True),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"user_id",
|
'user_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
sa.ForeignKey('user.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column("provider", sa.Text(), nullable=False),
|
sa.Column('provider', sa.Text(), nullable=False),
|
||||||
sa.Column("token", sa.Text(), nullable=False),
|
sa.Column('token', sa.Text(), nullable=False),
|
||||||
sa.Column("expires_at", sa.BigInteger(), nullable=False),
|
sa.Column('expires_at', sa.BigInteger(), nullable=False),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create indexes for better performance
|
# Create indexes for better performance
|
||||||
op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"])
|
op.create_index('idx_oauth_session_user_id', 'oauth_session', ['user_id'])
|
||||||
op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"])
|
op.create_index('idx_oauth_session_expires_at', 'oauth_session', ['expires_at'])
|
||||||
op.create_index(
|
op.create_index('idx_oauth_session_user_provider', 'oauth_session', ['user_id', 'provider'])
|
||||||
"idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# Drop indexes first
|
# Drop indexes first
|
||||||
op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session")
|
op.drop_index('idx_oauth_session_user_provider', table_name='oauth_session')
|
||||||
op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session")
|
op.drop_index('idx_oauth_session_expires_at', table_name='oauth_session')
|
||||||
op.drop_index("idx_oauth_session_user_id", table_name="oauth_session")
|
op.drop_index('idx_oauth_session_user_id', table_name='oauth_session')
|
||||||
|
|
||||||
# Drop the table
|
# Drop the table
|
||||||
op.drop_table("oauth_session")
|
op.drop_table('oauth_session')
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from sqlalchemy.engine.reflection import Inspector
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
revision = "3ab32c4b8f59"
|
revision = '3ab32c4b8f59'
|
||||||
down_revision = "1af9b942657b"
|
down_revision = '1af9b942657b'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
@@ -24,58 +24,55 @@ def upgrade():
|
|||||||
inspector = Inspector.from_engine(conn)
|
inspector = Inspector.from_engine(conn)
|
||||||
|
|
||||||
# Inspecting the 'tag' table constraints and structure
|
# Inspecting the 'tag' table constraints and structure
|
||||||
existing_pk = inspector.get_pk_constraint("tag")
|
existing_pk = inspector.get_pk_constraint('tag')
|
||||||
unique_constraints = inspector.get_unique_constraints("tag")
|
unique_constraints = inspector.get_unique_constraints('tag')
|
||||||
existing_indexes = inspector.get_indexes("tag")
|
existing_indexes = inspector.get_indexes('tag')
|
||||||
|
|
||||||
print(f"Primary Key: {existing_pk}")
|
print(f'Primary Key: {existing_pk}')
|
||||||
print(f"Unique Constraints: {unique_constraints}")
|
print(f'Unique Constraints: {unique_constraints}')
|
||||||
print(f"Indexes: {existing_indexes}")
|
print(f'Indexes: {existing_indexes}')
|
||||||
|
|
||||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
with op.batch_alter_table('tag', schema=None) as batch_op:
|
||||||
# Drop existing primary key constraint if it exists
|
# Drop existing primary key constraint if it exists
|
||||||
if existing_pk and existing_pk.get("constrained_columns"):
|
if existing_pk and existing_pk.get('constrained_columns'):
|
||||||
pk_name = existing_pk.get("name")
|
pk_name = existing_pk.get('name')
|
||||||
if pk_name:
|
if pk_name:
|
||||||
print(f"Dropping primary key constraint: {pk_name}")
|
print(f'Dropping primary key constraint: {pk_name}')
|
||||||
batch_op.drop_constraint(pk_name, type_="primary")
|
batch_op.drop_constraint(pk_name, type_='primary')
|
||||||
|
|
||||||
# Now create the new primary key with the combination of 'id' and 'user_id'
|
# Now create the new primary key with the combination of 'id' and 'user_id'
|
||||||
print("Creating new primary key with 'id' and 'user_id'.")
|
print("Creating new primary key with 'id' and 'user_id'.")
|
||||||
batch_op.create_primary_key("pk_id_user_id", ["id", "user_id"])
|
batch_op.create_primary_key('pk_id_user_id', ['id', 'user_id'])
|
||||||
|
|
||||||
# Drop unique constraints that could conflict with the new primary key
|
# Drop unique constraints that could conflict with the new primary key
|
||||||
for constraint in unique_constraints:
|
for constraint in unique_constraints:
|
||||||
if (
|
if (
|
||||||
constraint["name"] == "uq_id_user_id"
|
constraint['name'] == 'uq_id_user_id'
|
||||||
): # Adjust this name according to what is actually returned by the inspector
|
): # Adjust this name according to what is actually returned by the inspector
|
||||||
print(f"Dropping unique constraint: {constraint['name']}")
|
print(f'Dropping unique constraint: {constraint["name"]}')
|
||||||
batch_op.drop_constraint(constraint["name"], type_="unique")
|
batch_op.drop_constraint(constraint['name'], type_='unique')
|
||||||
|
|
||||||
for index in existing_indexes:
|
for index in existing_indexes:
|
||||||
if index["unique"]:
|
if index['unique']:
|
||||||
if not any(
|
if not any(constraint['name'] == index['name'] for constraint in unique_constraints):
|
||||||
constraint["name"] == index["name"]
|
|
||||||
for constraint in unique_constraints
|
|
||||||
):
|
|
||||||
# You are attempting to drop unique indexes
|
# You are attempting to drop unique indexes
|
||||||
print(f"Dropping unique index: {index['name']}")
|
print(f'Dropping unique index: {index["name"]}')
|
||||||
batch_op.drop_index(index["name"])
|
batch_op.drop_index(index['name'])
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
conn = op.get_bind()
|
conn = op.get_bind()
|
||||||
inspector = Inspector.from_engine(conn)
|
inspector = Inspector.from_engine(conn)
|
||||||
|
|
||||||
current_pk = inspector.get_pk_constraint("tag")
|
current_pk = inspector.get_pk_constraint('tag')
|
||||||
|
|
||||||
with op.batch_alter_table("tag", schema=None) as batch_op:
|
with op.batch_alter_table('tag', schema=None) as batch_op:
|
||||||
# Drop the current primary key first, if it matches the one we know we added in upgrade
|
# Drop the current primary key first, if it matches the one we know we added in upgrade
|
||||||
if current_pk and "pk_id_user_id" == current_pk.get("name"):
|
if current_pk and 'pk_id_user_id' == current_pk.get('name'):
|
||||||
batch_op.drop_constraint("pk_id_user_id", type_="primary")
|
batch_op.drop_constraint('pk_id_user_id', type_='primary')
|
||||||
|
|
||||||
# Restore the original primary key
|
# Restore the original primary key
|
||||||
batch_op.create_primary_key("pk_id", ["id"])
|
batch_op.create_primary_key('pk_id', ['id'])
|
||||||
|
|
||||||
# Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
|
# Since primary key on just 'id' is restored, we now add back any unique constraints if necessary
|
||||||
batch_op.create_unique_constraint("uq_id_user_id", ["id", "user_id"])
|
batch_op.create_unique_constraint('uq_id_user_id', ['id', 'user_id'])
|
||||||
|
|||||||
@@ -12,21 +12,21 @@ from alembic import op
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "3af16a1c9fb6"
|
revision: str = '3af16a1c9fb6'
|
||||||
down_revision: Union[str, None] = "018012973d35"
|
down_revision: Union[str, None] = '018012973d35'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.add_column("user", sa.Column("username", sa.String(length=50), nullable=True))
|
op.add_column('user', sa.Column('username', sa.String(length=50), nullable=True))
|
||||||
op.add_column("user", sa.Column("bio", sa.Text(), nullable=True))
|
op.add_column('user', sa.Column('bio', sa.Text(), nullable=True))
|
||||||
op.add_column("user", sa.Column("gender", sa.Text(), nullable=True))
|
op.add_column('user', sa.Column('gender', sa.Text(), nullable=True))
|
||||||
op.add_column("user", sa.Column("date_of_birth", sa.Date(), nullable=True))
|
op.add_column('user', sa.Column('date_of_birth', sa.Date(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_column("user", "username")
|
op.drop_column('user', 'username')
|
||||||
op.drop_column("user", "bio")
|
op.drop_column('user', 'bio')
|
||||||
op.drop_column("user", "gender")
|
op.drop_column('user', 'gender')
|
||||||
op.drop_column("user", "date_of_birth")
|
op.drop_column('user', 'date_of_birth')
|
||||||
|
|||||||
@@ -18,38 +18,38 @@ import json
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "3e0e00844bb0"
|
revision: str = '3e0e00844bb0'
|
||||||
down_revision: Union[str, None] = "90ef40d4714e"
|
down_revision: Union[str, None] = '90ef40d4714e'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"knowledge_file",
|
'knowledge_file',
|
||||||
sa.Column("id", sa.Text(), primary_key=True),
|
sa.Column('id', sa.Text(), primary_key=True),
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
sa.Column('user_id', sa.Text(), nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"knowledge_id",
|
'knowledge_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey("knowledge.id", ondelete="CASCADE"),
|
sa.ForeignKey('knowledge.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"file_id",
|
'file_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey("file.id", ondelete="CASCADE"),
|
sa.ForeignKey('file.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||||
# indexes
|
# indexes
|
||||||
sa.Index("ix_knowledge_file_knowledge_id", "knowledge_id"),
|
sa.Index('ix_knowledge_file_knowledge_id', 'knowledge_id'),
|
||||||
sa.Index("ix_knowledge_file_file_id", "file_id"),
|
sa.Index('ix_knowledge_file_file_id', 'file_id'),
|
||||||
sa.Index("ix_knowledge_file_user_id", "user_id"),
|
sa.Index('ix_knowledge_file_user_id', 'user_id'),
|
||||||
# unique constraints
|
# unique constraints
|
||||||
sa.UniqueConstraint(
|
sa.UniqueConstraint(
|
||||||
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
|
'knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file'
|
||||||
), # prevent duplicate entries
|
), # prevent duplicate entries
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -57,35 +57,33 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
# 2. Read existing group with user_ids JSON column
|
# 2. Read existing group with user_ids JSON column
|
||||||
knowledge_table = sa.Table(
|
knowledge_table = sa.Table(
|
||||||
"knowledge",
|
'knowledge',
|
||||||
sa.MetaData(),
|
sa.MetaData(),
|
||||||
sa.Column("id", sa.Text()),
|
sa.Column('id', sa.Text()),
|
||||||
sa.Column("user_id", sa.Text()),
|
sa.Column('user_id', sa.Text()),
|
||||||
sa.Column("data", sa.JSON()), # JSON stored as text in SQLite + PG
|
sa.Column('data', sa.JSON()), # JSON stored as text in SQLite + PG
|
||||||
)
|
)
|
||||||
|
|
||||||
results = connection.execute(
|
results = connection.execute(
|
||||||
sa.select(
|
sa.select(knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data)
|
||||||
knowledge_table.c.id, knowledge_table.c.user_id, knowledge_table.c.data
|
|
||||||
)
|
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
# 3. Insert members into group_member table
|
# 3. Insert members into group_member table
|
||||||
kf_table = sa.Table(
|
kf_table = sa.Table(
|
||||||
"knowledge_file",
|
'knowledge_file',
|
||||||
sa.MetaData(),
|
sa.MetaData(),
|
||||||
sa.Column("id", sa.Text()),
|
sa.Column('id', sa.Text()),
|
||||||
sa.Column("user_id", sa.Text()),
|
sa.Column('user_id', sa.Text()),
|
||||||
sa.Column("knowledge_id", sa.Text()),
|
sa.Column('knowledge_id', sa.Text()),
|
||||||
sa.Column("file_id", sa.Text()),
|
sa.Column('file_id', sa.Text()),
|
||||||
sa.Column("created_at", sa.BigInteger()),
|
sa.Column('created_at', sa.BigInteger()),
|
||||||
sa.Column("updated_at", sa.BigInteger()),
|
sa.Column('updated_at', sa.BigInteger()),
|
||||||
)
|
)
|
||||||
|
|
||||||
file_table = sa.Table(
|
file_table = sa.Table(
|
||||||
"file",
|
'file',
|
||||||
sa.MetaData(),
|
sa.MetaData(),
|
||||||
sa.Column("id", sa.Text()),
|
sa.Column('id', sa.Text()),
|
||||||
)
|
)
|
||||||
|
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
@@ -102,50 +100,48 @@ def upgrade() -> None:
|
|||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
file_ids = data.get("file_ids", [])
|
file_ids = data.get('file_ids', [])
|
||||||
|
|
||||||
for file_id in file_ids:
|
for file_id in file_ids:
|
||||||
file_exists = connection.execute(
|
file_exists = connection.execute(sa.select(file_table.c.id).where(file_table.c.id == file_id)).fetchone()
|
||||||
sa.select(file_table.c.id).where(file_table.c.id == file_id)
|
|
||||||
).fetchone()
|
|
||||||
|
|
||||||
if not file_exists:
|
if not file_exists:
|
||||||
continue # skip non-existing files
|
continue # skip non-existing files
|
||||||
|
|
||||||
row = {
|
row = {
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"knowledge_id": knowledge_id,
|
'knowledge_id': knowledge_id,
|
||||||
"file_id": file_id,
|
'file_id': file_id,
|
||||||
"created_at": now,
|
'created_at': now,
|
||||||
"updated_at": now,
|
'updated_at': now,
|
||||||
}
|
}
|
||||||
connection.execute(kf_table.insert().values(**row))
|
connection.execute(kf_table.insert().values(**row))
|
||||||
|
|
||||||
with op.batch_alter_table("knowledge") as batch:
|
with op.batch_alter_table('knowledge') as batch:
|
||||||
batch.drop_column("data")
|
batch.drop_column('data')
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# 1. Add back the old data column
|
# 1. Add back the old data column
|
||||||
op.add_column("knowledge", sa.Column("data", sa.JSON(), nullable=True))
|
op.add_column('knowledge', sa.Column('data', sa.JSON(), nullable=True))
|
||||||
|
|
||||||
connection = op.get_bind()
|
connection = op.get_bind()
|
||||||
|
|
||||||
# 2. Read knowledge_file entries and reconstruct data JSON
|
# 2. Read knowledge_file entries and reconstruct data JSON
|
||||||
knowledge_table = sa.Table(
|
knowledge_table = sa.Table(
|
||||||
"knowledge",
|
'knowledge',
|
||||||
sa.MetaData(),
|
sa.MetaData(),
|
||||||
sa.Column("id", sa.Text()),
|
sa.Column('id', sa.Text()),
|
||||||
sa.Column("data", sa.JSON()),
|
sa.Column('data', sa.JSON()),
|
||||||
)
|
)
|
||||||
|
|
||||||
kf_table = sa.Table(
|
kf_table = sa.Table(
|
||||||
"knowledge_file",
|
'knowledge_file',
|
||||||
sa.MetaData(),
|
sa.MetaData(),
|
||||||
sa.Column("id", sa.Text()),
|
sa.Column('id', sa.Text()),
|
||||||
sa.Column("knowledge_id", sa.Text()),
|
sa.Column('knowledge_id', sa.Text()),
|
||||||
sa.Column("file_id", sa.Text()),
|
sa.Column('file_id', sa.Text()),
|
||||||
)
|
)
|
||||||
|
|
||||||
results = connection.execute(sa.select(knowledge_table.c.id)).fetchall()
|
results = connection.execute(sa.select(knowledge_table.c.id)).fetchall()
|
||||||
@@ -157,13 +153,9 @@ def downgrade() -> None:
|
|||||||
|
|
||||||
file_ids_list = [fid for (fid,) in file_ids]
|
file_ids_list = [fid for (fid,) in file_ids]
|
||||||
|
|
||||||
data_json = {"file_ids": file_ids_list}
|
data_json = {'file_ids': file_ids_list}
|
||||||
|
|
||||||
connection.execute(
|
connection.execute(knowledge_table.update().where(knowledge_table.c.id == knowledge_id).values(data=data_json))
|
||||||
knowledge_table.update()
|
|
||||||
.where(knowledge_table.c.id == knowledge_id)
|
|
||||||
.values(data=data_json)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Drop the knowledge_file table
|
# 3. Drop the knowledge_file table
|
||||||
op.drop_table("knowledge_file")
|
op.drop_table('knowledge_file')
|
||||||
|
|||||||
@@ -9,56 +9,56 @@ Create Date: 2024-10-23 03:00:00.000000
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
revision = "4ace53fd72c8"
|
revision = '4ace53fd72c8'
|
||||||
down_revision = "af906e964978"
|
down_revision = 'af906e964978'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
# Perform safe alterations using batch operation
|
# Perform safe alterations using batch operation
|
||||||
with op.batch_alter_table("folder", schema=None) as batch_op:
|
with op.batch_alter_table('folder', schema=None) as batch_op:
|
||||||
# Step 1: Remove server defaults for created_at and updated_at
|
# Step 1: Remove server defaults for created_at and updated_at
|
||||||
batch_op.alter_column(
|
batch_op.alter_column(
|
||||||
"created_at",
|
'created_at',
|
||||||
server_default=None, # Removing server default
|
server_default=None, # Removing server default
|
||||||
)
|
)
|
||||||
batch_op.alter_column(
|
batch_op.alter_column(
|
||||||
"updated_at",
|
'updated_at',
|
||||||
server_default=None, # Removing server default
|
server_default=None, # Removing server default
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 2: Change the column types to BigInteger for created_at
|
# Step 2: Change the column types to BigInteger for created_at
|
||||||
batch_op.alter_column(
|
batch_op.alter_column(
|
||||||
"created_at",
|
'created_at',
|
||||||
type_=sa.BigInteger(),
|
type_=sa.BigInteger(),
|
||||||
existing_type=sa.DateTime(),
|
existing_type=sa.DateTime(),
|
||||||
existing_nullable=False,
|
existing_nullable=False,
|
||||||
postgresql_using="extract(epoch from created_at)::bigint", # Conversion for PostgreSQL
|
postgresql_using='extract(epoch from created_at)::bigint', # Conversion for PostgreSQL
|
||||||
)
|
)
|
||||||
|
|
||||||
# Change the column types to BigInteger for updated_at
|
# Change the column types to BigInteger for updated_at
|
||||||
batch_op.alter_column(
|
batch_op.alter_column(
|
||||||
"updated_at",
|
'updated_at',
|
||||||
type_=sa.BigInteger(),
|
type_=sa.BigInteger(),
|
||||||
existing_type=sa.DateTime(),
|
existing_type=sa.DateTime(),
|
||||||
existing_nullable=False,
|
existing_nullable=False,
|
||||||
postgresql_using="extract(epoch from updated_at)::bigint", # Conversion for PostgreSQL
|
postgresql_using='extract(epoch from updated_at)::bigint', # Conversion for PostgreSQL
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# Downgrade: Convert columns back to DateTime and restore defaults
|
# Downgrade: Convert columns back to DateTime and restore defaults
|
||||||
with op.batch_alter_table("folder", schema=None) as batch_op:
|
with op.batch_alter_table('folder', schema=None) as batch_op:
|
||||||
batch_op.alter_column(
|
batch_op.alter_column(
|
||||||
"created_at",
|
'created_at',
|
||||||
type_=sa.DateTime(),
|
type_=sa.DateTime(),
|
||||||
existing_type=sa.BigInteger(),
|
existing_type=sa.BigInteger(),
|
||||||
existing_nullable=False,
|
existing_nullable=False,
|
||||||
server_default=sa.func.now(), # Restoring server default on downgrade
|
server_default=sa.func.now(), # Restoring server default on downgrade
|
||||||
)
|
)
|
||||||
batch_op.alter_column(
|
batch_op.alter_column(
|
||||||
"updated_at",
|
'updated_at',
|
||||||
type_=sa.DateTime(),
|
type_=sa.DateTime(),
|
||||||
existing_type=sa.BigInteger(),
|
existing_type=sa.BigInteger(),
|
||||||
existing_nullable=False,
|
existing_nullable=False,
|
||||||
|
|||||||
@@ -9,40 +9,40 @@ Create Date: 2024-12-22 03:00:00.000000
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
revision = "57c599a3cb57"
|
revision = '57c599a3cb57'
|
||||||
down_revision = "922e7a387820"
|
down_revision = '922e7a387820'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"channel",
|
'channel',
|
||||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||||
sa.Column("user_id", sa.Text()),
|
sa.Column('user_id', sa.Text()),
|
||||||
sa.Column("name", sa.Text()),
|
sa.Column('name', sa.Text()),
|
||||||
sa.Column("description", sa.Text(), nullable=True),
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
sa.Column("data", sa.JSON(), nullable=True),
|
sa.Column('data', sa.JSON(), nullable=True),
|
||||||
sa.Column("meta", sa.JSON(), nullable=True),
|
sa.Column('meta', sa.JSON(), nullable=True),
|
||||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"message",
|
'message',
|
||||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||||
sa.Column("user_id", sa.Text()),
|
sa.Column('user_id', sa.Text()),
|
||||||
sa.Column("channel_id", sa.Text(), nullable=True),
|
sa.Column('channel_id', sa.Text(), nullable=True),
|
||||||
sa.Column("content", sa.Text()),
|
sa.Column('content', sa.Text()),
|
||||||
sa.Column("data", sa.JSON(), nullable=True),
|
sa.Column('data', sa.JSON(), nullable=True),
|
||||||
sa.Column("meta", sa.JSON(), nullable=True),
|
sa.Column('meta', sa.JSON(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table("channel")
|
op.drop_table('channel')
|
||||||
|
|
||||||
op.drop_table("message")
|
op.drop_table('message')
|
||||||
|
|||||||
@@ -13,41 +13,39 @@ import sqlalchemy as sa
|
|||||||
import open_webui.internal.db
|
import open_webui.internal.db
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "6283dc0e4d8d"
|
revision: str = '6283dc0e4d8d'
|
||||||
down_revision: Union[str, None] = "3e0e00844bb0"
|
down_revision: Union[str, None] = '3e0e00844bb0'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"channel_file",
|
'channel_file',
|
||||||
sa.Column("id", sa.Text(), primary_key=True),
|
sa.Column('id', sa.Text(), primary_key=True),
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
sa.Column('user_id', sa.Text(), nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"channel_id",
|
'channel_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey("channel.id", ondelete="CASCADE"),
|
sa.ForeignKey('channel.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"file_id",
|
'file_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey("file.id", ondelete="CASCADE"),
|
sa.ForeignKey('file.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||||
# indexes
|
# indexes
|
||||||
sa.Index("ix_channel_file_channel_id", "channel_id"),
|
sa.Index('ix_channel_file_channel_id', 'channel_id'),
|
||||||
sa.Index("ix_channel_file_file_id", "file_id"),
|
sa.Index('ix_channel_file_file_id', 'file_id'),
|
||||||
sa.Index("ix_channel_file_user_id", "user_id"),
|
sa.Index('ix_channel_file_user_id', 'user_id'),
|
||||||
# unique constraints
|
# unique constraints
|
||||||
sa.UniqueConstraint(
|
sa.UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'), # prevent duplicate entries
|
||||||
"channel_id", "file_id", name="uq_channel_file_channel_file"
|
|
||||||
), # prevent duplicate entries
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_table("channel_file")
|
op.drop_table('channel_file')
|
||||||
|
|||||||
@@ -11,37 +11,37 @@ import sqlalchemy as sa
|
|||||||
from sqlalchemy.sql import table, column, select
|
from sqlalchemy.sql import table, column, select
|
||||||
import json
|
import json
|
||||||
|
|
||||||
revision = "6a39f3d8e55c"
|
revision = '6a39f3d8e55c'
|
||||||
down_revision = "c0fbf31ca0db"
|
down_revision = 'c0fbf31ca0db'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
# Creating the 'knowledge' table
|
# Creating the 'knowledge' table
|
||||||
print("Creating knowledge table")
|
print('Creating knowledge table')
|
||||||
knowledge_table = op.create_table(
|
knowledge_table = op.create_table(
|
||||||
"knowledge",
|
'knowledge',
|
||||||
sa.Column("id", sa.Text(), primary_key=True),
|
sa.Column('id', sa.Text(), primary_key=True),
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
sa.Column('user_id', sa.Text(), nullable=False),
|
||||||
sa.Column("name", sa.Text(), nullable=False),
|
sa.Column('name', sa.Text(), nullable=False),
|
||||||
sa.Column("description", sa.Text(), nullable=True),
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
sa.Column("data", sa.JSON(), nullable=True),
|
sa.Column('data', sa.JSON(), nullable=True),
|
||||||
sa.Column("meta", sa.JSON(), nullable=True),
|
sa.Column('meta', sa.JSON(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Migrating data from document table to knowledge table")
|
print('Migrating data from document table to knowledge table')
|
||||||
# Representation of the existing 'document' table
|
# Representation of the existing 'document' table
|
||||||
document_table = table(
|
document_table = table(
|
||||||
"document",
|
'document',
|
||||||
column("collection_name", sa.String()),
|
column('collection_name', sa.String()),
|
||||||
column("user_id", sa.String()),
|
column('user_id', sa.String()),
|
||||||
column("name", sa.String()),
|
column('name', sa.String()),
|
||||||
column("title", sa.Text()),
|
column('title', sa.Text()),
|
||||||
column("content", sa.Text()),
|
column('content', sa.Text()),
|
||||||
column("timestamp", sa.BigInteger()),
|
column('timestamp', sa.BigInteger()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Select all from existing document table
|
# Select all from existing document table
|
||||||
@@ -64,9 +64,9 @@ def upgrade():
|
|||||||
user_id=doc.user_id,
|
user_id=doc.user_id,
|
||||||
description=doc.name,
|
description=doc.name,
|
||||||
meta={
|
meta={
|
||||||
"legacy": True,
|
'legacy': True,
|
||||||
"document": True,
|
'document': True,
|
||||||
"tags": json.loads(doc.content or "{}").get("tags", []),
|
'tags': json.loads(doc.content or '{}').get('tags', []),
|
||||||
},
|
},
|
||||||
name=doc.title,
|
name=doc.title,
|
||||||
created_at=doc.timestamp,
|
created_at=doc.timestamp,
|
||||||
@@ -76,4 +76,4 @@ def upgrade():
|
|||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table("knowledge")
|
op.drop_table('knowledge')
|
||||||
|
|||||||
@@ -9,18 +9,18 @@ Create Date: 2024-12-23 03:00:00.000000
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
revision = "7826ab40b532"
|
revision = '7826ab40b532'
|
||||||
down_revision = "57c599a3cb57"
|
down_revision = '57c599a3cb57'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"file",
|
'file',
|
||||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_column("file", "access_control")
|
op.drop_column('file', 'access_control')
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from open_webui.internal.db import JSONField
|
|||||||
from open_webui.migrations.util import get_existing_tables
|
from open_webui.migrations.util import get_existing_tables
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "7e5b5dc7342b"
|
revision: str = '7e5b5dc7342b'
|
||||||
down_revision: Union[str, None] = None
|
down_revision: Union[str, None] = None
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
@@ -26,179 +26,179 @@ def upgrade() -> None:
|
|||||||
existing_tables = set(get_existing_tables())
|
existing_tables = set(get_existing_tables())
|
||||||
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
if "auth" not in existing_tables:
|
if 'auth' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"auth",
|
'auth',
|
||||||
sa.Column("id", sa.String(), nullable=False),
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
sa.Column("email", sa.String(), nullable=True),
|
sa.Column('email', sa.String(), nullable=True),
|
||||||
sa.Column("password", sa.Text(), nullable=True),
|
sa.Column('password', sa.Text(), nullable=True),
|
||||||
sa.Column("active", sa.Boolean(), nullable=True),
|
sa.Column('active', sa.Boolean(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "chat" not in existing_tables:
|
if 'chat' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"chat",
|
'chat',
|
||||||
sa.Column("id", sa.String(), nullable=False),
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
sa.Column("user_id", sa.String(), nullable=True),
|
sa.Column('user_id', sa.String(), nullable=True),
|
||||||
sa.Column("title", sa.Text(), nullable=True),
|
sa.Column('title', sa.Text(), nullable=True),
|
||||||
sa.Column("chat", sa.Text(), nullable=True),
|
sa.Column('chat', sa.Text(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("share_id", sa.Text(), nullable=True),
|
sa.Column('share_id', sa.Text(), nullable=True),
|
||||||
sa.Column("archived", sa.Boolean(), nullable=True),
|
sa.Column('archived', sa.Boolean(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
sa.UniqueConstraint("share_id"),
|
sa.UniqueConstraint('share_id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "chatidtag" not in existing_tables:
|
if 'chatidtag' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"chatidtag",
|
'chatidtag',
|
||||||
sa.Column("id", sa.String(), nullable=False),
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
sa.Column("tag_name", sa.String(), nullable=True),
|
sa.Column('tag_name', sa.String(), nullable=True),
|
||||||
sa.Column("chat_id", sa.String(), nullable=True),
|
sa.Column('chat_id', sa.String(), nullable=True),
|
||||||
sa.Column("user_id", sa.String(), nullable=True),
|
sa.Column('user_id', sa.String(), nullable=True),
|
||||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "document" not in existing_tables:
|
if 'document' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"document",
|
'document',
|
||||||
sa.Column("collection_name", sa.String(), nullable=False),
|
sa.Column('collection_name', sa.String(), nullable=False),
|
||||||
sa.Column("name", sa.String(), nullable=True),
|
sa.Column('name', sa.String(), nullable=True),
|
||||||
sa.Column("title", sa.Text(), nullable=True),
|
sa.Column('title', sa.Text(), nullable=True),
|
||||||
sa.Column("filename", sa.Text(), nullable=True),
|
sa.Column('filename', sa.Text(), nullable=True),
|
||||||
sa.Column("content", sa.Text(), nullable=True),
|
sa.Column('content', sa.Text(), nullable=True),
|
||||||
sa.Column("user_id", sa.String(), nullable=True),
|
sa.Column('user_id', sa.String(), nullable=True),
|
||||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("collection_name"),
|
sa.PrimaryKeyConstraint('collection_name'),
|
||||||
sa.UniqueConstraint("name"),
|
sa.UniqueConstraint('name'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "file" not in existing_tables:
|
if 'file' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"file",
|
'file',
|
||||||
sa.Column("id", sa.String(), nullable=False),
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
sa.Column("user_id", sa.String(), nullable=True),
|
sa.Column('user_id', sa.String(), nullable=True),
|
||||||
sa.Column("filename", sa.Text(), nullable=True),
|
sa.Column('filename', sa.Text(), nullable=True),
|
||||||
sa.Column("meta", JSONField(), nullable=True),
|
sa.Column('meta', JSONField(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "function" not in existing_tables:
|
if 'function' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"function",
|
'function',
|
||||||
sa.Column("id", sa.String(), nullable=False),
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
sa.Column("user_id", sa.String(), nullable=True),
|
sa.Column('user_id', sa.String(), nullable=True),
|
||||||
sa.Column("name", sa.Text(), nullable=True),
|
sa.Column('name', sa.Text(), nullable=True),
|
||||||
sa.Column("type", sa.Text(), nullable=True),
|
sa.Column('type', sa.Text(), nullable=True),
|
||||||
sa.Column("content", sa.Text(), nullable=True),
|
sa.Column('content', sa.Text(), nullable=True),
|
||||||
sa.Column("meta", JSONField(), nullable=True),
|
sa.Column('meta', JSONField(), nullable=True),
|
||||||
sa.Column("valves", JSONField(), nullable=True),
|
sa.Column('valves', JSONField(), nullable=True),
|
||||||
sa.Column("is_active", sa.Boolean(), nullable=True),
|
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||||
sa.Column("is_global", sa.Boolean(), nullable=True),
|
sa.Column('is_global', sa.Boolean(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "memory" not in existing_tables:
|
if 'memory' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"memory",
|
'memory',
|
||||||
sa.Column("id", sa.String(), nullable=False),
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
sa.Column("user_id", sa.String(), nullable=True),
|
sa.Column('user_id', sa.String(), nullable=True),
|
||||||
sa.Column("content", sa.Text(), nullable=True),
|
sa.Column('content', sa.Text(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "model" not in existing_tables:
|
if 'model' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"model",
|
'model',
|
||||||
sa.Column("id", sa.Text(), nullable=False),
|
sa.Column('id', sa.Text(), nullable=False),
|
||||||
sa.Column("user_id", sa.Text(), nullable=True),
|
sa.Column('user_id', sa.Text(), nullable=True),
|
||||||
sa.Column("base_model_id", sa.Text(), nullable=True),
|
sa.Column('base_model_id', sa.Text(), nullable=True),
|
||||||
sa.Column("name", sa.Text(), nullable=True),
|
sa.Column('name', sa.Text(), nullable=True),
|
||||||
sa.Column("params", JSONField(), nullable=True),
|
sa.Column('params', JSONField(), nullable=True),
|
||||||
sa.Column("meta", JSONField(), nullable=True),
|
sa.Column('meta', JSONField(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "prompt" not in existing_tables:
|
if 'prompt' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"prompt",
|
'prompt',
|
||||||
sa.Column("command", sa.String(), nullable=False),
|
sa.Column('command', sa.String(), nullable=False),
|
||||||
sa.Column("user_id", sa.String(), nullable=True),
|
sa.Column('user_id', sa.String(), nullable=True),
|
||||||
sa.Column("title", sa.Text(), nullable=True),
|
sa.Column('title', sa.Text(), nullable=True),
|
||||||
sa.Column("content", sa.Text(), nullable=True),
|
sa.Column('content', sa.Text(), nullable=True),
|
||||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("command"),
|
sa.PrimaryKeyConstraint('command'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "tag" not in existing_tables:
|
if 'tag' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"tag",
|
'tag',
|
||||||
sa.Column("id", sa.String(), nullable=False),
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
sa.Column("name", sa.String(), nullable=True),
|
sa.Column('name', sa.String(), nullable=True),
|
||||||
sa.Column("user_id", sa.String(), nullable=True),
|
sa.Column('user_id', sa.String(), nullable=True),
|
||||||
sa.Column("data", sa.Text(), nullable=True),
|
sa.Column('data', sa.Text(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "tool" not in existing_tables:
|
if 'tool' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"tool",
|
'tool',
|
||||||
sa.Column("id", sa.String(), nullable=False),
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
sa.Column("user_id", sa.String(), nullable=True),
|
sa.Column('user_id', sa.String(), nullable=True),
|
||||||
sa.Column("name", sa.Text(), nullable=True),
|
sa.Column('name', sa.Text(), nullable=True),
|
||||||
sa.Column("content", sa.Text(), nullable=True),
|
sa.Column('content', sa.Text(), nullable=True),
|
||||||
sa.Column("specs", JSONField(), nullable=True),
|
sa.Column('specs', JSONField(), nullable=True),
|
||||||
sa.Column("meta", JSONField(), nullable=True),
|
sa.Column('meta', JSONField(), nullable=True),
|
||||||
sa.Column("valves", JSONField(), nullable=True),
|
sa.Column('valves', JSONField(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if "user" not in existing_tables:
|
if 'user' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"user",
|
'user',
|
||||||
sa.Column("id", sa.String(), nullable=False),
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
sa.Column("name", sa.String(), nullable=True),
|
sa.Column('name', sa.String(), nullable=True),
|
||||||
sa.Column("email", sa.String(), nullable=True),
|
sa.Column('email', sa.String(), nullable=True),
|
||||||
sa.Column("role", sa.String(), nullable=True),
|
sa.Column('role', sa.String(), nullable=True),
|
||||||
sa.Column("profile_image_url", sa.Text(), nullable=True),
|
sa.Column('profile_image_url', sa.Text(), nullable=True),
|
||||||
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
|
sa.Column('last_active_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("api_key", sa.String(), nullable=True),
|
sa.Column('api_key', sa.String(), nullable=True),
|
||||||
sa.Column("settings", JSONField(), nullable=True),
|
sa.Column('settings', JSONField(), nullable=True),
|
||||||
sa.Column("info", JSONField(), nullable=True),
|
sa.Column('info', JSONField(), nullable=True),
|
||||||
sa.Column("oauth_sub", sa.Text(), nullable=True),
|
sa.Column('oauth_sub', sa.Text(), nullable=True),
|
||||||
sa.PrimaryKeyConstraint("id"),
|
sa.PrimaryKeyConstraint('id'),
|
||||||
sa.UniqueConstraint("api_key"),
|
sa.UniqueConstraint('api_key'),
|
||||||
sa.UniqueConstraint("oauth_sub"),
|
sa.UniqueConstraint('oauth_sub'),
|
||||||
)
|
)
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_table("user")
|
op.drop_table('user')
|
||||||
op.drop_table("tool")
|
op.drop_table('tool')
|
||||||
op.drop_table("tag")
|
op.drop_table('tag')
|
||||||
op.drop_table("prompt")
|
op.drop_table('prompt')
|
||||||
op.drop_table("model")
|
op.drop_table('model')
|
||||||
op.drop_table("memory")
|
op.drop_table('memory')
|
||||||
op.drop_table("function")
|
op.drop_table('function')
|
||||||
op.drop_table("file")
|
op.drop_table('file')
|
||||||
op.drop_table("document")
|
op.drop_table('document')
|
||||||
op.drop_table("chatidtag")
|
op.drop_table('chatidtag')
|
||||||
op.drop_table("chat")
|
op.drop_table('chat')
|
||||||
op.drop_table("auth")
|
op.drop_table('auth')
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|||||||
@@ -13,36 +13,34 @@ import sqlalchemy as sa
|
|||||||
import open_webui.internal.db
|
import open_webui.internal.db
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "81cc2ce44d79"
|
revision: str = '81cc2ce44d79'
|
||||||
down_revision: Union[str, None] = "6283dc0e4d8d"
|
down_revision: Union[str, None] = '6283dc0e4d8d'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Add message_id column to channel_file table
|
# Add message_id column to channel_file table
|
||||||
with op.batch_alter_table("channel_file", schema=None) as batch_op:
|
with op.batch_alter_table('channel_file', schema=None) as batch_op:
|
||||||
batch_op.add_column(
|
batch_op.add_column(
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"message_id",
|
'message_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey(
|
sa.ForeignKey('message.id', ondelete='CASCADE', name='fk_channel_file_message_id'),
|
||||||
"message.id", ondelete="CASCADE", name="fk_channel_file_message_id"
|
|
||||||
),
|
|
||||||
nullable=True,
|
nullable=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add data column to knowledge table
|
# Add data column to knowledge table
|
||||||
with op.batch_alter_table("knowledge", schema=None) as batch_op:
|
with op.batch_alter_table('knowledge', schema=None) as batch_op:
|
||||||
batch_op.add_column(sa.Column("data", sa.JSON(), nullable=True))
|
batch_op.add_column(sa.Column('data', sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# Remove message_id column from channel_file table
|
# Remove message_id column from channel_file table
|
||||||
with op.batch_alter_table("channel_file", schema=None) as batch_op:
|
with op.batch_alter_table('channel_file', schema=None) as batch_op:
|
||||||
batch_op.drop_column("message_id")
|
batch_op.drop_column('message_id')
|
||||||
|
|
||||||
# Remove data column from knowledge table
|
# Remove data column from knowledge table
|
||||||
with op.batch_alter_table("knowledge", schema=None) as batch_op:
|
with op.batch_alter_table('knowledge', schema=None) as batch_op:
|
||||||
batch_op.drop_column("data")
|
batch_op.drop_column('data')
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ import sqlalchemy as sa
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
revision: str = "8452d01d26d7"
|
revision: str = '8452d01d26d7'
|
||||||
down_revision: Union[str, None] = "374d2f66af06"
|
down_revision: Union[str, None] = '374d2f66af06'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
@@ -51,74 +51,68 @@ def _flush_batch(conn, table, batch):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
sp.rollback()
|
sp.rollback()
|
||||||
failed += 1
|
failed += 1
|
||||||
log.warning(f"Failed to insert message {msg['id']}: {e}")
|
log.warning(f'Failed to insert message {msg["id"]}: {e}')
|
||||||
return inserted, failed
|
return inserted, failed
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Step 1: Create table
|
# Step 1: Create table
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"chat_message",
|
'chat_message',
|
||||||
sa.Column("id", sa.Text(), primary_key=True),
|
sa.Column('id', sa.Text(), primary_key=True),
|
||||||
sa.Column("chat_id", sa.Text(), nullable=False, index=True),
|
sa.Column('chat_id', sa.Text(), nullable=False, index=True),
|
||||||
sa.Column("user_id", sa.Text(), index=True),
|
sa.Column('user_id', sa.Text(), index=True),
|
||||||
sa.Column("role", sa.Text(), nullable=False),
|
sa.Column('role', sa.Text(), nullable=False),
|
||||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
sa.Column('parent_id', sa.Text(), nullable=True),
|
||||||
sa.Column("content", sa.JSON(), nullable=True),
|
sa.Column('content', sa.JSON(), nullable=True),
|
||||||
sa.Column("output", sa.JSON(), nullable=True),
|
sa.Column('output', sa.JSON(), nullable=True),
|
||||||
sa.Column("model_id", sa.Text(), nullable=True, index=True),
|
sa.Column('model_id', sa.Text(), nullable=True, index=True),
|
||||||
sa.Column("files", sa.JSON(), nullable=True),
|
sa.Column('files', sa.JSON(), nullable=True),
|
||||||
sa.Column("sources", sa.JSON(), nullable=True),
|
sa.Column('sources', sa.JSON(), nullable=True),
|
||||||
sa.Column("embeds", sa.JSON(), nullable=True),
|
sa.Column('embeds', sa.JSON(), nullable=True),
|
||||||
sa.Column("done", sa.Boolean(), default=True),
|
sa.Column('done', sa.Boolean(), default=True),
|
||||||
sa.Column("status_history", sa.JSON(), nullable=True),
|
sa.Column('status_history', sa.JSON(), nullable=True),
|
||||||
sa.Column("error", sa.JSON(), nullable=True),
|
sa.Column('error', sa.JSON(), nullable=True),
|
||||||
sa.Column("usage", sa.JSON(), nullable=True),
|
sa.Column('usage', sa.JSON(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), index=True),
|
sa.Column('created_at', sa.BigInteger(), index=True),
|
||||||
sa.Column("updated_at", sa.BigInteger()),
|
sa.Column('updated_at', sa.BigInteger()),
|
||||||
sa.ForeignKeyConstraint(["chat_id"], ["chat.id"], ondelete="CASCADE"),
|
sa.ForeignKeyConstraint(['chat_id'], ['chat.id'], ondelete='CASCADE'),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create composite indexes
|
# Create composite indexes
|
||||||
op.create_index(
|
op.create_index('chat_message_chat_parent_idx', 'chat_message', ['chat_id', 'parent_id'])
|
||||||
"chat_message_chat_parent_idx", "chat_message", ["chat_id", "parent_id"]
|
op.create_index('chat_message_model_created_idx', 'chat_message', ['model_id', 'created_at'])
|
||||||
)
|
op.create_index('chat_message_user_created_idx', 'chat_message', ['user_id', 'created_at'])
|
||||||
op.create_index(
|
|
||||||
"chat_message_model_created_idx", "chat_message", ["model_id", "created_at"]
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"chat_message_user_created_idx", "chat_message", ["user_id", "created_at"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 2: Backfill from existing chats
|
# Step 2: Backfill from existing chats
|
||||||
conn = op.get_bind()
|
conn = op.get_bind()
|
||||||
|
|
||||||
chat_table = sa.table(
|
chat_table = sa.table(
|
||||||
"chat",
|
'chat',
|
||||||
sa.column("id", sa.Text()),
|
sa.column('id', sa.Text()),
|
||||||
sa.column("user_id", sa.Text()),
|
sa.column('user_id', sa.Text()),
|
||||||
sa.column("chat", sa.JSON()),
|
sa.column('chat', sa.JSON()),
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_message_table = sa.table(
|
chat_message_table = sa.table(
|
||||||
"chat_message",
|
'chat_message',
|
||||||
sa.column("id", sa.Text()),
|
sa.column('id', sa.Text()),
|
||||||
sa.column("chat_id", sa.Text()),
|
sa.column('chat_id', sa.Text()),
|
||||||
sa.column("user_id", sa.Text()),
|
sa.column('user_id', sa.Text()),
|
||||||
sa.column("role", sa.Text()),
|
sa.column('role', sa.Text()),
|
||||||
sa.column("parent_id", sa.Text()),
|
sa.column('parent_id', sa.Text()),
|
||||||
sa.column("content", sa.JSON()),
|
sa.column('content', sa.JSON()),
|
||||||
sa.column("output", sa.JSON()),
|
sa.column('output', sa.JSON()),
|
||||||
sa.column("model_id", sa.Text()),
|
sa.column('model_id', sa.Text()),
|
||||||
sa.column("files", sa.JSON()),
|
sa.column('files', sa.JSON()),
|
||||||
sa.column("sources", sa.JSON()),
|
sa.column('sources', sa.JSON()),
|
||||||
sa.column("embeds", sa.JSON()),
|
sa.column('embeds', sa.JSON()),
|
||||||
sa.column("done", sa.Boolean()),
|
sa.column('done', sa.Boolean()),
|
||||||
sa.column("status_history", sa.JSON()),
|
sa.column('status_history', sa.JSON()),
|
||||||
sa.column("error", sa.JSON()),
|
sa.column('error', sa.JSON()),
|
||||||
sa.column("usage", sa.JSON()),
|
sa.column('usage', sa.JSON()),
|
||||||
sa.column("created_at", sa.BigInteger()),
|
sa.column('created_at', sa.BigInteger()),
|
||||||
sa.column("updated_at", sa.BigInteger()),
|
sa.column('updated_at', sa.BigInteger()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Stream rows instead of loading all into memory:
|
# Stream rows instead of loading all into memory:
|
||||||
@@ -126,7 +120,7 @@ def upgrade() -> None:
|
|||||||
# - stream_results: enables server-side cursors on PostgreSQL (no-op on SQLite)
|
# - stream_results: enables server-side cursors on PostgreSQL (no-op on SQLite)
|
||||||
result = conn.execute(
|
result = conn.execute(
|
||||||
sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat)
|
sa.select(chat_table.c.id, chat_table.c.user_id, chat_table.c.chat)
|
||||||
.where(~chat_table.c.user_id.like("shared-%"))
|
.where(~chat_table.c.user_id.like('shared-%'))
|
||||||
.execution_options(yield_per=1000, stream_results=True)
|
.execution_options(yield_per=1000, stream_results=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -150,11 +144,11 @@ def upgrade() -> None:
|
|||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
history = chat_data.get("history", {})
|
history = chat_data.get('history', {})
|
||||||
if not isinstance(history, dict):
|
if not isinstance(history, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
messages = history.get("messages", {})
|
messages = history.get('messages', {})
|
||||||
if not isinstance(messages, dict):
|
if not isinstance(messages, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -162,11 +156,11 @@ def upgrade() -> None:
|
|||||||
if not isinstance(message, dict):
|
if not isinstance(message, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
role = message.get("role")
|
role = message.get('role')
|
||||||
if not role:
|
if not role:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
timestamp = message.get("timestamp", now)
|
timestamp = message.get('timestamp', now)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
timestamp = int(float(timestamp))
|
timestamp = int(float(timestamp))
|
||||||
@@ -182,37 +176,33 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
messages_batch.append(
|
messages_batch.append(
|
||||||
{
|
{
|
||||||
"id": f"{chat_id}-{message_id}",
|
'id': f'{chat_id}-{message_id}',
|
||||||
"chat_id": chat_id,
|
'chat_id': chat_id,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"role": role,
|
'role': role,
|
||||||
"parent_id": message.get("parentId"),
|
'parent_id': message.get('parentId'),
|
||||||
"content": message.get("content"),
|
'content': message.get('content'),
|
||||||
"output": message.get("output"),
|
'output': message.get('output'),
|
||||||
"model_id": message.get("model"),
|
'model_id': message.get('model'),
|
||||||
"files": message.get("files"),
|
'files': message.get('files'),
|
||||||
"sources": message.get("sources"),
|
'sources': message.get('sources'),
|
||||||
"embeds": message.get("embeds"),
|
'embeds': message.get('embeds'),
|
||||||
"done": message.get("done", True),
|
'done': message.get('done', True),
|
||||||
"status_history": message.get("statusHistory"),
|
'status_history': message.get('statusHistory'),
|
||||||
"error": message.get("error"),
|
'error': message.get('error'),
|
||||||
"usage": message.get("usage"),
|
'usage': message.get('usage'),
|
||||||
"created_at": timestamp,
|
'created_at': timestamp,
|
||||||
"updated_at": timestamp,
|
'updated_at': timestamp,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Flush batch when full
|
# Flush batch when full
|
||||||
if len(messages_batch) >= BATCH_SIZE:
|
if len(messages_batch) >= BATCH_SIZE:
|
||||||
inserted, failed = _flush_batch(
|
inserted, failed = _flush_batch(conn, chat_message_table, messages_batch)
|
||||||
conn, chat_message_table, messages_batch
|
|
||||||
)
|
|
||||||
total_inserted += inserted
|
total_inserted += inserted
|
||||||
total_failed += failed
|
total_failed += failed
|
||||||
if total_inserted % 50000 < BATCH_SIZE:
|
if total_inserted % 50000 < BATCH_SIZE:
|
||||||
log.info(
|
log.info(f'Migration progress: {total_inserted} messages inserted...')
|
||||||
f"Migration progress: {total_inserted} messages inserted..."
|
|
||||||
)
|
|
||||||
messages_batch.clear()
|
messages_batch.clear()
|
||||||
|
|
||||||
# Flush remaining messages
|
# Flush remaining messages
|
||||||
@@ -221,13 +211,11 @@ def upgrade() -> None:
|
|||||||
total_inserted += inserted
|
total_inserted += inserted
|
||||||
total_failed += failed
|
total_failed += failed
|
||||||
|
|
||||||
log.info(
|
log.info(f'Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)')
|
||||||
f"Backfilled {total_inserted} messages into chat_message table ({total_failed} failed)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_index("chat_message_user_created_idx", table_name="chat_message")
|
op.drop_index('chat_message_user_created_idx', table_name='chat_message')
|
||||||
op.drop_index("chat_message_model_created_idx", table_name="chat_message")
|
op.drop_index('chat_message_model_created_idx', table_name='chat_message')
|
||||||
op.drop_index("chat_message_chat_parent_idx", table_name="chat_message")
|
op.drop_index('chat_message_chat_parent_idx', table_name='chat_message')
|
||||||
op.drop_table("chat_message")
|
op.drop_table('chat_message')
|
||||||
|
|||||||
@@ -13,48 +13,46 @@ import sqlalchemy as sa
|
|||||||
import open_webui.internal.db
|
import open_webui.internal.db
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "90ef40d4714e"
|
revision: str = '90ef40d4714e'
|
||||||
down_revision: Union[str, None] = "b10670c03dd5"
|
down_revision: Union[str, None] = 'b10670c03dd5'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Update 'channel' table
|
# Update 'channel' table
|
||||||
op.add_column("channel", sa.Column("is_private", sa.Boolean(), nullable=True))
|
op.add_column('channel', sa.Column('is_private', sa.Boolean(), nullable=True))
|
||||||
|
|
||||||
op.add_column("channel", sa.Column("archived_at", sa.BigInteger(), nullable=True))
|
op.add_column('channel', sa.Column('archived_at', sa.BigInteger(), nullable=True))
|
||||||
op.add_column("channel", sa.Column("archived_by", sa.Text(), nullable=True))
|
op.add_column('channel', sa.Column('archived_by', sa.Text(), nullable=True))
|
||||||
|
|
||||||
op.add_column("channel", sa.Column("deleted_at", sa.BigInteger(), nullable=True))
|
op.add_column('channel', sa.Column('deleted_at', sa.BigInteger(), nullable=True))
|
||||||
op.add_column("channel", sa.Column("deleted_by", sa.Text(), nullable=True))
|
op.add_column('channel', sa.Column('deleted_by', sa.Text(), nullable=True))
|
||||||
|
|
||||||
op.add_column("channel", sa.Column("updated_by", sa.Text(), nullable=True))
|
op.add_column('channel', sa.Column('updated_by', sa.Text(), nullable=True))
|
||||||
|
|
||||||
# Update 'channel_member' table
|
# Update 'channel_member' table
|
||||||
op.add_column("channel_member", sa.Column("role", sa.Text(), nullable=True))
|
op.add_column('channel_member', sa.Column('role', sa.Text(), nullable=True))
|
||||||
op.add_column("channel_member", sa.Column("invited_by", sa.Text(), nullable=True))
|
op.add_column('channel_member', sa.Column('invited_by', sa.Text(), nullable=True))
|
||||||
op.add_column(
|
op.add_column('channel_member', sa.Column('invited_at', sa.BigInteger(), nullable=True))
|
||||||
"channel_member", sa.Column("invited_at", sa.BigInteger(), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create 'channel_webhook' table
|
# Create 'channel_webhook' table
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"channel_webhook",
|
'channel_webhook',
|
||||||
sa.Column("id", sa.Text(), primary_key=True, unique=True, nullable=False),
|
sa.Column('id', sa.Text(), primary_key=True, unique=True, nullable=False),
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
sa.Column('user_id', sa.Text(), nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"channel_id",
|
'channel_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey("channel.id", ondelete="CASCADE"),
|
sa.ForeignKey('channel.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column("name", sa.Text(), nullable=False),
|
sa.Column('name', sa.Text(), nullable=False),
|
||||||
sa.Column("profile_image_url", sa.Text(), nullable=True),
|
sa.Column('profile_image_url', sa.Text(), nullable=True),
|
||||||
sa.Column("token", sa.Text(), nullable=False),
|
sa.Column('token', sa.Text(), nullable=False),
|
||||||
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
|
sa.Column('last_used_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
@@ -62,19 +60,19 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# Downgrade 'channel' table
|
# Downgrade 'channel' table
|
||||||
op.drop_column("channel", "is_private")
|
op.drop_column('channel', 'is_private')
|
||||||
op.drop_column("channel", "archived_at")
|
op.drop_column('channel', 'archived_at')
|
||||||
op.drop_column("channel", "archived_by")
|
op.drop_column('channel', 'archived_by')
|
||||||
op.drop_column("channel", "deleted_at")
|
op.drop_column('channel', 'deleted_at')
|
||||||
op.drop_column("channel", "deleted_by")
|
op.drop_column('channel', 'deleted_by')
|
||||||
op.drop_column("channel", "updated_by")
|
op.drop_column('channel', 'updated_by')
|
||||||
|
|
||||||
# Downgrade 'channel_member' table
|
# Downgrade 'channel_member' table
|
||||||
op.drop_column("channel_member", "role")
|
op.drop_column('channel_member', 'role')
|
||||||
op.drop_column("channel_member", "invited_by")
|
op.drop_column('channel_member', 'invited_by')
|
||||||
op.drop_column("channel_member", "invited_at")
|
op.drop_column('channel_member', 'invited_at')
|
||||||
|
|
||||||
# Drop 'channel_webhook' table
|
# Drop 'channel_webhook' table
|
||||||
op.drop_table("channel_webhook")
|
op.drop_table('channel_webhook')
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -9,38 +9,38 @@ Create Date: 2024-11-14 03:00:00.000000
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
revision = "922e7a387820"
|
revision = '922e7a387820'
|
||||||
down_revision = "4ace53fd72c8"
|
down_revision = '4ace53fd72c8'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"group",
|
'group',
|
||||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||||
sa.Column("user_id", sa.Text(), nullable=True),
|
sa.Column('user_id', sa.Text(), nullable=True),
|
||||||
sa.Column("name", sa.Text(), nullable=True),
|
sa.Column('name', sa.Text(), nullable=True),
|
||||||
sa.Column("description", sa.Text(), nullable=True),
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
sa.Column("data", sa.JSON(), nullable=True),
|
sa.Column('data', sa.JSON(), nullable=True),
|
||||||
sa.Column("meta", sa.JSON(), nullable=True),
|
sa.Column('meta', sa.JSON(), nullable=True),
|
||||||
sa.Column("permissions", sa.JSON(), nullable=True),
|
sa.Column('permissions', sa.JSON(), nullable=True),
|
||||||
sa.Column("user_ids", sa.JSON(), nullable=True),
|
sa.Column('user_ids', sa.JSON(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add 'access_control' column to 'model' table
|
# Add 'access_control' column to 'model' table
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"model",
|
'model',
|
||||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add 'is_active' column to 'model' table
|
# Add 'is_active' column to 'model' table
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"model",
|
'model',
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"is_active",
|
'is_active',
|
||||||
sa.Boolean(),
|
sa.Boolean(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default=sa.sql.expression.true(),
|
server_default=sa.sql.expression.true(),
|
||||||
@@ -49,37 +49,37 @@ def upgrade():
|
|||||||
|
|
||||||
# Add 'access_control' column to 'knowledge' table
|
# Add 'access_control' column to 'knowledge' table
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"knowledge",
|
'knowledge',
|
||||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add 'access_control' column to 'prompt' table
|
# Add 'access_control' column to 'prompt' table
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"prompt",
|
'prompt',
|
||||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add 'access_control' column to 'tools' table
|
# Add 'access_control' column to 'tools' table
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"tool",
|
'tool',
|
||||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table("group")
|
op.drop_table('group')
|
||||||
|
|
||||||
# Drop 'access_control' column from 'model' table
|
# Drop 'access_control' column from 'model' table
|
||||||
op.drop_column("model", "access_control")
|
op.drop_column('model', 'access_control')
|
||||||
|
|
||||||
# Drop 'is_active' column from 'model' table
|
# Drop 'is_active' column from 'model' table
|
||||||
op.drop_column("model", "is_active")
|
op.drop_column('model', 'is_active')
|
||||||
|
|
||||||
# Drop 'access_control' column from 'knowledge' table
|
# Drop 'access_control' column from 'knowledge' table
|
||||||
op.drop_column("knowledge", "access_control")
|
op.drop_column('knowledge', 'access_control')
|
||||||
|
|
||||||
# Drop 'access_control' column from 'prompt' table
|
# Drop 'access_control' column from 'prompt' table
|
||||||
op.drop_column("prompt", "access_control")
|
op.drop_column('prompt', 'access_control')
|
||||||
|
|
||||||
# Drop 'access_control' column from 'tools' table
|
# Drop 'access_control' column from 'tools' table
|
||||||
op.drop_column("tool", "access_control")
|
op.drop_column('tool', 'access_control')
|
||||||
|
|||||||
@@ -9,25 +9,25 @@ Create Date: 2025-05-03 03:00:00.000000
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
revision = "9f0c9cd09105"
|
revision = '9f0c9cd09105'
|
||||||
down_revision = "3781e22d8b01"
|
down_revision = '3781e22d8b01'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"note",
|
'note',
|
||||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True, unique=True),
|
sa.Column('id', sa.Text(), nullable=False, primary_key=True, unique=True),
|
||||||
sa.Column("user_id", sa.Text(), nullable=True),
|
sa.Column('user_id', sa.Text(), nullable=True),
|
||||||
sa.Column("title", sa.Text(), nullable=True),
|
sa.Column('title', sa.Text(), nullable=True),
|
||||||
sa.Column("data", sa.JSON(), nullable=True),
|
sa.Column('data', sa.JSON(), nullable=True),
|
||||||
sa.Column("meta", sa.JSON(), nullable=True),
|
sa.Column('meta', sa.JSON(), nullable=True),
|
||||||
sa.Column("access_control", sa.JSON(), nullable=True),
|
sa.Column('access_control', sa.JSON(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table("note")
|
op.drop_table('note')
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ import sqlalchemy as sa
|
|||||||
|
|
||||||
from open_webui.migrations.util import get_existing_tables
|
from open_webui.migrations.util import get_existing_tables
|
||||||
|
|
||||||
revision: str = "a1b2c3d4e5f6"
|
revision: str = 'a1b2c3d4e5f6'
|
||||||
down_revision: Union[str, None] = "f1e2d3c4b5a6"
|
down_revision: Union[str, None] = 'f1e2d3c4b5a6'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
@@ -22,24 +22,24 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
existing_tables = set(get_existing_tables())
|
existing_tables = set(get_existing_tables())
|
||||||
|
|
||||||
if "skill" not in existing_tables:
|
if 'skill' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"skill",
|
'skill',
|
||||||
sa.Column("id", sa.String(), nullable=False, primary_key=True),
|
sa.Column('id', sa.String(), nullable=False, primary_key=True),
|
||||||
sa.Column("user_id", sa.String(), nullable=False),
|
sa.Column('user_id', sa.String(), nullable=False),
|
||||||
sa.Column("name", sa.Text(), nullable=False, unique=True),
|
sa.Column('name', sa.Text(), nullable=False, unique=True),
|
||||||
sa.Column("description", sa.Text(), nullable=True),
|
sa.Column('description', sa.Text(), nullable=True),
|
||||||
sa.Column("content", sa.Text(), nullable=False),
|
sa.Column('content', sa.Text(), nullable=False),
|
||||||
sa.Column("meta", sa.JSON(), nullable=True),
|
sa.Column('meta', sa.JSON(), nullable=True),
|
||||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
)
|
)
|
||||||
op.create_index("idx_skill_user_id", "skill", ["user_id"])
|
op.create_index('idx_skill_user_id', 'skill', ['user_id'])
|
||||||
op.create_index("idx_skill_updated_at", "skill", ["updated_at"])
|
op.create_index('idx_skill_updated_at', 'skill', ['updated_at'])
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_index("idx_skill_updated_at", table_name="skill")
|
op.drop_index('idx_skill_updated_at', table_name='skill')
|
||||||
op.drop_index("idx_skill_user_id", table_name="skill")
|
op.drop_index('idx_skill_user_id', table_name='skill')
|
||||||
op.drop_table("skill")
|
op.drop_table('skill')
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ from alembic import op
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "a5c220713937"
|
revision: str = 'a5c220713937'
|
||||||
down_revision: Union[str, None] = "38d63c18f30f"
|
down_revision: Union[str, None] = '38d63c18f30f'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
@@ -21,14 +21,14 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# Add 'reply_to_id' column to the 'message' table for replying to messages
|
# Add 'reply_to_id' column to the 'message' table for replying to messages
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"message",
|
'message',
|
||||||
sa.Column("reply_to_id", sa.Text(), nullable=True),
|
sa.Column('reply_to_id', sa.Text(), nullable=True),
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# Remove 'reply_to_id' column from the 'message' table
|
# Remove 'reply_to_id' column from the 'message' table
|
||||||
op.drop_column("message", "reply_to_id")
|
op.drop_column('message', 'reply_to_id')
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ from alembic import op
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
# Revision identifiers, used by Alembic.
|
# Revision identifiers, used by Alembic.
|
||||||
revision = "af906e964978"
|
revision = 'af906e964978'
|
||||||
down_revision = "c29facfe716b"
|
down_revision = 'c29facfe716b'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
@@ -19,33 +19,23 @@ depends_on = None
|
|||||||
def upgrade():
|
def upgrade():
|
||||||
# ### Create feedback table ###
|
# ### Create feedback table ###
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"feedback",
|
'feedback',
|
||||||
|
sa.Column('id', sa.Text(), primary_key=True), # Unique identifier for each feedback (TEXT type)
|
||||||
|
sa.Column('user_id', sa.Text(), nullable=True), # ID of the user providing the feedback (TEXT type)
|
||||||
|
sa.Column('version', sa.BigInteger(), default=0), # Version of feedback (BIGINT type)
|
||||||
|
sa.Column('type', sa.Text(), nullable=True), # Type of feedback (TEXT type)
|
||||||
|
sa.Column('data', sa.JSON(), nullable=True), # Feedback data (JSON type)
|
||||||
|
sa.Column('meta', sa.JSON(), nullable=True), # Metadata for feedback (JSON type)
|
||||||
|
sa.Column('snapshot', sa.JSON(), nullable=True), # snapshot data for feedback (JSON type)
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"id", sa.Text(), primary_key=True
|
'created_at', sa.BigInteger(), nullable=False
|
||||||
), # Unique identifier for each feedback (TEXT type)
|
|
||||||
sa.Column(
|
|
||||||
"user_id", sa.Text(), nullable=True
|
|
||||||
), # ID of the user providing the feedback (TEXT type)
|
|
||||||
sa.Column(
|
|
||||||
"version", sa.BigInteger(), default=0
|
|
||||||
), # Version of feedback (BIGINT type)
|
|
||||||
sa.Column("type", sa.Text(), nullable=True), # Type of feedback (TEXT type)
|
|
||||||
sa.Column("data", sa.JSON(), nullable=True), # Feedback data (JSON type)
|
|
||||||
sa.Column(
|
|
||||||
"meta", sa.JSON(), nullable=True
|
|
||||||
), # Metadata for feedback (JSON type)
|
|
||||||
sa.Column(
|
|
||||||
"snapshot", sa.JSON(), nullable=True
|
|
||||||
), # snapshot data for feedback (JSON type)
|
|
||||||
sa.Column(
|
|
||||||
"created_at", sa.BigInteger(), nullable=False
|
|
||||||
), # Feedback creation timestamp (BIGINT representing epoch)
|
), # Feedback creation timestamp (BIGINT representing epoch)
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"updated_at", sa.BigInteger(), nullable=False
|
'updated_at', sa.BigInteger(), nullable=False
|
||||||
), # Feedback update timestamp (BIGINT representing epoch)
|
), # Feedback update timestamp (BIGINT representing epoch)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# ### Drop feedback table ###
|
# ### Drop feedback table ###
|
||||||
op.drop_table("feedback")
|
op.drop_table('feedback')
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ import json
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "b10670c03dd5"
|
revision: str = 'b10670c03dd5'
|
||||||
down_revision: Union[str, None] = "2f1211949ecc"
|
down_revision: Union[str, None] = '2f1211949ecc'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
@@ -33,13 +33,11 @@ def _drop_sqlite_indexes_for_column(table_name, column_name, conn):
|
|||||||
for idx in indexes:
|
for idx in indexes:
|
||||||
index_name = idx[1] # index name
|
index_name = idx[1] # index name
|
||||||
# Get indexed columns
|
# Get indexed columns
|
||||||
idx_info = conn.execute(
|
idx_info = conn.execute(sa.text(f"PRAGMA index_info('{index_name}')")).fetchall()
|
||||||
sa.text(f"PRAGMA index_info('{index_name}')")
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
indexed_cols = [row[2] for row in idx_info] # col names
|
indexed_cols = [row[2] for row in idx_info] # col names
|
||||||
if column_name in indexed_cols:
|
if column_name in indexed_cols:
|
||||||
conn.execute(sa.text(f"DROP INDEX IF EXISTS {index_name}"))
|
conn.execute(sa.text(f'DROP INDEX IF EXISTS {index_name}'))
|
||||||
|
|
||||||
|
|
||||||
def _convert_column_to_json(table: str, column: str):
|
def _convert_column_to_json(table: str, column: str):
|
||||||
@@ -47,9 +45,9 @@ def _convert_column_to_json(table: str, column: str):
|
|||||||
dialect = conn.dialect.name
|
dialect = conn.dialect.name
|
||||||
|
|
||||||
# SQLite cannot ALTER COLUMN → must recreate column
|
# SQLite cannot ALTER COLUMN → must recreate column
|
||||||
if dialect == "sqlite":
|
if dialect == 'sqlite':
|
||||||
# 1. Add temporary column
|
# 1. Add temporary column
|
||||||
op.add_column(table, sa.Column(f"{column}_json", sa.JSON(), nullable=True))
|
op.add_column(table, sa.Column(f'{column}_json', sa.JSON(), nullable=True))
|
||||||
|
|
||||||
# 2. Load old data
|
# 2. Load old data
|
||||||
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
||||||
@@ -66,14 +64,14 @@ def _convert_column_to_json(table: str, column: str):
|
|||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
|
sa.text(f'UPDATE "{table}" SET {column}_json = :val WHERE id = :id'),
|
||||||
{"val": json.dumps(parsed) if parsed else None, "id": uid},
|
{'val': json.dumps(parsed) if parsed else None, 'id': uid},
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Drop old TEXT column
|
# 3. Drop old TEXT column
|
||||||
op.drop_column(table, column)
|
op.drop_column(table, column)
|
||||||
|
|
||||||
# 4. Rename new JSON column → original name
|
# 4. Rename new JSON column → original name
|
||||||
op.alter_column(table, f"{column}_json", new_column_name=column)
|
op.alter_column(table, f'{column}_json', new_column_name=column)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# PostgreSQL supports direct CAST
|
# PostgreSQL supports direct CAST
|
||||||
@@ -81,7 +79,7 @@ def _convert_column_to_json(table: str, column: str):
|
|||||||
table,
|
table,
|
||||||
column,
|
column,
|
||||||
type_=sa.JSON(),
|
type_=sa.JSON(),
|
||||||
postgresql_using=f"{column}::json",
|
postgresql_using=f'{column}::json',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -89,85 +87,77 @@ def _convert_column_to_text(table: str, column: str):
|
|||||||
conn = op.get_bind()
|
conn = op.get_bind()
|
||||||
dialect = conn.dialect.name
|
dialect = conn.dialect.name
|
||||||
|
|
||||||
if dialect == "sqlite":
|
if dialect == 'sqlite':
|
||||||
op.add_column(table, sa.Column(f"{column}_text", sa.Text(), nullable=True))
|
op.add_column(table, sa.Column(f'{column}_text', sa.Text(), nullable=True))
|
||||||
|
|
||||||
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
rows = conn.execute(sa.text(f'SELECT id, {column} FROM "{table}"')).fetchall()
|
||||||
|
|
||||||
for uid, raw in rows:
|
for uid, raw in rows:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
|
sa.text(f'UPDATE "{table}" SET {column}_text = :val WHERE id = :id'),
|
||||||
{"val": json.dumps(raw) if raw else None, "id": uid},
|
{'val': json.dumps(raw) if raw else None, 'id': uid},
|
||||||
)
|
)
|
||||||
|
|
||||||
op.drop_column(table, column)
|
op.drop_column(table, column)
|
||||||
op.alter_column(table, f"{column}_text", new_column_name=column)
|
op.alter_column(table, f'{column}_text', new_column_name=column)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
op.alter_column(
|
op.alter_column(
|
||||||
table,
|
table,
|
||||||
column,
|
column,
|
||||||
type_=sa.Text(),
|
type_=sa.Text(),
|
||||||
postgresql_using=f"to_json({column})::text",
|
postgresql_using=f'to_json({column})::text',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.add_column(
|
op.add_column('user', sa.Column('profile_banner_image_url', sa.Text(), nullable=True))
|
||||||
"user", sa.Column("profile_banner_image_url", sa.Text(), nullable=True)
|
op.add_column('user', sa.Column('timezone', sa.String(), nullable=True))
|
||||||
)
|
|
||||||
op.add_column("user", sa.Column("timezone", sa.String(), nullable=True))
|
|
||||||
|
|
||||||
op.add_column("user", sa.Column("presence_state", sa.String(), nullable=True))
|
op.add_column('user', sa.Column('presence_state', sa.String(), nullable=True))
|
||||||
op.add_column("user", sa.Column("status_emoji", sa.String(), nullable=True))
|
op.add_column('user', sa.Column('status_emoji', sa.String(), nullable=True))
|
||||||
op.add_column("user", sa.Column("status_message", sa.Text(), nullable=True))
|
op.add_column('user', sa.Column('status_message', sa.Text(), nullable=True))
|
||||||
op.add_column(
|
op.add_column('user', sa.Column('status_expires_at', sa.BigInteger(), nullable=True))
|
||||||
"user", sa.Column("status_expires_at", sa.BigInteger(), nullable=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
op.add_column("user", sa.Column("oauth", sa.JSON(), nullable=True))
|
op.add_column('user', sa.Column('oauth', sa.JSON(), nullable=True))
|
||||||
|
|
||||||
# Convert info (TEXT/JSONField) → JSON
|
# Convert info (TEXT/JSONField) → JSON
|
||||||
_convert_column_to_json("user", "info")
|
_convert_column_to_json('user', 'info')
|
||||||
# Convert settings (TEXT/JSONField) → JSON
|
# Convert settings (TEXT/JSONField) → JSON
|
||||||
_convert_column_to_json("user", "settings")
|
_convert_column_to_json('user', 'settings')
|
||||||
|
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"api_key",
|
'api_key',
|
||||||
sa.Column("id", sa.Text(), primary_key=True, unique=True),
|
sa.Column('id', sa.Text(), primary_key=True, unique=True),
|
||||||
sa.Column("user_id", sa.Text(), sa.ForeignKey("user.id", ondelete="CASCADE")),
|
sa.Column('user_id', sa.Text(), sa.ForeignKey('user.id', ondelete='CASCADE')),
|
||||||
sa.Column("key", sa.Text(), unique=True, nullable=False),
|
sa.Column('key', sa.Text(), unique=True, nullable=False),
|
||||||
sa.Column("data", sa.JSON(), nullable=True),
|
sa.Column('data', sa.JSON(), nullable=True),
|
||||||
sa.Column("expires_at", sa.BigInteger(), nullable=True),
|
sa.Column('expires_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("last_used_at", sa.BigInteger(), nullable=True),
|
sa.Column('last_used_at', sa.BigInteger(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
conn = op.get_bind()
|
conn = op.get_bind()
|
||||||
users = conn.execute(
|
users = conn.execute(sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')).fetchall()
|
||||||
sa.text('SELECT id, oauth_sub FROM "user" WHERE oauth_sub IS NOT NULL')
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
for uid, oauth_sub in users:
|
for uid, oauth_sub in users:
|
||||||
if oauth_sub:
|
if oauth_sub:
|
||||||
# Example formats supported:
|
# Example formats supported:
|
||||||
# provider@sub
|
# provider@sub
|
||||||
# plain sub (stored as {"oidc": {"sub": sub}})
|
# plain sub (stored as {"oidc": {"sub": sub}})
|
||||||
if "@" in oauth_sub:
|
if '@' in oauth_sub:
|
||||||
provider, sub = oauth_sub.split("@", 1)
|
provider, sub = oauth_sub.split('@', 1)
|
||||||
else:
|
else:
|
||||||
provider, sub = "oidc", oauth_sub
|
provider, sub = 'oidc', oauth_sub
|
||||||
|
|
||||||
oauth_json = json.dumps({provider: {"sub": sub}})
|
oauth_json = json.dumps({provider: {'sub': sub}})
|
||||||
conn.execute(
|
conn.execute(
|
||||||
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
|
sa.text('UPDATE "user" SET oauth = :oauth WHERE id = :id'),
|
||||||
{"oauth": oauth_json, "id": uid},
|
{'oauth': oauth_json, 'id': uid},
|
||||||
)
|
)
|
||||||
|
|
||||||
users_with_keys = conn.execute(
|
users_with_keys = conn.execute(sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')).fetchall()
|
||||||
sa.text('SELECT id, api_key FROM "user" WHERE api_key IS NOT NULL')
|
|
||||||
).fetchall()
|
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
|
|
||||||
for uid, api_key in users_with_keys:
|
for uid, api_key in users_with_keys:
|
||||||
@@ -178,72 +168,70 @@ def upgrade() -> None:
|
|||||||
VALUES (:id, :user_id, :key, :created_at, :updated_at)
|
VALUES (:id, :user_id, :key, :created_at, :updated_at)
|
||||||
"""),
|
"""),
|
||||||
{
|
{
|
||||||
"id": f"key_{uid}",
|
'id': f'key_{uid}',
|
||||||
"user_id": uid,
|
'user_id': uid,
|
||||||
"key": api_key,
|
'key': api_key,
|
||||||
"created_at": now,
|
'created_at': now,
|
||||||
"updated_at": now,
|
'updated_at': now,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if conn.dialect.name == "sqlite":
|
if conn.dialect.name == 'sqlite':
|
||||||
_drop_sqlite_indexes_for_column("user", "api_key", conn)
|
_drop_sqlite_indexes_for_column('user', 'api_key', conn)
|
||||||
_drop_sqlite_indexes_for_column("user", "oauth_sub", conn)
|
_drop_sqlite_indexes_for_column('user', 'oauth_sub', conn)
|
||||||
|
|
||||||
with op.batch_alter_table("user") as batch_op:
|
with op.batch_alter_table('user') as batch_op:
|
||||||
batch_op.drop_column("api_key")
|
batch_op.drop_column('api_key')
|
||||||
batch_op.drop_column("oauth_sub")
|
batch_op.drop_column('oauth_sub')
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# --- 1. Restore old oauth_sub column ---
|
# --- 1. Restore old oauth_sub column ---
|
||||||
op.add_column("user", sa.Column("oauth_sub", sa.Text(), nullable=True))
|
op.add_column('user', sa.Column('oauth_sub', sa.Text(), nullable=True))
|
||||||
|
|
||||||
conn = op.get_bind()
|
conn = op.get_bind()
|
||||||
users = conn.execute(
|
users = conn.execute(sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')).fetchall()
|
||||||
sa.text('SELECT id, oauth FROM "user" WHERE oauth IS NOT NULL')
|
|
||||||
).fetchall()
|
|
||||||
|
|
||||||
for uid, oauth in users:
|
for uid, oauth in users:
|
||||||
try:
|
try:
|
||||||
data = json.loads(oauth)
|
data = json.loads(oauth)
|
||||||
provider = list(data.keys())[0]
|
provider = list(data.keys())[0]
|
||||||
sub = data[provider].get("sub")
|
sub = data[provider].get('sub')
|
||||||
oauth_sub = f"{provider}@{sub}"
|
oauth_sub = f'{provider}@{sub}'
|
||||||
except Exception:
|
except Exception:
|
||||||
oauth_sub = None
|
oauth_sub = None
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
|
sa.text('UPDATE "user" SET oauth_sub = :oauth_sub WHERE id = :id'),
|
||||||
{"oauth_sub": oauth_sub, "id": uid},
|
{'oauth_sub': oauth_sub, 'id': uid},
|
||||||
)
|
)
|
||||||
|
|
||||||
op.drop_column("user", "oauth")
|
op.drop_column('user', 'oauth')
|
||||||
|
|
||||||
# --- 2. Restore api_key field ---
|
# --- 2. Restore api_key field ---
|
||||||
op.add_column("user", sa.Column("api_key", sa.String(), nullable=True))
|
op.add_column('user', sa.Column('api_key', sa.String(), nullable=True))
|
||||||
|
|
||||||
# Restore values from api_key
|
# Restore values from api_key
|
||||||
keys = conn.execute(sa.text("SELECT user_id, key FROM api_key")).fetchall()
|
keys = conn.execute(sa.text('SELECT user_id, key FROM api_key')).fetchall()
|
||||||
for uid, key in keys:
|
for uid, key in keys:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
|
sa.text('UPDATE "user" SET api_key = :key WHERE id = :id'),
|
||||||
{"key": key, "id": uid},
|
{'key': key, 'id': uid},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Drop new table
|
# Drop new table
|
||||||
op.drop_table("api_key")
|
op.drop_table('api_key')
|
||||||
|
|
||||||
with op.batch_alter_table("user") as batch_op:
|
with op.batch_alter_table('user') as batch_op:
|
||||||
batch_op.drop_column("profile_banner_image_url")
|
batch_op.drop_column('profile_banner_image_url')
|
||||||
batch_op.drop_column("timezone")
|
batch_op.drop_column('timezone')
|
||||||
|
|
||||||
batch_op.drop_column("presence_state")
|
batch_op.drop_column('presence_state')
|
||||||
batch_op.drop_column("status_emoji")
|
batch_op.drop_column('status_emoji')
|
||||||
batch_op.drop_column("status_message")
|
batch_op.drop_column('status_message')
|
||||||
batch_op.drop_column("status_expires_at")
|
batch_op.drop_column('status_expires_at')
|
||||||
|
|
||||||
# Convert info (JSON) → TEXT
|
# Convert info (JSON) → TEXT
|
||||||
_convert_column_to_text("user", "info")
|
_convert_column_to_text('user', 'info')
|
||||||
# Convert settings (JSON) → TEXT
|
# Convert settings (JSON) → TEXT
|
||||||
_convert_column_to_text("user", "settings")
|
_convert_column_to_text('user', 'settings')
|
||||||
|
|||||||
@@ -12,15 +12,15 @@ from alembic import op
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "b2c3d4e5f6a7"
|
revision: str = 'b2c3d4e5f6a7'
|
||||||
down_revision: Union[str, None] = "a1b2c3d4e5f6"
|
down_revision: Union[str, None] = 'a1b2c3d4e5f6'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.add_column("user", sa.Column("scim", sa.JSON(), nullable=True))
|
op.add_column('user', sa.Column('scim', sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_column("user", "scim")
|
op.drop_column('user', 'scim')
|
||||||
|
|||||||
@@ -12,21 +12,21 @@ import sqlalchemy as sa
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "c0fbf31ca0db"
|
revision: str = 'c0fbf31ca0db'
|
||||||
down_revision: Union[str, None] = "ca81bd47c050"
|
down_revision: Union[str, None] = 'ca81bd47c050'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.add_column("file", sa.Column("hash", sa.Text(), nullable=True))
|
op.add_column('file', sa.Column('hash', sa.Text(), nullable=True))
|
||||||
op.add_column("file", sa.Column("data", sa.JSON(), nullable=True))
|
op.add_column('file', sa.Column('data', sa.JSON(), nullable=True))
|
||||||
op.add_column("file", sa.Column("updated_at", sa.BigInteger(), nullable=True))
|
op.add_column('file', sa.Column('updated_at', sa.BigInteger(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_column("file", "updated_at")
|
op.drop_column('file', 'updated_at')
|
||||||
op.drop_column("file", "data")
|
op.drop_column('file', 'data')
|
||||||
op.drop_column("file", "hash")
|
op.drop_column('file', 'hash')
|
||||||
|
|||||||
@@ -12,35 +12,33 @@ import json
|
|||||||
from sqlalchemy.sql import table, column
|
from sqlalchemy.sql import table, column
|
||||||
from sqlalchemy import String, Text, JSON, and_
|
from sqlalchemy import String, Text, JSON, and_
|
||||||
|
|
||||||
revision = "c29facfe716b"
|
revision = 'c29facfe716b'
|
||||||
down_revision = "c69f45358db4"
|
down_revision = 'c69f45358db4'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
# 1. Add the `path` column to the "file" table.
|
# 1. Add the `path` column to the "file" table.
|
||||||
op.add_column("file", sa.Column("path", sa.Text(), nullable=True))
|
op.add_column('file', sa.Column('path', sa.Text(), nullable=True))
|
||||||
|
|
||||||
# 2. Convert the `meta` column from Text/JSONField to `JSON()`
|
# 2. Convert the `meta` column from Text/JSONField to `JSON()`
|
||||||
# Use Alembic's default batch_op for dialect compatibility.
|
# Use Alembic's default batch_op for dialect compatibility.
|
||||||
with op.batch_alter_table("file", schema=None) as batch_op:
|
with op.batch_alter_table('file', schema=None) as batch_op:
|
||||||
batch_op.alter_column(
|
batch_op.alter_column(
|
||||||
"meta",
|
'meta',
|
||||||
type_=sa.JSON(),
|
type_=sa.JSON(),
|
||||||
existing_type=sa.Text(),
|
existing_type=sa.Text(),
|
||||||
existing_nullable=True,
|
existing_nullable=True,
|
||||||
nullable=True,
|
nullable=True,
|
||||||
postgresql_using="meta::json",
|
postgresql_using='meta::json',
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Migrate legacy data from `meta` JSONField
|
# 3. Migrate legacy data from `meta` JSONField
|
||||||
# Fetch and process `meta` data from the table, add values to the new `path` column as necessary.
|
# Fetch and process `meta` data from the table, add values to the new `path` column as necessary.
|
||||||
# We will use SQLAlchemy core bindings to ensure safety across different databases.
|
# We will use SQLAlchemy core bindings to ensure safety across different databases.
|
||||||
|
|
||||||
file_table = table(
|
file_table = table('file', column('id', String), column('meta', JSON), column('path', Text))
|
||||||
"file", column("id", String), column("meta", JSON), column("path", Text)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create connection to the database
|
# Create connection to the database
|
||||||
connection = op.get_bind()
|
connection = op.get_bind()
|
||||||
@@ -55,24 +53,18 @@ def upgrade():
|
|||||||
|
|
||||||
# Iterate over each row to extract and update the `path` from `meta` column
|
# Iterate over each row to extract and update the `path` from `meta` column
|
||||||
for row in results:
|
for row in results:
|
||||||
if "path" in row.meta:
|
if 'path' in row.meta:
|
||||||
# Extract the `path` field from the `meta` JSON
|
# Extract the `path` field from the `meta` JSON
|
||||||
path = row.meta.get("path")
|
path = row.meta.get('path')
|
||||||
|
|
||||||
# Update the `file` table with the new `path` value
|
# Update the `file` table with the new `path` value
|
||||||
connection.execute(
|
connection.execute(file_table.update().where(file_table.c.id == row.id).values({'path': path}))
|
||||||
file_table.update()
|
|
||||||
.where(file_table.c.id == row.id)
|
|
||||||
.values({"path": path})
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
# 1. Remove the `path` column
|
# 1. Remove the `path` column
|
||||||
op.drop_column("file", "path")
|
op.drop_column('file', 'path')
|
||||||
|
|
||||||
# 2. Revert the `meta` column back to Text/JSONField
|
# 2. Revert the `meta` column back to Text/JSONField
|
||||||
with op.batch_alter_table("file", schema=None) as batch_op:
|
with op.batch_alter_table('file', schema=None) as batch_op:
|
||||||
batch_op.alter_column(
|
batch_op.alter_column('meta', type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True)
|
||||||
"meta", type_=sa.Text(), existing_type=sa.JSON(), existing_nullable=True
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -12,45 +12,43 @@ from alembic import op
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "c440947495f3"
|
revision: str = 'c440947495f3'
|
||||||
down_revision: Union[str, None] = "81cc2ce44d79"
|
down_revision: Union[str, None] = '81cc2ce44d79'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"chat_file",
|
'chat_file',
|
||||||
sa.Column("id", sa.Text(), primary_key=True),
|
sa.Column('id', sa.Text(), primary_key=True),
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
sa.Column('user_id', sa.Text(), nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"chat_id",
|
'chat_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey("chat.id", ondelete="CASCADE"),
|
sa.ForeignKey('chat.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"file_id",
|
'file_id',
|
||||||
sa.Text(),
|
sa.Text(),
|
||||||
sa.ForeignKey("file.id", ondelete="CASCADE"),
|
sa.ForeignKey('file.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
),
|
),
|
||||||
sa.Column("message_id", sa.Text(), nullable=True),
|
sa.Column('message_id', sa.Text(), nullable=True),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
sa.Column("updated_at", sa.BigInteger(), nullable=False),
|
sa.Column('updated_at', sa.BigInteger(), nullable=False),
|
||||||
# indexes
|
# indexes
|
||||||
sa.Index("ix_chat_file_chat_id", "chat_id"),
|
sa.Index('ix_chat_file_chat_id', 'chat_id'),
|
||||||
sa.Index("ix_chat_file_file_id", "file_id"),
|
sa.Index('ix_chat_file_file_id', 'file_id'),
|
||||||
sa.Index("ix_chat_file_message_id", "message_id"),
|
sa.Index('ix_chat_file_message_id', 'message_id'),
|
||||||
sa.Index("ix_chat_file_user_id", "user_id"),
|
sa.Index('ix_chat_file_user_id', 'user_id'),
|
||||||
# unique constraints
|
# unique constraints
|
||||||
sa.UniqueConstraint(
|
sa.UniqueConstraint('chat_id', 'file_id', name='uq_chat_file_chat_file'), # prevent duplicate entries
|
||||||
"chat_id", "file_id", name="uq_chat_file_chat_file"
|
|
||||||
), # prevent duplicate entries
|
|
||||||
)
|
)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_table("chat_file")
|
op.drop_table('chat_file')
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -9,42 +9,40 @@ Create Date: 2024-10-16 02:02:35.241684
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
revision = "c69f45358db4"
|
revision = 'c69f45358db4'
|
||||||
down_revision = "3ab32c4b8f59"
|
down_revision = '3ab32c4b8f59'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"folder",
|
'folder',
|
||||||
sa.Column("id", sa.Text(), nullable=False),
|
sa.Column('id', sa.Text(), nullable=False),
|
||||||
sa.Column("parent_id", sa.Text(), nullable=True),
|
sa.Column('parent_id', sa.Text(), nullable=True),
|
||||||
sa.Column("user_id", sa.Text(), nullable=False),
|
sa.Column('user_id', sa.Text(), nullable=False),
|
||||||
sa.Column("name", sa.Text(), nullable=False),
|
sa.Column('name', sa.Text(), nullable=False),
|
||||||
sa.Column("items", sa.JSON(), nullable=True),
|
sa.Column('items', sa.JSON(), nullable=True),
|
||||||
sa.Column("meta", sa.JSON(), nullable=True),
|
sa.Column('meta', sa.JSON(), nullable=True),
|
||||||
sa.Column("is_expanded", sa.Boolean(), default=False, nullable=False),
|
sa.Column('is_expanded', sa.Boolean(), default=False, nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.func.now(), nullable=False),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False
|
'updated_at',
|
||||||
),
|
|
||||||
sa.Column(
|
|
||||||
"updated_at",
|
|
||||||
sa.DateTime(),
|
sa.DateTime(),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
server_default=sa.func.now(),
|
server_default=sa.func.now(),
|
||||||
onupdate=sa.func.now(),
|
onupdate=sa.func.now(),
|
||||||
),
|
),
|
||||||
sa.PrimaryKeyConstraint("id", "user_id"),
|
sa.PrimaryKeyConstraint('id', 'user_id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
op.add_column(
|
op.add_column(
|
||||||
"chat",
|
'chat',
|
||||||
sa.Column("folder_id", sa.Text(), nullable=True),
|
sa.Column('folder_id', sa.Text(), nullable=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_column("chat", "folder_id")
|
op.drop_column('chat', 'folder_id')
|
||||||
|
|
||||||
op.drop_table("folder")
|
op.drop_table('folder')
|
||||||
|
|||||||
@@ -12,23 +12,21 @@ import sqlalchemy as sa
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = "ca81bd47c050"
|
revision: str = 'ca81bd47c050'
|
||||||
down_revision: Union[str, None] = "7e5b5dc7342b"
|
down_revision: Union[str, None] = '7e5b5dc7342b'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"config",
|
'config',
|
||||||
sa.Column("id", sa.Integer, primary_key=True),
|
sa.Column('id', sa.Integer, primary_key=True),
|
||||||
sa.Column("data", sa.JSON(), nullable=False),
|
sa.Column('data', sa.JSON(), nullable=False),
|
||||||
sa.Column("version", sa.Integer, nullable=False),
|
sa.Column('version', sa.Integer, nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False, server_default=sa.func.now()),
|
||||||
sa.Column(
|
sa.Column(
|
||||||
"created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()
|
'updated_at',
|
||||||
),
|
|
||||||
sa.Column(
|
|
||||||
"updated_at",
|
|
||||||
sa.DateTime(),
|
sa.DateTime(),
|
||||||
nullable=True,
|
nullable=True,
|
||||||
server_default=sa.func.now(),
|
server_default=sa.func.now(),
|
||||||
@@ -38,4 +36,4 @@ def upgrade():
|
|||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table("config")
|
op.drop_table('config')
|
||||||
|
|||||||
@@ -9,15 +9,15 @@ Create Date: 2025-07-13 03:00:00.000000
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
revision = "d31026856c01"
|
revision = 'd31026856c01'
|
||||||
down_revision = "9f0c9cd09105"
|
down_revision = '9f0c9cd09105'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True))
|
op.add_column('folder', sa.Column('data', sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_column("folder", "data")
|
op.drop_column('folder', 'data')
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ import sqlalchemy as sa
|
|||||||
|
|
||||||
from open_webui.migrations.util import get_existing_tables
|
from open_webui.migrations.util import get_existing_tables
|
||||||
|
|
||||||
revision: str = "f1e2d3c4b5a6"
|
revision: str = 'f1e2d3c4b5a6'
|
||||||
down_revision: Union[str, None] = "8452d01d26d7"
|
down_revision: Union[str, None] = '8452d01d26d7'
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
@@ -30,34 +30,34 @@ def upgrade() -> None:
|
|||||||
existing_tables = set(get_existing_tables())
|
existing_tables = set(get_existing_tables())
|
||||||
|
|
||||||
# Create access_grant table
|
# Create access_grant table
|
||||||
if "access_grant" not in existing_tables:
|
if 'access_grant' not in existing_tables:
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"access_grant",
|
'access_grant',
|
||||||
sa.Column("id", sa.Text(), nullable=False, primary_key=True),
|
sa.Column('id', sa.Text(), nullable=False, primary_key=True),
|
||||||
sa.Column("resource_type", sa.Text(), nullable=False),
|
sa.Column('resource_type', sa.Text(), nullable=False),
|
||||||
sa.Column("resource_id", sa.Text(), nullable=False),
|
sa.Column('resource_id', sa.Text(), nullable=False),
|
||||||
sa.Column("principal_type", sa.Text(), nullable=False),
|
sa.Column('principal_type', sa.Text(), nullable=False),
|
||||||
sa.Column("principal_id", sa.Text(), nullable=False),
|
sa.Column('principal_id', sa.Text(), nullable=False),
|
||||||
sa.Column("permission", sa.Text(), nullable=False),
|
sa.Column('permission', sa.Text(), nullable=False),
|
||||||
sa.Column("created_at", sa.BigInteger(), nullable=False),
|
sa.Column('created_at', sa.BigInteger(), nullable=False),
|
||||||
sa.UniqueConstraint(
|
sa.UniqueConstraint(
|
||||||
"resource_type",
|
'resource_type',
|
||||||
"resource_id",
|
'resource_id',
|
||||||
"principal_type",
|
'principal_type',
|
||||||
"principal_id",
|
'principal_id',
|
||||||
"permission",
|
'permission',
|
||||||
name="uq_access_grant_grant",
|
name='uq_access_grant_grant',
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
op.create_index(
|
op.create_index(
|
||||||
"idx_access_grant_resource",
|
'idx_access_grant_resource',
|
||||||
"access_grant",
|
'access_grant',
|
||||||
["resource_type", "resource_id"],
|
['resource_type', 'resource_id'],
|
||||||
)
|
)
|
||||||
op.create_index(
|
op.create_index(
|
||||||
"idx_access_grant_principal",
|
'idx_access_grant_principal',
|
||||||
"access_grant",
|
'access_grant',
|
||||||
["principal_type", "principal_id"],
|
['principal_type', 'principal_id'],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backfill existing access_control JSON data
|
# Backfill existing access_control JSON data
|
||||||
@@ -65,13 +65,13 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
# Tables with access_control JSON columns: (table_name, resource_type)
|
# Tables with access_control JSON columns: (table_name, resource_type)
|
||||||
resource_tables = [
|
resource_tables = [
|
||||||
("knowledge", "knowledge"),
|
('knowledge', 'knowledge'),
|
||||||
("prompt", "prompt"),
|
('prompt', 'prompt'),
|
||||||
("tool", "tool"),
|
('tool', 'tool'),
|
||||||
("model", "model"),
|
('model', 'model'),
|
||||||
("note", "note"),
|
('note', 'note'),
|
||||||
("channel", "channel"),
|
('channel', 'channel'),
|
||||||
("file", "file"),
|
('file', 'file'),
|
||||||
]
|
]
|
||||||
|
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
@@ -83,9 +83,7 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
# Query all rows
|
# Query all rows
|
||||||
try:
|
try:
|
||||||
result = conn.execute(
|
result = conn.execute(sa.text(f'SELECT id, access_control FROM "{table_name}"'))
|
||||||
sa.text(f'SELECT id, access_control FROM "{table_name}"')
|
|
||||||
)
|
|
||||||
rows = result.fetchall()
|
rows = result.fetchall()
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
@@ -99,19 +97,16 @@ def upgrade() -> None:
|
|||||||
# EXCEPTION: files with NULL are PRIVATE (owner-only), not public
|
# EXCEPTION: files with NULL are PRIVATE (owner-only), not public
|
||||||
is_null = (
|
is_null = (
|
||||||
access_control_json is None
|
access_control_json is None
|
||||||
or access_control_json == "null"
|
or access_control_json == 'null'
|
||||||
or (
|
or (isinstance(access_control_json, str) and access_control_json.strip().lower() == 'null')
|
||||||
isinstance(access_control_json, str)
|
|
||||||
and access_control_json.strip().lower() == "null"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if is_null:
|
if is_null:
|
||||||
# Files: NULL = private (no entry needed, owner has implicit access)
|
# Files: NULL = private (no entry needed, owner has implicit access)
|
||||||
# Other resources: NULL = public (insert user:* for read)
|
# Other resources: NULL = public (insert user:* for read)
|
||||||
if resource_type == "file":
|
if resource_type == 'file':
|
||||||
continue # Private - no entry needed
|
continue # Private - no entry needed
|
||||||
|
|
||||||
key = (resource_type, resource_id, "user", "*", "read")
|
key = (resource_type, resource_id, 'user', '*', 'read')
|
||||||
if key not in inserted:
|
if key not in inserted:
|
||||||
try:
|
try:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@@ -120,13 +115,13 @@ def upgrade() -> None:
|
|||||||
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
|
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
|
||||||
"""),
|
"""),
|
||||||
{
|
{
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"resource_type": resource_type,
|
'resource_type': resource_type,
|
||||||
"resource_id": resource_id,
|
'resource_id': resource_id,
|
||||||
"principal_type": "user",
|
'principal_type': 'user',
|
||||||
"principal_id": "*",
|
'principal_id': '*',
|
||||||
"permission": "read",
|
'permission': 'read',
|
||||||
"created_at": now,
|
'created_at': now,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
inserted.add(key)
|
inserted.add(key)
|
||||||
@@ -149,28 +144,24 @@ def upgrade() -> None:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if it's effectively empty (no read/write keys with content)
|
# Check if it's effectively empty (no read/write keys with content)
|
||||||
read_data = access_control_json.get("read", {})
|
read_data = access_control_json.get('read', {})
|
||||||
write_data = access_control_json.get("write", {})
|
write_data = access_control_json.get('write', {})
|
||||||
|
|
||||||
has_read_grants = read_data.get("group_ids", []) or read_data.get(
|
has_read_grants = read_data.get('group_ids', []) or read_data.get('user_ids', [])
|
||||||
"user_ids", []
|
has_write_grants = write_data.get('group_ids', []) or write_data.get('user_ids', [])
|
||||||
)
|
|
||||||
has_write_grants = write_data.get("group_ids", []) or write_data.get(
|
|
||||||
"user_ids", []
|
|
||||||
)
|
|
||||||
|
|
||||||
if not has_read_grants and not has_write_grants:
|
if not has_read_grants and not has_write_grants:
|
||||||
# Empty permissions = private, no grants needed
|
# Empty permissions = private, no grants needed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Extract permissions and insert into access_grant table
|
# Extract permissions and insert into access_grant table
|
||||||
for permission in ["read", "write"]:
|
for permission in ['read', 'write']:
|
||||||
perm_data = access_control_json.get(permission, {})
|
perm_data = access_control_json.get(permission, {})
|
||||||
if not perm_data:
|
if not perm_data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for group_id in perm_data.get("group_ids", []):
|
for group_id in perm_data.get('group_ids', []):
|
||||||
key = (resource_type, resource_id, "group", group_id, permission)
|
key = (resource_type, resource_id, 'group', group_id, permission)
|
||||||
if key in inserted:
|
if key in inserted:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
@@ -180,21 +171,21 @@ def upgrade() -> None:
|
|||||||
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
|
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
|
||||||
"""),
|
"""),
|
||||||
{
|
{
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"resource_type": resource_type,
|
'resource_type': resource_type,
|
||||||
"resource_id": resource_id,
|
'resource_id': resource_id,
|
||||||
"principal_type": "group",
|
'principal_type': 'group',
|
||||||
"principal_id": group_id,
|
'principal_id': group_id,
|
||||||
"permission": permission,
|
'permission': permission,
|
||||||
"created_at": now,
|
'created_at': now,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
inserted.add(key)
|
inserted.add(key)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
for user_id in perm_data.get("user_ids", []):
|
for user_id in perm_data.get('user_ids', []):
|
||||||
key = (resource_type, resource_id, "user", user_id, permission)
|
key = (resource_type, resource_id, 'user', user_id, permission)
|
||||||
if key in inserted:
|
if key in inserted:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
@@ -204,13 +195,13 @@ def upgrade() -> None:
|
|||||||
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
|
VALUES (:id, :resource_type, :resource_id, :principal_type, :principal_id, :permission, :created_at)
|
||||||
"""),
|
"""),
|
||||||
{
|
{
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"resource_type": resource_type,
|
'resource_type': resource_type,
|
||||||
"resource_id": resource_id,
|
'resource_id': resource_id,
|
||||||
"principal_type": "user",
|
'principal_type': 'user',
|
||||||
"principal_id": user_id,
|
'principal_id': user_id,
|
||||||
"permission": permission,
|
'permission': permission,
|
||||||
"created_at": now,
|
'created_at': now,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
inserted.add(key)
|
inserted.add(key)
|
||||||
@@ -223,7 +214,7 @@ def upgrade() -> None:
|
|||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
with op.batch_alter_table(table_name) as batch:
|
with op.batch_alter_table(table_name) as batch:
|
||||||
batch.drop_column("access_control")
|
batch.drop_column('access_control')
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -235,20 +226,20 @@ def downgrade() -> None:
|
|||||||
|
|
||||||
# Resource tables mapping: (table_name, resource_type)
|
# Resource tables mapping: (table_name, resource_type)
|
||||||
resource_tables = [
|
resource_tables = [
|
||||||
("knowledge", "knowledge"),
|
('knowledge', 'knowledge'),
|
||||||
("prompt", "prompt"),
|
('prompt', 'prompt'),
|
||||||
("tool", "tool"),
|
('tool', 'tool'),
|
||||||
("model", "model"),
|
('model', 'model'),
|
||||||
("note", "note"),
|
('note', 'note'),
|
||||||
("channel", "channel"),
|
('channel', 'channel'),
|
||||||
("file", "file"),
|
('file', 'file'),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Step 1: Re-add access_control columns to resource tables
|
# Step 1: Re-add access_control columns to resource tables
|
||||||
for table_name, _ in resource_tables:
|
for table_name, _ in resource_tables:
|
||||||
try:
|
try:
|
||||||
with op.batch_alter_table(table_name) as batch:
|
with op.batch_alter_table(table_name) as batch:
|
||||||
batch.add_column(sa.Column("access_control", sa.JSON(), nullable=True))
|
batch.add_column(sa.Column('access_control', sa.JSON(), nullable=True))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -262,7 +253,7 @@ def downgrade() -> None:
|
|||||||
FROM access_grant
|
FROM access_grant
|
||||||
WHERE resource_type = :resource_type
|
WHERE resource_type = :resource_type
|
||||||
"""),
|
"""),
|
||||||
{"resource_type": resource_type},
|
{'resource_type': resource_type},
|
||||||
)
|
)
|
||||||
rows = result.fetchall()
|
rows = result.fetchall()
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -278,49 +269,35 @@ def downgrade() -> None:
|
|||||||
|
|
||||||
if resource_id not in resource_grants:
|
if resource_id not in resource_grants:
|
||||||
resource_grants[resource_id] = {
|
resource_grants[resource_id] = {
|
||||||
"is_public": False,
|
'is_public': False,
|
||||||
"read": {"group_ids": [], "user_ids": []},
|
'read': {'group_ids': [], 'user_ids': []},
|
||||||
"write": {"group_ids": [], "user_ids": []},
|
'write': {'group_ids': [], 'user_ids': []},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle public access (user:* for read)
|
# Handle public access (user:* for read)
|
||||||
if (
|
if principal_type == 'user' and principal_id == '*' and permission == 'read':
|
||||||
principal_type == "user"
|
resource_grants[resource_id]['is_public'] = True
|
||||||
and principal_id == "*"
|
|
||||||
and permission == "read"
|
|
||||||
):
|
|
||||||
resource_grants[resource_id]["is_public"] = True
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Add to appropriate list
|
# Add to appropriate list
|
||||||
if permission in ["read", "write"]:
|
if permission in ['read', 'write']:
|
||||||
if principal_type == "group":
|
if principal_type == 'group':
|
||||||
if (
|
if principal_id not in resource_grants[resource_id][permission]['group_ids']:
|
||||||
principal_id
|
resource_grants[resource_id][permission]['group_ids'].append(principal_id)
|
||||||
not in resource_grants[resource_id][permission]["group_ids"]
|
elif principal_type == 'user':
|
||||||
):
|
if principal_id not in resource_grants[resource_id][permission]['user_ids']:
|
||||||
resource_grants[resource_id][permission]["group_ids"].append(
|
resource_grants[resource_id][permission]['user_ids'].append(principal_id)
|
||||||
principal_id
|
|
||||||
)
|
|
||||||
elif principal_type == "user":
|
|
||||||
if (
|
|
||||||
principal_id
|
|
||||||
not in resource_grants[resource_id][permission]["user_ids"]
|
|
||||||
):
|
|
||||||
resource_grants[resource_id][permission]["user_ids"].append(
|
|
||||||
principal_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 3: Update each resource with reconstructed JSON
|
# Step 3: Update each resource with reconstructed JSON
|
||||||
for resource_id, grants in resource_grants.items():
|
for resource_id, grants in resource_grants.items():
|
||||||
if grants["is_public"]:
|
if grants['is_public']:
|
||||||
# Public = NULL
|
# Public = NULL
|
||||||
access_control_value = None
|
access_control_value = None
|
||||||
elif (
|
elif (
|
||||||
not grants["read"]["group_ids"]
|
not grants['read']['group_ids']
|
||||||
and not grants["read"]["user_ids"]
|
and not grants['read']['user_ids']
|
||||||
and not grants["write"]["group_ids"]
|
and not grants['write']['group_ids']
|
||||||
and not grants["write"]["user_ids"]
|
and not grants['write']['user_ids']
|
||||||
):
|
):
|
||||||
# No grants = should not happen (would mean no entries), default to {}
|
# No grants = should not happen (would mean no entries), default to {}
|
||||||
access_control_value = json.dumps({})
|
access_control_value = json.dumps({})
|
||||||
@@ -328,17 +305,15 @@ def downgrade() -> None:
|
|||||||
# Custom permissions
|
# Custom permissions
|
||||||
access_control_value = json.dumps(
|
access_control_value = json.dumps(
|
||||||
{
|
{
|
||||||
"read": grants["read"],
|
'read': grants['read'],
|
||||||
"write": grants["write"],
|
'write': grants['write'],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
sa.text(
|
sa.text(f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id'),
|
||||||
f'UPDATE "{table_name}" SET access_control = :access_control WHERE id = :id'
|
{'access_control': access_control_value, 'id': resource_id},
|
||||||
),
|
|
||||||
{"access_control": access_control_value, "id": resource_id},
|
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -346,7 +321,7 @@ def downgrade() -> None:
|
|||||||
# Step 4: Set all resources WITHOUT entries to private
|
# Step 4: Set all resources WITHOUT entries to private
|
||||||
# For files: NULL means private (owner-only), so leave as NULL
|
# For files: NULL means private (owner-only), so leave as NULL
|
||||||
# For other resources: {} means private, so update to {}
|
# For other resources: {} means private, so update to {}
|
||||||
if resource_type != "file":
|
if resource_type != 'file':
|
||||||
try:
|
try:
|
||||||
conn.execute(
|
conn.execute(
|
||||||
sa.text(f"""
|
sa.text(f"""
|
||||||
@@ -357,13 +332,13 @@ def downgrade() -> None:
|
|||||||
)
|
)
|
||||||
AND access_control IS NULL
|
AND access_control IS NULL
|
||||||
"""),
|
"""),
|
||||||
{"private_value": json.dumps({}), "resource_type": resource_type},
|
{'private_value': json.dumps({}), 'resource_type': resource_type},
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# For files, NULL stays NULL - no action needed
|
# For files, NULL stays NULL - no action needed
|
||||||
|
|
||||||
# Step 5: Drop the access_grant table
|
# Step 5: Drop the access_grant table
|
||||||
op.drop_index("idx_access_grant_principal", table_name="access_grant")
|
op.drop_index('idx_access_grant_principal', table_name='access_grant')
|
||||||
op.drop_index("idx_access_grant_resource", table_name="access_grant")
|
op.drop_index('idx_access_grant_resource', table_name='access_grant')
|
||||||
op.drop_table("access_grant")
|
op.drop_table('access_grant')
|
||||||
|
|||||||
@@ -19,28 +19,24 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AccessGrant(Base):
|
class AccessGrant(Base):
|
||||||
__tablename__ = "access_grant"
|
__tablename__ = 'access_grant'
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True)
|
||||||
resource_type = Column(
|
resource_type = Column(Text, nullable=False) # "knowledge", "model", "prompt", "tool", "note", "channel", "file"
|
||||||
Text, nullable=False
|
|
||||||
) # "knowledge", "model", "prompt", "tool", "note", "channel", "file"
|
|
||||||
resource_id = Column(Text, nullable=False)
|
resource_id = Column(Text, nullable=False)
|
||||||
principal_type = Column(Text, nullable=False) # "user" or "group"
|
principal_type = Column(Text, nullable=False) # "user" or "group"
|
||||||
principal_id = Column(
|
principal_id = Column(Text, nullable=False) # user_id, group_id, or "*" (wildcard for public)
|
||||||
Text, nullable=False
|
|
||||||
) # user_id, group_id, or "*" (wildcard for public)
|
|
||||||
permission = Column(Text, nullable=False) # "read" or "write"
|
permission = Column(Text, nullable=False) # "read" or "write"
|
||||||
created_at = Column(BigInteger, nullable=False)
|
created_at = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
"resource_type",
|
'resource_type',
|
||||||
"resource_id",
|
'resource_id',
|
||||||
"principal_type",
|
'principal_type',
|
||||||
"principal_id",
|
'principal_id',
|
||||||
"permission",
|
'permission',
|
||||||
name="uq_access_grant_grant",
|
name='uq_access_grant_grant',
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,7 +62,7 @@ class AccessGrantResponse(BaseModel):
|
|||||||
permission: str
|
permission: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_grant(cls, grant: "AccessGrantModel") -> "AccessGrantResponse":
|
def from_grant(cls, grant: 'AccessGrantModel') -> 'AccessGrantResponse':
|
||||||
return cls(
|
return cls(
|
||||||
id=grant.id,
|
id=grant.id,
|
||||||
principal_type=grant.principal_type,
|
principal_type=grant.principal_type,
|
||||||
@@ -100,14 +96,14 @@ def access_control_to_grants(
|
|||||||
if access_control is None:
|
if access_control is None:
|
||||||
# NULL → public read (user:* for read)
|
# NULL → public read (user:* for read)
|
||||||
# Exception: files with NULL are private (owner-only), no grants needed
|
# Exception: files with NULL are private (owner-only), no grants needed
|
||||||
if resource_type != "file":
|
if resource_type != 'file':
|
||||||
grants.append(
|
grants.append(
|
||||||
{
|
{
|
||||||
"resource_type": resource_type,
|
'resource_type': resource_type,
|
||||||
"resource_id": resource_id,
|
'resource_id': resource_id,
|
||||||
"principal_type": "user",
|
'principal_type': 'user',
|
||||||
"principal_id": "*",
|
'principal_id': '*',
|
||||||
"permission": "read",
|
'permission': 'read',
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return grants
|
return grants
|
||||||
@@ -117,30 +113,30 @@ def access_control_to_grants(
|
|||||||
return grants
|
return grants
|
||||||
|
|
||||||
# Parse structured permissions
|
# Parse structured permissions
|
||||||
for permission in ["read", "write"]:
|
for permission in ['read', 'write']:
|
||||||
perm_data = access_control.get(permission, {})
|
perm_data = access_control.get(permission, {})
|
||||||
if not perm_data:
|
if not perm_data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for group_id in perm_data.get("group_ids", []):
|
for group_id in perm_data.get('group_ids', []):
|
||||||
grants.append(
|
grants.append(
|
||||||
{
|
{
|
||||||
"resource_type": resource_type,
|
'resource_type': resource_type,
|
||||||
"resource_id": resource_id,
|
'resource_id': resource_id,
|
||||||
"principal_type": "group",
|
'principal_type': 'group',
|
||||||
"principal_id": group_id,
|
'principal_id': group_id,
|
||||||
"permission": permission,
|
'permission': permission,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
for user_id in perm_data.get("user_ids", []):
|
for user_id in perm_data.get('user_ids', []):
|
||||||
grants.append(
|
grants.append(
|
||||||
{
|
{
|
||||||
"resource_type": resource_type,
|
'resource_type': resource_type,
|
||||||
"resource_id": resource_id,
|
'resource_id': resource_id,
|
||||||
"principal_type": "user",
|
'principal_type': 'user',
|
||||||
"principal_id": user_id,
|
'principal_id': user_id,
|
||||||
"permission": permission,
|
'permission': permission,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -164,27 +160,23 @@ def normalize_access_grants(access_grants: Optional[list]) -> list[dict]:
|
|||||||
if not isinstance(grant, dict):
|
if not isinstance(grant, dict):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
principal_type = grant.get("principal_type")
|
principal_type = grant.get('principal_type')
|
||||||
principal_id = grant.get("principal_id")
|
principal_id = grant.get('principal_id')
|
||||||
permission = grant.get("permission")
|
permission = grant.get('permission')
|
||||||
|
|
||||||
if principal_type not in ("user", "group"):
|
if principal_type not in ('user', 'group'):
|
||||||
continue
|
continue
|
||||||
if permission not in ("read", "write"):
|
if permission not in ('read', 'write'):
|
||||||
continue
|
continue
|
||||||
if not isinstance(principal_id, str) or not principal_id:
|
if not isinstance(principal_id, str) or not principal_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
key = (principal_type, principal_id, permission)
|
key = (principal_type, principal_id, permission)
|
||||||
deduped[key] = {
|
deduped[key] = {
|
||||||
"id": (
|
'id': (grant.get('id') if isinstance(grant.get('id'), str) and grant.get('id') else str(uuid.uuid4())),
|
||||||
grant.get("id")
|
'principal_type': principal_type,
|
||||||
if isinstance(grant.get("id"), str) and grant.get("id")
|
'principal_id': principal_id,
|
||||||
else str(uuid.uuid4())
|
'permission': permission,
|
||||||
),
|
|
||||||
"principal_type": principal_type,
|
|
||||||
"principal_id": principal_id,
|
|
||||||
"permission": permission,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return list(deduped.values())
|
return list(deduped.values())
|
||||||
@@ -195,11 +187,7 @@ def has_public_read_access_grant(access_grants: Optional[list]) -> bool:
|
|||||||
Returns True when a direct grant list includes wildcard public-read.
|
Returns True when a direct grant list includes wildcard public-read.
|
||||||
"""
|
"""
|
||||||
for grant in normalize_access_grants(access_grants):
|
for grant in normalize_access_grants(access_grants):
|
||||||
if (
|
if grant['principal_type'] == 'user' and grant['principal_id'] == '*' and grant['permission'] == 'read':
|
||||||
grant["principal_type"] == "user"
|
|
||||||
and grant["principal_id"] == "*"
|
|
||||||
and grant["permission"] == "read"
|
|
||||||
):
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -209,7 +197,7 @@ def has_user_access_grant(access_grants: Optional[list]) -> bool:
|
|||||||
Returns True when a direct grant list includes any non-wildcard user grant.
|
Returns True when a direct grant list includes any non-wildcard user grant.
|
||||||
"""
|
"""
|
||||||
for grant in normalize_access_grants(access_grants):
|
for grant in normalize_access_grants(access_grants):
|
||||||
if grant["principal_type"] == "user" and grant["principal_id"] != "*":
|
if grant['principal_type'] == 'user' and grant['principal_id'] != '*':
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -225,18 +213,9 @@ def strip_user_access_grants(access_grants: Optional[list]) -> list:
|
|||||||
grant
|
grant
|
||||||
for grant in access_grants
|
for grant in access_grants
|
||||||
if not (
|
if not (
|
||||||
(
|
(grant.get('principal_type') if isinstance(grant, dict) else getattr(grant, 'principal_type', None))
|
||||||
grant.get("principal_type")
|
== 'user'
|
||||||
if isinstance(grant, dict)
|
and (grant.get('principal_id') if isinstance(grant, dict) else getattr(grant, 'principal_id', None)) != '*'
|
||||||
else getattr(grant, "principal_type", None)
|
|
||||||
)
|
|
||||||
== "user"
|
|
||||||
and (
|
|
||||||
grant.get("principal_id")
|
|
||||||
if isinstance(grant, dict)
|
|
||||||
else getattr(grant, "principal_id", None)
|
|
||||||
)
|
|
||||||
!= "*"
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -260,29 +239,25 @@ def grants_to_access_control(grants: list) -> Optional[dict]:
|
|||||||
return {} # No grants = private/owner-only
|
return {} # No grants = private/owner-only
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"read": {"group_ids": [], "user_ids": []},
|
'read': {'group_ids': [], 'user_ids': []},
|
||||||
"write": {"group_ids": [], "user_ids": []},
|
'write': {'group_ids': [], 'user_ids': []},
|
||||||
}
|
}
|
||||||
|
|
||||||
is_public = False
|
is_public = False
|
||||||
for grant in grants:
|
for grant in grants:
|
||||||
if (
|
if grant.principal_type == 'user' and grant.principal_id == '*' and grant.permission == 'read':
|
||||||
grant.principal_type == "user"
|
|
||||||
and grant.principal_id == "*"
|
|
||||||
and grant.permission == "read"
|
|
||||||
):
|
|
||||||
is_public = True
|
is_public = True
|
||||||
continue # Don't add wildcard to user_ids list
|
continue # Don't add wildcard to user_ids list
|
||||||
|
|
||||||
if grant.permission not in ("read", "write"):
|
if grant.permission not in ('read', 'write'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if grant.principal_type == "group":
|
if grant.principal_type == 'group':
|
||||||
if grant.principal_id not in result[grant.permission]["group_ids"]:
|
if grant.principal_id not in result[grant.permission]['group_ids']:
|
||||||
result[grant.permission]["group_ids"].append(grant.principal_id)
|
result[grant.permission]['group_ids'].append(grant.principal_id)
|
||||||
elif grant.principal_type == "user":
|
elif grant.principal_type == 'user':
|
||||||
if grant.principal_id not in result[grant.permission]["user_ids"]:
|
if grant.principal_id not in result[grant.permission]['user_ids']:
|
||||||
result[grant.permission]["user_ids"].append(grant.principal_id)
|
result[grant.permission]['user_ids'].append(grant.principal_id)
|
||||||
|
|
||||||
if is_public:
|
if is_public:
|
||||||
return None # Public read access
|
return None # Public read access
|
||||||
@@ -399,9 +374,7 @@ class AccessGrantsTable:
|
|||||||
).delete()
|
).delete()
|
||||||
|
|
||||||
# Convert JSON to grant dicts
|
# Convert JSON to grant dicts
|
||||||
grant_dicts = access_control_to_grants(
|
grant_dicts = access_control_to_grants(resource_type, resource_id, access_control)
|
||||||
resource_type, resource_id, access_control
|
|
||||||
)
|
|
||||||
|
|
||||||
# Insert new grants
|
# Insert new grants
|
||||||
results = []
|
results = []
|
||||||
@@ -442,9 +415,9 @@ class AccessGrantsTable:
|
|||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
resource_type=resource_type,
|
resource_type=resource_type,
|
||||||
resource_id=resource_id,
|
resource_id=resource_id,
|
||||||
principal_type=grant_dict["principal_type"],
|
principal_type=grant_dict['principal_type'],
|
||||||
principal_id=grant_dict["principal_id"],
|
principal_id=grant_dict['principal_id'],
|
||||||
permission=grant_dict["permission"],
|
permission=grant_dict['permission'],
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
)
|
)
|
||||||
db.add(grant)
|
db.add(grant)
|
||||||
@@ -511,9 +484,7 @@ class AccessGrantsTable:
|
|||||||
)
|
)
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
result: dict[str, list[AccessGrantModel]] = {
|
result: dict[str, list[AccessGrantModel]] = {rid: [] for rid in resource_ids}
|
||||||
rid: [] for rid in resource_ids
|
|
||||||
}
|
|
||||||
for g in grants:
|
for g in grants:
|
||||||
result[g.resource_id].append(AccessGrantModel.model_validate(g))
|
result[g.resource_id].append(AccessGrantModel.model_validate(g))
|
||||||
return result
|
return result
|
||||||
@@ -523,7 +494,7 @@ class AccessGrantsTable:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
resource_type: str,
|
resource_type: str,
|
||||||
resource_id: str,
|
resource_id: str,
|
||||||
permission: str = "read",
|
permission: str = 'read',
|
||||||
user_group_ids: Optional[set[str]] = None,
|
user_group_ids: Optional[set[str]] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -540,12 +511,12 @@ class AccessGrantsTable:
|
|||||||
conditions = [
|
conditions = [
|
||||||
# Public access
|
# Public access
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == "*",
|
AccessGrant.principal_id == '*',
|
||||||
),
|
),
|
||||||
# Direct user access
|
# Direct user access
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == user_id,
|
AccessGrant.principal_id == user_id,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -560,7 +531,7 @@ class AccessGrantsTable:
|
|||||||
if user_group_ids:
|
if user_group_ids:
|
||||||
conditions.append(
|
conditions.append(
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "group",
|
AccessGrant.principal_type == 'group',
|
||||||
AccessGrant.principal_id.in_(user_group_ids),
|
AccessGrant.principal_id.in_(user_group_ids),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -582,7 +553,7 @@ class AccessGrantsTable:
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
resource_type: str,
|
resource_type: str,
|
||||||
resource_ids: list[str],
|
resource_ids: list[str],
|
||||||
permission: str = "read",
|
permission: str = 'read',
|
||||||
user_group_ids: Optional[set[str]] = None,
|
user_group_ids: Optional[set[str]] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> set[str]:
|
) -> set[str]:
|
||||||
@@ -597,11 +568,11 @@ class AccessGrantsTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
conditions = [
|
conditions = [
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == "*",
|
AccessGrant.principal_id == '*',
|
||||||
),
|
),
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == user_id,
|
AccessGrant.principal_id == user_id,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -615,7 +586,7 @@ class AccessGrantsTable:
|
|||||||
if user_group_ids:
|
if user_group_ids:
|
||||||
conditions.append(
|
conditions.append(
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "group",
|
AccessGrant.principal_type == 'group',
|
||||||
AccessGrant.principal_id.in_(user_group_ids),
|
AccessGrant.principal_id.in_(user_group_ids),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -637,7 +608,7 @@ class AccessGrantsTable:
|
|||||||
self,
|
self,
|
||||||
resource_type: str,
|
resource_type: str,
|
||||||
resource_id: str,
|
resource_id: str,
|
||||||
permission: str = "read",
|
permission: str = 'read',
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> list:
|
) -> list:
|
||||||
"""
|
"""
|
||||||
@@ -660,19 +631,17 @@ class AccessGrantsTable:
|
|||||||
|
|
||||||
# Check for public access
|
# Check for public access
|
||||||
for grant in grants:
|
for grant in grants:
|
||||||
if grant.principal_type == "user" and grant.principal_id == "*":
|
if grant.principal_type == 'user' and grant.principal_id == '*':
|
||||||
result = Users.get_users(filter={"roles": ["!pending"]}, db=db)
|
result = Users.get_users(filter={'roles': ['!pending']}, db=db)
|
||||||
return result.get("users", [])
|
return result.get('users', [])
|
||||||
|
|
||||||
user_ids_with_access = set()
|
user_ids_with_access = set()
|
||||||
|
|
||||||
for grant in grants:
|
for grant in grants:
|
||||||
if grant.principal_type == "user":
|
if grant.principal_type == 'user':
|
||||||
user_ids_with_access.add(grant.principal_id)
|
user_ids_with_access.add(grant.principal_id)
|
||||||
elif grant.principal_type == "group":
|
elif grant.principal_type == 'group':
|
||||||
group_user_ids = Groups.get_group_user_ids_by_id(
|
group_user_ids = Groups.get_group_user_ids_by_id(grant.principal_id, db=db)
|
||||||
grant.principal_id, db=db
|
|
||||||
)
|
|
||||||
if group_user_ids:
|
if group_user_ids:
|
||||||
user_ids_with_access.update(group_user_ids)
|
user_ids_with_access.update(group_user_ids)
|
||||||
|
|
||||||
@@ -688,20 +657,18 @@ class AccessGrantsTable:
|
|||||||
DocumentModel,
|
DocumentModel,
|
||||||
filter: dict,
|
filter: dict,
|
||||||
resource_type: str,
|
resource_type: str,
|
||||||
permission: str = "read",
|
permission: str = 'read',
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Apply access control filtering to a SQLAlchemy query by JOINing with access_grant.
|
Apply access control filtering to a SQLAlchemy query by JOINing with access_grant.
|
||||||
|
|
||||||
This replaces the old JSON-column-based filtering with a proper relational JOIN.
|
This replaces the old JSON-column-based filtering with a proper relational JOIN.
|
||||||
"""
|
"""
|
||||||
group_ids = filter.get("group_ids", [])
|
group_ids = filter.get('group_ids', [])
|
||||||
user_id = filter.get("user_id")
|
user_id = filter.get('user_id')
|
||||||
|
|
||||||
if permission == "read_only":
|
if permission == 'read_only':
|
||||||
return self._has_read_only_permission_filter(
|
return self._has_read_only_permission_filter(db, query, DocumentModel, filter, resource_type)
|
||||||
db, query, DocumentModel, filter, resource_type
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build principal conditions
|
# Build principal conditions
|
||||||
principal_conditions = []
|
principal_conditions = []
|
||||||
@@ -710,8 +677,8 @@ class AccessGrantsTable:
|
|||||||
# Public access: user:* read
|
# Public access: user:* read
|
||||||
principal_conditions.append(
|
principal_conditions.append(
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == "*",
|
AccessGrant.principal_id == '*',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -722,7 +689,7 @@ class AccessGrantsTable:
|
|||||||
# Direct user grant
|
# Direct user grant
|
||||||
principal_conditions.append(
|
principal_conditions.append(
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == user_id,
|
AccessGrant.principal_id == user_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -731,7 +698,7 @@ class AccessGrantsTable:
|
|||||||
# Group grants
|
# Group grants
|
||||||
principal_conditions.append(
|
principal_conditions.append(
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "group",
|
AccessGrant.principal_type == 'group',
|
||||||
AccessGrant.principal_id.in_(group_ids),
|
AccessGrant.principal_id.in_(group_ids),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -751,13 +718,13 @@ class AccessGrantsTable:
|
|||||||
AccessGrant.permission == permission,
|
AccessGrant.permission == permission,
|
||||||
or_(
|
or_(
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == "*",
|
AccessGrant.principal_id == '*',
|
||||||
),
|
),
|
||||||
*(
|
*(
|
||||||
[
|
[
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == user_id,
|
AccessGrant.principal_id == user_id,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -767,7 +734,7 @@ class AccessGrantsTable:
|
|||||||
*(
|
*(
|
||||||
[
|
[
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "group",
|
AccessGrant.principal_type == 'group',
|
||||||
AccessGrant.principal_id.in_(group_ids),
|
AccessGrant.principal_id.in_(group_ids),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -800,8 +767,8 @@ class AccessGrantsTable:
|
|||||||
Filter for items where user has read BUT NOT write access.
|
Filter for items where user has read BUT NOT write access.
|
||||||
Public items are NOT considered read_only.
|
Public items are NOT considered read_only.
|
||||||
"""
|
"""
|
||||||
group_ids = filter.get("group_ids", [])
|
group_ids = filter.get('group_ids', [])
|
||||||
user_id = filter.get("user_id")
|
user_id = filter.get('user_id')
|
||||||
|
|
||||||
from sqlalchemy import exists as sa_exists, select
|
from sqlalchemy import exists as sa_exists, select
|
||||||
|
|
||||||
@@ -811,12 +778,12 @@ class AccessGrantsTable:
|
|||||||
.where(
|
.where(
|
||||||
AccessGrant.resource_type == resource_type,
|
AccessGrant.resource_type == resource_type,
|
||||||
AccessGrant.resource_id == DocumentModel.id,
|
AccessGrant.resource_id == DocumentModel.id,
|
||||||
AccessGrant.permission == "read",
|
AccessGrant.permission == 'read',
|
||||||
or_(
|
or_(
|
||||||
*(
|
*(
|
||||||
[
|
[
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == user_id,
|
AccessGrant.principal_id == user_id,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -826,7 +793,7 @@ class AccessGrantsTable:
|
|||||||
*(
|
*(
|
||||||
[
|
[
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "group",
|
AccessGrant.principal_type == 'group',
|
||||||
AccessGrant.principal_id.in_(group_ids),
|
AccessGrant.principal_id.in_(group_ids),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -845,12 +812,12 @@ class AccessGrantsTable:
|
|||||||
.where(
|
.where(
|
||||||
AccessGrant.resource_type == resource_type,
|
AccessGrant.resource_type == resource_type,
|
||||||
AccessGrant.resource_id == DocumentModel.id,
|
AccessGrant.resource_id == DocumentModel.id,
|
||||||
AccessGrant.permission == "write",
|
AccessGrant.permission == 'write',
|
||||||
or_(
|
or_(
|
||||||
*(
|
*(
|
||||||
[
|
[
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == user_id,
|
AccessGrant.principal_id == user_id,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -860,7 +827,7 @@ class AccessGrantsTable:
|
|||||||
*(
|
*(
|
||||||
[
|
[
|
||||||
and_(
|
and_(
|
||||||
AccessGrant.principal_type == "group",
|
AccessGrant.principal_type == 'group',
|
||||||
AccessGrant.principal_id.in_(group_ids),
|
AccessGrant.principal_id.in_(group_ids),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -879,9 +846,9 @@ class AccessGrantsTable:
|
|||||||
.where(
|
.where(
|
||||||
AccessGrant.resource_type == resource_type,
|
AccessGrant.resource_type == resource_type,
|
||||||
AccessGrant.resource_id == DocumentModel.id,
|
AccessGrant.resource_id == DocumentModel.id,
|
||||||
AccessGrant.permission == "read",
|
AccessGrant.permission == 'read',
|
||||||
AccessGrant.principal_type == "user",
|
AccessGrant.principal_type == 'user',
|
||||||
AccessGrant.principal_id == "*",
|
AccessGrant.principal_id == '*',
|
||||||
)
|
)
|
||||||
.correlate(DocumentModel)
|
.correlate(DocumentModel)
|
||||||
.exists()
|
.exists()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Auth(Base):
|
class Auth(Base):
|
||||||
__tablename__ = "auth"
|
__tablename__ = 'auth'
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
email = Column(String)
|
email = Column(String)
|
||||||
@@ -73,9 +73,9 @@ class SignupForm(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
email: str
|
email: str
|
||||||
password: str
|
password: str
|
||||||
profile_image_url: Optional[str] = "/user.png"
|
profile_image_url: Optional[str] = '/user.png'
|
||||||
|
|
||||||
@field_validator("profile_image_url")
|
@field_validator('profile_image_url')
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]:
|
def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]:
|
||||||
if v is not None:
|
if v is not None:
|
||||||
@@ -84,7 +84,7 @@ class SignupForm(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AddUserForm(SignupForm):
|
class AddUserForm(SignupForm):
|
||||||
role: Optional[str] = "pending"
|
role: Optional[str] = 'pending'
|
||||||
|
|
||||||
|
|
||||||
class AuthsTable:
|
class AuthsTable:
|
||||||
@@ -93,25 +93,21 @@ class AuthsTable:
|
|||||||
email: str,
|
email: str,
|
||||||
password: str,
|
password: str,
|
||||||
name: str,
|
name: str,
|
||||||
profile_image_url: str = "/user.png",
|
profile_image_url: str = '/user.png',
|
||||||
role: str = "pending",
|
role: str = 'pending',
|
||||||
oauth: Optional[dict] = None,
|
oauth: Optional[dict] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
log.info("insert_new_auth")
|
log.info('insert_new_auth')
|
||||||
|
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
|
|
||||||
auth = AuthModel(
|
auth = AuthModel(**{'id': id, 'email': email, 'password': password, 'active': True})
|
||||||
**{"id": id, "email": email, "password": password, "active": True}
|
|
||||||
)
|
|
||||||
result = Auth(**auth.model_dump())
|
result = Auth(**auth.model_dump())
|
||||||
db.add(result)
|
db.add(result)
|
||||||
|
|
||||||
user = Users.insert_new_user(
|
user = Users.insert_new_user(id, name, email, profile_image_url, role, oauth=oauth, db=db)
|
||||||
id, name, email, profile_image_url, role, oauth=oauth, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(result)
|
db.refresh(result)
|
||||||
@@ -124,7 +120,7 @@ class AuthsTable:
|
|||||||
def authenticate_user(
|
def authenticate_user(
|
||||||
self, email: str, verify_password: callable, db: Optional[Session] = None
|
self, email: str, verify_password: callable, db: Optional[Session] = None
|
||||||
) -> Optional[UserModel]:
|
) -> Optional[UserModel]:
|
||||||
log.info(f"authenticate_user: {email}")
|
log.info(f'authenticate_user: {email}')
|
||||||
|
|
||||||
user = Users.get_user_by_email(email, db=db)
|
user = Users.get_user_by_email(email, db=db)
|
||||||
if not user:
|
if not user:
|
||||||
@@ -143,10 +139,8 @@ class AuthsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def authenticate_user_by_api_key(
|
def authenticate_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
self, api_key: str, db: Optional[Session] = None
|
log.info(f'authenticate_user_by_api_key')
|
||||||
) -> Optional[UserModel]:
|
|
||||||
log.info(f"authenticate_user_by_api_key")
|
|
||||||
# if no api_key, return None
|
# if no api_key, return None
|
||||||
if not api_key:
|
if not api_key:
|
||||||
return None
|
return None
|
||||||
@@ -157,10 +151,8 @@ class AuthsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def authenticate_user_by_email(
|
def authenticate_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
self, email: str, db: Optional[Session] = None
|
log.info(f'authenticate_user_by_email: {email}')
|
||||||
) -> Optional[UserModel]:
|
|
||||||
log.info(f"authenticate_user_by_email: {email}")
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# Single JOIN query instead of two separate queries
|
# Single JOIN query instead of two separate queries
|
||||||
@@ -177,28 +169,22 @@ class AuthsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_password_by_id(
|
def update_user_password_by_id(self, id: str, new_password: str, db: Optional[Session] = None) -> bool:
|
||||||
self, id: str, new_password: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
result = (
|
result = db.query(Auth).filter_by(id=id).update({'password': new_password})
|
||||||
db.query(Auth).filter_by(id=id).update({"password": new_password})
|
|
||||||
)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return True if result == 1 else False
|
return True if result == 1 else False
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def update_email_by_id(
|
def update_email_by_id(self, id: str, email: str, db: Optional[Session] = None) -> bool:
|
||||||
self, id: str, email: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
result = db.query(Auth).filter_by(id=id).update({"email": email})
|
result = db.query(Auth).filter_by(id=id).update({'email': email})
|
||||||
db.commit()
|
db.commit()
|
||||||
if result == 1:
|
if result == 1:
|
||||||
Users.update_user_by_id(id, {"email": email}, db=db)
|
Users.update_user_by_id(id, {'email': email}, db=db)
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from sqlalchemy.sql import exists
|
|||||||
|
|
||||||
|
|
||||||
class Channel(Base):
|
class Channel(Base):
|
||||||
__tablename__ = "channel"
|
__tablename__ = 'channel'
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
@@ -94,7 +94,7 @@ class ChannelModel(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChannelMember(Base):
|
class ChannelMember(Base):
|
||||||
__tablename__ = "channel_member"
|
__tablename__ = 'channel_member'
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
channel_id = Column(Text, nullable=False)
|
channel_id = Column(Text, nullable=False)
|
||||||
@@ -154,25 +154,19 @@ class ChannelMemberModel(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChannelFile(Base):
|
class ChannelFile(Base):
|
||||||
__tablename__ = "channel_file"
|
__tablename__ = 'channel_file'
|
||||||
|
|
||||||
id = Column(Text, unique=True, primary_key=True)
|
id = Column(Text, unique=True, primary_key=True)
|
||||||
user_id = Column(Text, nullable=False)
|
user_id = Column(Text, nullable=False)
|
||||||
|
|
||||||
channel_id = Column(
|
channel_id = Column(Text, ForeignKey('channel.id', ondelete='CASCADE'), nullable=False)
|
||||||
Text, ForeignKey("channel.id", ondelete="CASCADE"), nullable=False
|
message_id = Column(Text, ForeignKey('message.id', ondelete='CASCADE'), nullable=True)
|
||||||
)
|
file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False)
|
||||||
message_id = Column(
|
|
||||||
Text, ForeignKey("message.id", ondelete="CASCADE"), nullable=True
|
|
||||||
)
|
|
||||||
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
|
|
||||||
|
|
||||||
created_at = Column(BigInteger, nullable=False)
|
created_at = Column(BigInteger, nullable=False)
|
||||||
updated_at = Column(BigInteger, nullable=False)
|
updated_at = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (UniqueConstraint('channel_id', 'file_id', name='uq_channel_file_channel_file'),)
|
||||||
UniqueConstraint("channel_id", "file_id", name="uq_channel_file_channel_file"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelFileModel(BaseModel):
|
class ChannelFileModel(BaseModel):
|
||||||
@@ -189,7 +183,7 @@ class ChannelFileModel(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChannelWebhook(Base):
|
class ChannelWebhook(Base):
|
||||||
__tablename__ = "channel_webhook"
|
__tablename__ = 'channel_webhook'
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
channel_id = Column(Text, nullable=False)
|
channel_id = Column(Text, nullable=False)
|
||||||
@@ -235,7 +229,7 @@ class ChannelResponse(ChannelModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChannelForm(BaseModel):
|
class ChannelForm(BaseModel):
|
||||||
name: str = ""
|
name: str = ''
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
is_private: Optional[bool] = None
|
is_private: Optional[bool] = None
|
||||||
data: Optional[dict] = None
|
data: Optional[dict] = None
|
||||||
@@ -255,10 +249,8 @@ class ChannelWebhookForm(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChannelTable:
|
class ChannelTable:
|
||||||
def _get_access_grants(
|
def _get_access_grants(self, channel_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||||
self, channel_id: str, db: Optional[Session] = None
|
return AccessGrants.get_grants_by_resource('channel', channel_id, db=db)
|
||||||
) -> list[AccessGrantModel]:
|
|
||||||
return AccessGrants.get_grants_by_resource("channel", channel_id, db=db)
|
|
||||||
|
|
||||||
def _to_channel_model(
|
def _to_channel_model(
|
||||||
self,
|
self,
|
||||||
@@ -266,13 +258,9 @@ class ChannelTable:
|
|||||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> ChannelModel:
|
) -> ChannelModel:
|
||||||
channel_data = ChannelModel.model_validate(channel).model_dump(
|
channel_data = ChannelModel.model_validate(channel).model_dump(exclude={'access_grants'})
|
||||||
exclude={"access_grants"}
|
channel_data['access_grants'] = (
|
||||||
)
|
access_grants if access_grants is not None else self._get_access_grants(channel_data['id'], db=db)
|
||||||
channel_data["access_grants"] = (
|
|
||||||
access_grants
|
|
||||||
if access_grants is not None
|
|
||||||
else self._get_access_grants(channel_data["id"], db=db)
|
|
||||||
)
|
)
|
||||||
return ChannelModel.model_validate(channel_data)
|
return ChannelModel.model_validate(channel_data)
|
||||||
|
|
||||||
@@ -313,20 +301,20 @@ class ChannelTable:
|
|||||||
for uid in user_ids:
|
for uid in user_ids:
|
||||||
model = ChannelMemberModel(
|
model = ChannelMemberModel(
|
||||||
**{
|
**{
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"channel_id": channel_id,
|
'channel_id': channel_id,
|
||||||
"user_id": uid,
|
'user_id': uid,
|
||||||
"status": "joined",
|
'status': 'joined',
|
||||||
"is_active": True,
|
'is_active': True,
|
||||||
"is_channel_muted": False,
|
'is_channel_muted': False,
|
||||||
"is_channel_pinned": False,
|
'is_channel_pinned': False,
|
||||||
"invited_at": now,
|
'invited_at': now,
|
||||||
"invited_by": invited_by,
|
'invited_by': invited_by,
|
||||||
"joined_at": now,
|
'joined_at': now,
|
||||||
"left_at": None,
|
'left_at': None,
|
||||||
"last_read_at": now,
|
'last_read_at': now,
|
||||||
"created_at": now,
|
'created_at': now,
|
||||||
"updated_at": now,
|
'updated_at': now,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
memberships.append(ChannelMember(**model.model_dump()))
|
memberships.append(ChannelMember(**model.model_dump()))
|
||||||
@@ -339,19 +327,19 @@ class ChannelTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
channel = ChannelModel(
|
channel = ChannelModel(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(exclude={"access_grants"}),
|
**form_data.model_dump(exclude={'access_grants'}),
|
||||||
"type": form_data.type if form_data.type else None,
|
'type': form_data.type if form_data.type else None,
|
||||||
"name": form_data.name.lower(),
|
'name': form_data.name.lower(),
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"created_at": int(time.time_ns()),
|
'created_at': int(time.time_ns()),
|
||||||
"updated_at": int(time.time_ns()),
|
'updated_at': int(time.time_ns()),
|
||||||
"access_grants": [],
|
'access_grants': [],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
new_channel = Channel(**channel.model_dump(exclude={"access_grants"}))
|
new_channel = Channel(**channel.model_dump(exclude={'access_grants'}))
|
||||||
|
|
||||||
if form_data.type in ["group", "dm"]:
|
if form_data.type in ['group', 'dm']:
|
||||||
users = self._collect_unique_user_ids(
|
users = self._collect_unique_user_ids(
|
||||||
invited_by=user_id,
|
invited_by=user_id,
|
||||||
user_ids=form_data.user_ids,
|
user_ids=form_data.user_ids,
|
||||||
@@ -366,18 +354,14 @@ class ChannelTable:
|
|||||||
db.add_all(memberships)
|
db.add_all(memberships)
|
||||||
db.add(new_channel)
|
db.add(new_channel)
|
||||||
db.commit()
|
db.commit()
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('channel', new_channel.id, form_data.access_grants, db=db)
|
||||||
"channel", new_channel.id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
return self._to_channel_model(new_channel, db=db)
|
return self._to_channel_model(new_channel, db=db)
|
||||||
|
|
||||||
def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]:
|
def get_channels(self, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
channels = db.query(Channel).all()
|
channels = db.query(Channel).all()
|
||||||
channel_ids = [channel.id for channel in channels]
|
channel_ids = [channel.id for channel in channels]
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
|
||||||
"channel", channel_ids, db=db
|
|
||||||
)
|
|
||||||
return [
|
return [
|
||||||
self._to_channel_model(
|
self._to_channel_model(
|
||||||
channel,
|
channel,
|
||||||
@@ -387,23 +371,19 @@ class ChannelTable:
|
|||||||
for channel in channels
|
for channel in channels
|
||||||
]
|
]
|
||||||
|
|
||||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
|
||||||
return AccessGrants.has_permission_filter(
|
return AccessGrants.has_permission_filter(
|
||||||
db=db,
|
db=db,
|
||||||
query=query,
|
query=query,
|
||||||
DocumentModel=Channel,
|
DocumentModel=Channel,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
resource_type="channel",
|
resource_type='channel',
|
||||||
permission=permission,
|
permission=permission,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_channels_by_user_id(
|
def get_channels_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[ChannelModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user_group_ids = [
|
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
]
|
|
||||||
|
|
||||||
membership_channels = (
|
membership_channels = (
|
||||||
db.query(Channel)
|
db.query(Channel)
|
||||||
@@ -411,7 +391,7 @@ class ChannelTable:
|
|||||||
.filter(
|
.filter(
|
||||||
Channel.deleted_at.is_(None),
|
Channel.deleted_at.is_(None),
|
||||||
Channel.archived_at.is_(None),
|
Channel.archived_at.is_(None),
|
||||||
Channel.type.in_(["group", "dm"]),
|
Channel.type.in_(['group', 'dm']),
|
||||||
ChannelMember.user_id == user_id,
|
ChannelMember.user_id == user_id,
|
||||||
ChannelMember.is_active.is_(True),
|
ChannelMember.is_active.is_(True),
|
||||||
)
|
)
|
||||||
@@ -423,29 +403,20 @@ class ChannelTable:
|
|||||||
Channel.archived_at.is_(None),
|
Channel.archived_at.is_(None),
|
||||||
or_(
|
or_(
|
||||||
Channel.type.is_(None), # True NULL/None
|
Channel.type.is_(None), # True NULL/None
|
||||||
Channel.type == "", # Empty string
|
Channel.type == '', # Empty string
|
||||||
and_(Channel.type != "group", Channel.type != "dm"),
|
and_(Channel.type != 'group', Channel.type != 'dm'),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
query = self._has_permission(
|
query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids})
|
||||||
db, query, {"user_id": user_id, "group_ids": user_group_ids}
|
|
||||||
)
|
|
||||||
|
|
||||||
standard_channels = query.all()
|
standard_channels = query.all()
|
||||||
|
|
||||||
all_channels = membership_channels + standard_channels
|
all_channels = membership_channels + standard_channels
|
||||||
channel_ids = [c.id for c in all_channels]
|
channel_ids = [c.id for c in all_channels]
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
|
||||||
"channel", channel_ids, db=db
|
return [self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) for c in all_channels]
|
||||||
)
|
|
||||||
return [
|
|
||||||
self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db)
|
|
||||||
for c in all_channels
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_dm_channel_by_user_ids(
|
def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> Optional[ChannelModel]:
|
||||||
self, user_ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> Optional[ChannelModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# Ensure uniqueness in case a list with duplicates is passed
|
# Ensure uniqueness in case a list with duplicates is passed
|
||||||
unique_user_ids = list(set(user_ids))
|
unique_user_ids = list(set(user_ids))
|
||||||
@@ -471,7 +442,7 @@ class ChannelTable:
|
|||||||
db.query(Channel)
|
db.query(Channel)
|
||||||
.filter(
|
.filter(
|
||||||
Channel.id.in_(subquery),
|
Channel.id.in_(subquery),
|
||||||
Channel.type == "dm",
|
Channel.type == 'dm',
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
@@ -488,32 +459,23 @@ class ChannelTable:
|
|||||||
) -> list[ChannelMemberModel]:
|
) -> list[ChannelMemberModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# 1. Collect all user_ids including groups + inviter
|
# 1. Collect all user_ids including groups + inviter
|
||||||
requested_users = self._collect_unique_user_ids(
|
requested_users = self._collect_unique_user_ids(invited_by, user_ids, group_ids)
|
||||||
invited_by, user_ids, group_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
existing_users = {
|
existing_users = {
|
||||||
row.user_id
|
row.user_id
|
||||||
for row in db.query(ChannelMember.user_id)
|
for row in db.query(ChannelMember.user_id).filter(ChannelMember.channel_id == channel_id).all()
|
||||||
.filter(ChannelMember.channel_id == channel_id)
|
|
||||||
.all()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
new_user_ids = requested_users - existing_users
|
new_user_ids = requested_users - existing_users
|
||||||
if not new_user_ids:
|
if not new_user_ids:
|
||||||
return [] # Nothing to add
|
return [] # Nothing to add
|
||||||
|
|
||||||
new_memberships = self._create_membership_models(
|
new_memberships = self._create_membership_models(channel_id, invited_by, new_user_ids)
|
||||||
channel_id, invited_by, new_user_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
db.add_all(new_memberships)
|
db.add_all(new_memberships)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return [
|
return [ChannelMemberModel.model_validate(membership) for membership in new_memberships]
|
||||||
ChannelMemberModel.model_validate(membership)
|
|
||||||
for membership in new_memberships
|
|
||||||
]
|
|
||||||
|
|
||||||
def remove_members_from_channel(
|
def remove_members_from_channel(
|
||||||
self,
|
self,
|
||||||
@@ -533,9 +495,7 @@ class ChannelTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return result # number of rows deleted
|
return result # number of rows deleted
|
||||||
|
|
||||||
def is_user_channel_manager(
|
def is_user_channel_manager(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# Check if the user is the creator of the channel
|
# Check if the user is the creator of the channel
|
||||||
# or has a 'manager' role in ChannelMember
|
# or has a 'manager' role in ChannelMember
|
||||||
@@ -548,15 +508,13 @@ class ChannelTable:
|
|||||||
.filter(
|
.filter(
|
||||||
ChannelMember.channel_id == channel_id,
|
ChannelMember.channel_id == channel_id,
|
||||||
ChannelMember.user_id == user_id,
|
ChannelMember.user_id == user_id,
|
||||||
ChannelMember.role == "manager",
|
ChannelMember.role == 'manager',
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
return membership is not None
|
return membership is not None
|
||||||
|
|
||||||
def join_channel(
|
def join_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> Optional[ChannelMemberModel]:
|
||||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[ChannelMemberModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# Check if the membership already exists
|
# Check if the membership already exists
|
||||||
existing_membership = (
|
existing_membership = (
|
||||||
@@ -573,18 +531,18 @@ class ChannelTable:
|
|||||||
# Create new membership
|
# Create new membership
|
||||||
channel_member = ChannelMemberModel(
|
channel_member = ChannelMemberModel(
|
||||||
**{
|
**{
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"channel_id": channel_id,
|
'channel_id': channel_id,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"status": "joined",
|
'status': 'joined',
|
||||||
"is_active": True,
|
'is_active': True,
|
||||||
"is_channel_muted": False,
|
'is_channel_muted': False,
|
||||||
"is_channel_pinned": False,
|
'is_channel_pinned': False,
|
||||||
"joined_at": int(time.time_ns()),
|
'joined_at': int(time.time_ns()),
|
||||||
"left_at": None,
|
'left_at': None,
|
||||||
"last_read_at": int(time.time_ns()),
|
'last_read_at': int(time.time_ns()),
|
||||||
"created_at": int(time.time_ns()),
|
'created_at': int(time.time_ns()),
|
||||||
"updated_at": int(time.time_ns()),
|
'updated_at': int(time.time_ns()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
new_membership = ChannelMember(**channel_member.model_dump())
|
new_membership = ChannelMember(**channel_member.model_dump())
|
||||||
@@ -593,9 +551,7 @@ class ChannelTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return channel_member
|
return channel_member
|
||||||
|
|
||||||
def leave_channel(
|
def leave_channel(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
@@ -608,7 +564,7 @@ class ChannelTable:
|
|||||||
if not membership:
|
if not membership:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
membership.status = "left"
|
membership.status = 'left'
|
||||||
membership.is_active = False
|
membership.is_active = False
|
||||||
membership.left_at = int(time.time_ns())
|
membership.left_at = int(time.time_ns())
|
||||||
membership.updated_at = int(time.time_ns())
|
membership.updated_at = int(time.time_ns())
|
||||||
@@ -630,19 +586,10 @@ class ChannelTable:
|
|||||||
)
|
)
|
||||||
return ChannelMemberModel.model_validate(membership) if membership else None
|
return ChannelMemberModel.model_validate(membership) if membership else None
|
||||||
|
|
||||||
def get_members_by_channel_id(
|
def get_members_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelMemberModel]:
|
||||||
self, channel_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[ChannelMemberModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
memberships = (
|
memberships = db.query(ChannelMember).filter(ChannelMember.channel_id == channel_id).all()
|
||||||
db.query(ChannelMember)
|
return [ChannelMemberModel.model_validate(membership) for membership in memberships]
|
||||||
.filter(ChannelMember.channel_id == channel_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
ChannelMemberModel.model_validate(membership)
|
|
||||||
for membership in memberships
|
|
||||||
]
|
|
||||||
|
|
||||||
def pin_channel(
|
def pin_channel(
|
||||||
self,
|
self,
|
||||||
@@ -669,9 +616,7 @@ class ChannelTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def update_member_last_read_at(
|
def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
@@ -715,9 +660,7 @@ class ChannelTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def is_user_channel_member(
|
def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, channel_id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
@@ -729,9 +672,7 @@ class ChannelTable:
|
|||||||
)
|
)
|
||||||
return membership is not None
|
return membership is not None
|
||||||
|
|
||||||
def get_channel_by_id(
|
def get_channel_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChannelModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[ChannelModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
channel = db.query(Channel).filter(Channel.id == id).first()
|
channel = db.query(Channel).filter(Channel.id == id).first()
|
||||||
@@ -739,18 +680,12 @@ class ChannelTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_channels_by_file_id(
|
def get_channels_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[ChannelModel]:
|
||||||
self, file_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[ChannelModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
channel_files = (
|
channel_files = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
|
||||||
)
|
|
||||||
channel_ids = [cf.channel_id for cf in channel_files]
|
channel_ids = [cf.channel_id for cf in channel_files]
|
||||||
channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all()
|
channels = db.query(Channel).filter(Channel.id.in_(channel_ids)).all()
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
|
||||||
"channel", channel_ids, db=db
|
|
||||||
)
|
|
||||||
return [
|
return [
|
||||||
self._to_channel_model(
|
self._to_channel_model(
|
||||||
channel,
|
channel,
|
||||||
@@ -765,9 +700,7 @@ class ChannelTable:
|
|||||||
) -> list[ChannelModel]:
|
) -> list[ChannelModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# 1. Determine which channels have this file
|
# 1. Determine which channels have this file
|
||||||
channel_file_rows = (
|
channel_file_rows = db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
||||||
db.query(ChannelFile).filter(ChannelFile.file_id == file_id).all()
|
|
||||||
)
|
|
||||||
channel_ids = [row.channel_id for row in channel_file_rows]
|
channel_ids = [row.channel_id for row in channel_file_rows]
|
||||||
|
|
||||||
if not channel_ids:
|
if not channel_ids:
|
||||||
@@ -787,15 +720,13 @@ class ChannelTable:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# Preload user's group membership
|
# Preload user's group membership
|
||||||
user_group_ids = [
|
user_group_ids = [g.id for g in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||||
g.id for g in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
]
|
|
||||||
|
|
||||||
allowed_channels = []
|
allowed_channels = []
|
||||||
|
|
||||||
for channel in channels:
|
for channel in channels:
|
||||||
# --- Case A: group or dm => user must be an active member ---
|
# --- Case A: group or dm => user must be an active member ---
|
||||||
if channel.type in ["group", "dm"]:
|
if channel.type in ['group', 'dm']:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
.filter(
|
.filter(
|
||||||
@@ -815,8 +746,8 @@ class ChannelTable:
|
|||||||
query = self._has_permission(
|
query = self._has_permission(
|
||||||
db,
|
db,
|
||||||
query,
|
query,
|
||||||
{"user_id": user_id, "group_ids": user_group_ids},
|
{'user_id': user_id, 'group_ids': user_group_ids},
|
||||||
permission="read",
|
permission='read',
|
||||||
)
|
)
|
||||||
|
|
||||||
allowed = query.first()
|
allowed = query.first()
|
||||||
@@ -844,7 +775,7 @@ class ChannelTable:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# If the channel is a group or dm, read access requires membership (active)
|
# If the channel is a group or dm, read access requires membership (active)
|
||||||
if channel.type in ["group", "dm"]:
|
if channel.type in ['group', 'dm']:
|
||||||
membership = (
|
membership = (
|
||||||
db.query(ChannelMember)
|
db.query(ChannelMember)
|
||||||
.filter(
|
.filter(
|
||||||
@@ -863,24 +794,18 @@ class ChannelTable:
|
|||||||
query = db.query(Channel).filter(Channel.id == id)
|
query = db.query(Channel).filter(Channel.id == id)
|
||||||
|
|
||||||
# Determine user groups
|
# Determine user groups
|
||||||
user_group_ids = [
|
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply ACL rules
|
# Apply ACL rules
|
||||||
query = self._has_permission(
|
query = self._has_permission(
|
||||||
db,
|
db,
|
||||||
query,
|
query,
|
||||||
{"user_id": user_id, "group_ids": user_group_ids},
|
{'user_id': user_id, 'group_ids': user_group_ids},
|
||||||
permission="read",
|
permission='read',
|
||||||
)
|
)
|
||||||
|
|
||||||
channel_allowed = query.first()
|
channel_allowed = query.first()
|
||||||
return (
|
return self._to_channel_model(channel_allowed, db=db) if channel_allowed else None
|
||||||
self._to_channel_model(channel_allowed, db=db)
|
|
||||||
if channel_allowed
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_channel_by_id(
|
def update_channel_by_id(
|
||||||
self, id: str, form_data: ChannelForm, db: Optional[Session] = None
|
self, id: str, form_data: ChannelForm, db: Optional[Session] = None
|
||||||
@@ -898,9 +823,7 @@ class ChannelTable:
|
|||||||
channel.meta = form_data.meta
|
channel.meta = form_data.meta
|
||||||
|
|
||||||
if form_data.access_grants is not None:
|
if form_data.access_grants is not None:
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('channel', id, form_data.access_grants, db=db)
|
||||||
"channel", id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
channel.updated_at = int(time.time_ns())
|
channel.updated_at = int(time.time_ns())
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -912,12 +835,12 @@ class ChannelTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
channel_file = ChannelFileModel(
|
channel_file = ChannelFileModel(
|
||||||
**{
|
**{
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"channel_id": channel_id,
|
'channel_id': channel_id,
|
||||||
"file_id": file_id,
|
'file_id': file_id,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -942,11 +865,7 @@ class ChannelTable:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
channel_file = (
|
channel_file = db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).first()
|
||||||
db.query(ChannelFile)
|
|
||||||
.filter_by(channel_id=channel_id, file_id=file_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if not channel_file:
|
if not channel_file:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -958,14 +877,10 @@ class ChannelTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def remove_file_from_channel_by_id(
|
def remove_file_from_channel_by_id(self, channel_id: str, file_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, channel_id: str, file_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
db.query(ChannelFile).filter_by(
|
db.query(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id).delete()
|
||||||
channel_id=channel_id, file_id=file_id
|
|
||||||
).delete()
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -973,7 +888,7 @@ class ChannelTable:
|
|||||||
|
|
||||||
def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
def delete_channel_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
AccessGrants.revoke_all_access("channel", id, db=db)
|
AccessGrants.revoke_all_access('channel', id, db=db)
|
||||||
db.query(Channel).filter(Channel.id == id).delete()
|
db.query(Channel).filter(Channel.id == id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
@@ -1005,24 +920,14 @@ class ChannelTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return webhook
|
return webhook
|
||||||
|
|
||||||
def get_webhooks_by_channel_id(
|
def get_webhooks_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> list[ChannelWebhookModel]:
|
||||||
self, channel_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[ChannelWebhookModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
webhooks = (
|
webhooks = db.query(ChannelWebhook).filter(ChannelWebhook.channel_id == channel_id).all()
|
||||||
db.query(ChannelWebhook)
|
|
||||||
.filter(ChannelWebhook.channel_id == channel_id)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [ChannelWebhookModel.model_validate(w) for w in webhooks]
|
return [ChannelWebhookModel.model_validate(w) for w in webhooks]
|
||||||
|
|
||||||
def get_webhook_by_id(
|
def get_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> Optional[ChannelWebhookModel]:
|
||||||
self, webhook_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[ChannelWebhookModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
webhook = (
|
webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
||||||
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
|
||||||
)
|
|
||||||
return ChannelWebhookModel.model_validate(webhook) if webhook else None
|
return ChannelWebhookModel.model_validate(webhook) if webhook else None
|
||||||
|
|
||||||
def get_webhook_by_id_and_token(
|
def get_webhook_by_id_and_token(
|
||||||
@@ -1046,9 +951,7 @@ class ChannelTable:
|
|||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> Optional[ChannelWebhookModel]:
|
) -> Optional[ChannelWebhookModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
webhook = (
|
webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
||||||
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
|
||||||
)
|
|
||||||
if not webhook:
|
if not webhook:
|
||||||
return None
|
return None
|
||||||
webhook.name = form_data.name
|
webhook.name = form_data.name
|
||||||
@@ -1057,28 +960,18 @@ class ChannelTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return ChannelWebhookModel.model_validate(webhook)
|
return ChannelWebhookModel.model_validate(webhook)
|
||||||
|
|
||||||
def update_webhook_last_used_at(
|
def update_webhook_last_used_at(self, webhook_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, webhook_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
webhook = (
|
webhook = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
||||||
db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).first()
|
|
||||||
)
|
|
||||||
if not webhook:
|
if not webhook:
|
||||||
return False
|
return False
|
||||||
webhook.last_used_at = int(time.time_ns())
|
webhook.last_used_at = int(time.time_ns())
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def delete_webhook_by_id(
|
def delete_webhook_by_id(self, webhook_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, webhook_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
result = (
|
result = db.query(ChannelWebhook).filter(ChannelWebhook.id == webhook_id).delete()
|
||||||
db.query(ChannelWebhook)
|
|
||||||
.filter(ChannelWebhook.id == webhook_id)
|
|
||||||
.delete()
|
|
||||||
)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return result > 0
|
return result > 0
|
||||||
|
|
||||||
|
|||||||
@@ -47,13 +47,11 @@ def _normalize_timestamp(timestamp: int) -> float:
|
|||||||
|
|
||||||
|
|
||||||
class ChatMessage(Base):
|
class ChatMessage(Base):
|
||||||
__tablename__ = "chat_message"
|
__tablename__ = 'chat_message'
|
||||||
|
|
||||||
# Identity
|
# Identity
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True)
|
||||||
chat_id = Column(
|
chat_id = Column(Text, ForeignKey('chat.id', ondelete='CASCADE'), nullable=False, index=True)
|
||||||
Text, ForeignKey("chat.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
user_id = Column(Text, index=True)
|
user_id = Column(Text, index=True)
|
||||||
|
|
||||||
# Structure
|
# Structure
|
||||||
@@ -85,9 +83,9 @@ class ChatMessage(Base):
|
|||||||
updated_at = Column(BigInteger)
|
updated_at = Column(BigInteger)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("chat_message_chat_parent_idx", "chat_id", "parent_id"),
|
Index('chat_message_chat_parent_idx', 'chat_id', 'parent_id'),
|
||||||
Index("chat_message_model_created_idx", "model_id", "created_at"),
|
Index('chat_message_model_created_idx', 'model_id', 'created_at'),
|
||||||
Index("chat_message_user_created_idx", "user_id", "created_at"),
|
Index('chat_message_user_created_idx', 'user_id', 'created_at'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -135,43 +133,41 @@ class ChatMessageTable:
|
|||||||
"""Insert or update a chat message."""
|
"""Insert or update a chat message."""
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
timestamp = data.get("timestamp", now)
|
timestamp = data.get('timestamp', now)
|
||||||
|
|
||||||
# Use composite ID: {chat_id}-{message_id}
|
# Use composite ID: {chat_id}-{message_id}
|
||||||
composite_id = f"{chat_id}-{message_id}"
|
composite_id = f'{chat_id}-{message_id}'
|
||||||
|
|
||||||
existing = db.get(ChatMessage, composite_id)
|
existing = db.get(ChatMessage, composite_id)
|
||||||
if existing:
|
if existing:
|
||||||
# Update existing
|
# Update existing
|
||||||
if "role" in data:
|
if 'role' in data:
|
||||||
existing.role = data["role"]
|
existing.role = data['role']
|
||||||
if "parent_id" in data:
|
if 'parent_id' in data:
|
||||||
existing.parent_id = data.get("parent_id") or data.get("parentId")
|
existing.parent_id = data.get('parent_id') or data.get('parentId')
|
||||||
if "content" in data:
|
if 'content' in data:
|
||||||
existing.content = data.get("content")
|
existing.content = data.get('content')
|
||||||
if "output" in data:
|
if 'output' in data:
|
||||||
existing.output = data.get("output")
|
existing.output = data.get('output')
|
||||||
if "model_id" in data or "model" in data:
|
if 'model_id' in data or 'model' in data:
|
||||||
existing.model_id = data.get("model_id") or data.get("model")
|
existing.model_id = data.get('model_id') or data.get('model')
|
||||||
if "files" in data:
|
if 'files' in data:
|
||||||
existing.files = data.get("files")
|
existing.files = data.get('files')
|
||||||
if "sources" in data:
|
if 'sources' in data:
|
||||||
existing.sources = data.get("sources")
|
existing.sources = data.get('sources')
|
||||||
if "embeds" in data:
|
if 'embeds' in data:
|
||||||
existing.embeds = data.get("embeds")
|
existing.embeds = data.get('embeds')
|
||||||
if "done" in data:
|
if 'done' in data:
|
||||||
existing.done = data.get("done", True)
|
existing.done = data.get('done', True)
|
||||||
if "status_history" in data or "statusHistory" in data:
|
if 'status_history' in data or 'statusHistory' in data:
|
||||||
existing.status_history = data.get("status_history") or data.get(
|
existing.status_history = data.get('status_history') or data.get('statusHistory')
|
||||||
"statusHistory"
|
if 'error' in data:
|
||||||
)
|
existing.error = data.get('error')
|
||||||
if "error" in data:
|
|
||||||
existing.error = data.get("error")
|
|
||||||
# Extract usage - check direct field first, then info.usage
|
# Extract usage - check direct field first, then info.usage
|
||||||
usage = data.get("usage")
|
usage = data.get('usage')
|
||||||
if not usage:
|
if not usage:
|
||||||
info = data.get("info", {})
|
info = data.get('info', {})
|
||||||
usage = info.get("usage") if info else None
|
usage = info.get('usage') if info else None
|
||||||
if usage:
|
if usage:
|
||||||
existing.usage = usage
|
existing.usage = usage
|
||||||
existing.updated_at = now
|
existing.updated_at = now
|
||||||
@@ -181,26 +177,25 @@ class ChatMessageTable:
|
|||||||
else:
|
else:
|
||||||
# Insert new
|
# Insert new
|
||||||
# Extract usage - check direct field first, then info.usage
|
# Extract usage - check direct field first, then info.usage
|
||||||
usage = data.get("usage")
|
usage = data.get('usage')
|
||||||
if not usage:
|
if not usage:
|
||||||
info = data.get("info", {})
|
info = data.get('info', {})
|
||||||
usage = info.get("usage") if info else None
|
usage = info.get('usage') if info else None
|
||||||
message = ChatMessage(
|
message = ChatMessage(
|
||||||
id=composite_id,
|
id=composite_id,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
role=data.get("role", "user"),
|
role=data.get('role', 'user'),
|
||||||
parent_id=data.get("parent_id") or data.get("parentId"),
|
parent_id=data.get('parent_id') or data.get('parentId'),
|
||||||
content=data.get("content"),
|
content=data.get('content'),
|
||||||
output=data.get("output"),
|
output=data.get('output'),
|
||||||
model_id=data.get("model_id") or data.get("model"),
|
model_id=data.get('model_id') or data.get('model'),
|
||||||
files=data.get("files"),
|
files=data.get('files'),
|
||||||
sources=data.get("sources"),
|
sources=data.get('sources'),
|
||||||
embeds=data.get("embeds"),
|
embeds=data.get('embeds'),
|
||||||
done=data.get("done", True),
|
done=data.get('done', True),
|
||||||
status_history=data.get("status_history")
|
status_history=data.get('status_history') or data.get('statusHistory'),
|
||||||
or data.get("statusHistory"),
|
error=data.get('error'),
|
||||||
error=data.get("error"),
|
|
||||||
usage=usage,
|
usage=usage,
|
||||||
created_at=timestamp,
|
created_at=timestamp,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -210,23 +205,14 @@ class ChatMessageTable:
|
|||||||
db.refresh(message)
|
db.refresh(message)
|
||||||
return ChatMessageModel.model_validate(message)
|
return ChatMessageModel.model_validate(message)
|
||||||
|
|
||||||
def get_message_by_id(
|
def get_message_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ChatMessageModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[ChatMessageModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
message = db.get(ChatMessage, id)
|
message = db.get(ChatMessage, id)
|
||||||
return ChatMessageModel.model_validate(message) if message else None
|
return ChatMessageModel.model_validate(message) if message else None
|
||||||
|
|
||||||
def get_messages_by_chat_id(
|
def get_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[ChatMessageModel]:
|
||||||
self, chat_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[ChatMessageModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
messages = (
|
messages = db.query(ChatMessage).filter_by(chat_id=chat_id).order_by(ChatMessage.created_at.asc()).all()
|
||||||
db.query(ChatMessage)
|
|
||||||
.filter_by(chat_id=chat_id)
|
|
||||||
.order_by(ChatMessage.created_at.asc())
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [ChatMessageModel.model_validate(message) for message in messages]
|
return [ChatMessageModel.model_validate(message) for message in messages]
|
||||||
|
|
||||||
def get_messages_by_user_id(
|
def get_messages_by_user_id(
|
||||||
@@ -262,12 +248,7 @@ class ChatMessageTable:
|
|||||||
query = query.filter(ChatMessage.created_at >= start_date)
|
query = query.filter(ChatMessage.created_at >= start_date)
|
||||||
if end_date:
|
if end_date:
|
||||||
query = query.filter(ChatMessage.created_at <= end_date)
|
query = query.filter(ChatMessage.created_at <= end_date)
|
||||||
messages = (
|
messages = query.order_by(ChatMessage.created_at.desc()).offset(skip).limit(limit).all()
|
||||||
query.order_by(ChatMessage.created_at.desc())
|
|
||||||
.offset(skip)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [ChatMessageModel.model_validate(message) for message in messages]
|
return [ChatMessageModel.model_validate(message) for message in messages]
|
||||||
|
|
||||||
def get_chat_ids_by_model_id(
|
def get_chat_ids_by_model_id(
|
||||||
@@ -284,7 +265,7 @@ class ChatMessageTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
query = db.query(
|
query = db.query(
|
||||||
ChatMessage.chat_id,
|
ChatMessage.chat_id,
|
||||||
func.max(ChatMessage.created_at).label("last_message_at"),
|
func.max(ChatMessage.created_at).label('last_message_at'),
|
||||||
).filter(ChatMessage.model_id == model_id)
|
).filter(ChatMessage.model_id == model_id)
|
||||||
if start_date:
|
if start_date:
|
||||||
query = query.filter(ChatMessage.created_at >= start_date)
|
query = query.filter(ChatMessage.created_at >= start_date)
|
||||||
@@ -303,9 +284,7 @@ class ChatMessageTable:
|
|||||||
)
|
)
|
||||||
return [chat_id for chat_id, _ in chat_ids]
|
return [chat_id for chat_id, _ in chat_ids]
|
||||||
|
|
||||||
def delete_messages_by_chat_id(
|
def delete_messages_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, chat_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
db.query(ChatMessage).filter_by(chat_id=chat_id).delete()
|
db.query(ChatMessage).filter_by(chat_id=chat_id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -323,12 +302,10 @@ class ChatMessageTable:
|
|||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from open_webui.models.groups import GroupMember
|
from open_webui.models.groups import GroupMember
|
||||||
|
|
||||||
query = db.query(
|
query = db.query(ChatMessage.model_id, func.count(ChatMessage.id).label('count')).filter(
|
||||||
ChatMessage.model_id, func.count(ChatMessage.id).label("count")
|
ChatMessage.role == 'assistant',
|
||||||
).filter(
|
|
||||||
ChatMessage.role == "assistant",
|
|
||||||
ChatMessage.model_id.isnot(None),
|
ChatMessage.model_id.isnot(None),
|
||||||
~ChatMessage.user_id.like("shared-%"),
|
~ChatMessage.user_id.like('shared-%'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if start_date:
|
if start_date:
|
||||||
@@ -336,11 +313,7 @@ class ChatMessageTable:
|
|||||||
if end_date:
|
if end_date:
|
||||||
query = query.filter(ChatMessage.created_at <= end_date)
|
query = query.filter(ChatMessage.created_at <= end_date)
|
||||||
if group_id:
|
if group_id:
|
||||||
group_users = (
|
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||||
db.query(GroupMember.user_id)
|
|
||||||
.filter(GroupMember.group_id == group_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||||
|
|
||||||
results = query.group_by(ChatMessage.model_id).all()
|
results = query.group_by(ChatMessage.model_id).all()
|
||||||
@@ -360,36 +333,32 @@ class ChatMessageTable:
|
|||||||
|
|
||||||
dialect = db.bind.dialect.name
|
dialect = db.bind.dialect.name
|
||||||
|
|
||||||
if dialect == "sqlite":
|
if dialect == 'sqlite':
|
||||||
input_tokens = cast(
|
input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer)
|
||||||
func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer
|
output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer)
|
||||||
)
|
elif dialect == 'postgresql':
|
||||||
output_tokens = cast(
|
|
||||||
func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer
|
|
||||||
)
|
|
||||||
elif dialect == "postgresql":
|
|
||||||
# Use json_extract_path_text for PostgreSQL JSON columns
|
# Use json_extract_path_text for PostgreSQL JSON columns
|
||||||
input_tokens = cast(
|
input_tokens = cast(
|
||||||
func.json_extract_path_text(ChatMessage.usage, "input_tokens"),
|
func.json_extract_path_text(ChatMessage.usage, 'input_tokens'),
|
||||||
Integer,
|
Integer,
|
||||||
)
|
)
|
||||||
output_tokens = cast(
|
output_tokens = cast(
|
||||||
func.json_extract_path_text(ChatMessage.usage, "output_tokens"),
|
func.json_extract_path_text(ChatMessage.usage, 'output_tokens'),
|
||||||
Integer,
|
Integer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported dialect: {dialect}")
|
raise NotImplementedError(f'Unsupported dialect: {dialect}')
|
||||||
|
|
||||||
query = db.query(
|
query = db.query(
|
||||||
ChatMessage.model_id,
|
ChatMessage.model_id,
|
||||||
func.coalesce(func.sum(input_tokens), 0).label("input_tokens"),
|
func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
|
||||||
func.coalesce(func.sum(output_tokens), 0).label("output_tokens"),
|
func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
|
||||||
func.count(ChatMessage.id).label("message_count"),
|
func.count(ChatMessage.id).label('message_count'),
|
||||||
).filter(
|
).filter(
|
||||||
ChatMessage.role == "assistant",
|
ChatMessage.role == 'assistant',
|
||||||
ChatMessage.model_id.isnot(None),
|
ChatMessage.model_id.isnot(None),
|
||||||
ChatMessage.usage.isnot(None),
|
ChatMessage.usage.isnot(None),
|
||||||
~ChatMessage.user_id.like("shared-%"),
|
~ChatMessage.user_id.like('shared-%'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if start_date:
|
if start_date:
|
||||||
@@ -397,21 +366,17 @@ class ChatMessageTable:
|
|||||||
if end_date:
|
if end_date:
|
||||||
query = query.filter(ChatMessage.created_at <= end_date)
|
query = query.filter(ChatMessage.created_at <= end_date)
|
||||||
if group_id:
|
if group_id:
|
||||||
group_users = (
|
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||||
db.query(GroupMember.user_id)
|
|
||||||
.filter(GroupMember.group_id == group_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||||
|
|
||||||
results = query.group_by(ChatMessage.model_id).all()
|
results = query.group_by(ChatMessage.model_id).all()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
row.model_id: {
|
row.model_id: {
|
||||||
"input_tokens": row.input_tokens,
|
'input_tokens': row.input_tokens,
|
||||||
"output_tokens": row.output_tokens,
|
'output_tokens': row.output_tokens,
|
||||||
"total_tokens": row.input_tokens + row.output_tokens,
|
'total_tokens': row.input_tokens + row.output_tokens,
|
||||||
"message_count": row.message_count,
|
'message_count': row.message_count,
|
||||||
}
|
}
|
||||||
for row in results
|
for row in results
|
||||||
}
|
}
|
||||||
@@ -430,36 +395,32 @@ class ChatMessageTable:
|
|||||||
|
|
||||||
dialect = db.bind.dialect.name
|
dialect = db.bind.dialect.name
|
||||||
|
|
||||||
if dialect == "sqlite":
|
if dialect == 'sqlite':
|
||||||
input_tokens = cast(
|
input_tokens = cast(func.json_extract(ChatMessage.usage, '$.input_tokens'), Integer)
|
||||||
func.json_extract(ChatMessage.usage, "$.input_tokens"), Integer
|
output_tokens = cast(func.json_extract(ChatMessage.usage, '$.output_tokens'), Integer)
|
||||||
)
|
elif dialect == 'postgresql':
|
||||||
output_tokens = cast(
|
|
||||||
func.json_extract(ChatMessage.usage, "$.output_tokens"), Integer
|
|
||||||
)
|
|
||||||
elif dialect == "postgresql":
|
|
||||||
# Use json_extract_path_text for PostgreSQL JSON columns
|
# Use json_extract_path_text for PostgreSQL JSON columns
|
||||||
input_tokens = cast(
|
input_tokens = cast(
|
||||||
func.json_extract_path_text(ChatMessage.usage, "input_tokens"),
|
func.json_extract_path_text(ChatMessage.usage, 'input_tokens'),
|
||||||
Integer,
|
Integer,
|
||||||
)
|
)
|
||||||
output_tokens = cast(
|
output_tokens = cast(
|
||||||
func.json_extract_path_text(ChatMessage.usage, "output_tokens"),
|
func.json_extract_path_text(ChatMessage.usage, 'output_tokens'),
|
||||||
Integer,
|
Integer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported dialect: {dialect}")
|
raise NotImplementedError(f'Unsupported dialect: {dialect}')
|
||||||
|
|
||||||
query = db.query(
|
query = db.query(
|
||||||
ChatMessage.user_id,
|
ChatMessage.user_id,
|
||||||
func.coalesce(func.sum(input_tokens), 0).label("input_tokens"),
|
func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
|
||||||
func.coalesce(func.sum(output_tokens), 0).label("output_tokens"),
|
func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
|
||||||
func.count(ChatMessage.id).label("message_count"),
|
func.count(ChatMessage.id).label('message_count'),
|
||||||
).filter(
|
).filter(
|
||||||
ChatMessage.role == "assistant",
|
ChatMessage.role == 'assistant',
|
||||||
ChatMessage.user_id.isnot(None),
|
ChatMessage.user_id.isnot(None),
|
||||||
ChatMessage.usage.isnot(None),
|
ChatMessage.usage.isnot(None),
|
||||||
~ChatMessage.user_id.like("shared-%"),
|
~ChatMessage.user_id.like('shared-%'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if start_date:
|
if start_date:
|
||||||
@@ -467,21 +428,17 @@ class ChatMessageTable:
|
|||||||
if end_date:
|
if end_date:
|
||||||
query = query.filter(ChatMessage.created_at <= end_date)
|
query = query.filter(ChatMessage.created_at <= end_date)
|
||||||
if group_id:
|
if group_id:
|
||||||
group_users = (
|
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||||
db.query(GroupMember.user_id)
|
|
||||||
.filter(GroupMember.group_id == group_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||||
|
|
||||||
results = query.group_by(ChatMessage.user_id).all()
|
results = query.group_by(ChatMessage.user_id).all()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
row.user_id: {
|
row.user_id: {
|
||||||
"input_tokens": row.input_tokens,
|
'input_tokens': row.input_tokens,
|
||||||
"output_tokens": row.output_tokens,
|
'output_tokens': row.output_tokens,
|
||||||
"total_tokens": row.input_tokens + row.output_tokens,
|
'total_tokens': row.input_tokens + row.output_tokens,
|
||||||
"message_count": row.message_count,
|
'message_count': row.message_count,
|
||||||
}
|
}
|
||||||
for row in results
|
for row in results
|
||||||
}
|
}
|
||||||
@@ -497,20 +454,16 @@ class ChatMessageTable:
|
|||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from open_webui.models.groups import GroupMember
|
from open_webui.models.groups import GroupMember
|
||||||
|
|
||||||
query = db.query(
|
query = db.query(ChatMessage.user_id, func.count(ChatMessage.id).label('count')).filter(
|
||||||
ChatMessage.user_id, func.count(ChatMessage.id).label("count")
|
~ChatMessage.user_id.like('shared-%')
|
||||||
).filter(~ChatMessage.user_id.like("shared-%"))
|
)
|
||||||
|
|
||||||
if start_date:
|
if start_date:
|
||||||
query = query.filter(ChatMessage.created_at >= start_date)
|
query = query.filter(ChatMessage.created_at >= start_date)
|
||||||
if end_date:
|
if end_date:
|
||||||
query = query.filter(ChatMessage.created_at <= end_date)
|
query = query.filter(ChatMessage.created_at <= end_date)
|
||||||
if group_id:
|
if group_id:
|
||||||
group_users = (
|
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||||
db.query(GroupMember.user_id)
|
|
||||||
.filter(GroupMember.group_id == group_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||||
|
|
||||||
results = query.group_by(ChatMessage.user_id).all()
|
results = query.group_by(ChatMessage.user_id).all()
|
||||||
@@ -527,20 +480,16 @@ class ChatMessageTable:
|
|||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from open_webui.models.groups import GroupMember
|
from open_webui.models.groups import GroupMember
|
||||||
|
|
||||||
query = db.query(
|
query = db.query(ChatMessage.chat_id, func.count(ChatMessage.id).label('count')).filter(
|
||||||
ChatMessage.chat_id, func.count(ChatMessage.id).label("count")
|
~ChatMessage.user_id.like('shared-%')
|
||||||
).filter(~ChatMessage.user_id.like("shared-%"))
|
)
|
||||||
|
|
||||||
if start_date:
|
if start_date:
|
||||||
query = query.filter(ChatMessage.created_at >= start_date)
|
query = query.filter(ChatMessage.created_at >= start_date)
|
||||||
if end_date:
|
if end_date:
|
||||||
query = query.filter(ChatMessage.created_at <= end_date)
|
query = query.filter(ChatMessage.created_at <= end_date)
|
||||||
if group_id:
|
if group_id:
|
||||||
group_users = (
|
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||||
db.query(GroupMember.user_id)
|
|
||||||
.filter(GroupMember.group_id == group_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||||
|
|
||||||
results = query.group_by(ChatMessage.chat_id).all()
|
results = query.group_by(ChatMessage.chat_id).all()
|
||||||
@@ -559,9 +508,9 @@ class ChatMessageTable:
|
|||||||
from open_webui.models.groups import GroupMember
|
from open_webui.models.groups import GroupMember
|
||||||
|
|
||||||
query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter(
|
query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter(
|
||||||
ChatMessage.role == "assistant",
|
ChatMessage.role == 'assistant',
|
||||||
ChatMessage.model_id.isnot(None),
|
ChatMessage.model_id.isnot(None),
|
||||||
~ChatMessage.user_id.like("shared-%"),
|
~ChatMessage.user_id.like('shared-%'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if start_date:
|
if start_date:
|
||||||
@@ -569,11 +518,7 @@ class ChatMessageTable:
|
|||||||
if end_date:
|
if end_date:
|
||||||
query = query.filter(ChatMessage.created_at <= end_date)
|
query = query.filter(ChatMessage.created_at <= end_date)
|
||||||
if group_id:
|
if group_id:
|
||||||
group_users = (
|
group_users = db.query(GroupMember.user_id).filter(GroupMember.group_id == group_id).subquery()
|
||||||
db.query(GroupMember.user_id)
|
|
||||||
.filter(GroupMember.group_id == group_id)
|
|
||||||
.subquery()
|
|
||||||
)
|
|
||||||
query = query.filter(ChatMessage.user_id.in_(group_users))
|
query = query.filter(ChatMessage.user_id.in_(group_users))
|
||||||
|
|
||||||
results = query.all()
|
results = query.all()
|
||||||
@@ -581,21 +526,17 @@ class ChatMessageTable:
|
|||||||
# Group by date -> model -> count
|
# Group by date -> model -> count
|
||||||
daily_counts: dict[str, dict[str, int]] = {}
|
daily_counts: dict[str, dict[str, int]] = {}
|
||||||
for timestamp, model_id in results:
|
for timestamp, model_id in results:
|
||||||
date_str = datetime.fromtimestamp(
|
date_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d')
|
||||||
_normalize_timestamp(timestamp)
|
|
||||||
).strftime("%Y-%m-%d")
|
|
||||||
if date_str not in daily_counts:
|
if date_str not in daily_counts:
|
||||||
daily_counts[date_str] = {}
|
daily_counts[date_str] = {}
|
||||||
daily_counts[date_str][model_id] = (
|
daily_counts[date_str][model_id] = daily_counts[date_str].get(model_id, 0) + 1
|
||||||
daily_counts[date_str].get(model_id, 0) + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fill in missing days
|
# Fill in missing days
|
||||||
if start_date and end_date:
|
if start_date and end_date:
|
||||||
current = datetime.fromtimestamp(_normalize_timestamp(start_date))
|
current = datetime.fromtimestamp(_normalize_timestamp(start_date))
|
||||||
end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date))
|
end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date))
|
||||||
while current <= end_dt:
|
while current <= end_dt:
|
||||||
date_str = current.strftime("%Y-%m-%d")
|
date_str = current.strftime('%Y-%m-%d')
|
||||||
if date_str not in daily_counts:
|
if date_str not in daily_counts:
|
||||||
daily_counts[date_str] = {}
|
daily_counts[date_str] = {}
|
||||||
current += timedelta(days=1)
|
current += timedelta(days=1)
|
||||||
@@ -613,9 +554,9 @@ class ChatMessageTable:
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter(
|
query = db.query(ChatMessage.created_at, ChatMessage.model_id).filter(
|
||||||
ChatMessage.role == "assistant",
|
ChatMessage.role == 'assistant',
|
||||||
ChatMessage.model_id.isnot(None),
|
ChatMessage.model_id.isnot(None),
|
||||||
~ChatMessage.user_id.like("shared-%"),
|
~ChatMessage.user_id.like('shared-%'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if start_date:
|
if start_date:
|
||||||
@@ -628,23 +569,19 @@ class ChatMessageTable:
|
|||||||
# Group by hour -> model -> count
|
# Group by hour -> model -> count
|
||||||
hourly_counts: dict[str, dict[str, int]] = {}
|
hourly_counts: dict[str, dict[str, int]] = {}
|
||||||
for timestamp, model_id in results:
|
for timestamp, model_id in results:
|
||||||
hour_str = datetime.fromtimestamp(
|
hour_str = datetime.fromtimestamp(_normalize_timestamp(timestamp)).strftime('%Y-%m-%d %H:00')
|
||||||
_normalize_timestamp(timestamp)
|
|
||||||
).strftime("%Y-%m-%d %H:00")
|
|
||||||
if hour_str not in hourly_counts:
|
if hour_str not in hourly_counts:
|
||||||
hourly_counts[hour_str] = {}
|
hourly_counts[hour_str] = {}
|
||||||
hourly_counts[hour_str][model_id] = (
|
hourly_counts[hour_str][model_id] = hourly_counts[hour_str].get(model_id, 0) + 1
|
||||||
hourly_counts[hour_str].get(model_id, 0) + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fill in missing hours
|
# Fill in missing hours
|
||||||
if start_date and end_date:
|
if start_date and end_date:
|
||||||
current = datetime.fromtimestamp(
|
current = datetime.fromtimestamp(_normalize_timestamp(start_date)).replace(
|
||||||
_normalize_timestamp(start_date)
|
minute=0, second=0, microsecond=0
|
||||||
).replace(minute=0, second=0, microsecond=0)
|
)
|
||||||
end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date))
|
end_dt = datetime.fromtimestamp(_normalize_timestamp(end_date))
|
||||||
while current <= end_dt:
|
while current <= end_dt:
|
||||||
hour_str = current.strftime("%Y-%m-%d %H:00")
|
hour_str = current.strftime('%Y-%m-%d %H:00')
|
||||||
if hour_str not in hourly_counts:
|
if hour_str not in hourly_counts:
|
||||||
hourly_counts[hour_str] = {}
|
hourly_counts[hour_str] = {}
|
||||||
current += timedelta(hours=1)
|
current += timedelta(hours=1)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Feedback(Base):
|
class Feedback(Base):
|
||||||
__tablename__ = "feedback"
|
__tablename__ = 'feedback'
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
version = Column(BigInteger, default=0)
|
version = Column(BigInteger, default=0)
|
||||||
@@ -81,7 +81,7 @@ class RatingData(BaseModel):
|
|||||||
sibling_model_ids: Optional[list[str]] = None
|
sibling_model_ids: Optional[list[str]] = None
|
||||||
reason: Optional[str] = None
|
reason: Optional[str] = None
|
||||||
comment: Optional[str] = None
|
comment: Optional[str] = None
|
||||||
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
model_config = ConfigDict(extra='allow', protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class MetaData(BaseModel):
|
class MetaData(BaseModel):
|
||||||
@@ -89,12 +89,12 @@ class MetaData(BaseModel):
|
|||||||
chat_id: Optional[str] = None
|
chat_id: Optional[str] = None
|
||||||
message_id: Optional[str] = None
|
message_id: Optional[str] = None
|
||||||
tags: Optional[list[str]] = None
|
tags: Optional[list[str]] = None
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
|
||||||
class SnapshotData(BaseModel):
|
class SnapshotData(BaseModel):
|
||||||
chat: Optional[dict] = None
|
chat: Optional[dict] = None
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
|
||||||
class FeedbackForm(BaseModel):
|
class FeedbackForm(BaseModel):
|
||||||
@@ -102,14 +102,14 @@ class FeedbackForm(BaseModel):
|
|||||||
data: Optional[RatingData] = None
|
data: Optional[RatingData] = None
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
snapshot: Optional[SnapshotData] = None
|
snapshot: Optional[SnapshotData] = None
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
class UserResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
email: str
|
email: str
|
||||||
role: str = "pending"
|
role: str = 'pending'
|
||||||
|
|
||||||
last_active_at: int # timestamp in epoch
|
last_active_at: int # timestamp in epoch
|
||||||
updated_at: int # timestamp in epoch
|
updated_at: int # timestamp in epoch
|
||||||
@@ -146,12 +146,12 @@ class FeedbackTable:
|
|||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
feedback = FeedbackModel(
|
feedback = FeedbackModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
'id': id,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"version": 0,
|
'version': 0,
|
||||||
**form_data.model_dump(),
|
**form_data.model_dump(),
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@@ -164,12 +164,10 @@ class FeedbackTable:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error creating a new feedback: {e}")
|
log.exception(f'Error creating a new feedback: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_feedback_by_id(
|
def get_feedback_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FeedbackModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[FeedbackModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
feedback = db.query(Feedback).filter_by(id=id).first()
|
feedback = db.query(Feedback).filter_by(id=id).first()
|
||||||
@@ -191,16 +189,14 @@ class FeedbackTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_feedbacks_by_chat_id(
|
def get_feedbacks_by_chat_id(self, chat_id: str, db: Optional[Session] = None) -> list[FeedbackModel]:
|
||||||
self, chat_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[FeedbackModel]:
|
|
||||||
"""Get all feedbacks for a specific chat."""
|
"""Get all feedbacks for a specific chat."""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# meta.chat_id stores the chat reference
|
# meta.chat_id stores the chat reference
|
||||||
feedbacks = (
|
feedbacks = (
|
||||||
db.query(Feedback)
|
db.query(Feedback)
|
||||||
.filter(Feedback.meta["chat_id"].as_string() == chat_id)
|
.filter(Feedback.meta['chat_id'].as_string() == chat_id)
|
||||||
.order_by(Feedback.created_at.desc())
|
.order_by(Feedback.created_at.desc())
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
@@ -219,36 +215,28 @@ class FeedbackTable:
|
|||||||
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
|
query = db.query(Feedback, User).join(User, Feedback.user_id == User.id)
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
order_by = filter.get("order_by")
|
order_by = filter.get('order_by')
|
||||||
direction = filter.get("direction")
|
direction = filter.get('direction')
|
||||||
|
|
||||||
if order_by == "username":
|
if order_by == 'username':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(User.name.asc())
|
query = query.order_by(User.name.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.name.desc())
|
query = query.order_by(User.name.desc())
|
||||||
elif order_by == "model_id":
|
elif order_by == 'model_id':
|
||||||
# it's stored in feedback.data['model_id']
|
# it's stored in feedback.data['model_id']
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(
|
query = query.order_by(Feedback.data['model_id'].as_string().asc())
|
||||||
Feedback.data["model_id"].as_string().asc()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
query = query.order_by(
|
query = query.order_by(Feedback.data['model_id'].as_string().desc())
|
||||||
Feedback.data["model_id"].as_string().desc()
|
elif order_by == 'rating':
|
||||||
)
|
|
||||||
elif order_by == "rating":
|
|
||||||
# it's stored in feedback.data['rating']
|
# it's stored in feedback.data['rating']
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(
|
query = query.order_by(Feedback.data['rating'].as_string().asc())
|
||||||
Feedback.data["rating"].as_string().asc()
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
query = query.order_by(
|
query = query.order_by(Feedback.data['rating'].as_string().desc())
|
||||||
Feedback.data["rating"].as_string().desc()
|
elif order_by == 'updated_at':
|
||||||
)
|
if direction == 'asc':
|
||||||
elif order_by == "updated_at":
|
|
||||||
if direction == "asc":
|
|
||||||
query = query.order_by(Feedback.updated_at.asc())
|
query = query.order_by(Feedback.updated_at.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(Feedback.updated_at.desc())
|
query = query.order_by(Feedback.updated_at.desc())
|
||||||
@@ -270,9 +258,7 @@ class FeedbackTable:
|
|||||||
for feedback, user in items:
|
for feedback, user in items:
|
||||||
feedback_model = FeedbackModel.model_validate(feedback)
|
feedback_model = FeedbackModel.model_validate(feedback)
|
||||||
user_model = UserResponse.model_validate(user)
|
user_model = UserResponse.model_validate(user)
|
||||||
feedbacks.append(
|
feedbacks.append(FeedbackUserResponse(**feedback_model.model_dump(), user=user_model))
|
||||||
FeedbackUserResponse(**feedback_model.model_dump(), user=user_model)
|
|
||||||
)
|
|
||||||
|
|
||||||
return FeedbackListResponse(items=feedbacks, total=total)
|
return FeedbackListResponse(items=feedbacks, total=total)
|
||||||
|
|
||||||
@@ -280,14 +266,10 @@ class FeedbackTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FeedbackModel.model_validate(feedback)
|
FeedbackModel.model_validate(feedback)
|
||||||
for feedback in db.query(Feedback)
|
for feedback in db.query(Feedback).order_by(Feedback.updated_at.desc()).all()
|
||||||
.order_by(Feedback.updated_at.desc())
|
|
||||||
.all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_all_feedback_ids(
|
def get_all_feedback_ids(self, db: Optional[Session] = None) -> list[FeedbackIdResponse]:
|
||||||
self, db: Optional[Session] = None
|
|
||||||
) -> list[FeedbackIdResponse]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FeedbackIdResponse(
|
FeedbackIdResponse(
|
||||||
@@ -306,14 +288,11 @@ class FeedbackTable:
|
|||||||
.all()
|
.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_feedbacks_for_leaderboard(
|
def get_feedbacks_for_leaderboard(self, db: Optional[Session] = None) -> list[LeaderboardFeedbackData]:
|
||||||
self, db: Optional[Session] = None
|
|
||||||
) -> list[LeaderboardFeedbackData]:
|
|
||||||
"""Fetch only id and data for leaderboard computation (excludes snapshot/meta)."""
|
"""Fetch only id and data for leaderboard computation (excludes snapshot/meta)."""
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
LeaderboardFeedbackData(id=row.id, data=row.data)
|
LeaderboardFeedbackData(id=row.id, data=row.data) for row in db.query(Feedback.id, Feedback.data).all()
|
||||||
for row in db.query(Feedback.id, Feedback.data).all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_model_evaluation_history(
|
def get_model_evaluation_history(
|
||||||
@@ -333,30 +312,26 @@ class FeedbackTable:
|
|||||||
rows = db.query(Feedback.created_at, Feedback.data).all()
|
rows = db.query(Feedback.created_at, Feedback.data).all()
|
||||||
else:
|
else:
|
||||||
cutoff = int(time.time()) - (days * 86400)
|
cutoff = int(time.time()) - (days * 86400)
|
||||||
rows = (
|
rows = db.query(Feedback.created_at, Feedback.data).filter(Feedback.created_at >= cutoff).all()
|
||||||
db.query(Feedback.created_at, Feedback.data)
|
|
||||||
.filter(Feedback.created_at >= cutoff)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
daily_counts = defaultdict(lambda: {"won": 0, "lost": 0})
|
daily_counts = defaultdict(lambda: {'won': 0, 'lost': 0})
|
||||||
first_date = None
|
first_date = None
|
||||||
|
|
||||||
for created_at, data in rows:
|
for created_at, data in rows:
|
||||||
if not data:
|
if not data:
|
||||||
continue
|
continue
|
||||||
if data.get("model_id") != model_id:
|
if data.get('model_id') != model_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
rating_str = str(data.get("rating", ""))
|
rating_str = str(data.get('rating', ''))
|
||||||
if rating_str not in ("1", "-1"):
|
if rating_str not in ('1', '-1'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
date_str = datetime.fromtimestamp(created_at).strftime("%Y-%m-%d")
|
date_str = datetime.fromtimestamp(created_at).strftime('%Y-%m-%d')
|
||||||
if rating_str == "1":
|
if rating_str == '1':
|
||||||
daily_counts[date_str]["won"] += 1
|
daily_counts[date_str]['won'] += 1
|
||||||
else:
|
else:
|
||||||
daily_counts[date_str]["lost"] += 1
|
daily_counts[date_str]['lost'] += 1
|
||||||
|
|
||||||
# Track first date for this model
|
# Track first date for this model
|
||||||
if first_date is None or date_str < first_date:
|
if first_date is None or date_str < first_date:
|
||||||
@@ -368,7 +343,7 @@ class FeedbackTable:
|
|||||||
|
|
||||||
if days == 0 and first_date:
|
if days == 0 and first_date:
|
||||||
# All time: start from first feedback date
|
# All time: start from first feedback date
|
||||||
start_date = datetime.strptime(first_date, "%Y-%m-%d").date()
|
start_date = datetime.strptime(first_date, '%Y-%m-%d').date()
|
||||||
num_days = (today - start_date).days + 1
|
num_days = (today - start_date).days + 1
|
||||||
else:
|
else:
|
||||||
# Fixed range
|
# Fixed range
|
||||||
@@ -377,36 +352,24 @@ class FeedbackTable:
|
|||||||
|
|
||||||
for i in range(num_days):
|
for i in range(num_days):
|
||||||
d = start_date + timedelta(days=i)
|
d = start_date + timedelta(days=i)
|
||||||
date_str = d.strftime("%Y-%m-%d")
|
date_str = d.strftime('%Y-%m-%d')
|
||||||
counts = daily_counts.get(date_str, {"won": 0, "lost": 0})
|
counts = daily_counts.get(date_str, {'won': 0, 'lost': 0})
|
||||||
result.append(
|
result.append(ModelHistoryEntry(date=date_str, won=counts['won'], lost=counts['lost']))
|
||||||
ModelHistoryEntry(date=date_str, won=counts["won"], lost=counts["lost"])
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_feedbacks_by_type(
|
def get_feedbacks_by_type(self, type: str, db: Optional[Session] = None) -> list[FeedbackModel]:
|
||||||
self, type: str, db: Optional[Session] = None
|
|
||||||
) -> list[FeedbackModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FeedbackModel.model_validate(feedback)
|
FeedbackModel.model_validate(feedback)
|
||||||
for feedback in db.query(Feedback)
|
for feedback in db.query(Feedback).filter_by(type=type).order_by(Feedback.updated_at.desc()).all()
|
||||||
.filter_by(type=type)
|
|
||||||
.order_by(Feedback.updated_at.desc())
|
|
||||||
.all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_feedbacks_by_user_id(
|
def get_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FeedbackModel]:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[FeedbackModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FeedbackModel.model_validate(feedback)
|
FeedbackModel.model_validate(feedback)
|
||||||
for feedback in db.query(Feedback)
|
for feedback in db.query(Feedback).filter_by(user_id=user_id).order_by(Feedback.updated_at.desc()).all()
|
||||||
.filter_by(user_id=user_id)
|
|
||||||
.order_by(Feedback.updated_at.desc())
|
|
||||||
.all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def update_feedback_by_id(
|
def update_feedback_by_id(
|
||||||
@@ -462,9 +425,7 @@ class FeedbackTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def delete_feedback_by_id_and_user_id(
|
def delete_feedback_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
|
feedback = db.query(Feedback).filter_by(id=id, user_id=user_id).first()
|
||||||
if not feedback:
|
if not feedback:
|
||||||
@@ -473,9 +434,7 @@ class FeedbackTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def delete_feedbacks_by_user_id(
|
def delete_feedbacks_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
result = db.query(Feedback).filter_by(user_id=user_id).delete()
|
result = db.query(Feedback).filter_by(user_id=user_id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class File(Base):
|
class File(Base):
|
||||||
__tablename__ = "file"
|
__tablename__ = 'file'
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
hash = Column(Text, nullable=True)
|
hash = Column(Text, nullable=True)
|
||||||
@@ -58,9 +58,9 @@ class FileMeta(BaseModel):
|
|||||||
content_type: Optional[str] = None
|
content_type: Optional[str] = None
|
||||||
size: Optional[int] = None
|
size: Optional[int] = None
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode='before')
|
||||||
@classmethod
|
@classmethod
|
||||||
def sanitize_meta(cls, data):
|
def sanitize_meta(cls, data):
|
||||||
"""Sanitize metadata fields to handle malformed legacy data."""
|
"""Sanitize metadata fields to handle malformed legacy data."""
|
||||||
@@ -68,14 +68,12 @@ class FileMeta(BaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
# Handle content_type that may be a list like ['application/pdf', None]
|
# Handle content_type that may be a list like ['application/pdf', None]
|
||||||
content_type = data.get("content_type")
|
content_type = data.get('content_type')
|
||||||
if isinstance(content_type, list):
|
if isinstance(content_type, list):
|
||||||
# Extract first non-None string value
|
# Extract first non-None string value
|
||||||
data["content_type"] = next(
|
data['content_type'] = next((item for item in content_type if isinstance(item, str)), None)
|
||||||
(item for item in content_type if isinstance(item, str)), None
|
|
||||||
)
|
|
||||||
elif content_type is not None and not isinstance(content_type, str):
|
elif content_type is not None and not isinstance(content_type, str):
|
||||||
data["content_type"] = None
|
data['content_type'] = None
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -92,7 +90,7 @@ class FileModelResponse(BaseModel):
|
|||||||
created_at: int # timestamp in epoch
|
created_at: int # timestamp in epoch
|
||||||
updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files
|
updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
|
||||||
class FileMetadataResponse(BaseModel):
|
class FileMetadataResponse(BaseModel):
|
||||||
@@ -123,25 +121,22 @@ class FileUpdateForm(BaseModel):
|
|||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FilesTable:
|
class FilesTable:
|
||||||
def insert_new_file(
|
def insert_new_file(self, user_id: str, form_data: FileForm, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||||
self, user_id: str, form_data: FileForm, db: Optional[Session] = None
|
|
||||||
) -> Optional[FileModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
file_data = form_data.model_dump()
|
file_data = form_data.model_dump()
|
||||||
|
|
||||||
# Sanitize meta to remove non-JSON-serializable objects
|
# Sanitize meta to remove non-JSON-serializable objects
|
||||||
# (e.g. callable tool functions, MCP client instances from middleware)
|
# (e.g. callable tool functions, MCP client instances from middleware)
|
||||||
if file_data.get("meta"):
|
if file_data.get('meta'):
|
||||||
file_data["meta"] = sanitize_metadata(file_data["meta"])
|
file_data['meta'] = sanitize_metadata(file_data['meta'])
|
||||||
|
|
||||||
file = FileModel(
|
file = FileModel(
|
||||||
**{
|
**{
|
||||||
**file_data,
|
**file_data,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -155,12 +150,10 @@ class FilesTable:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error inserting a new file: {e}")
|
log.exception(f'Error inserting a new file: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_file_by_id(
|
def get_file_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[FileModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
@@ -171,9 +164,7 @@ class FilesTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_file_by_id_and_user_id(
|
def get_file_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||||
self, id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[FileModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id, user_id=user_id).first()
|
file = db.query(File).filter_by(id=id, user_id=user_id).first()
|
||||||
@@ -184,9 +175,7 @@ class FilesTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_file_metadata_by_id(
|
def get_file_metadata_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FileMetadataResponse]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[FileMetadataResponse]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.get(File, id)
|
file = db.get(File, id)
|
||||||
@@ -204,9 +193,7 @@ class FilesTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
return [FileModel.model_validate(file) for file in db.query(File).all()]
|
||||||
|
|
||||||
def check_access_by_user_id(
|
def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool:
|
||||||
self, id, user_id, permission="write", db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
file = self.get_file_by_id(id, db=db)
|
file = self.get_file_by_id(id, db=db)
|
||||||
if not file:
|
if not file:
|
||||||
return False
|
return False
|
||||||
@@ -215,21 +202,14 @@ class FilesTable:
|
|||||||
# Implement additional access control logic here as needed
|
# Implement additional access control logic here as needed
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_files_by_ids(
|
def get_files_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileModel]:
|
||||||
self, ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> list[FileModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FileModel.model_validate(file)
|
FileModel.model_validate(file)
|
||||||
for file in db.query(File)
|
for file in db.query(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc()).all()
|
||||||
.filter(File.id.in_(ids))
|
|
||||||
.order_by(File.updated_at.desc())
|
|
||||||
.all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_file_metadatas_by_ids(
|
def get_file_metadatas_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FileMetadataResponse]:
|
||||||
self, ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> list[FileMetadataResponse]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FileMetadataResponse(
|
FileMetadataResponse(
|
||||||
@@ -239,22 +219,15 @@ class FilesTable:
|
|||||||
created_at=file.created_at,
|
created_at=file.created_at,
|
||||||
updated_at=file.updated_at,
|
updated_at=file.updated_at,
|
||||||
)
|
)
|
||||||
for file in db.query(
|
for file in db.query(File.id, File.hash, File.meta, File.created_at, File.updated_at)
|
||||||
File.id, File.hash, File.meta, File.created_at, File.updated_at
|
|
||||||
)
|
|
||||||
.filter(File.id.in_(ids))
|
.filter(File.id.in_(ids))
|
||||||
.order_by(File.updated_at.desc())
|
.order_by(File.updated_at.desc())
|
||||||
.all()
|
.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_files_by_user_id(
|
def get_files_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FileModel]:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[FileModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [FileModel.model_validate(file) for file in db.query(File).filter_by(user_id=user_id).all()]
|
||||||
FileModel.model_validate(file)
|
|
||||||
for file in db.query(File).filter_by(user_id=user_id).all()
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_file_list(
|
def get_file_list(
|
||||||
self,
|
self,
|
||||||
@@ -262,7 +235,7 @@ class FilesTable:
|
|||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> "FileListResponse":
|
) -> 'FileListResponse':
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
query = db.query(File)
|
query = db.query(File)
|
||||||
if user_id:
|
if user_id:
|
||||||
@@ -272,10 +245,7 @@ class FilesTable:
|
|||||||
|
|
||||||
items = [
|
items = [
|
||||||
FileModel.model_validate(file)
|
FileModel.model_validate(file)
|
||||||
for file in query.order_by(File.updated_at.desc(), File.id.desc())
|
for file in query.order_by(File.updated_at.desc(), File.id.desc()).offset(skip).limit(limit).all()
|
||||||
.offset(skip)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return FileListResponse(items=items, total=total)
|
return FileListResponse(items=items, total=total)
|
||||||
@@ -296,17 +266,17 @@ class FilesTable:
|
|||||||
A SQL LIKE compatible pattern with proper escaping.
|
A SQL LIKE compatible pattern with proper escaping.
|
||||||
"""
|
"""
|
||||||
# Escape SQL special characters first, then convert glob wildcards
|
# Escape SQL special characters first, then convert glob wildcards
|
||||||
pattern = glob.replace("\\", "\\\\")
|
pattern = glob.replace('\\', '\\\\')
|
||||||
pattern = pattern.replace("%", "\\%")
|
pattern = pattern.replace('%', '\\%')
|
||||||
pattern = pattern.replace("_", "\\_")
|
pattern = pattern.replace('_', '\\_')
|
||||||
pattern = pattern.replace("*", "%")
|
pattern = pattern.replace('*', '%')
|
||||||
pattern = pattern.replace("?", "_")
|
pattern = pattern.replace('?', '_')
|
||||||
return pattern
|
return pattern
|
||||||
|
|
||||||
def search_files(
|
def search_files(
|
||||||
self,
|
self,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
filename: str = "*",
|
filename: str = '*',
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
@@ -331,15 +301,12 @@ class FilesTable:
|
|||||||
query = query.filter_by(user_id=user_id)
|
query = query.filter_by(user_id=user_id)
|
||||||
|
|
||||||
pattern = self._glob_to_like_pattern(filename)
|
pattern = self._glob_to_like_pattern(filename)
|
||||||
if pattern != "%":
|
if pattern != '%':
|
||||||
query = query.filter(File.filename.ilike(pattern, escape="\\"))
|
query = query.filter(File.filename.ilike(pattern, escape='\\'))
|
||||||
|
|
||||||
return [
|
return [
|
||||||
FileModel.model_validate(file)
|
FileModel.model_validate(file)
|
||||||
for file in query.order_by(File.created_at.desc(), File.id.desc())
|
for file in query.order_by(File.created_at.desc(), File.id.desc()).offset(skip).limit(limit).all()
|
||||||
.offset(skip)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def update_file_by_id(
|
def update_file_by_id(
|
||||||
@@ -362,12 +329,10 @@ class FilesTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return FileModel.model_validate(file)
|
return FileModel.model_validate(file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error updating file completely by id: {e}")
|
log.exception(f'Error updating file completely by id: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_file_hash_by_id(
|
def update_file_hash_by_id(self, id: str, hash: Optional[str], db: Optional[Session] = None) -> Optional[FileModel]:
|
||||||
self, id: str, hash: Optional[str], db: Optional[Session] = None
|
|
||||||
) -> Optional[FileModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id).first()
|
file = db.query(File).filter_by(id=id).first()
|
||||||
@@ -379,9 +344,7 @@ class FilesTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_file_data_by_id(
|
def update_file_data_by_id(self, id: str, data: dict, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||||
self, id: str, data: dict, db: Optional[Session] = None
|
|
||||||
) -> Optional[FileModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id).first()
|
file = db.query(File).filter_by(id=id).first()
|
||||||
@@ -390,12 +353,9 @@ class FilesTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return FileModel.model_validate(file)
|
return FileModel.model_validate(file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_file_metadata_by_id(
|
def update_file_metadata_by_id(self, id: str, meta: dict, db: Optional[Session] = None) -> Optional[FileModel]:
|
||||||
self, id: str, meta: dict, db: Optional[Session] = None
|
|
||||||
) -> Optional[FileModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
file = db.query(File).filter_by(id=id).first()
|
file = db.query(File).filter_by(id=id).first()
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Folder(Base):
|
class Folder(Base):
|
||||||
__tablename__ = "folder"
|
__tablename__ = 'folder'
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
parent_id = Column(Text, nullable=True)
|
parent_id = Column(Text, nullable=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
@@ -72,14 +72,14 @@ class FolderForm(BaseModel):
|
|||||||
data: Optional[dict] = None
|
data: Optional[dict] = None
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
parent_id: Optional[str] = None
|
parent_id: Optional[str] = None
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
|
||||||
class FolderUpdateForm(BaseModel):
|
class FolderUpdateForm(BaseModel):
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
data: Optional[dict] = None
|
data: Optional[dict] = None
|
||||||
meta: Optional[dict] = None
|
meta: Optional[dict] = None
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
|
||||||
class FolderTable:
|
class FolderTable:
|
||||||
@@ -94,12 +94,12 @@ class FolderTable:
|
|||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
folder = FolderModel(
|
folder = FolderModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
'id': id,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
**(form_data.model_dump(exclude_unset=True) or {}),
|
**(form_data.model_dump(exclude_unset=True) or {}),
|
||||||
"parent_id": parent_id,
|
'parent_id': parent_id,
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@@ -112,7 +112,7 @@ class FolderTable:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error inserting a new folder: {e}")
|
log.exception(f'Error inserting a new folder: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_folder_by_id_and_user_id(
|
def get_folder_by_id_and_user_id(
|
||||||
@@ -137,9 +137,7 @@ class FolderTable:
|
|||||||
folders = []
|
folders = []
|
||||||
|
|
||||||
def get_children(folder):
|
def get_children(folder):
|
||||||
children = self.get_folders_by_parent_id_and_user_id(
|
children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db)
|
||||||
folder.id, user_id, db=db
|
|
||||||
)
|
|
||||||
for child in children:
|
for child in children:
|
||||||
get_children(child)
|
get_children(child)
|
||||||
folders.append(child)
|
folders.append(child)
|
||||||
@@ -153,14 +151,9 @@ class FolderTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_folders_by_user_id(
|
def get_folders_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[FolderModel]:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[FolderModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [FolderModel.model_validate(folder) for folder in db.query(Folder).filter_by(user_id=user_id).all()]
|
||||||
FolderModel.model_validate(folder)
|
|
||||||
for folder in db.query(Folder).filter_by(user_id=user_id).all()
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_folder_by_parent_id_and_user_id_and_name(
|
def get_folder_by_parent_id_and_user_id_and_name(
|
||||||
self,
|
self,
|
||||||
@@ -184,7 +177,7 @@ class FolderTable:
|
|||||||
|
|
||||||
return FolderModel.model_validate(folder)
|
return FolderModel.model_validate(folder)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"get_folder_by_parent_id_and_user_id_and_name: {e}")
|
log.error(f'get_folder_by_parent_id_and_user_id_and_name: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_folders_by_parent_id_and_user_id(
|
def get_folders_by_parent_id_and_user_id(
|
||||||
@@ -193,9 +186,7 @@ class FolderTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FolderModel.model_validate(folder)
|
FolderModel.model_validate(folder)
|
||||||
for folder in db.query(Folder)
|
for folder in db.query(Folder).filter_by(parent_id=parent_id, user_id=user_id).all()
|
||||||
.filter_by(parent_id=parent_id, user_id=user_id)
|
|
||||||
.all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def update_folder_parent_id_by_id_and_user_id(
|
def update_folder_parent_id_by_id_and_user_id(
|
||||||
@@ -219,7 +210,7 @@ class FolderTable:
|
|||||||
|
|
||||||
return FolderModel.model_validate(folder)
|
return FolderModel.model_validate(folder)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"update_folder: {e}")
|
log.error(f'update_folder: {e}')
|
||||||
return
|
return
|
||||||
|
|
||||||
def update_folder_by_id_and_user_id(
|
def update_folder_by_id_and_user_id(
|
||||||
@@ -241,7 +232,7 @@ class FolderTable:
|
|||||||
existing_folder = (
|
existing_folder = (
|
||||||
db.query(Folder)
|
db.query(Folder)
|
||||||
.filter_by(
|
.filter_by(
|
||||||
name=form_data.get("name"),
|
name=form_data.get('name'),
|
||||||
parent_id=folder.parent_id,
|
parent_id=folder.parent_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
@@ -251,17 +242,17 @@ class FolderTable:
|
|||||||
if existing_folder and existing_folder.id != id:
|
if existing_folder and existing_folder.id != id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
folder.name = form_data.get("name", folder.name)
|
folder.name = form_data.get('name', folder.name)
|
||||||
if "data" in form_data:
|
if 'data' in form_data:
|
||||||
folder.data = {
|
folder.data = {
|
||||||
**(folder.data or {}),
|
**(folder.data or {}),
|
||||||
**form_data["data"],
|
**form_data['data'],
|
||||||
}
|
}
|
||||||
|
|
||||||
if "meta" in form_data:
|
if 'meta' in form_data:
|
||||||
folder.meta = {
|
folder.meta = {
|
||||||
**(folder.meta or {}),
|
**(folder.meta or {}),
|
||||||
**form_data["meta"],
|
**form_data['meta'],
|
||||||
}
|
}
|
||||||
|
|
||||||
folder.updated_at = int(time.time())
|
folder.updated_at = int(time.time())
|
||||||
@@ -269,7 +260,7 @@ class FolderTable:
|
|||||||
|
|
||||||
return FolderModel.model_validate(folder)
|
return FolderModel.model_validate(folder)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"update_folder: {e}")
|
log.error(f'update_folder: {e}')
|
||||||
return
|
return
|
||||||
|
|
||||||
def update_folder_is_expanded_by_id_and_user_id(
|
def update_folder_is_expanded_by_id_and_user_id(
|
||||||
@@ -289,12 +280,10 @@ class FolderTable:
|
|||||||
|
|
||||||
return FolderModel.model_validate(folder)
|
return FolderModel.model_validate(folder)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"update_folder: {e}")
|
log.error(f'update_folder: {e}')
|
||||||
return
|
return
|
||||||
|
|
||||||
def delete_folder_by_id_and_user_id(
|
def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> list[str]:
|
||||||
self, id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[str]:
|
|
||||||
try:
|
try:
|
||||||
folder_ids = []
|
folder_ids = []
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
@@ -306,11 +295,8 @@ class FolderTable:
|
|||||||
|
|
||||||
# Delete all children folders
|
# Delete all children folders
|
||||||
def delete_children(folder):
|
def delete_children(folder):
|
||||||
folder_children = self.get_folders_by_parent_id_and_user_id(
|
folder_children = self.get_folders_by_parent_id_and_user_id(folder.id, user_id, db=db)
|
||||||
folder.id, user_id, db=db
|
|
||||||
)
|
|
||||||
for folder_child in folder_children:
|
for folder_child in folder_children:
|
||||||
|
|
||||||
delete_children(folder_child)
|
delete_children(folder_child)
|
||||||
folder_ids.append(folder_child.id)
|
folder_ids.append(folder_child.id)
|
||||||
|
|
||||||
@@ -323,12 +309,12 @@ class FolderTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return folder_ids
|
return folder_ids
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"delete_folder: {e}")
|
log.error(f'delete_folder: {e}')
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def normalize_folder_name(self, name: str) -> str:
|
def normalize_folder_name(self, name: str) -> str:
|
||||||
# Replace _ and space with a single space, lower case, collapse multiple spaces
|
# Replace _ and space with a single space, lower case, collapse multiple spaces
|
||||||
name = re.sub(r"[\s_]+", " ", name)
|
name = re.sub(r'[\s_]+', ' ', name)
|
||||||
return name.strip().lower()
|
return name.strip().lower()
|
||||||
|
|
||||||
def search_folders_by_names(
|
def search_folders_by_names(
|
||||||
@@ -349,9 +335,7 @@ class FolderTable:
|
|||||||
results[folder.id] = FolderModel.model_validate(folder)
|
results[folder.id] = FolderModel.model_validate(folder)
|
||||||
|
|
||||||
# get children folders
|
# get children folders
|
||||||
children = self.get_children_folders_by_id_and_user_id(
|
children = self.get_children_folders_by_id_and_user_id(folder.id, user_id, db=db)
|
||||||
folder.id, user_id, db=db
|
|
||||||
)
|
|
||||||
for child in children:
|
for child in children:
|
||||||
results[child.id] = child
|
results[child.id] = child
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Function(Base):
|
class Function(Base):
|
||||||
__tablename__ = "function"
|
__tablename__ = 'function'
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
@@ -30,13 +30,13 @@ class Function(Base):
|
|||||||
updated_at = Column(BigInteger)
|
updated_at = Column(BigInteger)
|
||||||
created_at = Column(BigInteger)
|
created_at = Column(BigInteger)
|
||||||
|
|
||||||
__table_args__ = (Index("is_global_idx", "is_global"),)
|
__table_args__ = (Index('is_global_idx', 'is_global'),)
|
||||||
|
|
||||||
|
|
||||||
class FunctionMeta(BaseModel):
|
class FunctionMeta(BaseModel):
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
manifest: Optional[dict] = {}
|
manifest: Optional[dict] = {}
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
|
||||||
class FunctionModel(BaseModel):
|
class FunctionModel(BaseModel):
|
||||||
@@ -113,10 +113,10 @@ class FunctionsTable:
|
|||||||
function = FunctionModel(
|
function = FunctionModel(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(),
|
**form_data.model_dump(),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"type": type,
|
'type': type,
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -131,7 +131,7 @@ class FunctionsTable:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error creating a new function: {e}")
|
log.exception(f'Error creating a new function: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def sync_functions(
|
def sync_functions(
|
||||||
@@ -156,16 +156,16 @@ class FunctionsTable:
|
|||||||
db.query(Function).filter_by(id=func.id).update(
|
db.query(Function).filter_by(id=func.id).update(
|
||||||
{
|
{
|
||||||
**func.model_dump(),
|
**func.model_dump(),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_func = Function(
|
new_func = Function(
|
||||||
**{
|
**{
|
||||||
**func.model_dump(),
|
**func.model_dump(),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.add(new_func)
|
db.add(new_func)
|
||||||
@@ -177,17 +177,12 @@ class FunctionsTable:
|
|||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return [
|
return [FunctionModel.model_validate(func) for func in db.query(Function).all()]
|
||||||
FunctionModel.model_validate(func)
|
|
||||||
for func in db.query(Function).all()
|
|
||||||
]
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error syncing functions for user {user_id}: {e}")
|
log.exception(f'Error syncing functions for user {user_id}: {e}')
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_function_by_id(
|
def get_function_by_id(self, id: str, db: Optional[Session] = None) -> Optional[FunctionModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[FunctionModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
function = db.get(Function, id)
|
function = db.get(Function, id)
|
||||||
@@ -195,9 +190,7 @@ class FunctionsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_functions_by_ids(
|
def get_functions_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[FunctionModel]:
|
||||||
self, ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> list[FunctionModel]:
|
|
||||||
"""
|
"""
|
||||||
Batch fetch multiple functions by their IDs in a single query.
|
Batch fetch multiple functions by their IDs in a single query.
|
||||||
Returns functions in the same order as the input IDs (None entries filtered out).
|
Returns functions in the same order as the input IDs (None entries filtered out).
|
||||||
@@ -225,18 +218,11 @@ class FunctionsTable:
|
|||||||
functions = db.query(Function).all()
|
functions = db.query(Function).all()
|
||||||
|
|
||||||
if include_valves:
|
if include_valves:
|
||||||
return [
|
return [FunctionWithValvesModel.model_validate(function) for function in functions]
|
||||||
FunctionWithValvesModel.model_validate(function)
|
|
||||||
for function in functions
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
return [
|
return [FunctionModel.model_validate(function) for function in functions]
|
||||||
FunctionModel.model_validate(function) for function in functions
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_function_list(
|
def get_function_list(self, db: Optional[Session] = None) -> list[FunctionUserResponse]:
|
||||||
self, db: Optional[Session] = None
|
|
||||||
) -> list[FunctionUserResponse]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
functions = db.query(Function).order_by(Function.updated_at.desc()).all()
|
functions = db.query(Function).order_by(Function.updated_at.desc()).all()
|
||||||
user_ids = list(set(func.user_id for func in functions))
|
user_ids = list(set(func.user_id for func in functions))
|
||||||
@@ -248,69 +234,48 @@ class FunctionsTable:
|
|||||||
FunctionUserResponse.model_validate(
|
FunctionUserResponse.model_validate(
|
||||||
{
|
{
|
||||||
**FunctionModel.model_validate(func).model_dump(),
|
**FunctionModel.model_validate(func).model_dump(),
|
||||||
"user": (
|
'user': (users_dict.get(func.user_id).model_dump() if func.user_id in users_dict else None),
|
||||||
users_dict.get(func.user_id).model_dump()
|
|
||||||
if func.user_id in users_dict
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
for func in functions
|
for func in functions
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_functions_by_type(
|
def get_functions_by_type(self, type: str, active_only=False, db: Optional[Session] = None) -> list[FunctionModel]:
|
||||||
self, type: str, active_only=False, db: Optional[Session] = None
|
|
||||||
) -> list[FunctionModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
if active_only:
|
if active_only:
|
||||||
return [
|
return [
|
||||||
FunctionModel.model_validate(function)
|
FunctionModel.model_validate(function)
|
||||||
for function in db.query(Function)
|
for function in db.query(Function).filter_by(type=type, is_active=True).all()
|
||||||
.filter_by(type=type, is_active=True)
|
|
||||||
.all()
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
return [
|
return [
|
||||||
FunctionModel.model_validate(function)
|
FunctionModel.model_validate(function) for function in db.query(Function).filter_by(type=type).all()
|
||||||
for function in db.query(Function).filter_by(type=type).all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_global_filter_functions(
|
def get_global_filter_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
|
||||||
self, db: Optional[Session] = None
|
|
||||||
) -> list[FunctionModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FunctionModel.model_validate(function)
|
FunctionModel.model_validate(function)
|
||||||
for function in db.query(Function)
|
for function in db.query(Function).filter_by(type='filter', is_active=True, is_global=True).all()
|
||||||
.filter_by(type="filter", is_active=True, is_global=True)
|
|
||||||
.all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_global_action_functions(
|
def get_global_action_functions(self, db: Optional[Session] = None) -> list[FunctionModel]:
|
||||||
self, db: Optional[Session] = None
|
|
||||||
) -> list[FunctionModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
FunctionModel.model_validate(function)
|
FunctionModel.model_validate(function)
|
||||||
for function in db.query(Function)
|
for function in db.query(Function).filter_by(type='action', is_active=True, is_global=True).all()
|
||||||
.filter_by(type="action", is_active=True, is_global=True)
|
|
||||||
.all()
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_function_valves_by_id(
|
def get_function_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[dict]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
function = db.get(Function, id)
|
function = db.get(Function, id)
|
||||||
return function.valves if function.valves else {}
|
return function.valves if function.valves else {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error getting function valves by id {id}: {e}")
|
log.exception(f'Error getting function valves by id {id}: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_function_valves_by_ids(
|
def get_function_valves_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, dict]:
|
||||||
self, ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> dict[str, dict]:
|
|
||||||
"""
|
"""
|
||||||
Batch fetch valves for multiple functions in a single query.
|
Batch fetch valves for multiple functions in a single query.
|
||||||
Returns a dict mapping function_id -> valves dict.
|
Returns a dict mapping function_id -> valves dict.
|
||||||
@@ -320,14 +285,10 @@ class FunctionsTable:
|
|||||||
return {}
|
return {}
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
functions = (
|
functions = db.query(Function.id, Function.valves).filter(Function.id.in_(ids)).all()
|
||||||
db.query(Function.id, Function.valves)
|
|
||||||
.filter(Function.id.in_(ids))
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return {f.id: (f.valves if f.valves else {}) for f in functions}
|
return {f.id: (f.valves if f.valves else {}) for f in functions}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error batch-fetching function valves: {e}")
|
log.exception(f'Error batch-fetching function valves: {e}')
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def update_function_valves_by_id(
|
def update_function_valves_by_id(
|
||||||
@@ -364,25 +325,23 @@ class FunctionsTable:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error updating function metadata by id {id}: {e}")
|
log.exception(f'Error updating function metadata by id {id}: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_valves_by_id_and_user_id(
|
def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]:
|
||||||
self, id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[dict]:
|
|
||||||
try:
|
try:
|
||||||
user = Users.get_user_by_id(user_id, db=db)
|
user = Users.get_user_by_id(user_id, db=db)
|
||||||
user_settings = user.settings.model_dump() if user.settings else {}
|
user_settings = user.settings.model_dump() if user.settings else {}
|
||||||
|
|
||||||
# Check if user has "functions" and "valves" settings
|
# Check if user has "functions" and "valves" settings
|
||||||
if "functions" not in user_settings:
|
if 'functions' not in user_settings:
|
||||||
user_settings["functions"] = {}
|
user_settings['functions'] = {}
|
||||||
if "valves" not in user_settings["functions"]:
|
if 'valves' not in user_settings['functions']:
|
||||||
user_settings["functions"]["valves"] = {}
|
user_settings['functions']['valves'] = {}
|
||||||
|
|
||||||
return user_settings["functions"]["valves"].get(id, {})
|
return user_settings['functions']['valves'].get(id, {})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error getting user values by id {id} and user id {user_id}")
|
log.exception(f'Error getting user values by id {id} and user id {user_id}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_valves_by_id_and_user_id(
|
def update_user_valves_by_id_and_user_id(
|
||||||
@@ -393,32 +352,28 @@ class FunctionsTable:
|
|||||||
user_settings = user.settings.model_dump() if user.settings else {}
|
user_settings = user.settings.model_dump() if user.settings else {}
|
||||||
|
|
||||||
# Check if user has "functions" and "valves" settings
|
# Check if user has "functions" and "valves" settings
|
||||||
if "functions" not in user_settings:
|
if 'functions' not in user_settings:
|
||||||
user_settings["functions"] = {}
|
user_settings['functions'] = {}
|
||||||
if "valves" not in user_settings["functions"]:
|
if 'valves' not in user_settings['functions']:
|
||||||
user_settings["functions"]["valves"] = {}
|
user_settings['functions']['valves'] = {}
|
||||||
|
|
||||||
user_settings["functions"]["valves"][id] = valves
|
user_settings['functions']['valves'][id] = valves
|
||||||
|
|
||||||
# Update the user settings in the database
|
# Update the user settings in the database
|
||||||
Users.update_user_by_id(user_id, {"settings": user_settings}, db=db)
|
Users.update_user_by_id(user_id, {'settings': user_settings}, db=db)
|
||||||
|
|
||||||
return user_settings["functions"]["valves"][id]
|
return user_settings['functions']['valves'][id]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(
|
log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
|
||||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_function_by_id(
|
def update_function_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[FunctionModel]:
|
||||||
self, id: str, updated: dict, db: Optional[Session] = None
|
|
||||||
) -> Optional[FunctionModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
db.query(Function).filter_by(id=id).update(
|
db.query(Function).filter_by(id=id).update(
|
||||||
{
|
{
|
||||||
**updated,
|
**updated,
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -432,8 +387,8 @@ class FunctionsTable:
|
|||||||
try:
|
try:
|
||||||
db.query(Function).update(
|
db.query(Function).update(
|
||||||
{
|
{
|
||||||
"is_active": False,
|
'is_active': False,
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Group(Base):
|
class Group(Base):
|
||||||
__tablename__ = "group"
|
__tablename__ = 'group'
|
||||||
|
|
||||||
id = Column(Text, unique=True, primary_key=True)
|
id = Column(Text, unique=True, primary_key=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
@@ -70,12 +70,12 @@ class GroupModel(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class GroupMember(Base):
|
class GroupMember(Base):
|
||||||
__tablename__ = "group_member"
|
__tablename__ = 'group_member'
|
||||||
|
|
||||||
id = Column(Text, unique=True, primary_key=True)
|
id = Column(Text, unique=True, primary_key=True)
|
||||||
group_id = Column(
|
group_id = Column(
|
||||||
Text,
|
Text,
|
||||||
ForeignKey("group.id", ondelete="CASCADE"),
|
ForeignKey('group.id', ondelete='CASCADE'),
|
||||||
nullable=False,
|
nullable=False,
|
||||||
)
|
)
|
||||||
user_id = Column(Text, nullable=False)
|
user_id = Column(Text, nullable=False)
|
||||||
@@ -133,28 +133,26 @@ class GroupListResponse(BaseModel):
|
|||||||
class GroupTable:
|
class GroupTable:
|
||||||
def _ensure_default_share_config(self, group_data: dict) -> dict:
|
def _ensure_default_share_config(self, group_data: dict) -> dict:
|
||||||
"""Ensure the group data dict has a default share config if not already set."""
|
"""Ensure the group data dict has a default share config if not already set."""
|
||||||
if "data" not in group_data or group_data["data"] is None:
|
if 'data' not in group_data or group_data['data'] is None:
|
||||||
group_data["data"] = {}
|
group_data['data'] = {}
|
||||||
if "config" not in group_data["data"]:
|
if 'config' not in group_data['data']:
|
||||||
group_data["data"]["config"] = {}
|
group_data['data']['config'] = {}
|
||||||
if "share" not in group_data["data"]["config"]:
|
if 'share' not in group_data['data']['config']:
|
||||||
group_data["data"]["config"]["share"] = DEFAULT_GROUP_SHARE_PERMISSION
|
group_data['data']['config']['share'] = DEFAULT_GROUP_SHARE_PERMISSION
|
||||||
return group_data
|
return group_data
|
||||||
|
|
||||||
def insert_new_group(
|
def insert_new_group(
|
||||||
self, user_id: str, form_data: GroupForm, db: Optional[Session] = None
|
self, user_id: str, form_data: GroupForm, db: Optional[Session] = None
|
||||||
) -> Optional[GroupModel]:
|
) -> Optional[GroupModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
group_data = self._ensure_default_share_config(
|
group_data = self._ensure_default_share_config(form_data.model_dump(exclude_none=True))
|
||||||
form_data.model_dump(exclude_none=True)
|
|
||||||
)
|
|
||||||
group = GroupModel(
|
group = GroupModel(
|
||||||
**{
|
**{
|
||||||
**group_data,
|
**group_data,
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -183,19 +181,19 @@ class GroupTable:
|
|||||||
.where(GroupMember.group_id == Group.id)
|
.where(GroupMember.group_id == Group.id)
|
||||||
.correlate(Group)
|
.correlate(Group)
|
||||||
.scalar_subquery()
|
.scalar_subquery()
|
||||||
.label("member_count")
|
.label('member_count')
|
||||||
)
|
)
|
||||||
query = db.query(Group, member_count)
|
query = db.query(Group, member_count)
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
if "query" in filter:
|
if 'query' in filter:
|
||||||
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
query = query.filter(Group.name.ilike(f'%{filter["query"]}%'))
|
||||||
|
|
||||||
# When share filter is present, member check is handled in the share logic
|
# When share filter is present, member check is handled in the share logic
|
||||||
if "share" in filter:
|
if 'share' in filter:
|
||||||
share_value = filter["share"]
|
share_value = filter['share']
|
||||||
member_id = filter.get("member_id")
|
member_id = filter.get('member_id')
|
||||||
json_share = Group.data["config"]["share"]
|
json_share = Group.data['config']['share']
|
||||||
json_share_str = json_share.as_string()
|
json_share_str = json_share.as_string()
|
||||||
json_share_lower = func.lower(json_share_str)
|
json_share_lower = func.lower(json_share_str)
|
||||||
|
|
||||||
@@ -203,37 +201,27 @@ class GroupTable:
|
|||||||
anyone_can_share = or_(
|
anyone_can_share = or_(
|
||||||
Group.data.is_(None),
|
Group.data.is_(None),
|
||||||
json_share_str.is_(None),
|
json_share_str.is_(None),
|
||||||
json_share_lower == "true",
|
json_share_lower == 'true',
|
||||||
json_share_lower == "1", # Handle SQLite boolean true
|
json_share_lower == '1', # Handle SQLite boolean true
|
||||||
)
|
)
|
||||||
|
|
||||||
if member_id:
|
if member_id:
|
||||||
member_groups_select = select(GroupMember.group_id).where(
|
member_groups_select = select(GroupMember.group_id).where(GroupMember.user_id == member_id)
|
||||||
GroupMember.user_id == member_id
|
|
||||||
)
|
|
||||||
members_only_and_is_member = and_(
|
members_only_and_is_member = and_(
|
||||||
json_share_lower == "members",
|
json_share_lower == 'members',
|
||||||
Group.id.in_(member_groups_select),
|
Group.id.in_(member_groups_select),
|
||||||
)
|
)
|
||||||
query = query.filter(
|
query = query.filter(or_(anyone_can_share, members_only_and_is_member))
|
||||||
or_(anyone_can_share, members_only_and_is_member)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
query = query.filter(anyone_can_share)
|
query = query.filter(anyone_can_share)
|
||||||
else:
|
else:
|
||||||
query = query.filter(
|
query = query.filter(and_(Group.data.isnot(None), json_share_lower == 'false'))
|
||||||
and_(Group.data.isnot(None), json_share_lower == "false")
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Only apply member_id filter when share filter is NOT present
|
# Only apply member_id filter when share filter is NOT present
|
||||||
if "member_id" in filter:
|
if 'member_id' in filter:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
Group.id.in_(
|
Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id']))
|
||||||
select(GroupMember.group_id).where(
|
|
||||||
GroupMember.user_id == filter["member_id"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
results = query.order_by(Group.updated_at.desc()).all()
|
results = query.order_by(Group.updated_at.desc()).all()
|
||||||
@@ -242,7 +230,7 @@ class GroupTable:
|
|||||||
GroupResponse.model_validate(
|
GroupResponse.model_validate(
|
||||||
{
|
{
|
||||||
**GroupModel.model_validate(group).model_dump(),
|
**GroupModel.model_validate(group).model_dump(),
|
||||||
"member_count": count or 0,
|
'member_count': count or 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
for group, count in results
|
for group, count in results
|
||||||
@@ -259,22 +247,16 @@ class GroupTable:
|
|||||||
query = db.query(Group)
|
query = db.query(Group)
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
if "query" in filter:
|
if 'query' in filter:
|
||||||
query = query.filter(Group.name.ilike(f"%{filter['query']}%"))
|
query = query.filter(Group.name.ilike(f'%{filter["query"]}%'))
|
||||||
if "member_id" in filter:
|
if 'member_id' in filter:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
Group.id.in_(
|
Group.id.in_(select(GroupMember.group_id).where(GroupMember.user_id == filter['member_id']))
|
||||||
select(GroupMember.group_id).where(
|
|
||||||
GroupMember.user_id == filter["member_id"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if "share" in filter:
|
if 'share' in filter:
|
||||||
share_value = filter["share"]
|
share_value = filter['share']
|
||||||
query = query.filter(
|
query = query.filter(Group.data.op('->>')('share') == str(share_value))
|
||||||
Group.data.op("->>")("share") == str(share_value)
|
|
||||||
)
|
|
||||||
|
|
||||||
total = query.count()
|
total = query.count()
|
||||||
|
|
||||||
@@ -283,32 +265,24 @@ class GroupTable:
|
|||||||
.where(GroupMember.group_id == Group.id)
|
.where(GroupMember.group_id == Group.id)
|
||||||
.correlate(Group)
|
.correlate(Group)
|
||||||
.scalar_subquery()
|
.scalar_subquery()
|
||||||
.label("member_count")
|
.label('member_count')
|
||||||
)
|
|
||||||
results = (
|
|
||||||
query.add_columns(member_count)
|
|
||||||
.order_by(Group.updated_at.desc())
|
|
||||||
.offset(skip)
|
|
||||||
.limit(limit)
|
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
results = query.add_columns(member_count).order_by(Group.updated_at.desc()).offset(skip).limit(limit).all()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"items": [
|
'items': [
|
||||||
GroupResponse.model_validate(
|
GroupResponse.model_validate(
|
||||||
{
|
{
|
||||||
**GroupModel.model_validate(group).model_dump(),
|
**GroupModel.model_validate(group).model_dump(),
|
||||||
"member_count": count or 0,
|
'member_count': count or 0,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
for group, count in results
|
for group, count in results
|
||||||
],
|
],
|
||||||
"total": total,
|
'total': total,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_groups_by_member_id(
|
def get_groups_by_member_id(self, user_id: str, db: Optional[Session] = None) -> list[GroupModel]:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[GroupModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
GroupModel.model_validate(group)
|
GroupModel.model_validate(group)
|
||||||
@@ -340,9 +314,7 @@ class GroupTable:
|
|||||||
|
|
||||||
return user_groups
|
return user_groups
|
||||||
|
|
||||||
def get_group_by_id(
|
def get_group_by_id(self, id: str, db: Optional[Session] = None) -> Optional[GroupModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[GroupModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
group = db.query(Group).filter_by(id=id).first()
|
group = db.query(Group).filter_by(id=id).first()
|
||||||
@@ -350,41 +322,29 @@ class GroupTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_group_user_ids_by_id(
|
def get_group_user_ids_by_id(self, id: str, db: Optional[Session] = None) -> list[str]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> list[str]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
members = (
|
members = db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
|
||||||
db.query(GroupMember.user_id).filter(GroupMember.group_id == id).all()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not members:
|
if not members:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return [m[0] for m in members]
|
return [m[0] for m in members]
|
||||||
|
|
||||||
def get_group_user_ids_by_ids(
|
def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[Session] = None) -> dict[str, list[str]]:
|
||||||
self, group_ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> dict[str, list[str]]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
members = (
|
members = (
|
||||||
db.query(GroupMember.group_id, GroupMember.user_id)
|
db.query(GroupMember.group_id, GroupMember.user_id).filter(GroupMember.group_id.in_(group_ids)).all()
|
||||||
.filter(GroupMember.group_id.in_(group_ids))
|
|
||||||
.all()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
group_user_ids: dict[str, list[str]] = {
|
group_user_ids: dict[str, list[str]] = {group_id: [] for group_id in group_ids}
|
||||||
group_id: [] for group_id in group_ids
|
|
||||||
}
|
|
||||||
|
|
||||||
for group_id, user_id in members:
|
for group_id, user_id in members:
|
||||||
group_user_ids[group_id].append(user_id)
|
group_user_ids[group_id].append(user_id)
|
||||||
|
|
||||||
return group_user_ids
|
return group_user_ids
|
||||||
|
|
||||||
def set_group_user_ids_by_id(
|
def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[Session] = None) -> None:
|
||||||
self, group_id: str, user_ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> None:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# Delete existing members
|
# Delete existing members
|
||||||
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
|
db.query(GroupMember).filter(GroupMember.group_id == group_id).delete()
|
||||||
@@ -405,20 +365,12 @@ class GroupTable:
|
|||||||
db.add_all(new_members)
|
db.add_all(new_members)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
def get_group_member_count_by_id(
|
def get_group_member_count_by_id(self, id: str, db: Optional[Session] = None) -> int:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> int:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
count = (
|
count = db.query(func.count(GroupMember.user_id)).filter(GroupMember.group_id == id).scalar()
|
||||||
db.query(func.count(GroupMember.user_id))
|
|
||||||
.filter(GroupMember.group_id == id)
|
|
||||||
.scalar()
|
|
||||||
)
|
|
||||||
return count if count else 0
|
return count if count else 0
|
||||||
|
|
||||||
def get_group_member_counts_by_ids(
|
def get_group_member_counts_by_ids(self, ids: list[str], db: Optional[Session] = None) -> dict[str, int]:
|
||||||
self, ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> dict[str, int]:
|
|
||||||
if not ids:
|
if not ids:
|
||||||
return {}
|
return {}
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
@@ -442,7 +394,7 @@ class GroupTable:
|
|||||||
db.query(Group).filter_by(id=id).update(
|
db.query(Group).filter_by(id=id).update(
|
||||||
{
|
{
|
||||||
**form_data.model_dump(exclude_none=True),
|
**form_data.model_dump(exclude_none=True),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -470,9 +422,7 @@ class GroupTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def remove_user_from_all_groups(
|
def remove_user_from_all_groups(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
# Find all groups the user belongs to
|
# Find all groups the user belongs to
|
||||||
@@ -489,9 +439,7 @@ class GroupTable:
|
|||||||
GroupMember.group_id == group.id, GroupMember.user_id == user_id
|
GroupMember.group_id == group.id, GroupMember.user_id == user_id
|
||||||
).delete()
|
).delete()
|
||||||
|
|
||||||
db.query(Group).filter_by(id=group.id).update(
|
db.query(Group).filter_by(id=group.id).update({'updated_at': int(time.time())})
|
||||||
{"updated_at": int(time.time())}
|
|
||||||
)
|
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
@@ -503,7 +451,6 @@ class GroupTable:
|
|||||||
def create_groups_by_group_names(
|
def create_groups_by_group_names(
|
||||||
self, user_id: str, group_names: list[str], db: Optional[Session] = None
|
self, user_id: str, group_names: list[str], db: Optional[Session] = None
|
||||||
) -> list[GroupModel]:
|
) -> list[GroupModel]:
|
||||||
|
|
||||||
# check for existing groups
|
# check for existing groups
|
||||||
existing_groups = self.get_all_groups(db=db)
|
existing_groups = self.get_all_groups(db=db)
|
||||||
existing_group_names = {group.name for group in existing_groups}
|
existing_group_names = {group.name for group in existing_groups}
|
||||||
@@ -517,10 +464,10 @@ class GroupTable:
|
|||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
name=group_name,
|
name=group_name,
|
||||||
description="",
|
description='',
|
||||||
data={
|
data={
|
||||||
"config": {
|
'config': {
|
||||||
"share": DEFAULT_GROUP_SHARE_PERMISSION,
|
'share': DEFAULT_GROUP_SHARE_PERMISSION,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
created_at=int(time.time()),
|
created_at=int(time.time()),
|
||||||
@@ -537,17 +484,13 @@ class GroupTable:
|
|||||||
continue
|
continue
|
||||||
return new_groups
|
return new_groups
|
||||||
|
|
||||||
def sync_groups_by_group_names(
|
def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[Session] = None) -> bool:
|
||||||
self, user_id: str, group_names: list[str], db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
|
|
||||||
# 1. Groups that SHOULD contain the user
|
# 1. Groups that SHOULD contain the user
|
||||||
target_groups = (
|
target_groups = db.query(Group).filter(Group.name.in_(group_names)).all()
|
||||||
db.query(Group).filter(Group.name.in_(group_names)).all()
|
|
||||||
)
|
|
||||||
target_group_ids = {g.id for g in target_groups}
|
target_group_ids = {g.id for g in target_groups}
|
||||||
|
|
||||||
# 2. Groups the user is CURRENTLY in
|
# 2. Groups the user is CURRENTLY in
|
||||||
@@ -571,7 +514,7 @@ class GroupTable:
|
|||||||
).delete(synchronize_session=False)
|
).delete(synchronize_session=False)
|
||||||
|
|
||||||
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
|
db.query(Group).filter(Group.id.in_(groups_to_remove)).update(
|
||||||
{"updated_at": now}, synchronize_session=False
|
{'updated_at': now}, synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Bulk insert missing memberships
|
# 5. Bulk insert missing memberships
|
||||||
@@ -588,7 +531,7 @@ class GroupTable:
|
|||||||
|
|
||||||
if groups_to_add:
|
if groups_to_add:
|
||||||
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
|
db.query(Group).filter(Group.id.in_(groups_to_add)).update(
|
||||||
{"updated_at": now}, synchronize_session=False
|
{'updated_at': now}, synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -656,9 +599,9 @@ class GroupTable:
|
|||||||
return GroupModel.model_validate(group)
|
return GroupModel.model_validate(group)
|
||||||
|
|
||||||
# Remove users from group_member in batch
|
# Remove users from group_member in batch
|
||||||
db.query(GroupMember).filter(
|
db.query(GroupMember).filter(GroupMember.group_id == id, GroupMember.user_id.in_(user_ids)).delete(
|
||||||
GroupMember.group_id == id, GroupMember.user_id.in_(user_ids)
|
synchronize_session=False
|
||||||
).delete(synchronize_session=False)
|
)
|
||||||
|
|
||||||
# Update group timestamp
|
# Update group timestamp
|
||||||
group.updated_at = int(time.time())
|
group.updated_at = int(time.time())
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Knowledge(Base):
|
class Knowledge(Base):
|
||||||
__tablename__ = "knowledge"
|
__tablename__ = 'knowledge'
|
||||||
|
|
||||||
id = Column(Text, unique=True, primary_key=True)
|
id = Column(Text, unique=True, primary_key=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
@@ -70,24 +70,18 @@ class KnowledgeModel(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgeFile(Base):
|
class KnowledgeFile(Base):
|
||||||
__tablename__ = "knowledge_file"
|
__tablename__ = 'knowledge_file'
|
||||||
|
|
||||||
id = Column(Text, unique=True, primary_key=True)
|
id = Column(Text, unique=True, primary_key=True)
|
||||||
|
|
||||||
knowledge_id = Column(
|
knowledge_id = Column(Text, ForeignKey('knowledge.id', ondelete='CASCADE'), nullable=False)
|
||||||
Text, ForeignKey("knowledge.id", ondelete="CASCADE"), nullable=False
|
file_id = Column(Text, ForeignKey('file.id', ondelete='CASCADE'), nullable=False)
|
||||||
)
|
|
||||||
file_id = Column(Text, ForeignKey("file.id", ondelete="CASCADE"), nullable=False)
|
|
||||||
user_id = Column(Text, nullable=False)
|
user_id = Column(Text, nullable=False)
|
||||||
|
|
||||||
created_at = Column(BigInteger, nullable=False)
|
created_at = Column(BigInteger, nullable=False)
|
||||||
updated_at = Column(BigInteger, nullable=False)
|
updated_at = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (UniqueConstraint('knowledge_id', 'file_id', name='uq_knowledge_file_knowledge_file'),)
|
||||||
UniqueConstraint(
|
|
||||||
"knowledge_id", "file_id", name="uq_knowledge_file_knowledge_file"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeFileModel(BaseModel):
|
class KnowledgeFileModel(BaseModel):
|
||||||
@@ -138,10 +132,8 @@ class KnowledgeFileListResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgeTable:
|
class KnowledgeTable:
|
||||||
def _get_access_grants(
|
def _get_access_grants(self, knowledge_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||||
self, knowledge_id: str, db: Optional[Session] = None
|
return AccessGrants.get_grants_by_resource('knowledge', knowledge_id, db=db)
|
||||||
) -> list[AccessGrantModel]:
|
|
||||||
return AccessGrants.get_grants_by_resource("knowledge", knowledge_id, db=db)
|
|
||||||
|
|
||||||
def _to_knowledge_model(
|
def _to_knowledge_model(
|
||||||
self,
|
self,
|
||||||
@@ -149,13 +141,9 @@ class KnowledgeTable:
|
|||||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> KnowledgeModel:
|
) -> KnowledgeModel:
|
||||||
knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump(
|
knowledge_data = KnowledgeModel.model_validate(knowledge).model_dump(exclude={'access_grants'})
|
||||||
exclude={"access_grants"}
|
knowledge_data['access_grants'] = (
|
||||||
)
|
access_grants if access_grants is not None else self._get_access_grants(knowledge_data['id'], db=db)
|
||||||
knowledge_data["access_grants"] = (
|
|
||||||
access_grants
|
|
||||||
if access_grants is not None
|
|
||||||
else self._get_access_grants(knowledge_data["id"], db=db)
|
|
||||||
)
|
)
|
||||||
return KnowledgeModel.model_validate(knowledge_data)
|
return KnowledgeModel.model_validate(knowledge_data)
|
||||||
|
|
||||||
@@ -165,23 +153,21 @@ class KnowledgeTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
knowledge = KnowledgeModel(
|
knowledge = KnowledgeModel(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(exclude={"access_grants"}),
|
**form_data.model_dump(exclude={'access_grants'}),
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
"access_grants": [],
|
'access_grants': [],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = Knowledge(**knowledge.model_dump(exclude={"access_grants"}))
|
result = Knowledge(**knowledge.model_dump(exclude={'access_grants'}))
|
||||||
db.add(result)
|
db.add(result)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(result)
|
db.refresh(result)
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('knowledge', result.id, form_data.access_grants, db=db)
|
||||||
"knowledge", result.id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
if result:
|
if result:
|
||||||
return self._to_knowledge_model(result, db=db)
|
return self._to_knowledge_model(result, db=db)
|
||||||
else:
|
else:
|
||||||
@@ -193,17 +179,13 @@ class KnowledgeTable:
|
|||||||
self, skip: int = 0, limit: int = 30, db: Optional[Session] = None
|
self, skip: int = 0, limit: int = 30, db: Optional[Session] = None
|
||||||
) -> list[KnowledgeUserModel]:
|
) -> list[KnowledgeUserModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
all_knowledge = (
|
all_knowledge = db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
|
||||||
db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all()
|
|
||||||
)
|
|
||||||
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
|
user_ids = list(set(knowledge.user_id for knowledge in all_knowledge))
|
||||||
knowledge_ids = [knowledge.id for knowledge in all_knowledge]
|
knowledge_ids = [knowledge.id for knowledge in all_knowledge]
|
||||||
|
|
||||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||||
users_dict = {user.id: user for user in users}
|
users_dict = {user.id: user for user in users}
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
|
||||||
"knowledge", knowledge_ids, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
knowledge_bases = []
|
knowledge_bases = []
|
||||||
for knowledge in all_knowledge:
|
for knowledge in all_knowledge:
|
||||||
@@ -216,7 +198,7 @@ class KnowledgeTable:
|
|||||||
access_grants=grants_map.get(knowledge.id, []),
|
access_grants=grants_map.get(knowledge.id, []),
|
||||||
db=db,
|
db=db,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
"user": user.model_dump() if user else None,
|
'user': user.model_dump() if user else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -232,27 +214,25 @@ class KnowledgeTable:
|
|||||||
) -> KnowledgeListResponse:
|
) -> KnowledgeListResponse:
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
query = db.query(Knowledge, User).outerjoin(
|
query = db.query(Knowledge, User).outerjoin(User, User.id == Knowledge.user_id)
|
||||||
User, User.id == Knowledge.user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
query_key = filter.get("query")
|
query_key = filter.get('query')
|
||||||
if query_key:
|
if query_key:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
Knowledge.name.ilike(f"%{query_key}%"),
|
Knowledge.name.ilike(f'%{query_key}%'),
|
||||||
Knowledge.description.ilike(f"%{query_key}%"),
|
Knowledge.description.ilike(f'%{query_key}%'),
|
||||||
User.name.ilike(f"%{query_key}%"),
|
User.name.ilike(f'%{query_key}%'),
|
||||||
User.email.ilike(f"%{query_key}%"),
|
User.email.ilike(f'%{query_key}%'),
|
||||||
User.username.ilike(f"%{query_key}%"),
|
User.username.ilike(f'%{query_key}%'),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
view_option = filter.get('view_option')
|
||||||
if view_option == "created":
|
if view_option == 'created':
|
||||||
query = query.filter(Knowledge.user_id == user_id)
|
query = query.filter(Knowledge.user_id == user_id)
|
||||||
elif view_option == "shared":
|
elif view_option == 'shared':
|
||||||
query = query.filter(Knowledge.user_id != user_id)
|
query = query.filter(Knowledge.user_id != user_id)
|
||||||
|
|
||||||
query = AccessGrants.has_permission_filter(
|
query = AccessGrants.has_permission_filter(
|
||||||
@@ -260,8 +240,8 @@ class KnowledgeTable:
|
|||||||
query=query,
|
query=query,
|
||||||
DocumentModel=Knowledge,
|
DocumentModel=Knowledge,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
resource_type="knowledge",
|
resource_type='knowledge',
|
||||||
permission="read",
|
permission='read',
|
||||||
)
|
)
|
||||||
|
|
||||||
query = query.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc())
|
query = query.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc())
|
||||||
@@ -275,9 +255,7 @@ class KnowledgeTable:
|
|||||||
items = query.all()
|
items = query.all()
|
||||||
|
|
||||||
knowledge_ids = [kb.id for kb, _ in items]
|
knowledge_ids = [kb.id for kb, _ in items]
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
|
||||||
"knowledge", knowledge_ids, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
knowledge_bases = []
|
knowledge_bases = []
|
||||||
for knowledge_base, user in items:
|
for knowledge_base, user in items:
|
||||||
@@ -289,11 +267,7 @@ class KnowledgeTable:
|
|||||||
access_grants=grants_map.get(knowledge_base.id, []),
|
access_grants=grants_map.get(knowledge_base.id, []),
|
||||||
db=db,
|
db=db,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
"user": (
|
'user': (UserModel.model_validate(user).model_dump() if user else None),
|
||||||
UserModel.model_validate(user).model_dump()
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -327,15 +301,15 @@ class KnowledgeTable:
|
|||||||
query=query,
|
query=query,
|
||||||
DocumentModel=Knowledge,
|
DocumentModel=Knowledge,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
resource_type="knowledge",
|
resource_type='knowledge',
|
||||||
permission="read",
|
permission='read',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply filename search
|
# Apply filename search
|
||||||
if filter:
|
if filter:
|
||||||
q = filter.get("query")
|
q = filter.get('query')
|
||||||
if q:
|
if q:
|
||||||
query = query.filter(File.filename.ilike(f"%{q}%"))
|
query = query.filter(File.filename.ilike(f'%{q}%'))
|
||||||
|
|
||||||
# Order by file changes
|
# Order by file changes
|
||||||
query = query.order_by(File.updated_at.desc(), File.id.asc())
|
query = query.order_by(File.updated_at.desc(), File.id.asc())
|
||||||
@@ -355,39 +329,27 @@ class KnowledgeTable:
|
|||||||
items.append(
|
items.append(
|
||||||
FileUserResponse(
|
FileUserResponse(
|
||||||
**FileModel.model_validate(file).model_dump(),
|
**FileModel.model_validate(file).model_dump(),
|
||||||
user=(
|
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||||
UserResponse(
|
collection=self._to_knowledge_model(knowledge, db=db).model_dump(),
|
||||||
**UserModel.model_validate(user).model_dump()
|
|
||||||
)
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
collection=self._to_knowledge_model(
|
|
||||||
knowledge, db=db
|
|
||||||
).model_dump(),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return KnowledgeFileListResponse(items=items, total=total)
|
return KnowledgeFileListResponse(items=items, total=total)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("search_knowledge_files error:", e)
|
print('search_knowledge_files error:', e)
|
||||||
return KnowledgeFileListResponse(items=[], total=0)
|
return KnowledgeFileListResponse(items=[], total=0)
|
||||||
|
|
||||||
def check_access_by_user_id(
|
def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[Session] = None) -> bool:
|
||||||
self, id, user_id, permission="write", db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
knowledge = self.get_knowledge_by_id(id, db=db)
|
knowledge = self.get_knowledge_by_id(id, db=db)
|
||||||
if not knowledge:
|
if not knowledge:
|
||||||
return False
|
return False
|
||||||
if knowledge.user_id == user_id:
|
if knowledge.user_id == user_id:
|
||||||
return True
|
return True
|
||||||
user_group_ids = {
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
}
|
|
||||||
return AccessGrants.has_access(
|
return AccessGrants.has_access(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
resource_type="knowledge",
|
resource_type='knowledge',
|
||||||
resource_id=knowledge.id,
|
resource_id=knowledge.id,
|
||||||
permission=permission,
|
permission=permission,
|
||||||
user_group_ids=user_group_ids,
|
user_group_ids=user_group_ids,
|
||||||
@@ -395,19 +357,17 @@ class KnowledgeTable:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_knowledge_bases_by_user_id(
|
def get_knowledge_bases_by_user_id(
|
||||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
|
||||||
) -> list[KnowledgeUserModel]:
|
) -> list[KnowledgeUserModel]:
|
||||||
knowledge_bases = self.get_knowledge_bases(db=db)
|
knowledge_bases = self.get_knowledge_bases(db=db)
|
||||||
user_group_ids = {
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
}
|
|
||||||
return [
|
return [
|
||||||
knowledge_base
|
knowledge_base
|
||||||
for knowledge_base in knowledge_bases
|
for knowledge_base in knowledge_bases
|
||||||
if knowledge_base.user_id == user_id
|
if knowledge_base.user_id == user_id
|
||||||
or AccessGrants.has_access(
|
or AccessGrants.has_access(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
resource_type="knowledge",
|
resource_type='knowledge',
|
||||||
resource_id=knowledge_base.id,
|
resource_id=knowledge_base.id,
|
||||||
permission=permission,
|
permission=permission,
|
||||||
user_group_ids=user_group_ids,
|
user_group_ids=user_group_ids,
|
||||||
@@ -415,9 +375,7 @@ class KnowledgeTable:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_knowledge_by_id(
|
def get_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[KnowledgeModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
knowledge = db.query(Knowledge).filter_by(id=id).first()
|
knowledge = db.query(Knowledge).filter_by(id=id).first()
|
||||||
@@ -435,23 +393,19 @@ class KnowledgeTable:
|
|||||||
if knowledge.user_id == user_id:
|
if knowledge.user_id == user_id:
|
||||||
return knowledge
|
return knowledge
|
||||||
|
|
||||||
user_group_ids = {
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
}
|
|
||||||
if AccessGrants.has_access(
|
if AccessGrants.has_access(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
resource_type="knowledge",
|
resource_type='knowledge',
|
||||||
resource_id=knowledge.id,
|
resource_id=knowledge.id,
|
||||||
permission="write",
|
permission='write',
|
||||||
user_group_ids=user_group_ids,
|
user_group_ids=user_group_ids,
|
||||||
db=db,
|
db=db,
|
||||||
):
|
):
|
||||||
return knowledge
|
return knowledge
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_knowledges_by_file_id(
|
def get_knowledges_by_file_id(self, file_id: str, db: Optional[Session] = None) -> list[KnowledgeModel]:
|
||||||
self, file_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[KnowledgeModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
knowledges = (
|
knowledges = (
|
||||||
@@ -461,9 +415,7 @@ class KnowledgeTable:
|
|||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
knowledge_ids = [k.id for k in knowledges]
|
knowledge_ids = [k.id for k in knowledges]
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('knowledge', knowledge_ids, db=db)
|
||||||
"knowledge", knowledge_ids, db=db
|
|
||||||
)
|
|
||||||
return [
|
return [
|
||||||
self._to_knowledge_model(
|
self._to_knowledge_model(
|
||||||
knowledge,
|
knowledge,
|
||||||
@@ -497,32 +449,26 @@ class KnowledgeTable:
|
|||||||
primary_sort = File.updated_at.desc()
|
primary_sort = File.updated_at.desc()
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
query_key = filter.get("query")
|
query_key = filter.get('query')
|
||||||
if query_key:
|
if query_key:
|
||||||
query = query.filter(or_(File.filename.ilike(f"%{query_key}%")))
|
query = query.filter(or_(File.filename.ilike(f'%{query_key}%')))
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
view_option = filter.get('view_option')
|
||||||
if view_option == "created":
|
if view_option == 'created':
|
||||||
query = query.filter(KnowledgeFile.user_id == user_id)
|
query = query.filter(KnowledgeFile.user_id == user_id)
|
||||||
elif view_option == "shared":
|
elif view_option == 'shared':
|
||||||
query = query.filter(KnowledgeFile.user_id != user_id)
|
query = query.filter(KnowledgeFile.user_id != user_id)
|
||||||
|
|
||||||
order_by = filter.get("order_by")
|
order_by = filter.get('order_by')
|
||||||
direction = filter.get("direction")
|
direction = filter.get('direction')
|
||||||
is_asc = direction == "asc"
|
is_asc = direction == 'asc'
|
||||||
|
|
||||||
if order_by == "name":
|
if order_by == 'name':
|
||||||
primary_sort = (
|
primary_sort = File.filename.asc() if is_asc else File.filename.desc()
|
||||||
File.filename.asc() if is_asc else File.filename.desc()
|
elif order_by == 'created_at':
|
||||||
)
|
primary_sort = File.created_at.asc() if is_asc else File.created_at.desc()
|
||||||
elif order_by == "created_at":
|
elif order_by == 'updated_at':
|
||||||
primary_sort = (
|
primary_sort = File.updated_at.asc() if is_asc else File.updated_at.desc()
|
||||||
File.created_at.asc() if is_asc else File.created_at.desc()
|
|
||||||
)
|
|
||||||
elif order_by == "updated_at":
|
|
||||||
primary_sort = (
|
|
||||||
File.updated_at.asc() if is_asc else File.updated_at.desc()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply sort with secondary key for deterministic pagination
|
# Apply sort with secondary key for deterministic pagination
|
||||||
query = query.order_by(primary_sort, File.id.asc())
|
query = query.order_by(primary_sort, File.id.asc())
|
||||||
@@ -542,13 +488,7 @@ class KnowledgeTable:
|
|||||||
files.append(
|
files.append(
|
||||||
FileUserResponse(
|
FileUserResponse(
|
||||||
**FileModel.model_validate(file).model_dump(),
|
**FileModel.model_validate(file).model_dump(),
|
||||||
user=(
|
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||||
UserResponse(
|
|
||||||
**UserModel.model_validate(user).model_dump()
|
|
||||||
)
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -557,9 +497,7 @@ class KnowledgeTable:
|
|||||||
print(e)
|
print(e)
|
||||||
return KnowledgeFileListResponse(items=[], total=0)
|
return KnowledgeFileListResponse(items=[], total=0)
|
||||||
|
|
||||||
def get_files_by_id(
|
def get_files_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileModel]:
|
||||||
self, knowledge_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[FileModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
files = (
|
files = (
|
||||||
@@ -572,9 +510,7 @@ class KnowledgeTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_file_metadatas_by_id(
|
def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[Session] = None) -> list[FileMetadataResponse]:
|
||||||
self, knowledge_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[FileMetadataResponse]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
files = self.get_files_by_id(knowledge_id, db=db)
|
files = self.get_files_by_id(knowledge_id, db=db)
|
||||||
@@ -592,12 +528,12 @@ class KnowledgeTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
knowledge_file = KnowledgeFileModel(
|
knowledge_file = KnowledgeFileModel(
|
||||||
**{
|
**{
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"knowledge_id": knowledge_id,
|
'knowledge_id': knowledge_id,
|
||||||
"file_id": file_id,
|
'file_id': file_id,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -613,37 +549,24 @@ class KnowledgeTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def has_file(
|
def has_file(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, knowledge_id: str, file_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Check whether a file belongs to a knowledge base."""
|
"""Check whether a file belongs to a knowledge base."""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return (
|
return db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).first() is not None
|
||||||
db.query(KnowledgeFile)
|
|
||||||
.filter_by(knowledge_id=knowledge_id, file_id=file_id)
|
|
||||||
.first()
|
|
||||||
is not None
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def remove_file_from_knowledge_by_id(
|
def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, knowledge_id: str, file_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
db.query(KnowledgeFile).filter_by(
|
db.query(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id).delete()
|
||||||
knowledge_id=knowledge_id, file_id=file_id
|
|
||||||
).delete()
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def reset_knowledge_by_id(
|
def reset_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> Optional[KnowledgeModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[KnowledgeModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# Delete all knowledge_file entries for this knowledge_id
|
# Delete all knowledge_file entries for this knowledge_id
|
||||||
@@ -653,7 +576,7 @@ class KnowledgeTable:
|
|||||||
# Update the knowledge entry's updated_at timestamp
|
# Update the knowledge entry's updated_at timestamp
|
||||||
db.query(Knowledge).filter_by(id=id).update(
|
db.query(Knowledge).filter_by(id=id).update(
|
||||||
{
|
{
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -675,15 +598,13 @@ class KnowledgeTable:
|
|||||||
knowledge = self.get_knowledge_by_id(id=id, db=db)
|
knowledge = self.get_knowledge_by_id(id=id, db=db)
|
||||||
db.query(Knowledge).filter_by(id=id).update(
|
db.query(Knowledge).filter_by(id=id).update(
|
||||||
{
|
{
|
||||||
**form_data.model_dump(exclude={"access_grants"}),
|
**form_data.model_dump(exclude={'access_grants'}),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
if form_data.access_grants is not None:
|
if form_data.access_grants is not None:
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('knowledge', id, form_data.access_grants, db=db)
|
||||||
"knowledge", id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
return self.get_knowledge_by_id(id=id, db=db)
|
return self.get_knowledge_by_id(id=id, db=db)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
@@ -697,8 +618,8 @@ class KnowledgeTable:
|
|||||||
knowledge = self.get_knowledge_by_id(id=id, db=db)
|
knowledge = self.get_knowledge_by_id(id=id, db=db)
|
||||||
db.query(Knowledge).filter_by(id=id).update(
|
db.query(Knowledge).filter_by(id=id).update(
|
||||||
{
|
{
|
||||||
"data": data,
|
'data': data,
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -710,7 +631,7 @@ class KnowledgeTable:
|
|||||||
def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
def delete_knowledge_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
AccessGrants.revoke_all_access("knowledge", id, db=db)
|
AccessGrants.revoke_all_access('knowledge', id, db=db)
|
||||||
db.query(Knowledge).filter_by(id=id).delete()
|
db.query(Knowledge).filter_by(id=id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
@@ -722,7 +643,7 @@ class KnowledgeTable:
|
|||||||
try:
|
try:
|
||||||
knowledge_ids = [row[0] for row in db.query(Knowledge.id).all()]
|
knowledge_ids = [row[0] for row in db.query(Knowledge.id).all()]
|
||||||
for knowledge_id in knowledge_ids:
|
for knowledge_id in knowledge_ids:
|
||||||
AccessGrants.revoke_all_access("knowledge", knowledge_id, db=db)
|
AccessGrants.revoke_all_access('knowledge', knowledge_id, db=db)
|
||||||
db.query(Knowledge).delete()
|
db.query(Knowledge).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from sqlalchemy import BigInteger, Column, String, Text
|
|||||||
|
|
||||||
|
|
||||||
class Memory(Base):
|
class Memory(Base):
|
||||||
__tablename__ = "memory"
|
__tablename__ = 'memory'
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
@@ -49,11 +49,11 @@ class MemoriesTable:
|
|||||||
|
|
||||||
memory = MemoryModel(
|
memory = MemoryModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
'id': id,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"content": content,
|
'content': content,
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = Memory(**memory.model_dump())
|
result = Memory(**memory.model_dump())
|
||||||
@@ -95,9 +95,7 @@ class MemoriesTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_memories_by_user_id(
|
def get_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[MemoryModel]:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[MemoryModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
memories = db.query(Memory).filter_by(user_id=user_id).all()
|
memories = db.query(Memory).filter_by(user_id=user_id).all()
|
||||||
@@ -105,9 +103,7 @@ class MemoriesTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_memory_by_id(
|
def get_memory_by_id(self, id: str, db: Optional[Session] = None) -> Optional[MemoryModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[MemoryModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
memory = db.get(Memory, id)
|
memory = db.get(Memory, id)
|
||||||
@@ -126,9 +122,7 @@ class MemoriesTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_memories_by_user_id(
|
def delete_memories_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
db.query(Memory).filter_by(user_id=user_id).delete()
|
db.query(Memory).filter_by(user_id=user_id).delete()
|
||||||
@@ -138,9 +132,7 @@ class MemoriesTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_memory_by_id_and_user_id(
|
def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
memory = db.get(Memory, id)
|
memory = db.get(Memory, id)
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from sqlalchemy.sql import exists
|
|||||||
|
|
||||||
|
|
||||||
class MessageReaction(Base):
|
class MessageReaction(Base):
|
||||||
__tablename__ = "message_reaction"
|
__tablename__ = 'message_reaction'
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
message_id = Column(Text)
|
message_id = Column(Text)
|
||||||
@@ -40,7 +40,7 @@ class MessageReactionModel(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Message(Base):
|
class Message(Base):
|
||||||
__tablename__ = "message"
|
__tablename__ = 'message'
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
|
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
@@ -112,7 +112,7 @@ class MessageUserResponse(MessageModel):
|
|||||||
class MessageUserSlimResponse(MessageUserResponse):
|
class MessageUserSlimResponse(MessageUserResponse):
|
||||||
data: bool | None = None
|
data: bool | None = None
|
||||||
|
|
||||||
@field_validator("data", mode="before")
|
@field_validator('data', mode='before')
|
||||||
def convert_data_to_bool(cls, v):
|
def convert_data_to_bool(cls, v):
|
||||||
# No data or not a dict → False
|
# No data or not a dict → False
|
||||||
if not isinstance(v, dict):
|
if not isinstance(v, dict):
|
||||||
@@ -152,19 +152,19 @@ class MessageTable:
|
|||||||
|
|
||||||
message = MessageModel(
|
message = MessageModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
'id': id,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"channel_id": channel_id,
|
'channel_id': channel_id,
|
||||||
"reply_to_id": form_data.reply_to_id,
|
'reply_to_id': form_data.reply_to_id,
|
||||||
"parent_id": form_data.parent_id,
|
'parent_id': form_data.parent_id,
|
||||||
"is_pinned": False,
|
'is_pinned': False,
|
||||||
"pinned_at": None,
|
'pinned_at': None,
|
||||||
"pinned_by": None,
|
'pinned_by': None,
|
||||||
"content": form_data.content,
|
'content': form_data.content,
|
||||||
"data": form_data.data,
|
'data': form_data.data,
|
||||||
"meta": form_data.meta,
|
'meta': form_data.meta,
|
||||||
"created_at": ts,
|
'created_at': ts,
|
||||||
"updated_at": ts,
|
'updated_at': ts,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = Message(**message.model_dump())
|
result = Message(**message.model_dump())
|
||||||
@@ -186,9 +186,7 @@ class MessageTable:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
reply_to_message = (
|
reply_to_message = (
|
||||||
self.get_message_by_id(
|
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
||||||
message.reply_to_id, include_thread_replies=False, db=db
|
|
||||||
)
|
|
||||||
if message.reply_to_id
|
if message.reply_to_id
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
@@ -200,22 +198,22 @@ class MessageTable:
|
|||||||
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
|
thread_replies = self.get_thread_replies_by_message_id(id, db=db)
|
||||||
|
|
||||||
# Check if message was sent by webhook (webhook info in meta takes precedence)
|
# Check if message was sent by webhook (webhook info in meta takes precedence)
|
||||||
webhook_info = message.meta.get("webhook") if message.meta else None
|
webhook_info = message.meta.get('webhook') if message.meta else None
|
||||||
if webhook_info and webhook_info.get("id"):
|
if webhook_info and webhook_info.get('id'):
|
||||||
# Look up webhook by ID to get current name
|
# Look up webhook by ID to get current name
|
||||||
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
|
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
||||||
if webhook:
|
if webhook:
|
||||||
user_info = {
|
user_info = {
|
||||||
"id": webhook.id,
|
'id': webhook.id,
|
||||||
"name": webhook.name,
|
'name': webhook.name,
|
||||||
"role": "webhook",
|
'role': 'webhook',
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Webhook was deleted, use placeholder
|
# Webhook was deleted, use placeholder
|
||||||
user_info = {
|
user_info = {
|
||||||
"id": webhook_info.get("id"),
|
'id': webhook_info.get('id'),
|
||||||
"name": "Deleted Webhook",
|
'name': 'Deleted Webhook',
|
||||||
"role": "webhook",
|
'role': 'webhook',
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
user = Users.get_user_by_id(message.user_id, db=db)
|
user = Users.get_user_by_id(message.user_id, db=db)
|
||||||
@@ -224,79 +222,57 @@ class MessageTable:
|
|||||||
return MessageResponse.model_validate(
|
return MessageResponse.model_validate(
|
||||||
{
|
{
|
||||||
**MessageModel.model_validate(message).model_dump(),
|
**MessageModel.model_validate(message).model_dump(),
|
||||||
"user": user_info,
|
'user': user_info,
|
||||||
"reply_to_message": (
|
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
||||||
reply_to_message.model_dump() if reply_to_message else None
|
'latest_reply_at': (thread_replies[0].created_at if thread_replies else None),
|
||||||
),
|
'reply_count': len(thread_replies),
|
||||||
"latest_reply_at": (
|
'reactions': reactions,
|
||||||
thread_replies[0].created_at if thread_replies else None
|
|
||||||
),
|
|
||||||
"reply_count": len(thread_replies),
|
|
||||||
"reactions": reactions,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_thread_replies_by_message_id(
|
def get_thread_replies_by_message_id(self, id: str, db: Optional[Session] = None) -> list[MessageReplyToResponse]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> list[MessageReplyToResponse]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
all_messages = (
|
all_messages = db.query(Message).filter_by(parent_id=id).order_by(Message.created_at.desc()).all()
|
||||||
db.query(Message)
|
|
||||||
.filter_by(parent_id=id)
|
|
||||||
.order_by(Message.created_at.desc())
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = []
|
messages = []
|
||||||
for message in all_messages:
|
for message in all_messages:
|
||||||
reply_to_message = (
|
reply_to_message = (
|
||||||
self.get_message_by_id(
|
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
||||||
message.reply_to_id, include_thread_replies=False, db=db
|
|
||||||
)
|
|
||||||
if message.reply_to_id
|
if message.reply_to_id
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
webhook_info = message.meta.get("webhook") if message.meta else None
|
webhook_info = message.meta.get('webhook') if message.meta else None
|
||||||
user_info = None
|
user_info = None
|
||||||
if webhook_info and webhook_info.get("id"):
|
if webhook_info and webhook_info.get('id'):
|
||||||
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
|
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
||||||
if webhook:
|
if webhook:
|
||||||
user_info = {
|
user_info = {
|
||||||
"id": webhook.id,
|
'id': webhook.id,
|
||||||
"name": webhook.name,
|
'name': webhook.name,
|
||||||
"role": "webhook",
|
'role': 'webhook',
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
user_info = {
|
user_info = {
|
||||||
"id": webhook_info.get("id"),
|
'id': webhook_info.get('id'),
|
||||||
"name": "Deleted Webhook",
|
'name': 'Deleted Webhook',
|
||||||
"role": "webhook",
|
'role': 'webhook',
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.append(
|
messages.append(
|
||||||
MessageReplyToResponse.model_validate(
|
MessageReplyToResponse.model_validate(
|
||||||
{
|
{
|
||||||
**MessageModel.model_validate(message).model_dump(),
|
**MessageModel.model_validate(message).model_dump(),
|
||||||
"user": user_info,
|
'user': user_info,
|
||||||
"reply_to_message": (
|
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
||||||
reply_to_message.model_dump()
|
|
||||||
if reply_to_message
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def get_reply_user_ids_by_message_id(
|
def get_reply_user_ids_by_message_id(self, id: str, db: Optional[Session] = None) -> list[str]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> list[str]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [message.user_id for message in db.query(Message).filter_by(parent_id=id).all()]
|
||||||
message.user_id
|
|
||||||
for message in db.query(Message).filter_by(parent_id=id).all()
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_messages_by_channel_id(
|
def get_messages_by_channel_id(
|
||||||
self,
|
self,
|
||||||
@@ -318,40 +294,34 @@ class MessageTable:
|
|||||||
messages = []
|
messages = []
|
||||||
for message in all_messages:
|
for message in all_messages:
|
||||||
reply_to_message = (
|
reply_to_message = (
|
||||||
self.get_message_by_id(
|
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
||||||
message.reply_to_id, include_thread_replies=False, db=db
|
|
||||||
)
|
|
||||||
if message.reply_to_id
|
if message.reply_to_id
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
webhook_info = message.meta.get("webhook") if message.meta else None
|
webhook_info = message.meta.get('webhook') if message.meta else None
|
||||||
user_info = None
|
user_info = None
|
||||||
if webhook_info and webhook_info.get("id"):
|
if webhook_info and webhook_info.get('id'):
|
||||||
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
|
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
||||||
if webhook:
|
if webhook:
|
||||||
user_info = {
|
user_info = {
|
||||||
"id": webhook.id,
|
'id': webhook.id,
|
||||||
"name": webhook.name,
|
'name': webhook.name,
|
||||||
"role": "webhook",
|
'role': 'webhook',
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
user_info = {
|
user_info = {
|
||||||
"id": webhook_info.get("id"),
|
'id': webhook_info.get('id'),
|
||||||
"name": "Deleted Webhook",
|
'name': 'Deleted Webhook',
|
||||||
"role": "webhook",
|
'role': 'webhook',
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.append(
|
messages.append(
|
||||||
MessageReplyToResponse.model_validate(
|
MessageReplyToResponse.model_validate(
|
||||||
{
|
{
|
||||||
**MessageModel.model_validate(message).model_dump(),
|
**MessageModel.model_validate(message).model_dump(),
|
||||||
"user": user_info,
|
'user': user_info,
|
||||||
"reply_to_message": (
|
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
||||||
reply_to_message.model_dump()
|
|
||||||
if reply_to_message
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -387,55 +357,42 @@ class MessageTable:
|
|||||||
messages = []
|
messages = []
|
||||||
for message in all_messages:
|
for message in all_messages:
|
||||||
reply_to_message = (
|
reply_to_message = (
|
||||||
self.get_message_by_id(
|
self.get_message_by_id(message.reply_to_id, include_thread_replies=False, db=db)
|
||||||
message.reply_to_id, include_thread_replies=False, db=db
|
|
||||||
)
|
|
||||||
if message.reply_to_id
|
if message.reply_to_id
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
webhook_info = message.meta.get("webhook") if message.meta else None
|
webhook_info = message.meta.get('webhook') if message.meta else None
|
||||||
user_info = None
|
user_info = None
|
||||||
if webhook_info and webhook_info.get("id"):
|
if webhook_info and webhook_info.get('id'):
|
||||||
webhook = Channels.get_webhook_by_id(webhook_info.get("id"), db=db)
|
webhook = Channels.get_webhook_by_id(webhook_info.get('id'), db=db)
|
||||||
if webhook:
|
if webhook:
|
||||||
user_info = {
|
user_info = {
|
||||||
"id": webhook.id,
|
'id': webhook.id,
|
||||||
"name": webhook.name,
|
'name': webhook.name,
|
||||||
"role": "webhook",
|
'role': 'webhook',
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
user_info = {
|
user_info = {
|
||||||
"id": webhook_info.get("id"),
|
'id': webhook_info.get('id'),
|
||||||
"name": "Deleted Webhook",
|
'name': 'Deleted Webhook',
|
||||||
"role": "webhook",
|
'role': 'webhook',
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.append(
|
messages.append(
|
||||||
MessageReplyToResponse.model_validate(
|
MessageReplyToResponse.model_validate(
|
||||||
{
|
{
|
||||||
**MessageModel.model_validate(message).model_dump(),
|
**MessageModel.model_validate(message).model_dump(),
|
||||||
"user": user_info,
|
'user': user_info,
|
||||||
"reply_to_message": (
|
'reply_to_message': (reply_to_message.model_dump() if reply_to_message else None),
|
||||||
reply_to_message.model_dump()
|
|
||||||
if reply_to_message
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def get_last_message_by_channel_id(
|
def get_last_message_by_channel_id(self, channel_id: str, db: Optional[Session] = None) -> Optional[MessageModel]:
|
||||||
self, channel_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[MessageModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
message = (
|
message = db.query(Message).filter_by(channel_id=channel_id).order_by(Message.created_at.desc()).first()
|
||||||
db.query(Message)
|
|
||||||
.filter_by(channel_id=channel_id)
|
|
||||||
.order_by(Message.created_at.desc())
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
return MessageModel.model_validate(message) if message else None
|
return MessageModel.model_validate(message) if message else None
|
||||||
|
|
||||||
def get_pinned_messages_by_channel_id(
|
def get_pinned_messages_by_channel_id(
|
||||||
@@ -513,11 +470,7 @@ class MessageTable:
|
|||||||
) -> Optional[MessageReactionModel]:
|
) -> Optional[MessageReactionModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# check for existing reaction
|
# check for existing reaction
|
||||||
existing_reaction = (
|
existing_reaction = db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).first()
|
||||||
db.query(MessageReaction)
|
|
||||||
.filter_by(message_id=id, user_id=user_id, name=name)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if existing_reaction:
|
if existing_reaction:
|
||||||
return MessageReactionModel.model_validate(existing_reaction)
|
return MessageReactionModel.model_validate(existing_reaction)
|
||||||
|
|
||||||
@@ -535,9 +488,7 @@ class MessageTable:
|
|||||||
db.refresh(result)
|
db.refresh(result)
|
||||||
return MessageReactionModel.model_validate(result) if result else None
|
return MessageReactionModel.model_validate(result) if result else None
|
||||||
|
|
||||||
def get_reactions_by_message_id(
|
def get_reactions_by_message_id(self, id: str, db: Optional[Session] = None) -> list[Reactions]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> list[Reactions]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# JOIN User so all user info is fetched in one query
|
# JOIN User so all user info is fetched in one query
|
||||||
results = (
|
results = (
|
||||||
@@ -552,18 +503,18 @@ class MessageTable:
|
|||||||
for reaction, user in results:
|
for reaction, user in results:
|
||||||
if reaction.name not in reactions:
|
if reaction.name not in reactions:
|
||||||
reactions[reaction.name] = {
|
reactions[reaction.name] = {
|
||||||
"name": reaction.name,
|
'name': reaction.name,
|
||||||
"users": [],
|
'users': [],
|
||||||
"count": 0,
|
'count': 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
reactions[reaction.name]["users"].append(
|
reactions[reaction.name]['users'].append(
|
||||||
{
|
{
|
||||||
"id": user.id,
|
'id': user.id,
|
||||||
"name": user.name,
|
'name': user.name,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
reactions[reaction.name]["count"] += 1
|
reactions[reaction.name]['count'] += 1
|
||||||
|
|
||||||
return [Reactions(**reaction) for reaction in reactions.values()]
|
return [Reactions(**reaction) for reaction in reactions.values()]
|
||||||
|
|
||||||
@@ -571,9 +522,7 @@ class MessageTable:
|
|||||||
self, id: str, user_id: str, name: str, db: Optional[Session] = None
|
self, id: str, user_id: str, name: str, db: Optional[Session] = None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
db.query(MessageReaction).filter_by(
|
db.query(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name).delete()
|
||||||
message_id=id, user_id=user_id, name=name
|
|
||||||
).delete()
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -612,21 +561,15 @@ class MessageTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
query_builder = db.query(Message).filter(
|
query_builder = db.query(Message).filter(
|
||||||
Message.channel_id.in_(channel_ids),
|
Message.channel_id.in_(channel_ids),
|
||||||
Message.content.ilike(f"%{query}%"),
|
Message.content.ilike(f'%{query}%'),
|
||||||
)
|
)
|
||||||
|
|
||||||
if start_timestamp:
|
if start_timestamp:
|
||||||
query_builder = query_builder.filter(
|
query_builder = query_builder.filter(Message.created_at >= start_timestamp)
|
||||||
Message.created_at >= start_timestamp
|
|
||||||
)
|
|
||||||
if end_timestamp:
|
if end_timestamp:
|
||||||
query_builder = query_builder.filter(
|
query_builder = query_builder.filter(Message.created_at <= end_timestamp)
|
||||||
Message.created_at <= end_timestamp
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = (
|
messages = query_builder.order_by(Message.created_at.desc()).limit(limit).all()
|
||||||
query_builder.order_by(Message.created_at.desc()).limit(limit).all()
|
|
||||||
)
|
|
||||||
return [MessageModel.model_validate(msg) for msg in messages]
|
return [MessageModel.model_validate(msg) for msg in messages]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -28,13 +28,13 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# ModelParams is a model for the data stored in the params field of the Model table
|
# ModelParams is a model for the data stored in the params field of the Model table
|
||||||
class ModelParams(BaseModel):
|
class ModelParams(BaseModel):
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# ModelMeta is a model for the data stored in the meta field of the Model table
|
# ModelMeta is a model for the data stored in the meta field of the Model table
|
||||||
class ModelMeta(BaseModel):
|
class ModelMeta(BaseModel):
|
||||||
profile_image_url: Optional[str] = "/static/favicon.png"
|
profile_image_url: Optional[str] = '/static/favicon.png'
|
||||||
|
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
"""
|
"""
|
||||||
@@ -43,13 +43,13 @@ class ModelMeta(BaseModel):
|
|||||||
|
|
||||||
capabilities: Optional[dict] = None
|
capabilities: Optional[dict] = None
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Model(Base):
|
class Model(Base):
|
||||||
__tablename__ = "model"
|
__tablename__ = 'model'
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
"""
|
"""
|
||||||
@@ -139,10 +139,8 @@ class ModelForm(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ModelsTable:
|
class ModelsTable:
|
||||||
def _get_access_grants(
|
def _get_access_grants(self, model_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||||
self, model_id: str, db: Optional[Session] = None
|
return AccessGrants.get_grants_by_resource('model', model_id, db=db)
|
||||||
) -> list[AccessGrantModel]:
|
|
||||||
return AccessGrants.get_grants_by_resource("model", model_id, db=db)
|
|
||||||
|
|
||||||
def _to_model_model(
|
def _to_model_model(
|
||||||
self,
|
self,
|
||||||
@@ -150,13 +148,9 @@ class ModelsTable:
|
|||||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> ModelModel:
|
) -> ModelModel:
|
||||||
model_data = ModelModel.model_validate(model).model_dump(
|
model_data = ModelModel.model_validate(model).model_dump(exclude={'access_grants'})
|
||||||
exclude={"access_grants"}
|
model_data['access_grants'] = (
|
||||||
)
|
access_grants if access_grants is not None else self._get_access_grants(model_data['id'], db=db)
|
||||||
model_data["access_grants"] = (
|
|
||||||
access_grants
|
|
||||||
if access_grants is not None
|
|
||||||
else self._get_access_grants(model_data["id"], db=db)
|
|
||||||
)
|
)
|
||||||
return ModelModel.model_validate(model_data)
|
return ModelModel.model_validate(model_data)
|
||||||
|
|
||||||
@@ -167,37 +161,32 @@ class ModelsTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
result = Model(
|
result = Model(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(exclude={"access_grants"}),
|
**form_data.model_dump(exclude={'access_grants'}),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.add(result)
|
db.add(result)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(result)
|
db.refresh(result)
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('model', result.id, form_data.access_grants, db=db)
|
||||||
"model", result.id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
return self._to_model_model(result, db=db)
|
return self._to_model_model(result, db=db)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Failed to insert a new model: {e}")
|
log.exception(f'Failed to insert a new model: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]:
|
def get_all_models(self, db: Optional[Session] = None) -> list[ModelModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
all_models = db.query(Model).all()
|
all_models = db.query(Model).all()
|
||||||
model_ids = [model.id for model in all_models]
|
model_ids = [model.id for model in all_models]
|
||||||
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
|
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||||
return [
|
return [
|
||||||
self._to_model_model(
|
self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models
|
||||||
model, access_grants=grants_map.get(model.id, []), db=db
|
|
||||||
)
|
|
||||||
for model in all_models
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]:
|
def get_models(self, db: Optional[Session] = None) -> list[ModelUserResponse]:
|
||||||
@@ -209,7 +198,7 @@ class ModelsTable:
|
|||||||
|
|
||||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||||
users_dict = {user.id: user for user in users}
|
users_dict = {user.id: user for user in users}
|
||||||
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
|
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
for model in all_models:
|
for model in all_models:
|
||||||
@@ -222,7 +211,7 @@ class ModelsTable:
|
|||||||
access_grants=grants_map.get(model.id, []),
|
access_grants=grants_map.get(model.id, []),
|
||||||
db=db,
|
db=db,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
"user": user.model_dump() if user else None,
|
'user': user.model_dump() if user else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -232,28 +221,23 @@ class ModelsTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
all_models = db.query(Model).filter(Model.base_model_id == None).all()
|
all_models = db.query(Model).filter(Model.base_model_id == None).all()
|
||||||
model_ids = [model.id for model in all_models]
|
model_ids = [model.id for model in all_models]
|
||||||
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
|
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||||
return [
|
return [
|
||||||
self._to_model_model(
|
self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models
|
||||||
model, access_grants=grants_map.get(model.id, []), db=db
|
|
||||||
)
|
|
||||||
for model in all_models
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_models_by_user_id(
|
def get_models_by_user_id(
|
||||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
|
||||||
) -> list[ModelUserResponse]:
|
) -> list[ModelUserResponse]:
|
||||||
models = self.get_models(db=db)
|
models = self.get_models(db=db)
|
||||||
user_group_ids = {
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
}
|
|
||||||
return [
|
return [
|
||||||
model
|
model
|
||||||
for model in models
|
for model in models
|
||||||
if model.user_id == user_id
|
if model.user_id == user_id
|
||||||
or AccessGrants.has_access(
|
or AccessGrants.has_access(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
resource_type="model",
|
resource_type='model',
|
||||||
resource_id=model.id,
|
resource_id=model.id,
|
||||||
permission=permission,
|
permission=permission,
|
||||||
user_group_ids=user_group_ids,
|
user_group_ids=user_group_ids,
|
||||||
@@ -261,13 +245,13 @@ class ModelsTable:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
|
||||||
return AccessGrants.has_permission_filter(
|
return AccessGrants.has_permission_filter(
|
||||||
db=db,
|
db=db,
|
||||||
query=query,
|
query=query,
|
||||||
DocumentModel=Model,
|
DocumentModel=Model,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
resource_type="model",
|
resource_type='model',
|
||||||
permission=permission,
|
permission=permission,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -285,22 +269,22 @@ class ModelsTable:
|
|||||||
query = query.filter(Model.base_model_id != None)
|
query = query.filter(Model.base_model_id != None)
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
query_key = filter.get("query")
|
query_key = filter.get('query')
|
||||||
if query_key:
|
if query_key:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
Model.name.ilike(f"%{query_key}%"),
|
Model.name.ilike(f'%{query_key}%'),
|
||||||
Model.base_model_id.ilike(f"%{query_key}%"),
|
Model.base_model_id.ilike(f'%{query_key}%'),
|
||||||
User.name.ilike(f"%{query_key}%"),
|
User.name.ilike(f'%{query_key}%'),
|
||||||
User.email.ilike(f"%{query_key}%"),
|
User.email.ilike(f'%{query_key}%'),
|
||||||
User.username.ilike(f"%{query_key}%"),
|
User.username.ilike(f'%{query_key}%'),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
view_option = filter.get('view_option')
|
||||||
if view_option == "created":
|
if view_option == 'created':
|
||||||
query = query.filter(Model.user_id == user_id)
|
query = query.filter(Model.user_id == user_id)
|
||||||
elif view_option == "shared":
|
elif view_option == 'shared':
|
||||||
query = query.filter(Model.user_id != user_id)
|
query = query.filter(Model.user_id != user_id)
|
||||||
|
|
||||||
# Apply access control filtering
|
# Apply access control filtering
|
||||||
@@ -308,10 +292,10 @@ class ModelsTable:
|
|||||||
db,
|
db,
|
||||||
query,
|
query,
|
||||||
filter,
|
filter,
|
||||||
permission="read",
|
permission='read',
|
||||||
)
|
)
|
||||||
|
|
||||||
tag = filter.get("tag")
|
tag = filter.get('tag')
|
||||||
if tag:
|
if tag:
|
||||||
# TODO: This is a simple implementation and should be improved for performance
|
# TODO: This is a simple implementation and should be improved for performance
|
||||||
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
|
like_pattern = f'%"{tag.lower()}"%' # `"tag"` inside JSON array
|
||||||
@@ -319,21 +303,21 @@ class ModelsTable:
|
|||||||
|
|
||||||
query = query.filter(meta_text.like(like_pattern))
|
query = query.filter(meta_text.like(like_pattern))
|
||||||
|
|
||||||
order_by = filter.get("order_by")
|
order_by = filter.get('order_by')
|
||||||
direction = filter.get("direction")
|
direction = filter.get('direction')
|
||||||
|
|
||||||
if order_by == "name":
|
if order_by == 'name':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(Model.name.asc())
|
query = query.order_by(Model.name.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(Model.name.desc())
|
query = query.order_by(Model.name.desc())
|
||||||
elif order_by == "created_at":
|
elif order_by == 'created_at':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(Model.created_at.asc())
|
query = query.order_by(Model.created_at.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(Model.created_at.desc())
|
query = query.order_by(Model.created_at.desc())
|
||||||
elif order_by == "updated_at":
|
elif order_by == 'updated_at':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(Model.updated_at.asc())
|
query = query.order_by(Model.updated_at.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(Model.updated_at.desc())
|
query = query.order_by(Model.updated_at.desc())
|
||||||
@@ -352,7 +336,7 @@ class ModelsTable:
|
|||||||
items = query.all()
|
items = query.all()
|
||||||
|
|
||||||
model_ids = [model.id for model, _ in items]
|
model_ids = [model.id for model, _ in items]
|
||||||
grants_map = AccessGrants.get_grants_by_resources("model", model_ids, db=db)
|
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
for model, user in items:
|
for model, user in items:
|
||||||
@@ -363,19 +347,13 @@ class ModelsTable:
|
|||||||
access_grants=grants_map.get(model.id, []),
|
access_grants=grants_map.get(model.id, []),
|
||||||
db=db,
|
db=db,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
user=(
|
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||||
UserResponse(**UserModel.model_validate(user).model_dump())
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return ModelListResponse(items=models, total=total)
|
return ModelListResponse(items=models, total=total)
|
||||||
|
|
||||||
def get_model_by_id(
|
def get_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[ModelModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
model = db.get(Model, id)
|
model = db.get(Model, id)
|
||||||
@@ -383,16 +361,12 @@ class ModelsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_models_by_ids(
|
def get_models_by_ids(self, ids: list[str], db: Optional[Session] = None) -> list[ModelModel]:
|
||||||
self, ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> list[ModelModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
models = db.query(Model).filter(Model.id.in_(ids)).all()
|
models = db.query(Model).filter(Model.id.in_(ids)).all()
|
||||||
model_ids = [model.id for model in models]
|
model_ids = [model.id for model in models]
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||||
"model", model_ids, db=db
|
|
||||||
)
|
|
||||||
return [
|
return [
|
||||||
self._to_model_model(
|
self._to_model_model(
|
||||||
model,
|
model,
|
||||||
@@ -404,9 +378,7 @@ class ModelsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def toggle_model_by_id(
|
def toggle_model_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ModelModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[ModelModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
model = db.query(Model).filter_by(id=id).first()
|
model = db.query(Model).filter_by(id=id).first()
|
||||||
@@ -422,30 +394,26 @@ class ModelsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_model_by_id(
|
def update_model_by_id(self, id: str, model: ModelForm, db: Optional[Session] = None) -> Optional[ModelModel]:
|
||||||
self, id: str, model: ModelForm, db: Optional[Session] = None
|
|
||||||
) -> Optional[ModelModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# update only the fields that are present in the model
|
# update only the fields that are present in the model
|
||||||
data = model.model_dump(exclude={"id", "access_grants"})
|
data = model.model_dump(exclude={'id', 'access_grants'})
|
||||||
result = db.query(Model).filter_by(id=id).update(data)
|
result = db.query(Model).filter_by(id=id).update(data)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
if model.access_grants is not None:
|
if model.access_grants is not None:
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('model', id, model.access_grants, db=db)
|
||||||
"model", id, model.access_grants, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.get_model_by_id(id, db=db)
|
return self.get_model_by_id(id, db=db)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Failed to update the model by id {id}: {e}")
|
log.exception(f'Failed to update the model by id {id}: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
def delete_model_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
AccessGrants.revoke_all_access("model", id, db=db)
|
AccessGrants.revoke_all_access('model', id, db=db)
|
||||||
db.query(Model).filter_by(id=id).delete()
|
db.query(Model).filter_by(id=id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
@@ -458,7 +426,7 @@ class ModelsTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
model_ids = [row[0] for row in db.query(Model.id).all()]
|
model_ids = [row[0] for row in db.query(Model.id).all()]
|
||||||
for model_id in model_ids:
|
for model_id in model_ids:
|
||||||
AccessGrants.revoke_all_access("model", model_id, db=db)
|
AccessGrants.revoke_all_access('model', model_id, db=db)
|
||||||
db.query(Model).delete()
|
db.query(Model).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
@@ -466,9 +434,7 @@ class ModelsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def sync_models(
|
def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[Session] = None) -> list[ModelModel]:
|
||||||
self, user_id: str, models: list[ModelModel], db: Optional[Session] = None
|
|
||||||
) -> list[ModelModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# Get existing models
|
# Get existing models
|
||||||
@@ -483,37 +449,33 @@ class ModelsTable:
|
|||||||
if model.id in existing_ids:
|
if model.id in existing_ids:
|
||||||
db.query(Model).filter_by(id=model.id).update(
|
db.query(Model).filter_by(id=model.id).update(
|
||||||
{
|
{
|
||||||
**model.model_dump(exclude={"access_grants"}),
|
**model.model_dump(exclude={'access_grants'}),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
new_model = Model(
|
new_model = Model(
|
||||||
**{
|
**{
|
||||||
**model.model_dump(exclude={"access_grants"}),
|
**model.model_dump(exclude={'access_grants'}),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.add(new_model)
|
db.add(new_model)
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('model', model.id, model.access_grants, db=db)
|
||||||
"model", model.id, model.access_grants, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove models that are no longer present
|
# Remove models that are no longer present
|
||||||
for model in existing_models:
|
for model in existing_models:
|
||||||
if model.id not in new_model_ids:
|
if model.id not in new_model_ids:
|
||||||
AccessGrants.revoke_all_access("model", model.id, db=db)
|
AccessGrants.revoke_all_access('model', model.id, db=db)
|
||||||
db.delete(model)
|
db.delete(model)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
all_models = db.query(Model).all()
|
all_models = db.query(Model).all()
|
||||||
model_ids = [model.id for model in all_models]
|
model_ids = [model.id for model in all_models]
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||||
"model", model_ids, db=db
|
|
||||||
)
|
|
||||||
return [
|
return [
|
||||||
self._to_model_model(
|
self._to_model_model(
|
||||||
model,
|
model,
|
||||||
@@ -523,7 +485,7 @@ class ModelsTable:
|
|||||||
for model in all_models
|
for model in all_models
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error syncing models for user {user_id}: {e}")
|
log.exception(f'Error syncing models for user {user_id}: {e}')
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from sqlalchemy import or_, func, cast
|
|||||||
|
|
||||||
|
|
||||||
class Note(Base):
|
class Note(Base):
|
||||||
__tablename__ = "note"
|
__tablename__ = 'note'
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text)
|
user_id = Column(Text)
|
||||||
@@ -88,10 +88,8 @@ class NoteListResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class NoteTable:
|
class NoteTable:
|
||||||
def _get_access_grants(
|
def _get_access_grants(self, note_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||||
self, note_id: str, db: Optional[Session] = None
|
return AccessGrants.get_grants_by_resource('note', note_id, db=db)
|
||||||
) -> list[AccessGrantModel]:
|
|
||||||
return AccessGrants.get_grants_by_resource("note", note_id, db=db)
|
|
||||||
|
|
||||||
def _to_note_model(
|
def _to_note_model(
|
||||||
self,
|
self,
|
||||||
@@ -99,51 +97,43 @@ class NoteTable:
|
|||||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> NoteModel:
|
) -> NoteModel:
|
||||||
note_data = NoteModel.model_validate(note).model_dump(exclude={"access_grants"})
|
note_data = NoteModel.model_validate(note).model_dump(exclude={'access_grants'})
|
||||||
note_data["access_grants"] = (
|
note_data['access_grants'] = (
|
||||||
access_grants
|
access_grants if access_grants is not None else self._get_access_grants(note_data['id'], db=db)
|
||||||
if access_grants is not None
|
|
||||||
else self._get_access_grants(note_data["id"], db=db)
|
|
||||||
)
|
)
|
||||||
return NoteModel.model_validate(note_data)
|
return NoteModel.model_validate(note_data)
|
||||||
|
|
||||||
def _has_permission(self, db, query, filter: dict, permission: str = "read"):
|
def _has_permission(self, db, query, filter: dict, permission: str = 'read'):
|
||||||
return AccessGrants.has_permission_filter(
|
return AccessGrants.has_permission_filter(
|
||||||
db=db,
|
db=db,
|
||||||
query=query,
|
query=query,
|
||||||
DocumentModel=Note,
|
DocumentModel=Note,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
resource_type="note",
|
resource_type='note',
|
||||||
permission=permission,
|
permission=permission,
|
||||||
)
|
)
|
||||||
|
|
||||||
def insert_new_note(
|
def insert_new_note(self, user_id: str, form_data: NoteForm, db: Optional[Session] = None) -> Optional[NoteModel]:
|
||||||
self, user_id: str, form_data: NoteForm, db: Optional[Session] = None
|
|
||||||
) -> Optional[NoteModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
note = NoteModel(
|
note = NoteModel(
|
||||||
**{
|
**{
|
||||||
"id": str(uuid.uuid4()),
|
'id': str(uuid.uuid4()),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
**form_data.model_dump(exclude={"access_grants"}),
|
**form_data.model_dump(exclude={'access_grants'}),
|
||||||
"created_at": int(time.time_ns()),
|
'created_at': int(time.time_ns()),
|
||||||
"updated_at": int(time.time_ns()),
|
'updated_at': int(time.time_ns()),
|
||||||
"access_grants": [],
|
'access_grants': [],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
new_note = Note(**note.model_dump(exclude={"access_grants"}))
|
new_note = Note(**note.model_dump(exclude={'access_grants'}))
|
||||||
|
|
||||||
db.add(new_note)
|
db.add(new_note)
|
||||||
db.commit()
|
db.commit()
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('note', note.id, form_data.access_grants, db=db)
|
||||||
"note", note.id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
return self._to_note_model(new_note, db=db)
|
return self._to_note_model(new_note, db=db)
|
||||||
|
|
||||||
def get_notes(
|
def get_notes(self, skip: int = 0, limit: int = 50, db: Optional[Session] = None) -> list[NoteModel]:
|
||||||
self, skip: int = 0, limit: int = 50, db: Optional[Session] = None
|
|
||||||
) -> list[NoteModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
query = db.query(Note).order_by(Note.updated_at.desc())
|
query = db.query(Note).order_by(Note.updated_at.desc())
|
||||||
if skip is not None:
|
if skip is not None:
|
||||||
@@ -152,13 +142,8 @@ class NoteTable:
|
|||||||
query = query.limit(limit)
|
query = query.limit(limit)
|
||||||
notes = query.all()
|
notes = query.all()
|
||||||
note_ids = [note.id for note in notes]
|
note_ids = [note.id for note in notes]
|
||||||
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db)
|
grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
|
||||||
return [
|
return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes]
|
||||||
self._to_note_model(
|
|
||||||
note, access_grants=grants_map.get(note.id, []), db=db
|
|
||||||
)
|
|
||||||
for note in notes
|
|
||||||
]
|
|
||||||
|
|
||||||
def search_notes(
|
def search_notes(
|
||||||
self,
|
self,
|
||||||
@@ -171,36 +156,32 @@ class NoteTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
query = db.query(Note, User).outerjoin(User, User.id == Note.user_id)
|
query = db.query(Note, User).outerjoin(User, User.id == Note.user_id)
|
||||||
if filter:
|
if filter:
|
||||||
query_key = filter.get("query")
|
query_key = filter.get('query')
|
||||||
if query_key:
|
if query_key:
|
||||||
# Normalize search by removing hyphens and spaces (e.g., "todo" matches "to-do" and "to do")
|
# Normalize search by removing hyphens and spaces (e.g., "todo" matches "to-do" and "to do")
|
||||||
normalized_query = query_key.replace("-", "").replace(" ", "")
|
normalized_query = query_key.replace('-', '').replace(' ', '')
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
|
func.replace(func.replace(Note.title, '-', ''), ' ', '').ilike(f'%{normalized_query}%'),
|
||||||
func.replace(
|
func.replace(
|
||||||
func.replace(Note.title, "-", ""), " ", ""
|
func.replace(cast(Note.data['content']['md'], Text), '-', ''),
|
||||||
).ilike(f"%{normalized_query}%"),
|
' ',
|
||||||
func.replace(
|
'',
|
||||||
func.replace(
|
).ilike(f'%{normalized_query}%'),
|
||||||
cast(Note.data["content"]["md"], Text), "-", ""
|
|
||||||
),
|
|
||||||
" ",
|
|
||||||
"",
|
|
||||||
).ilike(f"%{normalized_query}%"),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
view_option = filter.get('view_option')
|
||||||
if view_option == "created":
|
if view_option == 'created':
|
||||||
query = query.filter(Note.user_id == user_id)
|
query = query.filter(Note.user_id == user_id)
|
||||||
elif view_option == "shared":
|
elif view_option == 'shared':
|
||||||
query = query.filter(Note.user_id != user_id)
|
query = query.filter(Note.user_id != user_id)
|
||||||
|
|
||||||
# Apply access control filtering
|
# Apply access control filtering
|
||||||
if "permission" in filter:
|
if 'permission' in filter:
|
||||||
permission = filter["permission"]
|
permission = filter['permission']
|
||||||
else:
|
else:
|
||||||
permission = "write"
|
permission = 'write'
|
||||||
|
|
||||||
query = self._has_permission(
|
query = self._has_permission(
|
||||||
db,
|
db,
|
||||||
@@ -209,21 +190,21 @@ class NoteTable:
|
|||||||
permission=permission,
|
permission=permission,
|
||||||
)
|
)
|
||||||
|
|
||||||
order_by = filter.get("order_by")
|
order_by = filter.get('order_by')
|
||||||
direction = filter.get("direction")
|
direction = filter.get('direction')
|
||||||
|
|
||||||
if order_by == "name":
|
if order_by == 'name':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(Note.title.asc())
|
query = query.order_by(Note.title.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(Note.title.desc())
|
query = query.order_by(Note.title.desc())
|
||||||
elif order_by == "created_at":
|
elif order_by == 'created_at':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(Note.created_at.asc())
|
query = query.order_by(Note.created_at.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(Note.created_at.desc())
|
query = query.order_by(Note.created_at.desc())
|
||||||
elif order_by == "updated_at":
|
elif order_by == 'updated_at':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(Note.updated_at.asc())
|
query = query.order_by(Note.updated_at.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(Note.updated_at.desc())
|
query = query.order_by(Note.updated_at.desc())
|
||||||
@@ -244,7 +225,7 @@ class NoteTable:
|
|||||||
items = query.all()
|
items = query.all()
|
||||||
|
|
||||||
note_ids = [note.id for note, _ in items]
|
note_ids = [note.id for note, _ in items]
|
||||||
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db)
|
grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
|
||||||
|
|
||||||
notes = []
|
notes = []
|
||||||
for note, user in items:
|
for note, user in items:
|
||||||
@@ -255,11 +236,7 @@ class NoteTable:
|
|||||||
access_grants=grants_map.get(note.id, []),
|
access_grants=grants_map.get(note.id, []),
|
||||||
db=db,
|
db=db,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
user=(
|
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||||
UserResponse(**UserModel.model_validate(user).model_dump())
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -268,20 +245,16 @@ class NoteTable:
|
|||||||
def get_notes_by_user_id(
|
def get_notes_by_user_id(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
permission: str = "read",
|
permission: str = 'read',
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> list[NoteModel]:
|
) -> list[NoteModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user_group_ids = [
|
user_group_ids = [group.id for group in Groups.get_groups_by_member_id(user_id, db=db)]
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
]
|
|
||||||
|
|
||||||
query = db.query(Note).order_by(Note.updated_at.desc())
|
query = db.query(Note).order_by(Note.updated_at.desc())
|
||||||
query = self._has_permission(
|
query = self._has_permission(db, query, {'user_id': user_id, 'group_ids': user_group_ids}, permission)
|
||||||
db, query, {"user_id": user_id, "group_ids": user_group_ids}, permission
|
|
||||||
)
|
|
||||||
|
|
||||||
if skip is not None:
|
if skip is not None:
|
||||||
query = query.offset(skip)
|
query = query.offset(skip)
|
||||||
@@ -290,17 +263,10 @@ class NoteTable:
|
|||||||
|
|
||||||
notes = query.all()
|
notes = query.all()
|
||||||
note_ids = [note.id for note in notes]
|
note_ids = [note.id for note in notes]
|
||||||
grants_map = AccessGrants.get_grants_by_resources("note", note_ids, db=db)
|
grants_map = AccessGrants.get_grants_by_resources('note', note_ids, db=db)
|
||||||
return [
|
return [self._to_note_model(note, access_grants=grants_map.get(note.id, []), db=db) for note in notes]
|
||||||
self._to_note_model(
|
|
||||||
note, access_grants=grants_map.get(note.id, []), db=db
|
|
||||||
)
|
|
||||||
for note in notes
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_note_by_id(
|
def get_note_by_id(self, id: str, db: Optional[Session] = None) -> Optional[NoteModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[NoteModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
note = db.query(Note).filter(Note.id == id).first()
|
note = db.query(Note).filter(Note.id == id).first()
|
||||||
return self._to_note_model(note, db=db) if note else None
|
return self._to_note_model(note, db=db) if note else None
|
||||||
@@ -315,17 +281,15 @@ class NoteTable:
|
|||||||
|
|
||||||
form_data = form_data.model_dump(exclude_unset=True)
|
form_data = form_data.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
if "title" in form_data:
|
if 'title' in form_data:
|
||||||
note.title = form_data["title"]
|
note.title = form_data['title']
|
||||||
if "data" in form_data:
|
if 'data' in form_data:
|
||||||
note.data = {**note.data, **form_data["data"]}
|
note.data = {**note.data, **form_data['data']}
|
||||||
if "meta" in form_data:
|
if 'meta' in form_data:
|
||||||
note.meta = {**note.meta, **form_data["meta"]}
|
note.meta = {**note.meta, **form_data['meta']}
|
||||||
|
|
||||||
if "access_grants" in form_data:
|
if 'access_grants' in form_data:
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('note', id, form_data['access_grants'], db=db)
|
||||||
"note", id, form_data["access_grants"], db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
note.updated_at = int(time.time_ns())
|
note.updated_at = int(time.time_ns())
|
||||||
|
|
||||||
@@ -335,7 +299,7 @@ class NoteTable:
|
|||||||
def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
AccessGrants.revoke_all_access("note", id, db=db)
|
AccessGrants.revoke_all_access('note', id, db=db)
|
||||||
db.query(Note).filter(Note.id == id).delete()
|
db.query(Note).filter(Note.id == id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -23,23 +23,21 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class OAuthSession(Base):
|
class OAuthSession(Base):
|
||||||
__tablename__ = "oauth_session"
|
__tablename__ = 'oauth_session'
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text, nullable=False)
|
user_id = Column(Text, nullable=False)
|
||||||
provider = Column(Text, nullable=False)
|
provider = Column(Text, nullable=False)
|
||||||
token = Column(
|
token = Column(Text, nullable=False) # JSON with access_token, id_token, refresh_token
|
||||||
Text, nullable=False
|
|
||||||
) # JSON with access_token, id_token, refresh_token
|
|
||||||
expires_at = Column(BigInteger, nullable=False)
|
expires_at = Column(BigInteger, nullable=False)
|
||||||
created_at = Column(BigInteger, nullable=False)
|
created_at = Column(BigInteger, nullable=False)
|
||||||
updated_at = Column(BigInteger, nullable=False)
|
updated_at = Column(BigInteger, nullable=False)
|
||||||
|
|
||||||
# Add indexes for better performance
|
# Add indexes for better performance
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_oauth_session_user_id", "user_id"),
|
Index('idx_oauth_session_user_id', 'user_id'),
|
||||||
Index("idx_oauth_session_expires_at", "expires_at"),
|
Index('idx_oauth_session_expires_at', 'expires_at'),
|
||||||
Index("idx_oauth_session_user_provider", "user_id", "provider"),
|
Index('idx_oauth_session_user_provider', 'user_id', 'provider'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +69,7 @@ class OAuthSessionTable:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||||
if not self.encryption_key:
|
if not self.encryption_key:
|
||||||
raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set")
|
raise Exception('OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set')
|
||||||
|
|
||||||
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
|
# check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes)
|
||||||
if len(self.encryption_key) != 44:
|
if len(self.encryption_key) != 44:
|
||||||
@@ -83,7 +81,7 @@ class OAuthSessionTable:
|
|||||||
try:
|
try:
|
||||||
self.fernet = Fernet(self.encryption_key)
|
self.fernet = Fernet(self.encryption_key)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error initializing Fernet with provided key: {e}")
|
log.error(f'Error initializing Fernet with provided key: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _encrypt_token(self, token) -> str:
|
def _encrypt_token(self, token) -> str:
|
||||||
@@ -93,7 +91,7 @@ class OAuthSessionTable:
|
|||||||
encrypted = self.fernet.encrypt(token_json.encode()).decode()
|
encrypted = self.fernet.encrypt(token_json.encode()).decode()
|
||||||
return encrypted
|
return encrypted
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error encrypting tokens: {e}")
|
log.error(f'Error encrypting tokens: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _decrypt_token(self, token: str):
|
def _decrypt_token(self, token: str):
|
||||||
@@ -102,7 +100,7 @@ class OAuthSessionTable:
|
|||||||
decrypted = self.fernet.decrypt(token.encode()).decode()
|
decrypted = self.fernet.decrypt(token.encode()).decode()
|
||||||
return json.loads(decrypted)
|
return json.loads(decrypted)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error decrypting tokens: {type(e).__name__}: {e}")
|
log.error(f'Error decrypting tokens: {type(e).__name__}: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def create_session(
|
def create_session(
|
||||||
@@ -120,13 +118,13 @@ class OAuthSessionTable:
|
|||||||
|
|
||||||
result = OAuthSession(
|
result = OAuthSession(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
'id': id,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"provider": provider,
|
'provider': provider,
|
||||||
"token": self._encrypt_token(token),
|
'token': self._encrypt_token(token),
|
||||||
"expires_at": token.get("expires_at"),
|
'expires_at': token.get('expires_at'),
|
||||||
"created_at": current_time,
|
'created_at': current_time,
|
||||||
"updated_at": current_time,
|
'updated_at': current_time,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -141,12 +139,10 @@ class OAuthSessionTable:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error creating OAuth session: {e}")
|
log.error(f'Error creating OAuth session: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_session_by_id(
|
def get_session_by_id(self, session_id: str, db: Optional[Session] = None) -> Optional[OAuthSessionModel]:
|
||||||
self, session_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[OAuthSessionModel]:
|
|
||||||
"""Get OAuth session by ID"""
|
"""Get OAuth session by ID"""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
@@ -158,7 +154,7 @@ class OAuthSessionTable:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error getting OAuth session by ID: {e}")
|
log.error(f'Error getting OAuth session by ID: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_session_by_id_and_user_id(
|
def get_session_by_id_and_user_id(
|
||||||
@@ -167,11 +163,7 @@ class OAuthSessionTable:
|
|||||||
"""Get OAuth session by ID and user ID"""
|
"""Get OAuth session by ID and user ID"""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
session = (
|
session = db.query(OAuthSession).filter_by(id=session_id, user_id=user_id).first()
|
||||||
db.query(OAuthSession)
|
|
||||||
.filter_by(id=session_id, user_id=user_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if session:
|
if session:
|
||||||
db.expunge(session)
|
db.expunge(session)
|
||||||
session.token = self._decrypt_token(session.token)
|
session.token = self._decrypt_token(session.token)
|
||||||
@@ -179,7 +171,7 @@ class OAuthSessionTable:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error getting OAuth session by ID: {e}")
|
log.error(f'Error getting OAuth session by ID: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_session_by_provider_and_user_id(
|
def get_session_by_provider_and_user_id(
|
||||||
@@ -201,12 +193,10 @@ class OAuthSessionTable:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error getting OAuth session by provider and user ID: {e}")
|
log.error(f'Error getting OAuth session by provider and user ID: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_sessions_by_user_id(
|
def get_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> List[OAuthSessionModel]:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> List[OAuthSessionModel]:
|
|
||||||
"""Get all OAuth sessions for a user"""
|
"""Get all OAuth sessions for a user"""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
@@ -220,7 +210,7 @@ class OAuthSessionTable:
|
|||||||
results.append(OAuthSessionModel.model_validate(session))
|
results.append(OAuthSessionModel.model_validate(session))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}"
|
f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}'
|
||||||
)
|
)
|
||||||
db.query(OAuthSession).filter_by(id=session.id).delete()
|
db.query(OAuthSession).filter_by(id=session.id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -228,7 +218,7 @@ class OAuthSessionTable:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error getting OAuth sessions by user ID: {e}")
|
log.error(f'Error getting OAuth sessions by user ID: {e}')
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def update_session_by_id(
|
def update_session_by_id(
|
||||||
@@ -241,9 +231,9 @@ class OAuthSessionTable:
|
|||||||
|
|
||||||
db.query(OAuthSession).filter_by(id=session_id).update(
|
db.query(OAuthSession).filter_by(id=session_id).update(
|
||||||
{
|
{
|
||||||
"token": self._encrypt_token(token),
|
'token': self._encrypt_token(token),
|
||||||
"expires_at": token.get("expires_at"),
|
'expires_at': token.get('expires_at'),
|
||||||
"updated_at": current_time,
|
'updated_at': current_time,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -256,12 +246,10 @@ class OAuthSessionTable:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error updating OAuth session tokens: {e}")
|
log.error(f'Error updating OAuth session tokens: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_session_by_id(
|
def delete_session_by_id(self, session_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, session_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Delete an OAuth session"""
|
"""Delete an OAuth session"""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
@@ -269,12 +257,10 @@ class OAuthSessionTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return result > 0
|
return result > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error deleting OAuth session: {e}")
|
log.error(f'Error deleting OAuth session: {e}')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_sessions_by_user_id(
|
def delete_sessions_by_user_id(self, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Delete all OAuth sessions for a user"""
|
"""Delete all OAuth sessions for a user"""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
@@ -282,12 +268,10 @@ class OAuthSessionTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error deleting OAuth sessions by user ID: {e}")
|
log.error(f'Error deleting OAuth sessions by user ID: {e}')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_sessions_by_provider(
|
def delete_sessions_by_provider(self, provider: str, db: Optional[Session] = None) -> bool:
|
||||||
self, provider: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Delete all OAuth sessions for a provider"""
|
"""Delete all OAuth sessions for a provider"""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
@@ -295,7 +279,7 @@ class OAuthSessionTable:
|
|||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error deleting OAuth sessions by provider {provider}: {e}")
|
log.error(f'Error deleting OAuth sessions by provider {provider}: {e}')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Column, Text, JSON, Index
|
|||||||
|
|
||||||
|
|
||||||
class PromptHistory(Base):
|
class PromptHistory(Base):
|
||||||
__tablename__ = "prompt_history"
|
__tablename__ = 'prompt_history'
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True)
|
||||||
prompt_id = Column(Text, nullable=False, index=True)
|
prompt_id = Column(Text, nullable=False, index=True)
|
||||||
@@ -100,11 +100,7 @@ class PromptHistoryTable:
|
|||||||
return [
|
return [
|
||||||
PromptHistoryResponse(
|
PromptHistoryResponse(
|
||||||
**PromptHistoryModel.model_validate(entry).model_dump(),
|
**PromptHistoryModel.model_validate(entry).model_dump(),
|
||||||
user=(
|
user=(users_dict.get(entry.user_id).model_dump() if users_dict.get(entry.user_id) else None),
|
||||||
users_dict.get(entry.user_id).model_dump()
|
|
||||||
if users_dict.get(entry.user_id)
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for entry in entries
|
for entry in entries
|
||||||
]
|
]
|
||||||
@@ -116,9 +112,7 @@ class PromptHistoryTable:
|
|||||||
) -> Optional[PromptHistoryModel]:
|
) -> Optional[PromptHistoryModel]:
|
||||||
"""Get a specific history entry by ID."""
|
"""Get a specific history entry by ID."""
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
entry = (
|
entry = db.query(PromptHistory).filter(PromptHistory.id == history_id).first()
|
||||||
db.query(PromptHistory).filter(PromptHistory.id == history_id).first()
|
|
||||||
)
|
|
||||||
if entry:
|
if entry:
|
||||||
return PromptHistoryModel.model_validate(entry)
|
return PromptHistoryModel.model_validate(entry)
|
||||||
return None
|
return None
|
||||||
@@ -147,11 +141,7 @@ class PromptHistoryTable:
|
|||||||
) -> int:
|
) -> int:
|
||||||
"""Get the number of history entries for a prompt."""
|
"""Get the number of history entries for a prompt."""
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return (
|
return db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).count()
|
||||||
db.query(PromptHistory)
|
|
||||||
.filter(PromptHistory.prompt_id == prompt_id)
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_diff(
|
def compute_diff(
|
||||||
self,
|
self,
|
||||||
@@ -161,9 +151,7 @@ class PromptHistoryTable:
|
|||||||
) -> Optional[dict]:
|
) -> Optional[dict]:
|
||||||
"""Compute diff between two history entries."""
|
"""Compute diff between two history entries."""
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
from_entry = (
|
from_entry = db.query(PromptHistory).filter(PromptHistory.id == from_id).first()
|
||||||
db.query(PromptHistory).filter(PromptHistory.id == from_id).first()
|
|
||||||
)
|
|
||||||
to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first()
|
to_entry = db.query(PromptHistory).filter(PromptHistory.id == to_id).first()
|
||||||
|
|
||||||
if not from_entry or not to_entry:
|
if not from_entry or not to_entry:
|
||||||
@@ -173,26 +161,26 @@ class PromptHistoryTable:
|
|||||||
to_snapshot = to_entry.snapshot
|
to_snapshot = to_entry.snapshot
|
||||||
|
|
||||||
# Compute diff for content field
|
# Compute diff for content field
|
||||||
from_content = from_snapshot.get("content", "")
|
from_content = from_snapshot.get('content', '')
|
||||||
to_content = to_snapshot.get("content", "")
|
to_content = to_snapshot.get('content', '')
|
||||||
|
|
||||||
diff_lines = list(
|
diff_lines = list(
|
||||||
difflib.unified_diff(
|
difflib.unified_diff(
|
||||||
from_content.splitlines(keepends=True),
|
from_content.splitlines(keepends=True),
|
||||||
to_content.splitlines(keepends=True),
|
to_content.splitlines(keepends=True),
|
||||||
fromfile=f"v{from_id[:8]}",
|
fromfile=f'v{from_id[:8]}',
|
||||||
tofile=f"v{to_id[:8]}",
|
tofile=f'v{to_id[:8]}',
|
||||||
lineterm="",
|
lineterm='',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"from_id": from_id,
|
'from_id': from_id,
|
||||||
"to_id": to_id,
|
'to_id': to_id,
|
||||||
"from_snapshot": from_snapshot,
|
'from_snapshot': from_snapshot,
|
||||||
"to_snapshot": to_snapshot,
|
'to_snapshot': to_snapshot,
|
||||||
"content_diff": diff_lines,
|
'content_diff': diff_lines,
|
||||||
"name_changed": from_snapshot.get("name") != to_snapshot.get("name"),
|
'name_changed': from_snapshot.get('name') != to_snapshot.get('name'),
|
||||||
}
|
}
|
||||||
|
|
||||||
def delete_history_by_prompt_id(
|
def delete_history_by_prompt_id(
|
||||||
@@ -202,9 +190,7 @@ class PromptHistoryTable:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""Delete all history entries for a prompt."""
|
"""Delete all history entries for a prompt."""
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
db.query(PromptHistory).filter(
|
db.query(PromptHistory).filter(PromptHistory.prompt_id == prompt_id).delete()
|
||||||
PromptHistory.prompt_id == prompt_id
|
|
||||||
).delete()
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, or_, fun
|
|||||||
|
|
||||||
|
|
||||||
class Prompt(Base):
|
class Prompt(Base):
|
||||||
__tablename__ = "prompt"
|
__tablename__ = 'prompt'
|
||||||
|
|
||||||
id = Column(Text, primary_key=True)
|
id = Column(Text, primary_key=True)
|
||||||
command = Column(String, unique=True, index=True)
|
command = Column(String, unique=True, index=True)
|
||||||
@@ -77,7 +77,6 @@ class PromptAccessListResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class PromptForm(BaseModel):
|
class PromptForm(BaseModel):
|
||||||
|
|
||||||
command: str
|
command: str
|
||||||
name: str # Changed from title
|
name: str # Changed from title
|
||||||
content: str
|
content: str
|
||||||
@@ -91,10 +90,8 @@ class PromptForm(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class PromptsTable:
|
class PromptsTable:
|
||||||
def _get_access_grants(
|
def _get_access_grants(self, prompt_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||||
self, prompt_id: str, db: Optional[Session] = None
|
return AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db)
|
||||||
) -> list[AccessGrantModel]:
|
|
||||||
return AccessGrants.get_grants_by_resource("prompt", prompt_id, db=db)
|
|
||||||
|
|
||||||
def _to_prompt_model(
|
def _to_prompt_model(
|
||||||
self,
|
self,
|
||||||
@@ -102,13 +99,9 @@ class PromptsTable:
|
|||||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> PromptModel:
|
) -> PromptModel:
|
||||||
prompt_data = PromptModel.model_validate(prompt).model_dump(
|
prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'})
|
||||||
exclude={"access_grants"}
|
prompt_data['access_grants'] = (
|
||||||
)
|
access_grants if access_grants is not None else self._get_access_grants(prompt_data['id'], db=db)
|
||||||
prompt_data["access_grants"] = (
|
|
||||||
access_grants
|
|
||||||
if access_grants is not None
|
|
||||||
else self._get_access_grants(prompt_data["id"], db=db)
|
|
||||||
)
|
)
|
||||||
return PromptModel.model_validate(prompt_data)
|
return PromptModel.model_validate(prompt_data)
|
||||||
|
|
||||||
@@ -135,26 +128,22 @@ class PromptsTable:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
result = Prompt(**prompt.model_dump(exclude={"access_grants"}))
|
result = Prompt(**prompt.model_dump(exclude={'access_grants'}))
|
||||||
db.add(result)
|
db.add(result)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(result)
|
db.refresh(result)
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('prompt', prompt_id, form_data.access_grants, db=db)
|
||||||
"prompt", prompt_id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
current_access_grants = self._get_access_grants(prompt_id, db=db)
|
current_access_grants = self._get_access_grants(prompt_id, db=db)
|
||||||
snapshot = {
|
snapshot = {
|
||||||
"name": form_data.name,
|
'name': form_data.name,
|
||||||
"content": form_data.content,
|
'content': form_data.content,
|
||||||
"command": form_data.command,
|
'command': form_data.command,
|
||||||
"data": form_data.data or {},
|
'data': form_data.data or {},
|
||||||
"meta": form_data.meta or {},
|
'meta': form_data.meta or {},
|
||||||
"tags": form_data.tags or [],
|
'tags': form_data.tags or [],
|
||||||
"access_grants": [
|
'access_grants': [grant.model_dump() for grant in current_access_grants],
|
||||||
grant.model_dump() for grant in current_access_grants
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
history_entry = PromptHistories.create_history_entry(
|
history_entry = PromptHistories.create_history_entry(
|
||||||
@@ -162,7 +151,7 @@ class PromptsTable:
|
|||||||
snapshot=snapshot,
|
snapshot=snapshot,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
parent_id=None, # Initial commit has no parent
|
parent_id=None, # Initial commit has no parent
|
||||||
commit_message=form_data.commit_message or "Initial version",
|
commit_message=form_data.commit_message or 'Initial version',
|
||||||
db=db,
|
db=db,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -178,9 +167,7 @@ class PromptsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_prompt_by_id(
|
def get_prompt_by_id(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]:
|
||||||
self, prompt_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[PromptModel]:
|
|
||||||
"""Get prompt by UUID."""
|
"""Get prompt by UUID."""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
@@ -191,9 +178,7 @@ class PromptsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_prompt_by_command(
|
def get_prompt_by_command(self, command: str, db: Optional[Session] = None) -> Optional[PromptModel]:
|
||||||
self, command: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[PromptModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||||
@@ -205,21 +190,14 @@ class PromptsTable:
|
|||||||
|
|
||||||
def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]:
|
def get_prompts(self, db: Optional[Session] = None) -> list[PromptUserResponse]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
all_prompts = (
|
all_prompts = db.query(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc()).all()
|
||||||
db.query(Prompt)
|
|
||||||
.filter(Prompt.is_active == True)
|
|
||||||
.order_by(Prompt.updated_at.desc())
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
user_ids = list(set(prompt.user_id for prompt in all_prompts))
|
user_ids = list(set(prompt.user_id for prompt in all_prompts))
|
||||||
prompt_ids = [prompt.id for prompt in all_prompts]
|
prompt_ids = [prompt.id for prompt in all_prompts]
|
||||||
|
|
||||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||||
users_dict = {user.id: user for user in users}
|
users_dict = {user.id: user for user in users}
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
|
||||||
"prompt", prompt_ids, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
prompts = []
|
prompts = []
|
||||||
for prompt in all_prompts:
|
for prompt in all_prompts:
|
||||||
@@ -232,7 +210,7 @@ class PromptsTable:
|
|||||||
access_grants=grants_map.get(prompt.id, []),
|
access_grants=grants_map.get(prompt.id, []),
|
||||||
db=db,
|
db=db,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
"user": user.model_dump() if user else None,
|
'user': user.model_dump() if user else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -240,12 +218,10 @@ class PromptsTable:
|
|||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def get_prompts_by_user_id(
|
def get_prompts_by_user_id(
|
||||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
|
||||||
) -> list[PromptUserResponse]:
|
) -> list[PromptUserResponse]:
|
||||||
prompts = self.get_prompts(db=db)
|
prompts = self.get_prompts(db=db)
|
||||||
user_group_ids = {
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
}
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
prompt
|
prompt
|
||||||
@@ -253,7 +229,7 @@ class PromptsTable:
|
|||||||
if prompt.user_id == user_id
|
if prompt.user_id == user_id
|
||||||
or AccessGrants.has_access(
|
or AccessGrants.has_access(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
resource_type="prompt",
|
resource_type='prompt',
|
||||||
resource_id=prompt.id,
|
resource_id=prompt.id,
|
||||||
permission=permission,
|
permission=permission,
|
||||||
user_group_ids=user_group_ids,
|
user_group_ids=user_group_ids,
|
||||||
@@ -276,22 +252,22 @@ class PromptsTable:
|
|||||||
query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id)
|
query = db.query(Prompt, User).outerjoin(User, User.id == Prompt.user_id)
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
query_key = filter.get("query")
|
query_key = filter.get('query')
|
||||||
if query_key:
|
if query_key:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
Prompt.name.ilike(f"%{query_key}%"),
|
Prompt.name.ilike(f'%{query_key}%'),
|
||||||
Prompt.command.ilike(f"%{query_key}%"),
|
Prompt.command.ilike(f'%{query_key}%'),
|
||||||
Prompt.content.ilike(f"%{query_key}%"),
|
Prompt.content.ilike(f'%{query_key}%'),
|
||||||
User.name.ilike(f"%{query_key}%"),
|
User.name.ilike(f'%{query_key}%'),
|
||||||
User.email.ilike(f"%{query_key}%"),
|
User.email.ilike(f'%{query_key}%'),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
view_option = filter.get('view_option')
|
||||||
if view_option == "created":
|
if view_option == 'created':
|
||||||
query = query.filter(Prompt.user_id == user_id)
|
query = query.filter(Prompt.user_id == user_id)
|
||||||
elif view_option == "shared":
|
elif view_option == 'shared':
|
||||||
query = query.filter(Prompt.user_id != user_id)
|
query = query.filter(Prompt.user_id != user_id)
|
||||||
|
|
||||||
# Apply access grant filtering
|
# Apply access grant filtering
|
||||||
@@ -300,32 +276,32 @@ class PromptsTable:
|
|||||||
query=query,
|
query=query,
|
||||||
DocumentModel=Prompt,
|
DocumentModel=Prompt,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
resource_type="prompt",
|
resource_type='prompt',
|
||||||
permission="read",
|
permission='read',
|
||||||
)
|
)
|
||||||
|
|
||||||
tag = filter.get("tag")
|
tag = filter.get('tag')
|
||||||
if tag:
|
if tag:
|
||||||
# Search for tag in JSON array field
|
# Search for tag in JSON array field
|
||||||
like_pattern = f'%"{tag.lower()}"%'
|
like_pattern = f'%"{tag.lower()}"%'
|
||||||
tags_text = func.lower(cast(Prompt.tags, String))
|
tags_text = func.lower(cast(Prompt.tags, String))
|
||||||
query = query.filter(tags_text.like(like_pattern))
|
query = query.filter(tags_text.like(like_pattern))
|
||||||
|
|
||||||
order_by = filter.get("order_by")
|
order_by = filter.get('order_by')
|
||||||
direction = filter.get("direction")
|
direction = filter.get('direction')
|
||||||
|
|
||||||
if order_by == "name":
|
if order_by == 'name':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(Prompt.name.asc())
|
query = query.order_by(Prompt.name.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(Prompt.name.desc())
|
query = query.order_by(Prompt.name.desc())
|
||||||
elif order_by == "created_at":
|
elif order_by == 'created_at':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(Prompt.created_at.asc())
|
query = query.order_by(Prompt.created_at.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(Prompt.created_at.desc())
|
query = query.order_by(Prompt.created_at.desc())
|
||||||
elif order_by == "updated_at":
|
elif order_by == 'updated_at':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(Prompt.updated_at.asc())
|
query = query.order_by(Prompt.updated_at.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(Prompt.updated_at.desc())
|
query = query.order_by(Prompt.updated_at.desc())
|
||||||
@@ -345,9 +321,7 @@ class PromptsTable:
|
|||||||
items = query.all()
|
items = query.all()
|
||||||
|
|
||||||
prompt_ids = [prompt.id for prompt, _ in items]
|
prompt_ids = [prompt.id for prompt, _ in items]
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
|
||||||
"prompt", prompt_ids, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
prompts = []
|
prompts = []
|
||||||
for prompt, user in items:
|
for prompt, user in items:
|
||||||
@@ -358,11 +332,7 @@ class PromptsTable:
|
|||||||
access_grants=grants_map.get(prompt.id, []),
|
access_grants=grants_map.get(prompt.id, []),
|
||||||
db=db,
|
db=db,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
user=(
|
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||||
UserResponse(**UserModel.model_validate(user).model_dump())
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -381,9 +351,7 @@ class PromptsTable:
|
|||||||
if not prompt:
|
if not prompt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
latest_history = PromptHistories.get_latest_history_entry(
|
latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db)
|
||||||
prompt.id, db=db
|
|
||||||
)
|
|
||||||
parent_id = latest_history.id if latest_history else None
|
parent_id = latest_history.id if latest_history else None
|
||||||
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
||||||
|
|
||||||
@@ -401,9 +369,7 @@ class PromptsTable:
|
|||||||
prompt.meta = form_data.meta or prompt.meta
|
prompt.meta = form_data.meta or prompt.meta
|
||||||
prompt.updated_at = int(time.time())
|
prompt.updated_at = int(time.time())
|
||||||
if form_data.access_grants is not None:
|
if form_data.access_grants is not None:
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
|
||||||
"prompt", prompt.id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -411,14 +377,12 @@ class PromptsTable:
|
|||||||
# Create history entry only if content changed
|
# Create history entry only if content changed
|
||||||
if content_changed:
|
if content_changed:
|
||||||
snapshot = {
|
snapshot = {
|
||||||
"name": form_data.name,
|
'name': form_data.name,
|
||||||
"content": form_data.content,
|
'content': form_data.content,
|
||||||
"command": command,
|
'command': command,
|
||||||
"data": form_data.data or {},
|
'data': form_data.data or {},
|
||||||
"meta": form_data.meta or {},
|
'meta': form_data.meta or {},
|
||||||
"access_grants": [
|
'access_grants': [grant.model_dump() for grant in current_access_grants],
|
||||||
grant.model_dump() for grant in current_access_grants
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
history_entry = PromptHistories.create_history_entry(
|
history_entry = PromptHistories.create_history_entry(
|
||||||
@@ -452,9 +416,7 @@ class PromptsTable:
|
|||||||
if not prompt:
|
if not prompt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
latest_history = PromptHistories.get_latest_history_entry(
|
latest_history = PromptHistories.get_latest_history_entry(prompt.id, db=db)
|
||||||
prompt.id, db=db
|
|
||||||
)
|
|
||||||
parent_id = latest_history.id if latest_history else None
|
parent_id = latest_history.id if latest_history else None
|
||||||
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
||||||
|
|
||||||
@@ -478,9 +440,7 @@ class PromptsTable:
|
|||||||
prompt.tags = form_data.tags
|
prompt.tags = form_data.tags
|
||||||
|
|
||||||
if form_data.access_grants is not None:
|
if form_data.access_grants is not None:
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
|
||||||
"prompt", prompt.id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
current_access_grants = self._get_access_grants(prompt.id, db=db)
|
||||||
|
|
||||||
prompt.updated_at = int(time.time())
|
prompt.updated_at = int(time.time())
|
||||||
@@ -490,15 +450,13 @@ class PromptsTable:
|
|||||||
# Create history entry only if content changed
|
# Create history entry only if content changed
|
||||||
if content_changed:
|
if content_changed:
|
||||||
snapshot = {
|
snapshot = {
|
||||||
"name": form_data.name,
|
'name': form_data.name,
|
||||||
"content": form_data.content,
|
'content': form_data.content,
|
||||||
"command": prompt.command,
|
'command': prompt.command,
|
||||||
"data": form_data.data or {},
|
'data': form_data.data or {},
|
||||||
"meta": form_data.meta or {},
|
'meta': form_data.meta or {},
|
||||||
"tags": prompt.tags or [],
|
'tags': prompt.tags or [],
|
||||||
"access_grants": [
|
'access_grants': [grant.model_dump() for grant in current_access_grants],
|
||||||
grant.model_dump() for grant in current_access_grants
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
history_entry = PromptHistories.create_history_entry(
|
history_entry = PromptHistories.create_history_entry(
|
||||||
@@ -560,9 +518,7 @@ class PromptsTable:
|
|||||||
if not prompt:
|
if not prompt:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
history_entry = PromptHistories.get_history_entry_by_id(
|
history_entry = PromptHistories.get_history_entry_by_id(version_id, db=db)
|
||||||
version_id, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
if not history_entry:
|
if not history_entry:
|
||||||
return None
|
return None
|
||||||
@@ -570,11 +526,11 @@ class PromptsTable:
|
|||||||
# Restore prompt content from the snapshot
|
# Restore prompt content from the snapshot
|
||||||
snapshot = history_entry.snapshot
|
snapshot = history_entry.snapshot
|
||||||
if snapshot:
|
if snapshot:
|
||||||
prompt.name = snapshot.get("name", prompt.name)
|
prompt.name = snapshot.get('name', prompt.name)
|
||||||
prompt.content = snapshot.get("content", prompt.content)
|
prompt.content = snapshot.get('content', prompt.content)
|
||||||
prompt.data = snapshot.get("data", prompt.data)
|
prompt.data = snapshot.get('data', prompt.data)
|
||||||
prompt.meta = snapshot.get("meta", prompt.meta)
|
prompt.meta = snapshot.get('meta', prompt.meta)
|
||||||
prompt.tags = snapshot.get("tags", prompt.tags)
|
prompt.tags = snapshot.get('tags', prompt.tags)
|
||||||
# Note: command and access_grants are not restored from snapshot
|
# Note: command and access_grants are not restored from snapshot
|
||||||
|
|
||||||
prompt.version_id = version_id
|
prompt.version_id = version_id
|
||||||
@@ -585,9 +541,7 @@ class PromptsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def toggle_prompt_active(
|
def toggle_prompt_active(self, prompt_id: str, db: Optional[Session] = None) -> Optional[PromptModel]:
|
||||||
self, prompt_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[PromptModel]:
|
|
||||||
"""Toggle the is_active flag on a prompt."""
|
"""Toggle the is_active flag on a prompt."""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
@@ -602,16 +556,14 @@ class PromptsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_prompt_by_command(
|
def delete_prompt_by_command(self, command: str, db: Optional[Session] = None) -> bool:
|
||||||
self, command: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Permanently delete a prompt and its history."""
|
"""Permanently delete a prompt and its history."""
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
prompt = db.query(Prompt).filter_by(command=command).first()
|
prompt = db.query(Prompt).filter_by(command=command).first()
|
||||||
if prompt:
|
if prompt:
|
||||||
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
|
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
|
||||||
AccessGrants.revoke_all_access("prompt", prompt.id, db=db)
|
AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
|
||||||
|
|
||||||
db.delete(prompt)
|
db.delete(prompt)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -627,7 +579,7 @@ class PromptsTable:
|
|||||||
prompt = db.query(Prompt).filter_by(id=prompt_id).first()
|
prompt = db.query(Prompt).filter_by(id=prompt_id).first()
|
||||||
if prompt:
|
if prompt:
|
||||||
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
|
PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
|
||||||
AccessGrants.revoke_all_access("prompt", prompt.id, db=db)
|
AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
|
||||||
|
|
||||||
db.delete(prompt)
|
db.delete(prompt)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Skill(Base):
|
class Skill(Base):
|
||||||
__tablename__ = "skill"
|
__tablename__ = 'skill'
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
@@ -77,7 +77,7 @@ class SkillResponse(BaseModel):
|
|||||||
class SkillUserResponse(SkillResponse):
|
class SkillUserResponse(SkillResponse):
|
||||||
user: Optional[UserResponse] = None
|
user: Optional[UserResponse] = None
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
|
||||||
class SkillAccessResponse(SkillUserResponse):
|
class SkillAccessResponse(SkillUserResponse):
|
||||||
@@ -105,10 +105,8 @@ class SkillAccessListResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class SkillsTable:
|
class SkillsTable:
|
||||||
def _get_access_grants(
|
def _get_access_grants(self, skill_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||||
self, skill_id: str, db: Optional[Session] = None
|
return AccessGrants.get_grants_by_resource('skill', skill_id, db=db)
|
||||||
) -> list[AccessGrantModel]:
|
|
||||||
return AccessGrants.get_grants_by_resource("skill", skill_id, db=db)
|
|
||||||
|
|
||||||
def _to_skill_model(
|
def _to_skill_model(
|
||||||
self,
|
self,
|
||||||
@@ -116,13 +114,9 @@ class SkillsTable:
|
|||||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> SkillModel:
|
) -> SkillModel:
|
||||||
skill_data = SkillModel.model_validate(skill).model_dump(
|
skill_data = SkillModel.model_validate(skill).model_dump(exclude={'access_grants'})
|
||||||
exclude={"access_grants"}
|
skill_data['access_grants'] = (
|
||||||
)
|
access_grants if access_grants is not None else self._get_access_grants(skill_data['id'], db=db)
|
||||||
skill_data["access_grants"] = (
|
|
||||||
access_grants
|
|
||||||
if access_grants is not None
|
|
||||||
else self._get_access_grants(skill_data["id"], db=db)
|
|
||||||
)
|
)
|
||||||
return SkillModel.model_validate(skill_data)
|
return SkillModel.model_validate(skill_data)
|
||||||
|
|
||||||
@@ -136,29 +130,25 @@ class SkillsTable:
|
|||||||
try:
|
try:
|
||||||
result = Skill(
|
result = Skill(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(exclude={"access_grants"}),
|
**form_data.model_dump(exclude={'access_grants'}),
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.add(result)
|
db.add(result)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(result)
|
db.refresh(result)
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('skill', result.id, form_data.access_grants, db=db)
|
||||||
"skill", result.id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
if result:
|
if result:
|
||||||
return self._to_skill_model(result, db=db)
|
return self._to_skill_model(result, db=db)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error creating a new skill: {e}")
|
log.exception(f'Error creating a new skill: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_skill_by_id(
|
def get_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[SkillModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
skill = db.get(Skill, id)
|
skill = db.get(Skill, id)
|
||||||
@@ -166,9 +156,7 @@ class SkillsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_skill_by_name(
|
def get_skill_by_name(self, name: str, db: Optional[Session] = None) -> Optional[SkillModel]:
|
||||||
self, name: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[SkillModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
skill = db.query(Skill).filter_by(name=name).first()
|
skill = db.query(Skill).filter_by(name=name).first()
|
||||||
@@ -185,7 +173,7 @@ class SkillsTable:
|
|||||||
|
|
||||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||||
users_dict = {user.id: user for user in users}
|
users_dict = {user.id: user for user in users}
|
||||||
grants_map = AccessGrants.get_grants_by_resources("skill", skill_ids, db=db)
|
grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db)
|
||||||
|
|
||||||
skills = []
|
skills = []
|
||||||
for skill in all_skills:
|
for skill in all_skills:
|
||||||
@@ -198,19 +186,17 @@ class SkillsTable:
|
|||||||
access_grants=grants_map.get(skill.id, []),
|
access_grants=grants_map.get(skill.id, []),
|
||||||
db=db,
|
db=db,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
"user": user.model_dump() if user else None,
|
'user': user.model_dump() if user else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return skills
|
return skills
|
||||||
|
|
||||||
def get_skills_by_user_id(
|
def get_skills_by_user_id(
|
||||||
self, user_id: str, permission: str = "write", db: Optional[Session] = None
|
self, user_id: str, permission: str = 'write', db: Optional[Session] = None
|
||||||
) -> list[SkillUserModel]:
|
) -> list[SkillUserModel]:
|
||||||
skills = self.get_skills(db=db)
|
skills = self.get_skills(db=db)
|
||||||
user_group_ids = {
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
}
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
skill
|
skill
|
||||||
@@ -218,7 +204,7 @@ class SkillsTable:
|
|||||||
if skill.user_id == user_id
|
if skill.user_id == user_id
|
||||||
or AccessGrants.has_access(
|
or AccessGrants.has_access(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
resource_type="skill",
|
resource_type='skill',
|
||||||
resource_id=skill.id,
|
resource_id=skill.id,
|
||||||
permission=permission,
|
permission=permission,
|
||||||
user_group_ids=user_group_ids,
|
user_group_ids=user_group_ids,
|
||||||
@@ -242,22 +228,22 @@ class SkillsTable:
|
|||||||
query = db.query(Skill, User).outerjoin(User, User.id == Skill.user_id)
|
query = db.query(Skill, User).outerjoin(User, User.id == Skill.user_id)
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
query_key = filter.get("query")
|
query_key = filter.get('query')
|
||||||
if query_key:
|
if query_key:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
Skill.name.ilike(f"%{query_key}%"),
|
Skill.name.ilike(f'%{query_key}%'),
|
||||||
Skill.description.ilike(f"%{query_key}%"),
|
Skill.description.ilike(f'%{query_key}%'),
|
||||||
Skill.id.ilike(f"%{query_key}%"),
|
Skill.id.ilike(f'%{query_key}%'),
|
||||||
User.name.ilike(f"%{query_key}%"),
|
User.name.ilike(f'%{query_key}%'),
|
||||||
User.email.ilike(f"%{query_key}%"),
|
User.email.ilike(f'%{query_key}%'),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
view_option = filter.get("view_option")
|
view_option = filter.get('view_option')
|
||||||
if view_option == "created":
|
if view_option == 'created':
|
||||||
query = query.filter(Skill.user_id == user_id)
|
query = query.filter(Skill.user_id == user_id)
|
||||||
elif view_option == "shared":
|
elif view_option == 'shared':
|
||||||
query = query.filter(Skill.user_id != user_id)
|
query = query.filter(Skill.user_id != user_id)
|
||||||
|
|
||||||
# Apply access grant filtering
|
# Apply access grant filtering
|
||||||
@@ -266,8 +252,8 @@ class SkillsTable:
|
|||||||
query=query,
|
query=query,
|
||||||
DocumentModel=Skill,
|
DocumentModel=Skill,
|
||||||
filter=filter,
|
filter=filter,
|
||||||
resource_type="skill",
|
resource_type='skill',
|
||||||
permission="read",
|
permission='read',
|
||||||
)
|
)
|
||||||
|
|
||||||
query = query.order_by(Skill.updated_at.desc())
|
query = query.order_by(Skill.updated_at.desc())
|
||||||
@@ -283,9 +269,7 @@ class SkillsTable:
|
|||||||
items = query.all()
|
items = query.all()
|
||||||
|
|
||||||
skill_ids = [skill.id for skill, _ in items]
|
skill_ids = [skill.id for skill, _ in items]
|
||||||
grants_map = AccessGrants.get_grants_by_resources(
|
grants_map = AccessGrants.get_grants_by_resources('skill', skill_ids, db=db)
|
||||||
"skill", skill_ids, db=db
|
|
||||||
)
|
|
||||||
|
|
||||||
skills = []
|
skills = []
|
||||||
for skill, user in items:
|
for skill, user in items:
|
||||||
@@ -296,33 +280,23 @@ class SkillsTable:
|
|||||||
access_grants=grants_map.get(skill.id, []),
|
access_grants=grants_map.get(skill.id, []),
|
||||||
db=db,
|
db=db,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
user=(
|
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||||
UserResponse(
|
|
||||||
**UserModel.model_validate(user).model_dump()
|
|
||||||
)
|
|
||||||
if user
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return SkillListResponse(items=skills, total=total)
|
return SkillListResponse(items=skills, total=total)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error searching skills: {e}")
|
log.exception(f'Error searching skills: {e}')
|
||||||
return SkillListResponse(items=[], total=0)
|
return SkillListResponse(items=[], total=0)
|
||||||
|
|
||||||
def update_skill_by_id(
|
def update_skill_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[SkillModel]:
|
||||||
self, id: str, updated: dict, db: Optional[Session] = None
|
|
||||||
) -> Optional[SkillModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
access_grants = updated.pop("access_grants", None)
|
access_grants = updated.pop('access_grants', None)
|
||||||
db.query(Skill).filter_by(id=id).update(
|
db.query(Skill).filter_by(id=id).update({**updated, 'updated_at': int(time.time())})
|
||||||
{**updated, "updated_at": int(time.time())}
|
|
||||||
)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
if access_grants is not None:
|
if access_grants is not None:
|
||||||
AccessGrants.set_access_grants("skill", id, access_grants, db=db)
|
AccessGrants.set_access_grants('skill', id, access_grants, db=db)
|
||||||
|
|
||||||
skill = db.query(Skill).get(id)
|
skill = db.query(Skill).get(id)
|
||||||
db.refresh(skill)
|
db.refresh(skill)
|
||||||
@@ -330,9 +304,7 @@ class SkillsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def toggle_skill_by_id(
|
def toggle_skill_by_id(self, id: str, db: Optional[Session] = None) -> Optional[SkillModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[SkillModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
try:
|
try:
|
||||||
skill = db.query(Skill).filter_by(id=id).first()
|
skill = db.query(Skill).filter_by(id=id).first()
|
||||||
@@ -351,7 +323,7 @@ class SkillsTable:
|
|||||||
def delete_skill_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
def delete_skill_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
AccessGrants.revoke_all_access("skill", id, db=db)
|
AccessGrants.revoke_all_access('skill', id, db=db)
|
||||||
db.query(Skill).filter_by(id=id).delete()
|
db.query(Skill).filter_by(id=id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -17,19 +17,19 @@ log = logging.getLogger(__name__)
|
|||||||
# Tag DB Schema
|
# Tag DB Schema
|
||||||
####################
|
####################
|
||||||
class Tag(Base):
|
class Tag(Base):
|
||||||
__tablename__ = "tag"
|
__tablename__ = 'tag'
|
||||||
id = Column(String)
|
id = Column(String)
|
||||||
name = Column(String)
|
name = Column(String)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
meta = Column(JSON, nullable=True)
|
meta = Column(JSON, nullable=True)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),
|
PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'),
|
||||||
Index("user_id_idx", "user_id"),
|
Index('user_id_idx', 'user_id'),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
|
# Unique constraint ensuring (id, user_id) is unique, not just the `id` column
|
||||||
__table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),)
|
__table_args__ = (PrimaryKeyConstraint('id', 'user_id', name='pk_id_user_id'),)
|
||||||
|
|
||||||
|
|
||||||
class TagModel(BaseModel):
|
class TagModel(BaseModel):
|
||||||
@@ -51,12 +51,10 @@ class TagChatIdForm(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TagTable:
|
class TagTable:
|
||||||
def insert_new_tag(
|
def insert_new_tag(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]:
|
||||||
self, name: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[TagModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
id = name.replace(" ", "_").lower()
|
id = name.replace(' ', '_').lower()
|
||||||
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name})
|
||||||
try:
|
try:
|
||||||
result = Tag(**tag.model_dump())
|
result = Tag(**tag.model_dump())
|
||||||
db.add(result)
|
db.add(result)
|
||||||
@@ -67,89 +65,63 @@ class TagTable:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error inserting a new tag: {e}")
|
log.exception(f'Error inserting a new tag: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_tag_by_name_and_user_id(
|
def get_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> Optional[TagModel]:
|
||||||
self, name: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[TagModel]:
|
|
||||||
try:
|
try:
|
||||||
id = name.replace(" ", "_").lower()
|
id = name.replace(' ', '_').lower()
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
|
tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
|
||||||
return TagModel.model_validate(tag)
|
return TagModel.model_validate(tag)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_tags_by_user_id(
|
def get_tags_by_user_id(self, user_id: str, db: Optional[Session] = None) -> list[TagModel]:
|
||||||
self, user_id: str, db: Optional[Session] = None
|
with get_db_context(db) as db:
|
||||||
) -> list[TagModel]:
|
return [TagModel.model_validate(tag) for tag in (db.query(Tag).filter_by(user_id=user_id).all())]
|
||||||
|
|
||||||
|
def get_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> list[TagModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
return [
|
return [
|
||||||
TagModel.model_validate(tag)
|
TagModel.model_validate(tag)
|
||||||
for tag in (db.query(Tag).filter_by(user_id=user_id).all())
|
for tag in (db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all())
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_tags_by_ids_and_user_id(
|
def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, ids: list[str], user_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[TagModel]:
|
|
||||||
with get_db_context(db) as db:
|
|
||||||
return [
|
|
||||||
TagModel.model_validate(tag)
|
|
||||||
for tag in (
|
|
||||||
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def delete_tag_by_name_and_user_id(
|
|
||||||
self, name: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
id = name.replace(" ", "_").lower()
|
id = name.replace(' ', '_').lower()
|
||||||
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
|
res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
|
||||||
log.debug(f"res: {res}")
|
log.debug(f'res: {res}')
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"delete_tag: {e}")
|
log.error(f'delete_tag: {e}')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_tags_by_ids_and_user_id(
|
def delete_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[Session] = None) -> bool:
|
||||||
self, ids: list[str], user_id: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Delete all tags whose id is in *ids* for the given user, in one query."""
|
"""Delete all tags whose id is in *ids* for the given user, in one query."""
|
||||||
if not ids:
|
if not ids:
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete(
|
db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).delete(synchronize_session=False)
|
||||||
synchronize_session=False
|
|
||||||
)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"delete_tags_by_ids: {e}")
|
log.error(f'delete_tags_by_ids: {e}')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def ensure_tags_exist(
|
def ensure_tags_exist(self, names: list[str], user_id: str, db: Optional[Session] = None) -> None:
|
||||||
self, names: list[str], user_id: str, db: Optional[Session] = None
|
|
||||||
) -> None:
|
|
||||||
"""Create tag rows for any *names* that don't already exist for *user_id*."""
|
"""Create tag rows for any *names* that don't already exist for *user_id*."""
|
||||||
if not names:
|
if not names:
|
||||||
return
|
return
|
||||||
ids = [n.replace(" ", "_").lower() for n in names]
|
ids = [n.replace(' ', '_').lower() for n in names]
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
existing = {
|
existing = {t.id for t in db.query(Tag.id).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()}
|
||||||
t.id
|
|
||||||
for t in db.query(Tag.id)
|
|
||||||
.filter(Tag.id.in_(ids), Tag.user_id == user_id)
|
|
||||||
.all()
|
|
||||||
}
|
|
||||||
new_tags = [
|
new_tags = [
|
||||||
Tag(id=tag_id, name=name, user_id=user_id)
|
Tag(id=tag_id, name=name, user_id=user_id) for tag_id, name in zip(ids, names) if tag_id not in existing
|
||||||
for tag_id, name in zip(ids, names)
|
|
||||||
if tag_id not in existing
|
|
||||||
]
|
]
|
||||||
if new_tags:
|
if new_tags:
|
||||||
db.add_all(new_tags)
|
db.add_all(new_tags)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Tool(Base):
|
class Tool(Base):
|
||||||
__tablename__ = "tool"
|
__tablename__ = 'tool'
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
user_id = Column(String)
|
user_id = Column(String)
|
||||||
@@ -75,7 +75,7 @@ class ToolResponse(BaseModel):
|
|||||||
class ToolUserResponse(ToolResponse):
|
class ToolUserResponse(ToolResponse):
|
||||||
user: Optional[UserResponse] = None
|
user: Optional[UserResponse] = None
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
|
||||||
class ToolAccessResponse(ToolUserResponse):
|
class ToolAccessResponse(ToolUserResponse):
|
||||||
@@ -95,10 +95,8 @@ class ToolValves(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ToolsTable:
|
class ToolsTable:
|
||||||
def _get_access_grants(
|
def _get_access_grants(self, tool_id: str, db: Optional[Session] = None) -> list[AccessGrantModel]:
|
||||||
self, tool_id: str, db: Optional[Session] = None
|
return AccessGrants.get_grants_by_resource('tool', tool_id, db=db)
|
||||||
) -> list[AccessGrantModel]:
|
|
||||||
return AccessGrants.get_grants_by_resource("tool", tool_id, db=db)
|
|
||||||
|
|
||||||
def _to_tool_model(
|
def _to_tool_model(
|
||||||
self,
|
self,
|
||||||
@@ -106,11 +104,9 @@ class ToolsTable:
|
|||||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> ToolModel:
|
) -> ToolModel:
|
||||||
tool_data = ToolModel.model_validate(tool).model_dump(exclude={"access_grants"})
|
tool_data = ToolModel.model_validate(tool).model_dump(exclude={'access_grants'})
|
||||||
tool_data["access_grants"] = (
|
tool_data['access_grants'] = (
|
||||||
access_grants
|
access_grants if access_grants is not None else self._get_access_grants(tool_data['id'], db=db)
|
||||||
if access_grants is not None
|
|
||||||
else self._get_access_grants(tool_data["id"], db=db)
|
|
||||||
)
|
)
|
||||||
return ToolModel.model_validate(tool_data)
|
return ToolModel.model_validate(tool_data)
|
||||||
|
|
||||||
@@ -125,30 +121,26 @@ class ToolsTable:
|
|||||||
try:
|
try:
|
||||||
result = Tool(
|
result = Tool(
|
||||||
**{
|
**{
|
||||||
**form_data.model_dump(exclude={"access_grants"}),
|
**form_data.model_dump(exclude={'access_grants'}),
|
||||||
"specs": specs,
|
'specs': specs,
|
||||||
"user_id": user_id,
|
'user_id': user_id,
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db.add(result)
|
db.add(result)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(result)
|
db.refresh(result)
|
||||||
AccessGrants.set_access_grants(
|
AccessGrants.set_access_grants('tool', result.id, form_data.access_grants, db=db)
|
||||||
"tool", result.id, form_data.access_grants, db=db
|
|
||||||
)
|
|
||||||
if result:
|
if result:
|
||||||
return self._to_tool_model(result, db=db)
|
return self._to_tool_model(result, db=db)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error creating a new tool: {e}")
|
log.exception(f'Error creating a new tool: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_tool_by_id(
|
def get_tool_by_id(self, id: str, db: Optional[Session] = None) -> Optional[ToolModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[ToolModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
tool = db.get(Tool, id)
|
tool = db.get(Tool, id)
|
||||||
@@ -156,9 +148,7 @@ class ToolsTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_tools(
|
def get_tools(self, defer_content: bool = False, db: Optional[Session] = None) -> list[ToolUserModel]:
|
||||||
self, defer_content: bool = False, db: Optional[Session] = None
|
|
||||||
) -> list[ToolUserModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
query = db.query(Tool).order_by(Tool.updated_at.desc())
|
query = db.query(Tool).order_by(Tool.updated_at.desc())
|
||||||
if defer_content:
|
if defer_content:
|
||||||
@@ -170,7 +160,7 @@ class ToolsTable:
|
|||||||
|
|
||||||
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
users = Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||||
users_dict = {user.id: user for user in users}
|
users_dict = {user.id: user for user in users}
|
||||||
grants_map = AccessGrants.get_grants_by_resources("tool", tool_ids, db=db)
|
grants_map = AccessGrants.get_grants_by_resources('tool', tool_ids, db=db)
|
||||||
|
|
||||||
tools = []
|
tools = []
|
||||||
for tool in all_tools:
|
for tool in all_tools:
|
||||||
@@ -183,7 +173,7 @@ class ToolsTable:
|
|||||||
access_grants=grants_map.get(tool.id, []),
|
access_grants=grants_map.get(tool.id, []),
|
||||||
db=db,
|
db=db,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
"user": user.model_dump() if user else None,
|
'user': user.model_dump() if user else None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -192,14 +182,12 @@ class ToolsTable:
|
|||||||
def get_tools_by_user_id(
|
def get_tools_by_user_id(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
permission: str = "write",
|
permission: str = 'write',
|
||||||
defer_content: bool = False,
|
defer_content: bool = False,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
) -> list[ToolUserModel]:
|
) -> list[ToolUserModel]:
|
||||||
tools = self.get_tools(defer_content=defer_content, db=db)
|
tools = self.get_tools(defer_content=defer_content, db=db)
|
||||||
user_group_ids = {
|
user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id, db=db)}
|
||||||
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
|
|
||||||
}
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
tool
|
tool
|
||||||
@@ -207,7 +195,7 @@ class ToolsTable:
|
|||||||
if tool.user_id == user_id
|
if tool.user_id == user_id
|
||||||
or AccessGrants.has_access(
|
or AccessGrants.has_access(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
resource_type="tool",
|
resource_type='tool',
|
||||||
resource_id=tool.id,
|
resource_id=tool.id,
|
||||||
permission=permission,
|
permission=permission,
|
||||||
user_group_ids=user_group_ids,
|
user_group_ids=user_group_ids,
|
||||||
@@ -215,48 +203,38 @@ class ToolsTable:
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_tool_valves_by_id(
|
def get_tool_valves_by_id(self, id: str, db: Optional[Session] = None) -> Optional[dict]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[dict]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
tool = db.get(Tool, id)
|
tool = db.get(Tool, id)
|
||||||
return tool.valves if tool.valves else {}
|
return tool.valves if tool.valves else {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error getting tool valves by id {id}")
|
log.exception(f'Error getting tool valves by id {id}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_tool_valves_by_id(
|
def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[Session] = None) -> Optional[ToolValves]:
|
||||||
self, id: str, valves: dict, db: Optional[Session] = None
|
|
||||||
) -> Optional[ToolValves]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
db.query(Tool).filter_by(id=id).update(
|
db.query(Tool).filter_by(id=id).update({'valves': valves, 'updated_at': int(time.time())})
|
||||||
{"valves": valves, "updated_at": int(time.time())}
|
|
||||||
)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return self.get_tool_by_id(id, db=db)
|
return self.get_tool_by_id(id, db=db)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_valves_by_id_and_user_id(
|
def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[Session] = None) -> Optional[dict]:
|
||||||
self, id: str, user_id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[dict]:
|
|
||||||
try:
|
try:
|
||||||
user = Users.get_user_by_id(user_id, db=db)
|
user = Users.get_user_by_id(user_id, db=db)
|
||||||
user_settings = user.settings.model_dump() if user.settings else {}
|
user_settings = user.settings.model_dump() if user.settings else {}
|
||||||
|
|
||||||
# Check if user has "tools" and "valves" settings
|
# Check if user has "tools" and "valves" settings
|
||||||
if "tools" not in user_settings:
|
if 'tools' not in user_settings:
|
||||||
user_settings["tools"] = {}
|
user_settings['tools'] = {}
|
||||||
if "valves" not in user_settings["tools"]:
|
if 'valves' not in user_settings['tools']:
|
||||||
user_settings["tools"]["valves"] = {}
|
user_settings['tools']['valves'] = {}
|
||||||
|
|
||||||
return user_settings["tools"]["valves"].get(id, {})
|
return user_settings['tools']['valves'].get(id, {})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(
|
log.exception(f'Error getting user values by id {id} and user_id {user_id}: {e}')
|
||||||
f"Error getting user values by id {id} and user_id {user_id}: {e}"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_valves_by_id_and_user_id(
|
def update_user_valves_by_id_and_user_id(
|
||||||
@@ -267,35 +245,29 @@ class ToolsTable:
|
|||||||
user_settings = user.settings.model_dump() if user.settings else {}
|
user_settings = user.settings.model_dump() if user.settings else {}
|
||||||
|
|
||||||
# Check if user has "tools" and "valves" settings
|
# Check if user has "tools" and "valves" settings
|
||||||
if "tools" not in user_settings:
|
if 'tools' not in user_settings:
|
||||||
user_settings["tools"] = {}
|
user_settings['tools'] = {}
|
||||||
if "valves" not in user_settings["tools"]:
|
if 'valves' not in user_settings['tools']:
|
||||||
user_settings["tools"]["valves"] = {}
|
user_settings['tools']['valves'] = {}
|
||||||
|
|
||||||
user_settings["tools"]["valves"][id] = valves
|
user_settings['tools']['valves'][id] = valves
|
||||||
|
|
||||||
# Update the user settings in the database
|
# Update the user settings in the database
|
||||||
Users.update_user_by_id(user_id, {"settings": user_settings}, db=db)
|
Users.update_user_by_id(user_id, {'settings': user_settings}, db=db)
|
||||||
|
|
||||||
return user_settings["tools"]["valves"][id]
|
return user_settings['tools']['valves'][id]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(
|
log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
|
||||||
f"Error updating user valves by id {id} and user_id {user_id}: {e}"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_tool_by_id(
|
def update_tool_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[ToolModel]:
|
||||||
self, id: str, updated: dict, db: Optional[Session] = None
|
|
||||||
) -> Optional[ToolModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
access_grants = updated.pop("access_grants", None)
|
access_grants = updated.pop('access_grants', None)
|
||||||
db.query(Tool).filter_by(id=id).update(
|
db.query(Tool).filter_by(id=id).update({**updated, 'updated_at': int(time.time())})
|
||||||
{**updated, "updated_at": int(time.time())}
|
|
||||||
)
|
|
||||||
db.commit()
|
db.commit()
|
||||||
if access_grants is not None:
|
if access_grants is not None:
|
||||||
AccessGrants.set_access_grants("tool", id, access_grants, db=db)
|
AccessGrants.set_access_grants('tool', id, access_grants, db=db)
|
||||||
|
|
||||||
tool = db.query(Tool).get(id)
|
tool = db.query(Tool).get(id)
|
||||||
db.refresh(tool)
|
db.refresh(tool)
|
||||||
@@ -306,7 +278,7 @@ class ToolsTable:
|
|||||||
def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
def delete_tool_by_id(self, id: str, db: Optional[Session] = None) -> bool:
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
AccessGrants.revoke_all_access("tool", id, db=db)
|
AccessGrants.revoke_all_access('tool', id, db=db)
|
||||||
db.query(Tool).filter_by(id=id).delete()
|
db.query(Tool).filter_by(id=id).delete()
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -40,12 +40,12 @@ import datetime
|
|||||||
|
|
||||||
class UserSettings(BaseModel):
|
class UserSettings(BaseModel):
|
||||||
ui: Optional[dict] = {}
|
ui: Optional[dict] = {}
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
__tablename__ = "user"
|
__tablename__ = 'user'
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True)
|
id = Column(String, primary_key=True, unique=True)
|
||||||
email = Column(String)
|
email = Column(String)
|
||||||
@@ -83,7 +83,7 @@ class UserModel(BaseModel):
|
|||||||
|
|
||||||
email: str
|
email: str
|
||||||
username: Optional[str] = None
|
username: Optional[str] = None
|
||||||
role: str = "pending"
|
role: str = 'pending'
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
@@ -112,10 +112,10 @@ class UserModel(BaseModel):
|
|||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode='after')
|
||||||
def set_profile_image_url(self):
|
def set_profile_image_url(self):
|
||||||
if not self.profile_image_url:
|
if not self.profile_image_url:
|
||||||
self.profile_image_url = f"/api/v1/users/{self.id}/profile/image"
|
self.profile_image_url = f'/api/v1/users/{self.id}/profile/image'
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@@ -126,7 +126,7 @@ class UserStatusModel(UserModel):
|
|||||||
|
|
||||||
|
|
||||||
class ApiKey(Base):
|
class ApiKey(Base):
|
||||||
__tablename__ = "api_key"
|
__tablename__ = 'api_key'
|
||||||
|
|
||||||
id = Column(Text, primary_key=True, unique=True)
|
id = Column(Text, primary_key=True, unique=True)
|
||||||
user_id = Column(Text, nullable=False)
|
user_id = Column(Text, nullable=False)
|
||||||
@@ -163,7 +163,7 @@ class UpdateProfileForm(BaseModel):
|
|||||||
gender: Optional[str] = None
|
gender: Optional[str] = None
|
||||||
date_of_birth: Optional[datetime.date] = None
|
date_of_birth: Optional[datetime.date] = None
|
||||||
|
|
||||||
@field_validator("profile_image_url")
|
@field_validator('profile_image_url')
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_profile_image_url(cls, v: str) -> str:
|
def check_profile_image_url(cls, v: str) -> str:
|
||||||
return validate_profile_image_url(v)
|
return validate_profile_image_url(v)
|
||||||
@@ -174,7 +174,7 @@ class UserGroupIdsModel(UserModel):
|
|||||||
|
|
||||||
|
|
||||||
class UserModelResponse(UserModel):
|
class UserModelResponse(UserModel):
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra='allow')
|
||||||
|
|
||||||
|
|
||||||
class UserListResponse(BaseModel):
|
class UserListResponse(BaseModel):
|
||||||
@@ -251,7 +251,7 @@ class UserUpdateForm(BaseModel):
|
|||||||
profile_image_url: str
|
profile_image_url: str
|
||||||
password: Optional[str] = None
|
password: Optional[str] = None
|
||||||
|
|
||||||
@field_validator("profile_image_url")
|
@field_validator('profile_image_url')
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_profile_image_url(cls, v: str) -> str:
|
def check_profile_image_url(cls, v: str) -> str:
|
||||||
return validate_profile_image_url(v)
|
return validate_profile_image_url(v)
|
||||||
@@ -263,8 +263,8 @@ class UsersTable:
|
|||||||
id: str,
|
id: str,
|
||||||
name: str,
|
name: str,
|
||||||
email: str,
|
email: str,
|
||||||
profile_image_url: str = "/user.png",
|
profile_image_url: str = '/user.png',
|
||||||
role: str = "pending",
|
role: str = 'pending',
|
||||||
username: Optional[str] = None,
|
username: Optional[str] = None,
|
||||||
oauth: Optional[dict] = None,
|
oauth: Optional[dict] = None,
|
||||||
db: Optional[Session] = None,
|
db: Optional[Session] = None,
|
||||||
@@ -272,16 +272,16 @@ class UsersTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user = UserModel(
|
user = UserModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
'id': id,
|
||||||
"email": email,
|
'email': email,
|
||||||
"name": name,
|
'name': name,
|
||||||
"role": role,
|
'role': role,
|
||||||
"profile_image_url": profile_image_url,
|
'profile_image_url': profile_image_url,
|
||||||
"last_active_at": int(time.time()),
|
'last_active_at': int(time.time()),
|
||||||
"created_at": int(time.time()),
|
'created_at': int(time.time()),
|
||||||
"updated_at": int(time.time()),
|
'updated_at': int(time.time()),
|
||||||
"username": username,
|
'username': username,
|
||||||
"oauth": oauth,
|
'oauth': oauth,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
result = User(**user.model_dump())
|
result = User(**user.model_dump())
|
||||||
@@ -293,9 +293,7 @@ class UsersTable:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_by_id(
|
def get_user_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[UserModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user = db.query(User).filter_by(id=id).first()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
@@ -303,49 +301,32 @@ class UsersTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_by_api_key(
|
def get_user_by_api_key(self, api_key: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
self, api_key: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[UserModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user = (
|
user = db.query(User).join(ApiKey, User.id == ApiKey.user_id).filter(ApiKey.key == api_key).first()
|
||||||
db.query(User)
|
|
||||||
.join(ApiKey, User.id == ApiKey.user_id)
|
|
||||||
.filter(ApiKey.key == api_key)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
return UserModel.model_validate(user) if user else None
|
return UserModel.model_validate(user) if user else None
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_by_email(
|
def get_user_by_email(self, email: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
self, email: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[UserModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user = (
|
user = db.query(User).filter(func.lower(User.email) == email.lower()).first()
|
||||||
db.query(User)
|
|
||||||
.filter(func.lower(User.email) == email.lower())
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
return UserModel.model_validate(user) if user else None
|
return UserModel.model_validate(user) if user else None
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_by_oauth_sub(
|
def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
self, provider: str, sub: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[UserModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db: # type: Session
|
with get_db_context(db) as db: # type: Session
|
||||||
dialect_name = db.bind.dialect.name
|
dialect_name = db.bind.dialect.name
|
||||||
|
|
||||||
query = db.query(User)
|
query = db.query(User)
|
||||||
if dialect_name == "sqlite":
|
if dialect_name == 'sqlite':
|
||||||
query = query.filter(User.oauth.contains({provider: {"sub": sub}}))
|
query = query.filter(User.oauth.contains({provider: {'sub': sub}}))
|
||||||
elif dialect_name == "postgresql":
|
elif dialect_name == 'postgresql':
|
||||||
query = query.filter(
|
query = query.filter(User.oauth[provider].cast(JSONB)['sub'].astext == sub)
|
||||||
User.oauth[provider].cast(JSONB)["sub"].astext == sub
|
|
||||||
)
|
|
||||||
|
|
||||||
user = query.first()
|
user = query.first()
|
||||||
return UserModel.model_validate(user) if user else None
|
return UserModel.model_validate(user) if user else None
|
||||||
@@ -361,15 +342,10 @@ class UsersTable:
|
|||||||
dialect_name = db.bind.dialect.name
|
dialect_name = db.bind.dialect.name
|
||||||
|
|
||||||
query = db.query(User)
|
query = db.query(User)
|
||||||
if dialect_name == "sqlite":
|
if dialect_name == 'sqlite':
|
||||||
query = query.filter(
|
query = query.filter(User.scim.contains({provider: {'external_id': external_id}}))
|
||||||
User.scim.contains({provider: {"external_id": external_id}})
|
elif dialect_name == 'postgresql':
|
||||||
)
|
query = query.filter(User.scim[provider].cast(JSONB)['external_id'].astext == external_id)
|
||||||
elif dialect_name == "postgresql":
|
|
||||||
query = query.filter(
|
|
||||||
User.scim[provider].cast(JSONB)["external_id"].astext
|
|
||||||
== external_id
|
|
||||||
)
|
|
||||||
|
|
||||||
user = query.first()
|
user = query.first()
|
||||||
return UserModel.model_validate(user) if user else None
|
return UserModel.model_validate(user) if user else None
|
||||||
@@ -388,16 +364,16 @@ class UsersTable:
|
|||||||
query = db.query(User).options(defer(User.profile_image_url))
|
query = db.query(User).options(defer(User.profile_image_url))
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
query_key = filter.get("query")
|
query_key = filter.get('query')
|
||||||
if query_key:
|
if query_key:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
User.name.ilike(f"%{query_key}%"),
|
User.name.ilike(f'%{query_key}%'),
|
||||||
User.email.ilike(f"%{query_key}%"),
|
User.email.ilike(f'%{query_key}%'),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
channel_id = filter.get("channel_id")
|
channel_id = filter.get('channel_id')
|
||||||
if channel_id:
|
if channel_id:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
exists(
|
exists(
|
||||||
@@ -408,13 +384,13 @@ class UsersTable:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
user_ids = filter.get("user_ids")
|
user_ids = filter.get('user_ids')
|
||||||
group_ids = filter.get("group_ids")
|
group_ids = filter.get('group_ids')
|
||||||
|
|
||||||
if isinstance(user_ids, list) and isinstance(group_ids, list):
|
if isinstance(user_ids, list) and isinstance(group_ids, list):
|
||||||
# If both are empty lists, return no users
|
# If both are empty lists, return no users
|
||||||
if not user_ids and not group_ids:
|
if not user_ids and not group_ids:
|
||||||
return {"users": [], "total": 0}
|
return {'users': [], 'total': 0}
|
||||||
|
|
||||||
if user_ids:
|
if user_ids:
|
||||||
query = query.filter(User.id.in_(user_ids))
|
query = query.filter(User.id.in_(user_ids))
|
||||||
@@ -429,21 +405,21 @@ class UsersTable:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
roles = filter.get("roles")
|
roles = filter.get('roles')
|
||||||
if roles:
|
if roles:
|
||||||
include_roles = [role for role in roles if not role.startswith("!")]
|
include_roles = [role for role in roles if not role.startswith('!')]
|
||||||
exclude_roles = [role[1:] for role in roles if role.startswith("!")]
|
exclude_roles = [role[1:] for role in roles if role.startswith('!')]
|
||||||
|
|
||||||
if include_roles:
|
if include_roles:
|
||||||
query = query.filter(User.role.in_(include_roles))
|
query = query.filter(User.role.in_(include_roles))
|
||||||
if exclude_roles:
|
if exclude_roles:
|
||||||
query = query.filter(~User.role.in_(exclude_roles))
|
query = query.filter(~User.role.in_(exclude_roles))
|
||||||
|
|
||||||
order_by = filter.get("order_by")
|
order_by = filter.get('order_by')
|
||||||
direction = filter.get("direction")
|
direction = filter.get('direction')
|
||||||
|
|
||||||
if order_by and order_by.startswith("group_id:"):
|
if order_by and order_by.startswith('group_id:'):
|
||||||
group_id = order_by.split(":", 1)[1]
|
group_id = order_by.split(':', 1)[1]
|
||||||
|
|
||||||
# Subquery that checks if the user belongs to the group
|
# Subquery that checks if the user belongs to the group
|
||||||
membership_exists = exists(
|
membership_exists = exists(
|
||||||
@@ -456,42 +432,42 @@ class UsersTable:
|
|||||||
# CASE: user in group → 1, user not in group → 0
|
# CASE: user in group → 1, user not in group → 0
|
||||||
group_sort = case((membership_exists, 1), else_=0)
|
group_sort = case((membership_exists, 1), else_=0)
|
||||||
|
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(group_sort.asc(), User.name.asc())
|
query = query.order_by(group_sort.asc(), User.name.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(group_sort.desc(), User.name.asc())
|
query = query.order_by(group_sort.desc(), User.name.asc())
|
||||||
|
|
||||||
elif order_by == "name":
|
elif order_by == 'name':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(User.name.asc())
|
query = query.order_by(User.name.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.name.desc())
|
query = query.order_by(User.name.desc())
|
||||||
|
|
||||||
elif order_by == "email":
|
elif order_by == 'email':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(User.email.asc())
|
query = query.order_by(User.email.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.email.desc())
|
query = query.order_by(User.email.desc())
|
||||||
|
|
||||||
elif order_by == "created_at":
|
elif order_by == 'created_at':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(User.created_at.asc())
|
query = query.order_by(User.created_at.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.created_at.desc())
|
query = query.order_by(User.created_at.desc())
|
||||||
|
|
||||||
elif order_by == "last_active_at":
|
elif order_by == 'last_active_at':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(User.last_active_at.asc())
|
query = query.order_by(User.last_active_at.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.last_active_at.desc())
|
query = query.order_by(User.last_active_at.desc())
|
||||||
|
|
||||||
elif order_by == "updated_at":
|
elif order_by == 'updated_at':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(User.updated_at.asc())
|
query = query.order_by(User.updated_at.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.updated_at.desc())
|
query = query.order_by(User.updated_at.desc())
|
||||||
elif order_by == "role":
|
elif order_by == 'role':
|
||||||
if direction == "asc":
|
if direction == 'asc':
|
||||||
query = query.order_by(User.role.asc())
|
query = query.order_by(User.role.asc())
|
||||||
else:
|
else:
|
||||||
query = query.order_by(User.role.desc())
|
query = query.order_by(User.role.desc())
|
||||||
@@ -510,13 +486,11 @@ class UsersTable:
|
|||||||
|
|
||||||
users = query.all()
|
users = query.all()
|
||||||
return {
|
return {
|
||||||
"users": [UserModel.model_validate(user) for user in users],
|
'users': [UserModel.model_validate(user) for user in users],
|
||||||
"total": total,
|
'total': total,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_users_by_group_id(
|
def get_users_by_group_id(self, group_id: str, db: Optional[Session] = None) -> list[UserModel]:
|
||||||
self, group_id: str, db: Optional[Session] = None
|
|
||||||
) -> list[UserModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
users = (
|
users = (
|
||||||
db.query(User)
|
db.query(User)
|
||||||
@@ -527,16 +501,9 @@ class UsersTable:
|
|||||||
)
|
)
|
||||||
return [UserModel.model_validate(user) for user in users]
|
return [UserModel.model_validate(user) for user in users]
|
||||||
|
|
||||||
def get_users_by_user_ids(
|
def get_users_by_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[UserStatusModel]:
|
||||||
self, user_ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> list[UserStatusModel]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
users = (
|
users = db.query(User).options(defer(User.profile_image_url)).filter(User.id.in_(user_ids)).all()
|
||||||
db.query(User)
|
|
||||||
.options(defer(User.profile_image_url))
|
|
||||||
.filter(User.id.in_(user_ids))
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
return [UserModel.model_validate(user) for user in users]
|
return [UserModel.model_validate(user) for user in users]
|
||||||
|
|
||||||
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:
|
def get_num_users(self, db: Optional[Session] = None) -> Optional[int]:
|
||||||
@@ -555,9 +522,7 @@ class UsersTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_user_webhook_url_by_id(
|
def get_user_webhook_url_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[str]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user = db.query(User).filter_by(id=id).first()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
@@ -565,11 +530,7 @@ class UsersTable:
|
|||||||
if user.settings is None:
|
if user.settings is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return (
|
return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None)
|
||||||
user.settings.get("ui", {})
|
|
||||||
.get("notifications", {})
|
|
||||||
.get("webhook_url", None)
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -577,14 +538,10 @@ class UsersTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||||
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
|
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
|
||||||
query = db.query(User).filter(
|
query = db.query(User).filter(User.last_active_at > today_midnight_timestamp)
|
||||||
User.last_active_at > today_midnight_timestamp
|
|
||||||
)
|
|
||||||
return query.count()
|
return query.count()
|
||||||
|
|
||||||
def update_user_role_by_id(
|
def update_user_role_by_id(self, id: str, role: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
self, id: str, role: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[UserModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user = db.query(User).filter_by(id=id).first()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
@@ -629,9 +586,7 @@ class UsersTable:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||||
def update_last_active_by_id(
|
def update_last_active_by_id(self, id: str, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[UserModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user = db.query(User).filter_by(id=id).first()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
@@ -665,10 +620,10 @@ class UsersTable:
|
|||||||
oauth = user.oauth or {}
|
oauth = user.oauth or {}
|
||||||
|
|
||||||
# Update or insert provider entry
|
# Update or insert provider entry
|
||||||
oauth[provider] = {"sub": sub}
|
oauth[provider] = {'sub': sub}
|
||||||
|
|
||||||
# Persist updated JSON
|
# Persist updated JSON
|
||||||
db.query(User).filter_by(id=id).update({"oauth": oauth})
|
db.query(User).filter_by(id=id).update({'oauth': oauth})
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
@@ -698,9 +653,9 @@ class UsersTable:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
scim = user.scim or {}
|
scim = user.scim or {}
|
||||||
scim[provider] = {"external_id": external_id}
|
scim[provider] = {'external_id': external_id}
|
||||||
|
|
||||||
db.query(User).filter_by(id=id).update({"scim": scim})
|
db.query(User).filter_by(id=id).update({'scim': scim})
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
@@ -708,9 +663,7 @@ class UsersTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_by_id(
|
def update_user_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
self, id: str, updated: dict, db: Optional[Session] = None
|
|
||||||
) -> Optional[UserModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user = db.query(User).filter_by(id=id).first()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
@@ -725,9 +678,7 @@ class UsersTable:
|
|||||||
print(e)
|
print(e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_settings_by_id(
|
def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
self, id: str, updated: dict, db: Optional[Session] = None
|
|
||||||
) -> Optional[UserModel]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user = db.query(User).filter_by(id=id).first()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
@@ -741,7 +692,7 @@ class UsersTable:
|
|||||||
|
|
||||||
user_settings.update(updated)
|
user_settings.update(updated)
|
||||||
|
|
||||||
db.query(User).filter_by(id=id).update({"settings": user_settings})
|
db.query(User).filter_by(id=id).update({'settings': user_settings})
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
user = db.query(User).filter_by(id=id).first()
|
user = db.query(User).filter_by(id=id).first()
|
||||||
@@ -768,9 +719,7 @@ class UsersTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_user_api_key_by_id(
|
def get_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> Optional[str]:
|
||||||
self, id: str, db: Optional[Session] = None
|
|
||||||
) -> Optional[str]:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
api_key = db.query(ApiKey).filter_by(user_id=id).first()
|
api_key = db.query(ApiKey).filter_by(user_id=id).first()
|
||||||
@@ -778,9 +727,7 @@ class UsersTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_user_api_key_by_id(
|
def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[Session] = None) -> bool:
|
||||||
self, id: str, api_key: str, db: Optional[Session] = None
|
|
||||||
) -> bool:
|
|
||||||
try:
|
try:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
db.query(ApiKey).filter_by(user_id=id).delete()
|
db.query(ApiKey).filter_by(user_id=id).delete()
|
||||||
@@ -788,7 +735,7 @@ class UsersTable:
|
|||||||
|
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
new_api_key = ApiKey(
|
new_api_key = ApiKey(
|
||||||
id=f"key_{id}",
|
id=f'key_{id}',
|
||||||
user_id=id,
|
user_id=id,
|
||||||
key=api_key,
|
key=api_key,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
@@ -811,16 +758,14 @@ class UsersTable:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_valid_user_ids(
|
def get_valid_user_ids(self, user_ids: list[str], db: Optional[Session] = None) -> list[str]:
|
||||||
self, user_ids: list[str], db: Optional[Session] = None
|
|
||||||
) -> list[str]:
|
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
users = db.query(User).filter(User.id.in_(user_ids)).all()
|
||||||
return [user.id for user in users]
|
return [user.id for user in users]
|
||||||
|
|
||||||
def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]:
|
def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]:
|
||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
user = db.query(User).filter_by(role="admin").first()
|
user = db.query(User).filter_by(role='admin').first()
|
||||||
if user:
|
if user:
|
||||||
return UserModel.model_validate(user)
|
return UserModel.model_validate(user)
|
||||||
else:
|
else:
|
||||||
@@ -830,9 +775,7 @@ class UsersTable:
|
|||||||
with get_db_context(db) as db:
|
with get_db_context(db) as db:
|
||||||
# Consider user active if last_active_at within the last 3 minutes
|
# Consider user active if last_active_at within the last 3 minutes
|
||||||
three_minutes_ago = int(time.time()) - 180
|
three_minutes_ago = int(time.time()) - 180
|
||||||
count = (
|
count = db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
|
||||||
db.query(User).filter(User.last_active_at >= three_minutes_ago).count()
|
|
||||||
)
|
|
||||||
return count
|
return count
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -40,78 +40,76 @@ class DatalabMarkerLoader:
|
|||||||
self.output_format = output_format
|
self.output_format = output_format
|
||||||
|
|
||||||
def _get_mime_type(self, filename: str) -> str:
|
def _get_mime_type(self, filename: str) -> str:
|
||||||
ext = filename.rsplit(".", 1)[-1].lower()
|
ext = filename.rsplit('.', 1)[-1].lower()
|
||||||
mime_map = {
|
mime_map = {
|
||||||
"pdf": "application/pdf",
|
'pdf': 'application/pdf',
|
||||||
"xls": "application/vnd.ms-excel",
|
'xls': 'application/vnd.ms-excel',
|
||||||
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
'xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||||
"ods": "application/vnd.oasis.opendocument.spreadsheet",
|
'ods': 'application/vnd.oasis.opendocument.spreadsheet',
|
||||||
"doc": "application/msword",
|
'doc': 'application/msword',
|
||||||
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
'docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||||
"odt": "application/vnd.oasis.opendocument.text",
|
'odt': 'application/vnd.oasis.opendocument.text',
|
||||||
"ppt": "application/vnd.ms-powerpoint",
|
'ppt': 'application/vnd.ms-powerpoint',
|
||||||
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
'pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||||
"odp": "application/vnd.oasis.opendocument.presentation",
|
'odp': 'application/vnd.oasis.opendocument.presentation',
|
||||||
"html": "text/html",
|
'html': 'text/html',
|
||||||
"epub": "application/epub+zip",
|
'epub': 'application/epub+zip',
|
||||||
"png": "image/png",
|
'png': 'image/png',
|
||||||
"jpeg": "image/jpeg",
|
'jpeg': 'image/jpeg',
|
||||||
"jpg": "image/jpeg",
|
'jpg': 'image/jpeg',
|
||||||
"webp": "image/webp",
|
'webp': 'image/webp',
|
||||||
"gif": "image/gif",
|
'gif': 'image/gif',
|
||||||
"tiff": "image/tiff",
|
'tiff': 'image/tiff',
|
||||||
}
|
}
|
||||||
return mime_map.get(ext, "application/octet-stream")
|
return mime_map.get(ext, 'application/octet-stream')
|
||||||
|
|
||||||
def check_marker_request_status(self, request_id: str) -> dict:
|
def check_marker_request_status(self, request_id: str) -> dict:
|
||||||
url = f"{self.api_base_url}/{request_id}"
|
url = f'{self.api_base_url}/{request_id}'
|
||||||
headers = {"X-Api-Key": self.api_key}
|
headers = {'X-Api-Key': self.api_key}
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
log.info(f"Marker API status check for request {request_id}: {result}")
|
log.info(f'Marker API status check for request {request_id}: {result}')
|
||||||
return result
|
return result
|
||||||
except requests.HTTPError as e:
|
except requests.HTTPError as e:
|
||||||
log.error(f"Error checking Marker request status: {e}")
|
log.error(f'Error checking Marker request status: {e}')
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_502_BAD_GATEWAY,
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
detail=f"Failed to check Marker request: {e}",
|
detail=f'Failed to check Marker request: {e}',
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
log.error(f"Invalid JSON checking Marker request: {e}")
|
log.error(f'Invalid JSON checking Marker request: {e}')
|
||||||
raise HTTPException(
|
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON: {e}')
|
||||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
filename = os.path.basename(self.file_path)
|
filename = os.path.basename(self.file_path)
|
||||||
mime_type = self._get_mime_type(filename)
|
mime_type = self._get_mime_type(filename)
|
||||||
headers = {"X-Api-Key": self.api_key}
|
headers = {'X-Api-Key': self.api_key}
|
||||||
|
|
||||||
form_data = {
|
form_data = {
|
||||||
"use_llm": str(self.use_llm).lower(),
|
'use_llm': str(self.use_llm).lower(),
|
||||||
"skip_cache": str(self.skip_cache).lower(),
|
'skip_cache': str(self.skip_cache).lower(),
|
||||||
"force_ocr": str(self.force_ocr).lower(),
|
'force_ocr': str(self.force_ocr).lower(),
|
||||||
"paginate": str(self.paginate).lower(),
|
'paginate': str(self.paginate).lower(),
|
||||||
"strip_existing_ocr": str(self.strip_existing_ocr).lower(),
|
'strip_existing_ocr': str(self.strip_existing_ocr).lower(),
|
||||||
"disable_image_extraction": str(self.disable_image_extraction).lower(),
|
'disable_image_extraction': str(self.disable_image_extraction).lower(),
|
||||||
"format_lines": str(self.format_lines).lower(),
|
'format_lines': str(self.format_lines).lower(),
|
||||||
"output_format": self.output_format,
|
'output_format': self.output_format,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.additional_config and self.additional_config.strip():
|
if self.additional_config and self.additional_config.strip():
|
||||||
form_data["additional_config"] = self.additional_config
|
form_data['additional_config'] = self.additional_config
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
|
f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, 'rb') as f:
|
||||||
files = {"file": (filename, f, mime_type)}
|
files = {'file': (filename, f, mime_type)}
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{self.api_base_url}",
|
f'{self.api_base_url}',
|
||||||
data=form_data,
|
data=form_data,
|
||||||
files=files,
|
files=files,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@@ -119,29 +117,25 @@ class DatalabMarkerLoader:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise HTTPException(
|
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
|
||||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
|
||||||
)
|
|
||||||
except requests.HTTPError as e:
|
except requests.HTTPError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Datalab Marker request failed: {e}",
|
detail=f'Datalab Marker request failed: {e}',
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Invalid JSON response: {e}')
|
||||||
status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||||
|
|
||||||
if not result.get("success"):
|
if not result.get('success'):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}",
|
detail=f'Datalab Marker request failed: {result.get("error", "Unknown error")}',
|
||||||
)
|
)
|
||||||
|
|
||||||
check_url = result.get("request_check_url")
|
check_url = result.get('request_check_url')
|
||||||
request_id = result.get("request_id")
|
request_id = result.get('request_id')
|
||||||
|
|
||||||
# Check if this is a direct response (self-hosted) or polling response (DataLab)
|
# Check if this is a direct response (self-hosted) or polling response (DataLab)
|
||||||
if check_url:
|
if check_url:
|
||||||
@@ -154,54 +148,45 @@ class DatalabMarkerLoader:
|
|||||||
poll_result = poll_response.json()
|
poll_result = poll_response.json()
|
||||||
except (requests.HTTPError, ValueError) as e:
|
except (requests.HTTPError, ValueError) as e:
|
||||||
raw_body = poll_response.text
|
raw_body = poll_response.text
|
||||||
log.error(f"Polling error: {e}, response body: {raw_body}")
|
log.error(f'Polling error: {e}, response body: {raw_body}')
|
||||||
raise HTTPException(
|
raise HTTPException(status.HTTP_502_BAD_GATEWAY, detail=f'Polling failed: {e}')
|
||||||
status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
status_val = poll_result.get("status")
|
status_val = poll_result.get('status')
|
||||||
success_val = poll_result.get("success")
|
success_val = poll_result.get('success')
|
||||||
|
|
||||||
if status_val == "complete":
|
if status_val == 'complete':
|
||||||
summary = {
|
summary = {
|
||||||
k: poll_result.get(k)
|
k: poll_result.get(k)
|
||||||
for k in (
|
for k in (
|
||||||
"status",
|
'status',
|
||||||
"output_format",
|
'output_format',
|
||||||
"success",
|
'success',
|
||||||
"error",
|
'error',
|
||||||
"page_count",
|
'page_count',
|
||||||
"total_cost",
|
'total_cost',
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
log.info(
|
log.info(f'Marker processing completed successfully: {json.dumps(summary, indent=2)}')
|
||||||
f"Marker processing completed successfully: {json.dumps(summary, indent=2)}"
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if status_val == "failed" or success_val is False:
|
if status_val == 'failed' or success_val is False:
|
||||||
log.error(
|
log.error(f'Marker poll failed full response: {json.dumps(poll_result, indent=2)}')
|
||||||
f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}"
|
error_msg = poll_result.get('error') or 'Marker returned failure without error message'
|
||||||
)
|
|
||||||
error_msg = (
|
|
||||||
poll_result.get("error")
|
|
||||||
or "Marker returned failure without error message"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Marker processing failed: {error_msg}",
|
detail=f'Marker processing failed: {error_msg}',
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
detail="Marker processing timed out",
|
detail='Marker processing timed out',
|
||||||
)
|
)
|
||||||
|
|
||||||
if not poll_result.get("success", False):
|
if not poll_result.get('success', False):
|
||||||
error_msg = poll_result.get("error") or "Unknown processing error"
|
error_msg = poll_result.get('error') or 'Unknown processing error'
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Final processing failed: {error_msg}",
|
detail=f'Final processing failed: {error_msg}',
|
||||||
)
|
)
|
||||||
|
|
||||||
# DataLab format - content in format-specific fields
|
# DataLab format - content in format-specific fields
|
||||||
@@ -210,69 +195,65 @@ class DatalabMarkerLoader:
|
|||||||
final_result = poll_result
|
final_result = poll_result
|
||||||
else:
|
else:
|
||||||
# Self-hosted direct response - content in "output" field
|
# Self-hosted direct response - content in "output" field
|
||||||
if "output" in result:
|
if 'output' in result:
|
||||||
log.info("Self-hosted Marker returned direct response without polling")
|
log.info('Self-hosted Marker returned direct response without polling')
|
||||||
raw_content = result.get("output")
|
raw_content = result.get('output')
|
||||||
final_result = result
|
final_result = result
|
||||||
else:
|
else:
|
||||||
available_fields = (
|
available_fields = list(result.keys()) if isinstance(result, dict) else 'non-dict response'
|
||||||
list(result.keys())
|
|
||||||
if isinstance(result, dict)
|
|
||||||
else "non-dict response"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_502_BAD_GATEWAY,
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
|
detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.output_format.lower() == "json":
|
if self.output_format.lower() == 'json':
|
||||||
full_text = json.dumps(raw_content, indent=2)
|
full_text = json.dumps(raw_content, indent=2)
|
||||||
elif self.output_format.lower() in {"markdown", "html"}:
|
elif self.output_format.lower() in {'markdown', 'html'}:
|
||||||
full_text = str(raw_content).strip()
|
full_text = str(raw_content).strip()
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Unsupported output format: {self.output_format}",
|
detail=f'Unsupported output format: {self.output_format}',
|
||||||
)
|
)
|
||||||
|
|
||||||
if not full_text:
|
if not full_text:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Marker returned empty content",
|
detail='Marker returned empty content',
|
||||||
)
|
)
|
||||||
|
|
||||||
marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output")
|
marker_output_dir = os.path.join('/app/backend/data/uploads', 'marker_output')
|
||||||
os.makedirs(marker_output_dir, exist_ok=True)
|
os.makedirs(marker_output_dir, exist_ok=True)
|
||||||
|
|
||||||
file_ext_map = {"markdown": "md", "json": "json", "html": "html"}
|
file_ext_map = {'markdown': 'md', 'json': 'json', 'html': 'html'}
|
||||||
file_ext = file_ext_map.get(self.output_format.lower(), "txt")
|
file_ext = file_ext_map.get(self.output_format.lower(), 'txt')
|
||||||
output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}"
|
output_filename = f'{os.path.splitext(filename)[0]}.{file_ext}'
|
||||||
output_path = os.path.join(marker_output_dir, output_filename)
|
output_path = os.path.join(marker_output_dir, output_filename)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
f.write(full_text)
|
f.write(full_text)
|
||||||
log.info(f"Saved Marker output to: {output_path}")
|
log.info(f'Saved Marker output to: {output_path}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"Failed to write marker output to disk: {e}")
|
log.warning(f'Failed to write marker output to disk: {e}')
|
||||||
|
|
||||||
metadata = {
|
metadata = {
|
||||||
"source": filename,
|
'source': filename,
|
||||||
"output_format": final_result.get("output_format", self.output_format),
|
'output_format': final_result.get('output_format', self.output_format),
|
||||||
"page_count": final_result.get("page_count", 0),
|
'page_count': final_result.get('page_count', 0),
|
||||||
"processed_with_llm": self.use_llm,
|
'processed_with_llm': self.use_llm,
|
||||||
"request_id": request_id or "",
|
'request_id': request_id or '',
|
||||||
}
|
}
|
||||||
|
|
||||||
images = final_result.get("images", {})
|
images = final_result.get('images', {})
|
||||||
if images:
|
if images:
|
||||||
metadata["image_count"] = len(images)
|
metadata['image_count'] = len(images)
|
||||||
metadata["images"] = json.dumps(list(images.keys()))
|
metadata['images'] = json.dumps(list(images.keys()))
|
||||||
|
|
||||||
for k, v in metadata.items():
|
for k, v in metadata.items():
|
||||||
if isinstance(v, (dict, list)):
|
if isinstance(v, (dict, list)):
|
||||||
metadata[k] = json.dumps(v)
|
metadata[k] = json.dumps(v)
|
||||||
elif v is None:
|
elif v is None:
|
||||||
metadata[k] = ""
|
metadata[k] = ''
|
||||||
|
|
||||||
return [Document(page_content=full_text, metadata=metadata)]
|
return [Document(page_content=full_text, metadata=metadata)]
|
||||||
|
|||||||
@@ -29,18 +29,18 @@ class ExternalDocumentLoader(BaseLoader):
|
|||||||
self.user = user
|
self.user = user
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, 'rb') as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
||||||
headers = {}
|
headers = {}
|
||||||
if self.mime_type is not None:
|
if self.mime_type is not None:
|
||||||
headers["Content-Type"] = self.mime_type
|
headers['Content-Type'] = self.mime_type
|
||||||
|
|
||||||
if self.api_key is not None:
|
if self.api_key is not None:
|
||||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
headers['Authorization'] = f'Bearer {self.api_key}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
headers["X-Filename"] = quote(os.path.basename(self.file_path))
|
headers['X-Filename'] = quote(os.path.basename(self.file_path))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -48,24 +48,23 @@ class ExternalDocumentLoader(BaseLoader):
|
|||||||
headers = include_user_info_headers(headers, self.user)
|
headers = include_user_info_headers(headers, self.user)
|
||||||
|
|
||||||
url = self.url
|
url = self.url
|
||||||
if url.endswith("/"):
|
if url.endswith('/'):
|
||||||
url = url[:-1]
|
url = url[:-1]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.put(f"{url}/process", data=data, headers=headers)
|
response = requests.put(f'{url}/process', data=data, headers=headers)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error connecting to endpoint: {e}")
|
log.error(f'Error connecting to endpoint: {e}')
|
||||||
raise Exception(f"Error connecting to endpoint: {e}")
|
raise Exception(f'Error connecting to endpoint: {e}')
|
||||||
|
|
||||||
if response.ok:
|
if response.ok:
|
||||||
|
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
if response_data:
|
if response_data:
|
||||||
if isinstance(response_data, dict):
|
if isinstance(response_data, dict):
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content=response_data.get("page_content"),
|
page_content=response_data.get('page_content'),
|
||||||
metadata=response_data.get("metadata"),
|
metadata=response_data.get('metadata'),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
elif isinstance(response_data, list):
|
elif isinstance(response_data, list):
|
||||||
@@ -73,17 +72,15 @@ class ExternalDocumentLoader(BaseLoader):
|
|||||||
for document in response_data:
|
for document in response_data:
|
||||||
documents.append(
|
documents.append(
|
||||||
Document(
|
Document(
|
||||||
page_content=document.get("page_content"),
|
page_content=document.get('page_content'),
|
||||||
metadata=document.get("metadata"),
|
metadata=document.get('metadata'),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return documents
|
return documents
|
||||||
else:
|
else:
|
||||||
raise Exception("Error loading document: Unable to parse content")
|
raise Exception('Error loading document: Unable to parse content')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception("Error loading document: No content returned")
|
raise Exception('Error loading document: No content returned')
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(f'Error loading document: {response.status_code} {response.text}')
|
||||||
f"Error loading document: {response.status_code} {response.text}"
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -30,22 +30,22 @@ class ExternalWebLoader(BaseLoader):
|
|||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.external_url,
|
self.external_url,
|
||||||
headers={
|
headers={
|
||||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) External Web Loader",
|
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) External Web Loader',
|
||||||
"Authorization": f"Bearer {self.external_api_key}",
|
'Authorization': f'Bearer {self.external_api_key}',
|
||||||
},
|
},
|
||||||
json={
|
json={
|
||||||
"urls": urls,
|
'urls': urls,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
results = response.json()
|
results = response.json()
|
||||||
for result in results:
|
for result in results:
|
||||||
yield Document(
|
yield Document(
|
||||||
page_content=result.get("page_content", ""),
|
page_content=result.get('page_content', ''),
|
||||||
metadata=result.get("metadata", {}),
|
metadata=result.get('metadata', {}),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.continue_on_failure:
|
if self.continue_on_failure:
|
||||||
log.error(f"Error extracting content from batch {urls}: {e}")
|
log.error(f'Error extracting content from batch {urls}: {e}')
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -30,59 +30,59 @@ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
known_source_ext = [
|
known_source_ext = [
|
||||||
"go",
|
'go',
|
||||||
"py",
|
'py',
|
||||||
"java",
|
'java',
|
||||||
"sh",
|
'sh',
|
||||||
"bat",
|
'bat',
|
||||||
"ps1",
|
'ps1',
|
||||||
"cmd",
|
'cmd',
|
||||||
"js",
|
'js',
|
||||||
"ts",
|
'ts',
|
||||||
"css",
|
'css',
|
||||||
"cpp",
|
'cpp',
|
||||||
"hpp",
|
'hpp',
|
||||||
"h",
|
'h',
|
||||||
"c",
|
'c',
|
||||||
"cs",
|
'cs',
|
||||||
"sql",
|
'sql',
|
||||||
"log",
|
'log',
|
||||||
"ini",
|
'ini',
|
||||||
"pl",
|
'pl',
|
||||||
"pm",
|
'pm',
|
||||||
"r",
|
'r',
|
||||||
"dart",
|
'dart',
|
||||||
"dockerfile",
|
'dockerfile',
|
||||||
"env",
|
'env',
|
||||||
"php",
|
'php',
|
||||||
"hs",
|
'hs',
|
||||||
"hsc",
|
'hsc',
|
||||||
"lua",
|
'lua',
|
||||||
"nginxconf",
|
'nginxconf',
|
||||||
"conf",
|
'conf',
|
||||||
"m",
|
'm',
|
||||||
"mm",
|
'mm',
|
||||||
"plsql",
|
'plsql',
|
||||||
"perl",
|
'perl',
|
||||||
"rb",
|
'rb',
|
||||||
"rs",
|
'rs',
|
||||||
"db2",
|
'db2',
|
||||||
"scala",
|
'scala',
|
||||||
"bash",
|
'bash',
|
||||||
"swift",
|
'swift',
|
||||||
"vue",
|
'vue',
|
||||||
"svelte",
|
'svelte',
|
||||||
"ex",
|
'ex',
|
||||||
"exs",
|
'exs',
|
||||||
"erl",
|
'erl',
|
||||||
"tsx",
|
'tsx',
|
||||||
"jsx",
|
'jsx',
|
||||||
"hs",
|
'hs',
|
||||||
"lhs",
|
'lhs',
|
||||||
"json",
|
'json',
|
||||||
"yaml",
|
'yaml',
|
||||||
"yml",
|
'yml',
|
||||||
"toml",
|
'toml',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -99,11 +99,11 @@ class ExcelLoader:
|
|||||||
xls = pd.ExcelFile(self.file_path)
|
xls = pd.ExcelFile(self.file_path)
|
||||||
for sheet_name in xls.sheet_names:
|
for sheet_name in xls.sheet_names:
|
||||||
df = pd.read_excel(xls, sheet_name=sheet_name)
|
df = pd.read_excel(xls, sheet_name=sheet_name)
|
||||||
text_parts.append(f"Sheet: {sheet_name}\n{df.to_string(index=False)}")
|
text_parts.append(f'Sheet: {sheet_name}\n{df.to_string(index=False)}')
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content="\n\n".join(text_parts),
|
page_content='\n\n'.join(text_parts),
|
||||||
metadata={"source": self.file_path},
|
metadata={'source': self.file_path},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -125,11 +125,11 @@ class PptxLoader:
|
|||||||
if shape.has_text_frame:
|
if shape.has_text_frame:
|
||||||
slide_texts.append(shape.text_frame.text)
|
slide_texts.append(shape.text_frame.text)
|
||||||
if slide_texts:
|
if slide_texts:
|
||||||
text_parts.append(f"Slide {i}:\n" + "\n".join(slide_texts))
|
text_parts.append(f'Slide {i}:\n' + '\n'.join(slide_texts))
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content="\n\n".join(text_parts),
|
page_content='\n\n'.join(text_parts),
|
||||||
metadata={"source": self.file_path},
|
metadata={'source': self.file_path},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -143,41 +143,41 @@ class TikaLoader:
|
|||||||
self.extract_images = extract_images
|
self.extract_images = extract_images
|
||||||
|
|
||||||
def load(self) -> list[Document]:
|
def load(self) -> list[Document]:
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, 'rb') as f:
|
||||||
data = f.read()
|
data = f.read()
|
||||||
|
|
||||||
if self.mime_type is not None:
|
if self.mime_type is not None:
|
||||||
headers = {"Content-Type": self.mime_type}
|
headers = {'Content-Type': self.mime_type}
|
||||||
else:
|
else:
|
||||||
headers = {}
|
headers = {}
|
||||||
|
|
||||||
if self.extract_images == True:
|
if self.extract_images == True:
|
||||||
headers["X-Tika-PDFextractInlineImages"] = "true"
|
headers['X-Tika-PDFextractInlineImages'] = 'true'
|
||||||
|
|
||||||
endpoint = self.url
|
endpoint = self.url
|
||||||
if not endpoint.endswith("/"):
|
if not endpoint.endswith('/'):
|
||||||
endpoint += "/"
|
endpoint += '/'
|
||||||
endpoint += "tika/text"
|
endpoint += 'tika/text'
|
||||||
|
|
||||||
r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY)
|
r = requests.put(endpoint, data=data, headers=headers, verify=REQUESTS_VERIFY)
|
||||||
|
|
||||||
if r.ok:
|
if r.ok:
|
||||||
raw_metadata = r.json()
|
raw_metadata = r.json()
|
||||||
text = raw_metadata.get("X-TIKA:content", "<No text content found>").strip()
|
text = raw_metadata.get('X-TIKA:content', '<No text content found>').strip()
|
||||||
|
|
||||||
if "Content-Type" in raw_metadata:
|
if 'Content-Type' in raw_metadata:
|
||||||
headers["Content-Type"] = raw_metadata["Content-Type"]
|
headers['Content-Type'] = raw_metadata['Content-Type']
|
||||||
|
|
||||||
log.debug("Tika extracted text: %s", text)
|
log.debug('Tika extracted text: %s', text)
|
||||||
|
|
||||||
return [Document(page_content=text, metadata=headers)]
|
return [Document(page_content=text, metadata=headers)]
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Error calling Tika: {r.reason}")
|
raise Exception(f'Error calling Tika: {r.reason}')
|
||||||
|
|
||||||
|
|
||||||
class DoclingLoader:
|
class DoclingLoader:
|
||||||
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
|
def __init__(self, url, api_key=None, file_path=None, mime_type=None, params=None):
|
||||||
self.url = url.rstrip("/")
|
self.url = url.rstrip('/')
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.mime_type = mime_type
|
self.mime_type = mime_type
|
||||||
@@ -185,199 +185,183 @@ class DoclingLoader:
|
|||||||
self.params = params or {}
|
self.params = params or {}
|
||||||
|
|
||||||
def load(self) -> list[Document]:
|
def load(self) -> list[Document]:
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, 'rb') as f:
|
||||||
headers = {}
|
headers = {}
|
||||||
if self.api_key:
|
if self.api_key:
|
||||||
headers["X-Api-Key"] = f"{self.api_key}"
|
headers['X-Api-Key'] = f'{self.api_key}'
|
||||||
|
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
f"{self.url}/v1/convert/file",
|
f'{self.url}/v1/convert/file',
|
||||||
files={
|
files={
|
||||||
"files": (
|
'files': (
|
||||||
self.file_path,
|
self.file_path,
|
||||||
f,
|
f,
|
||||||
self.mime_type or "application/octet-stream",
|
self.mime_type or 'application/octet-stream',
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
data={
|
data={
|
||||||
"image_export_mode": "placeholder",
|
'image_export_mode': 'placeholder',
|
||||||
**self.params,
|
**self.params,
|
||||||
},
|
},
|
||||||
headers=headers,
|
headers=headers,
|
||||||
)
|
)
|
||||||
if r.ok:
|
if r.ok:
|
||||||
result = r.json()
|
result = r.json()
|
||||||
document_data = result.get("document", {})
|
document_data = result.get('document', {})
|
||||||
text = document_data.get("md_content", "<No text content found>")
|
text = document_data.get('md_content', '<No text content found>')
|
||||||
|
|
||||||
metadata = {"Content-Type": self.mime_type} if self.mime_type else {}
|
metadata = {'Content-Type': self.mime_type} if self.mime_type else {}
|
||||||
|
|
||||||
log.debug("Docling extracted text: %s", text)
|
log.debug('Docling extracted text: %s', text)
|
||||||
return [Document(page_content=text, metadata=metadata)]
|
return [Document(page_content=text, metadata=metadata)]
|
||||||
else:
|
else:
|
||||||
error_msg = f"Error calling Docling API: {r.reason}"
|
error_msg = f'Error calling Docling API: {r.reason}'
|
||||||
if r.text:
|
if r.text:
|
||||||
try:
|
try:
|
||||||
error_data = r.json()
|
error_data = r.json()
|
||||||
if "detail" in error_data:
|
if 'detail' in error_data:
|
||||||
error_msg += f" - {error_data['detail']}"
|
error_msg += f' - {error_data["detail"]}'
|
||||||
except Exception:
|
except Exception:
|
||||||
error_msg += f" - {r.text}"
|
error_msg += f' - {r.text}'
|
||||||
raise Exception(f"Error calling Docling: {error_msg}")
|
raise Exception(f'Error calling Docling: {error_msg}')
|
||||||
|
|
||||||
|
|
||||||
class Loader:
|
class Loader:
|
||||||
def __init__(self, engine: str = "", **kwargs):
|
def __init__(self, engine: str = '', **kwargs):
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
self.user = kwargs.get("user", None)
|
self.user = kwargs.get('user', None)
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def load(
|
def load(self, filename: str, file_content_type: str, file_path: str) -> list[Document]:
|
||||||
self, filename: str, file_content_type: str, file_path: str
|
|
||||||
) -> list[Document]:
|
|
||||||
loader = self._get_loader(filename, file_content_type, file_path)
|
loader = self._get_loader(filename, file_content_type, file_path)
|
||||||
docs = loader.load()
|
docs = loader.load()
|
||||||
|
|
||||||
return [
|
return [Document(page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata) for doc in docs]
|
||||||
Document(
|
|
||||||
page_content=ftfy.fix_text(doc.page_content), metadata=doc.metadata
|
|
||||||
)
|
|
||||||
for doc in docs
|
|
||||||
]
|
|
||||||
|
|
||||||
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
|
def _is_text_file(self, file_ext: str, file_content_type: str) -> bool:
|
||||||
return file_ext in known_source_ext or (
|
return file_ext in known_source_ext or (
|
||||||
file_content_type
|
file_content_type
|
||||||
and file_content_type.find("text/") >= 0
|
and file_content_type.find('text/') >= 0
|
||||||
# Avoid text/html files being detected as text
|
# Avoid text/html files being detected as text
|
||||||
and not file_content_type.find("html") >= 0
|
and not file_content_type.find('html') >= 0
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
|
def _get_loader(self, filename: str, file_content_type: str, file_path: str):
|
||||||
file_ext = filename.split(".")[-1].lower()
|
file_ext = filename.split('.')[-1].lower()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.engine == "external"
|
self.engine == 'external'
|
||||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL")
|
and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL')
|
||||||
and self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY")
|
and self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY')
|
||||||
):
|
):
|
||||||
loader = ExternalDocumentLoader(
|
loader = ExternalDocumentLoader(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
url=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_URL"),
|
url=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_URL'),
|
||||||
api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"),
|
api_key=self.kwargs.get('EXTERNAL_DOCUMENT_LOADER_API_KEY'),
|
||||||
mime_type=file_content_type,
|
mime_type=file_content_type,
|
||||||
user=self.user,
|
user=self.user,
|
||||||
)
|
)
|
||||||
elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"):
|
elif self.engine == 'tika' and self.kwargs.get('TIKA_SERVER_URL'):
|
||||||
if self._is_text_file(file_ext, file_content_type):
|
if self._is_text_file(file_ext, file_content_type):
|
||||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||||
else:
|
else:
|
||||||
loader = TikaLoader(
|
loader = TikaLoader(
|
||||||
url=self.kwargs.get("TIKA_SERVER_URL"),
|
url=self.kwargs.get('TIKA_SERVER_URL'),
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
self.engine == "datalab_marker"
|
self.engine == 'datalab_marker'
|
||||||
and self.kwargs.get("DATALAB_MARKER_API_KEY")
|
and self.kwargs.get('DATALAB_MARKER_API_KEY')
|
||||||
and file_ext
|
and file_ext
|
||||||
in [
|
in [
|
||||||
"pdf",
|
'pdf',
|
||||||
"xls",
|
'xls',
|
||||||
"xlsx",
|
'xlsx',
|
||||||
"ods",
|
'ods',
|
||||||
"doc",
|
'doc',
|
||||||
"docx",
|
'docx',
|
||||||
"odt",
|
'odt',
|
||||||
"ppt",
|
'ppt',
|
||||||
"pptx",
|
'pptx',
|
||||||
"odp",
|
'odp',
|
||||||
"html",
|
'html',
|
||||||
"epub",
|
'epub',
|
||||||
"png",
|
'png',
|
||||||
"jpeg",
|
'jpeg',
|
||||||
"jpg",
|
'jpg',
|
||||||
"webp",
|
'webp',
|
||||||
"gif",
|
'gif',
|
||||||
"tiff",
|
'tiff',
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "")
|
api_base_url = self.kwargs.get('DATALAB_MARKER_API_BASE_URL', '')
|
||||||
if not api_base_url or api_base_url.strip() == "":
|
if not api_base_url or api_base_url.strip() == '':
|
||||||
api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
|
api_base_url = 'https://www.datalab.to/api/v1/marker' # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349
|
||||||
|
|
||||||
loader = DatalabMarkerLoader(
|
loader = DatalabMarkerLoader(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
api_key=self.kwargs["DATALAB_MARKER_API_KEY"],
|
api_key=self.kwargs['DATALAB_MARKER_API_KEY'],
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"),
|
additional_config=self.kwargs.get('DATALAB_MARKER_ADDITIONAL_CONFIG'),
|
||||||
use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False),
|
use_llm=self.kwargs.get('DATALAB_MARKER_USE_LLM', False),
|
||||||
skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False),
|
skip_cache=self.kwargs.get('DATALAB_MARKER_SKIP_CACHE', False),
|
||||||
force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False),
|
force_ocr=self.kwargs.get('DATALAB_MARKER_FORCE_OCR', False),
|
||||||
paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False),
|
paginate=self.kwargs.get('DATALAB_MARKER_PAGINATE', False),
|
||||||
strip_existing_ocr=self.kwargs.get(
|
strip_existing_ocr=self.kwargs.get('DATALAB_MARKER_STRIP_EXISTING_OCR', False),
|
||||||
"DATALAB_MARKER_STRIP_EXISTING_OCR", False
|
disable_image_extraction=self.kwargs.get('DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION', False),
|
||||||
),
|
format_lines=self.kwargs.get('DATALAB_MARKER_FORMAT_LINES', False),
|
||||||
disable_image_extraction=self.kwargs.get(
|
output_format=self.kwargs.get('DATALAB_MARKER_OUTPUT_FORMAT', 'markdown'),
|
||||||
"DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False
|
|
||||||
),
|
|
||||||
format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False),
|
|
||||||
output_format=self.kwargs.get(
|
|
||||||
"DATALAB_MARKER_OUTPUT_FORMAT", "markdown"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"):
|
elif self.engine == 'docling' and self.kwargs.get('DOCLING_SERVER_URL'):
|
||||||
if self._is_text_file(file_ext, file_content_type):
|
if self._is_text_file(file_ext, file_content_type):
|
||||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||||
else:
|
else:
|
||||||
# Build params for DoclingLoader
|
# Build params for DoclingLoader
|
||||||
params = self.kwargs.get("DOCLING_PARAMS", {})
|
params = self.kwargs.get('DOCLING_PARAMS', {})
|
||||||
if not isinstance(params, dict):
|
if not isinstance(params, dict):
|
||||||
try:
|
try:
|
||||||
params = json.loads(params)
|
params = json.loads(params)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
log.error("Invalid DOCLING_PARAMS format, expected JSON object")
|
log.error('Invalid DOCLING_PARAMS format, expected JSON object')
|
||||||
params = {}
|
params = {}
|
||||||
|
|
||||||
loader = DoclingLoader(
|
loader = DoclingLoader(
|
||||||
url=self.kwargs.get("DOCLING_SERVER_URL"),
|
url=self.kwargs.get('DOCLING_SERVER_URL'),
|
||||||
api_key=self.kwargs.get("DOCLING_API_KEY", None),
|
api_key=self.kwargs.get('DOCLING_API_KEY', None),
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
mime_type=file_content_type,
|
mime_type=file_content_type,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
self.engine == "document_intelligence"
|
self.engine == 'document_intelligence'
|
||||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
and self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT') != ''
|
||||||
and (
|
and (
|
||||||
file_ext in ["pdf", "docx", "ppt", "pptx"]
|
file_ext in ['pdf', 'docx', 'ppt', 'pptx']
|
||||||
or file_content_type
|
or file_content_type
|
||||||
in [
|
in [
|
||||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||||
"application/vnd.ms-powerpoint",
|
'application/vnd.ms-powerpoint',
|
||||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "":
|
if self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY') != '':
|
||||||
loader = AzureAIDocumentIntelligenceLoader(
|
loader = AzureAIDocumentIntelligenceLoader(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
|
||||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
api_key=self.kwargs.get('DOCUMENT_INTELLIGENCE_KEY'),
|
||||||
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
loader = AzureAIDocumentIntelligenceLoader(
|
loader = AzureAIDocumentIntelligenceLoader(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
api_endpoint=self.kwargs.get('DOCUMENT_INTELLIGENCE_ENDPOINT'),
|
||||||
azure_credential=DefaultAzureCredential(),
|
azure_credential=DefaultAzureCredential(),
|
||||||
api_model=self.kwargs.get("DOCUMENT_INTELLIGENCE_MODEL"),
|
api_model=self.kwargs.get('DOCUMENT_INTELLIGENCE_MODEL'),
|
||||||
)
|
)
|
||||||
elif self.engine == "mineru" and file_ext in [
|
elif self.engine == 'mineru' and file_ext in ['pdf']: # MinerU currently only supports PDF
|
||||||
"pdf"
|
mineru_timeout = self.kwargs.get('MINERU_API_TIMEOUT', 300)
|
||||||
]: # MinerU currently only supports PDF
|
|
||||||
|
|
||||||
mineru_timeout = self.kwargs.get("MINERU_API_TIMEOUT", 300)
|
|
||||||
if mineru_timeout:
|
if mineru_timeout:
|
||||||
try:
|
try:
|
||||||
mineru_timeout = int(mineru_timeout)
|
mineru_timeout = int(mineru_timeout)
|
||||||
@@ -386,111 +370,115 @@ class Loader:
|
|||||||
|
|
||||||
loader = MinerULoader(
|
loader = MinerULoader(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
api_mode=self.kwargs.get("MINERU_API_MODE", "local"),
|
api_mode=self.kwargs.get('MINERU_API_MODE', 'local'),
|
||||||
api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"),
|
api_url=self.kwargs.get('MINERU_API_URL', 'http://localhost:8000'),
|
||||||
api_key=self.kwargs.get("MINERU_API_KEY", ""),
|
api_key=self.kwargs.get('MINERU_API_KEY', ''),
|
||||||
params=self.kwargs.get("MINERU_PARAMS", {}),
|
params=self.kwargs.get('MINERU_PARAMS', {}),
|
||||||
timeout=mineru_timeout,
|
timeout=mineru_timeout,
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
self.engine == "mistral_ocr"
|
self.engine == 'mistral_ocr'
|
||||||
and self.kwargs.get("MISTRAL_OCR_API_KEY") != ""
|
and self.kwargs.get('MISTRAL_OCR_API_KEY') != ''
|
||||||
and file_ext
|
and file_ext in ['pdf'] # Mistral OCR currently only supports PDF and images
|
||||||
in ["pdf"] # Mistral OCR currently only supports PDF and images
|
|
||||||
):
|
):
|
||||||
loader = MistralLoader(
|
loader = MistralLoader(
|
||||||
base_url=self.kwargs.get("MISTRAL_OCR_API_BASE_URL"),
|
base_url=self.kwargs.get('MISTRAL_OCR_API_BASE_URL'),
|
||||||
api_key=self.kwargs.get("MISTRAL_OCR_API_KEY"),
|
api_key=self.kwargs.get('MISTRAL_OCR_API_KEY'),
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if file_ext == "pdf":
|
if file_ext == 'pdf':
|
||||||
loader = PyPDFLoader(
|
loader = PyPDFLoader(
|
||||||
file_path,
|
file_path,
|
||||||
extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"),
|
extract_images=self.kwargs.get('PDF_EXTRACT_IMAGES'),
|
||||||
mode=self.kwargs.get("PDF_LOADER_MODE", "page"),
|
mode=self.kwargs.get('PDF_LOADER_MODE', 'page'),
|
||||||
)
|
)
|
||||||
elif file_ext == "csv":
|
elif file_ext == 'csv':
|
||||||
loader = CSVLoader(file_path, autodetect_encoding=True)
|
loader = CSVLoader(file_path, autodetect_encoding=True)
|
||||||
elif file_ext == "rst":
|
elif file_ext == 'rst':
|
||||||
try:
|
try:
|
||||||
from langchain_community.document_loaders import UnstructuredRSTLoader
|
from langchain_community.document_loaders import UnstructuredRSTLoader
|
||||||
loader = UnstructuredRSTLoader(file_path, mode="elements")
|
|
||||||
|
loader = UnstructuredRSTLoader(file_path, mode='elements')
|
||||||
except ImportError:
|
except ImportError:
|
||||||
log.warning(
|
log.warning(
|
||||||
"The 'unstructured' package is not installed. "
|
"The 'unstructured' package is not installed. "
|
||||||
"Falling back to plain text loading for .rst file. "
|
'Falling back to plain text loading for .rst file. '
|
||||||
"Install it with: pip install unstructured"
|
'Install it with: pip install unstructured'
|
||||||
)
|
)
|
||||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||||
elif file_ext == "xml":
|
elif file_ext == 'xml':
|
||||||
try:
|
try:
|
||||||
from langchain_community.document_loaders import UnstructuredXMLLoader
|
from langchain_community.document_loaders import UnstructuredXMLLoader
|
||||||
|
|
||||||
loader = UnstructuredXMLLoader(file_path)
|
loader = UnstructuredXMLLoader(file_path)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
log.warning(
|
log.warning(
|
||||||
"The 'unstructured' package is not installed. "
|
"The 'unstructured' package is not installed. "
|
||||||
"Falling back to plain text loading for .xml file. "
|
'Falling back to plain text loading for .xml file. '
|
||||||
"Install it with: pip install unstructured"
|
'Install it with: pip install unstructured'
|
||||||
)
|
)
|
||||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||||
elif file_ext in ["htm", "html"]:
|
elif file_ext in ['htm', 'html']:
|
||||||
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
|
loader = BSHTMLLoader(file_path, open_encoding='unicode_escape')
|
||||||
elif file_ext == "md":
|
elif file_ext == 'md':
|
||||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||||
elif file_content_type == "application/epub+zip":
|
elif file_content_type == 'application/epub+zip':
|
||||||
try:
|
try:
|
||||||
from langchain_community.document_loaders import UnstructuredEPubLoader
|
from langchain_community.document_loaders import UnstructuredEPubLoader
|
||||||
|
|
||||||
loader = UnstructuredEPubLoader(file_path)
|
loader = UnstructuredEPubLoader(file_path)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Processing .epub files requires the 'unstructured' package. "
|
"Processing .epub files requires the 'unstructured' package. "
|
||||||
"Install it with: pip install unstructured"
|
'Install it with: pip install unstructured'
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
file_content_type
|
file_content_type == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
||||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
or file_ext == 'docx'
|
||||||
or file_ext == "docx"
|
|
||||||
):
|
):
|
||||||
loader = Docx2txtLoader(file_path)
|
loader = Docx2txtLoader(file_path)
|
||||||
elif file_content_type in [
|
elif file_content_type in [
|
||||||
"application/vnd.ms-excel",
|
'application/vnd.ms-excel',
|
||||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||||
] or file_ext in ["xls", "xlsx"]:
|
] or file_ext in ['xls', 'xlsx']:
|
||||||
try:
|
try:
|
||||||
from langchain_community.document_loaders import UnstructuredExcelLoader
|
from langchain_community.document_loaders import UnstructuredExcelLoader
|
||||||
|
|
||||||
loader = UnstructuredExcelLoader(file_path)
|
loader = UnstructuredExcelLoader(file_path)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
log.warning(
|
log.warning(
|
||||||
"The 'unstructured' package is not installed. "
|
"The 'unstructured' package is not installed. "
|
||||||
"Falling back to pandas for Excel file loading. "
|
'Falling back to pandas for Excel file loading. '
|
||||||
"Install unstructured for better results: pip install unstructured"
|
'Install unstructured for better results: pip install unstructured'
|
||||||
)
|
)
|
||||||
loader = ExcelLoader(file_path)
|
loader = ExcelLoader(file_path)
|
||||||
elif file_content_type in [
|
elif file_content_type in [
|
||||||
"application/vnd.ms-powerpoint",
|
'application/vnd.ms-powerpoint',
|
||||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||||
] or file_ext in ["ppt", "pptx"]:
|
] or file_ext in ['ppt', 'pptx']:
|
||||||
try:
|
try:
|
||||||
from langchain_community.document_loaders import UnstructuredPowerPointLoader
|
from langchain_community.document_loaders import UnstructuredPowerPointLoader
|
||||||
|
|
||||||
loader = UnstructuredPowerPointLoader(file_path)
|
loader = UnstructuredPowerPointLoader(file_path)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
log.warning(
|
log.warning(
|
||||||
"The 'unstructured' package is not installed. "
|
"The 'unstructured' package is not installed. "
|
||||||
"Falling back to python-pptx for PowerPoint file loading. "
|
'Falling back to python-pptx for PowerPoint file loading. '
|
||||||
"Install unstructured for better results: pip install unstructured"
|
'Install unstructured for better results: pip install unstructured'
|
||||||
)
|
)
|
||||||
loader = PptxLoader(file_path)
|
loader = PptxLoader(file_path)
|
||||||
elif file_ext == "msg":
|
elif file_ext == 'msg':
|
||||||
loader = OutlookMessageLoader(file_path)
|
loader = OutlookMessageLoader(file_path)
|
||||||
elif file_ext == "odt":
|
elif file_ext == 'odt':
|
||||||
try:
|
try:
|
||||||
from langchain_community.document_loaders import UnstructuredODTLoader
|
from langchain_community.document_loaders import UnstructuredODTLoader
|
||||||
|
|
||||||
loader = UnstructuredODTLoader(file_path)
|
loader = UnstructuredODTLoader(file_path)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Processing .odt files requires the 'unstructured' package. "
|
"Processing .odt files requires the 'unstructured' package. "
|
||||||
"Install it with: pip install unstructured"
|
'Install it with: pip install unstructured'
|
||||||
)
|
)
|
||||||
elif self._is_text_file(file_ext, file_content_type):
|
elif self._is_text_file(file_ext, file_content_type):
|
||||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||||
@@ -498,4 +486,3 @@ class Loader:
|
|||||||
loader = TextLoader(file_path, autodetect_encoding=True)
|
loader = TextLoader(file_path, autodetect_encoding=True)
|
||||||
|
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
|
|||||||
@@ -22,37 +22,35 @@ class MinerULoader:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
api_mode: str = "local",
|
api_mode: str = 'local',
|
||||||
api_url: str = "http://localhost:8000",
|
api_url: str = 'http://localhost:8000',
|
||||||
api_key: str = "",
|
api_key: str = '',
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
timeout: Optional[int] = 300,
|
timeout: Optional[int] = 300,
|
||||||
):
|
):
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.api_mode = api_mode.lower()
|
self.api_mode = api_mode.lower()
|
||||||
self.api_url = api_url.rstrip("/")
|
self.api_url = api_url.rstrip('/')
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
# Parse params dict with defaults
|
# Parse params dict with defaults
|
||||||
self.params = params or {}
|
self.params = params or {}
|
||||||
self.enable_ocr = params.get("enable_ocr", False)
|
self.enable_ocr = params.get('enable_ocr', False)
|
||||||
self.enable_formula = params.get("enable_formula", True)
|
self.enable_formula = params.get('enable_formula', True)
|
||||||
self.enable_table = params.get("enable_table", True)
|
self.enable_table = params.get('enable_table', True)
|
||||||
self.language = params.get("language", "en")
|
self.language = params.get('language', 'en')
|
||||||
self.model_version = params.get("model_version", "pipeline")
|
self.model_version = params.get('model_version', 'pipeline')
|
||||||
|
|
||||||
self.page_ranges = self.params.pop("page_ranges", "")
|
self.page_ranges = self.params.pop('page_ranges', '')
|
||||||
|
|
||||||
# Validate API mode
|
# Validate API mode
|
||||||
if self.api_mode not in ["local", "cloud"]:
|
if self.api_mode not in ['local', 'cloud']:
|
||||||
raise ValueError(
|
raise ValueError(f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'")
|
||||||
f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate Cloud API requirements
|
# Validate Cloud API requirements
|
||||||
if self.api_mode == "cloud" and not self.api_key:
|
if self.api_mode == 'cloud' and not self.api_key:
|
||||||
raise ValueError("API key is required for Cloud API mode")
|
raise ValueError('API key is required for Cloud API mode')
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
@@ -60,12 +58,12 @@ class MinerULoader:
|
|||||||
Routes to Cloud or Local API based on api_mode.
|
Routes to Cloud or Local API based on api_mode.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if self.api_mode == "cloud":
|
if self.api_mode == 'cloud':
|
||||||
return self._load_cloud_api()
|
return self._load_cloud_api()
|
||||||
else:
|
else:
|
||||||
return self._load_local_api()
|
return self._load_local_api()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error loading document with MinerU: {e}")
|
log.error(f'Error loading document with MinerU: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _load_local_api(self) -> List[Document]:
|
def _load_local_api(self) -> List[Document]:
|
||||||
@@ -73,14 +71,14 @@ class MinerULoader:
|
|||||||
Load document using Local API (synchronous).
|
Load document using Local API (synchronous).
|
||||||
Posts file to /file_parse endpoint and gets immediate response.
|
Posts file to /file_parse endpoint and gets immediate response.
|
||||||
"""
|
"""
|
||||||
log.info(f"Using MinerU Local API at {self.api_url}")
|
log.info(f'Using MinerU Local API at {self.api_url}')
|
||||||
|
|
||||||
filename = os.path.basename(self.file_path)
|
filename = os.path.basename(self.file_path)
|
||||||
|
|
||||||
# Build form data for Local API
|
# Build form data for Local API
|
||||||
form_data = {
|
form_data = {
|
||||||
**self.params,
|
**self.params,
|
||||||
"return_md": "true",
|
'return_md': 'true',
|
||||||
}
|
}
|
||||||
|
|
||||||
# Page ranges (Local API uses start_page_id and end_page_id)
|
# Page ranges (Local API uses start_page_id and end_page_id)
|
||||||
@@ -89,18 +87,18 @@ class MinerULoader:
|
|||||||
# Full page range parsing would require parsing the string
|
# Full page range parsing would require parsing the string
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
|
f"Page ranges '{self.page_ranges}' specified but Local API uses different format. "
|
||||||
"Consider using start_page_id/end_page_id parameters if needed."
|
'Consider using start_page_id/end_page_id parameters if needed.'
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, 'rb') as f:
|
||||||
files = {"files": (filename, f, "application/octet-stream")}
|
files = {'files': (filename, f, 'application/octet-stream')}
|
||||||
|
|
||||||
log.info(f"Sending file to MinerU Local API: {filename}")
|
log.info(f'Sending file to MinerU Local API: {filename}')
|
||||||
log.debug(f"Local API parameters: {form_data}")
|
log.debug(f'Local API parameters: {form_data}')
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{self.api_url}/file_parse",
|
f'{self.api_url}/file_parse',
|
||||||
data=form_data,
|
data=form_data,
|
||||||
files=files,
|
files=files,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
@@ -108,27 +106,25 @@ class MinerULoader:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise HTTPException(
|
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
|
||||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
|
||||||
)
|
|
||||||
except requests.Timeout:
|
except requests.Timeout:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
detail="MinerU Local API request timed out",
|
detail='MinerU Local API request timed out',
|
||||||
)
|
)
|
||||||
except requests.HTTPError as e:
|
except requests.HTTPError as e:
|
||||||
error_detail = f"MinerU Local API request failed: {e}"
|
error_detail = f'MinerU Local API request failed: {e}'
|
||||||
if e.response is not None:
|
if e.response is not None:
|
||||||
try:
|
try:
|
||||||
error_data = e.response.json()
|
error_data = e.response.json()
|
||||||
error_detail += f" - {error_data}"
|
error_detail += f' - {error_data}'
|
||||||
except Exception:
|
except Exception:
|
||||||
error_detail += f" - {e.response.text}"
|
error_detail += f' - {e.response.text}'
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Error calling MinerU Local API: {str(e)}",
|
detail=f'Error calling MinerU Local API: {str(e)}',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Parse response
|
# Parse response
|
||||||
@@ -137,41 +133,41 @@ class MinerULoader:
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_502_BAD_GATEWAY,
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
detail=f"Invalid JSON response from MinerU Local API: {e}",
|
detail=f'Invalid JSON response from MinerU Local API: {e}',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract markdown content from response
|
# Extract markdown content from response
|
||||||
if "results" not in result:
|
if 'results' not in result:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_502_BAD_GATEWAY,
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
detail="MinerU Local API response missing 'results' field",
|
detail="MinerU Local API response missing 'results' field",
|
||||||
)
|
)
|
||||||
|
|
||||||
results = result["results"]
|
results = result['results']
|
||||||
if not results:
|
if not results:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail="MinerU returned empty results",
|
detail='MinerU returned empty results',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the first (and typically only) result
|
# Get the first (and typically only) result
|
||||||
file_result = list(results.values())[0]
|
file_result = list(results.values())[0]
|
||||||
markdown_content = file_result.get("md_content", "")
|
markdown_content = file_result.get('md_content', '')
|
||||||
|
|
||||||
if not markdown_content:
|
if not markdown_content:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail="MinerU returned empty markdown content",
|
detail='MinerU returned empty markdown content',
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(f"Successfully parsed document with MinerU Local API: {filename}")
|
log.info(f'Successfully parsed document with MinerU Local API: {filename}')
|
||||||
|
|
||||||
# Create metadata
|
# Create metadata
|
||||||
metadata = {
|
metadata = {
|
||||||
"source": filename,
|
'source': filename,
|
||||||
"api_mode": "local",
|
'api_mode': 'local',
|
||||||
"backend": result.get("backend", "unknown"),
|
'backend': result.get('backend', 'unknown'),
|
||||||
"version": result.get("version", "unknown"),
|
'version': result.get('version', 'unknown'),
|
||||||
}
|
}
|
||||||
|
|
||||||
return [Document(page_content=markdown_content, metadata=metadata)]
|
return [Document(page_content=markdown_content, metadata=metadata)]
|
||||||
@@ -181,7 +177,7 @@ class MinerULoader:
|
|||||||
Load document using Cloud API (asynchronous).
|
Load document using Cloud API (asynchronous).
|
||||||
Uses batch upload endpoint to avoid need for public file URLs.
|
Uses batch upload endpoint to avoid need for public file URLs.
|
||||||
"""
|
"""
|
||||||
log.info(f"Using MinerU Cloud API at {self.api_url}")
|
log.info(f'Using MinerU Cloud API at {self.api_url}')
|
||||||
|
|
||||||
filename = os.path.basename(self.file_path)
|
filename = os.path.basename(self.file_path)
|
||||||
|
|
||||||
@@ -195,17 +191,15 @@ class MinerULoader:
|
|||||||
result = self._poll_batch_status(batch_id, filename)
|
result = self._poll_batch_status(batch_id, filename)
|
||||||
|
|
||||||
# Step 4: Download and extract markdown from ZIP
|
# Step 4: Download and extract markdown from ZIP
|
||||||
markdown_content = self._download_and_extract_zip(
|
markdown_content = self._download_and_extract_zip(result['full_zip_url'], filename)
|
||||||
result["full_zip_url"], filename
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info(f"Successfully parsed document with MinerU Cloud API: {filename}")
|
log.info(f'Successfully parsed document with MinerU Cloud API: {filename}')
|
||||||
|
|
||||||
# Create metadata
|
# Create metadata
|
||||||
metadata = {
|
metadata = {
|
||||||
"source": filename,
|
'source': filename,
|
||||||
"api_mode": "cloud",
|
'api_mode': 'cloud',
|
||||||
"batch_id": batch_id,
|
'batch_id': batch_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
return [Document(page_content=markdown_content, metadata=metadata)]
|
return [Document(page_content=markdown_content, metadata=metadata)]
|
||||||
@@ -216,49 +210,49 @@ class MinerULoader:
|
|||||||
Returns (batch_id, upload_url).
|
Returns (batch_id, upload_url).
|
||||||
"""
|
"""
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
"Content-Type": "application/json",
|
'Content-Type': 'application/json',
|
||||||
}
|
}
|
||||||
|
|
||||||
# Build request body
|
# Build request body
|
||||||
request_body = {
|
request_body = {
|
||||||
**self.params,
|
**self.params,
|
||||||
"files": [
|
'files': [
|
||||||
{
|
{
|
||||||
"name": filename,
|
'name': filename,
|
||||||
"is_ocr": self.enable_ocr,
|
'is_ocr': self.enable_ocr,
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add page ranges if specified
|
# Add page ranges if specified
|
||||||
if self.page_ranges:
|
if self.page_ranges:
|
||||||
request_body["files"][0]["page_ranges"] = self.page_ranges
|
request_body['files'][0]['page_ranges'] = self.page_ranges
|
||||||
|
|
||||||
log.info(f"Requesting upload URL for: {filename}")
|
log.info(f'Requesting upload URL for: {filename}')
|
||||||
log.debug(f"Cloud API request body: {request_body}")
|
log.debug(f'Cloud API request body: {request_body}')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{self.api_url}/file-urls/batch",
|
f'{self.api_url}/file-urls/batch',
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=request_body,
|
json=request_body,
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except requests.HTTPError as e:
|
except requests.HTTPError as e:
|
||||||
error_detail = f"Failed to request upload URL: {e}"
|
error_detail = f'Failed to request upload URL: {e}'
|
||||||
if e.response is not None:
|
if e.response is not None:
|
||||||
try:
|
try:
|
||||||
error_data = e.response.json()
|
error_data = e.response.json()
|
||||||
error_detail += f" - {error_data.get('msg', error_data)}"
|
error_detail += f' - {error_data.get("msg", error_data)}'
|
||||||
except Exception:
|
except Exception:
|
||||||
error_detail += f" - {e.response.text}"
|
error_detail += f' - {e.response.text}'
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Error requesting upload URL: {str(e)}",
|
detail=f'Error requesting upload URL: {str(e)}',
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -266,28 +260,28 @@ class MinerULoader:
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_502_BAD_GATEWAY,
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
detail=f"Invalid JSON response: {e}",
|
detail=f'Invalid JSON response: {e}',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for API error response
|
# Check for API error response
|
||||||
if result.get("code") != 0:
|
if result.get('code') != 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
|
detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
|
||||||
)
|
)
|
||||||
|
|
||||||
data = result.get("data", {})
|
data = result.get('data', {})
|
||||||
batch_id = data.get("batch_id")
|
batch_id = data.get('batch_id')
|
||||||
file_urls = data.get("file_urls", [])
|
file_urls = data.get('file_urls', [])
|
||||||
|
|
||||||
if not batch_id or not file_urls:
|
if not batch_id or not file_urls:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_502_BAD_GATEWAY,
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
detail="MinerU Cloud API response missing batch_id or file_urls",
|
detail='MinerU Cloud API response missing batch_id or file_urls',
|
||||||
)
|
)
|
||||||
|
|
||||||
upload_url = file_urls[0]
|
upload_url = file_urls[0]
|
||||||
log.info(f"Received upload URL for batch: {batch_id}")
|
log.info(f'Received upload URL for batch: {batch_id}')
|
||||||
|
|
||||||
return batch_id, upload_url
|
return batch_id, upload_url
|
||||||
|
|
||||||
@@ -295,10 +289,10 @@ class MinerULoader:
|
|||||||
"""
|
"""
|
||||||
Upload file to presigned URL (no authentication needed).
|
Upload file to presigned URL (no authentication needed).
|
||||||
"""
|
"""
|
||||||
log.info(f"Uploading file to presigned URL")
|
log.info(f'Uploading file to presigned URL')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, 'rb') as f:
|
||||||
response = requests.put(
|
response = requests.put(
|
||||||
upload_url,
|
upload_url,
|
||||||
data=f,
|
data=f,
|
||||||
@@ -306,26 +300,24 @@ class MinerULoader:
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
raise HTTPException(
|
raise HTTPException(status.HTTP_404_NOT_FOUND, detail=f'File not found: {self.file_path}')
|
||||||
status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}"
|
|
||||||
)
|
|
||||||
except requests.Timeout:
|
except requests.Timeout:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
detail="File upload to presigned URL timed out",
|
detail='File upload to presigned URL timed out',
|
||||||
)
|
)
|
||||||
except requests.HTTPError as e:
|
except requests.HTTPError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Failed to upload file to presigned URL: {e}",
|
detail=f'Failed to upload file to presigned URL: {e}',
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Error uploading file: {str(e)}",
|
detail=f'Error uploading file: {str(e)}',
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info("File uploaded successfully")
|
log.info('File uploaded successfully')
|
||||||
|
|
||||||
def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
|
def _poll_batch_status(self, batch_id: str, filename: str) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -333,35 +325,35 @@ class MinerULoader:
|
|||||||
Returns the result dict for the file.
|
Returns the result dict for the file.
|
||||||
"""
|
"""
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
}
|
}
|
||||||
|
|
||||||
max_iterations = 300 # 10 minutes max (2 seconds per iteration)
|
max_iterations = 300 # 10 minutes max (2 seconds per iteration)
|
||||||
poll_interval = 2 # seconds
|
poll_interval = 2 # seconds
|
||||||
|
|
||||||
log.info(f"Polling batch status: {batch_id}")
|
log.info(f'Polling batch status: {batch_id}')
|
||||||
|
|
||||||
for iteration in range(max_iterations):
|
for iteration in range(max_iterations):
|
||||||
try:
|
try:
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{self.api_url}/extract-results/batch/{batch_id}",
|
f'{self.api_url}/extract-results/batch/{batch_id}',
|
||||||
headers=headers,
|
headers=headers,
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except requests.HTTPError as e:
|
except requests.HTTPError as e:
|
||||||
error_detail = f"Failed to poll batch status: {e}"
|
error_detail = f'Failed to poll batch status: {e}'
|
||||||
if e.response is not None:
|
if e.response is not None:
|
||||||
try:
|
try:
|
||||||
error_data = e.response.json()
|
error_data = e.response.json()
|
||||||
error_detail += f" - {error_data.get('msg', error_data)}"
|
error_detail += f' - {error_data.get("msg", error_data)}'
|
||||||
except Exception:
|
except Exception:
|
||||||
error_detail += f" - {e.response.text}"
|
error_detail += f' - {e.response.text}'
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Error polling batch status: {str(e)}",
|
detail=f'Error polling batch status: {str(e)}',
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -369,58 +361,56 @@ class MinerULoader:
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_502_BAD_GATEWAY,
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
detail=f"Invalid JSON response while polling: {e}",
|
detail=f'Invalid JSON response while polling: {e}',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check for API error response
|
# Check for API error response
|
||||||
if result.get("code") != 0:
|
if result.get('code') != 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}",
|
detail=f'MinerU Cloud API error: {result.get("msg", "Unknown error")}',
|
||||||
)
|
)
|
||||||
|
|
||||||
data = result.get("data", {})
|
data = result.get('data', {})
|
||||||
extract_result = data.get("extract_result", [])
|
extract_result = data.get('extract_result', [])
|
||||||
|
|
||||||
# Find our file in the batch results
|
# Find our file in the batch results
|
||||||
file_result = None
|
file_result = None
|
||||||
for item in extract_result:
|
for item in extract_result:
|
||||||
if item.get("file_name") == filename:
|
if item.get('file_name') == filename:
|
||||||
file_result = item
|
file_result = item
|
||||||
break
|
break
|
||||||
|
|
||||||
if not file_result:
|
if not file_result:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_502_BAD_GATEWAY,
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
detail=f"File {filename} not found in batch results",
|
detail=f'File {filename} not found in batch results',
|
||||||
)
|
)
|
||||||
|
|
||||||
state = file_result.get("state")
|
state = file_result.get('state')
|
||||||
|
|
||||||
if state == "done":
|
if state == 'done':
|
||||||
log.info(f"Processing complete for {filename}")
|
log.info(f'Processing complete for {filename}')
|
||||||
return file_result
|
return file_result
|
||||||
elif state == "failed":
|
elif state == 'failed':
|
||||||
error_msg = file_result.get("err_msg", "Unknown error")
|
error_msg = file_result.get('err_msg', 'Unknown error')
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"MinerU processing failed: {error_msg}",
|
detail=f'MinerU processing failed: {error_msg}',
|
||||||
)
|
)
|
||||||
elif state in ["waiting-file", "pending", "running", "converting"]:
|
elif state in ['waiting-file', 'pending', 'running', 'converting']:
|
||||||
# Still processing
|
# Still processing
|
||||||
if iteration % 10 == 0: # Log every 20 seconds
|
if iteration % 10 == 0: # Log every 20 seconds
|
||||||
log.info(
|
log.info(f'Processing status: {state} (iteration {iteration + 1}/{max_iterations})')
|
||||||
f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})"
|
|
||||||
)
|
|
||||||
time.sleep(poll_interval)
|
time.sleep(poll_interval)
|
||||||
else:
|
else:
|
||||||
log.warning(f"Unknown state: {state}")
|
log.warning(f'Unknown state: {state}')
|
||||||
time.sleep(poll_interval)
|
time.sleep(poll_interval)
|
||||||
|
|
||||||
# Timeout
|
# Timeout
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_504_GATEWAY_TIMEOUT,
|
status.HTTP_504_GATEWAY_TIMEOUT,
|
||||||
detail="MinerU processing timed out after 10 minutes",
|
detail='MinerU processing timed out after 10 minutes',
|
||||||
)
|
)
|
||||||
|
|
||||||
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
|
def _download_and_extract_zip(self, zip_url: str, filename: str) -> str:
|
||||||
@@ -428,7 +418,7 @@ class MinerULoader:
|
|||||||
Download ZIP file from CDN and extract markdown content.
|
Download ZIP file from CDN and extract markdown content.
|
||||||
Returns the markdown content as a string.
|
Returns the markdown content as a string.
|
||||||
"""
|
"""
|
||||||
log.info(f"Downloading results from: {zip_url}")
|
log.info(f'Downloading results from: {zip_url}')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(zip_url, timeout=60)
|
response = requests.get(zip_url, timeout=60)
|
||||||
@@ -436,23 +426,23 @@ class MinerULoader:
|
|||||||
except requests.HTTPError as e:
|
except requests.HTTPError as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"Failed to download results ZIP: {e}",
|
detail=f'Failed to download results ZIP: {e}',
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Error downloading results: {str(e)}",
|
detail=f'Error downloading results: {str(e)}',
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save ZIP to temporary file and extract
|
# Save ZIP to temporary file and extract
|
||||||
try:
|
try:
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.zip') as tmp_zip:
|
||||||
tmp_zip.write(response.content)
|
tmp_zip.write(response.content)
|
||||||
tmp_zip_path = tmp_zip.name
|
tmp_zip_path = tmp_zip.name
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
# Extract ZIP
|
# Extract ZIP
|
||||||
with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref:
|
with zipfile.ZipFile(tmp_zip_path, 'r') as zip_ref:
|
||||||
zip_ref.extractall(tmp_dir)
|
zip_ref.extractall(tmp_dir)
|
||||||
|
|
||||||
# Find markdown file - search recursively for any .md file
|
# Find markdown file - search recursively for any .md file
|
||||||
@@ -466,33 +456,27 @@ class MinerULoader:
|
|||||||
full_path = os.path.join(root, file)
|
full_path = os.path.join(root, file)
|
||||||
all_files.append(full_path)
|
all_files.append(full_path)
|
||||||
# Look for any .md file
|
# Look for any .md file
|
||||||
if file.endswith(".md"):
|
if file.endswith('.md'):
|
||||||
found_md_path = full_path
|
found_md_path = full_path
|
||||||
log.info(f"Found markdown file at: {full_path}")
|
log.info(f'Found markdown file at: {full_path}')
|
||||||
try:
|
try:
|
||||||
with open(full_path, "r", encoding="utf-8") as f:
|
with open(full_path, 'r', encoding='utf-8') as f:
|
||||||
markdown_content = f.read()
|
markdown_content = f.read()
|
||||||
if (
|
if markdown_content: # Use the first non-empty markdown file
|
||||||
markdown_content
|
|
||||||
): # Use the first non-empty markdown file
|
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warning(f"Failed to read {full_path}: {e}")
|
log.warning(f'Failed to read {full_path}: {e}')
|
||||||
if markdown_content:
|
if markdown_content:
|
||||||
break
|
break
|
||||||
|
|
||||||
if markdown_content is None:
|
if markdown_content is None:
|
||||||
log.error(f"Available files in ZIP: {all_files}")
|
log.error(f'Available files in ZIP: {all_files}')
|
||||||
# Try to provide more helpful error message
|
# Try to provide more helpful error message
|
||||||
md_files = [f for f in all_files if f.endswith(".md")]
|
md_files = [f for f in all_files if f.endswith('.md')]
|
||||||
if md_files:
|
if md_files:
|
||||||
error_msg = (
|
error_msg = f"Found .md files but couldn't read them: {md_files}"
|
||||||
f"Found .md files but couldn't read them: {md_files}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
error_msg = (
|
error_msg = f'No .md files found in ZIP. Available files: {all_files}'
|
||||||
f"No .md files found in ZIP. Available files: {all_files}"
|
|
||||||
)
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_502_BAD_GATEWAY,
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
detail=error_msg,
|
detail=error_msg,
|
||||||
@@ -504,21 +488,19 @@ class MinerULoader:
|
|||||||
except zipfile.BadZipFile as e:
|
except zipfile.BadZipFile as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_502_BAD_GATEWAY,
|
status.HTTP_502_BAD_GATEWAY,
|
||||||
detail=f"Invalid ZIP file received: {e}",
|
detail=f'Invalid ZIP file received: {e}',
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Error extracting ZIP: {str(e)}",
|
detail=f'Error extracting ZIP: {str(e)}',
|
||||||
)
|
)
|
||||||
|
|
||||||
if not markdown_content:
|
if not markdown_content:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_400_BAD_REQUEST,
|
status.HTTP_400_BAD_REQUEST,
|
||||||
detail="Extracted markdown content is empty",
|
detail='Extracted markdown content is empty',
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info(
|
log.info(f'Successfully extracted markdown content ({len(markdown_content)} characters)')
|
||||||
f"Successfully extracted markdown content ({len(markdown_content)} characters)"
|
|
||||||
)
|
|
||||||
return markdown_content
|
return markdown_content
|
||||||
|
|||||||
@@ -49,13 +49,11 @@ class MistralLoader:
|
|||||||
enable_debug_logging: Enable detailed debug logs.
|
enable_debug_logging: Enable detailed debug logs.
|
||||||
"""
|
"""
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("API key cannot be empty.")
|
raise ValueError('API key cannot be empty.')
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
raise FileNotFoundError(f"File not found at {file_path}")
|
raise FileNotFoundError(f'File not found at {file_path}')
|
||||||
|
|
||||||
self.base_url = (
|
self.base_url = base_url.rstrip('/') if base_url else 'https://api.mistral.ai/v1'
|
||||||
base_url.rstrip("/") if base_url else "https://api.mistral.ai/v1"
|
|
||||||
)
|
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.file_path = file_path
|
self.file_path = file_path
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
@@ -65,18 +63,10 @@ class MistralLoader:
|
|||||||
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
|
# PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations
|
||||||
# This prevents long-running OCR operations from affecting quick operations
|
# This prevents long-running OCR operations from affecting quick operations
|
||||||
# and improves user experience by failing fast on operations that should be quick
|
# and improves user experience by failing fast on operations that should be quick
|
||||||
self.upload_timeout = min(
|
self.upload_timeout = min(timeout, 120) # Cap upload at 2 minutes - prevents hanging on large files
|
||||||
timeout, 120
|
self.url_timeout = 30 # URL requests should be fast - fail quickly if API is slow
|
||||||
) # Cap upload at 2 minutes - prevents hanging on large files
|
self.ocr_timeout = timeout # OCR can take the full timeout - this is the heavy operation
|
||||||
self.url_timeout = (
|
self.cleanup_timeout = 30 # Cleanup should be quick - don't hang on file deletion
|
||||||
30 # URL requests should be fast - fail quickly if API is slow
|
|
||||||
)
|
|
||||||
self.ocr_timeout = (
|
|
||||||
timeout # OCR can take the full timeout - this is the heavy operation
|
|
||||||
)
|
|
||||||
self.cleanup_timeout = (
|
|
||||||
30 # Cleanup should be quick - don't hang on file deletion
|
|
||||||
)
|
|
||||||
|
|
||||||
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
|
# PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls
|
||||||
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
|
# This avoids multiple os.path.basename() and os.path.getsize() calls during processing
|
||||||
@@ -85,8 +75,8 @@ class MistralLoader:
|
|||||||
|
|
||||||
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
|
# ENHANCEMENT: Added User-Agent for better API tracking and debugging
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
"User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage
|
'User-Agent': 'OpenWebUI-MistralLoader/2.0', # Helps API provider track usage
|
||||||
}
|
}
|
||||||
|
|
||||||
def _debug_log(self, message: str, *args) -> None:
|
def _debug_log(self, message: str, *args) -> None:
|
||||||
@@ -108,43 +98,39 @@ class MistralLoader:
|
|||||||
return {} # Return empty dict if no content
|
return {} # Return empty dict if no content
|
||||||
return response.json()
|
return response.json()
|
||||||
except requests.exceptions.HTTPError as http_err:
|
except requests.exceptions.HTTPError as http_err:
|
||||||
log.error(f"HTTP error occurred: {http_err} - Response: {response.text}")
|
log.error(f'HTTP error occurred: {http_err} - Response: {response.text}')
|
||||||
raise
|
raise
|
||||||
except requests.exceptions.RequestException as req_err:
|
except requests.exceptions.RequestException as req_err:
|
||||||
log.error(f"Request exception occurred: {req_err}")
|
log.error(f'Request exception occurred: {req_err}')
|
||||||
raise
|
raise
|
||||||
except ValueError as json_err: # Includes JSONDecodeError
|
except ValueError as json_err: # Includes JSONDecodeError
|
||||||
log.error(f"JSON decode error: {json_err} - Response: {response.text}")
|
log.error(f'JSON decode error: {json_err} - Response: {response.text}')
|
||||||
raise # Re-raise after logging
|
raise # Re-raise after logging
|
||||||
|
|
||||||
async def _handle_response_async(
|
async def _handle_response_async(self, response: aiohttp.ClientResponse) -> Dict[str, Any]:
|
||||||
self, response: aiohttp.ClientResponse
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Async version of response handling with better error info."""
|
"""Async version of response handling with better error info."""
|
||||||
try:
|
try:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
# Check content type
|
# Check content type
|
||||||
content_type = response.headers.get("content-type", "")
|
content_type = response.headers.get('content-type', '')
|
||||||
if "application/json" not in content_type:
|
if 'application/json' not in content_type:
|
||||||
if response.status == 204:
|
if response.status == 204:
|
||||||
return {}
|
return {}
|
||||||
text = await response.text()
|
text = await response.text()
|
||||||
raise ValueError(
|
raise ValueError(f'Unexpected content type: {content_type}, body: {text[:200]}...')
|
||||||
f"Unexpected content type: {content_type}, body: {text[:200]}..."
|
|
||||||
)
|
|
||||||
|
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
except aiohttp.ClientResponseError as e:
|
except aiohttp.ClientResponseError as e:
|
||||||
error_text = await response.text() if response else "No response"
|
error_text = await response.text() if response else 'No response'
|
||||||
log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}")
|
log.error(f'HTTP {e.status}: {e.message} - Response: {error_text[:500]}')
|
||||||
raise
|
raise
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
log.error(f"Client error: {e}")
|
log.error(f'Client error: {e}')
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Unexpected error processing response: {e}")
|
log.error(f'Unexpected error processing response: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _is_retryable_error(self, error: Exception) -> bool:
|
def _is_retryable_error(self, error: Exception) -> bool:
|
||||||
@@ -172,13 +158,11 @@ class MistralLoader:
|
|||||||
return True # Timeouts might resolve on retry
|
return True # Timeouts might resolve on retry
|
||||||
if isinstance(error, requests.exceptions.HTTPError):
|
if isinstance(error, requests.exceptions.HTTPError):
|
||||||
# Only retry on server errors (5xx) or rate limits (429)
|
# Only retry on server errors (5xx) or rate limits (429)
|
||||||
if hasattr(error, "response") and error.response is not None:
|
if hasattr(error, 'response') and error.response is not None:
|
||||||
status_code = error.response.status_code
|
status_code = error.response.status_code
|
||||||
return status_code >= 500 or status_code == 429
|
return status_code >= 500 or status_code == 429
|
||||||
return False
|
return False
|
||||||
if isinstance(
|
if isinstance(error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)):
|
||||||
error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError)
|
|
||||||
):
|
|
||||||
return True # Async network/timeout errors are retryable
|
return True # Async network/timeout errors are retryable
|
||||||
if isinstance(error, aiohttp.ClientResponseError):
|
if isinstance(error, aiohttp.ClientResponseError):
|
||||||
return error.status >= 500 or error.status == 429
|
return error.status >= 500 or error.status == 429
|
||||||
@@ -204,8 +188,7 @@ class MistralLoader:
|
|||||||
# Prevents overwhelming the server while ensuring reasonable retry delays
|
# Prevents overwhelming the server while ensuring reasonable retry delays
|
||||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
|
||||||
f"Retrying in {wait_time}s..."
|
|
||||||
)
|
)
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
|
|
||||||
@@ -226,8 +209,7 @@ class MistralLoader:
|
|||||||
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
|
# PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff
|
||||||
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. "
|
f'Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. Retrying in {wait_time}s...'
|
||||||
f"Retrying in {wait_time}s..."
|
|
||||||
)
|
)
|
||||||
await asyncio.sleep(wait_time) # Non-blocking wait
|
await asyncio.sleep(wait_time) # Non-blocking wait
|
||||||
|
|
||||||
@@ -240,15 +222,15 @@ class MistralLoader:
|
|||||||
Although streaming is not enabled for this endpoint, the file is opened
|
Although streaming is not enabled for this endpoint, the file is opened
|
||||||
in a context manager to minimize memory usage duration.
|
in a context manager to minimize memory usage duration.
|
||||||
"""
|
"""
|
||||||
log.info("Uploading file to Mistral API")
|
log.info('Uploading file to Mistral API')
|
||||||
url = f"{self.base_url}/files"
|
url = f'{self.base_url}/files'
|
||||||
|
|
||||||
def upload_request():
|
def upload_request():
|
||||||
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
|
# MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime
|
||||||
# This ensures the file is closed immediately after reading, reducing memory usage
|
# This ensures the file is closed immediately after reading, reducing memory usage
|
||||||
with open(self.file_path, "rb") as f:
|
with open(self.file_path, 'rb') as f:
|
||||||
files = {"file": (self.file_name, f, "application/pdf")}
|
files = {'file': (self.file_name, f, 'application/pdf')}
|
||||||
data = {"purpose": "ocr"}
|
data = {'purpose': 'ocr'}
|
||||||
|
|
||||||
# NOTE: stream=False is required for this endpoint
|
# NOTE: stream=False is required for this endpoint
|
||||||
# The Mistral API doesn't support chunked uploads for this endpoint
|
# The Mistral API doesn't support chunked uploads for this endpoint
|
||||||
@@ -265,42 +247,38 @@ class MistralLoader:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response_data = self._retry_request_sync(upload_request)
|
response_data = self._retry_request_sync(upload_request)
|
||||||
file_id = response_data.get("id")
|
file_id = response_data.get('id')
|
||||||
if not file_id:
|
if not file_id:
|
||||||
raise ValueError("File ID not found in upload response.")
|
raise ValueError('File ID not found in upload response.')
|
||||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
log.info(f'File uploaded successfully. File ID: {file_id}')
|
||||||
return file_id
|
return file_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Failed to upload file: {e}")
|
log.error(f'Failed to upload file: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
|
async def _upload_file_async(self, session: aiohttp.ClientSession) -> str:
|
||||||
"""Async file upload with streaming for better memory efficiency."""
|
"""Async file upload with streaming for better memory efficiency."""
|
||||||
url = f"{self.base_url}/files"
|
url = f'{self.base_url}/files'
|
||||||
|
|
||||||
async def upload_request():
|
async def upload_request():
|
||||||
# Create multipart writer for streaming upload
|
# Create multipart writer for streaming upload
|
||||||
writer = aiohttp.MultipartWriter("form-data")
|
writer = aiohttp.MultipartWriter('form-data')
|
||||||
|
|
||||||
# Add purpose field
|
# Add purpose field
|
||||||
purpose_part = writer.append("ocr")
|
purpose_part = writer.append('ocr')
|
||||||
purpose_part.set_content_disposition("form-data", name="purpose")
|
purpose_part.set_content_disposition('form-data', name='purpose')
|
||||||
|
|
||||||
# Add file part with streaming
|
# Add file part with streaming
|
||||||
file_part = writer.append_payload(
|
file_part = writer.append_payload(
|
||||||
aiohttp.streams.FilePayload(
|
aiohttp.streams.FilePayload(
|
||||||
self.file_path,
|
self.file_path,
|
||||||
filename=self.file_name,
|
filename=self.file_name,
|
||||||
content_type="application/pdf",
|
content_type='application/pdf',
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
file_part.set_content_disposition(
|
file_part.set_content_disposition('form-data', name='file', filename=self.file_name)
|
||||||
"form-data", name="file", filename=self.file_name
|
|
||||||
)
|
|
||||||
|
|
||||||
self._debug_log(
|
self._debug_log(f'Uploading file: {self.file_name} ({self.file_size:,} bytes)')
|
||||||
f"Uploading file: {self.file_name} ({self.file_size:,} bytes)"
|
|
||||||
)
|
|
||||||
|
|
||||||
async with session.post(
|
async with session.post(
|
||||||
url,
|
url,
|
||||||
@@ -312,48 +290,44 @@ class MistralLoader:
|
|||||||
|
|
||||||
response_data = await self._retry_request_async(upload_request)
|
response_data = await self._retry_request_async(upload_request)
|
||||||
|
|
||||||
file_id = response_data.get("id")
|
file_id = response_data.get('id')
|
||||||
if not file_id:
|
if not file_id:
|
||||||
raise ValueError("File ID not found in upload response.")
|
raise ValueError('File ID not found in upload response.')
|
||||||
|
|
||||||
log.info(f"File uploaded successfully. File ID: {file_id}")
|
log.info(f'File uploaded successfully. File ID: {file_id}')
|
||||||
return file_id
|
return file_id
|
||||||
|
|
||||||
def _get_signed_url(self, file_id: str) -> str:
|
def _get_signed_url(self, file_id: str) -> str:
|
||||||
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
|
"""Retrieves a temporary signed URL for the uploaded file (sync version)."""
|
||||||
log.info(f"Getting signed URL for file ID: {file_id}")
|
log.info(f'Getting signed URL for file ID: {file_id}')
|
||||||
url = f"{self.base_url}/files/{file_id}/url"
|
url = f'{self.base_url}/files/{file_id}/url'
|
||||||
params = {"expiry": 1}
|
params = {'expiry': 1}
|
||||||
signed_url_headers = {**self.headers, "Accept": "application/json"}
|
signed_url_headers = {**self.headers, 'Accept': 'application/json'}
|
||||||
|
|
||||||
def url_request():
|
def url_request():
|
||||||
response = requests.get(
|
response = requests.get(url, headers=signed_url_headers, params=params, timeout=self.url_timeout)
|
||||||
url, headers=signed_url_headers, params=params, timeout=self.url_timeout
|
|
||||||
)
|
|
||||||
return self._handle_response(response)
|
return self._handle_response(response)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_data = self._retry_request_sync(url_request)
|
response_data = self._retry_request_sync(url_request)
|
||||||
signed_url = response_data.get("url")
|
signed_url = response_data.get('url')
|
||||||
if not signed_url:
|
if not signed_url:
|
||||||
raise ValueError("Signed URL not found in response.")
|
raise ValueError('Signed URL not found in response.')
|
||||||
log.info("Signed URL received.")
|
log.info('Signed URL received.')
|
||||||
return signed_url
|
return signed_url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Failed to get signed URL: {e}")
|
log.error(f'Failed to get signed URL: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _get_signed_url_async(
|
async def _get_signed_url_async(self, session: aiohttp.ClientSession, file_id: str) -> str:
|
||||||
self, session: aiohttp.ClientSession, file_id: str
|
|
||||||
) -> str:
|
|
||||||
"""Async signed URL retrieval."""
|
"""Async signed URL retrieval."""
|
||||||
url = f"{self.base_url}/files/{file_id}/url"
|
url = f'{self.base_url}/files/{file_id}/url'
|
||||||
params = {"expiry": 1}
|
params = {'expiry': 1}
|
||||||
|
|
||||||
headers = {**self.headers, "Accept": "application/json"}
|
headers = {**self.headers, 'Accept': 'application/json'}
|
||||||
|
|
||||||
async def url_request():
|
async def url_request():
|
||||||
self._debug_log(f"Getting signed URL for file ID: {file_id}")
|
self._debug_log(f'Getting signed URL for file ID: {file_id}')
|
||||||
async with session.get(
|
async with session.get(
|
||||||
url,
|
url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@@ -364,69 +338,65 @@ class MistralLoader:
|
|||||||
|
|
||||||
response_data = await self._retry_request_async(url_request)
|
response_data = await self._retry_request_async(url_request)
|
||||||
|
|
||||||
signed_url = response_data.get("url")
|
signed_url = response_data.get('url')
|
||||||
if not signed_url:
|
if not signed_url:
|
||||||
raise ValueError("Signed URL not found in response.")
|
raise ValueError('Signed URL not found in response.')
|
||||||
|
|
||||||
self._debug_log("Signed URL received successfully")
|
self._debug_log('Signed URL received successfully')
|
||||||
return signed_url
|
return signed_url
|
||||||
|
|
||||||
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
|
def _process_ocr(self, signed_url: str) -> Dict[str, Any]:
|
||||||
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
|
"""Sends the signed URL to the OCR endpoint for processing (sync version)."""
|
||||||
log.info("Processing OCR via Mistral API")
|
log.info('Processing OCR via Mistral API')
|
||||||
url = f"{self.base_url}/ocr"
|
url = f'{self.base_url}/ocr'
|
||||||
ocr_headers = {
|
ocr_headers = {
|
||||||
**self.headers,
|
**self.headers,
|
||||||
"Content-Type": "application/json",
|
'Content-Type': 'application/json',
|
||||||
"Accept": "application/json",
|
'Accept': 'application/json',
|
||||||
}
|
}
|
||||||
payload = {
|
payload = {
|
||||||
"model": "mistral-ocr-latest",
|
'model': 'mistral-ocr-latest',
|
||||||
"document": {
|
'document': {
|
||||||
"type": "document_url",
|
'type': 'document_url',
|
||||||
"document_url": signed_url,
|
'document_url': signed_url,
|
||||||
},
|
},
|
||||||
"include_image_base64": False,
|
'include_image_base64': False,
|
||||||
}
|
}
|
||||||
|
|
||||||
def ocr_request():
|
def ocr_request():
|
||||||
response = requests.post(
|
response = requests.post(url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout)
|
||||||
url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout
|
|
||||||
)
|
|
||||||
return self._handle_response(response)
|
return self._handle_response(response)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ocr_response = self._retry_request_sync(ocr_request)
|
ocr_response = self._retry_request_sync(ocr_request)
|
||||||
log.info("OCR processing done.")
|
log.info('OCR processing done.')
|
||||||
self._debug_log("OCR response: %s", ocr_response)
|
self._debug_log('OCR response: %s', ocr_response)
|
||||||
return ocr_response
|
return ocr_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Failed during OCR processing: {e}")
|
log.error(f'Failed during OCR processing: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _process_ocr_async(
|
async def _process_ocr_async(self, session: aiohttp.ClientSession, signed_url: str) -> Dict[str, Any]:
|
||||||
self, session: aiohttp.ClientSession, signed_url: str
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Async OCR processing with timing metrics."""
|
"""Async OCR processing with timing metrics."""
|
||||||
url = f"{self.base_url}/ocr"
|
url = f'{self.base_url}/ocr'
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
**self.headers,
|
**self.headers,
|
||||||
"Content-Type": "application/json",
|
'Content-Type': 'application/json',
|
||||||
"Accept": "application/json",
|
'Accept': 'application/json',
|
||||||
}
|
}
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": "mistral-ocr-latest",
|
'model': 'mistral-ocr-latest',
|
||||||
"document": {
|
'document': {
|
||||||
"type": "document_url",
|
'type': 'document_url',
|
||||||
"document_url": signed_url,
|
'document_url': signed_url,
|
||||||
},
|
},
|
||||||
"include_image_base64": False,
|
'include_image_base64': False,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def ocr_request():
|
async def ocr_request():
|
||||||
log.info("Starting OCR processing via Mistral API")
|
log.info('Starting OCR processing via Mistral API')
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
async with session.post(
|
async with session.post(
|
||||||
@@ -438,7 +408,7 @@ class MistralLoader:
|
|||||||
ocr_response = await self._handle_response_async(response)
|
ocr_response = await self._handle_response_async(response)
|
||||||
|
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
log.info(f"OCR processing completed in {processing_time:.2f}s")
|
log.info(f'OCR processing completed in {processing_time:.2f}s')
|
||||||
|
|
||||||
return ocr_response
|
return ocr_response
|
||||||
|
|
||||||
@@ -446,42 +416,36 @@ class MistralLoader:
|
|||||||
|
|
||||||
def _delete_file(self, file_id: str) -> None:
|
def _delete_file(self, file_id: str) -> None:
|
||||||
"""Deletes the file from Mistral storage (sync version)."""
|
"""Deletes the file from Mistral storage (sync version)."""
|
||||||
log.info(f"Deleting uploaded file ID: {file_id}")
|
log.info(f'Deleting uploaded file ID: {file_id}')
|
||||||
url = f"{self.base_url}/files/{file_id}"
|
url = f'{self.base_url}/files/{file_id}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.delete(
|
response = requests.delete(url, headers=self.headers, timeout=self.cleanup_timeout)
|
||||||
url, headers=self.headers, timeout=self.cleanup_timeout
|
|
||||||
)
|
|
||||||
delete_response = self._handle_response(response)
|
delete_response = self._handle_response(response)
|
||||||
log.info(f"File deleted successfully: {delete_response}")
|
log.info(f'File deleted successfully: {delete_response}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log error but don't necessarily halt execution if deletion fails
|
# Log error but don't necessarily halt execution if deletion fails
|
||||||
log.error(f"Failed to delete file ID {file_id}: {e}")
|
log.error(f'Failed to delete file ID {file_id}: {e}')
|
||||||
|
|
||||||
async def _delete_file_async(
|
async def _delete_file_async(self, session: aiohttp.ClientSession, file_id: str) -> None:
|
||||||
self, session: aiohttp.ClientSession, file_id: str
|
|
||||||
) -> None:
|
|
||||||
"""Async file deletion with error tolerance."""
|
"""Async file deletion with error tolerance."""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
async def delete_request():
|
async def delete_request():
|
||||||
self._debug_log(f"Deleting file ID: {file_id}")
|
self._debug_log(f'Deleting file ID: {file_id}')
|
||||||
async with session.delete(
|
async with session.delete(
|
||||||
url=f"{self.base_url}/files/{file_id}",
|
url=f'{self.base_url}/files/{file_id}',
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
timeout=aiohttp.ClientTimeout(
|
timeout=aiohttp.ClientTimeout(total=self.cleanup_timeout), # Shorter timeout for cleanup
|
||||||
total=self.cleanup_timeout
|
|
||||||
), # Shorter timeout for cleanup
|
|
||||||
) as response:
|
) as response:
|
||||||
return await self._handle_response_async(response)
|
return await self._handle_response_async(response)
|
||||||
|
|
||||||
await self._retry_request_async(delete_request)
|
await self._retry_request_async(delete_request)
|
||||||
self._debug_log(f"File {file_id} deleted successfully")
|
self._debug_log(f'File {file_id} deleted successfully')
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Don't fail the entire process if cleanup fails
|
# Don't fail the entire process if cleanup fails
|
||||||
log.warning(f"Failed to delete file ID {file_id}: {e}")
|
log.warning(f'Failed to delete file ID {file_id}: {e}')
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def _get_session(self):
|
async def _get_session(self):
|
||||||
@@ -506,7 +470,7 @@ class MistralLoader:
|
|||||||
async with aiohttp.ClientSession(
|
async with aiohttp.ClientSession(
|
||||||
connector=connector,
|
connector=connector,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"},
|
headers={'User-Agent': 'OpenWebUI-MistralLoader/2.0'},
|
||||||
raise_for_status=False, # We handle status codes manually
|
raise_for_status=False, # We handle status codes manually
|
||||||
trust_env=True,
|
trust_env=True,
|
||||||
) as session:
|
) as session:
|
||||||
@@ -514,13 +478,13 @@ class MistralLoader:
|
|||||||
|
|
||||||
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
|
def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]:
|
||||||
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
|
"""Process OCR results into Document objects with enhanced metadata and memory efficiency."""
|
||||||
pages_data = ocr_response.get("pages")
|
pages_data = ocr_response.get('pages')
|
||||||
if not pages_data:
|
if not pages_data:
|
||||||
log.warning("No pages found in OCR response.")
|
log.warning('No pages found in OCR response.')
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content="No text content found",
|
page_content='No text content found',
|
||||||
metadata={"error": "no_pages", "file_name": self.file_name},
|
metadata={'error': 'no_pages', 'file_name': self.file_name},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -530,8 +494,8 @@ class MistralLoader:
|
|||||||
|
|
||||||
# Process pages in a memory-efficient way
|
# Process pages in a memory-efficient way
|
||||||
for page_data in pages_data:
|
for page_data in pages_data:
|
||||||
page_content = page_data.get("markdown")
|
page_content = page_data.get('markdown')
|
||||||
page_index = page_data.get("index") # API uses 0-based index
|
page_index = page_data.get('index') # API uses 0-based index
|
||||||
|
|
||||||
if page_content is None or page_index is None:
|
if page_content is None or page_index is None:
|
||||||
skipped_pages += 1
|
skipped_pages += 1
|
||||||
@@ -548,7 +512,7 @@ class MistralLoader:
|
|||||||
|
|
||||||
if not cleaned_content:
|
if not cleaned_content:
|
||||||
skipped_pages += 1
|
skipped_pages += 1
|
||||||
self._debug_log(f"Skipping empty page {page_index}")
|
self._debug_log(f'Skipping empty page {page_index}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Create document with optimized metadata
|
# Create document with optimized metadata
|
||||||
@@ -556,34 +520,30 @@ class MistralLoader:
|
|||||||
Document(
|
Document(
|
||||||
page_content=cleaned_content,
|
page_content=cleaned_content,
|
||||||
metadata={
|
metadata={
|
||||||
"page": page_index, # 0-based index from API
|
'page': page_index, # 0-based index from API
|
||||||
"page_label": page_index + 1, # 1-based label for convenience
|
'page_label': page_index + 1, # 1-based label for convenience
|
||||||
"total_pages": total_pages,
|
'total_pages': total_pages,
|
||||||
"file_name": self.file_name,
|
'file_name': self.file_name,
|
||||||
"file_size": self.file_size,
|
'file_size': self.file_size,
|
||||||
"processing_engine": "mistral-ocr",
|
'processing_engine': 'mistral-ocr',
|
||||||
"content_length": len(cleaned_content),
|
'content_length': len(cleaned_content),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if skipped_pages > 0:
|
if skipped_pages > 0:
|
||||||
log.info(
|
log.info(f'Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages')
|
||||||
f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not documents:
|
if not documents:
|
||||||
# Case where pages existed but none had valid markdown/index
|
# Case where pages existed but none had valid markdown/index
|
||||||
log.warning(
|
log.warning('OCR response contained pages, but none had valid content/index.')
|
||||||
"OCR response contained pages, but none had valid content/index."
|
|
||||||
)
|
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content="No valid text content found in document",
|
page_content='No valid text content found in document',
|
||||||
metadata={
|
metadata={
|
||||||
"error": "no_valid_pages",
|
'error': 'no_valid_pages',
|
||||||
"total_pages": total_pages,
|
'total_pages': total_pages,
|
||||||
"file_name": self.file_name,
|
'file_name': self.file_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -615,24 +575,20 @@ class MistralLoader:
|
|||||||
documents = self._process_results(ocr_response)
|
documents = self._process_results(ocr_response)
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
log.info(
|
log.info(f'Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
|
||||||
f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
|
||||||
)
|
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
log.error(
|
log.error(f'An error occurred during the loading process after {total_time:.2f}s: {e}')
|
||||||
f"An error occurred during the loading process after {total_time:.2f}s: {e}"
|
|
||||||
)
|
|
||||||
# Return an error document on failure
|
# Return an error document on failure
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content=f"Error during processing: {e}",
|
page_content=f'Error during processing: {e}',
|
||||||
metadata={
|
metadata={
|
||||||
"error": "processing_failed",
|
'error': 'processing_failed',
|
||||||
"file_name": self.file_name,
|
'file_name': self.file_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -643,9 +599,7 @@ class MistralLoader:
|
|||||||
self._delete_file(file_id)
|
self._delete_file(file_id)
|
||||||
except Exception as del_e:
|
except Exception as del_e:
|
||||||
# Log deletion error, but don't overwrite original error if one occurred
|
# Log deletion error, but don't overwrite original error if one occurred
|
||||||
log.error(
|
log.error(f'Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}')
|
||||||
f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_async(self) -> List[Document]:
|
async def load_async(self) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
@@ -672,21 +626,19 @@ class MistralLoader:
|
|||||||
documents = self._process_results(ocr_response)
|
documents = self._process_results(ocr_response)
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
log.info(
|
log.info(f'Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents')
|
||||||
f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents"
|
|
||||||
)
|
|
||||||
|
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}")
|
log.error(f'Async OCR workflow failed after {total_time:.2f}s: {e}')
|
||||||
return [
|
return [
|
||||||
Document(
|
Document(
|
||||||
page_content=f"Error during OCR processing: {e}",
|
page_content=f'Error during OCR processing: {e}',
|
||||||
metadata={
|
metadata={
|
||||||
"error": "processing_failed",
|
'error': 'processing_failed',
|
||||||
"file_name": self.file_name,
|
'file_name': self.file_name,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -697,11 +649,11 @@ class MistralLoader:
|
|||||||
async with self._get_session() as session:
|
async with self._get_session() as session:
|
||||||
await self._delete_file_async(session, file_id)
|
await self._delete_file_async(session, file_id)
|
||||||
except Exception as cleanup_error:
|
except Exception as cleanup_error:
|
||||||
log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}")
|
log.error(f'Cleanup failed for file ID {file_id}: {cleanup_error}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def load_multiple_async(
|
async def load_multiple_async(
|
||||||
loaders: List["MistralLoader"],
|
loaders: List['MistralLoader'],
|
||||||
max_concurrent: int = 5, # Limit concurrent requests
|
max_concurrent: int = 5, # Limit concurrent requests
|
||||||
) -> List[List[Document]]:
|
) -> List[List[Document]]:
|
||||||
"""
|
"""
|
||||||
@@ -717,15 +669,13 @@ class MistralLoader:
|
|||||||
if not loaders:
|
if not loaders:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
log.info(
|
log.info(f'Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent')
|
||||||
f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent"
|
|
||||||
)
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Use semaphore to control concurrency
|
# Use semaphore to control concurrency
|
||||||
semaphore = asyncio.Semaphore(max_concurrent)
|
semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
|
||||||
async def process_with_semaphore(loader: "MistralLoader") -> List[Document]:
|
async def process_with_semaphore(loader: 'MistralLoader') -> List[Document]:
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
return await loader.load_async()
|
return await loader.load_async()
|
||||||
|
|
||||||
@@ -737,14 +687,14 @@ class MistralLoader:
|
|||||||
processed_results = []
|
processed_results = []
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
log.error(f"File {i} failed: {result}")
|
log.error(f'File {i} failed: {result}')
|
||||||
processed_results.append(
|
processed_results.append(
|
||||||
[
|
[
|
||||||
Document(
|
Document(
|
||||||
page_content=f"Error processing file: {result}",
|
page_content=f'Error processing file: {result}',
|
||||||
metadata={
|
metadata={
|
||||||
"error": "batch_processing_failed",
|
'error': 'batch_processing_failed',
|
||||||
"file_index": i,
|
'file_index': i,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@@ -755,15 +705,13 @@ class MistralLoader:
|
|||||||
# MONITORING: Log comprehensive batch processing statistics
|
# MONITORING: Log comprehensive batch processing statistics
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
total_docs = sum(len(docs) for docs in processed_results)
|
total_docs = sum(len(docs) for docs in processed_results)
|
||||||
success_count = sum(
|
success_count = sum(1 for result in results if not isinstance(result, Exception))
|
||||||
1 for result in results if not isinstance(result, Exception)
|
|
||||||
)
|
|
||||||
failure_count = len(results) - success_count
|
failure_count = len(results) - success_count
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
f"Batch processing completed in {total_time:.2f}s: "
|
f'Batch processing completed in {total_time:.2f}s: '
|
||||||
f"{success_count} files succeeded, {failure_count} files failed, "
|
f'{success_count} files succeeded, {failure_count} files failed, '
|
||||||
f"produced {total_docs} total documents"
|
f'produced {total_docs} total documents'
|
||||||
)
|
)
|
||||||
|
|
||||||
return processed_results
|
return processed_results
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class TavilyLoader(BaseLoader):
|
|||||||
self,
|
self,
|
||||||
urls: Union[str, List[str]],
|
urls: Union[str, List[str]],
|
||||||
api_key: str,
|
api_key: str,
|
||||||
extract_depth: Literal["basic", "advanced"] = "basic",
|
extract_depth: Literal['basic', 'advanced'] = 'basic',
|
||||||
continue_on_failure: bool = True,
|
continue_on_failure: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize Tavily Extract client.
|
"""Initialize Tavily Extract client.
|
||||||
@@ -42,13 +42,13 @@ class TavilyLoader(BaseLoader):
|
|||||||
continue_on_failure: Whether to continue if extraction of a URL fails.
|
continue_on_failure: Whether to continue if extraction of a URL fails.
|
||||||
"""
|
"""
|
||||||
if not urls:
|
if not urls:
|
||||||
raise ValueError("At least one URL must be provided.")
|
raise ValueError('At least one URL must be provided.')
|
||||||
|
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.urls = urls if isinstance(urls, list) else [urls]
|
self.urls = urls if isinstance(urls, list) else [urls]
|
||||||
self.extract_depth = extract_depth
|
self.extract_depth = extract_depth
|
||||||
self.continue_on_failure = continue_on_failure
|
self.continue_on_failure = continue_on_failure
|
||||||
self.api_url = "https://api.tavily.com/extract"
|
self.api_url = 'https://api.tavily.com/extract'
|
||||||
|
|
||||||
def lazy_load(self) -> Iterator[Document]:
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
"""Extract and yield documents from the URLs using Tavily Extract API."""
|
"""Extract and yield documents from the URLs using Tavily Extract API."""
|
||||||
@@ -57,35 +57,35 @@ class TavilyLoader(BaseLoader):
|
|||||||
batch_urls = self.urls[i : i + batch_size]
|
batch_urls = self.urls[i : i + batch_size]
|
||||||
try:
|
try:
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
'Content-Type': 'application/json',
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
}
|
}
|
||||||
# Use string for single URL, array for multiple URLs
|
# Use string for single URL, array for multiple URLs
|
||||||
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
|
urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
|
||||||
payload = {"urls": urls_param, "extract_depth": self.extract_depth}
|
payload = {'urls': urls_param, 'extract_depth': self.extract_depth}
|
||||||
# Make the API call
|
# Make the API call
|
||||||
response = requests.post(self.api_url, headers=headers, json=payload)
|
response = requests.post(self.api_url, headers=headers, json=payload)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
# Process successful results
|
# Process successful results
|
||||||
for result in response_data.get("results", []):
|
for result in response_data.get('results', []):
|
||||||
url = result.get("url", "")
|
url = result.get('url', '')
|
||||||
content = result.get("raw_content", "")
|
content = result.get('raw_content', '')
|
||||||
if not content:
|
if not content:
|
||||||
log.warning(f"No content extracted from {url}")
|
log.warning(f'No content extracted from {url}')
|
||||||
continue
|
continue
|
||||||
# Add URLs as metadata
|
# Add URLs as metadata
|
||||||
metadata = {"source": url}
|
metadata = {'source': url}
|
||||||
yield Document(
|
yield Document(
|
||||||
page_content=content,
|
page_content=content,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
for failed in response_data.get("failed_results", []):
|
for failed in response_data.get('failed_results', []):
|
||||||
url = failed.get("url", "")
|
url = failed.get('url', '')
|
||||||
error = failed.get("error", "Unknown error")
|
error = failed.get('error', 'Unknown error')
|
||||||
log.error(f"Failed to extract content from {url}: {error}")
|
log.error(f'Failed to extract content from {url}: {error}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.continue_on_failure:
|
if self.continue_on_failure:
|
||||||
log.error(f"Error extracting content from batch {batch_urls}: {e}")
|
log.error(f'Error extracting content from batch {batch_urls}: {e}')
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ from langchain_core.documents import Document
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALLOWED_SCHEMES = {"http", "https"}
|
ALLOWED_SCHEMES = {'http', 'https'}
|
||||||
ALLOWED_NETLOCS = {
|
ALLOWED_NETLOCS = {
|
||||||
"youtu.be",
|
'youtu.be',
|
||||||
"m.youtube.com",
|
'm.youtube.com',
|
||||||
"youtube.com",
|
'youtube.com',
|
||||||
"www.youtube.com",
|
'www.youtube.com',
|
||||||
"www.youtube-nocookie.com",
|
'www.youtube-nocookie.com',
|
||||||
"vid.plus",
|
'vid.plus',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -30,17 +30,17 @@ def _parse_video_id(url: str) -> Optional[str]:
|
|||||||
|
|
||||||
path = parsed_url.path
|
path = parsed_url.path
|
||||||
|
|
||||||
if path.endswith("/watch"):
|
if path.endswith('/watch'):
|
||||||
query = parsed_url.query
|
query = parsed_url.query
|
||||||
parsed_query = parse_qs(query)
|
parsed_query = parse_qs(query)
|
||||||
if "v" in parsed_query:
|
if 'v' in parsed_query:
|
||||||
ids = parsed_query["v"]
|
ids = parsed_query['v']
|
||||||
video_id = ids if isinstance(ids, str) else ids[0]
|
video_id = ids if isinstance(ids, str) else ids[0]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
path = parsed_url.path.lstrip("/")
|
path = parsed_url.path.lstrip('/')
|
||||||
video_id = path.split("/")[-1]
|
video_id = path.split('/')[-1]
|
||||||
|
|
||||||
if len(video_id) != 11: # Video IDs are 11 characters long
|
if len(video_id) != 11: # Video IDs are 11 characters long
|
||||||
return None
|
return None
|
||||||
@@ -54,13 +54,13 @@ class YoutubeLoader:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
video_id: str,
|
video_id: str,
|
||||||
language: Union[str, Sequence[str]] = "en",
|
language: Union[str, Sequence[str]] = 'en',
|
||||||
proxy_url: Optional[str] = None,
|
proxy_url: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Initialize with YouTube video ID."""
|
"""Initialize with YouTube video ID."""
|
||||||
_video_id = _parse_video_id(video_id)
|
_video_id = _parse_video_id(video_id)
|
||||||
self.video_id = _video_id if _video_id is not None else video_id
|
self.video_id = _video_id if _video_id is not None else video_id
|
||||||
self._metadata = {"source": video_id}
|
self._metadata = {'source': video_id}
|
||||||
self.proxy_url = proxy_url
|
self.proxy_url = proxy_url
|
||||||
|
|
||||||
# Ensure language is a list
|
# Ensure language is a list
|
||||||
@@ -70,8 +70,8 @@ class YoutubeLoader:
|
|||||||
self.language = list(language)
|
self.language = list(language)
|
||||||
|
|
||||||
# Add English as fallback if not already in the list
|
# Add English as fallback if not already in the list
|
||||||
if "en" not in self.language:
|
if 'en' not in self.language:
|
||||||
self.language.append("en")
|
self.language.append('en')
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
"""Load YouTube transcripts into `Document` objects."""
|
"""Load YouTube transcripts into `Document` objects."""
|
||||||
@@ -85,14 +85,12 @@ class YoutubeLoader:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
'Could not import "youtube_transcript_api" Python package. '
|
'Could not import "youtube_transcript_api" Python package. '
|
||||||
"Please install it with `pip install youtube-transcript-api`."
|
'Please install it with `pip install youtube-transcript-api`.'
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.proxy_url:
|
if self.proxy_url:
|
||||||
youtube_proxies = GenericProxyConfig(
|
youtube_proxies = GenericProxyConfig(http_url=self.proxy_url, https_url=self.proxy_url)
|
||||||
http_url=self.proxy_url, https_url=self.proxy_url
|
log.debug(f'Using proxy URL: {self.proxy_url[:14]}...')
|
||||||
)
|
|
||||||
log.debug(f"Using proxy URL: {self.proxy_url[:14]}...")
|
|
||||||
else:
|
else:
|
||||||
youtube_proxies = None
|
youtube_proxies = None
|
||||||
|
|
||||||
@@ -100,7 +98,7 @@ class YoutubeLoader:
|
|||||||
try:
|
try:
|
||||||
transcript_list = transcript_api.list(self.video_id)
|
transcript_list = transcript_api.list(self.video_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception("Loading YouTube transcript failed")
|
log.exception('Loading YouTube transcript failed')
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Try each language in order of priority
|
# Try each language in order of priority
|
||||||
@@ -110,14 +108,10 @@ class YoutubeLoader:
|
|||||||
if transcript.is_generated:
|
if transcript.is_generated:
|
||||||
log.debug(f"Found generated transcript for language '{lang}'")
|
log.debug(f"Found generated transcript for language '{lang}'")
|
||||||
try:
|
try:
|
||||||
transcript = transcript_list.find_manually_created_transcript(
|
transcript = transcript_list.find_manually_created_transcript([lang])
|
||||||
[lang]
|
|
||||||
)
|
|
||||||
log.debug(f"Found manual transcript for language '{lang}'")
|
log.debug(f"Found manual transcript for language '{lang}'")
|
||||||
except NoTranscriptFound:
|
except NoTranscriptFound:
|
||||||
log.debug(
|
log.debug(f"No manual transcript found for language '{lang}', using generated")
|
||||||
f"No manual transcript found for language '{lang}', using generated"
|
|
||||||
)
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
log.debug(f"Found transcript for language '{lang}'")
|
log.debug(f"Found transcript for language '{lang}'")
|
||||||
@@ -131,12 +125,10 @@ class YoutubeLoader:
|
|||||||
log.debug(f"Empty transcript for language '{lang}'")
|
log.debug(f"Empty transcript for language '{lang}'")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
transcript_text = " ".join(
|
transcript_text = ' '.join(
|
||||||
map(
|
map(
|
||||||
lambda transcript_piece: (
|
lambda transcript_piece: (
|
||||||
transcript_piece.text.strip(" ")
|
transcript_piece.text.strip(' ') if hasattr(transcript_piece, 'text') else ''
|
||||||
if hasattr(transcript_piece, "text")
|
|
||||||
else ""
|
|
||||||
),
|
),
|
||||||
transcript_pieces,
|
transcript_pieces,
|
||||||
)
|
)
|
||||||
@@ -150,9 +142,9 @@ class YoutubeLoader:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
# If we get here, all languages failed
|
# If we get here, all languages failed
|
||||||
languages_tried = ", ".join(self.language)
|
languages_tried = ', '.join(self.language)
|
||||||
log.warning(
|
log.warning(
|
||||||
f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed."
|
f'No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed.'
|
||||||
)
|
)
|
||||||
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
raise NoTranscriptFound(self.video_id, self.language, list(transcript_list))
|
||||||
|
|
||||||
|
|||||||
@@ -13,19 +13,17 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class ColBERT(BaseReranker):
|
class ColBERT(BaseReranker):
|
||||||
def __init__(self, name, **kwargs) -> None:
|
def __init__(self, name, **kwargs) -> None:
|
||||||
log.info("ColBERT: Loading model", name)
|
log.info('ColBERT: Loading model', name)
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
DOCKER = kwargs.get("env") == "docker"
|
DOCKER = kwargs.get('env') == 'docker'
|
||||||
if DOCKER:
|
if DOCKER:
|
||||||
# This is a workaround for the issue with the docker container
|
# This is a workaround for the issue with the docker container
|
||||||
# where the torch extension is not loaded properly
|
# where the torch extension is not loaded properly
|
||||||
# and the following error is thrown:
|
# and the following error is thrown:
|
||||||
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
|
# /root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/segmented_maxsim_cpp.so: cannot open shared object file: No such file or directory
|
||||||
|
|
||||||
lock_file = (
|
lock_file = '/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock'
|
||||||
"/root/.cache/torch_extensions/py311_cpu/segmented_maxsim_cpp/lock"
|
|
||||||
)
|
|
||||||
if os.path.exists(lock_file):
|
if os.path.exists(lock_file):
|
||||||
os.remove(lock_file)
|
os.remove(lock_file)
|
||||||
|
|
||||||
@@ -36,23 +34,16 @@ class ColBERT(BaseReranker):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def calculate_similarity_scores(self, query_embeddings, document_embeddings):
|
def calculate_similarity_scores(self, query_embeddings, document_embeddings):
|
||||||
|
|
||||||
query_embeddings = query_embeddings.to(self.device)
|
query_embeddings = query_embeddings.to(self.device)
|
||||||
document_embeddings = document_embeddings.to(self.device)
|
document_embeddings = document_embeddings.to(self.device)
|
||||||
|
|
||||||
# Validate dimensions to ensure compatibility
|
# Validate dimensions to ensure compatibility
|
||||||
if query_embeddings.dim() != 3:
|
if query_embeddings.dim() != 3:
|
||||||
raise ValueError(
|
raise ValueError(f'Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}.')
|
||||||
f"Expected query embeddings to have 3 dimensions, but got {query_embeddings.dim()}."
|
|
||||||
)
|
|
||||||
if document_embeddings.dim() != 3:
|
if document_embeddings.dim() != 3:
|
||||||
raise ValueError(
|
raise ValueError(f'Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}.')
|
||||||
f"Expected document embeddings to have 3 dimensions, but got {document_embeddings.dim()}."
|
|
||||||
)
|
|
||||||
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
|
if query_embeddings.size(0) not in [1, document_embeddings.size(0)]:
|
||||||
raise ValueError(
|
raise ValueError('There should be either one query or queries equal to the number of documents.')
|
||||||
"There should be either one query or queries equal to the number of documents."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Transpose the query embeddings to align for matrix multiplication
|
# Transpose the query embeddings to align for matrix multiplication
|
||||||
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
|
transposed_query_embeddings = query_embeddings.permute(0, 2, 1)
|
||||||
@@ -69,7 +60,6 @@ class ColBERT(BaseReranker):
|
|||||||
return normalized_scores.detach().cpu().numpy().astype(np.float32)
|
return normalized_scores.detach().cpu().numpy().astype(np.float32)
|
||||||
|
|
||||||
def predict(self, sentences):
|
def predict(self, sentences):
|
||||||
|
|
||||||
query = sentences[0][0]
|
query = sentences[0][0]
|
||||||
docs = [i[1] for i in sentences]
|
docs = [i[1] for i in sentences]
|
||||||
|
|
||||||
@@ -80,8 +70,6 @@ class ColBERT(BaseReranker):
|
|||||||
embedded_query = embedded_queries[0]
|
embedded_query = embedded_queries[0]
|
||||||
|
|
||||||
# Calculate retrieval scores for the query against all documents
|
# Calculate retrieval scores for the query against all documents
|
||||||
scores = self.calculate_similarity_scores(
|
scores = self.calculate_similarity_scores(embedded_query.unsqueeze(0), embedded_docs)
|
||||||
embedded_query.unsqueeze(0), embedded_docs
|
|
||||||
)
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ class ExternalReranker(BaseReranker):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
url: str = "http://localhost:8080/v1/rerank",
|
url: str = 'http://localhost:8080/v1/rerank',
|
||||||
model: str = "reranker",
|
model: str = 'reranker',
|
||||||
timeout: Optional[int] = None,
|
timeout: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
@@ -24,33 +24,31 @@ class ExternalReranker(BaseReranker):
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
def predict(
|
def predict(self, sentences: List[Tuple[str, str]], user=None) -> Optional[List[float]]:
|
||||||
self, sentences: List[Tuple[str, str]], user=None
|
|
||||||
) -> Optional[List[float]]:
|
|
||||||
query = sentences[0][0]
|
query = sentences[0][0]
|
||||||
docs = [i[1] for i in sentences]
|
docs = [i[1] for i in sentences]
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
'model': self.model,
|
||||||
"query": query,
|
'query': query,
|
||||||
"documents": docs,
|
'documents': docs,
|
||||||
"top_n": len(docs),
|
'top_n': len(docs),
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.info(f"ExternalReranker:predict:model {self.model}")
|
log.info(f'ExternalReranker:predict:model {self.model}')
|
||||||
log.info(f"ExternalReranker:predict:query {query}")
|
log.info(f'ExternalReranker:predict:query {query}')
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
'Content-Type': 'application/json',
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
'Authorization': f'Bearer {self.api_key}',
|
||||||
}
|
}
|
||||||
|
|
||||||
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
|
||||||
headers = include_user_info_headers(headers, user)
|
headers = include_user_info_headers(headers, user)
|
||||||
|
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
f"{self.url}",
|
f'{self.url}',
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
@@ -60,13 +58,13 @@ class ExternalReranker(BaseReranker):
|
|||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
data = r.json()
|
data = r.json()
|
||||||
|
|
||||||
if "results" in data:
|
if 'results' in data:
|
||||||
sorted_results = sorted(data["results"], key=lambda x: x["index"])
|
sorted_results = sorted(data['results'], key=lambda x: x['index'])
|
||||||
return [result["relevance_score"] for result in sorted_results]
|
return [result['relevance_score'] for result in sorted_results]
|
||||||
else:
|
else:
|
||||||
log.error("No results found in external reranking response")
|
log.error('No results found in external reranking response')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error in external reranking: {e}")
|
log.exception(f'Error in external reranking: {e}')
|
||||||
return None
|
return None
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -31,17 +31,15 @@ log = logging.getLogger(__name__)
|
|||||||
class ChromaClient(VectorDBBase):
|
class ChromaClient(VectorDBBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
settings_dict = {
|
settings_dict = {
|
||||||
"allow_reset": True,
|
'allow_reset': True,
|
||||||
"anonymized_telemetry": False,
|
'anonymized_telemetry': False,
|
||||||
}
|
}
|
||||||
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
|
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
|
||||||
settings_dict["chroma_client_auth_provider"] = CHROMA_CLIENT_AUTH_PROVIDER
|
settings_dict['chroma_client_auth_provider'] = CHROMA_CLIENT_AUTH_PROVIDER
|
||||||
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
|
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
|
||||||
settings_dict["chroma_client_auth_credentials"] = (
|
settings_dict['chroma_client_auth_credentials'] = CHROMA_CLIENT_AUTH_CREDENTIALS
|
||||||
CHROMA_CLIENT_AUTH_CREDENTIALS
|
|
||||||
)
|
|
||||||
|
|
||||||
if CHROMA_HTTP_HOST != "":
|
if CHROMA_HTTP_HOST != '':
|
||||||
self.client = chromadb.HttpClient(
|
self.client = chromadb.HttpClient(
|
||||||
host=CHROMA_HTTP_HOST,
|
host=CHROMA_HTTP_HOST,
|
||||||
port=CHROMA_HTTP_PORT,
|
port=CHROMA_HTTP_PORT,
|
||||||
@@ -87,25 +85,23 @@ class ChromaClient(VectorDBBase):
|
|||||||
|
|
||||||
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
|
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
|
||||||
# https://docs.trychroma.com/docs/collections/configure cosine equation
|
# https://docs.trychroma.com/docs/collections/configure cosine equation
|
||||||
distances: list = result["distances"][0]
|
distances: list = result['distances'][0]
|
||||||
distances = [2 - dist for dist in distances]
|
distances = [2 - dist for dist in distances]
|
||||||
distances = [[dist / 2 for dist in distances]]
|
distances = [[dist / 2 for dist in distances]]
|
||||||
|
|
||||||
return SearchResult(
|
return SearchResult(
|
||||||
**{
|
**{
|
||||||
"ids": result["ids"],
|
'ids': result['ids'],
|
||||||
"distances": distances,
|
'distances': distances,
|
||||||
"documents": result["documents"],
|
'documents': result['documents'],
|
||||||
"metadatas": result["metadatas"],
|
'metadatas': result['metadatas'],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def query(
|
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
|
||||||
) -> Optional[GetResult]:
|
|
||||||
# Query the items from the collection based on the filter.
|
# Query the items from the collection based on the filter.
|
||||||
try:
|
try:
|
||||||
collection = self.client.get_collection(name=collection_name)
|
collection = self.client.get_collection(name=collection_name)
|
||||||
@@ -117,9 +113,9 @@ class ChromaClient(VectorDBBase):
|
|||||||
|
|
||||||
return GetResult(
|
return GetResult(
|
||||||
**{
|
**{
|
||||||
"ids": [result["ids"]],
|
'ids': [result['ids']],
|
||||||
"documents": [result["documents"]],
|
'documents': [result['documents']],
|
||||||
"metadatas": [result["metadatas"]],
|
'metadatas': [result['metadatas']],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@@ -133,23 +129,21 @@ class ChromaClient(VectorDBBase):
|
|||||||
result = collection.get()
|
result = collection.get()
|
||||||
return GetResult(
|
return GetResult(
|
||||||
**{
|
**{
|
||||||
"ids": [result["ids"]],
|
'ids': [result['ids']],
|
||||||
"documents": [result["documents"]],
|
'documents': [result['documents']],
|
||||||
"metadatas": [result["metadatas"]],
|
'metadatas': [result['metadatas']],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||||
collection = self.client.get_or_create_collection(
|
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
|
||||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
|
||||||
)
|
|
||||||
|
|
||||||
ids = [item["id"] for item in items]
|
ids = [item['id'] for item in items]
|
||||||
documents = [item["text"] for item in items]
|
documents = [item['text'] for item in items]
|
||||||
embeddings = [item["vector"] for item in items]
|
embeddings = [item['vector'] for item in items]
|
||||||
metadatas = [process_metadata(item["metadata"]) for item in items]
|
metadatas = [process_metadata(item['metadata']) for item in items]
|
||||||
|
|
||||||
for batch in create_batches(
|
for batch in create_batches(
|
||||||
api=self.client,
|
api=self.client,
|
||||||
@@ -162,18 +156,14 @@ class ChromaClient(VectorDBBase):
|
|||||||
|
|
||||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||||
collection = self.client.get_or_create_collection(
|
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
|
||||||
name=collection_name, metadata={"hnsw:space": "cosine"}
|
|
||||||
)
|
|
||||||
|
|
||||||
ids = [item["id"] for item in items]
|
ids = [item['id'] for item in items]
|
||||||
documents = [item["text"] for item in items]
|
documents = [item['text'] for item in items]
|
||||||
embeddings = [item["vector"] for item in items]
|
embeddings = [item['vector'] for item in items]
|
||||||
metadatas = [process_metadata(item["metadata"]) for item in items]
|
metadatas = [process_metadata(item['metadata']) for item in items]
|
||||||
|
|
||||||
collection.upsert(
|
collection.upsert(ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas)
|
||||||
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
@@ -191,9 +181,7 @@ class ChromaClient(VectorDBBase):
|
|||||||
collection.delete(where=filter)
|
collection.delete(where=filter)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If collection doesn't exist, that's fine - nothing to delete
|
# If collection doesn't exist, that's fine - nothing to delete
|
||||||
log.debug(
|
log.debug(f'Attempted to delete from non-existent collection {collection_name}. Ignoring.')
|
||||||
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
|
|
||||||
)
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class ElasticsearchClient(VectorDBBase):
|
|||||||
|
|
||||||
# Status: works
|
# Status: works
|
||||||
def _get_index_name(self, dimension: int) -> str:
|
def _get_index_name(self, dimension: int) -> str:
|
||||||
return f"{self.index_prefix}_d{str(dimension)}"
|
return f'{self.index_prefix}_d{str(dimension)}'
|
||||||
|
|
||||||
# Status: works
|
# Status: works
|
||||||
def _scan_result_to_get_result(self, result) -> GetResult:
|
def _scan_result_to_get_result(self, result) -> GetResult:
|
||||||
@@ -62,24 +62,24 @@ class ElasticsearchClient(VectorDBBase):
|
|||||||
metadatas = []
|
metadatas = []
|
||||||
|
|
||||||
for hit in result:
|
for hit in result:
|
||||||
ids.append(hit["_id"])
|
ids.append(hit['_id'])
|
||||||
documents.append(hit["_source"].get("text"))
|
documents.append(hit['_source'].get('text'))
|
||||||
metadatas.append(hit["_source"].get("metadata"))
|
metadatas.append(hit['_source'].get('metadata'))
|
||||||
|
|
||||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||||
|
|
||||||
# Status: works
|
# Status: works
|
||||||
def _result_to_get_result(self, result) -> GetResult:
|
def _result_to_get_result(self, result) -> GetResult:
|
||||||
if not result["hits"]["hits"]:
|
if not result['hits']['hits']:
|
||||||
return None
|
return None
|
||||||
ids = []
|
ids = []
|
||||||
documents = []
|
documents = []
|
||||||
metadatas = []
|
metadatas = []
|
||||||
|
|
||||||
for hit in result["hits"]["hits"]:
|
for hit in result['hits']['hits']:
|
||||||
ids.append(hit["_id"])
|
ids.append(hit['_id'])
|
||||||
documents.append(hit["_source"].get("text"))
|
documents.append(hit['_source'].get('text'))
|
||||||
metadatas.append(hit["_source"].get("metadata"))
|
metadatas.append(hit['_source'].get('metadata'))
|
||||||
|
|
||||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||||
|
|
||||||
@@ -90,11 +90,11 @@ class ElasticsearchClient(VectorDBBase):
|
|||||||
documents = []
|
documents = []
|
||||||
metadatas = []
|
metadatas = []
|
||||||
|
|
||||||
for hit in result["hits"]["hits"]:
|
for hit in result['hits']['hits']:
|
||||||
ids.append(hit["_id"])
|
ids.append(hit['_id'])
|
||||||
distances.append(hit["_score"])
|
distances.append(hit['_score'])
|
||||||
documents.append(hit["_source"].get("text"))
|
documents.append(hit['_source'].get('text'))
|
||||||
metadatas.append(hit["_source"].get("metadata"))
|
metadatas.append(hit['_source'].get('metadata'))
|
||||||
|
|
||||||
return SearchResult(
|
return SearchResult(
|
||||||
ids=[ids],
|
ids=[ids],
|
||||||
@@ -106,26 +106,26 @@ class ElasticsearchClient(VectorDBBase):
|
|||||||
# Status: works
|
# Status: works
|
||||||
def _create_index(self, dimension: int):
|
def _create_index(self, dimension: int):
|
||||||
body = {
|
body = {
|
||||||
"mappings": {
|
'mappings': {
|
||||||
"dynamic_templates": [
|
'dynamic_templates': [
|
||||||
{
|
{
|
||||||
"strings": {
|
'strings': {
|
||||||
"match_mapping_type": "string",
|
'match_mapping_type': 'string',
|
||||||
"mapping": {"type": "keyword"},
|
'mapping': {'type': 'keyword'},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"properties": {
|
'properties': {
|
||||||
"collection": {"type": "keyword"},
|
'collection': {'type': 'keyword'},
|
||||||
"id": {"type": "keyword"},
|
'id': {'type': 'keyword'},
|
||||||
"vector": {
|
'vector': {
|
||||||
"type": "dense_vector",
|
'type': 'dense_vector',
|
||||||
"dims": dimension, # Adjust based on your vector dimensions
|
'dims': dimension, # Adjust based on your vector dimensions
|
||||||
"index": True,
|
'index': True,
|
||||||
"similarity": "cosine",
|
'similarity': 'cosine',
|
||||||
},
|
},
|
||||||
"text": {"type": "text"},
|
'text': {'type': 'text'},
|
||||||
"metadata": {"type": "object"},
|
'metadata': {'type': 'object'},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -139,21 +139,19 @@ class ElasticsearchClient(VectorDBBase):
|
|||||||
|
|
||||||
# Status: works
|
# Status: works
|
||||||
def has_collection(self, collection_name) -> bool:
|
def has_collection(self, collection_name) -> bool:
|
||||||
query_body = {"query": {"bool": {"filter": []}}}
|
query_body = {'query': {'bool': {'filter': []}}}
|
||||||
query_body["query"]["bool"]["filter"].append(
|
query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
|
||||||
{"term": {"collection": collection_name}}
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.client.count(index=f"{self.index_prefix}*", body=query_body)
|
result = self.client.count(index=f'{self.index_prefix}*', body=query_body)
|
||||||
|
|
||||||
return result.body["count"] > 0
|
return result.body['count'] > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str):
|
def delete_collection(self, collection_name: str):
|
||||||
query = {"query": {"term": {"collection": collection_name}}}
|
query = {'query': {'term': {'collection': collection_name}}}
|
||||||
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
|
self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
|
||||||
|
|
||||||
# Status: works
|
# Status: works
|
||||||
def search(
|
def search(
|
||||||
@@ -164,51 +162,41 @@ class ElasticsearchClient(VectorDBBase):
|
|||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> Optional[SearchResult]:
|
) -> Optional[SearchResult]:
|
||||||
query = {
|
query = {
|
||||||
"size": limit,
|
'size': limit,
|
||||||
"_source": ["text", "metadata"],
|
'_source': ['text', 'metadata'],
|
||||||
"query": {
|
'query': {
|
||||||
"script_score": {
|
'script_score': {
|
||||||
"query": {
|
'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
|
||||||
"bool": {"filter": [{"term": {"collection": collection_name}}]}
|
'script': {
|
||||||
},
|
'source': "cosineSimilarity(params.vector, 'vector') + 1.0",
|
||||||
"script": {
|
'params': {'vector': vectors[0]}, # Assuming single query vector
|
||||||
"source": "cosineSimilarity(params.vector, 'vector') + 1.0",
|
|
||||||
"params": {
|
|
||||||
"vector": vectors[0]
|
|
||||||
}, # Assuming single query vector
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
result = self.client.search(
|
result = self.client.search(index=self._get_index_name(len(vectors[0])), body=query)
|
||||||
index=self._get_index_name(len(vectors[0])), body=query
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._result_to_search_result(result)
|
return self._result_to_search_result(result)
|
||||||
|
|
||||||
# Status: only tested halfwat
|
# Status: only tested halfwat
|
||||||
def query(
|
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||||
self, collection_name: str, filter: dict, limit: Optional[int] = None
|
|
||||||
) -> Optional[GetResult]:
|
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
query_body = {
|
query_body = {
|
||||||
"query": {"bool": {"filter": []}},
|
'query': {'bool': {'filter': []}},
|
||||||
"_source": ["text", "metadata"],
|
'_source': ['text', 'metadata'],
|
||||||
}
|
}
|
||||||
|
|
||||||
for field, value in filter.items():
|
for field, value in filter.items():
|
||||||
query_body["query"]["bool"]["filter"].append({"term": {field: value}})
|
query_body['query']['bool']['filter'].append({'term': {field: value}})
|
||||||
query_body["query"]["bool"]["filter"].append(
|
query_body['query']['bool']['filter'].append({'term': {'collection': collection_name}})
|
||||||
{"term": {"collection": collection_name}}
|
|
||||||
)
|
|
||||||
size = limit if limit else 10
|
size = limit if limit else 10
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = self.client.search(
|
result = self.client.search(
|
||||||
index=f"{self.index_prefix}*",
|
index=f'{self.index_prefix}*',
|
||||||
body=query_body,
|
body=query_body,
|
||||||
size=size,
|
size=size,
|
||||||
)
|
)
|
||||||
@@ -220,9 +208,7 @@ class ElasticsearchClient(VectorDBBase):
|
|||||||
|
|
||||||
# Status: works
|
# Status: works
|
||||||
def _has_index(self, dimension: int):
|
def _has_index(self, dimension: int):
|
||||||
return self.client.indices.exists(
|
return self.client.indices.exists(index=self._get_index_name(dimension=dimension))
|
||||||
index=self._get_index_name(dimension=dimension)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_or_create_index(self, dimension: int):
|
def get_or_create_index(self, dimension: int):
|
||||||
if not self._has_index(dimension=dimension):
|
if not self._has_index(dimension=dimension):
|
||||||
@@ -232,28 +218,28 @@ class ElasticsearchClient(VectorDBBase):
|
|||||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
# Get all the items in the collection.
|
# Get all the items in the collection.
|
||||||
query = {
|
query = {
|
||||||
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}},
|
'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}},
|
||||||
"_source": ["text", "metadata"],
|
'_source': ['text', 'metadata'],
|
||||||
}
|
}
|
||||||
results = list(scan(self.client, index=f"{self.index_prefix}*", query=query))
|
results = list(scan(self.client, index=f'{self.index_prefix}*', query=query))
|
||||||
|
|
||||||
return self._scan_result_to_get_result(results)
|
return self._scan_result_to_get_result(results)
|
||||||
|
|
||||||
# Status: works
|
# Status: works
|
||||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||||
if not self._has_index(dimension=len(items[0]["vector"])):
|
if not self._has_index(dimension=len(items[0]['vector'])):
|
||||||
self._create_index(dimension=len(items[0]["vector"]))
|
self._create_index(dimension=len(items[0]['vector']))
|
||||||
|
|
||||||
for batch in self._create_batches(items):
|
for batch in self._create_batches(items):
|
||||||
actions = [
|
actions = [
|
||||||
{
|
{
|
||||||
"_index": self._get_index_name(dimension=len(items[0]["vector"])),
|
'_index': self._get_index_name(dimension=len(items[0]['vector'])),
|
||||||
"_id": item["id"],
|
'_id': item['id'],
|
||||||
"_source": {
|
'_source': {
|
||||||
"collection": collection_name,
|
'collection': collection_name,
|
||||||
"vector": item["vector"],
|
'vector': item['vector'],
|
||||||
"text": item["text"],
|
'text': item['text'],
|
||||||
"metadata": process_metadata(item["metadata"]),
|
'metadata': process_metadata(item['metadata']),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for item in batch
|
for item in batch
|
||||||
@@ -262,21 +248,21 @@ class ElasticsearchClient(VectorDBBase):
|
|||||||
|
|
||||||
# Upsert documents using the update API with doc_as_upsert=True.
|
# Upsert documents using the update API with doc_as_upsert=True.
|
||||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||||
if not self._has_index(dimension=len(items[0]["vector"])):
|
if not self._has_index(dimension=len(items[0]['vector'])):
|
||||||
self._create_index(dimension=len(items[0]["vector"]))
|
self._create_index(dimension=len(items[0]['vector']))
|
||||||
for batch in self._create_batches(items):
|
for batch in self._create_batches(items):
|
||||||
actions = [
|
actions = [
|
||||||
{
|
{
|
||||||
"_op_type": "update",
|
'_op_type': 'update',
|
||||||
"_index": self._get_index_name(dimension=len(item["vector"])),
|
'_index': self._get_index_name(dimension=len(item['vector'])),
|
||||||
"_id": item["id"],
|
'_id': item['id'],
|
||||||
"doc": {
|
'doc': {
|
||||||
"collection": collection_name,
|
'collection': collection_name,
|
||||||
"vector": item["vector"],
|
'vector': item['vector'],
|
||||||
"text": item["text"],
|
'text': item['text'],
|
||||||
"metadata": process_metadata(item["metadata"]),
|
'metadata': process_metadata(item['metadata']),
|
||||||
},
|
},
|
||||||
"doc_as_upsert": True,
|
'doc_as_upsert': True,
|
||||||
}
|
}
|
||||||
for item in batch
|
for item in batch
|
||||||
]
|
]
|
||||||
@@ -289,22 +275,17 @@ class ElasticsearchClient(VectorDBBase):
|
|||||||
ids: Optional[list[str]] = None,
|
ids: Optional[list[str]] = None,
|
||||||
filter: Optional[dict] = None,
|
filter: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
|
query = {'query': {'bool': {'filter': [{'term': {'collection': collection_name}}]}}}
|
||||||
query = {
|
|
||||||
"query": {"bool": {"filter": [{"term": {"collection": collection_name}}]}}
|
|
||||||
}
|
|
||||||
# logic based on chromaDB
|
# logic based on chromaDB
|
||||||
if ids:
|
if ids:
|
||||||
query["query"]["bool"]["filter"].append({"terms": {"_id": ids}})
|
query['query']['bool']['filter'].append({'terms': {'_id': ids}})
|
||||||
elif filter:
|
elif filter:
|
||||||
for field, value in filter.items():
|
for field, value in filter.items():
|
||||||
query["query"]["bool"]["filter"].append(
|
query['query']['bool']['filter'].append({'term': {f'metadata.{field}': value}})
|
||||||
{"term": {f"metadata.{field}": value}}
|
|
||||||
)
|
|
||||||
|
|
||||||
self.client.delete_by_query(index=f"{self.index_prefix}*", body=query)
|
self.client.delete_by_query(index=f'{self.index_prefix}*', body=query)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
indices = self.client.indices.get(index=f"{self.index_prefix}*")
|
indices = self.client.indices.get(index=f'{self.index_prefix}*')
|
||||||
for index in indices:
|
for index in indices:
|
||||||
self.client.indices.delete(index=index)
|
self.client.indices.delete(index=index)
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ def _embedding_to_f32_bytes(vec: List[float]) -> bytes:
|
|||||||
byte sequence. We use array('f') to avoid a numpy dependency and byteswap on
|
byte sequence. We use array('f') to avoid a numpy dependency and byteswap on
|
||||||
big-endian platforms for portability.
|
big-endian platforms for portability.
|
||||||
"""
|
"""
|
||||||
a = array.array("f", [float(x) for x in vec]) # float32
|
a = array.array('f', [float(x) for x in vec]) # float32
|
||||||
if sys.byteorder != "little":
|
if sys.byteorder != 'little':
|
||||||
a.byteswap()
|
a.byteswap()
|
||||||
return a.tobytes()
|
return a.tobytes()
|
||||||
|
|
||||||
@@ -68,7 +68,7 @@ def _safe_json(v: Any) -> Dict[str, Any]:
|
|||||||
return v
|
return v
|
||||||
if isinstance(v, (bytes, bytearray)):
|
if isinstance(v, (bytes, bytearray)):
|
||||||
try:
|
try:
|
||||||
v = v.decode("utf-8")
|
v = v.decode('utf-8')
|
||||||
except Exception:
|
except Exception:
|
||||||
return {}
|
return {}
|
||||||
if isinstance(v, str):
|
if isinstance(v, str):
|
||||||
@@ -105,16 +105,16 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
"""
|
"""
|
||||||
self.db_url = (db_url or MARIADB_VECTOR_DB_URL).strip()
|
self.db_url = (db_url or MARIADB_VECTOR_DB_URL).strip()
|
||||||
self.vector_length = int(vector_length)
|
self.vector_length = int(vector_length)
|
||||||
self.distance_strategy = (distance_strategy or "cosine").strip().lower()
|
self.distance_strategy = (distance_strategy or 'cosine').strip().lower()
|
||||||
self.index_m = int(index_m)
|
self.index_m = int(index_m)
|
||||||
|
|
||||||
if self.distance_strategy not in {"cosine", "euclidean"}:
|
if self.distance_strategy not in {'cosine', 'euclidean'}:
|
||||||
raise ValueError("distance_strategy must be 'cosine' or 'euclidean'")
|
raise ValueError("distance_strategy must be 'cosine' or 'euclidean'")
|
||||||
|
|
||||||
if not self.db_url.lower().startswith("mariadb+mariadbconnector://"):
|
if not self.db_url.lower().startswith('mariadb+mariadbconnector://'):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) "
|
'MariaDBVectorClient requires mariadb+mariadbconnector:// (official MariaDB driver) '
|
||||||
"to ensure qmark paramstyle and correct VECTOR binding."
|
'to ensure qmark paramstyle and correct VECTOR binding.'
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(MARIADB_VECTOR_POOL_SIZE, int):
|
if isinstance(MARIADB_VECTOR_POOL_SIZE, int):
|
||||||
@@ -129,9 +129,7 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
poolclass=QueuePool,
|
poolclass=QueuePool,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.engine = create_engine(
|
self.engine = create_engine(self.db_url, pool_pre_ping=True, poolclass=NullPool)
|
||||||
self.db_url, pool_pre_ping=True, poolclass=NullPool
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.engine = create_engine(self.db_url, pool_pre_ping=True)
|
self.engine = create_engine(self.db_url, pool_pre_ping=True)
|
||||||
self._init_schema()
|
self._init_schema()
|
||||||
@@ -185,7 +183,7 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
log.exception(f"Error during database initialization: {e}")
|
log.exception(f'Error during database initialization: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _check_vector_length(self) -> None:
|
def _check_vector_length(self) -> None:
|
||||||
@@ -197,19 +195,19 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
"""
|
"""
|
||||||
with self._connect() as conn:
|
with self._connect() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute("SHOW CREATE TABLE document_chunk")
|
cur.execute('SHOW CREATE TABLE document_chunk')
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
if not row or len(row) < 2:
|
if not row or len(row) < 2:
|
||||||
return
|
return
|
||||||
ddl = row[1]
|
ddl = row[1]
|
||||||
m = re.search(r"vector\\((\\d+)\\)", ddl, flags=re.IGNORECASE)
|
m = re.search(r'vector\\((\\d+)\\)', ddl, flags=re.IGNORECASE)
|
||||||
if not m:
|
if not m:
|
||||||
return
|
return
|
||||||
existing = int(m.group(1))
|
existing = int(m.group(1))
|
||||||
if existing != int(self.vector_length):
|
if existing != int(self.vector_length):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. "
|
f'VECTOR_LENGTH {self.vector_length} does not match existing vector column dimension {existing}. '
|
||||||
"Cannot change vector size after initialization without migrating the data."
|
'Cannot change vector size after initialization without migrating the data.'
|
||||||
)
|
)
|
||||||
|
|
||||||
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
def adjust_vector_length(self, vector: List[float]) -> List[float]:
|
||||||
@@ -227,11 +225,7 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
"""
|
"""
|
||||||
Return the MariaDB Vector distance function name for the configured strategy.
|
Return the MariaDB Vector distance function name for the configured strategy.
|
||||||
"""
|
"""
|
||||||
return (
|
return 'vec_distance_cosine' if self.distance_strategy == 'cosine' else 'vec_distance_euclidean'
|
||||||
"vec_distance_cosine"
|
|
||||||
if self.distance_strategy == "cosine"
|
|
||||||
else "vec_distance_euclidean"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _score_from_dist(self, dist: float) -> float:
|
def _score_from_dist(self, dist: float) -> float:
|
||||||
"""
|
"""
|
||||||
@@ -240,7 +234,7 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
- cosine: score ~= 1 - cosine_distance, clamped to [0, 1]
|
- cosine: score ~= 1 - cosine_distance, clamped to [0, 1]
|
||||||
- euclidean: score = 1 / (1 + dist)
|
- euclidean: score = 1 / (1 + dist)
|
||||||
"""
|
"""
|
||||||
if self.distance_strategy == "cosine":
|
if self.distance_strategy == 'cosine':
|
||||||
score = 1.0 - dist
|
score = 1.0 - dist
|
||||||
if score < 0.0:
|
if score < 0.0:
|
||||||
score = 0.0
|
score = 0.0
|
||||||
@@ -260,48 +254,48 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
- {"$or": [ ... ]}
|
- {"$or": [ ... ]}
|
||||||
"""
|
"""
|
||||||
if not expr or not isinstance(expr, dict):
|
if not expr or not isinstance(expr, dict):
|
||||||
return "", []
|
return '', []
|
||||||
|
|
||||||
if "$and" in expr:
|
if '$and' in expr:
|
||||||
parts: List[str] = []
|
parts: List[str] = []
|
||||||
params: List[Any] = []
|
params: List[Any] = []
|
||||||
for e in expr.get("$and") or []:
|
for e in expr.get('$and') or []:
|
||||||
s, p = self._build_filter_sql_qmark(e)
|
s, p = self._build_filter_sql_qmark(e)
|
||||||
if s:
|
if s:
|
||||||
parts.append(s)
|
parts.append(s)
|
||||||
params.extend(p)
|
params.extend(p)
|
||||||
return ("(" + " AND ".join(parts) + ")") if parts else "", params
|
return ('(' + ' AND '.join(parts) + ')') if parts else '', params
|
||||||
|
|
||||||
if "$or" in expr:
|
if '$or' in expr:
|
||||||
parts: List[str] = []
|
parts: List[str] = []
|
||||||
params: List[Any] = []
|
params: List[Any] = []
|
||||||
for e in expr.get("$or") or []:
|
for e in expr.get('$or') or []:
|
||||||
s, p = self._build_filter_sql_qmark(e)
|
s, p = self._build_filter_sql_qmark(e)
|
||||||
if s:
|
if s:
|
||||||
parts.append(s)
|
parts.append(s)
|
||||||
params.extend(p)
|
params.extend(p)
|
||||||
return ("(" + " OR ".join(parts) + ")") if parts else "", params
|
return ('(' + ' OR '.join(parts) + ')') if parts else '', params
|
||||||
|
|
||||||
clauses: List[str] = []
|
clauses: List[str] = []
|
||||||
params: List[Any] = []
|
params: List[Any] = []
|
||||||
for key, value in expr.items():
|
for key, value in expr.items():
|
||||||
if key.startswith("$"):
|
if key.startswith('$'):
|
||||||
continue
|
continue
|
||||||
json_expr = f"JSON_UNQUOTE(JSON_EXTRACT(vmetadata, '$.{key}'))"
|
json_expr = f"JSON_UNQUOTE(JSON_EXTRACT(vmetadata, '$.{key}'))"
|
||||||
if isinstance(value, dict) and "$in" in value:
|
if isinstance(value, dict) and '$in' in value:
|
||||||
vals = [str(v) for v in (value.get("$in") or [])]
|
vals = [str(v) for v in (value.get('$in') or [])]
|
||||||
if not vals:
|
if not vals:
|
||||||
clauses.append("0=1")
|
clauses.append('0=1')
|
||||||
continue
|
continue
|
||||||
ors = []
|
ors = []
|
||||||
for v in vals:
|
for v in vals:
|
||||||
ors.append(f"{json_expr} = ?")
|
ors.append(f'{json_expr} = ?')
|
||||||
params.append(v)
|
params.append(v)
|
||||||
clauses.append("(" + " OR ".join(ors) + ")")
|
clauses.append('(' + ' OR '.join(ors) + ')')
|
||||||
else:
|
else:
|
||||||
clauses.append(f"{json_expr} = ?")
|
clauses.append(f'{json_expr} = ?')
|
||||||
params.append(str(value))
|
params.append(str(value))
|
||||||
return ("(" + " AND ".join(clauses) + ")") if clauses else "", params
|
return ('(' + ' AND '.join(clauses) + ')') if clauses else '', params
|
||||||
|
|
||||||
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
def insert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -322,15 +316,15 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
"""
|
"""
|
||||||
params: List[Tuple[Any, ...]] = []
|
params: List[Tuple[Any, ...]] = []
|
||||||
for item in items:
|
for item in items:
|
||||||
v = self.adjust_vector_length(item["vector"])
|
v = self.adjust_vector_length(item['vector'])
|
||||||
emb = _embedding_to_f32_bytes(v)
|
emb = _embedding_to_f32_bytes(v)
|
||||||
meta = process_metadata(item.get("metadata") or {})
|
meta = process_metadata(item.get('metadata') or {})
|
||||||
params.append(
|
params.append(
|
||||||
(
|
(
|
||||||
item["id"],
|
item['id'],
|
||||||
emb,
|
emb,
|
||||||
collection_name,
|
collection_name,
|
||||||
item.get("text"),
|
item.get('text'),
|
||||||
json.dumps(meta),
|
json.dumps(meta),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -338,7 +332,7 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
log.exception(f"Error during insert: {e}")
|
log.exception(f'Error during insert: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
def upsert(self, collection_name: str, items: List[VectorItem]) -> None:
|
||||||
@@ -365,15 +359,15 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
"""
|
"""
|
||||||
params: List[Tuple[Any, ...]] = []
|
params: List[Tuple[Any, ...]] = []
|
||||||
for item in items:
|
for item in items:
|
||||||
v = self.adjust_vector_length(item["vector"])
|
v = self.adjust_vector_length(item['vector'])
|
||||||
emb = _embedding_to_f32_bytes(v)
|
emb = _embedding_to_f32_bytes(v)
|
||||||
meta = process_metadata(item.get("metadata") or {})
|
meta = process_metadata(item.get('metadata') or {})
|
||||||
params.append(
|
params.append(
|
||||||
(
|
(
|
||||||
item["id"],
|
item['id'],
|
||||||
emb,
|
emb,
|
||||||
collection_name,
|
collection_name,
|
||||||
item.get("text"),
|
item.get('text'),
|
||||||
json.dumps(meta),
|
json.dumps(meta),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -381,7 +375,7 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
log.exception(f"Error during upsert: {e}")
|
log.exception(f'Error during upsert: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
@@ -415,10 +409,10 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
with self._connect() as conn:
|
with self._connect() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
fsql, fparams = self._build_filter_sql_qmark(filter or {})
|
fsql, fparams = self._build_filter_sql_qmark(filter or {})
|
||||||
where = "collection_name = ?"
|
where = 'collection_name = ?'
|
||||||
base_params: List[Any] = [collection_name]
|
base_params: List[Any] = [collection_name]
|
||||||
if fsql:
|
if fsql:
|
||||||
where = where + " AND " + fsql
|
where = where + ' AND ' + fsql
|
||||||
base_params.extend(fparams)
|
base_params.extend(fparams)
|
||||||
|
|
||||||
sql = f"""
|
sql = f"""
|
||||||
@@ -460,26 +454,24 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
metadatas=metadatas,
|
metadatas=metadatas,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"[MARIADB_VECTOR] search() failed: {e}")
|
log.exception(f'[MARIADB_VECTOR] search() failed: {e}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def query(
|
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
||||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
|
||||||
) -> Optional[GetResult]:
|
|
||||||
"""
|
"""
|
||||||
Retrieve documents by metadata filter (non-vector query).
|
Retrieve documents by metadata filter (non-vector query).
|
||||||
"""
|
"""
|
||||||
with self._connect() as conn:
|
with self._connect() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
fsql, fparams = self._build_filter_sql_qmark(filter or {})
|
fsql, fparams = self._build_filter_sql_qmark(filter or {})
|
||||||
where = "collection_name = ?"
|
where = 'collection_name = ?'
|
||||||
params: List[Any] = [collection_name]
|
params: List[Any] = [collection_name]
|
||||||
if fsql:
|
if fsql:
|
||||||
where = where + " AND " + fsql
|
where = where + ' AND ' + fsql
|
||||||
params.extend(fparams)
|
params.extend(fparams)
|
||||||
sql = f"SELECT id, text, vmetadata FROM document_chunk WHERE {where}"
|
sql = f'SELECT id, text, vmetadata FROM document_chunk WHERE {where}'
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
sql += " LIMIT ?"
|
sql += ' LIMIT ?'
|
||||||
params.append(int(limit))
|
params.append(int(limit))
|
||||||
cur.execute(sql, params)
|
cur.execute(sql, params)
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
@@ -490,18 +482,16 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
metadatas = [[_safe_json(r[2]) for r in rows]]
|
metadatas = [[_safe_json(r[2]) for r in rows]]
|
||||||
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
return GetResult(ids=ids, documents=documents, metadatas=metadatas)
|
||||||
|
|
||||||
def get(
|
def get(self, collection_name: str, limit: Optional[int] = None) -> Optional[GetResult]:
|
||||||
self, collection_name: str, limit: Optional[int] = None
|
|
||||||
) -> Optional[GetResult]:
|
|
||||||
"""
|
"""
|
||||||
Retrieve documents in a collection without filtering (optionally limited).
|
Retrieve documents in a collection without filtering (optionally limited).
|
||||||
"""
|
"""
|
||||||
with self._connect() as conn:
|
with self._connect() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
sql = "SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?"
|
sql = 'SELECT id, text, vmetadata FROM document_chunk WHERE collection_name = ?'
|
||||||
params: List[Any] = [collection_name]
|
params: List[Any] = [collection_name]
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
sql += " LIMIT ?"
|
sql += ' LIMIT ?'
|
||||||
params.append(int(limit))
|
params.append(int(limit))
|
||||||
cur.execute(sql, params)
|
cur.execute(sql, params)
|
||||||
rows = cur.fetchall()
|
rows = cur.fetchall()
|
||||||
@@ -526,12 +516,12 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
with self._connect() as conn:
|
with self._connect() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
try:
|
try:
|
||||||
where = ["collection_name = ?"]
|
where = ['collection_name = ?']
|
||||||
params: List[Any] = [collection_name]
|
params: List[Any] = [collection_name]
|
||||||
|
|
||||||
if ids:
|
if ids:
|
||||||
ph = ", ".join(["?"] * len(ids))
|
ph = ', '.join(['?'] * len(ids))
|
||||||
where.append(f"id IN ({ph})")
|
where.append(f'id IN ({ph})')
|
||||||
params.extend(ids)
|
params.extend(ids)
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
@@ -540,12 +530,12 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
where.append(fsql)
|
where.append(fsql)
|
||||||
params.extend(fparams)
|
params.extend(fparams)
|
||||||
|
|
||||||
sql = "DELETE FROM document_chunk WHERE " + " AND ".join(where)
|
sql = 'DELETE FROM document_chunk WHERE ' + ' AND '.join(where)
|
||||||
cur.execute(sql, params)
|
cur.execute(sql, params)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
log.exception(f"Error during delete: {e}")
|
log.exception(f'Error during delete: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
@@ -555,11 +545,11 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
with self._connect() as conn:
|
with self._connect() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
try:
|
try:
|
||||||
cur.execute("TRUNCATE TABLE document_chunk")
|
cur.execute('TRUNCATE TABLE document_chunk')
|
||||||
conn.commit()
|
conn.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
log.exception(f"Error during reset: {e}")
|
log.exception(f'Error during reset: {e}')
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def has_collection(self, collection_name: str) -> bool:
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
@@ -570,7 +560,7 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
with self._connect() as conn:
|
with self._connect() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1",
|
'SELECT 1 FROM document_chunk WHERE collection_name = ? LIMIT 1',
|
||||||
(collection_name,),
|
(collection_name,),
|
||||||
)
|
)
|
||||||
return cur.fetchone() is not None
|
return cur.fetchone() is not None
|
||||||
@@ -590,4 +580,4 @@ class MariaDBVectorClient(VectorDBBase):
|
|||||||
try:
|
try:
|
||||||
self.engine.dispose()
|
self.engine.dispose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Error during dispose the underlying SQLAlchemy engine: {e}")
|
log.exception(f'Error during dispose the underlying SQLAlchemy engine: {e}')
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class MilvusClient(VectorDBBase):
|
class MilvusClient(VectorDBBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.collection_prefix = "open_webui"
|
self.collection_prefix = 'open_webui'
|
||||||
if MILVUS_TOKEN is None:
|
if MILVUS_TOKEN is None:
|
||||||
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
|
self.client = Client(uri=MILVUS_URI, db_name=MILVUS_DB)
|
||||||
else:
|
else:
|
||||||
@@ -50,17 +50,17 @@ class MilvusClient(VectorDBBase):
|
|||||||
_documents = []
|
_documents = []
|
||||||
_metadatas = []
|
_metadatas = []
|
||||||
for item in match:
|
for item in match:
|
||||||
_ids.append(item.get("id"))
|
_ids.append(item.get('id'))
|
||||||
_documents.append(item.get("data", {}).get("text"))
|
_documents.append(item.get('data', {}).get('text'))
|
||||||
_metadatas.append(item.get("metadata"))
|
_metadatas.append(item.get('metadata'))
|
||||||
ids.append(_ids)
|
ids.append(_ids)
|
||||||
documents.append(_documents)
|
documents.append(_documents)
|
||||||
metadatas.append(_metadatas)
|
metadatas.append(_metadatas)
|
||||||
return GetResult(
|
return GetResult(
|
||||||
**{
|
**{
|
||||||
"ids": ids,
|
'ids': ids,
|
||||||
"documents": documents,
|
'documents': documents,
|
||||||
"metadatas": metadatas,
|
'metadatas': metadatas,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -75,23 +75,23 @@ class MilvusClient(VectorDBBase):
|
|||||||
_documents = []
|
_documents = []
|
||||||
_metadatas = []
|
_metadatas = []
|
||||||
for item in match:
|
for item in match:
|
||||||
_ids.append(item.get("id"))
|
_ids.append(item.get('id'))
|
||||||
# normalize milvus score from [-1, 1] to [0, 1] range
|
# normalize milvus score from [-1, 1] to [0, 1] range
|
||||||
# https://milvus.io/docs/de/metric.md
|
# https://milvus.io/docs/de/metric.md
|
||||||
_dist = (item.get("distance") + 1.0) / 2.0
|
_dist = (item.get('distance') + 1.0) / 2.0
|
||||||
_distances.append(_dist)
|
_distances.append(_dist)
|
||||||
_documents.append(item.get("entity", {}).get("data", {}).get("text"))
|
_documents.append(item.get('entity', {}).get('data', {}).get('text'))
|
||||||
_metadatas.append(item.get("entity", {}).get("metadata"))
|
_metadatas.append(item.get('entity', {}).get('metadata'))
|
||||||
ids.append(_ids)
|
ids.append(_ids)
|
||||||
distances.append(_distances)
|
distances.append(_distances)
|
||||||
documents.append(_documents)
|
documents.append(_documents)
|
||||||
metadatas.append(_metadatas)
|
metadatas.append(_metadatas)
|
||||||
return SearchResult(
|
return SearchResult(
|
||||||
**{
|
**{
|
||||||
"ids": ids,
|
'ids': ids,
|
||||||
"distances": distances,
|
'distances': distances,
|
||||||
"documents": documents,
|
'documents': documents,
|
||||||
"metadatas": metadatas,
|
'metadatas': metadatas,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -101,21 +101,19 @@ class MilvusClient(VectorDBBase):
|
|||||||
enable_dynamic_field=True,
|
enable_dynamic_field=True,
|
||||||
)
|
)
|
||||||
schema.add_field(
|
schema.add_field(
|
||||||
field_name="id",
|
field_name='id',
|
||||||
datatype=DataType.VARCHAR,
|
datatype=DataType.VARCHAR,
|
||||||
is_primary=True,
|
is_primary=True,
|
||||||
max_length=65535,
|
max_length=65535,
|
||||||
)
|
)
|
||||||
schema.add_field(
|
schema.add_field(
|
||||||
field_name="vector",
|
field_name='vector',
|
||||||
datatype=DataType.FLOAT_VECTOR,
|
datatype=DataType.FLOAT_VECTOR,
|
||||||
dim=dimension,
|
dim=dimension,
|
||||||
description="vector",
|
description='vector',
|
||||||
)
|
|
||||||
schema.add_field(field_name="data", datatype=DataType.JSON, description="data")
|
|
||||||
schema.add_field(
|
|
||||||
field_name="metadata", datatype=DataType.JSON, description="metadata"
|
|
||||||
)
|
)
|
||||||
|
schema.add_field(field_name='data', datatype=DataType.JSON, description='data')
|
||||||
|
schema.add_field(field_name='metadata', datatype=DataType.JSON, description='metadata')
|
||||||
|
|
||||||
index_params = self.client.prepare_index_params()
|
index_params = self.client.prepare_index_params()
|
||||||
|
|
||||||
@@ -123,44 +121,44 @@ class MilvusClient(VectorDBBase):
|
|||||||
index_type = MILVUS_INDEX_TYPE.upper()
|
index_type = MILVUS_INDEX_TYPE.upper()
|
||||||
metric_type = MILVUS_METRIC_TYPE.upper()
|
metric_type = MILVUS_METRIC_TYPE.upper()
|
||||||
|
|
||||||
log.info(f"Using Milvus index type: {index_type}, metric type: {metric_type}")
|
log.info(f'Using Milvus index type: {index_type}, metric type: {metric_type}')
|
||||||
|
|
||||||
index_creation_params = {}
|
index_creation_params = {}
|
||||||
if index_type == "HNSW":
|
if index_type == 'HNSW':
|
||||||
index_creation_params = {
|
index_creation_params = {
|
||||||
"M": MILVUS_HNSW_M,
|
'M': MILVUS_HNSW_M,
|
||||||
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
|
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
|
||||||
}
|
}
|
||||||
log.info(f"HNSW params: {index_creation_params}")
|
log.info(f'HNSW params: {index_creation_params}')
|
||||||
elif index_type == "IVF_FLAT":
|
elif index_type == 'IVF_FLAT':
|
||||||
index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST}
|
index_creation_params = {'nlist': MILVUS_IVF_FLAT_NLIST}
|
||||||
log.info(f"IVF_FLAT params: {index_creation_params}")
|
log.info(f'IVF_FLAT params: {index_creation_params}')
|
||||||
elif index_type == "DISKANN":
|
elif index_type == 'DISKANN':
|
||||||
index_creation_params = {
|
index_creation_params = {
|
||||||
"max_degree": MILVUS_DISKANN_MAX_DEGREE,
|
'max_degree': MILVUS_DISKANN_MAX_DEGREE,
|
||||||
"search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE,
|
'search_list_size': MILVUS_DISKANN_SEARCH_LIST_SIZE,
|
||||||
}
|
}
|
||||||
log.info(f"DISKANN params: {index_creation_params}")
|
log.info(f'DISKANN params: {index_creation_params}')
|
||||||
elif index_type in ["FLAT", "AUTOINDEX"]:
|
elif index_type in ['FLAT', 'AUTOINDEX']:
|
||||||
log.info(f"Using {index_type} index with no specific build-time params.")
|
log.info(f'Using {index_type} index with no specific build-time params.')
|
||||||
else:
|
else:
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
|
f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. "
|
||||||
f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. "
|
f'Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. '
|
||||||
f"Milvus will use its default for the collection if this type is not directly supported for index creation."
|
f'Milvus will use its default for the collection if this type is not directly supported for index creation.'
|
||||||
)
|
)
|
||||||
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
|
# For unsupported types, pass the type directly to Milvus; it might handle it or use a default.
|
||||||
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
|
# If Milvus errors out, the user needs to correct the MILVUS_INDEX_TYPE env var.
|
||||||
|
|
||||||
index_params.add_index(
|
index_params.add_index(
|
||||||
field_name="vector",
|
field_name='vector',
|
||||||
index_type=index_type,
|
index_type=index_type,
|
||||||
metric_type=metric_type,
|
metric_type=metric_type,
|
||||||
params=index_creation_params,
|
params=index_creation_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client.create_collection(
|
self.client.create_collection(
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||||
schema=schema,
|
schema=schema,
|
||||||
index_params=index_params,
|
index_params=index_params,
|
||||||
)
|
)
|
||||||
@@ -170,17 +168,13 @@ class MilvusClient(VectorDBBase):
|
|||||||
|
|
||||||
def has_collection(self, collection_name: str) -> bool:
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
# Check if the collection exists based on the collection name.
|
# Check if the collection exists based on the collection name.
|
||||||
collection_name = collection_name.replace("-", "_")
|
collection_name = collection_name.replace('-', '_')
|
||||||
return self.client.has_collection(
|
return self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str):
|
def delete_collection(self, collection_name: str):
|
||||||
# Delete the collection based on the collection name.
|
# Delete the collection based on the collection name.
|
||||||
collection_name = collection_name.replace("-", "_")
|
collection_name = collection_name.replace('-', '_')
|
||||||
return self.client.drop_collection(
|
return self.client.drop_collection(collection_name=f'{self.collection_prefix}_{collection_name}')
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -190,15 +184,15 @@ class MilvusClient(VectorDBBase):
|
|||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> Optional[SearchResult]:
|
) -> Optional[SearchResult]:
|
||||||
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
|
||||||
collection_name = collection_name.replace("-", "_")
|
collection_name = collection_name.replace('-', '_')
|
||||||
# For some index types like IVF_FLAT, search params like nprobe can be set.
|
# For some index types like IVF_FLAT, search params like nprobe can be set.
|
||||||
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
|
# Example: search_params = {"nprobe": 10} if using IVF_FLAT
|
||||||
# For simplicity, not adding configurable search_params here, but could be extended.
|
# For simplicity, not adding configurable search_params here, but could be extended.
|
||||||
result = self.client.search(
|
result = self.client.search(
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||||
data=vectors,
|
data=vectors,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
output_fields=["data", "metadata"],
|
output_fields=['data', 'metadata'],
|
||||||
# search_params=search_params # Potentially add later if needed
|
# search_params=search_params # Potentially add later if needed
|
||||||
)
|
)
|
||||||
return self._result_to_search_result(result)
|
return self._result_to_search_result(result)
|
||||||
@@ -206,11 +200,9 @@ class MilvusClient(VectorDBBase):
|
|||||||
def query(self, collection_name: str, filter: dict, limit: int = -1):
|
def query(self, collection_name: str, filter: dict, limit: int = -1):
|
||||||
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
|
connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB)
|
||||||
|
|
||||||
collection_name = collection_name.replace("-", "_")
|
collection_name = collection_name.replace('-', '_')
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.warning(
|
log.warning(f'Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
|
||||||
f"Query attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
filter_expressions = []
|
filter_expressions = []
|
||||||
@@ -220,9 +212,9 @@ class MilvusClient(VectorDBBase):
|
|||||||
else:
|
else:
|
||||||
filter_expressions.append(f'metadata["{key}"] == {value}')
|
filter_expressions.append(f'metadata["{key}"] == {value}')
|
||||||
|
|
||||||
filter_string = " && ".join(filter_expressions)
|
filter_string = ' && '.join(filter_expressions)
|
||||||
|
|
||||||
collection = Collection(f"{self.collection_prefix}_{collection_name}")
|
collection = Collection(f'{self.collection_prefix}_{collection_name}')
|
||||||
collection.load()
|
collection.load()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -233,9 +225,9 @@ class MilvusClient(VectorDBBase):
|
|||||||
iterator = collection.query_iterator(
|
iterator = collection.query_iterator(
|
||||||
expr=filter_string,
|
expr=filter_string,
|
||||||
output_fields=[
|
output_fields=[
|
||||||
"id",
|
'id',
|
||||||
"data",
|
'data',
|
||||||
"metadata",
|
'metadata',
|
||||||
],
|
],
|
||||||
limit=limit if limit > 0 else -1,
|
limit=limit if limit > 0 else -1,
|
||||||
)
|
)
|
||||||
@@ -248,7 +240,7 @@ class MilvusClient(VectorDBBase):
|
|||||||
break
|
break
|
||||||
all_results.extend(batch)
|
all_results.extend(batch)
|
||||||
|
|
||||||
log.debug(f"Total results from query: {len(all_results)}")
|
log.debug(f'Total results from query: {len(all_results)}')
|
||||||
return self._result_to_get_result([all_results] if all_results else [[]])
|
return self._result_to_get_result([all_results] if all_results else [[]])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -259,7 +251,7 @@ class MilvusClient(VectorDBBase):
|
|||||||
|
|
||||||
def get(self, collection_name: str) -> Optional[GetResult]:
|
def get(self, collection_name: str) -> Optional[GetResult]:
|
||||||
# Get all the items in the collection. This can be very resource-intensive for large collections.
|
# Get all the items in the collection. This can be very resource-intensive for large collections.
|
||||||
collection_name = collection_name.replace("-", "_")
|
collection_name = collection_name.replace('-', '_')
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
|
f"Fetching ALL items from collection '{self.collection_prefix}_{collection_name}'. This might be slow for large collections."
|
||||||
)
|
)
|
||||||
@@ -269,35 +261,25 @@ class MilvusClient(VectorDBBase):
|
|||||||
|
|
||||||
def insert(self, collection_name: str, items: list[VectorItem]):
|
def insert(self, collection_name: str, items: list[VectorItem]):
|
||||||
# Insert the items into the collection, if the collection does not exist, it will be created.
|
# Insert the items into the collection, if the collection does not exist, it will be created.
|
||||||
collection_name = collection_name.replace("-", "_")
|
collection_name = collection_name.replace('-', '_')
|
||||||
if not self.client.has_collection(
|
if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist. Creating now.')
|
||||||
):
|
|
||||||
log.info(
|
|
||||||
f"Collection {self.collection_prefix}_{collection_name} does not exist. Creating now."
|
|
||||||
)
|
|
||||||
if not items:
|
if not items:
|
||||||
log.error(
|
log.error(
|
||||||
f"Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension."
|
f'Cannot create collection {self.collection_prefix}_{collection_name} without items to determine dimension.'
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot create Milvus collection without items to determine vector dimension."
|
|
||||||
)
|
|
||||||
self._create_collection(
|
|
||||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
|
||||||
)
|
)
|
||||||
|
raise ValueError('Cannot create Milvus collection without items to determine vector dimension.')
|
||||||
|
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
|
||||||
|
|
||||||
log.info(
|
log.info(f'Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
|
||||||
f"Inserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
|
|
||||||
)
|
|
||||||
return self.client.insert(
|
return self.client.insert(
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||||
data=[
|
data=[
|
||||||
{
|
{
|
||||||
"id": item["id"],
|
'id': item['id'],
|
||||||
"vector": item["vector"],
|
'vector': item['vector'],
|
||||||
"data": {"text": item["text"]},
|
'data': {'text': item['text']},
|
||||||
"metadata": process_metadata(item["metadata"]),
|
'metadata': process_metadata(item['metadata']),
|
||||||
}
|
}
|
||||||
for item in items
|
for item in items
|
||||||
],
|
],
|
||||||
@@ -305,35 +287,27 @@ class MilvusClient(VectorDBBase):
|
|||||||
|
|
||||||
def upsert(self, collection_name: str, items: list[VectorItem]):
|
def upsert(self, collection_name: str, items: list[VectorItem]):
|
||||||
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
|
||||||
collection_name = collection_name.replace("-", "_")
|
collection_name = collection_name.replace('-', '_')
|
||||||
if not self.client.has_collection(
|
if not self.client.has_collection(collection_name=f'{self.collection_prefix}_{collection_name}'):
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}"
|
log.info(f'Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now.')
|
||||||
):
|
|
||||||
log.info(
|
|
||||||
f"Collection {self.collection_prefix}_{collection_name} does not exist for upsert. Creating now."
|
|
||||||
)
|
|
||||||
if not items:
|
if not items:
|
||||||
log.error(
|
log.error(
|
||||||
f"Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension."
|
f'Cannot create collection {self.collection_prefix}_{collection_name} for upsert without items to determine dimension.'
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot create Milvus collection for upsert without items to determine vector dimension."
|
'Cannot create Milvus collection for upsert without items to determine vector dimension.'
|
||||||
)
|
|
||||||
self._create_collection(
|
|
||||||
collection_name=collection_name, dimension=len(items[0]["vector"])
|
|
||||||
)
|
)
|
||||||
|
self._create_collection(collection_name=collection_name, dimension=len(items[0]['vector']))
|
||||||
|
|
||||||
log.info(
|
log.info(f'Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}.')
|
||||||
f"Upserting {len(items)} items into collection {self.collection_prefix}_{collection_name}."
|
|
||||||
)
|
|
||||||
return self.client.upsert(
|
return self.client.upsert(
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||||
data=[
|
data=[
|
||||||
{
|
{
|
||||||
"id": item["id"],
|
'id': item['id'],
|
||||||
"vector": item["vector"],
|
'vector': item['vector'],
|
||||||
"data": {"text": item["text"]},
|
'data': {'text': item['text']},
|
||||||
"metadata": process_metadata(item["metadata"]),
|
'metadata': process_metadata(item['metadata']),
|
||||||
}
|
}
|
||||||
for item in items
|
for item in items
|
||||||
],
|
],
|
||||||
@@ -346,46 +320,35 @@ class MilvusClient(VectorDBBase):
|
|||||||
filter: Optional[dict] = None,
|
filter: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
# Delete the items from the collection based on the ids or filter.
|
# Delete the items from the collection based on the ids or filter.
|
||||||
collection_name = collection_name.replace("-", "_")
|
collection_name = collection_name.replace('-', '_')
|
||||||
if not self.has_collection(collection_name):
|
if not self.has_collection(collection_name):
|
||||||
log.warning(
|
log.warning(f'Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}')
|
||||||
f"Delete attempted on non-existent collection: {self.collection_prefix}_{collection_name}"
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if ids:
|
if ids:
|
||||||
log.info(
|
log.info(f'Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}')
|
||||||
f"Deleting items by IDs from {self.collection_prefix}_{collection_name}. IDs: {ids}"
|
|
||||||
)
|
|
||||||
return self.client.delete(
|
return self.client.delete(
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||||
ids=ids,
|
ids=ids,
|
||||||
)
|
)
|
||||||
elif filter:
|
elif filter:
|
||||||
filter_string = " && ".join(
|
filter_string = ' && '.join([f'metadata["{key}"] == {json.dumps(value)}' for key, value in filter.items()])
|
||||||
[
|
|
||||||
f'metadata["{key}"] == {json.dumps(value)}'
|
|
||||||
for key, value in filter.items()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
log.info(
|
log.info(
|
||||||
f"Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}"
|
f'Deleting items by filter from {self.collection_prefix}_{collection_name}. Filter: {filter_string}'
|
||||||
)
|
)
|
||||||
return self.client.delete(
|
return self.client.delete(
|
||||||
collection_name=f"{self.collection_prefix}_{collection_name}",
|
collection_name=f'{self.collection_prefix}_{collection_name}',
|
||||||
filter=filter_string,
|
filter=filter_string,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.warning(
|
log.warning(
|
||||||
f"Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken."
|
f'Delete operation on {self.collection_prefix}_{collection_name} called without IDs or filter. No action taken.'
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
# Resets the database. This will delete all collections and item entries that match the prefix.
|
# Resets the database. This will delete all collections and item entries that match the prefix.
|
||||||
log.warning(
|
log.warning(f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'.")
|
||||||
f"Resetting Milvus: Deleting all collections with prefix '{self.collection_prefix}'."
|
|
||||||
)
|
|
||||||
collection_names = self.client.list_collections()
|
collection_names = self.client.list_collections()
|
||||||
deleted_collections = []
|
deleted_collections = []
|
||||||
for collection_name_full in collection_names:
|
for collection_name_full in collection_names:
|
||||||
@@ -393,7 +356,7 @@ class MilvusClient(VectorDBBase):
|
|||||||
try:
|
try:
|
||||||
self.client.drop_collection(collection_name=collection_name_full)
|
self.client.drop_collection(collection_name=collection_name_full)
|
||||||
deleted_collections.append(collection_name_full)
|
deleted_collections.append(collection_name_full)
|
||||||
log.info(f"Deleted collection: {collection_name_full}")
|
log.info(f'Deleted collection: {collection_name_full}')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"Error deleting collection {collection_name_full}: {e}")
|
log.error(f'Error deleting collection {collection_name_full}: {e}')
|
||||||
log.info(f"Milvus reset complete. Deleted collections: {deleted_collections}")
|
log.info(f'Milvus reset complete. Deleted collections: {deleted_collections}')
|
||||||
|
|||||||
@@ -33,26 +33,26 @@ from pymilvus import (
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
RESOURCE_ID_FIELD = "resource_id"
|
RESOURCE_ID_FIELD = 'resource_id'
|
||||||
|
|
||||||
|
|
||||||
class MilvusClient(VectorDBBase):
|
class MilvusClient(VectorDBBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Milvus collection names can only contain numbers, letters, and underscores.
|
# Milvus collection names can only contain numbers, letters, and underscores.
|
||||||
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
|
self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace('-', '_')
|
||||||
connections.connect(
|
connections.connect(
|
||||||
alias="default",
|
alias='default',
|
||||||
uri=MILVUS_URI,
|
uri=MILVUS_URI,
|
||||||
token=MILVUS_TOKEN,
|
token=MILVUS_TOKEN,
|
||||||
db_name=MILVUS_DB,
|
db_name=MILVUS_DB,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Main collection types for multi-tenancy
|
# Main collection types for multi-tenancy
|
||||||
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
|
self.MEMORY_COLLECTION = f'{self.collection_prefix}_memories'
|
||||||
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
|
self.KNOWLEDGE_COLLECTION = f'{self.collection_prefix}_knowledge'
|
||||||
self.FILE_COLLECTION = f"{self.collection_prefix}_files"
|
self.FILE_COLLECTION = f'{self.collection_prefix}_files'
|
||||||
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
|
self.WEB_SEARCH_COLLECTION = f'{self.collection_prefix}_web_search'
|
||||||
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
|
self.HASH_BASED_COLLECTION = f'{self.collection_prefix}_hash_based'
|
||||||
self.shared_collections = [
|
self.shared_collections = [
|
||||||
self.MEMORY_COLLECTION,
|
self.MEMORY_COLLECTION,
|
||||||
self.KNOWLEDGE_COLLECTION,
|
self.KNOWLEDGE_COLLECTION,
|
||||||
@@ -74,15 +74,13 @@ class MilvusClient(VectorDBBase):
|
|||||||
"""
|
"""
|
||||||
resource_id = collection_name
|
resource_id = collection_name
|
||||||
|
|
||||||
if collection_name.startswith("user-memory-"):
|
if collection_name.startswith('user-memory-'):
|
||||||
return self.MEMORY_COLLECTION, resource_id
|
return self.MEMORY_COLLECTION, resource_id
|
||||||
elif collection_name.startswith("file-"):
|
elif collection_name.startswith('file-'):
|
||||||
return self.FILE_COLLECTION, resource_id
|
return self.FILE_COLLECTION, resource_id
|
||||||
elif collection_name.startswith("web-search-"):
|
elif collection_name.startswith('web-search-'):
|
||||||
return self.WEB_SEARCH_COLLECTION, resource_id
|
return self.WEB_SEARCH_COLLECTION, resource_id
|
||||||
elif len(collection_name) == 63 and all(
|
elif len(collection_name) == 63 and all(c in '0123456789abcdef' for c in collection_name):
|
||||||
c in "0123456789abcdef" for c in collection_name
|
|
||||||
):
|
|
||||||
return self.HASH_BASED_COLLECTION, resource_id
|
return self.HASH_BASED_COLLECTION, resource_id
|
||||||
else:
|
else:
|
||||||
return self.KNOWLEDGE_COLLECTION, resource_id
|
return self.KNOWLEDGE_COLLECTION, resource_id
|
||||||
@@ -90,36 +88,36 @@ class MilvusClient(VectorDBBase):
|
|||||||
def _create_shared_collection(self, mt_collection_name: str, dimension: int):
|
def _create_shared_collection(self, mt_collection_name: str, dimension: int):
|
||||||
fields = [
|
fields = [
|
||||||
FieldSchema(
|
FieldSchema(
|
||||||
name="id",
|
name='id',
|
||||||
dtype=DataType.VARCHAR,
|
dtype=DataType.VARCHAR,
|
||||||
is_primary=True,
|
is_primary=True,
|
||||||
auto_id=False,
|
auto_id=False,
|
||||||
max_length=36,
|
max_length=36,
|
||||||
),
|
),
|
||||||
FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
FieldSchema(name='vector', dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
||||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
FieldSchema(name='text', dtype=DataType.VARCHAR, max_length=65535),
|
||||||
FieldSchema(name="metadata", dtype=DataType.JSON),
|
FieldSchema(name='metadata', dtype=DataType.JSON),
|
||||||
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
|
FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
|
||||||
]
|
]
|
||||||
schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
|
schema = CollectionSchema(fields, 'Shared collection for multi-tenancy')
|
||||||
collection = Collection(mt_collection_name, schema)
|
collection = Collection(mt_collection_name, schema)
|
||||||
|
|
||||||
index_params = {
|
index_params = {
|
||||||
"metric_type": MILVUS_METRIC_TYPE,
|
'metric_type': MILVUS_METRIC_TYPE,
|
||||||
"index_type": MILVUS_INDEX_TYPE,
|
'index_type': MILVUS_INDEX_TYPE,
|
||||||
"params": {},
|
'params': {},
|
||||||
}
|
}
|
||||||
if MILVUS_INDEX_TYPE == "HNSW":
|
if MILVUS_INDEX_TYPE == 'HNSW':
|
||||||
index_params["params"] = {
|
index_params['params'] = {
|
||||||
"M": MILVUS_HNSW_M,
|
'M': MILVUS_HNSW_M,
|
||||||
"efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
|
'efConstruction': MILVUS_HNSW_EFCONSTRUCTION,
|
||||||
}
|
}
|
||||||
elif MILVUS_INDEX_TYPE == "IVF_FLAT":
|
elif MILVUS_INDEX_TYPE == 'IVF_FLAT':
|
||||||
index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}
|
index_params['params'] = {'nlist': MILVUS_IVF_FLAT_NLIST}
|
||||||
|
|
||||||
collection.create_index("vector", index_params)
|
collection.create_index('vector', index_params)
|
||||||
collection.create_index(RESOURCE_ID_FIELD)
|
collection.create_index(RESOURCE_ID_FIELD)
|
||||||
log.info(f"Created shared collection: {mt_collection_name}")
|
log.info(f'Created shared collection: {mt_collection_name}')
|
||||||
return collection
|
return collection
|
||||||
|
|
||||||
def _ensure_collection(self, mt_collection_name: str, dimension: int):
|
def _ensure_collection(self, mt_collection_name: str, dimension: int):
|
||||||
@@ -127,9 +125,7 @@ class MilvusClient(VectorDBBase):
|
|||||||
self._create_shared_collection(mt_collection_name, dimension)
|
self._create_shared_collection(mt_collection_name, dimension)
|
||||||
|
|
||||||
def has_collection(self, collection_name: str) -> bool:
|
def has_collection(self, collection_name: str) -> bool:
|
||||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||||
collection_name
|
|
||||||
)
|
|
||||||
if not utility.has_collection(mt_collection):
|
if not utility.has_collection(mt_collection):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -141,19 +137,17 @@ class MilvusClient(VectorDBBase):
|
|||||||
def upsert(self, collection_name: str, items: List[VectorItem]):
|
def upsert(self, collection_name: str, items: List[VectorItem]):
|
||||||
if not items:
|
if not items:
|
||||||
return
|
return
|
||||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||||
collection_name
|
dimension = len(items[0]['vector'])
|
||||||
)
|
|
||||||
dimension = len(items[0]["vector"])
|
|
||||||
self._ensure_collection(mt_collection, dimension)
|
self._ensure_collection(mt_collection, dimension)
|
||||||
collection = Collection(mt_collection)
|
collection = Collection(mt_collection)
|
||||||
|
|
||||||
entities = [
|
entities = [
|
||||||
{
|
{
|
||||||
"id": item["id"],
|
'id': item['id'],
|
||||||
"vector": item["vector"],
|
'vector': item['vector'],
|
||||||
"text": item["text"],
|
'text': item['text'],
|
||||||
"metadata": item["metadata"],
|
'metadata': item['metadata'],
|
||||||
RESOURCE_ID_FIELD: resource_id,
|
RESOURCE_ID_FIELD: resource_id,
|
||||||
}
|
}
|
||||||
for item in items
|
for item in items
|
||||||
@@ -170,41 +164,37 @@ class MilvusClient(VectorDBBase):
|
|||||||
if not vectors:
|
if not vectors:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||||
collection_name
|
|
||||||
)
|
|
||||||
if not utility.has_collection(mt_collection):
|
if not utility.has_collection(mt_collection):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
collection = Collection(mt_collection)
|
collection = Collection(mt_collection)
|
||||||
collection.load()
|
collection.load()
|
||||||
|
|
||||||
search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
|
search_params = {'metric_type': MILVUS_METRIC_TYPE, 'params': {}}
|
||||||
results = collection.search(
|
results = collection.search(
|
||||||
data=vectors,
|
data=vectors,
|
||||||
anns_field="vector",
|
anns_field='vector',
|
||||||
param=search_params,
|
param=search_params,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
|
expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
|
||||||
output_fields=["id", "text", "metadata"],
|
output_fields=['id', 'text', 'metadata'],
|
||||||
)
|
)
|
||||||
|
|
||||||
ids, documents, metadatas, distances = [], [], [], []
|
ids, documents, metadatas, distances = [], [], [], []
|
||||||
for hits in results:
|
for hits in results:
|
||||||
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
|
batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
|
||||||
for hit in hits:
|
for hit in hits:
|
||||||
batch_ids.append(hit.entity.get("id"))
|
batch_ids.append(hit.entity.get('id'))
|
||||||
batch_docs.append(hit.entity.get("text"))
|
batch_docs.append(hit.entity.get('text'))
|
||||||
batch_metadatas.append(hit.entity.get("metadata"))
|
batch_metadatas.append(hit.entity.get('metadata'))
|
||||||
batch_dists.append(hit.distance)
|
batch_dists.append(hit.distance)
|
||||||
ids.append(batch_ids)
|
ids.append(batch_ids)
|
||||||
documents.append(batch_docs)
|
documents.append(batch_docs)
|
||||||
metadatas.append(batch_metadatas)
|
metadatas.append(batch_metadatas)
|
||||||
distances.append(batch_dists)
|
distances.append(batch_dists)
|
||||||
|
|
||||||
return SearchResult(
|
return SearchResult(ids=ids, documents=documents, metadatas=metadatas, distances=distances)
|
||||||
ids=ids, documents=documents, metadatas=metadatas, distances=distances
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(
|
def delete(
|
||||||
self,
|
self,
|
||||||
@@ -212,9 +202,7 @@ class MilvusClient(VectorDBBase):
|
|||||||
ids: Optional[List[str]] = None,
|
ids: Optional[List[str]] = None,
|
||||||
filter: Optional[Dict[str, Any]] = None,
|
filter: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||||
collection_name
|
|
||||||
)
|
|
||||||
if not utility.has_collection(mt_collection):
|
if not utility.has_collection(mt_collection):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -224,14 +212,14 @@ class MilvusClient(VectorDBBase):
|
|||||||
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
|
expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
|
||||||
if ids:
|
if ids:
|
||||||
# Milvus expects a string list for 'in' operator
|
# Milvus expects a string list for 'in' operator
|
||||||
id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
|
id_list_str = ', '.join([f"'{id_val}'" for id_val in ids])
|
||||||
expr.append(f"id in [{id_list_str}]")
|
expr.append(f'id in [{id_list_str}]')
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
for key, value in filter.items():
|
for key, value in filter.items():
|
||||||
expr.append(f"metadata['{key}'] == '{value}'")
|
expr.append(f"metadata['{key}'] == '{value}'")
|
||||||
|
|
||||||
collection.delete(" and ".join(expr))
|
collection.delete(' and '.join(expr))
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
for collection_name in self.shared_collections:
|
for collection_name in self.shared_collections:
|
||||||
@@ -239,21 +227,15 @@ class MilvusClient(VectorDBBase):
|
|||||||
utility.drop_collection(collection_name)
|
utility.drop_collection(collection_name)
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str):
|
def delete_collection(self, collection_name: str):
|
||||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||||
collection_name
|
|
||||||
)
|
|
||||||
if not utility.has_collection(mt_collection):
|
if not utility.has_collection(mt_collection):
|
||||||
return
|
return
|
||||||
|
|
||||||
collection = Collection(mt_collection)
|
collection = Collection(mt_collection)
|
||||||
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
|
collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")
|
||||||
|
|
||||||
def query(
|
def query(self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None) -> Optional[GetResult]:
|
||||||
self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
|
mt_collection, resource_id = self._get_collection_and_resource_id(collection_name)
|
||||||
) -> Optional[GetResult]:
|
|
||||||
mt_collection, resource_id = self._get_collection_and_resource_id(
|
|
||||||
collection_name
|
|
||||||
)
|
|
||||||
if not utility.has_collection(mt_collection):
|
if not utility.has_collection(mt_collection):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -269,8 +251,8 @@ class MilvusClient(VectorDBBase):
|
|||||||
expr.append(f"metadata['{key}'] == {value}")
|
expr.append(f"metadata['{key}'] == {value}")
|
||||||
|
|
||||||
iterator = collection.query_iterator(
|
iterator = collection.query_iterator(
|
||||||
expr=" and ".join(expr),
|
expr=' and '.join(expr),
|
||||||
output_fields=["id", "text", "metadata"],
|
output_fields=['id', 'text', 'metadata'],
|
||||||
limit=limit if limit else -1,
|
limit=limit if limit else -1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -282,9 +264,9 @@ class MilvusClient(VectorDBBase):
|
|||||||
break
|
break
|
||||||
all_results.extend(batch)
|
all_results.extend(batch)
|
||||||
|
|
||||||
ids = [res["id"] for res in all_results]
|
ids = [res['id'] for res in all_results]
|
||||||
documents = [res["text"] for res in all_results]
|
documents = [res['text'] for res in all_results]
|
||||||
metadatas = [res["metadata"] for res in all_results]
|
metadatas = [res['metadata'] for res in all_results]
|
||||||
|
|
||||||
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user